diff --git a/conn.go b/conn.go index 86f718b..a860d40 100644 --- a/conn.go +++ b/conn.go @@ -4,23 +4,36 @@ import ( "fmt" "io" "net" + "time" ) type transmitFunc func(buf []byte) (bool, error) type Conn struct { - addr net.Addr - conn net.PacketConn - protocol *Protocol + addr net.Addr + conn net.PacketConn + + p *Protocol + + updatePeriod time.Duration // how often time-dependant parts of the p get checked + resendTimeout time.Duration // how long we wait until unacked packets should be resent + + exit chan struct{} // signal channel to close the conn } func NewConn(addr net.Addr, conn net.PacketConn, opts ...ProtocolOption) *Conn { - p := NewProtocol(opts...) - return &Conn{addr: addr, conn: conn, protocol: p} + return &Conn{ + addr: addr, + conn: conn, + p: NewProtocol(opts...), + updatePeriod: DefaultUpdatePeriod, + resendTimeout: DefaultResendTimeout, + exit: make(chan struct{}), + } } func (c *Conn) WriteReliablePacket(buf []byte) error { - buf, err := c.protocol.WritePacket(true, buf) + buf, err := c.p.WritePacket(true, buf) if err != nil { return err } @@ -30,7 +43,7 @@ func (c *Conn) WriteReliablePacket(buf []byte) error { } func (c *Conn) WriteUnreliablePacket(buf []byte) error { - buf, err := c.protocol.WritePacket(false, buf) + buf, err := c.p.WritePacket(false, buf) if err != nil { return err } @@ -40,22 +53,33 @@ func (c *Conn) WriteUnreliablePacket(buf []byte) error { } func (c *Conn) Read(header PacketHeader, buf []byte) error { - buf = c.protocol.ReadPacket(header, buf) - - if len(buf) != 0 { - _, err := c.transmit(buf) - return err + needed := c.p.ReadPacket(header, buf) + if !needed { + return nil } - return nil + return c.writeAcks() } func (c *Conn) Close() { - c.protocol.Close() + close(c.exit) + c.p.Close() } func (c *Conn) Run() { - c.protocol.Run(c.transmit) + ticker := time.NewTicker(c.updatePeriod) + defer ticker.Stop() + + for { + select { + case <-c.exit: + return + case <-ticker.C: + if err := c.retransmitUnackedPackets(); err != nil { + c.p.callErrorHandler(err) + } + } + } } func (c *Conn) transmit(buf []byte) (EOF bool, err error) { @@ -74,3 +98,46 @@ func (c *Conn) transmit(buf []byte) (EOF bool, err error) { return } + +func (c *Conn) writeAcks() error { + c.p.mu.Lock() + defer c.p.mu.Unlock() + + for { + needed := c.p.checkIfAck() + if !needed { + break + } + + header := c.p.createAck() + + buf := c.p.write(header, nil) + if _, err := c.transmit(buf); err != nil { + return fmt.Errorf("failed to transmit acks: %w", err) + } + } + + return nil +} + +func (c *Conn) retransmitUnackedPackets() error { + c.p.mu.Lock() + defer c.p.mu.Unlock() + + for idx := uint16(0); idx < c.p.writeQueueLen(); idx++ { + buf, needed := c.p.checkIfRetransmit(idx, c.resendTimeout) + if !needed { + continue + } + + if isEOF, err := c.transmit(buf); err != nil { + return fmt.Errorf("failed to retransmit unacked packet: %w", err) + } else if isEOF { + break + } + + c.p.incrementWqe(idx) + } + + return nil +} diff --git a/conn_test.go b/conn_test.go index ae87e4a..19fd172 100644 --- a/conn_test.go +++ b/conn_test.go @@ -27,22 +27,28 @@ func TestConnWriteReliablePacket(t *testing.T) { require.EqualValues(t, data, buf) } - ca := NewConn(a.LocalAddr(), a, WithProtocolPacketHandler(handler)) - cb := NewConn(b.LocalAddr(), b, WithProtocolPacketHandler(handler)) + ca := NewConn(b.LocalAddr(), a, WithProtocolPacketHandler(handler)) + cb := NewConn(a.LocalAddr(), b, WithProtocolPacketHandler(handler)) go readLoop(t, a, ca) go readLoop(t, b, cb) + go ca.Run() + go cb.Run() + defer func() { + // Note: Guarantee that all messages are deliverd + time.Sleep(1 * time.Second) + require.NoError(t, a.SetDeadline(time.Now().Add(1*time.Millisecond))) require.NoError(t, b.SetDeadline(time.Now().Add(1*time.Millisecond))) - require.NoError(t, a.Close()) - require.NoError(t, b.Close()) - ca.Close() cb.Close() + require.NoError(t, a.Close()) + require.NoError(t, b.Close()) + require.EqualValues(t, expected, atomic.LoadUint64(&actual)) }() diff --git a/endpoint.go b/endpoint.go index 22ed72c..93be8c5 100644 --- a/endpoint.go +++ b/endpoint.go @@ -6,7 +6,6 @@ import ( "net" "sync" "sync/atomic" - "time" ) type EndpointPacketHandler func(buf []byte, addr net.Addr) @@ -16,9 +15,6 @@ type Endpoint struct { writeBufferSize uint16 // write buffer size that must be a divisor of 65536 readBufferSize uint16 // read buffer size that must be a divisor of 65536 - updatePeriod time.Duration // how often time-dependant parts of the protocol get checked - resendTimeout time.Duration // how long we wait until unacked packets should be resent - mu sync.Mutex wg sync.WaitGroup @@ -49,14 +45,6 @@ func NewEndpoint(conn net.PacketConn, opts ...EndpointOption) *Endpoint { e.readBufferSize = DefaultReadBufferSize } - if e.resendTimeout == 0 { - e.resendTimeout = DefaultResendTimeout - } - - if e.updatePeriod == 0 { - e.updatePeriod = DefaultUpdatePeriod - } - if e.pool == nil { e.pool = new(Pool) } @@ -81,8 +69,6 @@ func (e *Endpoint) getConn(addr net.Addr) *Conn { e.conn, WithWriteBufferSize(e.writeBufferSize), WithReadBufferSize(e.readBufferSize), - WithUpdatePeriod(e.updatePeriod), - WithResendTimeout(e.resendTimeout), WithBufferPool(e.pool), ) @@ -127,23 +113,33 @@ func (e *Endpoint) Addr() net.Addr { return e.addr } -func (e *Endpoint) WriteReliablePacket(buf []byte, addr net.Addr) error { +func (e *Endpoint) WriteReliablePacket(buf []byte, addr net.Addr, opts ...ConnOption) error { conn := e.getConn(addr) if conn == nil { return io.EOF } + + for _, opt := range opts { + opt.applyConn(conn) + } + return conn.WriteReliablePacket(buf) } -func (e *Endpoint) WriteUnreliablePacket(buf []byte, addr net.Addr) error { +func (e *Endpoint) WriteUnreliablePacket(buf []byte, addr net.Addr, opts ...ConnOption) error { conn := e.getConn(addr) if conn == nil { return io.EOF } + + for _, opt := range opts { + opt.applyConn(conn) + } + return conn.WriteUnreliablePacket(buf) } -func (e *Endpoint) Listen() { +func (e *Endpoint) Listen(opts ...ConnOption) { e.mu.Lock() e.wg.Add(1) e.mu.Unlock() @@ -168,6 +164,10 @@ func (e *Endpoint) Listen() { break } + for _, opt := range opts { + opt.applyConn(conn) + } + header, buf, err := UnmarshalPacketHeader(buf[:n]) if err != nil { e.clearConn(addr) diff --git a/endpoint_test.go b/endpoint_test.go index f572fc6..31a2a6a 100644 --- a/endpoint_test.go +++ b/endpoint_test.go @@ -8,7 +8,6 @@ import ( "sort" "strconv" "sync" - "sync/atomic" "testing" "time" ) @@ -89,26 +88,18 @@ func BenchmarkEndpointWriteUnreliablePacket(b *testing.B) { func TestEndpointWriteReliablePacket(t *testing.T) { defer goleak.VerifyNone(t) - var mu sync.Mutex - - values := make(map[string]struct{}) - - actual := uint64(0) - expected := uint64(65536) + var ( + expected []int + actual []int + loop uint64 = 65536 + ) handler := func(buf []byte, _ net.Addr) { if len(buf) == 0 { return } - - atomic.AddUint64(&actual, 1) - - mu.Lock() - _, exists := values[string(buf)] - delete(values, string(buf)) - mu.Unlock() - - require.True(t, exists) + num, _ := strconv.Atoi(string(buf)) + actual = append(actual, num) } ca := newPacketConn(t, "127.0.0.1:0") @@ -121,6 +112,9 @@ func TestEndpointWriteReliablePacket(t *testing.T) { go b.Listen() defer func() { + // Note: Guarantee that all messages are deliverd + time.Sleep(1 * time.Second) + require.NoError(t, ca.SetDeadline(time.Now().Add(1*time.Millisecond))) require.NoError(t, cb.SetDeadline(time.Now().Add(1*time.Millisecond))) @@ -130,15 +124,12 @@ func TestEndpointWriteReliablePacket(t *testing.T) { require.NoError(t, ca.Close()) require.NoError(t, cb.Close()) - require.EqualValues(t, expected, atomic.LoadUint64(&actual)) + require.EqualValues(t, expected, uniqSort(actual)) }() - for i := uint64(0); i < expected; i++ { + for i := uint64(0); i < loop; i++ { data := strconv.AppendUint(nil, i, 10) - - mu.Lock() - values[string(data)] = struct{}{} - mu.Unlock() + expected = append(expected, int(i)) require.NoError(t, a.WriteReliablePacket(data, b.Addr())) } @@ -147,14 +138,21 @@ func TestEndpointWriteReliablePacket(t *testing.T) { func TestEndpointWriteReliablePacketEndToEnd(t *testing.T) { defer goleak.VerifyNone(t) - actual := uint64(0) - expected := uint64(512) + var ( + expected []int + actual []int + loop uint64 = 512 + mu sync.Mutex + ) handler := func(buf []byte, _ net.Addr) { if len(buf) == 0 { return } - atomic.AddUint64(&actual, 1) + mu.Lock() + num, _ := strconv.Atoi(string(buf)) + actual = append(actual, num) + mu.Unlock() } ca := newPacketConn(t, "127.0.0.1:0") @@ -176,14 +174,17 @@ func TestEndpointWriteReliablePacketEndToEnd(t *testing.T) { require.NoError(t, ca.Close()) require.NoError(t, cb.Close()) - require.EqualValues(t, expected*2, atomic.LoadUint64(&actual)) + sort.Ints(expected) + require.EqualValues(t, expected, uniqSort(actual)) }() - for i := uint64(0); i < expected; i++ { - data := strconv.AppendUint(nil, i, 10) + for i := uint64(0); i < loop; i++ { + dataA := strconv.AppendUint(nil, i, 10) + dataB := strconv.AppendUint(nil, i+loop, 10) + expected = append(expected, int(i), int(i+loop)) - require.NoError(t, a.WriteReliablePacket(data, b.Addr())) - require.NoError(t, b.WriteReliablePacket(data, a.Addr())) + require.NoError(t, a.WriteReliablePacket(dataA, b.Addr())) + require.NoError(t, b.WriteReliablePacket(dataB, a.Addr())) } } diff --git a/fuzz.go b/fuzz.go index 109098e..c99ca6c 100644 --- a/fuzz.go +++ b/fuzz.go @@ -28,13 +28,13 @@ func Fuzz(data []byte) int { chErr <- errors.New("data miss match") } - ea := NewEndpoint(ca, reliable.WithEndpointPacketHandler(handler)) - eb := NewEndpoint(cb, reliable.WithEndpointPacketHandler(handler)) + ea := NewEndpoint(ca, WithEndpointPacketHandler(handler)) + eb := NewEndpoint(cb, WithEndpointPacketHandler(handler)) go ea.Listen() go eb.Listen() - for i := 0; i < 65536; i++ { + for i := 0; i < 4096; i++ { select { case <-chErr: return 0 diff --git a/options.go b/options.go index 71f0681..47331bc 100644 --- a/options.go +++ b/options.go @@ -18,6 +18,10 @@ type EndpointOption interface { applyEndpoint(e *Endpoint) } +type ConnOption interface { + applyConn(c *Conn) +} + type Option interface { ProtocolOption EndpointOption @@ -82,10 +86,9 @@ func WithEndpointErrorHandler(eh EndpointErrorHandler) EndpointOption { type withUpdatePeriod struct{ updatePeriod time.Duration } -func (o withUpdatePeriod) applyProtocol(p *Protocol) { p.updatePeriod = o.updatePeriod } -func (o withUpdatePeriod) applyEndpoint(e *Endpoint) { e.updatePeriod = o.updatePeriod } +func (o withUpdatePeriod) applyConn(c *Conn) { c.updatePeriod = o.updatePeriod } -func WithUpdatePeriod(updatePeriod time.Duration) Option { +func WithUpdatePeriod(updatePeriod time.Duration) ConnOption { if updatePeriod == 0 { panic("update period of zero is not supported yet") } @@ -94,10 +97,9 @@ func WithUpdatePeriod(updatePeriod time.Duration) Option { type withResendTimeout struct{ resendTimeout time.Duration } -func (o withResendTimeout) applyProtocol(p *Protocol) { p.resendTimeout = o.resendTimeout } -func (o withResendTimeout) applyEndpoint(e *Endpoint) { e.resendTimeout = o.resendTimeout } +func (o withResendTimeout) applyConn(c *Conn) { c.resendTimeout = o.resendTimeout } -func WithResendTimeout(resendTimeout time.Duration) Option { +func WithResendTimeout(resendTimeout time.Duration) ConnOption { if resendTimeout == 0 { panic("ack timeout of zero is not supported yet") } diff --git a/options_test.go b/options_test.go new file mode 100644 index 0000000..9eb1414 --- /dev/null +++ b/options_test.go @@ -0,0 +1,27 @@ +package reliable + +import ( + "github.com/stretchr/testify/require" + "net" + "testing" + "time" +) + +func TestConnApply(t *testing.T) { + a, _ := net.ListenPacket("udp", "127.0.0.1:0") + b, _ := net.ListenPacket("udp", "127.0.0.1:0") + + ca := NewConn(b.LocalAddr(), a) + + expectedUpdatePeriod := 1 * time.Second + expectedResendTimeout := 2 * time.Second + + uOpt := WithUpdatePeriod(expectedUpdatePeriod) + rOpt := WithResendTimeout(expectedResendTimeout) + + uOpt.applyConn(ca) + require.EqualValues(t, expectedUpdatePeriod, ca.updatePeriod) + + rOpt.applyConn(ca) + require.EqualValues(t, expectedResendTimeout, ca.resendTimeout) +} diff --git a/protocol.go b/protocol.go index 43b6b2e..cefb0c8 100644 --- a/protocol.go +++ b/protocol.go @@ -1,7 +1,6 @@ package reliable import ( - "fmt" "github.com/lithdew/seq" "io" "sync" @@ -15,9 +14,6 @@ type Protocol struct { writeBufferSize uint16 // write buffer size that must be a divisor of 65536 readBufferSize uint16 // read buffer size that must be a divisor of 65536 - updatePeriod time.Duration // how often time-dependant parts of the protocol get checked - resendTimeout time.Duration // how long we wait until unacked packets should be resent - pool *Pool ph ProtocolPacketHandler @@ -56,14 +52,6 @@ func NewProtocol(opts ...ProtocolOption) *Protocol { p.readBufferSize = DefaultReadBufferSize } - if p.resendTimeout == 0 { - p.resendTimeout = DefaultResendTimeout - } - - if p.updatePeriod == 0 { - p.updatePeriod = DefaultUpdatePeriod - } - if p.pool == nil { p.pool = new(Pool) } @@ -104,7 +92,7 @@ func (p *Protocol) WritePacket(reliable bool, buf []byte) ([]byte, error) { p.trackAcked(ack) - // log.Printf("%v: send (seq=%05d) (ack=%05d) (ack_bits=%032b) (size=%d) (reliable=%t)", &p, idx, ack, ackBits, len(buf), reliable) + // log.Printf("%p: send (seq=%05d) (ack=%05d) (ack_bits=%032b) (size=%d) (reliable=%t)", p, idx, ack, ackBits, len(buf), reliable) return p.write(PacketHeader{Sequence: idx, ACK: ack, ACKBits: ackBits, Unordered: !reliable}, buf), nil } @@ -201,41 +189,46 @@ func (p *Protocol) clearWrites(start, end uint16) { emptyBufferIndices(second) } -func (p *Protocol) ReadPacket(header PacketHeader, buf []byte) []byte { +func (p *Protocol) ReadPacket(header PacketHeader, buf []byte) (needed bool) { p.mu.Lock() defer p.mu.Unlock() p.readAckBits(header.ACK, header.ACKBits) if !header.Unordered && !p.trackRead(header.Sequence) { - return nil + return } p.trackUnacked() if header.Empty { - return nil + return } if p.ph != nil { p.ph(buf, header.Sequence) } - // log.Printf("%v: recv (seq=%05d) (ack=%05d) (ack_bits=%032b) (size=%d) (reliable=%t)", &p, header.Sequence, header.ACK, header.ACKBits, len(buf), !header.Unordered) + // log.Printf("%p: recv (seq=%05d) (ack=%05d) (ack_bits=%032b) (size=%d) (reliable=%t)", p, header.Sequence, header.ACK, header.ACKBits, len(buf), !header.Unordered) - return p.writeAcksIfNecessary() + needed = p.checkIfAck() + return } -func (p *Protocol) createAckIfNecessary() (header PacketHeader, needed bool) { +func (p *Protocol) checkIfAck() bool { lui := p.lui for i := uint16(0); i < ACKBitsetSize; i++ { if p.rq[(lui+i)%uint16(len(p.rq))] != uint32(lui+i) { - return header, needed + return false } } - lui += ACKBitsetSize + return !p.die +} + +func (p *Protocol) createAck() (header PacketHeader) { + lui := p.lui + ACKBitsetSize p.lui = lui p.ls = time.Now() @@ -245,22 +238,7 @@ func (p *Protocol) createAckIfNecessary() (header PacketHeader, needed bool) { header.ACKBits = p.prepareAckBits(header.ACK) header.Empty = true - needed = !p.die - - return header, needed -} - -func (p *Protocol) writeAcksIfNecessary() []byte { - for { - header, needed := p.createAckIfNecessary() - if !needed { - return nil - } - - // log.Printf("%v: ack (seq=%05d) (ack=%05d) (ack_bits=%032b)", &p, header.Sequence, header.ACK, header.ACKBits) - - return p.write(header, nil) - } + return header } func (p *Protocol) readAckBits(ack uint16, ackBits uint32) { @@ -355,7 +333,6 @@ func (p *Protocol) close() bool { if p.die { return false } - close(p.exit) p.die = true p.ouc.Broadcast() @@ -371,43 +348,26 @@ func (p *Protocol) Close() { } } -func (p *Protocol) Run(transmit transmitFunc) { - ticker := time.NewTicker(p.updatePeriod) - defer ticker.Stop() - - for { - select { - case <-p.exit: - return - case <-ticker.C: - if err := p.retransmitUnackedPackets(transmit); err != nil && p.eh != nil { - p.eh(err) - } - } +func (p *Protocol) checkIfRetransmit(idx uint16, resendTimeout time.Duration) ([]byte, bool) { + i := (p.oui + idx) % uint16(len(p.wq)) + if p.wq[i] != uint32(p.oui+idx) || !p.wqe[i].shouldResend(time.Now(), resendTimeout) { + return nil, false } + return p.wqe[i].buf.B, true } -func (p *Protocol) retransmitUnackedPackets(transmit transmitFunc) error { - p.mu.Lock() - defer p.mu.Unlock() - - for idx := uint16(0); idx < uint16(len(p.wq)); idx++ { - i := (p.oui + idx) % uint16(len(p.wq)) - if p.wq[i] != uint32(p.oui+idx) || !p.wqe[i].shouldResend(time.Now(), p.resendTimeout) { - continue - } - - // log.Printf("%v: resend (seq=%d)", &p, p.oui+idx) - - if isEOF, err := transmit(p.wqe[i].buf.B); err != nil { - return fmt.Errorf("failed to retransmit unacked packet: %w", err) - } else if isEOF { - break - } +func (p *Protocol) incrementWqe(idx uint16) { + i := (p.oui + idx) % uint16(len(p.wq)) + p.wqe[i].written = time.Now() + p.wqe[i].resent++ +} - p.wqe[i].written = time.Now() - p.wqe[i].resent++ +func (p *Protocol) callErrorHandler(err error) { + if p.eh != nil { + p.eh(err) } +} - return nil +func (p *Protocol) writeQueueLen() uint16 { + return uint16(len(p.wq)) }