Skip to content

level1_81_conv_transposed_2D_asymmetric_input_square_kernel___dilated____padded____strided__ 执行错误 #94

@pdx1989

Description

@pdx1989

报错: Triton 运行失败: 'tuple' object has no attribute 'dtype'

执行代码:

import torch
import torch.nn as nn
import triton
import triton.language as tl


@triton.autotune(configs=[triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N':
    16, 'BLOCK_SIZE_K': 8}, num_warps=4, num_stages=4)], key=['M', 'N', 'K'])
@triton.jit
def _matmul_kernel(A_ptr, B_ptr, C_ptr, M, N, K, stride_am, stride_ak,
    stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr):
    pid = tl.program_id(0)
    pid_m = pid // tl.cdiv(N, BLOCK_SIZE_N)
    pid_n = pid % tl.cdiv(N, BLOCK_SIZE_N)
    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    A_ptr += offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
    B_ptr += offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        a = tl.load(A_ptr, mask=(offs_m[:, None] < M) & (offs_k[None, :] < 
            K - k * BLOCK_SIZE_K), other=0.0)
        b = tl.load(B_ptr, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) &
            (offs_n[None, :] < N), other=0.0)
        acc += tl.dot(a, b, allow_tf32=False)
        A_ptr += BLOCK_SIZE_K * stride_ak
        B_ptr += BLOCK_SIZE_K * stride_bk
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    C_ptr += offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(C_ptr, acc, mask=c_mask)


class ModelNew(nn.Module):

    def __init__(self):
        super(ModelNew, self).__init__()

    def forward(self, A, B):
        n, m, k = A.shape
        l = B.shape[1]
        A_flat = A.view(-1, k)
        C_flat = torch.empty((n * m, l), device=A.device, dtype=A.dtype)
        grid = lambda meta: (triton.cdiv(n * m, meta['BLOCK_SIZE_M']) *
            triton.cdiv(l, meta['BLOCK_SIZE_N']),)
        _matmul_kernel[grid](A_flat, B, C_flat, n * m, l, k, A_flat.stride(
            0), A_flat.stride(1), B.stride(0), B.stride(1), C_flat.stride(0
            ), C_flat.stride(1))
        return C_flat.view(n, m, l)


N = 16
M = 1024
K = 2048
L = 768


def get_inputs():
    A = torch.randn(N, M, K, device='npu')
    B = torch.randn(K, L, device='npu')
    return [A, B]


def get_init_inputs():
    return []

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions