diff --git a/internal/provisioning/hostagent/networkmanager/network_manager.go b/internal/provisioning/hostagent/networkmanager/network_manager.go index 6aefe393c..0aa8f415d 100644 --- a/internal/provisioning/hostagent/networkmanager/network_manager.go +++ b/internal/provisioning/hostagent/networkmanager/network_manager.go @@ -154,16 +154,26 @@ func (nm *NetworkManager) loadNetworkRequest() error { } func (nm *NetworkManager) run() { - nm.Lock() - defer nm.Unlock() + nm.RLock() + reqs := make([]NetworkRequest, 0, len(nm.reqs)) + for _, nr := range nm.reqs { + reqs = append(reqs, nr) + } + nm.RUnlock() - for uid, nr := range nm.reqs { + for _, nr := range reqs { if err := nm.processNetworkRequest(nr); err != nil { - klog.Errorf("failed to process network request, nr: %+v, err: %v", nm.reqs[uid], err) + klog.Errorf("failed to process network request, nr: %+v, err: %v", nr, err) } } } +func (nm *NetworkManager) removeRequest(uid string) { + nm.Lock() + defer nm.Unlock() + delete(nm.reqs, uid) +} + func (nm *NetworkManager) processNetworkRequest(nr NetworkRequest) error { nn := types.NamespacedName{Namespace: nr.DPUNamespace, Name: nr.DpuName} dpu := &provisioningv1.DPU{} @@ -180,7 +190,7 @@ func (nm *NetworkManager) processNetworkRequest(nr NetworkRequest) error { if err != nil && !os.IsNotExist(err) { return fmt.Errorf("failed to remove network request file: %w", err) } - delete(nm.reqs, nr.UID) + nm.removeRequest(nr.UID) klog.Infof("removed VF and network request for DPU %s/%s(UID: %s)", nr.DPUNamespace, nr.DpuName, nr.UID) return nil } @@ -270,16 +280,34 @@ func (nm *NetworkManager) processNetworkRequest(nr NetworkRequest) error { return nil } -func (nm *NetworkManager) AddNetworkRequest(dpu *provisioningv1.DPU) error { - nm.Lock() - defer nm.Unlock() +// lookupDevice checks preconditions and returns the PCI device for the DPU. +// Returns found=false when a network request for this DPU already exists. +func (nm *NetworkManager) lookupDevice(dpu *provisioningv1.DPU) (dev hostutil.Device, found bool, err error) { + nm.RLock() + defer nm.RUnlock() if !nm.initialized { - return fmt.Errorf("network manager is not initialized") - } else if dpu == nil { + return hostutil.Device{}, false, fmt.Errorf("network manager is not initialized") + } + if _, ok := nm.reqs[string(dpu.UID)]; ok { + return hostutil.Device{}, false, nil + } + dev, ok := nm.devicesBySN[dpu.Spec.SerialNumber] + if !ok { + return hostutil.Device{}, false, fmt.Errorf("PCI address of device %s not found", dpu.Spec.SerialNumber) + } + return dev, true, nil +} + +func (nm *NetworkManager) AddNetworkRequest(dpu *provisioningv1.DPU) error { + if dpu == nil { return fmt.Errorf("DPU is nil") } - if _, ok := nm.reqs[string(dpu.UID)]; ok { + dev, found, err := nm.lookupDevice(dpu) + if err != nil { + return err + } + if !found { return nil } @@ -289,10 +317,6 @@ func (nm *NetworkManager) AddNetworkRequest(dpu *provisioningv1.DPU) error { nr.SetDPUObjectMeta(*dpu) // use the PCI address collected locally, so that it's not affected by PCI address changes - dev, ok := nm.devicesBySN[nr.SerialNumber] - if !ok { - return fmt.Errorf("PCI address of device %s not found", nr.SerialNumber) - } nr.PCIAddress = dev.Address numOfVFs, err := nm.getNumOfVFs(dpu) @@ -317,10 +341,17 @@ func (nm *NetworkManager) AddNetworkRequest(dpu *provisioningv1.DPU) error { if err := writeNetworkRequestFile(nr); err != nil { return fmt.Errorf("failed to write network request file: %w", err) } - nm.reqs[nr.UID] = *nr + + nm.addRequest(nr) return nil } +func (nm *NetworkManager) addRequest(nr *NetworkRequest) { + nm.Lock() + defer nm.Unlock() + nm.reqs[nr.UID] = *nr +} + func (nm *NetworkManager) getNumOfVFs(dpu *provisioningv1.DPU) (int, error) { flavor := &provisioningv1.DPUFlavor{} if err := nm.Get(context.TODO(), types.NamespacedName{Namespace: dpu.Namespace, Name: dpu.Spec.DPUFlavor}, flavor); err != nil { diff --git a/internal/provisioning/hostagent/networkmanager/network_manager_test.go b/internal/provisioning/hostagent/networkmanager/network_manager_test.go index bf13a747f..267562a13 100644 --- a/internal/provisioning/hostagent/networkmanager/network_manager_test.go +++ b/internal/provisioning/hostagent/networkmanager/network_manager_test.go @@ -19,15 +19,26 @@ limitations under the License. package networkmanager import ( + "context" + "encoding/json" + "fmt" "os" + "path/filepath" + "sync" + "time" + operatorv1 "github.com/nvidia/doca-platform/api/operator/v1alpha1" provisioningv1 "github.com/nvidia/doca-platform/api/provisioning/v1alpha1" hostutil "github.com/nvidia/doca-platform/internal/provisioning/hostagent/util" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" ) var _ = Describe("NetworkManager", func() { @@ -210,4 +221,297 @@ var _ = Describe("NetworkManager", func() { } }) }) + + Context("NetworkManager.lookupDevice", Label("lookupDevice"), func() { + It("should return device when initialized, device exists, and no existing request", func() { + nm := NewNetworkManager(nil) + nm.initialized = true + nm.devicesBySN["MT2334XZ0L"] = hostutil.Device{ + Address: "0000:03:00", + SerialNumber: "MT2334XZ0L", + NumOfPFs: 2, + } + + dpu := &provisioningv1.DPU{ + ObjectMeta: metav1.ObjectMeta{UID: "uid-1"}, + Spec: provisioningv1.DPUSpec{SerialNumber: "MT2334XZ0L", NodeEffect: provisioningv1.NodeEffect{Action: provisioningv1.Action{NoEffect: ptr.To(true)}}}, + } + + dev, found, err := nm.lookupDevice(dpu) + Expect(err).NotTo(HaveOccurred()) + Expect(found).To(BeTrue()) + Expect(dev.Address).To(Equal("0000:03:00")) + }) + + It("should return found=false when request already exists for the UID", func() { + nm := NewNetworkManager(nil) + nm.initialized = true + nm.devicesBySN["MT2334XZ0L"] = hostutil.Device{Address: "0000:03:00", SerialNumber: "MT2334XZ0L"} + nm.reqs["uid-1"] = NetworkRequest{UID: "uid-1"} + + dpu := &provisioningv1.DPU{ + ObjectMeta: metav1.ObjectMeta{UID: "uid-1"}, + Spec: provisioningv1.DPUSpec{SerialNumber: "MT2334XZ0L", NodeEffect: provisioningv1.NodeEffect{Action: provisioningv1.Action{NoEffect: ptr.To(true)}}}, + } + + _, found, err := nm.lookupDevice(dpu) + Expect(err).NotTo(HaveOccurred()) + Expect(found).To(BeFalse()) + }) + + It("should return error when not initialized", func() { + nm := NewNetworkManager(nil) + + dpu := &provisioningv1.DPU{ + ObjectMeta: metav1.ObjectMeta{UID: "uid-1"}, + Spec: provisioningv1.DPUSpec{SerialNumber: "MT2334XZ0L", NodeEffect: provisioningv1.NodeEffect{Action: provisioningv1.Action{NoEffect: ptr.To(true)}}}, + } + + _, _, err := nm.lookupDevice(dpu) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("not initialized")) + }) + + It("should return error when device serial number not in map", func() { + nm := NewNetworkManager(nil) + nm.initialized = true + + dpu := &provisioningv1.DPU{ + ObjectMeta: metav1.ObjectMeta{UID: "uid-1"}, + Spec: provisioningv1.DPUSpec{SerialNumber: "UNKNOWN", NodeEffect: provisioningv1.NodeEffect{Action: provisioningv1.Action{NoEffect: ptr.To(true)}}}, + } + + _, _, err := nm.lookupDevice(dpu) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("not found")) + }) + }) + + Context("NetworkManager.addRequest and removeRequest", Label("addRequest", "removeRequest"), func() { + It("should insert a request into the map", func() { + nm := NewNetworkManager(nil) + nr := &NetworkRequest{UID: "uid-add-1", DpuName: "dpu-1"} + + nm.addRequest(nr) + + Expect(nm.reqs).To(HaveKey("uid-add-1")) + Expect(nm.reqs["uid-add-1"].DpuName).To(Equal("dpu-1")) + }) + + It("should remove a request from the map", func() { + nm := NewNetworkManager(nil) + nm.reqs["uid-rm-1"] = NetworkRequest{UID: "uid-rm-1"} + + nm.removeRequest("uid-rm-1") + + Expect(nm.reqs).NotTo(HaveKey("uid-rm-1")) + }) + + It("should handle removing a non-existent key", func() { + nm := NewNetworkManager(nil) + Expect(func() { nm.removeRequest("nonexistent") }).NotTo(Panic()) + }) + + It("should be safe under concurrent access", func() { + nm := NewNetworkManager(nil) + var wg sync.WaitGroup + + for i := 0; i < 50; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + uid := fmt.Sprintf("uid-%d", i) + nm.addRequest(&NetworkRequest{UID: uid}) + }(i) + } + wg.Wait() + Expect(nm.reqs).To(HaveLen(50)) + + for i := 0; i < 50; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + nm.removeRequest(fmt.Sprintf("uid-%d", i)) + }(i) + } + wg.Wait() + Expect(nm.reqs).To(BeEmpty()) + }) + }) + + Context("NetworkManager.processNetworkRequest", Label("processNetworkRequest"), func() { + var ( + testScheme *runtime.Scheme + tempDir string + origNetworkRequestDir string + ) + + BeforeEach(func() { + testScheme = runtime.NewScheme() + Expect(provisioningv1.AddToScheme(testScheme)).To(Succeed()) + Expect(operatorv1.AddToScheme(testScheme)).To(Succeed()) + + var err error + tempDir, err = os.MkdirTemp("", "process-test-*") + Expect(err).NotTo(HaveOccurred()) + origNetworkRequestDir = NetworkRequestDir + NetworkRequestDir = tempDir + }) + + AfterEach(func() { + NetworkRequestDir = origNetworkRequestDir + _ = os.RemoveAll(tempDir) + }) + + It("should clean up when DPU is not found", func() { + fakeClient := fake.NewClientBuilder().WithScheme(testScheme).Build() + nm := NewNetworkManager(fakeClient) + + nr := NetworkRequest{ + UID: "uid-gone", + DpuName: "nonexistent", + DPUNamespace: "default", + VFName: "nonexistent-vf", + } + nm.reqs[nr.UID] = nr + + data, err := json.Marshal(&nr) + Expect(err).NotTo(HaveOccurred()) + Expect(os.WriteFile(filepath.Join(tempDir, nr.UID), data, 0644)).To(Succeed()) + + err = nm.processNetworkRequest(nr) + Expect(err).NotTo(HaveOccurred()) + Expect(nm.reqs).NotTo(HaveKey("uid-gone")) + Expect(filepath.Join(tempDir, nr.UID)).NotTo(BeAnExistingFile()) + }) + + It("should clean up when UID does not match", func() { + dpu := &provisioningv1.DPU{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dpu", + Namespace: "default", + UID: "uid-actual", + }, + } + fakeClient := fake.NewClientBuilder().WithScheme(testScheme).WithObjects(dpu).Build() + nm := NewNetworkManager(fakeClient) + + nr := NetworkRequest{ + UID: "uid-stale", + DpuName: "test-dpu", + DPUNamespace: "default", + VFName: "nonexistent-vf", + } + nm.reqs[nr.UID] = nr + + err := nm.processNetworkRequest(nr) + Expect(err).NotTo(HaveOccurred()) + Expect(nm.reqs).NotTo(HaveKey("uid-stale")) + }) + }) + + Context("Concurrency: AddNetworkRequest not blocked by run()", Label("concurrency"), func() { + var ( + testScheme *runtime.Scheme + tempDir string + origNetworkRequestDir string + ) + + BeforeEach(func() { + testScheme = runtime.NewScheme() + Expect(provisioningv1.AddToScheme(testScheme)).To(Succeed()) + Expect(operatorv1.AddToScheme(testScheme)).To(Succeed()) + + var err error + tempDir, err = os.MkdirTemp("", "concurrency-test-*") + Expect(err).NotTo(HaveOccurred()) + origNetworkRequestDir = NetworkRequestDir + NetworkRequestDir = tempDir + }) + + AfterEach(func() { + NetworkRequestDir = origNetworkRequestDir + _ = os.RemoveAll(tempDir) + }) + + It("should not block AddNetworkRequest while run() is processing", func() { + existingDPU := &provisioningv1.DPU{ + ObjectMeta: metav1.ObjectMeta{ + Name: "slow-dpu", + Namespace: "default", + UID: "uid-slow", + }, + } + newDPU := &provisioningv1.DPU{ + ObjectMeta: metav1.ObjectMeta{ + Name: "fast-dpu", + Namespace: "default", + UID: "uid-fast", + }, + Spec: provisioningv1.DPUSpec{ + SerialNumber: "SN-FAST", + DPUFlavor: "test-flavor", + NodeEffect: provisioningv1.NodeEffect{Action: provisioningv1.Action{NoEffect: ptr.To(true)}}, + }, + } + flavor := &provisioningv1.DPUFlavor{ + ObjectMeta: metav1.ObjectMeta{Name: "test-flavor", Namespace: "default"}, + } + mtu := 1500 + dpfConfig := &operatorv1.DPFOperatorConfig{ + ObjectMeta: metav1.ObjectMeta{Name: "config", Namespace: "default"}, + Spec: operatorv1.DPFOperatorConfigSpec{ + Networking: &operatorv1.Networking{ + ControlPlaneMTU: &mtu, + }, + }, + } + + slowClient := &slowGetClient{ + Client: fake.NewClientBuilder().WithScheme(testScheme).WithObjects(existingDPU, newDPU, flavor, dpfConfig).Build(), + delay: 500 * time.Millisecond, + slowKey: types.NamespacedName{Name: "slow-dpu", Namespace: "default"}, + } + + nm := NewNetworkManager(slowClient) + nm.initialized = true + nm.devicesBySN["SN-FAST"] = hostutil.Device{Address: "0000:04:00", SerialNumber: "SN-FAST"} + + nm.reqs["uid-slow"] = NetworkRequest{ + UID: "uid-slow", + DpuName: "slow-dpu", + DPUNamespace: "default", + } + + go nm.run() + + done := make(chan error, 1) + go func() { + done <- nm.AddNetworkRequest(newDPU) + }() + + select { + case err := <-done: + Expect(err).NotTo(HaveOccurred()) + Expect(nm.reqs).To(HaveKey("uid-fast")) + case <-time.After(200 * time.Millisecond): + Fail("AddNetworkRequest was blocked by run() — lock contention detected") + } + }) + }) }) + +// slowGetClient wraps a real client and injects a delay on Get calls for a +// specific object, simulating a slow API server for concurrency tests. +type slowGetClient struct { + client.Client + delay time.Duration + slowKey types.NamespacedName +} + +func (s *slowGetClient) Get(ctx context.Context, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + if key == s.slowKey { + time.Sleep(s.delay) + } + return s.Client.Get(ctx, key, obj, opts...) +}