Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 51 additions & 25 deletions platform/firewall/cgroup_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,21 @@ import (
"os"
"path/filepath"
"strings"
"syscall"

boshlog "github.com/cloudfoundry/bosh-utils/logger"
"golang.org/x/sys/unix"
)

const cgroupLogTag = "cgroup"

// getCurrentCgroupPath reads /proc/self/cgroup and extracts the cgroupv2 path.
// Returns path WITHOUT leading slash (e.g., "system.slice/runc-bpm-galera-agent.scope")
// to match the format used by the nft CLI.
// getCurrentCgroupPath reads /proc/self/cgroup and determines the effective
// cgroup v2 path.
//
// Returns path WITHOUT leading slash to match the format used by the nft CLI.
// (e.g. "system.slice/runc-bpm-galera-agent.scope")
//
// On hybrid systems (e.g., Ubuntu Jammy), the path is automatically prefixed
// with "unified/" to align with the hybrid mount point.
func getCurrentCgroupPath(logger boshlog.Logger) (string, error) {
data, err := os.ReadFile("/proc/self/cgroup")
if err != nil {
Expand All @@ -29,6 +34,10 @@ func getCurrentCgroupPath(logger boshlog.Logger) (string, error) {
if strings.HasPrefix(line, "0::") {
path := strings.TrimPrefix(line, "0::")
path = strings.TrimPrefix(path, "/")
path, err := cgroupv2Path(path)
if err != nil {
return "", fmt.Errorf("determining cgroup v2 path: %w", err)
}
logger.Info(cgroupLogTag, "Detected cgroupv2 path: %s", path)
return path, nil
}
Expand All @@ -37,32 +46,49 @@ func getCurrentCgroupPath(logger boshlog.Logger) (string, error) {
return "", fmt.Errorf("cgroupv2 path not found in /proc/self/cgroup")
}

// isCgroupAccessible checks if the cgroup path is accessible and functional
// for nftables socket cgroupv2 matching.
// cgroupv2Path canonicalizes a path based on the detected cgroup hierarchy.
//
// This returns false in these cases:
// - Cgroup path doesn't exist in /sys/fs/cgroup
// - Hybrid cgroup system (cgroupv2 mounted but no controllers delegated)
// - Nested containers where cgroup path is different from host view
func isCgroupAccessible(logger boshlog.Logger, cgroupPath string) bool {
fullPath := filepath.Join("/sys/fs/cgroup", cgroupPath)
if _, err := os.Stat(fullPath); err != nil {
logger.Info(cgroupLogTag, "Cgroup path doesn't exist: %s", fullPath)
return false
// On unified systems the path is returned unchanged
// on hybrid systems it is prefixed with "unified/" to match the cgroupv2
// mount at /sys/fs/cgroup/unified.
//
// Returns an error if the cgroup mode cannot be determined.
func cgroupv2Path(path string) (string, error) {
switch detectCgroupMode() {
case unifiedMode:
return path, nil
case hybridMode:
return filepath.Join("unified", path), nil
default:
return "", fmt.Errorf("unknown cgroup mode")
}
}

controllers, err := os.ReadFile("/sys/fs/cgroup/cgroup.controllers")
if err != nil {
logger.Info(cgroupLogTag, "Cannot read cgroup.controllers: %v", err)
return false
type cgroupMode int

const (
unifiedMode cgroupMode = iota // Pure v2
hybridMode // v1 with v2 at /unified
unknownMode
)

// detectCgroupMode determines the system's cgroup hierarchy mode.
//
// Returns `unifiedMode`, if /sys/fs/cgroup is a cgroup2 filesystem.
// Returns `hybridMode`, if /sys/fs/cgroup/unified is a cgroup2 filesystem.
// Returns `unknownMode`, if cgroup2 was otherwise not detected.
func detectCgroupMode() cgroupMode {
var st unix.Statfs_t

if err := unix.Statfs("/sys/fs/cgroup", &st); err == nil && st.Type == unix.CGROUP2_SUPER_MAGIC {
return unifiedMode
}

if len(strings.TrimSpace(string(controllers))) == 0 {
logger.Info(cgroupLogTag, "Hybrid cgroup system detected (no controllers in cgroupv2)")
return false
if err := unix.Statfs("/sys/fs/cgroup/unified", &st); err == nil && st.Type == unix.CGROUP2_SUPER_MAGIC {
return hybridMode
}

return true
return unknownMode
}

// getCgroupInodeID returns the inode ID for the cgroup path.
Expand All @@ -72,8 +98,8 @@ func isCgroupAccessible(logger boshlog.Logger, cgroupPath string) bool {
func getCgroupInodeID(cgroupPath string) (uint64, error) {
fullPath := filepath.Join("/sys/fs/cgroup", cgroupPath)

var stat syscall.Stat_t
if err := syscall.Stat(fullPath, &stat); err != nil {
var stat unix.Stat_t
if err := unix.Stat(fullPath, &stat); err != nil {
return 0, fmt.Errorf("stat %s: %w", fullPath, err)
}

Expand Down
36 changes: 25 additions & 11 deletions platform/firewall/nftables_firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"fmt"
"net"
gonetURL "net/url"
"os"
"os/user"
"strconv"
"strings"

Expand All @@ -17,6 +17,8 @@ import (
"github.com/google/nftables/expr"
"github.com/google/nftables/userdata"
"golang.org/x/sys/unix"

"github.com/cloudfoundry/bosh-agent/v2/settings"
)

// NftablesConn abstracts the nftables connection for testing
Expand Down Expand Up @@ -93,6 +95,8 @@ func (r *realNftablesConn) CloseLasting() error {
return r.conn.CloseLasting()
}

type UserLookup func(username string) (*user.User, error)

// NftablesFirewall implements Manager and NatsFirewallHook using nftables with UID-based matching
type NftablesFirewall struct {
conn NftablesConn
Expand All @@ -103,6 +107,7 @@ type NftablesFirewall struct {
monitChain *nftables.Chain
monitJobsChain *nftables.Chain
natsChain *nftables.Chain
userLookup UserLookup
}

// NewNftablesFirewall creates a new nftables-based firewall manager
Expand All @@ -115,17 +120,19 @@ func NewNftablesFirewall(logger boshlog.Logger) (Manager, error) {
return NewNftablesFirewallWithDeps(
&realNftablesConn{conn: conn},
&realDNSResolver{},
user.Lookup,
logger,
), nil
}

// NewNftablesFirewallWithDeps creates a firewall manager with injected dependencies (for testing)
func NewNftablesFirewallWithDeps(conn NftablesConn, resolver DNSResolver, logger boshlog.Logger) Manager {
func NewNftablesFirewallWithDeps(conn NftablesConn, resolver DNSResolver, userLookup UserLookup, logger boshlog.Logger) Manager {
return &NftablesFirewall{
conn: conn,
resolver: resolver,
logger: logger,
logTag: "NftablesFirewall",
conn: conn,
resolver: resolver,
logger: logger,
logTag: "NftablesFirewall",
userLookup: userLookup,
}
}

Expand Down Expand Up @@ -182,7 +189,7 @@ func (f *NftablesFirewall) EnableMonitAccess() error {

// 2. Try cgroup-based rule first (better isolation)
cgroupPath, err := getCurrentCgroupPath(f.logger)
if err == nil && isCgroupAccessible(f.logger, cgroupPath) {
if err == nil {
inodeID, err := getCgroupInodeID(cgroupPath)
if err == nil {
f.logger.Info(f.logTag, "Using cgroup rule for: %s (inode: %d)", cgroupPath, inodeID)
Expand All @@ -196,15 +203,22 @@ func (f *NftablesFirewall) EnableMonitAccess() error {
} else {
f.logger.Error(f.logTag, "Failed to get cgroup inode ID: %v", err)
}
} else if err != nil {
} else {
f.logger.Error(f.logTag, "Could not detect cgroup: %v", err)
}

// 3. Fallback to UID-based rule
uid := uint32(os.Getuid())
f.logger.Info(f.logTag, "Falling back to UID rule for UID: %d", uid)
f.logger.Info(f.logTag, "Falling back to UID rule for vcap")
u, err := f.userLookup(settings.VCAPUsername)
if err != nil {
return fmt.Errorf("could not find vcap user: %w", err)
}
uid, err := strconv.ParseUint(u.Uid, 10, 32)
if err != nil {
return fmt.Errorf("could not parse vcap UID %q: %w", u.Uid, err)
}

return f.addUIDRule(uid)
return f.addUIDRule(uint32(uid))
}

// SetupNATSFirewall creates firewall rules to protect NATS.
Expand Down
20 changes: 17 additions & 3 deletions platform/firewall/nftables_firewall_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"encoding/binary"
"errors"
"net"
"os"
"os/user"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
Expand All @@ -30,8 +30,21 @@ var _ = Describe("NftablesFirewall", func() {
BeforeEach(func() {
fakeConn = &firewallfakes.FakeNftablesConn{}
fakeResolver = &firewallfakes.FakeDNSResolver{}
var fakeUserLookup firewall.UserLookup = func(username string) (*user.User, error) {
if username != "vcap" {
return nil, errors.New("unexpected user")
}

return &user.User{
Uid: "1000",
Gid: "1000",
Username: "vcap",
Name: "BOSH System User",
HomeDir: "/home/vcap",
}, nil
}
logger = boshlog.NewWriterLogger(boshlog.LevelDebug, GinkgoWriter)
manager = firewall.NewNftablesFirewallWithDeps(fakeConn, fakeResolver, logger)
manager = firewall.NewNftablesFirewallWithDeps(fakeConn, fakeResolver, fakeUserLookup, logger)
})

Describe("SetupMonitFirewall", func() {
Expand Down Expand Up @@ -218,7 +231,8 @@ var _ = Describe("NftablesFirewall", func() {
Context("when the rule already exists (idempotency)", func() {
It("does not add a duplicate UID rule", func() {
// Simulate an existing UID rule matching the current UID
uid := uint32(os.Getuid())
const vcapUid = 1000
uid := uint32(vcapUid)
uidBytes := make([]byte, 4)
binary.NativeEndian.PutUint32(uidBytes, uid)

Expand Down