From ee8097661ded0f740fe4725a1ad97777676ed41f Mon Sep 17 00:00:00 2001 From: Igal Tsoiref Date: Thu, 28 May 2026 22:32:14 -0400 Subject: [PATCH] fix(hostagent): minimize lock contention in NetworkManager run() and AddNetworkRequest() previously held the write lock across slow operations (API calls, PCI ops, netlink, network config), causing AddNetworkRequest HTTP handler to block for the entire run() cycle. Refactor to hold locks only for in-memory map operations: - run(): snapshot reqs under RLock, process lock-free - AddNetworkRequest(): RLock for map checks, no lock for API calls, brief Lock for map write - processNetworkRequest(): use removeRequest() helper for thread-safe map deletion Co-authored-by: Cursor --- .../networkmanager/network_manager.go | 63 +++- .../networkmanager/network_manager_test.go | 304 ++++++++++++++++++ 2 files changed, 351 insertions(+), 16 deletions(-) diff --git a/internal/provisioning/hostagent/networkmanager/network_manager.go b/internal/provisioning/hostagent/networkmanager/network_manager.go index 6aefe393c..0aa8f415d 100644 --- a/internal/provisioning/hostagent/networkmanager/network_manager.go +++ b/internal/provisioning/hostagent/networkmanager/network_manager.go @@ -154,16 +154,26 @@ func (nm *NetworkManager) loadNetworkRequest() error { } func (nm *NetworkManager) run() { - nm.Lock() - defer nm.Unlock() + nm.RLock() + reqs := make([]NetworkRequest, 0, len(nm.reqs)) + for _, nr := range nm.reqs { + reqs = append(reqs, nr) + } + nm.RUnlock() - for uid, nr := range nm.reqs { + for _, nr := range reqs { if err := nm.processNetworkRequest(nr); err != nil { - klog.Errorf("failed to process network request, nr: %+v, err: %v", nm.reqs[uid], err) + klog.Errorf("failed to process network request, nr: %+v, err: %v", nr, err) } } } +func (nm *NetworkManager) removeRequest(uid string) { + nm.Lock() + defer nm.Unlock() + delete(nm.reqs, uid) +} + func (nm *NetworkManager) processNetworkRequest(nr NetworkRequest) error { nn := types.NamespacedName{Namespace: nr.DPUNamespace, Name: nr.DpuName} dpu := &provisioningv1.DPU{} @@ -180,7 +190,7 @@ func (nm *NetworkManager) processNetworkRequest(nr NetworkRequest) error { if err != nil && !os.IsNotExist(err) { return fmt.Errorf("failed to remove network request file: %w", err) } - delete(nm.reqs, nr.UID) + nm.removeRequest(nr.UID) klog.Infof("removed VF and network request for DPU %s/%s(UID: %s)", nr.DPUNamespace, nr.DpuName, nr.UID) return nil } @@ -270,16 +280,34 @@ func (nm *NetworkManager) processNetworkRequest(nr NetworkRequest) error { return nil } -func (nm *NetworkManager) AddNetworkRequest(dpu *provisioningv1.DPU) error { - nm.Lock() - defer nm.Unlock() +// lookupDevice checks preconditions and returns the PCI device for the DPU. +// Returns found=false when a network request for this DPU already exists. +func (nm *NetworkManager) lookupDevice(dpu *provisioningv1.DPU) (dev hostutil.Device, found bool, err error) { + nm.RLock() + defer nm.RUnlock() if !nm.initialized { - return fmt.Errorf("network manager is not initialized") - } else if dpu == nil { + return hostutil.Device{}, false, fmt.Errorf("network manager is not initialized") + } + if _, ok := nm.reqs[string(dpu.UID)]; ok { + return hostutil.Device{}, false, nil + } + dev, ok := nm.devicesBySN[dpu.Spec.SerialNumber] + if !ok { + return hostutil.Device{}, false, fmt.Errorf("PCI address of device %s not found", dpu.Spec.SerialNumber) + } + return dev, true, nil +} + +func (nm *NetworkManager) AddNetworkRequest(dpu *provisioningv1.DPU) error { + if dpu == nil { return fmt.Errorf("DPU is nil") } - if _, ok := nm.reqs[string(dpu.UID)]; ok { + dev, found, err := nm.lookupDevice(dpu) + if err != nil { + return err + } + if !found { return nil } @@ -289,10 +317,6 @@ func (nm *NetworkManager) AddNetworkRequest(dpu *provisioningv1.DPU) error { nr.SetDPUObjectMeta(*dpu) // use the PCI address collected locally, so that it's not affected by PCI address changes - dev, ok := nm.devicesBySN[nr.SerialNumber] - if !ok { - return fmt.Errorf("PCI address of device %s not found", nr.SerialNumber) - } nr.PCIAddress = dev.Address numOfVFs, err := nm.getNumOfVFs(dpu) @@ -317,10 +341,17 @@ func (nm *NetworkManager) AddNetworkRequest(dpu *provisioningv1.DPU) error { if err := writeNetworkRequestFile(nr); err != nil { return fmt.Errorf("failed to write network request file: %w", err) } - nm.reqs[nr.UID] = *nr + + nm.addRequest(nr) return nil } +func (nm *NetworkManager) addRequest(nr *NetworkRequest) { + nm.Lock() + defer nm.Unlock() + nm.reqs[nr.UID] = *nr +} + func (nm *NetworkManager) getNumOfVFs(dpu *provisioningv1.DPU) (int, error) { flavor := &provisioningv1.DPUFlavor{} if err := nm.Get(context.TODO(), types.NamespacedName{Namespace: dpu.Namespace, Name: dpu.Spec.DPUFlavor}, flavor); err != nil { diff --git a/internal/provisioning/hostagent/networkmanager/network_manager_test.go b/internal/provisioning/hostagent/networkmanager/network_manager_test.go index bf13a747f..267562a13 100644 --- a/internal/provisioning/hostagent/networkmanager/network_manager_test.go +++ b/internal/provisioning/hostagent/networkmanager/network_manager_test.go @@ -19,15 +19,26 @@ limitations under the License. package networkmanager import ( + "context" + "encoding/json" + "fmt" "os" + "path/filepath" + "sync" + "time" + operatorv1 "github.com/nvidia/doca-platform/api/operator/v1alpha1" provisioningv1 "github.com/nvidia/doca-platform/api/provisioning/v1alpha1" hostutil "github.com/nvidia/doca-platform/internal/provisioning/hostagent/util" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" ) var _ = Describe("NetworkManager", func() { @@ -210,4 +221,297 @@ var _ = Describe("NetworkManager", func() { } }) }) + + Context("NetworkManager.lookupDevice", Label("lookupDevice"), func() { + It("should return device when initialized, device exists, and no existing request", func() { + nm := NewNetworkManager(nil) + nm.initialized = true + nm.devicesBySN["MT2334XZ0L"] = hostutil.Device{ + Address: "0000:03:00", + SerialNumber: "MT2334XZ0L", + NumOfPFs: 2, + } + + dpu := &provisioningv1.DPU{ + ObjectMeta: metav1.ObjectMeta{UID: "uid-1"}, + Spec: provisioningv1.DPUSpec{SerialNumber: "MT2334XZ0L", NodeEffect: provisioningv1.NodeEffect{Action: provisioningv1.Action{NoEffect: ptr.To(true)}}}, + } + + dev, found, err := nm.lookupDevice(dpu) + Expect(err).NotTo(HaveOccurred()) + Expect(found).To(BeTrue()) + Expect(dev.Address).To(Equal("0000:03:00")) + }) + + It("should return found=false when request already exists for the UID", func() { + nm := NewNetworkManager(nil) + nm.initialized = true + nm.devicesBySN["MT2334XZ0L"] = hostutil.Device{Address: "0000:03:00", SerialNumber: "MT2334XZ0L"} + nm.reqs["uid-1"] = NetworkRequest{UID: "uid-1"} + + dpu := &provisioningv1.DPU{ + ObjectMeta: metav1.ObjectMeta{UID: "uid-1"}, + Spec: provisioningv1.DPUSpec{SerialNumber: "MT2334XZ0L", NodeEffect: provisioningv1.NodeEffect{Action: provisioningv1.Action{NoEffect: ptr.To(true)}}}, + } + + _, found, err := nm.lookupDevice(dpu) + Expect(err).NotTo(HaveOccurred()) + Expect(found).To(BeFalse()) + }) + + It("should return error when not initialized", func() { + nm := NewNetworkManager(nil) + + dpu := &provisioningv1.DPU{ + ObjectMeta: metav1.ObjectMeta{UID: "uid-1"}, + Spec: provisioningv1.DPUSpec{SerialNumber: "MT2334XZ0L", NodeEffect: provisioningv1.NodeEffect{Action: provisioningv1.Action{NoEffect: ptr.To(true)}}}, + } + + _, _, err := nm.lookupDevice(dpu) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("not initialized")) + }) + + It("should return error when device serial number not in map", func() { + nm := NewNetworkManager(nil) + nm.initialized = true + + dpu := &provisioningv1.DPU{ + ObjectMeta: metav1.ObjectMeta{UID: "uid-1"}, + Spec: provisioningv1.DPUSpec{SerialNumber: "UNKNOWN", NodeEffect: provisioningv1.NodeEffect{Action: provisioningv1.Action{NoEffect: ptr.To(true)}}}, + } + + _, _, err := nm.lookupDevice(dpu) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("not found")) + }) + }) + + Context("NetworkManager.addRequest and removeRequest", Label("addRequest", "removeRequest"), func() { + It("should insert a request into the map", func() { + nm := NewNetworkManager(nil) + nr := &NetworkRequest{UID: "uid-add-1", DpuName: "dpu-1"} + + nm.addRequest(nr) + + Expect(nm.reqs).To(HaveKey("uid-add-1")) + Expect(nm.reqs["uid-add-1"].DpuName).To(Equal("dpu-1")) + }) + + It("should remove a request from the map", func() { + nm := NewNetworkManager(nil) + nm.reqs["uid-rm-1"] = NetworkRequest{UID: "uid-rm-1"} + + nm.removeRequest("uid-rm-1") + + Expect(nm.reqs).NotTo(HaveKey("uid-rm-1")) + }) + + It("should handle removing a non-existent key", func() { + nm := NewNetworkManager(nil) + Expect(func() { nm.removeRequest("nonexistent") }).NotTo(Panic()) + }) + + It("should be safe under concurrent access", func() { + nm := NewNetworkManager(nil) + var wg sync.WaitGroup + + for i := 0; i < 50; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + uid := fmt.Sprintf("uid-%d", i) + nm.addRequest(&NetworkRequest{UID: uid}) + }(i) + } + wg.Wait() + Expect(nm.reqs).To(HaveLen(50)) + + for i := 0; i < 50; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + nm.removeRequest(fmt.Sprintf("uid-%d", i)) + }(i) + } + wg.Wait() + Expect(nm.reqs).To(BeEmpty()) + }) + }) + + Context("NetworkManager.processNetworkRequest", Label("processNetworkRequest"), func() { + var ( + testScheme *runtime.Scheme + tempDir string + origNetworkRequestDir string + ) + + BeforeEach(func() { + testScheme = runtime.NewScheme() + Expect(provisioningv1.AddToScheme(testScheme)).To(Succeed()) + Expect(operatorv1.AddToScheme(testScheme)).To(Succeed()) + + var err error + tempDir, err = os.MkdirTemp("", "process-test-*") + Expect(err).NotTo(HaveOccurred()) + origNetworkRequestDir = NetworkRequestDir + NetworkRequestDir = tempDir + }) + + AfterEach(func() { + NetworkRequestDir = origNetworkRequestDir + _ = os.RemoveAll(tempDir) + }) + + It("should clean up when DPU is not found", func() { + fakeClient := fake.NewClientBuilder().WithScheme(testScheme).Build() + nm := NewNetworkManager(fakeClient) + + nr := NetworkRequest{ + UID: "uid-gone", + DpuName: "nonexistent", + DPUNamespace: "default", + VFName: "nonexistent-vf", + } + nm.reqs[nr.UID] = nr + + data, err := json.Marshal(&nr) + Expect(err).NotTo(HaveOccurred()) + Expect(os.WriteFile(filepath.Join(tempDir, nr.UID), data, 0644)).To(Succeed()) + + err = nm.processNetworkRequest(nr) + Expect(err).NotTo(HaveOccurred()) + Expect(nm.reqs).NotTo(HaveKey("uid-gone")) + Expect(filepath.Join(tempDir, nr.UID)).NotTo(BeAnExistingFile()) + }) + + It("should clean up when UID does not match", func() { + dpu := &provisioningv1.DPU{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dpu", + Namespace: "default", + UID: "uid-actual", + }, + } + fakeClient := fake.NewClientBuilder().WithScheme(testScheme).WithObjects(dpu).Build() + nm := NewNetworkManager(fakeClient) + + nr := NetworkRequest{ + UID: "uid-stale", + DpuName: "test-dpu", + DPUNamespace: "default", + VFName: "nonexistent-vf", + } + nm.reqs[nr.UID] = nr + + err := nm.processNetworkRequest(nr) + Expect(err).NotTo(HaveOccurred()) + Expect(nm.reqs).NotTo(HaveKey("uid-stale")) + }) + }) + + Context("Concurrency: AddNetworkRequest not blocked by run()", Label("concurrency"), func() { + var ( + testScheme *runtime.Scheme + tempDir string + origNetworkRequestDir string + ) + + BeforeEach(func() { + testScheme = runtime.NewScheme() + Expect(provisioningv1.AddToScheme(testScheme)).To(Succeed()) + Expect(operatorv1.AddToScheme(testScheme)).To(Succeed()) + + var err error + tempDir, err = os.MkdirTemp("", "concurrency-test-*") + Expect(err).NotTo(HaveOccurred()) + origNetworkRequestDir = NetworkRequestDir + NetworkRequestDir = tempDir + }) + + AfterEach(func() { + NetworkRequestDir = origNetworkRequestDir + _ = os.RemoveAll(tempDir) + }) + + It("should not block AddNetworkRequest while run() is processing", func() { + existingDPU := &provisioningv1.DPU{ + ObjectMeta: metav1.ObjectMeta{ + Name: "slow-dpu", + Namespace: "default", + UID: "uid-slow", + }, + } + newDPU := &provisioningv1.DPU{ + ObjectMeta: metav1.ObjectMeta{ + Name: "fast-dpu", + Namespace: "default", + UID: "uid-fast", + }, + Spec: provisioningv1.DPUSpec{ + SerialNumber: "SN-FAST", + DPUFlavor: "test-flavor", + NodeEffect: provisioningv1.NodeEffect{Action: provisioningv1.Action{NoEffect: ptr.To(true)}}, + }, + } + flavor := &provisioningv1.DPUFlavor{ + ObjectMeta: metav1.ObjectMeta{Name: "test-flavor", Namespace: "default"}, + } + mtu := 1500 + dpfConfig := &operatorv1.DPFOperatorConfig{ + ObjectMeta: metav1.ObjectMeta{Name: "config", Namespace: "default"}, + Spec: operatorv1.DPFOperatorConfigSpec{ + Networking: &operatorv1.Networking{ + ControlPlaneMTU: &mtu, + }, + }, + } + + slowClient := &slowGetClient{ + Client: fake.NewClientBuilder().WithScheme(testScheme).WithObjects(existingDPU, newDPU, flavor, dpfConfig).Build(), + delay: 500 * time.Millisecond, + slowKey: types.NamespacedName{Name: "slow-dpu", Namespace: "default"}, + } + + nm := NewNetworkManager(slowClient) + nm.initialized = true + nm.devicesBySN["SN-FAST"] = hostutil.Device{Address: "0000:04:00", SerialNumber: "SN-FAST"} + + nm.reqs["uid-slow"] = NetworkRequest{ + UID: "uid-slow", + DpuName: "slow-dpu", + DPUNamespace: "default", + } + + go nm.run() + + done := make(chan error, 1) + go func() { + done <- nm.AddNetworkRequest(newDPU) + }() + + select { + case err := <-done: + Expect(err).NotTo(HaveOccurred()) + Expect(nm.reqs).To(HaveKey("uid-fast")) + case <-time.After(200 * time.Millisecond): + Fail("AddNetworkRequest was blocked by run() — lock contention detected") + } + }) + }) }) + +// slowGetClient wraps a real client and injects a delay on Get calls for a +// specific object, simulating a slow API server for concurrency tests. +type slowGetClient struct { + client.Client + delay time.Duration + slowKey types.NamespacedName +} + +func (s *slowGetClient) Get(ctx context.Context, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + if key == s.slowKey { + time.Sleep(s.delay) + } + return s.Client.Get(ctx, key, obj, opts...) +}