Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions mlir/docs/DefiningDialects/Assembly.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ void MyDialect::initialize() {
}
```

* If `getAlias` provides an alias with a trailing digit, `AsmPrinter` appends an underscore to avoid conflicts with autogenerated IDs.
* If multiple types/attributes have the same alias from `getAlias`, a number is appended to the alias to avoid conflicts.
* If multiple types/attributes have the same alias from `getAlias`, a numeric suffix is appended to the alias to disambiguate. The suffix assignment automatically avoids collisions with other registered alias names.

## Suggesting SSA/Block Names

Expand Down
4 changes: 3 additions & 1 deletion mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its
// Modifications (c) Copyright 2025-2026 Advanced Micro Devices, Inc. or its
// affiliates
//
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -263,6 +263,8 @@ class BlockFloatQuantizedType

static std::optional<BlockMode> parseBlockMode(StringRef name);
static StringRef getBlockModeName(BlockMode blockMode);
/// An empty alias means no alias should be emitted.
static StringRef getBlockModeAlias(BlockMode blockMode);

static BlockFloatQuantizedType get(MLIRContext *ctx, BlockMode blockMode,
int32_t axis);
Expand Down
6 changes: 4 additions & 2 deletions mlir/include/mlir/IR/OpImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// Modifications (c) Copyright 2026 Advanced Micro Devices, Inc. or its
// affiliates
//
//===----------------------------------------------------------------------===//
//
// This classes used by the implementation details of Op types.
Expand Down Expand Up @@ -1762,8 +1765,7 @@ class OpAsmDialectInterface

/// Hooks for getting an alias identifier alias for a given symbol, that is
/// not necessarily a part of this dialect. The identifier is used in place of
/// the symbol when printing textual IR. These aliases must not contain `.` or
/// end with a numeric digit([0-9]+).
/// the symbol when printing textual IR. These aliases must not contain `.`.
virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const {
return AliasResult::NoAlias;
}
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Quant/IR/QuantOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
namespace mlir {
namespace quant {

namespace detail {
void addAsmInterface(QuantDialect *dialect);
} // namespace detail

namespace {

// Verify the integrity of per-axis quantization information, if present.
Expand Down Expand Up @@ -245,6 +249,7 @@ void QuantDialect::initialize() {
>();
detail::addBytecodeInterface(this);
addInterfaces<QuantInlinerInterface>();
detail::addAsmInterface(this);
}

//===----------------------------------------------------------------------===//
Expand Down
12 changes: 9 additions & 3 deletions mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its
// Modifications (c) Copyright 2025-2026 Advanced Micro Devices, Inc. or its
// affiliates
//
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -38,6 +38,8 @@ double getMaxScale(Type expressedType) {

struct BlockFloatQuantizedTypeConfig {
StringRef mode;
// If empty, no type alias should be emitted for this mode.
StringRef alias;
unsigned singleElementBitWidth;
unsigned averageBitsPerElement;
unsigned blockSize;
Expand All @@ -48,12 +50,12 @@ getBlockFloatQuantizedTypeConfig(BlockFloatQuantizedType::BlockMode blockMode) {
switch (blockMode) {
case BlockFloatQuantizedType::BlockMode::BFP16: {
static constexpr BlockFloatQuantizedTypeConfig config = {
StringLiteral("BFP16"), 16, 9, 8};
StringLiteral("BFP16"), StringLiteral("bfp16"), 16, 9, 8};
return config;
}
case BlockFloatQuantizedType::BlockMode::MX6: {
static constexpr BlockFloatQuantizedTypeConfig config = {
StringLiteral("MX6"), 13, 6, 16};
StringLiteral("MX6"), StringLiteral("mx6"), 13, 6, 16};
return config;
}
}
Expand Down Expand Up @@ -463,6 +465,10 @@ StringRef BlockFloatQuantizedType::getBlockModeName(BlockMode blockMode) {
return getBlockFloatQuantizedTypeConfig(blockMode).mode;
}

StringRef BlockFloatQuantizedType::getBlockModeAlias(BlockMode blockMode) {
return getBlockFloatQuantizedTypeConfig(blockMode).alias;
}

UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get(
unsigned flags, Type storageType, Type expressedType,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
Expand Down
29 changes: 28 additions & 1 deletion mlir/lib/Dialect/Quant/IR/TypeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its
// Modifications (c) Copyright 2025-2026 Advanced Micro Devices, Inc. or its
// affiliates
//
//===----------------------------------------------------------------------===//
Expand All @@ -13,6 +13,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/Support/Format.h"
Expand Down Expand Up @@ -708,3 +709,29 @@ void QuantDialect::printType(Type type, DialectAsmPrinter &os) const {
else
llvm_unreachable("Unhandled quantized type");
}

namespace {
struct QuantOpAsmDialectInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;

AliasResult getAlias(Type type, raw_ostream &os) const final {
const auto blockType = dyn_cast<BlockFloatQuantizedType>(type);
if (!blockType)
return AliasResult::NoAlias;
if (blockType.getAxis() != 1)
return AliasResult::NoAlias;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we want this to encode the axis in the type (e.g. mx6_1bf) some time in the future, but I see no blocker for this with the current design.

const StringRef alias =
BlockFloatQuantizedType::getBlockModeAlias(blockType.getBlockMode());
if (alias.empty())
return AliasResult::NoAlias;
os << alias;
return AliasResult::OverridableAlias;
}
};
} // namespace

namespace mlir::quant::detail {
void addAsmInterface(QuantDialect *dialect) {
dialect->addInterfaces<QuantOpAsmDialectInterface>();
}
} // namespace mlir::quant::detail
20 changes: 18 additions & 2 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// Modifications (c) Copyright 2026 Advanced Micro Devices, Inc. or its
// affiliates
//
//===----------------------------------------------------------------------===//
//
// This file implements the MLIR AsmPrinter class, which is used to implement
Expand Down Expand Up @@ -1109,12 +1112,26 @@ void AliasInitializer::initializeAliases(
return lhs.second < rhs.second;
});

llvm::StringSet<> usedNames;
for (auto &[symbol, aliasInfo] : unprocessedAliases)
if (aliasInfo.alias)
usedNames.insert(*aliasInfo.alias);

const auto tryClaimName = [&usedNames](StringRef alias,
unsigned index) -> bool {
return usedNames.insert((alias + Twine(index)).str()).second;
};

llvm::StringMap<unsigned> nameCounts;
for (auto &[symbol, aliasInfo] : unprocessedAliases) {
if (!aliasInfo.alias)
continue;
StringRef alias = *aliasInfo.alias;
unsigned nameIndex = nameCounts[alias]++;

while (nameIndex > 0 && !tryClaimName(alias, nameIndex))
nameIndex = nameCounts[alias]++;

symbolToAlias.insert(
{symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType,
aliasInfo.canBeDeferred)});
Expand Down Expand Up @@ -1206,8 +1223,7 @@ void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,

SmallString<16> tempBuffer;
StringRef name =
sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
/*allowTrailingDigit=*/false);
sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-");
name = name.copy(aliasAllocator);
alias = InProgressAliasInfo(name);
}
Expand Down
46 changes: 46 additions & 0 deletions mlir/test/Dialect/Quant/print-block-float-aliases.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// (c) Copyright 2026 Advanced Micro Devices, Inc. or its
// affiliates

// RUN: mlir-opt %s -split-input-file | FileCheck %s

// CHECK: !mx6 = !quant.block_float<mode=MX6, axis=1>
// CHECK-LABEL: func.func @alias_positive(
// CHECK: tensor<4x!mx6>
// CHECK: return %arg0 : tensor<4x!mx6>
func.func @alias_positive(%arg0: tensor<4x!quant.block_float<mode=MX6, axis=1>>)
-> tensor<4x!quant.block_float<mode=MX6, axis=1>> {
return %arg0 : tensor<4x!quant.block_float<mode=MX6, axis=1>>
}

// -----

// CHECK-LABEL: func.func @no_alias_mx6_other_axis(
// CHECK-NOT: !mx6
// CHECK: tensor<4x!quant.block_float<mode=MX6, axis=2>>
func.func @no_alias_mx6_other_axis(
%arg0: tensor<4x!quant.block_float<mode=MX6, axis=2>>)
-> tensor<4x!quant.block_float<mode=MX6, axis=2>> {
return %arg0 : tensor<4x!quant.block_float<mode=MX6, axis=2>>
}

// -----

// CHECK: !bfp16 = !quant.block_float<mode=BFP16, axis=1>
// CHECK-LABEL: func.func @alias_bfp16_axis1(
// CHECK: tensor<4x!bfp16>
func.func @alias_bfp16_axis1(
%arg0: tensor<4x!quant.block_float<mode=BFP16, axis=1>>)
-> tensor<4x!quant.block_float<mode=BFP16, axis=1>> {
return %arg0 : tensor<4x!quant.block_float<mode=BFP16, axis=1>>
}

// -----

// CHECK-LABEL: func.func @no_alias_bfp16_other_axis(
// CHECK-NOT: !bfp16
// CHECK: tensor<4x!quant.block_float<mode=BFP16, axis=0>>
func.func @no_alias_bfp16_other_axis(
%arg0: tensor<4x!quant.block_float<mode=BFP16, axis=0>>)
-> tensor<4x!quant.block_float<mode=BFP16, axis=0>> {
return %arg0 : tensor<4x!quant.block_float<mode=BFP16, axis=0>>
}
28 changes: 21 additions & 7 deletions mlir/test/IR/print-attr-type-aliases.mlir
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
// Modifications (c) Copyright 2026 Advanced Micro Devices, Inc. or its
// affiliates
// RUN: mlir-opt %s -split-input-file -mlir-print-debuginfo | FileCheck %s
// Verify printer of type & attr aliases.
// RUN: mlir-opt %s -split-input-file -mlir-print-debuginfo | mlir-opt -split-input-file -mlir-print-debuginfo | FileCheck %s

// CHECK-DAG: #test2Ealias = "alias_test:dot_in_name"
"test.op"() {alias_test = "alias_test:dot_in_name"} : () -> ()

// CHECK-DAG: #test_alias0_ = "alias_test:trailing_digit"
// CHECK-DAG: #test_alias0 = "alias_test:trailing_digit"
"test.op"() {alias_test = "alias_test:trailing_digit"} : () -> ()

// CHECK-DAG: #_0_test_alias = "alias_test:prefixed_digit"
Expand All @@ -14,10 +16,22 @@
// CHECK-DAG: #_25test = "alias_test:prefixed_symbol"
"test.op"() {alias_test = "alias_test:prefixed_symbol"} : () -> ()

// CHECK-DAG: #test_alias_conflict0_ = "alias_test:sanitize_conflict_a"
// CHECK-DAG: #test_alias_conflict0_1 = "alias_test:sanitize_conflict_b"
// CHECK-DAG: #test_alias_conflict0 = "alias_test:sanitize_conflict_a"
// CHECK-DAG: #test_alias_conflict0_ = "alias_test:sanitize_conflict_b"
"test.op"() {alias_test = ["alias_test:sanitize_conflict_a", "alias_test:sanitize_conflict_b"]} : () -> ()

// CHECK-DAG: #collide = "alias_test:suffix_collision_a"
// CHECK-DAG: #collide2 = "alias_test:suffix_collision_b"
// CHECK-DAG: #collide1 = "alias_test:suffix_collision_c"
"test.op"() {alias_test = ["alias_test:suffix_collision_a", "alias_test:suffix_collision_b", "alias_test:suffix_collision_c"]} : () -> ()

// CHECK-DAG: #cross = "alias_test:cross_collision_a"
// CHECK-DAG: #cross2 = "alias_test:cross_collision_b"
// CHECK-DAG: #cross3 = "alias_test:cross_collision_c"
// CHECK-DAG: #cross1 = "alias_test:cross_collision_d"
// CHECK-DAG: #cross11 = "alias_test:cross_collision_e"
"test.op"() {alias_test = ["alias_test:cross_collision_a", "alias_test:cross_collision_b", "alias_test:cross_collision_c", "alias_test:cross_collision_d", "alias_test:cross_collision_e"]} : () -> ()

// CHECK-DAG: !tuple = tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>
"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>)

Expand All @@ -28,8 +42,8 @@
// CHECK-DAG: tensor<32xf32, #test_encoding>
"test.op"() : () -> tensor<32xf32, "alias_test:tensor_encoding">

// CHECK-DAG: !test_ui8_ = !test.int<unsigned, 8>
// CHECK-DAG: tensor<32x!test_ui8_>
// CHECK-DAG: !test_ui8 = !test.int<unsigned, 8>
// CHECK-DAG: tensor<32x!test_ui8>
"test.op"() : () -> tensor<32x!test.int<unsigned, 8>>

// CHECK-DAG: #[[LOC_NESTED:.+]] = loc("nested")
Expand All @@ -47,8 +61,8 @@
// -----

// Ensure self type parameters get considered for aliases.
// CHECK: !test_ui8_ = !test.int<unsigned, 8>
// CHECK: #test.attr_with_self_type_param : !test_ui8_
// CHECK: !test_ui8 = !test.int<unsigned, 8>
// CHECK: #test.attr_with_self_type_param : !test_ui8
"test.op"() {alias_test = #test.attr_with_self_type_param : !test.int<unsigned, 8> } : () -> ()

// -----
Expand Down
6 changes: 4 additions & 2 deletions mlir/test/IR/recursive-type.mlir
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
// Modifications (c) Copyright 2026 Advanced Micro Devices, Inc. or its
// affiliates
// RUN: mlir-opt %s -test-recursive-types | FileCheck %s

// CHECK: !testrec = !test.test_rec<type_to_alias, test_rec<type_to_alias>>
// CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
// CHECK: ![[$NAME5:.*]] = !test.test_rec_alias<name5, !test.test_rec_alias<name3, !test.test_rec_alias<name4, !test.test_rec_alias<name5>>>>
// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5_>
// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4_>
// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5>
// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4>

// CHECK-LABEL: @roundtrip
func.func @roundtrip() {
Expand Down
11 changes: 11 additions & 0 deletions mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// Modifications (c) Copyright 2026 Advanced Micro Devices, Inc. or its
// affiliates
//
//===----------------------------------------------------------------------===//

#include "TestDialect.h"
Expand Down Expand Up @@ -194,6 +197,14 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
.Case("alias_test:sanitize_conflict_b",
StringRef("test_alias_conflict0_"))
.Case("alias_test:tensor_encoding", StringRef("test_encoding"))
.Case("alias_test:suffix_collision_a", StringRef("collide"))
.Case("alias_test:suffix_collision_b", StringRef("collide"))
.Case("alias_test:suffix_collision_c", StringRef("collide1"))
.Case("alias_test:cross_collision_a", StringRef("cross"))
.Case("alias_test:cross_collision_b", StringRef("cross"))
.Case("alias_test:cross_collision_c", StringRef("cross"))
.Case("alias_test:cross_collision_d", StringRef("cross1"))
.Case("alias_test:cross_collision_e", StringRef("cross1"))
.Default(std::nullopt);
if (!aliasName)
return AliasResult::NoAlias;
Expand Down