Skip to content

skip reduction domain when compare allocation domain mapping in transpose scheduler#5938

Open
liqiangxl wants to merge 2 commits intomainfrom
llu/transpose_skip_reduction
Open

skip reduction domain when compare allocation domain mapping in transpose scheduler#5938
liqiangxl wants to merge 2 commits intomainfrom
llu/transpose_skip_reduction

Conversation

@liqiangxl
Copy link
Collaborator

Summary

  • Fix TransposeDomainMap::hasAtLeastTwoValidGroups to skip reduction domains when checking if allocation domains are all mapped between reference tensors
  • Reduction domains should be ignored during the comparison since they are not relevant for transpose scheduling decisions

Problem

When comparing allocation domains between two reference tensors to determine if they form valid transpose groups, the previous code compared all domains including reduction domains. This caused incorrect behavior when one tensor had reduction domains that shouldn't participate in the mapping check.

For example, when comparing (iS11{i0}, iS12{i6}) with (iS6{i0}, rS7{i1}, bS8{1}), the reduction domain rS7 was incorrectly included in the comparison, causing a mismatch in sequence lengths and incorrect scheduling decisions.

Solution

Filter out reduction domains using std::views::filter before performing the all_mapped comparison. This ensures only iteration and broadcast domains are considered when determining if tensors belong to the same transpose group.

Test plan

  • Updated ReductionIterDomainOnInputsIssue1659 test to expect PointWise scheduler instead of Transpose scheduler, reflecting the corrected behavior
  • Existing transpose tests continue to pass

@liqiangxl liqiangxl marked this pull request as ready for review February 8, 2026 15:49
@liqiangxl
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Feb 8, 2026

Description

  • Fix TransposeDomainMap::hasAtLeastTwoValidGroups to skip reduction domains during allocation domain mapping comparison

  • Filter out reduction domains using std::views::filter before performing domain mapping checks

  • Update test expectation from Transpose to PointWise scheduler to reflect corrected behavior

  • Prevent incorrect scheduling decisions when reduction domains are present in tensor comparisons

Changes walkthrough

Relevant files
Bug fix
domain_map.cpp
Filter reduction domains in transpose domain mapping         

