Skip to content

Add JAX build integration to multi-arch CI (build-only, tarball-based)#4366

Draft
erman-gurses wants to merge 26 commits intomainfrom
users/erman-gurses/add-bumps-for-jax
Draft

Add JAX build integration to multi-arch CI (build-only, tarball-based)#4366
erman-gurses wants to merge 26 commits intomainfrom
users/erman-gurses/add-bumps-for-jax

Conversation

@erman-gurses
Copy link
Copy Markdown
Contributor

@erman-gurses erman-gurses commented Apr 7, 2026

[CI] Add JAX build integration to multi-arch CI (build-only, tarball-based)

Motivation

Introduce JAX as a first-class consumer of TheRock multi-arch artifacts in CI.

This change extends the existing control-plane plumbing to execute JAX builds in a clean, multi-arch–compliant way. Instead of building JAX against ad-hoc or hardcoded environments, JAX now consumes ROCm tarballs produced within the same CI run.


Scope

This PR combines:

1. Control-plane plumbing

  • Add build_jax / test_jax as JobGroupDecision in configure_multi_arch_ci.py
  • Enable via PR label:
    ci:build-jax
  • Restrict to pull_request events
  • Default behavior: JAX jobs are skipped unless explicitly enabled

This follows the existing multi-arch CI pattern used for PyTorch (build_pytorch / test_pytorch).


2. Build integration (multi-arch, tarball-based)

This PR introduces build-only JAX execution as a downstream consumer of TheRock artifacts:

  • Produce ROCm tarballs via:
    multi_arch_build_tarballs.yml
  • Generate presigned URLs per architecture using:
    build_tools/github_actions/generate_tarball_urls.py
  • Fan out JAX builds per AMDGPU family using CI matrix
  • Invoke:
    build_linux_jax_wheels.yml
    with tarball input

Each JAX job:
family → tarball → JAX build


Key Changes

CI Wiring

  • multi_arch_ci.yml
    • Add linux_build_tarballs job
    • Add linux_build_jax matrix job (per-family)
    • Wire tarball production → JAX consumption

Tarball Handling

  • Introduce:
    build_tools/github_actions/generate_tarball_urls.py
  • Generate presigned tarball URLs per family
  • Remove legacy single-family (jax_amdgpu_family) assumptions

JAX Workflow Wrapper

  • Add:
    .github/workflows/build_linux_jax_from_tarball.yml
  • Resolves tarball URL and invokes JAX build workflow
  • Clean separation between artifact resolution and build execution

JAX Workflow Updates

  • build_linux_jax_wheels.yml
    • Add run_tests and upload_wheels inputs
    • Gate test and upload jobs

Control Plane

  • configure_multi_arch_ci.py

    • Add build_jax / test_jax
    • Label-based enablement (ci:build-jax)
  • setup_multi_arch.yml

    • Export build_jax to downstream workflows

Behavior

Default:
JAX jobs are skipped

Opt-in:
ci:build-jax

When enabled:

  • JAX builds run per-family using multi-arch matrix
  • Each job consumes the corresponding ROCm tarball from the same CI run

@erman-gurses erman-gurses requested a review from WBobby April 7, 2026 02:34
@erman-gurses erman-gurses linked an issue Apr 7, 2026 that may be closed by this pull request
Comment thread .github/workflows/ci.yml Outdated
Comment thread build_tools/github_actions/configure_ci.py
@erman-gurses erman-gurses force-pushed the users/erman-gurses/add-bumps-for-jax branch from 2d7afc1 to d305c3e Compare April 7, 2026 06:28
@erman-gurses erman-gurses linked an issue Apr 7, 2026 that may be closed by this pull request
@erman-gurses erman-gurses requested a review from ScottTodd April 7, 2026 06:44
Comment thread build_tools/github_actions/configure_multi_arch_ci.py
Comment thread build_tools/github_actions/configure_multi_arch_ci.py
Comment on lines +589 to +594
build_jax=JobGroupDecision(
action=JobAction.RUN if enable_jax else JobAction.SKIP
),
test_jax=JobGroupDecision(
action=JobAction.RUN if enable_jax else JobAction.SKIP
),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scope

