Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 82 additions & 15 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,36 @@ import (
"fmt"
"io"
"net"
"time"
)

type transmitFunc func(buf []byte) (bool, error)

type Conn struct {
addr net.Addr
conn net.PacketConn
protocol *Protocol
addr net.Addr
conn net.PacketConn

p *Protocol

updatePeriod time.Duration // how often time-dependant parts of the p get checked
resendTimeout time.Duration // how long we wait until unacked packets should be resent

exit chan struct{} // signal channel to close the conn
}

func NewConn(addr net.Addr, conn net.PacketConn, opts ...ProtocolOption) *Conn {
p := NewProtocol(opts...)
return &Conn{addr: addr, conn: conn, protocol: p}
return &Conn{
addr: addr,
conn: conn,
p: NewProtocol(opts...),
updatePeriod: DefaultUpdatePeriod,
resendTimeout: DefaultResendTimeout,
exit: make(chan struct{}),
}
}

func (c *Conn) WriteReliablePacket(buf []byte) error {
buf, err := c.protocol.WritePacket(true, buf)
buf, err := c.p.WritePacket(true, buf)
if err != nil {
return err
}
Expand All @@ -30,7 +43,7 @@ func (c *Conn) WriteReliablePacket(buf []byte) error {
}

func (c *Conn) WriteUnreliablePacket(buf []byte) error {
buf, err := c.protocol.WritePacket(false, buf)
buf, err := c.p.WritePacket(false, buf)
if err != nil {
return err
}
Expand All @@ -40,22 +53,33 @@ func (c *Conn) WriteUnreliablePacket(buf []byte) error {
}

func (c *Conn) Read(header PacketHeader, buf []byte) error {
buf = c.protocol.ReadPacket(header, buf)

if len(buf) != 0 {
_, err := c.transmit(buf)
return err
needed := c.p.ReadPacket(header, buf)
if !needed {
return nil
}

return nil
return c.writeAcks()
}

func (c *Conn) Close() {
c.protocol.Close()
close(c.exit)
c.p.Close()
}

func (c *Conn) Run() {
c.protocol.Run(c.transmit)
ticker := time.NewTicker(c.updatePeriod)
defer ticker.Stop()

for {
select {
case <-c.exit:
return
case <-ticker.C:
if err := c.retransmitUnackedPackets(); err != nil {
c.p.callErrorHandler(err)
}
}
}
}

func (c *Conn) transmit(buf []byte) (EOF bool, err error) {
Expand All @@ -74,3 +98,46 @@ func (c *Conn) transmit(buf []byte) (EOF bool, err error) {

return
}

func (c *Conn) writeAcks() error {
c.p.mu.Lock()
defer c.p.mu.Unlock()

for {
needed := c.p.checkIfAck()
if !needed {
break
}

header := c.p.createAck()

buf := c.p.write(header, nil)
if _, err := c.transmit(buf); err != nil {
return fmt.Errorf("failed to transmit acks: %w", err)
}
}

return nil
}

func (c *Conn) retransmitUnackedPackets() error {
c.p.mu.Lock()
defer c.p.mu.Unlock()

for idx := uint16(0); idx < c.p.writeQueueLen(); idx++ {
buf, needed := c.p.checkIfRetransmit(idx, c.resendTimeout)
if !needed {
continue
}

if isEOF, err := c.transmit(buf); err != nil {
return fmt.Errorf("failed to retransmit unacked packet: %w", err)
} else if isEOF {
break
}

c.p.incrementWqe(idx)
}

return nil
}
16 changes: 11 additions & 5 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,28 @@ func TestConnWriteReliablePacket(t *testing.T) {
require.EqualValues(t, data, buf)
}

ca := NewConn(a.LocalAddr(), a, WithProtocolPacketHandler(handler))
cb := NewConn(b.LocalAddr(), b, WithProtocolPacketHandler(handler))
ca := NewConn(b.LocalAddr(), a, WithProtocolPacketHandler(handler))
cb := NewConn(a.LocalAddr(), b, WithProtocolPacketHandler(handler))

go readLoop(t, a, ca)
go readLoop(t, b, cb)

go ca.Run()
go cb.Run()

defer func() {
// Note: Guarantee that all messages are deliverd
time.Sleep(1 * time.Second)

require.NoError(t, a.SetDeadline(time.Now().Add(1*time.Millisecond)))
require.NoError(t, b.SetDeadline(time.Now().Add(1*time.Millisecond)))

require.NoError(t, a.Close())
require.NoError(t, b.Close())

ca.Close()
cb.Close()

require.NoError(t, a.Close())
require.NoError(t, b.Close())

require.EqualValues(t, expected, atomic.LoadUint64(&actual))
}()

Expand Down
34 changes: 17 additions & 17 deletions endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"net"
"sync"
"sync/atomic"
"time"
)

type EndpointPacketHandler func(buf []byte, addr net.Addr)
Expand All @@ -16,9 +15,6 @@ type Endpoint struct {
writeBufferSize uint16 // write buffer size that must be a divisor of 65536
readBufferSize uint16 // read buffer size that must be a divisor of 65536

updatePeriod time.Duration // how often time-dependant parts of the protocol get checked
resendTimeout time.Duration // how long we wait until unacked packets should be resent

mu sync.Mutex
wg sync.WaitGroup

Expand Down Expand Up @@ -49,14 +45,6 @@ func NewEndpoint(conn net.PacketConn, opts ...EndpointOption) *Endpoint {
e.readBufferSize = DefaultReadBufferSize
}

if e.resendTimeout == 0 {
e.resendTimeout = DefaultResendTimeout
}

if e.updatePeriod == 0 {
e.updatePeriod = DefaultUpdatePeriod
}

if e.pool == nil {
e.pool = new(Pool)
}
Expand All @@ -81,8 +69,6 @@ func (e *Endpoint) getConn(addr net.Addr) *Conn {
e.conn,
WithWriteBufferSize(e.writeBufferSize),
WithReadBufferSize(e.readBufferSize),
WithUpdatePeriod(e.updatePeriod),
WithResendTimeout(e.resendTimeout),
WithBufferPool(e.pool),
)

Expand Down Expand Up @@ -127,23 +113,33 @@ func (e *Endpoint) Addr() net.Addr {
return e.addr
}

func (e *Endpoint) WriteReliablePacket(buf []byte, addr net.Addr) error {
func (e *Endpoint) WriteReliablePacket(buf []byte, addr net.Addr, opts ...ConnOption) error {
conn := e.getConn(addr)
if conn == nil {
return io.EOF
}

for _, opt := range opts {
opt.applyConn(conn)
}

return conn.WriteReliablePacket(buf)
}

func (e *Endpoint) WriteUnreliablePacket(buf []byte, addr net.Addr) error {
func (e *Endpoint) WriteUnreliablePacket(buf []byte, addr net.Addr, opts ...ConnOption) error {
conn := e.getConn(addr)
if conn == nil {
return io.EOF
}

for _, opt := range opts {
opt.applyConn(conn)
}

return conn.WriteUnreliablePacket(buf)
}

func (e *Endpoint) Listen() {
func (e *Endpoint) Listen(opts ...ConnOption) {
e.mu.Lock()
e.wg.Add(1)
e.mu.Unlock()
Expand All @@ -168,6 +164,10 @@ func (e *Endpoint) Listen() {
break
}

for _, opt := range opts {
opt.applyConn(conn)
}

header, buf, err := UnmarshalPacketHeader(buf[:n])
if err != nil {
e.clearConn(addr)
Expand Down
61 changes: 31 additions & 30 deletions endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"sort"
"strconv"
"sync"
"sync/atomic"
"testing"
"time"
)
Expand Down Expand Up @@ -89,26 +88,18 @@ func BenchmarkEndpointWriteUnreliablePacket(b *testing.B) {
func TestEndpointWriteReliablePacket(t *testing.T) {
defer goleak.VerifyNone(t)

var mu sync.Mutex

values := make(map[string]struct{})

actual := uint64(0)
expected := uint64(65536)
var (
expected []int
actual []int
loop uint64 = 65536
)

handler := func(buf []byte, _ net.Addr) {
if len(buf) == 0 {
return
}

atomic.AddUint64(&actual, 1)

mu.Lock()
_, exists := values[string(buf)]
delete(values, string(buf))
mu.Unlock()

require.True(t, exists)
num, _ := strconv.Atoi(string(buf))
actual = append(actual, num)
}

ca := newPacketConn(t, "127.0.0.1:0")
Expand All @@ -121,6 +112,9 @@ func TestEndpointWriteReliablePacket(t *testing.T) {
go b.Listen()

defer func() {
// Note: Guarantee that all messages are deliverd
time.Sleep(1 * time.Second)

require.NoError(t, ca.SetDeadline(time.Now().Add(1*time.Millisecond)))
require.NoError(t, cb.SetDeadline(time.Now().Add(1*time.Millisecond)))

Expand All @@ -130,15 +124,12 @@ func TestEndpointWriteReliablePacket(t *testing.T) {
require.NoError(t, ca.Close())
require.NoError(t, cb.Close())

require.EqualValues(t, expected, atomic.LoadUint64(&actual))
require.EqualValues(t, expected, uniqSort(actual))
}()

for i := uint64(0); i < expected; i++ {
for i := uint64(0); i < loop; i++ {
data := strconv.AppendUint(nil, i, 10)

mu.Lock()
values[string(data)] = struct{}{}
mu.Unlock()
expected = append(expected, int(i))

require.NoError(t, a.WriteReliablePacket(data, b.Addr()))
}
Expand All @@ -147,14 +138,21 @@ func TestEndpointWriteReliablePacket(t *testing.T) {
func TestEndpointWriteReliablePacketEndToEnd(t *testing.T) {
defer goleak.VerifyNone(t)

actual := uint64(0)
expected := uint64(512)
var (
expected []int
actual []int
loop uint64 = 512
mu sync.Mutex
)

handler := func(buf []byte, _ net.Addr) {
if len(buf) == 0 {
return
}
atomic.AddUint64(&actual, 1)
mu.Lock()
num, _ := strconv.Atoi(string(buf))
actual = append(actual, num)
mu.Unlock()
}

ca := newPacketConn(t, "127.0.0.1:0")
Expand All @@ -176,14 +174,17 @@ func TestEndpointWriteReliablePacketEndToEnd(t *testing.T) {
require.NoError(t, ca.Close())
require.NoError(t, cb.Close())

require.EqualValues(t, expected*2, atomic.LoadUint64(&actual))
sort.Ints(expected)
require.EqualValues(t, expected, uniqSort(actual))
}()

for i := uint64(0); i < expected; i++ {
data := strconv.AppendUint(nil, i, 10)
for i := uint64(0); i < loop; i++ {
dataA := strconv.AppendUint(nil, i, 10)
dataB := strconv.AppendUint(nil, i+loop, 10)
expected = append(expected, int(i), int(i+loop))

require.NoError(t, a.WriteReliablePacket(data, b.Addr()))
require.NoError(t, b.WriteReliablePacket(data, a.Addr()))
require.NoError(t, a.WriteReliablePacket(dataA, b.Addr()))
require.NoError(t, b.WriteReliablePacket(dataB, a.Addr()))
}
}

Expand Down
Loading