diff --git a/codel.go b/codel.go index 394ec51..c763714 100644 --- a/codel.go +++ b/codel.go @@ -1,212 +1,134 @@ package simnet import ( - "container/heap" "math" - "sync" "time" ) -// codelQueue is a FIFO queue with CoDel bufferbloat control -type codelQueue struct { - mu sync.Mutex - packets packetHeap - newPacket chan struct{} - closed bool - pushCount int +// coDelQueue is a FIFO queue with CoDel bufferbloat control. +// Refer to RFC 8289 +type coDelQueue struct { + q ringBuffer[packetWithTimestamp] + + byteCount uint64 + MTU uint16 // CoDel state - dropping bool - firstAbove time.Time - dropNext time.Time - count int - target time.Duration // target queue delay (e.g., 5ms) - interval time.Duration // interval for sustained bad queue (e.g., 100ms) - lastDropTime time.Time + target time.Duration // target queue delay (e.g., 5ms) + interval time.Duration // interval for sustained bad queue (e.g., 100ms) + + dropping bool + firstAbove time.Time + dropNext time.Time + count int + lastCount int } -func newCodelQueue(target, interval time.Duration) *codelQueue { - q := &codelQueue{ - target: target, - interval: interval, - newPacket: make(chan struct{}, 1), - } - heap.Init(&q.packets) - return q +type packetWithTimestamp struct { + Packet + ts time.Time } -// Enqueue adds a packet to the queue -func (q *codelQueue) Enqueue(p *packetWithDeliveryTime) { - q.mu.Lock() - defer q.mu.Unlock() - if q.closed { - return +func newCoDelQueue(target, interval time.Duration) coDelQueue { + return coDelQueue{ + target: target, + interval: interval, + q: newRingBuffer[packetWithTimestamp](128), } - q.pushCount++ - heap.Push(&q.packets, packetWithDeliveryTimeAndOrder{packetWithDeliveryTime: p, count: q.pushCount}) +} - // Signal that a new packet arrived (non-blocking) - select { - case q.newPacket <- struct{}{}: - default: - } +// Enqueue adds a packet to the queue +func (q *coDelQueue) Enqueue(p Packet) { + q.byteCount += uint64(len(p.buf)) + q.q.PushBack(packetWithTimestamp{p, time.Now()}) } // Dequeue removes and returns the next packet when it's ready for delivery // This blocks until a packet is available AND its delivery time has been reached // Uses a timer that can be reset if a packet with earlier delivery time arrives -func (q *codelQueue) Dequeue() (*packetWithDeliveryTime, bool) { - timer := time.NewTimer(time.Hour) - timer.Stop() - - for { - q.mu.Lock() +func (q *coDelQueue) Dequeue() (Packet, bool) { + now := time.Now() + p, okayToDrop := q.doDequeue(now) - if q.closed { - q.mu.Unlock() - timer.Stop() - return nil, false - } - - if len(q.packets) == 0 { - // No packets, wait for one to arrive - q.mu.Unlock() - select { - case <-q.newPacket: - timer.Stop() - continue - case <-timer.C: - continue - } + if q.dropping { + if !okayToDrop { + // sojourn below target leave dropping + q.dropping = false } - earliest := q.packets[0] - earliestTime := earliest.DeliveryTime - - now := time.Now() - if now.Before(earliestTime) { - // Not ready yet, wait until delivery time or new packet - waitDuration := earliestTime.Sub(now) - timer.Reset(waitDuration) - q.mu.Unlock() - - select { - case <-timer.C: - // Timer expired, check again - continue - case <-q.newPacket: - // New packet arrived, might have earlier delivery time - timer.Stop() - continue + for !now.Before(q.dropNext) && q.dropping { + // implicitly drop the packet + q.drop(p) + q.count++ + p, okayToDrop = q.doDequeue(now) + if !okayToDrop { + // leave drop state + q.dropping = false + } else { + // schedule next drop + q.dropNext = controlLaw(q.dropNext, q.interval, q.count) } } + } else if okayToDrop { + // If we get here, we're not in drop state. The `okToDrop` + // return from doDequeue means that the sojourn time has been above + // 'TARGET' for 'INTERVAL', so enter drop state. + q.drop(p) + p, _ = q.doDequeue(now) + q.dropping = true - // Packet is ready, remove from queue and return it - po := heap.Pop(&q.packets).(packetWithDeliveryTimeAndOrder) - p := po.packetWithDeliveryTime - - // Reset CoDel state when queue becomes empty - if len(q.packets) == 0 { - q.dropping = false - q.firstAbove = time.Time{} + // If min went above TARGET close to when it last went + // below, assume that the drop rate that controlled the + // queue on the last cycle is a good starting point to + // control it now. (`dropNext` will be at most 'INTERVAL' + // later than the time of the last drop, so 'now - dropNext' + // is a good approximation of the time from the last drop + // until now.) Implementations vary slightly here; this is + // the Linux version, which is more widely deployed and + // tested. + delta := q.count - q.lastCount + q.count = 1 + if delta > 1 && now.Sub(q.dropNext) < 16*q.interval { + q.count = delta } - q.mu.Unlock() - - return p, true + q.dropNext = controlLaw(now, q.interval, q.count) + q.lastCount = q.count } -} -// shouldDrop implements the CoDel dropping decision (thread-safe version) -func (q *codelQueue) shouldDrop(sojournTime time.Duration) bool { - q.mu.Lock() - defer q.mu.Unlock() - return q.codelShouldDrop(sojournTime, time.Now()) + return p.Packet, len(p.Packet.buf) > 0 } -// codelShouldDrop implements the CoDel dropping decision -func (q *codelQueue) codelShouldDrop(sojournTime time.Duration, now time.Time) bool { - // Reset CoDel state when queue is empty (checked by caller before dequeue) - // This is handled by resetting state when queue becomes empty in Dequeue +func (q *coDelQueue) drop(p packetWithTimestamp) { + // TODO add stats +} - if sojournTime < q.target { - // Queue is good, reset state +func (q *coDelQueue) doDequeue(now time.Time) (p packetWithTimestamp, okToDrop bool) { + if q.q.Empty() { q.firstAbove = time.Time{} - q.dropping = false - return false - } - - // Queue delay is above target - if q.firstAbove.IsZero() { - // First time above target, start tracking - q.firstAbove = now.Add(q.interval) - return false - } - - if now.Before(q.firstAbove) { - // Haven't been above target for long enough - return false - } - - // We've been above target for the full interval - if !q.dropping { - // Enter dropping state - q.dropping = true - q.count = 1 - q.dropNext = now - q.lastDropTime = now - return true + return } - // Already in dropping state - if now.After(q.dropNext) { - // Time to drop another packet - q.count++ - // Calculate next drop time using control law: interval / sqrt(count) - delta := time.Duration(float64(q.interval) / math.Sqrt(float64(q.count))) - q.dropNext = now.Add(delta) - q.lastDropTime = now - return true + p = q.q.PopFront() + q.byteCount -= uint64(len(p.Packet.buf)) + sojournTime := now.Sub(p.ts) + if sojournTime < q.target || q.byteCount < uint64(q.MTU) { + q.firstAbove = time.Time{} + } else { + if q.firstAbove.IsZero() { + // Just went above from below. If still above later will say it's + // okay to drop + q.firstAbove = now.Add(q.interval) + } else if !now.Before(q.firstAbove) { + okToDrop = true + } } - return false -} - -// Close closes the queue -func (q *codelQueue) Close() { - q.mu.Lock() - defer q.mu.Unlock() - q.closed = true - close(q.newPacket) -} - -type packetWithDeliveryTimeAndOrder struct { - count int - *packetWithDeliveryTime -} - -// packetHeap implements heap.Interface ordered by packet delivery time. -type packetHeap []packetWithDeliveryTimeAndOrder - -func (h packetHeap) Len() int { return len(h) } - -func (h packetHeap) Less(i, j int) bool { - return (h[i].DeliveryTime.Before(h[j].DeliveryTime) || - h[i].DeliveryTime.Equal(h[j].DeliveryTime) && h[i].count < h[j].count) -} - -func (h packetHeap) Swap(i, j int) { - h[i], h[j] = h[j], h[i] -} - -func (h *packetHeap) Push(x any) { - *h = append(*h, x.(packetWithDeliveryTimeAndOrder)) + return } -func (h *packetHeap) Pop() any { - old := *h - n := len(old) - item := old[n-1] - *h = old[:n-1] - return item +func controlLaw(t time.Time, interval time.Duration, count int) time.Time { + return t.Add(time.Duration( + float64(time.Second) * + (interval.Seconds() / math.Sqrt(float64(count))))) } diff --git a/codel_test.go b/codel_test.go index d8aa7da..19e796d 100644 --- a/codel_test.go +++ b/codel_test.go @@ -1,28 +1,107 @@ +//go:build go1.25 + package simnet import ( - "math/rand" "testing" + "testing/synctest" "time" ) +func TestCodelQueueDropsPersistentBadQueue(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + const ( + target = 5 * time.Millisecond + interval = 100 * time.Millisecond + ) + q := newCoDelQueue(target, interval) + q.MTU = 1 + + // Build a queue that stays full for more than one interval. + for i := range 4 { + q.Enqueue(Packet{buf: []byte{byte(i)}}) + } + + // Allow in-flight packets to accumulate sojourn time. + time.Sleep(3 * q.interval) + + pkt, ok := q.Dequeue() + if !ok || len(pkt.buf) == 0 { + t.Fatal("expected packet before CoDel enters drop state") + } + if got := pkt.buf[0]; got != 0 { + t.Fatalf("first dequeue returned %d, want 0", got) + } + if q.dropping { + t.Fatal("queue should not drop until delay persists beyond interval") + } + + // Keep the queue persistently bad so dropping kicks in. + time.Sleep(3 * q.interval) + + pkt, ok = q.Dequeue() + if !ok || len(pkt.buf) == 0 { + t.Fatal("expected packet after CoDel begins dropping") + } + if got := pkt.buf[0]; got != 2 { + t.Fatalf("persistent queue should drop packet 1, got %d", got) + } + if !q.dropping { + t.Fatal("persistent bad queue should enter drop state") + } + }) +} + +func TestCodelQueueNoDropOnTransientQueue(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + const ( + target = 5 * time.Millisecond + interval = 100 * time.Millisecond + ) + q := newCoDelQueue(target, interval) + q.MTU = 1 + + for i := range 2 { + q.Enqueue(Packet{buf: []byte{byte(i)}}) + } + + // The queue goes above target but drains in less than one interval. + time.Sleep(q.target + q.interval/10) + + pkt, ok := q.Dequeue() + if !ok || len(pkt.buf) == 0 { + t.Fatal("expected first packet delivery") + } + if got := pkt.buf[0]; got != 0 { + t.Fatalf("expected packet 0, got %d", got) + } + + time.Sleep(q.interval / 2) + + pkt, ok = q.Dequeue() + if !ok || len(pkt.buf) == 0 { + t.Fatal("expected second packet delivery") + } + if got := pkt.buf[0]; got != 1 { + t.Fatalf("expected packet 1, got %d", got) + } + if q.dropping { + t.Fatal("transient queue should not enter drop state") + } + }) +} + func BenchmarkCodelQueueEnqueueDequeue(b *testing.B) { const initSize = 30000 const queueSize = 50000 - packets := make([]*packetWithDeliveryTime, queueSize) - base := time.Now().Add(-time.Second) + packets := make([]Packet, queueSize) + data := []byte("test") for i := range queueSize { - packets[i] = &packetWithDeliveryTime{ - DeliveryTime: base.Add(time.Duration(i) * time.Microsecond), - } + packets[i] = Packet{buf: data} } - r := rand.New(rand.NewSource(42)) - r.Shuffle(queueSize, func(i, j int) { - packets[i], packets[j] = packets[j], packets[i] - }) - q := newCodelQueue(5*time.Millisecond, 100*time.Millisecond) + q := newCoDelQueue(5*time.Millisecond, 100*time.Millisecond) for _, p := range packets { q.Enqueue(p) } @@ -33,7 +112,7 @@ func BenchmarkCodelQueueEnqueueDequeue(b *testing.B) { i = (i + 1) % queueSize pkt, ok := q.Dequeue() - if !ok || pkt == nil { + if !ok || len(pkt.buf) == 0 { b.Fatal("unexpected empty dequeue") } } diff --git a/fq_codel.go b/fq_codel.go new file mode 100644 index 0000000..06f7741 --- /dev/null +++ b/fq_codel.go @@ -0,0 +1,199 @@ +package simnet + +import ( + "hash/maphash" + "time" +) + +type coDelQueueWithCredits struct { + coDelQueue + credits int + bucket int +} + +// fqCoDel is an implementation of FQ-CoDel per RFC 8290. +type fqCoDel struct { + flows []*coDelQueueWithCredits + + newFlows linkedList[*coDelQueueWithCredits] + oldFlows linkedList[*coDelQueueWithCredits] + activeFlowCount int + + quantum int + target time.Duration + interval time.Duration + + hash maphash.Hash + + // Note: No packet limit is implemented in this fqCoDel implementation. This + // is not for production use. +} + +func newFqCoDel(target, interval time.Duration, quantum, flowCount int) fqCoDel { + return fqCoDel{ + flows: make([]*coDelQueueWithCredits, flowCount), + activeFlowCount: 0, + + quantum: quantum, + target: target, + interval: interval, + } +} + +func (q *fqCoDel) Enqueue(p Packet) { + bucket := int(p.Hash(&q.hash) % uint64(len(q.flows))) + + fq := q.flows[bucket] + if fq == nil { + fq = &coDelQueueWithCredits{ + coDelQueue: newCoDelQueue(q.target, q.interval), + credits: q.quantum, + bucket: bucket, + } + q.flows[bucket] = fq + q.newFlows.append(&listNode[*coDelQueueWithCredits]{v: fq}) + } + + fq.Enqueue(p) +} + +// Dequeue implements the Fq-CoDel Dequeue algorithm. The state transition of +// queues between new, old, and empty are represented by this diagram. +// +/* +-----------------+ +------------------+ +/* | | Empty | | +/* | Empty |<---------------+ Old +----+ +/* | | | | | +/* +-------+---------+ +------------------+ | +/* | ^ ^ |Credits +/* |Arrival | | |Exhausted +/* v | | | +/* +-----------------+ | | | +/* | | Empty or | | | +/* | New +-------------------+ +-------+ +/* | | Credits Exhausted +/* +-----------------+ +*/ +func (q *fqCoDel) Dequeue() (Packet, bool) { + + for !q.newFlows.empty() { + fq := q.newFlows.peek() + if fq.credits < 0 { + // For the first part, the scheduler first looks at the list of new + // queues; for the queue at the head of that list, if that queue has a + // negative number of credits (i.e., it has already dequeued at least a + // quantum of bytes), it is given an additional quantum of credits, the + // queue is put onto _the end of_ the list of old queues, and the + // routine selects the next queue and starts again. + fq.credits += q.quantum + n := q.newFlows.removeFirst() + q.oldFlows.append(n) + + continue + } + + // Otherwise, that queue is selected for dequeue. + if p, ok := fq.Dequeue(); ok { + fq.credits -= len(p.buf) + return p, true + } else { + // If the CoDel algorithm does not return a packet, then the + // queue must be empty, and the scheduler does one of two things. If + // the queue selected for dequeue came from the list of new queues, it + // is moved to _the end of_ the list of old queues. + // + // The step that moves an empty queue from the list of new queues to the + // end of the list of old queues before it is removed is crucial to + // prevent starvation. Otherwise, the queue could reappear (the next + // time a packet arrives for it) before the list of old queues is + // visited; this can go on indefinitely, even with a small number of + // active flows, if the flow providing packets to the queue in question + // transmits at just the right rate. This is prevented by first moving + // the queue to the end of the list of old queues, forcing the scheduler + // to service all old queues before the empty queue is removed and thus + // preventing starvation. + n := q.newFlows.removeFirst() + q.oldFlows.append(n) + continue + } + } + + // If the list of new queues is empty, the scheduler proceeds down the + // list of old queues in the same fashion (checking the credits and + // either selecting the queue for dequeueing or adding credits and + // putting the queue back at the end of the list). + for !q.oldFlows.empty() { + fq := q.oldFlows.peek() + if fq.credits < 0 { + fq.credits += q.quantum + n := q.oldFlows.removeFirst() + q.oldFlows.append(n) + continue + } + + if p, ok := fq.Dequeue(); ok { + // If, instead, the scheduler _did_ get a packet back from the CoDel + // algorithm, it subtracts the size of the packet from the byte credits + // for the selected queue and returns the packet as the result of the + // dequeue operation. + fq.credits -= len(p.buf) + return p, true + } else { + // Finally, if the CoDel algorithm does not return a packet, then the + // queue must be empty, and the scheduler does one of two things. If + // the queue selected for dequeue came from the list of new queues, it + // is moved to _the end of_ the list of old queues. If instead it came + // from the list of old queues, that queue is removed from the list, to + // be added back (as a new queue) the next time a packet arrives that + // hashes to that queue. Then (since no packet was available for + // dequeue), the whole dequeue process is restarted from the beginning. + q.oldFlows.removeFirst() + q.flows[fq.bucket] = nil + continue + } + } + + return Packet{}, false +} + +type linkedList[T any] struct { + head *listNode[T] + tail *listNode[T] +} + +func (l *linkedList[T]) append(item *listNode[T]) { + item.next = nil + if l.tail == nil { + l.head = item + l.tail = l.head + } else { + l.tail.next = item + l.tail = l.tail.next + } +} + +func (l *linkedList[T]) empty() bool { + return l.head == nil +} + +func (l *linkedList[T]) peek() T { + return l.head.v +} + +func (l *linkedList[T]) removeFirst() *listNode[T] { + if l.head == nil { + return nil + } + node := l.head + l.head = l.head.next + if l.head == nil { + l.tail = nil + } + node.next = nil + return node +} + +type listNode[T any] struct { + v T + next *listNode[T] +} diff --git a/fq_codel_test.go b/fq_codel_test.go new file mode 100644 index 0000000..13ed3d9 --- /dev/null +++ b/fq_codel_test.go @@ -0,0 +1,284 @@ +//go:build go1.25 + +package simnet + +import ( + "net" + "testing" + "testing/synctest" + "time" +) + +func TestFqCoDelSingleFlow(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + const ( + target = 5 * time.Millisecond + interval = 100 * time.Millisecond + quantum = 128 + flowCount = 32 + payload = 64 + ) + + q := newTestFqCoDel(target, interval, quantum, flowCount) + flow := mustNewTestFlow(t, q, 1, nil) + + for seq := range 3 { + q.Enqueue(flow.packet(seq, payload)) + } + + for seq := range 3 { + pkt, ok := q.Dequeue() + if !ok { + t.Fatalf("dequeue %d returned empty packet", seq) + } + flowID, gotSeq := decodeFlowAndSeq(pkt) + if flowID != flow.id || gotSeq != seq { + t.Fatalf("got flow %d seq %d, want flow %d seq %d", flowID, gotSeq, flow.id, seq) + } + } + + if pkt, ok := q.Dequeue(); ok { + flowID, seq := decodeFlowAndSeq(pkt) + t.Fatalf("queue should be empty, got flow %d seq %d", flowID, seq) + } + }) +} + +func TestFqCoDelMultipleFlowsNoStandingQueue(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + const ( + target = 5 * time.Millisecond + interval = 100 * time.Millisecond + quantum = 64 + flowCount = 64 + payload = 96 + ) + + q := newTestFqCoDel(target, interval, quantum, flowCount) + used := make(map[int]struct{}) + flows := []testFlow{ + mustNewTestFlow(t, q, 1, used), + mustNewTestFlow(t, q, 2, used), + } + + // Each flow only has a single packet in flight (no standing queue) and + // packets should be dequeued in arrival order. + for _, f := range flows { + q.Enqueue(f.packet(0, payload)) + } + + for i, f := range flows { + pkt, ok := q.Dequeue() + if !ok { + t.Fatalf("unexpected empty dequeue at position %d", i) + } + flowID, seq := decodeFlowAndSeq(pkt) + if flowID != f.id || seq != 0 { + t.Fatalf("got flow %d seq %d, want flow %d seq 0", flowID, seq, f.id) + } + } + + if pkt, ok := q.Dequeue(); ok { + flowID, seq := decodeFlowAndSeq(pkt) + t.Fatalf("queue should be empty after draining, got flow %d seq %d", flowID, seq) + } + }) +} + +func TestFqCoDelPersistentBadQueueIsIsolated(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + const ( + target = 1 * time.Millisecond + interval = 20 * time.Millisecond + quantum = 64 + flowCount = 128 + payload = 192 + ) + + q := newTestFqCoDel(target, interval, quantum, flowCount) + used := make(map[int]struct{}) + badFlow := mustNewTestFlow(t, q, 1, used) + goodFlow := mustNewTestFlow(t, q, 2, used) + + // Build a persistently full queue for the bad flow so CoDel enters drop state. + for seq := 0; seq < 10; seq++ { + q.Enqueue(badFlow.packet(seq, payload)) + } + + badQueue := flowQueue(t, q, badFlow) + if badQueue == nil { + t.Fatal("bad flow queue missing") + } + + time.Sleep(3 * interval) + + enteredDrop := false + for attempt := 0; attempt < 6; attempt++ { + pkt, ok := q.Dequeue() + if !ok { + t.Fatalf("bad flow should still have packets on attempt %d", attempt) + } + flowID, _ := decodeFlowAndSeq(pkt) + if flowID != badFlow.id { + t.Fatalf("unexpected flow %d while draining bad queue", flowID) + } + if badQueue.dropping { + enteredDrop = true + break + } + time.Sleep(interval) + } + if !enteredDrop { + t.Fatal("persistently bad flow never entered drop state") + } + + q.Enqueue(goodFlow.packet(0, payload)) + + next, ok := q.Dequeue() + if !ok { + t.Fatal("expected good flow packet") + } + nextFlow, seq := decodeFlowAndSeq(next) + if nextFlow != goodFlow.id || seq != 0 { + t.Fatalf("expected good flow packet 0, got flow %d seq %d", nextFlow, seq) + } + }) +} + +func TestFqCoDelThreeFlowsNoStarvation(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + const ( + target = 1 * time.Millisecond + interval = 20 * time.Millisecond + quantum = 64 + flowCount = 128 + payload = 160 + rounds = 3 + ) + + q := newTestFqCoDel(target, interval, quantum, flowCount) + used := make(map[int]struct{}) + flows := []testFlow{ + mustNewTestFlow(t, q, 1, used), + mustNewTestFlow(t, q, 2, used), + mustNewTestFlow(t, q, 3, used), + } + + for seq := range rounds { + for _, f := range flows { + q.Enqueue(f.packet(seq, payload)) + } + } + + for seq := range rounds { + seen := make(map[int]bool) + for range flows { + pkt, ok := q.Dequeue() + if !ok { + t.Fatalf("unexpected empty dequeue while draining seq %d", seq) + } + flowID, gotSeq := decodeFlowAndSeq(pkt) + if gotSeq != seq { + t.Fatalf("flow %d got seq %d, want %d", flowID, gotSeq, seq) + } + seen[flowID] = true + } + if len(seen) != len(flows) { + t.Fatalf("not all flows serviced for seq %d, got %v", seq, seen) + } + } + + if pkt, ok := q.Dequeue(); ok { + flowID, seq := decodeFlowAndSeq(pkt) + t.Fatalf("queue should be empty after all rounds, got flow %d seq %d", flowID, seq) + } + }) +} + +func newTestFqCoDel(target, interval time.Duration, quantum, flowCount int) *fqCoDel { + q := newFqCoDel(target, interval, quantum, flowCount) + return &q +} + +type testFlow struct { + id int + to net.UDPAddr + from net.UDPAddr +} + +func (f testFlow) packet(seq, size int) Packet { + to := f.to + from := f.from + return Packet{ + To: &to, + From: &from, + buf: flowPayload(f.id, seq, size), + } +} + +func (f testFlow) bucket(q *fqCoDel) int { + to := f.to + from := f.from + probe := Packet{To: &to, From: &from} + return int(probe.Hash(&q.hash) % uint64(len(q.flows))) +} + +func decodeFlowAndSeq(p Packet) (flowID, seq int) { + if len(p.buf) < 2 { + return -1, -1 + } + return int(p.buf[0]), int(p.buf[1]) +} + +func flowPayload(flowID, seq, size int) []byte { + if size < 2 { + size = 2 + } + buf := make([]byte, size) + buf[0] = byte(flowID) + buf[1] = byte(seq) + for i := 2; i < len(buf); i++ { + buf[i] = byte(flowID + seq) + } + return buf +} + +func flowQueue(t *testing.T, q *fqCoDel, f testFlow) *coDelQueueWithCredits { + t.Helper() + bucket := f.bucket(q) + if bucket < 0 || bucket >= len(q.flows) { + t.Fatalf("flow bucket %d out of range", bucket) + } + return q.flows[bucket] +} + +func mustNewTestFlow(t *testing.T, q *fqCoDel, id int, usedBuckets map[int]struct{}) testFlow { + t.Helper() + + for salt := range 512 { + to := net.UDPAddr{ + IP: net.IPv4(10, byte(id), byte(salt), byte(id+salt)), + Port: 10000 + id + salt*7, + } + from := net.UDPAddr{ + IP: net.IPv4(192, 0, byte(id+salt), byte(salt)), + Port: 20000 + id + salt*11, + } + flow := testFlow{ + id: id, + to: to, + from: from, + } + bucket := flow.bucket(q) + if usedBuckets != nil { + if _, exists := usedBuckets[bucket]; exists { + continue + } + usedBuckets[bucket] = struct{}{} + } + return flow + } + + t.Fatalf("unable to derive flow %d with unique bucket", id) + return testFlow{} +} diff --git a/internal/require/require.go b/internal/require/require.go index 00a379b..666b904 100644 --- a/internal/require/require.go +++ b/internal/require/require.go @@ -1,10 +1,10 @@ package require import ( - "errors" - "fmt" - "reflect" - "testing" + "errors" + "fmt" + "reflect" + "testing" ) // NoError fails the test if err is not nil. @@ -25,15 +25,16 @@ func ErrorIs(t *testing.T, err error, target error, msgAndArgs ...any) { // Equal fails the test if expected != actual using reflect.DeepEqual. func Equal(t *testing.T, expected, actual any, msgAndArgs ...any) { - t.Helper() - if !reflect.DeepEqual(expected, actual) { - failNow(t, fmt.Sprintf("not equal\nexpected: %#v\nactual: %#v", expected, actual), msgAndArgs...) - } + t.Helper() + if !reflect.DeepEqual(expected, actual) { + failNow(t, fmt.Sprintf("not equal\nexpected: %#v\nactual: %#v", expected, actual), msgAndArgs...) + } } func failNow(t *testing.T, baseMsg string, msgAndArgs ...any) { - if len(msgAndArgs) > 0 { - baseMsg = baseMsg + ": " + fmt.Sprint(msgAndArgs...) - } - t.Fatalf("%s", baseMsg) + t.Helper() + if len(msgAndArgs) > 0 { + baseMsg = baseMsg + ": " + fmt.Sprint(msgAndArgs...) + } + t.Fatalf("%s", baseMsg) } diff --git a/packetheap.go b/packetheap.go new file mode 100644 index 0000000..50f7ef5 --- /dev/null +++ b/packetheap.go @@ -0,0 +1,35 @@ +package simnet + +import "time" + +type packetWithDeliveryTimeAndOrder struct { + *Packet + order int + deliveryTime time.Time +} + +// packetHeap implements heap.Interface ordered by packet delivery time. +type packetHeap []packetWithDeliveryTimeAndOrder + +func (h packetHeap) Len() int { return len(h) } + +func (h packetHeap) Less(i, j int) bool { + return (h[i].deliveryTime.Before(h[j].deliveryTime) || + h[i].deliveryTime.Equal(h[j].deliveryTime) && h[i].order < h[j].order) +} + +func (h packetHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +func (h *packetHeap) Push(x any) { + *h = append(*h, x.(packetWithDeliveryTimeAndOrder)) +} + +func (h *packetHeap) Pop() any { + old := *h + n := len(old) + item := old[n-1] + *h = old[:n-1] + return item +} diff --git a/ratelink.go b/ratelink.go new file mode 100644 index 0000000..f3c2a02 --- /dev/null +++ b/ratelink.go @@ -0,0 +1,38 @@ +package simnet + +import ( + "time" + + "golang.org/x/time/rate" +) + +type RateLink struct { + *rate.Limiter + BitsPerSecond int + Receiver PacketReceiver +} + +// Creates a new RateLimiter with the following parameters: +// bandwidth (in bits/sec). +// burstSize is in Bytes +func newRateLimiter(bandwidth int, burstSize int) *rate.Limiter { + // Convert bandwidth from bits/sec to bytes/sec + bytesPerSecond := rate.Limit(float64(bandwidth) / 8.0) + return rate.NewLimiter(bytesPerSecond, burstSize) +} + +func NewRateLink(bandwidth int, burstSize int, receiver PacketReceiver) *RateLink { + return &RateLink{ + Limiter: newRateLimiter(bandwidth, burstSize), + Receiver: receiver, + } +} + +func (l *RateLink) Reserve(now time.Time, packetSize int) time.Duration { + r := l.Limiter.ReserveN(now, packetSize) + return r.DelayFrom(now) +} + +func (l *RateLink) RecvPacket(p Packet) { + l.Receiver.RecvPacket(p) +} diff --git a/ratelink_test.go b/ratelink_test.go new file mode 100644 index 0000000..3815b43 --- /dev/null +++ b/ratelink_test.go @@ -0,0 +1,52 @@ +package simnet + +import ( + "math" + "testing" + "testing/synctest" + "time" +) + +type countingReceiver struct { + totalBytes int +} + +func (c *countingReceiver) RecvPacket(p Packet) { + c.totalBytes += len(p.buf) +} + +func TestRateLinkObservedBandwidth(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + const ( + mtu = 1500 + bandwidth = 50 * Mibps + burstSize = 10 * mtu + packets = 20_000 + totalBytes = packets * mtu + ) + + receiver := &countingReceiver{} + link := NewRateLink(bandwidth, burstSize, receiver) + + chunk := make([]byte, mtu) + + start := time.Now() + for range packets { + p := Packet{buf: chunk} + time.Sleep(link.Reserve(time.Now(), len(p.buf))) + link.RecvPacket(p) + } + duration := time.Since(start) + + if receiver.totalBytes != totalBytes { + t.Fatalf("expected receiver to get %d bytes, got %d", totalBytes, receiver.totalBytes) + } + + observedBandwidth := 8 * float64(totalBytes) / duration.Seconds() + diff := math.Abs(observedBandwidth - float64(bandwidth)) + allowedError := 0.10 * float64(bandwidth) + if diff > allowedError { + t.Fatalf("observed bandwidth %f bps differs from expected %d bps by %f bps (allowed %f)", observedBandwidth, bandwidth, diff, allowedError) + } + }) +} diff --git a/ringbuffer.go b/ringbuffer.go new file mode 100644 index 0000000..f0099bb --- /dev/null +++ b/ringbuffer.go @@ -0,0 +1,44 @@ +package simnet + +type ringBuffer[T any] struct { + head int + tail int + buf []T +} + +func newRingBuffer[T any](capacity int) ringBuffer[T] { + return ringBuffer[T]{ + head: 0, + tail: 0, + buf: make([]T, capacity), + } +} + +func (r *ringBuffer[T]) PushBack(value T) { + r.buf[r.tail] = value + r.tail = (r.tail + 1) % len(r.buf) + if r.tail == r.head { + // reallocate a larger buffer + newBuf := make([]T, len(r.buf)*2) + copy(newBuf, r.buf[r.head:]) + copy(newBuf[len(r.buf)-r.head:], r.buf[:r.head]) + oldLen := len(r.buf) + r.buf = newBuf + r.head = 0 + r.tail = oldLen + } +} + +func (r *ringBuffer[T]) PopFront() T { + value := r.buf[r.head] + r.head = (r.head + 1) % len(r.buf) + return value +} + +func (r *ringBuffer[T]) Peek() T { + return r.buf[r.head] +} + +func (r *ringBuffer[T]) Empty() bool { + return r.head == r.tail +} diff --git a/ringbuffer_test.go b/ringbuffer_test.go new file mode 100644 index 0000000..487c4ad --- /dev/null +++ b/ringbuffer_test.go @@ -0,0 +1,69 @@ +package simnet + +import "testing" + +func TestRingBufferPushPopOrder(t *testing.T) { + rb := newRingBuffer[int](2) + total := 10 + + for i := range total { + rb.PushBack(i) + } + + for i := range total { + got := rb.PopFront() + if got != i { + t.Fatalf("PopFront()=%d, want %d", got, i) + } + } +} + +func TestRingBufferWrapAndGrowth(t *testing.T) { + rb := newRingBuffer[int](4) + + for _, v := range []int{0, 1, 2} { + rb.PushBack(v) + } + + if got := rb.PopFront(); got != 0 { + t.Fatalf("PopFront()=%d, want 0", got) + } + + for _, v := range []int{3, 4, 5} { + rb.PushBack(v) + } + + want := []int{1, 2, 3, 4, 5} + for _, v := range want { + got := rb.PopFront() + if got != v { + t.Fatalf("PopFront()=%d, want %d", got, v) + } + } +} + +func TestRingBufferPeekAndEmpty(t *testing.T) { + rb := newRingBuffer[string](1) + + if !rb.Empty() { + t.Fatalf("empty()=false, want true") + } + + rb.PushBack("hello") + + if rb.Empty() { + t.Fatalf("empty()=true, want false") + } + + if got := rb.Peek(); got != "hello" { + t.Fatalf("Peek()=%q, want %q", got, "hello") + } + + if got := rb.PopFront(); got != "hello" { + t.Fatalf("PopFront()=%q, want %q", got, "hello") + } + + if !rb.Empty() { + t.Fatalf("empty()=false, want true after PopFront") + } +} diff --git a/router.go b/router.go index dcce4e3..0c641c8 100644 --- a/router.go +++ b/router.go @@ -1,14 +1,31 @@ package simnet import ( - "errors" + "container/heap" "fmt" + "log/slog" "net" "net/netip" "sync" "time" ) +type DropReason string + +const ( + DropReasonUnknownDestination DropReason = "unknown destination" + DropReasonUnknownSource DropReason = "unknown source" + DropReasonFirewalled DropReason = "Packet firewalled" +) + +type OnDrop func(packet Packet, reason DropReason) + +func LogOnDrop(logger *slog.Logger) OnDrop { + return func(packet Packet, reason DropReason) { + logger.Error("Dropping packet", "from", packet.From, "to", packet.To, "reason", reason) + } +} + type ipPortKey struct { ip string port uint16 @@ -98,18 +115,20 @@ func (m *addrMap[V]) Delete(addr net.Addr) error { // PerfectRouter is a router that has no latency or jitter and can route to // every node type PerfectRouter struct { - nodes addrMap[PacketReceiver] + OnDrop OnDrop + nodes addrMap[PacketReceiver] } -// SendPacket implements Router. -func (r *PerfectRouter) SendPacket(p Packet) error { +func (r *PerfectRouter) RecvPacket(p Packet) { conn, ok := r.nodes.Get(p.To) if !ok { - return errors.New("unknown destination") + if r.OnDrop != nil { + r.OnDrop(p, DropReasonUnknownDestination) + } + return } conn.RecvPacket(p) - return nil } func (r *PerfectRouter) AddNode(addr net.Addr, conn PacketReceiver) { @@ -122,33 +141,72 @@ func (r *PerfectRouter) RemoveNode(addr net.Addr) { var _ Router = &PerfectRouter{} -type DelayedPacketReciever struct { - inner PacketReceiver - delay time.Duration -} +type VariableLatencyRouter struct { + PerfectRouter + LatencyFunc func(packet *Packet) time.Duration + CloseSignal chan struct{} -func (r *DelayedPacketReciever) RecvPacket(p Packet) { - time.AfterFunc(r.delay, func() { r.inner.RecvPacket(p) }) + packets chan Packet + packetCount int + h packetHeap } -type FixedLatencyRouter struct { - PerfectRouter - latency time.Duration +func (r *VariableLatencyRouter) RecvPacket(p Packet) { + r.packets <- p } -func (r *FixedLatencyRouter) SendPacket(p Packet) error { - return r.PerfectRouter.SendPacket(p) +// wgGo is the same as Go 1.25's wg.Go. Remove this when Go 1.26 is out +func wgGo(wg *sync.WaitGroup, f func()) { + wg.Add(1) + go func() { + defer wg.Done() + f() + }() } -func (r *FixedLatencyRouter) AddNode(addr net.Addr, conn PacketReceiver) { - r.PerfectRouter.AddNode(addr, &DelayedPacketReciever{ - inner: conn, - delay: r.latency, +func (r *VariableLatencyRouter) Start(wg *sync.WaitGroup) { + r.packets = make(chan Packet, 128) + heap.Init(&r.h) + + wgGo(wg, func() { + var nextDelivery time.Time + deliveryTimer := time.NewTimer(0) + deliveryTimer.Stop() + + for { + select { + case <-r.CloseSignal: + return + case p := <-r.packets: + r.packetCount++ + latency := r.LatencyFunc(&p) + deliveryTime := time.Now().Add(latency) + heap.Push(&r.h, packetWithDeliveryTimeAndOrder{ + Packet: &p, + order: r.packetCount, + deliveryTime: deliveryTime, + }) + if nextDelivery.IsZero() || deliveryTime.Before(nextDelivery) { + nextDelivery = deliveryTime + deliveryTimer.Reset(latency) + } + case <-deliveryTimer.C: + now := time.Now() + for len(r.h) > 0 && !r.h[0].deliveryTime.After(now) { + p := heap.Pop(&r.h).(packetWithDeliveryTimeAndOrder).Packet + r.PerfectRouter.RecvPacket(*p) + } + if len(r.h) > 0 { + nextDelivery = r.h[0].deliveryTime + deliveryTimer.Reset(nextDelivery.Sub(now)) + } else { + nextDelivery = time.Time{} + } + } + } }) } -var _ Router = &FixedLatencyRouter{} - type simpleNodeFirewall struct { mu sync.Mutex publiclyReachable bool @@ -181,6 +239,7 @@ func (f *simpleNodeFirewall) String() string { } type SimpleFirewallRouter struct { + OnDrop OnDrop mu sync.Mutex nodes map[string]*simpleNodeFirewall publiclyReachableAddrs map[string]bool @@ -205,27 +264,35 @@ func (r *SimpleFirewallRouter) SetAddrPubliclyReachable(addr net.Addr) { r.publiclyReachableAddrs[addr.String()] = true } -func (r *SimpleFirewallRouter) SendPacket(p Packet) error { +func (r *SimpleFirewallRouter) RecvPacket(p Packet) { r.mu.Lock() defer r.mu.Unlock() toNode, exists := r.nodes[p.To.String()] if !exists { - return errors.New("unknown destination") + if r.OnDrop != nil { + r.OnDrop(p, DropReasonUnknownDestination) + } + return } // Record that this node is sending a packet to the destination fromNode, exists := r.nodes[p.From.String()] if !exists { - return errors.New("unknown source") + if r.OnDrop != nil { + r.OnDrop(p, DropReasonUnknownSource) + } + return } fromNode.MarkPacketSentOut(p) if !toNode.IsPacketInAllowed(p) { - return nil // Silently drop blocked packets + if r.OnDrop != nil { + r.OnDrop(p, DropReasonFirewalled) + } + return } toNode.node.RecvPacket(p) - return nil } func (r *SimpleFirewallRouter) AddNode(addr net.Addr, conn PacketReceiver) { diff --git a/router_test.go b/router_test.go new file mode 100644 index 0000000..23f357b --- /dev/null +++ b/router_test.go @@ -0,0 +1,139 @@ +package simnet + +import ( + "net" + "sync" + "testing" + "testing/synctest" + "time" +) + +func TestVariableLatencyRouterDelaysPackets(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + const wantLatency = 25 * time.Millisecond + receiver := newRouterRecordingReceiver(1) + router, from, to := startVariableLatencyRouter(t, func(*Packet) time.Duration { + return wantLatency + }, receiver) + + sendTime := time.Now() + router.RecvPacket(Packet{From: from, To: to, buf: []byte{0x1}}) + + delivery := receiver.waitFor(t) + delay := delivery.arrival.Sub(sendTime) + const slop = time.Millisecond + if delay < wantLatency { + t.Fatalf("packet delivered too early: got %v want at least %v", delay, wantLatency) + } + if delay > wantLatency+slop { + t.Fatalf("packet delivered too late: got %v want no more than %v", delay, wantLatency+slop) + } + }) +} + +func TestVariableLatencyRouterAllowsReordering(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + receiver := newRouterRecordingReceiver(2) + router, from, to := startVariableLatencyRouter(t, func(p *Packet) time.Duration { + if len(p.buf) == 0 { + return 0 + } + if p.buf[0] == 1 { + return 60 * time.Millisecond + } + return 5 * time.Millisecond + }, receiver) + + router.RecvPacket(Packet{From: from, To: to, buf: []byte{1}}) + time.Sleep(10 * time.Millisecond) + router.RecvPacket(Packet{From: from, To: to, buf: []byte{2}}) + + first := receiver.waitFor(t) + second := receiver.waitFor(t) + if first.packet.buf[0] != 2 { + t.Fatalf("expected packet 2 to arrive first, got %d", first.packet.buf[0]) + } + if second.packet.buf[0] != 1 { + t.Fatalf("expected packet 1 to arrive second, got %d", second.packet.buf[0]) + } + }) +} + +func TestVariableLatencyRouterKeepsOrderWithEqualLatency(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + const latency = 15 * time.Millisecond + receiver := newRouterRecordingReceiver(2) + router, from, to := startVariableLatencyRouter(t, func(*Packet) time.Duration { + return latency + }, receiver) + + router.RecvPacket(Packet{From: from, To: to, buf: []byte{1}}) + router.RecvPacket(Packet{From: from, To: to, buf: []byte{2}}) + + first := receiver.waitFor(t) + second := receiver.waitFor(t) + if first.packet.buf[0] != 1 { + t.Fatalf("expected packet 1 to arrive first, got %d", first.packet.buf[0]) + } + if second.packet.buf[0] != 2 { + t.Fatalf("expected packet 2 to arrive second, got %d", second.packet.buf[0]) + } + if second.arrival.Before(first.arrival) { + t.Fatalf("expected packets with same latency to preserve order, got first=%v second=%v", first.arrival, second.arrival) + } + }) +} + +type deliveredPacket struct { + packet Packet + arrival time.Time +} + +type routerRecordingReceiver struct { + deliveries chan deliveredPacket +} + +func newRouterRecordingReceiver(buffer int) *routerRecordingReceiver { + return &routerRecordingReceiver{ + deliveries: make(chan deliveredPacket, buffer), + } +} + +func (r *routerRecordingReceiver) RecvPacket(p Packet) { + r.deliveries <- deliveredPacket{ + packet: p, + arrival: time.Now(), + } +} + +func (r *routerRecordingReceiver) waitFor(t *testing.T) deliveredPacket { + t.Helper() + select { + case delivery := <-r.deliveries: + return delivery + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for packet delivery") + return deliveredPacket{} + } +} + +func startVariableLatencyRouter(t *testing.T, latency func(*Packet) time.Duration, receiver PacketReceiver) (*VariableLatencyRouter, net.Addr, net.Addr) { + t.Helper() + router := &VariableLatencyRouter{ + LatencyFunc: latency, + CloseSignal: make(chan struct{}), + } + + from := &net.UDPAddr{IP: net.IPv4(203, 0, 113, 1), Port: 40000} + to := &net.UDPAddr{IP: net.IPv4(203, 0, 113, 2), Port: 40001} + router.AddNode(to, receiver) + + var wg sync.WaitGroup + router.Start(&wg) + t.Cleanup(func() { + close(router.CloseSignal) + wg.Wait() + }) + + return router, from, to +} diff --git a/simconn.go b/simconn.go index a070669..38cf75c 100644 --- a/simconn.go +++ b/simconn.go @@ -2,6 +2,7 @@ package simnet import ( "errors" + "hash/maphash" "net" "slices" "sync" @@ -18,7 +19,7 @@ type PacketReceiver interface { // Router handles routing of packets between simulated connections. // Implementations are responsible for delivering packets to their destinations. type Router interface { - SendPacket(p Packet) error + PacketReceiver AddNode(addr net.Addr, receiver PacketReceiver) } @@ -28,6 +29,13 @@ type Packet struct { buf []byte } +func (p *Packet) Hash(h *maphash.Hash) uint64 { + h.Reset() + h.WriteString(p.To.String()) + h.WriteString(p.From.String()) + return h.Sum64() +} + // SimConn is a simulated network connection that implements net.PacketConn. It // provides packet-based communication through a Router for testing and // simulation purposes. All send/recv operations are handled through the @@ -38,12 +46,14 @@ type SimConn struct { closedChan chan struct{} deadlineUpdated chan struct{} + link Simlink + packetsSent atomic.Uint64 packetsRcvd atomic.Uint64 bytesSent atomic.Int64 bytesRcvd atomic.Int64 - router Router + upPacketReceiver PacketReceiver myAddr *net.UDPAddr myLocalAddr net.Addr @@ -59,29 +69,31 @@ type SimConn struct { // NewSimConn creates a new simulated connection that drops packets if the // receive buffer is full. -func NewSimConn(addr *net.UDPAddr, rtr Router) *SimConn { - return newSimConn(addr, rtr, false) +func NewSimConn(addr *net.UDPAddr) *SimConn { + return newSimConn(addr, false) } // NewBlockingSimConn creates a new simulated connection that blocks if the // receive buffer is full. Does not drop packets. -func NewBlockingSimConn(addr *net.UDPAddr, rtr Router) *SimConn { - return newSimConn(addr, rtr, true) +func NewBlockingSimConn(addr *net.UDPAddr) *SimConn { + return newSimConn(addr, true) } -func newSimConn(addr *net.UDPAddr, rtr Router, block bool) *SimConn { +func newSimConn(addr *net.UDPAddr, block bool) *SimConn { c := &SimConn{ recvBackPressure: block, - router: rtr, myAddr: addr, packetsToRead: make(chan Packet, 32), closedChan: make(chan struct{}), deadlineUpdated: make(chan struct{}, 1), } - rtr.AddNode(addr, c) return c } +func (c *SimConn) SetUpPacketReceiver(r PacketReceiver) { + c.upPacketReceiver = r +} + type ConnStats struct { BytesSent int BytesRcvd int @@ -213,7 +225,11 @@ func (c *SimConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { To: addr, buf: slices.Clone(p), } - return len(p), c.router.SendPacket(pkt) + if c.upPacketReceiver == nil { + panic("upPacketReceiver is nil. Did you forget to call simconn.SetUpPacketReceiver?") + } + c.upPacketReceiver.RecvPacket(pkt) + return len(p), nil } func (c *SimConn) UnicastAddr() net.Addr { diff --git a/simconn_test.go b/simconn_test.go index b37c8fb..42b3978 100644 --- a/simconn_test.go +++ b/simconn_test.go @@ -18,8 +18,13 @@ func TestSimConnBasicConnectivity(t *testing.T) { addr1 := &net.UDPAddr{IP: IntToPublicIPv4(1), Port: 1234} addr2 := &net.UDPAddr{IP: IntToPublicIPv4(2), Port: 1234} - conn1 := NewSimConn(addr1, router) - conn2 := NewSimConn(addr2, router) + conn1 := NewSimConn(addr1) + conn1.SetUpPacketReceiver(router) + router.AddNode(conn1.UnicastAddr(), conn1) + + conn2 := NewSimConn(addr2) + conn2.SetUpPacketReceiver(router) + router.AddNode(conn2.UnicastAddr(), conn2) // Test sending data from conn1 to conn2 testData := []byte("hello world") @@ -48,7 +53,9 @@ func TestSimConnDeadlines(t *testing.T) { router := &PerfectRouter{} addr1 := &net.UDPAddr{IP: IntToPublicIPv4(1), Port: 1234} - conn := NewSimConn(addr1, router) + conn := NewSimConn(addr1) + conn.SetUpPacketReceiver(router) + router.AddNode(conn.UnicastAddr(), conn) t.Run("read deadline", func(t *testing.T) { deadline := time.Now().Add(10 * time.Millisecond) @@ -74,7 +81,9 @@ func TestSimConnClose(t *testing.T) { router := &PerfectRouter{} addr1 := &net.UDPAddr{IP: IntToPublicIPv4(1), Port: 1234} - conn := NewSimConn(addr1, router) + conn := NewSimConn(addr1) + conn.SetUpPacketReceiver(router) + router.AddNode(conn.UnicastAddr(), conn) err := conn.Close() require.NoError(t, err) @@ -96,7 +105,9 @@ func TestSimConnLocalAddr(t *testing.T) { router := &PerfectRouter{} addr1 := &net.UDPAddr{IP: IntToPublicIPv4(1), Port: 1234} - conn := NewSimConn(addr1, router) + conn := NewSimConn(addr1) + conn.SetUpPacketReceiver(router) + router.AddNode(conn.UnicastAddr(), conn) // Test default local address require.Equal(t, addr1, conn.LocalAddr()) @@ -107,111 +118,34 @@ func TestSimConnLocalAddr(t *testing.T) { require.Equal(t, customAddr, conn.LocalAddr()) } -func TestSimConnDeadlinesWithLatency(t *testing.T) { - router := &FixedLatencyRouter{ - PerfectRouter: PerfectRouter{}, - latency: 100 * time.Millisecond, - } - - addr1 := &net.UDPAddr{IP: IntToPublicIPv4(1), Port: 1234} - addr2 := &net.UDPAddr{IP: IntToPublicIPv4(2), Port: 1234} - - conn1 := NewSimConn(addr1, router) - conn2 := NewSimConn(addr2, router) - - reset := func() { - router.RemoveNode(addr1) - router.RemoveNode(addr2) - - conn1 = NewSimConn(addr1, router) - conn2 = NewSimConn(addr2, router) - } - - t.Run("write succeeds within deadline", func(t *testing.T) { - deadline := time.Now().Add(200 * time.Millisecond) - err := conn1.SetWriteDeadline(deadline) - require.NoError(t, err) - - n, err := conn1.WriteTo([]byte("test"), addr2) - require.NoError(t, err) - require.Equal(t, 4, n) - reset() - }) - - t.Run("write fails after past deadline", func(t *testing.T) { - deadline := time.Now().Add(-time.Second) // Already expired - err := conn1.SetWriteDeadline(deadline) - require.NoError(t, err) - - _, err = conn1.WriteTo([]byte("test"), addr2) - require.ErrorIs(t, err, ErrDeadlineExceeded) - reset() - }) - - t.Run("read succeeds within deadline", func(t *testing.T) { - // Reset deadline and send a message - conn2.SetReadDeadline(time.Time{}) - testData := []byte("hello") - deadline := time.Now().Add(200 * time.Millisecond) - conn1.SetWriteDeadline(deadline) - _, err := conn1.WriteTo(testData, addr2) - require.NoError(t, err) - - // Set read deadline and try to read - deadline = time.Now().Add(200 * time.Millisecond) - err = conn2.SetReadDeadline(deadline) - require.NoError(t, err) - - buf := make([]byte, 1024) - n, addr, err := conn2.ReadFrom(buf) - require.NoError(t, err) - require.Equal(t, addr1, addr) - require.Equal(t, testData, buf[:n]) - reset() - }) - - t.Run("read fails after deadline", func(t *testing.T) { - defer reset() - // Set a short deadline - deadline := time.Now().Add(50 * time.Millisecond) // Less than router latency - err := conn2.SetReadDeadline(deadline) - require.NoError(t, err) - - var wg sync.WaitGroup - defer wg.Wait() - wg.Add(1) - go func() { - defer wg.Done() - // Send data after setting deadline - _, err := conn1.WriteTo([]byte("test"), addr2) - require.NoError(t, err) - }() - - // Read should fail due to deadline - buf := make([]byte, 1024) - _, _, err = conn2.ReadFrom(buf) - require.ErrorIs(t, err, ErrDeadlineExceeded) - }) -} - func TestSimpleHolePunch(t *testing.T) { router := &SimpleFirewallRouter{ nodes: make(map[string]*simpleNodeFirewall), } // Create two peers - addr1 := &net.UDPAddr{IP: IntToPublicIPv4(1), Port: 1234} - addr2 := &net.UDPAddr{IP: IntToPublicIPv4(2), Port: 1234} + addr1 := &net.UDPAddr{IP: IntToPublicIPv4(0), Port: 1234} + addr2 := &net.UDPAddr{IP: IntToPublicIPv4(1), Port: 1234} + + peer1 := NewSimConn(addr1) + peer1.SetUpPacketReceiver(router) + router.AddNode(peer1.UnicastAddr(), peer1) - peer1 := NewSimConn(addr1, router) - peer2 := NewSimConn(addr2, router) + peer2 := NewSimConn(addr2) + peer2.SetUpPacketReceiver(router) + router.AddNode(peer2.UnicastAddr(), peer2) reset := func() { router.RemoveNode(addr1) router.RemoveNode(addr2) - peer1 = NewSimConn(addr1, router) - peer2 = NewSimConn(addr2, router) + peer1 = NewSimConn(addr1) + peer1.SetUpPacketReceiver(router) + router.AddNode(peer1.UnicastAddr(), peer1) + + peer2 = NewSimConn(addr2) + peer2.SetUpPacketReceiver(router) + router.AddNode(peer2.UnicastAddr(), peer2) } // Initially, direct communication between peer1 and peer2 should fail diff --git a/simlink.go b/simlink.go index 5ff967c..402a69b 100644 --- a/simlink.go +++ b/simlink.go @@ -1,24 +1,13 @@ package simnet import ( - "context" - "net" "sync" "time" - - "golang.org/x/time/rate" ) const Mibps = 1_000_000 -// Creates a new RateLimiter with the following parameters: -// bandwidth (in bits/sec). -// burstSize is in Bytes -func newRateLimiter(bandwidth int, burstSize int) *rate.Limiter { - // Convert bandwidth from bits/sec to bytes/sec - bytesPerSecond := rate.Limit(float64(bandwidth) / 8.0) - return rate.NewLimiter(bytesPerSecond, burstSize) -} +const DefaultFlowBucketCount = 128 // packetWithDeliveryTime holds a packet along with its delivery time and enqueue time type packetWithDeliveryTime struct { @@ -33,190 +22,131 @@ type LinkSettings struct { // MTU (Maximum Transmission Unit) specifies the maximum packet size in bytes MTU int -} -// SimulatedLink simulates a bidirectional network link with variable latency, -// bandwidth limiting, and CoDel-based bufferbloat mitigation -type SimulatedLink struct { - // Internal state for lifecycle management - closed chan struct{} - wg sync.WaitGroup - - // CoDel queues for bufferbloat control - downstreamQueue *codelQueue - upstreamQueue *codelQueue - - // Rate limiters enforce bandwidth constraints - upLimiter *rate.Limiter - downLimiter *rate.Limiter - - // Configuration for link characteristics - UplinkSettings LinkSettings - DownlinkSettings LinkSettings - - // Latency specifies a fixed network delay for downlink packets - // If both Latency and LatencyFunc are set, LatencyFunc takes precedence - Latency time.Duration - - // LatencyFunc computes the network delay for each downlink packet - // This allows variable latency based on packet source/destination - // If nil, Latency field is used instead - LatencyFunc func(Packet) time.Duration - - // Packet routing interfaces - UploadPacket Router - downloadPacket PacketReceiver + // FlowBucketCount sets the number of flow buckets for FQ-CoDel. If zero + // defaults to DefaultFlowBucketCount + FlowBucketCount int } -func (l *SimulatedLink) AddNode(addr net.Addr, receiver PacketReceiver) { - l.downloadPacket = receiver +// Simlink simulates a bidirectional network link with variable latency, +// bandwidth limiting, and CoDel-based bufferbloat mitigation +type Simlink struct { + up *linkDriver + down *linkDriver } -func (l *SimulatedLink) Start() { - if l.downloadPacket == nil { - panic("SimulatedLink.Start() called without having added a packet receiver") +func NewSimlink( + closeSignal chan struct{}, + linkSettings NodeBiDiLinkSettings, + upPacketReceiver PacketReceiver, + downPacketReceiver PacketReceiver, +) *Simlink { + const ( + target = 5 * time.Millisecond + interval = 100 * time.Millisecond + defaultMTU = 1500 + ) + + if linkSettings.Uplink.MTU == 0 { + linkSettings.Uplink.MTU = defaultMTU + } + if linkSettings.Downlink.MTU == 0 { + linkSettings.Downlink.MTU = defaultMTU } - l.closed = make(chan struct{}) - - // Sane defaults - if l.DownlinkSettings.MTU == 0 { - l.DownlinkSettings.MTU = 1400 + if linkSettings.Uplink.FlowBucketCount == 0 { + linkSettings.Uplink.FlowBucketCount = DefaultFlowBucketCount } - if l.UplinkSettings.MTU == 0 { - l.UplinkSettings.MTU = 1400 + if linkSettings.Downlink.FlowBucketCount == 0 { + linkSettings.Downlink.FlowBucketCount = DefaultFlowBucketCount } - // Initialize CoDel queues with 5ms target and 100ms interval - const target = 5 * time.Millisecond - const interval = 100 * time.Millisecond - l.downstreamQueue = newCodelQueue(target, interval) - l.upstreamQueue = newCodelQueue(target, interval) - - // Initialize rate limiters - const burstSizeInPackets = 16 - l.upLimiter = newRateLimiter(l.UplinkSettings.BitsPerSecond, l.UplinkSettings.MTU*burstSizeInPackets) - l.downLimiter = newRateLimiter(l.DownlinkSettings.BitsPerSecond, l.DownlinkSettings.MTU*burstSizeInPackets) - - l.wg.Add(2) - go l.backgroundDownlink() - go l.backgroundUplink() + return &Simlink{ + up: newLinkDriver( + target, interval, + linkSettings.Uplink.MTU, + linkSettings.Uplink.FlowBucketCount, + linkSettings.Uplink.MTU, linkSettings.Uplink.BitsPerSecond, + upPacketReceiver, + closeSignal, + ), + down: newLinkDriver( + target, interval, + linkSettings.Downlink.MTU, + linkSettings.Downlink.FlowBucketCount, + linkSettings.Downlink.MTU, linkSettings.Downlink.BitsPerSecond, + downPacketReceiver, + closeSignal, + ), + } } -func (l *SimulatedLink) Close() error { - close(l.closed) - l.downstreamQueue.Close() - l.upstreamQueue.Close() - l.wg.Wait() - return nil +func (l *Simlink) Start(wg *sync.WaitGroup) { + l.up.Start(wg) + l.down.Start(wg) } -func (l *SimulatedLink) backgroundDownlink() { - defer l.wg.Done() - - for { - select { - case <-l.closed: - return - default: - } - - // Dequeue a packet (this will block until packet is ready for delivery) - p, ok := l.downstreamQueue.Dequeue() - if !ok { - return - } - - // Calculate sojourn time (time spent in queue) - sojournTime := time.Since(p.DeliveryTime) - - // Check if CoDel wants to drop this packet - shouldDrop := l.downstreamQueue.shouldDrop(sojournTime) - if shouldDrop { - // Drop the packet and continue to next one - continue - } - - // Apply rate limiting before delivery - l.downLimiter.WaitN(context.Background(), len(p.buf)) - - // Deliver the packet - l.downloadPacket.RecvPacket(p.Packet) - } +type linkDriver struct { + newPacket chan Packet + q fqCoDel + rateLink *RateLink + closeSignal chan struct{} } -func (l *SimulatedLink) backgroundUplink() { - defer l.wg.Done() - - for { - select { - case <-l.closed: - return - default: - } - - // Dequeue a packet (this will block until packet is ready for delivery) - p, ok := l.upstreamQueue.Dequeue() - if !ok { - return - } - - // Calculate sojourn time (time spent in queue) - sojournTime := time.Since(p.DeliveryTime) - - // Check if CoDel wants to drop this packet - shouldDrop := l.upstreamQueue.shouldDrop(sojournTime) - if shouldDrop { - // Drop the packet and continue to next one - continue - } - - // Apply rate limiting before delivery - l.upLimiter.WaitN(context.Background(), len(p.buf)) - - // Deliver the packet - _ = l.UploadPacket.SendPacket(p.Packet) +func newLinkDriver( + target, interval time.Duration, + quantum int, + flowCount int, + mtu int, bandwidth int, + receiver PacketReceiver, + closeSignal chan struct{}) *linkDriver { + return &linkDriver{ + newPacket: make(chan Packet, 1_024), + q: newFqCoDel(target, interval, quantum, flowCount), + closeSignal: closeSignal, + rateLink: NewRateLink(bandwidth, mtu, receiver), } } -func (l *SimulatedLink) SendPacket(p Packet) error { - if len(p.buf) > l.UplinkSettings.MTU { - // Drop packet if it's too large - return nil - } - - // Uplink has no latency - packets are delivered immediately - deliveryTime := time.Now() - - // Enqueue packet with delivery time to CoDel queue - // Rate limiting happens after dequeue in background goroutine - l.upstreamQueue.Enqueue(&packetWithDeliveryTime{ - Packet: p, - DeliveryTime: deliveryTime, - }) - - return nil +func (d *linkDriver) RecvPacket(p Packet) { + d.newPacket <- p } -func (l *SimulatedLink) RecvPacket(p Packet) { - if len(p.buf) > l.DownlinkSettings.MTU { - // Drop packet if it's too large - return - } - - // Calculate delivery time based on downlink latency - var latency time.Duration - if l.LatencyFunc != nil { - latency = l.LatencyFunc(p) - } else { - latency = l.Latency - } - deliveryTime := time.Now().Add(latency) - - // Enqueue packet with delivery time to CoDel queue - // Rate limiting happens after dequeue in background goroutine - l.downstreamQueue.Enqueue(&packetWithDeliveryTime{ - Packet: p, - DeliveryTime: deliveryTime, +func (d *linkDriver) Start(wg *sync.WaitGroup) { + wgGo(wg, func() { + deqTimer := time.NewTimer(0) + deqTimer.Stop() + var pendingPacket *Packet + + for { + select { + case <-d.closeSignal: + return + case packet := <-d.newPacket: + d.q.Enqueue(packet) + if pendingPacket == nil { + deqTimer.Reset(0) + } + case <-deqTimer.C: + for { + if pendingPacket != nil { + d.rateLink.RecvPacket(*pendingPacket) + pendingPacket = nil + } + + p, ok := d.q.Dequeue() + if ok { + pendingPacket = &p + now := time.Now() + if d.rateLink.AllowN(now, len(p.buf)) { + continue + } + delayDeq := d.rateLink.Reserve(now, len(p.buf)) + deqTimer.Reset(delayDeq) + } + break + } + } + } }) } diff --git a/simlink_test.go b/simlink_test.go index 33058a0..964d358 100644 --- a/simlink_test.go +++ b/simlink_test.go @@ -6,21 +6,20 @@ import ( "fmt" "math" "net" + "net/netip" + "sync" + "sync/atomic" "testing" "testing/synctest" "time" + + "github.com/marcopolo/simnet/internal/require" ) type testRouter struct { - onSend func(p Packet) onRecv func(p Packet) } -func (r *testRouter) SendPacket(p Packet) error { - r.onSend(p) - return nil -} - func (r *testRouter) RecvPacket(p Packet) { r.onRecv(p) } @@ -29,17 +28,60 @@ func (r *testRouter) AddNode(addr net.Addr, receiver PacketReceiver) { r.onRecv = receiver.RecvPacket } -func TestBandwidthLimiterAndLatency_synctest(t *testing.T) { - for _, testUpload := range []bool{true, false} { +func TestLinkDriver(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + const mtu = 1500 + var recvdPackets atomic.Uint32 + tr := testRouter{ + onRecv: func(p Packet) { + recvdPackets.Add(1) + }, + } + + closeSignal := make(chan struct{}) + ld := newLinkDriver( + 5*time.Millisecond, + 100*time.Millisecond, + 10*mtu, + 128, + mtu, + 50*Mibps, + &tr, + closeSignal, + ) + + var wg sync.WaitGroup + ld.Start(&wg) + + defer wg.Wait() + defer close(closeSignal) + + ld.RecvPacket(Packet{ + buf: []byte("Hello World"), + To: net.UDPAddrFromAddrPort(netip.MustParseAddrPort("1.2.3.4:1234")), + From: net.UDPAddrFromAddrPort(netip.MustParseAddrPort("1.2.3.5:1234")), + }) + + require.Equal(t, uint32(0), recvdPackets.Load()) + time.Sleep(10 * time.Millisecond) + require.Equal(t, uint32(1), recvdPackets.Load()) + }) +} + +func TestBandwidthLimiter_synctest(t *testing.T) { + for _, testUpload := range []bool{true} { t.Run(fmt.Sprintf("testing upload=%t", testUpload), func(t *testing.T) { synctest.Test(t, func(t *testing.T) { const expectedSpeed = 10 * Mibps - const downlinkLatency = 10 * time.Millisecond const MTU = 1400 linkSettings := LinkSettings{ BitsPerSecond: expectedSpeed, MTU: MTU, } + bidiLinkSettings := NodeBiDiLinkSettings{ + Uplink: linkSettings, + Downlink: linkSettings, + } recvStartTimeChan := make(chan time.Time, 1) recvStarted := false @@ -53,26 +95,22 @@ func TestBandwidthLimiterAndLatency_synctest(t *testing.T) { } router := &testRouter{} - if testUpload { - router.onSend = packetHandler - } else { - router.onRecv = packetHandler - } - link := SimulatedLink{ - UplinkSettings: linkSettings, - DownlinkSettings: linkSettings, - LatencyFunc: func(p Packet) time.Duration { return downlinkLatency }, - UploadPacket: router, - downloadPacket: router, - } - - link.Start() + router.onRecv = packetHandler + closeSignal := make(chan struct{}) + var wg sync.WaitGroup + link := NewSimlink( + closeSignal, + bidiLinkSettings, + router, + router, + ) + + link.Start(&wg) // Send 10MiB of data chunk := make([]byte, MTU) bytesSent := 0 - sendStartTime := time.Now() { totalBytes := 10 << 20 // Blast a bunch of packets @@ -80,10 +118,15 @@ func TestBandwidthLimiterAndLatency_synctest(t *testing.T) { // This sleep shouldn't limit the speed. 1400 Bytes/100us = 14KB/ms = 14MB/s = 14*8 Mbps // but it acts as a simple pacer to avoid just dropping the packets when the link is saturated. time.Sleep(100 * time.Microsecond) + p := Packet{ + buf: chunk, + To: net.UDPAddrFromAddrPort(netip.MustParseAddrPort("1.2.3.4:1234")), + From: net.UDPAddrFromAddrPort(netip.MustParseAddrPort("1.2.3.5:1234")), + } if testUpload { - _ = link.SendPacket(Packet{buf: chunk}) + link.up.RecvPacket(p) } else { - link.RecvPacket(Packet{buf: chunk}) + link.down.RecvPacket(p) } bytesSent += len(chunk) } @@ -91,33 +134,15 @@ func TestBandwidthLimiterAndLatency_synctest(t *testing.T) { // Wait for delayed packets to be sent time.Sleep(40 * time.Millisecond) - fmt.Printf("sent: %d\n", bytesSent) + t.Logf("sent: %d\n", bytesSent) + + close(closeSignal) + wg.Wait() - link.Close() - fmt.Printf("bytesRead: %d\n", bytesRead) + t.Logf("bytesRead: %d\n", bytesRead) recvStartTime := <-recvStartTimeChan duration := time.Since(recvStartTime) - observedLatency := recvStartTime.Sub(sendStartTime) - // Uplink is now instant (no latency), only downlink has latency - var expectedLatency time.Duration - if testUpload { - // Uplink test: expect near-zero latency - expectedLatency = 0 - t.Logf("observed latency: %s (uplink is instant)\n", observedLatency) - if observedLatency > 5*time.Millisecond { - t.Fatalf("observed latency %s is too high for instant uplink", observedLatency) - } - } else { - // Downlink test: expect configured latency - expectedLatency = downlinkLatency - percentErrorLatency := math.Abs(observedLatency.Seconds()-expectedLatency.Seconds()) / expectedLatency.Seconds() - t.Logf("observed latency: %s, expected latency: %s, percent error: %f\n", observedLatency, expectedLatency, percentErrorLatency) - if percentErrorLatency > 0.20 { - t.Fatalf("observed latency %s is wrong", observedLatency) - } - } - observedSpeed := 8 * float64(bytesRead) / duration.Seconds() t.Logf("observed speed: %f Mbps over %s\n", observedSpeed/Mibps, duration) percentErrorSpeed := math.Abs(observedSpeed-float64(expectedSpeed)) / float64(expectedSpeed) @@ -129,105 +154,3 @@ func TestBandwidthLimiterAndLatency_synctest(t *testing.T) { }) } } - -type linkAdapter struct { - link PacketReceiver -} - -var _ Router = &linkAdapter{} - -// AddNode implements Router. -func (c *linkAdapter) AddNode(addr net.Addr, receiver PacketReceiver) { - c.link = receiver -} - -// SendPacket implements Router. -func (c *linkAdapter) SendPacket(p Packet) error { - c.link.RecvPacket(p) - return nil -} - -func TestBandwidthLimiterAndLatencyConnectedLinks_synctest(t *testing.T) { - synctest.Test(t, func(t *testing.T) { - const expectedSpeed = 100 * Mibps - const downlinkLatency = 10 * time.Millisecond - // Only downlink has latency, so total latency is 1x downlink latency - const expectedLatency = downlinkLatency - const MTU = 1400 - linkSettings := LinkSettings{ - BitsPerSecond: expectedSpeed, - MTU: MTU, - } - - recvStartTimeChan := make(chan time.Time, 1) - recvStarted := false - bytesRead := 0 - packetHandler := func(p Packet) { - if !recvStarted { - recvStarted = true - recvStartTimeChan <- time.Now() - } - bytesRead += len(p.buf) - } - r := &testRouter{ - onRecv: packetHandler, - } - - link2 := SimulatedLink{ - UplinkSettings: linkSettings, - DownlinkSettings: linkSettings, - LatencyFunc: func(p Packet) time.Duration { return downlinkLatency }, - downloadPacket: r, - } - link1 := SimulatedLink{ - UplinkSettings: linkSettings, - DownlinkSettings: linkSettings, - LatencyFunc: func(p Packet) time.Duration { return downlinkLatency }, - UploadPacket: &linkAdapter{link: &link2}, - downloadPacket: &testRouter{}, - } - - link1.Start() - link2.Start() - - // Send 10MiB of data - chunk := make([]byte, MTU) - bytesSent := 0 - - sendStartTime := time.Now() - { - totalBytes := 10 << 20 - // Blast a bunch of packets - for bytesSent < totalBytes { - time.Sleep(100 * time.Microsecond) - _ = link1.SendPacket(Packet{buf: chunk}) - bytesSent += len(chunk) - } - } - - // Wait for delayed packets to be sent - time.Sleep(40 * time.Millisecond) - fmt.Printf("sent: %d\n", bytesSent) - - link1.Close() - link2.Close() - fmt.Printf("bytesRead: %d\n", bytesRead) - recvStartTime := <-recvStartTimeChan - duration := time.Since(recvStartTime) - - observedLatency := recvStartTime.Sub(sendStartTime) - percentErrorLatency := math.Abs(observedLatency.Seconds()-expectedLatency.Seconds()) / expectedLatency.Seconds() - t.Logf("observed latency: %s, expected latency: %s, percent error: %f\n", observedLatency, expectedLatency, percentErrorLatency) - if percentErrorLatency > 0.20 { - t.Fatalf("observed latency %s is wrong", observedLatency) - } - - observedSpeed := 8 * float64(bytesRead) / duration.Seconds() - t.Logf("observed speed: %f Mbps over %s\n", observedSpeed/Mibps, duration) - percentErrorSpeed := math.Abs(observedSpeed-float64(expectedSpeed)) / float64(expectedSpeed) - t.Logf("observed speed: %f Mbps, expected speed: %d Mbps, percent error: %f\n", observedSpeed/Mibps, expectedSpeed/Mibps, percentErrorSpeed) - if percentErrorSpeed > 0.20 { - t.Fatalf("observed speed %f Mbps is too far from expected speed %d Mbps. Percent error: %f", observedSpeed/Mibps, expectedSpeed/Mibps, percentErrorSpeed) - } - }) -} diff --git a/simnet.go b/simnet.go index 76c116a..fda897b 100644 --- a/simnet.go +++ b/simnet.go @@ -1,17 +1,40 @@ package simnet import ( - "errors" - "fmt" + "log/slog" "net" + "sync" "time" ) +func StaticLatency(duration time.Duration) func(*Packet) time.Duration { + return func(*Packet) time.Duration { + return duration + } +} + // Simnet is a simulated network that manages connections between nodes // with configurable network conditions. type Simnet struct { - router PerfectRouter - links []*SimulatedLink + // LatencyFunc defines the latency added when routing a given packet. + // The latency is allowed to be dynamic and change packet to packet (which + // could lead to packet reordering). + // + // A simple use case can use `StaticLatency(duration)` to set a static + // latency for all packets. + // + // More complex use cases can define a latency map between endpoints and + // have this function return the expected latency. + LatencyFunc func(*Packet) time.Duration + + // Optional, if unset will use the default slog logger. + Logger *slog.Logger + + started bool + closeSignal chan struct{} + wg sync.WaitGroup + router VariableLatencyRouter + links []*Simlink } // NodeBiDiLinkSettings defines the bidirectional link settings for a network node. @@ -22,49 +45,52 @@ type NodeBiDiLinkSettings struct { Downlink LinkSettings // Uplink configures the settings for outgoing traffic from this node Uplink LinkSettings - - // Latency specifies a fixed network delay for downlink packets only - // If both Latency and LatencyFunc are set, LatencyFunc takes precedence - Latency time.Duration - - // LatencyFunc computes the network delay for each downlink packet - // This allows variable latency based on packet source/destination - // If nil, Latency field is used instead - LatencyFunc func(Packet) time.Duration } -func (n *Simnet) Start() error { +// Start starts the simulated network and related goroutines +func (n *Simnet) Start() { + n.started = true + if n.Logger == nil { + n.Logger = slog.Default() + } + // Log whenever the router fails to route a packet (likely a test setup bug). + n.router.OnDrop = LogOnDrop(n.Logger) + n.router.LatencyFunc = n.LatencyFunc + n.router.CloseSignal = n.closeSignal + n.router.Start(&n.wg) for _, link := range n.links { - link.Start() + link.Start(&n.wg) } - return nil } -func (n *Simnet) Close() error { - var errs error - for _, link := range n.links { - err := link.Close() - if err != nil { - errs = errors.Join(errs, err) - } - } - if errs != nil { - return fmt.Errorf("failed to close some links: %w", errs) +func (n *Simnet) Close() { + close(n.closeSignal) + n.wg.Wait() +} + +func (n *Simnet) init() { + if n.closeSignal == nil { + n.closeSignal = make(chan struct{}) } - return nil } func (n *Simnet) NewEndpoint(addr *net.UDPAddr, linkSettings NodeBiDiLinkSettings) *SimConn { - link := &SimulatedLink{ - DownlinkSettings: linkSettings.Downlink, - UplinkSettings: linkSettings.Uplink, - Latency: linkSettings.Latency, - LatencyFunc: linkSettings.LatencyFunc, - UploadPacket: &n.router, + n.init() + if n.started { + panic("Must add endpoints before starting the network") } - c := NewBlockingSimConn(addr, link) + + c := NewBlockingSimConn(addr) + link := NewSimlink( + n.closeSignal, + linkSettings, + &n.router, + c, + ) + c.SetUpPacketReceiver(link.up) + n.router.AddNode(addr, link.down) n.links = append(n.links, link) - n.router.AddNode(addr, link) + return c } diff --git a/simnet_synctest_test.go b/simnet_synctest_test.go index 8dab8fe..a546adf 100644 --- a/simnet_synctest_test.go +++ b/simnet_synctest_test.go @@ -3,7 +3,6 @@ package simnet import ( - "fmt" "math" "net" "testing" @@ -16,14 +15,8 @@ import ( const oneMbps = 1_000_000 -func newConn(simnet *Simnet, address *net.UDPAddr, linkSettings NodeBiDiLinkSettings) *SimConn { - return simnet.NewEndpoint(address, linkSettings) -} - -func TestSimnetWIthSynctest(t *testing.T) { +func TestSimnetWithSynctest(t *testing.T) { synctest.Test(t, func(t *testing.T) { - router := &Simnet{} - const bandwidth = 10 * oneMbps const latency = 10 * time.Millisecond linkSettings := NodeBiDiLinkSettings{ @@ -33,22 +26,25 @@ func TestSimnetWIthSynctest(t *testing.T) { Uplink: LinkSettings{ BitsPerSecond: bandwidth, }, - Latency: latency, + } + + nw := &Simnet{ + LatencyFunc: StaticLatency(latency), } addressA := net.UDPAddr{ IP: net.ParseIP("1.0.0.1"), Port: 8000, } - connA := newConn(router, &addressA, linkSettings) + connA := nw.NewEndpoint(&addressA, linkSettings) addressB := net.UDPAddr{ IP: net.ParseIP("1.0.0.2"), Port: 8000, } - connB := newConn(router, &addressB, linkSettings) + connB := nw.NewEndpoint(&addressB, linkSettings) - router.Start() - defer router.Close() + nw.Start() + defer nw.Close() start := time.Now() connA.WriteTo([]byte("hello"), &addressB) @@ -71,7 +67,6 @@ func TestSimnetWIthSynctest(t *testing.T) { func TestSimnetBandwidthWithSynctest(t *testing.T) { synctest.Test(t, func(t *testing.T) { - router := &Simnet{} const bandwidth = 40 * oneMbps const latency = 10 * time.Millisecond @@ -85,23 +80,24 @@ func TestSimnetBandwidthWithSynctest(t *testing.T) { BitsPerSecond: bandwidth, MTU: MTU, }, - Latency: latency, + } + nw := &Simnet{ + LatencyFunc: StaticLatency(latency), } addressA := net.UDPAddr{ IP: net.ParseIP("1.0.0.1"), Port: 8000, } - connA := newConn(router, &addressA, linkSettings) + connA := nw.NewEndpoint(&addressA, linkSettings) addressB := net.UDPAddr{ IP: net.ParseIP("1.0.0.2"), Port: 8000, } - connB := newConn(router, &addressB, linkSettings) + connB := nw.NewEndpoint(&addressB, linkSettings) - err := router.Start() - require.NoError(t, err) - defer router.Close() + nw.Start() + defer nw.Close() readDone := make(chan struct{}) @@ -149,8 +145,8 @@ func TestSimnetBandwidthWithSynctest(t *testing.T) { observedBandwidth := float64(bytesRead*8) / readDuration.Seconds() expectedBandwidth := float64(bandwidth) - fmt.Println("sent bytes", bytesSent) - fmt.Println("Read bytes", bytesRead) + t.Log("sent bytes", bytesSent) + t.Log("Read bytes", bytesRead) percentDiffBandwidth := math.Abs(observedBandwidth-expectedBandwidth) / expectedBandwidth t.Logf("observed bandwidth: %v mbps, expected bandwidth: %v mbps, percent diff: %v", observedBandwidth/oneMbps, expectedBandwidth/oneMbps, percentDiffBandwidth) if percentDiffBandwidth > 0.20 { diff --git a/simnet_test.go b/simnet_test.go index 1f479d2..74856a9 100644 --- a/simnet_test.go +++ b/simnet_test.go @@ -14,12 +14,14 @@ import ( // Example showing a simple echo using Simnet and the returned net.PacketConn. func ExampleSimnet_echo() { + // Create the simulated network and two endpoints - n := &simnet.Simnet{} + n := &simnet.Simnet{ + LatencyFunc: simnet.StaticLatency(5 * time.Millisecond), + } settings := simnet.NodeBiDiLinkSettings{ Downlink: simnet.LinkSettings{BitsPerSecond: 10 * simnet.Mibps}, Uplink: simnet.LinkSettings{BitsPerSecond: 10 * simnet.Mibps}, - Latency: 5 * time.Millisecond, } addrA := &net.UDPAddr{IP: net.ParseIP("1.0.0.1"), Port: 9001} @@ -28,7 +30,7 @@ func ExampleSimnet_echo() { client := n.NewEndpoint(addrA, settings) server := n.NewEndpoint(addrB, settings) - _ = n.Start() + n.Start() defer n.Close() // Simple echo server using the returned PacketConn @@ -36,7 +38,7 @@ func ExampleSimnet_echo() { go func() { defer close(done) buf := make([]byte, 1024) - server.SetReadDeadline(time.Now().Add(2 * time.Second)) + server.SetReadDeadline(time.Now().Add(1 * time.Second)) n, src, err := server.ReadFrom(buf) if err != nil { return @@ -46,7 +48,11 @@ func ExampleSimnet_echo() { // Client sends a message and waits for the echo response client.SetReadDeadline(time.Now().Add(2 * time.Second)) - _, _ = client.WriteTo([]byte("ping"), addrB) + _, err := client.WriteTo([]byte("ping"), addrB) + if err != nil { + fmt.Println("Error writing to server:", err) + return + } buf := make([]byte, 1024) nRead, _, _ := client.ReadFrom(buf) @@ -62,13 +68,16 @@ func ExampleSimnet_echo() { func TestSimnet_pingWithDelay(t *testing.T) { synctest.Test(t, func(t *testing.T) { + const latency = 400 * time.Millisecond // Create the simulated network and two endpoints - n := &simnet.Simnet{} - latency := 400 * time.Millisecond + n := &simnet.Simnet{ + LatencyFunc: func(p *simnet.Packet) time.Duration { + return latency + }, + } settings := simnet.NodeBiDiLinkSettings{ Downlink: simnet.LinkSettings{BitsPerSecond: 10 * simnet.Mibps}, Uplink: simnet.LinkSettings{BitsPerSecond: 10 * simnet.Mibps}, - Latency: latency / 2, // Each endpoint has downlink latency, so RTT = 2 * (latency/2) = latency } addrA := &net.UDPAddr{IP: net.ParseIP("1.0.0.1"), Port: 9001} @@ -77,10 +86,7 @@ func TestSimnet_pingWithDelay(t *testing.T) { client := n.NewEndpoint(addrA, settings) server := n.NewEndpoint(addrB, settings) - err := n.Start() - if err != nil { - t.Fatalf("Failed to start simnet: %v", err) - } + n.Start() defer n.Close() // Simple echo server using the returned PacketConn @@ -126,7 +132,7 @@ func TestSimnet_pingWithDelay(t *testing.T) { // Client sends first ping client.SetReadDeadline(time.Now().Add(1 * time.Second)) - _, err = client.WriteTo([]byte("ping1"), addrB) + _, err := client.WriteTo([]byte("ping1"), addrB) if err != nil { t.Fatalf("Client failed to write ping1: %v", err) }