This PR only introduces control-plane plumbing. Follow-up PRs will:

  • produce JAX-consumable ROCm artifacts
  • invoke the JAX workflow
  • add validation/tests

The plumbing here looks correct, but I'd prefer to adjust the sequencing so people don't start expecting this to actually do anything (though to that point, we could also remove test_pytorch 🤔)

  1. Produce ROCm tarballs (I'm sort of adding this as part of [Multi-arch] Add multi-arch release pipelines #3334), OR switch JAX to build from ROCm python packages
  2. Add JAX building to multi-arch CI workflows, together with this plumbing (they can be separate PRs, but I'd get them up for review together so the period between advertising that it is available and it actually working is short)
  3. Add JAX testing to multi-arch CI workflows

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expect I'll have ROCm tarballs produced by multi-arch CI workflows within a few days, so this can probably compose pretty easily on top of that 🤞

Copy link
Copy Markdown
Contributor Author

@erman-gurses erman-gurses Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, thanks.

I’ll wait for (1) to land on your side, then follow up with a PR for (2). We can review this plumbing PR together with the JAX build integration PR to minimize the gap between enabling it and it actually working.

After that, I’ll proceed with (3).

Will address rest of the comments today.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ScottTodd, I combined (1) and (2) because once tarball production was available, the remaining JAX build integration became the direct consumer of the new control-plane outputs (build_jax, jax_amdgpu_family).

Keeping them separate would leave a no-op state where CI exposes JAX enablement but does not execute any JAX workflow. Combining them makes the path coherent end-to-end, while still keeping the scope limited to build-only execution.

Comment thread build_tools/github_actions/configure_multi_arch_ci.py Outdated
@erman-gurses erman-gurses changed the title Add opt-in JAX build knob for CI/bump PR workflows Add JAX build integration to multi-arch CI (build-only, tarball-based) Apr 22, 2026
@erman-gurses erman-gurses marked this pull request as draft April 22, 2026 06:39
@erman-gurses erman-gurses force-pushed the users/erman-gurses/add-bumps-for-jax branch from e2b78da to 395f69e Compare April 27, 2026 15:42
Copy link
Copy Markdown
Member

@ScottTodd ScottTodd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drive-by review pointing out some details that matter significantly (please be careful testing such changes)

Comment thread .github/workflows/multi_arch_ci.yml Outdated
Comment thread build_tools/github_actions/generate_tarball_urls.py Outdated
@erman-gurses
Copy link
Copy Markdown
Contributor Author

erman-gurses commented Apr 27, 2026

@ScottTodd, Those comments are helpful thanks. Not tested anything yet - will do after the fixes.

@erman-gurses
Copy link
Copy Markdown
Contributor Author

erman-gurses commented Apr 28, 2026

@ScottTodd, let me know if you still see any critical issue for the testing.

@erman-gurses erman-gurses requested a review from ScottTodd April 28, 2026 06:36
Comment on lines +27 to +51
generate_tarball_url:
name: Generate tarball URL | ${{ inputs.amdgpu_family }}
runs-on: ubuntu-24.04
outputs:
tarball_url: ${{ steps.compute_tarball_url.outputs.tarball_url }}
steps:
- name: Checking out repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2

- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@b47578312673ae6fa5b5096b330d9fbac3d116df # v4.2.1
with:
role-to-assume: arn:aws:iam::692859939525:role/therock-${{ github.event_name == 'pull_request' && 'developer' || 'artifact' }}-role
aws-region: us-east-2

- name: Generate tarball download URL
id: compute_tarball_url
run: |
python build_tools/github_actions/generate_tarball_urls.py \
--run-id "${{ inputs.artifact_run_id }}" \
--platform linux \
--release-type "${{ inputs.release_type }}" \
--package-version "${{ inputs.package_version }}" \
--dist-amdgpu-families "${{ inputs.dist_amdgpu_families }}" \
--family "${{ inputs.amdgpu_family }}"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can pass the URL in to the workflow instead of computing here.

The workflow then can build from any tarball you provide, instead of needing to infer where a tarball should be. If we change the structure of the tarballs or their index pages then this workflow code would break. If a developer wants to build from a dev or release tarball this would not support that.

See how we pass needs.build_python_packages.outputs.package_find_links_url and inputs.rocm_package_version here:

  • build_python_packages:
    needs: [build_multi_arch_stages]
    name: Build Python Packages
    if: ${{ !failure() && !cancelled() && fromJSON(inputs.build_config).expect_failure == false }}
    uses: ./.github/workflows/build_portable_linux_python_packages.yml
    with:
    artifact_group: ${{ fromJSON(inputs.build_config).artifact_group }}
    amdgpu_families: ${{ fromJSON(inputs.build_config).dist_amdgpu_families }}
    multiarch_index: true
    package_version: ${{ inputs.rocm_package_version }}
    release_type: ${{ inputs.release_type }}
    permissions:
    contents: read
    id-token: write
    test_python_packages_per_family:
    needs: [build_python_packages]
    name: Test Python ${{ matrix.family_info.amdgpu_family }} | ${{ matrix.image.name }}
    if: ${{ !failure() && !cancelled() && fromJSON(inputs.build_config).expect_failure == false }}
    strategy:
    fail-fast: false
    matrix:
    family_info: ${{ fromJSON(inputs.build_config).per_family_info }}
    # Fan out wheel tests across multiple base images to catch distro-specific regressions.
    image:
    - name: ubuntu24.04
    url: ghcr.io/rocm/no_rocm_image_ubuntu24_04@sha256:405945a40deaff9db90b9839c0f41d4cba4a383c1a7459b28627047bf6302a26
    - name: ubi10
    url: ghcr.io/rocm/no_rocm_image_ubi10@sha256:a10f34d6006a20d02cf688982de9dea147710927ed405a3b0d5c73b58a6030c0
    uses: ./.github/workflows/test_rocm_wheels.yml
    with:
    amdgpu_family: ${{ matrix.family_info.amdgpu_family }}
    test_runs_on: ${{ matrix.family_info.test-runs-on }}
    # TODO: Simplify to just `needs.build_python_packages.outputs.package_find_links_url`
    # once kpack split is always enabled.
    package_find_links_url: >-
    ${{
    needs.build_python_packages.outputs.kpack_split == 'true'
    && needs.build_python_packages.outputs.package_find_links_url
    || format('{0}/{1}/index.html',
    needs.build_python_packages.outputs.package_find_links_url,
    matrix.family_info.amdgpu_family)
    }}
    python_version: "3.12"
    rocm_version: ${{ inputs.rocm_package_version }}
    amdgpu_targets: ${{ matrix.family_info.amdgpu_targets }}
    kpack_split: ${{ needs.build_python_packages.outputs.kpack_split }}
    container_image_name: ${{ matrix.image.name }}
    container_image_url: ${{ matrix.image.url }}
  • outputs:
    package_find_links_url:
    description: URL for pip --find-links to install built packages
    value: ${{ jobs.build_rocm_wheels.outputs.package_find_links_url }}
  • outputs:
    package_find_links_url: ${{ steps.upload.outputs.package_find_links_url }}
  • log("Set github actions output")
    log("-------------------------")
    gha_set_output(
    {"package_find_links_url": index_url, "kpack_split": kpack_split}
    )

We can set a similar output here:

Comment on lines +66 to +67
cloudfront_url: "https://rocm.devreleases.amd.com/v2"
cloudfront_staging_url: "https://rocm.devreleases.amd.com/v2-staging"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These URLs should be similarly dynamic.

Some references:

Looks like https://github.com/ROCm/TheRock/blob/main/.github/workflows/build_linux_jax_wheels.yml is only using package_index_url: ${{ inputs.cloudfront_staging_url }} but then it always copies to a release bucket.

We could add a variant similar to https://github.com/ROCm/TheRock/blob/main/.github/workflows/build_portable_linux_pytorch_wheels_ci.yml that doesn't upload at all (or uploads to the artifacts buckets, not a release bucket), then a release workflow like https://github.com/ROCm/TheRock/blob/main/.github/workflows/multi_arch_release_linux_pytorch_wheels.yml that uploads

@erman-gurses erman-gurses added the ci:build-jax Enable Jax Build label May 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci:build-jax Enable Jax Build

Projects

Status: TODO

Development

Successfully merging this pull request may close these issues.

[CI] Build and test JAX Python packages as part of ci.yml [CI] Enable JAX builds on bump PR workflows

2 participants