csrc/scheduler/tools/domain_map.cpp

  • Added filtering logic to exclude reduction domains before allocation
    domain comparison
  • Applied std::views::filter to both reference tensor allocation domains
  • Updated all_mapped and any_bcast checks to use filtered domains
  • Added comments explaining the filtering behavior
  • +10/-3   
    Tests
    test_transpose.cpp
    Update test expectation for corrected transpose behavior 

    tests/cpp/test_transpose.cpp

  • Updated ReductionIterDomainOnInputsIssue1659 test to expect PointWise
    scheduler
  • Changed expected heuristic from SchedulerType::Transpose to
    SchedulerType::PointWise
  • Reflects corrected behavior after fixing reduction domain handling
  • +1/-1     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    🔒 No security concerns identified
    ⚡ Recommended focus areas for review
    Performance consideration

    The lambda function is_not_reduction is defined inline within the function. Consider making it static or moving it outside the function scope to avoid potential performance overhead from repeated lambda creation, though this is likely negligible for this use case.

    auto is_not_reduction = [](IterDomain* id) { return !id->isReduction(); };
    Code consistency

    The filtered views ref1_filtered and ref2_filtered are used consistently throughout the updated code, which is good. The fix correctly addresses the core issue by excluding reduction domains from the allocation domain mapping comparison.

    auto ref1_filtered = ref1_loop | std::views::filter(is_not_reduction);
    auto ref2_filtered = ref2_loop | std::views::filter(is_not_reduction);
    
    const bool all_mapped = std::ranges::equal(
        ref1_filtered, ref2_filtered, [&](IterDomain* id1, IterDomain* id2) {
          return ca_map.areMapped(id1, id2, IdMappingMode::PERMISSIVE);
        });
    if (all_mapped) {
      // Not required, just to validate the assumption that all_mapped implies
      // any_bcast
      const bool any_bcast =
          std::ranges::any_of(
              ref1_filtered, [](IterDomain* id) { return id->isBroadcast(); }) ||
          std::ranges::any_of(
              ref2_filtered, [](IterDomain* id) { return id->isBroadcast(); });

    Test failures

    • (Medium, 3) Shape mismatch in thunder.tests.test_update_aliases higher_order_inplace_alias_update with nvFuser CUDA

      Test Name A100 GB200 H100 Source
      thunder.tests.test_update_aliases.test_higher_order_inplace_alias_update_nvfuser_cuda_thunder.dtypes.float32

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 8, 2026

    Greptile Overview

    Greptile Summary

    Fixed TransposeDomainMap::hasAtLeastTwoValidGroups to filter out reduction domains when comparing allocation domains between reference tensors for transpose scheduling decisions.

    Key Changes

    • Added std::views::filter to exclude reduction domains before performing the std::ranges::equal comparison
    • Updated ReductionIterDomainOnInputsIssue1659 test expectation from Transpose to PointWise scheduler

    Analysis

    The change correctly addresses the issue where reduction domains were incorrectly included in allocation domain comparisons, causing mismatches in sequence lengths. However, the codebase consistently uses the pattern !id->isReduction() && !id->isStride() when filtering allocation domains (see TensorDomain::kNoReductions in csrc/ir/internal_base_nodes.h:861-862). Consider also filtering stride domains for consistency with the broader codebase patterns, as stride domains can exist in allocation domains and are typically excluded during scheduling operations.

    Confidence Score: 4/5

    • This PR is safe to merge with the fix correctly addressing the reduction domain filtering issue
    • The core fix is correct and properly addresses the stated problem. The test update reflects the expected behavior change. One style consideration is whether stride domains should also be filtered for consistency with common codebase patterns, though this may not affect correctness in practice
    • No files require special attention - the changes are well-scoped and the test validates the fix

    Important Files Changed

    Filename Overview
    csrc/scheduler/tools/domain_map.cpp Added filtering for reduction domains before comparing allocation domains, but may need to also filter stride domains for consistency with codebase patterns
    tests/cpp/test_transpose.cpp Updated test expectation from Transpose to PointWise scheduler to reflect corrected behavior after fixing reduction domain filtering

    Sequence Diagram

    sequenceDiagram
        participant TS as TransposeScheduler::canSchedule
        participant TDM as TransposeDomainMap::hasAtLeastTwoValidGroups
        participant CAM as ComputeAtMap
        
        TS->>TDM: Check if fusion has at least two valid groups
        TDM->>TDM: groupInputsOutputsByInnerDim()
        TDM->>TDM: findReferenceFor(group1) -> ref1
        TDM->>TDM: findReferenceFor(group2) -> ref2
        TDM->>TDM: Check if innermost dims are mapped
        
        TDM->>TDM: Get allocation domains for ref1 and ref2
        Note over TDM: Filter out reduction domains<br/>(NEW: this fix skips reductions)
        TDM->>TDM: ref1_filtered = ref1_loop | filter(!isReduction)
        TDM->>TDM: ref2_filtered = ref2_loop | filter(!isReduction)
        
        TDM->>CAM: Compare filtered domains using areMapped()
        CAM-->>TDM: all_mapped result
        
        alt all_mapped is true
            TDM->>TDM: Check for broadcast domains
            Note over TDM: If all mapped with broadcast,<br/>use PointWise scheduler instead
            TDM-->>TS: return false (not valid for Transpose)
        else all_mapped is false
            Note over TDM: Different mappings indicate<br/>valid transpose pattern
            TDM-->>TS: return true (valid for Transpose)
        end
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    2 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Comment on lines +546 to +548
    auto is_not_reduction = [](IterDomain* id) { return !id->isReduction(); };
    auto ref1_filtered = ref1_loop | std::views::filter(is_not_reduction);
    auto ref2_filtered = ref2_loop | std::views::filter(is_not_reduction);
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    consider also filtering stride domains for consistency with codebase patterns. the standard TensorDomain::kNoReductions filter (defined in csrc/ir/internal_base_nodes.h:861-862) filters both reductions AND strides: !id->isReduction() && !id->isStride(). this pattern appears 66+ times across the codebase when filtering allocation domains.

    Suggested change
    auto is_not_reduction = [](IterDomain* id) { return !id->isReduction(); };
    auto ref1_filtered = ref1_loop | std::views::filter(is_not_reduction);
    auto ref2_filtered = ref2_loop | std::views::filter(is_not_reduction);
    // Filter out reduction and stride domains before comparing as they are
    // ignored during scheduling.
    auto is_not_reduction_or_stride = [](IterDomain* id) {
    return !id->isReduction() && !id->isStride();
    };
    auto ref1_filtered = ref1_loop | std::views::filter(is_not_reduction_or_stride);
    auto ref2_filtered = ref2_loop | std::views::filter(is_not_reduction_or_stride);

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    1 participant