diff --git a/platform/firewall/cgroup_linux.go b/platform/firewall/cgroup_linux.go index 287f90e3f..6f3c91fd0 100644 --- a/platform/firewall/cgroup_linux.go +++ b/platform/firewall/cgroup_linux.go @@ -5,16 +5,21 @@ import ( "os" "path/filepath" "strings" - "syscall" boshlog "github.com/cloudfoundry/bosh-utils/logger" + "golang.org/x/sys/unix" ) const cgroupLogTag = "cgroup" -// getCurrentCgroupPath reads /proc/self/cgroup and extracts the cgroupv2 path. -// Returns path WITHOUT leading slash (e.g., "system.slice/runc-bpm-galera-agent.scope") -// to match the format used by the nft CLI. +// getCurrentCgroupPath reads /proc/self/cgroup and determines the effective +// cgroup v2 path. +// +// Returns path WITHOUT leading slash to match the format used by the nft CLI. +// (e.g. "system.slice/runc-bpm-galera-agent.scope") +// +// On hybrid systems (e.g., Ubuntu Jammy), the path is automatically prefixed +// with "unified/" to align with the hybrid mount point. func getCurrentCgroupPath(logger boshlog.Logger) (string, error) { data, err := os.ReadFile("/proc/self/cgroup") if err != nil { @@ -29,6 +34,10 @@ func getCurrentCgroupPath(logger boshlog.Logger) (string, error) { if strings.HasPrefix(line, "0::") { path := strings.TrimPrefix(line, "0::") path = strings.TrimPrefix(path, "/") + path, err := cgroupv2Path(path) + if err != nil { + return "", fmt.Errorf("determining cgroup v2 path: %w", err) + } logger.Info(cgroupLogTag, "Detected cgroupv2 path: %s", path) return path, nil } @@ -37,32 +46,49 @@ func getCurrentCgroupPath(logger boshlog.Logger) (string, error) { return "", fmt.Errorf("cgroupv2 path not found in /proc/self/cgroup") } -// isCgroupAccessible checks if the cgroup path is accessible and functional -// for nftables socket cgroupv2 matching. +// cgroupv2Path canonicalizes a path based on the detected cgroup hierarchy. // -// This returns false in these cases: -// - Cgroup path doesn't exist in /sys/fs/cgroup -// - Hybrid cgroup system (cgroupv2 mounted but no controllers delegated) -// - Nested containers where cgroup path is different from host view -func isCgroupAccessible(logger boshlog.Logger, cgroupPath string) bool { - fullPath := filepath.Join("/sys/fs/cgroup", cgroupPath) - if _, err := os.Stat(fullPath); err != nil { - logger.Info(cgroupLogTag, "Cgroup path doesn't exist: %s", fullPath) - return false +// On unified systems the path is returned unchanged +// on hybrid systems it is prefixed with "unified/" to match the cgroupv2 +// mount at /sys/fs/cgroup/unified. +// +// Returns an error if the cgroup mode cannot be determined. +func cgroupv2Path(path string) (string, error) { + switch detectCgroupMode() { + case unifiedMode: + return path, nil + case hybridMode: + return filepath.Join("unified", path), nil + default: + return "", fmt.Errorf("unknown cgroup mode") } +} - controllers, err := os.ReadFile("/sys/fs/cgroup/cgroup.controllers") - if err != nil { - logger.Info(cgroupLogTag, "Cannot read cgroup.controllers: %v", err) - return false +type cgroupMode int + +const ( + unifiedMode cgroupMode = iota // Pure v2 + hybridMode // v1 with v2 at /unified + unknownMode +) + +// detectCgroupMode determines the system's cgroup hierarchy mode. +// +// Returns `unifiedMode`, if /sys/fs/cgroup is a cgroup2 filesystem. +// Returns `hybridMode`, if /sys/fs/cgroup/unified is a cgroup2 filesystem. +// Returns `unknownMode`, if cgroup2 was otherwise not detected. +func detectCgroupMode() cgroupMode { + var st unix.Statfs_t + + if err := unix.Statfs("/sys/fs/cgroup", &st); err == nil && st.Type == unix.CGROUP2_SUPER_MAGIC { + return unifiedMode } - if len(strings.TrimSpace(string(controllers))) == 0 { - logger.Info(cgroupLogTag, "Hybrid cgroup system detected (no controllers in cgroupv2)") - return false + if err := unix.Statfs("/sys/fs/cgroup/unified", &st); err == nil && st.Type == unix.CGROUP2_SUPER_MAGIC { + return hybridMode } - return true + return unknownMode } // getCgroupInodeID returns the inode ID for the cgroup path. @@ -72,8 +98,8 @@ func isCgroupAccessible(logger boshlog.Logger, cgroupPath string) bool { func getCgroupInodeID(cgroupPath string) (uint64, error) { fullPath := filepath.Join("/sys/fs/cgroup", cgroupPath) - var stat syscall.Stat_t - if err := syscall.Stat(fullPath, &stat); err != nil { + var stat unix.Stat_t + if err := unix.Stat(fullPath, &stat); err != nil { return 0, fmt.Errorf("stat %s: %w", fullPath, err) } diff --git a/platform/firewall/nftables_firewall.go b/platform/firewall/nftables_firewall.go index c78fb49ae..5a1b2bc4f 100644 --- a/platform/firewall/nftables_firewall.go +++ b/platform/firewall/nftables_firewall.go @@ -7,7 +7,7 @@ import ( "fmt" "net" gonetURL "net/url" - "os" + "os/user" "strconv" "strings" @@ -17,6 +17,8 @@ import ( "github.com/google/nftables/expr" "github.com/google/nftables/userdata" "golang.org/x/sys/unix" + + "github.com/cloudfoundry/bosh-agent/v2/settings" ) // NftablesConn abstracts the nftables connection for testing @@ -93,6 +95,8 @@ func (r *realNftablesConn) CloseLasting() error { return r.conn.CloseLasting() } +type UserLookup func(username string) (*user.User, error) + // NftablesFirewall implements Manager and NatsFirewallHook using nftables with UID-based matching type NftablesFirewall struct { conn NftablesConn @@ -103,6 +107,7 @@ type NftablesFirewall struct { monitChain *nftables.Chain monitJobsChain *nftables.Chain natsChain *nftables.Chain + userLookup UserLookup } // NewNftablesFirewall creates a new nftables-based firewall manager @@ -115,17 +120,19 @@ func NewNftablesFirewall(logger boshlog.Logger) (Manager, error) { return NewNftablesFirewallWithDeps( &realNftablesConn{conn: conn}, &realDNSResolver{}, + user.Lookup, logger, ), nil } // NewNftablesFirewallWithDeps creates a firewall manager with injected dependencies (for testing) -func NewNftablesFirewallWithDeps(conn NftablesConn, resolver DNSResolver, logger boshlog.Logger) Manager { +func NewNftablesFirewallWithDeps(conn NftablesConn, resolver DNSResolver, userLookup UserLookup, logger boshlog.Logger) Manager { return &NftablesFirewall{ - conn: conn, - resolver: resolver, - logger: logger, - logTag: "NftablesFirewall", + conn: conn, + resolver: resolver, + logger: logger, + logTag: "NftablesFirewall", + userLookup: userLookup, } } @@ -182,7 +189,7 @@ func (f *NftablesFirewall) EnableMonitAccess() error { // 2. Try cgroup-based rule first (better isolation) cgroupPath, err := getCurrentCgroupPath(f.logger) - if err == nil && isCgroupAccessible(f.logger, cgroupPath) { + if err == nil { inodeID, err := getCgroupInodeID(cgroupPath) if err == nil { f.logger.Info(f.logTag, "Using cgroup rule for: %s (inode: %d)", cgroupPath, inodeID) @@ -196,15 +203,22 @@ func (f *NftablesFirewall) EnableMonitAccess() error { } else { f.logger.Error(f.logTag, "Failed to get cgroup inode ID: %v", err) } - } else if err != nil { + } else { f.logger.Error(f.logTag, "Could not detect cgroup: %v", err) } // 3. Fallback to UID-based rule - uid := uint32(os.Getuid()) - f.logger.Info(f.logTag, "Falling back to UID rule for UID: %d", uid) + f.logger.Info(f.logTag, "Falling back to UID rule for vcap") + u, err := f.userLookup(settings.VCAPUsername) + if err != nil { + return fmt.Errorf("could not find vcap user: %w", err) + } + uid, err := strconv.ParseUint(u.Uid, 10, 32) + if err != nil { + return fmt.Errorf("could not parse vcap UID %q: %w", u.Uid, err) + } - return f.addUIDRule(uid) + return f.addUIDRule(uint32(uid)) } // SetupNATSFirewall creates firewall rules to protect NATS. diff --git a/platform/firewall/nftables_firewall_test.go b/platform/firewall/nftables_firewall_test.go index a02c4221b..02c7ec49a 100644 --- a/platform/firewall/nftables_firewall_test.go +++ b/platform/firewall/nftables_firewall_test.go @@ -6,7 +6,7 @@ import ( "encoding/binary" "errors" "net" - "os" + "os/user" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -30,8 +30,21 @@ var _ = Describe("NftablesFirewall", func() { BeforeEach(func() { fakeConn = &firewallfakes.FakeNftablesConn{} fakeResolver = &firewallfakes.FakeDNSResolver{} + var fakeUserLookup firewall.UserLookup = func(username string) (*user.User, error) { + if username != "vcap" { + return nil, errors.New("unexpected user") + } + + return &user.User{ + Uid: "1000", + Gid: "1000", + Username: "vcap", + Name: "BOSH System User", + HomeDir: "/home/vcap", + }, nil + } logger = boshlog.NewWriterLogger(boshlog.LevelDebug, GinkgoWriter) - manager = firewall.NewNftablesFirewallWithDeps(fakeConn, fakeResolver, logger) + manager = firewall.NewNftablesFirewallWithDeps(fakeConn, fakeResolver, fakeUserLookup, logger) }) Describe("SetupMonitFirewall", func() { @@ -218,7 +231,8 @@ var _ = Describe("NftablesFirewall", func() { Context("when the rule already exists (idempotency)", func() { It("does not add a duplicate UID rule", func() { // Simulate an existing UID rule matching the current UID - uid := uint32(os.Getuid()) + const vcapUid = 1000 + uid := uint32(vcapUid) uidBytes := make([]byte, 4) binary.NativeEndian.PutUint32(uidBytes, uid)