diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml new file mode 100644 index 000000000..38b706a86 --- /dev/null +++ b/.github/workflows/unit-tests.yml @@ -0,0 +1,31 @@ +name: Unit Tests + +on: + pull_request: + branches: + - master + - release-* + workflow_dispatch: + +concurrency: + group: unit-tests-${{ github.ref }} + cancel-in-progress: true + +jobs: + test: + runs-on: ubuntu-latest + timeout-minutes: 15 + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install system dependencies + run: sudo apt-get update && sudo apt-get install -y libhwloc-dev libdrm-dev + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + + - name: Run unit tests + run: go test ./... diff --git a/internal/pkg/allocator/device.go b/internal/pkg/allocator/device.go index 0358b099a..4cdf6b7c9 100644 --- a/internal/pkg/allocator/device.go +++ b/internal/pkg/allocator/device.go @@ -18,6 +18,7 @@ package allocator import ( "bufio" + "errors" "fmt" "os" "path/filepath" @@ -219,9 +220,9 @@ func scanAndPopulatePeerWeights(fromPath string, devices []*Device, lookupNodes func fetchAllPairWeights(devices []*Device, p2pWeights map[int]map[int]int, folderPath string) error { if len(devices) == 0 { - errMsg := fmt.Sprintf("Devices list is empty. Unable to calculate pair wise weights") + errMsg := "Devices list is empty. Unable to calculate pair wise weights" glog.Info(errMsg) - return fmt.Errorf(errMsg) + return errors.New(errMsg) } if folderPath == "" { folderPath = topoRootPath diff --git a/internal/pkg/amdgpu/amdgpu.go b/internal/pkg/amdgpu/amdgpu.go index 4947191d6..874969fea 100644 --- a/internal/pkg/amdgpu/amdgpu.go +++ b/internal/pkg/amdgpu/amdgpu.go @@ -147,10 +147,19 @@ func GetDevIdsFromTopology(topoRootParam ...string) map[int]string { return renderDevIds } +// FatalOnDriverUnavailable controls whether GetAMDGPUs calls glog.Fatalf +// when the amdgpu driver is not present. Tests set this to false so the +// test process isn't killed on machines without AMD GPUs. +var FatalOnDriverUnavailable = true + // GetAMDGPUs return a map of AMD GPU on a node identified by the part of the pci address func GetAMDGPUs() map[string]map[string]interface{} { if _, err := os.Stat("/sys/module/amdgpu/drivers/"); err != nil { - glog.Fatalf("amdgpu driver unavailable. exiting with exit code 2. error: %s", err) + if FatalOnDriverUnavailable { + glog.Fatalf("amdgpu driver unavailable. exiting with exit code 2. error: %s", err) + } + glog.Warningf("amdgpu driver unavailable: %s", err) + return map[string]map[string]interface{}{} } //ex: /sys/module/amdgpu/drivers/pci:amdgpu/0000:19:00.0 diff --git a/internal/pkg/amdgpu/amdgpu_test.go b/internal/pkg/amdgpu/amdgpu_test.go index 85f3df96c..28a33bd8c 100644 --- a/internal/pkg/amdgpu/amdgpu_test.go +++ b/internal/pkg/amdgpu/amdgpu_test.go @@ -20,6 +20,7 @@ import ( "encoding/json" "fmt" "io/ioutil" + "os" "path/filepath" "reflect" "regexp" @@ -27,6 +28,11 @@ import ( "testing" ) +func TestMain(m *testing.M) { + FatalOnDriverUnavailable = false + os.Exit(m.Run()) +} + func hasAMDGPU(t *testing.T) bool { devices := GetAMDGPUs() @@ -221,38 +227,38 @@ func TestRenderDevIdsFromTopology(t *testing.T) { renderDevIds := GetDevIdsFromTopology("../../../testdata/topology-parsing-mi308") expDevIds := map[int]string{ - 128: "598046273873802902", - 129: "598046273873802902", - 130: "598046273873802902", - 131: "598046273873802902", - 136: "11803749423592941193", - 137: "11803749423592941193", - 138: "11803749423592941193", - 139: "11803749423592941193", - 144: "10187445671099294242", - 145: "10187445671099294242", - 146: "10187445671099294242", - 147: "10187445671099294242", - 152: "9604994527082705072", - 153: "9604994527082705072", - 154: "9604994527082705072", - 155: "9604994527082705072", - 160: "17466021589395472075", - 161: "17466021589395472075", - 162: "17466021589395472075", - 163: "17466021589395472075", - 168: "1044926823201815193", - 169: "1044926823201815193", - 170: "1044926823201815193", - 171: "1044926823201815193", - 176: "13372828617950681944", - 177: "13372828617950681944", - 178: "13372828617950681944", - 179: "13372828617950681944", - 184: "6576958293045616595", - 185: "6576958293045616595", - 186: "6576958293045616595", - 187: "6576958293045616595"} + 128: "0000:0a:00:0", + 129: "0000:0a:00:0", + 130: "0000:0a:00:0", + 131: "0000:0a:00:0", + 136: "0000:80:00:0", + 137: "0000:80:00:0", + 138: "0000:80:00:0", + 139: "0000:80:00:0", + 144: "0000:a4:00:0", + 145: "0000:a4:00:0", + 146: "0000:a4:00:0", + 147: "0000:a4:00:0", + 152: "0000:c8:00:0", + 153: "0000:c8:00:0", + 154: "0000:c8:00:0", + 155: "0000:c8:00:0", + 160: "0001:0b:00:0", + 161: "0001:0b:00:0", + 162: "0001:0b:00:0", + 163: "0001:0b:00:0", + 168: "0001:81:00:0", + 169: "0001:81:00:0", + 170: "0001:81:00:0", + 171: "0001:81:00:0", + 176: "0001:a5:00:0", + 177: "0001:a5:00:0", + 178: "0001:a5:00:0", + 179: "0001:a5:00:0", + 184: "0001:c9:00:0", + 185: "0001:c9:00:0", + 186: "0001:c9:00:0", + 187: "0001:c9:00:0"} if !reflect.DeepEqual(renderDevIds, expDevIds) { val, _ := json.MarshalIndent(renderDevIds, "", " ") exp, _ := json.MarshalIndent(expDevIds, "", " ")