From 86562ac72cda90f4141698d50bef8a4a4600d50e Mon Sep 17 00:00:00 2001 From: tak1827 Date: Fri, 29 May 2020 07:49:03 +0900 Subject: [PATCH 1/6] fix conn_test bug --- conn.go | 9 +---- conn_test.go | 16 +++++--- endpoint_test.go | 61 ++++++++++++++-------------- examples/basic/debug.go | 90 +++++++++++++++++++++++++++++++++++++++++ fuzz.go | 4 +- protocol.go | 19 +++++---- 6 files changed, 146 insertions(+), 53 deletions(-) create mode 100644 examples/basic/debug.go diff --git a/conn.go b/conn.go index 86f718b..a13c1b9 100644 --- a/conn.go +++ b/conn.go @@ -40,14 +40,7 @@ func (c *Conn) WriteUnreliablePacket(buf []byte) error { } func (c *Conn) Read(header PacketHeader, buf []byte) error { - buf = c.protocol.ReadPacket(header, buf) - - if len(buf) != 0 { - _, err := c.transmit(buf) - return err - } - - return nil + return c.protocol.ReadPacket(header, buf, c.transmit) } func (c *Conn) Close() { diff --git a/conn_test.go b/conn_test.go index ae87e4a..19fd172 100644 --- a/conn_test.go +++ b/conn_test.go @@ -27,22 +27,28 @@ func TestConnWriteReliablePacket(t *testing.T) { require.EqualValues(t, data, buf) } - ca := NewConn(a.LocalAddr(), a, WithProtocolPacketHandler(handler)) - cb := NewConn(b.LocalAddr(), b, WithProtocolPacketHandler(handler)) + ca := NewConn(b.LocalAddr(), a, WithProtocolPacketHandler(handler)) + cb := NewConn(a.LocalAddr(), b, WithProtocolPacketHandler(handler)) go readLoop(t, a, ca) go readLoop(t, b, cb) + go ca.Run() + go cb.Run() + defer func() { + // Note: Guarantee that all messages are deliverd + time.Sleep(1 * time.Second) + require.NoError(t, a.SetDeadline(time.Now().Add(1*time.Millisecond))) require.NoError(t, b.SetDeadline(time.Now().Add(1*time.Millisecond))) - require.NoError(t, a.Close()) - require.NoError(t, b.Close()) - ca.Close() cb.Close() + require.NoError(t, a.Close()) + require.NoError(t, b.Close()) + require.EqualValues(t, expected, atomic.LoadUint64(&actual)) }() diff --git a/endpoint_test.go b/endpoint_test.go index f572fc6..31a2a6a 100644 --- a/endpoint_test.go +++ b/endpoint_test.go @@ -8,7 +8,6 @@ import ( "sort" "strconv" "sync" - "sync/atomic" "testing" "time" ) @@ -89,26 +88,18 @@ func BenchmarkEndpointWriteUnreliablePacket(b *testing.B) { func TestEndpointWriteReliablePacket(t *testing.T) { defer goleak.VerifyNone(t) - var mu sync.Mutex - - values := make(map[string]struct{}) - - actual := uint64(0) - expected := uint64(65536) + var ( + expected []int + actual []int + loop uint64 = 65536 + ) handler := func(buf []byte, _ net.Addr) { if len(buf) == 0 { return } - - atomic.AddUint64(&actual, 1) - - mu.Lock() - _, exists := values[string(buf)] - delete(values, string(buf)) - mu.Unlock() - - require.True(t, exists) + num, _ := strconv.Atoi(string(buf)) + actual = append(actual, num) } ca := newPacketConn(t, "127.0.0.1:0") @@ -121,6 +112,9 @@ func TestEndpointWriteReliablePacket(t *testing.T) { go b.Listen() defer func() { + // Note: Guarantee that all messages are deliverd + time.Sleep(1 * time.Second) + require.NoError(t, ca.SetDeadline(time.Now().Add(1*time.Millisecond))) require.NoError(t, cb.SetDeadline(time.Now().Add(1*time.Millisecond))) @@ -130,15 +124,12 @@ func TestEndpointWriteReliablePacket(t *testing.T) { require.NoError(t, ca.Close()) require.NoError(t, cb.Close()) - require.EqualValues(t, expected, atomic.LoadUint64(&actual)) + require.EqualValues(t, expected, uniqSort(actual)) }() - for i := uint64(0); i < expected; i++ { + for i := uint64(0); i < loop; i++ { data := strconv.AppendUint(nil, i, 10) - - mu.Lock() - values[string(data)] = struct{}{} - mu.Unlock() + expected = append(expected, int(i)) require.NoError(t, a.WriteReliablePacket(data, b.Addr())) } @@ -147,14 +138,21 @@ func TestEndpointWriteReliablePacket(t *testing.T) { func TestEndpointWriteReliablePacketEndToEnd(t *testing.T) { defer goleak.VerifyNone(t) - actual := uint64(0) - expected := uint64(512) + var ( + expected []int + actual []int + loop uint64 = 512 + mu sync.Mutex + ) handler := func(buf []byte, _ net.Addr) { if len(buf) == 0 { return } - atomic.AddUint64(&actual, 1) + mu.Lock() + num, _ := strconv.Atoi(string(buf)) + actual = append(actual, num) + mu.Unlock() } ca := newPacketConn(t, "127.0.0.1:0") @@ -176,14 +174,17 @@ func TestEndpointWriteReliablePacketEndToEnd(t *testing.T) { require.NoError(t, ca.Close()) require.NoError(t, cb.Close()) - require.EqualValues(t, expected*2, atomic.LoadUint64(&actual)) + sort.Ints(expected) + require.EqualValues(t, expected, uniqSort(actual)) }() - for i := uint64(0); i < expected; i++ { - data := strconv.AppendUint(nil, i, 10) + for i := uint64(0); i < loop; i++ { + dataA := strconv.AppendUint(nil, i, 10) + dataB := strconv.AppendUint(nil, i+loop, 10) + expected = append(expected, int(i), int(i+loop)) - require.NoError(t, a.WriteReliablePacket(data, b.Addr())) - require.NoError(t, b.WriteReliablePacket(data, a.Addr())) + require.NoError(t, a.WriteReliablePacket(dataA, b.Addr())) + require.NoError(t, b.WriteReliablePacket(dataB, a.Addr())) } } diff --git a/examples/basic/debug.go b/examples/basic/debug.go new file mode 100644 index 0000000..18edd01 --- /dev/null +++ b/examples/basic/debug.go @@ -0,0 +1,90 @@ +package main + +import ( + "bytes" + "errors" + "net" + "time" + "io" + "github.com/lithdew/reliable" +) + +func main() { + var data []byte = []byte{} + + ca, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + panic("hoge") + } + cb, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + panic("hoge") + } + + chErr := make(chan error) + + handler := func(buf []byte, _ net.Addr) { + if len(buf) == 0 || bytes.Equal(buf, data) { + return + } + panic("hoge") + chErr <- errors.New("data miss match") + } + + ea := reliable.NewEndpoint(ca, reliable.WithEndpointPacketHandler(handler)) + eb := reliable.NewEndpoint(cb, reliable.WithEndpointPacketHandler(handler)) + + go ea.Listen() + go eb.Listen() + + for i := 0; i < 65536; i++ { + select { + case <-chErr: + panic("hoge") + default: + if err := ea.WriteReliablePacket(data, eb.Addr()); err != nil && !isEOF(err) { + panic("hoge") + } + } + } + + if err := ca.SetDeadline(time.Now().Add(1 * time.Millisecond)); err != nil { + panic("hoge") + } + + if err := cb.SetDeadline(time.Now().Add(1 * time.Millisecond)); err != nil { + panic("hoge") + } + + if err := ea.Close(); err != nil { + panic("hoge") + } + if err := eb.Close(); err != nil { + panic("hoge") + } + + if err := ca.Close(); err != nil { + panic("hoge") + } + if err := cb.Close(); err != nil { + panic("hoge") + } +} + +func isEOF(err error) bool { + if errors.Is(err, io.EOF) { + return true + } + + var netErr *net.OpError + if errors.As(err, &netErr) { + if netErr.Err.Error() == "use of closed network connection" { + return true + } + if netErr.Timeout() { + return true + } + } + + return false +} diff --git a/fuzz.go b/fuzz.go index 109098e..94f9976 100644 --- a/fuzz.go +++ b/fuzz.go @@ -28,8 +28,8 @@ func Fuzz(data []byte) int { chErr <- errors.New("data miss match") } - ea := NewEndpoint(ca, reliable.WithEndpointPacketHandler(handler)) - eb := NewEndpoint(cb, reliable.WithEndpointPacketHandler(handler)) + ea := NewEndpoint(ca, WithEndpointPacketHandler(handler)) + eb := NewEndpoint(cb, WithEndpointPacketHandler(handler)) go ea.Listen() go eb.Listen() diff --git a/protocol.go b/protocol.go index 43b6b2e..3063862 100644 --- a/protocol.go +++ b/protocol.go @@ -104,7 +104,7 @@ func (p *Protocol) WritePacket(reliable bool, buf []byte) ([]byte, error) { p.trackAcked(ack) - // log.Printf("%v: send (seq=%05d) (ack=%05d) (ack_bits=%032b) (size=%d) (reliable=%t)", &p, idx, ack, ackBits, len(buf), reliable) + // log.Printf("%p: send (seq=%05d) (ack=%05d) (ack_bits=%032b) (size=%d) (reliable=%t)", p, idx, ack, ackBits, len(buf), reliable) return p.write(PacketHeader{Sequence: idx, ACK: ack, ACKBits: ackBits, Unordered: !reliable}, buf), nil } @@ -201,7 +201,7 @@ func (p *Protocol) clearWrites(start, end uint16) { emptyBufferIndices(second) } -func (p *Protocol) ReadPacket(header PacketHeader, buf []byte) []byte { +func (p *Protocol) ReadPacket(header PacketHeader, buf []byte, transmit transmitFunc) error { p.mu.Lock() defer p.mu.Unlock() @@ -221,9 +221,9 @@ func (p *Protocol) ReadPacket(header PacketHeader, buf []byte) []byte { p.ph(buf, header.Sequence) } - // log.Printf("%v: recv (seq=%05d) (ack=%05d) (ack_bits=%032b) (size=%d) (reliable=%t)", &p, header.Sequence, header.ACK, header.ACKBits, len(buf), !header.Unordered) + // log.Printf("%p: recv (seq=%05d) (ack=%05d) (ack_bits=%032b) (size=%d) (reliable=%t)", p, header.Sequence, header.ACK, header.ACKBits, len(buf), !header.Unordered) - return p.writeAcksIfNecessary() + return p.writeAcksIfNecessary(transmit) } func (p *Protocol) createAckIfNecessary() (header PacketHeader, needed bool) { @@ -250,16 +250,19 @@ func (p *Protocol) createAckIfNecessary() (header PacketHeader, needed bool) { return header, needed } -func (p *Protocol) writeAcksIfNecessary() []byte { +func (p *Protocol) writeAcksIfNecessary(transmit transmitFunc) error { for { header, needed := p.createAckIfNecessary() if !needed { return nil } - // log.Printf("%v: ack (seq=%05d) (ack=%05d) (ack_bits=%032b)", &p, header.Sequence, header.ACK, header.ACKBits) + // log.Printf("%p: ack (seq=%05d) (ack=%05d) (ack_bits=%032b)", p, header.Sequence, header.ACK, header.ACKBits) - return p.write(header, nil) + buf := p.write(header, nil) + if _, err := transmit(buf); err != nil { + return fmt.Errorf("failed to transmit acks: %w", err) + } } } @@ -397,7 +400,7 @@ func (p *Protocol) retransmitUnackedPackets(transmit transmitFunc) error { continue } - // log.Printf("%v: resend (seq=%d)", &p, p.oui+idx) + // log.Printf("%p: resend (seq=%d)", p, p.oui+idx) if isEOF, err := transmit(p.wqe[i].buf.B); err != nil { return fmt.Errorf("failed to retransmit unacked packet: %w", err) From ec6cff0d5ebdffb12b1c5da52f5dac88577cbd68 Mon Sep 17 00:00:00 2001 From: tak1827 Date: Fri, 29 May 2020 07:57:06 +0900 Subject: [PATCH 2/6] remove debug code --- examples/basic/debug.go | 90 ----------------------------------------- 1 file changed, 90 deletions(-) delete mode 100644 examples/basic/debug.go diff --git a/examples/basic/debug.go b/examples/basic/debug.go deleted file mode 100644 index 18edd01..0000000 --- a/examples/basic/debug.go +++ /dev/null @@ -1,90 +0,0 @@ -package main - -import ( - "bytes" - "errors" - "net" - "time" - "io" - "github.com/lithdew/reliable" -) - -func main() { - var data []byte = []byte{} - - ca, err := net.ListenPacket("udp", "127.0.0.1:0") - if err != nil { - panic("hoge") - } - cb, err := net.ListenPacket("udp", "127.0.0.1:0") - if err != nil { - panic("hoge") - } - - chErr := make(chan error) - - handler := func(buf []byte, _ net.Addr) { - if len(buf) == 0 || bytes.Equal(buf, data) { - return - } - panic("hoge") - chErr <- errors.New("data miss match") - } - - ea := reliable.NewEndpoint(ca, reliable.WithEndpointPacketHandler(handler)) - eb := reliable.NewEndpoint(cb, reliable.WithEndpointPacketHandler(handler)) - - go ea.Listen() - go eb.Listen() - - for i := 0; i < 65536; i++ { - select { - case <-chErr: - panic("hoge") - default: - if err := ea.WriteReliablePacket(data, eb.Addr()); err != nil && !isEOF(err) { - panic("hoge") - } - } - } - - if err := ca.SetDeadline(time.Now().Add(1 * time.Millisecond)); err != nil { - panic("hoge") - } - - if err := cb.SetDeadline(time.Now().Add(1 * time.Millisecond)); err != nil { - panic("hoge") - } - - if err := ea.Close(); err != nil { - panic("hoge") - } - if err := eb.Close(); err != nil { - panic("hoge") - } - - if err := ca.Close(); err != nil { - panic("hoge") - } - if err := cb.Close(); err != nil { - panic("hoge") - } -} - -func isEOF(err error) bool { - if errors.Is(err, io.EOF) { - return true - } - - var netErr *net.OpError - if errors.As(err, &netErr) { - if netErr.Err.Error() == "use of closed network connection" { - return true - } - if netErr.Timeout() { - return true - } - } - - return false -} From 0c138dcf7138f30d7432c043039b5d582b1dc9e5 Mon Sep 17 00:00:00 2001 From: tak1827 Date: Fri, 29 May 2020 13:48:58 +0900 Subject: [PATCH 3/6] fix fuzz bug --- fuzz.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fuzz.go b/fuzz.go index 94f9976..c99ca6c 100644 --- a/fuzz.go +++ b/fuzz.go @@ -34,7 +34,7 @@ func Fuzz(data []byte) int { go ea.Listen() go eb.Listen() - for i := 0; i < 65536; i++ { + for i := 0; i < 4096; i++ { select { case <-chErr: return 0 From 0f630c3c313d019771355969d6f6a79730015a35 Mon Sep 17 00:00:00 2001 From: tak1827 Date: Fri, 29 May 2020 20:58:43 +0900 Subject: [PATCH 4/6] don't pass trasmit to conn when read --- conn.go | 10 +++++++++- protocol.go | 14 +++++++------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/conn.go b/conn.go index a13c1b9..ac0cb74 100644 --- a/conn.go +++ b/conn.go @@ -40,7 +40,15 @@ func (c *Conn) WriteUnreliablePacket(buf []byte) error { } func (c *Conn) Read(header PacketHeader, buf []byte) error { - return c.protocol.ReadPacket(header, buf, c.transmit) + bufs := c.protocol.ReadPacket(header, buf) + + for _, b := range bufs { + if _, err := c.transmit(b); err != nil { + return fmt.Errorf("failed to transmit acks: %w", err) + } + } + + return nil } func (c *Conn) Close() { diff --git a/protocol.go b/protocol.go index 3063862..ee040b0 100644 --- a/protocol.go +++ b/protocol.go @@ -201,7 +201,7 @@ func (p *Protocol) clearWrites(start, end uint16) { emptyBufferIndices(second) } -func (p *Protocol) ReadPacket(header PacketHeader, buf []byte, transmit transmitFunc) error { +func (p *Protocol) ReadPacket(header PacketHeader, buf []byte) [][]byte { p.mu.Lock() defer p.mu.Unlock() @@ -223,7 +223,7 @@ func (p *Protocol) ReadPacket(header PacketHeader, buf []byte, transmit transmit // log.Printf("%p: recv (seq=%05d) (ack=%05d) (ack_bits=%032b) (size=%d) (reliable=%t)", p, header.Sequence, header.ACK, header.ACKBits, len(buf), !header.Unordered) - return p.writeAcksIfNecessary(transmit) + return p.writeAcksIfNecessary() } func (p *Protocol) createAckIfNecessary() (header PacketHeader, needed bool) { @@ -250,20 +250,20 @@ func (p *Protocol) createAckIfNecessary() (header PacketHeader, needed bool) { return header, needed } -func (p *Protocol) writeAcksIfNecessary(transmit transmitFunc) error { +func (p *Protocol) writeAcksIfNecessary() (bufs [][]byte) { for { header, needed := p.createAckIfNecessary() if !needed { - return nil + break } // log.Printf("%p: ack (seq=%05d) (ack=%05d) (ack_bits=%032b)", p, header.Sequence, header.ACK, header.ACKBits) buf := p.write(header, nil) - if _, err := transmit(buf); err != nil { - return fmt.Errorf("failed to transmit acks: %w", err) - } + bufs = append(bufs, buf) } + + return } func (p *Protocol) readAckBits(ack uint16, ackBits uint32) { From b1727178d63bf290e5d7a834a1de2873b6aee136 Mon Sep 17 00:00:00 2001 From: tak1827 Date: Sat, 30 May 2020 19:37:20 +0900 Subject: [PATCH 5/6] expose ack sending to protocol --- conn.go | 33 ++++++++++++++++++++++++++------- protocol.go | 39 +++++++++++++-------------------------- 2 files changed, 39 insertions(+), 33 deletions(-) diff --git a/conn.go b/conn.go index ac0cb74..82c2d65 100644 --- a/conn.go +++ b/conn.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "net" + // "sync" ) type transmitFunc func(buf []byte) (bool, error) @@ -40,15 +41,12 @@ func (c *Conn) WriteUnreliablePacket(buf []byte) error { } func (c *Conn) Read(header PacketHeader, buf []byte) error { - bufs := c.protocol.ReadPacket(header, buf) - - for _, b := range bufs { - if _, err := c.transmit(b); err != nil { - return fmt.Errorf("failed to transmit acks: %w", err) - } + needed := c.protocol.ReadPacket(header, buf) + if !needed { + return nil } - return nil + return c.writeAcks() } func (c *Conn) Close() { @@ -75,3 +73,24 @@ func (c *Conn) transmit(buf []byte) (EOF bool, err error) { return } + +func (c *Conn) writeAcks() error { + c.protocol.mu.Lock() + defer c.protocol.mu.Unlock() + + for { + needed := c.protocol.ackNeeded() + if !needed { + break + } + + header := c.protocol.createAck() + + buf := c.protocol.write(header, nil) + if _, err := c.transmit(buf); err != nil { + return fmt.Errorf("failed to transmit acks: %w", err) + } + } + + return nil +} diff --git a/protocol.go b/protocol.go index ee040b0..e8c61ed 100644 --- a/protocol.go +++ b/protocol.go @@ -201,20 +201,20 @@ func (p *Protocol) clearWrites(start, end uint16) { emptyBufferIndices(second) } -func (p *Protocol) ReadPacket(header PacketHeader, buf []byte) [][]byte { +func (p *Protocol) ReadPacket(header PacketHeader, buf []byte) (needed bool) { p.mu.Lock() defer p.mu.Unlock() p.readAckBits(header.ACK, header.ACKBits) if !header.Unordered && !p.trackRead(header.Sequence) { - return nil + return } p.trackUnacked() if header.Empty { - return nil + return } if p.ph != nil { @@ -223,19 +223,24 @@ func (p *Protocol) ReadPacket(header PacketHeader, buf []byte) [][]byte { // log.Printf("%p: recv (seq=%05d) (ack=%05d) (ack_bits=%032b) (size=%d) (reliable=%t)", p, header.Sequence, header.ACK, header.ACKBits, len(buf), !header.Unordered) - return p.writeAcksIfNecessary() + needed = p.ackNeeded() + return } -func (p *Protocol) createAckIfNecessary() (header PacketHeader, needed bool) { +func (p *Protocol) ackNeeded() bool { lui := p.lui for i := uint16(0); i < ACKBitsetSize; i++ { if p.rq[(lui+i)%uint16(len(p.rq))] != uint32(lui+i) { - return header, needed + return false } } - lui += ACKBitsetSize + return !p.die +} + +func (p *Protocol) createAck() (header PacketHeader) { + lui := p.lui + ACKBitsetSize p.lui = lui p.ls = time.Now() @@ -245,25 +250,7 @@ func (p *Protocol) createAckIfNecessary() (header PacketHeader, needed bool) { header.ACKBits = p.prepareAckBits(header.ACK) header.Empty = true - needed = !p.die - - return header, needed -} - -func (p *Protocol) writeAcksIfNecessary() (bufs [][]byte) { - for { - header, needed := p.createAckIfNecessary() - if !needed { - break - } - - // log.Printf("%p: ack (seq=%05d) (ack=%05d) (ack_bits=%032b)", p, header.Sequence, header.ACK, header.ACKBits) - - buf := p.write(header, nil) - bufs = append(bufs, buf) - } - - return + return header } func (p *Protocol) readAckBits(ack uint16, ackBits uint32) { From 7d0093e678495d7938a4418c966791f50e3fdae4 Mon Sep 17 00:00:00 2001 From: tak1827 Date: Sun, 31 May 2020 10:30:12 +0900 Subject: [PATCH 6/6] expose retransmit from protocol --- conn.go | 79 +++++++++++++++++++++++++++++++++++++++---------- endpoint.go | 34 ++++++++++----------- options.go | 14 +++++---- options_test.go | 27 +++++++++++++++++ protocol.go | 66 +++++++++++------------------------------ 5 files changed, 133 insertions(+), 87 deletions(-) create mode 100644 options_test.go diff --git a/conn.go b/conn.go index 82c2d65..a860d40 100644 --- a/conn.go +++ b/conn.go @@ -4,24 +4,36 @@ import ( "fmt" "io" "net" - // "sync" + "time" ) type transmitFunc func(buf []byte) (bool, error) type Conn struct { - addr net.Addr - conn net.PacketConn - protocol *Protocol + addr net.Addr + conn net.PacketConn + + p *Protocol + + updatePeriod time.Duration // how often time-dependant parts of the p get checked + resendTimeout time.Duration // how long we wait until unacked packets should be resent + + exit chan struct{} // signal channel to close the conn } func NewConn(addr net.Addr, conn net.PacketConn, opts ...ProtocolOption) *Conn { - p := NewProtocol(opts...) - return &Conn{addr: addr, conn: conn, protocol: p} + return &Conn{ + addr: addr, + conn: conn, + p: NewProtocol(opts...), + updatePeriod: DefaultUpdatePeriod, + resendTimeout: DefaultResendTimeout, + exit: make(chan struct{}), + } } func (c *Conn) WriteReliablePacket(buf []byte) error { - buf, err := c.protocol.WritePacket(true, buf) + buf, err := c.p.WritePacket(true, buf) if err != nil { return err } @@ -31,7 +43,7 @@ func (c *Conn) WriteReliablePacket(buf []byte) error { } func (c *Conn) WriteUnreliablePacket(buf []byte) error { - buf, err := c.protocol.WritePacket(false, buf) + buf, err := c.p.WritePacket(false, buf) if err != nil { return err } @@ -41,7 +53,7 @@ func (c *Conn) WriteUnreliablePacket(buf []byte) error { } func (c *Conn) Read(header PacketHeader, buf []byte) error { - needed := c.protocol.ReadPacket(header, buf) + needed := c.p.ReadPacket(header, buf) if !needed { return nil } @@ -50,11 +62,24 @@ func (c *Conn) Read(header PacketHeader, buf []byte) error { } func (c *Conn) Close() { - c.protocol.Close() + close(c.exit) + c.p.Close() } func (c *Conn) Run() { - c.protocol.Run(c.transmit) + ticker := time.NewTicker(c.updatePeriod) + defer ticker.Stop() + + for { + select { + case <-c.exit: + return + case <-ticker.C: + if err := c.retransmitUnackedPackets(); err != nil { + c.p.callErrorHandler(err) + } + } + } } func (c *Conn) transmit(buf []byte) (EOF bool, err error) { @@ -75,18 +100,18 @@ func (c *Conn) transmit(buf []byte) (EOF bool, err error) { } func (c *Conn) writeAcks() error { - c.protocol.mu.Lock() - defer c.protocol.mu.Unlock() + c.p.mu.Lock() + defer c.p.mu.Unlock() for { - needed := c.protocol.ackNeeded() + needed := c.p.checkIfAck() if !needed { break } - header := c.protocol.createAck() + header := c.p.createAck() - buf := c.protocol.write(header, nil) + buf := c.p.write(header, nil) if _, err := c.transmit(buf); err != nil { return fmt.Errorf("failed to transmit acks: %w", err) } @@ -94,3 +119,25 @@ func (c *Conn) writeAcks() error { return nil } + +func (c *Conn) retransmitUnackedPackets() error { + c.p.mu.Lock() + defer c.p.mu.Unlock() + + for idx := uint16(0); idx < c.p.writeQueueLen(); idx++ { + buf, needed := c.p.checkIfRetransmit(idx, c.resendTimeout) + if !needed { + continue + } + + if isEOF, err := c.transmit(buf); err != nil { + return fmt.Errorf("failed to retransmit unacked packet: %w", err) + } else if isEOF { + break + } + + c.p.incrementWqe(idx) + } + + return nil +} diff --git a/endpoint.go b/endpoint.go index 22ed72c..93be8c5 100644 --- a/endpoint.go +++ b/endpoint.go @@ -6,7 +6,6 @@ import ( "net" "sync" "sync/atomic" - "time" ) type EndpointPacketHandler func(buf []byte, addr net.Addr) @@ -16,9 +15,6 @@ type Endpoint struct { writeBufferSize uint16 // write buffer size that must be a divisor of 65536 readBufferSize uint16 // read buffer size that must be a divisor of 65536 - updatePeriod time.Duration // how often time-dependant parts of the protocol get checked - resendTimeout time.Duration // how long we wait until unacked packets should be resent - mu sync.Mutex wg sync.WaitGroup @@ -49,14 +45,6 @@ func NewEndpoint(conn net.PacketConn, opts ...EndpointOption) *Endpoint { e.readBufferSize = DefaultReadBufferSize } - if e.resendTimeout == 0 { - e.resendTimeout = DefaultResendTimeout - } - - if e.updatePeriod == 0 { - e.updatePeriod = DefaultUpdatePeriod - } - if e.pool == nil { e.pool = new(Pool) } @@ -81,8 +69,6 @@ func (e *Endpoint) getConn(addr net.Addr) *Conn { e.conn, WithWriteBufferSize(e.writeBufferSize), WithReadBufferSize(e.readBufferSize), - WithUpdatePeriod(e.updatePeriod), - WithResendTimeout(e.resendTimeout), WithBufferPool(e.pool), ) @@ -127,23 +113,33 @@ func (e *Endpoint) Addr() net.Addr { return e.addr } -func (e *Endpoint) WriteReliablePacket(buf []byte, addr net.Addr) error { +func (e *Endpoint) WriteReliablePacket(buf []byte, addr net.Addr, opts ...ConnOption) error { conn := e.getConn(addr) if conn == nil { return io.EOF } + + for _, opt := range opts { + opt.applyConn(conn) + } + return conn.WriteReliablePacket(buf) } -func (e *Endpoint) WriteUnreliablePacket(buf []byte, addr net.Addr) error { +func (e *Endpoint) WriteUnreliablePacket(buf []byte, addr net.Addr, opts ...ConnOption) error { conn := e.getConn(addr) if conn == nil { return io.EOF } + + for _, opt := range opts { + opt.applyConn(conn) + } + return conn.WriteUnreliablePacket(buf) } -func (e *Endpoint) Listen() { +func (e *Endpoint) Listen(opts ...ConnOption) { e.mu.Lock() e.wg.Add(1) e.mu.Unlock() @@ -168,6 +164,10 @@ func (e *Endpoint) Listen() { break } + for _, opt := range opts { + opt.applyConn(conn) + } + header, buf, err := UnmarshalPacketHeader(buf[:n]) if err != nil { e.clearConn(addr) diff --git a/options.go b/options.go index 71f0681..47331bc 100644 --- a/options.go +++ b/options.go @@ -18,6 +18,10 @@ type EndpointOption interface { applyEndpoint(e *Endpoint) } +type ConnOption interface { + applyConn(c *Conn) +} + type Option interface { ProtocolOption EndpointOption @@ -82,10 +86,9 @@ func WithEndpointErrorHandler(eh EndpointErrorHandler) EndpointOption { type withUpdatePeriod struct{ updatePeriod time.Duration } -func (o withUpdatePeriod) applyProtocol(p *Protocol) { p.updatePeriod = o.updatePeriod } -func (o withUpdatePeriod) applyEndpoint(e *Endpoint) { e.updatePeriod = o.updatePeriod } +func (o withUpdatePeriod) applyConn(c *Conn) { c.updatePeriod = o.updatePeriod } -func WithUpdatePeriod(updatePeriod time.Duration) Option { +func WithUpdatePeriod(updatePeriod time.Duration) ConnOption { if updatePeriod == 0 { panic("update period of zero is not supported yet") } @@ -94,10 +97,9 @@ func WithUpdatePeriod(updatePeriod time.Duration) Option { type withResendTimeout struct{ resendTimeout time.Duration } -func (o withResendTimeout) applyProtocol(p *Protocol) { p.resendTimeout = o.resendTimeout } -func (o withResendTimeout) applyEndpoint(e *Endpoint) { e.resendTimeout = o.resendTimeout } +func (o withResendTimeout) applyConn(c *Conn) { c.resendTimeout = o.resendTimeout } -func WithResendTimeout(resendTimeout time.Duration) Option { +func WithResendTimeout(resendTimeout time.Duration) ConnOption { if resendTimeout == 0 { panic("ack timeout of zero is not supported yet") } diff --git a/options_test.go b/options_test.go new file mode 100644 index 0000000..9eb1414 --- /dev/null +++ b/options_test.go @@ -0,0 +1,27 @@ +package reliable + +import ( + "github.com/stretchr/testify/require" + "net" + "testing" + "time" +) + +func TestConnApply(t *testing.T) { + a, _ := net.ListenPacket("udp", "127.0.0.1:0") + b, _ := net.ListenPacket("udp", "127.0.0.1:0") + + ca := NewConn(b.LocalAddr(), a) + + expectedUpdatePeriod := 1 * time.Second + expectedResendTimeout := 2 * time.Second + + uOpt := WithUpdatePeriod(expectedUpdatePeriod) + rOpt := WithResendTimeout(expectedResendTimeout) + + uOpt.applyConn(ca) + require.EqualValues(t, expectedUpdatePeriod, ca.updatePeriod) + + rOpt.applyConn(ca) + require.EqualValues(t, expectedResendTimeout, ca.resendTimeout) +} diff --git a/protocol.go b/protocol.go index e8c61ed..cefb0c8 100644 --- a/protocol.go +++ b/protocol.go @@ -1,7 +1,6 @@ package reliable import ( - "fmt" "github.com/lithdew/seq" "io" "sync" @@ -15,9 +14,6 @@ type Protocol struct { writeBufferSize uint16 // write buffer size that must be a divisor of 65536 readBufferSize uint16 // read buffer size that must be a divisor of 65536 - updatePeriod time.Duration // how often time-dependant parts of the protocol get checked - resendTimeout time.Duration // how long we wait until unacked packets should be resent - pool *Pool ph ProtocolPacketHandler @@ -56,14 +52,6 @@ func NewProtocol(opts ...ProtocolOption) *Protocol { p.readBufferSize = DefaultReadBufferSize } - if p.resendTimeout == 0 { - p.resendTimeout = DefaultResendTimeout - } - - if p.updatePeriod == 0 { - p.updatePeriod = DefaultUpdatePeriod - } - if p.pool == nil { p.pool = new(Pool) } @@ -223,11 +211,11 @@ func (p *Protocol) ReadPacket(header PacketHeader, buf []byte) (needed bool) { // log.Printf("%p: recv (seq=%05d) (ack=%05d) (ack_bits=%032b) (size=%d) (reliable=%t)", p, header.Sequence, header.ACK, header.ACKBits, len(buf), !header.Unordered) - needed = p.ackNeeded() + needed = p.checkIfAck() return } -func (p *Protocol) ackNeeded() bool { +func (p *Protocol) checkIfAck() bool { lui := p.lui for i := uint16(0); i < ACKBitsetSize; i++ { @@ -345,7 +333,6 @@ func (p *Protocol) close() bool { if p.die { return false } - close(p.exit) p.die = true p.ouc.Broadcast() @@ -361,43 +348,26 @@ func (p *Protocol) Close() { } } -func (p *Protocol) Run(transmit transmitFunc) { - ticker := time.NewTicker(p.updatePeriod) - defer ticker.Stop() - - for { - select { - case <-p.exit: - return - case <-ticker.C: - if err := p.retransmitUnackedPackets(transmit); err != nil && p.eh != nil { - p.eh(err) - } - } +func (p *Protocol) checkIfRetransmit(idx uint16, resendTimeout time.Duration) ([]byte, bool) { + i := (p.oui + idx) % uint16(len(p.wq)) + if p.wq[i] != uint32(p.oui+idx) || !p.wqe[i].shouldResend(time.Now(), resendTimeout) { + return nil, false } + return p.wqe[i].buf.B, true } -func (p *Protocol) retransmitUnackedPackets(transmit transmitFunc) error { - p.mu.Lock() - defer p.mu.Unlock() - - for idx := uint16(0); idx < uint16(len(p.wq)); idx++ { - i := (p.oui + idx) % uint16(len(p.wq)) - if p.wq[i] != uint32(p.oui+idx) || !p.wqe[i].shouldResend(time.Now(), p.resendTimeout) { - continue - } - - // log.Printf("%p: resend (seq=%d)", p, p.oui+idx) - - if isEOF, err := transmit(p.wqe[i].buf.B); err != nil { - return fmt.Errorf("failed to retransmit unacked packet: %w", err) - } else if isEOF { - break - } +func (p *Protocol) incrementWqe(idx uint16) { + i := (p.oui + idx) % uint16(len(p.wq)) + p.wqe[i].written = time.Now() + p.wqe[i].resent++ +} - p.wqe[i].written = time.Now() - p.wqe[i].resent++ +func (p *Protocol) callErrorHandler(err error) { + if p.eh != nil { + p.eh(err) } +} - return nil +func (p *Protocol) writeQueueLen() uint16 { + return uint16(len(p.wq)) }