[mlir][Linalg] Refactor StructuredOpInterface - NFC

This revision refactors and cleans up a bunch of things to simplify StructuredOpInterface
before work can proceed on Linalg on tensors:
- break out pieces of the StructuredOps trait that are part of the StructuredOpInterface,
- drop referenceIterators and referenceIndexingMaps that end up being more confusing than useful,
- drop NamedStructuredOpTrait
This commit is contained in:
Nicolas Vasilache 2020-09-11 06:19:07 -04:00
parent 7527898fef
commit e6f2f17f05
7 changed files with 490 additions and 497 deletions

View File

@ -130,21 +130,22 @@ def CopyOp : LinalgStructured_Op<"copy", [
let extraClassDeclaration = libraryCallName # [{
// Rank-polymorphic.
// filling_value -> O(ivs) with parallel iterators.
llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
unsigned nPar = input().getType().cast<ShapedType>().getRank();
return SmallVector<StringRef, 8>(nPar, getParallelIteratorTypeName());
ArrayAttr iterator_types() {
unsigned nPar = getInputShapedType(0).getRank();
return Builder(getContext()).getStrArrayAttr(
SmallVector<StringRef, 8>(nPar, getParallelIteratorTypeName()));
}
// I(input_perm(ivs)) -> O(output_perm(ivs))
llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
ArrayAttr indexing_maps() {
MLIRContext *context = getContext();
auto maybeInputMap = inputPermutation();
auto maybeOutputMap = outputPermutation();
unsigned inputRank = getInputShapedType(0).getRank();
unsigned outputRank = getOutputShapedType(0).getRank();
return SmallVector<AffineMap, 8>{
return Builder(getContext()).getAffineMapArrayAttr({
extractOrIdentityMap(maybeInputMap, inputRank, context),
extractOrIdentityMap(maybeOutputMap, outputRank, context)};
extractOrIdentityMap(maybeOutputMap, outputRank, context)});
}
Value getSource() { return input();}
@ -163,16 +164,17 @@ def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
let extraClassDeclaration = libraryCallName # [{
// Rank-polymorphic.
// filling_value -> O(ivs) with parallel iterators.
llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
unsigned nPar = output().getType().cast<ShapedType>().getRank();
return SmallVector<StringRef, 8>(nPar, getParallelIteratorTypeName());
ArrayAttr iterator_types() {
unsigned nPar = getOutputShapedType(0).getRank();
return Builder(getContext()).getStrArrayAttr(
SmallVector<StringRef, 8>(nPar, getParallelIteratorTypeName()));
}
llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
ArrayAttr indexing_maps() {
MLIRContext *context = getContext();
// filling_value -> O(ivs)
return SmallVector<AffineMap, 8>{
extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)};
return Builder(getContext()).getAffineMapArrayAttr({
extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)});
}
}];
@ -295,7 +297,7 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> {
getNumOutputFeatureDimensions();
}
llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
ArrayAttr iterator_types() {
// Outer parallel loops are always the number of output dimensions; i.e.
// [b, xs, q] in the TF notation above.
unsigned nPar = getOutputShapedType(0).getRank();
@ -310,7 +312,7 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> {
iters.reserve(nPar + nRed + nWin);
iters.append(nRed, getReductionIteratorTypeName());
iters.append(nWin, getWindowIteratorTypeName());
return iters;
return Builder(getContext()).getStrArrayAttr(iters);
}
// F(z0, ..., zN-1, q, k) *
@ -318,7 +320,7 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> {
// -> O(b, x0, ..., xN-1, k)
// for N equal to `nWindow`. If there is no padding attribute, it will be
// ignored.
llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
ArrayAttr indexing_maps() {
MLIRContext *context = getContext();
auto nWin = getNumWindowLoops();
assert(nWin > 0 && "expected at least one window dimension");
@ -343,7 +345,7 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> {
auto zs = makeAffineDimExprs(nWin, idx, context);
// Construct the weighedSum expression.
auto ws = weightedPoolingInputIndex(*this, xs, zs);
return SmallVector<AffineMap, 8>{
return Builder(getContext()).getAffineMapArrayAttr({
// filter[z[0], ..., z[N-1], q, k]
AffineMap::get(idx, 0, concat(concat(zs, qs), ks), context),
// input[b,
@ -353,7 +355,7 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> {
// q]
AffineMap::get(idx, 0, concat(concat(bs, ws), qs), context),
// output[b, x[0], ..., x[N-1], k]
AffineMap::get(idx, 0, concat(concat(bs, xs), ks), context)};
AffineMap::get(idx, 0, concat(concat(bs, xs), ks), context)});
}
}];
@ -384,7 +386,7 @@ class SingleInputPoolingBase_Op<string mnemonic>
OptionalAttr<I64ElementsAttr>:$padding);
let extraClassDeclaration = commonUtils# [{
llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
ArrayAttr iterator_types() {
// Outer parallel loops are always the number of output dimensions.
unsigned nPar = getOutputShapedType(0).getRank();
// The window loops has the same number loops with output dimensions.
@ -392,10 +394,10 @@ class SingleInputPoolingBase_Op<string mnemonic>
SmallVector<StringRef, 8> iters(nPar, getParallelIteratorTypeName());
iters.reserve(nPar + nWin);
iters.append(nWin, getWindowIteratorTypeName());
return iters;
return Builder(getContext()).getStrArrayAttr(iters);
}
llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
ArrayAttr indexing_maps() {
MLIRContext *context = getContext();
auto nPar = getNumParallelLoops();
auto nWin = getNumWindowLoops();
@ -406,14 +408,13 @@ class SingleInputPoolingBase_Op<string mnemonic>
// Construct the weighedSum expression.
auto inputDims =
weightedPoolingInputIndex(*this, outputDims, windowDims);
return SmallVector<AffineMap, 8>{
return Builder(getContext()).getAffineMapArrayAttr({
// input
AffineMap::get(idx, 0, inputDims, context),
// windowDims
AffineMap::get(idx, 0, windowDims, context),
// output
AffineMap::get(idx, 0, outputDims, context)
};
AffineMap::get(idx, 0, outputDims, context)});
}
}];
@ -466,7 +467,7 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic,
OptionalAttr<StrAttr>:$library_call,
Confined<OptionalAttr<I64Attr>,
[IntMinValue<0>]>:$symbol_source);
let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
let results = (outs Variadic<AnyRankedTensor>:$output_lis);
let regions = (region AnyRegion:$region);
let extraClassDeclaration = [{
SmallVector<StringRef, 8> linalgTraitAttrNames() {
@ -485,16 +486,6 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic,
return library_call().hasValue() ? library_call().getValue() : "";
}
llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
llvm_unreachable(
"No such thing as reference iterator types for a generic op.");
}
llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
llvm_unreachable(
"No such thing as reference indexing maps for a generic op.");
}
llvm::Optional<unsigned> getSymbolSource() {
auto ss = symbol_source();
return ss.hasValue() ?
@ -807,8 +798,6 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//
def NamedStructuredOpTraits : NativeOpTrait<"linalg::NamedStructuredOpTraits">;
class LinalgNamedStructured_Op<string mnemonic, list<OpTrait> props>
: LinalgStructuredBase_Op<mnemonic, props> {
string spec = ?;

View File

@ -23,168 +23,486 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
// Loop types handling.
//===------------------------------------------------------------------===//
InterfaceMethod<
"Return the number of parallel loops within the current operation.",
"unsigned", "getNumParallelLoops"
/*desc=*/[{
Return the number of parallel loops within the current operation.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getNumParallelLoops",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return getNumIterators(getParallelIteratorTypeName(),
$_op.iterator_types());
}]
>,
InterfaceMethod<
"Return the number of reduction loops within the current operation.",
"unsigned", "getNumReductionLoops"
/*desc=*/[{
Return the number of reduction loops within the current operation.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getNumReductionLoops",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return getNumIterators(getReductionIteratorTypeName(),
$_op.iterator_types());
}]
>,
InterfaceMethod<
"Return the number of window loops within the current operation.",
"unsigned", "getNumWindowLoops"
/*desc=*/[{
Return the number of window loops within the current operation.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getNumWindowLoops",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return getNumIterators(getWindowIteratorTypeName(),
$_op.iterator_types());
}]
>,
InterfaceMethod<
"Return the number of loops within the current operation.",
"unsigned", "getNumLoops">,
/*desc=*/[{
Return the total number of loops within the current operation.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getNumLoops",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return getNumIterators($_op.iterator_types());
}]
>,
InterfaceMethod<
[{Returns true if the current operation has only one loop and it's a
reduction loop}],
"bool", "hasSingleReductionLoop">,
/*desc=*/[{
Returns true if the current operation has only one loop and it's a
reduction loop.
}],
/*retTy=*/"bool",
/*methodName=*/"hasSingleReductionLoop",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto iters = $_op.iterator_types();
return iters.size() == 1 &&
getNumIterators(getReductionIteratorTypeName(), iters) == 1;
}]>,
//===------------------------------------------------------------------===//
// Num input/output arguments handling.
//===------------------------------------------------------------------===//
// These special methods must be defined by each op that wants to implement
// the LinalgStructuredInterface. For now, this is either:
// - inherited statically by using the NInputs<unsigned> or
// NOutputs<unsigned> traits.
// - derived from args_in/args_out attributes (for linalg.generic and
// linalg.indexed_generic ops).
InterfaceMethod<
/*desc=*/[{
Return the number of inputs from the current operation.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getNumInputs"
>,
InterfaceMethod<
/*desc=*/[{
Return the number of outputs from the current operation.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getNumOutputs"
>,
//===------------------------------------------------------------------===//
// Input arguments handling.
//===------------------------------------------------------------------===//
InterfaceMethod<
"Return the number of inputs from the current operation.",
"unsigned", "getNumInputs"
/*desc=*/[{
Return the `i`-th input value.
The `i^th` input argument is always the `i^th` operand regardless of
whether we have tensors or buffers.
}],
/*retTy=*/"Value",
/*methodName=*/"getInput",
/*args=*/(ins "unsigned":$i),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(i < $_op.getNumInputs());
return this->getOperation()->getOperand(i);
}]
>,
InterfaceMethod<"Return the input view at the given index.",
"Value", "getInput", (ins "unsigned":$i)
>,
InterfaceMethod<[{
InterfaceMethod<
/*desc=*/[{
Return the index of the given input value `v`, or `None` if the value is
not an input.
}],
"llvm::Optional<unsigned>", "getIndexOfInput", (ins "Value":$v)
/*retTy=*/"llvm::Optional<unsigned>",
/*methodName=*/"getIndexOfInput",
/*args=*/(ins "Value":$value),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto it = llvm::find(getInputs(), value);
if (it != getInputs().end())
return it - getInputs().begin();
return llvm::None;
}]
>,
InterfaceMethod<
"Return the input operands from the current operation.",
"Operation::operand_range", "getInputs"
>,
InterfaceMethod<[{
/*desc=*/[{
Return the `i`-th input shaped type, irrespective of buffer or tensor
type.
}], "ShapedType", "getInputShapedType", (ins "unsigned":$i)>,
InterfaceMethod<[{
}],
/*retTy=*/"ShapedType",
/*methodName=*/"getInputShapedType",
/*args=*/(ins "unsigned":$i),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return getInput(i).getType().template cast<ShapedType>();
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the input operands from the current operation.
}],
/*retTy=*/"Operation::operand_range",
/*methodName=*/"getInputs",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto range = this->getOperation()->getOperands();
return {range.begin(), range.begin() + $_op.getNumInputs()};
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the subset of input operands that are of ranked tensor type.
}], "SmallVector<RankedTensorType, 4>", "getInputTensorTypes">,
}],
/*retTy=*/"SmallVector<RankedTensorType, 4>",
/*methodName=*/"getInputTensorTypes" ,
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
SmallVector<RankedTensorType, 4> res;
for (Type type : getInputs().getTypes())
if (auto t = type.template dyn_cast<RankedTensorType>())
res.push_back(t);
return res;
}]
>,
//===------------------------------------------------------------------===//
// Output arguments handling.
//===------------------------------------------------------------------===//
InterfaceMethod<
"Return the number of outputs from the current operation.",
"unsigned", "getNumOutputs"
/*desc=*/[{
Return the output buffer at the given index, asserts that this is a
buffer operand and not a tensor result.
The `i^th` output argument is an operand (resp. a return value) iff it
is a value of buffer type (resp. a return value of tensor type).
}],
/*retTy=*/"Value",
/*methodName=*/"getOutputBuffer",
/*args=*/(ins "unsigned":$i),
/*methodBody=*/"",
/*defaultImplementation=*/[{
// Output buffers are passed as output buffer operands (side-effecting).
// Output tensors are results.
// The union of the 2 are all the outputs and we want to ensure i does
// not overflow the buffer operands.
assert(i + this->getOperation()->getNumResults() < $_op.getNumOutputs()
&& "overflowing output buffer index");
return this->getOperation()->getOperand($_op.getNumInputs() + i);
}]
>,
InterfaceMethod<"Return the output buffer at the given index.",
"Value", "getOutputBuffer", (ins "unsigned":$i)
>,
InterfaceMethod<[{
InterfaceMethod<
/*desc=*/[{
Return the index of the given buffer value, or `None` if the value is
not part of the output buffers.
}],
"llvm::Optional<unsigned>", "getIndexOfOutputBuffer", (ins "Value":$view)
/*retTy=*/"llvm::Optional<unsigned>",
/*methodName=*/"getIndexOfOutputBuffer",
/*args=*/(ins "Value":$value),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto it = llvm::find(getOutputBuffers(), value);
if (it != getOutputBuffers().end())
return it - getOutputBuffers().begin();
return llvm::None;
}]
>,
InterfaceMethod<[{
InterfaceMethod<
/*desc=*/[{
Return the type of the output buffer at the given index.
}], "MemRefType", "getOutputBufferType", (ins "unsigned":$i)>,
InterfaceMethod<[{
}],
/*retTy=*/"MemRefType",
/*methodName=*/"getOutputBufferType",
/*args=*/(ins "unsigned":$i),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return getOutputBuffer(i).getType().template cast<MemRefType>();
}]>,
InterfaceMethod<
/*desc=*/[{
Return the `i`-th output shaped type, irrespective of buffer or tensor
type.
}], "ShapedType", "getOutputShapedType", (ins "unsigned":$i)>,
InterfaceMethod<[{
Return the results that are of ranked tensor type.
}], "SmallVector<RankedTensorType, 4>", "getOutputTensorTypes">,
}],
/*retTy=*/"ShapedType",
/*methodName=*/"getOutputShapedType",
/*args=*/(ins "unsigned":$i),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return getShapedType(i + $_op.getNumInputs());
}]>,
InterfaceMethod<
"Return the output buffers (operands) from the current operation.",
"Operation::operand_range", "getOutputBuffers"
/*desc=*/[{
Return the results that are of ranked tensor type.
}],
/*retTy=*/"SmallVector<RankedTensorType, 4>",
/*methodName=*/"getOutputTensorTypes",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
SmallVector<RankedTensorType, 4> res;
for (Type type : this->getOperation()->getResults().getTypes())
res.push_back(type.template cast<RankedTensorType>());
return res;
}]>,
InterfaceMethod<
/*desc=*/[{
Return the output buffers (operands) from the current operation.
}],
/*retTy=*/"Operation::operand_range",
/*methodName=*/"getOutputBuffers",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto range = this->getOperation()->getOperands();
return {range.begin() + $_op.getNumInputs(),
range.begin() + getNumInputsAndOutputBuffers()};
}]
>,
//===------------------------------------------------------------------===//
// Input and Output arguments handling.
//===------------------------------------------------------------------===//
InterfaceMethod<
"Return one single buffer at position `$i`.",
"Value", "getBuffer", (ins "unsigned":$i)
/*desc=*/[{
Return one single buffer at position `$i`.
}],
/*retTy=*/"Value",
/*methodName=*/"getBuffer",
/*args=*/(ins "unsigned":$i),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(i < getNumInputsAndOutputBuffers() && "overflowing buffers index");
return this->getOperation()->getOperand(i);
}]
>,
InterfaceMethod<
"Return the number of inputs and outputs, irrespective of their buffer "
"or tensor type.",
"unsigned", "getNumInputsAndOutputs"
/*desc=*/[{
Return the number of inputs and outputs, irrespective of their buffer or
tensor type.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getNumInputsAndOutputs",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op.getNumInputs() + $_op.getNumOutputs();
}]
>,
InterfaceMethod<
"Return the number of inputs, irrespective of their buffer or tensor "
"type, and output buffers",
"unsigned", "getNumInputsAndOutputBuffers"
/*desc=*/[{
Return the number of inputs, irrespective of their buffer or tensor type
and output buffers
}],
/*retTy=*/"unsigned",
/*methodName=*/"getNumInputsAndOutputBuffers",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op.getNumInputs() + $_op.getNumOutputs() -
this->getOperation()->getNumResults();
}]
>,
InterfaceMethod<
"Return the range over inputs (irrespective of type) and output buffers.",
"Operation::operand_range", "getInputsAndOutputBuffers"
/*desc=*/[{
Return the range over inputs (irrespective of type) and output buffers.
}],
/*retTy=*/"Operation::operand_range",
/*methodName=*/"getInputsAndOutputBuffers",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto range = this->getOperation()->getOperands();
return {range.begin(), range.begin() + getNumInputsAndOutputBuffers()};
}]
>,
InterfaceMethod<
"Return the shaped types for all the inputs and outputs",
"SmallVector<ShapedType, 4>", "getInputOutputShapedTypes"
/*desc=*/[{
Return the `i`-th shaped type, there are 3 cases:
1. if `i < $_op.getNumInputs()` then return `getInputShapedType(i)`;
otherwise
2. if `i < getNumInputsAndOutputBuffers()` then return the
`getOutputBufferType(i - $_op.getNumInputs())`; otherwise
3. return the `i - getNumInputsAndOutputBuffers()` result type.
}],
/*retTy=*/"ShapedType",
/*methodName=*/"getShapedType",
/*args=*/(ins "unsigned":$i),
/*methodBody=*/"",
/*defaultImplementation=*/[{
if (i < $_op.getNumInputs())
return getInputShapedType(i);
if (i < getNumInputsAndOutputBuffers())
return getOutputBufferType(i - $_op.getNumInputs());
return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()];
}]>,
InterfaceMethod<
/*desc=*/[{
Return the shaped types for all the inputs and outputs
}],
/*retTy=*/"SmallVector<ShapedType, 4>",
/*methodName=*/"getInputOutputShapedTypes",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
SmallVector<Type, 4> inputOutputTypes(
this->getOperation()->operand_type_begin(),
this->getOperation()->operand_type_end());
inputOutputTypes.append(this->getOperation()->result_type_begin(),
this->getOperation()->result_type_end());
return llvm::to_vector<4>(
llvm::map_range(inputOutputTypes, [](Type type) -> ShapedType {
return type.cast<ShapedType>();
}));
}]
>,
//===------------------------------------------------------------------===//
// Other interface methods.
//===------------------------------------------------------------------===//
InterfaceMethod<
"Return the reference iterators for this named op (if any are "
"specified). These reference iterators are used to specify the default "
"behavior of the op. Typically this would be a static method but in "
"order to allow rank-polymorphic ops, this needs to be per object "
"instance. Named ops must define referenceIterators, even if empty for "
"the 0-D case. Generic ops on the other hand have a None "
"`referenceIterators`",
"llvm::Optional<SmallVector<StringRef, 8>>", "referenceIterators"
/*desc=*/[{
Return the iterator types attribute within the current operation.
}],
/*retTy=*/"ArrayAttr",
/*methodName=*/"iterator_types",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op.iterator_types();
}]
>,
InterfaceMethod<
"Return the reference indexing maps for this named op (if any are "
"specified). Typically this would be a static method but in order to "
"allow rank-polymorphic ops, this needs to be per object instance. Named "
"ops must define referenceIterators, even if empty for the 0-D case. "
"Generic ops on the other hand have a None `referenceIndexingMaps`",
"llvm::Optional<SmallVector<AffineMap, 8>>", "referenceIndexingMaps"
/*desc=*/[{
Return the indexing maps attribute within the current operation.
}],
/*retTy=*/"ArrayAttr",
/*methodName=*/"indexing_maps"
>,
InterfaceMethod<
"Return the iterator types attribute within the current operation.",
"ArrayAttr", "iterator_types"
/*desc=*/[{
Return the indexing maps within the current operation.
}],
/*retTy=*/"SmallVector<AffineMap, 4>",
/*methodName=*/"getIndexingMaps",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return llvm::to_vector<4>(
llvm::map_range($_op.indexing_maps(),
[](Attribute attr) -> AffineMap {
return attr.cast<AffineMapAttr>().getValue();
}));
}]
>,
InterfaceMethod<
"Return the indexing maps attribute within the current operation.",
"ArrayAttr", "indexing_maps"
/*desc=*/[{
Return the input or output indexing map at index `i`.
}],
/*retTy=*/"AffineMap",
/*methodName=*/"getIndexingMap",
/*args=*/(ins "unsigned":$i),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(i < getNumInputsAndOutputs());
return $_op.indexing_maps()
.getValue()[i]
.template cast<AffineMapAttr>()
.getValue();
}]
>,
InterfaceMethod<
"Return the indexing maps within the current operation.",
"SmallVector<AffineMap, 4>", "getIndexingMaps"
/*desc=*/[{
Return the input indexing map at index `i`.
}],
/*retTy=*/"AffineMap",
/*methodName=*/"getInputIndexingMap",
/*args=*/(ins "unsigned":$i),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(i < $_op.getNumInputs());
return $_op.indexing_maps()
.getValue()[i]
.template cast<AffineMapAttr>()
.getValue();
}]
>,
InterfaceMethod<"Return the input or output indexing map at index `i`.",
"AffineMap", "getIndexingMap", (ins "unsigned":$i)
InterfaceMethod<
/*desc=*/[{
Return the output indexing map at index `i`.
}],
/*retTy=*/"AffineMap",
/*methodName=*/"getOutputIndexingMap",
/*args=*/(ins "unsigned":$i),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(i < $_op.getNumOutputs());
return $_op.indexing_maps()
.getValue()[i + $_op.getNumInputs()]
.template cast<AffineMapAttr>()
.getValue();
}]
>,
InterfaceMethod<"Return the input indexing map at index `i`.",
"AffineMap", "getInputIndexingMap", (ins "unsigned":$i)
>,
InterfaceMethod<"Return the output indexing map at index `i`.",
"AffineMap", "getOutputIndexingMap", (ins "unsigned":$i)
>,
InterfaceMethod<[{
InterfaceMethod<
/*desc=*/[{
Return whether the op has only MemRef input and outputs.
}], "bool", "hasBufferSemantics">,
InterfaceMethod<[{
}],
/*retTy=*/"bool",
/*methodName=*/"hasBufferSemantics",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return this->getOperation()->getNumResults() == 0 &&
llvm::all_of(getInputs(),
[](Value v) { return v.getType().isa<MemRefType>(); });
}]
>,
InterfaceMethod<
/*desc=*/[{
Return whether the op has only RankedTensor input and outputs.
}], "bool", "hasTensorSemantics">,
}],
/*retTy=*/"bool",
/*methodName=*/"hasTensorSemantics",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto isTensorType = [](Value v) {
return v.getType().isa<RankedTensorType>();
};
return llvm::all_of(getInputs(), isTensorType) &&
llvm::all_of(this->getOperation()->getResults(), isTensorType);
}]
>,
//===------------------------------------------------------------------===//
// Other static interface methods.
//===------------------------------------------------------------------===//
StaticInterfaceMethod<[{
StaticInterfaceMethod<
/*desc=*/[{
Create an operation of the current type with the given location,
operands, and attributes.
}],
"Operation *", "create",
/*retTy=*/"Operation *",
/*methodName=*/"create",
(ins "OpBuilder &":$builder, "Location":$loc,
"ValueRange":$operands,
"ArrayRef<NamedAttribute>":$attributes), [{
@ -192,11 +510,13 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
attributes);
}]
>,
InterfaceMethod<[{
InterfaceMethod<
/*desc=*/[{
Clone the current operation with the given location and operands. This
is used to abstract away the optional underlying region creation.
}],
"Operation *", "clone",
/*retTy=*/"Operation *",
/*methodName=*/"clone",
(ins "OpBuilder &":$b, "Location":$loc, "ValueRange":$operands), [{
BlockAndValueMapping map;
unsigned numRegions = $_op.getOperation()->getNumRegions();

View File

@ -49,8 +49,8 @@ public:
};
};
/// This class provides the API for structured ops that are known to operate on
/// buffers or tensors. This trait must be used in conjunction with an op
/// This class provides a verifier for structured ops that are known to operate
/// on buffers or tensors. This trait must be used in conjunction with an op
/// definition or a trait that provides the methods `getNumInputs` and
/// `getNumOutputs`. Use as a trait as follows:
///
@ -59,324 +59,18 @@ public:
template <typename ConcreteType>
class StructuredOpTraits
: public OpTrait::TraitBase<ConcreteType, StructuredOpTraits> {
private:
/// Return the number of inputs, irrespective of their buffer or tensor type.
/// For internal use only.
unsigned nInputs() {
return cast<ConcreteType>(this->getOperation()).getNumInputs();
}
/// Return the number of outputs, irrespective of their buffer or tensor type.
/// For internal use only.
unsigned nOutputs() {
return cast<ConcreteType>(this->getOperation()).getNumOutputs();
}
public:
//==========================================================================//
// Loop types handling.
//==========================================================================//
unsigned getNumParallelLoops() {
return getNumIterators(
getParallelIteratorTypeName(),
cast<ConcreteType>(this->getOperation()).iterator_types());
}
unsigned getNumReductionLoops() {
return getNumIterators(
getReductionIteratorTypeName(),
cast<ConcreteType>(this->getOperation()).iterator_types());
}
unsigned getNumWindowLoops() {
return getNumIterators(
getWindowIteratorTypeName(),
cast<ConcreteType>(this->getOperation()).iterator_types());
}
unsigned getNumLoops() {
return getNumIterators(
cast<ConcreteType>(this->getOperation()).iterator_types());
}
bool hasSingleReductionLoop() {
auto iterators = cast<ConcreteType>(this->getOperation()).iterator_types();
return iterators.size() == 1 &&
getNumIterators(getReductionIteratorTypeName(), iterators);
}
//==========================================================================//
// Input arguments handling.
//==========================================================================//
// The `i^th` input argument is always the `i^th` operand regardless of
// whether we have tensors or buffers.
//
/// Return the `i`-th input value.
Value getInput(unsigned i) {
assert(i < nInputs());
return this->getOperation()->getOperand(i);
}
/// Return the index of `value` in the list of inputs if found, llvm::None
/// otherwise.
Optional<unsigned> getIndexOfInput(Value value) {
auto it = llvm::find(getInputs(), value);
if (it != getInputs().end())
return it - getInputs().begin();
return llvm::None;
}
/// Return the `i`-th input shaped type, irrespective of buffer or tensor
/// type.
ShapedType getInputShapedType(unsigned i) {
return getInput(i).getType().template cast<ShapedType>();
}
/// Return the range over inputs.
Operation::operand_range getInputs() {
auto range = this->getOperation()->getOperands();
return {range.begin(), range.begin() + nInputs()};
}
/// Query the subset of input operands that are of ranked tensor type.
SmallVector<RankedTensorType, 4> getInputTensorTypes() {
SmallVector<RankedTensorType, 4> res;
for (Type type : getInputs().getTypes())
if (auto t = type.template dyn_cast<RankedTensorType>())
res.push_back(t);
return res;
}
//==========================================================================//
// Output arguments handling.
//==========================================================================//
// The `i^th` output argument is an operand (resp. a return value) iff it is
// a value of buffer type (resp. a return value of tensor type).
/// Return the `i`-th output, asserts that this is a buffer operand and not
/// a tensor result.
Value getOutputBuffer(unsigned i) {
assert(i + this->getOperation()->getNumResults() < nOutputs() &&
"overflowing output buffer index");
return this->getOperation()->getOperand(nInputs() + i);
}
/// Return the index of `value` in the list of output buffers if found,
/// llvm::None otherwise.
Optional<unsigned> getIndexOfOutputBuffer(Value value) {
auto it = llvm::find(getOutputBuffers(), value);
if (it != getOutputBuffers().end())
return it - getOutputBuffers().begin();
return llvm::None;
}
/// Return the `i`-th output buffer type.
MemRefType getOutputBufferType(unsigned i) {
return getOutputBuffer(i).getType().template cast<MemRefType>();
}
/// Return the `i`-th output shaped type, irrespective of buffer of tensor
/// type.
ShapedType getOutputShapedType(unsigned i) {
return getShapedType(i + nInputs());
}
/// Query the subset of results that are of ranked tensor type.
SmallVector<RankedTensorType, 4> getOutputTensorTypes() {
SmallVector<RankedTensorType, 4> res;
for (Type type : this->getOperation()->getResults().getTypes())
res.push_back(type.template cast<RankedTensorType>());
return res;
}
/// Return the range over outputs.
Operation::operand_range getOutputBuffers() {
auto range = this->getOperation()->getOperands();
return {range.begin() + nInputs(),
range.begin() + getNumInputsAndOutputBuffers()};
}
//==========================================================================//
// Input and Output arguments handling.
//==========================================================================//
Value getBuffer(unsigned i) {
assert(i < getNumInputsAndOutputBuffers() && "overflowing buffers index");
return this->getOperation()->getOperand(i);
}
/// Return the number of inputs and outputs, irrespective of their buffer or
/// tensor type.
unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); }
/// Return the number of inputs, irrespective of their buffer or tensor type,
/// and output buffers.
unsigned getNumInputsAndOutputBuffers() {
assert(this->getOperation()->getNumResults() <= nOutputs());
return nInputs() + nOutputs() - this->getOperation()->getNumResults();
}
/// Return the range over inputs (irrespective of type) and output buffers.
Operation::operand_range getInputsAndOutputBuffers() {
auto range = this->getOperation()->getOperands();
return {range.begin(), range.begin() + getNumInputsAndOutputBuffers()};
}
/// Return the `i`-th shaped type, there are 3 cases:
/// 1. if `i < nInputs()` then return `getInputShapedType(i)`; otherwise
/// 2. if `i < getNumInputsAndOutputBuffers()` then return the
/// `getOutputBufferType(i - nInputs())`; otherwise
/// 3. return the `i - getNumInputsAndOutputBuffers()` result type.
ShapedType getShapedType(unsigned i) {
if (i < nInputs())
return getInputShapedType(i);
if (i < getNumInputsAndOutputBuffers())
return getOutputBufferType(i - nInputs()).template cast<ShapedType>();
return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()]
.template cast<ShapedType>();
}
/// Return the shaped types for all the inputs and outputs
SmallVector<ShapedType, 4> getInputOutputShapedTypes() {
SmallVector<Type, 4> inputOutputTypes(
this->getOperation()->operand_type_begin(),
this->getOperation()->operand_type_end());
inputOutputTypes.append(this->getOperation()->result_type_begin(),
this->getOperation()->result_type_end());
return llvm::to_vector<4>(
llvm::map_range(inputOutputTypes, [](Type type) -> ShapedType {
return type.cast<ShapedType>();
}));
}
//==========================================================================//
// Other interface methods.
//==========================================================================//
// Get or build the indexing_maps ArrayAttr.
ArrayAttr iterator_types() {
// Return the attribute if it is present.
if (auto attr = this->getOperation()->getAttr("iterator_types"))
return attr.template cast<ArrayAttr>();
// If not, form the attribute using the reference iterator types for the
// ConcreteType.
auto maybeReferenceIteratorTypes =
cast<ConcreteType>(this->getOperation()).referenceIterators();
// If there is no reference, this must be a generic op.
// TODO: Traits are used to define ops. Split into cpp to avoid cyclic
// dependency.
auto name = this->getOperation()->getName().getStringRef();
if (!maybeReferenceIteratorTypes && name != "generic" &&
name != "indexed_generic") {
this->getOperation()->dump();
llvm_unreachable("Op missing referenceIterators");
}
// If we have a reference, build the reference attribute and set it in the
// op before returning.
auto *ctx = this->getOperation()->getContext();
auto attrRange = llvm::map_range(*maybeReferenceIteratorTypes,
[ctx](StringRef str) -> Attribute {
return StringAttr::get(str, ctx);
});
auto attr = ArrayAttr::get(llvm::to_vector<4>(attrRange), ctx);
// TODO: Need to memoize this. Can't just store as an attribute atm as it
// will impact parser, printer and tests.
// this->getOperation()->setAttr("iterator_types", attr);
return attr;
}
// Get or build the indexing_maps ArrayAttr.
ArrayAttr indexing_maps() {
// Return the attribute if it is present.
if (auto attr = this->getOperation()->getAttr("indexing_maps"))
return attr.template cast<ArrayAttr>();
// If not, form the attribute using the reference indexing map for the
// ConcreteType.
auto maybeReferenceIndexingMaps =
cast<ConcreteType>(this->getOperation()).referenceIndexingMaps();
// If there is no reference, this must be a generic op.
auto name = this->getOperation()->getName().getStringRef();
if (!maybeReferenceIndexingMaps && name != "generic" &&
name != "indexed_generic") {
this->getOperation()->dump();
llvm_unreachable("Op missing referenceIndexingMaps");
}
// If we have a reference, build the reference attribute and set it in the
// op before returning.
auto *ctx = this->getOperation()->getContext();
auto attrRange =
llvm::map_range(*maybeReferenceIndexingMaps, [ctx](AffineMap map) {
// 0-D corner case because there is no such thing as a concrete empty
// map type.
if (!map)
map = AffineMap::get(0, 0, getAffineConstantExpr(0, ctx));
return AffineMapAttr::get(map);
});
SmallVector<Attribute, 4> attrs{attrRange.begin(), attrRange.end()};
auto attr = ArrayAttr::get(attrs, ctx);
// TODO: Need to memoize this. Can't just store as an attribute atm as it
// will impact parser, printer and tests.
// this->getOperation()->setAttr("indexing_maps", attr);
return attr;
}
SmallVector<AffineMap, 4> getIndexingMaps() {
return llvm::to_vector<4>(
llvm::map_range(indexing_maps(), [](Attribute attr) -> AffineMap {
return attr.cast<AffineMapAttr>().getValue();
}));
}
AffineMap getIndexingMap(unsigned i) {
assert(i < getNumInputsAndOutputs());
return indexing_maps()
.getValue()[i]
.template cast<AffineMapAttr>()
.getValue();
}
AffineMap getInputIndexingMap(unsigned i) {
assert(i < nInputs());
return indexing_maps()
.getValue()[i]
.template cast<AffineMapAttr>()
.getValue();
}
AffineMap getOutputIndexingMap(unsigned i) {
assert(i < nOutputs());
return indexing_maps()
.getValue()[i + nInputs()]
.template cast<AffineMapAttr>()
.getValue();
}
/// Query whether the op has only buffer inputs and no returns.
bool hasBufferSemantics() {
return this->getOperation()->getNumResults() == 0 &&
llvm::all_of(getInputs(),
[](Value v) { return v.getType().isa<MemRefType>(); });
}
/// Query whether the op has only tensor inputs and outputs.
bool hasTensorSemantics() {
auto isTensorType = [](Value v) {
return v.getType().isa<RankedTensorType>();
};
return llvm::all_of(getInputs(), isTensorType) &&
llvm::all_of(this->getOperation()->getResults(), isTensorType);
}
//==========================================================================//
// Other static interface methods.
//==========================================================================//
static LogicalResult verifyTrait(Operation *op) {
ConcreteType concreteOp = cast<ConcreteType>(op);
auto nOperands = cast<ConcreteType>(op).getNumInputsAndOutputBuffers();
if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nOperands)))
return failure();
if (op->getNumResults() > concreteOp.getNumOutputs())
return op->emitError("unexpected #results > #outputs");
return success();
}
};
/// This class provides the API for named Linalg StructuredOps.
template <typename ConcreteType>
class NamedStructuredOpTraits
: public OpTrait::TraitBase<ConcreteType, NamedStructuredOpTraits> {
public:
static SmallVector<StringRef, 8> referenceIterators(TypeRange inputTypes,
TypeRange outputTypes);
static SmallVector<AffineMap, 8> referenceIndexingMaps(TypeRange inputTypes,
TypeRange outputTypes);
};
} // namespace linalg
} // namespace OpTrait
} // namespace mlir

View File

@ -260,13 +260,14 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
if (failed(BlockArgsVerifier<GenericOpType>::verify(op, region.front())))
return failure();
auto attr = op.template getAttrOfType<IntegerAttr>("symbol_source");
int64_t targetRank = 0;
if (attr) {
unsigned index = attr.getInt();
auto symbolSourceAttr =
op.template getAttrOfType<IntegerAttr>("symbol_source");
int64_t expectedNumSymbols = 0;
if (symbolSourceAttr) {
unsigned index = symbolSourceAttr.getInt();
if (index >= op.getNumOperands())
return op.emitOpError("symbol_source index out of range");
targetRank = op.getShapedType(index).getRank();
expectedNumSymbols = op.getShapedType(index).getRank();
}
SmallVector<AffineMap, 4> indexingMaps;
@ -278,9 +279,9 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
auto view = (idx < nInputViews) ? op.getInputShapedType(idx)
: op.getOutputShapedType(idx - nInputViews);
if (m.getNumSymbols() != targetRank)
if (m.getNumSymbols() != expectedNumSymbols)
return op.emitOpError("expected the number of symbols in indexing_map #")
<< idx << " to match target rank";
<< idx << " to match rank of operand `symbol_source`";
if (m.getNumDims() != nLoops)
return op.emitOpError("expected indexing_map #")
@ -1246,15 +1247,9 @@ void buildNamedStructuredOpRegionAndAttributes(Builder &builder,
mlir::edsc::ScopedContext scope(opBuilder, builder.getUnknownLoc());
NamedStructuredOpType::regionBuilder(*body);
auto indexingMaps = builder.getAffineMapArrayAttr(
NamedStructuredOpType::referenceIndexingMaps(operandTypes,
tensorResultTypes));
result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
// indexing_maps is an auto-generated method.
auto iterators =
builder.getStrArrayAttr(NamedStructuredOpType::referenceIterators(
operandTypes, tensorResultTypes));
result.addAttribute(getIteratorTypesAttrName(), iterators);
// iterator_types is an auto-generated method.
}
template <typename NamedStructuredOpType>

View File

@ -113,7 +113,7 @@ func @generic_mismatched_num_returns(%arg0: memref<f32>) {
// -----
func @generic_symbol_in_map(%arg0: memref<i32>) {
// expected-error @+1 {{expected the number of symbols in indexing_map #0 to match target rank}}
// expected-error @+1 {{expected the number of symbols in indexing_map #0 to match rank of operand `symbol_source`}}
linalg.generic {
args_in = 0,
args_out = 1,
@ -514,3 +514,20 @@ func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?xf32>, %c3: memref<?x?x?x
linalg.batch_matmul %a3, %b3, %c3 : (memref<?x?x?xf32>, memref<?x?xf32>, memref<?x?x?xf32>) -> ()
return
}
// -----
func @generic(%arg0: tensor<?x?xi4>) {
// expected-error @+1 {{unexpected #results > #outputs}}
linalg.generic {
args_in = 1,
args_out = 1,
indexing_maps = [ affine_map<(i) -> (i)> ],
iterator_types = ["parallel"]
} %arg0 {
^bb(%0: i4) :
%1 = std.addi %0, %0: i4
linalg.yield %1, %1: i4, i4
} : tensor<?x?xi4> -> (tensor<?x?xi4>, tensor<?x?xi4>)
return
}

View File

@ -4,16 +4,15 @@
// ODS-LABEL: def Test1Op : LinalgNamedStructured_Op<"test1", [
// ODS-NEXT: NInputs<2>
// ODS-NEXT: NOutputs<1>
// ODS-NEXT: NamedStructuredOpTraits
// ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp">
//
// IMPL-LABEL: SmallVector<StringRef, 8> Test1Op::referenceIterators
// IMPL-LABEL: ArrayAttr Test1Op::iterator_types() {
// IMPL: { {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
//
// IMPL: SmallVector<AffineMap, 8> Test1Op::referenceIndexingMaps
// IMPL: ArrayAttr Test1Op::indexing_maps() {
// IMPL: AffineMap::get(2, 0, {d0, d1}, context),
// IMPL-NEXT: AffineMap::get(2, 0, {d1}, context),
// IMPL-NEXT: AffineMap::get(2, 0, {d0}, context) };
// IMPL-NEXT: AffineMap::get(2, 0, {d0}, context) });
//
// IMPL: void Test1Op::regionBuilder(Block &block) {
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
@ -29,16 +28,15 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
// ODS-LABEL: def Test2Op : LinalgNamedStructured_Op<"test2", [
// ODS-NEXT: NInputs<2>
// ODS-NEXT: NOutputs<1>
// ODS-NEXT: NamedStructuredOpTraits
// ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp">
//
// IMPL-LABEL: SmallVector<StringRef, 8> Test2Op::referenceIterators
// IMPL-LABEL: ArrayAttr Test2Op::iterator_types() {
// IMPL: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
//
// IMPL: SmallVector<AffineMap, 8> Test2Op::referenceIndexingMaps
// IMPL: ArrayAttr Test2Op::indexing_maps() {
// IMPL: AffineMap::get(3, 0, {d0, d2}, context),
// IMPL-NEXT: AffineMap::get(3, 0, {d2, d1}, context),
// IMPL-NEXT: AffineMap::get(3, 0, {d0, d1}, context) };
// IMPL-NEXT: AffineMap::get(3, 0, {d0, d1}, context) });
//
// IMPL: Test2Op::regionBuilder(Block &block) {
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
@ -54,16 +52,15 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
// ODS-LABEL: def Test3Op : LinalgNamedStructured_Op<"test3", [
// ODS-NEXT: NInputs<2>
// ODS-NEXT: NOutputs<1>
// ODS-NEXT: NamedStructuredOpTraits
// ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp">
//
// IMPL-LABEL: SmallVector<StringRef, 8> Test3Op::referenceIterators
// IMPL-LABEL: ArrayAttr Test3Op::iterator_types() {
// IMPL: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
//
// IMPL: SmallVector<AffineMap, 8> Test3Op::referenceIndexingMaps
// IMPL: ArrayAttr Test3Op::indexing_maps() {
// IMPL: AffineMap::get(4, 0, {d0, d1, d3}, context),
// IMPL-NEXT: AffineMap::get(4, 0, {d3, d2}, context),
// IMPL-NEXT: AffineMap::get(4, 0, {d0, d1, d2}, context) };
// IMPL-NEXT: AffineMap::get(4, 0, {d0, d1, d2}, context) });
//
// IMPL: Test3Op::regionBuilder(Block &block) {
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);

View File

@ -974,19 +974,19 @@ public:
/// Parse and print the information for a TC def.
/// When `gen-ods-decl` is used, this prints the ODS declaration for the TC.
/// When `gen-impl` is used, this prints the C++ implementation for the extra
/// methods defined in ODS (referenceIterators, referenceIndexingMaps and
/// regionBuilder).
/// methods defined in ODS (`iterator_types`, `indexing_maps` and
/// `regionBuilder`).
LogicalResult parseAndEmitODSDef(llvm::raw_ostream &os);
/// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
void printODS(llvm::raw_ostream &os, StringRef cppOpName,
StringRef linalgOpName);
/// Print the C++ StructuredOpsInterface impl of `referenceIterators`.
/// Print the C++ StructuredOpsInterface impl of `iterator_types`.
void printReferenceIterators(llvm::raw_ostream &os, StringRef cppOpName,
ComprehensionParsingState &state);
/// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`.
/// Print the C++ StructuredOpsInterface impl of `indexing_maps`.
void printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef cppOpName,
ComprehensionParsingState &state);
@ -1446,7 +1446,6 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
const char *header = R"FMT( def {0} : LinalgNamedStructured_Op<"{1}", [
NInputs<{2}>,
NOutputs<{3}>,
NamedStructuredOpTraits,
SingleBlockImplicitTerminator<"YieldOp">]> {
let arguments = (ins Variadic<LinalgOperand>:$views);
let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
@ -1465,16 +1464,9 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
return ::parseNamedStructuredOp<{0}>(parser, result);
}];
let extraClassDeclaration = [{{
llvm::Optional<SmallVector<StringRef, 8>> referenceIterators();
static SmallVector<StringRef, 8> referenceIterators(
TypeRange inputTypes, TypeRange outputTypes);
llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps();
static SmallVector<AffineMap, 8> referenceIndexingMaps(
TypeRange inputTypes, TypeRange outputTypes);
ArrayAttr iterator_types();
ArrayAttr indexing_maps();
static void regionBuilder(Block &block);
std::string getLibraryCallName() {{
return generateLibraryCallName(getOperation());
}
@ -1492,20 +1484,14 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs);
}
/// Print the C++ StructuredOpsInterface impl of `referenceIterators`.
/// Print the C++ StructuredOpsInterface impl of `iterator_types`.
void TCParser::printReferenceIterators(llvm::raw_ostream &os,
StringRef cppOpName,
ComprehensionParsingState &state) {
const char *referenceReferenceIteratorsFmt =
R"FMT(
// This is temporary until we transition out of manually specified ops
// that should be auto-generated with linalg-ods-gen.
llvm::Optional<SmallVector<StringRef, 8>> {0}::referenceIterators() {{
llvm_unreachable("Unexpected missing `iterator_types` attribute.");
}
SmallVector<StringRef, 8> {0}::referenceIterators(
TypeRange inputTypes, TypeRange outputTypes) {
return SmallVector<StringRef, 8>{{ {1} };
ArrayAttr {0}::iterator_types() {
return Builder(getContext()).getStrArrayAttr(SmallVector<StringRef, 8>{{ {1} });
})FMT";
std::string iteratorsStr;
@ -1542,16 +1528,11 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
R"FMT(
// This is temporary until we transition out of manually specified ops that
// should be auto-generated with linalg-ods-gen.
llvm::Optional<SmallVector<AffineMap, 8>> {0}::referenceIndexingMaps() {{
llvm_unreachable("Unexpected missing `indexing_maps` attribute.");
}
SmallVector<AffineMap, 8> {0}::referenceIndexingMaps(
TypeRange inputTypes, TypeRange outputTypes) {
assert(!inputTypes.empty() && "At least one input expected");
MLIRContext *context = (*inputTypes.begin()).getContext();
ArrayAttr {0}::indexing_maps() {
MLIRContext *context = getContext();
AffineExpr {1};
bindDims(context, {1});
return SmallVector<AffineMap, 8>{{ {2} };
return Builder(context).getAffineMapArrayAttr({ {2} });
})FMT";
// 2. Print a comma-separated list of identifiers for the AffineExpr in