Small batch fast kernels#48
Merged
Merged
Conversation
At batch < 16 the sparse sum kernels leave the large node/edge dims un-tiled (one program per node block -> ~1 SM busy, up to ~30x slower than the tiled block-sparse kernels). Route large-block layers (block_size >= 128, tunable) to the block-sparse forward / element-flow backward at small batch too, using a small-batch tiling heuristic that decouples TILE_SIZE_K from the batch and splits the node dimension into many small tiles for SM occupancy (the batch is a short, masked, no-tensor-core inner dim). The parameter-flow backward falls back to the sparse kernel for batch < 16 (its block-sparse variant drops flows when batch % TILE_SIZE_B != 0; sparse is correct for any batch and is a tiny fraction of the cost). The >= 16 path and the >= 16 tile heuristics are byte-identical (the small-batch branches are gated). Smaller blocks (< 128) stay on the sparse kernel, which is faster there (measured crossover ~128 on an RTX PRO 6000). These reuse the existing, validated block-sparse Triton kernels (separate compiled specializations for the small-batch tile configs); no kernel math changed. Measured (GeneralizedHMM, block_size 1024, fwd+bwd): batch 2 6.7x, batch 4 6.5x, batch 8 2.3x; batch 16 unchanged. Bit-level agreement with the sparse path (param flows relmax ~5e-6). Regression test test_small_batch_block_sparse_fast_path. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Adds an optional plain-CUDA fast path for the sum-layer forward in the small-batch regime (batch < 16, block_size >= 128), where the Triton block-sparse kernels under-tile the large node dimension. The kernel (smallbatch_forward_sum.cu) uses 32-node coalesced warps + edge-split online-logsumexp + a within-block shared-memory combine; it is numerically equivalent to the Triton block-sparse forward (~1.5e-6 in log-space, well under the 1.5e-3 bar) and ~1.68x faster on the heaviest layer (~1.2x on the full forward), found best after sweeping 9 kernel variants + 5 pipelining strategies. - smallbatch_forward_sum.cu: generalized v4 kernel (handles num_nblocks>1 via per-node-block ebase/pbase; block_size must be a multiple of 32; hard TORCH_CHECK on bad cfg so a no-op launch can't masquerade as correct against a reused output buffer). - c/__init__.py: a _jit_plain loader path (no CUTLASS, no sm_90 -- plain CUDA works on any CUDA GPU, far broader than the CuTe kernels) + smallbatch_fw_is_available / smallbatch_forward_sum / smallbatch_fw_configs. - sum_layer._forward_block_sparse: small-batch CUDA dispatch, gated identically to the small-batch tiling heuristic (batch<16, block_size>=_SMALL_BATCH_MIN_BLOCK_SIZE, LL, no tempering/partial- eval, contiguous layout), autotuned per (signature, batch) against the Triton small-batch launch and used only when it wins; otherwise falls through to Triton. Forward overwrites node_mars so autotune candidates need no scratch. The >=16 path and small-batch Triton path are unchanged. - Regression test test_small_batch_forward_cuda_matches_triton (skips if no nvcc/ninja). Routing verified: CUDA engages at batch 1/2/3/8, not at batch>=16 (bit-identical there); graceful Triton fallback when CUDA/CUTLASS/ninja absent. Full suite green. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The CuTe/plain sum-layer CUDA kernels (tlmm forward, ele backward, par backward, small-batch forward) launched with kernel<<<grid,blk,smem>>>, i.e. on the default stream. Under torch.cuda.graph capture of the forward/backward this issues them on the wrong stream, so the captured graph has incorrect dependencies and replay silently corrupts the parameter flows (~2% rel-err). Pass c10::cuda::getCurrentCUDAStream() as the launch stream so the kernels capture correctly (as the input-layer cat_backward.cu already does); CUDA-graph replay is now bit-exact. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
At batch < 16 the 2D product kernel's BLOCK_M heuristic balloons to block_size, so each product layer launches a single thread block that walks block_size nodes serially (~1 of 188 SMs busy). On a small-batch HMM this product kernel was 86% of the fwd+bwd GPU time. Cap BLOCK_M at small batch (_SMALL_BATCH_PROD_TILE_M, default 8, env PYJUICE_SB_PROD_TM) so the node dimension fans out across many programs. Pure tiling -> bit-identical; ~22x faster product kernel, ~1.9x faster small-batch fwd+bwd. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Plain-CUDA warp-per-child kernel for the small-batch element-flow backward, mirroring the small-batch forward. Each child node gets a warp that streams its own edges' flows from global and reduces with an in-register online log-sum-exp (no shared staging, no barrier). Autotuned against the Triton csmm2 kernel into a scratch buffer (corruption-safe) and used only when it wins; falls back to Triton otherwise. Bit-identical (~3e-6). Adds the small-batch dispatch in _backward_block_sparse_ele_flows plus regression tests for the ele kernel and the product-layer tiling. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The sparse element-flow backward set BLOCK_M = cs_block_size, so each node-block ran as a single serial program (~1 of 188 SMs busy). On a small-batch HMM the block_size==1 passthrough layer that routes here was a single 90us launch. Add a within-block tile split (TILES_PER_BLOCK) to _bk_triton_sparse_ele_kernel and cap BLOCK_M at small batch (_SMALL_BATCH_SPARSE_TILE_M, default 8, env PYJUICE_SB_SPARSE_TM) so the node dimension fans across many programs. TILES_PER_BLOCK==1 reproduces the original indexing; pure tiling -> bit-identical (torch.equal). ~38x faster on that layer (90us -> 2.4us); kernel floor of the small-batch fwd+bwd 294us -> 205us. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Batch sizes in [16,64) fall between the small-batch (<16) path and the >=64-aligned CUDA path. There the budget heuristics under-tile three kernels -- worst at batch=16 (eager 1905us, cudagraph 1835us): the block-sparse param-flow kernel got ~8 programs (~1 SM, 963us), the product kernel ~588us, and the sparse element-flow kernel 128us. Fixes: (1) gate the par TUNED TILE_SIZE_K doubling to batch>=64 and, for batch<64, shrink the par output-column tile TILE_SIZE_K (bit-safe: only TILE_SIZE_M sets the max-stabilization group); (2) extend the product and sparse-ele node-tile caps from <16 to <64. All pure tiling -> bit-identical (torch.equal). Result at batch=16: par kernel 963->24us (40x), kernel floor 1815->231us; cudagraph 1835->253us (7.3x), eager 1905->1107us. Large batch (64/256/1024) unchanged (gates inactive; TUNED still fires at >=64). New regression test test_gap_batch_tiling_matches_untiled. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The sparse Triton par kernel (the batch<16 fallback) re-reads the shared children per node and was the largest kernel in the small-batch backward floor (~71us). This plain-CUDA kernel exploits the node-contiguous param layout (like the forward): a node-warp (lane = node -> coalesced node-contiguous param load AND param-flow store) with the edges split across the grid. Each (node,edge) flow is independent, so with a single batch tile + collision-free (untied) flows the write is a plain read-add-store -- no atomics, no cross-edge reduction. Launches on the current stream (CUDA-graph-safe). Numerically equivalent (~4e-11). Autotuned vs the Triton sparse par INTO A SCRATCH buffer (param_flows is read-accumulate-write) and used only when it wins; gated to LL, logspace flows, allow_modify_flows/negate/tempering off, single batch tile, collision-free flows, block_size a multiple of 32, batch<16. ~2.5x over Triton (4.0 vs 10.3us/layer clock-pinned); batch=2 cudagraph 226->178us (21%), kernel floor 205->155us. Large batch unchanged (gated batch<16). New regression test test_small_batch_par_backward_cuda_matches_triton. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Triton's per-launch Python (the 'binder': unwrap constexprs + alignment/divisibility specialization per arg + cache-key + launch_metadata) dominates PyJuice's small-batch eager step. The original FastJITFunction (now renamed FastJITFunction2) was disabled for triton>=3 (launch ABI change). FastJITFunction3x caches the compiled kernel under a signature that is a SUPERSET of everything triton specializes on (device, constexpr values, per-tensor dtype+16B-alignment-class, every other arg value, launch meta-params) and, on a cache hit, calls CompiledKernel.run directly with launch_metadata=None (only consumed by optional profiling hooks -> not a silent-correctness risk). Anything unexpected (callable grid, unknown kwarg, missing arg, any exception) falls back to the stock triton.jit launch, so results are always correct. Opt out with PYJUICE_FAST_LAUNCH=0. Also memoize the small-batch CUDA config lists (they were re-entering the pybind module per layer per step). ~6% faster HMM eager fwd+bwd (more on Triton-heavy workloads, ~11%); bit-identical; full layer+model suite passes (26) plus a dedicated launcher safety test hammering alignment/dtype/constexpr variations vs triton.jit. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Speed up the kernels under small batch sizes.