diff --git a/aiter/ops/mhc.py b/aiter/ops/mhc.py index e733f88dc8..dff04e1c1f 100644 --- a/aiter/ops/mhc.py +++ b/aiter/ops/mhc.py @@ -1,10 +1,12 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. -import torch import math -from torch import Tensor + +import torch from aiter import dtypes +from torch import Tensor + from ..jit.core import compile_ops from ..jit.utils.chip_info import get_cu_num @@ -87,18 +89,19 @@ def mhc_pre( if num_tg > meanwhile_tg * 4: break + device = residual.device out_pad = torch.empty( - selected_splitk, m, (hc_mult3 + 31) // 32 * 32, dtype=dtypes.fp32 + selected_splitk, m, (hc_mult3 + 31) // 32 * 32, dtype=dtypes.fp32, device=device ) out = out_pad[:, :, :hc_mult3] - sqrsum = torch.empty(selected_splitk, m, dtype=dtypes.fp32) + sqrsum = torch.empty(selected_splitk, m, dtype=dtypes.fp32, device=device) mhc_pre_gemm_sqrsum(out, sqrsum, residual, fn, selected_tile_k) # out = out.sum(0) # sqrsum = sqrsum.sum(0) - post_mix = torch.empty(m, hc_mult, 1, dtype=dtypes.fp32) - comb_mix = torch.empty(m, hc_mult, hc_mult, dtype=dtypes.fp32) - layer_input = torch.empty(m, hidden_size, dtype=dtypes.bf16) + post_mix = torch.empty(m, hc_mult, 1, dtype=dtypes.fp32, device=device) + comb_mix = torch.empty(m, hc_mult, hc_mult, dtype=dtypes.fp32, device=device) + layer_input = torch.empty(m, hidden_size, dtype=dtypes.bf16, device=device) mhc_pre_big_fuse( post_mix, comb_mix,