Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
name: CI

on:
push:
branches: [main]
pull_request:

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

permissions:
contents: read

env:
CARGO_TERM_COLOR: always

jobs:
fmt:
name: rustfmt
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt
- uses: Swatinem/rust-cache@v2
- run: cargo fmt --all --check

nextest:
name: nextest (${{ matrix.os }})
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- uses: Swatinem/rust-cache@v2
- uses: taiki-e/install-action@nextest

- name: Run tests with nextest
run: cargo nextest run --workspace --release --no-fail-fast

- name: Run doctests
run: cargo test --doc --workspace --release

coverage:
name: coverage
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
with:
components: llvm-tools-preview
- uses: Swatinem/rust-cache@v2
- uses: taiki-e/install-action@nextest
- uses: taiki-e/install-action@cargo-llvm-cov
- name: Generate coverage report
run: cargo llvm-cov nextest --workspace --release
12 changes: 12 additions & 0 deletions src/primitive_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,21 @@ use crate::ADKey;
/// }
///
/// impl PrimitiveOp for AddOp {
/// type ADContext = ();
///
/// fn add() -> Self { AddOp }
/// fn linearize(
/// &self, _b: &mut FragmentBuilder<Self>,
/// _pi: &[GlobalValKey<Self>], _po: &[GlobalValKey<Self>],
/// t: &[Option<LocalValId>],
/// _ctx: &mut (),
/// ) -> Vec<Option<LocalValId>> {
/// vec![t[0].or(t[1])]
/// }
/// fn transpose_rule(
/// &self, _b: &mut FragmentBuilder<Self>,
/// ct: &[Option<LocalValId>], _i: &[ValRef<Self>], _m: &OpMode,
/// _ctx: &mut (),
/// ) -> Vec<Option<LocalValId>> {
/// vec![ct[0], ct[0]]
/// }
Expand All @@ -60,6 +64,12 @@ pub trait PrimitiveOp: GraphOp
where
Self::InputKey: ADKey,
{
/// Runtime AD context threaded through linearization and transpose.
///
/// This can carry information such as concrete shapes or guard decisions
/// that influence how AD rules emit graph structure.
type ADContext: Default;

/// Returns the addition operation used for cotangent accumulation
/// in `tidu::transpose`. When multiple cotangents flow to the same
/// `GlobalValKey`, transpose emits `Op::add()` nodes to sum them.
Expand All @@ -77,6 +87,7 @@ where
primal_in: &[GlobalValKey<Self>],
primal_out: &[GlobalValKey<Self>],
tangent_in: &[Option<LocalValId>],
ctx: &mut Self::ADContext,
) -> Vec<Option<LocalValId>>
where
Self: Sized;
Expand All @@ -91,6 +102,7 @@ where
cotangent_out: &[Option<LocalValId>],
inputs: &[ValRef<Self>],
mode: &OpMode,
ctx: &mut Self::ADContext,
) -> Vec<Option<LocalValId>>
where
Self: Sized;
Expand Down
Loading
Loading