Add JAX build integration to multi-arch CI (build-only, tarball-based)#4366
Add JAX build integration to multi-arch CI (build-only, tarball-based)#4366erman-gurses wants to merge 26 commits intomainfrom
Conversation
2d7afc1 to
d305c3e
Compare
| build_jax=JobGroupDecision( | ||
| action=JobAction.RUN if enable_jax else JobAction.SKIP | ||
| ), | ||
| test_jax=JobGroupDecision( | ||
| action=JobAction.RUN if enable_jax else JobAction.SKIP | ||
| ), |
There was a problem hiding this comment.
Scope
This PR only introduces control-plane plumbing. Follow-up PRs will:
- produce JAX-consumable ROCm artifacts
- invoke the JAX workflow
- add validation/tests
The plumbing here looks correct, but I'd prefer to adjust the sequencing so people don't start expecting this to actually do anything (though to that point, we could also remove test_pytorch 🤔)
- Produce ROCm tarballs (I'm sort of adding this as part of [Multi-arch] Add multi-arch release pipelines #3334), OR switch JAX to build from ROCm python packages
- Add JAX building to multi-arch CI workflows, together with this plumbing (they can be separate PRs, but I'd get them up for review together so the period between advertising that it is available and it actually working is short)
- Add JAX testing to multi-arch CI workflows
There was a problem hiding this comment.
I expect I'll have ROCm tarballs produced by multi-arch CI workflows within a few days, so this can probably compose pretty easily on top of that 🤞
There was a problem hiding this comment.
That makes sense, thanks.
I’ll wait for (1) to land on your side, then follow up with a PR for (2). We can review this plumbing PR together with the JAX build integration PR to minimize the gap between enabling it and it actually working.
After that, I’ll proceed with (3).
Will address rest of the comments today.
There was a problem hiding this comment.
@ScottTodd, I combined (1) and (2) because once tarball production was available, the remaining JAX build integration became the direct consumer of the new control-plane outputs (build_jax, jax_amdgpu_family).
Keeping them separate would leave a no-op state where CI exposes JAX enablement but does not execute any JAX workflow. Combining them makes the path coherent end-to-end, while still keeping the scope limited to build-only execution.
e2b78da to
395f69e
Compare
ScottTodd
left a comment
There was a problem hiding this comment.
drive-by review pointing out some details that matter significantly (please be careful testing such changes)
|
@ScottTodd, Those comments are helpful thanks. Not tested anything yet - will do after the fixes. |
|
@ScottTodd, let me know if you still see any critical issue for the testing. |
| generate_tarball_url: | ||
| name: Generate tarball URL | ${{ inputs.amdgpu_family }} | ||
| runs-on: ubuntu-24.04 | ||
| outputs: | ||
| tarball_url: ${{ steps.compute_tarball_url.outputs.tarball_url }} | ||
| steps: | ||
| - name: Checking out repository | ||
| uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 | ||
|
|
||
| - name: Configure AWS credentials | ||
| uses: aws-actions/configure-aws-credentials@b47578312673ae6fa5b5096b330d9fbac3d116df # v4.2.1 | ||
| with: | ||
| role-to-assume: arn:aws:iam::692859939525:role/therock-${{ github.event_name == 'pull_request' && 'developer' || 'artifact' }}-role | ||
| aws-region: us-east-2 | ||
|
|
||
| - name: Generate tarball download URL | ||
| id: compute_tarball_url | ||
| run: | | ||
| python build_tools/github_actions/generate_tarball_urls.py \ | ||
| --run-id "${{ inputs.artifact_run_id }}" \ | ||
| --platform linux \ | ||
| --release-type "${{ inputs.release_type }}" \ | ||
| --package-version "${{ inputs.package_version }}" \ | ||
| --dist-amdgpu-families "${{ inputs.dist_amdgpu_families }}" \ | ||
| --family "${{ inputs.amdgpu_family }}" |
There was a problem hiding this comment.
You can pass the URL in to the workflow instead of computing here.
The workflow then can build from any tarball you provide, instead of needing to infer where a tarball should be. If we change the structure of the tarballs or their index pages then this workflow code would break. If a developer wants to build from a dev or release tarball this would not support that.
See how we pass needs.build_python_packages.outputs.package_find_links_url and inputs.rocm_package_version here:
TheRock/.github/workflows/multi_arch_ci_linux.yml
Lines 197 to 245 in 8552a89
TheRock/.github/workflows/build_portable_linux_python_packages.yml
Lines 71 to 74 in 8552a89
TheRock/.github/workflows/build_portable_linux_python_packages.yml
Lines 94 to 95 in 8552a89
TheRock/build_tools/github_actions/upload_python_packages.py
Lines 282 to 286 in 8552a89
We can set a similar output here:
| cloudfront_url: "https://rocm.devreleases.amd.com/v2" | ||
| cloudfront_staging_url: "https://rocm.devreleases.amd.com/v2-staging" |
There was a problem hiding this comment.
These URLs should be similarly dynamic.
Some references:
- https://github.com/ROCm/TheRock/blob/main/build_tools/_therock_utils/s3_buckets.py
- https://github.com/ROCm/TheRock/blob/main/docs/development/s3_buckets.md
- https://github.com/ROCm/TheRock/blob/main/docs/development/workflow_outputs.md
Looks like https://github.com/ROCm/TheRock/blob/main/.github/workflows/build_linux_jax_wheels.yml is only using package_index_url: ${{ inputs.cloudfront_staging_url }} but then it always copies to a release bucket.
We could add a variant similar to https://github.com/ROCm/TheRock/blob/main/.github/workflows/build_portable_linux_pytorch_wheels_ci.yml that doesn't upload at all (or uploads to the artifacts buckets, not a release bucket), then a release workflow like https://github.com/ROCm/TheRock/blob/main/.github/workflows/multi_arch_release_linux_pytorch_wheels.yml that uploads
[CI] Add JAX build integration to multi-arch CI (build-only, tarball-based)
Motivation
Introduce JAX as a first-class consumer of TheRock multi-arch artifacts in CI.
This change extends the existing control-plane plumbing to execute JAX builds in a clean, multi-arch–compliant way. Instead of building JAX against ad-hoc or hardcoded environments, JAX now consumes ROCm tarballs produced within the same CI run.
Scope
This PR combines:
1. Control-plane plumbing
build_jax/test_jaxasJobGroupDecisioninconfigure_multi_arch_ci.pyci:build-jax
pull_requesteventsThis follows the existing multi-arch CI pattern used for PyTorch (
build_pytorch/test_pytorch).2. Build integration (multi-arch, tarball-based)
This PR introduces build-only JAX execution as a downstream consumer of TheRock artifacts:
multi_arch_build_tarballs.yml
build_tools/github_actions/generate_tarball_urls.py
build_linux_jax_wheels.yml
with tarball input
Each JAX job:
family → tarball → JAX build
Key Changes
CI Wiring
Tarball Handling
build_tools/github_actions/generate_tarball_urls.py
JAX Workflow Wrapper
.github/workflows/build_linux_jax_from_tarball.yml
JAX Workflow Updates
Control Plane
configure_multi_arch_ci.py
setup_multi_arch.yml
Behavior
Default:
JAX jobs are skipped
Opt-in:
ci:build-jax
When enabled: