diff --git a/.gitignore b/.gitignore index daf913b1..966dc06a 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,6 @@ _testmain.go *.exe *.test *.prof + +# Coverage +coverage.out diff --git a/.mockery.yml b/.mockery.yml new file mode 100644 index 00000000..5b97245e --- /dev/null +++ b/.mockery.yml @@ -0,0 +1,20 @@ +all: false +dir: 'mocks' +filename: 'mock_{{.InterfaceName | snakecase}}.go' +force-file-write: true +formatter: goimports +generate: true +include-auto-generated: false +log-level: info +structname: 'Mock{{.InterfaceName}}' +pkgname: 'mocks' +recursive: false +require-template-schema-exists: true +template: testify +template-schema: '{{.Template}}.schema.json' +packages: + github.com/enbility/zeroconf/v3/api: + interfaces: + PacketConn: + ConnectionFactory: + InterfaceProvider: diff --git a/README.md b/README.md index eb365035..a223bb69 100644 --- a/README.md +++ b/README.md @@ -22,24 +22,30 @@ Target environments: private LAN/Wifi, small or isolated networks. ## Install Nothing is as easy as that: ```bash -$ go get -u github.com/enbility/zeroconf/v2 +$ go get -u github.com/enbility/zeroconf/v3 ``` ## Browse for services in your local network ```go entries := make(chan *zeroconf.ServiceEntry) -go func(results <-chan *zeroconf.ServiceEntry) { - for entry := range results { - log.Println(entry) +removed := make(chan *zeroconf.ServiceEntry) + +go func() { + for { + select { + case entry := <-entries: + log.Println("Found:", entry) + case entry := <-removed: + log.Println("Removed:", entry) + } } - log.Println("No more entries.") -}(entries) +}() ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) defer cancel() // Discover all services on the network (e.g. _workstation._tcp) -err = zeroconf.Browse(ctx, "_workstation._tcp", "local.", entries) +err := zeroconf.Browse(ctx, "_workstation._tcp", "local.", entries, removed) if err != nil { log.Fatalln("Failed to browse:", err.Error()) } @@ -53,7 +59,23 @@ See https://github.com/enbility/zeroconf/blob/master/examples/resolv/client.go. ## Lookup a specific service instance ```go -// Example filled soon. +entries := make(chan *zeroconf.ServiceEntry) + +go func() { + for entry := range entries { + log.Println("Found:", entry) + } +}() + +ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) +defer cancel() +// Lookup a specific service instance by name +err := zeroconf.Lookup(ctx, "MyService", "_workstation._tcp", "local.", entries) +if err != nil { + log.Fatalln("Failed to lookup:", err.Error()) +} + +<-ctx.Done() ``` ## Register a service @@ -81,6 +103,29 @@ Multiple subtypes may be added to service name, separated by commas. E.g `_works See https://github.com/enbility/zeroconf/blob/master/examples/register/server.go. +## Testing Support (v3) + +Version 3 introduces interface-based abstractions for improved testability. You can inject mock connections for unit testing without requiring real network access: + +```go +// Create mock connections using the provided interfaces +mockFactory := &MyMockConnectionFactory{} + +// Client with mock connections +client, err := zeroconf.NewClient(zeroconf.WithClientConnFactory(mockFactory)) + +// Server with mock connections +server, err := zeroconf.RegisterProxy( + "MyService", "_http._tcp", "local.", 8080, + "myhost.local.", []string{"192.168.1.100"}, + []string{"txtvers=1"}, + nil, // interfaces + zeroconf.WithServerConnFactory(mockFactory), +) +``` + +See the `api/` package for interface definitions and `mocks/` for mockery-generated mocks. + ## Features and ToDo's This list gives a quick impression about the state of this library. See what needs to be done and submit a pull request :) @@ -89,6 +134,8 @@ See what needs to be done and submit a pull request :) * [x] Multiple IPv6 / IPv4 addresses support * [x] Send multiple probes (exp. back-off) if no service answers (*) * [x] Timestamp entries for TTL checks +* [x] Service removal notifications via `removed` channel +* [x] Interface-based abstractions for testability (v3) * [ ] Compare new multicasts with already received services _Notes:_ diff --git a/V3_REFACTORING_PLAN.md b/V3_REFACTORING_PLAN.md new file mode 100644 index 00000000..99884dfc --- /dev/null +++ b/V3_REFACTORING_PLAN.md @@ -0,0 +1,248 @@ +# ZeroConf v3 Refactoring Plan + +## Goals + +1. **Testability**: Enable unit testing without real network access +2. **Interfaces**: Define clear abstractions at network boundaries +3. **Dependency Injection**: Allow mock injection for testing +4. **Test Coverage**: Target 85%+ coverage with meaningful unit tests +5. **Generated Mocks**: Use mockery for maintainable mocks +6. **Remove Global State**: Move package-level vars into config structs + +## Key Insight: ControlMessage Simplification + +Analysis of the codebase shows that only `IfIndex` is ever used from `ipv4.ControlMessage` and `ipv6.ControlMessage`. This allows us to create a unified `PacketConn` interface that works for both IPv4 and IPv6: + +```go +// Instead of exposing ControlMessage, we just expose ifIndex +ReadFrom(b []byte) (n int, ifIndex int, src net.Addr, err error) +WriteTo(b []byte, ifIndex int, dst net.Addr) (n int, err error) +``` + +## Architecture Overview + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Public API │ +│ Browse() / Lookup() / Register() / RegisterProxy() │ +└─────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Client / Server │ +│ - Use api.PacketConn interface (not concrete types) │ +│ - Accept ConnectionFactory via options │ +│ - Use InterfaceProvider internally for default interfaces │ +└─────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ api/ Package │ +│ PacketConn / ConnectionFactory / InterfaceProvider │ +└─────────────────────────────────────────────────────────────┘ + │ + ┌───────────────┴───────────────┐ + ▼ ▼ +┌──────────────────────────┐ ┌──────────────────────────┐ +│ Real Implementations │ │ mocks/ Package │ +│ - ipv4PacketConn │ │ - MockPacketConn │ +│ - ipv6PacketConn │ │ - MockConnectionFactory │ +│ - defaultConnFactory │ │ - MockInterfaceProvider │ +│ - defaultIfaceProvider │ │ (generated by mockery) │ +└──────────────────────────┘ └──────────────────────────┘ +``` + +## Package Structure + +``` +zeroconf/v3/ +├── api/ # Pure interfaces (no internal deps) +│ └── interfaces.go # PacketConn, ConnectionFactory, InterfaceProvider +├── mocks/ # Generated mocks (mockery) +│ ├── mock_packet_conn.go +│ ├── mock_connection_factory.go +│ └── mock_interface_provider.go +├── .mockery.yml # Mockery configuration +├── client.go # Client implementation +├── server.go # Server implementation +├── conn_ipv4.go # ipv4PacketConn wrapper +├── conn_ipv6.go # ipv6PacketConn wrapper +├── conn_factory.go # defaultConnectionFactory +├── conn_provider.go # defaultInterfaceProvider +├── mdns.go # Network constants (mDNS addresses) +├── service.go # ServiceEntry, ServiceRecord +├── utils.go # Helper functions +├── doc.go # Package documentation +├── *_test.go # Tests +└── examples/ # Example applications +``` + +## Interface Definitions (in api/interfaces.go) + +### PacketConn + +```go +type PacketConn interface { + ReadFrom(b []byte) (n int, ifIndex int, src net.Addr, err error) + WriteTo(b []byte, ifIndex int, dst net.Addr) (n int, err error) + Close() error + JoinGroup(ifi *net.Interface, group net.Addr) error + LeaveGroup(ifi *net.Interface, group net.Addr) error + SetMulticastTTL(ttl int) error + SetMulticastHopLimit(hopLimit int) error + SetMulticastInterface(ifi *net.Interface) error +} +``` + +### ConnectionFactory + +```go +type ConnectionFactory interface { + CreateIPv4Conn(ifaces []net.Interface) (PacketConn, error) + CreateIPv6Conn(ifaces []net.Interface) (PacketConn, error) +} +``` + +### InterfaceProvider + +```go +type InterfaceProvider interface { + MulticastInterfaces() []net.Interface +} +``` + +## Implementation Phases + +### Phase 1: Package Structure & Interfaces ✓ + +**Completed:** +- [x] Create `api/` package with interface definitions +- [x] Configure mockery (`.mockery.yml`) +- [x] Generate mocks in `mocks/` package +- [x] Update implementation to import `api/` +- [x] Update tests to use `mocks/` + +--- + +### Phase 2: Connection Wrappers (File Split) ✓ + +**Completed:** +- [x] `conn_ipv4.go` - ipv4PacketConn wrapper +- [x] `conn_ipv6.go` - ipv6PacketConn wrapper +- [x] `conn_factory.go` - defaultConnectionFactory +- [x] `conn_provider.go` - defaultInterfaceProvider with MulticastInterfaces() +- [x] Removed `conn_wrapper.go` (split into above files) + +--- + +### Phase 3: InterfaceProvider Implementation ✓ + +**Completed:** +- [x] Created `defaultInterfaceProvider` implementing `api.InterfaceProvider` +- [x] Moved `listMulticastInterfaces()` into `defaultInterfaceProvider.MulticastInterfaces()` +- [x] Used internally via `NewInterfaceProvider().MulticastInterfaces()` in client/server + +**Design Decision:** Removed `WithIfaceProvider` options after review - they added complexity without clear benefit since: +- `Register()` already accepts `ifaces []net.Interface` directly +- `SelectIfaces()` option exists for client +- Interface selection is simpler as a direct parameter than an injected provider + +--- + +### Phase 4: Server Improvements ✓ + +**Changes to `server.go`:** +- [x] Change `Server.ipv4conn` to use `api.PacketConn` +- [x] Change `Server.ipv6conn` to use `api.PacketConn` +- [x] Add `WithServerConnFactory()` option +- [x] Remove deprecated `Server.TTL()` method + +--- + +### Phase 5: Client Improvements ✓ + +**Changes to `client.go`:** +- [x] Change connection fields to use `api.PacketConn` +- [x] Add `WithClientConnFactory()` option +- [x] Export `Client` type (renamed `client` -> `Client`) +- [x] Add `NewClient()` constructor + +--- + +### Phase 6: Coverage & Cleanup + +1. Run coverage report, identify gaps +2. Add tests for untested functions +3. Update doc.go for v3 +4. Final integration test pass + +--- + +## File Changes Summary + +| File | Action | Description | +|------|--------|-------------| +| `api/interfaces.go` | DONE | Interface definitions | +| `mocks/*.go` | DONE | Generated mocks (mockery) | +| `.mockery.yml` | DONE | Mockery configuration | +| `conn_ipv4.go` | NEW | IPv4 PacketConn wrapper | +| `conn_ipv6.go` | NEW | IPv6 PacketConn wrapper | +| `conn_factory.go` | NEW | defaultConnectionFactory | +| `conn_provider.go` | NEW | defaultInterfaceProvider + listMulticastInterfaces | +| `conn_wrapper.go` | DELETE | Split into above files | +| `mdns.go` | RENAME | Network constants (was connection.go) | +| `server.go` | MODIFY | Add WithServerConnFactory, remove TTL() | +| `client.go` | MODIFY | Export Client, add NewClient, add WithClientConnFactory | +| `server_unit_test.go` | DONE | Unit tests with mocks | +| `client_unit_test.go` | DONE | Unit tests with mocks | + +--- + +## Breaking Changes + +1. **Module path**: `github.com/enbility/zeroconf/v3` +2. **Exported `Client` type**: New public API +3. **`NewClient()` function**: New constructor +4. **Removed**: `Server.TTL()` method (was deprecated) + +## Backward Compatibility + +Main API functions remain compatible: +- `Browse(ctx, service, domain, entries, removed, opts...)` - unchanged +- `Lookup(ctx, instance, service, domain, entries, opts...)` - unchanged +- `Register(instance, service, domain, port, text, ifaces, opts...)` - unchanged +- `RegisterProxy(...)` - unchanged + +New optional features via options: +- `WithClientConnFactory(factory)` / `WithServerConnFactory(factory)` - for injecting mock connections in tests + +--- + +## Mock Generation + +Using mockery v3. Configuration in `.mockery.yml`: + +```yaml +packages: + github.com/enbility/zeroconf/v3/api: + interfaces: + PacketConn: + ConnectionFactory: + InterfaceProvider: +``` + +Regenerate mocks: +```bash +mockery +``` + +--- + +## Success Criteria + +- [x] All existing tests pass +- [ ] Test coverage > 85% (currently 72.7%) +- [x] All network I/O behind interfaces +- [x] Unit tests run without network access +- [x] Mocks generated automatically +- [ ] Documentation updated diff --git a/api/interfaces.go b/api/interfaces.go new file mode 100644 index 00000000..141c392e --- /dev/null +++ b/api/interfaces.go @@ -0,0 +1,57 @@ +// Package api defines the core interfaces for the zeroconf library. +// These interfaces enable dependency injection and testing. +package api + +import "net" + +//go:generate mockery + +// PacketConn abstracts IPv4/IPv6 multicast packet connections. +// This interface unifies ipv4.PacketConn and ipv6.PacketConn by extracting +// only the IfIndex from ControlMessage, which is the only field used. +type PacketConn interface { + // ReadFrom reads a packet from the connection. + // Returns the number of bytes read, the interface index the packet arrived on, + // the source address, and any error. + ReadFrom(b []byte) (n int, ifIndex int, src net.Addr, err error) + + // WriteTo writes a packet to the destination address. + // The ifIndex specifies which interface to send from (0 for default/all). + WriteTo(b []byte, ifIndex int, dst net.Addr) (n int, err error) + + // Close closes the connection. + Close() error + + // JoinGroup joins the multicast group on the specified interface. + JoinGroup(ifi *net.Interface, group net.Addr) error + + // LeaveGroup leaves the multicast group on the specified interface. + LeaveGroup(ifi *net.Interface, group net.Addr) error + + // SetMulticastTTL sets the TTL for outgoing multicast packets (IPv4). + SetMulticastTTL(ttl int) error + + // SetMulticastHopLimit sets the hop limit for outgoing multicast packets (IPv6). + SetMulticastHopLimit(hopLimit int) error + + // SetMulticastInterface sets the default interface for outgoing multicast. + // Used as fallback on platforms where ControlMessage is not supported (Windows). + SetMulticastInterface(ifi *net.Interface) error +} + +// ConnectionFactory creates multicast connections. +// This abstraction allows injecting mock connections for testing. +type ConnectionFactory interface { + // CreateIPv4Conn creates an IPv4 multicast connection joined to the mDNS group. + CreateIPv4Conn(ifaces []net.Interface) (PacketConn, error) + + // CreateIPv6Conn creates an IPv6 multicast connection joined to the mDNS group. + CreateIPv6Conn(ifaces []net.Interface) (PacketConn, error) +} + +// InterfaceProvider lists network interfaces. +// This abstraction allows injecting mock interface lists for testing. +type InterfaceProvider interface { + // MulticastInterfaces returns all network interfaces capable of multicast. + MulticastInterfaces() []net.Interface +} diff --git a/client.go b/client.go index 5d845414..08aca3a9 100644 --- a/client.go +++ b/client.go @@ -3,17 +3,14 @@ package zeroconf import ( "context" "fmt" - "log" "math/rand" "net" "reflect" - "runtime" "strings" "time" + "github.com/enbility/zeroconf/v3/api" "github.com/miekg/dns" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" ) // IPType specifies the IP traffic the client listens for. @@ -32,15 +29,19 @@ const ( var initialQueryInterval = 4 * time.Second // Client structure encapsulates both IPv4/IPv6 UDP connections. -type client struct { - ipv4conn *ipv4.PacketConn - ipv6conn *ipv6.PacketConn - ifaces []net.Interface +type Client struct { + ipv4conn api.PacketConn + ipv6conn api.PacketConn + ipv4Mgr *InterfaceManager + ipv6Mgr *InterfaceManager + provider api.InterfaceProvider } type clientOpts struct { - listenOn IPType - ifaces []net.Interface + listenOn IPType + ifaces []net.Interface + connFactory api.ConnectionFactory + provider api.InterfaceProvider } // ClientOption fills the option struct to configure intefaces, etc. @@ -64,6 +65,22 @@ func SelectIfaces(ifaces []net.Interface) ClientOption { } } +// WithClientConnFactory sets a custom connection factory for the client. +// This is primarily useful for testing with mock connections. +func WithClientConnFactory(factory api.ConnectionFactory) ClientOption { + return func(o *clientOpts) { + o.connFactory = factory + } +} + +// WithClientInterfaceProvider sets a custom interface provider for the client. +// This is primarily useful for testing with mock interface lists. +func WithClientInterfaceProvider(provider api.InterfaceProvider) ClientOption { + return func(o *clientOpts) { + o.provider = provider + } +} + // Browse for all services of a given type in a given domain. // Received entries are sent on the entries channel. // It blocks until the context is canceled (or an error occurs). @@ -112,8 +129,20 @@ func applyOpts(options ...ClientOption) clientOpts { return conf } -func (c *client) run(ctx context.Context, params *lookupParams) error { +func (c *Client) run(ctx context.Context, params *lookupParams) error { + // Run immediate sync on startup to catch any interfaces that changed + // between client creation and run() + c.syncInterfaces() + ctx, cancel := context.WithCancel(ctx) + + // Start interface sync in background + syncDone := make(chan struct{}) + go func() { + defer close(syncDone) + c.runInterfaceSync(ctx) + }() + done := make(chan struct{}) go func() { defer close(done) @@ -125,6 +154,7 @@ func (c *client) run(ctx context.Context, params *lookupParams) error { err := c.periodicQuery(ctx, params) cancel() <-done + <-syncDone return err } @@ -133,42 +163,77 @@ func defaultParams(service string) *lookupParams { return newLookupParams("", service, "local", false, make(chan *ServiceEntry), make(chan *ServiceEntry)) } -// Client structure constructor -func newClient(opts clientOpts) (*client, error) { +// NewClient creates a new mDNS client with the given options. +// This is the low-level constructor. For most use cases, prefer Browse() or Lookup(). +func NewClient(opts ...ClientOption) (*Client, error) { + return newClient(applyOpts(opts...)) +} + +// newClient is the internal constructor that takes pre-applied options. +func newClient(opts clientOpts) (*Client, error) { + // Get interface provider (use default if not injected for testing) + provider := opts.provider + if provider == nil { + provider = NewInterfaceProvider() + } + ifaces := opts.ifaces - if len(ifaces) == 0 { - ifaces = listMulticastInterfaces() + var requested []string + + // Determine mode based on whether interfaces were explicitly provided + if len(ifaces) > 0 { + // Explicit mode: extract names for the manager + requested = make([]string, len(ifaces)) + for i, iface := range ifaces { + requested[i] = iface.Name + } + } else { + // Dynamic mode: get current interfaces + ifaces = provider.MulticastInterfaces() + } + + factory := opts.connFactory + if factory == nil { + factory = NewConnectionFactory() } + + // Create SEPARATE managers for IPv4 and IPv6. + // This ensures IPv6 failures don't affect IPv4 (and vice versa). + ipv4Mgr := NewInterfaceManager(ifaces, requested) + ipv6Mgr := NewInterfaceManager(ifaces, requested) + // IPv4 interfaces - var ipv4conn *ipv4.PacketConn + var ipv4conn api.PacketConn if (opts.listenOn & IPv4) > 0 { var err error - ipv4conn, err = joinUdp4Multicast(ifaces) + ipv4conn, err = factory.CreateIPv4Conn(ifaces) if err != nil { return nil, err } } // IPv6 interfaces - var ipv6conn *ipv6.PacketConn + var ipv6conn api.PacketConn if (opts.listenOn & IPv6) > 0 { var err error - ipv6conn, err = joinUdp6Multicast(ifaces) + ipv6conn, err = factory.CreateIPv6Conn(ifaces) if err != nil { return nil, err } } - return &client{ + return &Client{ ipv4conn: ipv4conn, ipv6conn: ipv6conn, - ifaces: ifaces, + ipv4Mgr: ipv4Mgr, + ipv6Mgr: ipv6Mgr, + provider: provider, }, nil } var cleanupFreq = 5 * time.Second // Start listeners and waits for the shutdown signal from exit channel -func (c *client) mainloop(ctx context.Context, params *lookupParams) { +func (c *Client) mainloop(ctx context.Context, params *lookupParams) { // start listening for responses msgCh := make(chan *dns.Msg, 32) if c.ipv4conn != nil { @@ -319,7 +384,7 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) { } // Shutdown client will close currently open connections and channel implicitly. -func (c *client) shutdown() { +func (c *Client) shutdown() { if c.ipv4conn != nil { c.ipv4conn.Close() } @@ -330,22 +395,8 @@ func (c *client) shutdown() { // Data receiving routine reads from connection, unpacks packets into dns.Msg // structures and sends them to a given msgCh channel -func (c *client) recv(ctx context.Context, l interface{}, msgCh chan *dns.Msg) { - var readFrom func([]byte) (n int, src net.Addr, err error) - - switch pConn := l.(type) { - case *ipv6.PacketConn: - readFrom = func(b []byte) (n int, src net.Addr, err error) { - n, _, src, err = pConn.ReadFrom(b) - return - } - case *ipv4.PacketConn: - readFrom = func(b []byte) (n int, src net.Addr, err error) { - n, _, src, err = pConn.ReadFrom(b) - return - } - - default: +func (c *Client) recv(ctx context.Context, conn api.PacketConn, msgCh chan *dns.Msg) { + if conn == nil { return } @@ -355,12 +406,11 @@ func (c *client) recv(ctx context.Context, l interface{}, msgCh chan *dns.Msg) { // Handles the following cases: // - ReadFrom aborts with error due to closed UDP connection -> causes ctx cancel // - ReadFrom aborts otherwise. - // TODO: the context check can be removed. Verify! if ctx.Err() != nil || fatalErr != nil { return } - n, _, err := readFrom(buf) + n, _, _, err := conn.ReadFrom(buf) if err != nil { fatalErr = err continue @@ -384,7 +434,7 @@ func (c *client) recv(ctx context.Context, l interface{}, msgCh chan *dns.Msg) { // the main processing loop or some timeout/cancel fires. // TODO: move error reporting to shutdown function as periodicQuery is called from // go routine context. -func (c *client) periodicQuery(ctx context.Context, params *lookupParams) error { +func (c *Client) periodicQuery(ctx context.Context, params *lookupParams) error { // Do the first query immediately. if err := c.query(params); err != nil { return err @@ -426,7 +476,7 @@ func (c *client) periodicQuery(ctx context.Context, params *lookupParams) error // Performs the actual query by service name (browse) or service instance name (lookup), // start response listeners goroutines and loops over the entries channel. -func (c *client) query(params *lookupParams) error { +func (c *Client) query(params *lookupParams) error { var serviceName, serviceInstanceName string serviceName = fmt.Sprintf("%s.%s.", trimDot(params.Service), trimDot(params.Domain)) @@ -447,45 +497,74 @@ func (c *client) query(params *lookupParams) error { return c.sendQuery(m) } -// Pack the dns.Msg and write to available connections (multicast) -func (c *client) sendQuery(msg *dns.Msg) error { +// sendQuery packs the dns.Msg and writes to available connections (multicast). +// +// THE CRITICAL FIX: Dynamic iteration using ActiveIndices(). +// Gets a fresh snapshot of active indices on each call. The snapshot may become +// stale during iteration (race with syncInterfaces), but this is BENIGN because: +// - Sends to removed indices fail immediately +// - MarkFailed is idempotent (safe to call on already-removed index) +// - New indices are picked up on the next sendQuery call +func (c *Client) sendQuery(msg *dns.Msg) error { buf, err := msg.Pack() if err != nil { return err } + + // IPv4: iterate over CURRENT active indices if c.ipv4conn != nil { - // See https://pkg.go.dev/golang.org/x/net/ipv4#pkg-note-BUG - // As of Golang 1.18.4 - // On Windows, the ControlMessage for ReadFrom and WriteTo methods of PacketConn is not implemented. - var wcm ipv4.ControlMessage - for ifi := range c.ifaces { - switch runtime.GOOS { - case "darwin", "ios", "linux": - wcm.IfIndex = c.ifaces[ifi].Index - default: - if err := c.ipv4conn.SetMulticastInterface(&c.ifaces[ifi]); err != nil { - log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) - } + for _, idx := range c.ipv4Mgr.ActiveIndices() { + if _, err := c.ipv4conn.WriteTo(buf, idx, ipv4Addr); err != nil { + c.ipv4Mgr.MarkFailed(idx, err) } - _, _ = c.ipv4conn.WriteTo(buf, &wcm, ipv4Addr) } } + + // IPv6: same pattern, separate manager if c.ipv6conn != nil { - // See https://pkg.go.dev/golang.org/x/net/ipv6#pkg-note-BUG - // As of Golang 1.18.4 - // On Windows, the ControlMessage for ReadFrom and WriteTo methods of PacketConn is not implemented. - var wcm ipv6.ControlMessage - for ifi := range c.ifaces { - switch runtime.GOOS { - case "darwin", "ios", "linux": - wcm.IfIndex = c.ifaces[ifi].Index - default: - if err := c.ipv6conn.SetMulticastInterface(&c.ifaces[ifi]); err != nil { - log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) - } + for _, idx := range c.ipv6Mgr.ActiveIndices() { + if _, err := c.ipv6conn.WriteTo(buf, idx, ipv6Addr); err != nil { + c.ipv6Mgr.MarkFailed(idx, err) } - _, _ = c.ipv6conn.WriteTo(buf, &wcm, ipv6Addr) } } + return nil } + +// runInterfaceSync periodically polls for interface changes. +func (c *Client) runInterfaceSync(ctx context.Context) { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + c.syncInterfaces() + } + } +} + +// syncInterfaces polls for interface changes and recovers interfaces. +func (c *Client) syncInterfaces() { + current := c.provider.MulticastInterfaces() + + // Helper to sync a single manager + syncManager := func(mgr *InterfaceManager, conn api.PacketConn, groupIP net.IP) { + if conn == nil || mgr == nil { + return + } + for _, iface := range mgr.Sync(current) { + if err := conn.JoinGroup(&iface, &net.UDPAddr{IP: groupIP}); err != nil { + mgr.SetBackoff(iface.Name) + } else { + mgr.Activate(iface) + } + } + } + + syncManager(c.ipv4Mgr, c.ipv4conn, mdnsGroupIPv4) + syncManager(c.ipv6Mgr, c.ipv6conn, mdnsGroupIPv6) +} diff --git a/client_unit_test.go b/client_unit_test.go new file mode 100644 index 00000000..b05e3909 --- /dev/null +++ b/client_unit_test.go @@ -0,0 +1,715 @@ +package zeroconf + +import ( + "context" + "errors" + "net" + "sync" + "syscall" + "testing" + "time" + + "github.com/enbility/zeroconf/v3/api" + "github.com/enbility/zeroconf/v3/mocks" + "github.com/miekg/dns" + "github.com/stretchr/testify/mock" +) + +// testClient creates a Client with mock connections and InterfaceManagers. +// This is a helper for unit tests that need to create a Client directly. +func testClient(ipv4conn, ipv6conn api.PacketConn, ifaces []net.Interface) *Client { + return &Client{ + ipv4conn: ipv4conn, + ipv6conn: ipv6conn, + ipv4Mgr: NewInterfaceManager(ifaces, nil), + ipv6Mgr: NewInterfaceManager(ifaces, nil), + provider: NewInterfaceProvider(), + } +} + +// TestClient_InterfaceDisconnect_StopsSendingToFailedInterface is the key integration test +// that verifies the fix for the original issue: when an interface disconnects, we should +// stop sending to it rather than generating infinite warning logs. +// +// Original issue: Interface disconnects -> WriteTo fails -> code keeps trying -> infinite warnings +// Expected behavior: Interface disconnects -> WriteTo fails -> interface removed -> no more attempts +func TestClient_InterfaceDisconnect_StopsSendingToFailedInterface(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + + // Two interfaces: eth0 (will fail) and wlan0 (stays healthy) + ifaces := []net.Interface{ + {Index: 1, Name: "eth0"}, + {Index: 2, Name: "wlan0"}, + } + + // Track calls per interface + var mu sync.Mutex + callsToEth0 := 0 + callsToWlan0 := 0 + + // eth0 (index 1) will return ENETDOWN error (simulating disconnect) + // wlan0 (index 2) will succeed + mockIPv4.EXPECT().WriteTo(mock.Anything, mock.AnythingOfType("int"), mock.Anything).RunAndReturn( + func(b []byte, ifIndex int, dst net.Addr) (int, error) { + mu.Lock() + defer mu.Unlock() + if ifIndex == 1 { + callsToEth0++ + // Simulate interface gone - this is the error that was causing infinite warnings + return 0, syscall.ENETDOWN + } + callsToWlan0++ + return len(b), nil + }).Maybe() + + c := testClient(mockIPv4, nil, ifaces) + + msg := new(dns.Msg) + msg.SetQuestion("_test._tcp.local.", dns.TypePTR) + + // First query: both interfaces should be attempted + // eth0 fails with ENETDOWN, wlan0 succeeds + _ = c.sendQuery(msg) + + mu.Lock() + firstEth0Calls := callsToEth0 + firstWlan0Calls := callsToWlan0 + mu.Unlock() + + if firstEth0Calls != 1 { + t.Errorf("First query: expected 1 call to eth0, got %d", firstEth0Calls) + } + if firstWlan0Calls != 1 { + t.Errorf("First query: expected 1 call to wlan0, got %d", firstWlan0Calls) + } + + // Second query: eth0 should NOT be attempted (it was marked failed) + // Only wlan0 should receive the query + _ = c.sendQuery(msg) + + mu.Lock() + secondEth0Calls := callsToEth0 + secondWlan0Calls := callsToWlan0 + mu.Unlock() + + // THE KEY ASSERTION: eth0 should NOT have been called again + // This is the fix for the infinite warning issue + if secondEth0Calls != 1 { + t.Errorf("Second query: expected eth0 to NOT be called again (still 1), got %d calls total", secondEth0Calls) + } + if secondWlan0Calls != 2 { + t.Errorf("Second query: expected wlan0 to be called (now 2), got %d calls total", secondWlan0Calls) + } + + // Third query: same behavior - eth0 still excluded + _ = c.sendQuery(msg) + + mu.Lock() + thirdEth0Calls := callsToEth0 + thirdWlan0Calls := callsToWlan0 + mu.Unlock() + + if thirdEth0Calls != 1 { + t.Errorf("Third query: eth0 should still be excluded (1 call total), got %d", thirdEth0Calls) + } + if thirdWlan0Calls != 3 { + t.Errorf("Third query: expected wlan0 calls to be 3, got %d", thirdWlan0Calls) + } + + t.Logf("SUCCESS: After eth0 disconnect, subsequent queries only went to wlan0") + t.Logf("eth0 calls: %d (only the initial failed attempt)", thirdEth0Calls) + t.Logf("wlan0 calls: %d (all 3 queries)", thirdWlan0Calls) +} + +// TestClient_AllInterfacesDisconnect_NoInfiniteLoop verifies that if ALL interfaces +// disconnect, we don't enter an infinite loop - we just have no interfaces to send to. +func TestClient_AllInterfacesDisconnect_NoInfiniteLoop(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + + ifaces := []net.Interface{{Index: 1, Name: "eth0"}} + + callCount := 0 + var mu sync.Mutex + + // Interface always returns ENETDOWN + mockIPv4.EXPECT().WriteTo(mock.Anything, mock.AnythingOfType("int"), mock.Anything).RunAndReturn( + func(b []byte, ifIndex int, dst net.Addr) (int, error) { + mu.Lock() + callCount++ + mu.Unlock() + return 0, syscall.ENETDOWN + }).Maybe() + + c := testClient(mockIPv4, nil, ifaces) + + msg := new(dns.Msg) + msg.SetQuestion("_test._tcp.local.", dns.TypePTR) + + // Send multiple queries + for i := 0; i < 10; i++ { + _ = c.sendQuery(msg) + } + + mu.Lock() + finalCount := callCount + mu.Unlock() + + // Should only have 1 call - the first one that failed and removed the interface + // Without the fix, this would be 10 (one per query, each generating a warning) + if finalCount != 1 { + t.Errorf("Expected only 1 call to failed interface, got %d (suggests interface not removed)", finalCount) + } + + t.Logf("SUCCESS: Only %d call to disconnected interface across 10 queries", finalCount) +} + +// TestClient_SendQuery_WritesToConnections verifies sendQuery writes to both connections +func TestClient_SendQuery_WritesToConnections(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + mockIPv6 := mocks.NewMockPacketConn(t) + + ifaces := []net.Interface{{Index: 1, Name: "eth0"}} + + // Expect WriteTo to be called on both connections + mockIPv4.EXPECT().WriteTo(mock.Anything, 1, mock.Anything).Return(0, nil).Once() + mockIPv6.EXPECT().WriteTo(mock.Anything, 1, mock.Anything).Return(0, nil).Once() + + c := testClient(mockIPv4, mockIPv6, ifaces) + + msg := new(dns.Msg) + msg.SetQuestion("_test._tcp.local.", dns.TypePTR) + + err := c.sendQuery(msg) + if err != nil { + t.Fatalf("sendQuery failed: %v", err) + } +} + +// TestClient_SendQuery_MultipleInterfaces verifies sendQuery writes to all interfaces +func TestClient_SendQuery_MultipleInterfaces(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + mockIPv6 := mocks.NewMockPacketConn(t) + + ifaces := []net.Interface{ + {Index: 1, Name: "eth0"}, + {Index: 2, Name: "wlan0"}, + {Index: 3, Name: "lo0"}, + } + + // Expect WriteTo to be called 3 times on each connection (once per interface) + mockIPv4.EXPECT().WriteTo(mock.Anything, 1, mock.Anything).Return(0, nil).Once() + mockIPv4.EXPECT().WriteTo(mock.Anything, 2, mock.Anything).Return(0, nil).Once() + mockIPv4.EXPECT().WriteTo(mock.Anything, 3, mock.Anything).Return(0, nil).Once() + mockIPv6.EXPECT().WriteTo(mock.Anything, 1, mock.Anything).Return(0, nil).Once() + mockIPv6.EXPECT().WriteTo(mock.Anything, 2, mock.Anything).Return(0, nil).Once() + mockIPv6.EXPECT().WriteTo(mock.Anything, 3, mock.Anything).Return(0, nil).Once() + + c := testClient(mockIPv4, mockIPv6, ifaces) + + msg := new(dns.Msg) + msg.SetQuestion("_test._tcp.local.", dns.TypePTR) + + err := c.sendQuery(msg) + if err != nil { + t.Fatalf("sendQuery failed: %v", err) + } +} + +// TestClient_SendQuery_IPv4Only verifies sendQuery handles IPv4-only client +func TestClient_SendQuery_IPv4Only(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + + mockIPv4.EXPECT().WriteTo(mock.Anything, 1, mock.Anything).Return(0, nil).Once() + + ifaces := []net.Interface{{Index: 1, Name: "eth0"}} + c := testClient(mockIPv4, nil, ifaces) + + msg := new(dns.Msg) + msg.SetQuestion("_test._tcp.local.", dns.TypePTR) + + err := c.sendQuery(msg) + if err != nil { + t.Fatalf("sendQuery failed: %v", err) + } +} + +// TestClient_SendQuery_IPv6Only verifies sendQuery handles IPv6-only client +func TestClient_SendQuery_IPv6Only(t *testing.T) { + mockIPv6 := mocks.NewMockPacketConn(t) + + mockIPv6.EXPECT().WriteTo(mock.Anything, 1, mock.Anything).Return(0, nil).Once() + + ifaces := []net.Interface{{Index: 1, Name: "eth0"}} + c := testClient(nil, mockIPv6, ifaces) + + msg := new(dns.Msg) + msg.SetQuestion("_test._tcp.local.", dns.TypePTR) + + err := c.sendQuery(msg) + if err != nil { + t.Fatalf("sendQuery failed: %v", err) + } +} + +// TestClient_Shutdown_ClosesConnections verifies shutdown properly closes connections +func TestClient_Shutdown_ClosesConnections(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + mockIPv6 := mocks.NewMockPacketConn(t) + + mockIPv4.EXPECT().Close().Return(nil).Once() + mockIPv6.EXPECT().Close().Return(nil).Once() + + ifaces := []net.Interface{{Index: 1, Name: "eth0"}} + c := testClient(mockIPv4, mockIPv6, ifaces) + + c.shutdown() +} + +// TestClientConfig verifies client configuration options +func TestClientConfig(t *testing.T) { + t.Run("default options", func(t *testing.T) { + opts := applyOpts() + if opts.listenOn != IPv4AndIPv6 { + t.Errorf("Expected default listenOn IPv4AndIPv6, got %d", opts.listenOn) + } + }) + + t.Run("IPv4 only", func(t *testing.T) { + opts := applyOpts(SelectIPTraffic(IPv4)) + if opts.listenOn != IPv4 { + t.Errorf("Expected listenOn IPv4, got %d", opts.listenOn) + } + }) + + t.Run("IPv6 only", func(t *testing.T) { + opts := applyOpts(SelectIPTraffic(IPv6)) + if opts.listenOn != IPv6 { + t.Errorf("Expected listenOn IPv6, got %d", opts.listenOn) + } + }) + + t.Run("custom interfaces", func(t *testing.T) { + ifaces := []net.Interface{{Index: 1, Name: "eth0"}} + opts := applyOpts(SelectIfaces(ifaces)) + if len(opts.ifaces) != 1 { + t.Errorf("Expected 1 interface, got %d", len(opts.ifaces)) + } + }) +} + +// TestNewClient_WithMockFactory verifies newClient uses the connection factory +func TestNewClient_WithMockFactory(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + mockIPv6 := mocks.NewMockPacketConn(t) + factory := mocks.NewMockConnectionFactory(t) + + factory.EXPECT().CreateIPv4Conn(mock.Anything).Return(mockIPv4, nil).Once() + factory.EXPECT().CreateIPv6Conn(mock.Anything).Return(mockIPv6, nil).Once() + + opts := clientOpts{ + listenOn: IPv4AndIPv6, + connFactory: factory, + } + + c, err := newClient(opts) + if err != nil { + t.Fatalf("newClient failed: %v", err) + } + + if c.ipv4conn != mockIPv4 { + t.Error("Expected mock IPv4 connection to be used") + } + if c.ipv6conn != mockIPv6 { + t.Error("Expected mock IPv6 connection to be used") + } +} + +// TestNewClient_ExportedConstructor verifies the exported NewClient constructor +func TestNewClient_ExportedConstructor(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + mockIPv6 := mocks.NewMockPacketConn(t) + factory := mocks.NewMockConnectionFactory(t) + + factory.EXPECT().CreateIPv4Conn(mock.Anything).Return(mockIPv4, nil).Once() + factory.EXPECT().CreateIPv6Conn(mock.Anything).Return(mockIPv6, nil).Once() + + c, err := NewClient(WithClientConnFactory(factory)) + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + if c.ipv4conn != mockIPv4 { + t.Error("Expected mock IPv4 connection to be used") + } + if c.ipv6conn != mockIPv6 { + t.Error("Expected mock IPv6 connection to be used") + } +} + +// TestWithClientConnFactory verifies the WithClientConnFactory option +func TestWithClientConnFactory(t *testing.T) { + factory := mocks.NewMockConnectionFactory(t) + + opts := applyOpts(WithClientConnFactory(factory)) + + if opts.connFactory != factory { + t.Error("Expected connection factory to be set") + } +} + +// TestClient_Query_WithInstance verifies query builds correct message for Lookup +func TestClient_Query_WithInstance(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + + // Capture the DNS message to verify it contains SRV and TXT questions + var capturedMsg []byte + mockIPv4.EXPECT().WriteTo(mock.Anything, 1, mock.Anything).RunAndReturn( + func(b []byte, ifIndex int, dst net.Addr) (int, error) { + capturedMsg = make([]byte, len(b)) + copy(capturedMsg, b) + return len(b), nil + }).Once() + + ifaces := []net.Interface{{Index: 1, Name: "eth0"}} + c := testClient(mockIPv4, nil, ifaces) + + params := newLookupParams("myservice", "_http._tcp", "local", false, + make(chan *ServiceEntry), make(chan *ServiceEntry)) + + err := c.query(params) + if err != nil { + t.Fatalf("query failed: %v", err) + } + + // Parse the captured message + msg := new(dns.Msg) + if err := msg.Unpack(capturedMsg); err != nil { + t.Fatalf("Failed to unpack captured message: %v", err) + } + + // For instance lookup, we expect SRV and TXT questions + if len(msg.Question) != 2 { + t.Fatalf("Expected 2 questions for instance lookup, got %d", len(msg.Question)) + } + + // Check question types + hasSRV := false + hasTXT := false + for _, q := range msg.Question { + if q.Qtype == dns.TypeSRV { + hasSRV = true + } + if q.Qtype == dns.TypeTXT { + hasTXT = true + } + } + + if !hasSRV { + t.Error("Expected SRV question for instance lookup") + } + if !hasTXT { + t.Error("Expected TXT question for instance lookup") + } +} + +// TestClient_Query_Browse verifies query builds correct message for Browse +func TestClient_Query_Browse(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + + var capturedMsg []byte + mockIPv4.EXPECT().WriteTo(mock.Anything, 1, mock.Anything).RunAndReturn( + func(b []byte, ifIndex int, dst net.Addr) (int, error) { + capturedMsg = make([]byte, len(b)) + copy(capturedMsg, b) + return len(b), nil + }).Once() + + ifaces := []net.Interface{{Index: 1, Name: "eth0"}} + c := testClient(mockIPv4, nil, ifaces) + + // No instance = browse mode + params := newLookupParams("", "_http._tcp", "local", true, + make(chan *ServiceEntry), make(chan *ServiceEntry)) + + err := c.query(params) + if err != nil { + t.Fatalf("query failed: %v", err) + } + + msg := new(dns.Msg) + if err := msg.Unpack(capturedMsg); err != nil { + t.Fatalf("Failed to unpack captured message: %v", err) + } + + // For browse, we expect a single PTR question + if len(msg.Question) != 1 { + t.Fatalf("Expected 1 question for browse, got %d", len(msg.Question)) + } + + if msg.Question[0].Qtype != dns.TypePTR { + t.Errorf("Expected PTR question for browse, got %d", msg.Question[0].Qtype) + } +} + +// createMockDNSResponse creates a complete DNS response for testing Lookup +func createMockDNSResponse(instanceName, hostName string, port uint16, ip net.IP) []byte { + msg := new(dns.Msg) + msg.Response = true + + // SRV record + msg.Answer = append(msg.Answer, &dns.SRV{ + Hdr: dns.RR_Header{ + Name: instanceName, + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + Ttl: 120, + }, + Priority: 0, + Weight: 0, + Port: port, + Target: hostName, + }) + + // TXT record + msg.Answer = append(msg.Answer, &dns.TXT{ + Hdr: dns.RR_Header{ + Name: instanceName, + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + Ttl: 120, + }, + Txt: []string{"key=value"}, + }) + + // A record + msg.Extra = append(msg.Extra, &dns.A{ + Hdr: dns.RR_Header{ + Name: hostName, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 120, + }, + A: ip, + }) + + data, _ := msg.Pack() + return data +} + +// TestBrowse_WithMockConnections tests the full Browse flow with mocked connections +func TestBrowse_WithMockConnections(t *testing.T) { + // Reduce query interval for faster test + oldInterval := initialQueryInterval + initialQueryInterval = 50 * time.Millisecond + defer func() { initialQueryInterval = oldInterval }() + + mockIPv4 := mocks.NewMockPacketConn(t) + factory := mocks.NewMockConnectionFactory(t) + + factory.EXPECT().CreateIPv4Conn(mock.Anything).Return(mockIPv4, nil).Once() + + // Create a DNS response with PTR record (for browse) + instanceName := "myservice._http._tcp.local." + serviceName := "_http._tcp.local." + hostName := "myhost.local." + + msg := new(dns.Msg) + msg.Response = true + + // PTR record pointing to the instance + msg.Answer = append(msg.Answer, &dns.PTR{ + Hdr: dns.RR_Header{ + Name: serviceName, + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + Ttl: 120, + }, + Ptr: instanceName, + }) + + // SRV record + msg.Answer = append(msg.Answer, &dns.SRV{ + Hdr: dns.RR_Header{ + Name: instanceName, + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + Ttl: 120, + }, + Port: 8080, + Target: hostName, + }) + + // TXT record + msg.Answer = append(msg.Answer, &dns.TXT{ + Hdr: dns.RR_Header{ + Name: instanceName, + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + Ttl: 120, + }, + Txt: []string{"version=1.0"}, + }) + + // A record + msg.Extra = append(msg.Extra, &dns.A{ + Hdr: dns.RR_Header{ + Name: hostName, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 120, + }, + A: net.ParseIP("192.168.1.100"), + }) + + responseData, _ := msg.Pack() + + var readCount int + var mu sync.Mutex + + mockIPv4.EXPECT().WriteTo(mock.Anything, mock.Anything, mock.Anything).Return(0, nil).Maybe() + mockIPv4.EXPECT().ReadFrom(mock.Anything).RunAndReturn(func(b []byte) (int, int, net.Addr, error) { + mu.Lock() + readCount++ + count := readCount + mu.Unlock() + + if count == 1 { + copy(b, responseData) + return len(responseData), 1, &net.UDPAddr{IP: net.ParseIP("192.168.1.100"), Port: 5353}, nil + } + time.Sleep(100 * time.Millisecond) + return 0, 0, nil, errors.New("context cancelled") + }).Maybe() + mockIPv4.EXPECT().Close().Return(nil).Maybe() + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + entries := make(chan *ServiceEntry, 1) + removed := make(chan *ServiceEntry, 1) + + var browseErr error + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + browseErr = Browse(ctx, "_http._tcp", "local", entries, removed, + WithClientConnFactory(factory), + SelectIPTraffic(IPv4)) + }() + + select { + case entry := <-entries: + if entry.Instance != "myservice" { + t.Errorf("Expected instance 'myservice', got '%s'", entry.Instance) + } + if entry.Port != 8080 { + t.Errorf("Expected port 8080, got %d", entry.Port) + } + if len(entry.Text) == 0 || entry.Text[0] != "version=1.0" { + t.Errorf("Expected text 'version=1.0', got %v", entry.Text) + } + cancel() + case <-ctx.Done(): + t.Log("Context done before receiving entry") + } + + wg.Wait() + + if browseErr != nil && browseErr != context.DeadlineExceeded && browseErr != context.Canceled { + t.Errorf("Browse returned unexpected error: %v", browseErr) + } +} + +// TestLookup_WithMockConnections tests the full Lookup flow with mocked connections +func TestLookup_WithMockConnections(t *testing.T) { + // Reduce query interval for faster test + oldInterval := initialQueryInterval + initialQueryInterval = 50 * time.Millisecond + defer func() { initialQueryInterval = oldInterval }() + + mockIPv4 := mocks.NewMockPacketConn(t) + factory := mocks.NewMockConnectionFactory(t) + + // Factory returns our mock connection (IPv4 only since we use SelectIPTraffic(IPv4)) + factory.EXPECT().CreateIPv4Conn(mock.Anything).Return(mockIPv4, nil).Once() + + // Create the DNS response + instanceName := "myservice._http._tcp.local." + hostName := "myhost.local." + responseData := createMockDNSResponse(instanceName, hostName, 8080, net.ParseIP("192.168.1.100")) + + // Track ReadFrom calls + var readCount int + var mu sync.Mutex + + // WriteTo for queries - just accept them + mockIPv4.EXPECT().WriteTo(mock.Anything, mock.Anything, mock.Anything).Return(0, nil).Maybe() + + // ReadFrom returns the response once, then blocks + mockIPv4.EXPECT().ReadFrom(mock.Anything).RunAndReturn(func(b []byte) (int, int, net.Addr, error) { + mu.Lock() + readCount++ + count := readCount + mu.Unlock() + + if count == 1 { + // First call: return the DNS response + copy(b, responseData) + return len(responseData), 1, &net.UDPAddr{IP: net.ParseIP("192.168.1.100"), Port: 5353}, nil + } + // Subsequent calls: block until test ends (simulates waiting for more data) + time.Sleep(100 * time.Millisecond) + return 0, 0, nil, errors.New("context cancelled") + }).Maybe() + + // Close when shutdown + mockIPv4.EXPECT().Close().Return(nil).Maybe() + + // Create context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + entries := make(chan *ServiceEntry, 1) + + // Run Lookup in background + var lookupErr error + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + lookupErr = Lookup(ctx, "myservice", "_http._tcp", "local", entries, + WithClientConnFactory(factory), + SelectIPTraffic(IPv4)) + }() + + // Wait for entry or timeout + select { + case entry := <-entries: + if entry.Instance != "myservice" { + t.Errorf("Expected instance 'myservice', got '%s'", entry.Instance) + } + if entry.Port != 8080 { + t.Errorf("Expected port 8080, got %d", entry.Port) + } + if entry.HostName != hostName { + t.Errorf("Expected hostname '%s', got '%s'", hostName, entry.HostName) + } + if len(entry.AddrIPv4) == 0 { + t.Error("Expected IPv4 address") + } else if !entry.AddrIPv4[0].Equal(net.ParseIP("192.168.1.100")) { + t.Errorf("Expected IP 192.168.1.100, got %s", entry.AddrIPv4[0]) + } + // Success - cancel to clean up + cancel() + case <-ctx.Done(): + t.Log("Context done before receiving entry (may be timing issue)") + } + + wg.Wait() + + // Context cancellation is expected, not an error for Lookup + if lookupErr != nil && lookupErr != context.DeadlineExceeded && lookupErr != context.Canceled { + t.Errorf("Lookup returned unexpected error: %v", lookupErr) + } +} diff --git a/conn_factory.go b/conn_factory.go new file mode 100644 index 00000000..a706b60d --- /dev/null +++ b/conn_factory.go @@ -0,0 +1,72 @@ +package zeroconf + +import ( + "fmt" + "net" + + "github.com/enbility/zeroconf/v3/api" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +// defaultConnectionFactory is the production implementation of api.ConnectionFactory. +// It creates real UDP multicast connections for mDNS communication. +type defaultConnectionFactory struct{} + +// Compile-time interface check +var _ api.ConnectionFactory = (*defaultConnectionFactory)(nil) + +// NewConnectionFactory creates a new default connection factory. +func NewConnectionFactory() api.ConnectionFactory { + return &defaultConnectionFactory{} +} + +func (f *defaultConnectionFactory) CreateIPv4Conn(ifaces []net.Interface) (api.PacketConn, error) { + udpConn, err := net.ListenUDP("udp4", mdnsWildcardAddrIPv4) + if err != nil { + return nil, err + } + + pkConn := ipv4.NewPacketConn(udpConn) + _ = pkConn.SetControlMessage(ipv4.FlagInterface, true) + + var failedJoins int + for _, iface := range ifaces { + if err := pkConn.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { + failedJoins++ + } + } + if failedJoins == len(ifaces) { + pkConn.Close() + return nil, fmt.Errorf("udp4: failed to join any of these interfaces: %v", ifaces) + } + + _ = pkConn.SetMulticastTTL(255) + + return newIPv4PacketConn(pkConn), nil +} + +func (f *defaultConnectionFactory) CreateIPv6Conn(ifaces []net.Interface) (api.PacketConn, error) { + udpConn, err := net.ListenUDP("udp6", mdnsWildcardAddrIPv6) + if err != nil { + return nil, err + } + + pkConn := ipv6.NewPacketConn(udpConn) + _ = pkConn.SetControlMessage(ipv6.FlagInterface, true) + + var failedJoins int + for _, iface := range ifaces { + if err := pkConn.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { + failedJoins++ + } + } + if failedJoins == len(ifaces) { + pkConn.Close() + return nil, fmt.Errorf("udp6: failed to join any of these interfaces: %v", ifaces) + } + + _ = pkConn.SetMulticastHopLimit(255) + + return newIPv6PacketConn(pkConn), nil +} diff --git a/conn_ipv4.go b/conn_ipv4.go new file mode 100644 index 00000000..6acc8479 --- /dev/null +++ b/conn_ipv4.go @@ -0,0 +1,93 @@ +package zeroconf + +import ( + "fmt" + "net" + "runtime" + "syscall" + + "github.com/enbility/zeroconf/v3/api" + "golang.org/x/net/ipv4" +) + +// ipv4PacketConn wraps ipv4.PacketConn to implement api.PacketConn interface. +// This adapter is needed because ipv4.PacketConn uses ControlMessage for +// interface selection, but we only need the IfIndex field. +type ipv4PacketConn struct { + conn *ipv4.PacketConn +} + +// Compile-time interface check +var _ api.PacketConn = (*ipv4PacketConn)(nil) + +// newIPv4PacketConn creates a new IPv4 PacketConn wrapper. +func newIPv4PacketConn(conn *ipv4.PacketConn) *ipv4PacketConn { + return &ipv4PacketConn{conn: conn} +} + +func (c *ipv4PacketConn) ReadFrom(b []byte) (n int, ifIndex int, src net.Addr, err error) { + n, cm, src, err := c.conn.ReadFrom(b) + if cm != nil { + ifIndex = cm.IfIndex + } + return +} + +func (c *ipv4PacketConn) WriteTo(b []byte, ifIndex int, dst net.Addr) (n int, err error) { + // See https://pkg.go.dev/golang.org/x/net/ipv4#pkg-note-BUG + // On Windows, the ControlMessage for WriteTo is not implemented. + // Use SetMulticastInterface as fallback. + var cm *ipv4.ControlMessage + + if ifIndex != 0 { + switch runtime.GOOS { + case "darwin", "ios", "linux": + cm = &ipv4.ControlMessage{IfIndex: ifIndex} + + default: + // Windows and other platforms: validate and set interface. + // CRITICAL: Return errors instead of logging them. The caller + // (via InterfaceManager.MarkFailed) handles removal and backoff. + iface, err := net.InterfaceByIndex(ifIndex) + if err != nil { + // Interface gone - wrap with ENXIO so isInterfaceGone() detects it + return 0, fmt.Errorf("interface index %d: %w", ifIndex, syscall.ENXIO) + } + // Verify interface is actually up + if iface.Flags&net.FlagUp == 0 { + return 0, fmt.Errorf("interface %s is down: %w", iface.Name, syscall.ENETDOWN) + } + if err := c.conn.SetMulticastInterface(iface); err != nil { + // Return the actual error - may contain WSAENETDOWN or similar + return 0, fmt.Errorf("set multicast interface %s: %w", iface.Name, err) + } + } + } + + return c.conn.WriteTo(b, cm, dst) +} + +func (c *ipv4PacketConn) Close() error { + return c.conn.Close() +} + +func (c *ipv4PacketConn) JoinGroup(ifi *net.Interface, group net.Addr) error { + return c.conn.JoinGroup(ifi, group) +} + +func (c *ipv4PacketConn) LeaveGroup(ifi *net.Interface, group net.Addr) error { + return c.conn.LeaveGroup(ifi, group) +} + +func (c *ipv4PacketConn) SetMulticastTTL(ttl int) error { + return c.conn.SetMulticastTTL(ttl) +} + +func (c *ipv4PacketConn) SetMulticastHopLimit(hopLimit int) error { + // IPv4 doesn't have hop limit, this is a no-op + return nil +} + +func (c *ipv4PacketConn) SetMulticastInterface(ifi *net.Interface) error { + return c.conn.SetMulticastInterface(ifi) +} diff --git a/conn_ipv6.go b/conn_ipv6.go new file mode 100644 index 00000000..c723148f --- /dev/null +++ b/conn_ipv6.go @@ -0,0 +1,93 @@ +package zeroconf + +import ( + "fmt" + "net" + "runtime" + "syscall" + + "github.com/enbility/zeroconf/v3/api" + "golang.org/x/net/ipv6" +) + +// ipv6PacketConn wraps ipv6.PacketConn to implement api.PacketConn interface. +// This adapter is needed because ipv6.PacketConn uses ControlMessage for +// interface selection, but we only need the IfIndex field. +type ipv6PacketConn struct { + conn *ipv6.PacketConn +} + +// Compile-time interface check +var _ api.PacketConn = (*ipv6PacketConn)(nil) + +// newIPv6PacketConn creates a new IPv6 PacketConn wrapper. +func newIPv6PacketConn(conn *ipv6.PacketConn) *ipv6PacketConn { + return &ipv6PacketConn{conn: conn} +} + +func (c *ipv6PacketConn) ReadFrom(b []byte) (n int, ifIndex int, src net.Addr, err error) { + n, cm, src, err := c.conn.ReadFrom(b) + if cm != nil { + ifIndex = cm.IfIndex + } + return +} + +func (c *ipv6PacketConn) WriteTo(b []byte, ifIndex int, dst net.Addr) (n int, err error) { + // See https://pkg.go.dev/golang.org/x/net/ipv6#pkg-note-BUG + // On Windows, the ControlMessage for WriteTo is not implemented. + // Use SetMulticastInterface as fallback. + var cm *ipv6.ControlMessage + + if ifIndex != 0 { + switch runtime.GOOS { + case "darwin", "ios", "linux": + cm = &ipv6.ControlMessage{IfIndex: ifIndex} + + default: + // Windows and other platforms: validate and set interface. + // CRITICAL: Return errors instead of logging them. The caller + // (via InterfaceManager.MarkFailed) handles removal and backoff. + iface, err := net.InterfaceByIndex(ifIndex) + if err != nil { + // Interface gone - wrap with ENXIO so isInterfaceGone() detects it + return 0, fmt.Errorf("interface index %d: %w", ifIndex, syscall.ENXIO) + } + // Verify interface is actually up + if iface.Flags&net.FlagUp == 0 { + return 0, fmt.Errorf("interface %s is down: %w", iface.Name, syscall.ENETDOWN) + } + if err := c.conn.SetMulticastInterface(iface); err != nil { + // Return the actual error - may contain WSAENETDOWN or similar + return 0, fmt.Errorf("set multicast interface %s: %w", iface.Name, err) + } + } + } + + return c.conn.WriteTo(b, cm, dst) +} + +func (c *ipv6PacketConn) Close() error { + return c.conn.Close() +} + +func (c *ipv6PacketConn) JoinGroup(ifi *net.Interface, group net.Addr) error { + return c.conn.JoinGroup(ifi, group) +} + +func (c *ipv6PacketConn) LeaveGroup(ifi *net.Interface, group net.Addr) error { + return c.conn.LeaveGroup(ifi, group) +} + +func (c *ipv6PacketConn) SetMulticastTTL(ttl int) error { + // IPv6 doesn't have TTL, this is a no-op + return nil +} + +func (c *ipv6PacketConn) SetMulticastHopLimit(hopLimit int) error { + return c.conn.SetMulticastHopLimit(hopLimit) +} + +func (c *ipv6PacketConn) SetMulticastInterface(ifi *net.Interface) error { + return c.conn.SetMulticastInterface(ifi) +} diff --git a/conn_provider.go b/conn_provider.go new file mode 100644 index 00000000..611e8536 --- /dev/null +++ b/conn_provider.go @@ -0,0 +1,37 @@ +package zeroconf + +import ( + "net" + + "github.com/enbility/zeroconf/v3/api" +) + +// defaultInterfaceProvider is the production implementation of api.InterfaceProvider. +// It lists network interfaces capable of multicast communication. +type defaultInterfaceProvider struct{} + +// Compile-time interface check +var _ api.InterfaceProvider = (*defaultInterfaceProvider)(nil) + +// NewInterfaceProvider creates a new default interface provider. +func NewInterfaceProvider() api.InterfaceProvider { + return &defaultInterfaceProvider{} +} + +// MulticastInterfaces returns all network interfaces that are up and support multicast. +func (p *defaultInterfaceProvider) MulticastInterfaces() []net.Interface { + var interfaces []net.Interface + ifaces, err := net.Interfaces() + if err != nil { + return nil + } + for _, ifi := range ifaces { + if (ifi.Flags & net.FlagUp) == 0 { + continue + } + if (ifi.Flags & net.FlagMulticast) > 0 { + interfaces = append(interfaces, ifi) + } + } + return interfaces +} diff --git a/connection.go b/connection.go deleted file mode 100644 index 0efbac12..00000000 --- a/connection.go +++ /dev/null @@ -1,119 +0,0 @@ -package zeroconf - -import ( - "fmt" - "net" - - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" -) - -var ( - // Multicast groups used by mDNS - mdnsGroupIPv4 = net.IPv4(224, 0, 0, 251) - mdnsGroupIPv6 = net.ParseIP("ff02::fb") - - // mDNS wildcard addresses - mdnsWildcardAddrIPv4 = &net.UDPAddr{ - IP: net.ParseIP("224.0.0.0"), - Port: 5353, - } - mdnsWildcardAddrIPv6 = &net.UDPAddr{ - IP: net.ParseIP("ff02::"), - // IP: net.ParseIP("fd00::12d3:26e7:48db:e7d"), - Port: 5353, - } - - // mDNS endpoint addresses - ipv4Addr = &net.UDPAddr{ - IP: mdnsGroupIPv4, - Port: 5353, - } - ipv6Addr = &net.UDPAddr{ - IP: mdnsGroupIPv6, - Port: 5353, - } -) - -func joinUdp6Multicast(interfaces []net.Interface) (*ipv6.PacketConn, error) { - udpConn, err := net.ListenUDP("udp6", mdnsWildcardAddrIPv6) - if err != nil { - return nil, err - } - - // Join multicast groups to receive announcements - pkConn := ipv6.NewPacketConn(udpConn) - _ = pkConn.SetControlMessage(ipv6.FlagInterface, true) - - if len(interfaces) == 0 { - interfaces = listMulticastInterfaces() - } - // log.Println("Using multicast interfaces: ", interfaces) - - var failedJoins int - for _, iface := range interfaces { - if err := pkConn.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { - // log.Println("Udp6 JoinGroup failed for iface ", iface) - failedJoins++ - } - } - if failedJoins == len(interfaces) { - pkConn.Close() - return nil, fmt.Errorf("udp6: failed to join any of these interfaces: %v", interfaces) - } - - _ = pkConn.SetMulticastHopLimit(255) - - return pkConn, nil -} - -func joinUdp4Multicast(interfaces []net.Interface) (*ipv4.PacketConn, error) { - udpConn, err := net.ListenUDP("udp4", mdnsWildcardAddrIPv4) - if err != nil { - // log.Printf("[ERR] bonjour: Failed to bind to udp4 mutlicast: %v", err) - return nil, err - } - - // Join multicast groups to receive announcements - pkConn := ipv4.NewPacketConn(udpConn) - _ = pkConn.SetControlMessage(ipv4.FlagInterface, true) - - if len(interfaces) == 0 { - interfaces = listMulticastInterfaces() - } - // log.Println("Using multicast interfaces: ", interfaces) - - var failedJoins int - for _, iface := range interfaces { - if err := pkConn.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { - // log.Println("Udp4 JoinGroup failed for iface ", iface) - failedJoins++ - } - } - if failedJoins == len(interfaces) { - pkConn.Close() - return nil, fmt.Errorf("udp4: failed to join any of these interfaces: %v", interfaces) - } - - _ = pkConn.SetMulticastTTL(255) - - return pkConn, nil -} - -func listMulticastInterfaces() []net.Interface { - var interfaces []net.Interface - ifaces, err := net.Interfaces() - if err != nil { - return nil - } - for _, ifi := range ifaces { - if (ifi.Flags & net.FlagUp) == 0 { - continue - } - if (ifi.Flags & net.FlagMulticast) > 0 { - interfaces = append(interfaces, ifi) - } - } - - return interfaces -} diff --git a/error_classify.go b/error_classify.go new file mode 100644 index 00000000..d70fb8d8 --- /dev/null +++ b/error_classify.go @@ -0,0 +1,60 @@ +package zeroconf + +import ( + "errors" + "strings" + "syscall" +) + +// Windows socket error codes (not in standard syscall package). +// These constants are safe to define cross-platform because errors.Is() +// performs type comparison - on non-Windows systems, these simply won't match. +const ( + WSAENETDOWN syscall.Errno = 10050 // Network is down + WSAEADDRNOTAVAIL syscall.Errno = 10049 // Cannot assign requested address + WSAEINVAL syscall.Errno = 10022 // Invalid argument +) + +// interfaceGoneErrors lists all error codes that indicate an interface is gone. +var interfaceGoneErrors = []syscall.Errno{ + // Unix errors + syscall.ENXIO, // "no such device or address" + syscall.ENODEV, // "no such device" + syscall.EADDRNOTAVAIL, // "can't assign requested address" + syscall.EINVAL, // "invalid argument" (stale ifIndex) + syscall.ENETDOWN, // "network is down" + syscall.ENETUNREACH, // "network unreachable" + // Windows errors + WSAENETDOWN, + WSAEADDRNOTAVAIL, + WSAEINVAL, +} + +// isInterfaceGone returns true if the error indicates the interface +// is no longer available and should be removed from active set. +// +// Uses errors.Is() for proper unwrapping of fmt.Errorf("%w") chains. +func isInterfaceGone(err error) bool { + if err == nil { + return false + } + + // Check known error codes + for _, e := range interfaceGoneErrors { + if errors.Is(err, e) { + return true + } + } + + // Fallback: check error message patterns for unknown error codes. + // This handles edge cases on unusual platforms or new OS versions. + errStr := strings.ToLower(err.Error()) + if strings.Contains(errStr, "no such device") || + strings.Contains(errStr, "no such interface") || + strings.Contains(errStr, "network is down") || + strings.Contains(errStr, "network is unreachable") { + return true + } + + return false +} diff --git a/error_classify_test.go b/error_classify_test.go new file mode 100644 index 00000000..1b37258a --- /dev/null +++ b/error_classify_test.go @@ -0,0 +1,131 @@ +package zeroconf + +import ( + "errors" + "fmt" + "syscall" + "testing" +) + +func TestIsInterfaceGone_Nil_ReturnsFalse(t *testing.T) { + if isInterfaceGone(nil) { + t.Error("expected false for nil error") + } +} + +func TestIsInterfaceGone_ENXIO_ReturnsTrue(t *testing.T) { + err := syscall.ENXIO + if !isInterfaceGone(err) { + t.Errorf("expected true for ENXIO, got false") + } +} + +func TestIsInterfaceGone_ENODEV_ReturnsTrue(t *testing.T) { + err := syscall.ENODEV + if !isInterfaceGone(err) { + t.Errorf("expected true for ENODEV, got false") + } +} + +func TestIsInterfaceGone_EADDRNOTAVAIL_ReturnsTrue(t *testing.T) { + err := syscall.EADDRNOTAVAIL + if !isInterfaceGone(err) { + t.Errorf("expected true for EADDRNOTAVAIL, got false") + } +} + +func TestIsInterfaceGone_EINVAL_ReturnsTrue(t *testing.T) { + err := syscall.EINVAL + if !isInterfaceGone(err) { + t.Errorf("expected true for EINVAL, got false") + } +} + +func TestIsInterfaceGone_ENETDOWN_ReturnsTrue(t *testing.T) { + err := syscall.ENETDOWN + if !isInterfaceGone(err) { + t.Errorf("expected true for ENETDOWN, got false") + } +} + +func TestIsInterfaceGone_ENETUNREACH_ReturnsTrue(t *testing.T) { + err := syscall.ENETUNREACH + if !isInterfaceGone(err) { + t.Errorf("expected true for ENETUNREACH, got false") + } +} + +func TestIsInterfaceGone_WrappedError_ReturnsTrue(t *testing.T) { + // errors.Is() should unwrap fmt.Errorf("%w") chains + wrapped := fmt.Errorf("send failed: %w", syscall.ENXIO) + if !isInterfaceGone(wrapped) { + t.Errorf("expected true for wrapped ENXIO, got false") + } +} + +func TestIsInterfaceGone_DoubleWrappedError_ReturnsTrue(t *testing.T) { + inner := fmt.Errorf("inner: %w", syscall.ENETDOWN) + outer := fmt.Errorf("outer: %w", inner) + if !isInterfaceGone(outer) { + t.Errorf("expected true for double-wrapped ENETDOWN, got false") + } +} + +func TestIsInterfaceGone_TransientError_ReturnsFalse(t *testing.T) { + // EAGAIN is a transient error - should not remove interface + err := syscall.EAGAIN + if isInterfaceGone(err) { + t.Errorf("expected false for EAGAIN (transient), got true") + } +} + +func TestIsInterfaceGone_ETIMEDOUT_ReturnsFalse(t *testing.T) { + // Timeout is transient - interface might still be fine + err := syscall.ETIMEDOUT + if isInterfaceGone(err) { + t.Errorf("expected false for ETIMEDOUT (transient), got true") + } +} + +func TestIsInterfaceGone_GenericError_ReturnsFalse(t *testing.T) { + err := errors.New("some random error") + if isInterfaceGone(err) { + t.Errorf("expected false for generic error, got true") + } +} + +func TestIsInterfaceGone_FallbackMessageParsing_NoSuchDevice(t *testing.T) { + // Fallback for unknown error codes with recognizable message + err := errors.New("operation failed: no such device") + if !isInterfaceGone(err) { + t.Errorf("expected true for 'no such device' message, got false") + } +} + +func TestIsInterfaceGone_FallbackMessageParsing_NetworkDown(t *testing.T) { + err := errors.New("send error: network is down") + if !isInterfaceGone(err) { + t.Errorf("expected true for 'network is down' message, got false") + } +} + +func TestIsInterfaceGone_FallbackMessageParsing_NetworkUnreachable(t *testing.T) { + err := errors.New("cannot route: network is unreachable") + if !isInterfaceGone(err) { + t.Errorf("expected true for 'network is unreachable' message, got false") + } +} + +func TestIsInterfaceGone_FallbackMessageParsing_NoSuchInterface(t *testing.T) { + err := errors.New("interface eth0: no such interface") + if !isInterfaceGone(err) { + t.Errorf("expected true for 'no such interface' message, got false") + } +} + +func TestIsInterfaceGone_FallbackMessageParsing_CaseInsensitive(t *testing.T) { + err := errors.New("NETWORK IS DOWN") + if !isInterfaceGone(err) { + t.Errorf("expected true for uppercase 'NETWORK IS DOWN', got false") + } +} diff --git a/examples/proxyservice/server.go b/examples/proxyservice/server.go index a601f52e..4ad27a1b 100644 --- a/examples/proxyservice/server.go +++ b/examples/proxyservice/server.go @@ -9,7 +9,7 @@ import ( "time" - "github.com/enbility/zeroconf/v2" + "github.com/enbility/zeroconf/v3" ) var ( diff --git a/examples/register/server.go b/examples/register/server.go index eefb72e3..e37755f5 100644 --- a/examples/register/server.go +++ b/examples/register/server.go @@ -9,7 +9,7 @@ import ( "time" - "github.com/enbility/zeroconf/v2" + "github.com/enbility/zeroconf/v3" ) var ( diff --git a/examples/resolv/client.go b/examples/resolv/client.go index f435ac58..6a3b3d30 100644 --- a/examples/resolv/client.go +++ b/examples/resolv/client.go @@ -6,7 +6,7 @@ import ( "log" "time" - "github.com/enbility/zeroconf/v2" + "github.com/enbility/zeroconf/v3" ) var ( diff --git a/go.mod b/go.mod index 599b7926..8e6d8f8e 100644 --- a/go.mod +++ b/go.mod @@ -1,15 +1,20 @@ -module github.com/enbility/zeroconf/v2 +module github.com/enbility/zeroconf/v3 go 1.22.0 require ( github.com/miekg/dns v1.1.62 + github.com/stretchr/testify v1.11.1 golang.org/x/net v0.29.0 ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect golang.org/x/mod v0.21.0 // indirect golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.25.0 // indirect golang.org/x/tools v0.25.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 3207eb13..bc8d7283 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,13 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ= github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= @@ -10,3 +18,7 @@ golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/tools v0.25.0 h1:oFU9pkj/iJgs+0DT+VMHrx+oBKs/LJMV+Uvg78sl+fE= golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/interface_manager.go b/interface_manager.go new file mode 100644 index 00000000..521385a3 --- /dev/null +++ b/interface_manager.go @@ -0,0 +1,251 @@ +package zeroconf + +import ( + "net" + "sync" + "time" +) + +// Backoff intervals for adaptive retry strategy. +// Fast first retry for user-initiated reconnects, then progressive delay. +const ( + backoffFirst = 1 * time.Second // First retry: fast for quick reconnects + backoffSecond = 5 * time.Second // Second retry: moderate delay + backoffMax = 30 * time.Second // Subsequent retries: avoid thrashing +) + +// failureState tracks failure history for adaptive backoff. +type failureState struct { + count int // Number of consecutive failures + retryAt time.Time // Don't retry until this time +} + +// InterfaceManager tracks active and failed interfaces for one IP version. +// Thread-safe. Create separate instances for IPv4 and IPv6. +// +// Concurrency model: +// - ActiveIndices() returns a snapshot; iteration is lock-free +// - MarkFailed() is idempotent; safe to call even if already removed +// - Sync() runs periodically in background; updates are atomic +type InterfaceManager struct { + mu sync.RWMutex + active map[int]string // ifIndex -> name (currently usable) + failures map[string]*failureState // name -> failure tracking (adaptive backoff) + requested []string // Mode selector: + // nil = dynamic mode (accept any multicast interface) + // non-nil = explicit mode (only names in this slice) + // NOTE: Empty slice []string{} is treated as explicit mode + // with NO allowed interfaces - almost certainly a bug. + // Callers should pass nil for dynamic mode, not empty slice. +} + +// NewInterfaceManager creates a manager with initial interfaces. +// If requested is nil, dynamic mode is used (accepts new interfaces). +// If requested is non-nil, only those interface names are ever used. +func NewInterfaceManager(initial []net.Interface, requested []string) *InterfaceManager { + m := &InterfaceManager{ + active: make(map[int]string, len(initial)), + failures: make(map[string]*failureState), + requested: requested, + } + for _, iface := range initial { + m.active[iface.Index] = iface.Name + } + return m +} + +// ActiveIndices returns current active interface indices. +// Call this in send loops - never cache the result. +// +// The returned slice is a snapshot. The caller iterates over it while +// the sync goroutine may modify the active map. This is safe because: +// - Sends to removed indices fail fast and call MarkFailed (idempotent) +// - New indices are picked up on the next ActiveIndices() call +func (m *InterfaceManager) ActiveIndices() []int { + m.mu.RLock() + defer m.mu.RUnlock() + + result := make([]int, 0, len(m.active)) + for idx := range m.active { + result = append(result, idx) + } + return result +} + +// MarkFailed removes an interface from active set if error indicates it's gone. +// Uses adaptive backoff: first failure = 1s, second = 5s, third+ = 30s. +// +// This method is IDEMPOTENT: safe to call even if the interface was already +// removed by a concurrent Sync() call. +// +// Returns true if the error indicated the interface is gone. +func (m *InterfaceManager) MarkFailed(ifIndex int, err error) bool { + if !isInterfaceGone(err) { + return false // Transient error, don't remove + } + + m.mu.Lock() + defer m.mu.Unlock() + + // Get name from active map (if still present) + name := m.active[ifIndex] + if name == "" { + // Already removed - this is the benign race case + // We can't set backoff without knowing the name, but that's OK: + // either Sync() already set it, or we don't have enough info. + return true + } + + // Remove from active (idempotent - no-op if not present) + delete(m.active, ifIndex) + + // Update failure tracking with adaptive backoff + m.recordFailure(name) + + return true +} + +// recordFailure updates the failure state for an interface (must hold lock). +func (m *InterfaceManager) recordFailure(name string) { + state := m.failures[name] + if state == nil { + state = &failureState{} + m.failures[name] = state + } + state.count++ + + // Adaptive backoff based on failure count + var backoff time.Duration + switch state.count { + case 1: + backoff = backoffFirst // 1s - fast retry for quick reconnects + case 2: + backoff = backoffSecond // 5s - moderate delay + default: + backoff = backoffMax // 30s - avoid thrashing + } + state.retryAt = time.Now().Add(backoff) +} + +// Sync updates state based on currently available interfaces. +// Returns interfaces that were recovered and need JoinGroup calls. +// +// Handles: +// - Disappeared interfaces (removes from active, sets backoff) +// - Index changes (interface reconnects with different index) +// - New interfaces in dynamic mode +// - Recovery after backoff expires +func (m *InterfaceManager) Sync(current []net.Interface) []net.Interface { + m.mu.Lock() + defer m.mu.Unlock() + + now := time.Now() + currentByName := make(map[string]net.Interface, len(current)) + for _, iface := range current { + currentByName[iface.Name] = iface + } + + // Step 1: Remove disappeared interfaces + for idx, name := range m.active { + if _, exists := currentByName[name]; !exists { + delete(m.active, idx) + m.recordFailure(name) + } + } + + // Step 2: Find interfaces to recover and clean up stale indices + var recovered []net.Interface + for _, iface := range current { + if m.shouldRecover(iface, now) { + // Clean up stale index before adding to recovered list + m.cleanupStaleIndex(iface) + recovered = append(recovered, iface) + } + } + + return recovered +} + +// shouldRecover checks if an interface should be recovered (must hold lock). +// NOTE: This is a pure predicate - it does NOT mutate state. +// Use cleanupStaleIndex() separately to handle index changes. +func (m *InterfaceManager) shouldRecover(iface net.Interface, now time.Time) bool { + // Check if already active with same index + if existingName, ok := m.active[iface.Index]; ok && existingName == iface.Name { + return false // Already active, nothing to do + } + + // Check mode restrictions + if !m.isAllowed(iface.Name) { + return false + } + + // Check backoff + if state := m.failures[iface.Name]; state != nil && now.Before(state.retryAt) { + return false + } + + return true +} + +// isAllowed checks if interface name is allowed by mode (must hold lock). +func (m *InterfaceManager) isAllowed(name string) bool { + if m.requested == nil { + return true // Dynamic mode: allow all + } + for _, allowed := range m.requested { + if allowed == name { + return true + } + } + return false // Explicit mode: not in requested set +} + +// cleanupStaleIndex removes old index mapping if interface reconnected with new index. +// Must hold lock. Call this before adding new mapping for recovered interfaces. +func (m *InterfaceManager) cleanupStaleIndex(iface net.Interface) { + for idx, name := range m.active { + if name == iface.Name && idx != iface.Index { + delete(m.active, idx) + return // Only one stale mapping possible per name + } + } +} + +// Activate adds an interface to the active set. +// Called after successful JoinGroup. Clears failure history. +// Handles the case where interface reconnected with a different index. +func (m *InterfaceManager) Activate(iface net.Interface) { + m.mu.Lock() + defer m.mu.Unlock() + + // Remove stale index mapping if interface reconnected with new index + m.cleanupStaleIndex(iface) + + m.active[iface.Index] = iface.Name + delete(m.failures, iface.Name) // Clear failure history on success +} + +// SetBackoff marks an interface as temporarily failed (e.g., JoinGroup failed). +// Increments the failure counter for adaptive backoff. +func (m *InterfaceManager) SetBackoff(ifName string) { + m.mu.Lock() + defer m.mu.Unlock() + + m.recordFailure(ifName) +} + +// GetActiveInterfaces returns full interface objects for all active indices. +// Used for IP address collection (avoids race between ActiveIndices and lookup). +func (m *InterfaceManager) GetActiveInterfaces() []net.Interface { + m.mu.RLock() + defer m.mu.RUnlock() + + result := make([]net.Interface, 0, len(m.active)) + for idx := range m.active { + if iface, err := net.InterfaceByIndex(idx); err == nil { + result = append(result, *iface) + } + } + return result +} diff --git a/interface_manager_test.go b/interface_manager_test.go new file mode 100644 index 00000000..217da9ac --- /dev/null +++ b/interface_manager_test.go @@ -0,0 +1,576 @@ +package zeroconf + +import ( + "net" + "sort" + "sync" + "syscall" + "testing" + "time" +) + +// Helper to create mock interfaces +func mockInterface(index int, name string) net.Interface { + return net.Interface{ + Index: index, + Name: name, + Flags: net.FlagUp | net.FlagMulticast, + } +} + +func mockInterfaces(specs ...struct{ idx int; name string }) []net.Interface { + result := make([]net.Interface, len(specs)) + for i, s := range specs { + result[i] = mockInterface(s.idx, s.name) + } + return result +} + +// ============================================================================ +// NewInterfaceManager Tests +// ============================================================================ + +func TestInterfaceManager_NewDynamicMode(t *testing.T) { + ifaces := []net.Interface{mockInterface(1, "eth0"), mockInterface(2, "wlan0")} + + // nil requested = dynamic mode + mgr := NewInterfaceManager(ifaces, nil) + + indices := mgr.ActiveIndices() + if len(indices) != 2 { + t.Errorf("expected 2 active indices, got %d", len(indices)) + } +} + +func TestInterfaceManager_NewExplicitMode(t *testing.T) { + ifaces := []net.Interface{mockInterface(1, "eth0"), mockInterface(2, "wlan0")} + requested := []string{"eth0", "wlan0"} + + mgr := NewInterfaceManager(ifaces, requested) + + indices := mgr.ActiveIndices() + if len(indices) != 2 { + t.Errorf("expected 2 active indices, got %d", len(indices)) + } +} + +func TestInterfaceManager_NewEmptyInitial(t *testing.T) { + mgr := NewInterfaceManager(nil, nil) + + indices := mgr.ActiveIndices() + if len(indices) != 0 { + t.Errorf("expected 0 active indices, got %d", len(indices)) + } +} + +// ============================================================================ +// ActiveIndices Tests +// ============================================================================ + +func TestInterfaceManager_ActiveIndices_ReturnsSnapshot(t *testing.T) { + ifaces := []net.Interface{mockInterface(1, "eth0"), mockInterface(5, "wlan0")} + mgr := NewInterfaceManager(ifaces, nil) + + indices := mgr.ActiveIndices() + + // Should contain both indices + sort.Ints(indices) + if len(indices) != 2 || indices[0] != 1 || indices[1] != 5 { + t.Errorf("expected [1, 5], got %v", indices) + } +} + +func TestInterfaceManager_ActiveIndices_ReturnsEmptySliceNotNil(t *testing.T) { + mgr := NewInterfaceManager(nil, nil) + + indices := mgr.ActiveIndices() + + if indices == nil { + t.Error("expected empty slice, got nil") + } + if len(indices) != 0 { + t.Errorf("expected length 0, got %d", len(indices)) + } +} + +// ============================================================================ +// MarkFailed Tests +// ============================================================================ + +func TestInterfaceManager_MarkFailed_InterfaceGoneError_RemovesInterface(t *testing.T) { + ifaces := []net.Interface{mockInterface(1, "eth0"), mockInterface(2, "wlan0")} + mgr := NewInterfaceManager(ifaces, nil) + + // ENXIO indicates interface is gone + removed := mgr.MarkFailed(1, syscall.ENXIO) + + if !removed { + t.Error("expected MarkFailed to return true for ENXIO") + } + + indices := mgr.ActiveIndices() + if len(indices) != 1 { + t.Errorf("expected 1 active index after removal, got %d", len(indices)) + } + if indices[0] != 2 { + t.Errorf("expected remaining index to be 2, got %d", indices[0]) + } +} + +func TestInterfaceManager_MarkFailed_TransientError_KeepsInterface(t *testing.T) { + ifaces := []net.Interface{mockInterface(1, "eth0")} + mgr := NewInterfaceManager(ifaces, nil) + + // EAGAIN is transient - should not remove + removed := mgr.MarkFailed(1, syscall.EAGAIN) + + if removed { + t.Error("expected MarkFailed to return false for transient error") + } + + indices := mgr.ActiveIndices() + if len(indices) != 1 { + t.Errorf("expected interface to remain active, got %d active", len(indices)) + } +} + +func TestInterfaceManager_MarkFailed_Idempotent_SafeWhenAlreadyRemoved(t *testing.T) { + ifaces := []net.Interface{mockInterface(1, "eth0")} + mgr := NewInterfaceManager(ifaces, nil) + + // First removal + mgr.MarkFailed(1, syscall.ENXIO) + + // Second removal of same index - should not panic + removed := mgr.MarkFailed(1, syscall.ENXIO) + + // Returns true because error still indicates "interface gone" + if !removed { + t.Error("expected MarkFailed to return true even when already removed") + } + + indices := mgr.ActiveIndices() + if len(indices) != 0 { + t.Errorf("expected 0 active indices, got %d", len(indices)) + } +} + +func TestInterfaceManager_MarkFailed_UnknownIndex_DoesNotPanic(t *testing.T) { + mgr := NewInterfaceManager(nil, nil) + + // Index 999 was never added - should not panic + removed := mgr.MarkFailed(999, syscall.ENXIO) + + if !removed { + t.Error("expected true because error indicates interface gone") + } +} + +// ============================================================================ +// Adaptive Backoff Tests +// ============================================================================ + +func TestInterfaceManager_AdaptiveBackoff_FirstFailure1s(t *testing.T) { + ifaces := []net.Interface{mockInterface(1, "eth0")} + mgr := NewInterfaceManager(ifaces, nil) + + // Fail the interface + mgr.MarkFailed(1, syscall.ENXIO) + + // Check backoff is ~1s + mgr.mu.RLock() + state := mgr.failures["eth0"] + mgr.mu.RUnlock() + + if state == nil { + t.Fatal("expected failure state to exist") + } + if state.count != 1 { + t.Errorf("expected count 1, got %d", state.count) + } + + // retryAt should be ~1s from now + expectedBackoff := backoffFirst + actualBackoff := time.Until(state.retryAt) + if actualBackoff < expectedBackoff-100*time.Millisecond || actualBackoff > expectedBackoff+100*time.Millisecond { + t.Errorf("expected backoff ~%v, got %v", expectedBackoff, actualBackoff) + } +} + +func TestInterfaceManager_AdaptiveBackoff_SecondFailure5s(t *testing.T) { + ifaces := []net.Interface{mockInterface(1, "eth0")} + mgr := NewInterfaceManager(ifaces, nil) + + // First failure + mgr.MarkFailed(1, syscall.ENXIO) + + // Manually re-add and fail again (simulating Sync + re-fail) + mgr.mu.Lock() + mgr.active[1] = "eth0" + mgr.mu.Unlock() + mgr.MarkFailed(1, syscall.ENXIO) + + mgr.mu.RLock() + state := mgr.failures["eth0"] + mgr.mu.RUnlock() + + if state.count != 2 { + t.Errorf("expected count 2, got %d", state.count) + } + + expectedBackoff := backoffSecond + actualBackoff := time.Until(state.retryAt) + if actualBackoff < expectedBackoff-100*time.Millisecond || actualBackoff > expectedBackoff+100*time.Millisecond { + t.Errorf("expected backoff ~%v, got %v", expectedBackoff, actualBackoff) + } +} + +func TestInterfaceManager_AdaptiveBackoff_ThirdFailure30s(t *testing.T) { + ifaces := []net.Interface{mockInterface(1, "eth0")} + mgr := NewInterfaceManager(ifaces, nil) + + // Three failures + for i := 0; i < 3; i++ { + mgr.mu.Lock() + mgr.active[1] = "eth0" + mgr.mu.Unlock() + mgr.MarkFailed(1, syscall.ENXIO) + } + + mgr.mu.RLock() + state := mgr.failures["eth0"] + mgr.mu.RUnlock() + + if state.count != 3 { + t.Errorf("expected count 3, got %d", state.count) + } + + expectedBackoff := backoffMax + actualBackoff := time.Until(state.retryAt) + if actualBackoff < expectedBackoff-100*time.Millisecond || actualBackoff > expectedBackoff+100*time.Millisecond { + t.Errorf("expected backoff ~%v, got %v", expectedBackoff, actualBackoff) + } +} + +// ============================================================================ +// Sync Tests +// ============================================================================ + +func TestInterfaceManager_Sync_DetectsDisappeared(t *testing.T) { + ifaces := []net.Interface{mockInterface(1, "eth0"), mockInterface(2, "wlan0")} + mgr := NewInterfaceManager(ifaces, nil) + + // wlan0 disappeared + current := []net.Interface{mockInterface(1, "eth0")} + + recovered := mgr.Sync(current) + + // Nothing to recover (eth0 was already active) + if len(recovered) != 0 { + t.Errorf("expected 0 recovered, got %d", len(recovered)) + } + + // wlan0 should be removed + indices := mgr.ActiveIndices() + if len(indices) != 1 || indices[0] != 1 { + t.Errorf("expected [1], got %v", indices) + } + + // wlan0 should have failure state + mgr.mu.RLock() + _, hasFailure := mgr.failures["wlan0"] + mgr.mu.RUnlock() + if !hasFailure { + t.Error("expected failure state for wlan0") + } +} + +func TestInterfaceManager_Sync_RecoversAfterBackoff(t *testing.T) { + ifaces := []net.Interface{mockInterface(1, "eth0")} + mgr := NewInterfaceManager(ifaces, nil) + + // Remove eth0 and set backoff in the past + mgr.mu.Lock() + delete(mgr.active, 1) + mgr.failures["eth0"] = &failureState{ + count: 1, + retryAt: time.Now().Add(-1 * time.Second), // Backoff expired + } + mgr.mu.Unlock() + + // eth0 reappears + current := []net.Interface{mockInterface(1, "eth0")} + + recovered := mgr.Sync(current) + + if len(recovered) != 1 { + t.Fatalf("expected 1 recovered, got %d", len(recovered)) + } + if recovered[0].Name != "eth0" { + t.Errorf("expected eth0 to be recovered, got %s", recovered[0].Name) + } +} + +func TestInterfaceManager_Sync_RespectsBackoffNotExpired(t *testing.T) { + ifaces := []net.Interface{mockInterface(1, "eth0")} + mgr := NewInterfaceManager(ifaces, nil) + + // Remove eth0 and set backoff in the future + mgr.mu.Lock() + delete(mgr.active, 1) + mgr.failures["eth0"] = &failureState{ + count: 1, + retryAt: time.Now().Add(10 * time.Second), // Backoff NOT expired + } + mgr.mu.Unlock() + + current := []net.Interface{mockInterface(1, "eth0")} + + recovered := mgr.Sync(current) + + // Should NOT recover yet + if len(recovered) != 0 { + t.Errorf("expected 0 recovered (backoff not expired), got %d", len(recovered)) + } +} + +func TestInterfaceManager_Sync_RespectsExplicitMode(t *testing.T) { + ifaces := []net.Interface{mockInterface(1, "eth0")} + requested := []string{"eth0"} // Only eth0 allowed + mgr := NewInterfaceManager(ifaces, requested) + + // New interface wlan0 appears (not in requested list) + current := []net.Interface{mockInterface(1, "eth0"), mockInterface(2, "wlan0")} + + recovered := mgr.Sync(current) + + // wlan0 should NOT be recovered (not in requested) + for _, iface := range recovered { + if iface.Name == "wlan0" { + t.Error("wlan0 should not be recovered in explicit mode") + } + } + + // Only eth0 should be active + indices := mgr.ActiveIndices() + if len(indices) != 1 || indices[0] != 1 { + t.Errorf("expected [1], got %v", indices) + } +} + +func TestInterfaceManager_Sync_AcceptsNewInDynamicMode(t *testing.T) { + ifaces := []net.Interface{mockInterface(1, "eth0")} + mgr := NewInterfaceManager(ifaces, nil) // nil = dynamic mode + + // New interface wlan0 appears + current := []net.Interface{mockInterface(1, "eth0"), mockInterface(2, "wlan0")} + + recovered := mgr.Sync(current) + + // wlan0 should be in recovered list + found := false + for _, iface := range recovered { + if iface.Name == "wlan0" { + found = true + break + } + } + if !found { + t.Error("expected wlan0 to be in recovered list (dynamic mode)") + } +} + +func TestInterfaceManager_Sync_DetectsIndexChange(t *testing.T) { + // eth0 starts with index 1 + ifaces := []net.Interface{mockInterface(1, "eth0")} + mgr := NewInterfaceManager(ifaces, nil) + + // eth0 reconnects with index 5 (different index, same name) + current := []net.Interface{mockInterface(5, "eth0")} + + recovered := mgr.Sync(current) + + // eth0 should be recovered with new index + if len(recovered) != 1 { + t.Fatalf("expected 1 recovered, got %d", len(recovered)) + } + if recovered[0].Index != 5 || recovered[0].Name != "eth0" { + t.Errorf("expected {5, eth0}, got {%d, %s}", recovered[0].Index, recovered[0].Name) + } + + // Old index 1 should be removed + mgr.mu.RLock() + _, hasOld := mgr.active[1] + mgr.mu.RUnlock() + if hasOld { + t.Error("old index 1 should be removed") + } +} + +// ============================================================================ +// Activate Tests +// ============================================================================ + +func TestInterfaceManager_Activate_AddsToActive(t *testing.T) { + mgr := NewInterfaceManager(nil, nil) + + iface := mockInterface(3, "eth1") + mgr.Activate(iface) + + indices := mgr.ActiveIndices() + if len(indices) != 1 || indices[0] != 3 { + t.Errorf("expected [3], got %v", indices) + } +} + +func TestInterfaceManager_Activate_ClearsFailureHistory(t *testing.T) { + mgr := NewInterfaceManager(nil, nil) + + // Set up failure state + mgr.mu.Lock() + mgr.failures["eth1"] = &failureState{count: 5, retryAt: time.Now().Add(time.Hour)} + mgr.mu.Unlock() + + // Activate should clear it + iface := mockInterface(3, "eth1") + mgr.Activate(iface) + + mgr.mu.RLock() + _, hasFailure := mgr.failures["eth1"] + mgr.mu.RUnlock() + + if hasFailure { + t.Error("expected failure history to be cleared after Activate") + } +} + +func TestInterfaceManager_Activate_HandlesIndexChange(t *testing.T) { + // Start with eth0 at index 1 + ifaces := []net.Interface{mockInterface(1, "eth0")} + mgr := NewInterfaceManager(ifaces, nil) + + // Activate eth0 with new index 5 + mgr.Activate(mockInterface(5, "eth0")) + + indices := mgr.ActiveIndices() + sort.Ints(indices) + + // Should only have index 5, not both 1 and 5 + if len(indices) != 1 || indices[0] != 5 { + t.Errorf("expected [5], got %v", indices) + } +} + +// ============================================================================ +// SetBackoff Tests +// ============================================================================ + +func TestInterfaceManager_SetBackoff_SetsFailureState(t *testing.T) { + mgr := NewInterfaceManager(nil, nil) + + mgr.SetBackoff("eth0") + + mgr.mu.RLock() + state := mgr.failures["eth0"] + mgr.mu.RUnlock() + + if state == nil { + t.Fatal("expected failure state to be set") + } + if state.count != 1 { + t.Errorf("expected count 1, got %d", state.count) + } +} + +// ============================================================================ +// GetActiveInterfaces Tests +// ============================================================================ + +func TestInterfaceManager_GetActiveInterfaces_ReturnsInterfaces(t *testing.T) { + // This test requires actual system interfaces, so we'll just test the empty case + mgr := NewInterfaceManager(nil, nil) + + ifaces := mgr.GetActiveInterfaces() + + if ifaces == nil { + t.Error("expected empty slice, got nil") + } + if len(ifaces) != 0 { + t.Errorf("expected 0 interfaces, got %d", len(ifaces)) + } +} + +// ============================================================================ +// Concurrency Tests +// ============================================================================ + +func TestInterfaceManager_Concurrent_ReadWrite(t *testing.T) { + ifaces := []net.Interface{mockInterface(1, "eth0"), mockInterface(2, "wlan0")} + mgr := NewInterfaceManager(ifaces, nil) + + var wg sync.WaitGroup + stop := make(chan struct{}) + + // Reader goroutine + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-stop: + return + default: + _ = mgr.ActiveIndices() + } + } + }() + + // Writer goroutine - MarkFailed + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 100; i++ { + select { + case <-stop: + return + default: + mgr.MarkFailed(1, syscall.ENXIO) + } + } + }() + + // Writer goroutine - Sync + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 100; i++ { + select { + case <-stop: + return + default: + mgr.Sync(ifaces) + } + } + }() + + // Writer goroutine - Activate + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 100; i++ { + select { + case <-stop: + return + default: + mgr.Activate(mockInterface(3, "eth1")) + } + } + }() + + // Let them run for a bit + time.Sleep(50 * time.Millisecond) + close(stop) + wg.Wait() + + // If we get here without deadlock or panic, the test passes +} diff --git a/mdns.go b/mdns.go new file mode 100644 index 00000000..03c0a1da --- /dev/null +++ b/mdns.go @@ -0,0 +1,30 @@ +package zeroconf + +import "net" + +// mDNS network constants per RFC 6762 +var ( + // Multicast groups used by mDNS + mdnsGroupIPv4 = net.IPv4(224, 0, 0, 251) + mdnsGroupIPv6 = net.ParseIP("ff02::fb") + + // mDNS wildcard addresses for listening + mdnsWildcardAddrIPv4 = &net.UDPAddr{ + IP: net.ParseIP("224.0.0.0"), + Port: 5353, + } + mdnsWildcardAddrIPv6 = &net.UDPAddr{ + IP: net.ParseIP("ff02::"), + Port: 5353, + } + + // mDNS endpoint addresses for sending + ipv4Addr = &net.UDPAddr{ + IP: mdnsGroupIPv4, + Port: 5353, + } + ipv6Addr = &net.UDPAddr{ + IP: mdnsGroupIPv6, + Port: 5353, + } +) diff --git a/mocks/mock_connection_factory.go b/mocks/mock_connection_factory.go new file mode 100644 index 00000000..3f2e1b44 --- /dev/null +++ b/mocks/mock_connection_factory.go @@ -0,0 +1,163 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "net" + + "github.com/enbility/zeroconf/v3/api" + mock "github.com/stretchr/testify/mock" +) + +// NewMockConnectionFactory creates a new instance of MockConnectionFactory. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockConnectionFactory(t interface { + mock.TestingT + Cleanup(func()) +}) *MockConnectionFactory { + mock := &MockConnectionFactory{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// MockConnectionFactory is an autogenerated mock type for the ConnectionFactory type +type MockConnectionFactory struct { + mock.Mock +} + +type MockConnectionFactory_Expecter struct { + mock *mock.Mock +} + +func (_m *MockConnectionFactory) EXPECT() *MockConnectionFactory_Expecter { + return &MockConnectionFactory_Expecter{mock: &_m.Mock} +} + +// CreateIPv4Conn provides a mock function for the type MockConnectionFactory +func (_mock *MockConnectionFactory) CreateIPv4Conn(ifaces []net.Interface) (api.PacketConn, error) { + ret := _mock.Called(ifaces) + + if len(ret) == 0 { + panic("no return value specified for CreateIPv4Conn") + } + + var r0 api.PacketConn + var r1 error + if returnFunc, ok := ret.Get(0).(func([]net.Interface) (api.PacketConn, error)); ok { + return returnFunc(ifaces) + } + if returnFunc, ok := ret.Get(0).(func([]net.Interface) api.PacketConn); ok { + r0 = returnFunc(ifaces) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(api.PacketConn) + } + } + if returnFunc, ok := ret.Get(1).(func([]net.Interface) error); ok { + r1 = returnFunc(ifaces) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockConnectionFactory_CreateIPv4Conn_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateIPv4Conn' +type MockConnectionFactory_CreateIPv4Conn_Call struct { + *mock.Call +} + +// CreateIPv4Conn is a helper method to define mock.On call +// - ifaces []net.Interface +func (_e *MockConnectionFactory_Expecter) CreateIPv4Conn(ifaces interface{}) *MockConnectionFactory_CreateIPv4Conn_Call { + return &MockConnectionFactory_CreateIPv4Conn_Call{Call: _e.mock.On("CreateIPv4Conn", ifaces)} +} + +func (_c *MockConnectionFactory_CreateIPv4Conn_Call) Run(run func(ifaces []net.Interface)) *MockConnectionFactory_CreateIPv4Conn_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 []net.Interface + if args[0] != nil { + arg0 = args[0].([]net.Interface) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *MockConnectionFactory_CreateIPv4Conn_Call) Return(packetConn api.PacketConn, err error) *MockConnectionFactory_CreateIPv4Conn_Call { + _c.Call.Return(packetConn, err) + return _c +} + +func (_c *MockConnectionFactory_CreateIPv4Conn_Call) RunAndReturn(run func(ifaces []net.Interface) (api.PacketConn, error)) *MockConnectionFactory_CreateIPv4Conn_Call { + _c.Call.Return(run) + return _c +} + +// CreateIPv6Conn provides a mock function for the type MockConnectionFactory +func (_mock *MockConnectionFactory) CreateIPv6Conn(ifaces []net.Interface) (api.PacketConn, error) { + ret := _mock.Called(ifaces) + + if len(ret) == 0 { + panic("no return value specified for CreateIPv6Conn") + } + + var r0 api.PacketConn + var r1 error + if returnFunc, ok := ret.Get(0).(func([]net.Interface) (api.PacketConn, error)); ok { + return returnFunc(ifaces) + } + if returnFunc, ok := ret.Get(0).(func([]net.Interface) api.PacketConn); ok { + r0 = returnFunc(ifaces) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(api.PacketConn) + } + } + if returnFunc, ok := ret.Get(1).(func([]net.Interface) error); ok { + r1 = returnFunc(ifaces) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockConnectionFactory_CreateIPv6Conn_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateIPv6Conn' +type MockConnectionFactory_CreateIPv6Conn_Call struct { + *mock.Call +} + +// CreateIPv6Conn is a helper method to define mock.On call +// - ifaces []net.Interface +func (_e *MockConnectionFactory_Expecter) CreateIPv6Conn(ifaces interface{}) *MockConnectionFactory_CreateIPv6Conn_Call { + return &MockConnectionFactory_CreateIPv6Conn_Call{Call: _e.mock.On("CreateIPv6Conn", ifaces)} +} + +func (_c *MockConnectionFactory_CreateIPv6Conn_Call) Run(run func(ifaces []net.Interface)) *MockConnectionFactory_CreateIPv6Conn_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 []net.Interface + if args[0] != nil { + arg0 = args[0].([]net.Interface) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *MockConnectionFactory_CreateIPv6Conn_Call) Return(packetConn api.PacketConn, err error) *MockConnectionFactory_CreateIPv6Conn_Call { + _c.Call.Return(packetConn, err) + return _c +} + +func (_c *MockConnectionFactory_CreateIPv6Conn_Call) RunAndReturn(run func(ifaces []net.Interface) (api.PacketConn, error)) *MockConnectionFactory_CreateIPv6Conn_Call { + _c.Call.Return(run) + return _c +} diff --git a/mocks/mock_interface_provider.go b/mocks/mock_interface_provider.go new file mode 100644 index 00000000..60f7f27a --- /dev/null +++ b/mocks/mock_interface_provider.go @@ -0,0 +1,84 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "net" + + mock "github.com/stretchr/testify/mock" +) + +// NewMockInterfaceProvider creates a new instance of MockInterfaceProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockInterfaceProvider(t interface { + mock.TestingT + Cleanup(func()) +}) *MockInterfaceProvider { + mock := &MockInterfaceProvider{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// MockInterfaceProvider is an autogenerated mock type for the InterfaceProvider type +type MockInterfaceProvider struct { + mock.Mock +} + +type MockInterfaceProvider_Expecter struct { + mock *mock.Mock +} + +func (_m *MockInterfaceProvider) EXPECT() *MockInterfaceProvider_Expecter { + return &MockInterfaceProvider_Expecter{mock: &_m.Mock} +} + +// MulticastInterfaces provides a mock function for the type MockInterfaceProvider +func (_mock *MockInterfaceProvider) MulticastInterfaces() []net.Interface { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for MulticastInterfaces") + } + + var r0 []net.Interface + if returnFunc, ok := ret.Get(0).(func() []net.Interface); ok { + r0 = returnFunc() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]net.Interface) + } + } + return r0 +} + +// MockInterfaceProvider_MulticastInterfaces_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'MulticastInterfaces' +type MockInterfaceProvider_MulticastInterfaces_Call struct { + *mock.Call +} + +// MulticastInterfaces is a helper method to define mock.On call +func (_e *MockInterfaceProvider_Expecter) MulticastInterfaces() *MockInterfaceProvider_MulticastInterfaces_Call { + return &MockInterfaceProvider_MulticastInterfaces_Call{Call: _e.mock.On("MulticastInterfaces")} +} + +func (_c *MockInterfaceProvider_MulticastInterfaces_Call) Run(run func()) *MockInterfaceProvider_MulticastInterfaces_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockInterfaceProvider_MulticastInterfaces_Call) Return(interfaces []net.Interface) *MockInterfaceProvider_MulticastInterfaces_Call { + _c.Call.Return(interfaces) + return _c +} + +func (_c *MockInterfaceProvider_MulticastInterfaces_Call) RunAndReturn(run func() []net.Interface) *MockInterfaceProvider_MulticastInterfaces_Call { + _c.Call.Return(run) + return _c +} diff --git a/mocks/mock_packet_conn.go b/mocks/mock_packet_conn.go new file mode 100644 index 00000000..c4c9da66 --- /dev/null +++ b/mocks/mock_packet_conn.go @@ -0,0 +1,495 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "net" + + mock "github.com/stretchr/testify/mock" +) + +// NewMockPacketConn creates a new instance of MockPacketConn. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockPacketConn(t interface { + mock.TestingT + Cleanup(func()) +}) *MockPacketConn { + mock := &MockPacketConn{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// MockPacketConn is an autogenerated mock type for the PacketConn type +type MockPacketConn struct { + mock.Mock +} + +type MockPacketConn_Expecter struct { + mock *mock.Mock +} + +func (_m *MockPacketConn) EXPECT() *MockPacketConn_Expecter { + return &MockPacketConn_Expecter{mock: &_m.Mock} +} + +// Close provides a mock function for the type MockPacketConn +func (_mock *MockPacketConn) Close() error { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for Close") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func() error); ok { + r0 = returnFunc() + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockPacketConn_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockPacketConn_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockPacketConn_Expecter) Close() *MockPacketConn_Close_Call { + return &MockPacketConn_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockPacketConn_Close_Call) Run(run func()) *MockPacketConn_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockPacketConn_Close_Call) Return(err error) *MockPacketConn_Close_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockPacketConn_Close_Call) RunAndReturn(run func() error) *MockPacketConn_Close_Call { + _c.Call.Return(run) + return _c +} + +// JoinGroup provides a mock function for the type MockPacketConn +func (_mock *MockPacketConn) JoinGroup(ifi *net.Interface, group net.Addr) error { + ret := _mock.Called(ifi, group) + + if len(ret) == 0 { + panic("no return value specified for JoinGroup") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(*net.Interface, net.Addr) error); ok { + r0 = returnFunc(ifi, group) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockPacketConn_JoinGroup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'JoinGroup' +type MockPacketConn_JoinGroup_Call struct { + *mock.Call +} + +// JoinGroup is a helper method to define mock.On call +// - ifi *net.Interface +// - group net.Addr +func (_e *MockPacketConn_Expecter) JoinGroup(ifi interface{}, group interface{}) *MockPacketConn_JoinGroup_Call { + return &MockPacketConn_JoinGroup_Call{Call: _e.mock.On("JoinGroup", ifi, group)} +} + +func (_c *MockPacketConn_JoinGroup_Call) Run(run func(ifi *net.Interface, group net.Addr)) *MockPacketConn_JoinGroup_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 *net.Interface + if args[0] != nil { + arg0 = args[0].(*net.Interface) + } + var arg1 net.Addr + if args[1] != nil { + arg1 = args[1].(net.Addr) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockPacketConn_JoinGroup_Call) Return(err error) *MockPacketConn_JoinGroup_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockPacketConn_JoinGroup_Call) RunAndReturn(run func(ifi *net.Interface, group net.Addr) error) *MockPacketConn_JoinGroup_Call { + _c.Call.Return(run) + return _c +} + +// LeaveGroup provides a mock function for the type MockPacketConn +func (_mock *MockPacketConn) LeaveGroup(ifi *net.Interface, group net.Addr) error { + ret := _mock.Called(ifi, group) + + if len(ret) == 0 { + panic("no return value specified for LeaveGroup") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(*net.Interface, net.Addr) error); ok { + r0 = returnFunc(ifi, group) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockPacketConn_LeaveGroup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LeaveGroup' +type MockPacketConn_LeaveGroup_Call struct { + *mock.Call +} + +// LeaveGroup is a helper method to define mock.On call +// - ifi *net.Interface +// - group net.Addr +func (_e *MockPacketConn_Expecter) LeaveGroup(ifi interface{}, group interface{}) *MockPacketConn_LeaveGroup_Call { + return &MockPacketConn_LeaveGroup_Call{Call: _e.mock.On("LeaveGroup", ifi, group)} +} + +func (_c *MockPacketConn_LeaveGroup_Call) Run(run func(ifi *net.Interface, group net.Addr)) *MockPacketConn_LeaveGroup_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 *net.Interface + if args[0] != nil { + arg0 = args[0].(*net.Interface) + } + var arg1 net.Addr + if args[1] != nil { + arg1 = args[1].(net.Addr) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockPacketConn_LeaveGroup_Call) Return(err error) *MockPacketConn_LeaveGroup_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockPacketConn_LeaveGroup_Call) RunAndReturn(run func(ifi *net.Interface, group net.Addr) error) *MockPacketConn_LeaveGroup_Call { + _c.Call.Return(run) + return _c +} + +// ReadFrom provides a mock function for the type MockPacketConn +func (_mock *MockPacketConn) ReadFrom(b []byte) (int, int, net.Addr, error) { + ret := _mock.Called(b) + + if len(ret) == 0 { + panic("no return value specified for ReadFrom") + } + + var r0 int + var r1 int + var r2 net.Addr + var r3 error + if returnFunc, ok := ret.Get(0).(func([]byte) (int, int, net.Addr, error)); ok { + return returnFunc(b) + } + if returnFunc, ok := ret.Get(0).(func([]byte) int); ok { + r0 = returnFunc(b) + } else { + r0 = ret.Get(0).(int) + } + if returnFunc, ok := ret.Get(1).(func([]byte) int); ok { + r1 = returnFunc(b) + } else { + r1 = ret.Get(1).(int) + } + if returnFunc, ok := ret.Get(2).(func([]byte) net.Addr); ok { + r2 = returnFunc(b) + } else { + if ret.Get(2) != nil { + r2 = ret.Get(2).(net.Addr) + } + } + if returnFunc, ok := ret.Get(3).(func([]byte) error); ok { + r3 = returnFunc(b) + } else { + r3 = ret.Error(3) + } + return r0, r1, r2, r3 +} + +// MockPacketConn_ReadFrom_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReadFrom' +type MockPacketConn_ReadFrom_Call struct { + *mock.Call +} + +// ReadFrom is a helper method to define mock.On call +// - b []byte +func (_e *MockPacketConn_Expecter) ReadFrom(b interface{}) *MockPacketConn_ReadFrom_Call { + return &MockPacketConn_ReadFrom_Call{Call: _e.mock.On("ReadFrom", b)} +} + +func (_c *MockPacketConn_ReadFrom_Call) Run(run func(b []byte)) *MockPacketConn_ReadFrom_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 []byte + if args[0] != nil { + arg0 = args[0].([]byte) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *MockPacketConn_ReadFrom_Call) Return(n int, ifIndex int, src net.Addr, err error) *MockPacketConn_ReadFrom_Call { + _c.Call.Return(n, ifIndex, src, err) + return _c +} + +func (_c *MockPacketConn_ReadFrom_Call) RunAndReturn(run func(b []byte) (int, int, net.Addr, error)) *MockPacketConn_ReadFrom_Call { + _c.Call.Return(run) + return _c +} + +// SetMulticastHopLimit provides a mock function for the type MockPacketConn +func (_mock *MockPacketConn) SetMulticastHopLimit(hopLimit int) error { + ret := _mock.Called(hopLimit) + + if len(ret) == 0 { + panic("no return value specified for SetMulticastHopLimit") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(int) error); ok { + r0 = returnFunc(hopLimit) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockPacketConn_SetMulticastHopLimit_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetMulticastHopLimit' +type MockPacketConn_SetMulticastHopLimit_Call struct { + *mock.Call +} + +// SetMulticastHopLimit is a helper method to define mock.On call +// - hopLimit int +func (_e *MockPacketConn_Expecter) SetMulticastHopLimit(hopLimit interface{}) *MockPacketConn_SetMulticastHopLimit_Call { + return &MockPacketConn_SetMulticastHopLimit_Call{Call: _e.mock.On("SetMulticastHopLimit", hopLimit)} +} + +func (_c *MockPacketConn_SetMulticastHopLimit_Call) Run(run func(hopLimit int)) *MockPacketConn_SetMulticastHopLimit_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 int + if args[0] != nil { + arg0 = args[0].(int) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *MockPacketConn_SetMulticastHopLimit_Call) Return(err error) *MockPacketConn_SetMulticastHopLimit_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockPacketConn_SetMulticastHopLimit_Call) RunAndReturn(run func(hopLimit int) error) *MockPacketConn_SetMulticastHopLimit_Call { + _c.Call.Return(run) + return _c +} + +// SetMulticastInterface provides a mock function for the type MockPacketConn +func (_mock *MockPacketConn) SetMulticastInterface(ifi *net.Interface) error { + ret := _mock.Called(ifi) + + if len(ret) == 0 { + panic("no return value specified for SetMulticastInterface") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(*net.Interface) error); ok { + r0 = returnFunc(ifi) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockPacketConn_SetMulticastInterface_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetMulticastInterface' +type MockPacketConn_SetMulticastInterface_Call struct { + *mock.Call +} + +// SetMulticastInterface is a helper method to define mock.On call +// - ifi *net.Interface +func (_e *MockPacketConn_Expecter) SetMulticastInterface(ifi interface{}) *MockPacketConn_SetMulticastInterface_Call { + return &MockPacketConn_SetMulticastInterface_Call{Call: _e.mock.On("SetMulticastInterface", ifi)} +} + +func (_c *MockPacketConn_SetMulticastInterface_Call) Run(run func(ifi *net.Interface)) *MockPacketConn_SetMulticastInterface_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 *net.Interface + if args[0] != nil { + arg0 = args[0].(*net.Interface) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *MockPacketConn_SetMulticastInterface_Call) Return(err error) *MockPacketConn_SetMulticastInterface_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockPacketConn_SetMulticastInterface_Call) RunAndReturn(run func(ifi *net.Interface) error) *MockPacketConn_SetMulticastInterface_Call { + _c.Call.Return(run) + return _c +} + +// SetMulticastTTL provides a mock function for the type MockPacketConn +func (_mock *MockPacketConn) SetMulticastTTL(ttl int) error { + ret := _mock.Called(ttl) + + if len(ret) == 0 { + panic("no return value specified for SetMulticastTTL") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(int) error); ok { + r0 = returnFunc(ttl) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockPacketConn_SetMulticastTTL_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetMulticastTTL' +type MockPacketConn_SetMulticastTTL_Call struct { + *mock.Call +} + +// SetMulticastTTL is a helper method to define mock.On call +// - ttl int +func (_e *MockPacketConn_Expecter) SetMulticastTTL(ttl interface{}) *MockPacketConn_SetMulticastTTL_Call { + return &MockPacketConn_SetMulticastTTL_Call{Call: _e.mock.On("SetMulticastTTL", ttl)} +} + +func (_c *MockPacketConn_SetMulticastTTL_Call) Run(run func(ttl int)) *MockPacketConn_SetMulticastTTL_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 int + if args[0] != nil { + arg0 = args[0].(int) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *MockPacketConn_SetMulticastTTL_Call) Return(err error) *MockPacketConn_SetMulticastTTL_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockPacketConn_SetMulticastTTL_Call) RunAndReturn(run func(ttl int) error) *MockPacketConn_SetMulticastTTL_Call { + _c.Call.Return(run) + return _c +} + +// WriteTo provides a mock function for the type MockPacketConn +func (_mock *MockPacketConn) WriteTo(b []byte, ifIndex int, dst net.Addr) (int, error) { + ret := _mock.Called(b, ifIndex, dst) + + if len(ret) == 0 { + panic("no return value specified for WriteTo") + } + + var r0 int + var r1 error + if returnFunc, ok := ret.Get(0).(func([]byte, int, net.Addr) (int, error)); ok { + return returnFunc(b, ifIndex, dst) + } + if returnFunc, ok := ret.Get(0).(func([]byte, int, net.Addr) int); ok { + r0 = returnFunc(b, ifIndex, dst) + } else { + r0 = ret.Get(0).(int) + } + if returnFunc, ok := ret.Get(1).(func([]byte, int, net.Addr) error); ok { + r1 = returnFunc(b, ifIndex, dst) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockPacketConn_WriteTo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WriteTo' +type MockPacketConn_WriteTo_Call struct { + *mock.Call +} + +// WriteTo is a helper method to define mock.On call +// - b []byte +// - ifIndex int +// - dst net.Addr +func (_e *MockPacketConn_Expecter) WriteTo(b interface{}, ifIndex interface{}, dst interface{}) *MockPacketConn_WriteTo_Call { + return &MockPacketConn_WriteTo_Call{Call: _e.mock.On("WriteTo", b, ifIndex, dst)} +} + +func (_c *MockPacketConn_WriteTo_Call) Run(run func(b []byte, ifIndex int, dst net.Addr)) *MockPacketConn_WriteTo_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 []byte + if args[0] != nil { + arg0 = args[0].([]byte) + } + var arg1 int + if args[1] != nil { + arg1 = args[1].(int) + } + var arg2 net.Addr + if args[2] != nil { + arg2 = args[2].(net.Addr) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *MockPacketConn_WriteTo_Call) Return(n int, err error) *MockPacketConn_WriteTo_Call { + _c.Call.Return(n, err) + return _c +} + +func (_c *MockPacketConn_WriteTo_Call) RunAndReturn(run func(b []byte, ifIndex int, dst net.Addr) (int, error)) *MockPacketConn_WriteTo_Call { + _c.Call.Return(run) + return _c +} diff --git a/server.go b/server.go index 71b7bf8c..66ba5faa 100644 --- a/server.go +++ b/server.go @@ -6,14 +6,12 @@ import ( "math/rand" "net" "os" - "runtime" "strings" "sync" "time" + "github.com/enbility/zeroconf/v3/api" "github.com/miekg/dns" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" ) const ( @@ -24,7 +22,9 @@ const ( var defaultTTL uint32 = 3200 type serverOpts struct { - ttl uint32 + ttl uint32 + connFactory api.ConnectionFactory + provider api.InterfaceProvider } func applyServerOpts(options ...ServerOption) serverOpts { @@ -50,6 +50,22 @@ func TTL(ttl uint32) ServerOption { } } +// WithServerConnFactory sets a custom connection factory for the server. +// This is primarily useful for testing with mock connections. +func WithServerConnFactory(factory api.ConnectionFactory) ServerOption { + return func(o *serverOpts) { + o.connFactory = factory + } +} + +// WithServerInterfaceProvider sets a custom interface provider for the server. +// This is primarily useful for testing with mock interface lists. +func WithServerInterfaceProvider(provider api.InterfaceProvider) ServerOption { + return func(o *serverOpts) { + o.provider = provider + } +} + // Register a service by given arguments. This call will take the system's hostname // and lookup IP by that hostname. func Register(instance, service, domain string, port int, text []string, ifaces []net.Interface, opts ...ServerOption) (*Server, error) { @@ -83,7 +99,7 @@ func Register(instance, service, domain string, port int, text []string, ifaces } if len(ifaces) == 0 { - ifaces = listMulticastInterfaces() + ifaces = NewInterfaceProvider().MulticastInterfaces() } for _, iface := range ifaces { @@ -149,7 +165,7 @@ func RegisterProxy(instance, service, domain string, port int, host string, ips } if len(ifaces) == 0 { - ifaces = listMulticastInterfaces() + ifaces = NewInterfaceProvider().MulticastInterfaces() } s, err := newServer(ifaces, applyServerOpts(opts...)) @@ -170,9 +186,11 @@ const ( // Server structure encapsulates both IPv4/IPv6 UDP connections type Server struct { service *ServiceEntry - ipv4conn *ipv4.PacketConn - ipv6conn *ipv6.PacketConn - ifaces []net.Interface + ipv4conn api.PacketConn + ipv6conn api.PacketConn + ipv4Mgr *InterfaceManager + ipv6Mgr *InterfaceManager + provider api.InterfaceProvider shouldShutdown chan struct{} shutdownLock sync.Mutex @@ -183,11 +201,37 @@ type Server struct { // Constructs server structure func newServer(ifaces []net.Interface, opts serverOpts) (*Server, error) { - ipv4conn, err4 := joinUdp4Multicast(ifaces) + // Get interface provider (use default if not injected for testing) + provider := opts.provider + if provider == nil { + provider = NewInterfaceProvider() + } + + factory := opts.connFactory + if factory == nil { + factory = NewConnectionFactory() + } + + // Determine mode + var requested []string + if len(ifaces) > 0 { + requested = make([]string, len(ifaces)) + for i, iface := range ifaces { + requested[i] = iface.Name + } + } else { + ifaces = provider.MulticastInterfaces() + } + + // Create SEPARATE managers for IPv4 and IPv6. + ipv4Mgr := NewInterfaceManager(ifaces, requested) + ipv6Mgr := NewInterfaceManager(ifaces, requested) + + ipv4conn, err4 := factory.CreateIPv4Conn(ifaces) if err4 != nil { log.Printf("[zeroconf] no suitable IPv4 interface: %s", err4.Error()) } - ipv6conn, err6 := joinUdp6Multicast(ifaces) + ipv6conn, err6 := factory.CreateIPv6Conn(ifaces) if err6 != nil { log.Printf("[zeroconf] no suitable IPv6 interface: %s", err6.Error()) } @@ -199,7 +243,9 @@ func newServer(ifaces []net.Interface, opts serverOpts) (*Server, error) { s := &Server{ ipv4conn: ipv4conn, ipv6conn: ipv6conn, - ifaces: ifaces, + ipv4Mgr: ipv4Mgr, + ipv6Mgr: ipv6Mgr, + provider: provider, ttl: opts.ttl, shouldShutdown: make(chan struct{}), } @@ -210,14 +256,18 @@ func newServer(ifaces []net.Interface, opts serverOpts) (*Server, error) { func (s *Server) start() { if s.ipv4conn != nil { s.refCount.Add(1) - go s.recv4(s.ipv4conn) + go s.recvLoop(s.ipv4conn) } if s.ipv6conn != nil { s.refCount.Add(1) - go s.recv6(s.ipv6conn) + go s.recvLoop(s.ipv6conn) } s.refCount.Add(1) go s.probe() + + // Start interface sync goroutine + s.refCount.Add(1) + go s.runInterfaceSync() } // SetText updates and announces the TXT records @@ -226,13 +276,6 @@ func (s *Server) SetText(text []string) { s.announceText() } -// TTL sets the TTL for DNS replies -// -// Deprecated: This method is racy. Use the TTL server option instead. -func (s *Server) TTL(ttl uint32) { - s.ttl = ttl -} - // Shutdown closes all udp connections and unregisters the service func (s *Server) Shutdown() { s.shutdownLock.Lock() @@ -259,33 +302,9 @@ func (s *Server) Shutdown() { s.isShutdown = true } -// recv4 is a long running routine to receive packets from an interface -func (s *Server) recv4(c *ipv4.PacketConn) { - defer s.refCount.Done() - if c == nil { - return - } - buf := make([]byte, 65536) - for { - select { - case <-s.shouldShutdown: - return - default: - var ifIndex int - n, cm, from, err := c.ReadFrom(buf) - if err != nil { - continue - } - if cm != nil { - ifIndex = cm.IfIndex - } - _ = s.parsePacket(buf[:n], ifIndex, from) - } - } -} - -// recv6 is a long running routine to receive packets from an interface -func (s *Server) recv6(c *ipv6.PacketConn) { +// recvLoop is a long running routine to receive packets from a connection. +// It uses the PacketConn interface, allowing for mock injection in tests. +func (s *Server) recvLoop(c api.PacketConn) { defer s.refCount.Done() if c == nil { return @@ -296,13 +315,15 @@ func (s *Server) recv6(c *ipv6.PacketConn) { case <-s.shouldShutdown: return default: - var ifIndex int - n, cm, from, err := c.ReadFrom(buf) + n, ifIndex, from, err := c.ReadFrom(buf) if err != nil { - continue - } - if cm != nil { - ifIndex = cm.IfIndex + // Backoff to prevent CPU spin on persistent errors + select { + case <-s.shouldShutdown: + return + case <-time.After(50 * time.Millisecond): + continue + } } _ = s.parsePacket(buf[:n], ifIndex, from) } @@ -609,7 +630,12 @@ func (s *Server) probe() { // at least a factor of two with every response sent. timeout := time.Second for i := 0; i < multicastRepetitions; i++ { - for _, intf := range s.ifaces { + // Use active interfaces from both managers + activeIfaces := s.ipv4Mgr.GetActiveInterfaces() + if len(activeIfaces) == 0 { + activeIfaces = s.ipv6Mgr.GetActiveInterfaces() + } + for _, intf := range activeIfaces { resp := new(dns.Msg) resp.MsgHdr.Response = true // TODO: make response authoritative if we are the publisher @@ -738,24 +764,11 @@ func (s *Server) unicastResponse(resp *dns.Msg, ifIndex int, from net.Addr) erro } addr := from.(*net.UDPAddr) if addr.IP.To4() != nil { - if ifIndex != 0 { - var wcm ipv4.ControlMessage - wcm.IfIndex = ifIndex - _, err = s.ipv4conn.WriteTo(buf, &wcm, addr) - } else { - _, err = s.ipv4conn.WriteTo(buf, nil, addr) - } - return err - } else { - if ifIndex != 0 { - var wcm ipv6.ControlMessage - wcm.IfIndex = ifIndex - _, err = s.ipv6conn.WriteTo(buf, &wcm, addr) - } else { - _, err = s.ipv6conn.WriteTo(buf, nil, addr) - } + _, err = s.ipv4conn.WriteTo(buf, ifIndex, addr) return err } + _, err = s.ipv6conn.WriteTo(buf, ifIndex, addr) + return err } // multicastResponse is used to send a multicast response packet @@ -764,67 +777,37 @@ func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int) error { if err != nil { return fmt.Errorf("failed to pack msg %v: %w", msg, err) } + + // Send to IPv4 multicast group if s.ipv4conn != nil { - // See https://pkg.go.dev/golang.org/x/net/ipv4#pkg-note-BUG - // As of Golang 1.18.4 - // On Windows, the ControlMessage for ReadFrom and WriteTo methods of PacketConn is not implemented. - var wcm ipv4.ControlMessage + var indices []int if ifIndex != 0 { - switch runtime.GOOS { - case "darwin", "ios", "linux": - wcm.IfIndex = ifIndex - default: - iface, _ := net.InterfaceByIndex(ifIndex) - if err := s.ipv4conn.SetMulticastInterface(iface); err != nil { - log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) - } - } - _, _ = s.ipv4conn.WriteTo(buf, &wcm, ipv4Addr) + indices = []int{ifIndex} } else { - for _, intf := range s.ifaces { - switch runtime.GOOS { - case "darwin", "ios", "linux": - wcm.IfIndex = intf.Index - default: - if err := s.ipv4conn.SetMulticastInterface(&intf); err != nil { - log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) - } - } - _, _ = s.ipv4conn.WriteTo(buf, &wcm, ipv4Addr) + indices = s.ipv4Mgr.ActiveIndices() + } + for _, idx := range indices { + if _, err := s.ipv4conn.WriteTo(buf, idx, ipv4Addr); err != nil { + s.ipv4Mgr.MarkFailed(idx, err) } } } + // Send to IPv6 multicast group if s.ipv6conn != nil { - // See https://pkg.go.dev/golang.org/x/net/ipv6#pkg-note-BUG - // As of Golang 1.18.4 - // On Windows, the ControlMessage for ReadFrom and WriteTo methods of PacketConn is not implemented. - var wcm ipv6.ControlMessage + var indices []int if ifIndex != 0 { - switch runtime.GOOS { - case "darwin", "ios", "linux": - wcm.IfIndex = ifIndex - default: - iface, _ := net.InterfaceByIndex(ifIndex) - if err := s.ipv6conn.SetMulticastInterface(iface); err != nil { - log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) - } - } - _, _ = s.ipv6conn.WriteTo(buf, &wcm, ipv6Addr) + indices = []int{ifIndex} } else { - for _, intf := range s.ifaces { - switch runtime.GOOS { - case "darwin", "ios", "linux": - wcm.IfIndex = intf.Index - default: - if err := s.ipv6conn.SetMulticastInterface(&intf); err != nil { - log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) - } - } - _, _ = s.ipv6conn.WriteTo(buf, &wcm, ipv6Addr) + indices = s.ipv6Mgr.ActiveIndices() + } + for _, idx := range indices { + if _, err := s.ipv6conn.WriteTo(buf, idx, ipv6Addr); err != nil { + s.ipv6Mgr.MarkFailed(idx, err) } } } + return nil } @@ -837,3 +820,28 @@ func isUnicastQuestion(q dns.Question) bool { // for this particular question. (See Section 5.4.) return q.Qclass&qClassCacheFlush != 0 } + +// runInterfaceSync periodically syncs the interface managers with the current +// system interface state, detecting recovered interfaces. +func (s *Server) runInterfaceSync() { + defer s.refCount.Done() + + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + for { + select { + case <-s.shouldShutdown: + return + case <-ticker.C: + s.syncInterfaces() + } + } +} + +// syncInterfaces updates both interface managers with current system state. +func (s *Server) syncInterfaces() { + current := s.provider.MulticastInterfaces() + s.ipv4Mgr.Sync(current) + s.ipv6Mgr.Sync(current) +} diff --git a/server_unit_test.go b/server_unit_test.go new file mode 100644 index 00000000..fb0bd851 --- /dev/null +++ b/server_unit_test.go @@ -0,0 +1,774 @@ +package zeroconf + +import ( + "errors" + "net" + "sync" + "syscall" + "testing" + "time" + + "github.com/enbility/zeroconf/v3/api" + "github.com/enbility/zeroconf/v3/mocks" + "github.com/miekg/dns" + "github.com/stretchr/testify/mock" +) + +// TestServer_Recv_BacksOffOnError verifies that recv backs off when ReadFrom returns errors +// This is the fix for the CPU spin bug. +func TestServer_Recv_BacksOffOnError(t *testing.T) { + mockConn := mocks.NewMockPacketConn(t) + + // Track call count + var callCount int + var mu sync.Mutex + + // Configure ReadFrom to always return an error + mockConn.EXPECT().ReadFrom(mock.Anything).RunAndReturn(func(b []byte) (int, int, net.Addr, error) { + mu.Lock() + callCount++ + mu.Unlock() + return 0, 0, nil, errors.New("mock read error") + }).Maybe() + + s := &Server{ + shouldShutdown: make(chan struct{}), + ttl: 3200, + } + + // recvLoop calls s.refCount.Done() on exit, so we need to Add first + s.refCount.Add(1) + + // Start recv in background + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + s.recvLoop(mockConn) + }() + + // Let it run briefly + time.Sleep(200 * time.Millisecond) + + // Shutdown + close(s.shouldShutdown) + wg.Wait() + + mu.Lock() + calls := callCount + mu.Unlock() + + // With 50ms backoff and 200ms runtime, we expect roughly 4 calls max + // Without backoff, we'd see thousands of calls + if calls > 10 { + t.Errorf("Expected few calls with backoff, got %d (suggests spinning)", calls) + } + t.Logf("ReadFrom called %d times in 200ms with backoff", calls) +} + +// TestServer_Recv_ProcessesPacket verifies that recv correctly processes incoming packets +func TestServer_Recv_ProcessesPacket(t *testing.T) { + // Create a valid DNS query packet + msg := new(dns.Msg) + msg.SetQuestion("_test._tcp.local.", dns.TypePTR) + packetData, err := msg.Pack() + if err != nil { + t.Fatalf("Failed to pack DNS message: %v", err) + } + + // We can test the packet parsing directly + parsed := new(dns.Msg) + if err := parsed.Unpack(packetData); err != nil { + t.Fatalf("Failed to unpack: %v", err) + } + + if len(parsed.Question) != 1 { + t.Errorf("Expected 1 question, got %d", len(parsed.Question)) + } + if parsed.Question[0].Name != "_test._tcp.local." { + t.Errorf("Expected question name _test._tcp.local., got %s", parsed.Question[0].Name) + } +} + +// testServer creates a Server with InterfaceManagers for testing. +// This helper avoids direct struct construction with the removed ifaces field. +func testServer(ipv4conn, ipv6conn api.PacketConn, ifaces []net.Interface) *Server { + return &Server{ + ipv4conn: ipv4conn, + ipv6conn: ipv6conn, + ipv4Mgr: NewInterfaceManager(ifaces, nil), + ipv6Mgr: NewInterfaceManager(ifaces, nil), + provider: NewInterfaceProvider(), + shouldShutdown: make(chan struct{}), + ttl: 3200, + } +} + +// TestServer_InterfaceDisconnect_StopsSendingToFailedInterface verifies that when +// a network interface disconnects during multicast response, the server stops +// attempting to send to that interface. This is the server-side fix for the +// infinite warning log issue. +func TestServer_InterfaceDisconnect_StopsSendingToFailedInterface(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + + // Two interfaces: eth0 (will fail) and wlan0 (stays healthy) + ifaces := []net.Interface{ + {Index: 1, Name: "eth0"}, + {Index: 2, Name: "wlan0"}, + } + + // Track calls per interface + var mu sync.Mutex + callsToEth0 := 0 + callsToWlan0 := 0 + + // eth0 (index 1) returns ENETDOWN, wlan0 (index 2) succeeds + mockIPv4.EXPECT().WriteTo(mock.Anything, mock.AnythingOfType("int"), mock.Anything).RunAndReturn( + func(b []byte, ifIndex int, dst net.Addr) (int, error) { + mu.Lock() + defer mu.Unlock() + if ifIndex == 1 { + callsToEth0++ + return 0, syscall.ENETDOWN + } + callsToWlan0++ + return len(b), nil + }).Maybe() + + s := testServer(mockIPv4, nil, ifaces) + + msg := new(dns.Msg) + msg.SetQuestion("_test._tcp.local.", dns.TypePTR) + + // First multicast: both interfaces attempted + _ = s.multicastResponse(msg, 0) + + mu.Lock() + firstEth0 := callsToEth0 + firstWlan0 := callsToWlan0 + mu.Unlock() + + if firstEth0 != 1 || firstWlan0 != 1 { + t.Errorf("First response: expected 1 call each, got eth0=%d wlan0=%d", firstEth0, firstWlan0) + } + + // Second multicast: eth0 should be excluded + _ = s.multicastResponse(msg, 0) + + mu.Lock() + secondEth0 := callsToEth0 + secondWlan0 := callsToWlan0 + mu.Unlock() + + if secondEth0 != 1 { + t.Errorf("Second response: eth0 should NOT be called again, got %d total calls", secondEth0) + } + if secondWlan0 != 2 { + t.Errorf("Second response: wlan0 should have 2 calls, got %d", secondWlan0) + } + + t.Logf("SUCCESS: Server stops sending to disconnected interface") + t.Logf("eth0 calls: %d, wlan0 calls: %d", secondEth0, secondWlan0) +} + +// TestServer_MulticastResponse_WritesToConnections verifies multicast sends to both connections +func TestServer_MulticastResponse_WritesToConnections(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + mockIPv6 := mocks.NewMockPacketConn(t) + + iface := net.Interface{Index: 1, Name: "eth0"} + + // Expect WriteTo to be called on both connections + mockIPv4.EXPECT().WriteTo(mock.Anything, 1, mock.Anything).Return(0, nil).Once() + mockIPv6.EXPECT().WriteTo(mock.Anything, 1, mock.Anything).Return(0, nil).Once() + + s := testServer(mockIPv4, mockIPv6, []net.Interface{iface}) + + msg := new(dns.Msg) + msg.SetQuestion("_test._tcp.local.", dns.TypePTR) + + err := s.multicastResponse(msg, 0) + if err != nil { + t.Fatalf("multicastResponse failed: %v", err) + } +} + +// TestServer_MulticastResponse_SpecificInterface verifies multicast to specific interface +func TestServer_MulticastResponse_SpecificInterface(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + mockIPv6 := mocks.NewMockPacketConn(t) + + // Expect WriteTo to be called with specific interface index 2 + mockIPv4.EXPECT().WriteTo(mock.Anything, 2, mock.Anything).Return(0, nil).Once() + mockIPv6.EXPECT().WriteTo(mock.Anything, 2, mock.Anything).Return(0, nil).Once() + + s := testServer(mockIPv4, mockIPv6, []net.Interface{{Index: 1, Name: "eth0"}, {Index: 2, Name: "wlan0"}}) + + msg := new(dns.Msg) + msg.SetQuestion("_test._tcp.local.", dns.TypePTR) + + // Send to specific interface (index 2) + err := s.multicastResponse(msg, 2) + if err != nil { + t.Fatalf("multicastResponse failed: %v", err) + } +} + +// TestServer_Shutdown_ClosesConnections verifies shutdown properly closes connections +func TestServer_Shutdown_ClosesConnections(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + mockIPv6 := mocks.NewMockPacketConn(t) + + // Expect Close and WriteTo (for unregister) to be called + mockIPv4.EXPECT().WriteTo(mock.Anything, mock.AnythingOfType("int"), mock.Anything).Return(0, nil).Maybe() + mockIPv6.EXPECT().WriteTo(mock.Anything, mock.AnythingOfType("int"), mock.Anything).Return(0, nil).Maybe() + mockIPv4.EXPECT().Close().Return(nil).Once() + mockIPv6.EXPECT().Close().Return(nil).Once() + + s := testServer(mockIPv4, mockIPv6, []net.Interface{{Index: 1, Name: "eth0"}}) + s.service = newServiceEntry("test", "_test._tcp", "local") + s.service.Port = 8080 + s.service.HostName = "test.local." + + s.Shutdown() +} + +// TestServerConfig verifies server configuration options +func TestServerConfig(t *testing.T) { + t.Run("default TTL", func(t *testing.T) { + opts := applyServerOpts() + if opts.ttl != defaultTTL { + t.Errorf("Expected default TTL %d, got %d", defaultTTL, opts.ttl) + } + }) + + t.Run("custom TTL", func(t *testing.T) { + opts := applyServerOpts(TTL(1000)) + if opts.ttl != 1000 { + t.Errorf("Expected TTL 1000, got %d", opts.ttl) + } + }) +} + +// TestWithServerConnFactory verifies the WithServerConnFactory option +func TestWithServerConnFactory(t *testing.T) { + factory := mocks.NewMockConnectionFactory(t) + + opts := applyServerOpts(WithServerConnFactory(factory)) + + if opts.connFactory != factory { + t.Error("Expected connection factory to be set") + } +} + +// TestIsKnownAnswer verifies known-answer suppression logic +func TestIsKnownAnswer(t *testing.T) { + t.Run("empty response answers", func(t *testing.T) { + resp := &dns.Msg{} + query := &dns.Msg{ + Answer: []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{Rrtype: dns.TypePTR, Ttl: 100}, + Ptr: "test._http._tcp.local.", + }, + }, + } + if isKnownAnswer(resp, query) { + t.Error("Expected false when response has no answers") + } + }) + + t.Run("empty query answers", func(t *testing.T) { + resp := &dns.Msg{ + Answer: []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{Rrtype: dns.TypePTR, Ttl: 100}, + Ptr: "test._http._tcp.local.", + }, + }, + } + query := &dns.Msg{} + if isKnownAnswer(resp, query) { + t.Error("Expected false when query has no answers") + } + }) + + t.Run("non-PTR response", func(t *testing.T) { + resp := &dns.Msg{ + Answer: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{Rrtype: dns.TypeA, Ttl: 100}, + A: net.ParseIP("192.168.1.1"), + }, + }, + } + query := &dns.Msg{ + Answer: []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{Rrtype: dns.TypePTR, Ttl: 100}, + Ptr: "test._http._tcp.local.", + }, + }, + } + if isKnownAnswer(resp, query) { + t.Error("Expected false for non-PTR response") + } + }) + + t.Run("matching known answer with sufficient TTL", func(t *testing.T) { + resp := &dns.Msg{ + Answer: []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{Rrtype: dns.TypePTR, Ttl: 100}, + Ptr: "test._http._tcp.local.", + }, + }, + } + query := &dns.Msg{ + Answer: []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{Rrtype: dns.TypePTR, Ttl: 60}, // >= 100/2 + Ptr: "test._http._tcp.local.", + }, + }, + } + if !isKnownAnswer(resp, query) { + t.Error("Expected true for matching known answer with sufficient TTL") + } + }) + + t.Run("matching known answer with insufficient TTL", func(t *testing.T) { + resp := &dns.Msg{ + Answer: []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{Rrtype: dns.TypePTR, Ttl: 100}, + Ptr: "test._http._tcp.local.", + }, + }, + } + query := &dns.Msg{ + Answer: []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{Rrtype: dns.TypePTR, Ttl: 40}, // < 100/2 + Ptr: "test._http._tcp.local.", + }, + }, + } + if isKnownAnswer(resp, query) { + t.Error("Expected false for known answer with insufficient TTL") + } + }) + + t.Run("non-matching PTR", func(t *testing.T) { + resp := &dns.Msg{ + Answer: []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{Rrtype: dns.TypePTR, Ttl: 100}, + Ptr: "test._http._tcp.local.", + }, + }, + } + query := &dns.Msg{ + Answer: []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{Rrtype: dns.TypePTR, Ttl: 100}, + Ptr: "other._http._tcp.local.", + }, + }, + } + if isKnownAnswer(resp, query) { + t.Error("Expected false for non-matching PTR") + } + }) +} + +// TestServer_HandleQuestion verifies question handling logic +func TestServer_HandleQuestion(t *testing.T) { + createTestServer := func() *Server { + s := &Server{ + ttl: 3200, + shouldShutdown: make(chan struct{}), + service: newServiceEntry("myservice", "_http._tcp", "local"), + } + s.service.Port = 8080 + s.service.HostName = "myhost.local." + s.service.Text = []string{"key=value"} + return s + } + + t.Run("nil service", func(t *testing.T) { + s := &Server{ + ttl: 3200, + shouldShutdown: make(chan struct{}), + service: nil, + } + resp := &dns.Msg{} + query := &dns.Msg{} + q := dns.Question{Name: "_http._tcp.local.", Qtype: dns.TypePTR} + + err := s.handleQuestion(q, resp, query, 1) + if err != nil { + t.Errorf("Expected no error for nil service, got %v", err) + } + if len(resp.Answer) != 0 { + t.Error("Expected no answers for nil service") + } + }) + + t.Run("service type query", func(t *testing.T) { + s := createTestServer() + resp := &dns.Msg{} + query := &dns.Msg{} + q := dns.Question{Name: s.service.ServiceTypeName(), Qtype: dns.TypePTR} + + err := s.handleQuestion(q, resp, query, 1) + if err != nil { + t.Errorf("handleQuestion failed: %v", err) + } + if len(resp.Answer) == 0 { + t.Error("Expected answers for service type query") + } + }) + + t.Run("service name query", func(t *testing.T) { + s := createTestServer() + resp := &dns.Msg{} + query := &dns.Msg{} + q := dns.Question{Name: s.service.ServiceName(), Qtype: dns.TypePTR} + + err := s.handleQuestion(q, resp, query, 1) + if err != nil { + t.Errorf("handleQuestion failed: %v", err) + } + if len(resp.Answer) == 0 { + t.Error("Expected answers for service name query") + } + }) + + t.Run("service instance query", func(t *testing.T) { + s := createTestServer() + resp := &dns.Msg{} + query := &dns.Msg{} + q := dns.Question{Name: s.service.ServiceInstanceName(), Qtype: dns.TypeSRV} + + err := s.handleQuestion(q, resp, query, 1) + if err != nil { + t.Errorf("handleQuestion failed: %v", err) + } + if len(resp.Answer) == 0 { + t.Error("Expected answers for service instance query") + } + }) + + t.Run("subtype query", func(t *testing.T) { + s := createTestServer() + s.service.Subtypes = []string{"_printer"} + resp := &dns.Msg{} + query := &dns.Msg{} + subtypeName := "_printer._sub." + s.service.ServiceName() + q := dns.Question{Name: subtypeName, Qtype: dns.TypePTR} + + err := s.handleQuestion(q, resp, query, 1) + if err != nil { + t.Errorf("handleQuestion failed: %v", err) + } + if len(resp.Answer) == 0 { + t.Error("Expected answers for subtype query") + } + }) + + t.Run("unknown query name", func(t *testing.T) { + s := createTestServer() + resp := &dns.Msg{} + query := &dns.Msg{} + q := dns.Question{Name: "_unknown._tcp.local.", Qtype: dns.TypePTR} + + err := s.handleQuestion(q, resp, query, 1) + if err != nil { + t.Errorf("handleQuestion failed: %v", err) + } + if len(resp.Answer) != 0 { + t.Error("Expected no answers for unknown query") + } + }) + + t.Run("known answer suppression", func(t *testing.T) { + s := createTestServer() + resp := &dns.Msg{} + // Query with known answer + query := &dns.Msg{ + Answer: []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypePTR, + Ttl: 3200, // >= s.ttl/2 + }, + Ptr: s.service.ServiceInstanceName(), + }, + }, + } + q := dns.Question{Name: s.service.ServiceName(), Qtype: dns.TypePTR} + + err := s.handleQuestion(q, resp, query, 1) + if err != nil { + t.Errorf("handleQuestion failed: %v", err) + } + // Answer should be suppressed + if len(resp.Answer) != 0 { + t.Error("Expected answer to be suppressed due to known-answer") + } + }) +} + +// TestRegisterProxy_Validation tests RegisterProxy input validation +func TestRegisterProxy_Validation(t *testing.T) { + t.Run("missing instance name", func(t *testing.T) { + _, err := RegisterProxy("", "_http._tcp", "local", 8080, "myhost", []string{"192.168.1.1"}, nil, nil) + if err == nil { + t.Error("Expected error for missing instance name") + } + }) + + t.Run("missing service name", func(t *testing.T) { + _, err := RegisterProxy("myservice", "", "local", 8080, "myhost", []string{"192.168.1.1"}, nil, nil) + if err == nil { + t.Error("Expected error for missing service name") + } + }) + + t.Run("missing host name", func(t *testing.T) { + _, err := RegisterProxy("myservice", "_http._tcp", "local", 8080, "", []string{"192.168.1.1"}, nil, nil) + if err == nil { + t.Error("Expected error for missing host name") + } + }) + + t.Run("missing port", func(t *testing.T) { + _, err := RegisterProxy("myservice", "_http._tcp", "local", 0, "myhost", []string{"192.168.1.1"}, nil, nil) + if err == nil { + t.Error("Expected error for missing port") + } + }) + + t.Run("invalid IP address", func(t *testing.T) { + _, err := RegisterProxy("myservice", "_http._tcp", "local", 8080, "myhost", []string{"invalid-ip"}, nil, nil) + if err == nil { + t.Error("Expected error for invalid IP address") + } + }) +} + +// setupMockServerConnections creates mock connections for server tests +func setupMockServerConnections(t *testing.T) (*mocks.MockPacketConn, *mocks.MockPacketConn, api.ConnectionFactory) { + mockIPv4 := mocks.NewMockPacketConn(t) + mockIPv6 := mocks.NewMockPacketConn(t) + factory := mocks.NewMockConnectionFactory(t) + + factory.EXPECT().CreateIPv4Conn(mock.Anything).Return(mockIPv4, nil).Once() + factory.EXPECT().CreateIPv6Conn(mock.Anything).Return(mockIPv6, nil).Once() + + return mockIPv4, mockIPv6, factory +} + +// TestRegisterProxy_WithMockConnections tests RegisterProxy with mocked connections +func TestRegisterProxy_WithMockConnections(t *testing.T) { + mockIPv4, mockIPv6, factory := setupMockServerConnections(t) + + // Mock ReadFrom to block until shutdown + mockIPv4.EXPECT().ReadFrom(mock.Anything).RunAndReturn(func(b []byte) (int, int, net.Addr, error) { + time.Sleep(50 * time.Millisecond) + return 0, 0, nil, errors.New("shutdown") + }).Maybe() + mockIPv6.EXPECT().ReadFrom(mock.Anything).RunAndReturn(func(b []byte) (int, int, net.Addr, error) { + time.Sleep(50 * time.Millisecond) + return 0, 0, nil, errors.New("shutdown") + }).Maybe() + + // Mock WriteTo for probes and announcements + mockIPv4.EXPECT().WriteTo(mock.Anything, mock.Anything, mock.Anything).Return(0, nil).Maybe() + mockIPv6.EXPECT().WriteTo(mock.Anything, mock.Anything, mock.Anything).Return(0, nil).Maybe() + + // Mock Close + mockIPv4.EXPECT().Close().Return(nil).Maybe() + mockIPv6.EXPECT().Close().Return(nil).Maybe() + + // Register the proxy service + server, err := RegisterProxy( + "myservice", + "_http._tcp", + "local", + 8080, + "myhost", + []string{"192.168.1.100", "fe80::1"}, + []string{"key=value"}, + []net.Interface{{Index: 1, Name: "eth0"}}, + WithServerConnFactory(factory), + ) + if err != nil { + t.Fatalf("RegisterProxy failed: %v", err) + } + defer server.Shutdown() + + // Verify service was set up correctly + if server.service.Instance != "myservice" { + t.Errorf("Expected instance 'myservice', got '%s'", server.service.Instance) + } + if server.service.Port != 8080 { + t.Errorf("Expected port 8080, got %d", server.service.Port) + } + if len(server.service.AddrIPv4) != 1 { + t.Errorf("Expected 1 IPv4 address, got %d", len(server.service.AddrIPv4)) + } + if len(server.service.AddrIPv6) != 1 { + t.Errorf("Expected 1 IPv6 address, got %d", len(server.service.AddrIPv6)) + } +} + +// TestServer_SetText tests the SetText method +func TestServer_SetText(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + mockIPv6 := mocks.NewMockPacketConn(t) + + // Track WriteTo calls to verify announcement was sent + var writeCount int + var mu sync.Mutex + + mockIPv4.EXPECT().WriteTo(mock.Anything, mock.Anything, mock.Anything).RunAndReturn( + func(b []byte, ifIndex int, dst net.Addr) (int, error) { + mu.Lock() + writeCount++ + mu.Unlock() + return len(b), nil + }).Maybe() + mockIPv6.EXPECT().WriteTo(mock.Anything, mock.Anything, mock.Anything).Return(0, nil).Maybe() + + s := testServer(mockIPv4, mockIPv6, []net.Interface{{Index: 1, Name: "eth0"}}) + s.service = newServiceEntry("test", "_test._tcp", "local") + s.service.Port = 8080 + s.service.HostName = "test.local." + s.service.Text = []string{"old=value"} + + // Update text + s.SetText([]string{"new=value"}) + + // Verify text was updated + if len(s.service.Text) != 1 || s.service.Text[0] != "new=value" { + t.Errorf("Expected text 'new=value', got %v", s.service.Text) + } + + // Verify announcement was sent (WriteTo was called) + mu.Lock() + if writeCount == 0 { + t.Error("Expected announcement to be sent after SetText") + } + mu.Unlock() +} + +// TestServer_HandleQuery_RespondsToQueries tests server responding to mDNS queries +func TestServer_HandleQuery_RespondsToQueries(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + mockIPv6 := mocks.NewMockPacketConn(t) + + // Capture responses + var capturedResponses [][]byte + var mu sync.Mutex + + mockIPv4.EXPECT().WriteTo(mock.Anything, mock.Anything, mock.Anything).RunAndReturn( + func(b []byte, ifIndex int, dst net.Addr) (int, error) { + mu.Lock() + responseCopy := make([]byte, len(b)) + copy(responseCopy, b) + capturedResponses = append(capturedResponses, responseCopy) + mu.Unlock() + return len(b), nil + }).Maybe() + mockIPv6.EXPECT().WriteTo(mock.Anything, mock.Anything, mock.Anything).Return(0, nil).Maybe() + + s := testServer(mockIPv4, mockIPv6, []net.Interface{{Index: 1, Name: "eth0"}}) + s.service = newServiceEntry("myservice", "_http._tcp", "local") + s.service.Port = 8080 + s.service.HostName = "myhost.local." + s.service.Text = []string{"key=value"} + s.service.AddrIPv4 = []net.IP{net.ParseIP("192.168.1.100")} + + // Create a query for our service + query := new(dns.Msg) + query.SetQuestion("_http._tcp.local.", dns.TypePTR) + + // Handle the query + err := s.handleQuery(query, 1, &net.UDPAddr{IP: net.ParseIP("192.168.1.50"), Port: 5353}) + if err != nil { + t.Fatalf("handleQuery failed: %v", err) + } + + // Verify response was sent + mu.Lock() + responseCount := len(capturedResponses) + mu.Unlock() + + if responseCount == 0 { + t.Error("Expected response to be sent for matching query") + } + + // Parse and verify the response + if responseCount > 0 { + mu.Lock() + respData := capturedResponses[0] + mu.Unlock() + + resp := new(dns.Msg) + if err := resp.Unpack(respData); err != nil { + t.Fatalf("Failed to unpack response: %v", err) + } + + if len(resp.Answer) == 0 { + t.Error("Expected answers in response") + } + } +} + +// TestServer_UnicastResponse tests unicast response handling +func TestServer_UnicastResponse(t *testing.T) { + mockIPv4 := mocks.NewMockPacketConn(t) + + // Capture the destination address to verify unicast + var capturedDst net.Addr + var mu sync.Mutex + + mockIPv4.EXPECT().WriteTo(mock.Anything, mock.Anything, mock.Anything).RunAndReturn( + func(b []byte, ifIndex int, dst net.Addr) (int, error) { + mu.Lock() + capturedDst = dst + mu.Unlock() + return len(b), nil + }).Once() + + s := testServer(mockIPv4, nil, []net.Interface{{Index: 1, Name: "eth0"}}) + s.service = newServiceEntry("myservice", "_http._tcp", "local") + s.service.Port = 8080 + s.service.HostName = "myhost.local." + + // Send unicast response + msg := new(dns.Msg) + msg.SetQuestion("_http._tcp.local.", dns.TypePTR) + clientAddr := &net.UDPAddr{IP: net.ParseIP("192.168.1.50"), Port: 5353} + + err := s.unicastResponse(msg, 1, clientAddr) + if err != nil { + t.Fatalf("unicastResponse failed: %v", err) + } + + // Verify response was sent to the client's address + mu.Lock() + defer mu.Unlock() + if capturedDst == nil { + t.Error("Expected response to be sent") + } else { + udpAddr, ok := capturedDst.(*net.UDPAddr) + if !ok { + t.Error("Expected UDP address") + } else if !udpAddr.IP.Equal(net.ParseIP("192.168.1.50")) { + t.Errorf("Expected response to 192.168.1.50, got %s", udpAddr.IP) + } + } +} diff --git a/version.json b/version.json index 26f9a280..f6c42b90 100644 --- a/version.json +++ b/version.json @@ -1,3 +1,3 @@ { - "version": "v2.2.0" + "version": "v3.0.0" }