diff --git a/compiler/include/garel/GARelAttr.td b/compiler/include/garel/GARelAttr.td index c09797c..4dac0ca 100644 --- a/compiler/include/garel/GARelAttr.td +++ b/compiler/include/garel/GARelAttr.td @@ -38,6 +38,7 @@ def AggregateFunc : I64EnumAttr< I64EnumAttrCase<"MAX", 2>, I64EnumAttrCase<"LOR", 3>, /* Logical OR (over i1) */ I64EnumAttrCase<"ARGMIN", 4>, + I64EnumAttrCase<"COUNT", 5>, ] > { let cppNamespace = "::garel"; @@ -49,7 +50,7 @@ def Aggregator : GARel_Attr<"Aggregator", "aggregator"> { let parameters = (ins "AggregateFunc":$func, - ArrayRefParameter<"ColumnIdx">:$inputs); + OptionalArrayRefParameter<"ColumnIdx">:$inputs); let assemblyFormat = [{ `<` $func $inputs `>` diff --git a/compiler/include/garel/GARelOps.td b/compiler/include/garel/GARelOps.td index f8ccaa3..e1e94ec 100644 --- a/compiler/include/garel/GARelOps.td +++ b/compiler/include/garel/GARelOps.td @@ -143,7 +143,7 @@ def ForOp : GARel_Op<"for", [InferTypeOpAdaptor]> { let arguments = (ins Variadic:$init, - I64Attr:$iters, + I64Relation:$iters, I64Attr:$resultIdx); let regions = (region diff --git a/compiler/include/garel/GARelTypes.h b/compiler/include/garel/GARelTypes.h index e6a4c90..689cd3a 100644 --- a/compiler/include/garel/GARelTypes.h +++ b/compiler/include/garel/GARelTypes.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include "garel/GARelAttr.h" @@ -11,4 +12,6 @@ namespace garel { bool isColumnType(mlir::Type t); +RelationType getI64RelationType(mlir::MLIRContext *ctx); + } // namespace garel diff --git a/compiler/include/garel/GARelTypes.td b/compiler/include/garel/GARelTypes.td index 5d848b5..d5eefed 100644 --- a/compiler/include/garel/GARelTypes.td +++ b/compiler/include/garel/GARelTypes.td @@ -34,4 +34,10 @@ def Tuple : GARel_Type<"Tuple", "tuple"> { def ColumnType : Type, "column type">; +def I64Relation : Type< + CPred<"::garel::getI64RelationType($_self.getContext()) == $_self">, + "relation with a single i64 column", + "RelationType">, + BuildableType<"::garel::getI64RelationType($_builder.getContext())">; + #endif // GAREL_TYPES diff --git a/compiler/include/graphalg/GraphAlgOps.td b/compiler/include/graphalg/GraphAlgOps.td index 0cdf7bf..8e99851 100644 --- a/compiler/include/graphalg/GraphAlgOps.td +++ b/compiler/include/graphalg/GraphAlgOps.td @@ -403,16 +403,22 @@ def BroadcastOp : Core_Op<"broadcast", [ let hasVerifier = 1; } -// Not core according to spec, but we don't want to unroll in the general case. -def ForConstOp : Core_Op<"for_const", [ +def ForOp : Core_Op<"for", [ Pure, - AllTypesMatch<["rangeBegin", "rangeEnd"]>, + AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { - let summary = "For loop with constant bounds"; + let summary = "For loop with dynamic bounds"; let description = [{ - A loop iterating over the integer range starting at `rangeBegin` - (inclusive) and ending at `rangeEnd` (exclusive). + A loop iterating over one of three ranges: + 1) `dynBegin` (inclusive) to `dynEnd` (exclusive) + 2) `begin` to `begin` + `iters`, where `iters` is an integer + 3) `begin` to `begin` + `iters`, where `iters` is a matrix dimension + + Only instances with range types 2 or 3 are considered part of GraphAlg + Core. Constant propagation is expected to transform range type 1 into + either 2 or 3. + The `body` region is executed once for every value in the integer range (that value is passed as the first block argument). At the first iteration of the loop, the other block arguments take the @@ -432,58 +438,39 @@ def ForConstOp : Core_Op<"for_const", [ let arguments = (ins Variadic:$initArgs, - I64Scalar:$rangeBegin, - I64Scalar:$rangeEnd); + Optional:$dynBegin, + Optional:$dynEnd, + OptionalAttr:$begin, + OptionalAttr:$iters); let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$body, MaxSizedRegion<1>:$until); let assemblyFormat = [{ - `range` `(` - $rangeBegin `,` - $rangeEnd - `)` `:` type($rangeEnd) + (`dyn_begin` `` `=` `` $dynBegin^)? + (`dyn_end` `` `=` `` $dynEnd^)? + (`begin` `` `=` `` $begin^)? + (`iters` `` `=` `` $iters^)? `init` `(` $initArgs `)` `:` type($initArgs) `->` type($results) attr-dict `body` $body `until` $until }]; + let hasVerifier = 1; let hasRegionVerifier = 1; -} - -def ForDimOp : Core_Op<"for_dim", [ - Pure, - DeclareOpInterfaceMethods]> { - let summary = "For loop over a matrix dimension"; - - let description = [{ - A loop iterating over the half-open range [0..`dim`). - - This op is otherwise equivalent to `ForConstOp`. - }]; - - let arguments = (ins Variadic:$initArgs, DimAttr:$dim); - - let results = (outs Variadic:$results); - - let regions = (region SizedRegion<1>:$body, MaxSizedRegion<1>:$until); + let hasFolder = 1; - let assemblyFormat = [{ - `range` `(` custom($dim) `)` - `init` `(` $initArgs `)` `:` type($initArgs) `->` type($results) attr-dict - `body` $body - `until` $until + let extraClassDeclaration = [{ + /** Whether at least one of `dyn_begin` and `dyn_end` is set. */ + bool isDynamicRange(); }]; - - let hasRegionVerifier = 1; - let hasCanonicalizer = 1; } def YieldOp : Core_Op<"yield", [ Pure, Terminator, - ParentOneOf<["ForConstOp", "ForDimOp"]>, + HasParent<"ForOp">, DeclareOpInterfaceMethods]> { let summary = "Yield from a loop body"; diff --git a/compiler/include/graphalg/GraphAlgPasses.td b/compiler/include/graphalg/GraphAlgPasses.td index d6a9d00..4d7ae71 100644 --- a/compiler/include/graphalg/GraphAlgPasses.td +++ b/compiler/include/graphalg/GraphAlgPasses.td @@ -3,6 +3,16 @@ include "mlir/Pass/PassBase.td" +def GraphAlgSetConstArg : Pass<"graphalg-set-const-arg", "::mlir::ModuleOp"> { + let summary = "Propagate a constant integer argument into a function"; + + let options = [ + Option<"functionName", "func", "std::string", /*default=*/"\"\"", "Name of the function to call">, + Option<"argumentNumber", "argNum", "int", /*default=*/"-1", "The argument number that is constant">, + Option<"value", "value", "std::int64_t", /*default=*/"0", "The value to propagate">, + ]; +} + def GraphAlgPrepareInline : Pass<"graphalg-prepare-inline", "::mlir::ModuleOp"> { let summary = "Prepares the IR for function inlining"; diff --git a/compiler/src/garel/GARelAttr.cpp b/compiler/src/garel/GARelAttr.cpp index abb69cc..d228595 100644 --- a/compiler/src/garel/GARelAttr.cpp +++ b/compiler/src/garel/GARelAttr.cpp @@ -24,23 +24,32 @@ mlir::Type AggregatorAttr::getResultType(mlir::Type inputRel) { case AggregateFunc::ARGMIN: // NOTE: argmin(arg, val) also uses first input column as output type. return llvm::cast(inputRel).getColumns()[getInputs()[0]]; + case AggregateFunc::COUNT: + return mlir::IntegerType::get(inputRel.getContext(), 64); + } +} + +static std::size_t expectedNumInputs(AggregateFunc f) { + switch (f) { + case AggregateFunc::SUM: + case AggregateFunc::MIN: + case AggregateFunc::MAX: + case AggregateFunc::LOR: + return 1; + case AggregateFunc::ARGMIN: + return 2; + case AggregateFunc::COUNT: + return 0; } } mlir::LogicalResult AggregatorAttr::verify(llvm::function_ref emitError, AggregateFunc func, llvm::ArrayRef inputs) { - if (func == AggregateFunc::ARGMIN) { - if (inputs.size() != 2) { - return emitError() << stringifyAggregateFunc(func) - << " expects exactly two inputs (arg, val), got " - << inputs.size(); - } - } else { - if (inputs.size() != 1) { - return emitError() << stringifyAggregateFunc(func) - << " expects exactly one input, got " << inputs.size(); - } + if (inputs.size() != expectedNumInputs(func)) { + return emitError() << stringifyAggregateFunc(func) << " expects exactly " + << expectedNumInputs(func) << " inputs, got " + << inputs.size(); } return mlir::success(); diff --git a/compiler/src/garel/GARelTypes.cpp b/compiler/src/garel/GARelTypes.cpp index 618642d..d807ddc 100644 --- a/compiler/src/garel/GARelTypes.cpp +++ b/compiler/src/garel/GARelTypes.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include @@ -18,6 +19,11 @@ bool isColumnType(mlir::Type t) { t.isIndex(); } +RelationType getI64RelationType(mlir::MLIRContext *ctx) { + return RelationType::get( + ctx, mlir::ArrayRef{mlir::IntegerType::get(ctx, 64)}); +} + // Need to define this here to avoid depending on IPRTypes in // IPRDialect and creating a cycle. void GARelDialect::registerTypes() { diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index ee60821..cc94417 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -4,16 +4,21 @@ #include #include #include +#include +#include #include #include +#include #include #include #include #include #include #include +#include #include #include +#include #include #include #include @@ -256,20 +261,78 @@ MatrixTypeConverter::MatrixTypeConverter( addConversion( [this](graphalg::MatrixType t) { return convertMatrixType(t); }); + + // No need to convert. + addConversion([](RelationType t) { return t; }); } // ============================================================================= // ============================== Helper Methods =============================== // ============================================================================= +static llvm::StringLiteral INPUT_DIMS_ATTR_KEY = "garel.input_dims"; + /** * Create a relation with all indices for a matrix dimension. * * Used to broadcast scalar values to a larger matrix. */ -static RangeOp createDimRead(mlir::Location loc, graphalg::DimAttr dim, - mlir::OpBuilder &builder) { - return builder.create(loc, dim.getConcreteDim()); +static mlir::Value createDimRead(mlir::Location loc, graphalg::DimAttr dim, + mlir::OpBuilder &builder) { + if (dim.isConcrete()) { + return builder.create(loc, dim.getConcreteDim()); + } + + assert(dim.isAbstract()); + + // Find owning func + mlir::func::FuncOp funcOp; + auto parentOp = builder.getInsertionBlock()->getParentOp(); + funcOp = llvm::dyn_cast(parentOp); + if (!funcOp) { + funcOp = parentOp->getParentOfType(); + } + + assert(funcOp && "not contained inside a function"); + auto inputDims = funcOp->getAttrOfType(INPUT_DIMS_ATTR_KEY); + if (!inputDims) { + funcOp.emitOpError("missing input dims attribute"); + return nullptr; + } + + // Find the function argument for this dimension. + auto args = funcOp.getFunctionBody().getArguments(); + for (auto [d, v] : llvm::zip(inputDims.getAsRange(), + llvm::reverse(args))) { + if (dim == d) { + return v; + } + } + + funcOp.emitOpError("dimension ") + << dim << "is not defined within this function"; + return nullptr; +} + +/** + * Create a relation with one tuple that contains the number of valid indices + * for a matrix dimension. + */ +static mlir::Value createDimInt(mlir::Location loc, graphalg::DimAttr dim, + mlir::OpBuilder &builder) { + if (dim.isConcrete()) { + return builder.create( + loc, builder.getI64IntegerAttr(dim.getConcreteDim())); + } + + auto input = createDimRead(loc, dim, builder); + llvm::ArrayRef groupBy; + std::array aggregators{ + builder.getAttr(AggregateFunc::COUNT, + llvm::ArrayRef{}), + }; + + return builder.create(loc, input, groupBy, aggregators); } static void @@ -406,15 +469,6 @@ createAggregator(mlir::Operation *op, graphalg::SemiringTypeInterface sring, return AggregatorAttr::get(ctx, func, inputs); } -static mlir::IntegerAttr tryGetConstantInt(mlir::Value v) { - mlir::Attribute attr; - if (!mlir::matchPattern(v, mlir::m_Constant(&attr))) { - return nullptr; - } - - return llvm::cast(attr); -} - static mlir::FailureOr createMul(mlir::Operation *op, graphalg::SemiringTypeInterface sring, mlir::Value lhs, mlir::Value rhs, mlir::OpBuilder &builder) { @@ -441,6 +495,16 @@ createMul(mlir::Operation *op, graphalg::SemiringTypeInterface sring, << sring << " is not supported"; } +static bool isConstantZeroI64(mlir::Value rangeBegin) { + if (auto constOp = rangeBegin.getDefiningOp()) { + auto zero = mlir::IntegerAttr::get( + mlir::IntegerType::get(rangeBegin.getContext(), 64), 0); + return constOp.getValue() == zero; + } + + return false; +} + // ============================================================================= // =============================== Op Conversion =============================== // ============================================================================= @@ -563,6 +627,10 @@ mlir::LogicalResult ApplyOpConversion::matchAndRewrite( // Broadcast to all rows. auto rowsOp = createDimRead(op.getLoc(), output.matrixType().getRows(), rewriter); + if (!rowsOp) { + return mlir::failure(); + } + joinChildren.push_back(rowsOp); rowColumns.push_back(InputColumnRef{ .relIdx = joinChildren.size() - 1, @@ -576,6 +644,10 @@ mlir::LogicalResult ApplyOpConversion::matchAndRewrite( // Broadcast to all columns. auto colsOp = createDimRead(op.getLoc(), output.matrixType().getCols(), rewriter); + if (!colsOp) { + return mlir::failure(); + } + joinChildren.push_back(colsOp); colColumns.push_back(InputColumnRef{ .relIdx = joinChildren.size() - 1, @@ -679,8 +751,13 @@ mlir::LogicalResult OpConversion::matchAndRewrite( rowColumnIdx = input.rowColumn(); } else if (output.hasRowColumn()) { // Broadcast over all rows. - joinChildren.push_back( - createDimRead(op.getLoc(), output.matrixType().getRows(), rewriter)); + auto rowsOp = + createDimRead(op.getLoc(), output.matrixType().getRows(), rewriter); + if (!rowsOp) { + return mlir::failure(); + } + + joinChildren.push_back(rowsOp); rowColumnIdx = currentColIdx++; } @@ -689,8 +766,13 @@ mlir::LogicalResult OpConversion::matchAndRewrite( colColumnIdx = input.colColumn(); } else if (output.hasColColumn()) { // Broadcast over all columns. - joinChildren.push_back( - createDimRead(op.getLoc(), output.matrixType().getCols(), rewriter)); + auto colsOp = + createDimRead(op.getLoc(), output.matrixType().getCols(), rewriter); + if (!colsOp) { + return mlir::failure(); + } + + joinChildren.push_back(colsOp); colColumnIdx = currentColIdx++; } @@ -738,14 +820,24 @@ mlir::LogicalResult OpConversion::matchAndRewrite( llvm::SmallVector joinChildren; if (!output.matrixType().getRows().isOne()) { // Broadcast over all rows. - joinChildren.push_back( - createDimRead(op.getLoc(), output.matrixType().getRows(), rewriter)); + auto rowsOp = + createDimRead(op.getLoc(), output.matrixType().getRows(), rewriter); + if (!rowsOp) { + return mlir::failure(); + } + + joinChildren.push_back(rowsOp); } if (!output.matrixType().getCols().isOne()) { // Broadcast over all columns. - joinChildren.push_back( - createDimRead(op.getLoc(), output.matrixType().getCols(), rewriter)); + auto colsOp = + createDimRead(op.getLoc(), output.matrixType().getCols(), rewriter); + if (!colsOp) { + return mlir::failure(); + } + + joinChildren.push_back(colsOp); } joinChildren.push_back(constantOp); @@ -803,18 +895,13 @@ mlir::LogicalResult OpConversion::matchAndRewrite( } template <> -mlir::LogicalResult OpConversion::matchAndRewrite( - graphalg::ForConstOp op, OpAdaptor adaptor, +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::ForOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { - auto rangeBegin = tryGetConstantInt(op.getRangeBegin()); - auto rangeEnd = tryGetConstantInt(op.getRangeEnd()); - if (!rangeBegin || !rangeEnd) { - return op->emitOpError("iter range is not constant"); - } - - auto iters = rangeEnd.getInt() - rangeBegin.getInt(); - - llvm::SmallVector initArgs{adaptor.getRangeBegin()}; + auto begin = + rewriter.create(op.getLoc(), rewriter.getI64IntegerAttr(0)); + auto iters = createDimInt(op.getLoc(), *op.getIters(), rewriter); + llvm::SmallVector initArgs{begin}; initArgs.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end()); auto blockSignature = @@ -838,7 +925,7 @@ mlir::LogicalResult OpConversion::matchAndRewrite( // the result index accordingly. std::int64_t resultIdx = i + 1; auto resultType = adaptor.getInitArgs()[i].getType(); - auto forOp = rewriter.create(op.getLoc(), resultType, initArgs, + auto forOp = rewriter.create(op->getLoc(), resultType, initArgs, iters, resultIdx); // body block rewriter.cloneRegionBefore(op.getBody(), forOp.getBody(), @@ -1078,6 +1165,14 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::success(); } +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::CastDimOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + rewriter.replaceOp(op, createDimInt(op->getLoc(), op.getInput(), rewriter)); + return mlir::success(); +} + // ============================================================================= // ============================ Tuple Op Conversion ============================ // ============================================================================= @@ -1279,7 +1374,51 @@ static bool hasRelationOperands(mlir::Operation *op) { [](auto t) { return llvm::isa(t); }); } +static void addDimensionInputs(mlir::func::FuncOp funcOp) { + // Function type + llvm::SmallVector inputs(funcOp.getFunctionType().getInputs()); + llvm::SmallVector dims; + for (auto t : funcOp.getFunctionType().getInputs()) { + auto matType = llvm::cast(t); + for (auto d : {matType.getRows(), matType.getCols()}) { + if (d.isAbstract() && !llvm::is_contained(dims, d)) { + dims.push_back(d); + } + } + } + + // Add dimension inputs to function type + auto *ctx = funcOp.getContext(); + auto dimReadType = RelationType::get( + ctx, mlir::ArrayRef{mlir::IndexType::get(ctx)}); + inputs.append(dims.size(), dimReadType); + auto newType = mlir::FunctionType::get(ctx, inputs, + funcOp.getFunctionType().getResults()); + funcOp.setFunctionType(newType); + + // Add dimension inputs as block args. + auto &block = funcOp.getFunctionBody().front(); + for (auto d : dims) { + auto arg = block.addArgument(dimReadType, funcOp.getLoc()); + } + + // Annotate with input_dims attribute. + llvm::SmallVector dimsErased; + for (auto d : dims) { + dimsErased.push_back(d); + } + funcOp->setAttr(INPUT_DIMS_ATTR_KEY, mlir::ArrayAttr::get(ctx, dimsErased)); +} + void GraphAlgToRel::runOnOperation() { + // Add dimension inputs + llvm::SmallVector funcOps( + getOperation().getOps()); + mlir::IRRewriter rewriter(getOperation()); + for (auto op : funcOps) { + addDimensionInputs(op); + } + mlir::ConversionTarget target(getContext()); // Eliminate all graphalg ops target.addIllegalDialect(); @@ -1302,10 +1441,10 @@ void GraphAlgToRel::runOnOperation() { OpConversion, OpConversion, OpConversion, OpConversion, OpConversion, - OpConversion, OpConversion, + OpConversion, OpConversion, OpConversion, OpConversion, - OpConversion, OpConversion>( - matrixTypeConverter, &getContext()); + OpConversion, OpConversion, + OpConversion>(matrixTypeConverter, &getContext()); patterns.add(semiringTypeConverter, matrixTypeConverter, &getContext()); diff --git a/compiler/src/graphalg/CMakeLists.txt b/compiler/src/graphalg/CMakeLists.txt index 7fc3fb4..0ba69a8 100644 --- a/compiler/src/graphalg/CMakeLists.txt +++ b/compiler/src/graphalg/CMakeLists.txt @@ -37,6 +37,7 @@ add_library(GraphAlgPasses GraphAlgLoopAggregate.cpp GraphAlgPrepareInlinePass.cpp GraphAlgScalarizeApply.cpp + GraphAlgSetConstArg.cpp GraphAlgSetDimensions.cpp GraphAlgSplitAggregate.cpp GraphAlgToCore.cpp diff --git a/compiler/src/graphalg/GraphAlgCanonicalize.cpp b/compiler/src/graphalg/GraphAlgCanonicalize.cpp index e89ff7e..f430b3c 100644 --- a/compiler/src/graphalg/GraphAlgCanonicalize.cpp +++ b/compiler/src/graphalg/GraphAlgCanonicalize.cpp @@ -242,35 +242,39 @@ mlir::OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { return nullptr; } -static mlir::LogicalResult forDimConst(ForDimOp op, - mlir::PatternRewriter &rewriter) { - if (!op.getDim().isConcrete()) { - return mlir::failure(); +mlir::LogicalResult +ForOp::fold(FoldAdaptor adaptor, + ::llvm::SmallVectorImpl<::mlir::OpFoldResult> &results) { + if (!getBegin() && adaptor.getDynBegin()) { + // Can infer a constant begin to the range. + auto begin = llvm::cast(adaptor.getDynBegin()); + setBeginAttr(begin); + getDynBeginMutable().clear(); + return mlir::success(); } - // The number of iterations is known, so we can replace with a ForConstOp. - auto end = op.getDim().getConcreteDim(); - - // Range from 0 to dim. - auto intType = - MatrixType::scalarOf(SemiringTypes::forInt(rewriter.getContext())); - auto beginOp = rewriter.create( - op->getLoc(), intType, rewriter.getI64IntegerAttr(0)); - auto endOp = rewriter.create( - op->getLoc(), intType, rewriter.getI64IntegerAttr(end)); + if (!getIters() && getBegin() && adaptor.getDynEnd()) { + // Can infer a constant number of iterations. + auto begin = *getBegin(); + auto end = llvm::cast(adaptor.getDynEnd()) + .getValue() + .getZExtValue(); + // If end < begin, drop to 0 iterations. + std::size_t iters = 0; + if (begin < end) { + iters = end - begin; + } - auto forConstOp = rewriter.create( - op->getLoc(), op->getResultTypes(), op.getInitArgs(), beginOp, endOp); - rewriter.inlineRegionBefore(op.getBody(), forConstOp.getBody(), - forConstOp.getBody().begin()); - rewriter.replaceOp(op, forConstOp); + // NOTE: number of iterations is encoded as a DimAttr. + auto dim = DimAttr::getConcrete(getContext(), iters); + setItersAttr(dim); + getDynEndMutable().clear(); + return mlir::success(); + } - return mlir::success(); -} + // TODO: Fold if iters=0 -void ForDimOp::getCanonicalizationPatterns(mlir::RewritePatternSet &patterns, - mlir::MLIRContext *context) { - patterns.add(forDimConst); + return mlir::failure(); } mlir::OpFoldResult PickAnyOp::fold(FoldAdaptor adaptor) { diff --git a/compiler/src/graphalg/GraphAlgLoopAggregate.cpp b/compiler/src/graphalg/GraphAlgLoopAggregate.cpp index cfac2af..aaa2bfe 100644 --- a/compiler/src/graphalg/GraphAlgLoopAggregate.cpp +++ b/compiler/src/graphalg/GraphAlgLoopAggregate.cpp @@ -28,19 +28,10 @@ class GraphAlgLoopAggregate // If the loop body ends with an aggregation op, ensure the init arg is also an // aggregation. Later on in the AvantGraph query pipeline, this will signal to // the optimizer that the iteration state can be kept as an aggregate table. -static void addInitReduce(mlir::Operation *op, mlir::IRRewriter &rewriter) { - mlir::Block *body; - llvm::SmallVector newInitArgs; - if (auto constOp = llvm::dyn_cast(op)) { - body = &constOp.getBody().front(); - newInitArgs = constOp.getInitArgs(); - } else { - auto dimOp = llvm::cast(op); - body = &dimOp.getBody().front(); - newInitArgs = dimOp.getInitArgs(); - } +static void addInitReduce(ForOp op, mlir::IRRewriter &rewriter) { + llvm::SmallVector newInitArgs(op.getInitArgs()); - auto yieldOp = llvm::cast(body->getTerminator()); + auto yieldOp = llvm::cast(op.getBody().front().getTerminator()); for (auto [i, iterResult] : llvm::enumerate(yieldOp.getInputs())) { auto iterLastOp = iterResult.getDefiningOp(); if (llvm::isa_and_present(iterLastOp)) { @@ -50,20 +41,13 @@ static void addInitReduce(mlir::Operation *op, mlir::IRRewriter &rewriter) { } } - rewriter.modifyOpInPlace(op, [&]() { - if (auto constOp = llvm::dyn_cast(op)) { - constOp.getInitArgsMutable().assign(newInitArgs); - } else { - auto dimOp = llvm::cast(op); - dimOp.getInitArgsMutable().assign(newInitArgs); - } - }); + rewriter.modifyOpInPlace( + op, [&]() { op.getInitArgsMutable().assign(newInitArgs); }); } void GraphAlgLoopAggregate::runOnOperation() { - llvm::SmallVector loopOps; - getOperation()->walk([&](ForConstOp op) { loopOps.emplace_back(op); }); - getOperation()->walk([&](ForDimOp op) { loopOps.emplace_back(op); }); + llvm::SmallVector loopOps; + getOperation()->walk([&](ForOp op) { loopOps.emplace_back(op); }); mlir::IRRewriter rewriter(&getContext()); for (auto op : loopOps) { diff --git a/compiler/src/graphalg/GraphAlgOps.cpp b/compiler/src/graphalg/GraphAlgOps.cpp index 01764bc..472693f 100644 --- a/compiler/src/graphalg/GraphAlgOps.cpp +++ b/compiler/src/graphalg/GraphAlgOps.cpp @@ -398,66 +398,6 @@ mlir::LogicalResult ConstantMatrixOp::verify() { return mlir::success(); } -mlir::LogicalResult verifyLoop(mlir::Operation *op, mlir::ValueRange initArgs, - mlir::Region &bodyRegion, - mlir::Region &untilRegion) { - llvm::SmallVector initArgTypes; - for (auto arg : initArgs) { - initArgTypes.emplace_back(arg.getType()); - } - - if (op->getResultTypes() != initArgTypes) { - return op->emitOpError("result types ") - << op->getResultTypes() << " do not match init args " - << initArgTypes; - } - - // Block arguments must start with an iteration counter - auto *ctx = op->getContext(); - auto iterType = MatrixType::scalarOf(SemiringTypes::forInt(ctx)); - for (auto ®ion : op->getRegions()) { - if (region.empty()) { - continue; - } - - mlir::TypeRange argTypes = region.getArgumentTypes(); - if (argTypes.empty() || argTypes.front() != iterType) { - return op->emitOpError("region types ") - << argTypes << "do not include the iteration variable"; - } - } - - // Body must have YieldOp as terminator - auto &body = bodyRegion.front(); - if (!body.mightHaveTerminator()) { - return op->emitOpError("body region does not have a terminator"); - } - auto bodyYield = llvm::dyn_cast_if_present(body.getTerminator()); - if (!bodyYield) { - return op->emitOpError("body region is not terminated with a YieldOp"); - } - - // If there is an until block, it should return a boolean. - if (!untilRegion.empty()) { - auto &until = untilRegion.front(); - if (!until.mightHaveTerminator()) { - return op->emitOpError("until region does not have a terminator"); - } - auto untilYield = llvm::dyn_cast_if_present(until.getTerminator()); - if (!untilYield) { - return op->emitOpError("until region is not terminated with a YieldOp"); - } - - auto expectedType = MatrixType::scalarOf(SemiringTypes::forBool(ctx)); - if (untilYield->getOperandTypes() != mlir::TypeRange{expectedType}) { - return op->emitOpError("until block does not return a bool scalar: ") - << untilYield->getOperandTypes(); - } - } - - return mlir::success(); -} - void getLoopSuccessorRegions( mlir::Operation *op, mlir::Region &body, mlir::Region &until, mlir::RegionBranchPoint point, @@ -505,29 +445,81 @@ mlir::LogicalResult BroadcastOp::verify() { return mlir::success(); } -// === ForConstOp === -mlir::LogicalResult ForConstOp::verifyRegions() { - return verifyLoop(getOperation(), getInitArgs(), getBody(), getUntil()); -} +// === ForOp === +mlir::LogicalResult ForOp::verify() { + if (getDynBegin() && getBegin()) { + return emitOpError("begin and dyn_begin are mutually exclusive"); + } else if (!getDynBegin() && !getBegin()) { + return emitOpError("no loop start: must have 'begin' or 'dyn_begin'"); + } -void ForConstOp::getSuccessorRegions( - mlir::RegionBranchPoint point, - llvm::SmallVectorImpl ®ions) { - getLoopSuccessorRegions(getOperation(), getBody(), getUntil(), point, - regions); -} + if (getDynEnd() && getIters()) { + return emitOpError("begin and iters are mutually exclusive"); + } else if (!getDynEnd() && !getIters()) { + return emitOpError("no loop end: must have 'iters' or 'dyn_end'"); + } -mlir::OperandRange -ForConstOp::getEntrySuccessorOperands(mlir::RegionBranchPoint point) { - return getInitArgs(); + return mlir::success(); } -// === ForDimOp === -mlir::LogicalResult ForDimOp::verifyRegions() { - return verifyLoop(getOperation(), getInitArgs(), getBody(), getUntil()); +mlir::LogicalResult ForOp::verifyRegions() { + llvm::SmallVector initArgTypes; + for (auto arg : getInitArgs()) { + initArgTypes.emplace_back(arg.getType()); + } + + if (getResultTypes() != initArgTypes) { + return emitOpError("result types ") + << getResultTypes() << " do not match init args " << initArgTypes; + } + + // Block arguments must start with an iteration counter + auto *ctx = getContext(); + auto iterType = MatrixType::scalarOf(SemiringTypes::forInt(ctx)); + for (auto ®ion : getOperation()->getRegions()) { + if (region.empty()) { + continue; + } + + mlir::TypeRange argTypes = region.getArgumentTypes(); + if (argTypes.empty() || argTypes.front() != iterType) { + return emitOpError("region types ") + << argTypes << "do not include the iteration variable"; + } + } + + // Body must have YieldOp as terminator + auto &body = getBody().front(); + if (!body.mightHaveTerminator()) { + return emitOpError("body region does not have a terminator"); + } + auto bodyYield = llvm::dyn_cast_if_present(body.getTerminator()); + if (!bodyYield) { + return emitOpError("body region is not terminated with a YieldOp"); + } + + // If there is an until block, it should return a boolean. + if (!getUntil().empty()) { + auto &until = getUntil().front(); + if (!until.mightHaveTerminator()) { + return emitOpError("until region does not have a terminator"); + } + auto untilYield = llvm::dyn_cast_if_present(until.getTerminator()); + if (!untilYield) { + return emitOpError("until region is not terminated with a YieldOp"); + } + + auto expectedType = MatrixType::scalarOf(SemiringTypes::forBool(ctx)); + if (untilYield->getOperandTypes() != mlir::TypeRange{expectedType}) { + return emitOpError("until block does not return a bool scalar: ") + << untilYield->getOperandTypes(); + } + } + + return mlir::success(); } -void ForDimOp::getSuccessorRegions( +void ForOp::getSuccessorRegions( mlir::RegionBranchPoint point, llvm::SmallVectorImpl ®ions) { getLoopSuccessorRegions(getOperation(), getBody(), getUntil(), point, @@ -535,10 +527,12 @@ void ForDimOp::getSuccessorRegions( } mlir::OperandRange -ForDimOp::getEntrySuccessorOperands(mlir::RegionBranchPoint point) { +ForOp::getEntrySuccessorOperands(mlir::RegionBranchPoint point) { return getInitArgs(); } +bool ForOp::isDynamicRange() { return getDynBegin() || getDynEnd(); } + // === YieldOp === mlir::MutableOperandRange YieldOp::getMutableSuccessorOperands(mlir::RegionBranchPoint point) { diff --git a/compiler/src/graphalg/GraphAlgScalarizeApply.cpp b/compiler/src/graphalg/GraphAlgScalarizeApply.cpp index 3f3008c..3821395 100644 --- a/compiler/src/graphalg/GraphAlgScalarizeApply.cpp +++ b/compiler/src/graphalg/GraphAlgScalarizeApply.cpp @@ -149,17 +149,17 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::success(); } -// Unrolls a single iteration of the loop body. +// Unroll one iteration of the loop. static mlir::LogicalResult -unrollLoopBody(mlir::Operation *loopOp, mlir::Block &body, - const llvm::APInt &iter, mlir::ValueRange iterArgs, +unrollLoopBody(ForOp forOp, std::uint64_t iter, mlir::ValueRange iterArgs, llvm::SmallVectorImpl &results, mlir::OpBuilder &builder) { mlir::IRMapping mapping; // Map iter var to a constant. - auto zeroOp = builder.create( - loopOp->getLoc(), mlir::IntegerAttr::get(builder.getI64Type(), iter)); + auto zeroOp = builder.create(forOp.getLoc(), + builder.getI64IntegerAttr(iter)); + auto &body = forOp.getBody().front(); mapping.map(body.getArgument(0), zeroOp); // Map the remainder of the body arguments to the init args. @@ -184,68 +184,23 @@ unrollLoopBody(mlir::Operation *loopOp, mlir::Block &body, mapping.map(&oldOp, newOp); } - return loopOp->emitOpError("does not have a terminator ") + return forOp.emitOpError("does not have a terminator ") << YieldOp::getOperationName() << ", so cannot be inlined"; } -static mlir::LogicalResult unrollForDimOne(ForDimOp op, - mlir::PatternRewriter &rewriter) { - // Well-formed GraphAlg programs do not use dimension symbols inside - // functions that do not have this dimension symbol as an argument (but this - // is not currently enforced by verifiers). Inside of an apply all - // dimensions are one, so this should cover all well-formed programs. - if (!op.getDim().isOne()) { - // We need a constant size to be able to do loop unrolling. - return op->emitOpError("is contained in an ") - << ApplyInlineOp::getOperationName() - << ", but contains a non-1 dimension symbol reference"; - } - - // Inline the body - llvm::APInt iter(64, std::int64_t(0)); - llvm::SmallVector resultValues; - if (mlir::failed(unrollLoopBody(op, op.getBody().front(), iter, - op.getInitArgs(), resultValues, rewriter))) { - return mlir::failure(); - } - - rewriter.replaceOp(op, resultValues); - return mlir::success(); -} - -static std::optional tryGetConstantRangeValue(mlir::Value v) { - auto litOp = v.getDefiningOp(); - if (!litOp) { - return std::nullopt; - } - - if (auto intAttr = llvm::dyn_cast(litOp.getValue())) { - return intAttr.getValue(); - } - - return std::nullopt; -} - -static mlir::LogicalResult unrollForConst(ForConstOp op, - mlir::PatternRewriter &rewriter) { - auto begin = tryGetConstantRangeValue(op.getRangeBegin()); - auto end = tryGetConstantRangeValue(op.getRangeEnd()); - if (!begin || !end) { - // TODO: Handle loops where the bound is not yet constant. - // Ranges should be constant in GraphAlg, but that constant may be - // provided by the caller of a function, so we could encounter - // \c mlir::BlockArgument here in a valid program. - return op->emitOpError("does not have a constant range") +static mlir::LogicalResult unrollFor(ForOp op, + mlir::PatternRewriter &rewriter) { + if (op.isDynamicRange() || op.getIters()->isAbstract()) { + return op.emitOpError("does not have a constant range") << ", so cannot be unrolled"; } llvm::SmallVector iterArgs(op.getInitArgs()); - auto one = llvm::APInt(begin->getBitWidth(), 1); - for (auto i = *begin; i.slt(*end); i = i + one) { + auto end = *op.getBegin() + op.getIters()->getConcreteDim(); + for (auto i = *op.getBegin(); i < end; i++) { llvm::SmallVector results; - if (mlir::failed(unrollLoopBody(op, op.getBody().front(), i, iterArgs, - results, rewriter))) { + if (mlir::failed(unrollLoopBody(op, i, iterArgs, results, rewriter))) { return mlir::failure(); } @@ -457,8 +412,7 @@ void GraphAlgScalarizeApply::runOnOperation() { OpConversion, OpConversion>( typeConverter, &getContext()); // Simplifications into ops that can be converted using rules above. - conversions.add(unrollForDimOne); - conversions.add(unrollForConst); + conversions.add(unrollFor); conversions.add(convertMask); conversions.add(convertMatMul); conversions.add(convertNeg); diff --git a/compiler/src/graphalg/GraphAlgSetConstArg.cpp b/compiler/src/graphalg/GraphAlgSetConstArg.cpp new file mode 100644 index 0000000..a3d58df --- /dev/null +++ b/compiler/src/graphalg/GraphAlgSetConstArg.cpp @@ -0,0 +1,70 @@ +#include +#include +#include + +#include "graphalg/GraphAlgOps.h" +#include "graphalg/GraphAlgPasses.h" +#include "graphalg/GraphAlgTypes.h" +#include "graphalg/SemiringTypes.h" + +namespace graphalg { + +#define GEN_PASS_DEF_GRAPHALGSETCONSTARG +#include "graphalg/GraphAlgPasses.h.inc" + +namespace { + +/** + * Propagate integer constant arguments into function bodies. + * + * Run canonicalization after this pass to propagate the constant through the + * program. This pass does not change the function signature (that is, it does + * not remove the original argument). + */ +class GraphAlgSetConstArg + : public impl::GraphAlgSetConstArgBase { + using impl::GraphAlgSetConstArgBase< + GraphAlgSetConstArg>::GraphAlgSetConstArgBase; + + void runOnOperation() final; +}; + +} // namespace + +void GraphAlgSetConstArg::runOnOperation() { + if (functionName.empty()) { + getOperation().emitError("missing value for required option 'func'"); + return signalPassFailure(); + } + + if (argumentNumber < 0) { + getOperation().emitError("missing value for required option 'argNum'"); + return signalPassFailure(); + } + + auto func = llvm::dyn_cast_if_present( + getOperation().lookupSymbol(functionName)); + if (!func) { + getOperation().emitOpError("does not contain a function named '") + << functionName << "'"; + return signalPassFailure(); + } + + auto &body = func.getBody().front(); + auto numArgs = body.getNumArguments(); + if (argumentNumber >= numArgs) { + getOperation().emitOpError("argument number ") + << argumentNumber << " is out of bounds function " << functionName + << ", which only has " << numArgs << " parameters"; + return signalPassFailure(); + } + + mlir::IRRewriter rewriter(func); + rewriter.setInsertionPointToStart(&body); + auto constOp = rewriter.create( + func.getLoc(), MatrixType::scalarOf(SemiringTypes::forInt(&getContext())), + rewriter.getI64IntegerAttr(value)); + rewriter.replaceAllUsesWith(body.getArgument(argumentNumber), constOp); +} + +} // namespace graphalg diff --git a/compiler/src/graphalg/GraphAlgSetDimensions.cpp b/compiler/src/graphalg/GraphAlgSetDimensions.cpp index 8578816..1c83839 100644 --- a/compiler/src/graphalg/GraphAlgSetDimensions.cpp +++ b/compiler/src/graphalg/GraphAlgSetDimensions.cpp @@ -125,20 +125,6 @@ class DimConversionPattern : public mlir::ConversionPattern { mlir::ConversionPatternRewriter &rewriter) const override; }; -/** Template for rewrites (without type conversion). */ -template -class DimOpRewritePattern : public mlir::OpRewritePattern { -private: - const DimMapper &_mapper; - - mlir::LogicalResult - matchAndRewrite(T op, mlir::PatternRewriter &rewriter) const override; - -public: - DimOpRewritePattern(const DimMapper &mapper, mlir::MLIRContext *ctx) - : mlir::OpRewritePattern(ctx), _mapper(mapper) {} -}; - } // namespace mlir::FailureOr @@ -309,31 +295,33 @@ mlir::LogicalResult DimConversionPattern::matchAndRewrite( return mlir::success(); } -template <> -mlir::LogicalResult DimOpRewritePattern::matchAndRewrite( - CastDimOp op, mlir::PatternRewriter &rewriter) const { - auto dim = _mapper.convertAttr(op.getInput()); - if (!dim) { - return mlir::failure(); +static mlir::LogicalResult updateDim(ForOp op, DimMapper &mapper) { + if (!op.getIters() || !op.getIters()->isAbstract()) { + // No update needed + return mlir::success(); } - // The folder on CastDimOp should turn this into a constant. - auto newOp = rewriter.createOrFold(op->getLoc(), dim); - rewriter.replaceOp(op, newOp); + auto dim = mapper.convertAttr(*op.getIters()); + if (!dim) { + return op.emitOpError("no mapping for ") << *op.getIters(); + } + op.setItersAttr(dim); return mlir::success(); } -template <> -mlir::LogicalResult DimOpRewritePattern::matchAndRewrite( - ForDimOp op, mlir::PatternRewriter &rewriter) const { - auto dim = _mapper.convertAttr(op.getDim()); - if (!dim) { - return mlir::failure(); +static mlir::LogicalResult updateDim(CastDimOp op, DimMapper &mapper) { + if (!op.getInput().isAbstract()) { + // No update needed + return mlir::success(); } - rewriter.modifyOpInPlace(op, [&]() { op.setDimAttr(dim); }); + auto dim = mapper.convertAttr(op.getInput()); + if (!dim) { + return op.emitOpError("no mapping for ") << op.getInput(); + } + op.setInputAttr(dim); return mlir::success(); } @@ -356,6 +344,19 @@ void GraphAlgSetDimensions::runOnOperation() { return signalPassFailure(); } + // Update direct references to dimensions. + bool failedDirectUpdate = false; + func->walk([&](ForOp op) { + if (mlir::failed(updateDim(op, *dimMapper))) { + failedDirectUpdate = true; + } + }); + func->walk([&](CastDimOp op) { + if (mlir::failed(updateDim(op, *dimMapper))) { + failedDirectUpdate = true; + } + }); + mlir::ConversionTarget target(getContext()); target.addDynamicallyLegalDialect( doesNotUseAbstractDimensions); @@ -363,7 +364,6 @@ void GraphAlgSetDimensions::runOnOperation() { doesNotUseAbstractDimensions); target.addDynamicallyLegalOp(doesNotHaveAbstractInputs); target.addIllegalOp(); - target.addIllegalOp(); mlir::RewritePatternSet patterns(&getContext()); @@ -374,12 +374,6 @@ void GraphAlgSetDimensions::runOnOperation() { // Convert all result types and block argument types. patterns.add(typeConverter, &getContext()); - // Convert ops that have a special dependency on DimAttr. - patterns.add, DimOpRewritePattern>( - *dimMapper, &getContext()); - // Use the canonicalization pattern to rewrite ForDimOp into ForConstOp. - ForDimOp::getCanonicalizationPatterns(patterns, &getContext()); - if (mlir::failed( mlir::applyPartialConversion(func, target, std::move(patterns)))) { return signalPassFailure(); diff --git a/compiler/src/graphalg/GraphAlgToCore.cpp b/compiler/src/graphalg/GraphAlgToCore.cpp index dbc8993..fdee0c4 100644 --- a/compiler/src/graphalg/GraphAlgToCore.cpp +++ b/compiler/src/graphalg/GraphAlgToCore.cpp @@ -254,6 +254,8 @@ void GraphAlgToCore::runOnOperation() { target.addIllegalDialect(); target.addDynamicallyLegalDialect( [](mlir::Operation *op) { return op->hasTrait(); }); + target.addDynamicallyLegalOp( + [](ForOp op) { return !op.isDynamicRange(); }); mlir::RewritePatternSet patterns(&getContext()); patterns.add(convertVecMatMul); @@ -266,6 +268,20 @@ void GraphAlgToCore::runOnOperation() { patterns.add(convertTriu); patterns.add(convertLiteral); + // Conversion will give very unclear errors about dynamic range for loops, so + // do our own analysis first. + bool haveDynamicRangeLoops = false; + getOperation()->walk([&](ForOp op) { + if (op.isDynamicRange()) { + op.emitOpError("loop bound must be a constant in GraphAlg Core"); + haveDynamicRangeLoops = true; + } + }); + + if (haveDynamicRangeLoops) { + return signalPassFailure(); + } + if (mlir::failed(mlir::applyFullConversion(getOperation(), target, std::move(patterns)))) { signalPassFailure(); diff --git a/compiler/src/graphalg/analysis/DenseAnalysis.cpp b/compiler/src/graphalg/analysis/DenseAnalysis.cpp index da2dfd1..36fec72 100644 --- a/compiler/src/graphalg/analysis/DenseAnalysis.cpp +++ b/compiler/src/graphalg/analysis/DenseAnalysis.cpp @@ -52,7 +52,7 @@ DenseAnalysis::visitOperation(mlir::Operation *op, void DenseAnalysis::visitNonControlFlowArguments( mlir::Operation *op, const mlir::RegionSuccessor &successor, llvm::ArrayRef argLattices, unsigned firstIndex) { - if (llvm::isa(op)) { + if (llvm::isa(op)) { // Iteration counter is dense. assert(firstIndex == 1); auto arg = argLattices[0]; diff --git a/compiler/src/graphalg/evaluate/Evaluator.cpp b/compiler/src/graphalg/evaluate/Evaluator.cpp index 8545c4b..8e08b30 100644 --- a/compiler/src/graphalg/evaluate/Evaluator.cpp +++ b/compiler/src/graphalg/evaluate/Evaluator.cpp @@ -35,7 +35,7 @@ class Evaluator { mlir::LogicalResult evaluate(ReduceOp op); mlir::LogicalResult evaluate(BroadcastOp op); mlir::LogicalResult evaluate(ConstantMatrixOp op); - mlir::LogicalResult evaluate(ForConstOp op); + mlir::LogicalResult evaluate(ForOp op); mlir::LogicalResult evaluate(ApplyOp op); mlir::LogicalResult evaluate(PickAnyOp op); mlir::LogicalResult evaluate(TrilOp op); @@ -187,12 +187,19 @@ mlir::LogicalResult Evaluator::evaluate(ConstantMatrixOp op) { return mlir::success(); } -mlir::LogicalResult Evaluator::evaluate(ForConstOp op) { - MatrixAttrReader rangeBeginMat(_values[op.getRangeBegin()]); - MatrixAttrReader rangeEndMat(_values[op.getRangeEnd()]); - auto rangeBegin = - llvm::cast(rangeBeginMat.at(0, 0)).getInt(); - auto rangeEnd = llvm::cast(rangeEndMat.at(0, 0)).getInt(); +mlir::LogicalResult Evaluator::evaluate(ForOp op) { + auto begin = op.getBegin(); + if (!begin) { + return op.emitOpError("loop begin is not constant"); + } + + auto iters = op.getIters(); + if (!iters || iters->isAbstract()) { + return op.emitOpError("number of loop iterations is not constant"); + } + + auto rangeBegin = *begin; + auto rangeEnd = rangeBegin + iters->getConcreteDim(); auto &body = op.getBody().front(); auto *ctx = op.getContext(); @@ -332,7 +339,7 @@ mlir::LogicalResult Evaluator::evaluate(mlir::Operation *op) { return llvm::TypeSwitch(op) #define GA_CASE(Op) .Case([&](Op op) { return evaluate(op); }) GA_CASE(TransposeOp) GA_CASE(DiagOp) GA_CASE(MatMulOp) GA_CASE(ReduceOp) - GA_CASE(BroadcastOp) GA_CASE(ConstantMatrixOp) GA_CASE(ForConstOp) + GA_CASE(BroadcastOp) GA_CASE(ConstantMatrixOp) GA_CASE(ForOp) GA_CASE(ApplyOp) GA_CASE(PickAnyOp) GA_CASE(TrilOp) #undef GA_CASE .Default([](mlir::Operation *op) { diff --git a/compiler/src/graphalg/parse/Parser.cpp b/compiler/src/graphalg/parse/Parser.cpp index e1013ef..d466c26 100644 --- a/compiler/src/graphalg/parse/Parser.cpp +++ b/compiler/src/graphalg/parse/Parser.cpp @@ -713,14 +713,18 @@ mlir::ParseResult Parser::parseStmtFor() { mlir::Region *untilRegion; mlir::ValueRange results; if (range.dim) { - auto forOp = _builder.create(loc, varTypes, initArgs, range.dim); + auto begin = _builder.getI64IntegerAttr(0); + auto forOp = + _builder.create(loc, varTypes, initArgs, /*dynBegin=*/nullptr, + /*dynEnd=*/nullptr, begin, range.dim); bodyRegion = &forOp.getBody(); untilRegion = &forOp.getUntil(); results = forOp->getResults(); } else { assert(range.begin && range.end); - auto forOp = _builder.create(loc, varTypes, initArgs, - range.begin, range.end); + auto forOp = _builder.create( + loc, varTypes, initArgs, /*dynBegin=*/range.begin, /*dynEnd=*/range.end, + /*begin=*/nullptr, /*end=*/nullptr); bodyRegion = &forOp.getBody(); untilRegion = &forOp.getUntil(); results = forOp->getResults(); @@ -827,7 +831,7 @@ mlir::ParseResult Parser::parseStmtReturn() { // Check if return is inside a loop auto *parentOp = _builder.getInsertionBlock()->getParentOp(); - if (llvm::isa(parentOp)) { + if (llvm::isa(parentOp)) { return mlir::emitError(loc) << "return statement inside a loop is not allowed"; } diff --git a/compiler/test/canonicalize/for-dim.mlir b/compiler/test/canonicalize/for-dim.mlir deleted file mode 100644 index 167a493..0000000 --- a/compiler/test/canonicalize/for-dim.mlir +++ /dev/null @@ -1,19 +0,0 @@ -// RUN: graphalg-opt --canonicalize < %s | FileCheck %s - -func.func @ForDimConst() -> !graphalg.mat<1 x 1 x i64> { - - // CHECK: %[[#ZERO:]] = graphalg.const_mat 0 - // CHECK: %[[#END:]] = graphalg.const_mat 42 - // CHECK: %[[#FOR:]] = graphalg.for_const range(%[[#ZERO]], %[[#END]]) - // CHECK-SAME: init(%[[#ZERO]]) - // CHECK: graphalg.yield %arg1 : !graphalg.mat<1 x 1 x i64> - // CHECK: return %[[#FOR]] - - %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> - %1 = graphalg.for_dim range(42) init(%0) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { - ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<1 x 1 x i64>): - graphalg.yield %arg2 : !graphalg.mat<1 x 1 x i64> - } until { - } - return %1 : !graphalg.mat<1 x 1 x i64> -} diff --git a/compiler/test/e2e/cdlp.gr b/compiler/test/e2e/cdlp.gr index 1e57991..86c8f38 100644 --- a/compiler/test/e2e/cdlp.gr +++ b/compiler/test/e2e/cdlp.gr @@ -1,7 +1,7 @@ // RUN: split-file %s %t // RUN: graphalg-translate --import-graphalg %t/input.gr > %t/parsed.mlir -// RUN: graphalg-opt --graphalg-to-core-pipeline --graphalg-set-dimensions='func=CDLP args=8x8' %t/parsed.mlir > %t/exec.mlir -// RUN: graphalg-exec %t/exec.mlir CDLP %t/graph.m | diff - %t/output.m +// RUN: graphalg-opt --graphalg-set-const-arg='func=CDLP argNum=1 value=5' --graphalg-to-core-pipeline --graphalg-set-dimensions='func=CDLP args=8x8,1x1' %t/parsed.mlir > %t/exec.mlir +// RUN: graphalg-exec %t/exec.mlir CDLP %t/graph.m %t/iters.m | diff - %t/output.m //--- graph.m 0 1 @@ -23,14 +23,16 @@ 6 7 7 5 +//--- iters.m +0 0 5 + //--- input.gr func isMax(v: int, max: trop_max_int) -> bool { return (cast(v) == max) * (v != zero(int)); } -func CDLP(graph: Matrix) -> Matrix { - iterations = int(5); +func CDLP(graph: Matrix, iterations:int) -> Matrix { id = Vector(graph.nrows); id[:] = bool(true); L = diag(id); diff --git a/compiler/test/e2e/pr.gr b/compiler/test/e2e/pr.gr index 374eb7d..c9b3c29 100644 --- a/compiler/test/e2e/pr.gr +++ b/compiler/test/e2e/pr.gr @@ -1,7 +1,7 @@ // RUN: split-file %s %t // RUN: graphalg-translate --import-graphalg %t/input.gr > %t/parsed.mlir -// RUN: graphalg-opt --graphalg-to-core-pipeline --graphalg-set-dimensions='func=PR args=50x50' %t/parsed.mlir > %t/exec.mlir -// RUN: graphalg-exec %t/exec.mlir PR %t/graph.m | diff - %t/output.m +// RUN: graphalg-opt --graphalg-set-const-arg='func=PR argNum=1 value=10' --graphalg-to-core-pipeline --graphalg-set-dimensions='func=PR args=50x50,1x1' %t/parsed.mlir > %t/exec.mlir +// RUN: graphalg-exec %t/exec.mlir PR %t/graph.m %t/iters.m | diff - %t/output.m //--- graph.m 0 18 @@ -251,14 +251,16 @@ 49 27 49 46 +//--- iters.m +0 0 10 + //--- input.gr func withDamping(degree:int, damping:real) -> real { return cast(degree) / damping; } -func PR(graph: Matrix) -> Vector { +func PR(graph: Matrix, iterations:int) -> Vector { damping = real(0.85); - iterations = int(10); n = graph.nrows; teleport = (real(1.0) - damping) / cast(n); rdiff = real(1.0); diff --git a/compiler/test/exec/for.mlir b/compiler/test/exec/for.mlir index 6516987..2450078 100644 --- a/compiler/test/exec/for.mlir +++ b/compiler/test/exec/for.mlir @@ -10,21 +10,19 @@ //--- input.mlir func.func @Reach(%arg0: !graphalg.mat<3 x 3 x i1>, %arg1: !graphalg.mat<3 x 1 x i1>) -> !graphalg.mat<3 x 1 x i1> { - %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> - %1 = graphalg.const_mat 3 : i64 -> <1 x 1 x i64> - %2 = graphalg.for_const range(%0, %1) : <1 x 1 x i64> init(%arg1) : !graphalg.mat<3 x 1 x i1> -> !graphalg.mat<3 x 1 x i1> body { + %0 = graphalg.for begin=0 iters=<3> init(%arg1) : !graphalg.mat<3 x 1 x i1> -> !graphalg.mat<3 x 1 x i1> body { ^bb0(%arg2: !graphalg.mat<1 x 1 x i64>, %arg3: !graphalg.mat<3 x 1 x i1>): - %3 = graphalg.transpose %arg0 : <3 x 3 x i1> - %4 = graphalg.mxm %3, %arg3 : <3 x 3 x i1>, <3 x 1 x i1> - %5 = graphalg.apply %arg3, %4 : !graphalg.mat<3 x 1 x i1>, !graphalg.mat<3 x 1 x i1> -> <3 x 1 x i1> { + %1 = graphalg.transpose %arg0 : <3 x 3 x i1> + %2 = graphalg.mxm %1, %arg3 : <3 x 3 x i1>, <3 x 1 x i1> + %3 = graphalg.apply %arg3, %2 : !graphalg.mat<3 x 1 x i1>, !graphalg.mat<3 x 1 x i1> -> <3 x 1 x i1> { ^bb0(%arg4: i1, %arg5: i1): - %6 = graphalg.add %arg4, %arg5 : i1 - graphalg.apply.return %6 : i1 + %4 = graphalg.add %arg4, %arg5 : i1 + graphalg.apply.return %4 : i1 } - graphalg.yield %5 : !graphalg.mat<3 x 1 x i1> + graphalg.yield %3 : !graphalg.mat<3 x 1 x i1> } until { } - return %2 : !graphalg.mat<3 x 1 x i1> + return %0 : !graphalg.mat<3 x 1 x i1> } //--- output.m diff --git a/compiler/test/exec/until.mlir b/compiler/test/exec/until.mlir index 7ad1e26..cb10fc5 100644 --- a/compiler/test/exec/until.mlir +++ b/compiler/test/exec/until.mlir @@ -5,26 +5,25 @@ func.func @Fib() -> !graphalg.mat<1 x 1 x i64> { %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> %1 = graphalg.const_mat 1 : i64 -> <1 x 1 x i64> - %2 = graphalg.const_mat 1000000 : i64 -> <1 x 1 x i64> - %3:2 = graphalg.for_const range(%0, %2) : <1 x 1 x i64> init(%0, %1) : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> body { + %2:2 = graphalg.for begin=0 iters=<1000000> init(%0, %1) : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> body { ^bb0(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<1 x 1 x i64>): - %4 = graphalg.apply %arg1, %arg2 : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i64> { + %3 = graphalg.apply %arg1, %arg2 : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i64> { ^bb0(%arg3: i64, %arg4: i64): - %5 = graphalg.add %arg3, %arg4 : i64 - graphalg.apply.return %5 : i64 + %4 = graphalg.add %arg3, %arg4 : i64 + graphalg.apply.return %4 : i64 } - graphalg.yield %arg2, %4 : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> + graphalg.yield %arg2, %3 : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> } until { ^bb0(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<1 x 1 x i64>): - %4 = graphalg.apply %arg2 : !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i1> { + %3 = graphalg.apply %arg2 : !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i1> { ^bb0(%arg3: i64): - %5 = graphalg.const 34 : i64 - %6 = graphalg.eq %arg3, %5 : i64 - graphalg.apply.return %6 : i1 + %4 = graphalg.const 34 : i64 + %5 = graphalg.eq %arg3, %4 : i64 + graphalg.apply.return %5 : i1 } - graphalg.yield %4 : !graphalg.mat<1 x 1 x i1> + graphalg.yield %3 : !graphalg.mat<1 x 1 x i1> } - return %3#1 : !graphalg.mat<1 x 1 x i64> + return %2#1 : !graphalg.mat<1 x 1 x i64> } //--- output.m diff --git a/compiler/test/golden/bfs.mlir.ref b/compiler/test/golden/bfs.mlir.ref index dd33de2..aa66d71 100644 --- a/compiler/test/golden/bfs.mlir.ref +++ b/compiler/test/golden/bfs.mlir.ref @@ -14,7 +14,7 @@ module @"" { %3 = graphalg.broadcast %2 : <1 x 1 x i64> -> <#dim x 1 x i64> %4 = graphalg.mask %1<%arg1 : <#dim x 1 x i1>> = %3 : <#dim x 1 x i64> {complement = false} %5 = graphalg.cast_dim #dim - %6:3 = graphalg.for_dim range(#dim) init(%4, %arg1, %arg1) : !graphalg.mat<#dim x 1 x i64>, !graphalg.mat<#dim x 1 x i1>, !graphalg.mat<#dim x 1 x i1> -> !graphalg.mat<#dim x 1 x i64>, !graphalg.mat<#dim x 1 x i1>, !graphalg.mat<#dim x 1 x i1> body { + %6:3 = graphalg.for begin=0 iters=#dim init(%4, %arg1, %arg1) : !graphalg.mat<#dim x 1 x i64>, !graphalg.mat<#dim x 1 x i1>, !graphalg.mat<#dim x 1 x i1> -> !graphalg.mat<#dim x 1 x i64>, !graphalg.mat<#dim x 1 x i1>, !graphalg.mat<#dim x 1 x i1> body { ^bb0(%arg2: !graphalg.mat<1 x 1 x i64>, %arg3: !graphalg.mat<#dim x 1 x i64>, %arg4: !graphalg.mat<#dim x 1 x i1>, %arg5: !graphalg.mat<#dim x 1 x i1>): %7 = graphalg.cast_dim #dim %8 = graphalg.const_mat false -> <#dim x 1 x i1> diff --git a/compiler/test/golden/cdlp.mlir.ref b/compiler/test/golden/cdlp.mlir.ref index a3955d9..6a9666b 100644 --- a/compiler/test/golden/cdlp.mlir.ref +++ b/compiler/test/golden/cdlp.mlir.ref @@ -8,7 +8,7 @@ module @"" { %4 = graphalg.broadcast %3 : <1 x 1 x i1> -> <#dim x 1 x i1> %5 = graphalg.diag %4 : !graphalg.mat<#dim x 1 x i1> %6 = graphalg.literal 0 : i64 - %7 = graphalg.for_const range(%6, %0) : <1 x 1 x i64> init(%5) : !graphalg.mat<#dim x #dim x i1> -> !graphalg.mat<#dim x #dim x i1> body { + %7 = graphalg.for dyn_begin=%6 dyn_end=%0 init(%5) : !graphalg.mat<#dim x #dim x i1> -> !graphalg.mat<#dim x #dim x i1> body { ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<#dim x #dim x i1>): %19 = graphalg.cast %arg0 : <#dim x #dim x i1> -> <#dim x #dim x i64> %20 = graphalg.cast %arg2 : <#dim x #dim x i1> -> <#dim x #dim x i64> diff --git a/compiler/test/golden/pr.mlir.ref b/compiler/test/golden/pr.mlir.ref index 59b2785..d5ff80e 100644 --- a/compiler/test/golden/pr.mlir.ref +++ b/compiler/test/golden/pr.mlir.ref @@ -26,7 +26,7 @@ module @"" { %17 = graphalg.ewise %15 DIV %16 : <1 x 1 x f64> %18 = graphalg.broadcast %17 : <1 x 1 x f64> -> <#dim x 1 x f64> %19 = graphalg.literal 0 : i64 - %20 = graphalg.for_const range(%19, %arg2) : <1 x 1 x i64> init(%18) : !graphalg.mat<#dim x 1 x f64> -> !graphalg.mat<#dim x 1 x f64> body { + %20 = graphalg.for dyn_begin=%19 dyn_end=%arg2 init(%18) : !graphalg.mat<#dim x 1 x f64> -> !graphalg.mat<#dim x 1 x f64> body { ^bb0(%arg3: !graphalg.mat<1 x 1 x i64>, %arg4: !graphalg.mat<#dim x 1 x f64>): %21 = graphalg.const_mat 0.000000e+00 : f64 -> <#dim x 1 x f64> %22 = graphalg.mask %21<%13 : <#dim x 1 x i1>> = %arg4 : <#dim x 1 x f64> {complement = false} diff --git a/compiler/test/golden/sssp.mlir.ref b/compiler/test/golden/sssp.mlir.ref index d4b21db..b32f2e7 100644 --- a/compiler/test/golden/sssp.mlir.ref +++ b/compiler/test/golden/sssp.mlir.ref @@ -3,7 +3,7 @@ module @"" { func.func @SSSP(%arg0: !graphalg.mat<#dim x #dim x !graphalg.trop_f64>, %arg1: !graphalg.mat<#dim x 1 x i1>) -> !graphalg.mat<#dim x 1 x !graphalg.trop_f64> { %0 = graphalg.cast %arg1 : <#dim x 1 x i1> -> <#dim x 1 x !graphalg.trop_f64> %1 = graphalg.cast_dim #dim - %2 = graphalg.for_dim range(#dim) init(%0) : !graphalg.mat<#dim x 1 x !graphalg.trop_f64> -> !graphalg.mat<#dim x 1 x !graphalg.trop_f64> body { + %2 = graphalg.for begin=0 iters=#dim init(%0) : !graphalg.mat<#dim x 1 x !graphalg.trop_f64> -> !graphalg.mat<#dim x 1 x !graphalg.trop_f64> body { ^bb0(%arg2: !graphalg.mat<1 x 1 x i64>, %arg3: !graphalg.mat<#dim x 1 x !graphalg.trop_f64>): %3 = graphalg.vxm %arg3, %arg0 : <#dim x 1 x !graphalg.trop_f64>, <#dim x #dim x !graphalg.trop_f64> %4 = graphalg.ewise %arg3 ADD %3 : <#dim x 1 x !graphalg.trop_f64> diff --git a/compiler/test/golden/update_goldens.sh b/compiler/test/golden/update_goldens.sh index 01c9183..306153d 100755 --- a/compiler/test/golden/update_goldens.sh +++ b/compiler/test/golden/update_goldens.sh @@ -7,6 +7,6 @@ GOLDEN_DIR=$(dirname $0) for f in $GOLDEN_DIR/*.gr; do out="${f%.gr}.mlir.ref" - "$BUILD_DIR/graphalg-translate" --import-graphalg < "$f" > "$out" + "$BUILD_DIR/tools/graphalg-translate" --import-graphalg < "$f" > "$out" echo "Wrote $out" done diff --git a/compiler/test/golden/wcc.mlir.ref b/compiler/test/golden/wcc.mlir.ref index a078151..7ed86d9 100644 --- a/compiler/test/golden/wcc.mlir.ref +++ b/compiler/test/golden/wcc.mlir.ref @@ -7,7 +7,7 @@ module @"" { %3 = graphalg.broadcast %2 : <1 x 1 x i1> -> <#dim x 1 x i1> %4 = graphalg.diag %3 : !graphalg.mat<#dim x 1 x i1> %5 = graphalg.cast_dim #dim - %6 = graphalg.for_dim range(#dim) init(%4) : !graphalg.mat<#dim x #dim x i1> -> !graphalg.mat<#dim x #dim x i1> body { + %6 = graphalg.for begin=0 iters=#dim init(%4) : !graphalg.mat<#dim x #dim x i1> -> !graphalg.mat<#dim x #dim x i1> body { ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<#dim x #dim x i1>): %7 = graphalg.mxm %arg0, %arg2 : <#dim x #dim x i1>, <#dim x #dim x i1> %8 = graphalg.ewise %arg2 ADD %7 : <#dim x #dim x i1> diff --git a/compiler/test/graphalg-to-rel/cast-dim.mlir b/compiler/test/graphalg-to-rel/cast-dim.mlir new file mode 100644 index 0000000..f48fca8 --- /dev/null +++ b/compiler/test/graphalg-to-rel/cast-dim.mlir @@ -0,0 +1,11 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +#dim = #graphalg.dim> +// CHECK-LABEL: @CastDim +func.func @CastDim(%arg0: !graphalg.mat<#dim x #dim x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: %[[#AGG:]] = garel.aggregate %arg1 : group_by=[] aggregators=[] + %0 = graphalg.cast_dim #dim + + // CHECK: return %[[#AGG]] + return %0 : !graphalg.mat<1 x 1 x i64> +} diff --git a/compiler/test/graphalg-to-rel/for-const.mlir b/compiler/test/graphalg-to-rel/for.mlir similarity index 61% rename from compiler/test/graphalg-to-rel/for-const.mlir rename to compiler/test/graphalg-to-rel/for.mlir index 2d418d8..6ba8936 100644 --- a/compiler/test/graphalg-to-rel/for-const.mlir +++ b/compiler/test/graphalg-to-rel/for.mlir @@ -1,13 +1,11 @@ // RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s -// CHECK-LABEL: @ForConst -func.func @ForConst(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { - %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> - %1 = graphalg.const_mat 10 : i64 -> <1 x 1 x i64> - +// CHECK-LABEL: @For +func.func @For(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { // CHECK: %[[#BEGIN:]] = garel.const 0 : i64 - // CHECK: %[[#FOR:]] = garel.for %[[#BEGIN]], %arg0 : !garel.rel, !garel.rel iters=10 result_idx=1 { - %2 = graphalg.for_const range(%0, %1) : <1 x 1 x i64> init(%arg0) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { + // CHECK: %[[#ITERS:]] = garel.const 10 : i64 + // CHECK: %[[#FOR:]] = garel.for %[[#BEGIN]], %arg0 : !garel.rel, !garel.rel iters=%[[#ITERS]] result_idx=1 { + %0 = graphalg.for begin=0 iters=<10> init(%arg0) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<1 x 1 x i64>): // CHECK: %[[#PROJ:]] = garel.project %arg1 // CHECK: %[[#EXT:]] = garel.extract 0 @@ -19,18 +17,15 @@ func.func @ForConst(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x } // CHECK: return %[[#FOR]] - return %2 : !graphalg.mat<1 x 1 x i64> + return %0 : !graphalg.mat<1 x 1 x i64> } // CHECK-LABEL: @ForResultUnused func.func @ForResultUnused(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.mat<1 x 1 x f64>) -> !graphalg.mat<1 x 1 x f64> { - %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> - %1 = graphalg.const_mat 10 : i64 -> <1 x 1 x i64> - // CHECK: %[[#BEGIN:]] = garel.const 0 : i64 // CHECK: %[[#FOR:]] = garel.for %[[#BEGIN]], %arg0, %arg1 - %2:2 = graphalg.for_const range(%0, %1) : <1 x 1 x i64> init(%arg0, %arg1) : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x f64> -> !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x f64> body { + %0:2 = graphalg.for begin=0 iters=<10> init(%arg0, %arg1) : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x f64> -> !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x f64> body { ^bb0(%arg2: !graphalg.mat<1 x 1 x i64>, %arg3: !graphalg.mat<1 x 1 x i64>, %arg4: !graphalg.mat<1 x 1 x f64>): // CHECK: %[[#PROJ:]] = garel.project %arg2 // CHECK: garel.for.yield %[[#PROJ]], %arg3, %arg4 @@ -39,17 +34,14 @@ func.func @ForResultUnused(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.m } // CHECK: return %[[#FOR]] - return %2#1 : !graphalg.mat<1 x 1 x f64> + return %0#1 : !graphalg.mat<1 x 1 x f64> } // CHECK-LABEL: @Until func.func @Until(%arg0: !graphalg.mat<42 x 42 x i1>) -> !graphalg.mat<42 x 42 x i1> { - %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> - %1 = graphalg.const_mat 10 : i64 -> <1 x 1 x i64> - // CHECK: %[[#BEGIN:]] = garel.const 0 : i64 // CHECK: %[[#FOR:]] = garel.for %[[#BEGIN]], %arg0 - %2 = graphalg.for_const range(%0, %1) : <1 x 1 x i64> init(%arg0) : !graphalg.mat<42 x 42 x i1> -> !graphalg.mat<42 x 42 x i1> body { + %0 = graphalg.for begin=0 iters=<10> init(%arg0) : !graphalg.mat<42 x 42 x i1> -> !graphalg.mat<42 x 42 x i1> body { ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<42 x 42 x i1>): // CHECK: %[[#PROJ:]] = garel.project %arg1 // CHECK: garel.for.yield %[[#PROJ]], %arg2 @@ -58,9 +50,9 @@ func.func @Until(%arg0: !graphalg.mat<42 x 42 x i1>) -> !graphalg.mat<42 x 42 x } until { ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<42 x 42 x i1>): // CHECK: %[[#AGG:]] = garel.aggregate %arg2 - %3 = graphalg.deferred_reduce %arg2 : !graphalg.mat<42 x 42 x i1> -> <1 x 1 x i1> + %1 = graphalg.deferred_reduce %arg2 : !graphalg.mat<42 x 42 x i1> -> <1 x 1 x i1> // CHECK: garel.for.yield %[[#AGG]] - graphalg.yield %3 : !graphalg.mat<1 x 1 x i1> + graphalg.yield %1 : !graphalg.mat<1 x 1 x i1> } - return %2 : !graphalg.mat<42 x 42 x i1> + return %0 : !graphalg.mat<42 x 42 x i1> } diff --git a/compiler/test/infer-density/for.mlir b/compiler/test/infer-density/for.mlir index 364dc15..0e31b0e 100644 --- a/compiler/test/infer-density/for.mlir +++ b/compiler/test/infer-density/for.mlir @@ -1,36 +1,12 @@ // RUN: graphalg-opt --test-print-dense --verify-diagnostics %s #dim = #graphalg.dim> -func.func @ForConst() -> !graphalg.mat<#dim x 1 x f64> { - %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> - %1 = graphalg.const_mat 10 : i64 -> <1 x 1 x i64> - %2 = graphalg.const_mat 42.0 : f64 -> <#dim x 1 x f64> - // expected-remark@below {{for}} - // expected-note@below {{operand #0: dense}} - // expected-note@below {{operand #1: dense}} - // expected-note@below {{operand #2: dense}} - // expected-note@below {{result #0: dense}} - %3 = graphalg.for_const - range(%0, %1) : <1 x 1 x i64> - init(%2) : !graphalg.mat<#dim x 1 x f64> - -> !graphalg.mat<#dim x 1 x f64> {tag = "for" } body { - // expected-note@below {{arg #0:0:0: dense}} - // expected-note@below {{arg #0:0:1: dense}} - ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<#dim x 1 x f64>): - %4 = graphalg.const_mat 42.0 : f64 -> <#dim x 1 x f64> - graphalg.yield %4 : !graphalg.mat<#dim x 1 x f64> - } until { - } - return %3 : !graphalg.mat<#dim x 1 x f64> -} - -func.func @ForDim() -> !graphalg.mat<#dim x 1 x f64> { +func.func @For() -> !graphalg.mat<#dim x 1 x f64> { %0 = graphalg.const_mat 42.0 : f64 -> <#dim x 1 x f64> // expected-remark@below {{for}} // expected-note@below {{operand #0: dense}} // expected-note@below {{result #0: dense}} - %1 = graphalg.for_dim range(#dim) - init(%0) : !graphalg.mat<#dim x 1 x f64> + %1 = graphalg.for begin=0 iters=<10> init(%0) : !graphalg.mat<#dim x 1 x f64> -> !graphalg.mat<#dim x 1 x f64> {tag = "for" } body { // expected-note@below {{arg #0:0:0: dense}} // expected-note@below {{arg #0:0:1: dense}} diff --git a/compiler/test/loop-aggregate/add-init-reduce.mlir b/compiler/test/loop-aggregate/add-init-reduce.mlir index 75914cb..9ce16d4 100644 --- a/compiler/test/loop-aggregate/add-init-reduce.mlir +++ b/compiler/test/loop-aggregate/add-init-reduce.mlir @@ -8,8 +8,8 @@ func.func @AddInitReduce(%arg0: !vec) -> !vec { // CHECK-SAME: %arg0 : !graphalg.mat<#dim x 1 x f64> // CHECK-SAME: -> <#dim x 1 x f64> // - // CHECK: graphalg.for_dim range(#dim) init(%[[#INIT]]) - %0 = graphalg.for_dim range(#dim) init(%arg0) : !vec -> !vec body { + // CHECK: graphalg.for begin=0 iters=#dim init(%[[#INIT]]) + %0 = graphalg.for begin=0 iters=#dim init(%arg0) : !vec -> !vec body { ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !vec): %1 = graphalg.deferred_reduce %arg2 : !vec -> !vec graphalg.yield %1 : !vec diff --git a/compiler/test/parse/loop.gr b/compiler/test/parse/loop.gr index 07b13fc..5120de5 100644 --- a/compiler/test/parse/loop.gr +++ b/compiler/test/parse/loop.gr @@ -6,7 +6,7 @@ func ForConst() -> int { a = int(42); // CHECK: %[[#BEGIN:]] = graphalg.literal 1 // CHECK: %[[#END:]] = graphalg.literal 10 - // CHECK: %[[#LOOP:]] = graphalg.for_const range(%[[#BEGIN]], %[[#END]]) + // CHECK: %[[#LOOP:]] = graphalg.for dyn_begin=%[[#BEGIN]] dyn_end=%[[#END]] // CHECK: init(%[[#INIT]]) // CHECK: -> !graphalg.mat<1 x 1 x i64> body { // CHECK: ^bb0(%arg0: !graphalg.mat<1 x 1 x i64>, @@ -26,7 +26,7 @@ func ForConst() -> int { func ForDim(m:Matrix) -> int { // CHECK: %[[#INIT:]] = graphalg.literal 42 a = int(42); - // CHECK: %[[#LOOP:]] = graphalg.for_dim range(#dim) + // CHECK: %[[#LOOP:]] = graphalg.for begin=0 iters=#dim // CHECK: init(%[[#INIT]]) // CHECK: -> !graphalg.mat<1 x 1 x i64> body { // CHECK: ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, diff --git a/compiler/test/scalarize-apply/for-dim.mlir b/compiler/test/scalarize-apply/for-dim.mlir deleted file mode 100644 index 558acf8..0000000 --- a/compiler/test/scalarize-apply/for-dim.mlir +++ /dev/null @@ -1,25 +0,0 @@ -// RUN: graphalg-opt --graphalg-scalarize-apply < %s | FileCheck %s - -#dim = #graphalg.dim> - -func.func @ForDim(%arg0: !graphalg.mat<#dim x #dim x i64>) -> !graphalg.mat<#dim x #dim x i64> { - // CHECK: %[[#APPLY:]] = graphalg.apply %arg0 - %0 = graphalg.apply_inline %arg0 : !graphalg.mat<#dim x #dim x i64> -> <#dim x #dim x i64> { - ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>): - // CHECK: %[[#ONE:]] = graphalg.const 1 - // CHECK: %[[#ADD:]] = graphalg.add %arg1, %[[#ONE]] - %1 = graphalg.literal 1 : i64 - %2 = graphalg.for_dim range(1) init(%arg1) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { - ^bb0(%arg2: !graphalg.mat<1 x 1 x i64>, %arg3: !graphalg.mat<1 x 1 x i64>): - %3 = graphalg.ewise %arg3 ADD %1 : <1 x 1 x i64> - graphalg.yield %3 : !graphalg.mat<1 x 1 x i64> - } until { - } - - // CHECK: graphalg.apply.return %[[#ADD]] - graphalg.apply_inline.return %2 : <1 x 1 x i64> - } - - // CHECK: return %[[#APPLY]] - return %0 : !graphalg.mat<#dim x #dim x i64> -} diff --git a/compiler/test/scalarize-apply/for-const.mlir b/compiler/test/scalarize-apply/for.mlir similarity index 62% rename from compiler/test/scalarize-apply/for-const.mlir rename to compiler/test/scalarize-apply/for.mlir index dc62cfa..7e23771 100644 --- a/compiler/test/scalarize-apply/for-const.mlir +++ b/compiler/test/scalarize-apply/for.mlir @@ -2,7 +2,7 @@ #dim = #graphalg.dim> -func.func @ForConst(%arg0: !graphalg.mat<#dim x #dim x i64>) -> !graphalg.mat<#dim x #dim x i64> { +func.func @For(%arg0: !graphalg.mat<#dim x #dim x i64>) -> !graphalg.mat<#dim x #dim x i64> { // CHECK: %[[#APPLY:]] = graphalg.apply %arg0 %0 = graphalg.apply_inline %arg0 : !graphalg.mat<#dim x #dim x i64> -> <#dim x #dim x i64> { ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>): @@ -11,17 +11,15 @@ func.func @ForConst(%arg0: !graphalg.mat<#dim x #dim x i64>) -> !graphalg.mat<#d // CHECK: %[[#ADD1:]] = graphalg.add %[[#ADD0]], %[[#ONE]] // CHECK: %[[#ADD2:]] = graphalg.add %[[#ADD1]], %[[#ONE]] %1 = graphalg.literal 1 : i64 - %2 = graphalg.literal 2 : i64 - %3 = graphalg.literal 5 : i64 - %4 = graphalg.for_const range(%2, %3) : <1 x 1 x i64> init(%arg1) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { + %2 = graphalg.for begin=2 iters=<3> init(%arg1) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { ^bb0(%arg2: !graphalg.mat<1 x 1 x i64>, %arg3: !graphalg.mat<1 x 1 x i64>): - %5 = graphalg.ewise %arg3 ADD %1 : <1 x 1 x i64> - graphalg.yield %5 : !graphalg.mat<1 x 1 x i64> + %3 = graphalg.ewise %arg3 ADD %1 : <1 x 1 x i64> + graphalg.yield %3 : !graphalg.mat<1 x 1 x i64> } until { } // CHECK: graphalg.apply.return %[[#ADD2]] - graphalg.apply_inline.return %4 : <1 x 1 x i64> + graphalg.apply_inline.return %2 : <1 x 1 x i64> } // CHECK: return %[[#APPLY]] diff --git a/compiler/test/set-dimensions/for-dim.mlir b/compiler/test/set-dimensions/for-dim.mlir index 515758e..c6ba788 100644 --- a/compiler/test/set-dimensions/for-dim.mlir +++ b/compiler/test/set-dimensions/for-dim.mlir @@ -3,10 +3,8 @@ func.func @ForDimConst(%arg0: !graphalg.mat<#dim x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { %0 = graphalg.const_mat 1 : i64 -> <1 x 1 x i64> - // CHECK: %[[#BEGIN:]] = graphalg.const_mat 0 - // CHECK: %[[#END:]] = graphalg.const_mat 42 - // CHECK: graphalg.for_const range(%[[#BEGIN]], %[[#END]]) - %1 = graphalg.for_dim range(#dim) init(%0) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { + // CHECK: graphalg.for begin=0 iters=<42> + %1 = graphalg.for begin=0 iters=#dim init(%0) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<1 x 1 x i64>): graphalg.yield %arg2 : !graphalg.mat<1 x 1 x i64> } until { diff --git a/compiler/test/verify-dimensions/for-dim.mlir b/compiler/test/verify-dimensions/for-dim.mlir index 1a447fb..a7e5a73 100644 --- a/compiler/test/verify-dimensions/for-dim.mlir +++ b/compiler/test/verify-dimensions/for-dim.mlir @@ -3,7 +3,7 @@ func.func @Ok(%arg0: !graphalg.mat<#dim x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> - %1 = graphalg.for_dim range(#dim) init(%0) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { + %1 = graphalg.for begin=0 iters=#dim init(%0) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<1 x 1 x i64>): graphalg.yield %arg2 : !graphalg.mat<1 x 1 x i64> } until { @@ -16,8 +16,8 @@ func.func @Ok(%arg0: !graphalg.mat<#dim x 1 x i64>) -> !graphalg.mat<1 x 1 x i64 func.func @IllegalRange() -> !graphalg.mat<1 x 1 x i64> { %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> - // expected-error@below{{'graphalg.for_dim' op attribute "dim" has value #graphalg.dim> which has not been marked as legal}} - %1 = graphalg.for_dim range(#dim) init(%0) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { + // expected-error@below{{'graphalg.for' op attribute "iters" has value #graphalg.dim> which has not been marked as legal}} + %1 = graphalg.for begin=0 iters=#dim init(%0) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<1 x 1 x i64>): graphalg.yield %arg2 : !graphalg.mat<1 x 1 x i64> } until {