Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion compiler/include/garel/GARelAttr.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def AggregateFunc : I64EnumAttr<
I64EnumAttrCase<"MAX", 2>,
I64EnumAttrCase<"LOR", 3>, /* Logical OR (over i1) */
I64EnumAttrCase<"ARGMIN", 4>,
I64EnumAttrCase<"COUNT", 5>,
]
> {
let cppNamespace = "::garel";
Expand All @@ -49,7 +50,7 @@ def Aggregator : GARel_Attr<"Aggregator", "aggregator"> {

let parameters = (ins
"AggregateFunc":$func,
ArrayRefParameter<"ColumnIdx">:$inputs);
OptionalArrayRefParameter<"ColumnIdx">:$inputs);

let assemblyFormat = [{
`<` $func $inputs `>`
Expand Down
2 changes: 1 addition & 1 deletion compiler/include/garel/GARelOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def ForOp : GARel_Op<"for", [InferTypeOpAdaptor]> {

let arguments = (ins
Variadic<Relation>:$init,
I64Attr:$iters,
I64Relation:$iters,
I64Attr:$resultIdx);

let regions = (region
Expand Down
3 changes: 3 additions & 0 deletions compiler/include/garel/GARelTypes.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/Types.h>

#include "garel/GARelAttr.h"
Expand All @@ -11,4 +12,6 @@ namespace garel {

bool isColumnType(mlir::Type t);

RelationType getI64RelationType(mlir::MLIRContext *ctx);

} // namespace garel
6 changes: 6 additions & 0 deletions compiler/include/garel/GARelTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,10 @@ def Tuple : GARel_Type<"Tuple", "tuple"> {

def ColumnType : Type<CPred<"::garel::isColumnType($_self)">, "column type">;

def I64Relation : Type<
CPred<"::garel::getI64RelationType($_self.getContext()) == $_self">,
"relation with a single i64 column",
"RelationType">,
BuildableType<"::garel::getI64RelationType($_builder.getContext())">;

#endif // GAREL_TYPES
65 changes: 26 additions & 39 deletions compiler/include/graphalg/GraphAlgOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -403,16 +403,22 @@ def BroadcastOp : Core_Op<"broadcast", [
let hasVerifier = 1;
}

// Not core according to spec, but we don't want to unroll in the general case.
def ForConstOp : Core_Op<"for_const", [
def ForOp : Core_Op<"for", [
Pure,
AllTypesMatch<["rangeBegin", "rangeEnd"]>,
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getEntrySuccessorOperands"]>]> {
let summary = "For loop with constant bounds";
let summary = "For loop with dynamic bounds";

let description = [{
A loop iterating over the integer range starting at `rangeBegin`
(inclusive) and ending at `rangeEnd` (exclusive).
A loop iterating over one of three ranges:
1) `dynBegin` (inclusive) to `dynEnd` (exclusive)
2) `begin` to `begin` + `iters`, where `iters` is an integer
3) `begin` to `begin` + `iters`, where `iters` is a matrix dimension

Only instances with range types 2 or 3 are considered part of GraphAlg
Core. Constant propagation is expected to transform range type 1 into
either 2 or 3.

The `body` region is executed once for every value in the integer range
(that value is passed as the first block argument).
At the first iteration of the loop, the other block arguments take the
Expand All @@ -432,58 +438,39 @@ def ForConstOp : Core_Op<"for_const", [

let arguments = (ins
Variadic<Matrix>:$initArgs,
I64Scalar:$rangeBegin,
I64Scalar:$rangeEnd);
Optional<I64Scalar>:$dynBegin,
Optional<I64Scalar>:$dynEnd,
OptionalAttr<I64Attr>:$begin,
OptionalAttr<DimAttr>:$iters);

let results = (outs Variadic<Matrix>:$results);

let regions = (region SizedRegion<1>:$body, MaxSizedRegion<1>:$until);

let assemblyFormat = [{
`range` `(`
$rangeBegin `,`
$rangeEnd
`)` `:` type($rangeEnd)
(`dyn_begin` `` `=` `` $dynBegin^)?
(`dyn_end` `` `=` `` $dynEnd^)?
(`begin` `` `=` `` $begin^)?
(`iters` `` `=` `` $iters^)?
`init` `(` $initArgs `)` `:` type($initArgs) `->` type($results) attr-dict
`body` $body
`until` $until
}];

let hasVerifier = 1;
let hasRegionVerifier = 1;
}

def ForDimOp : Core_Op<"for_dim", [
Pure,
DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getEntrySuccessorOperands"]>]> {
let summary = "For loop over a matrix dimension";

let description = [{
A loop iterating over the half-open range [0..`dim`).

This op is otherwise equivalent to `ForConstOp`.
}];

let arguments = (ins Variadic<Matrix>:$initArgs, DimAttr:$dim);

let results = (outs Variadic<Matrix>:$results);

let regions = (region SizedRegion<1>:$body, MaxSizedRegion<1>:$until);
let hasFolder = 1;

let assemblyFormat = [{
`range` `(` custom<BareAttr>($dim) `)`
`init` `(` $initArgs `)` `:` type($initArgs) `->` type($results) attr-dict
`body` $body
`until` $until
let extraClassDeclaration = [{
/** Whether at least one of `dyn_begin` and `dyn_end` is set. */
bool isDynamicRange();
}];

let hasRegionVerifier = 1;
let hasCanonicalizer = 1;
}

def YieldOp : Core_Op<"yield", [
Pure,
Terminator,
ParentOneOf<["ForConstOp", "ForDimOp"]>,
HasParent<"ForOp">,
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>]> {
let summary = "Yield from a loop body";

Expand Down
10 changes: 10 additions & 0 deletions compiler/include/graphalg/GraphAlgPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@

include "mlir/Pass/PassBase.td"

def GraphAlgSetConstArg : Pass<"graphalg-set-const-arg", "::mlir::ModuleOp"> {
let summary = "Propagate a constant integer argument into a function";

let options = [
Option<"functionName", "func", "std::string", /*default=*/"\"\"", "Name of the function to call">,
Option<"argumentNumber", "argNum", "int", /*default=*/"-1", "The argument number that is constant">,
Option<"value", "value", "std::int64_t", /*default=*/"0", "The value to propagate">,
];
}

def GraphAlgPrepareInline : Pass<"graphalg-prepare-inline", "::mlir::ModuleOp"> {
let summary = "Prepares the IR for function inlining";

Expand Down
31 changes: 20 additions & 11 deletions compiler/src/garel/GARelAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,32 @@ mlir::Type AggregatorAttr::getResultType(mlir::Type inputRel) {
case AggregateFunc::ARGMIN:
// NOTE: argmin(arg, val) also uses first input column as output type.
return llvm::cast<RelationType>(inputRel).getColumns()[getInputs()[0]];
case AggregateFunc::COUNT:
return mlir::IntegerType::get(inputRel.getContext(), 64);
}
}

static std::size_t expectedNumInputs(AggregateFunc f) {
switch (f) {
case AggregateFunc::SUM:
case AggregateFunc::MIN:
case AggregateFunc::MAX:
case AggregateFunc::LOR:
return 1;
case AggregateFunc::ARGMIN:
return 2;
case AggregateFunc::COUNT:
return 0;
}
}

mlir::LogicalResult
AggregatorAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
AggregateFunc func, llvm::ArrayRef<ColumnIdx> inputs) {
if (func == AggregateFunc::ARGMIN) {
if (inputs.size() != 2) {
return emitError() << stringifyAggregateFunc(func)
<< " expects exactly two inputs (arg, val), got "
<< inputs.size();
}
} else {
if (inputs.size() != 1) {
return emitError() << stringifyAggregateFunc(func)
<< " expects exactly one input, got " << inputs.size();
}
if (inputs.size() != expectedNumInputs(func)) {
return emitError() << stringifyAggregateFunc(func) << " expects exactly "
<< expectedNumInputs(func) << " inputs, got "
<< inputs.size();
}

return mlir::success();
Expand Down
6 changes: 6 additions & 0 deletions compiler/src/garel/GARelTypes.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <llvm/ADT/ArrayRef.h>
#include <llvm/ADT/TypeSwitch.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/OpImplementation.h>

Expand All @@ -18,6 +19,11 @@ bool isColumnType(mlir::Type t) {
t.isIndex();
}

RelationType getI64RelationType(mlir::MLIRContext *ctx) {
return RelationType::get(
ctx, mlir::ArrayRef<mlir::Type>{mlir::IntegerType::get(ctx, 64)});
}

// Need to define this here to avoid depending on IPRTypes in
// IPRDialect and creating a cycle.
void GARelDialect::registerTypes() {
Expand Down
Loading