From 9b401b767fa1e4f5b6d1408a1ec198326337cd6f Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Tue, 21 Apr 2026 10:51:53 +0000 Subject: [PATCH 1/6] Start dynamic deps. --- compiler/src/garel/GraphAlgToRel.cpp | 92 ++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index ee60821..e599479 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -29,6 +30,9 @@ #include "graphalg/GraphAlgOps.h" #include "graphalg/GraphAlgTypes.h" #include "graphalg/SemiringTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/DenseMap.h" namespace garel { @@ -256,6 +260,9 @@ MatrixTypeConverter::MatrixTypeConverter( addConversion( [this](graphalg::MatrixType t) { return convertMatrixType(t); }); + + // No need to convert. + addConversion([](RelationType t) { return t; }); } // ============================================================================= @@ -471,6 +478,82 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::success(); } +static void addDimensionInputs( + mlir::func::FuncOp funcOp, + llvm::SmallDenseMap &dimToValue) { + // Function type + llvm::SmallVector inputs(funcOp.getFunctionType().getInputs()); + llvm::SmallVector dims; + for (auto t : funcOp.getFunctionType().getInputs()) { + auto matType = llvm::cast(t); + for (auto d : {matType.getRows(), matType.getCols()}) { + if (d.isAbstract() && !llvm::is_contained(dims, d)) { + dims.push_back(d); + } + } + } + + // Add dimension inputs to function type + auto *ctx = funcOp.getContext(); + auto dimReadType = RelationType::get( + ctx, mlir::ArrayRef{mlir::IndexType::get(ctx)}); + inputs.append(dims.size(), dimReadType); + auto newType = mlir::FunctionType::get(ctx, inputs, + funcOp.getFunctionType().getResults()); + funcOp.setFunctionType(newType); + + // Add dimension inputs as block args. + auto &block = funcOp.getFunctionBody().front(); + for (auto d : dims) { + auto arg = block.addArgument(dimReadType, funcOp.getLoc()); + dimToValue[d] = arg; + } +} + +/* +static mlir::LogicalResult convertFunc(mlir::func::FuncOp funcOp, + mlir::IRRewriter &rewriter, + MatrixTypeConverter &typeConverter) { + rewriter.setInsertionPointAfter(funcOp); + + // Function type + llvm::SmallVector inputs; + llvm::SmallVector dims; + for (auto t : funcOp.getFunctionType().getInputs()) { + auto relType = typeConverter.convertType(t); + if (!relType) { + return funcOp.emitOpError("input type ") << t << " cannot be converted"; + } + + inputs.push_back(relType); + + auto matType = llvm::cast(t); + for (auto d : {matType.getRows(), matType.getCols()}) { + if (d.isAbstract() && !llvm::is_contained(dims, d)) { + dims.push_back(d); + } + } + } + + // Add dimension inputs. + auto dimReadType = rewriter.getType( + mlir::ArrayRef{rewriter.getIndexType()}); + for (auto _d : dims) { + inputs.push_back(dimReadType); + } + + llvm::SmallVector results; + for (auto t : funcOp.getFunctionType().getResults()) { + auto relType = typeConverter.convertType(t); + if (!relType) { + return funcOp.emitOpError("result type ") << t << " cannot be converted"; + } + + results.push_back(relType); + } +} +*/ + template <> mlir::LogicalResult OpConversion::matchAndRewrite( mlir::func::ReturnOp op, OpAdaptor adaptor, @@ -1280,6 +1363,15 @@ static bool hasRelationOperands(mlir::Operation *op) { } void GraphAlgToRel::runOnOperation() { + // Add dimension inputs + llvm::SmallVector funcOps( + getOperation().getOps()); + mlir::IRRewriter rewriter(getOperation()); + for (auto op : funcOps) { + llvm::SmallDenseMap dimToValue; + addDimensionInputs(op, dimToValue); + } + mlir::ConversionTarget target(getContext()); // Eliminate all graphalg ops target.addIllegalDialect(); From d7724c050b10bb64574114d680c71a94a344149f Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Tue, 21 Apr 2026 14:59:40 +0000 Subject: [PATCH 2/6] It works. --- compiler/include/garel/GARelAttr.td | 3 +- compiler/include/garel/GARelOps.td | 2 +- compiler/include/garel/GARelTypes.h | 3 + compiler/include/garel/GARelTypes.td | 6 + compiler/src/garel/GARelAttr.cpp | 31 +- compiler/src/garel/GARelTypes.cpp | 6 + compiler/src/garel/GraphAlgToRel.cpp | 358 ++++++++++++------- compiler/test/graphalg-to-rel/cast-dim.mlir | 11 + compiler/test/graphalg-to-rel/for-const.mlir | 3 +- compiler/test/graphalg-to-rel/for-dim.mlir | 24 ++ 10 files changed, 310 insertions(+), 137 deletions(-) create mode 100644 compiler/test/graphalg-to-rel/cast-dim.mlir create mode 100644 compiler/test/graphalg-to-rel/for-dim.mlir diff --git a/compiler/include/garel/GARelAttr.td b/compiler/include/garel/GARelAttr.td index c09797c..4dac0ca 100644 --- a/compiler/include/garel/GARelAttr.td +++ b/compiler/include/garel/GARelAttr.td @@ -38,6 +38,7 @@ def AggregateFunc : I64EnumAttr< I64EnumAttrCase<"MAX", 2>, I64EnumAttrCase<"LOR", 3>, /* Logical OR (over i1) */ I64EnumAttrCase<"ARGMIN", 4>, + I64EnumAttrCase<"COUNT", 5>, ] > { let cppNamespace = "::garel"; @@ -49,7 +50,7 @@ def Aggregator : GARel_Attr<"Aggregator", "aggregator"> { let parameters = (ins "AggregateFunc":$func, - ArrayRefParameter<"ColumnIdx">:$inputs); + OptionalArrayRefParameter<"ColumnIdx">:$inputs); let assemblyFormat = [{ `<` $func $inputs `>` diff --git a/compiler/include/garel/GARelOps.td b/compiler/include/garel/GARelOps.td index f8ccaa3..e1e94ec 100644 --- a/compiler/include/garel/GARelOps.td +++ b/compiler/include/garel/GARelOps.td @@ -143,7 +143,7 @@ def ForOp : GARel_Op<"for", [InferTypeOpAdaptor]> { let arguments = (ins Variadic:$init, - I64Attr:$iters, + I64Relation:$iters, I64Attr:$resultIdx); let regions = (region diff --git a/compiler/include/garel/GARelTypes.h b/compiler/include/garel/GARelTypes.h index e6a4c90..689cd3a 100644 --- a/compiler/include/garel/GARelTypes.h +++ b/compiler/include/garel/GARelTypes.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include "garel/GARelAttr.h" @@ -11,4 +12,6 @@ namespace garel { bool isColumnType(mlir::Type t); +RelationType getI64RelationType(mlir::MLIRContext *ctx); + } // namespace garel diff --git a/compiler/include/garel/GARelTypes.td b/compiler/include/garel/GARelTypes.td index 5d848b5..d5eefed 100644 --- a/compiler/include/garel/GARelTypes.td +++ b/compiler/include/garel/GARelTypes.td @@ -34,4 +34,10 @@ def Tuple : GARel_Type<"Tuple", "tuple"> { def ColumnType : Type, "column type">; +def I64Relation : Type< + CPred<"::garel::getI64RelationType($_self.getContext()) == $_self">, + "relation with a single i64 column", + "RelationType">, + BuildableType<"::garel::getI64RelationType($_builder.getContext())">; + #endif // GAREL_TYPES diff --git a/compiler/src/garel/GARelAttr.cpp b/compiler/src/garel/GARelAttr.cpp index abb69cc..d228595 100644 --- a/compiler/src/garel/GARelAttr.cpp +++ b/compiler/src/garel/GARelAttr.cpp @@ -24,23 +24,32 @@ mlir::Type AggregatorAttr::getResultType(mlir::Type inputRel) { case AggregateFunc::ARGMIN: // NOTE: argmin(arg, val) also uses first input column as output type. return llvm::cast(inputRel).getColumns()[getInputs()[0]]; + case AggregateFunc::COUNT: + return mlir::IntegerType::get(inputRel.getContext(), 64); + } +} + +static std::size_t expectedNumInputs(AggregateFunc f) { + switch (f) { + case AggregateFunc::SUM: + case AggregateFunc::MIN: + case AggregateFunc::MAX: + case AggregateFunc::LOR: + return 1; + case AggregateFunc::ARGMIN: + return 2; + case AggregateFunc::COUNT: + return 0; } } mlir::LogicalResult AggregatorAttr::verify(llvm::function_ref emitError, AggregateFunc func, llvm::ArrayRef inputs) { - if (func == AggregateFunc::ARGMIN) { - if (inputs.size() != 2) { - return emitError() << stringifyAggregateFunc(func) - << " expects exactly two inputs (arg, val), got " - << inputs.size(); - } - } else { - if (inputs.size() != 1) { - return emitError() << stringifyAggregateFunc(func) - << " expects exactly one input, got " << inputs.size(); - } + if (inputs.size() != expectedNumInputs(func)) { + return emitError() << stringifyAggregateFunc(func) << " expects exactly " + << expectedNumInputs(func) << " inputs, got " + << inputs.size(); } return mlir::success(); diff --git a/compiler/src/garel/GARelTypes.cpp b/compiler/src/garel/GARelTypes.cpp index 618642d..d807ddc 100644 --- a/compiler/src/garel/GARelTypes.cpp +++ b/compiler/src/garel/GARelTypes.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include @@ -18,6 +19,11 @@ bool isColumnType(mlir::Type t) { t.isIndex(); } +RelationType getI64RelationType(mlir::MLIRContext *ctx) { + return RelationType::get( + ctx, mlir::ArrayRef{mlir::IntegerType::get(ctx, 64)}); +} + // Need to define this here to avoid depending on IPRTypes in // IPRDialect and creating a cycle. void GARelDialect::registerTypes() { diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index e599479..8c67b9d 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -30,9 +30,14 @@ #include "graphalg/GraphAlgOps.h" #include "graphalg/GraphAlgTypes.h" #include "graphalg/SemiringTypes.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" namespace garel { @@ -269,14 +274,64 @@ MatrixTypeConverter::MatrixTypeConverter( // ============================== Helper Methods =============================== // ============================================================================= +static llvm::StringLiteral INPUT_DIMS_ATTR_KEY = "garel.input_dims"; + /** * Create a relation with all indices for a matrix dimension. * * Used to broadcast scalar values to a larger matrix. */ -static RangeOp createDimRead(mlir::Location loc, graphalg::DimAttr dim, - mlir::OpBuilder &builder) { - return builder.create(loc, dim.getConcreteDim()); +static mlir::Value createDimRead(mlir::Location loc, graphalg::DimAttr dim, + mlir::OpBuilder &builder) { + if (dim.isConcrete()) { + return builder.create(loc, dim.getConcreteDim()); + } + + assert(dim.isAbstract()); + + // Find owning func + mlir::func::FuncOp funcOp; + auto parentOp = builder.getInsertionBlock()->getParentOp(); + funcOp = llvm::dyn_cast(parentOp); + if (!funcOp) { + funcOp = parentOp->getParentOfType(); + } + + assert(funcOp && "not contained inside a function"); + auto inputDims = funcOp->getAttrOfType(INPUT_DIMS_ATTR_KEY); + if (!inputDims) { + funcOp.emitOpError("missing input dims attribute"); + return nullptr; + } + + // Find the function argument for this dimension. + auto args = funcOp.getFunctionBody().getArguments(); + for (auto [d, v] : llvm::zip(inputDims.getAsRange(), + llvm::reverse(args))) { + if (dim == d) { + return v; + } + } + + funcOp.emitOpError("dimension ") + << dim << "is not defined within this function"; + return nullptr; +} + +/** + * Create a relation with one tuple that contains the number of valid indices + * for a matrix dimension. + */ +static mlir::Value createDimInt(mlir::Location loc, graphalg::DimAttr dim, + mlir::OpBuilder &builder) { + auto input = createDimRead(loc, dim, builder); + llvm::ArrayRef groupBy; + std::array aggregators{ + builder.getAttr(AggregateFunc::COUNT, + llvm::ArrayRef{}), + }; + + return builder.create(loc, input, groupBy, aggregators); } static void @@ -413,15 +468,6 @@ createAggregator(mlir::Operation *op, graphalg::SemiringTypeInterface sring, return AggregatorAttr::get(ctx, func, inputs); } -static mlir::IntegerAttr tryGetConstantInt(mlir::Value v) { - mlir::Attribute attr; - if (!mlir::matchPattern(v, mlir::m_Constant(&attr))) { - return nullptr; - } - - return llvm::cast(attr); -} - static mlir::FailureOr createMul(mlir::Operation *op, graphalg::SemiringTypeInterface sring, mlir::Value lhs, mlir::Value rhs, mlir::OpBuilder &builder) { @@ -448,6 +494,16 @@ createMul(mlir::Operation *op, graphalg::SemiringTypeInterface sring, << sring << " is not supported"; } +static bool isConstantZeroI64(mlir::Value rangeBegin) { + if (auto constOp = rangeBegin.getDefiningOp()) { + auto zero = mlir::IntegerAttr::get( + mlir::IntegerType::get(rangeBegin.getContext(), 64), 0); + return constOp.getValue() == zero; + } + + return false; +} + // ============================================================================= // =============================== Op Conversion =============================== // ============================================================================= @@ -478,82 +534,6 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::success(); } -static void addDimensionInputs( - mlir::func::FuncOp funcOp, - llvm::SmallDenseMap &dimToValue) { - // Function type - llvm::SmallVector inputs(funcOp.getFunctionType().getInputs()); - llvm::SmallVector dims; - for (auto t : funcOp.getFunctionType().getInputs()) { - auto matType = llvm::cast(t); - for (auto d : {matType.getRows(), matType.getCols()}) { - if (d.isAbstract() && !llvm::is_contained(dims, d)) { - dims.push_back(d); - } - } - } - - // Add dimension inputs to function type - auto *ctx = funcOp.getContext(); - auto dimReadType = RelationType::get( - ctx, mlir::ArrayRef{mlir::IndexType::get(ctx)}); - inputs.append(dims.size(), dimReadType); - auto newType = mlir::FunctionType::get(ctx, inputs, - funcOp.getFunctionType().getResults()); - funcOp.setFunctionType(newType); - - // Add dimension inputs as block args. - auto &block = funcOp.getFunctionBody().front(); - for (auto d : dims) { - auto arg = block.addArgument(dimReadType, funcOp.getLoc()); - dimToValue[d] = arg; - } -} - -/* -static mlir::LogicalResult convertFunc(mlir::func::FuncOp funcOp, - mlir::IRRewriter &rewriter, - MatrixTypeConverter &typeConverter) { - rewriter.setInsertionPointAfter(funcOp); - - // Function type - llvm::SmallVector inputs; - llvm::SmallVector dims; - for (auto t : funcOp.getFunctionType().getInputs()) { - auto relType = typeConverter.convertType(t); - if (!relType) { - return funcOp.emitOpError("input type ") << t << " cannot be converted"; - } - - inputs.push_back(relType); - - auto matType = llvm::cast(t); - for (auto d : {matType.getRows(), matType.getCols()}) { - if (d.isAbstract() && !llvm::is_contained(dims, d)) { - dims.push_back(d); - } - } - } - - // Add dimension inputs. - auto dimReadType = rewriter.getType( - mlir::ArrayRef{rewriter.getIndexType()}); - for (auto _d : dims) { - inputs.push_back(dimReadType); - } - - llvm::SmallVector results; - for (auto t : funcOp.getFunctionType().getResults()) { - auto relType = typeConverter.convertType(t); - if (!relType) { - return funcOp.emitOpError("result type ") << t << " cannot be converted"; - } - - results.push_back(relType); - } -} -*/ - template <> mlir::LogicalResult OpConversion::matchAndRewrite( mlir::func::ReturnOp op, OpAdaptor adaptor, @@ -646,6 +626,10 @@ mlir::LogicalResult ApplyOpConversion::matchAndRewrite( // Broadcast to all rows. auto rowsOp = createDimRead(op.getLoc(), output.matrixType().getRows(), rewriter); + if (!rowsOp) { + return mlir::failure(); + } + joinChildren.push_back(rowsOp); rowColumns.push_back(InputColumnRef{ .relIdx = joinChildren.size() - 1, @@ -659,6 +643,10 @@ mlir::LogicalResult ApplyOpConversion::matchAndRewrite( // Broadcast to all columns. auto colsOp = createDimRead(op.getLoc(), output.matrixType().getCols(), rewriter); + if (!colsOp) { + return mlir::failure(); + } + joinChildren.push_back(colsOp); colColumns.push_back(InputColumnRef{ .relIdx = joinChildren.size() - 1, @@ -762,8 +750,13 @@ mlir::LogicalResult OpConversion::matchAndRewrite( rowColumnIdx = input.rowColumn(); } else if (output.hasRowColumn()) { // Broadcast over all rows. - joinChildren.push_back( - createDimRead(op.getLoc(), output.matrixType().getRows(), rewriter)); + auto rowsOp = + createDimRead(op.getLoc(), output.matrixType().getRows(), rewriter); + if (!rowsOp) { + return mlir::failure(); + } + + joinChildren.push_back(rowsOp); rowColumnIdx = currentColIdx++; } @@ -772,8 +765,13 @@ mlir::LogicalResult OpConversion::matchAndRewrite( colColumnIdx = input.colColumn(); } else if (output.hasColColumn()) { // Broadcast over all columns. - joinChildren.push_back( - createDimRead(op.getLoc(), output.matrixType().getCols(), rewriter)); + auto colsOp = + createDimRead(op.getLoc(), output.matrixType().getCols(), rewriter); + if (!colsOp) { + return mlir::failure(); + } + + joinChildren.push_back(colsOp); colColumnIdx = currentColIdx++; } @@ -821,14 +819,24 @@ mlir::LogicalResult OpConversion::matchAndRewrite( llvm::SmallVector joinChildren; if (!output.matrixType().getRows().isOne()) { // Broadcast over all rows. - joinChildren.push_back( - createDimRead(op.getLoc(), output.matrixType().getRows(), rewriter)); + auto rowsOp = + createDimRead(op.getLoc(), output.matrixType().getRows(), rewriter); + if (!rowsOp) { + return mlir::failure(); + } + + joinChildren.push_back(rowsOp); } if (!output.matrixType().getCols().isOne()) { // Broadcast over all columns. - joinChildren.push_back( - createDimRead(op.getLoc(), output.matrixType().getCols(), rewriter)); + auto colsOp = + createDimRead(op.getLoc(), output.matrixType().getCols(), rewriter); + if (!colsOp) { + return mlir::failure(); + } + + joinChildren.push_back(colsOp); } joinChildren.push_back(constantOp); @@ -885,25 +893,44 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::success(); } -template <> -mlir::LogicalResult OpConversion::matchAndRewrite( - graphalg::ForConstOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const { - auto rangeBegin = tryGetConstantInt(op.getRangeBegin()); - auto rangeEnd = tryGetConstantInt(op.getRangeEnd()); - if (!rangeBegin || !rangeEnd) { - return op->emitOpError("iter range is not constant"); +// Sharing logic between ForConstOp and ForDimOp +static mlir::LogicalResult +convertFor(mlir::Operation *op, mlir::ValueRange adaptorInitArgs, + mlir::Value rangeBegin, mlir::Value rangeEnd, mlir::Region &body, + mlir::Region &until, const mlir::TypeConverter *typeConverter, + mlir::ConversionPatternRewriter &rewriter) { + llvm::SmallVector initArgs{rangeBegin}; + initArgs.append(adaptorInitArgs.begin(), adaptorInitArgs.end()); + + auto blockSignature = typeConverter->convertBlockSignature(&body.front()); + if (!blockSignature) { + return op->emitOpError("Failed to convert iter args"); } - auto iters = rangeEnd.getInt() - rangeBegin.getInt(); + mlir::Value iters; + if (isConstantZeroI64(rangeBegin)) { + iters = rangeEnd; + } else { + // Subtract rangeBegin from rangeEnd + auto joinOp = rewriter.create( + op->getLoc(), mlir::ValueRange{rangeBegin, rangeEnd}, + rewriter.getAttr( + llvm::ArrayRef{})); + auto projOp = rewriter.create( + op->getLoc(), getI64RelationType(op->getContext()), joinOp); - llvm::SmallVector initArgs{adaptor.getRangeBegin()}; - initArgs.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end()); + auto &block = projOp.createProjectionsBlock(); + mlir::OpBuilder::InsertionGuard guard{rewriter}; + rewriter.setInsertionPointToStart(&block); - auto blockSignature = - typeConverter->convertBlockSignature(&op.getBody().front()); - if (!blockSignature) { - return op->emitOpError("Failed to convert iter args"); + auto begin = + rewriter.create(op->getLoc(), 0, block.getArgument(0)); + auto end = + rewriter.create(op->getLoc(), 1, block.getArgument(0)); + auto res = rewriter.create(op->getLoc(), end, begin); + rewriter.create(op->getLoc(), mlir::ValueRange{res}); + + iters = projOp; } // The relational version of this op can only have a single output value. @@ -913,25 +940,24 @@ mlir::LogicalResult OpConversion::matchAndRewrite( auto result = op->getResult(i); if (result.use_empty()) { // Not used. Take init arg as a dummy value. - resultValues.push_back(adaptor.getInitArgs()[i]); + resultValues.push_back(adaptorInitArgs[i]); continue; } // We are adding the iteration count variable as a first argument, so offset // the result index accordingly. std::int64_t resultIdx = i + 1; - auto resultType = adaptor.getInitArgs()[i].getType(); - auto forOp = rewriter.create(op.getLoc(), resultType, initArgs, + auto resultType = adaptorInitArgs[i].getType(); + auto forOp = rewriter.create(op->getLoc(), resultType, initArgs, iters, resultIdx); // body block - rewriter.cloneRegionBefore(op.getBody(), forOp.getBody(), - forOp.getBody().begin()); + rewriter.cloneRegionBefore(body, forOp.getBody(), forOp.getBody().begin()); rewriter.applySignatureConversion(&forOp.getBody().front(), *blockSignature); // until block - if (!op.getUntil().empty()) { - rewriter.cloneRegionBefore(op.getUntil(), forOp.getUntil(), + if (!until.empty()) { + rewriter.cloneRegionBefore(until, forOp.getUntil(), forOp.getUntil().begin()); rewriter.applySignatureConversion(&forOp.getUntil().front(), *blockSignature); @@ -944,6 +970,28 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::success(); } +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::ForConstOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + return convertFor(op, adaptor.getInitArgs(), adaptor.getRangeBegin(), + adaptor.getRangeEnd(), op.getBody(), op.getUntil(), + typeConverter, rewriter); +} + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::ForDimOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto ctx = op.getContext(); + auto rangeBegin = + rewriter.create(op.getLoc(), rewriter.getI64IntegerAttr(0)); + auto rangeEnd = createDimInt(op.getLoc(), op.getDim(), rewriter); + + return convertFor(op, adaptor.getInitArgs(), rangeBegin, rangeEnd, + op.getBody(), op.getUntil(), typeConverter, rewriter); +} + template <> mlir::LogicalResult OpConversion::matchAndRewrite( graphalg::YieldOp op, OpAdaptor adaptor, @@ -1161,6 +1209,14 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::success(); } +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::CastDimOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + rewriter.replaceOp(op, createDimInt(op->getLoc(), op.getInput(), rewriter)); + return mlir::success(); +} + // ============================================================================= // ============================ Tuple Op Conversion ============================ // ============================================================================= @@ -1362,14 +1418,69 @@ static bool hasRelationOperands(mlir::Operation *op) { [](auto t) { return llvm::isa(t); }); } +static void addDimensionInputs(mlir::func::FuncOp funcOp) { + // Function type + llvm::SmallVector inputs(funcOp.getFunctionType().getInputs()); + llvm::SmallVector dims; + for (auto t : funcOp.getFunctionType().getInputs()) { + auto matType = llvm::cast(t); + for (auto d : {matType.getRows(), matType.getCols()}) { + if (d.isAbstract() && !llvm::is_contained(dims, d)) { + dims.push_back(d); + } + } + } + + // Add dimension inputs to function type + auto *ctx = funcOp.getContext(); + auto dimReadType = RelationType::get( + ctx, mlir::ArrayRef{mlir::IndexType::get(ctx)}); + inputs.append(dims.size(), dimReadType); + auto newType = mlir::FunctionType::get(ctx, inputs, + funcOp.getFunctionType().getResults()); + funcOp.setFunctionType(newType); + + // Add dimension inputs as block args. + auto &block = funcOp.getFunctionBody().front(); + for (auto d : dims) { + auto arg = block.addArgument(dimReadType, funcOp.getLoc()); + } + + // Annotate with input_dims attribute. + llvm::SmallVector dimsErased; + for (auto d : dims) { + dimsErased.push_back(d); + } + funcOp->setAttr(INPUT_DIMS_ATTR_KEY, mlir::ArrayAttr::get(ctx, dimsErased)); +} + +/* +static mlir::LogicalResult forDimToConst(graphalg::ForDimOp op, + mlir::PatternRewriter &rewriter) { + auto ctx = op.getContext(); + auto rangeBegin = rewriter.create( + op.getLoc(), + graphalg::MatrixType::scalarOf(graphalg::SemiringTypes::forInt(ctx)), + rewriter.getI64IntegerAttr(0)); + auto rangeEnd = + rewriter.create(op.getLoc(), op.getDim()); + auto newOp = rewriter.create( + op.getLoc(), op->getResultTypes(), op.getInitArgs(), rangeBegin, + rangeEnd); + newOp.getBody().takeBody(op.getBody()); + newOp.getUntil().takeBody(op.getUntil()); + rewriter.replaceOp(op, newOp); + return mlir::success(); +} +*/ + void GraphAlgToRel::runOnOperation() { // Add dimension inputs llvm::SmallVector funcOps( getOperation().getOps()); mlir::IRRewriter rewriter(getOperation()); for (auto op : funcOps) { - llvm::SmallDenseMap dimToValue; - addDimensionInputs(op, dimToValue); + addDimensionInputs(op); } mlir::ConversionTarget target(getContext()); @@ -1394,9 +1505,10 @@ void GraphAlgToRel::runOnOperation() { OpConversion, OpConversion, OpConversion, OpConversion, OpConversion, - OpConversion, OpConversion, - OpConversion, OpConversion, - OpConversion, OpConversion>( + OpConversion, OpConversion, + OpConversion, OpConversion, + OpConversion, OpConversion, + OpConversion, OpConversion>( matrixTypeConverter, &getContext()); patterns.add(semiringTypeConverter, matrixTypeConverter, &getContext()); diff --git a/compiler/test/graphalg-to-rel/cast-dim.mlir b/compiler/test/graphalg-to-rel/cast-dim.mlir new file mode 100644 index 0000000..f48fca8 --- /dev/null +++ b/compiler/test/graphalg-to-rel/cast-dim.mlir @@ -0,0 +1,11 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +#dim = #graphalg.dim> +// CHECK-LABEL: @CastDim +func.func @CastDim(%arg0: !graphalg.mat<#dim x #dim x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: %[[#AGG:]] = garel.aggregate %arg1 : group_by=[] aggregators=[] + %0 = graphalg.cast_dim #dim + + // CHECK: return %[[#AGG]] + return %0 : !graphalg.mat<1 x 1 x i64> +} diff --git a/compiler/test/graphalg-to-rel/for-const.mlir b/compiler/test/graphalg-to-rel/for-const.mlir index 2d418d8..936fffc 100644 --- a/compiler/test/graphalg-to-rel/for-const.mlir +++ b/compiler/test/graphalg-to-rel/for-const.mlir @@ -6,7 +6,8 @@ func.func @ForConst(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x %1 = graphalg.const_mat 10 : i64 -> <1 x 1 x i64> // CHECK: %[[#BEGIN:]] = garel.const 0 : i64 - // CHECK: %[[#FOR:]] = garel.for %[[#BEGIN]], %arg0 : !garel.rel, !garel.rel iters=10 result_idx=1 { + // CHECK: %[[#ITERS:]] = garel.const 10 : i64 + // CHECK: %[[#FOR:]] = garel.for %[[#BEGIN]], %arg0 : !garel.rel, !garel.rel iters=%[[#ITERS]] result_idx=1 { %2 = graphalg.for_const range(%0, %1) : <1 x 1 x i64> init(%arg0) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<1 x 1 x i64>): // CHECK: %[[#PROJ:]] = garel.project %arg1 diff --git a/compiler/test/graphalg-to-rel/for-dim.mlir b/compiler/test/graphalg-to-rel/for-dim.mlir new file mode 100644 index 0000000..191565e --- /dev/null +++ b/compiler/test/graphalg-to-rel/for-dim.mlir @@ -0,0 +1,24 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +#dim = #graphalg.dim> +// CHECK-LABEL: @ForDim +func.func @ForDim(%arg0: !graphalg.mat<#dim x #dim x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: %[[#INIT:]] = garel.const 42 : i64 + %0 = graphalg.const_mat 42 : i64 -> <1 x 1 x i64> + + // CHECK: %[[#BEGIN:]] = garel.const 0 : i64 + // CHECK: %[[#ITERS:]] = garel.aggregate %arg1 : group_by=[] aggregators=[] + // CHECK: %[[#FOR:]] = garel.for + // CHECK-SAME: %[[#BEGIN]], %[[#INIT]] : !garel.rel, !garel.rel + // CHECK-SAME: iters=%[[#ITERS]] result_idx=1 + %1 = graphalg.for_dim range(#dim) init(%0) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { + ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<1 x 1 x i64>): + // CHECK: %[[#INC:]] = garel.project %arg2 + // CHECK: garel.for.yield %[[#INC]], %arg2 + graphalg.yield %arg1 : !graphalg.mat<1 x 1 x i64> + } until { + } + + // CHECK: return %[[#FOR]] + return %1 : !graphalg.mat<1 x 1 x i64> +} From fcce5f1aa13164112cd5bd9dcbe71c8957c84bff Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Tue, 21 Apr 2026 15:08:26 +0000 Subject: [PATCH 3/6] Cleaning up. --- compiler/src/garel/GraphAlgToRel.cpp | 34 ++++------------------------ 1 file changed, 5 insertions(+), 29 deletions(-) diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index 8c67b9d..cc58e52 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -1,20 +1,24 @@ #include -#include #include #include #include #include +#include +#include #include #include +#include #include #include #include #include #include #include +#include #include #include +#include #include #include #include @@ -30,14 +34,6 @@ #include "graphalg/GraphAlgOps.h" #include "graphalg/GraphAlgTypes.h" #include "graphalg/SemiringTypes.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/Casting.h" namespace garel { @@ -1454,26 +1450,6 @@ static void addDimensionInputs(mlir::func::FuncOp funcOp) { funcOp->setAttr(INPUT_DIMS_ATTR_KEY, mlir::ArrayAttr::get(ctx, dimsErased)); } -/* -static mlir::LogicalResult forDimToConst(graphalg::ForDimOp op, - mlir::PatternRewriter &rewriter) { - auto ctx = op.getContext(); - auto rangeBegin = rewriter.create( - op.getLoc(), - graphalg::MatrixType::scalarOf(graphalg::SemiringTypes::forInt(ctx)), - rewriter.getI64IntegerAttr(0)); - auto rangeEnd = - rewriter.create(op.getLoc(), op.getDim()); - auto newOp = rewriter.create( - op.getLoc(), op->getResultTypes(), op.getInitArgs(), rangeBegin, - rangeEnd); - newOp.getBody().takeBody(op.getBody()); - newOp.getUntil().takeBody(op.getUntil()); - rewriter.replaceOp(op, newOp); - return mlir::success(); -} -*/ - void GraphAlgToRel::runOnOperation() { // Add dimension inputs llvm::SmallVector funcOps( From 88f76ae7284c4a826652181e6c1d0e70915ae031 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Thu, 30 Apr 2026 13:40:09 +0000 Subject: [PATCH 4/6] New for loop design. --- compiler/include/graphalg/GraphAlgOps.td | 66 ++++++++++++- .../src/graphalg/GraphAlgCanonicalize.cpp | 33 +++++++ compiler/src/graphalg/GraphAlgOps.cpp | 35 +++++++ .../src/graphalg/GraphAlgSetDimensions.cpp | 67 ++++++------- compiler/src/graphalg/GraphAlgToCore.cpp | 16 ++++ compiler/src/graphalg/evaluate/Evaluator.cpp | 94 ++++++++++++++++++- compiler/src/graphalg/parse/Parser.cpp | 12 ++- compiler/test/golden/bfs.mlir.ref | 2 +- compiler/test/golden/cdlp.mlir.ref | 2 +- compiler/test/golden/pr.mlir.ref | 2 +- compiler/test/golden/sssp.mlir.ref | 2 +- compiler/test/golden/update_goldens.sh | 2 +- compiler/test/golden/wcc.mlir.ref | 2 +- compiler/test/parse/loop.gr | 4 +- compiler/test/set-dimensions/for-dim.mlir | 6 +- 15 files changed, 287 insertions(+), 58 deletions(-) diff --git a/compiler/include/graphalg/GraphAlgOps.td b/compiler/include/graphalg/GraphAlgOps.td index 0cdf7bf..f80deb8 100644 --- a/compiler/include/graphalg/GraphAlgOps.td +++ b/compiler/include/graphalg/GraphAlgOps.td @@ -403,6 +403,70 @@ def BroadcastOp : Core_Op<"broadcast", [ let hasVerifier = 1; } +def ForOp : Core_Op<"for", [ + Pure, + AttrSizedOperandSegments, + DeclareOpInterfaceMethods]> { + let summary = "For loop with dynamic bounds"; + + let description = [{ + A loop iterating over one of three ranges: + 1) `dynBegin` (inclusive) to `dynEnd` (exclusive) + 2) `begin` to `begin` + `iters`, where `iters` is an integer + 3) `begin` to `begin` + `iters`, where `iters` is a matrix dimension + + Only instances with range types 2 or 3 are considered part of GraphAlg + Core. Constant propagation is expected to transform range type 1 into + either 2 or 3. + + The `body` region is executed once for every value in the integer range + (that value is passed as the first block argument). + At the first iteration of the loop, the other block arguments take the + values of `initArgs`. For subsequent iterations, results from the + previous iteration (produced by `YieldOp`) are taken instead. + The `until` region, if present, is executed after `body`, and produces a + single boolean scalar indicating whether the loop should terminate + early. + + In more imperative terms, `initArgs` can be seen as the set of variables + that are updated in the loop body. + Within the loop body, those variables can be accessed through the block + arguments, and their updated values are set through `YieldOp`. + Finally, `results` represents the new state of those variables after the + loop terminates. + }]; + + let arguments = (ins + Variadic:$initArgs, + Optional:$dynBegin, + Optional:$dynEnd, + OptionalAttr:$begin, + OptionalAttr:$iters); + + let results = (outs Variadic:$results); + + let regions = (region SizedRegion<1>:$body, MaxSizedRegion<1>:$until); + + let assemblyFormat = [{ + (`dyn_begin` `` `=` `` $dynBegin^)? + (`dyn_end` `` `=` `` $dynEnd^)? + (`begin` `` `=` `` $begin^)? + (`iters` `` `=` `` $iters^)? + `init` `(` $initArgs `)` `:` type($initArgs) `->` type($results) attr-dict + `body` $body + `until` $until + }]; + + let hasVerifier = 1; + let hasRegionVerifier = 1; + let hasFolder = 1; + + let extraClassDeclaration = [{ + /** Whether at least one of `dyn_begin` and `dyn_end` is set. */ + bool isDynamicRange(); + }]; +} + // Not core according to spec, but we don't want to unroll in the general case. def ForConstOp : Core_Op<"for_const", [ Pure, @@ -483,7 +547,7 @@ def ForDimOp : Core_Op<"for_dim", [ def YieldOp : Core_Op<"yield", [ Pure, Terminator, - ParentOneOf<["ForConstOp", "ForDimOp"]>, + ParentOneOf<["ForOp", "ForConstOp", "ForDimOp"]>, DeclareOpInterfaceMethods]> { let summary = "Yield from a loop body"; diff --git a/compiler/src/graphalg/GraphAlgCanonicalize.cpp b/compiler/src/graphalg/GraphAlgCanonicalize.cpp index e89ff7e..1454d0d 100644 --- a/compiler/src/graphalg/GraphAlgCanonicalize.cpp +++ b/compiler/src/graphalg/GraphAlgCanonicalize.cpp @@ -242,6 +242,39 @@ mlir::OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { return nullptr; } +mlir::LogicalResult +ForOp::fold(FoldAdaptor adaptor, + ::llvm::SmallVectorImpl<::mlir::OpFoldResult> &results) { + if (!getBegin() && adaptor.getDynBegin()) { + // Can infer a constant begin to the range. + auto begin = llvm::cast(adaptor.getDynBegin()); + setBeginAttr(begin); + getDynBeginMutable().clear(); + return mlir::success(); + } + + if (!getIters() && getBegin() && adaptor.getDynEnd()) { + // Can infer a constant number of iterations. + auto begin = *getBegin(); + auto end = llvm::cast(adaptor.getDynEnd()) + .getValue() + .getZExtValue(); + // If end < begin, drop to 0 iterations. + std::size_t iters = 0; + if (begin < end) { + iters = end - begin; + } + + // NOTE: number of iterations is encoded as a DimAttr. + auto dim = DimAttr::getConcrete(getContext(), iters); + setItersAttr(dim); + getDynEndMutable().clear(); + return mlir::success(); + } + + return mlir::failure(); +} + static mlir::LogicalResult forDimConst(ForDimOp op, mlir::PatternRewriter &rewriter) { if (!op.getDim().isConcrete()) { diff --git a/compiler/src/graphalg/GraphAlgOps.cpp b/compiler/src/graphalg/GraphAlgOps.cpp index 01764bc..9176b8e 100644 --- a/compiler/src/graphalg/GraphAlgOps.cpp +++ b/compiler/src/graphalg/GraphAlgOps.cpp @@ -505,6 +505,41 @@ mlir::LogicalResult BroadcastOp::verify() { return mlir::success(); } +// === ForOp === +mlir::LogicalResult ForOp::verify() { + if (getDynBegin() && getBegin()) { + return emitOpError("begin and dyn_begin are mutually exclusive"); + } else if (!getDynBegin() && !getBegin()) { + return emitOpError("no loop start: must have 'begin' or 'dyn_begin'"); + } + + if (getDynEnd() && getIters()) { + return emitOpError("begin and iters are mutually exclusive"); + } else if (!getDynEnd() && !getIters()) { + return emitOpError("no loop end: must have 'iters' or 'dyn_end'"); + } + + return mlir::success(); +} + +mlir::LogicalResult ForOp::verifyRegions() { + return verifyLoop(getOperation(), getInitArgs(), getBody(), getUntil()); +} + +void ForOp::getSuccessorRegions( + mlir::RegionBranchPoint point, + llvm::SmallVectorImpl ®ions) { + getLoopSuccessorRegions(getOperation(), getBody(), getUntil(), point, + regions); +} + +mlir::OperandRange +ForOp::getEntrySuccessorOperands(mlir::RegionBranchPoint point) { + return getInitArgs(); +} + +bool ForOp::isDynamicRange() { return getDynBegin() || getDynEnd(); } + // === ForConstOp === mlir::LogicalResult ForConstOp::verifyRegions() { return verifyLoop(getOperation(), getInitArgs(), getBody(), getUntil()); diff --git a/compiler/src/graphalg/GraphAlgSetDimensions.cpp b/compiler/src/graphalg/GraphAlgSetDimensions.cpp index 8578816..490c8ef 100644 --- a/compiler/src/graphalg/GraphAlgSetDimensions.cpp +++ b/compiler/src/graphalg/GraphAlgSetDimensions.cpp @@ -125,20 +125,6 @@ class DimConversionPattern : public mlir::ConversionPattern { mlir::ConversionPatternRewriter &rewriter) const override; }; -/** Template for rewrites (without type conversion). */ -template -class DimOpRewritePattern : public mlir::OpRewritePattern { -private: - const DimMapper &_mapper; - - mlir::LogicalResult - matchAndRewrite(T op, mlir::PatternRewriter &rewriter) const override; - -public: - DimOpRewritePattern(const DimMapper &mapper, mlir::MLIRContext *ctx) - : mlir::OpRewritePattern(ctx), _mapper(mapper) {} -}; - } // namespace mlir::FailureOr @@ -309,31 +295,33 @@ mlir::LogicalResult DimConversionPattern::matchAndRewrite( return mlir::success(); } -template <> -mlir::LogicalResult DimOpRewritePattern::matchAndRewrite( - CastDimOp op, mlir::PatternRewriter &rewriter) const { - auto dim = _mapper.convertAttr(op.getInput()); - if (!dim) { - return mlir::failure(); +static mlir::LogicalResult updateDim(ForOp op, DimMapper &mapper) { + if (!op.getIters() || !op.getIters()->isAbstract()) { + // No update needed + return mlir::success(); } - // The folder on CastDimOp should turn this into a constant. - auto newOp = rewriter.createOrFold(op->getLoc(), dim); - rewriter.replaceOp(op, newOp); + auto dim = mapper.convertAttr(*op.getIters()); + if (!dim) { + return op.emitOpError("no mapping for ") << *op.getIters(); + } + op.setItersAttr(dim); return mlir::success(); } -template <> -mlir::LogicalResult DimOpRewritePattern::matchAndRewrite( - ForDimOp op, mlir::PatternRewriter &rewriter) const { - auto dim = _mapper.convertAttr(op.getDim()); - if (!dim) { - return mlir::failure(); +static mlir::LogicalResult updateDim(CastDimOp op, DimMapper &mapper) { + if (!op.getInput().isAbstract()) { + // No update needed + return mlir::success(); } - rewriter.modifyOpInPlace(op, [&]() { op.setDimAttr(dim); }); + auto dim = mapper.convertAttr(op.getInput()); + if (!dim) { + return op.emitOpError("no mapping for ") << op.getInput(); + } + op.setInputAttr(dim); return mlir::success(); } @@ -356,6 +344,19 @@ void GraphAlgSetDimensions::runOnOperation() { return signalPassFailure(); } + // Update direct references to dimensions. + bool failedDirectUpdate = false; + func->walk([&](ForOp op) { + if (mlir::failed(updateDim(op, *dimMapper))) { + failedDirectUpdate = true; + } + }); + func->walk([&](CastDimOp op) { + if (mlir::failed(updateDim(op, *dimMapper))) { + failedDirectUpdate = true; + } + }); + mlir::ConversionTarget target(getContext()); target.addDynamicallyLegalDialect( doesNotUseAbstractDimensions); @@ -374,12 +375,6 @@ void GraphAlgSetDimensions::runOnOperation() { // Convert all result types and block argument types. patterns.add(typeConverter, &getContext()); - // Convert ops that have a special dependency on DimAttr. - patterns.add, DimOpRewritePattern>( - *dimMapper, &getContext()); - // Use the canonicalization pattern to rewrite ForDimOp into ForConstOp. - ForDimOp::getCanonicalizationPatterns(patterns, &getContext()); - if (mlir::failed( mlir::applyPartialConversion(func, target, std::move(patterns)))) { return signalPassFailure(); diff --git a/compiler/src/graphalg/GraphAlgToCore.cpp b/compiler/src/graphalg/GraphAlgToCore.cpp index dbc8993..fdee0c4 100644 --- a/compiler/src/graphalg/GraphAlgToCore.cpp +++ b/compiler/src/graphalg/GraphAlgToCore.cpp @@ -254,6 +254,8 @@ void GraphAlgToCore::runOnOperation() { target.addIllegalDialect(); target.addDynamicallyLegalDialect( [](mlir::Operation *op) { return op->hasTrait(); }); + target.addDynamicallyLegalOp( + [](ForOp op) { return !op.isDynamicRange(); }); mlir::RewritePatternSet patterns(&getContext()); patterns.add(convertVecMatMul); @@ -266,6 +268,20 @@ void GraphAlgToCore::runOnOperation() { patterns.add(convertTriu); patterns.add(convertLiteral); + // Conversion will give very unclear errors about dynamic range for loops, so + // do our own analysis first. + bool haveDynamicRangeLoops = false; + getOperation()->walk([&](ForOp op) { + if (op.isDynamicRange()) { + op.emitOpError("loop bound must be a constant in GraphAlg Core"); + haveDynamicRangeLoops = true; + } + }); + + if (haveDynamicRangeLoops) { + return signalPassFailure(); + } + if (mlir::failed(mlir::applyFullConversion(getOperation(), target, std::move(patterns)))) { signalPassFailure(); diff --git a/compiler/src/graphalg/evaluate/Evaluator.cpp b/compiler/src/graphalg/evaluate/Evaluator.cpp index 8545c4b..335beb6 100644 --- a/compiler/src/graphalg/evaluate/Evaluator.cpp +++ b/compiler/src/graphalg/evaluate/Evaluator.cpp @@ -35,6 +35,7 @@ class Evaluator { mlir::LogicalResult evaluate(ReduceOp op); mlir::LogicalResult evaluate(BroadcastOp op); mlir::LogicalResult evaluate(ConstantMatrixOp op); + mlir::LogicalResult evaluate(ForOp op); mlir::LogicalResult evaluate(ForConstOp op); mlir::LogicalResult evaluate(ApplyOp op); mlir::LogicalResult evaluate(PickAnyOp op); @@ -187,6 +188,88 @@ mlir::LogicalResult Evaluator::evaluate(ConstantMatrixOp op) { return mlir::success(); } +mlir::LogicalResult Evaluator::evaluate(ForOp op) { + auto begin = op.getBegin(); + if (!begin) { + return op.emitOpError("loop begin is not constant"); + } + + auto iters = op.getIters(); + if (!iters || iters->isAbstract()) { + return op.emitOpError("number of loop iterations is not constant"); + } + + auto rangeBegin = *begin; + auto rangeEnd = rangeBegin + iters->getConcreteDim(); + + auto &body = op.getBody().front(); + auto *ctx = op.getContext(); + + // Initialize block arguments + for (auto [init, blockArg] : + llvm::zip_equal(op.getInitArgs(), body.getArguments().drop_front())) { + _values[blockArg] = _values[init]; + } + + for (auto i : llvm::seq(rangeBegin, rangeEnd)) { + // Iteration variable. + auto iterAttr = mlir::IntegerAttr::get(SemiringTypes::forInt(ctx), i); + auto iterArg = body.getArgument(0); + auto iterType = llvm::cast(iterArg.getType()); + MatrixAttrBuilder iterBuilder(iterType); + iterBuilder.set(0, 0, iterAttr); + _values[body.getArgument(0)] = iterBuilder.build(); + + for (auto &op : body) { + if (auto yieldOp = llvm::dyn_cast(op)) { + // Update block arguments + for (auto [value, blockArg] : llvm::zip_equal( + yieldOp.getInputs(), body.getArguments().drop_front())) { + _values[blockArg] = _values[value]; + } + } else if (mlir::failed(evaluate(&op))) { + return mlir::failure(); + } + } + + bool breakFromUntil = false; + if (!op.getUntil().empty()) { + // Have an until clause to evaluate. + auto &until = op.getUntil().front(); + + // Use current state of loop variables as input to until block. + for (auto [bodyArg, untilArg] : + llvm::zip_equal(body.getArguments(), until.getArguments())) { + _values[untilArg] = _values[bodyArg]; + } + + for (auto &op : until) { + if (auto yieldOp = llvm::dyn_cast(op)) { + // Check break condition + assert(yieldOp->getNumOperands() == 1); + MatrixAttrReader condMat(_values[yieldOp.getInputs().front()]); + breakFromUntil = + llvm::cast(condMat.at(0, 0)).getValue(); + } else if (mlir::failed(evaluate(&op))) { + return mlir::failure(); + } + } + } + + if (breakFromUntil) { + break; + } + } + + // Set loop results. + for (auto [value, result] : + llvm::zip_equal(body.getArguments().drop_front(), op->getResults())) { + _values[result] = _values[value]; + } + + return mlir::success(); +} + mlir::LogicalResult Evaluator::evaluate(ForConstOp op) { MatrixAttrReader rangeBeginMat(_values[op.getRangeBegin()]); MatrixAttrReader rangeEndMat(_values[op.getRangeEnd()]); @@ -332,12 +415,13 @@ mlir::LogicalResult Evaluator::evaluate(mlir::Operation *op) { return llvm::TypeSwitch(op) #define GA_CASE(Op) .Case([&](Op op) { return evaluate(op); }) GA_CASE(TransposeOp) GA_CASE(DiagOp) GA_CASE(MatMulOp) GA_CASE(ReduceOp) - GA_CASE(BroadcastOp) GA_CASE(ConstantMatrixOp) GA_CASE(ForConstOp) - GA_CASE(ApplyOp) GA_CASE(PickAnyOp) GA_CASE(TrilOp) + GA_CASE(BroadcastOp) GA_CASE(ConstantMatrixOp) GA_CASE(ForOp) + GA_CASE(ForConstOp) GA_CASE(ApplyOp) GA_CASE(PickAnyOp) + GA_CASE(TrilOp) #undef GA_CASE - .Default([](mlir::Operation *op) { - return op->emitOpError("unsupported op"); - }); + .Default([](mlir::Operation *op) { + return op->emitOpError("unsupported op"); + }); } MatrixAttr Evaluator::evaluate(mlir::func::FuncOp funcOp, diff --git a/compiler/src/graphalg/parse/Parser.cpp b/compiler/src/graphalg/parse/Parser.cpp index e1013ef..79f8030 100644 --- a/compiler/src/graphalg/parse/Parser.cpp +++ b/compiler/src/graphalg/parse/Parser.cpp @@ -713,14 +713,18 @@ mlir::ParseResult Parser::parseStmtFor() { mlir::Region *untilRegion; mlir::ValueRange results; if (range.dim) { - auto forOp = _builder.create(loc, varTypes, initArgs, range.dim); + auto begin = _builder.getI64IntegerAttr(0); + auto forOp = + _builder.create(loc, varTypes, initArgs, /*dynBegin=*/nullptr, + /*dynEnd=*/nullptr, begin, range.dim); bodyRegion = &forOp.getBody(); untilRegion = &forOp.getUntil(); results = forOp->getResults(); } else { assert(range.begin && range.end); - auto forOp = _builder.create(loc, varTypes, initArgs, - range.begin, range.end); + auto forOp = _builder.create( + loc, varTypes, initArgs, /*dynBegin=*/range.begin, /*dynEnd=*/range.end, + /*begin=*/nullptr, /*end=*/nullptr); bodyRegion = &forOp.getBody(); untilRegion = &forOp.getUntil(); results = forOp->getResults(); @@ -827,7 +831,7 @@ mlir::ParseResult Parser::parseStmtReturn() { // Check if return is inside a loop auto *parentOp = _builder.getInsertionBlock()->getParentOp(); - if (llvm::isa(parentOp)) { + if (llvm::isa(parentOp)) { return mlir::emitError(loc) << "return statement inside a loop is not allowed"; } diff --git a/compiler/test/golden/bfs.mlir.ref b/compiler/test/golden/bfs.mlir.ref index dd33de2..aa66d71 100644 --- a/compiler/test/golden/bfs.mlir.ref +++ b/compiler/test/golden/bfs.mlir.ref @@ -14,7 +14,7 @@ module @"" { %3 = graphalg.broadcast %2 : <1 x 1 x i64> -> <#dim x 1 x i64> %4 = graphalg.mask %1<%arg1 : <#dim x 1 x i1>> = %3 : <#dim x 1 x i64> {complement = false} %5 = graphalg.cast_dim #dim - %6:3 = graphalg.for_dim range(#dim) init(%4, %arg1, %arg1) : !graphalg.mat<#dim x 1 x i64>, !graphalg.mat<#dim x 1 x i1>, !graphalg.mat<#dim x 1 x i1> -> !graphalg.mat<#dim x 1 x i64>, !graphalg.mat<#dim x 1 x i1>, !graphalg.mat<#dim x 1 x i1> body { + %6:3 = graphalg.for begin=0 iters=#dim init(%4, %arg1, %arg1) : !graphalg.mat<#dim x 1 x i64>, !graphalg.mat<#dim x 1 x i1>, !graphalg.mat<#dim x 1 x i1> -> !graphalg.mat<#dim x 1 x i64>, !graphalg.mat<#dim x 1 x i1>, !graphalg.mat<#dim x 1 x i1> body { ^bb0(%arg2: !graphalg.mat<1 x 1 x i64>, %arg3: !graphalg.mat<#dim x 1 x i64>, %arg4: !graphalg.mat<#dim x 1 x i1>, %arg5: !graphalg.mat<#dim x 1 x i1>): %7 = graphalg.cast_dim #dim %8 = graphalg.const_mat false -> <#dim x 1 x i1> diff --git a/compiler/test/golden/cdlp.mlir.ref b/compiler/test/golden/cdlp.mlir.ref index a3955d9..6a9666b 100644 --- a/compiler/test/golden/cdlp.mlir.ref +++ b/compiler/test/golden/cdlp.mlir.ref @@ -8,7 +8,7 @@ module @"" { %4 = graphalg.broadcast %3 : <1 x 1 x i1> -> <#dim x 1 x i1> %5 = graphalg.diag %4 : !graphalg.mat<#dim x 1 x i1> %6 = graphalg.literal 0 : i64 - %7 = graphalg.for_const range(%6, %0) : <1 x 1 x i64> init(%5) : !graphalg.mat<#dim x #dim x i1> -> !graphalg.mat<#dim x #dim x i1> body { + %7 = graphalg.for dyn_begin=%6 dyn_end=%0 init(%5) : !graphalg.mat<#dim x #dim x i1> -> !graphalg.mat<#dim x #dim x i1> body { ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<#dim x #dim x i1>): %19 = graphalg.cast %arg0 : <#dim x #dim x i1> -> <#dim x #dim x i64> %20 = graphalg.cast %arg2 : <#dim x #dim x i1> -> <#dim x #dim x i64> diff --git a/compiler/test/golden/pr.mlir.ref b/compiler/test/golden/pr.mlir.ref index 59b2785..d5ff80e 100644 --- a/compiler/test/golden/pr.mlir.ref +++ b/compiler/test/golden/pr.mlir.ref @@ -26,7 +26,7 @@ module @"" { %17 = graphalg.ewise %15 DIV %16 : <1 x 1 x f64> %18 = graphalg.broadcast %17 : <1 x 1 x f64> -> <#dim x 1 x f64> %19 = graphalg.literal 0 : i64 - %20 = graphalg.for_const range(%19, %arg2) : <1 x 1 x i64> init(%18) : !graphalg.mat<#dim x 1 x f64> -> !graphalg.mat<#dim x 1 x f64> body { + %20 = graphalg.for dyn_begin=%19 dyn_end=%arg2 init(%18) : !graphalg.mat<#dim x 1 x f64> -> !graphalg.mat<#dim x 1 x f64> body { ^bb0(%arg3: !graphalg.mat<1 x 1 x i64>, %arg4: !graphalg.mat<#dim x 1 x f64>): %21 = graphalg.const_mat 0.000000e+00 : f64 -> <#dim x 1 x f64> %22 = graphalg.mask %21<%13 : <#dim x 1 x i1>> = %arg4 : <#dim x 1 x f64> {complement = false} diff --git a/compiler/test/golden/sssp.mlir.ref b/compiler/test/golden/sssp.mlir.ref index d4b21db..b32f2e7 100644 --- a/compiler/test/golden/sssp.mlir.ref +++ b/compiler/test/golden/sssp.mlir.ref @@ -3,7 +3,7 @@ module @"" { func.func @SSSP(%arg0: !graphalg.mat<#dim x #dim x !graphalg.trop_f64>, %arg1: !graphalg.mat<#dim x 1 x i1>) -> !graphalg.mat<#dim x 1 x !graphalg.trop_f64> { %0 = graphalg.cast %arg1 : <#dim x 1 x i1> -> <#dim x 1 x !graphalg.trop_f64> %1 = graphalg.cast_dim #dim - %2 = graphalg.for_dim range(#dim) init(%0) : !graphalg.mat<#dim x 1 x !graphalg.trop_f64> -> !graphalg.mat<#dim x 1 x !graphalg.trop_f64> body { + %2 = graphalg.for begin=0 iters=#dim init(%0) : !graphalg.mat<#dim x 1 x !graphalg.trop_f64> -> !graphalg.mat<#dim x 1 x !graphalg.trop_f64> body { ^bb0(%arg2: !graphalg.mat<1 x 1 x i64>, %arg3: !graphalg.mat<#dim x 1 x !graphalg.trop_f64>): %3 = graphalg.vxm %arg3, %arg0 : <#dim x 1 x !graphalg.trop_f64>, <#dim x #dim x !graphalg.trop_f64> %4 = graphalg.ewise %arg3 ADD %3 : <#dim x 1 x !graphalg.trop_f64> diff --git a/compiler/test/golden/update_goldens.sh b/compiler/test/golden/update_goldens.sh index 01c9183..306153d 100755 --- a/compiler/test/golden/update_goldens.sh +++ b/compiler/test/golden/update_goldens.sh @@ -7,6 +7,6 @@ GOLDEN_DIR=$(dirname $0) for f in $GOLDEN_DIR/*.gr; do out="${f%.gr}.mlir.ref" - "$BUILD_DIR/graphalg-translate" --import-graphalg < "$f" > "$out" + "$BUILD_DIR/tools/graphalg-translate" --import-graphalg < "$f" > "$out" echo "Wrote $out" done diff --git a/compiler/test/golden/wcc.mlir.ref b/compiler/test/golden/wcc.mlir.ref index a078151..7ed86d9 100644 --- a/compiler/test/golden/wcc.mlir.ref +++ b/compiler/test/golden/wcc.mlir.ref @@ -7,7 +7,7 @@ module @"" { %3 = graphalg.broadcast %2 : <1 x 1 x i1> -> <#dim x 1 x i1> %4 = graphalg.diag %3 : !graphalg.mat<#dim x 1 x i1> %5 = graphalg.cast_dim #dim - %6 = graphalg.for_dim range(#dim) init(%4) : !graphalg.mat<#dim x #dim x i1> -> !graphalg.mat<#dim x #dim x i1> body { + %6 = graphalg.for begin=0 iters=#dim init(%4) : !graphalg.mat<#dim x #dim x i1> -> !graphalg.mat<#dim x #dim x i1> body { ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<#dim x #dim x i1>): %7 = graphalg.mxm %arg0, %arg2 : <#dim x #dim x i1>, <#dim x #dim x i1> %8 = graphalg.ewise %arg2 ADD %7 : <#dim x #dim x i1> diff --git a/compiler/test/parse/loop.gr b/compiler/test/parse/loop.gr index 07b13fc..5120de5 100644 --- a/compiler/test/parse/loop.gr +++ b/compiler/test/parse/loop.gr @@ -6,7 +6,7 @@ func ForConst() -> int { a = int(42); // CHECK: %[[#BEGIN:]] = graphalg.literal 1 // CHECK: %[[#END:]] = graphalg.literal 10 - // CHECK: %[[#LOOP:]] = graphalg.for_const range(%[[#BEGIN]], %[[#END]]) + // CHECK: %[[#LOOP:]] = graphalg.for dyn_begin=%[[#BEGIN]] dyn_end=%[[#END]] // CHECK: init(%[[#INIT]]) // CHECK: -> !graphalg.mat<1 x 1 x i64> body { // CHECK: ^bb0(%arg0: !graphalg.mat<1 x 1 x i64>, @@ -26,7 +26,7 @@ func ForConst() -> int { func ForDim(m:Matrix) -> int { // CHECK: %[[#INIT:]] = graphalg.literal 42 a = int(42); - // CHECK: %[[#LOOP:]] = graphalg.for_dim range(#dim) + // CHECK: %[[#LOOP:]] = graphalg.for begin=0 iters=#dim // CHECK: init(%[[#INIT]]) // CHECK: -> !graphalg.mat<1 x 1 x i64> body { // CHECK: ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, diff --git a/compiler/test/set-dimensions/for-dim.mlir b/compiler/test/set-dimensions/for-dim.mlir index 515758e..c6ba788 100644 --- a/compiler/test/set-dimensions/for-dim.mlir +++ b/compiler/test/set-dimensions/for-dim.mlir @@ -3,10 +3,8 @@ func.func @ForDimConst(%arg0: !graphalg.mat<#dim x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { %0 = graphalg.const_mat 1 : i64 -> <1 x 1 x i64> - // CHECK: %[[#BEGIN:]] = graphalg.const_mat 0 - // CHECK: %[[#END:]] = graphalg.const_mat 42 - // CHECK: graphalg.for_const range(%[[#BEGIN]], %[[#END]]) - %1 = graphalg.for_dim range(#dim) init(%0) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { + // CHECK: graphalg.for begin=0 iters=<42> + %1 = graphalg.for begin=0 iters=#dim init(%0) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<1 x 1 x i64>): graphalg.yield %arg2 : !graphalg.mat<1 x 1 x i64> } until { From adc035816627bfa789db01c91e2217cdb7686415 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Thu, 30 Apr 2026 15:03:39 +0000 Subject: [PATCH 5/6] Tests passing again. --- compiler/include/graphalg/GraphAlgOps.td | 79 +--------- compiler/src/garel/GraphAlgToRel.cpp | 93 ++++-------- .../src/graphalg/GraphAlgCanonicalize.cpp | 33 +--- .../src/graphalg/GraphAlgLoopAggregate.cpp | 30 +--- compiler/src/graphalg/GraphAlgOps.cpp | 143 +++++++----------- .../src/graphalg/GraphAlgScalarizeApply.cpp | 74 ++------- .../src/graphalg/GraphAlgSetDimensions.cpp | 1 - .../src/graphalg/analysis/DenseAnalysis.cpp | 2 +- compiler/src/graphalg/evaluate/Evaluator.cpp | 85 +---------- compiler/src/graphalg/parse/Parser.cpp | 2 +- compiler/test/canonicalize/for-dim.mlir | 19 --- compiler/test/exec/for.mlir | 18 +-- compiler/test/exec/until.mlir | 23 ++- compiler/test/graphalg-to-rel/for-dim.mlir | 24 --- .../{for-const.mlir => for.mlir} | 29 ++-- compiler/test/infer-density/for.mlir | 28 +--- .../test/loop-aggregate/add-init-reduce.mlir | 4 +- compiler/test/scalarize-apply/for-dim.mlir | 25 --- .../{for-const.mlir => for.mlir} | 12 +- compiler/test/verify-dimensions/for-dim.mlir | 6 +- 20 files changed, 148 insertions(+), 582 deletions(-) delete mode 100644 compiler/test/canonicalize/for-dim.mlir delete mode 100644 compiler/test/graphalg-to-rel/for-dim.mlir rename compiler/test/graphalg-to-rel/{for-const.mlir => for.mlir} (63%) delete mode 100644 compiler/test/scalarize-apply/for-dim.mlir rename compiler/test/scalarize-apply/{for-const.mlir => for.mlir} (62%) diff --git a/compiler/include/graphalg/GraphAlgOps.td b/compiler/include/graphalg/GraphAlgOps.td index f80deb8..8e99851 100644 --- a/compiler/include/graphalg/GraphAlgOps.td +++ b/compiler/include/graphalg/GraphAlgOps.td @@ -467,87 +467,10 @@ def ForOp : Core_Op<"for", [ }]; } -// Not core according to spec, but we don't want to unroll in the general case. -def ForConstOp : Core_Op<"for_const", [ - Pure, - AllTypesMatch<["rangeBegin", "rangeEnd"]>, - DeclareOpInterfaceMethods]> { - let summary = "For loop with constant bounds"; - - let description = [{ - A loop iterating over the integer range starting at `rangeBegin` - (inclusive) and ending at `rangeEnd` (exclusive). - The `body` region is executed once for every value in the integer range - (that value is passed as the first block argument). - At the first iteration of the loop, the other block arguments take the - values of `initArgs`. For subsequent iterations, results from the - previous iteration (produced by `YieldOp`) are taken instead. - The `until` region, if present, is executed after `body`, and produces a - single boolean scalar indicating whether the loop should terminate - early. - - In more imperative terms, `initArgs` can be seen as the set of variables - that are updated in the loop body. - Within the loop body, those variables can be accessed through the block - arguments, and their updated values are set through `YieldOp`. - Finally, `results` represents the new state of those variables after the - loop terminates. - }]; - - let arguments = (ins - Variadic:$initArgs, - I64Scalar:$rangeBegin, - I64Scalar:$rangeEnd); - - let results = (outs Variadic:$results); - - let regions = (region SizedRegion<1>:$body, MaxSizedRegion<1>:$until); - - let assemblyFormat = [{ - `range` `(` - $rangeBegin `,` - $rangeEnd - `)` `:` type($rangeEnd) - `init` `(` $initArgs `)` `:` type($initArgs) `->` type($results) attr-dict - `body` $body - `until` $until - }]; - - let hasRegionVerifier = 1; -} - -def ForDimOp : Core_Op<"for_dim", [ - Pure, - DeclareOpInterfaceMethods]> { - let summary = "For loop over a matrix dimension"; - - let description = [{ - A loop iterating over the half-open range [0..`dim`). - - This op is otherwise equivalent to `ForConstOp`. - }]; - - let arguments = (ins Variadic:$initArgs, DimAttr:$dim); - - let results = (outs Variadic:$results); - - let regions = (region SizedRegion<1>:$body, MaxSizedRegion<1>:$until); - - let assemblyFormat = [{ - `range` `(` custom($dim) `)` - `init` `(` $initArgs `)` `:` type($initArgs) `->` type($results) attr-dict - `body` $body - `until` $until - }]; - - let hasRegionVerifier = 1; - let hasCanonicalizer = 1; -} - def YieldOp : Core_Op<"yield", [ Pure, Terminator, - ParentOneOf<["ForOp", "ForConstOp", "ForDimOp"]>, + HasParent<"ForOp">, DeclareOpInterfaceMethods]> { let summary = "Yield from a loop body"; diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index cc58e52..cc94417 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -320,6 +320,11 @@ static mlir::Value createDimRead(mlir::Location loc, graphalg::DimAttr dim, */ static mlir::Value createDimInt(mlir::Location loc, graphalg::DimAttr dim, mlir::OpBuilder &builder) { + if (dim.isConcrete()) { + return builder.create( + loc, builder.getI64IntegerAttr(dim.getConcreteDim())); + } + auto input = createDimRead(loc, dim, builder); llvm::ArrayRef groupBy; std::array aggregators{ @@ -889,46 +894,22 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::success(); } -// Sharing logic between ForConstOp and ForDimOp -static mlir::LogicalResult -convertFor(mlir::Operation *op, mlir::ValueRange adaptorInitArgs, - mlir::Value rangeBegin, mlir::Value rangeEnd, mlir::Region &body, - mlir::Region &until, const mlir::TypeConverter *typeConverter, - mlir::ConversionPatternRewriter &rewriter) { - llvm::SmallVector initArgs{rangeBegin}; - initArgs.append(adaptorInitArgs.begin(), adaptorInitArgs.end()); +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::ForOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto begin = + rewriter.create(op.getLoc(), rewriter.getI64IntegerAttr(0)); + auto iters = createDimInt(op.getLoc(), *op.getIters(), rewriter); + llvm::SmallVector initArgs{begin}; + initArgs.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end()); - auto blockSignature = typeConverter->convertBlockSignature(&body.front()); + auto blockSignature = + typeConverter->convertBlockSignature(&op.getBody().front()); if (!blockSignature) { return op->emitOpError("Failed to convert iter args"); } - mlir::Value iters; - if (isConstantZeroI64(rangeBegin)) { - iters = rangeEnd; - } else { - // Subtract rangeBegin from rangeEnd - auto joinOp = rewriter.create( - op->getLoc(), mlir::ValueRange{rangeBegin, rangeEnd}, - rewriter.getAttr( - llvm::ArrayRef{})); - auto projOp = rewriter.create( - op->getLoc(), getI64RelationType(op->getContext()), joinOp); - - auto &block = projOp.createProjectionsBlock(); - mlir::OpBuilder::InsertionGuard guard{rewriter}; - rewriter.setInsertionPointToStart(&block); - - auto begin = - rewriter.create(op->getLoc(), 0, block.getArgument(0)); - auto end = - rewriter.create(op->getLoc(), 1, block.getArgument(0)); - auto res = rewriter.create(op->getLoc(), end, begin); - rewriter.create(op->getLoc(), mlir::ValueRange{res}); - - iters = projOp; - } - // The relational version of this op can only have a single output value. // For loops with multiple results, duplicate. llvm::SmallVector resultValues; @@ -936,24 +917,25 @@ convertFor(mlir::Operation *op, mlir::ValueRange adaptorInitArgs, auto result = op->getResult(i); if (result.use_empty()) { // Not used. Take init arg as a dummy value. - resultValues.push_back(adaptorInitArgs[i]); + resultValues.push_back(adaptor.getInitArgs()[i]); continue; } // We are adding the iteration count variable as a first argument, so offset // the result index accordingly. std::int64_t resultIdx = i + 1; - auto resultType = adaptorInitArgs[i].getType(); + auto resultType = adaptor.getInitArgs()[i].getType(); auto forOp = rewriter.create(op->getLoc(), resultType, initArgs, iters, resultIdx); // body block - rewriter.cloneRegionBefore(body, forOp.getBody(), forOp.getBody().begin()); + rewriter.cloneRegionBefore(op.getBody(), forOp.getBody(), + forOp.getBody().begin()); rewriter.applySignatureConversion(&forOp.getBody().front(), *blockSignature); // until block - if (!until.empty()) { - rewriter.cloneRegionBefore(until, forOp.getUntil(), + if (!op.getUntil().empty()) { + rewriter.cloneRegionBefore(op.getUntil(), forOp.getUntil(), forOp.getUntil().begin()); rewriter.applySignatureConversion(&forOp.getUntil().front(), *blockSignature); @@ -966,28 +948,6 @@ convertFor(mlir::Operation *op, mlir::ValueRange adaptorInitArgs, return mlir::success(); } -template <> -mlir::LogicalResult OpConversion::matchAndRewrite( - graphalg::ForConstOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const { - return convertFor(op, adaptor.getInitArgs(), adaptor.getRangeBegin(), - adaptor.getRangeEnd(), op.getBody(), op.getUntil(), - typeConverter, rewriter); -} - -template <> -mlir::LogicalResult OpConversion::matchAndRewrite( - graphalg::ForDimOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const { - auto ctx = op.getContext(); - auto rangeBegin = - rewriter.create(op.getLoc(), rewriter.getI64IntegerAttr(0)); - auto rangeEnd = createDimInt(op.getLoc(), op.getDim(), rewriter); - - return convertFor(op, adaptor.getInitArgs(), rangeBegin, rangeEnd, - op.getBody(), op.getUntil(), typeConverter, rewriter); -} - template <> mlir::LogicalResult OpConversion::matchAndRewrite( graphalg::YieldOp op, OpAdaptor adaptor, @@ -1481,11 +1441,10 @@ void GraphAlgToRel::runOnOperation() { OpConversion, OpConversion, OpConversion, OpConversion, OpConversion, - OpConversion, OpConversion, - OpConversion, OpConversion, - OpConversion, OpConversion, - OpConversion, OpConversion>( - matrixTypeConverter, &getContext()); + OpConversion, OpConversion, + OpConversion, OpConversion, + OpConversion, OpConversion, + OpConversion>(matrixTypeConverter, &getContext()); patterns.add(semiringTypeConverter, matrixTypeConverter, &getContext()); diff --git a/compiler/src/graphalg/GraphAlgCanonicalize.cpp b/compiler/src/graphalg/GraphAlgCanonicalize.cpp index 1454d0d..f430b3c 100644 --- a/compiler/src/graphalg/GraphAlgCanonicalize.cpp +++ b/compiler/src/graphalg/GraphAlgCanonicalize.cpp @@ -272,38 +272,9 @@ ForOp::fold(FoldAdaptor adaptor, return mlir::success(); } - return mlir::failure(); -} - -static mlir::LogicalResult forDimConst(ForDimOp op, - mlir::PatternRewriter &rewriter) { - if (!op.getDim().isConcrete()) { - return mlir::failure(); - } - - // The number of iterations is known, so we can replace with a ForConstOp. - auto end = op.getDim().getConcreteDim(); - - // Range from 0 to dim. - auto intType = - MatrixType::scalarOf(SemiringTypes::forInt(rewriter.getContext())); - auto beginOp = rewriter.create( - op->getLoc(), intType, rewriter.getI64IntegerAttr(0)); - auto endOp = rewriter.create( - op->getLoc(), intType, rewriter.getI64IntegerAttr(end)); + // TODO: Fold if iters=0 - auto forConstOp = rewriter.create( - op->getLoc(), op->getResultTypes(), op.getInitArgs(), beginOp, endOp); - rewriter.inlineRegionBefore(op.getBody(), forConstOp.getBody(), - forConstOp.getBody().begin()); - rewriter.replaceOp(op, forConstOp); - - return mlir::success(); -} - -void ForDimOp::getCanonicalizationPatterns(mlir::RewritePatternSet &patterns, - mlir::MLIRContext *context) { - patterns.add(forDimConst); + return mlir::failure(); } mlir::OpFoldResult PickAnyOp::fold(FoldAdaptor adaptor) { diff --git a/compiler/src/graphalg/GraphAlgLoopAggregate.cpp b/compiler/src/graphalg/GraphAlgLoopAggregate.cpp index cfac2af..aaa2bfe 100644 --- a/compiler/src/graphalg/GraphAlgLoopAggregate.cpp +++ b/compiler/src/graphalg/GraphAlgLoopAggregate.cpp @@ -28,19 +28,10 @@ class GraphAlgLoopAggregate // If the loop body ends with an aggregation op, ensure the init arg is also an // aggregation. Later on in the AvantGraph query pipeline, this will signal to // the optimizer that the iteration state can be kept as an aggregate table. -static void addInitReduce(mlir::Operation *op, mlir::IRRewriter &rewriter) { - mlir::Block *body; - llvm::SmallVector newInitArgs; - if (auto constOp = llvm::dyn_cast(op)) { - body = &constOp.getBody().front(); - newInitArgs = constOp.getInitArgs(); - } else { - auto dimOp = llvm::cast(op); - body = &dimOp.getBody().front(); - newInitArgs = dimOp.getInitArgs(); - } +static void addInitReduce(ForOp op, mlir::IRRewriter &rewriter) { + llvm::SmallVector newInitArgs(op.getInitArgs()); - auto yieldOp = llvm::cast(body->getTerminator()); + auto yieldOp = llvm::cast(op.getBody().front().getTerminator()); for (auto [i, iterResult] : llvm::enumerate(yieldOp.getInputs())) { auto iterLastOp = iterResult.getDefiningOp(); if (llvm::isa_and_present(iterLastOp)) { @@ -50,20 +41,13 @@ static void addInitReduce(mlir::Operation *op, mlir::IRRewriter &rewriter) { } } - rewriter.modifyOpInPlace(op, [&]() { - if (auto constOp = llvm::dyn_cast(op)) { - constOp.getInitArgsMutable().assign(newInitArgs); - } else { - auto dimOp = llvm::cast(op); - dimOp.getInitArgsMutable().assign(newInitArgs); - } - }); + rewriter.modifyOpInPlace( + op, [&]() { op.getInitArgsMutable().assign(newInitArgs); }); } void GraphAlgLoopAggregate::runOnOperation() { - llvm::SmallVector loopOps; - getOperation()->walk([&](ForConstOp op) { loopOps.emplace_back(op); }); - getOperation()->walk([&](ForDimOp op) { loopOps.emplace_back(op); }); + llvm::SmallVector loopOps; + getOperation()->walk([&](ForOp op) { loopOps.emplace_back(op); }); mlir::IRRewriter rewriter(&getContext()); for (auto op : loopOps) { diff --git a/compiler/src/graphalg/GraphAlgOps.cpp b/compiler/src/graphalg/GraphAlgOps.cpp index 9176b8e..472693f 100644 --- a/compiler/src/graphalg/GraphAlgOps.cpp +++ b/compiler/src/graphalg/GraphAlgOps.cpp @@ -398,66 +398,6 @@ mlir::LogicalResult ConstantMatrixOp::verify() { return mlir::success(); } -mlir::LogicalResult verifyLoop(mlir::Operation *op, mlir::ValueRange initArgs, - mlir::Region &bodyRegion, - mlir::Region &untilRegion) { - llvm::SmallVector initArgTypes; - for (auto arg : initArgs) { - initArgTypes.emplace_back(arg.getType()); - } - - if (op->getResultTypes() != initArgTypes) { - return op->emitOpError("result types ") - << op->getResultTypes() << " do not match init args " - << initArgTypes; - } - - // Block arguments must start with an iteration counter - auto *ctx = op->getContext(); - auto iterType = MatrixType::scalarOf(SemiringTypes::forInt(ctx)); - for (auto ®ion : op->getRegions()) { - if (region.empty()) { - continue; - } - - mlir::TypeRange argTypes = region.getArgumentTypes(); - if (argTypes.empty() || argTypes.front() != iterType) { - return op->emitOpError("region types ") - << argTypes << "do not include the iteration variable"; - } - } - - // Body must have YieldOp as terminator - auto &body = bodyRegion.front(); - if (!body.mightHaveTerminator()) { - return op->emitOpError("body region does not have a terminator"); - } - auto bodyYield = llvm::dyn_cast_if_present(body.getTerminator()); - if (!bodyYield) { - return op->emitOpError("body region is not terminated with a YieldOp"); - } - - // If there is an until block, it should return a boolean. - if (!untilRegion.empty()) { - auto &until = untilRegion.front(); - if (!until.mightHaveTerminator()) { - return op->emitOpError("until region does not have a terminator"); - } - auto untilYield = llvm::dyn_cast_if_present(until.getTerminator()); - if (!untilYield) { - return op->emitOpError("until region is not terminated with a YieldOp"); - } - - auto expectedType = MatrixType::scalarOf(SemiringTypes::forBool(ctx)); - if (untilYield->getOperandTypes() != mlir::TypeRange{expectedType}) { - return op->emitOpError("until block does not return a bool scalar: ") - << untilYield->getOperandTypes(); - } - } - - return mlir::success(); -} - void getLoopSuccessorRegions( mlir::Operation *op, mlir::Region &body, mlir::Region &until, mlir::RegionBranchPoint point, @@ -523,46 +463,63 @@ mlir::LogicalResult ForOp::verify() { } mlir::LogicalResult ForOp::verifyRegions() { - return verifyLoop(getOperation(), getInitArgs(), getBody(), getUntil()); -} + llvm::SmallVector initArgTypes; + for (auto arg : getInitArgs()) { + initArgTypes.emplace_back(arg.getType()); + } -void ForOp::getSuccessorRegions( - mlir::RegionBranchPoint point, - llvm::SmallVectorImpl ®ions) { - getLoopSuccessorRegions(getOperation(), getBody(), getUntil(), point, - regions); -} + if (getResultTypes() != initArgTypes) { + return emitOpError("result types ") + << getResultTypes() << " do not match init args " << initArgTypes; + } -mlir::OperandRange -ForOp::getEntrySuccessorOperands(mlir::RegionBranchPoint point) { - return getInitArgs(); -} + // Block arguments must start with an iteration counter + auto *ctx = getContext(); + auto iterType = MatrixType::scalarOf(SemiringTypes::forInt(ctx)); + for (auto ®ion : getOperation()->getRegions()) { + if (region.empty()) { + continue; + } -bool ForOp::isDynamicRange() { return getDynBegin() || getDynEnd(); } + mlir::TypeRange argTypes = region.getArgumentTypes(); + if (argTypes.empty() || argTypes.front() != iterType) { + return emitOpError("region types ") + << argTypes << "do not include the iteration variable"; + } + } -// === ForConstOp === -mlir::LogicalResult ForConstOp::verifyRegions() { - return verifyLoop(getOperation(), getInitArgs(), getBody(), getUntil()); -} + // Body must have YieldOp as terminator + auto &body = getBody().front(); + if (!body.mightHaveTerminator()) { + return emitOpError("body region does not have a terminator"); + } + auto bodyYield = llvm::dyn_cast_if_present(body.getTerminator()); + if (!bodyYield) { + return emitOpError("body region is not terminated with a YieldOp"); + } -void ForConstOp::getSuccessorRegions( - mlir::RegionBranchPoint point, - llvm::SmallVectorImpl ®ions) { - getLoopSuccessorRegions(getOperation(), getBody(), getUntil(), point, - regions); -} + // If there is an until block, it should return a boolean. + if (!getUntil().empty()) { + auto &until = getUntil().front(); + if (!until.mightHaveTerminator()) { + return emitOpError("until region does not have a terminator"); + } + auto untilYield = llvm::dyn_cast_if_present(until.getTerminator()); + if (!untilYield) { + return emitOpError("until region is not terminated with a YieldOp"); + } -mlir::OperandRange -ForConstOp::getEntrySuccessorOperands(mlir::RegionBranchPoint point) { - return getInitArgs(); -} + auto expectedType = MatrixType::scalarOf(SemiringTypes::forBool(ctx)); + if (untilYield->getOperandTypes() != mlir::TypeRange{expectedType}) { + return emitOpError("until block does not return a bool scalar: ") + << untilYield->getOperandTypes(); + } + } -// === ForDimOp === -mlir::LogicalResult ForDimOp::verifyRegions() { - return verifyLoop(getOperation(), getInitArgs(), getBody(), getUntil()); + return mlir::success(); } -void ForDimOp::getSuccessorRegions( +void ForOp::getSuccessorRegions( mlir::RegionBranchPoint point, llvm::SmallVectorImpl ®ions) { getLoopSuccessorRegions(getOperation(), getBody(), getUntil(), point, @@ -570,10 +527,12 @@ void ForDimOp::getSuccessorRegions( } mlir::OperandRange -ForDimOp::getEntrySuccessorOperands(mlir::RegionBranchPoint point) { +ForOp::getEntrySuccessorOperands(mlir::RegionBranchPoint point) { return getInitArgs(); } +bool ForOp::isDynamicRange() { return getDynBegin() || getDynEnd(); } + // === YieldOp === mlir::MutableOperandRange YieldOp::getMutableSuccessorOperands(mlir::RegionBranchPoint point) { diff --git a/compiler/src/graphalg/GraphAlgScalarizeApply.cpp b/compiler/src/graphalg/GraphAlgScalarizeApply.cpp index 3f3008c..3821395 100644 --- a/compiler/src/graphalg/GraphAlgScalarizeApply.cpp +++ b/compiler/src/graphalg/GraphAlgScalarizeApply.cpp @@ -149,17 +149,17 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::success(); } -// Unrolls a single iteration of the loop body. +// Unroll one iteration of the loop. static mlir::LogicalResult -unrollLoopBody(mlir::Operation *loopOp, mlir::Block &body, - const llvm::APInt &iter, mlir::ValueRange iterArgs, +unrollLoopBody(ForOp forOp, std::uint64_t iter, mlir::ValueRange iterArgs, llvm::SmallVectorImpl &results, mlir::OpBuilder &builder) { mlir::IRMapping mapping; // Map iter var to a constant. - auto zeroOp = builder.create( - loopOp->getLoc(), mlir::IntegerAttr::get(builder.getI64Type(), iter)); + auto zeroOp = builder.create(forOp.getLoc(), + builder.getI64IntegerAttr(iter)); + auto &body = forOp.getBody().front(); mapping.map(body.getArgument(0), zeroOp); // Map the remainder of the body arguments to the init args. @@ -184,68 +184,23 @@ unrollLoopBody(mlir::Operation *loopOp, mlir::Block &body, mapping.map(&oldOp, newOp); } - return loopOp->emitOpError("does not have a terminator ") + return forOp.emitOpError("does not have a terminator ") << YieldOp::getOperationName() << ", so cannot be inlined"; } -static mlir::LogicalResult unrollForDimOne(ForDimOp op, - mlir::PatternRewriter &rewriter) { - // Well-formed GraphAlg programs do not use dimension symbols inside - // functions that do not have this dimension symbol as an argument (but this - // is not currently enforced by verifiers). Inside of an apply all - // dimensions are one, so this should cover all well-formed programs. - if (!op.getDim().isOne()) { - // We need a constant size to be able to do loop unrolling. - return op->emitOpError("is contained in an ") - << ApplyInlineOp::getOperationName() - << ", but contains a non-1 dimension symbol reference"; - } - - // Inline the body - llvm::APInt iter(64, std::int64_t(0)); - llvm::SmallVector resultValues; - if (mlir::failed(unrollLoopBody(op, op.getBody().front(), iter, - op.getInitArgs(), resultValues, rewriter))) { - return mlir::failure(); - } - - rewriter.replaceOp(op, resultValues); - return mlir::success(); -} - -static std::optional tryGetConstantRangeValue(mlir::Value v) { - auto litOp = v.getDefiningOp(); - if (!litOp) { - return std::nullopt; - } - - if (auto intAttr = llvm::dyn_cast(litOp.getValue())) { - return intAttr.getValue(); - } - - return std::nullopt; -} - -static mlir::LogicalResult unrollForConst(ForConstOp op, - mlir::PatternRewriter &rewriter) { - auto begin = tryGetConstantRangeValue(op.getRangeBegin()); - auto end = tryGetConstantRangeValue(op.getRangeEnd()); - if (!begin || !end) { - // TODO: Handle loops where the bound is not yet constant. - // Ranges should be constant in GraphAlg, but that constant may be - // provided by the caller of a function, so we could encounter - // \c mlir::BlockArgument here in a valid program. - return op->emitOpError("does not have a constant range") +static mlir::LogicalResult unrollFor(ForOp op, + mlir::PatternRewriter &rewriter) { + if (op.isDynamicRange() || op.getIters()->isAbstract()) { + return op.emitOpError("does not have a constant range") << ", so cannot be unrolled"; } llvm::SmallVector iterArgs(op.getInitArgs()); - auto one = llvm::APInt(begin->getBitWidth(), 1); - for (auto i = *begin; i.slt(*end); i = i + one) { + auto end = *op.getBegin() + op.getIters()->getConcreteDim(); + for (auto i = *op.getBegin(); i < end; i++) { llvm::SmallVector results; - if (mlir::failed(unrollLoopBody(op, op.getBody().front(), i, iterArgs, - results, rewriter))) { + if (mlir::failed(unrollLoopBody(op, i, iterArgs, results, rewriter))) { return mlir::failure(); } @@ -457,8 +412,7 @@ void GraphAlgScalarizeApply::runOnOperation() { OpConversion, OpConversion>( typeConverter, &getContext()); // Simplifications into ops that can be converted using rules above. - conversions.add(unrollForDimOne); - conversions.add(unrollForConst); + conversions.add(unrollFor); conversions.add(convertMask); conversions.add(convertMatMul); conversions.add(convertNeg); diff --git a/compiler/src/graphalg/GraphAlgSetDimensions.cpp b/compiler/src/graphalg/GraphAlgSetDimensions.cpp index 490c8ef..1c83839 100644 --- a/compiler/src/graphalg/GraphAlgSetDimensions.cpp +++ b/compiler/src/graphalg/GraphAlgSetDimensions.cpp @@ -364,7 +364,6 @@ void GraphAlgSetDimensions::runOnOperation() { doesNotUseAbstractDimensions); target.addDynamicallyLegalOp(doesNotHaveAbstractInputs); target.addIllegalOp(); - target.addIllegalOp(); mlir::RewritePatternSet patterns(&getContext()); diff --git a/compiler/src/graphalg/analysis/DenseAnalysis.cpp b/compiler/src/graphalg/analysis/DenseAnalysis.cpp index da2dfd1..36fec72 100644 --- a/compiler/src/graphalg/analysis/DenseAnalysis.cpp +++ b/compiler/src/graphalg/analysis/DenseAnalysis.cpp @@ -52,7 +52,7 @@ DenseAnalysis::visitOperation(mlir::Operation *op, void DenseAnalysis::visitNonControlFlowArguments( mlir::Operation *op, const mlir::RegionSuccessor &successor, llvm::ArrayRef argLattices, unsigned firstIndex) { - if (llvm::isa(op)) { + if (llvm::isa(op)) { // Iteration counter is dense. assert(firstIndex == 1); auto arg = argLattices[0]; diff --git a/compiler/src/graphalg/evaluate/Evaluator.cpp b/compiler/src/graphalg/evaluate/Evaluator.cpp index 335beb6..8e08b30 100644 --- a/compiler/src/graphalg/evaluate/Evaluator.cpp +++ b/compiler/src/graphalg/evaluate/Evaluator.cpp @@ -36,7 +36,6 @@ class Evaluator { mlir::LogicalResult evaluate(BroadcastOp op); mlir::LogicalResult evaluate(ConstantMatrixOp op); mlir::LogicalResult evaluate(ForOp op); - mlir::LogicalResult evaluate(ForConstOp op); mlir::LogicalResult evaluate(ApplyOp op); mlir::LogicalResult evaluate(PickAnyOp op); mlir::LogicalResult evaluate(TrilOp op); @@ -270,81 +269,6 @@ mlir::LogicalResult Evaluator::evaluate(ForOp op) { return mlir::success(); } -mlir::LogicalResult Evaluator::evaluate(ForConstOp op) { - MatrixAttrReader rangeBeginMat(_values[op.getRangeBegin()]); - MatrixAttrReader rangeEndMat(_values[op.getRangeEnd()]); - auto rangeBegin = - llvm::cast(rangeBeginMat.at(0, 0)).getInt(); - auto rangeEnd = llvm::cast(rangeEndMat.at(0, 0)).getInt(); - - auto &body = op.getBody().front(); - auto *ctx = op.getContext(); - - // Initialize block arguments - for (auto [init, blockArg] : - llvm::zip_equal(op.getInitArgs(), body.getArguments().drop_front())) { - _values[blockArg] = _values[init]; - } - - for (auto i : llvm::seq(rangeBegin, rangeEnd)) { - // Iteration variable. - auto iterAttr = mlir::IntegerAttr::get(SemiringTypes::forInt(ctx), i); - auto iterArg = body.getArgument(0); - auto iterType = llvm::cast(iterArg.getType()); - MatrixAttrBuilder iterBuilder(iterType); - iterBuilder.set(0, 0, iterAttr); - _values[body.getArgument(0)] = iterBuilder.build(); - - for (auto &op : body) { - if (auto yieldOp = llvm::dyn_cast(op)) { - // Update block arguments - for (auto [value, blockArg] : llvm::zip_equal( - yieldOp.getInputs(), body.getArguments().drop_front())) { - _values[blockArg] = _values[value]; - } - } else if (mlir::failed(evaluate(&op))) { - return mlir::failure(); - } - } - - bool breakFromUntil = false; - if (!op.getUntil().empty()) { - // Have an until clause to evaluate. - auto &until = op.getUntil().front(); - - // Use current state of loop variables as input to until block. - for (auto [bodyArg, untilArg] : - llvm::zip_equal(body.getArguments(), until.getArguments())) { - _values[untilArg] = _values[bodyArg]; - } - - for (auto &op : until) { - if (auto yieldOp = llvm::dyn_cast(op)) { - // Check break condition - assert(yieldOp->getNumOperands() == 1); - MatrixAttrReader condMat(_values[yieldOp.getInputs().front()]); - breakFromUntil = - llvm::cast(condMat.at(0, 0)).getValue(); - } else if (mlir::failed(evaluate(&op))) { - return mlir::failure(); - } - } - } - - if (breakFromUntil) { - break; - } - } - - // Set loop results. - for (auto [value, result] : - llvm::zip_equal(body.getArguments().drop_front(), op->getResults())) { - _values[result] = _values[value]; - } - - return mlir::success(); -} - mlir::LogicalResult Evaluator::evaluate(ApplyOp op) { llvm::SmallVector inputs; for (auto input : op.getInputs()) { @@ -416,12 +340,11 @@ mlir::LogicalResult Evaluator::evaluate(mlir::Operation *op) { #define GA_CASE(Op) .Case([&](Op op) { return evaluate(op); }) GA_CASE(TransposeOp) GA_CASE(DiagOp) GA_CASE(MatMulOp) GA_CASE(ReduceOp) GA_CASE(BroadcastOp) GA_CASE(ConstantMatrixOp) GA_CASE(ForOp) - GA_CASE(ForConstOp) GA_CASE(ApplyOp) GA_CASE(PickAnyOp) - GA_CASE(TrilOp) + GA_CASE(ApplyOp) GA_CASE(PickAnyOp) GA_CASE(TrilOp) #undef GA_CASE - .Default([](mlir::Operation *op) { - return op->emitOpError("unsupported op"); - }); + .Default([](mlir::Operation *op) { + return op->emitOpError("unsupported op"); + }); } MatrixAttr Evaluator::evaluate(mlir::func::FuncOp funcOp, diff --git a/compiler/src/graphalg/parse/Parser.cpp b/compiler/src/graphalg/parse/Parser.cpp index 79f8030..d466c26 100644 --- a/compiler/src/graphalg/parse/Parser.cpp +++ b/compiler/src/graphalg/parse/Parser.cpp @@ -831,7 +831,7 @@ mlir::ParseResult Parser::parseStmtReturn() { // Check if return is inside a loop auto *parentOp = _builder.getInsertionBlock()->getParentOp(); - if (llvm::isa(parentOp)) { + if (llvm::isa(parentOp)) { return mlir::emitError(loc) << "return statement inside a loop is not allowed"; } diff --git a/compiler/test/canonicalize/for-dim.mlir b/compiler/test/canonicalize/for-dim.mlir deleted file mode 100644 index 167a493..0000000 --- a/compiler/test/canonicalize/for-dim.mlir +++ /dev/null @@ -1,19 +0,0 @@ -// RUN: graphalg-opt --canonicalize < %s | FileCheck %s - -func.func @ForDimConst() -> !graphalg.mat<1 x 1 x i64> { - - // CHECK: %[[#ZERO:]] = graphalg.const_mat 0 - // CHECK: %[[#END:]] = graphalg.const_mat 42 - // CHECK: %[[#FOR:]] = graphalg.for_const range(%[[#ZERO]], %[[#END]]) - // CHECK-SAME: init(%[[#ZERO]]) - // CHECK: graphalg.yield %arg1 : !graphalg.mat<1 x 1 x i64> - // CHECK: return %[[#FOR]] - - %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> - %1 = graphalg.for_dim range(42) init(%0) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { - ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<1 x 1 x i64>): - graphalg.yield %arg2 : !graphalg.mat<1 x 1 x i64> - } until { - } - return %1 : !graphalg.mat<1 x 1 x i64> -} diff --git a/compiler/test/exec/for.mlir b/compiler/test/exec/for.mlir index 6516987..2450078 100644 --- a/compiler/test/exec/for.mlir +++ b/compiler/test/exec/for.mlir @@ -10,21 +10,19 @@ //--- input.mlir func.func @Reach(%arg0: !graphalg.mat<3 x 3 x i1>, %arg1: !graphalg.mat<3 x 1 x i1>) -> !graphalg.mat<3 x 1 x i1> { - %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> - %1 = graphalg.const_mat 3 : i64 -> <1 x 1 x i64> - %2 = graphalg.for_const range(%0, %1) : <1 x 1 x i64> init(%arg1) : !graphalg.mat<3 x 1 x i1> -> !graphalg.mat<3 x 1 x i1> body { + %0 = graphalg.for begin=0 iters=<3> init(%arg1) : !graphalg.mat<3 x 1 x i1> -> !graphalg.mat<3 x 1 x i1> body { ^bb0(%arg2: !graphalg.mat<1 x 1 x i64>, %arg3: !graphalg.mat<3 x 1 x i1>): - %3 = graphalg.transpose %arg0 : <3 x 3 x i1> - %4 = graphalg.mxm %3, %arg3 : <3 x 3 x i1>, <3 x 1 x i1> - %5 = graphalg.apply %arg3, %4 : !graphalg.mat<3 x 1 x i1>, !graphalg.mat<3 x 1 x i1> -> <3 x 1 x i1> { + %1 = graphalg.transpose %arg0 : <3 x 3 x i1> + %2 = graphalg.mxm %1, %arg3 : <3 x 3 x i1>, <3 x 1 x i1> + %3 = graphalg.apply %arg3, %2 : !graphalg.mat<3 x 1 x i1>, !graphalg.mat<3 x 1 x i1> -> <3 x 1 x i1> { ^bb0(%arg4: i1, %arg5: i1): - %6 = graphalg.add %arg4, %arg5 : i1 - graphalg.apply.return %6 : i1 + %4 = graphalg.add %arg4, %arg5 : i1 + graphalg.apply.return %4 : i1 } - graphalg.yield %5 : !graphalg.mat<3 x 1 x i1> + graphalg.yield %3 : !graphalg.mat<3 x 1 x i1> } until { } - return %2 : !graphalg.mat<3 x 1 x i1> + return %0 : !graphalg.mat<3 x 1 x i1> } //--- output.m diff --git a/compiler/test/exec/until.mlir b/compiler/test/exec/until.mlir index 7ad1e26..cb10fc5 100644 --- a/compiler/test/exec/until.mlir +++ b/compiler/test/exec/until.mlir @@ -5,26 +5,25 @@ func.func @Fib() -> !graphalg.mat<1 x 1 x i64> { %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> %1 = graphalg.const_mat 1 : i64 -> <1 x 1 x i64> - %2 = graphalg.const_mat 1000000 : i64 -> <1 x 1 x i64> - %3:2 = graphalg.for_const range(%0, %2) : <1 x 1 x i64> init(%0, %1) : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> body { + %2:2 = graphalg.for begin=0 iters=<1000000> init(%0, %1) : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> body { ^bb0(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<1 x 1 x i64>): - %4 = graphalg.apply %arg1, %arg2 : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i64> { + %3 = graphalg.apply %arg1, %arg2 : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i64> { ^bb0(%arg3: i64, %arg4: i64): - %5 = graphalg.add %arg3, %arg4 : i64 - graphalg.apply.return %5 : i64 + %4 = graphalg.add %arg3, %arg4 : i64 + graphalg.apply.return %4 : i64 } - graphalg.yield %arg2, %4 : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> + graphalg.yield %arg2, %3 : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> } until { ^bb0(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<1 x 1 x i64>): - %4 = graphalg.apply %arg2 : !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i1> { + %3 = graphalg.apply %arg2 : !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i1> { ^bb0(%arg3: i64): - %5 = graphalg.const 34 : i64 - %6 = graphalg.eq %arg3, %5 : i64 - graphalg.apply.return %6 : i1 + %4 = graphalg.const 34 : i64 + %5 = graphalg.eq %arg3, %4 : i64 + graphalg.apply.return %5 : i1 } - graphalg.yield %4 : !graphalg.mat<1 x 1 x i1> + graphalg.yield %3 : !graphalg.mat<1 x 1 x i1> } - return %3#1 : !graphalg.mat<1 x 1 x i64> + return %2#1 : !graphalg.mat<1 x 1 x i64> } //--- output.m diff --git a/compiler/test/graphalg-to-rel/for-dim.mlir b/compiler/test/graphalg-to-rel/for-dim.mlir deleted file mode 100644 index 191565e..0000000 --- a/compiler/test/graphalg-to-rel/for-dim.mlir +++ /dev/null @@ -1,24 +0,0 @@ -// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s - -#dim = #graphalg.dim> -// CHECK-LABEL: @ForDim -func.func @ForDim(%arg0: !graphalg.mat<#dim x #dim x i64>) -> !graphalg.mat<1 x 1 x i64> { - // CHECK: %[[#INIT:]] = garel.const 42 : i64 - %0 = graphalg.const_mat 42 : i64 -> <1 x 1 x i64> - - // CHECK: %[[#BEGIN:]] = garel.const 0 : i64 - // CHECK: %[[#ITERS:]] = garel.aggregate %arg1 : group_by=[] aggregators=[] - // CHECK: %[[#FOR:]] = garel.for - // CHECK-SAME: %[[#BEGIN]], %[[#INIT]] : !garel.rel, !garel.rel - // CHECK-SAME: iters=%[[#ITERS]] result_idx=1 - %1 = graphalg.for_dim range(#dim) init(%0) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { - ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<1 x 1 x i64>): - // CHECK: %[[#INC:]] = garel.project %arg2 - // CHECK: garel.for.yield %[[#INC]], %arg2 - graphalg.yield %arg1 : !graphalg.mat<1 x 1 x i64> - } until { - } - - // CHECK: return %[[#FOR]] - return %1 : !graphalg.mat<1 x 1 x i64> -} diff --git a/compiler/test/graphalg-to-rel/for-const.mlir b/compiler/test/graphalg-to-rel/for.mlir similarity index 63% rename from compiler/test/graphalg-to-rel/for-const.mlir rename to compiler/test/graphalg-to-rel/for.mlir index 936fffc..6ba8936 100644 --- a/compiler/test/graphalg-to-rel/for-const.mlir +++ b/compiler/test/graphalg-to-rel/for.mlir @@ -1,14 +1,11 @@ // RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s -// CHECK-LABEL: @ForConst -func.func @ForConst(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { - %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> - %1 = graphalg.const_mat 10 : i64 -> <1 x 1 x i64> - +// CHECK-LABEL: @For +func.func @For(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { // CHECK: %[[#BEGIN:]] = garel.const 0 : i64 // CHECK: %[[#ITERS:]] = garel.const 10 : i64 // CHECK: %[[#FOR:]] = garel.for %[[#BEGIN]], %arg0 : !garel.rel, !garel.rel iters=%[[#ITERS]] result_idx=1 { - %2 = graphalg.for_const range(%0, %1) : <1 x 1 x i64> init(%arg0) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { + %0 = graphalg.for begin=0 iters=<10> init(%arg0) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<1 x 1 x i64>): // CHECK: %[[#PROJ:]] = garel.project %arg1 // CHECK: %[[#EXT:]] = garel.extract 0 @@ -20,18 +17,15 @@ func.func @ForConst(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x } // CHECK: return %[[#FOR]] - return %2 : !graphalg.mat<1 x 1 x i64> + return %0 : !graphalg.mat<1 x 1 x i64> } // CHECK-LABEL: @ForResultUnused func.func @ForResultUnused(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.mat<1 x 1 x f64>) -> !graphalg.mat<1 x 1 x f64> { - %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> - %1 = graphalg.const_mat 10 : i64 -> <1 x 1 x i64> - // CHECK: %[[#BEGIN:]] = garel.const 0 : i64 // CHECK: %[[#FOR:]] = garel.for %[[#BEGIN]], %arg0, %arg1 - %2:2 = graphalg.for_const range(%0, %1) : <1 x 1 x i64> init(%arg0, %arg1) : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x f64> -> !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x f64> body { + %0:2 = graphalg.for begin=0 iters=<10> init(%arg0, %arg1) : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x f64> -> !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x f64> body { ^bb0(%arg2: !graphalg.mat<1 x 1 x i64>, %arg3: !graphalg.mat<1 x 1 x i64>, %arg4: !graphalg.mat<1 x 1 x f64>): // CHECK: %[[#PROJ:]] = garel.project %arg2 // CHECK: garel.for.yield %[[#PROJ]], %arg3, %arg4 @@ -40,17 +34,14 @@ func.func @ForResultUnused(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.m } // CHECK: return %[[#FOR]] - return %2#1 : !graphalg.mat<1 x 1 x f64> + return %0#1 : !graphalg.mat<1 x 1 x f64> } // CHECK-LABEL: @Until func.func @Until(%arg0: !graphalg.mat<42 x 42 x i1>) -> !graphalg.mat<42 x 42 x i1> { - %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> - %1 = graphalg.const_mat 10 : i64 -> <1 x 1 x i64> - // CHECK: %[[#BEGIN:]] = garel.const 0 : i64 // CHECK: %[[#FOR:]] = garel.for %[[#BEGIN]], %arg0 - %2 = graphalg.for_const range(%0, %1) : <1 x 1 x i64> init(%arg0) : !graphalg.mat<42 x 42 x i1> -> !graphalg.mat<42 x 42 x i1> body { + %0 = graphalg.for begin=0 iters=<10> init(%arg0) : !graphalg.mat<42 x 42 x i1> -> !graphalg.mat<42 x 42 x i1> body { ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<42 x 42 x i1>): // CHECK: %[[#PROJ:]] = garel.project %arg1 // CHECK: garel.for.yield %[[#PROJ]], %arg2 @@ -59,9 +50,9 @@ func.func @Until(%arg0: !graphalg.mat<42 x 42 x i1>) -> !graphalg.mat<42 x 42 x } until { ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<42 x 42 x i1>): // CHECK: %[[#AGG:]] = garel.aggregate %arg2 - %3 = graphalg.deferred_reduce %arg2 : !graphalg.mat<42 x 42 x i1> -> <1 x 1 x i1> + %1 = graphalg.deferred_reduce %arg2 : !graphalg.mat<42 x 42 x i1> -> <1 x 1 x i1> // CHECK: garel.for.yield %[[#AGG]] - graphalg.yield %3 : !graphalg.mat<1 x 1 x i1> + graphalg.yield %1 : !graphalg.mat<1 x 1 x i1> } - return %2 : !graphalg.mat<42 x 42 x i1> + return %0 : !graphalg.mat<42 x 42 x i1> } diff --git a/compiler/test/infer-density/for.mlir b/compiler/test/infer-density/for.mlir index 364dc15..0e31b0e 100644 --- a/compiler/test/infer-density/for.mlir +++ b/compiler/test/infer-density/for.mlir @@ -1,36 +1,12 @@ // RUN: graphalg-opt --test-print-dense --verify-diagnostics %s #dim = #graphalg.dim> -func.func @ForConst() -> !graphalg.mat<#dim x 1 x f64> { - %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> - %1 = graphalg.const_mat 10 : i64 -> <1 x 1 x i64> - %2 = graphalg.const_mat 42.0 : f64 -> <#dim x 1 x f64> - // expected-remark@below {{for}} - // expected-note@below {{operand #0: dense}} - // expected-note@below {{operand #1: dense}} - // expected-note@below {{operand #2: dense}} - // expected-note@below {{result #0: dense}} - %3 = graphalg.for_const - range(%0, %1) : <1 x 1 x i64> - init(%2) : !graphalg.mat<#dim x 1 x f64> - -> !graphalg.mat<#dim x 1 x f64> {tag = "for" } body { - // expected-note@below {{arg #0:0:0: dense}} - // expected-note@below {{arg #0:0:1: dense}} - ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<#dim x 1 x f64>): - %4 = graphalg.const_mat 42.0 : f64 -> <#dim x 1 x f64> - graphalg.yield %4 : !graphalg.mat<#dim x 1 x f64> - } until { - } - return %3 : !graphalg.mat<#dim x 1 x f64> -} - -func.func @ForDim() -> !graphalg.mat<#dim x 1 x f64> { +func.func @For() -> !graphalg.mat<#dim x 1 x f64> { %0 = graphalg.const_mat 42.0 : f64 -> <#dim x 1 x f64> // expected-remark@below {{for}} // expected-note@below {{operand #0: dense}} // expected-note@below {{result #0: dense}} - %1 = graphalg.for_dim range(#dim) - init(%0) : !graphalg.mat<#dim x 1 x f64> + %1 = graphalg.for begin=0 iters=<10> init(%0) : !graphalg.mat<#dim x 1 x f64> -> !graphalg.mat<#dim x 1 x f64> {tag = "for" } body { // expected-note@below {{arg #0:0:0: dense}} // expected-note@below {{arg #0:0:1: dense}} diff --git a/compiler/test/loop-aggregate/add-init-reduce.mlir b/compiler/test/loop-aggregate/add-init-reduce.mlir index 75914cb..9ce16d4 100644 --- a/compiler/test/loop-aggregate/add-init-reduce.mlir +++ b/compiler/test/loop-aggregate/add-init-reduce.mlir @@ -8,8 +8,8 @@ func.func @AddInitReduce(%arg0: !vec) -> !vec { // CHECK-SAME: %arg0 : !graphalg.mat<#dim x 1 x f64> // CHECK-SAME: -> <#dim x 1 x f64> // - // CHECK: graphalg.for_dim range(#dim) init(%[[#INIT]]) - %0 = graphalg.for_dim range(#dim) init(%arg0) : !vec -> !vec body { + // CHECK: graphalg.for begin=0 iters=#dim init(%[[#INIT]]) + %0 = graphalg.for begin=0 iters=#dim init(%arg0) : !vec -> !vec body { ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !vec): %1 = graphalg.deferred_reduce %arg2 : !vec -> !vec graphalg.yield %1 : !vec diff --git a/compiler/test/scalarize-apply/for-dim.mlir b/compiler/test/scalarize-apply/for-dim.mlir deleted file mode 100644 index 558acf8..0000000 --- a/compiler/test/scalarize-apply/for-dim.mlir +++ /dev/null @@ -1,25 +0,0 @@ -// RUN: graphalg-opt --graphalg-scalarize-apply < %s | FileCheck %s - -#dim = #graphalg.dim> - -func.func @ForDim(%arg0: !graphalg.mat<#dim x #dim x i64>) -> !graphalg.mat<#dim x #dim x i64> { - // CHECK: %[[#APPLY:]] = graphalg.apply %arg0 - %0 = graphalg.apply_inline %arg0 : !graphalg.mat<#dim x #dim x i64> -> <#dim x #dim x i64> { - ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>): - // CHECK: %[[#ONE:]] = graphalg.const 1 - // CHECK: %[[#ADD:]] = graphalg.add %arg1, %[[#ONE]] - %1 = graphalg.literal 1 : i64 - %2 = graphalg.for_dim range(1) init(%arg1) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { - ^bb0(%arg2: !graphalg.mat<1 x 1 x i64>, %arg3: !graphalg.mat<1 x 1 x i64>): - %3 = graphalg.ewise %arg3 ADD %1 : <1 x 1 x i64> - graphalg.yield %3 : !graphalg.mat<1 x 1 x i64> - } until { - } - - // CHECK: graphalg.apply.return %[[#ADD]] - graphalg.apply_inline.return %2 : <1 x 1 x i64> - } - - // CHECK: return %[[#APPLY]] - return %0 : !graphalg.mat<#dim x #dim x i64> -} diff --git a/compiler/test/scalarize-apply/for-const.mlir b/compiler/test/scalarize-apply/for.mlir similarity index 62% rename from compiler/test/scalarize-apply/for-const.mlir rename to compiler/test/scalarize-apply/for.mlir index dc62cfa..7e23771 100644 --- a/compiler/test/scalarize-apply/for-const.mlir +++ b/compiler/test/scalarize-apply/for.mlir @@ -2,7 +2,7 @@ #dim = #graphalg.dim> -func.func @ForConst(%arg0: !graphalg.mat<#dim x #dim x i64>) -> !graphalg.mat<#dim x #dim x i64> { +func.func @For(%arg0: !graphalg.mat<#dim x #dim x i64>) -> !graphalg.mat<#dim x #dim x i64> { // CHECK: %[[#APPLY:]] = graphalg.apply %arg0 %0 = graphalg.apply_inline %arg0 : !graphalg.mat<#dim x #dim x i64> -> <#dim x #dim x i64> { ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>): @@ -11,17 +11,15 @@ func.func @ForConst(%arg0: !graphalg.mat<#dim x #dim x i64>) -> !graphalg.mat<#d // CHECK: %[[#ADD1:]] = graphalg.add %[[#ADD0]], %[[#ONE]] // CHECK: %[[#ADD2:]] = graphalg.add %[[#ADD1]], %[[#ONE]] %1 = graphalg.literal 1 : i64 - %2 = graphalg.literal 2 : i64 - %3 = graphalg.literal 5 : i64 - %4 = graphalg.for_const range(%2, %3) : <1 x 1 x i64> init(%arg1) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { + %2 = graphalg.for begin=2 iters=<3> init(%arg1) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { ^bb0(%arg2: !graphalg.mat<1 x 1 x i64>, %arg3: !graphalg.mat<1 x 1 x i64>): - %5 = graphalg.ewise %arg3 ADD %1 : <1 x 1 x i64> - graphalg.yield %5 : !graphalg.mat<1 x 1 x i64> + %3 = graphalg.ewise %arg3 ADD %1 : <1 x 1 x i64> + graphalg.yield %3 : !graphalg.mat<1 x 1 x i64> } until { } // CHECK: graphalg.apply.return %[[#ADD2]] - graphalg.apply_inline.return %4 : <1 x 1 x i64> + graphalg.apply_inline.return %2 : <1 x 1 x i64> } // CHECK: return %[[#APPLY]] diff --git a/compiler/test/verify-dimensions/for-dim.mlir b/compiler/test/verify-dimensions/for-dim.mlir index 1a447fb..a7e5a73 100644 --- a/compiler/test/verify-dimensions/for-dim.mlir +++ b/compiler/test/verify-dimensions/for-dim.mlir @@ -3,7 +3,7 @@ func.func @Ok(%arg0: !graphalg.mat<#dim x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> - %1 = graphalg.for_dim range(#dim) init(%0) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { + %1 = graphalg.for begin=0 iters=#dim init(%0) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<1 x 1 x i64>): graphalg.yield %arg2 : !graphalg.mat<1 x 1 x i64> } until { @@ -16,8 +16,8 @@ func.func @Ok(%arg0: !graphalg.mat<#dim x 1 x i64>) -> !graphalg.mat<1 x 1 x i64 func.func @IllegalRange() -> !graphalg.mat<1 x 1 x i64> { %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> - // expected-error@below{{'graphalg.for_dim' op attribute "dim" has value #graphalg.dim> which has not been marked as legal}} - %1 = graphalg.for_dim range(#dim) init(%0) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { + // expected-error@below{{'graphalg.for' op attribute "iters" has value #graphalg.dim> which has not been marked as legal}} + %1 = graphalg.for begin=0 iters=#dim init(%0) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<1 x 1 x i64>): graphalg.yield %arg2 : !graphalg.mat<1 x 1 x i64> } until { From f82eb4ebb841aceb4664f5d4b0e09597e2d88f91 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Thu, 30 Apr 2026 15:35:15 +0000 Subject: [PATCH 6/6] Set constant args pass. --- compiler/include/graphalg/GraphAlgPasses.td | 10 +++ compiler/src/graphalg/CMakeLists.txt | 1 + compiler/src/graphalg/GraphAlgSetConstArg.cpp | 70 +++++++++++++++++++ compiler/test/e2e/cdlp.gr | 10 +-- compiler/test/e2e/pr.gr | 10 +-- 5 files changed, 93 insertions(+), 8 deletions(-) create mode 100644 compiler/src/graphalg/GraphAlgSetConstArg.cpp diff --git a/compiler/include/graphalg/GraphAlgPasses.td b/compiler/include/graphalg/GraphAlgPasses.td index d6a9d00..4d7ae71 100644 --- a/compiler/include/graphalg/GraphAlgPasses.td +++ b/compiler/include/graphalg/GraphAlgPasses.td @@ -3,6 +3,16 @@ include "mlir/Pass/PassBase.td" +def GraphAlgSetConstArg : Pass<"graphalg-set-const-arg", "::mlir::ModuleOp"> { + let summary = "Propagate a constant integer argument into a function"; + + let options = [ + Option<"functionName", "func", "std::string", /*default=*/"\"\"", "Name of the function to call">, + Option<"argumentNumber", "argNum", "int", /*default=*/"-1", "The argument number that is constant">, + Option<"value", "value", "std::int64_t", /*default=*/"0", "The value to propagate">, + ]; +} + def GraphAlgPrepareInline : Pass<"graphalg-prepare-inline", "::mlir::ModuleOp"> { let summary = "Prepares the IR for function inlining"; diff --git a/compiler/src/graphalg/CMakeLists.txt b/compiler/src/graphalg/CMakeLists.txt index 7fc3fb4..0ba69a8 100644 --- a/compiler/src/graphalg/CMakeLists.txt +++ b/compiler/src/graphalg/CMakeLists.txt @@ -37,6 +37,7 @@ add_library(GraphAlgPasses GraphAlgLoopAggregate.cpp GraphAlgPrepareInlinePass.cpp GraphAlgScalarizeApply.cpp + GraphAlgSetConstArg.cpp GraphAlgSetDimensions.cpp GraphAlgSplitAggregate.cpp GraphAlgToCore.cpp diff --git a/compiler/src/graphalg/GraphAlgSetConstArg.cpp b/compiler/src/graphalg/GraphAlgSetConstArg.cpp new file mode 100644 index 0000000..a3d58df --- /dev/null +++ b/compiler/src/graphalg/GraphAlgSetConstArg.cpp @@ -0,0 +1,70 @@ +#include +#include +#include + +#include "graphalg/GraphAlgOps.h" +#include "graphalg/GraphAlgPasses.h" +#include "graphalg/GraphAlgTypes.h" +#include "graphalg/SemiringTypes.h" + +namespace graphalg { + +#define GEN_PASS_DEF_GRAPHALGSETCONSTARG +#include "graphalg/GraphAlgPasses.h.inc" + +namespace { + +/** + * Propagate integer constant arguments into function bodies. + * + * Run canonicalization after this pass to propagate the constant through the + * program. This pass does not change the function signature (that is, it does + * not remove the original argument). + */ +class GraphAlgSetConstArg + : public impl::GraphAlgSetConstArgBase { + using impl::GraphAlgSetConstArgBase< + GraphAlgSetConstArg>::GraphAlgSetConstArgBase; + + void runOnOperation() final; +}; + +} // namespace + +void GraphAlgSetConstArg::runOnOperation() { + if (functionName.empty()) { + getOperation().emitError("missing value for required option 'func'"); + return signalPassFailure(); + } + + if (argumentNumber < 0) { + getOperation().emitError("missing value for required option 'argNum'"); + return signalPassFailure(); + } + + auto func = llvm::dyn_cast_if_present( + getOperation().lookupSymbol(functionName)); + if (!func) { + getOperation().emitOpError("does not contain a function named '") + << functionName << "'"; + return signalPassFailure(); + } + + auto &body = func.getBody().front(); + auto numArgs = body.getNumArguments(); + if (argumentNumber >= numArgs) { + getOperation().emitOpError("argument number ") + << argumentNumber << " is out of bounds function " << functionName + << ", which only has " << numArgs << " parameters"; + return signalPassFailure(); + } + + mlir::IRRewriter rewriter(func); + rewriter.setInsertionPointToStart(&body); + auto constOp = rewriter.create( + func.getLoc(), MatrixType::scalarOf(SemiringTypes::forInt(&getContext())), + rewriter.getI64IntegerAttr(value)); + rewriter.replaceAllUsesWith(body.getArgument(argumentNumber), constOp); +} + +} // namespace graphalg diff --git a/compiler/test/e2e/cdlp.gr b/compiler/test/e2e/cdlp.gr index 1e57991..86c8f38 100644 --- a/compiler/test/e2e/cdlp.gr +++ b/compiler/test/e2e/cdlp.gr @@ -1,7 +1,7 @@ // RUN: split-file %s %t // RUN: graphalg-translate --import-graphalg %t/input.gr > %t/parsed.mlir -// RUN: graphalg-opt --graphalg-to-core-pipeline --graphalg-set-dimensions='func=CDLP args=8x8' %t/parsed.mlir > %t/exec.mlir -// RUN: graphalg-exec %t/exec.mlir CDLP %t/graph.m | diff - %t/output.m +// RUN: graphalg-opt --graphalg-set-const-arg='func=CDLP argNum=1 value=5' --graphalg-to-core-pipeline --graphalg-set-dimensions='func=CDLP args=8x8,1x1' %t/parsed.mlir > %t/exec.mlir +// RUN: graphalg-exec %t/exec.mlir CDLP %t/graph.m %t/iters.m | diff - %t/output.m //--- graph.m 0 1 @@ -23,14 +23,16 @@ 6 7 7 5 +//--- iters.m +0 0 5 + //--- input.gr func isMax(v: int, max: trop_max_int) -> bool { return (cast(v) == max) * (v != zero(int)); } -func CDLP(graph: Matrix) -> Matrix { - iterations = int(5); +func CDLP(graph: Matrix, iterations:int) -> Matrix { id = Vector(graph.nrows); id[:] = bool(true); L = diag(id); diff --git a/compiler/test/e2e/pr.gr b/compiler/test/e2e/pr.gr index 374eb7d..c9b3c29 100644 --- a/compiler/test/e2e/pr.gr +++ b/compiler/test/e2e/pr.gr @@ -1,7 +1,7 @@ // RUN: split-file %s %t // RUN: graphalg-translate --import-graphalg %t/input.gr > %t/parsed.mlir -// RUN: graphalg-opt --graphalg-to-core-pipeline --graphalg-set-dimensions='func=PR args=50x50' %t/parsed.mlir > %t/exec.mlir -// RUN: graphalg-exec %t/exec.mlir PR %t/graph.m | diff - %t/output.m +// RUN: graphalg-opt --graphalg-set-const-arg='func=PR argNum=1 value=10' --graphalg-to-core-pipeline --graphalg-set-dimensions='func=PR args=50x50,1x1' %t/parsed.mlir > %t/exec.mlir +// RUN: graphalg-exec %t/exec.mlir PR %t/graph.m %t/iters.m | diff - %t/output.m //--- graph.m 0 18 @@ -251,14 +251,16 @@ 49 27 49 46 +//--- iters.m +0 0 10 + //--- input.gr func withDamping(degree:int, damping:real) -> real { return cast(degree) / damping; } -func PR(graph: Matrix) -> Vector { +func PR(graph: Matrix, iterations:int) -> Vector { damping = real(0.85); - iterations = int(10); n = graph.nrows; teleport = (real(1.0) - damping) / cast(n); rdiff = real(1.0);