Skip to content

remove input cached smem bank conflicts in transpose scheduler#5930

Draft
liqiangxl wants to merge 17 commits intollu/use_pointwise_not_transposefrom
llu/transpose_bank_conflict
Draft

remove input cached smem bank conflicts in transpose scheduler#5930
liqiangxl wants to merge 17 commits intollu/use_pointwise_not_transposefrom
llu/transpose_bank_conflict

Conversation

@liqiangxl
Copy link
Collaborator

No description provided.

@liqiangxl
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Feb 6, 2026

Review updated until commit e8fa871

Description

  • Add memory bandwidth optimization by doubling tile_size2 when insufficient bytes in flight

  • Implement shared memory swizzling to reduce bank conflicts for cached input tensors

  • Modify hasSmallTransposeDimensions to accept pointer instead of unique_ptr reference

  • Change one test from Transpose to PointWise scheduler due to limitations

Changes walkthrough

Relevant files
Enhancement
transpose.cpp
Add memory bandwidth optimization and smem swizzling         

csrc/scheduler/transpose.cpp

  • Change hasSmallTransposeDimensions function signature to accept
    pointer instead of unique_ptr
  • Add memory bandwidth optimization using Little's law to double
    tile_size2 when needed
  • Implement shared memory swizzling for cached input tensors to reduce
    bank conflicts
  • Add conditional logic to disable swizzle for non-square tiles or
    cached outputs
  • Update all call sites to use .get() on unique_ptr when calling
    hasSmallTransposeDimensions
  • +90/-10 
    Tests
    test_rng.cpp
    Update test to use PointWise scheduler                                     

    tests/cpp/test_rng.cpp

  • Change test scheduling from Transpose to PointWise scheduler
  • This test case cannot be handled by the transpose scheduler due to
    limitations
  • +1/-1     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Debug print statements

    Multiple std::cout debug print statements were added (lines 747-752, 1354, 1371) that should be removed or made conditional with a debug flag before merging to production code.

    std::cout << "total_input_bits_per_elem: " << total_input_bits_per_elem
              << std::endl;
    std::cout << "num_elems_per_tile: " << num_elems_per_tile << std::endl;
    std::cout << "max_blocks_per_sm: " << max_blocks_per_sm << std::endl;
    std::cout << "bits_in_flight_per_sm: " << bits_in_flight_per_sm << std::endl;
    std::cout << "required_bits_per_sm: " << required_bits_per_sm << std::endl;
    Complex swizzle conditions

    The new shared memory swizzling logic introduces multiple conditions that disable swizzling (lines 964, 971-979). These conditions should be well-documented and tested to ensure they don't negatively impact performance in edge cases.

    bool use_smem_swizzle = !hasSmallTransposeDimensions(tparams);
    // set cached outputs of group 2 to shared memory
    for (const auto& [cached_output, output_idx] : cached_outputs) {
      auto output = fusion->outputs()[output_idx]->as<TensorView>();
      if (group2_and_cached_inputs.count(output) > 0) {
        cached_output->setMemoryType(MemoryType::Shared);
        // current smem swizzle only works for cached input
        use_smem_swizzle = false;
      }
    }
    // For non-square tile, can't create smem swizzle chunks if tile2 is larger
    // and not vectorized
    if (tparams->tile_size2 > tparams->tile_size1 &&
        tparams->vectorize_factor2 == 1) {
      use_smem_swizzle = false;
    }
    Memory bandwidth optimization

    The new memory bandwidth optimization logic (lines 724-755) that doubles tile_size2 based on Little's law calculations should be validated with performance benchmarks to ensure it provides the expected benefits across different tensor sizes and hardware configurations.

    // Double tile_size2 if the default configuration doesn't provide enough
    // bytes in flight to saturate memory bandwidth. This is based on Little's
    // law: bytes_in_flight = bandwidth * latency. We estimate the bits in flight
    // per SM as: (sum of input tensor element sizes) * elements_per_tile *
    // blocks_per_sm. If this is less than the required bits in flight (derived
    // from hardware bandwidth and memory latency), we double tile_size2 to
    // increase the data in flight. If tile1 is doubled, it will also double
    // shared memory bank conflict, e.g. from 8-ways to 16 ways when increased
    // from 32 to 64 assuming vectorization factor is 4, we need 8 or 16 threads
    // loading per column.
    const auto dev_prop = at::cuda::getCurrentDeviceProperties();
    const int64_t max_blocks_per_sm = dev_prop->maxThreadsPerMultiProcessor /
        TransposeParams::getMaxThreadsPerBlock();
    const int64_t num_elems_per_tile = tparams->tile_size1 * tparams->tile_size2;
    const int64_t required_bits_per_sm =
        scheduler_utils::getRequiredBitsInFlight();
    int64_t total_input_bits_per_elem = 0;
    for (auto tv : ir_utils::filterByType<TensorView>(fusion->inputs())) {
      total_input_bits_per_elem +=
          dataTypeSizeBit(tv->getDataType().value(), index_type);
    }
    const int64_t bits_in_flight_per_sm =
        total_input_bits_per_elem * num_elems_per_tile * max_blocks_per_sm;
    std::cout << "total_input_bits_per_elem: " << total_input_bits_per_elem
              << std::endl;
    std::cout << "num_elems_per_tile: " << num_elems_per_tile << std::endl;
    std::cout << "max_blocks_per_sm: " << max_blocks_per_sm << std::endl;
    std::cout << "bits_in_flight_per_sm: " << bits_in_flight_per_sm << std::endl;
    std::cout << "required_bits_per_sm: " << required_bits_per_sm << std::endl;
    if (bits_in_flight_per_sm < required_bits_per_sm) {
      tparams->tile_size2 *= 2;
    }

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl liqiangxl force-pushed the llu/transpose_bank_conflict branch from 5bbb6fa to 6e53109 Compare February 8, 2026 16:37
    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl liqiangxl changed the base branch from llu/transpose_tile_size to main February 12, 2026 16:11
    @liqiangxl liqiangxl force-pushed the llu/transpose_bank_conflict branch from bd2d2b8 to a8f82f1 Compare February 17, 2026 01:14
    @liqiangxl liqiangxl changed the base branch from main to llu/bcast_alias February 17, 2026 01:17
    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Base automatically changed from llu/bcast_alias to main February 17, 2026 03:51
    @liqiangxl liqiangxl changed the base branch from main to llu/use_pointwise_not_transpose February 17, 2026 17:09
    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    1 participant