skip reduction domain when compare allocation domain mapping in transpose scheduler#5938
skip reduction domain when compare allocation domain mapping in transpose scheduler#5938
Conversation
|
!test |
Description
|
| Relevant files | |||
|---|---|---|---|
| Bug fix |
| ||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| 🔒 No security concerns identified |
| ⚡ Recommended focus areas for review |
Performance consideration
is_not_reduction is defined inline within the function. Consider making it static or moving it outside the function scope to avoid potential performance overhead from repeated lambda creation, though this is likely negligible for this use case. |
Test failures
-
(Medium, 3)
Shape mismatch in thunder.tests.test_update_aliases higher_order_inplace_alias_update with nvFuser CUDATest Name A100 GB200 H100 Source thunder.tests.test_update_aliases.test_higher_order_inplace_alias_update_nvfuser_cuda_thunder.dtypes.float32 ❌ ❌ ❌
Greptile OverviewGreptile SummaryFixed Key Changes
AnalysisThe change correctly addresses the issue where reduction domains were incorrectly included in allocation domain comparisons, causing mismatches in sequence lengths. However, the codebase consistently uses the pattern Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant TS as TransposeScheduler::canSchedule
participant TDM as TransposeDomainMap::hasAtLeastTwoValidGroups
participant CAM as ComputeAtMap
TS->>TDM: Check if fusion has at least two valid groups
TDM->>TDM: groupInputsOutputsByInnerDim()
TDM->>TDM: findReferenceFor(group1) -> ref1
TDM->>TDM: findReferenceFor(group2) -> ref2
TDM->>TDM: Check if innermost dims are mapped
TDM->>TDM: Get allocation domains for ref1 and ref2
Note over TDM: Filter out reduction domains<br/>(NEW: this fix skips reductions)
TDM->>TDM: ref1_filtered = ref1_loop | filter(!isReduction)
TDM->>TDM: ref2_filtered = ref2_loop | filter(!isReduction)
TDM->>CAM: Compare filtered domains using areMapped()
CAM-->>TDM: all_mapped result
alt all_mapped is true
TDM->>TDM: Check for broadcast domains
Note over TDM: If all mapped with broadcast,<br/>use PointWise scheduler instead
TDM-->>TS: return false (not valid for Transpose)
else all_mapped is false
Note over TDM: Different mappings indicate<br/>valid transpose pattern
TDM-->>TS: return true (valid for Transpose)
end
|
| auto is_not_reduction = [](IterDomain* id) { return !id->isReduction(); }; | ||
| auto ref1_filtered = ref1_loop | std::views::filter(is_not_reduction); | ||
| auto ref2_filtered = ref2_loop | std::views::filter(is_not_reduction); |
There was a problem hiding this comment.
consider also filtering stride domains for consistency with codebase patterns. the standard TensorDomain::kNoReductions filter (defined in csrc/ir/internal_base_nodes.h:861-862) filters both reductions AND strides: !id->isReduction() && !id->isStride(). this pattern appears 66+ times across the codebase when filtering allocation domains.
| auto is_not_reduction = [](IterDomain* id) { return !id->isReduction(); }; | |
| auto ref1_filtered = ref1_loop | std::views::filter(is_not_reduction); | |
| auto ref2_filtered = ref2_loop | std::views::filter(is_not_reduction); | |
| // Filter out reduction and stride domains before comparing as they are | |
| // ignored during scheduling. | |
| auto is_not_reduction_or_stride = [](IterDomain* id) { | |
| return !id->isReduction() && !id->isStride(); | |
| }; | |
| auto ref1_filtered = ref1_loop | std::views::filter(is_not_reduction_or_stride); | |
| auto ref2_filtered = ref2_loop | std::views::filter(is_not_reduction_or_stride); |
Summary
TransposeDomainMap::hasAtLeastTwoValidGroupsto skip reduction domains when checking if allocation domains are all mapped between reference tensorsProblem
When comparing allocation domains between two reference tensors to determine if they form valid transpose groups, the previous code compared all domains including reduction domains. This caused incorrect behavior when one tensor had reduction domains that shouldn't participate in the mapping check.
For example, when comparing
(iS11{i0}, iS12{i6})with(iS6{i0}, rS7{i1}, bS8{1}), the reduction domainrS7was incorrectly included in the comparison, causing a mismatch in sequence lengths and incorrect scheduling decisions.Solution
Filter out reduction domains using
std::views::filterbefore performing theall_mappedcomparison. This ensures only iteration and broadcast domains are considered when determining if tensors belong to the same transpose group.Test plan
ReductionIterDomainOnInputsIssue1659test to expectPointWisescheduler instead ofTransposescheduler, reflecting the corrected behavior