Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/port/builtin/builtin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ func TestBuiltIn(t *testing.T) {
}
testsuite.Run(t, pf)
testsuite.RunTCPTransparent(t, pf)
testsuite.RunUDPTransparent(t, pf)
}
11 changes: 9 additions & 2 deletions pkg/port/builtin/child/child.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func (d *childDriver) handleConnectRequest(c *net.UnixConn, req *msg.Request) er

var targetConn net.Conn
var err error
if d.sourceIPTransparent && req.SourceIP != "" && req.SourcePort != 0 && dialProto == "tcp" && !net.ParseIP(req.SourceIP).IsLoopback() {
if d.sourceIPTransparent && req.SourceIP != "" && req.SourcePort != 0 && (dialProto == "tcp" || dialProto == "udp") && !net.ParseIP(req.SourceIP).IsLoopback() {
d.routingSetup.Do(func() { d.routingReady = d.setupTransparentRouting() })
if !d.routingReady {
d.routingWarn.Do(func() {
Expand Down Expand Up @@ -251,9 +251,16 @@ func (d *childDriver) setupTransparentRouting() bool {
// transparentDial dials targetAddr using IP_TRANSPARENT, binding to the given
// source IP and port so the backend service sees the real client address.
func transparentDial(dialProto, targetAddr, sourceIP string, sourcePort int) (net.Conn, error) {
var localAddr net.Addr
switch dialProto {
case "tcp":
localAddr = &net.TCPAddr{IP: net.ParseIP(sourceIP), Port: sourcePort}
case "udp":
localAddr = &net.UDPAddr{IP: net.ParseIP(sourceIP), Port: sourcePort}
}
dialer := net.Dialer{
Timeout: time.Second,
LocalAddr: &net.TCPAddr{IP: net.ParseIP(sourceIP), Port: sourcePort},
LocalAddr: localAddr,
Control: func(network, address string, c syscall.RawConn) error {
var sockErr error
if err := c.Control(func(fd uintptr) {
Expand Down
14 changes: 11 additions & 3 deletions pkg/port/builtin/msg/msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,17 @@ func ConnectToChild(c *net.UnixConn, spec port.Spec, sourceAddr net.Addr) (int,
ParentIP: spec.ParentIP,
HostGatewayIP: hostGatewayIP(),
}
if tcpAddr, ok := sourceAddr.(*net.TCPAddr); ok && tcpAddr != nil {
req.SourceIP = tcpAddr.IP.String()
req.SourcePort = tcpAddr.Port
switch a := sourceAddr.(type) {
case *net.TCPAddr:
if a != nil {
req.SourceIP = a.IP.String()
req.SourcePort = a.Port
}
case *net.UDPAddr:
if a != nil {
req.SourceIP = a.IP.String()
req.SourcePort = a.Port
}
}
if _, err := lowlevelmsgutil.MarshalToWriter(c, &req); err != nil {
return 0, err
Expand Down
4 changes: 2 additions & 2 deletions pkg/port/builtin/parent/udp/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ func Run(socketPath string, spec port.Spec, stopCh <-chan struct{}, stoppedCh ch
udpp := &udpproxy.UDPProxy{
LogWriter: logWriter,
Listener: c,
BackendDial: func() (*net.UDPConn, error) {
BackendDial: func(from *net.UDPAddr) (*net.UDPConn, error) {
// get fd from the child as an SCM_RIGHTS cmsg
fd, err := msg.ConnectToChildWithRetry(socketPath, spec, 10, nil)
fd, err := msg.ConnectToChildWithRetry(socketPath, spec, 10, from)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/port/builtin/parent/udp/udpproxy/udp_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type connTrackMap map[connTrackKey]*net.UDPConn
type UDPProxy struct {
LogWriter io.Writer
Listener *net.UDPConn
BackendDial func() (*net.UDPConn, error)
BackendDial func(from *net.UDPAddr) (*net.UDPConn, error)
connTrackTable connTrackMap
connTrackLock sync.Mutex
}
Expand Down Expand Up @@ -108,7 +108,7 @@ func (proxy *UDPProxy) Run() {
proxy.connTrackLock.Lock()
proxyConn, hit := proxy.connTrackTable[*fromKey]
if !hit {
proxyConn, err = proxy.BackendDial()
proxyConn, err = proxy.BackendDial(from)
if err != nil {
fmt.Fprintf(proxy.LogWriter, "Can't proxy a datagram to udp: %v\n", err)
proxy.connTrackLock.Unlock()
Expand Down
232 changes: 232 additions & 0 deletions pkg/port/testsuite/testsuite.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ func Main(m *testing.M, cf func() port.ChildDriver) {
case "echoserver":
runEchoServer()
os.Exit(0)
case "udpechoserver":
runUDPEchoServer()
os.Exit(0)
default:
panic(fmt.Errorf("unknown mode: %q", mode))
}
Expand Down Expand Up @@ -603,3 +606,232 @@ func testTCPTransparentWithPID(t *testing.T, d port.ParentDriver, childPID int)
t.Fatal(err)
}
}

// runUDPEchoServer is a re-exec mode that runs a minimal UDP server.
// It listens on 127.0.0.1:<port>, signals readiness by closing fd 3,
// receives one datagram, writes the remote address to stdout, and echoes the data back.
func runUDPEchoServer() {
portStr := os.Getenv(reexecKeyEchoPort)
if portStr == "" {
panic("udpechoserver: missing " + reexecKeyEchoPort)
}
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:"+portStr)
if err != nil {
panic(fmt.Errorf("udpechoserver: resolve: %w", err))
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
panic(fmt.Errorf("udpechoserver: listen: %w", err))
}
defer conn.Close()
// Signal readiness by closing fd 3
readyW := os.NewFile(3, "ready")
readyW.Close()

buf := make([]byte, 65507)
n, from, err := conn.ReadFromUDP(buf)
if err != nil {
panic(fmt.Errorf("udpechoserver: read: %w", err))
}
fmt.Fprintln(os.Stdout, from.String())
conn.WriteToUDP(buf[:n], from)
}

func RunUDPTransparent(t *testing.T, pf func() port.ParentDriver) {
t.Run("TestUDPTransparent", func(t *testing.T) { TestUDPTransparent(t, pf()) })
}

func TestUDPTransparent(t *testing.T, d port.ParentDriver) {
ensureDeps(t, "nsenter")
t.Logf("creating USER+NET namespace")
opaque := d.OpaqueForChild()
opaqueJSON, err := json.Marshal(opaque)
if err != nil {
t.Fatal(err)
}
pr, pw, err := os.Pipe()
if err != nil {
t.Fatal(err)
}
cmd := exec.Command("/proc/self/exe")
cmd.Stdout = os.Stderr
cmd.Stderr = os.Stderr
cmd.Env = append([]string{
reexecKeyMode + "=child",
reexecKeyOpaque + "=" + string(opaqueJSON),
reexecKeyQuitFD + "=3"}, os.Environ()...)
cmd.SysProcAttr = &syscall.SysProcAttr{
Pdeathsig: syscall.SIGKILL,
Cloneflags: syscall.CLONE_NEWUSER | syscall.CLONE_NEWNET,
UidMappings: []syscall.SysProcIDMap{
{
ContainerID: 0,
HostID: os.Geteuid(),
Size: 1,
},
},
GidMappings: []syscall.SysProcIDMap{
{
ContainerID: 0,
HostID: os.Getegid(),
Size: 1,
},
},
}
cmd.ExtraFiles = []*os.File{pr}
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
defer func() {
pw.Close()
cmd.Wait()
}()
childPID := cmd.Process.Pid
if out, err := nsenterExec(childPID, "ip", "link", "set", "lo", "up"); err != nil {
t.Fatalf("%v, out=%s", err, string(out))
}
testUDPTransparentWithPID(t, d, childPID)
}

func testUDPTransparentWithPID(t *testing.T, d port.ParentDriver, childPID int) {
ensureDeps(t, "nsenter")
const childPort = 80

// Start parent driver
initComplete := make(chan struct{})
quit := make(chan struct{})
driverErr := make(chan error)
go func() {
cctx := &port.ChildContext{
IP: nil,
}
driverErr <- d.RunParentDriver(initComplete, quit, cctx)
}()
select {
case <-initComplete:
case err := <-driverErr:
t.Fatal(err)
}

// Start UDP echo server inside the child namespace
exe, err := os.Executable()
if err != nil {
t.Fatal(err)
}

readyR, readyW, err := os.Pipe()
if err != nil {
t.Fatal(err)
}

stdoutR, stdoutW, err := os.Pipe()
if err != nil {
t.Fatal(err)
}

echoCmd := exec.Command("nsenter", "-U", "--preserve-credential", "-n",
"-t", strconv.Itoa(childPID),
exe)
echoCmd.Env = append([]string{
reexecKeyMode + "=udpechoserver",
reexecKeyEchoPort + "=" + strconv.Itoa(childPort),
}, os.Environ()...)
echoCmd.Stdout = stdoutW
echoCmd.Stderr = os.Stderr
echoCmd.ExtraFiles = []*os.File{readyW}
echoCmd.SysProcAttr = &syscall.SysProcAttr{
Pdeathsig: syscall.SIGKILL,
}
if err := echoCmd.Start(); err != nil {
t.Fatal(err)
}
defer echoCmd.Process.Kill()
readyW.Close()

io.ReadAll(readyR)
readyR.Close()
stdoutW.Close()

// Allocate a parent port and add port forwarding
parentPort, err := allocateAvailablePort("udp")
if err != nil {
t.Fatal(err)
}

var portStatus *port.Status
const maxAttempts = 10
for attempt := 0; attempt < maxAttempts; attempt++ {
portStatus, err = d.AddPort(context.TODO(),
port.Spec{
Proto: "udp",
ParentIP: "127.0.0.1",
ParentPort: parentPort,
ChildPort: childPort,
})
if err == nil {
break
}
if attempt == maxAttempts-1 || !isAddrInUse(err) {
t.Fatal(err)
}
parentPort, err = allocateAvailablePort("udp")
if err != nil {
t.Fatal(err)
}
}
t.Logf("opened port: %+v", portStatus)

// Give the proxy time to start
time.Sleep(500 * time.Millisecond)

// Dial the parent port
conn, err := net.Dial("udp", fmt.Sprintf("127.0.0.1:%d", parentPort))
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should test dialing non-loopback addr

if err != nil {
t.Fatal(err)
}

clientAddr := conn.LocalAddr().String()
t.Logf("client local address: %s", clientAddr)

if _, err := conn.Write([]byte("hello")); err != nil {
t.Fatal(err)
}

// Read echo response to ensure round-trip
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
buf := make([]byte, 1024)
if _, err := conn.Read(buf); err != nil {
t.Fatal(err)
}
conn.Close()

// Read the remote address the echo server saw
scanner := bufio.NewScanner(stdoutR)
if !scanner.Scan() {
t.Fatal("failed to read remote address from UDP echo server")
}
serverSawAddr := scanner.Text()
t.Logf("server saw remote address: %s", serverSawAddr)

clientHost, _, err := net.SplitHostPort(clientAddr)
if err != nil {
t.Fatalf("failed to parse client address %q: %v", clientAddr, err)
}
serverHost, _, err := net.SplitHostPort(serverSawAddr)
if err != nil {
t.Fatalf("failed to parse server-seen address %q: %v", serverSawAddr, err)
}

if clientHost != serverHost {
t.Errorf("IP mismatch: client=%s, server saw=%s", clientHost, serverHost)
}

// Cleanup
if err := d.RemovePort(context.TODO(), portStatus.ID); err != nil {
t.Fatal(err)
}
quit <- struct{}{}
if err := <-driverErr; err != nil {
t.Fatal(err)
}
}