diff --git a/pkg/port/builtin/builtin_test.go b/pkg/port/builtin/builtin_test.go index 7dd05284..2228ab23 100644 --- a/pkg/port/builtin/builtin_test.go +++ b/pkg/port/builtin/builtin_test.go @@ -30,4 +30,5 @@ func TestBuiltIn(t *testing.T) { } testsuite.Run(t, pf) testsuite.RunTCPTransparent(t, pf) + testsuite.RunUDPTransparent(t, pf) } diff --git a/pkg/port/builtin/child/child.go b/pkg/port/builtin/child/child.go index 88a7a455..1debb07e 100644 --- a/pkg/port/builtin/child/child.go +++ b/pkg/port/builtin/child/child.go @@ -147,7 +147,7 @@ func (d *childDriver) handleConnectRequest(c *net.UnixConn, req *msg.Request) er var targetConn net.Conn var err error - if d.sourceIPTransparent && req.SourceIP != "" && req.SourcePort != 0 && dialProto == "tcp" && !net.ParseIP(req.SourceIP).IsLoopback() { + if d.sourceIPTransparent && req.SourceIP != "" && req.SourcePort != 0 && (dialProto == "tcp" || dialProto == "udp") && !net.ParseIP(req.SourceIP).IsLoopback() { d.routingSetup.Do(func() { d.routingReady = d.setupTransparentRouting() }) if !d.routingReady { d.routingWarn.Do(func() { @@ -251,9 +251,16 @@ func (d *childDriver) setupTransparentRouting() bool { // transparentDial dials targetAddr using IP_TRANSPARENT, binding to the given // source IP and port so the backend service sees the real client address. func transparentDial(dialProto, targetAddr, sourceIP string, sourcePort int) (net.Conn, error) { + var localAddr net.Addr + switch dialProto { + case "tcp": + localAddr = &net.TCPAddr{IP: net.ParseIP(sourceIP), Port: sourcePort} + case "udp": + localAddr = &net.UDPAddr{IP: net.ParseIP(sourceIP), Port: sourcePort} + } dialer := net.Dialer{ Timeout: time.Second, - LocalAddr: &net.TCPAddr{IP: net.ParseIP(sourceIP), Port: sourcePort}, + LocalAddr: localAddr, Control: func(network, address string, c syscall.RawConn) error { var sockErr error if err := c.Control(func(fd uintptr) { diff --git a/pkg/port/builtin/msg/msg.go b/pkg/port/builtin/msg/msg.go index aef2437d..1f845f68 100644 --- a/pkg/port/builtin/msg/msg.go +++ b/pkg/port/builtin/msg/msg.go @@ -82,9 +82,17 @@ func ConnectToChild(c *net.UnixConn, spec port.Spec, sourceAddr net.Addr) (int, ParentIP: spec.ParentIP, HostGatewayIP: hostGatewayIP(), } - if tcpAddr, ok := sourceAddr.(*net.TCPAddr); ok && tcpAddr != nil { - req.SourceIP = tcpAddr.IP.String() - req.SourcePort = tcpAddr.Port + switch a := sourceAddr.(type) { + case *net.TCPAddr: + if a != nil { + req.SourceIP = a.IP.String() + req.SourcePort = a.Port + } + case *net.UDPAddr: + if a != nil { + req.SourceIP = a.IP.String() + req.SourcePort = a.Port + } } if _, err := lowlevelmsgutil.MarshalToWriter(c, &req); err != nil { return 0, err diff --git a/pkg/port/builtin/parent/udp/udp.go b/pkg/port/builtin/parent/udp/udp.go index 2bcd0637..716b27fb 100644 --- a/pkg/port/builtin/parent/udp/udp.go +++ b/pkg/port/builtin/parent/udp/udp.go @@ -24,9 +24,9 @@ func Run(socketPath string, spec port.Spec, stopCh <-chan struct{}, stoppedCh ch udpp := &udpproxy.UDPProxy{ LogWriter: logWriter, Listener: c, - BackendDial: func() (*net.UDPConn, error) { + BackendDial: func(from *net.UDPAddr) (*net.UDPConn, error) { // get fd from the child as an SCM_RIGHTS cmsg - fd, err := msg.ConnectToChildWithRetry(socketPath, spec, 10, nil) + fd, err := msg.ConnectToChildWithRetry(socketPath, spec, 10, from) if err != nil { return nil, err } diff --git a/pkg/port/builtin/parent/udp/udpproxy/udp_proxy.go b/pkg/port/builtin/parent/udp/udpproxy/udp_proxy.go index af7b7d5d..7668e8c8 100644 --- a/pkg/port/builtin/parent/udp/udpproxy/udp_proxy.go +++ b/pkg/port/builtin/parent/udp/udpproxy/udp_proxy.go @@ -49,7 +49,7 @@ type connTrackMap map[connTrackKey]*net.UDPConn type UDPProxy struct { LogWriter io.Writer Listener *net.UDPConn - BackendDial func() (*net.UDPConn, error) + BackendDial func(from *net.UDPAddr) (*net.UDPConn, error) connTrackTable connTrackMap connTrackLock sync.Mutex } @@ -108,7 +108,7 @@ func (proxy *UDPProxy) Run() { proxy.connTrackLock.Lock() proxyConn, hit := proxy.connTrackTable[*fromKey] if !hit { - proxyConn, err = proxy.BackendDial() + proxyConn, err = proxy.BackendDial(from) if err != nil { fmt.Fprintf(proxy.LogWriter, "Can't proxy a datagram to udp: %v\n", err) proxy.connTrackLock.Unlock() diff --git a/pkg/port/testsuite/testsuite.go b/pkg/port/testsuite/testsuite.go index a374c29d..8a3edc5d 100644 --- a/pkg/port/testsuite/testsuite.go +++ b/pkg/port/testsuite/testsuite.go @@ -36,6 +36,9 @@ func Main(m *testing.M, cf func() port.ChildDriver) { case "echoserver": runEchoServer() os.Exit(0) + case "udpechoserver": + runUDPEchoServer() + os.Exit(0) default: panic(fmt.Errorf("unknown mode: %q", mode)) } @@ -603,3 +606,232 @@ func testTCPTransparentWithPID(t *testing.T, d port.ParentDriver, childPID int) t.Fatal(err) } } + +// runUDPEchoServer is a re-exec mode that runs a minimal UDP server. +// It listens on 127.0.0.1:, signals readiness by closing fd 3, +// receives one datagram, writes the remote address to stdout, and echoes the data back. +func runUDPEchoServer() { + portStr := os.Getenv(reexecKeyEchoPort) + if portStr == "" { + panic("udpechoserver: missing " + reexecKeyEchoPort) + } + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:"+portStr) + if err != nil { + panic(fmt.Errorf("udpechoserver: resolve: %w", err)) + } + conn, err := net.ListenUDP("udp", addr) + if err != nil { + panic(fmt.Errorf("udpechoserver: listen: %w", err)) + } + defer conn.Close() + // Signal readiness by closing fd 3 + readyW := os.NewFile(3, "ready") + readyW.Close() + + buf := make([]byte, 65507) + n, from, err := conn.ReadFromUDP(buf) + if err != nil { + panic(fmt.Errorf("udpechoserver: read: %w", err)) + } + fmt.Fprintln(os.Stdout, from.String()) + conn.WriteToUDP(buf[:n], from) +} + +func RunUDPTransparent(t *testing.T, pf func() port.ParentDriver) { + t.Run("TestUDPTransparent", func(t *testing.T) { TestUDPTransparent(t, pf()) }) +} + +func TestUDPTransparent(t *testing.T, d port.ParentDriver) { + ensureDeps(t, "nsenter") + t.Logf("creating USER+NET namespace") + opaque := d.OpaqueForChild() + opaqueJSON, err := json.Marshal(opaque) + if err != nil { + t.Fatal(err) + } + pr, pw, err := os.Pipe() + if err != nil { + t.Fatal(err) + } + cmd := exec.Command("/proc/self/exe") + cmd.Stdout = os.Stderr + cmd.Stderr = os.Stderr + cmd.Env = append([]string{ + reexecKeyMode + "=child", + reexecKeyOpaque + "=" + string(opaqueJSON), + reexecKeyQuitFD + "=3"}, os.Environ()...) + cmd.SysProcAttr = &syscall.SysProcAttr{ + Pdeathsig: syscall.SIGKILL, + Cloneflags: syscall.CLONE_NEWUSER | syscall.CLONE_NEWNET, + UidMappings: []syscall.SysProcIDMap{ + { + ContainerID: 0, + HostID: os.Geteuid(), + Size: 1, + }, + }, + GidMappings: []syscall.SysProcIDMap{ + { + ContainerID: 0, + HostID: os.Getegid(), + Size: 1, + }, + }, + } + cmd.ExtraFiles = []*os.File{pr} + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + defer func() { + pw.Close() + cmd.Wait() + }() + childPID := cmd.Process.Pid + if out, err := nsenterExec(childPID, "ip", "link", "set", "lo", "up"); err != nil { + t.Fatalf("%v, out=%s", err, string(out)) + } + testUDPTransparentWithPID(t, d, childPID) +} + +func testUDPTransparentWithPID(t *testing.T, d port.ParentDriver, childPID int) { + ensureDeps(t, "nsenter") + const childPort = 80 + + // Start parent driver + initComplete := make(chan struct{}) + quit := make(chan struct{}) + driverErr := make(chan error) + go func() { + cctx := &port.ChildContext{ + IP: nil, + } + driverErr <- d.RunParentDriver(initComplete, quit, cctx) + }() + select { + case <-initComplete: + case err := <-driverErr: + t.Fatal(err) + } + + // Start UDP echo server inside the child namespace + exe, err := os.Executable() + if err != nil { + t.Fatal(err) + } + + readyR, readyW, err := os.Pipe() + if err != nil { + t.Fatal(err) + } + + stdoutR, stdoutW, err := os.Pipe() + if err != nil { + t.Fatal(err) + } + + echoCmd := exec.Command("nsenter", "-U", "--preserve-credential", "-n", + "-t", strconv.Itoa(childPID), + exe) + echoCmd.Env = append([]string{ + reexecKeyMode + "=udpechoserver", + reexecKeyEchoPort + "=" + strconv.Itoa(childPort), + }, os.Environ()...) + echoCmd.Stdout = stdoutW + echoCmd.Stderr = os.Stderr + echoCmd.ExtraFiles = []*os.File{readyW} + echoCmd.SysProcAttr = &syscall.SysProcAttr{ + Pdeathsig: syscall.SIGKILL, + } + if err := echoCmd.Start(); err != nil { + t.Fatal(err) + } + defer echoCmd.Process.Kill() + readyW.Close() + + io.ReadAll(readyR) + readyR.Close() + stdoutW.Close() + + // Allocate a parent port and add port forwarding + parentPort, err := allocateAvailablePort("udp") + if err != nil { + t.Fatal(err) + } + + var portStatus *port.Status + const maxAttempts = 10 + for attempt := 0; attempt < maxAttempts; attempt++ { + portStatus, err = d.AddPort(context.TODO(), + port.Spec{ + Proto: "udp", + ParentIP: "127.0.0.1", + ParentPort: parentPort, + ChildPort: childPort, + }) + if err == nil { + break + } + if attempt == maxAttempts-1 || !isAddrInUse(err) { + t.Fatal(err) + } + parentPort, err = allocateAvailablePort("udp") + if err != nil { + t.Fatal(err) + } + } + t.Logf("opened port: %+v", portStatus) + + // Give the proxy time to start + time.Sleep(500 * time.Millisecond) + + // Dial the parent port + conn, err := net.Dial("udp", fmt.Sprintf("127.0.0.1:%d", parentPort)) + if err != nil { + t.Fatal(err) + } + + clientAddr := conn.LocalAddr().String() + t.Logf("client local address: %s", clientAddr) + + if _, err := conn.Write([]byte("hello")); err != nil { + t.Fatal(err) + } + + // Read echo response to ensure round-trip + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + buf := make([]byte, 1024) + if _, err := conn.Read(buf); err != nil { + t.Fatal(err) + } + conn.Close() + + // Read the remote address the echo server saw + scanner := bufio.NewScanner(stdoutR) + if !scanner.Scan() { + t.Fatal("failed to read remote address from UDP echo server") + } + serverSawAddr := scanner.Text() + t.Logf("server saw remote address: %s", serverSawAddr) + + clientHost, _, err := net.SplitHostPort(clientAddr) + if err != nil { + t.Fatalf("failed to parse client address %q: %v", clientAddr, err) + } + serverHost, _, err := net.SplitHostPort(serverSawAddr) + if err != nil { + t.Fatalf("failed to parse server-seen address %q: %v", serverSawAddr, err) + } + + if clientHost != serverHost { + t.Errorf("IP mismatch: client=%s, server saw=%s", clientHost, serverHost) + } + + // Cleanup + if err := d.RemovePort(context.TODO(), portStatus.ID); err != nil { + t.Fatal(err) + } + quit <- struct{}{} + if err := <-driverErr; err != nil { + t.Fatal(err) + } +}