diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index 2644f8e9..2891aa21 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -112,8 +112,20 @@ class _TritonOutOfResources(Exception): _SMALL_BATCH_PAR_TILE_K = int(os.environ.get("PYJUICE_SB_PAR_TK", 16)) # Minimum block size for the small-batch (batch < 16) block-sparse path. Below this the sparse # kernel is actually faster (its lower launch/tiling overhead beats the node-tiling parallelism when -# the node dimension is small) -- measured crossover is ~128 on an RTX PRO 6000. Tunable via env. -_SMALL_BATCH_MIN_BLOCK_SIZE = int(os.environ.get("PYJUICE_SB_MIN_BS", 128)) +# the node dimension is small). Re-measured on an RTX PRO 6000 after the small-batch block-sparse +# tiling/kernels landed: block_size >= 32 is now faster on the block-sparse path (eager +7..14%, and +# up to ~2x under CUDA graphs at block_size 64), while block_size == 16 regresses ~8% in eager (and +# is neutral under CUDA graphs) -- so the crossover dropped from ~128 to 32. Tunable via env. +_SMALL_BATCH_MIN_BLOCK_SIZE = int(os.environ.get("PYJUICE_SB_MIN_BS", 32)) + +# Block-sparse edge-tile trim: the compiler pads each partition's edge count up to a power of 2 +# (`next_power_of_2`), but the block-sparse kernels iterate the edge dimension in fixed `TILE_SIZE_K` +# tiles. Tiles entirely beyond the partition's REAL (pre-pow2) max edge count are all padding (param 0 +# -> the -inf dummy child element 0 -> contribute exactly 0), so iterating only +# `ceil(real_max / TILE_SIZE_K)` tiles is bit-identical and skips the padded work. The raw cids/pids +# stay pow2-wide (untouched), so the sparse / prod / compilation kernels and `mode="sparse"` are +# unaffected. Toggle for A/B; bit-identical so on by default. +_BLOCK_SPARSE_EDGE_TRIM = os.environ.get("PYJUICE_EDGE_TRIM", "1") != "0" class SumLayer(Layer, nn.Module): @@ -310,6 +322,8 @@ def __init__(self, nodes: Sequence[SumNodes], global_nid_start: int, self._cached_bk_par_cuda = dict() self._cached_bk_par_choice = dict() self._cached_bk_par_sparse_choice = dict() + # Edge-tile trim: per (id(cids), TILE_SIZE_K) -> trimmed contiguous (num_edges, cids, pids, pfids). + self._cached_bk_par_trim = dict() self._bk_par_scratch = None def to(self, device): @@ -546,9 +560,10 @@ def _forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, mode = self.BLOCK_SPARSE elif params.dim() == 1 and self.block_size >= _SMALL_BATCH_MIN_BLOCK_SIZE and num_edges >= 16: # Small batch (< 16): the sparse kernel leaves the large node dimension un-tiled (one - # program per node block -> ~1 SM busy, >10x slowdown). For a large enough block size the - # block-sparse small-batch tiling is far faster; smaller blocks fall through to the sparse - # kernel below (its lower launch/tiling overhead wins there -- measured crossover ~128). + # program per node block -> ~1 SM busy, >10x slowdown). For block_size >= 32 the + # block-sparse small-batch tiling is far faster; block_size == 16 falls through to the + # sparse kernel below (its lower launch/tiling overhead wins there -- measured crossover + # is 32, see `_SMALL_BATCH_MIN_BLOCK_SIZE`). mode = self.BLOCK_SPARSE elif self.block_size == 1 or num_edges < 4: # In this case, we should definitely use the sparse implementation @@ -662,16 +677,27 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten signature = ("block_sparse", partition_id, TILE_SIZE_K) if signature not in self._cached_fw_pcids: - # Pre-compute pointer increments for `cids` and `pids` - - cids = cids.clone().reshape(cids.size(0), K_NUM_TILES, TILE_SIZE_K) + # Pre-compute pointer increments for `cids` and `pids`. Edge-tile trim (see + # `_BLOCK_SPARSE_EDGE_TRIM`): the per-partition padding to a power of 2 adds trailing tiles + # that are pure padding (param 0 -> the -inf dummy child element 0 -> contribute exactly 0), + # so iterating only the tiles up to the REAL max edge count is bit-identical. The real + # count is read straight off `cids` -- padding is a zero suffix (real children point to + # elements >= 1; element 0 is the dummy) -- so no partition metadata is needed. On a cache + # hit `K_NUM_TILES` is recovered from the (already-trimmed) cached increment tensor below. + if _BLOCK_SPARSE_EDGE_TRIM: + real_max = int((cids != 0).any(dim = 0).sum()) + if real_max > 0: + K_NUM_TILES = min(K_NUM_TILES, triton.cdiv(real_max, TILE_SIZE_K)) + eff_num_edges = K_NUM_TILES * TILE_SIZE_K + + cids = cids[:, :eff_num_edges].clone().reshape(cids.size(0), K_NUM_TILES, TILE_SIZE_K) cids_start = cids[:,0,:].contiguous() cids_increment = torch.cat( (cids[:,1:,:] - cids[:,:-1,:], cids[:,0:1,:] * 0), dim = 1 ).contiguous() - pids = pids.clone().reshape(pids.size(0), K_NUM_TILES, TILE_SIZE_K) + pids = pids[:, :eff_num_edges].clone().reshape(pids.size(0), K_NUM_TILES, TILE_SIZE_K) pids_start = pids[:,0,:].contiguous() pids_increment = torch.cat( (pids[:,1:,:] - pids[:,:-1,:], pids[:,0:1,:] * 0), @@ -695,6 +721,12 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten self._cached_fw_cuda[signature] = [ebase, pbase, cuda_ok] else: cids_start, cids_increment, pids_start, pids_increment = self._cached_fw_pcids[signature] + K_NUM_TILES = cids_increment.size(1) # recover the (possibly trimmed) tile count + + # Keep `num_edges` consistent with the (possibly trimmed) tile count: the small-batch CUDA + # forward iterates `edge in [0, num_edges)` off the per-block first child, so it must not run + # past the trimmed (contiguity-verified) range. The tlmm CUDA / Triton paths use K_NUM_TILES. + num_edges = K_NUM_TILES * TILE_SIZE_K partial_eval = 1 if local_ids is not None else 0 BLOCK_SIZE_M = self.block_size @@ -1404,24 +1436,33 @@ def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flo signature = ("block_sparse", partition_id, TILE_SIZE_K) if signature not in self._cached_bk_parids: - # Pre-compute pointer increments for `parids` and `parpids` + # Pre-compute pointer increments for `parids` and `parpids`. Edge-tile trim (see + # `_BLOCK_SPARSE_EDGE_TRIM`): the parent dim maps to `num_edges = parids.size(1) * block_size`; + # tiles entirely beyond the REAL max parent count are pure padding (contribute 0), so they + # are dropped. The real count is read off `parids` (padding is a zero suffix -- the dummy + # parent is 0). On a cache hit `K_NUM_TILES` is recovered from the cached increment below. + if _BLOCK_SPARSE_EDGE_TRIM: + real_max = int((parids != 0).any(dim = 0).sum()) + if real_max > 0: + K_NUM_TILES = min(K_NUM_TILES, triton.cdiv(real_max * self.block_size, TILE_SIZE_K)) + eff_pars = (K_NUM_TILES * TILE_SIZE_K) // self.block_size if TILE_SIZE_K < self.block_size: ptr_inc_step = 1 num_rep = self.block_size // TILE_SIZE_K - parids = (parids[:,:,None].repeat(1, 1, num_rep) + \ + parids = (parids[:, :eff_pars, None].repeat(1, 1, num_rep) + \ torch.arange(0, self.block_size, TILE_SIZE_K, device = parids.device)[None,None,:]).reshape( parids.size(0), K_NUM_TILES, 1) - parpids = (parpids[:,:,None].repeat(1, 1, num_rep) + \ + parpids = (parpids[:, :eff_pars, None].repeat(1, 1, num_rep) + \ torch.arange(0, self.block_size, TILE_SIZE_K, device = parpids.device)[None,None,:]).reshape( parpids.size(0), K_NUM_TILES, 1) else: ptr_inc_step = TILE_SIZE_K // self.block_size - parids = parids.reshape(parids.size(0), K_NUM_TILES, ptr_inc_step) - parpids = parpids.reshape(parpids.size(0), K_NUM_TILES, ptr_inc_step) + parids = parids[:, :eff_pars].reshape(parids.size(0), K_NUM_TILES, ptr_inc_step) + parpids = parpids[:, :eff_pars].reshape(parpids.size(0), K_NUM_TILES, ptr_inc_step) parids_start = parids[:,0,:].contiguous() parids_increment = torch.cat( @@ -1452,6 +1493,12 @@ def _cumbase(start, incr): self._cached_bk_ele_cuda[signature] = [ele_ebase, ele_pbase, ele_cuda_ok] else: parids_start, parids_increment, parpids_start, parpids_increment, ptr_inc_step = self._cached_bk_parids[signature] + K_NUM_TILES = parids_increment.size(1) # recover the (possibly trimmed) tile count + + # Keep `num_edges` consistent with the (possibly trimmed) tile count: the small-batch CUDA ele + # backward iterates `edge in [0, num_edges)` off the per-block first child, so it must not run + # past the trimmed (contiguity-verified) range. The tlmm CUDA / Triton paths use K_NUM_TILES. + num_edges = K_NUM_TILES * TILE_SIZE_K partial_eval = 1 if local_ids is not None else 0 BLOCK_SIZE_M = cs_block_size @@ -1760,7 +1807,7 @@ def _par_flow_collision_free(self, pfids: torch.Tensor) -> bool: return self._par_collision_free_cache[key] def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Tensor, node_mars: torch.Tensor, - element_mars: torch.Tensor, param_flows: torch.Tensor, nids: torch.Tensor, + element_mars: torch.Tensor, param_flows: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, pids: torch.Tensor, pfids: torch.Tensor, allow_modify_flows: bool = False, propagation_alg: str = "LL", logspace_flows: bool = False, negate_pflows: bool = False, @@ -1852,6 +1899,28 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor # untouched), so this is bit-identical. TILE_SIZE_K = _SMALL_BATCH_PAR_TILE_K + # Edge-tile trim (see `_BLOCK_SPARSE_EDGE_TRIM`): the grid spans `cdiv(num_edges, TILE_SIZE_K)` + # edge-tile programs; programs entirely beyond the REAL (pre-pow2) max edge count only touch + # padding (param 0 -> the -inf dummy child -> 0 flow), so shrinking `num_edges` to it is + # bit-identical. The real count is read off `cids` (padding is a zero suffix -- real children + # point to elements >= 1; element 0 is the dummy), and the trimmed CONTIGUOUS cids/pids/pfids + # are cached (keyed by TILE_SIZE_K, which the parameter-flow tuning can double): the kernels + # index them as `nblock * num_edges + edge`, so num_edges must equal the tensors' actual width. + # Caching the slice -- rather than re-slicing every call -- avoids per-call copy kernels that + # otherwise regress the backward at small batch. The cids-read runs on cache miss (warmup only). + if _BLOCK_SPARSE_EDGE_TRIM: + tkey = (id(cids), TILE_SIZE_K) + trimmed = self._cached_bk_par_trim.get(tkey) + if trimmed is None: + real_max = int((cids != 0).any(dim = 0).sum()) + eff = triton.cdiv(real_max, TILE_SIZE_K) * TILE_SIZE_K + if 0 < eff < num_edges: + trimmed = (eff, cids[:, :eff].contiguous(), pids[:, :eff].contiguous(), pfids[:, :eff].contiguous()) + else: + trimmed = (num_edges, cids, pids, pfids) + self._cached_bk_par_trim[tkey] = trimmed + num_edges, cids, pids, pfids = trimmed + grid = (triton.cdiv(num_edges, TILE_SIZE_K), triton.cdiv(layer_n_nodes, TILE_SIZE_M)) # Optional CUDA fast path (CuTe/fp16/TMA), autotuned vs Triton INTO A SCRATCH buffer so the diff --git a/tests/model/batch_size_consistency_test.py b/tests/model/batch_size_consistency_test.py index 0c8bf1b4..a2148f9f 100644 --- a/tests/model/batch_size_consistency_test.py +++ b/tests/model/batch_size_consistency_test.py @@ -225,16 +225,17 @@ def test_small_batch_block_sparse_fast_path(): For a large block size the sparse sum kernels leave the big node/edge dimensions un-tiled (one program per node block -> ~1 SM busy, >10x slower than the block-sparse kernels). Layers with - `block_size >= 128` are therefore routed to the block-sparse forward / element-flow backward at - small batch too (with a small-batch tiling heuristic that splits the node dimension for SM - occupancy); the parameter-flow backward falls back to the sparse kernel (correct for any batch). - The results must match the sparse path (forced here as the reference), for both the forward LL - and the accumulated parameter flows. + `block_size >= _SMALL_BATCH_MIN_BLOCK_SIZE` (=32 after re-profiling) are therefore routed to the + block-sparse forward / element-flow backward at small batch too (with a small-batch tiling + heuristic that splits the node dimension for SM occupancy); the parameter-flow backward falls + back to the sparse kernel (correct for any batch). The results must match the sparse path (forced + here as the reference), for both the forward LL and the accumulated parameter flows. The 32/64 + cases pin the lowered crossover (these were on the sparse path before). """ device = torch.device("cuda:0") - for num_latents in [128, 512]: # block_size = num_latents >= 128 -> small-batch block-sparse path + for num_latents in [32, 64, 128, 512]: # block_size = num_latents >= 32 -> small-batch block-sparse path torch.manual_seed(num_latents) ns = juice.structures.GeneralizedHMM( seq_length = 4, num_latents = num_latents, homogeneous = True, @@ -537,6 +538,140 @@ def test_small_batch_par_backward_cuda_matches_triton(): sl.BACKWARD_PAR_FLOW_CUDA = saved +def _clear_sum_caches(pc): + _names = ["_cached_fw_pcids", "_cached_fw_cuda", "_cached_fw_cuda_choice", "_cached_fw_sb", + "_cached_bk_parids", "_cached_bk_ele_cuda", "_cached_bk_ele_choice", "_cached_bk_ele_sb", + "_cached_bk_par_cuda", "_cached_bk_par_choice", "_cached_bk_par_sparse_choice", + "_cached_bk_par_trim"] + for lg in pc.inner_layer_groups: + for layer in lg: + if type(layer).__name__ == "SumLayer": + for nm in _names: + if hasattr(layer, nm): + getattr(layer, nm).clear() + + +def _trim_fires(pc): + # the edge trim fires iff some partition's REAL max edge count (the last non-dummy cids column; + # padding is a contiguous zero suffix, real children point to elements >= 1) is below the + # pow2-padded width. + for lg in pc.inner_layer_groups: + for layer in lg: + if type(layer).__name__ == "SumLayer": + for cids in layer.partitioned_cids: + if int((cids != 0).any(dim = 0).sum()) < int(cids.size(1)): + return True + return False + + +def test_edge_trim_block_sparse_bit_identical(): + """ + The block-sparse edge-tile trim (`_BLOCK_SPARSE_EDGE_TRIM`) drops the fully-padded edge tiles that + the pow2 padding adds. On the Triton path with a small block size (no bf16 tensor-core dot) it must + be exactly bit-identical to the untrimmed result -- forward LL AND accumulated parameter flows. + HCLT (block_size 32, fan-in 160 -> padded to 256) genuinely triggers the trim (asserted). + """ + import pyjuice.layer.sum_layer as sl + + device = torch.device("cuda:0") + torch.manual_seed(160) + xs = torch.randint(0, 256, [400, 16]).float() + ns = juice.structures.HCLT(xs, num_latents = 160) + ns.init_parameters(perturbation = 2.0) + pc = juice.compile(ns) + pc.to(device) + + assert _trim_fires(pc), "edge trim did not fire on HCLT-160 (the test would be vacuous)" + + saved = (sl.FORWARD_SUM_CUDA, sl.BACKWARD_ELE_FLOW_CUDA, sl.BACKWARD_PAR_FLOW_CUDA, sl._BLOCK_SPARSE_EDGE_TRIM) + try: + # CUDA off -> the exact Triton path (the CUDA kernels are only numerically equivalent). + # batch 64 (>= _GAP_BATCH_MAX) exercises the BACKWARD_PAR_FLOW_TUNED path, which doubles + # TILE_SIZE_K -- so the trim is checked against a changed tile size too. + sl.FORWARD_SUM_CUDA = sl.BACKWARD_ELE_FLOW_CUDA = sl.BACKWARD_PAR_FLOW_CUDA = False + for batch_size in [1, 2, 8, 16, 64]: + data = torch.randint(0, 256, [batch_size, 16], device = device) + + sl._BLOCK_SPARSE_EDGE_TRIM = True + _clear_sum_caches(pc) + ll_trim = pc(data).clone() + pc.backward(data, flows_memory = 0.0, allow_modify_flows = False) + pf_trim = pc.param_flows.clone() + + sl._BLOCK_SPARSE_EDGE_TRIM = False + _clear_sum_caches(pc) + ll_full = pc(data).clone() + pc.backward(data, flows_memory = 0.0, allow_modify_flows = False) + pf_full = pc.param_flows.clone() + + assert torch.equal(ll_trim, ll_full), f"edge-trim forward LL not bit-identical at batch={batch_size}" + assert torch.equal(pf_trim, pf_full), f"edge-trim param flows not bit-identical at batch={batch_size}" + finally: + sl.FORWARD_SUM_CUDA, sl.BACKWARD_ELE_FLOW_CUDA, sl.BACKWARD_PAR_FLOW_CUDA, sl._BLOCK_SPARSE_EDGE_TRIM = saved + + +def test_edge_trim_cuda_matches_sparse(): + """ + Guards the trim's interaction with the small-batch CUDA fast path: that kernel iterates + `child = sb_ebase + edge` over `edge in [0, num_edges)` assuming global contiguity, so if the trim + shrank the tile count without ALSO shrinking `num_edges` it would read PAST the real children (a + silent OOB). On a globally-contiguous, non-pow2-fan-in model (HMM, block_size 128, fan-in 640 -> + padded to 1024) the trimmed CUDA result must still match the exact sparse reference within the CUDA + kernels' tolerance. Asserts both the trim fires AND the small-batch CUDA path is actually reached. + Skipped if the CUDA kernels can't be built. + """ + import pyjuice.layer.sum_layer as sl + from pyjuice.layer.kernels import c as cuda_kernels + + if not (torch.cuda.is_available() and cuda_kernels.smallbatch_fw_is_available()): + pytest.skip("small-batch CUDA kernels unavailable (no nvcc/ninja)") + + device = torch.device("cuda:0") + torch.manual_seed(640) + ns = juice.structures.GeneralizedHMM( + seq_length = 4, num_latents = 640, homogeneous = True, + input_dist = juice.distributions.Categorical(num_cats = 6) + ) + ns.init_parameters(perturbation = 2.0) + pc = juice.compile(ns) + pc.to(device) + + assert _trim_fires(pc), "edge trim did not fire on HMM-640 (the test would be vacuous)" + + sums = [L for lg in pc.inner_layer_groups for L in lg if type(L).__name__ == "SumLayer"] + saved = (sl.FORWARD_SUM_CUDA, sl.BACKWARD_ELE_FLOW_CUDA, sl.BACKWARD_PAR_FLOW_CUDA, sl._BLOCK_SPARSE_EDGE_TRIM) + try: + sl._BLOCK_SPARSE_EDGE_TRIM = True + sb_cuda_reached = False + for batch_size in [1, 2, 8]: + data = torch.randint(0, 6, [batch_size, 4], device = device) + + # exact reference: the sparse kernels (no trim, no CUDA, no bf16 tensor-core dot) + sl.FORWARD_SUM_CUDA = sl.BACKWARD_ELE_FLOW_CUDA = sl.BACKWARD_PAR_FLOW_CUDA = False + _clear_sum_caches(pc) + ll_ref = pc(data, mode = "sparse").clone() + pc.backward(data, mode = "sparse", flows_memory = 0.0, allow_modify_flows = False) + pf_ref = pc.param_flows.clone() + + # under test: the trimmed CUDA fast paths + sl.FORWARD_SUM_CUDA = sl.BACKWARD_ELE_FLOW_CUDA = sl.BACKWARD_PAR_FLOW_CUDA = True + _clear_sum_caches(pc) + ll_cuda = pc(data).clone() + pc.backward(data, flows_memory = 0.0, allow_modify_flows = False) + pf_cuda = pc.param_flows.clone() + sb_cuda_reached = sb_cuda_reached or any(v[2] for L in sums for v in L._cached_fw_sb.values()) + + assert torch.isfinite(ll_cuda).all() and torch.isfinite(pf_cuda).all() + assert (ll_cuda - ll_ref).abs().max() < 1e-2, \ + f"trimmed CUDA forward diverged from sparse at batch={batch_size}" + assert (pf_cuda - pf_ref).abs().max() / (pf_ref.abs().max() + 1e-9) < 1e-2, \ + f"trimmed CUDA param flows diverged from sparse at batch={batch_size}" + + assert sb_cuda_reached, "small-batch CUDA path never reached (test does not guard the num_edges fix)" + finally: + sl.FORWARD_SUM_CUDA, sl.BACKWARD_ELE_FLOW_CUDA, sl.BACKWARD_PAR_FLOW_CUDA, sl._BLOCK_SPARSE_EDGE_TRIM = saved + + if __name__ == "__main__": test_hmm_batch_size_consistency() test_hmm_backward_small_batch()