Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: Unit Tests

on:
pull_request:
branches:
- master
- release-*
workflow_dispatch:

concurrency:
group: unit-tests-${{ github.ref }}
cancel-in-progress: true

jobs:
test:
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
- name: Checkout
uses: actions/checkout@v4

- name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y libhwloc-dev libdrm-dev

- name: Set up Go
uses: actions/setup-go@v5
with:
go-version-file: go.mod

- name: Run unit tests
run: go test ./...
5 changes: 3 additions & 2 deletions internal/pkg/allocator/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package allocator

import (
"bufio"
"errors"
"fmt"
"os"
"path/filepath"
Expand Down Expand Up @@ -219,9 +220,9 @@ func scanAndPopulatePeerWeights(fromPath string, devices []*Device, lookupNodes

func fetchAllPairWeights(devices []*Device, p2pWeights map[int]map[int]int, folderPath string) error {
if len(devices) == 0 {
errMsg := fmt.Sprintf("Devices list is empty. Unable to calculate pair wise weights")
errMsg := "Devices list is empty. Unable to calculate pair wise weights"
glog.Info(errMsg)
return fmt.Errorf(errMsg)
return errors.New(errMsg)
Comment thread
bhatnitish marked this conversation as resolved.
}
if folderPath == "" {
folderPath = topoRootPath
Expand Down
11 changes: 10 additions & 1 deletion internal/pkg/amdgpu/amdgpu.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,19 @@ func GetDevIdsFromTopology(topoRootParam ...string) map[int]string {
return renderDevIds
}

// FatalOnDriverUnavailable controls whether GetAMDGPUs calls glog.Fatalf
// when the amdgpu driver is not present. Tests set this to false so the
// test process isn't killed on machines without AMD GPUs.
var FatalOnDriverUnavailable = true

// GetAMDGPUs return a map of AMD GPU on a node identified by the part of the pci address
func GetAMDGPUs() map[string]map[string]interface{} {
if _, err := os.Stat("/sys/module/amdgpu/drivers/"); err != nil {
glog.Fatalf("amdgpu driver unavailable. exiting with exit code 2. error: %s", err)
if FatalOnDriverUnavailable {
glog.Fatalf("amdgpu driver unavailable. exiting with exit code 2. error: %s", err)
}
glog.Warningf("amdgpu driver unavailable: %s", err)
Comment thread
bhatnitish marked this conversation as resolved.
return map[string]map[string]interface{}{}
}

//ex: /sys/module/amdgpu/drivers/pci:amdgpu/0000:19:00.0
Expand Down
70 changes: 38 additions & 32 deletions internal/pkg/amdgpu/amdgpu_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,19 @@ import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"reflect"
"regexp"
"strings"
"testing"
)

func TestMain(m *testing.M) {
FatalOnDriverUnavailable = false
os.Exit(m.Run())
}

func hasAMDGPU(t *testing.T) bool {
devices := GetAMDGPUs()

Expand Down Expand Up @@ -221,38 +227,38 @@ func TestRenderDevIdsFromTopology(t *testing.T) {
renderDevIds := GetDevIdsFromTopology("../../../testdata/topology-parsing-mi308")

expDevIds := map[int]string{
128: "598046273873802902",
129: "598046273873802902",
130: "598046273873802902",
131: "598046273873802902",
136: "11803749423592941193",
137: "11803749423592941193",
138: "11803749423592941193",
139: "11803749423592941193",
144: "10187445671099294242",
145: "10187445671099294242",
146: "10187445671099294242",
147: "10187445671099294242",
152: "9604994527082705072",
153: "9604994527082705072",
154: "9604994527082705072",
155: "9604994527082705072",
160: "17466021589395472075",
161: "17466021589395472075",
162: "17466021589395472075",
163: "17466021589395472075",
168: "1044926823201815193",
169: "1044926823201815193",
170: "1044926823201815193",
171: "1044926823201815193",
176: "13372828617950681944",
177: "13372828617950681944",
178: "13372828617950681944",
179: "13372828617950681944",
184: "6576958293045616595",
185: "6576958293045616595",
186: "6576958293045616595",
187: "6576958293045616595"}
128: "0000:0a:00:0",
129: "0000:0a:00:0",
130: "0000:0a:00:0",
131: "0000:0a:00:0",
136: "0000:80:00:0",
137: "0000:80:00:0",
138: "0000:80:00:0",
139: "0000:80:00:0",
144: "0000:a4:00:0",
145: "0000:a4:00:0",
146: "0000:a4:00:0",
147: "0000:a4:00:0",
152: "0000:c8:00:0",
153: "0000:c8:00:0",
154: "0000:c8:00:0",
155: "0000:c8:00:0",
160: "0001:0b:00:0",
161: "0001:0b:00:0",
162: "0001:0b:00:0",
163: "0001:0b:00:0",
168: "0001:81:00:0",
169: "0001:81:00:0",
170: "0001:81:00:0",
171: "0001:81:00:0",
176: "0001:a5:00:0",
177: "0001:a5:00:0",
178: "0001:a5:00:0",
179: "0001:a5:00:0",
184: "0001:c9:00:0",
185: "0001:c9:00:0",
186: "0001:c9:00:0",
187: "0001:c9:00:0"}
if !reflect.DeepEqual(renderDevIds, expDevIds) {
val, _ := json.MarshalIndent(renderDevIds, "", " ")
exp, _ := json.MarshalIndent(expDevIds, "", " ")
Expand Down
Loading