Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 84 additions & 15 deletions src/pyjuice/layer/sum_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,20 @@ class _TritonOutOfResources(Exception):
_SMALL_BATCH_PAR_TILE_K = int(os.environ.get("PYJUICE_SB_PAR_TK", 16))
# Minimum block size for the small-batch (batch < 16) block-sparse path. Below this the sparse
# kernel is actually faster (its lower launch/tiling overhead beats the node-tiling parallelism when
# the node dimension is small) -- measured crossover is ~128 on an RTX PRO 6000. Tunable via env.
_SMALL_BATCH_MIN_BLOCK_SIZE = int(os.environ.get("PYJUICE_SB_MIN_BS", 128))
# the node dimension is small). Re-measured on an RTX PRO 6000 after the small-batch block-sparse
# tiling/kernels landed: block_size >= 32 is now faster on the block-sparse path (eager +7..14%, and
# up to ~2x under CUDA graphs at block_size 64), while block_size == 16 regresses ~8% in eager (and
# is neutral under CUDA graphs) -- so the crossover dropped from ~128 to 32. Tunable via env.
_SMALL_BATCH_MIN_BLOCK_SIZE = int(os.environ.get("PYJUICE_SB_MIN_BS", 32))

# Block-sparse edge-tile trim: the compiler pads each partition's edge count up to a power of 2
# (`next_power_of_2`), but the block-sparse kernels iterate the edge dimension in fixed `TILE_SIZE_K`
# tiles. Tiles entirely beyond the partition's REAL (pre-pow2) max edge count are all padding (param 0
# -> the -inf dummy child element 0 -> contribute exactly 0), so iterating only
# `ceil(real_max / TILE_SIZE_K)` tiles is bit-identical and skips the padded work. The raw cids/pids
# stay pow2-wide (untouched), so the sparse / prod / compilation kernels and `mode="sparse"` are
# unaffected. Toggle for A/B; bit-identical so on by default.
_BLOCK_SPARSE_EDGE_TRIM = os.environ.get("PYJUICE_EDGE_TRIM", "1") != "0"


class SumLayer(Layer, nn.Module):
Expand Down Expand Up @@ -310,6 +322,8 @@ def __init__(self, nodes: Sequence[SumNodes], global_nid_start: int,
self._cached_bk_par_cuda = dict()
self._cached_bk_par_choice = dict()
self._cached_bk_par_sparse_choice = dict()
# Edge-tile trim: per (id(cids), TILE_SIZE_K) -> trimmed contiguous (num_edges, cids, pids, pfids).
self._cached_bk_par_trim = dict()
self._bk_par_scratch = None

def to(self, device):
Expand Down Expand Up @@ -546,9 +560,10 @@ def _forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor,
mode = self.BLOCK_SPARSE
elif params.dim() == 1 and self.block_size >= _SMALL_BATCH_MIN_BLOCK_SIZE and num_edges >= 16:
# Small batch (< 16): the sparse kernel leaves the large node dimension un-tiled (one
# program per node block -> ~1 SM busy, >10x slowdown). For a large enough block size the
# block-sparse small-batch tiling is far faster; smaller blocks fall through to the sparse
# kernel below (its lower launch/tiling overhead wins there -- measured crossover ~128).
# program per node block -> ~1 SM busy, >10x slowdown). For block_size >= 32 the
# block-sparse small-batch tiling is far faster; block_size == 16 falls through to the
# sparse kernel below (its lower launch/tiling overhead wins there -- measured crossover
# is 32, see `_SMALL_BATCH_MIN_BLOCK_SIZE`).
mode = self.BLOCK_SPARSE
elif self.block_size == 1 or num_edges < 4:
# In this case, we should definitely use the sparse implementation
Expand Down Expand Up @@ -662,16 +677,27 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten

signature = ("block_sparse", partition_id, TILE_SIZE_K)
if signature not in self._cached_fw_pcids:
# Pre-compute pointer increments for `cids` and `pids`

cids = cids.clone().reshape(cids.size(0), K_NUM_TILES, TILE_SIZE_K)
# Pre-compute pointer increments for `cids` and `pids`. Edge-tile trim (see
# `_BLOCK_SPARSE_EDGE_TRIM`): the per-partition padding to a power of 2 adds trailing tiles
# that are pure padding (param 0 -> the -inf dummy child element 0 -> contribute exactly 0),
# so iterating only the tiles up to the REAL max edge count is bit-identical. The real
# count is read straight off `cids` -- padding is a zero suffix (real children point to
# elements >= 1; element 0 is the dummy) -- so no partition metadata is needed. On a cache
# hit `K_NUM_TILES` is recovered from the (already-trimmed) cached increment tensor below.
if _BLOCK_SPARSE_EDGE_TRIM:
real_max = int((cids != 0).any(dim = 0).sum())
if real_max > 0:
K_NUM_TILES = min(K_NUM_TILES, triton.cdiv(real_max, TILE_SIZE_K))
eff_num_edges = K_NUM_TILES * TILE_SIZE_K

cids = cids[:, :eff_num_edges].clone().reshape(cids.size(0), K_NUM_TILES, TILE_SIZE_K)
cids_start = cids[:,0,:].contiguous()
cids_increment = torch.cat(
(cids[:,1:,:] - cids[:,:-1,:], cids[:,0:1,:] * 0),
dim = 1
).contiguous()

pids = pids.clone().reshape(pids.size(0), K_NUM_TILES, TILE_SIZE_K)
pids = pids[:, :eff_num_edges].clone().reshape(pids.size(0), K_NUM_TILES, TILE_SIZE_K)
pids_start = pids[:,0,:].contiguous()
pids_increment = torch.cat(
(pids[:,1:,:] - pids[:,:-1,:], pids[:,0:1,:] * 0),
Expand All @@ -695,6 +721,12 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten
self._cached_fw_cuda[signature] = [ebase, pbase, cuda_ok]
else:
cids_start, cids_increment, pids_start, pids_increment = self._cached_fw_pcids[signature]
K_NUM_TILES = cids_increment.size(1) # recover the (possibly trimmed) tile count

# Keep `num_edges` consistent with the (possibly trimmed) tile count: the small-batch CUDA
# forward iterates `edge in [0, num_edges)` off the per-block first child, so it must not run
# past the trimmed (contiguity-verified) range. The tlmm CUDA / Triton paths use K_NUM_TILES.
num_edges = K_NUM_TILES * TILE_SIZE_K

partial_eval = 1 if local_ids is not None else 0
BLOCK_SIZE_M = self.block_size
Expand Down Expand Up @@ -1404,24 +1436,33 @@ def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flo

signature = ("block_sparse", partition_id, TILE_SIZE_K)
if signature not in self._cached_bk_parids:
# Pre-compute pointer increments for `parids` and `parpids`
# Pre-compute pointer increments for `parids` and `parpids`. Edge-tile trim (see
# `_BLOCK_SPARSE_EDGE_TRIM`): the parent dim maps to `num_edges = parids.size(1) * block_size`;
# tiles entirely beyond the REAL max parent count are pure padding (contribute 0), so they
# are dropped. The real count is read off `parids` (padding is a zero suffix -- the dummy
# parent is 0). On a cache hit `K_NUM_TILES` is recovered from the cached increment below.
if _BLOCK_SPARSE_EDGE_TRIM:
real_max = int((parids != 0).any(dim = 0).sum())
if real_max > 0:
K_NUM_TILES = min(K_NUM_TILES, triton.cdiv(real_max * self.block_size, TILE_SIZE_K))
eff_pars = (K_NUM_TILES * TILE_SIZE_K) // self.block_size

if TILE_SIZE_K < self.block_size:
ptr_inc_step = 1

num_rep = self.block_size // TILE_SIZE_K
parids = (parids[:,:,None].repeat(1, 1, num_rep) + \
parids = (parids[:, :eff_pars, None].repeat(1, 1, num_rep) + \
torch.arange(0, self.block_size, TILE_SIZE_K, device = parids.device)[None,None,:]).reshape(
parids.size(0), K_NUM_TILES, 1)
parpids = (parpids[:,:,None].repeat(1, 1, num_rep) + \
parpids = (parpids[:, :eff_pars, None].repeat(1, 1, num_rep) + \
torch.arange(0, self.block_size, TILE_SIZE_K, device = parpids.device)[None,None,:]).reshape(
parpids.size(0), K_NUM_TILES, 1)

else:
ptr_inc_step = TILE_SIZE_K // self.block_size

parids = parids.reshape(parids.size(0), K_NUM_TILES, ptr_inc_step)
parpids = parpids.reshape(parpids.size(0), K_NUM_TILES, ptr_inc_step)
parids = parids[:, :eff_pars].reshape(parids.size(0), K_NUM_TILES, ptr_inc_step)
parpids = parpids[:, :eff_pars].reshape(parpids.size(0), K_NUM_TILES, ptr_inc_step)

parids_start = parids[:,0,:].contiguous()
parids_increment = torch.cat(
Expand Down Expand Up @@ -1452,6 +1493,12 @@ def _cumbase(start, incr):
self._cached_bk_ele_cuda[signature] = [ele_ebase, ele_pbase, ele_cuda_ok]
else:
parids_start, parids_increment, parpids_start, parpids_increment, ptr_inc_step = self._cached_bk_parids[signature]
K_NUM_TILES = parids_increment.size(1) # recover the (possibly trimmed) tile count

# Keep `num_edges` consistent with the (possibly trimmed) tile count: the small-batch CUDA ele
# backward iterates `edge in [0, num_edges)` off the per-block first child, so it must not run
# past the trimmed (contiguity-verified) range. The tlmm CUDA / Triton paths use K_NUM_TILES.
num_edges = K_NUM_TILES * TILE_SIZE_K

partial_eval = 1 if local_ids is not None else 0
BLOCK_SIZE_M = cs_block_size
Expand Down Expand Up @@ -1760,7 +1807,7 @@ def _par_flow_collision_free(self, pfids: torch.Tensor) -> bool:
return self._par_collision_free_cache[key]

def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Tensor, node_mars: torch.Tensor,
element_mars: torch.Tensor, param_flows: torch.Tensor, nids: torch.Tensor,
element_mars: torch.Tensor, param_flows: torch.Tensor, nids: torch.Tensor,
cids: torch.Tensor, pids: torch.Tensor, pfids: torch.Tensor,
allow_modify_flows: bool = False, propagation_alg: str = "LL",
logspace_flows: bool = False, negate_pflows: bool = False,
Expand Down Expand Up @@ -1852,6 +1899,28 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor
# untouched), so this is bit-identical.
TILE_SIZE_K = _SMALL_BATCH_PAR_TILE_K

# Edge-tile trim (see `_BLOCK_SPARSE_EDGE_TRIM`): the grid spans `cdiv(num_edges, TILE_SIZE_K)`
# edge-tile programs; programs entirely beyond the REAL (pre-pow2) max edge count only touch
# padding (param 0 -> the -inf dummy child -> 0 flow), so shrinking `num_edges` to it is
# bit-identical. The real count is read off `cids` (padding is a zero suffix -- real children
# point to elements >= 1; element 0 is the dummy), and the trimmed CONTIGUOUS cids/pids/pfids
# are cached (keyed by TILE_SIZE_K, which the parameter-flow tuning can double): the kernels
# index them as `nblock * num_edges + edge`, so num_edges must equal the tensors' actual width.
# Caching the slice -- rather than re-slicing every call -- avoids per-call copy kernels that
# otherwise regress the backward at small batch. The cids-read runs on cache miss (warmup only).
if _BLOCK_SPARSE_EDGE_TRIM:
tkey = (id(cids), TILE_SIZE_K)
trimmed = self._cached_bk_par_trim.get(tkey)
if trimmed is None:
real_max = int((cids != 0).any(dim = 0).sum())
eff = triton.cdiv(real_max, TILE_SIZE_K) * TILE_SIZE_K
if 0 < eff < num_edges:
trimmed = (eff, cids[:, :eff].contiguous(), pids[:, :eff].contiguous(), pfids[:, :eff].contiguous())
else:
trimmed = (num_edges, cids, pids, pfids)
self._cached_bk_par_trim[tkey] = trimmed
num_edges, cids, pids, pfids = trimmed

grid = (triton.cdiv(num_edges, TILE_SIZE_K), triton.cdiv(layer_n_nodes, TILE_SIZE_M))

# Optional CUDA fast path (CuTe/fp16/TMA), autotuned vs Triton INTO A SCRATCH buffer so the
Expand Down
Loading
Loading