Skip to content

Attention只能使用flash_attn方式计算吗? #12

@goldlee

Description

@goldlee

def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
backend: str = "flash_attn",
*,
causal: bool = False,
softmax_scale: float = None,
attn_kwargs: dict = None,
):
"""
Args:
q (torch.Tensor): Query tensor of shape [batch_size, seq_len, num_heads, head_dim]
k (torch.Tensor): Key tensor of shape [batch_size, seq_len, num_heads, head_dim]
v (torch.Tensor): Value tensor of shape [batch_size, seq_len, num_heads
"""
if "flash_attn" == get_preferred_attention_backend():
assert backend in ["flash_attn"], f"Unsupported attention backend: {backend}"
assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, "Input tensors must be 4D"
batch_size = q.shape[0]

    cu_seqlens_q = attn_kwargs['cu_seqlens_q']
    cu_seqlens_kv = attn_kwargs['cu_seqlens_kv']
    max_seqlen_q = attn_kwargs['max_seqlen_q']
    max_seqlen_kv = attn_kwargs['max_seqlen_kv']
    x = flash_attn_varlen_func(
        q.view(q.shape[0] * q.shape[1], *q.shape[2:]),
        k.view(k.shape[0] * k.shape[1], *k.shape[2:]),
        v.view(v.shape[0] * v.shape[1], *v.shape[2:]),
        cu_seqlens_q,
        cu_seqlens_kv,
        max_seqlen_q,
        max_seqlen_kv,
    )
    output = x.view(
        batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]
    )
else:
    from torch.nn.functional import scaled_dot_product_attention

    # 转换维度: [batch, seq_len, heads, dim] -> [batch, heads, seq_len, dim]
    q = q.transpose(1, 2)
    k = k.transpose(1, 2)
    v = v.transpose(1, 2)

    output = scaled_dot_product_attention(
        q, k, v,
        is_causal=causal,
        scale=softmax_scale
    )

    # 转回 [batch, seq_len, heads, dim]
    output = output.transpose(1, 2)

return output

我添加了torch原生的Attention计算,但是使用inference.py进行推理时,结果完全不对
我使用了https://huggingface.co/jdopensource/JoyAI-Image-Edit-Diffusers里面的input.png,和output1_predicted.png完全对不上,是哪里有问题吗?参数如下:
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description='Run local inference without FastAPI.')
parser.add_argument('--ckpt-root', default="JoyAI-Image-Edit", help='Checkpoint root.')
parser.add_argument('--prompt', default="Remove the construction structure from the top of the crane.", help='Edit prompt or T2I prompt.')
parser.add_argument('--image', default="test_images/input.png", help='Optional input image path for image editing.')
parser.add_argument('--output', default='example.png', help='Output image path.')
parser.add_argument('--height', type=int, default=1024, help='Only used for text-to-image inference.')
parser.add_argument('--width', type=int, default=1024, help='Only used for text-to-image inference.')
parser.add_argument('--steps', type=int, default=30)
parser.add_argument('--guidance-scale', type=float, default=4.0)
parser.add_argument('--seed', type=int, default=123)
parser.add_argument('--neg-prompt', default='')
parser.add_argument('--basesize', type=int, default=1024, help='Resize bucket base size for image editing inputs.')
parser.add_argument('--rewrite-prompt', action='store_true')
parser.add_argument('--config', help='Optional config path. Defaults to /infer_config.py.')
parser.add_argument('--rewrite-model', default='gpt-5')
parser.add_argument('--hsdp-shard-dim', type=int, help='Override config hsdp_shard_dim for multi-GPU FSDP inference.')
return parser.parse_args()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions