forked from OSchip/llvm-project
[mlir][Linalg] Refactor StructuredOpInterface - NFC
This revision refactors and cleans up a bunch of things to simplify StructuredOpInterface before work can proceed on Linalg on tensors: - break out pieces of the StructuredOps trait that are part of the StructuredOpInterface, - drop referenceIterators and referenceIndexingMaps that end up being more confusing than useful, - drop NamedStructuredOpTrait
This commit is contained in:
parent
7527898fef
commit
e6f2f17f05
|
@ -130,21 +130,22 @@ def CopyOp : LinalgStructured_Op<"copy", [
|
|||
let extraClassDeclaration = libraryCallName # [{
|
||||
// Rank-polymorphic.
|
||||
// filling_value -> O(ivs) with parallel iterators.
|
||||
llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
|
||||
unsigned nPar = input().getType().cast<ShapedType>().getRank();
|
||||
return SmallVector<StringRef, 8>(nPar, getParallelIteratorTypeName());
|
||||
ArrayAttr iterator_types() {
|
||||
unsigned nPar = getInputShapedType(0).getRank();
|
||||
return Builder(getContext()).getStrArrayAttr(
|
||||
SmallVector<StringRef, 8>(nPar, getParallelIteratorTypeName()));
|
||||
}
|
||||
|
||||
// I(input_perm(ivs)) -> O(output_perm(ivs))
|
||||
llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
|
||||
ArrayAttr indexing_maps() {
|
||||
MLIRContext *context = getContext();
|
||||
auto maybeInputMap = inputPermutation();
|
||||
auto maybeOutputMap = outputPermutation();
|
||||
unsigned inputRank = getInputShapedType(0).getRank();
|
||||
unsigned outputRank = getOutputShapedType(0).getRank();
|
||||
return SmallVector<AffineMap, 8>{
|
||||
return Builder(getContext()).getAffineMapArrayAttr({
|
||||
extractOrIdentityMap(maybeInputMap, inputRank, context),
|
||||
extractOrIdentityMap(maybeOutputMap, outputRank, context)};
|
||||
extractOrIdentityMap(maybeOutputMap, outputRank, context)});
|
||||
}
|
||||
|
||||
Value getSource() { return input();}
|
||||
|
@ -163,16 +164,17 @@ def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
|
|||
let extraClassDeclaration = libraryCallName # [{
|
||||
// Rank-polymorphic.
|
||||
// filling_value -> O(ivs) with parallel iterators.
|
||||
llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
|
||||
unsigned nPar = output().getType().cast<ShapedType>().getRank();
|
||||
return SmallVector<StringRef, 8>(nPar, getParallelIteratorTypeName());
|
||||
ArrayAttr iterator_types() {
|
||||
unsigned nPar = getOutputShapedType(0).getRank();
|
||||
return Builder(getContext()).getStrArrayAttr(
|
||||
SmallVector<StringRef, 8>(nPar, getParallelIteratorTypeName()));
|
||||
}
|
||||
|
||||
llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
|
||||
ArrayAttr indexing_maps() {
|
||||
MLIRContext *context = getContext();
|
||||
// filling_value -> O(ivs)
|
||||
return SmallVector<AffineMap, 8>{
|
||||
extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)};
|
||||
return Builder(getContext()).getAffineMapArrayAttr({
|
||||
extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)});
|
||||
}
|
||||
}];
|
||||
|
||||
|
@ -295,7 +297,7 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> {
|
|||
getNumOutputFeatureDimensions();
|
||||
}
|
||||
|
||||
llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
|
||||
ArrayAttr iterator_types() {
|
||||
// Outer parallel loops are always the number of output dimensions; i.e.
|
||||
// [b, xs, q] in the TF notation above.
|
||||
unsigned nPar = getOutputShapedType(0).getRank();
|
||||
|
@ -310,7 +312,7 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> {
|
|||
iters.reserve(nPar + nRed + nWin);
|
||||
iters.append(nRed, getReductionIteratorTypeName());
|
||||
iters.append(nWin, getWindowIteratorTypeName());
|
||||
return iters;
|
||||
return Builder(getContext()).getStrArrayAttr(iters);
|
||||
}
|
||||
|
||||
// F(z0, ..., zN-1, q, k) *
|
||||
|
@ -318,7 +320,7 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> {
|
|||
// -> O(b, x0, ..., xN-1, k)
|
||||
// for N equal to `nWindow`. If there is no padding attribute, it will be
|
||||
// ignored.
|
||||
llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
|
||||
ArrayAttr indexing_maps() {
|
||||
MLIRContext *context = getContext();
|
||||
auto nWin = getNumWindowLoops();
|
||||
assert(nWin > 0 && "expected at least one window dimension");
|
||||
|
@ -343,7 +345,7 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> {
|
|||
auto zs = makeAffineDimExprs(nWin, idx, context);
|
||||
// Construct the weighedSum expression.
|
||||
auto ws = weightedPoolingInputIndex(*this, xs, zs);
|
||||
return SmallVector<AffineMap, 8>{
|
||||
return Builder(getContext()).getAffineMapArrayAttr({
|
||||
// filter[z[0], ..., z[N-1], q, k]
|
||||
AffineMap::get(idx, 0, concat(concat(zs, qs), ks), context),
|
||||
// input[b,
|
||||
|
@ -353,7 +355,7 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> {
|
|||
// q]
|
||||
AffineMap::get(idx, 0, concat(concat(bs, ws), qs), context),
|
||||
// output[b, x[0], ..., x[N-1], k]
|
||||
AffineMap::get(idx, 0, concat(concat(bs, xs), ks), context)};
|
||||
AffineMap::get(idx, 0, concat(concat(bs, xs), ks), context)});
|
||||
}
|
||||
}];
|
||||
|
||||
|
@ -384,7 +386,7 @@ class SingleInputPoolingBase_Op<string mnemonic>
|
|||
OptionalAttr<I64ElementsAttr>:$padding);
|
||||
|
||||
let extraClassDeclaration = commonUtils# [{
|
||||
llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
|
||||
ArrayAttr iterator_types() {
|
||||
// Outer parallel loops are always the number of output dimensions.
|
||||
unsigned nPar = getOutputShapedType(0).getRank();
|
||||
// The window loops has the same number loops with output dimensions.
|
||||
|
@ -392,10 +394,10 @@ class SingleInputPoolingBase_Op<string mnemonic>
|
|||
SmallVector<StringRef, 8> iters(nPar, getParallelIteratorTypeName());
|
||||
iters.reserve(nPar + nWin);
|
||||
iters.append(nWin, getWindowIteratorTypeName());
|
||||
return iters;
|
||||
return Builder(getContext()).getStrArrayAttr(iters);
|
||||
}
|
||||
|
||||
llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
|
||||
ArrayAttr indexing_maps() {
|
||||
MLIRContext *context = getContext();
|
||||
auto nPar = getNumParallelLoops();
|
||||
auto nWin = getNumWindowLoops();
|
||||
|
@ -406,14 +408,13 @@ class SingleInputPoolingBase_Op<string mnemonic>
|
|||
// Construct the weighedSum expression.
|
||||
auto inputDims =
|
||||
weightedPoolingInputIndex(*this, outputDims, windowDims);
|
||||
return SmallVector<AffineMap, 8>{
|
||||
return Builder(getContext()).getAffineMapArrayAttr({
|
||||
// input
|
||||
AffineMap::get(idx, 0, inputDims, context),
|
||||
// windowDims
|
||||
AffineMap::get(idx, 0, windowDims, context),
|
||||
// output
|
||||
AffineMap::get(idx, 0, outputDims, context)
|
||||
};
|
||||
AffineMap::get(idx, 0, outputDims, context)});
|
||||
}
|
||||
}];
|
||||
|
||||
|
@ -466,7 +467,7 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic,
|
|||
OptionalAttr<StrAttr>:$library_call,
|
||||
Confined<OptionalAttr<I64Attr>,
|
||||
[IntMinValue<0>]>:$symbol_source);
|
||||
let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
|
||||
let results = (outs Variadic<AnyRankedTensor>:$output_lis);
|
||||
let regions = (region AnyRegion:$region);
|
||||
let extraClassDeclaration = [{
|
||||
SmallVector<StringRef, 8> linalgTraitAttrNames() {
|
||||
|
@ -485,16 +486,6 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic,
|
|||
return library_call().hasValue() ? library_call().getValue() : "";
|
||||
}
|
||||
|
||||
llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
|
||||
llvm_unreachable(
|
||||
"No such thing as reference iterator types for a generic op.");
|
||||
}
|
||||
|
||||
llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
|
||||
llvm_unreachable(
|
||||
"No such thing as reference indexing maps for a generic op.");
|
||||
}
|
||||
|
||||
llvm::Optional<unsigned> getSymbolSource() {
|
||||
auto ss = symbol_source();
|
||||
return ss.hasValue() ?
|
||||
|
@ -807,8 +798,6 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
|
|||
// Named Linalg ops, implemented as a declarative configurations of generic ops.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def NamedStructuredOpTraits : NativeOpTrait<"linalg::NamedStructuredOpTraits">;
|
||||
|
||||
class LinalgNamedStructured_Op<string mnemonic, list<OpTrait> props>
|
||||
: LinalgStructuredBase_Op<mnemonic, props> {
|
||||
string spec = ?;
|
||||
|
|
|
@ -23,168 +23,486 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
|||
// Loop types handling.
|
||||
//===------------------------------------------------------------------===//
|
||||
InterfaceMethod<
|
||||
"Return the number of parallel loops within the current operation.",
|
||||
"unsigned", "getNumParallelLoops"
|
||||
/*desc=*/[{
|
||||
Return the number of parallel loops within the current operation.
|
||||
}],
|
||||
/*retTy=*/"unsigned",
|
||||
/*methodName=*/"getNumParallelLoops",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return getNumIterators(getParallelIteratorTypeName(),
|
||||
$_op.iterator_types());
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
"Return the number of reduction loops within the current operation.",
|
||||
"unsigned", "getNumReductionLoops"
|
||||
/*desc=*/[{
|
||||
Return the number of reduction loops within the current operation.
|
||||
}],
|
||||
/*retTy=*/"unsigned",
|
||||
/*methodName=*/"getNumReductionLoops",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return getNumIterators(getReductionIteratorTypeName(),
|
||||
$_op.iterator_types());
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
"Return the number of window loops within the current operation.",
|
||||
"unsigned", "getNumWindowLoops"
|
||||
/*desc=*/[{
|
||||
Return the number of window loops within the current operation.
|
||||
}],
|
||||
/*retTy=*/"unsigned",
|
||||
/*methodName=*/"getNumWindowLoops",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return getNumIterators(getWindowIteratorTypeName(),
|
||||
$_op.iterator_types());
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
"Return the number of loops within the current operation.",
|
||||
"unsigned", "getNumLoops">,
|
||||
|
||||
/*desc=*/[{
|
||||
Return the total number of loops within the current operation.
|
||||
}],
|
||||
/*retTy=*/"unsigned",
|
||||
/*methodName=*/"getNumLoops",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return getNumIterators($_op.iterator_types());
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
[{Returns true if the current operation has only one loop and it's a
|
||||
reduction loop}],
|
||||
"bool", "hasSingleReductionLoop">,
|
||||
|
||||
/*desc=*/[{
|
||||
Returns true if the current operation has only one loop and it's a
|
||||
reduction loop.
|
||||
}],
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"hasSingleReductionLoop",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
auto iters = $_op.iterator_types();
|
||||
return iters.size() == 1 &&
|
||||
getNumIterators(getReductionIteratorTypeName(), iters) == 1;
|
||||
}]>,
|
||||
//===------------------------------------------------------------------===//
|
||||
// Num input/output arguments handling.
|
||||
//===------------------------------------------------------------------===//
|
||||
// These special methods must be defined by each op that wants to implement
|
||||
// the LinalgStructuredInterface. For now, this is either:
|
||||
// - inherited statically by using the NInputs<unsigned> or
|
||||
// NOutputs<unsigned> traits.
|
||||
// - derived from args_in/args_out attributes (for linalg.generic and
|
||||
// linalg.indexed_generic ops).
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the number of inputs from the current operation.
|
||||
}],
|
||||
/*retTy=*/"unsigned",
|
||||
/*methodName=*/"getNumInputs"
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the number of outputs from the current operation.
|
||||
}],
|
||||
/*retTy=*/"unsigned",
|
||||
/*methodName=*/"getNumOutputs"
|
||||
>,
|
||||
//===------------------------------------------------------------------===//
|
||||
// Input arguments handling.
|
||||
//===------------------------------------------------------------------===//
|
||||
InterfaceMethod<
|
||||
"Return the number of inputs from the current operation.",
|
||||
"unsigned", "getNumInputs"
|
||||
/*desc=*/[{
|
||||
Return the `i`-th input value.
|
||||
The `i^th` input argument is always the `i^th` operand regardless of
|
||||
whether we have tensors or buffers.
|
||||
}],
|
||||
/*retTy=*/"Value",
|
||||
/*methodName=*/"getInput",
|
||||
/*args=*/(ins "unsigned":$i),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
assert(i < $_op.getNumInputs());
|
||||
return this->getOperation()->getOperand(i);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<"Return the input view at the given index.",
|
||||
"Value", "getInput", (ins "unsigned":$i)
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the index of the given input value `v`, or `None` if the value is
|
||||
not an input.
|
||||
}],
|
||||
"llvm::Optional<unsigned>", "getIndexOfInput", (ins "Value":$v)
|
||||
/*retTy=*/"llvm::Optional<unsigned>",
|
||||
/*methodName=*/"getIndexOfInput",
|
||||
/*args=*/(ins "Value":$value),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
auto it = llvm::find(getInputs(), value);
|
||||
if (it != getInputs().end())
|
||||
return it - getInputs().begin();
|
||||
return llvm::None;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
"Return the input operands from the current operation.",
|
||||
"Operation::operand_range", "getInputs"
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
/*desc=*/[{
|
||||
Return the `i`-th input shaped type, irrespective of buffer or tensor
|
||||
type.
|
||||
}], "ShapedType", "getInputShapedType", (ins "unsigned":$i)>,
|
||||
InterfaceMethod<[{
|
||||
}],
|
||||
/*retTy=*/"ShapedType",
|
||||
/*methodName=*/"getInputShapedType",
|
||||
/*args=*/(ins "unsigned":$i),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return getInput(i).getType().template cast<ShapedType>();
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the input operands from the current operation.
|
||||
}],
|
||||
/*retTy=*/"Operation::operand_range",
|
||||
/*methodName=*/"getInputs",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
auto range = this->getOperation()->getOperands();
|
||||
return {range.begin(), range.begin() + $_op.getNumInputs()};
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the subset of input operands that are of ranked tensor type.
|
||||
}], "SmallVector<RankedTensorType, 4>", "getInputTensorTypes">,
|
||||
}],
|
||||
/*retTy=*/"SmallVector<RankedTensorType, 4>",
|
||||
/*methodName=*/"getInputTensorTypes" ,
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
SmallVector<RankedTensorType, 4> res;
|
||||
for (Type type : getInputs().getTypes())
|
||||
if (auto t = type.template dyn_cast<RankedTensorType>())
|
||||
res.push_back(t);
|
||||
return res;
|
||||
}]
|
||||
>,
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// Output arguments handling.
|
||||
//===------------------------------------------------------------------===//
|
||||
InterfaceMethod<
|
||||
"Return the number of outputs from the current operation.",
|
||||
"unsigned", "getNumOutputs"
|
||||
/*desc=*/[{
|
||||
Return the output buffer at the given index, asserts that this is a
|
||||
buffer operand and not a tensor result.
|
||||
The `i^th` output argument is an operand (resp. a return value) iff it
|
||||
is a value of buffer type (resp. a return value of tensor type).
|
||||
}],
|
||||
/*retTy=*/"Value",
|
||||
/*methodName=*/"getOutputBuffer",
|
||||
/*args=*/(ins "unsigned":$i),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
// Output buffers are passed as output buffer operands (side-effecting).
|
||||
// Output tensors are results.
|
||||
// The union of the 2 are all the outputs and we want to ensure i does
|
||||
// not overflow the buffer operands.
|
||||
assert(i + this->getOperation()->getNumResults() < $_op.getNumOutputs()
|
||||
&& "overflowing output buffer index");
|
||||
return this->getOperation()->getOperand($_op.getNumInputs() + i);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<"Return the output buffer at the given index.",
|
||||
"Value", "getOutputBuffer", (ins "unsigned":$i)
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the index of the given buffer value, or `None` if the value is
|
||||
not part of the output buffers.
|
||||
}],
|
||||
"llvm::Optional<unsigned>", "getIndexOfOutputBuffer", (ins "Value":$view)
|
||||
/*retTy=*/"llvm::Optional<unsigned>",
|
||||
/*methodName=*/"getIndexOfOutputBuffer",
|
||||
/*args=*/(ins "Value":$value),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
auto it = llvm::find(getOutputBuffers(), value);
|
||||
if (it != getOutputBuffers().end())
|
||||
return it - getOutputBuffers().begin();
|
||||
return llvm::None;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the type of the output buffer at the given index.
|
||||
}], "MemRefType", "getOutputBufferType", (ins "unsigned":$i)>,
|
||||
InterfaceMethod<[{
|
||||
}],
|
||||
/*retTy=*/"MemRefType",
|
||||
/*methodName=*/"getOutputBufferType",
|
||||
/*args=*/(ins "unsigned":$i),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return getOutputBuffer(i).getType().template cast<MemRefType>();
|
||||
}]>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the `i`-th output shaped type, irrespective of buffer or tensor
|
||||
type.
|
||||
}], "ShapedType", "getOutputShapedType", (ins "unsigned":$i)>,
|
||||
InterfaceMethod<[{
|
||||
Return the results that are of ranked tensor type.
|
||||
}], "SmallVector<RankedTensorType, 4>", "getOutputTensorTypes">,
|
||||
}],
|
||||
/*retTy=*/"ShapedType",
|
||||
/*methodName=*/"getOutputShapedType",
|
||||
/*args=*/(ins "unsigned":$i),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return getShapedType(i + $_op.getNumInputs());
|
||||
}]>,
|
||||
InterfaceMethod<
|
||||
"Return the output buffers (operands) from the current operation.",
|
||||
"Operation::operand_range", "getOutputBuffers"
|
||||
/*desc=*/[{
|
||||
Return the results that are of ranked tensor type.
|
||||
}],
|
||||
/*retTy=*/"SmallVector<RankedTensorType, 4>",
|
||||
/*methodName=*/"getOutputTensorTypes",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
SmallVector<RankedTensorType, 4> res;
|
||||
for (Type type : this->getOperation()->getResults().getTypes())
|
||||
res.push_back(type.template cast<RankedTensorType>());
|
||||
return res;
|
||||
}]>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the output buffers (operands) from the current operation.
|
||||
}],
|
||||
/*retTy=*/"Operation::operand_range",
|
||||
/*methodName=*/"getOutputBuffers",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
auto range = this->getOperation()->getOperands();
|
||||
return {range.begin() + $_op.getNumInputs(),
|
||||
range.begin() + getNumInputsAndOutputBuffers()};
|
||||
}]
|
||||
>,
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// Input and Output arguments handling.
|
||||
//===------------------------------------------------------------------===//
|
||||
InterfaceMethod<
|
||||
"Return one single buffer at position `$i`.",
|
||||
"Value", "getBuffer", (ins "unsigned":$i)
|
||||
/*desc=*/[{
|
||||
Return one single buffer at position `$i`.
|
||||
}],
|
||||
/*retTy=*/"Value",
|
||||
/*methodName=*/"getBuffer",
|
||||
/*args=*/(ins "unsigned":$i),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
assert(i < getNumInputsAndOutputBuffers() && "overflowing buffers index");
|
||||
return this->getOperation()->getOperand(i);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
"Return the number of inputs and outputs, irrespective of their buffer "
|
||||
"or tensor type.",
|
||||
"unsigned", "getNumInputsAndOutputs"
|
||||
/*desc=*/[{
|
||||
Return the number of inputs and outputs, irrespective of their buffer or
|
||||
tensor type.
|
||||
}],
|
||||
/*retTy=*/"unsigned",
|
||||
/*methodName=*/"getNumInputsAndOutputs",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return $_op.getNumInputs() + $_op.getNumOutputs();
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
"Return the number of inputs, irrespective of their buffer or tensor "
|
||||
"type, and output buffers",
|
||||
"unsigned", "getNumInputsAndOutputBuffers"
|
||||
/*desc=*/[{
|
||||
Return the number of inputs, irrespective of their buffer or tensor type
|
||||
and output buffers
|
||||
}],
|
||||
/*retTy=*/"unsigned",
|
||||
/*methodName=*/"getNumInputsAndOutputBuffers",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return $_op.getNumInputs() + $_op.getNumOutputs() -
|
||||
this->getOperation()->getNumResults();
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
"Return the range over inputs (irrespective of type) and output buffers.",
|
||||
"Operation::operand_range", "getInputsAndOutputBuffers"
|
||||
/*desc=*/[{
|
||||
Return the range over inputs (irrespective of type) and output buffers.
|
||||
}],
|
||||
/*retTy=*/"Operation::operand_range",
|
||||
/*methodName=*/"getInputsAndOutputBuffers",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
auto range = this->getOperation()->getOperands();
|
||||
return {range.begin(), range.begin() + getNumInputsAndOutputBuffers()};
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
"Return the shaped types for all the inputs and outputs",
|
||||
"SmallVector<ShapedType, 4>", "getInputOutputShapedTypes"
|
||||
/*desc=*/[{
|
||||
Return the `i`-th shaped type, there are 3 cases:
|
||||
1. if `i < $_op.getNumInputs()` then return `getInputShapedType(i)`;
|
||||
otherwise
|
||||
2. if `i < getNumInputsAndOutputBuffers()` then return the
|
||||
`getOutputBufferType(i - $_op.getNumInputs())`; otherwise
|
||||
3. return the `i - getNumInputsAndOutputBuffers()` result type.
|
||||
}],
|
||||
/*retTy=*/"ShapedType",
|
||||
/*methodName=*/"getShapedType",
|
||||
/*args=*/(ins "unsigned":$i),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
if (i < $_op.getNumInputs())
|
||||
return getInputShapedType(i);
|
||||
if (i < getNumInputsAndOutputBuffers())
|
||||
return getOutputBufferType(i - $_op.getNumInputs());
|
||||
return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()];
|
||||
}]>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the shaped types for all the inputs and outputs
|
||||
}],
|
||||
/*retTy=*/"SmallVector<ShapedType, 4>",
|
||||
/*methodName=*/"getInputOutputShapedTypes",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
SmallVector<Type, 4> inputOutputTypes(
|
||||
this->getOperation()->operand_type_begin(),
|
||||
this->getOperation()->operand_type_end());
|
||||
inputOutputTypes.append(this->getOperation()->result_type_begin(),
|
||||
this->getOperation()->result_type_end());
|
||||
return llvm::to_vector<4>(
|
||||
llvm::map_range(inputOutputTypes, [](Type type) -> ShapedType {
|
||||
return type.cast<ShapedType>();
|
||||
}));
|
||||
}]
|
||||
>,
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// Other interface methods.
|
||||
//===------------------------------------------------------------------===//
|
||||
InterfaceMethod<
|
||||
"Return the reference iterators for this named op (if any are "
|
||||
"specified). These reference iterators are used to specify the default "
|
||||
"behavior of the op. Typically this would be a static method but in "
|
||||
"order to allow rank-polymorphic ops, this needs to be per object "
|
||||
"instance. Named ops must define referenceIterators, even if empty for "
|
||||
"the 0-D case. Generic ops on the other hand have a None "
|
||||
"`referenceIterators`",
|
||||
"llvm::Optional<SmallVector<StringRef, 8>>", "referenceIterators"
|
||||
/*desc=*/[{
|
||||
Return the iterator types attribute within the current operation.
|
||||
}],
|
||||
/*retTy=*/"ArrayAttr",
|
||||
/*methodName=*/"iterator_types",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return $_op.iterator_types();
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
"Return the reference indexing maps for this named op (if any are "
|
||||
"specified). Typically this would be a static method but in order to "
|
||||
"allow rank-polymorphic ops, this needs to be per object instance. Named "
|
||||
"ops must define referenceIterators, even if empty for the 0-D case. "
|
||||
"Generic ops on the other hand have a None `referenceIndexingMaps`",
|
||||
"llvm::Optional<SmallVector<AffineMap, 8>>", "referenceIndexingMaps"
|
||||
/*desc=*/[{
|
||||
Return the indexing maps attribute within the current operation.
|
||||
}],
|
||||
/*retTy=*/"ArrayAttr",
|
||||
/*methodName=*/"indexing_maps"
|
||||
>,
|
||||
InterfaceMethod<
|
||||
"Return the iterator types attribute within the current operation.",
|
||||
"ArrayAttr", "iterator_types"
|
||||
/*desc=*/[{
|
||||
Return the indexing maps within the current operation.
|
||||
}],
|
||||
/*retTy=*/"SmallVector<AffineMap, 4>",
|
||||
/*methodName=*/"getIndexingMaps",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return llvm::to_vector<4>(
|
||||
llvm::map_range($_op.indexing_maps(),
|
||||
[](Attribute attr) -> AffineMap {
|
||||
return attr.cast<AffineMapAttr>().getValue();
|
||||
}));
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
"Return the indexing maps attribute within the current operation.",
|
||||
"ArrayAttr", "indexing_maps"
|
||||
/*desc=*/[{
|
||||
Return the input or output indexing map at index `i`.
|
||||
}],
|
||||
/*retTy=*/"AffineMap",
|
||||
/*methodName=*/"getIndexingMap",
|
||||
/*args=*/(ins "unsigned":$i),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
assert(i < getNumInputsAndOutputs());
|
||||
return $_op.indexing_maps()
|
||||
.getValue()[i]
|
||||
.template cast<AffineMapAttr>()
|
||||
.getValue();
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
"Return the indexing maps within the current operation.",
|
||||
"SmallVector<AffineMap, 4>", "getIndexingMaps"
|
||||
/*desc=*/[{
|
||||
Return the input indexing map at index `i`.
|
||||
}],
|
||||
/*retTy=*/"AffineMap",
|
||||
/*methodName=*/"getInputIndexingMap",
|
||||
/*args=*/(ins "unsigned":$i),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
assert(i < $_op.getNumInputs());
|
||||
return $_op.indexing_maps()
|
||||
.getValue()[i]
|
||||
.template cast<AffineMapAttr>()
|
||||
.getValue();
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<"Return the input or output indexing map at index `i`.",
|
||||
"AffineMap", "getIndexingMap", (ins "unsigned":$i)
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the output indexing map at index `i`.
|
||||
}],
|
||||
/*retTy=*/"AffineMap",
|
||||
/*methodName=*/"getOutputIndexingMap",
|
||||
/*args=*/(ins "unsigned":$i),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
assert(i < $_op.getNumOutputs());
|
||||
return $_op.indexing_maps()
|
||||
.getValue()[i + $_op.getNumInputs()]
|
||||
.template cast<AffineMapAttr>()
|
||||
.getValue();
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<"Return the input indexing map at index `i`.",
|
||||
"AffineMap", "getInputIndexingMap", (ins "unsigned":$i)
|
||||
>,
|
||||
InterfaceMethod<"Return the output indexing map at index `i`.",
|
||||
"AffineMap", "getOutputIndexingMap", (ins "unsigned":$i)
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return whether the op has only MemRef input and outputs.
|
||||
}], "bool", "hasBufferSemantics">,
|
||||
InterfaceMethod<[{
|
||||
}],
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"hasBufferSemantics",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return this->getOperation()->getNumResults() == 0 &&
|
||||
llvm::all_of(getInputs(),
|
||||
[](Value v) { return v.getType().isa<MemRefType>(); });
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return whether the op has only RankedTensor input and outputs.
|
||||
}], "bool", "hasTensorSemantics">,
|
||||
}],
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"hasTensorSemantics",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
auto isTensorType = [](Value v) {
|
||||
return v.getType().isa<RankedTensorType>();
|
||||
};
|
||||
return llvm::all_of(getInputs(), isTensorType) &&
|
||||
llvm::all_of(this->getOperation()->getResults(), isTensorType);
|
||||
}]
|
||||
>,
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// Other static interface methods.
|
||||
//===------------------------------------------------------------------===//
|
||||
StaticInterfaceMethod<[{
|
||||
StaticInterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Create an operation of the current type with the given location,
|
||||
operands, and attributes.
|
||||
}],
|
||||
"Operation *", "create",
|
||||
/*retTy=*/"Operation *",
|
||||
/*methodName=*/"create",
|
||||
(ins "OpBuilder &":$builder, "Location":$loc,
|
||||
"ValueRange":$operands,
|
||||
"ArrayRef<NamedAttribute>":$attributes), [{
|
||||
|
@ -192,11 +510,13 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
|||
attributes);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Clone the current operation with the given location and operands. This
|
||||
is used to abstract away the optional underlying region creation.
|
||||
}],
|
||||
"Operation *", "clone",
|
||||
/*retTy=*/"Operation *",
|
||||
/*methodName=*/"clone",
|
||||
(ins "OpBuilder &":$b, "Location":$loc, "ValueRange":$operands), [{
|
||||
BlockAndValueMapping map;
|
||||
unsigned numRegions = $_op.getOperation()->getNumRegions();
|
||||
|
|
|
@ -49,8 +49,8 @@ public:
|
|||
};
|
||||
};
|
||||
|
||||
/// This class provides the API for structured ops that are known to operate on
|
||||
/// buffers or tensors. This trait must be used in conjunction with an op
|
||||
/// This class provides a verifier for structured ops that are known to operate
|
||||
/// on buffers or tensors. This trait must be used in conjunction with an op
|
||||
/// definition or a trait that provides the methods `getNumInputs` and
|
||||
/// `getNumOutputs`. Use as a trait as follows:
|
||||
///
|
||||
|
@ -59,324 +59,18 @@ public:
|
|||
template <typename ConcreteType>
|
||||
class StructuredOpTraits
|
||||
: public OpTrait::TraitBase<ConcreteType, StructuredOpTraits> {
|
||||
private:
|
||||
/// Return the number of inputs, irrespective of their buffer or tensor type.
|
||||
/// For internal use only.
|
||||
unsigned nInputs() {
|
||||
return cast<ConcreteType>(this->getOperation()).getNumInputs();
|
||||
}
|
||||
/// Return the number of outputs, irrespective of their buffer or tensor type.
|
||||
/// For internal use only.
|
||||
unsigned nOutputs() {
|
||||
return cast<ConcreteType>(this->getOperation()).getNumOutputs();
|
||||
}
|
||||
|
||||
public:
|
||||
//==========================================================================//
|
||||
// Loop types handling.
|
||||
//==========================================================================//
|
||||
unsigned getNumParallelLoops() {
|
||||
return getNumIterators(
|
||||
getParallelIteratorTypeName(),
|
||||
cast<ConcreteType>(this->getOperation()).iterator_types());
|
||||
}
|
||||
unsigned getNumReductionLoops() {
|
||||
return getNumIterators(
|
||||
getReductionIteratorTypeName(),
|
||||
cast<ConcreteType>(this->getOperation()).iterator_types());
|
||||
}
|
||||
unsigned getNumWindowLoops() {
|
||||
return getNumIterators(
|
||||
getWindowIteratorTypeName(),
|
||||
cast<ConcreteType>(this->getOperation()).iterator_types());
|
||||
}
|
||||
unsigned getNumLoops() {
|
||||
return getNumIterators(
|
||||
cast<ConcreteType>(this->getOperation()).iterator_types());
|
||||
}
|
||||
|
||||
bool hasSingleReductionLoop() {
|
||||
auto iterators = cast<ConcreteType>(this->getOperation()).iterator_types();
|
||||
return iterators.size() == 1 &&
|
||||
getNumIterators(getReductionIteratorTypeName(), iterators);
|
||||
}
|
||||
|
||||
//==========================================================================//
|
||||
// Input arguments handling.
|
||||
//==========================================================================//
|
||||
// The `i^th` input argument is always the `i^th` operand regardless of
|
||||
// whether we have tensors or buffers.
|
||||
//
|
||||
/// Return the `i`-th input value.
|
||||
Value getInput(unsigned i) {
|
||||
assert(i < nInputs());
|
||||
return this->getOperation()->getOperand(i);
|
||||
}
|
||||
/// Return the index of `value` in the list of inputs if found, llvm::None
|
||||
/// otherwise.
|
||||
Optional<unsigned> getIndexOfInput(Value value) {
|
||||
auto it = llvm::find(getInputs(), value);
|
||||
if (it != getInputs().end())
|
||||
return it - getInputs().begin();
|
||||
return llvm::None;
|
||||
}
|
||||
/// Return the `i`-th input shaped type, irrespective of buffer or tensor
|
||||
/// type.
|
||||
ShapedType getInputShapedType(unsigned i) {
|
||||
return getInput(i).getType().template cast<ShapedType>();
|
||||
}
|
||||
/// Return the range over inputs.
|
||||
Operation::operand_range getInputs() {
|
||||
auto range = this->getOperation()->getOperands();
|
||||
return {range.begin(), range.begin() + nInputs()};
|
||||
}
|
||||
/// Query the subset of input operands that are of ranked tensor type.
|
||||
SmallVector<RankedTensorType, 4> getInputTensorTypes() {
|
||||
SmallVector<RankedTensorType, 4> res;
|
||||
for (Type type : getInputs().getTypes())
|
||||
if (auto t = type.template dyn_cast<RankedTensorType>())
|
||||
res.push_back(t);
|
||||
return res;
|
||||
}
|
||||
|
||||
//==========================================================================//
|
||||
// Output arguments handling.
|
||||
//==========================================================================//
|
||||
// The `i^th` output argument is an operand (resp. a return value) iff it is
|
||||
// a value of buffer type (resp. a return value of tensor type).
|
||||
|
||||
/// Return the `i`-th output, asserts that this is a buffer operand and not
|
||||
/// a tensor result.
|
||||
Value getOutputBuffer(unsigned i) {
|
||||
assert(i + this->getOperation()->getNumResults() < nOutputs() &&
|
||||
"overflowing output buffer index");
|
||||
return this->getOperation()->getOperand(nInputs() + i);
|
||||
}
|
||||
/// Return the index of `value` in the list of output buffers if found,
|
||||
/// llvm::None otherwise.
|
||||
Optional<unsigned> getIndexOfOutputBuffer(Value value) {
|
||||
auto it = llvm::find(getOutputBuffers(), value);
|
||||
if (it != getOutputBuffers().end())
|
||||
return it - getOutputBuffers().begin();
|
||||
return llvm::None;
|
||||
}
|
||||
/// Return the `i`-th output buffer type.
|
||||
MemRefType getOutputBufferType(unsigned i) {
|
||||
return getOutputBuffer(i).getType().template cast<MemRefType>();
|
||||
}
|
||||
/// Return the `i`-th output shaped type, irrespective of buffer of tensor
|
||||
/// type.
|
||||
ShapedType getOutputShapedType(unsigned i) {
|
||||
return getShapedType(i + nInputs());
|
||||
}
|
||||
/// Query the subset of results that are of ranked tensor type.
|
||||
SmallVector<RankedTensorType, 4> getOutputTensorTypes() {
|
||||
SmallVector<RankedTensorType, 4> res;
|
||||
for (Type type : this->getOperation()->getResults().getTypes())
|
||||
res.push_back(type.template cast<RankedTensorType>());
|
||||
return res;
|
||||
}
|
||||
/// Return the range over outputs.
|
||||
Operation::operand_range getOutputBuffers() {
|
||||
auto range = this->getOperation()->getOperands();
|
||||
return {range.begin() + nInputs(),
|
||||
range.begin() + getNumInputsAndOutputBuffers()};
|
||||
}
|
||||
|
||||
//==========================================================================//
|
||||
// Input and Output arguments handling.
|
||||
//==========================================================================//
|
||||
Value getBuffer(unsigned i) {
|
||||
assert(i < getNumInputsAndOutputBuffers() && "overflowing buffers index");
|
||||
return this->getOperation()->getOperand(i);
|
||||
}
|
||||
/// Return the number of inputs and outputs, irrespective of their buffer or
|
||||
/// tensor type.
|
||||
unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); }
|
||||
/// Return the number of inputs, irrespective of their buffer or tensor type,
|
||||
/// and output buffers.
|
||||
unsigned getNumInputsAndOutputBuffers() {
|
||||
assert(this->getOperation()->getNumResults() <= nOutputs());
|
||||
return nInputs() + nOutputs() - this->getOperation()->getNumResults();
|
||||
}
|
||||
/// Return the range over inputs (irrespective of type) and output buffers.
|
||||
Operation::operand_range getInputsAndOutputBuffers() {
|
||||
auto range = this->getOperation()->getOperands();
|
||||
return {range.begin(), range.begin() + getNumInputsAndOutputBuffers()};
|
||||
}
|
||||
/// Return the `i`-th shaped type, there are 3 cases:
|
||||
/// 1. if `i < nInputs()` then return `getInputShapedType(i)`; otherwise
|
||||
/// 2. if `i < getNumInputsAndOutputBuffers()` then return the
|
||||
/// `getOutputBufferType(i - nInputs())`; otherwise
|
||||
/// 3. return the `i - getNumInputsAndOutputBuffers()` result type.
|
||||
ShapedType getShapedType(unsigned i) {
|
||||
if (i < nInputs())
|
||||
return getInputShapedType(i);
|
||||
if (i < getNumInputsAndOutputBuffers())
|
||||
return getOutputBufferType(i - nInputs()).template cast<ShapedType>();
|
||||
return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()]
|
||||
.template cast<ShapedType>();
|
||||
}
|
||||
/// Return the shaped types for all the inputs and outputs
|
||||
SmallVector<ShapedType, 4> getInputOutputShapedTypes() {
|
||||
SmallVector<Type, 4> inputOutputTypes(
|
||||
this->getOperation()->operand_type_begin(),
|
||||
this->getOperation()->operand_type_end());
|
||||
inputOutputTypes.append(this->getOperation()->result_type_begin(),
|
||||
this->getOperation()->result_type_end());
|
||||
return llvm::to_vector<4>(
|
||||
llvm::map_range(inputOutputTypes, [](Type type) -> ShapedType {
|
||||
return type.cast<ShapedType>();
|
||||
}));
|
||||
}
|
||||
|
||||
//==========================================================================//
|
||||
// Other interface methods.
|
||||
//==========================================================================//
|
||||
|
||||
// Get or build the indexing_maps ArrayAttr.
|
||||
ArrayAttr iterator_types() {
|
||||
// Return the attribute if it is present.
|
||||
if (auto attr = this->getOperation()->getAttr("iterator_types"))
|
||||
return attr.template cast<ArrayAttr>();
|
||||
|
||||
// If not, form the attribute using the reference iterator types for the
|
||||
// ConcreteType.
|
||||
auto maybeReferenceIteratorTypes =
|
||||
cast<ConcreteType>(this->getOperation()).referenceIterators();
|
||||
|
||||
// If there is no reference, this must be a generic op.
|
||||
// TODO: Traits are used to define ops. Split into cpp to avoid cyclic
|
||||
// dependency.
|
||||
auto name = this->getOperation()->getName().getStringRef();
|
||||
if (!maybeReferenceIteratorTypes && name != "generic" &&
|
||||
name != "indexed_generic") {
|
||||
this->getOperation()->dump();
|
||||
llvm_unreachable("Op missing referenceIterators");
|
||||
}
|
||||
|
||||
// If we have a reference, build the reference attribute and set it in the
|
||||
// op before returning.
|
||||
auto *ctx = this->getOperation()->getContext();
|
||||
auto attrRange = llvm::map_range(*maybeReferenceIteratorTypes,
|
||||
[ctx](StringRef str) -> Attribute {
|
||||
return StringAttr::get(str, ctx);
|
||||
});
|
||||
auto attr = ArrayAttr::get(llvm::to_vector<4>(attrRange), ctx);
|
||||
// TODO: Need to memoize this. Can't just store as an attribute atm as it
|
||||
// will impact parser, printer and tests.
|
||||
// this->getOperation()->setAttr("iterator_types", attr);
|
||||
return attr;
|
||||
}
|
||||
|
||||
// Get or build the indexing_maps ArrayAttr.
|
||||
ArrayAttr indexing_maps() {
|
||||
// Return the attribute if it is present.
|
||||
if (auto attr = this->getOperation()->getAttr("indexing_maps"))
|
||||
return attr.template cast<ArrayAttr>();
|
||||
|
||||
// If not, form the attribute using the reference indexing map for the
|
||||
// ConcreteType.
|
||||
auto maybeReferenceIndexingMaps =
|
||||
cast<ConcreteType>(this->getOperation()).referenceIndexingMaps();
|
||||
|
||||
// If there is no reference, this must be a generic op.
|
||||
auto name = this->getOperation()->getName().getStringRef();
|
||||
if (!maybeReferenceIndexingMaps && name != "generic" &&
|
||||
name != "indexed_generic") {
|
||||
this->getOperation()->dump();
|
||||
llvm_unreachable("Op missing referenceIndexingMaps");
|
||||
}
|
||||
|
||||
// If we have a reference, build the reference attribute and set it in the
|
||||
// op before returning.
|
||||
auto *ctx = this->getOperation()->getContext();
|
||||
auto attrRange =
|
||||
llvm::map_range(*maybeReferenceIndexingMaps, [ctx](AffineMap map) {
|
||||
// 0-D corner case because there is no such thing as a concrete empty
|
||||
// map type.
|
||||
if (!map)
|
||||
map = AffineMap::get(0, 0, getAffineConstantExpr(0, ctx));
|
||||
return AffineMapAttr::get(map);
|
||||
});
|
||||
SmallVector<Attribute, 4> attrs{attrRange.begin(), attrRange.end()};
|
||||
auto attr = ArrayAttr::get(attrs, ctx);
|
||||
// TODO: Need to memoize this. Can't just store as an attribute atm as it
|
||||
// will impact parser, printer and tests.
|
||||
// this->getOperation()->setAttr("indexing_maps", attr);
|
||||
return attr;
|
||||
}
|
||||
|
||||
SmallVector<AffineMap, 4> getIndexingMaps() {
|
||||
return llvm::to_vector<4>(
|
||||
llvm::map_range(indexing_maps(), [](Attribute attr) -> AffineMap {
|
||||
return attr.cast<AffineMapAttr>().getValue();
|
||||
}));
|
||||
}
|
||||
|
||||
AffineMap getIndexingMap(unsigned i) {
|
||||
assert(i < getNumInputsAndOutputs());
|
||||
return indexing_maps()
|
||||
.getValue()[i]
|
||||
.template cast<AffineMapAttr>()
|
||||
.getValue();
|
||||
}
|
||||
|
||||
AffineMap getInputIndexingMap(unsigned i) {
|
||||
assert(i < nInputs());
|
||||
return indexing_maps()
|
||||
.getValue()[i]
|
||||
.template cast<AffineMapAttr>()
|
||||
.getValue();
|
||||
}
|
||||
|
||||
AffineMap getOutputIndexingMap(unsigned i) {
|
||||
assert(i < nOutputs());
|
||||
return indexing_maps()
|
||||
.getValue()[i + nInputs()]
|
||||
.template cast<AffineMapAttr>()
|
||||
.getValue();
|
||||
}
|
||||
|
||||
/// Query whether the op has only buffer inputs and no returns.
|
||||
bool hasBufferSemantics() {
|
||||
return this->getOperation()->getNumResults() == 0 &&
|
||||
llvm::all_of(getInputs(),
|
||||
[](Value v) { return v.getType().isa<MemRefType>(); });
|
||||
}
|
||||
|
||||
/// Query whether the op has only tensor inputs and outputs.
|
||||
bool hasTensorSemantics() {
|
||||
auto isTensorType = [](Value v) {
|
||||
return v.getType().isa<RankedTensorType>();
|
||||
};
|
||||
return llvm::all_of(getInputs(), isTensorType) &&
|
||||
llvm::all_of(this->getOperation()->getResults(), isTensorType);
|
||||
}
|
||||
|
||||
//==========================================================================//
|
||||
// Other static interface methods.
|
||||
//==========================================================================//
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
ConcreteType concreteOp = cast<ConcreteType>(op);
|
||||
auto nOperands = cast<ConcreteType>(op).getNumInputsAndOutputBuffers();
|
||||
if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nOperands)))
|
||||
return failure();
|
||||
if (op->getNumResults() > concreteOp.getNumOutputs())
|
||||
return op->emitError("unexpected #results > #outputs");
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// This class provides the API for named Linalg StructuredOps.
|
||||
template <typename ConcreteType>
|
||||
class NamedStructuredOpTraits
|
||||
: public OpTrait::TraitBase<ConcreteType, NamedStructuredOpTraits> {
|
||||
public:
|
||||
static SmallVector<StringRef, 8> referenceIterators(TypeRange inputTypes,
|
||||
TypeRange outputTypes);
|
||||
|
||||
static SmallVector<AffineMap, 8> referenceIndexingMaps(TypeRange inputTypes,
|
||||
TypeRange outputTypes);
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
|
|
@ -260,13 +260,14 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
|
|||
if (failed(BlockArgsVerifier<GenericOpType>::verify(op, region.front())))
|
||||
return failure();
|
||||
|
||||
auto attr = op.template getAttrOfType<IntegerAttr>("symbol_source");
|
||||
int64_t targetRank = 0;
|
||||
if (attr) {
|
||||
unsigned index = attr.getInt();
|
||||
auto symbolSourceAttr =
|
||||
op.template getAttrOfType<IntegerAttr>("symbol_source");
|
||||
int64_t expectedNumSymbols = 0;
|
||||
if (symbolSourceAttr) {
|
||||
unsigned index = symbolSourceAttr.getInt();
|
||||
if (index >= op.getNumOperands())
|
||||
return op.emitOpError("symbol_source index out of range");
|
||||
targetRank = op.getShapedType(index).getRank();
|
||||
expectedNumSymbols = op.getShapedType(index).getRank();
|
||||
}
|
||||
|
||||
SmallVector<AffineMap, 4> indexingMaps;
|
||||
|
@ -278,9 +279,9 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
|
|||
auto view = (idx < nInputViews) ? op.getInputShapedType(idx)
|
||||
: op.getOutputShapedType(idx - nInputViews);
|
||||
|
||||
if (m.getNumSymbols() != targetRank)
|
||||
if (m.getNumSymbols() != expectedNumSymbols)
|
||||
return op.emitOpError("expected the number of symbols in indexing_map #")
|
||||
<< idx << " to match target rank";
|
||||
<< idx << " to match rank of operand `symbol_source`";
|
||||
|
||||
if (m.getNumDims() != nLoops)
|
||||
return op.emitOpError("expected indexing_map #")
|
||||
|
@ -1246,15 +1247,9 @@ void buildNamedStructuredOpRegionAndAttributes(Builder &builder,
|
|||
mlir::edsc::ScopedContext scope(opBuilder, builder.getUnknownLoc());
|
||||
NamedStructuredOpType::regionBuilder(*body);
|
||||
|
||||
auto indexingMaps = builder.getAffineMapArrayAttr(
|
||||
NamedStructuredOpType::referenceIndexingMaps(operandTypes,
|
||||
tensorResultTypes));
|
||||
result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
|
||||
// indexing_maps is an auto-generated method.
|
||||
|
||||
auto iterators =
|
||||
builder.getStrArrayAttr(NamedStructuredOpType::referenceIterators(
|
||||
operandTypes, tensorResultTypes));
|
||||
result.addAttribute(getIteratorTypesAttrName(), iterators);
|
||||
// iterator_types is an auto-generated method.
|
||||
}
|
||||
|
||||
template <typename NamedStructuredOpType>
|
||||
|
|
|
@ -113,7 +113,7 @@ func @generic_mismatched_num_returns(%arg0: memref<f32>) {
|
|||
// -----
|
||||
|
||||
func @generic_symbol_in_map(%arg0: memref<i32>) {
|
||||
// expected-error @+1 {{expected the number of symbols in indexing_map #0 to match target rank}}
|
||||
// expected-error @+1 {{expected the number of symbols in indexing_map #0 to match rank of operand `symbol_source`}}
|
||||
linalg.generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
|
@ -514,3 +514,20 @@ func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?xf32>, %c3: memref<?x?x?x
|
|||
linalg.batch_matmul %a3, %b3, %c3 : (memref<?x?x?xf32>, memref<?x?xf32>, memref<?x?x?xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @generic(%arg0: tensor<?x?xi4>) {
|
||||
// expected-error @+1 {{unexpected #results > #outputs}}
|
||||
linalg.generic {
|
||||
args_in = 1,
|
||||
args_out = 1,
|
||||
indexing_maps = [ affine_map<(i) -> (i)> ],
|
||||
iterator_types = ["parallel"]
|
||||
} %arg0 {
|
||||
^bb(%0: i4) :
|
||||
%1 = std.addi %0, %0: i4
|
||||
linalg.yield %1, %1: i4, i4
|
||||
} : tensor<?x?xi4> -> (tensor<?x?xi4>, tensor<?x?xi4>)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -4,16 +4,15 @@
|
|||
// ODS-LABEL: def Test1Op : LinalgNamedStructured_Op<"test1", [
|
||||
// ODS-NEXT: NInputs<2>
|
||||
// ODS-NEXT: NOutputs<1>
|
||||
// ODS-NEXT: NamedStructuredOpTraits
|
||||
// ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp">
|
||||
//
|
||||
// IMPL-LABEL: SmallVector<StringRef, 8> Test1Op::referenceIterators
|
||||
// IMPL-LABEL: ArrayAttr Test1Op::iterator_types() {
|
||||
// IMPL: { {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
|
||||
//
|
||||
// IMPL: SmallVector<AffineMap, 8> Test1Op::referenceIndexingMaps
|
||||
// IMPL: ArrayAttr Test1Op::indexing_maps() {
|
||||
// IMPL: AffineMap::get(2, 0, {d0, d1}, context),
|
||||
// IMPL-NEXT: AffineMap::get(2, 0, {d1}, context),
|
||||
// IMPL-NEXT: AffineMap::get(2, 0, {d0}, context) };
|
||||
// IMPL-NEXT: AffineMap::get(2, 0, {d0}, context) });
|
||||
//
|
||||
// IMPL: void Test1Op::regionBuilder(Block &block) {
|
||||
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
|
||||
|
@ -29,16 +28,15 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
|
|||
// ODS-LABEL: def Test2Op : LinalgNamedStructured_Op<"test2", [
|
||||
// ODS-NEXT: NInputs<2>
|
||||
// ODS-NEXT: NOutputs<1>
|
||||
// ODS-NEXT: NamedStructuredOpTraits
|
||||
// ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp">
|
||||
//
|
||||
// IMPL-LABEL: SmallVector<StringRef, 8> Test2Op::referenceIterators
|
||||
// IMPL-LABEL: ArrayAttr Test2Op::iterator_types() {
|
||||
// IMPL: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
|
||||
//
|
||||
// IMPL: SmallVector<AffineMap, 8> Test2Op::referenceIndexingMaps
|
||||
// IMPL: ArrayAttr Test2Op::indexing_maps() {
|
||||
// IMPL: AffineMap::get(3, 0, {d0, d2}, context),
|
||||
// IMPL-NEXT: AffineMap::get(3, 0, {d2, d1}, context),
|
||||
// IMPL-NEXT: AffineMap::get(3, 0, {d0, d1}, context) };
|
||||
// IMPL-NEXT: AffineMap::get(3, 0, {d0, d1}, context) });
|
||||
//
|
||||
// IMPL: Test2Op::regionBuilder(Block &block) {
|
||||
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
|
||||
|
@ -54,16 +52,15 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
|
|||
// ODS-LABEL: def Test3Op : LinalgNamedStructured_Op<"test3", [
|
||||
// ODS-NEXT: NInputs<2>
|
||||
// ODS-NEXT: NOutputs<1>
|
||||
// ODS-NEXT: NamedStructuredOpTraits
|
||||
// ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp">
|
||||
//
|
||||
// IMPL-LABEL: SmallVector<StringRef, 8> Test3Op::referenceIterators
|
||||
// IMPL-LABEL: ArrayAttr Test3Op::iterator_types() {
|
||||
// IMPL: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
|
||||
//
|
||||
// IMPL: SmallVector<AffineMap, 8> Test3Op::referenceIndexingMaps
|
||||
// IMPL: ArrayAttr Test3Op::indexing_maps() {
|
||||
// IMPL: AffineMap::get(4, 0, {d0, d1, d3}, context),
|
||||
// IMPL-NEXT: AffineMap::get(4, 0, {d3, d2}, context),
|
||||
// IMPL-NEXT: AffineMap::get(4, 0, {d0, d1, d2}, context) };
|
||||
// IMPL-NEXT: AffineMap::get(4, 0, {d0, d1, d2}, context) });
|
||||
//
|
||||
// IMPL: Test3Op::regionBuilder(Block &block) {
|
||||
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
|
||||
|
|
|
@ -974,19 +974,19 @@ public:
|
|||
/// Parse and print the information for a TC def.
|
||||
/// When `gen-ods-decl` is used, this prints the ODS declaration for the TC.
|
||||
/// When `gen-impl` is used, this prints the C++ implementation for the extra
|
||||
/// methods defined in ODS (referenceIterators, referenceIndexingMaps and
|
||||
/// regionBuilder).
|
||||
/// methods defined in ODS (`iterator_types`, `indexing_maps` and
|
||||
/// `regionBuilder`).
|
||||
LogicalResult parseAndEmitODSDef(llvm::raw_ostream &os);
|
||||
|
||||
/// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
|
||||
void printODS(llvm::raw_ostream &os, StringRef cppOpName,
|
||||
StringRef linalgOpName);
|
||||
|
||||
/// Print the C++ StructuredOpsInterface impl of `referenceIterators`.
|
||||
/// Print the C++ StructuredOpsInterface impl of `iterator_types`.
|
||||
void printReferenceIterators(llvm::raw_ostream &os, StringRef cppOpName,
|
||||
ComprehensionParsingState &state);
|
||||
|
||||
/// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`.
|
||||
/// Print the C++ StructuredOpsInterface impl of `indexing_maps`.
|
||||
void printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef cppOpName,
|
||||
ComprehensionParsingState &state);
|
||||
|
||||
|
@ -1446,7 +1446,6 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
|
|||
const char *header = R"FMT( def {0} : LinalgNamedStructured_Op<"{1}", [
|
||||
NInputs<{2}>,
|
||||
NOutputs<{3}>,
|
||||
NamedStructuredOpTraits,
|
||||
SingleBlockImplicitTerminator<"YieldOp">]> {
|
||||
let arguments = (ins Variadic<LinalgOperand>:$views);
|
||||
let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
|
||||
|
@ -1465,16 +1464,9 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
|
|||
return ::parseNamedStructuredOp<{0}>(parser, result);
|
||||
}];
|
||||
let extraClassDeclaration = [{{
|
||||
llvm::Optional<SmallVector<StringRef, 8>> referenceIterators();
|
||||
static SmallVector<StringRef, 8> referenceIterators(
|
||||
TypeRange inputTypes, TypeRange outputTypes);
|
||||
|
||||
llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps();
|
||||
static SmallVector<AffineMap, 8> referenceIndexingMaps(
|
||||
TypeRange inputTypes, TypeRange outputTypes);
|
||||
|
||||
ArrayAttr iterator_types();
|
||||
ArrayAttr indexing_maps();
|
||||
static void regionBuilder(Block &block);
|
||||
|
||||
std::string getLibraryCallName() {{
|
||||
return generateLibraryCallName(getOperation());
|
||||
}
|
||||
|
@ -1492,20 +1484,14 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
|
|||
os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs);
|
||||
}
|
||||
|
||||
/// Print the C++ StructuredOpsInterface impl of `referenceIterators`.
|
||||
/// Print the C++ StructuredOpsInterface impl of `iterator_types`.
|
||||
void TCParser::printReferenceIterators(llvm::raw_ostream &os,
|
||||
StringRef cppOpName,
|
||||
ComprehensionParsingState &state) {
|
||||
const char *referenceReferenceIteratorsFmt =
|
||||
R"FMT(
|
||||
// This is temporary until we transition out of manually specified ops
|
||||
// that should be auto-generated with linalg-ods-gen.
|
||||
llvm::Optional<SmallVector<StringRef, 8>> {0}::referenceIterators() {{
|
||||
llvm_unreachable("Unexpected missing `iterator_types` attribute.");
|
||||
}
|
||||
SmallVector<StringRef, 8> {0}::referenceIterators(
|
||||
TypeRange inputTypes, TypeRange outputTypes) {
|
||||
return SmallVector<StringRef, 8>{{ {1} };
|
||||
ArrayAttr {0}::iterator_types() {
|
||||
return Builder(getContext()).getStrArrayAttr(SmallVector<StringRef, 8>{{ {1} });
|
||||
})FMT";
|
||||
|
||||
std::string iteratorsStr;
|
||||
|
@ -1542,16 +1528,11 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
|
|||
R"FMT(
|
||||
// This is temporary until we transition out of manually specified ops that
|
||||
// should be auto-generated with linalg-ods-gen.
|
||||
llvm::Optional<SmallVector<AffineMap, 8>> {0}::referenceIndexingMaps() {{
|
||||
llvm_unreachable("Unexpected missing `indexing_maps` attribute.");
|
||||
}
|
||||
SmallVector<AffineMap, 8> {0}::referenceIndexingMaps(
|
||||
TypeRange inputTypes, TypeRange outputTypes) {
|
||||
assert(!inputTypes.empty() && "At least one input expected");
|
||||
MLIRContext *context = (*inputTypes.begin()).getContext();
|
||||
ArrayAttr {0}::indexing_maps() {
|
||||
MLIRContext *context = getContext();
|
||||
AffineExpr {1};
|
||||
bindDims(context, {1});
|
||||
return SmallVector<AffineMap, 8>{{ {2} };
|
||||
return Builder(context).getAffineMapArrayAttr({ {2} });
|
||||
})FMT";
|
||||
|
||||
// 2. Print a comma-separated list of identifiers for the AffineExpr in
|
||||
|
|
Loading…
Reference in New Issue