From 7f56f9cf646d485d1d1e21bc65583334da2bf6e3 Mon Sep 17 00:00:00 2001 From: Roman Khafizianov Date: Mon, 27 Feb 2023 00:06:17 +0100 Subject: [PATCH 1/6] add socket write timeout; skip bad ifaces --- client.go | 41 ++++++++++----- connection.go | 39 ++++++++------- netinterface.go | 71 ++++++++++++++++++++++++++ server.go | 130 ++++++++++++++++++++++++++++++++++-------------- utils.go | 17 ++++++- 5 files changed, 228 insertions(+), 70 deletions(-) create mode 100644 netinterface.go diff --git a/client.go b/client.go index c0b2cae1..6a394c14 100644 --- a/client.go +++ b/client.go @@ -28,18 +28,23 @@ const ( IPv4AndIPv6 = IPv4 | IPv6 // default option ) -var initialQueryInterval = 4 * time.Second +var ( + initialQueryInterval = 4 * time.Second + defaultClientWriteTimeout = 10 * time.Second +) // Client structure encapsulates both IPv4/IPv6 UDP connections. type client struct { - ipv4conn *ipv4.PacketConn - ipv6conn *ipv6.PacketConn - ifaces []net.Interface + ipv4conn *ipv4.PacketConn + ipv6conn *ipv6.PacketConn + ifaces []net.Interface + writeTimeout time.Duration } type clientOpts struct { - listenOn IPType - ifaces []net.Interface + listenOn IPType + ifaces []net.Interface + writeTimeout time.Duration } // ClientOption fills the option struct to configure intefaces, etc. @@ -63,6 +68,13 @@ func SelectIfaces(ifaces []net.Interface) ClientOption { } } +// ClientWriteTimeout sets timeout for writing to the socket +func ClientWriteTimeout(duration time.Duration) ClientOption { + return func(o *clientOpts) { + o.writeTimeout = duration + } +} + // Browse for all services of a given type in a given domain. // Received entries are sent on the entries channel. // It blocks until the context is canceled (or an error occurs). @@ -100,7 +112,8 @@ func Lookup(ctx context.Context, instance, service, domain string, entries chan< func applyOpts(options ...ClientOption) clientOpts { // Apply default configuration and load supplied options. var conf = clientOpts{ - listenOn: IPv4AndIPv6, + listenOn: IPv4AndIPv6, + writeTimeout: defaultClientWriteTimeout, } for _, o := range options { if o != nil { @@ -137,11 +150,12 @@ func newClient(opts clientOpts) (*client, error) { if len(ifaces) == 0 { ifaces = listMulticastInterfaces() } + ifaceList := NewInterfaceList(ifaces) // IPv4 interfaces var ipv4conn *ipv4.PacketConn if (opts.listenOn & IPv4) > 0 { var err error - ipv4conn, err = joinUdp4Multicast(ifaces) + ipv4conn, err = joinUdp4Multicast(ifaceList) if err != nil { return nil, err } @@ -150,16 +164,17 @@ func newClient(opts clientOpts) (*client, error) { var ipv6conn *ipv6.PacketConn if (opts.listenOn & IPv6) > 0 { var err error - ipv6conn, err = joinUdp6Multicast(ifaces) + ipv6conn, err = joinUdp6Multicast(ifaceList) if err != nil { return nil, err } } return &client{ - ipv4conn: ipv4conn, - ipv6conn: ipv6conn, - ifaces: ifaces, + ipv4conn: ipv4conn, + ipv6conn: ipv6conn, + ifaces: ifaces, + writeTimeout: opts.writeTimeout, }, nil } @@ -451,6 +466,7 @@ func (c *client) sendQuery(msg *dns.Msg) error { log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) } } + setDeadline(c.writeTimeout, c.ipv4conn) c.ipv4conn.WriteTo(buf, &wcm, ipv4Addr) } } @@ -468,6 +484,7 @@ func (c *client) sendQuery(msg *dns.Msg) error { log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) } } + setDeadline(c.writeTimeout, c.ipv6conn) c.ipv6conn.WriteTo(buf, &wcm, ipv6Addr) } } diff --git a/connection.go b/connection.go index a0936d9e..07a9123a 100644 --- a/connection.go +++ b/connection.go @@ -35,7 +35,11 @@ var ( } ) -func joinUdp6Multicast(interfaces []net.Interface) (*ipv6.PacketConn, error) { +func joinUdp6Multicast(interfaces []*NetInterface) (*ipv6.PacketConn, error) { + if len(interfaces) == 0 { + return nil, fmt.Errorf("no interfaces to join multicast on") + } + udpConn, err := net.ListenUDP("udp6", mdnsWildcardAddrIPv6) if err != nil { return nil, err @@ -45,19 +49,15 @@ func joinUdp6Multicast(interfaces []net.Interface) (*ipv6.PacketConn, error) { pkConn := ipv6.NewPacketConn(udpConn) pkConn.SetControlMessage(ipv6.FlagInterface, true) - if len(interfaces) == 0 { - interfaces = listMulticastInterfaces() - } // log.Println("Using multicast interfaces: ", interfaces) - - var failedJoins int + var anySucceeded bool for _, iface := range interfaces { - if err := pkConn.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { - // log.Println("Udp6 JoinGroup failed for iface ", iface) - failedJoins++ + if err := pkConn.JoinGroup(&iface.Interface, &net.UDPAddr{IP: mdnsGroupIPv6}); err == nil { + iface.SetFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagJoined) + anySucceeded = true } } - if failedJoins == len(interfaces) { + if !anySucceeded { pkConn.Close() return nil, fmt.Errorf("udp6: failed to join any of these interfaces: %v", interfaces) } @@ -67,7 +67,11 @@ func joinUdp6Multicast(interfaces []net.Interface) (*ipv6.PacketConn, error) { return pkConn, nil } -func joinUdp4Multicast(interfaces []net.Interface) (*ipv4.PacketConn, error) { +func joinUdp4Multicast(interfaces []*NetInterface) (*ipv4.PacketConn, error) { + if len(interfaces) == 0 { + return nil, fmt.Errorf("no interfaces to join multicast on") + } + udpConn, err := net.ListenUDP("udp4", mdnsWildcardAddrIPv4) if err != nil { // log.Printf("[ERR] bonjour: Failed to bind to udp4 mutlicast: %v", err) @@ -78,19 +82,16 @@ func joinUdp4Multicast(interfaces []net.Interface) (*ipv4.PacketConn, error) { pkConn := ipv4.NewPacketConn(udpConn) pkConn.SetControlMessage(ipv4.FlagInterface, true) - if len(interfaces) == 0 { - interfaces = listMulticastInterfaces() - } // log.Println("Using multicast interfaces: ", interfaces) + var anySucceed bool - var failedJoins int for _, iface := range interfaces { - if err := pkConn.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { - // log.Println("Udp4 JoinGroup failed for iface ", iface) - failedJoins++ + if err := pkConn.JoinGroup(&iface.Interface, &net.UDPAddr{IP: mdnsGroupIPv4}); err == nil { + anySucceed = true + iface.SetFlag(NetInterfaceScopeIPv4, NetInterfaceStateFlagJoined) } } - if failedJoins == len(interfaces) { + if !anySucceed { pkConn.Close() return nil, fmt.Errorf("udp4: failed to join any of these interfaces: %v", interfaces) } diff --git a/netinterface.go b/netinterface.go new file mode 100644 index 00000000..1c92c4f5 --- /dev/null +++ b/netinterface.go @@ -0,0 +1,71 @@ +package zeroconf + +import ( + "net" +) + +type NetInterface struct { + net.Interface + stateIPv4 NetInterfaceStateFlag + stateIPv6 NetInterfaceStateFlag +} + +type NetInterfaceScope int + +const ( + NetInterfaceScopeIPv4 NetInterfaceScope = iota + NetInterfaceScopeIPv6 +) + +type NetInterfaceList []*NetInterface + +type NetInterfaceStateFlag uint8 + +const ( + NetInterfaceStateFlagJoined NetInterfaceStateFlag = 1 << iota + NetInterfaceStateFlagRegistered +) + +func (i *NetInterface) HasFlags(scope NetInterfaceScope, flags ...NetInterfaceStateFlag) bool { + for _, flag := range flags { + if !i.HasFlag(scope, flag) { + return false + } + } + return true +} + +func (i *NetInterface) HasFlag(scope NetInterfaceScope, flag NetInterfaceStateFlag) bool { + if scope == NetInterfaceScopeIPv4 { + return i.stateIPv4&flag != 0 + } else if scope == NetInterfaceScopeIPv6 { + return i.stateIPv6&flag != 0 + } + return false +} + +func (i *NetInterface) SetFlag(scope NetInterfaceScope, flag NetInterfaceStateFlag) { + if scope == NetInterfaceScopeIPv4 { + i.stateIPv4 |= flag + return + } else if scope == NetInterfaceScopeIPv6 { + i.stateIPv6 |= flag + return + } +} + +func (list NetInterfaceList) GetByIndex(index int) *NetInterface { + for _, iface := range list { + if iface.Index == index { + return iface + } + } + return nil +} + +func NewInterfaceList(ifaces []net.Interface) (list NetInterfaceList) { + for i := range ifaces { + list = append(list, &NetInterface{Interface: ifaces[i]}) + } + return +} diff --git a/server.go b/server.go index 2dbf536e..6c1cef6e 100644 --- a/server.go +++ b/server.go @@ -21,16 +21,21 @@ const ( multicastRepetitions = 2 ) -var defaultTTL uint32 = 3200 +var ( + defaultTTL uint32 = 3200 + defaultServerWriteTimeout = 10 * time.Second +) type serverOpts struct { - ttl uint32 + ttl uint32 + writeTimeout time.Duration } func applyServerOpts(options ...ServerOption) serverOpts { // Apply default configuration and load supplied options. var conf = serverOpts{ - ttl: defaultTTL, + ttl: defaultTTL, + writeTimeout: defaultServerWriteTimeout, } for _, o := range options { if o != nil { @@ -50,6 +55,13 @@ func TTL(ttl uint32) ServerOption { } } +// WriteTimeout sets timeout for writing to the socket +func WriteTimeout(duration time.Duration) ServerOption { + return func(o *serverOpts) { + o.writeTimeout = duration + } +} + // Register a service by given arguments. This call will take the system's hostname // and lookup IP by that hostname. func Register(instance, service, domain string, port int, text []string, ifaces []net.Interface, opts ...ServerOption) (*Server, error) { @@ -169,28 +181,37 @@ const ( // Server structure encapsulates both IPv4/IPv6 UDP connections type Server struct { - service *ServiceEntry - ipv4conn *ipv4.PacketConn - ipv6conn *ipv6.PacketConn - ifaces []net.Interface + service *ServiceEntry + ipv4conn *ipv4.PacketConn + ipv6conn *ipv6.PacketConn + interfaces NetInterfaceList + // store if any write to iface was successful shouldShutdown chan struct{} shutdownLock sync.Mutex refCount sync.WaitGroup isShutdown bool ttl uint32 + writeTimeout time.Duration } // Constructs server structure func newServer(ifaces []net.Interface, opts serverOpts) (*Server, error) { - ipv4conn, err4 := joinUdp4Multicast(ifaces) + if len(ifaces) == 0 { + ifaces = listMulticastInterfaces() + } + + ifaceList := NewInterfaceList(ifaces) + ipv4conn, err4 := joinUdp4Multicast(ifaceList) if err4 != nil { log.Printf("[zeroconf] no suitable IPv4 interface: %s", err4.Error()) } - ipv6conn, err6 := joinUdp6Multicast(ifaces) + + ipv6conn, err6 := joinUdp6Multicast(ifaceList) if err6 != nil { log.Printf("[zeroconf] no suitable IPv6 interface: %s", err6.Error()) } + if err4 != nil && err6 != nil { // No supported interface left. return nil, fmt.Errorf("no supported interface") @@ -199,8 +220,9 @@ func newServer(ifaces []net.Interface, opts serverOpts) (*Server, error) { s := &Server{ ipv4conn: ipv4conn, ipv6conn: ipv6conn, - ifaces: ifaces, + interfaces: ifaceList, ttl: opts.ttl, + writeTimeout: opts.writeTimeout, shouldShutdown: make(chan struct{}), } @@ -549,7 +571,7 @@ func (s *Server) serviceTypeName(resp *dns.Msg, ttl uint32) { } // Perform probing & announcement -//TODO: implement a proper probing & conflict resolution +// TODO: implement a proper probing & conflict resolution func (s *Server) probe() { defer s.refCount.Done() @@ -590,7 +612,7 @@ func (s *Server) probe() { return } for i := 0; i < 3; i++ { - if err := s.multicastResponse(q, 0); err != nil { + if err := s.multicastResponse(q, 0, NetInterfaceStateFlagJoined); err != nil { log.Println("[ERR] zeroconf: failed to send probe:", err.Error()) } timer.Reset(250 * time.Millisecond) @@ -609,7 +631,7 @@ func (s *Server) probe() { // at least a factor of two with every response sent. timeout := time.Second for i := 0; i < multicastRepetitions; i++ { - for _, intf := range s.ifaces { + for _, intf := range s.interfaces { resp := new(dns.Msg) resp.MsgHdr.Response = true // TODO: make response authoritative if we are the publisher @@ -617,7 +639,7 @@ func (s *Server) probe() { resp.Answer = []dns.RR{} resp.Extra = []dns.RR{} s.composeLookupAnswers(resp, s.ttl, intf.Index, true) - if err := s.multicastResponse(resp, intf.Index); err != nil { + if err := s.multicastResponse(resp, intf.Index, NetInterfaceStateFlagJoined); err != nil { log.Println("[ERR] zeroconf: failed to send announcement:", err.Error()) } } @@ -656,7 +678,9 @@ func (s *Server) unregister() error { resp.Answer = []dns.RR{} resp.Extra = []dns.RR{} s.composeLookupAnswers(resp, 0, 0, true) - return s.multicastResponse(resp, 0) + // cleanup ifaces we have NEVER written successfully. No need to unregister them + + return s.multicastResponse(resp, 0, NetInterfaceStateFlagJoined, NetInterfaceStateFlagRegistered) } func (s *Server) appendAddrs(list []dns.RR, ttl uint32, ifIndex int, flushCache bool) []dns.RR { @@ -741,6 +765,7 @@ func (s *Server) unicastResponse(resp *dns.Msg, ifIndex int, from net.Addr) erro if ifIndex != 0 { var wcm ipv4.ControlMessage wcm.IfIndex = ifIndex + setDeadline(s.writeTimeout, s.ipv4conn) _, err = s.ipv4conn.WriteTo(buf, &wcm, addr) } else { _, err = s.ipv4conn.WriteTo(buf, nil, addr) @@ -750,6 +775,7 @@ func (s *Server) unicastResponse(resp *dns.Msg, ifIndex int, from net.Addr) erro if ifIndex != 0 { var wcm ipv6.ControlMessage wcm.IfIndex = ifIndex + setDeadline(s.writeTimeout, s.ipv4conn) _, err = s.ipv6conn.WriteTo(buf, &wcm, addr) } else { _, err = s.ipv6conn.WriteTo(buf, nil, addr) @@ -759,7 +785,7 @@ func (s *Server) unicastResponse(resp *dns.Msg, ifIndex int, from net.Addr) erro } // multicastResponse is used to send a multicast response packet -func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int) error { +func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int, flags ...NetInterfaceStateFlag) error { buf, err := msg.Pack() if err != nil { return fmt.Errorf("failed to pack msg %v: %w", msg, err) @@ -770,27 +796,42 @@ func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int) error { // On Windows, the ControlMessage for ReadFrom and WriteTo methods of PacketConn is not implemented. var wcm ipv4.ControlMessage if ifIndex != 0 { - switch runtime.GOOS { - case "darwin", "ios", "linux": - wcm.IfIndex = ifIndex - default: - iface, _ := net.InterfaceByIndex(ifIndex) - if err := s.ipv4conn.SetMulticastInterface(iface); err != nil { - log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) + if v := s.interfaces.GetByIndex(ifIndex); v != nil && v.HasFlags(NetInterfaceScopeIPv4, flags...) { + switch runtime.GOOS { + case "darwin", "ios", "linux": + wcm.IfIndex = ifIndex + default: + iface, _ := net.InterfaceByIndex(ifIndex) + if err := s.ipv4conn.SetMulticastInterface(iface); err != nil { + log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) + } + } + setDeadline(s.writeTimeout, s.ipv4conn) + n, err := s.ipv4conn.WriteTo(buf, &wcm, ipv4Addr) + if err == nil && n > 0 { + s.interfaces.GetByIndex(ifIndex).SetFlag(NetInterfaceScopeIPv4, NetInterfaceStateFlagRegistered) } } - s.ipv4conn.WriteTo(buf, &wcm, ipv4Addr) + } else { - for _, intf := range s.ifaces { + for _, intf := range s.interfaces { + if !intf.HasFlags(NetInterfaceScopeIPv4, flags...) { + continue + } switch runtime.GOOS { case "darwin", "ios", "linux": wcm.IfIndex = intf.Index default: - if err := s.ipv4conn.SetMulticastInterface(&intf); err != nil { + if err := s.ipv4conn.SetMulticastInterface(&intf.Interface); err != nil { log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) } } - s.ipv4conn.WriteTo(buf, &wcm, ipv4Addr) + setDeadline(s.writeTimeout, s.ipv4conn) + n, err := s.ipv4conn.WriteTo(buf, &wcm, ipv4Addr) + if err == nil && n > 0 { + intf.SetFlag(NetInterfaceScopeIPv4, NetInterfaceStateFlagRegistered) + } + //s.ifaceOk[intf.Index] = s.ifaceOk[intf.Index] || n > 0 } } } @@ -801,27 +842,40 @@ func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int) error { // On Windows, the ControlMessage for ReadFrom and WriteTo methods of PacketConn is not implemented. var wcm ipv6.ControlMessage if ifIndex != 0 { - switch runtime.GOOS { - case "darwin", "ios", "linux": - wcm.IfIndex = ifIndex - default: - iface, _ := net.InterfaceByIndex(ifIndex) - if err := s.ipv6conn.SetMulticastInterface(iface); err != nil { - log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) + if v := s.interfaces.GetByIndex(ifIndex); v != nil && v.HasFlags(NetInterfaceScopeIPv6, flags...) { + switch runtime.GOOS { + case "darwin", "ios", "linux": + wcm.IfIndex = ifIndex + default: + iface, _ := net.InterfaceByIndex(ifIndex) + if err := s.ipv6conn.SetMulticastInterface(iface); err != nil { + log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) + } + } + setDeadline(s.writeTimeout, s.ipv6conn) + n, err := s.ipv6conn.WriteTo(buf, &wcm, ipv6Addr) + if err == nil && n > 0 { + s.interfaces.GetByIndex(ifIndex).SetFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagRegistered) } } - s.ipv6conn.WriteTo(buf, &wcm, ipv6Addr) } else { - for _, intf := range s.ifaces { + for _, intf := range s.interfaces { + if !intf.HasFlags(NetInterfaceScopeIPv6, flags...) { + continue + } switch runtime.GOOS { case "darwin", "ios", "linux": wcm.IfIndex = intf.Index default: - if err := s.ipv6conn.SetMulticastInterface(&intf); err != nil { + if err := s.ipv6conn.SetMulticastInterface(&intf.Interface); err != nil { log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) } } - s.ipv6conn.WriteTo(buf, &wcm, ipv6Addr) + setDeadline(s.writeTimeout, s.ipv6conn) + n, err := s.ipv6conn.WriteTo(buf, &wcm, ipv6Addr) + if err == nil && n > 0 { + intf.SetFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagRegistered) + } } } } diff --git a/utils.go b/utils.go index 106fc6e6..f1c52c8b 100644 --- a/utils.go +++ b/utils.go @@ -1,6 +1,9 @@ package zeroconf -import "strings" +import ( + "strings" + "time" +) func parseSubtypes(service string) (string, []string) { subtypes := strings.Split(service, ",") @@ -11,3 +14,15 @@ func parseSubtypes(service string) (string, []string) { func trimDot(s string) string { return strings.Trim(s, ".") } + +type DeadlineSetter interface { + SetWriteDeadline(time.Time) error +} + +func setDeadline(timeout time.Duration, ds DeadlineSetter) { + if timeout != 0 { + ds.SetWriteDeadline(time.Now().Add(timeout)) + } else { + ds.SetWriteDeadline(time.Time{}) + } +} From 9dba2bace95a61fbab00e3b131fb7c3bf72d3dd3 Mon Sep 17 00:00:00 2001 From: Roman Khafizianov Date: Thu, 2 Mar 2023 11:11:01 +0100 Subject: [PATCH 2/6] serverOpts: make ipv6 and ipv4 options same like client --- connection.go | 4 +-- netinterface.go | 4 +-- server.go | 75 ++++++++++++++++++++++++++----------------------- 3 files changed, 44 insertions(+), 39 deletions(-) diff --git a/connection.go b/connection.go index 07a9123a..ec5075d1 100644 --- a/connection.go +++ b/connection.go @@ -53,7 +53,7 @@ func joinUdp6Multicast(interfaces []*NetInterface) (*ipv6.PacketConn, error) { var anySucceeded bool for _, iface := range interfaces { if err := pkConn.JoinGroup(&iface.Interface, &net.UDPAddr{IP: mdnsGroupIPv6}); err == nil { - iface.SetFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagJoined) + iface.SetFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagMulticastJoined) anySucceeded = true } } @@ -88,7 +88,7 @@ func joinUdp4Multicast(interfaces []*NetInterface) (*ipv4.PacketConn, error) { for _, iface := range interfaces { if err := pkConn.JoinGroup(&iface.Interface, &net.UDPAddr{IP: mdnsGroupIPv4}); err == nil { anySucceed = true - iface.SetFlag(NetInterfaceScopeIPv4, NetInterfaceStateFlagJoined) + iface.SetFlag(NetInterfaceScopeIPv4, NetInterfaceStateFlagMulticastJoined) } } if !anySucceed { diff --git a/netinterface.go b/netinterface.go index 1c92c4f5..bc9484df 100644 --- a/netinterface.go +++ b/netinterface.go @@ -22,8 +22,8 @@ type NetInterfaceList []*NetInterface type NetInterfaceStateFlag uint8 const ( - NetInterfaceStateFlagJoined NetInterfaceStateFlag = 1 << iota - NetInterfaceStateFlagRegistered + NetInterfaceStateFlagMulticastJoined NetInterfaceStateFlag = 1 << iota // we have joined the multicast group on this interface + NetInterfaceStateFlagMessageSent // we have successfully sent at least one message on this interface ) func (i *NetInterface) HasFlags(scope NetInterfaceScope, flags ...NetInterfaceStateFlag) bool { diff --git a/server.go b/server.go index 6c1cef6e..61948aa9 100644 --- a/server.go +++ b/server.go @@ -28,12 +28,14 @@ var ( type serverOpts struct { ttl uint32 + listenOn IPType writeTimeout time.Duration } func applyServerOpts(options ...ServerOption) serverOpts { // Apply default configuration and load supplied options. var conf = serverOpts{ + listenOn: IPv4AndIPv6, ttl: defaultTTL, writeTimeout: defaultServerWriteTimeout, } @@ -62,6 +64,14 @@ func WriteTimeout(duration time.Duration) ServerOption { } } +// ServerSelectIPTraffic selects the type of IP packets (IPv4, IPv6, or both) this +// instance listens for. +func ServerSelectIPTraffic(t IPType) ServerOption { + return func(o *serverOpts) { + o.listenOn = t + } +} + // Register a service by given arguments. This call will take the system's hostname // and lookup IP by that hostname. func Register(instance, service, domain string, port int, text []string, ifaces []net.Interface, opts ...ServerOption) (*Server, error) { @@ -98,16 +108,6 @@ func Register(instance, service, domain string, port int, text []string, ifaces ifaces = listMulticastInterfaces() } - for _, iface := range ifaces { - v4, v6 := addrsForInterface(&iface) - entry.AddrIPv4 = append(entry.AddrIPv4, v4...) - entry.AddrIPv6 = append(entry.AddrIPv6, v6...) - } - - if entry.AddrIPv4 == nil && entry.AddrIPv6 == nil { - return nil, fmt.Errorf("could not determine host IP addresses") - } - s, err := newServer(ifaces, applyServerOpts(opts...)) if err != nil { return nil, err @@ -202,30 +202,33 @@ func newServer(ifaces []net.Interface, opts serverOpts) (*Server, error) { } ifaceList := NewInterfaceList(ifaces) - ipv4conn, err4 := joinUdp4Multicast(ifaceList) - if err4 != nil { - log.Printf("[zeroconf] no suitable IPv4 interface: %s", err4.Error()) - } - - ipv6conn, err6 := joinUdp6Multicast(ifaceList) - if err6 != nil { - log.Printf("[zeroconf] no suitable IPv6 interface: %s", err6.Error()) - } - - if err4 != nil && err6 != nil { - // No supported interface left. - return nil, fmt.Errorf("no supported interface") - } s := &Server{ - ipv4conn: ipv4conn, - ipv6conn: ipv6conn, interfaces: ifaceList, ttl: opts.ttl, writeTimeout: opts.writeTimeout, shouldShutdown: make(chan struct{}), } + var err error + if (opts.listenOn & IPv4) > 0 { + s.ipv4conn, err = joinUdp4Multicast(ifaceList) + if err != nil { + log.Printf("[zeroconf] no suitable IPv4 interface: %s", err.Error()) + } + } + + if (opts.listenOn & IPv6) > 0 { + s.ipv6conn, err = joinUdp6Multicast(ifaceList) + if err != nil { + log.Printf("[zeroconf] no suitable IPv6 interface: %s", err.Error()) + } + + } + + if s.ipv6conn == nil && s.ipv4conn == nil { + return nil, fmt.Errorf("no supported interface") + } return s, nil } @@ -612,7 +615,7 @@ func (s *Server) probe() { return } for i := 0; i < 3; i++ { - if err := s.multicastResponse(q, 0, NetInterfaceStateFlagJoined); err != nil { + if err := s.multicastResponse(q, 0, NetInterfaceStateFlagMulticastJoined); err != nil { log.Println("[ERR] zeroconf: failed to send probe:", err.Error()) } timer.Reset(250 * time.Millisecond) @@ -639,7 +642,7 @@ func (s *Server) probe() { resp.Answer = []dns.RR{} resp.Extra = []dns.RR{} s.composeLookupAnswers(resp, s.ttl, intf.Index, true) - if err := s.multicastResponse(resp, intf.Index, NetInterfaceStateFlagJoined); err != nil { + if err := s.multicastResponse(resp, intf.Index, NetInterfaceStateFlagMulticastJoined); err != nil { log.Println("[ERR] zeroconf: failed to send announcement:", err.Error()) } } @@ -680,7 +683,7 @@ func (s *Server) unregister() error { s.composeLookupAnswers(resp, 0, 0, true) // cleanup ifaces we have NEVER written successfully. No need to unregister them - return s.multicastResponse(resp, 0, NetInterfaceStateFlagJoined, NetInterfaceStateFlagRegistered) + return s.multicastResponse(resp, 0, NetInterfaceStateFlagMulticastJoined, NetInterfaceStateFlagMessageSent) } func (s *Server) appendAddrs(list []dns.RR, ttl uint32, ifIndex int, flushCache bool) []dns.RR { @@ -761,7 +764,7 @@ func (s *Server) unicastResponse(resp *dns.Msg, ifIndex int, from net.Addr) erro return err } addr := from.(*net.UDPAddr) - if addr.IP.To4() != nil { + if addr.IP.To4() != nil && s.ipv4conn != nil { if ifIndex != 0 { var wcm ipv4.ControlMessage wcm.IfIndex = ifIndex @@ -771,7 +774,7 @@ func (s *Server) unicastResponse(resp *dns.Msg, ifIndex int, from net.Addr) erro _, err = s.ipv4conn.WriteTo(buf, nil, addr) } return err - } else { + } else if s.ipv6conn != nil { if ifIndex != 0 { var wcm ipv6.ControlMessage wcm.IfIndex = ifIndex @@ -781,6 +784,8 @@ func (s *Server) unicastResponse(resp *dns.Msg, ifIndex int, from net.Addr) erro _, err = s.ipv6conn.WriteTo(buf, nil, addr) } return err + } else { + return fmt.Errorf("no suitable interface") } } @@ -809,7 +814,7 @@ func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int, flags ...NetInterf setDeadline(s.writeTimeout, s.ipv4conn) n, err := s.ipv4conn.WriteTo(buf, &wcm, ipv4Addr) if err == nil && n > 0 { - s.interfaces.GetByIndex(ifIndex).SetFlag(NetInterfaceScopeIPv4, NetInterfaceStateFlagRegistered) + s.interfaces.GetByIndex(ifIndex).SetFlag(NetInterfaceScopeIPv4, NetInterfaceStateFlagMessageSent) } } @@ -829,7 +834,7 @@ func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int, flags ...NetInterf setDeadline(s.writeTimeout, s.ipv4conn) n, err := s.ipv4conn.WriteTo(buf, &wcm, ipv4Addr) if err == nil && n > 0 { - intf.SetFlag(NetInterfaceScopeIPv4, NetInterfaceStateFlagRegistered) + intf.SetFlag(NetInterfaceScopeIPv4, NetInterfaceStateFlagMessageSent) } //s.ifaceOk[intf.Index] = s.ifaceOk[intf.Index] || n > 0 } @@ -855,7 +860,7 @@ func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int, flags ...NetInterf setDeadline(s.writeTimeout, s.ipv6conn) n, err := s.ipv6conn.WriteTo(buf, &wcm, ipv6Addr) if err == nil && n > 0 { - s.interfaces.GetByIndex(ifIndex).SetFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagRegistered) + s.interfaces.GetByIndex(ifIndex).SetFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagMessageSent) } } } else { @@ -874,7 +879,7 @@ func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int, flags ...NetInterf setDeadline(s.writeTimeout, s.ipv6conn) n, err := s.ipv6conn.WriteTo(buf, &wcm, ipv6Addr) if err == nil && n > 0 { - intf.SetFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagRegistered) + intf.SetFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagMessageSent) } } } From fa3ab41a494189d4bfdcca58cc158b0c18935a6f Mon Sep 17 00:00:00 2001 From: Roman Khafizianov Date: Fri, 3 Mar 2023 16:13:30 +0100 Subject: [PATCH 3/6] add same flags logic to the client --- client.go | 39 ++++++++++++++++++++++++++------------- server.go | 15 ++++++++------- 2 files changed, 34 insertions(+), 20 deletions(-) diff --git a/client.go b/client.go index 6a394c14..320ceb65 100644 --- a/client.go +++ b/client.go @@ -37,7 +37,7 @@ var ( type client struct { ipv4conn *ipv4.PacketConn ipv6conn *ipv6.PacketConn - ifaces []net.Interface + interfaces NetInterfaceList writeTimeout time.Duration } @@ -173,7 +173,7 @@ func newClient(opts clientOpts) (*client, error) { return &client{ ipv4conn: ipv4conn, ipv6conn: ipv6conn, - ifaces: ifaces, + interfaces: ifaceList, writeTimeout: opts.writeTimeout, }, nil } @@ -443,31 +443,38 @@ func (c *client) query(params *lookupParams) error { m.SetQuestion(serviceName, dns.TypePTR) } m.RecursionDesired = false - return c.sendQuery(m) + // only send multicast queries to interfaces that we have joined + return c.sendQuery(m, NetInterfaceStateFlagMulticastJoined) } // Pack the dns.Msg and write to available connections (multicast) -func (c *client) sendQuery(msg *dns.Msg) error { +func (c *client) sendQuery(msg *dns.Msg, requiredFlags ...NetInterfaceStateFlag) error { buf, err := msg.Pack() if err != nil { - return err + return fmt.Errorf("failed to pack msg %v: %w", msg, err) } if c.ipv4conn != nil { // See https://pkg.go.dev/golang.org/x/net/ipv4#pkg-note-BUG // As of Golang 1.18.4 // On Windows, the ControlMessage for ReadFrom and WriteTo methods of PacketConn is not implemented. var wcm ipv4.ControlMessage - for ifi := range c.ifaces { + for _, intf := range c.interfaces { + if !intf.HasFlags(NetInterfaceScopeIPv4, requiredFlags...) { + continue + } switch runtime.GOOS { case "darwin", "ios", "linux": - wcm.IfIndex = c.ifaces[ifi].Index + wcm.IfIndex = intf.Index default: - if err := c.ipv4conn.SetMulticastInterface(&c.ifaces[ifi]); err != nil { + if err := c.ipv4conn.SetMulticastInterface(&intf.Interface); err != nil { log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) } } setDeadline(c.writeTimeout, c.ipv4conn) - c.ipv4conn.WriteTo(buf, &wcm, ipv4Addr) + n, err := c.ipv4conn.WriteTo(buf, &wcm, ipv4Addr) + if err == nil && n > 0 { + intf.SetFlag(NetInterfaceScopeIPv4, NetInterfaceStateFlagMessageSent) + } } } if c.ipv6conn != nil { @@ -475,17 +482,23 @@ func (c *client) sendQuery(msg *dns.Msg) error { // As of Golang 1.18.4 // On Windows, the ControlMessage for ReadFrom and WriteTo methods of PacketConn is not implemented. var wcm ipv6.ControlMessage - for ifi := range c.ifaces { + for _, intf := range c.interfaces { + if !intf.HasFlags(NetInterfaceScopeIPv6, requiredFlags...) { + continue + } switch runtime.GOOS { case "darwin", "ios", "linux": - wcm.IfIndex = c.ifaces[ifi].Index + wcm.IfIndex = intf.Index default: - if err := c.ipv6conn.SetMulticastInterface(&c.ifaces[ifi]); err != nil { + if err := c.ipv6conn.SetMulticastInterface(&intf.Interface); err != nil { log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) } } setDeadline(c.writeTimeout, c.ipv6conn) - c.ipv6conn.WriteTo(buf, &wcm, ipv6Addr) + n, err := c.ipv6conn.WriteTo(buf, &wcm, ipv6Addr) + if err == nil && n > 0 { + intf.SetFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagMessageSent) + } } } return nil diff --git a/server.go b/server.go index 61948aa9..5887e6bb 100644 --- a/server.go +++ b/server.go @@ -771,6 +771,7 @@ func (s *Server) unicastResponse(resp *dns.Msg, ifIndex int, from net.Addr) erro setDeadline(s.writeTimeout, s.ipv4conn) _, err = s.ipv4conn.WriteTo(buf, &wcm, addr) } else { + setDeadline(s.writeTimeout, s.ipv4conn) _, err = s.ipv4conn.WriteTo(buf, nil, addr) } return err @@ -778,9 +779,10 @@ func (s *Server) unicastResponse(resp *dns.Msg, ifIndex int, from net.Addr) erro if ifIndex != 0 { var wcm ipv6.ControlMessage wcm.IfIndex = ifIndex - setDeadline(s.writeTimeout, s.ipv4conn) + setDeadline(s.writeTimeout, s.ipv6conn) _, err = s.ipv6conn.WriteTo(buf, &wcm, addr) } else { + setDeadline(s.writeTimeout, s.ipv6conn) _, err = s.ipv6conn.WriteTo(buf, nil, addr) } return err @@ -790,7 +792,7 @@ func (s *Server) unicastResponse(resp *dns.Msg, ifIndex int, from net.Addr) erro } // multicastResponse is used to send a multicast response packet -func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int, flags ...NetInterfaceStateFlag) error { +func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int, requiredFlags ...NetInterfaceStateFlag) error { buf, err := msg.Pack() if err != nil { return fmt.Errorf("failed to pack msg %v: %w", msg, err) @@ -801,7 +803,7 @@ func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int, flags ...NetInterf // On Windows, the ControlMessage for ReadFrom and WriteTo methods of PacketConn is not implemented. var wcm ipv4.ControlMessage if ifIndex != 0 { - if v := s.interfaces.GetByIndex(ifIndex); v != nil && v.HasFlags(NetInterfaceScopeIPv4, flags...) { + if v := s.interfaces.GetByIndex(ifIndex); v != nil && v.HasFlags(NetInterfaceScopeIPv4, requiredFlags...) { switch runtime.GOOS { case "darwin", "ios", "linux": wcm.IfIndex = ifIndex @@ -820,7 +822,7 @@ func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int, flags ...NetInterf } else { for _, intf := range s.interfaces { - if !intf.HasFlags(NetInterfaceScopeIPv4, flags...) { + if !intf.HasFlags(NetInterfaceScopeIPv4, requiredFlags...) { continue } switch runtime.GOOS { @@ -836,7 +838,6 @@ func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int, flags ...NetInterf if err == nil && n > 0 { intf.SetFlag(NetInterfaceScopeIPv4, NetInterfaceStateFlagMessageSent) } - //s.ifaceOk[intf.Index] = s.ifaceOk[intf.Index] || n > 0 } } } @@ -847,7 +848,7 @@ func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int, flags ...NetInterf // On Windows, the ControlMessage for ReadFrom and WriteTo methods of PacketConn is not implemented. var wcm ipv6.ControlMessage if ifIndex != 0 { - if v := s.interfaces.GetByIndex(ifIndex); v != nil && v.HasFlags(NetInterfaceScopeIPv6, flags...) { + if v := s.interfaces.GetByIndex(ifIndex); v != nil && v.HasFlags(NetInterfaceScopeIPv6, requiredFlags...) { switch runtime.GOOS { case "darwin", "ios", "linux": wcm.IfIndex = ifIndex @@ -865,7 +866,7 @@ func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int, flags ...NetInterf } } else { for _, intf := range s.interfaces { - if !intf.HasFlags(NetInterfaceScopeIPv6, flags...) { + if !intf.HasFlags(NetInterfaceScopeIPv6, requiredFlags...) { continue } switch runtime.GOOS { From d85785d4c9e75ca61ad49896073860740e17d198 Mon Sep 17 00:00:00 2001 From: Sergey Date: Tue, 27 Feb 2024 15:16:50 +0100 Subject: [PATCH 4/6] NetInterface: atomic set/has flag operations --- netinterface.go | 24 ++++++++++----- netinterface_test.go | 73 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 7 deletions(-) create mode 100644 netinterface_test.go diff --git a/netinterface.go b/netinterface.go index bc9484df..9fbc3b9b 100644 --- a/netinterface.go +++ b/netinterface.go @@ -2,6 +2,7 @@ package zeroconf import ( "net" + "sync/atomic" ) type NetInterface struct { @@ -19,7 +20,7 @@ const ( type NetInterfaceList []*NetInterface -type NetInterfaceStateFlag uint8 +type NetInterfaceStateFlag uint32 const ( NetInterfaceStateFlagMulticastJoined NetInterfaceStateFlag = 1 << iota // we have joined the multicast group on this interface @@ -35,22 +36,31 @@ func (i *NetInterface) HasFlags(scope NetInterfaceScope, flags ...NetInterfaceSt return true } +func (i *NetInterface) loadFlag(address *NetInterfaceStateFlag) NetInterfaceStateFlag { + return NetInterfaceStateFlag(atomic.LoadUint32((*uint32)(address))) +} + func (i *NetInterface) HasFlag(scope NetInterfaceScope, flag NetInterfaceStateFlag) bool { if scope == NetInterfaceScopeIPv4 { - return i.stateIPv4&flag != 0 + return NetInterfaceStateFlag(i.loadFlag(&i.stateIPv4)&flag) != 0 } else if scope == NetInterfaceScopeIPv6 { - return i.stateIPv6&flag != 0 + return NetInterfaceStateFlag(i.loadFlag(&i.stateIPv6)&flag) != 0 } return false } func (i *NetInterface) SetFlag(scope NetInterfaceScope, flag NetInterfaceStateFlag) { if scope == NetInterfaceScopeIPv4 { - i.stateIPv4 |= flag - return + i.setFlag(&i.stateIPv4, flag) } else if scope == NetInterfaceScopeIPv6 { - i.stateIPv6 |= flag - return + i.setFlag(&i.stateIPv6, flag) + } +} + +func (i *NetInterface) setFlag(address *NetInterfaceStateFlag, flag NetInterfaceStateFlag) { + // If (loaded value | flag) != previously loaded value, then repeat the operation + // This is the way to ensure atomicity of the operation + for !atomic.CompareAndSwapUint32((*uint32)(address), uint32(i.loadFlag(address)), uint32(i.loadFlag(address)|flag)) { } } diff --git a/netinterface_test.go b/netinterface_test.go new file mode 100644 index 00000000..fc1ea4da --- /dev/null +++ b/netinterface_test.go @@ -0,0 +1,73 @@ +package zeroconf + +import ( + "sync" + "testing" + "time" +) + +func TestSetFlagSimple(t *testing.T) { + t.Run("ipv4", func(t *testing.T) { + iface := &NetInterface{} + iface.SetFlag(NetInterfaceScopeIPv4, NetInterfaceStateFlagMulticastJoined) + if !iface.HasFlag(NetInterfaceScopeIPv4, NetInterfaceStateFlagMulticastJoined) { + t.Error("expect true") + } + if iface.HasFlags(NetInterfaceScopeIPv4, NetInterfaceStateFlagMulticastJoined, NetInterfaceStateFlagMessageSent) { + t.Error("expect false") + } + + iface.SetFlag(NetInterfaceScopeIPv4, NetInterfaceStateFlagMessageSent) + if !iface.HasFlags(NetInterfaceScopeIPv4, NetInterfaceStateFlagMulticastJoined, NetInterfaceStateFlagMessageSent) { + t.Error("expect true") + } + }) + + t.Run("ipv6", func(t *testing.T) { + iface := &NetInterface{} + iface.SetFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagMessageSent) + if !iface.HasFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagMessageSent) { + t.Error("expect true") + } + if iface.HasFlags(NetInterfaceScopeIPv6, NetInterfaceStateFlagMulticastJoined, NetInterfaceStateFlagMessageSent) { + t.Error("expect false") + } + + iface.SetFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagMulticastJoined) + if !iface.HasFlags(NetInterfaceScopeIPv6, NetInterfaceStateFlagMulticastJoined, NetInterfaceStateFlagMessageSent) { + t.Error("expect true") + } + }) +} + +func TestSetFlagConcurrent(t *testing.T) { + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + iface := &NetInterface{} + wg.Add(2) + go func() { + defer wg.Done() + + for j := 0; j < 10; j++ { + if j%2 == 0 { + iface.SetFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagMulticastJoined) + } else { + iface.SetFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagMessageSent) + } + } + }() + go func() { + defer wg.Done() + + var eventuallyOk bool + for j := 0; j < 10; j++ { + eventuallyOk = iface.HasFlags(NetInterfaceScopeIPv6, NetInterfaceStateFlagMulticastJoined, NetInterfaceStateFlagMessageSent) + time.Sleep(time.Millisecond) + } + if !eventuallyOk { + t.Error("expect true") + } + }() + } + wg.Wait() +} From 7248a87dad0f951814348fef99ddb7dc8b47e9ea Mon Sep 17 00:00:00 2001 From: Sergey Date: Tue, 27 Feb 2024 16:02:09 +0100 Subject: [PATCH 5/6] NetInterface: fix comment --- netinterface.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/netinterface.go b/netinterface.go index 9fbc3b9b..dac545a7 100644 --- a/netinterface.go +++ b/netinterface.go @@ -58,7 +58,8 @@ func (i *NetInterface) SetFlag(scope NetInterfaceScope, flag NetInterfaceStateFl } func (i *NetInterface) setFlag(address *NetInterfaceStateFlag, flag NetInterfaceStateFlag) { - // If (loaded value | flag) != previously loaded value, then repeat the operation + // If atomic value != previously loaded value, then repeat the operation + // If they are equal, then we can safely set the new value // This is the way to ensure atomicity of the operation for !atomic.CompareAndSwapUint32((*uint32)(address), uint32(i.loadFlag(address)), uint32(i.loadFlag(address)|flag)) { } From b8673467ef473c6c88f4370256d3a0f9273803ac Mon Sep 17 00:00:00 2001 From: Sergey Date: Wed, 28 Feb 2024 12:25:16 +0100 Subject: [PATCH 6/6] NetInterface: address code review --- netinterface.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/netinterface.go b/netinterface.go index dac545a7..b97d7211 100644 --- a/netinterface.go +++ b/netinterface.go @@ -61,7 +61,11 @@ func (i *NetInterface) setFlag(address *NetInterfaceStateFlag, flag NetInterface // If atomic value != previously loaded value, then repeat the operation // If they are equal, then we can safely set the new value // This is the way to ensure atomicity of the operation - for !atomic.CompareAndSwapUint32((*uint32)(address), uint32(i.loadFlag(address)), uint32(i.loadFlag(address)|flag)) { + for { + loadedValue := uint32(i.loadFlag(address)) + if atomic.CompareAndSwapUint32((*uint32)(address), loadedValue, loadedValue|uint32(flag)) { + break + } } }