Skip to content

docs: ADR for phase 4 — TwoTowerCF stays in PyTorch (deliberate)#106

Merged
JohnJacob-coder merged 2 commits into
mainfrom
docs/phase-4-pytorch-stays
May 27, 2026
Merged

docs: ADR for phase 4 — TwoTowerCF stays in PyTorch (deliberate)#106
JohnJacob-coder merged 2 commits into
mainfrom
docs/phase-4-pytorch-stays

Conversation

@Burton-David

Copy link
Copy Markdown
Owner

Phase 4 — TwoTowerCF stays in PyTorch

The pattern through Phase 3 was "find the slowest thing the data points at, rewrite it in whatever compiled language matches the workload." This phase exists to argue the discipline of not doing that rewrite for the neural module, and to make the argument durable enough that a future contributor doesn't quietly reopen the question without re-reading what was already considered.

Why

TwoTowerCF is already a thin Python wrapper over PyTorch's C++/CUDA ATen kernels. The Python interpreter is involved per batch, not per (user, positive, negative) triple the way it was in pure-numpy BPR. Applying the Phase-2 pattern here would mean swapping compiled-via-PyTorch for compiled-via-Rust-bindings-to-PyTorch — same underlying kernels with a worse ecosystem story.

Options considered, all rejected

  • Rewrite inner loop in Rust — no interpreter overhead left to eliminate
  • Port to tch-rs — same speed, lose Python autograd / Hugging Face / torch.compile / etc.
  • Hand-rolled CUDA kernels — premature; no specific bottleneck identified
  • JAX rewrite — splits the codebase across two ML frameworks for TPU-only gains
  • torch.compile — adds compile overhead that dominates on short epoch budgets; would force PyTorch 2.1+ floor. Worth revisiting if the neural module grows.

What's in the PR

  • docs/evolution/04-neural-stays-pytorch.md — the ADR
  • CHANGELOG.md — new Documented subsection pointing at it

No code changes anywhere in src/. Independent of Phase 3 (#105). Tag v0.4.0.

Verification

  • ruff check / ruff format --check clean
  • pytest — 135 passed, 1 skipped, 7 deselected (no test changes)

The pattern through Phase 3 was 'find the slowest thing the data points
at, rewrite it in whatever compiled language matches the workload.'
TwoTowerCF is already a thin Python wrapper over PyTorch's C++/CUDA
ATen kernels; applying the same pattern would mean swapping
compiled-via-PyTorch for compiled-via-Rust-bindings-to-PyTorch — same
underlying kernels with a worse ecosystem story.

This ADR documents the refusal: the options considered (Rust port,
tch-rs, custom CUDA, JAX, torch.compile), why each is rejected, and the
discipline that 'we measured before cutting' doesn't mean 'we always
cut.' Sometimes the right move is to recognize the workload is already
on the right tool and leave it alone.

No code changes. Tag v0.4.0.

@JohnJacob-coder JohnJacob-coder left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed the ADR as the deliverable — this reads as engineering reasoning, not rationalization, so it clears the bar for a 'deliberate refusal' phase.

The decisive argument is quantitative and grounded in neural.py: BPR's pure-numpy loop was ~500k per-triple interpreter calls (interpreter-bound, which is why Rust won in Phase 2), whereas TwoTowerCF touches the interpreter only ~len(positives)/batch_size times per epoch — hundreds, each handing off to compiled ATen/CUDA. There's no equivalent interpreter overhead to eliminate. That's the real reason, and it's the same 'measured before cut' logic the rest of the arc uses.

tch-rs is honestly evaluated and rejected for stated reasons (same kernels, same speed, lose Python autograd / the pip PyTorch ecosystem / torch.compile / HF checkpoints). The other options (CUDA, JAX, torch.compile) are concrete and fairly weighed — torch.compile even names the condition under which it'd be worth revisiting. The 'what it does / doesn't mean' section keeps it from reading as PyTorch-fandom (PyTorch isn't special; the module can grow; optimization returns when a bottleneck is measured).

Docs-only, CI green. Optional, non-blocking: a short cProfile of TwoTowerCF.fit showing torch ops dominate would make the per-batch claim empirical rather than structural — but the architectural argument already stands on its own. LGTM. (v0.4.0 is the phase tag.)

@JohnJacob-coder JohnJacob-coder enabled auto-merge (squash) May 27, 2026 22:54
@JohnJacob-coder JohnJacob-coder merged commit f2f6204 into main May 27, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants