Conversation
|
!test |
|
Review updated until commit cb8be65 Description
|
| Relevant files | |||||
|---|---|---|---|---|---|
| Enhancement |
| ||||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Breaking API Change
|
78310cd to
ba2f27d
Compare
|
!test |
f3a9608 to
f184027
Compare
|
!test |
f184027 to
b7345f7
Compare
|
!test |
aaa42fd to
12cc04e
Compare
|
!test |
| getProjectedExtent(id), commonOrConstExtent(ca_map_, id)); | ||
| } | ||
|
|
||
| void ContiguousInnerDimensionsMapper::addProjectedExtent( |
There was a problem hiding this comment.
| // Ordering of dimensions is important in this analysis, if an ordering is | ||
| // contiguous in the reference, but not the target tensor views, then we | ||
| // cannot consider that a contiguous merge dimension for vectorization. | ||
| auto projected_logical = projectId(filtered_ids, logical_domain); |
There was a problem hiding this comment.
projected_logical gives me the wrong impression that the whole logical domain is projected. In fact, it's still as filtered as filtered_ids.
Greptile OverviewGreptile SummaryThis PR refactors The main functional change is that the contig-inner-size computation now attempts to multiply projected extents of allocation IDs after matching against the mapper’s logical IDs, relying on Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant S as Scheduler/VectorizeHeuristic
participant M as ContiguousInnerDimensionsMapper
participant TV as TensorView (reference)
participant TV2 as TensorView (target)
S->>M: map(reference_tv, logical_ids)
activate M
M->>M: recording_=true
M->>M: addProjectedExtent(logical_id, commonOrConstExtent)
M->>M: projectId(filtered_ids, logical_domain, logical_domain)
M->>M: projectId(filtered_ids, root_domain, allocation_domain)
M->>M: recording_=false
M->>M: traverse spanning tree
deactivate M
S->>M: getTvToContigMergeOfInnerSizeMap()
activate M
loop for each tv in tv_infos_
M->>M: getContigMergeOfInnerSize(tv)
M->>TV2: alloc = getMaybeAllocationDomain()
M->>TV2: contiguity = getContiguity()
M->>M: projected_dims = mappedLogicalIds(tv)
M->>M: iterate alloc & contiguity (reverse)
M->>M: logical_id = ir_utils::getReachableIds(logical_domain, {alloc_id})
M->>M: if logical_id matches next projected_dim
M->>M: product *= getProjectedExtent(alloc_id)
end
deactivate M
|
| for (auto [alloc_id, cont] : | ||
| zip(alloc | std::views::reverse, contiguity | std::views::reverse)) { | ||
| auto is_treated_as_size_one = [](IterDomain* id) { | ||
| return id->isReduction() || id->isBroadcast() || id->isParallelized() || | ||
| id->extent()->isOneInt(); | ||
| }; | ||
| if (is_treated_as_size_one(alloc_id)) { | ||
| continue; | ||
| } | ||
|
|
||
| auto contiguity_i = contiguity.at(alloc_ii); | ||
| if (!contiguity_i.has_value()) { | ||
| NVF_THROW("contiguity flag at alloc_ii can't be null"); | ||
| } else { | ||
| // Not contiguous | ||
| if (!contiguity_i.value()) { | ||
| break; | ||
| } | ||
| NVF_ERROR(cont.has_value()); | ||
| if (!cont.value()) { | ||
| break; | ||
| } | ||
|
|
||
| // Get the logical ID corresponding to the allocation ID. | ||
| auto exprs = DependencyCheck::getAllExprsBetween( | ||
| {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()}, | ||
| {alloc_iid}); | ||
| IterDomain* logical_id = alloc_iid; | ||
| Val* num_devices = tv->container()->oneVal(); | ||
| bool only_valid_device_split = true; | ||
| for (Expr* expr : exprs | std::views::reverse) { | ||
| if (!isValidDeviceSplit(expr)) { | ||
| only_valid_device_split = false; | ||
| break; | ||
| } | ||
| auto* split = expr->as<Split>(); | ||
| logical_id = split->in(); | ||
| num_devices = SimplifyingIrBuilder::mulExpr(num_devices, split->factor()); | ||
| while (projected_dim != projected_dims.rend() && | ||
| is_treated_as_size_one(*projected_dim)) { | ||
| projected_dim++; | ||
| } | ||
|
|
||
| // Non device split could lead to padding, which prevents vectorization | ||
| if (!only_valid_device_split) { | ||
| break; | ||
| } | ||
| IterDomain* logical_id = [&]() { | ||
| std::vector<IterDomain*> reachable_ids = | ||
| ir_utils::getReachableIds(tv->getLogicalDomain(), {alloc_id}); | ||
| NVF_ERROR_EQ(reachable_ids.size(), 1); | ||
| return reachable_ids.front(); | ||
| }(); | ||
|
|
||
| // Mapping order isn't correct, cannot expand vectorization dimension. | ||
| if (projected_dims[--projected_dims_i] != logical_id) { | ||
| if (projected_dim == projected_dims.rend() || | ||
| *projected_dim != logical_id) { | ||
| break; | ||
| } | ||
|
|
||
| Val* sharded_extent; | ||
| if (logical_id->isDeviceDim()) { | ||
| sharded_extent = tv->container()->oneVal(); | ||
| } else { | ||
| sharded_extent = SimplifyingIrBuilder::divExpr( | ||
| getProjectedExtent(logical_id), num_devices); | ||
| } | ||
| product_of_inner_extents = | ||
| SimplifyingIrBuilder::mulExpr(product_of_inner_extents, sharded_extent); | ||
| // This assumes projected_dim can be matched only once. This assumption is | ||
| // OK for now but when we get to non-outermost sharding such as | ||
| // ``` | ||
| // [iS0] | ||
| // / \. | ||
| // iS1 iS2 | ||
| // / \. | ||
| // iDIDx3 iS4 | ||
| // ``` | ||
| // We may want to allow multiple contiguous allocation IDs to match | ||
| // projected_dim. | ||
| projected_dim++; | ||
|
|
||
| product_of_inner_extents = SimplifyingIrBuilder::mulExpr( | ||
| product_of_inner_extents, getProjectedExtent(alloc_id)); |
There was a problem hiding this comment.
Allocation extent mismatched
getContigMergeOfInnerSize now multiplies getProjectedExtent(alloc_id) (i.e., allocation ID) after matching logical_id against mappedLogicalIds(tv) (logical IDs). projected_extent_ values originate from recording logical/root projections and are not guaranteed to include allocation IDs; in those cases getProjectedExtent(alloc_id) will throw "Not projected" at runtime. Even when present, using allocation-ID extents here changes semantics vs the previous logical-ID-based computation and can incorrectly size the contig inner extent for TVs with an allocation permutation. Consider multiplying the projected extent for the matched logical_id (or ensure allocation IDs are always recorded consistently before using them here).
| const std::vector<IterDomain*>& alloc = tv->getMaybeAllocationDomain(); | ||
| const std::vector<std::optional<bool>>& contiguity = tv->getContiguity(); | ||
|
|
||
| NVF_ERROR(hasMappedDims(tv)); | ||
|
|
||
| const std::vector<IterDomain*>& projected_dims = mappedLogicalIds(tv); | ||
| auto alloc_no_reductions = TensorDomain::noReductions(alloc); | ||
|
|
||
| std::vector<std::optional<bool>> contiguity = tv->domain()->contiguity(); | ||
| NVF_ERROR_EQ(contiguity.size(), alloc.size()); | ||
| // Appears after reductions the reduction domain often has a contiguity entry. | ||
| // This only matters if the result of the reduction is an output | ||
| if (contiguity.size() != alloc_no_reductions.size()) { | ||
| std::vector<std::optional<bool>> new_contiguity; | ||
| for (auto i : arange(alloc.size())) { | ||
| if (!alloc[i]->isReduction()) { | ||
| new_contiguity.push_back(contiguity.at(i)); | ||
| } | ||
| } | ||
| contiguity = new_contiguity; | ||
| } | ||
|
|
||
| auto alloc_no_reductions_size = alloc_no_reductions.size(); | ||
|
|
||
| NVF_ERROR_EQ(alloc_no_reductions_size, contiguity.size()); | ||
|
|
||
| Val* product_of_inner_extents = tv->container()->oneVal(); | ||
| // Order is important, need to make sure dimensions match up correctly with | ||
| // what was propogated through the mapper. The mapper's dimensions is | ||
| // propogated in the order of the reference, if that order doesn't match the | ||
| // tensor we're mapping too then a transpose interfered with expanded the | ||
| // vectorize dimension. | ||
| size_t projected_dims_i = projected_dims.size(); | ||
|
|
||
| for (auto i : arange(alloc_no_reductions_size)) { | ||
| if (projected_dims_i == 0) { | ||
| break; | ||
| } | ||
| auto alloc_ii = alloc_no_reductions_size - i - 1; | ||
| auto alloc_iid = alloc_no_reductions.at(alloc_ii); | ||
|
|
||
| if (alloc_iid->extent()->isOneInt() || alloc_iid->isBroadcast()) { | ||
| if (projected_dims[projected_dims_i - 1] == alloc_iid) { | ||
| --projected_dims_i; | ||
| } | ||
| auto projected_dim = projected_dims.rbegin(); | ||
| // Wish I could `zip(alloc, contiguity) | std::views::reverse` here. It | ||
| // doesn't compile. | ||
| for (auto [alloc_id, cont] : | ||
| zip(alloc | std::views::reverse, contiguity | std::views::reverse)) { |
There was a problem hiding this comment.
Contiguity/alloc size assumption
This loop zips tv->getMaybeAllocationDomain() with tv->getContiguity() and iterates them in lockstep. If getContiguity() is defined in terms of the logical/root domain (as it historically was via tv->domain()->contiguity()), TVs with a distinct allocation domain can have a different rank/order, making the zip silently drop trailing elements and compute an incorrect inner-extent product. At minimum this should assert alloc.size() == contiguity.size() before zipping (or fetch contiguity for the allocation domain explicitly).
Makes the code less error prone, and removes the reliance on isValidDeviceSplit to support non-outermost sharding in the future.
Should be an NFC.