diff --git a/client.go b/client.go index c0b2cae1..320ceb65 100644 --- a/client.go +++ b/client.go @@ -28,18 +28,23 @@ const ( IPv4AndIPv6 = IPv4 | IPv6 // default option ) -var initialQueryInterval = 4 * time.Second +var ( + initialQueryInterval = 4 * time.Second + defaultClientWriteTimeout = 10 * time.Second +) // Client structure encapsulates both IPv4/IPv6 UDP connections. type client struct { - ipv4conn *ipv4.PacketConn - ipv6conn *ipv6.PacketConn - ifaces []net.Interface + ipv4conn *ipv4.PacketConn + ipv6conn *ipv6.PacketConn + interfaces NetInterfaceList + writeTimeout time.Duration } type clientOpts struct { - listenOn IPType - ifaces []net.Interface + listenOn IPType + ifaces []net.Interface + writeTimeout time.Duration } // ClientOption fills the option struct to configure intefaces, etc. @@ -63,6 +68,13 @@ func SelectIfaces(ifaces []net.Interface) ClientOption { } } +// ClientWriteTimeout sets timeout for writing to the socket +func ClientWriteTimeout(duration time.Duration) ClientOption { + return func(o *clientOpts) { + o.writeTimeout = duration + } +} + // Browse for all services of a given type in a given domain. // Received entries are sent on the entries channel. // It blocks until the context is canceled (or an error occurs). @@ -100,7 +112,8 @@ func Lookup(ctx context.Context, instance, service, domain string, entries chan< func applyOpts(options ...ClientOption) clientOpts { // Apply default configuration and load supplied options. var conf = clientOpts{ - listenOn: IPv4AndIPv6, + listenOn: IPv4AndIPv6, + writeTimeout: defaultClientWriteTimeout, } for _, o := range options { if o != nil { @@ -137,11 +150,12 @@ func newClient(opts clientOpts) (*client, error) { if len(ifaces) == 0 { ifaces = listMulticastInterfaces() } + ifaceList := NewInterfaceList(ifaces) // IPv4 interfaces var ipv4conn *ipv4.PacketConn if (opts.listenOn & IPv4) > 0 { var err error - ipv4conn, err = joinUdp4Multicast(ifaces) + ipv4conn, err = joinUdp4Multicast(ifaceList) if err != nil { return nil, err } @@ -150,16 +164,17 @@ func newClient(opts clientOpts) (*client, error) { var ipv6conn *ipv6.PacketConn if (opts.listenOn & IPv6) > 0 { var err error - ipv6conn, err = joinUdp6Multicast(ifaces) + ipv6conn, err = joinUdp6Multicast(ifaceList) if err != nil { return nil, err } } return &client{ - ipv4conn: ipv4conn, - ipv6conn: ipv6conn, - ifaces: ifaces, + ipv4conn: ipv4conn, + ipv6conn: ipv6conn, + interfaces: ifaceList, + writeTimeout: opts.writeTimeout, }, nil } @@ -428,30 +443,38 @@ func (c *client) query(params *lookupParams) error { m.SetQuestion(serviceName, dns.TypePTR) } m.RecursionDesired = false - return c.sendQuery(m) + // only send multicast queries to interfaces that we have joined + return c.sendQuery(m, NetInterfaceStateFlagMulticastJoined) } // Pack the dns.Msg and write to available connections (multicast) -func (c *client) sendQuery(msg *dns.Msg) error { +func (c *client) sendQuery(msg *dns.Msg, requiredFlags ...NetInterfaceStateFlag) error { buf, err := msg.Pack() if err != nil { - return err + return fmt.Errorf("failed to pack msg %v: %w", msg, err) } if c.ipv4conn != nil { // See https://pkg.go.dev/golang.org/x/net/ipv4#pkg-note-BUG // As of Golang 1.18.4 // On Windows, the ControlMessage for ReadFrom and WriteTo methods of PacketConn is not implemented. var wcm ipv4.ControlMessage - for ifi := range c.ifaces { + for _, intf := range c.interfaces { + if !intf.HasFlags(NetInterfaceScopeIPv4, requiredFlags...) { + continue + } switch runtime.GOOS { case "darwin", "ios", "linux": - wcm.IfIndex = c.ifaces[ifi].Index + wcm.IfIndex = intf.Index default: - if err := c.ipv4conn.SetMulticastInterface(&c.ifaces[ifi]); err != nil { + if err := c.ipv4conn.SetMulticastInterface(&intf.Interface); err != nil { log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) } } - c.ipv4conn.WriteTo(buf, &wcm, ipv4Addr) + setDeadline(c.writeTimeout, c.ipv4conn) + n, err := c.ipv4conn.WriteTo(buf, &wcm, ipv4Addr) + if err == nil && n > 0 { + intf.SetFlag(NetInterfaceScopeIPv4, NetInterfaceStateFlagMessageSent) + } } } if c.ipv6conn != nil { @@ -459,16 +482,23 @@ func (c *client) sendQuery(msg *dns.Msg) error { // As of Golang 1.18.4 // On Windows, the ControlMessage for ReadFrom and WriteTo methods of PacketConn is not implemented. var wcm ipv6.ControlMessage - for ifi := range c.ifaces { + for _, intf := range c.interfaces { + if !intf.HasFlags(NetInterfaceScopeIPv6, requiredFlags...) { + continue + } switch runtime.GOOS { case "darwin", "ios", "linux": - wcm.IfIndex = c.ifaces[ifi].Index + wcm.IfIndex = intf.Index default: - if err := c.ipv6conn.SetMulticastInterface(&c.ifaces[ifi]); err != nil { + if err := c.ipv6conn.SetMulticastInterface(&intf.Interface); err != nil { log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) } } - c.ipv6conn.WriteTo(buf, &wcm, ipv6Addr) + setDeadline(c.writeTimeout, c.ipv6conn) + n, err := c.ipv6conn.WriteTo(buf, &wcm, ipv6Addr) + if err == nil && n > 0 { + intf.SetFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagMessageSent) + } } } return nil diff --git a/connection.go b/connection.go index a0936d9e..ec5075d1 100644 --- a/connection.go +++ b/connection.go @@ -35,7 +35,11 @@ var ( } ) -func joinUdp6Multicast(interfaces []net.Interface) (*ipv6.PacketConn, error) { +func joinUdp6Multicast(interfaces []*NetInterface) (*ipv6.PacketConn, error) { + if len(interfaces) == 0 { + return nil, fmt.Errorf("no interfaces to join multicast on") + } + udpConn, err := net.ListenUDP("udp6", mdnsWildcardAddrIPv6) if err != nil { return nil, err @@ -45,19 +49,15 @@ func joinUdp6Multicast(interfaces []net.Interface) (*ipv6.PacketConn, error) { pkConn := ipv6.NewPacketConn(udpConn) pkConn.SetControlMessage(ipv6.FlagInterface, true) - if len(interfaces) == 0 { - interfaces = listMulticastInterfaces() - } // log.Println("Using multicast interfaces: ", interfaces) - - var failedJoins int + var anySucceeded bool for _, iface := range interfaces { - if err := pkConn.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { - // log.Println("Udp6 JoinGroup failed for iface ", iface) - failedJoins++ + if err := pkConn.JoinGroup(&iface.Interface, &net.UDPAddr{IP: mdnsGroupIPv6}); err == nil { + iface.SetFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagMulticastJoined) + anySucceeded = true } } - if failedJoins == len(interfaces) { + if !anySucceeded { pkConn.Close() return nil, fmt.Errorf("udp6: failed to join any of these interfaces: %v", interfaces) } @@ -67,7 +67,11 @@ func joinUdp6Multicast(interfaces []net.Interface) (*ipv6.PacketConn, error) { return pkConn, nil } -func joinUdp4Multicast(interfaces []net.Interface) (*ipv4.PacketConn, error) { +func joinUdp4Multicast(interfaces []*NetInterface) (*ipv4.PacketConn, error) { + if len(interfaces) == 0 { + return nil, fmt.Errorf("no interfaces to join multicast on") + } + udpConn, err := net.ListenUDP("udp4", mdnsWildcardAddrIPv4) if err != nil { // log.Printf("[ERR] bonjour: Failed to bind to udp4 mutlicast: %v", err) @@ -78,19 +82,16 @@ func joinUdp4Multicast(interfaces []net.Interface) (*ipv4.PacketConn, error) { pkConn := ipv4.NewPacketConn(udpConn) pkConn.SetControlMessage(ipv4.FlagInterface, true) - if len(interfaces) == 0 { - interfaces = listMulticastInterfaces() - } // log.Println("Using multicast interfaces: ", interfaces) + var anySucceed bool - var failedJoins int for _, iface := range interfaces { - if err := pkConn.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { - // log.Println("Udp4 JoinGroup failed for iface ", iface) - failedJoins++ + if err := pkConn.JoinGroup(&iface.Interface, &net.UDPAddr{IP: mdnsGroupIPv4}); err == nil { + anySucceed = true + iface.SetFlag(NetInterfaceScopeIPv4, NetInterfaceStateFlagMulticastJoined) } } - if failedJoins == len(interfaces) { + if !anySucceed { pkConn.Close() return nil, fmt.Errorf("udp4: failed to join any of these interfaces: %v", interfaces) } diff --git a/netinterface.go b/netinterface.go new file mode 100644 index 00000000..b97d7211 --- /dev/null +++ b/netinterface.go @@ -0,0 +1,86 @@ +package zeroconf + +import ( + "net" + "sync/atomic" +) + +type NetInterface struct { + net.Interface + stateIPv4 NetInterfaceStateFlag + stateIPv6 NetInterfaceStateFlag +} + +type NetInterfaceScope int + +const ( + NetInterfaceScopeIPv4 NetInterfaceScope = iota + NetInterfaceScopeIPv6 +) + +type NetInterfaceList []*NetInterface + +type NetInterfaceStateFlag uint32 + +const ( + NetInterfaceStateFlagMulticastJoined NetInterfaceStateFlag = 1 << iota // we have joined the multicast group on this interface + NetInterfaceStateFlagMessageSent // we have successfully sent at least one message on this interface +) + +func (i *NetInterface) HasFlags(scope NetInterfaceScope, flags ...NetInterfaceStateFlag) bool { + for _, flag := range flags { + if !i.HasFlag(scope, flag) { + return false + } + } + return true +} + +func (i *NetInterface) loadFlag(address *NetInterfaceStateFlag) NetInterfaceStateFlag { + return NetInterfaceStateFlag(atomic.LoadUint32((*uint32)(address))) +} + +func (i *NetInterface) HasFlag(scope NetInterfaceScope, flag NetInterfaceStateFlag) bool { + if scope == NetInterfaceScopeIPv4 { + return NetInterfaceStateFlag(i.loadFlag(&i.stateIPv4)&flag) != 0 + } else if scope == NetInterfaceScopeIPv6 { + return NetInterfaceStateFlag(i.loadFlag(&i.stateIPv6)&flag) != 0 + } + return false +} + +func (i *NetInterface) SetFlag(scope NetInterfaceScope, flag NetInterfaceStateFlag) { + if scope == NetInterfaceScopeIPv4 { + i.setFlag(&i.stateIPv4, flag) + } else if scope == NetInterfaceScopeIPv6 { + i.setFlag(&i.stateIPv6, flag) + } +} + +func (i *NetInterface) setFlag(address *NetInterfaceStateFlag, flag NetInterfaceStateFlag) { + // If atomic value != previously loaded value, then repeat the operation + // If they are equal, then we can safely set the new value + // This is the way to ensure atomicity of the operation + for { + loadedValue := uint32(i.loadFlag(address)) + if atomic.CompareAndSwapUint32((*uint32)(address), loadedValue, loadedValue|uint32(flag)) { + break + } + } +} + +func (list NetInterfaceList) GetByIndex(index int) *NetInterface { + for _, iface := range list { + if iface.Index == index { + return iface + } + } + return nil +} + +func NewInterfaceList(ifaces []net.Interface) (list NetInterfaceList) { + for i := range ifaces { + list = append(list, &NetInterface{Interface: ifaces[i]}) + } + return +} diff --git a/netinterface_test.go b/netinterface_test.go new file mode 100644 index 00000000..fc1ea4da --- /dev/null +++ b/netinterface_test.go @@ -0,0 +1,73 @@ +package zeroconf + +import ( + "sync" + "testing" + "time" +) + +func TestSetFlagSimple(t *testing.T) { + t.Run("ipv4", func(t *testing.T) { + iface := &NetInterface{} + iface.SetFlag(NetInterfaceScopeIPv4, NetInterfaceStateFlagMulticastJoined) + if !iface.HasFlag(NetInterfaceScopeIPv4, NetInterfaceStateFlagMulticastJoined) { + t.Error("expect true") + } + if iface.HasFlags(NetInterfaceScopeIPv4, NetInterfaceStateFlagMulticastJoined, NetInterfaceStateFlagMessageSent) { + t.Error("expect false") + } + + iface.SetFlag(NetInterfaceScopeIPv4, NetInterfaceStateFlagMessageSent) + if !iface.HasFlags(NetInterfaceScopeIPv4, NetInterfaceStateFlagMulticastJoined, NetInterfaceStateFlagMessageSent) { + t.Error("expect true") + } + }) + + t.Run("ipv6", func(t *testing.T) { + iface := &NetInterface{} + iface.SetFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagMessageSent) + if !iface.HasFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagMessageSent) { + t.Error("expect true") + } + if iface.HasFlags(NetInterfaceScopeIPv6, NetInterfaceStateFlagMulticastJoined, NetInterfaceStateFlagMessageSent) { + t.Error("expect false") + } + + iface.SetFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagMulticastJoined) + if !iface.HasFlags(NetInterfaceScopeIPv6, NetInterfaceStateFlagMulticastJoined, NetInterfaceStateFlagMessageSent) { + t.Error("expect true") + } + }) +} + +func TestSetFlagConcurrent(t *testing.T) { + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + iface := &NetInterface{} + wg.Add(2) + go func() { + defer wg.Done() + + for j := 0; j < 10; j++ { + if j%2 == 0 { + iface.SetFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagMulticastJoined) + } else { + iface.SetFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagMessageSent) + } + } + }() + go func() { + defer wg.Done() + + var eventuallyOk bool + for j := 0; j < 10; j++ { + eventuallyOk = iface.HasFlags(NetInterfaceScopeIPv6, NetInterfaceStateFlagMulticastJoined, NetInterfaceStateFlagMessageSent) + time.Sleep(time.Millisecond) + } + if !eventuallyOk { + t.Error("expect true") + } + }() + } + wg.Wait() +} diff --git a/server.go b/server.go index 2dbf536e..5887e6bb 100644 --- a/server.go +++ b/server.go @@ -21,16 +21,23 @@ const ( multicastRepetitions = 2 ) -var defaultTTL uint32 = 3200 +var ( + defaultTTL uint32 = 3200 + defaultServerWriteTimeout = 10 * time.Second +) type serverOpts struct { - ttl uint32 + ttl uint32 + listenOn IPType + writeTimeout time.Duration } func applyServerOpts(options ...ServerOption) serverOpts { // Apply default configuration and load supplied options. var conf = serverOpts{ - ttl: defaultTTL, + listenOn: IPv4AndIPv6, + ttl: defaultTTL, + writeTimeout: defaultServerWriteTimeout, } for _, o := range options { if o != nil { @@ -50,6 +57,21 @@ func TTL(ttl uint32) ServerOption { } } +// WriteTimeout sets timeout for writing to the socket +func WriteTimeout(duration time.Duration) ServerOption { + return func(o *serverOpts) { + o.writeTimeout = duration + } +} + +// ServerSelectIPTraffic selects the type of IP packets (IPv4, IPv6, or both) this +// instance listens for. +func ServerSelectIPTraffic(t IPType) ServerOption { + return func(o *serverOpts) { + o.listenOn = t + } +} + // Register a service by given arguments. This call will take the system's hostname // and lookup IP by that hostname. func Register(instance, service, domain string, port int, text []string, ifaces []net.Interface, opts ...ServerOption) (*Server, error) { @@ -86,16 +108,6 @@ func Register(instance, service, domain string, port int, text []string, ifaces ifaces = listMulticastInterfaces() } - for _, iface := range ifaces { - v4, v6 := addrsForInterface(&iface) - entry.AddrIPv4 = append(entry.AddrIPv4, v4...) - entry.AddrIPv6 = append(entry.AddrIPv6, v6...) - } - - if entry.AddrIPv4 == nil && entry.AddrIPv6 == nil { - return nil, fmt.Errorf("could not determine host IP addresses") - } - s, err := newServer(ifaces, applyServerOpts(opts...)) if err != nil { return nil, err @@ -169,41 +181,54 @@ const ( // Server structure encapsulates both IPv4/IPv6 UDP connections type Server struct { - service *ServiceEntry - ipv4conn *ipv4.PacketConn - ipv6conn *ipv6.PacketConn - ifaces []net.Interface + service *ServiceEntry + ipv4conn *ipv4.PacketConn + ipv6conn *ipv6.PacketConn + interfaces NetInterfaceList + // store if any write to iface was successful shouldShutdown chan struct{} shutdownLock sync.Mutex refCount sync.WaitGroup isShutdown bool ttl uint32 + writeTimeout time.Duration } // Constructs server structure func newServer(ifaces []net.Interface, opts serverOpts) (*Server, error) { - ipv4conn, err4 := joinUdp4Multicast(ifaces) - if err4 != nil { - log.Printf("[zeroconf] no suitable IPv4 interface: %s", err4.Error()) - } - ipv6conn, err6 := joinUdp6Multicast(ifaces) - if err6 != nil { - log.Printf("[zeroconf] no suitable IPv6 interface: %s", err6.Error()) - } - if err4 != nil && err6 != nil { - // No supported interface left. - return nil, fmt.Errorf("no supported interface") + if len(ifaces) == 0 { + ifaces = listMulticastInterfaces() } + ifaceList := NewInterfaceList(ifaces) + s := &Server{ - ipv4conn: ipv4conn, - ipv6conn: ipv6conn, - ifaces: ifaces, + interfaces: ifaceList, ttl: opts.ttl, + writeTimeout: opts.writeTimeout, shouldShutdown: make(chan struct{}), } + var err error + if (opts.listenOn & IPv4) > 0 { + s.ipv4conn, err = joinUdp4Multicast(ifaceList) + if err != nil { + log.Printf("[zeroconf] no suitable IPv4 interface: %s", err.Error()) + } + } + + if (opts.listenOn & IPv6) > 0 { + s.ipv6conn, err = joinUdp6Multicast(ifaceList) + if err != nil { + log.Printf("[zeroconf] no suitable IPv6 interface: %s", err.Error()) + } + + } + + if s.ipv6conn == nil && s.ipv4conn == nil { + return nil, fmt.Errorf("no supported interface") + } return s, nil } @@ -549,7 +574,7 @@ func (s *Server) serviceTypeName(resp *dns.Msg, ttl uint32) { } // Perform probing & announcement -//TODO: implement a proper probing & conflict resolution +// TODO: implement a proper probing & conflict resolution func (s *Server) probe() { defer s.refCount.Done() @@ -590,7 +615,7 @@ func (s *Server) probe() { return } for i := 0; i < 3; i++ { - if err := s.multicastResponse(q, 0); err != nil { + if err := s.multicastResponse(q, 0, NetInterfaceStateFlagMulticastJoined); err != nil { log.Println("[ERR] zeroconf: failed to send probe:", err.Error()) } timer.Reset(250 * time.Millisecond) @@ -609,7 +634,7 @@ func (s *Server) probe() { // at least a factor of two with every response sent. timeout := time.Second for i := 0; i < multicastRepetitions; i++ { - for _, intf := range s.ifaces { + for _, intf := range s.interfaces { resp := new(dns.Msg) resp.MsgHdr.Response = true // TODO: make response authoritative if we are the publisher @@ -617,7 +642,7 @@ func (s *Server) probe() { resp.Answer = []dns.RR{} resp.Extra = []dns.RR{} s.composeLookupAnswers(resp, s.ttl, intf.Index, true) - if err := s.multicastResponse(resp, intf.Index); err != nil { + if err := s.multicastResponse(resp, intf.Index, NetInterfaceStateFlagMulticastJoined); err != nil { log.Println("[ERR] zeroconf: failed to send announcement:", err.Error()) } } @@ -656,7 +681,9 @@ func (s *Server) unregister() error { resp.Answer = []dns.RR{} resp.Extra = []dns.RR{} s.composeLookupAnswers(resp, 0, 0, true) - return s.multicastResponse(resp, 0) + // cleanup ifaces we have NEVER written successfully. No need to unregister them + + return s.multicastResponse(resp, 0, NetInterfaceStateFlagMulticastJoined, NetInterfaceStateFlagMessageSent) } func (s *Server) appendAddrs(list []dns.RR, ttl uint32, ifIndex int, flushCache bool) []dns.RR { @@ -737,29 +764,35 @@ func (s *Server) unicastResponse(resp *dns.Msg, ifIndex int, from net.Addr) erro return err } addr := from.(*net.UDPAddr) - if addr.IP.To4() != nil { + if addr.IP.To4() != nil && s.ipv4conn != nil { if ifIndex != 0 { var wcm ipv4.ControlMessage wcm.IfIndex = ifIndex + setDeadline(s.writeTimeout, s.ipv4conn) _, err = s.ipv4conn.WriteTo(buf, &wcm, addr) } else { + setDeadline(s.writeTimeout, s.ipv4conn) _, err = s.ipv4conn.WriteTo(buf, nil, addr) } return err - } else { + } else if s.ipv6conn != nil { if ifIndex != 0 { var wcm ipv6.ControlMessage wcm.IfIndex = ifIndex + setDeadline(s.writeTimeout, s.ipv6conn) _, err = s.ipv6conn.WriteTo(buf, &wcm, addr) } else { + setDeadline(s.writeTimeout, s.ipv6conn) _, err = s.ipv6conn.WriteTo(buf, nil, addr) } return err + } else { + return fmt.Errorf("no suitable interface") } } // multicastResponse is used to send a multicast response packet -func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int) error { +func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int, requiredFlags ...NetInterfaceStateFlag) error { buf, err := msg.Pack() if err != nil { return fmt.Errorf("failed to pack msg %v: %w", msg, err) @@ -770,27 +803,41 @@ func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int) error { // On Windows, the ControlMessage for ReadFrom and WriteTo methods of PacketConn is not implemented. var wcm ipv4.ControlMessage if ifIndex != 0 { - switch runtime.GOOS { - case "darwin", "ios", "linux": - wcm.IfIndex = ifIndex - default: - iface, _ := net.InterfaceByIndex(ifIndex) - if err := s.ipv4conn.SetMulticastInterface(iface); err != nil { - log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) + if v := s.interfaces.GetByIndex(ifIndex); v != nil && v.HasFlags(NetInterfaceScopeIPv4, requiredFlags...) { + switch runtime.GOOS { + case "darwin", "ios", "linux": + wcm.IfIndex = ifIndex + default: + iface, _ := net.InterfaceByIndex(ifIndex) + if err := s.ipv4conn.SetMulticastInterface(iface); err != nil { + log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) + } + } + setDeadline(s.writeTimeout, s.ipv4conn) + n, err := s.ipv4conn.WriteTo(buf, &wcm, ipv4Addr) + if err == nil && n > 0 { + s.interfaces.GetByIndex(ifIndex).SetFlag(NetInterfaceScopeIPv4, NetInterfaceStateFlagMessageSent) } } - s.ipv4conn.WriteTo(buf, &wcm, ipv4Addr) + } else { - for _, intf := range s.ifaces { + for _, intf := range s.interfaces { + if !intf.HasFlags(NetInterfaceScopeIPv4, requiredFlags...) { + continue + } switch runtime.GOOS { case "darwin", "ios", "linux": wcm.IfIndex = intf.Index default: - if err := s.ipv4conn.SetMulticastInterface(&intf); err != nil { + if err := s.ipv4conn.SetMulticastInterface(&intf.Interface); err != nil { log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) } } - s.ipv4conn.WriteTo(buf, &wcm, ipv4Addr) + setDeadline(s.writeTimeout, s.ipv4conn) + n, err := s.ipv4conn.WriteTo(buf, &wcm, ipv4Addr) + if err == nil && n > 0 { + intf.SetFlag(NetInterfaceScopeIPv4, NetInterfaceStateFlagMessageSent) + } } } } @@ -801,27 +848,40 @@ func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int) error { // On Windows, the ControlMessage for ReadFrom and WriteTo methods of PacketConn is not implemented. var wcm ipv6.ControlMessage if ifIndex != 0 { - switch runtime.GOOS { - case "darwin", "ios", "linux": - wcm.IfIndex = ifIndex - default: - iface, _ := net.InterfaceByIndex(ifIndex) - if err := s.ipv6conn.SetMulticastInterface(iface); err != nil { - log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) + if v := s.interfaces.GetByIndex(ifIndex); v != nil && v.HasFlags(NetInterfaceScopeIPv6, requiredFlags...) { + switch runtime.GOOS { + case "darwin", "ios", "linux": + wcm.IfIndex = ifIndex + default: + iface, _ := net.InterfaceByIndex(ifIndex) + if err := s.ipv6conn.SetMulticastInterface(iface); err != nil { + log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) + } + } + setDeadline(s.writeTimeout, s.ipv6conn) + n, err := s.ipv6conn.WriteTo(buf, &wcm, ipv6Addr) + if err == nil && n > 0 { + s.interfaces.GetByIndex(ifIndex).SetFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagMessageSent) } } - s.ipv6conn.WriteTo(buf, &wcm, ipv6Addr) } else { - for _, intf := range s.ifaces { + for _, intf := range s.interfaces { + if !intf.HasFlags(NetInterfaceScopeIPv6, requiredFlags...) { + continue + } switch runtime.GOOS { case "darwin", "ios", "linux": wcm.IfIndex = intf.Index default: - if err := s.ipv6conn.SetMulticastInterface(&intf); err != nil { + if err := s.ipv6conn.SetMulticastInterface(&intf.Interface); err != nil { log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) } } - s.ipv6conn.WriteTo(buf, &wcm, ipv6Addr) + setDeadline(s.writeTimeout, s.ipv6conn) + n, err := s.ipv6conn.WriteTo(buf, &wcm, ipv6Addr) + if err == nil && n > 0 { + intf.SetFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagMessageSent) + } } } } diff --git a/utils.go b/utils.go index 106fc6e6..f1c52c8b 100644 --- a/utils.go +++ b/utils.go @@ -1,6 +1,9 @@ package zeroconf -import "strings" +import ( + "strings" + "time" +) func parseSubtypes(service string) (string, []string) { subtypes := strings.Split(service, ",") @@ -11,3 +14,15 @@ func parseSubtypes(service string) (string, []string) { func trimDot(s string) string { return strings.Trim(s, ".") } + +type DeadlineSetter interface { + SetWriteDeadline(time.Time) error +} + +func setDeadline(timeout time.Duration, ds DeadlineSetter) { + if timeout != 0 { + ds.SetWriteDeadline(time.Now().Add(timeout)) + } else { + ds.SetWriteDeadline(time.Time{}) + } +}