forked from OSchip/llvm-project
Continue refactoring StructuredOps utilities
This CL adds more common information to StructuredOpsUtils.h The n_view attribute is retired in favor of args_in + args_out but the CL is otherwise NFC. PiperOrigin-RevId: 285000621
This commit is contained in:
parent
c5fb4c1303
commit
508d4e672e
|
@ -26,37 +26,17 @@
|
|||
include "mlir/Dialect/AffineOps/AffineOpsBase.td"
|
||||
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
|
||||
|
||||
class LinalgParametricNativeOpTrait<string prop, string parameters> :
|
||||
NativeOpTrait<"linalg::" # prop # parameters>
|
||||
{}
|
||||
|
||||
class LinalgParametricIntNativeOpTrait<string prop, list<int> parameters> :
|
||||
LinalgParametricNativeOpTrait<
|
||||
prop,
|
||||
!strconcat("<",
|
||||
!cast<string>(!head(parameters)),
|
||||
!foldl("",
|
||||
!tail(parameters),
|
||||
sum,
|
||||
param,
|
||||
sum # "," # !cast<string>(param)),
|
||||
">::Impl")>
|
||||
{}
|
||||
|
||||
// The Linalg `NInputsAndOutputs` trait provides the API for ops that are known
|
||||
// to have a specified number of inputs and outputs, all passed as operands.
|
||||
// The Linalg `NInputs` trait provides the API for ops that are known
|
||||
// to have a specified number of inputs, all passed as operands.
|
||||
// See Linalg/LinalgTraits.h for implementation details an usage.
|
||||
class NInputsAndOutputs<int n_ins, int n_outs> :
|
||||
LinalgParametricIntNativeOpTrait<"NInputsAndOutputs", [n_ins, n_outs]>
|
||||
{}
|
||||
class NInputs<int args_in> :
|
||||
NativeOpTrait<"linalg::NInputs<" # !cast<string>(args_in) # ">::Impl"> {}
|
||||
|
||||
// The linalg `NLoopTypes` trait provides the API for ops that are known to have
|
||||
// a specified number of parallel (n_par), reduction (n_red) and window (n_win)
|
||||
// loops.
|
||||
// The Linalg `NOutputs` trait provides the API for ops that are known
|
||||
// to have a specified number of outputs, all passed as operands.
|
||||
// See Linalg/LinalgTraits.h for implementation details an usage.
|
||||
class NLoopTypes<int n_par, int n_red, int n_win> :
|
||||
LinalgParametricIntNativeOpTrait<"NLoopTypes", [n_par, n_red, n_win]>
|
||||
{}
|
||||
class NOutputs<int args_out> :
|
||||
NativeOpTrait<"linalg::NOutputs<" # !cast<string>(args_out) # ">::Impl"> {}
|
||||
|
||||
def ViewTraits : NativeOpTrait<"linalg::ViewTraits">;
|
||||
|
||||
|
@ -88,6 +68,14 @@ def LinalgLibraryInterface : OpInterface<"LinalgOp"> {
|
|||
"Query the input and output operands from the current operation.",
|
||||
"Operation::operand_range", "getInputsAndOutputs"
|
||||
>,
|
||||
InterfaceMethod<
|
||||
"Query the iterator types attribute within the current operation.",
|
||||
"ArrayAttr", "iterator_types"
|
||||
>,
|
||||
InterfaceMethod<
|
||||
"Query the indexing maps attribute within the current operation.",
|
||||
"ArrayAttr", "indexing_maps"
|
||||
>,
|
||||
InterfaceMethod<
|
||||
"Query the number of parallel loops within the current operation.",
|
||||
"unsigned", "getNumParallelLoops"
|
||||
|
@ -102,10 +90,7 @@ def LinalgLibraryInterface : OpInterface<"LinalgOp"> {
|
|||
>,
|
||||
InterfaceMethod<
|
||||
"Query the number of loops within the current operation.",
|
||||
"unsigned", "getNumLoops", (ins), [{
|
||||
return op.getNumParallelLoops() + op.getNumReductionLoops() +
|
||||
op.getNumWindowLoops();
|
||||
}]>,
|
||||
"unsigned", "getNumLoops">,
|
||||
InterfaceMethod<"Query the input view at the given index.",
|
||||
"Value *", "getInput", (ins "unsigned":$i)
|
||||
>,
|
||||
|
@ -188,7 +173,7 @@ class LinalgLibrary_Op<string mnemonic, list<OpTrait> props>
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Concrete Linalg ops.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> {
|
||||
def CopyOp : LinalgLibrary_Op<"copy", [NInputs<1>, NOutputs<1>]> {
|
||||
let description = [{
|
||||
Copies the data in the input view into the output view.
|
||||
|
||||
|
@ -248,61 +233,87 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> {
|
|||
builder, result, input, output, AffineMapAttr(), AffineMapAttr());
|
||||
}]>];
|
||||
let extraClassDeclaration = libraryCallName # [{
|
||||
unsigned getNumParallelLoops() {
|
||||
auto *view = *(getOperands().begin());
|
||||
return view->getType().cast<MemRefType>().getRank();
|
||||
ArrayAttr indexing_maps();
|
||||
|
||||
ArrayAttr iterator_types() {
|
||||
unsigned nPar = input()->getType().cast<ShapedType>().getRank();
|
||||
MLIRContext *ctx = getContext();
|
||||
SmallVector<Attribute, 8> iters(
|
||||
nPar, StringAttr::get(getParallelIteratorTypeName(), ctx));
|
||||
return ArrayAttr::get(iters, ctx);
|
||||
}
|
||||
unsigned getNumReductionLoops() { return 0; }
|
||||
unsigned getNumWindowLoops() { return 0; }
|
||||
}];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def FillOp : LinalgLibrary_Op<"fill", [NInputsAndOutputs<0, 1>]> {
|
||||
let arguments = (ins AnyStridedMemRef,
|
||||
AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>);
|
||||
def FillOp : LinalgLibrary_Op<"fill", [NInputs<0>, NOutputs<1>]> {
|
||||
let arguments = (ins AnyStridedMemRef:$input,
|
||||
AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>:$value);
|
||||
let extraClassDeclaration = libraryCallName # [{
|
||||
unsigned getNumParallelLoops() {
|
||||
auto *view = *(getOperands().begin());
|
||||
return view->getType().cast<MemRefType>().getRank();
|
||||
}
|
||||
unsigned getNumReductionLoops() { return 0; }
|
||||
unsigned getNumWindowLoops() { return 0; }
|
||||
Value *getValue() {
|
||||
return *(getOperands().begin() + getNumInputsAndOutputs());
|
||||
ArrayAttr indexing_maps();
|
||||
|
||||
ArrayAttr iterator_types() {
|
||||
unsigned nPar = input()->getType().cast<ShapedType>().getRank();
|
||||
MLIRContext *ctx = getContext();
|
||||
SmallVector<Attribute, 8> iters(
|
||||
nPar, StringAttr::get(getParallelIteratorTypeName(), ctx));
|
||||
return ArrayAttr::get(iters, ctx);
|
||||
}
|
||||
}];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def DotOp : LinalgLibrary_Op<"dot",
|
||||
[NInputsAndOutputs<2, 1>,
|
||||
NLoopTypes<0, 1, 0>]> {
|
||||
def DotOp : LinalgLibrary_Op<"dot", [NInputs<2>, NOutputs<1>]> {
|
||||
let arguments = (ins AnyStridedMemRefOfRank<1>,
|
||||
AnyStridedMemRefOfRank<1>,
|
||||
AnyStridedMemRefOfRank<0>);
|
||||
let extraClassDeclaration = libraryCallName;
|
||||
let extraClassDeclaration = libraryCallName # [{
|
||||
ArrayAttr indexing_maps();
|
||||
|
||||
ArrayAttr iterator_types() {
|
||||
MLIRContext *ctx = getContext();
|
||||
return ArrayAttr::get(
|
||||
StringAttr::get(getReductionIteratorTypeName(), ctx), ctx);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def MatvecOp : LinalgLibrary_Op<"matvec",
|
||||
[NInputsAndOutputs<2, 1>,
|
||||
NLoopTypes<1, 1, 0>]> {
|
||||
def MatvecOp : LinalgLibrary_Op<"matvec", [NInputs<2>, NOutputs<1>]> {
|
||||
let arguments = (ins AnyStridedMemRefOfRank<2>,
|
||||
AnyStridedMemRefOfRank<1>,
|
||||
AnyStridedMemRefOfRank<1>);
|
||||
let extraClassDeclaration = libraryCallName;
|
||||
let extraClassDeclaration = libraryCallName # [{
|
||||
ArrayAttr indexing_maps();
|
||||
|
||||
ArrayAttr iterator_types() {
|
||||
MLIRContext *ctx = getContext();
|
||||
Attribute iters[2]{
|
||||
StringAttr::get(getParallelIteratorTypeName(), ctx),
|
||||
StringAttr::get(getReductionIteratorTypeName(), ctx)};
|
||||
return ArrayAttr::get(iters, ctx);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def MatmulOp : LinalgLibrary_Op<"matmul",
|
||||
[NInputsAndOutputs<2, 1>,
|
||||
NLoopTypes<2, 1, 0>]> {
|
||||
def MatmulOp : LinalgLibrary_Op<"matmul", [NInputs<2>, NOutputs<1>]> {
|
||||
let arguments = (ins AnyStridedMemRefOfRank<2>,
|
||||
AnyStridedMemRefOfRank<2>,
|
||||
AnyStridedMemRefOfRank<2>);
|
||||
let extraClassDeclaration = libraryCallName;
|
||||
let extraClassDeclaration = libraryCallName # [{
|
||||
ArrayAttr indexing_maps();
|
||||
|
||||
ArrayAttr iterator_types() {
|
||||
MLIRContext *ctx = getContext();
|
||||
Attribute iters[3]{
|
||||
StringAttr::get(getParallelIteratorTypeName(), ctx),
|
||||
StringAttr::get(getParallelIteratorTypeName(), ctx),
|
||||
StringAttr::get(getReductionIteratorTypeName(), ctx)};
|
||||
return ArrayAttr::get(iters, ctx);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> {
|
||||
def ConvOp : LinalgLibrary_Op<"conv", [NInputs<2>, NOutputs<1>]> {
|
||||
let description = [{
|
||||
Generic n-D convolution as described in the TF documentation:
|
||||
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/nn/convolution
|
||||
|
@ -333,21 +344,27 @@ def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> {
|
|||
unsigned getNumInputFeatureDimensions() { return 1; }
|
||||
unsigned getNumOutputFeatureDimensions() { return 1; }
|
||||
|
||||
// Outer parallel loops are always the number of output dimensions; i.e.
|
||||
// [ b, xs, q] in the TF notation above.
|
||||
unsigned getNumParallelLoops() { return getOutputViewType(0).getRank(); }
|
||||
ArrayAttr indexing_maps();
|
||||
|
||||
// Window loops are a special kind of reduction that is neither tiled or
|
||||
// parallelized across; i.e. [zs] in the TF notation above whose number
|
||||
// match `xs` (i.e. 1 window loop per "image" dimension).
|
||||
unsigned getNumWindowLoops() {
|
||||
return getNumParallelLoops() - getNumBatchDimensions() -
|
||||
getNumInputFeatureDimensions(); }
|
||||
|
||||
// Reduction loops are exactly the non-parallel, non-window loops (i.e. `q`)
|
||||
// We distinguish between reduction loops and convolution window loops for
|
||||
// now. That distinction may disappear in the future.
|
||||
unsigned getNumReductionLoops() { return getNumInputFeatureDimensions(); }
|
||||
ArrayAttr iterator_types() {
|
||||
// Outer parallel loops are always the number of output dimensions; i.e.
|
||||
// [ b, xs, q] in the TF notation above.
|
||||
unsigned nPar = getOutputViewType(0).getRank();
|
||||
unsigned nRed = getNumInputFeatureDimensions();
|
||||
// Window loops are a special kind of reduction that is never tiled or
|
||||
// parallelized across; i.e. [zs] in the TF notation above whose number
|
||||
// match `xs` (i.e. 1 window loop per "image" dimension).
|
||||
// This may evolve in the future.
|
||||
unsigned nWin =
|
||||
nPar - getNumBatchDimensions() - getNumInputFeatureDimensions();
|
||||
MLIRContext *ctx = getContext();
|
||||
SmallVector<Attribute, 8> iters(
|
||||
nPar, StringAttr::get(getParallelIteratorTypeName(), ctx));
|
||||
iters.reserve(nPar + nRed + nWin);
|
||||
iters.append(nRed, StringAttr::get(getReductionIteratorTypeName(), ctx));
|
||||
iters.append(nWin, StringAttr::get(getWindowIteratorTypeName(), ctx));
|
||||
return ArrayAttr::get(iters, ctx);
|
||||
}
|
||||
|
||||
int64_t getStride(unsigned i) {
|
||||
assert(i < getNumWindowLoops());
|
||||
|
@ -368,9 +385,10 @@ def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> {
|
|||
|
||||
class GenericOpBase<string mnemonic> : LinalgLibraryBase_Op<mnemonic, []> {
|
||||
let arguments = (ins Variadic<AnyStridedMemRef>:$views,
|
||||
I64Attr:$args_in,
|
||||
I64Attr:$args_out,
|
||||
AffineMapArrayAttr:$indexing_maps,
|
||||
ArrayAttr:$iterator_types,
|
||||
I64ArrayAttr:$n_views,
|
||||
OptionalAttr<StrAttr>:$doc,
|
||||
OptionalAttr<FlatSymbolRefAttr>:$fun,
|
||||
OptionalAttr<StrAttr>:$library_call);
|
||||
|
@ -378,57 +396,13 @@ class GenericOpBase<string mnemonic> : LinalgLibraryBase_Op<mnemonic, []> {
|
|||
let extraClassDeclaration = [{
|
||||
SmallVector<StringRef, 8> linalgTraitAttrNames() {
|
||||
return SmallVector<StringRef, 8>{
|
||||
"doc", "fun", "indexing_maps", "library_call", "iterator_types", "n_views"
|
||||
getArgsInAttrName(), getArgsOutAttrName(), getDocAttrName(),
|
||||
getFunAttrName(), getIndexingMapsAttrName(), getLibraryCallAttrName(),
|
||||
getIteratorTypesAttrName()
|
||||
};
|
||||
}
|
||||
unsigned getNumInputs() {
|
||||
if (!getAttr("n_views") || n_views().getValue().size() != 2)
|
||||
return 0;
|
||||
auto val = n_views().getValue()[0].cast<IntegerAttr>().getValue();
|
||||
assert(val.getSExtValue() >= 0);
|
||||
return val.getZExtValue();
|
||||
}
|
||||
unsigned getNumOutputs() {
|
||||
if (!getAttr("n_views") || n_views().getValue().size() != 2)
|
||||
return 0;
|
||||
auto val = n_views().getValue()[1].cast<IntegerAttr>().getValue();
|
||||
assert(val.getSExtValue() >= 0);
|
||||
return val.getZExtValue();
|
||||
}
|
||||
unsigned getNumParallelLoops() {
|
||||
if (!getAttr("iterator_types") || iterator_types().getValue().size() == 0)
|
||||
return 0;
|
||||
unsigned nPar = 0;
|
||||
for (auto ty : iterator_types()) {
|
||||
if (ty.cast<StringAttr>().getValue() == "parallel")
|
||||
nPar++;
|
||||
}
|
||||
return nPar;
|
||||
}
|
||||
unsigned getNumReductionLoops() {
|
||||
if (!getAttr("iterator_types") || iterator_types().getValue().size() == 0)
|
||||
return 0;
|
||||
unsigned nRed = 0;
|
||||
for (auto ty : iterator_types()) {
|
||||
if (ty.cast<StringAttr>().getValue() == "reduction")
|
||||
nRed++;
|
||||
}
|
||||
return nRed;
|
||||
}
|
||||
unsigned getNumWindowLoops() {
|
||||
if (!getAttr("iterator_types") || iterator_types().getValue().size() == 0)
|
||||
return 0;
|
||||
unsigned nWin = 0;
|
||||
for (auto ty : iterator_types()) {
|
||||
if (ty.cast<StringAttr>().getValue() == "window")
|
||||
nWin++;
|
||||
}
|
||||
return nWin;
|
||||
}
|
||||
unsigned getNumLoops() {
|
||||
return getNumParallelLoops() + getNumReductionLoops() +
|
||||
getNumWindowLoops();
|
||||
}
|
||||
unsigned getNumInputs() { return args_in().getSExtValue(); }
|
||||
unsigned getNumOutputs() { return args_out().getSExtValue(); }
|
||||
FuncOp getFunction() {
|
||||
auto moduleOp = getParentOfType<ModuleOp>();
|
||||
return fun().hasValue() ?
|
||||
|
@ -468,10 +442,12 @@ def GenericOp : GenericOpBase<"generic"> {
|
|||
```
|
||||
|
||||
Where #trait_attributes is an alias of a dictionary attribute containing:
|
||||
- args_in: an I64Attr representing the number of input (readonly) views
|
||||
- args_out: an I64Attr representing the number of output (readwrite) views
|
||||
- doc [optional]: a documentation string
|
||||
- fun: a FlatSymbolRefAttr that must resolve to an existing function symbol.
|
||||
To support inplace updates in a generic fashion, the signature of the
|
||||
function must be:
|
||||
- fun: a FlatSymbolRefAttr that must resolve to an existing function
|
||||
symbol. To support inplace updates in a generic fashion, the signature
|
||||
of the function must be:
|
||||
```
|
||||
fun([input views element types], [output views element types])
|
||||
-> ([output views element types])
|
||||
|
@ -488,8 +464,6 @@ def GenericOp : GenericOpBase<"generic"> {
|
|||
Each element of the list represents and iterator of one of the following
|
||||
types:
|
||||
parallel, reduction, window
|
||||
- n_views: a pair of I64Attr representing the number of input (readonly)
|
||||
and output (readwrite) views.
|
||||
|
||||
Example:
|
||||
Defining a #matmul_trait attribute in MLIR can be done as follows:
|
||||
|
@ -564,12 +538,14 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
|
|||
```
|
||||
|
||||
Where #trait_attributes is an alias of a dictionary attribute containing:
|
||||
- args_in: an I64Attr representing the number of input (readonly) views
|
||||
- args_out: an I64Attr representing the number of output (readwrite) views
|
||||
- doc [optional]: a documentation string
|
||||
- fun: a FlatSymbolRefAttr that must resolve to an existing function symbol.
|
||||
To support inplace updates in a generic fashion, the signature of the
|
||||
function must be:
|
||||
- fun: a FlatSymbolRefAttr that must resolve to an existing function
|
||||
symbol. To support inplace updates in a generic fashion, the signature
|
||||
of the function must be:
|
||||
```
|
||||
fun([index types for induction variables], [input views element types],
|
||||
fun([index types of induction variables], [input views element types],
|
||||
[output views element types]) -> ([output views element types])
|
||||
```
|
||||
- indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input
|
||||
|
@ -580,16 +556,17 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
|
|||
maps to. The external library is assumed to be dynamically linked and
|
||||
no strong compile-time guarantees are provided. In the absence of such
|
||||
a library call, linalg.indexed_generic will always lower to loops.
|
||||
- iterator_types: an ArrayAttr they type of the enclosing loops; Each element of
|
||||
the list represents and iterator of one of the following types:
|
||||
parallel, reduction, window
|
||||
- n_views: a pair of I64Attr representing the number of input (readonly)
|
||||
and output (readwrite) views.
|
||||
- iterator_types: an ArrayAttr they type of the enclosing loops; Each
|
||||
element of the list represents and iterator of one of the following
|
||||
types:
|
||||
parallel, reduction, window
|
||||
|
||||
Example:
|
||||
Defining a #matmul_trait attribute in MLIR can be done as follows:
|
||||
```mlir
|
||||
func @fma(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32) -> f32 {
|
||||
func @fma(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32)
|
||||
-> f32
|
||||
{
|
||||
%d = mulf %a, %b: f32
|
||||
%e = addf %c, %d: f32
|
||||
return %e: f32
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTraits.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
@ -85,7 +86,6 @@ SmallVector<AffineMap, 4> loopToOperandRangesMaps(Operation *op);
|
|||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgLibraryOps.h.inc"
|
||||
|
||||
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#define MLIR_DIALECT_LINALG_LINALGTRAITS_H_
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
@ -28,23 +29,30 @@ namespace OpTrait {
|
|||
namespace linalg {
|
||||
|
||||
/// This class provides the API for ops that are known to have a specified
|
||||
/// number of inputs and outputs, all passed as operands. This is used as a
|
||||
/// trait like this:
|
||||
/// number of inputs, all passed as operands. This is used as a trait like this:
|
||||
///
|
||||
/// class DotOp : public Op<DotOp, OpTrait::NInputsAndOutputs<2, 1>::Impl> {
|
||||
/// class DotOp : public Op<DotOp, OpTrait::NInputs<2>::Impl> {
|
||||
///
|
||||
template <unsigned NInputs, unsigned NOutputs> class NInputsAndOutputs {
|
||||
template <unsigned N> class NInputs {
|
||||
public:
|
||||
template <typename ConcreteType>
|
||||
class Impl
|
||||
: public OpTrait::TraitBase<ConcreteType,
|
||||
NInputsAndOutputs<NInputs, NOutputs>::Impl> {
|
||||
class Impl : public OpTrait::TraitBase<ConcreteType, NInputs<N>::Impl> {
|
||||
public:
|
||||
static unsigned getNumInputs() { return NInputs; }
|
||||
static unsigned getNumOutputs() { return NOutputs; }
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return OpTrait::impl::verifyAtLeastNOperands(op, NInputs + NOutputs);
|
||||
}
|
||||
static unsigned getNumInputs() { return N; }
|
||||
};
|
||||
};
|
||||
|
||||
/// This class provides the API for ops that are known to have a specified
|
||||
/// number of inputs, all passed as operands. This is used as a trait like this:
|
||||
///
|
||||
/// class DotOp : public Op<DotOp, OpTrait::NOutputs<2>::Impl> {
|
||||
///
|
||||
template <unsigned N> class NOutputs {
|
||||
public:
|
||||
template <typename ConcreteType>
|
||||
class Impl : public OpTrait::TraitBase<ConcreteType, NOutputs<N>::Impl> {
|
||||
public:
|
||||
static unsigned getNumOutputs() { return N; }
|
||||
};
|
||||
};
|
||||
|
||||
|
@ -124,6 +132,25 @@ public:
|
|||
auto range = this->getOperation()->getOperands();
|
||||
return {range.begin(), range.begin() + getNumInputsAndOutputs()};
|
||||
}
|
||||
unsigned getNumParallelLoops() {
|
||||
return getNumIterators(
|
||||
getParallelIteratorTypeName(),
|
||||
cast<ConcreteType>(this->getOperation()).iterator_types());
|
||||
}
|
||||
unsigned getNumReductionLoops() {
|
||||
return getNumIterators(
|
||||
getReductionIteratorTypeName(),
|
||||
cast<ConcreteType>(this->getOperation()).iterator_types());
|
||||
}
|
||||
unsigned getNumWindowLoops() {
|
||||
return getNumIterators(
|
||||
getWindowIteratorTypeName(),
|
||||
cast<ConcreteType>(this->getOperation()).iterator_types());
|
||||
}
|
||||
unsigned getNumLoops() {
|
||||
return getNumIterators(
|
||||
cast<ConcreteType>(this->getOperation()).iterator_types());
|
||||
}
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
auto nViews = cast<ConcreteType>(op).getNumInputsAndOutputs();
|
||||
if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nViews)))
|
||||
|
@ -132,27 +159,6 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// This class provides the API for ops that are known to have a specified
|
||||
/// number of parallel, reduction and window loops. This is used as a trait like
|
||||
/// this:
|
||||
///
|
||||
/// class MatmulOp : public Op<MatmulOp, OpTrait::NLoopTypes<2, 1, 0>::Impl> {
|
||||
///
|
||||
template <unsigned NParallel, unsigned NReduction, unsigned NWindow = 0>
|
||||
class NLoopTypes {
|
||||
public:
|
||||
template <typename ConcreteType>
|
||||
class Impl
|
||||
: public OpTrait::TraitBase<
|
||||
ConcreteType, NLoopTypes<NParallel, NReduction, NWindow>::Impl> {
|
||||
public:
|
||||
static unsigned getNumParallelLoops() { return NParallel; }
|
||||
static unsigned getNumReductionLoops() { return NReduction; }
|
||||
static unsigned getNumWindowLoops() { return NWindow; }
|
||||
static unsigned getNumLoops() { return NParallel + NReduction + NWindow; }
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
|
|
@ -40,7 +40,7 @@ class IsProducedByOpOfType<string str> :
|
|||
CPred<"isProducedByOpOfType<" # str # ">($0, $1)">;
|
||||
|
||||
class AffineMapDomainHasDim<int n> : CPred<[{
|
||||
$0.getAttrOfType<ArrayAttr>("indexing_maps").getValue()[0].
|
||||
$0.getAttrOfType<ArrayAttr>(getIndexingMapsAttrName()).getValue()[0].
|
||||
cast<AffineMapAttr>().getValue().getNumDims() ==}] # n # [{}]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -67,11 +67,12 @@ class TileAndFuseLinalgOp<
|
|||
// `permutation` is an optional parameter to specify the ordering of the
|
||||
// tiled loops. If provided, it must be a list of integers with the same number
|
||||
// of elements as `sizes`.
|
||||
class TileLinalgOp<list<int> sizes, string value, list<int> permutation=[]> : NativeCodeCall<
|
||||
"if (failed(tileLinalgOpAndSetMarker($_builder, $0, {" #
|
||||
StrJoinInt<sizes>.result # "}, \"" # value # "\", {" #
|
||||
StrJoinInt<permutation>.result # "})))" #
|
||||
" return matchFailure();">;
|
||||
class TileLinalgOp<list<int> sizes, string value, list<int> permutation=[]> :
|
||||
NativeCodeCall<
|
||||
"if (failed(tileLinalgOpAndSetMarker($_builder, $0, {" #
|
||||
StrJoinInt<sizes>.result # "}, \"" # value # "\", {" #
|
||||
StrJoinInt<permutation>.result # "})))" #
|
||||
" return matchFailure();">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg to loop patterns.
|
||||
|
@ -96,7 +97,8 @@ class LinalgOpToVectorContraction<string OpType> : NativeCodeCall<
|
|||
//===----------------------------------------------------------------------===//
|
||||
class PermuteGenericLinalgOp<list<int> permutation, string value> :
|
||||
NativeCodeCall<
|
||||
"if (failed(permuteGenericLinalgOp($_builder, $0, {" #
|
||||
StrJoinInt<permutation>.result # "}, \"" # value # "\"))) " #
|
||||
" return matchFailure();">;
|
||||
"if (failed(permuteGenericLinalgOp($_builder, $0, {" #
|
||||
StrJoinInt<permutation>.result # "}, \"" # value # "\"))) " #
|
||||
" return matchFailure();">;
|
||||
|
||||
#endif // LINALG_TRANSFORMS
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#ifndef MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
|
||||
#define MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
|
@ -42,6 +43,32 @@ static constexpr StringLiteral getIteratorTypesAttrName() {
|
|||
return StringLiteral("iterator_types");
|
||||
}
|
||||
|
||||
/// Attribute name for the IntegerAttr which encodes the number of input buffer
|
||||
/// arguments.
|
||||
static constexpr StringLiteral getArgsInAttrName() {
|
||||
return StringLiteral("args_in");
|
||||
}
|
||||
|
||||
/// Attribute name for the IntegerAttr which encodes the number of input buffer
|
||||
/// arguments.
|
||||
static constexpr StringLiteral getArgsOutAttrName() {
|
||||
return StringLiteral("args_out");
|
||||
}
|
||||
|
||||
/// Attribute name for the StringAttr which encodes an optional documentation
|
||||
/// string of the structured op.
|
||||
static constexpr StringLiteral getDocAttrName() { return StringLiteral("doc"); }
|
||||
|
||||
/// Attribute name for the StrArrayAttr which encodes the SymbolAttr for the
|
||||
/// MLIR function that implements the body of the structured op.
|
||||
static constexpr StringLiteral getFunAttrName() { return StringLiteral("fun"); }
|
||||
|
||||
/// Attribute name for the StrArrayAttr which encodes the external library
|
||||
/// function that implements the structured op.
|
||||
static constexpr StringLiteral getLibraryCallAttrName() {
|
||||
return StringLiteral("library_call");
|
||||
}
|
||||
|
||||
/// Use to encode that a particular iterator type has parallel semantics.
|
||||
inline static constexpr StringLiteral getParallelIteratorTypeName() {
|
||||
return StringLiteral("parallel");
|
||||
|
@ -51,6 +78,37 @@ inline static constexpr StringLiteral getParallelIteratorTypeName() {
|
|||
inline static constexpr StringLiteral getReductionIteratorTypeName() {
|
||||
return StringLiteral("reduction");
|
||||
}
|
||||
|
||||
/// Use to encode that a particular iterator type has window semantics.
|
||||
inline static constexpr StringLiteral getWindowIteratorTypeName() {
|
||||
return StringLiteral("window");
|
||||
}
|
||||
|
||||
/// Use to encode that a particular iterator type has window semantics.
|
||||
inline static ArrayRef<StringRef> getAllIteratorTypeNames() {
|
||||
static const StringRef names[3] = {getParallelIteratorTypeName(),
|
||||
getReductionIteratorTypeName(),
|
||||
getWindowIteratorTypeName()};
|
||||
return names;
|
||||
}
|
||||
|
||||
/// Returns the iterator of a certain type.
|
||||
inline unsigned getNumIterators(StringRef name, ArrayAttr iteratorTypes) {
|
||||
auto names = getAllIteratorTypeNames();
|
||||
(void)names;
|
||||
assert(llvm::is_contained(names, name));
|
||||
return llvm::count_if(iteratorTypes, [name](Attribute a) {
|
||||
return a.cast<StringAttr>().getValue() == name;
|
||||
});
|
||||
}
|
||||
|
||||
inline unsigned getNumIterators(ArrayAttr iteratorTypes) {
|
||||
unsigned res = 0;
|
||||
for (auto n : getAllIteratorTypeNames())
|
||||
res += getNumIterators(n, iteratorTypes);
|
||||
return res;
|
||||
}
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_UTILS_STRUCTUREDOPSUTILS_H
|
||||
|
|
|
@ -96,8 +96,8 @@ static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
|
|||
Region ®ion = *result.addRegion();
|
||||
SmallVector<Type, 8> operandTypes, regionTypes;
|
||||
// Optional attributes may be added.
|
||||
// Either Optional "fun" attribute or region must be specified.
|
||||
if (!dictAttr.get("fun") &&
|
||||
// Either Optional getFunAttrName() attribute or region must be specified.
|
||||
if (!dictAttr.get(getFunAttrName()) &&
|
||||
parser.parseOptionalRegion(region, regionOperandsInfo, regionTypes))
|
||||
return failure();
|
||||
if (parser.parseOptionalAttrDict(result.attributes) ||
|
||||
|
@ -557,7 +557,7 @@ static ParseResult parseLinalgLibraryOp(OpAsmParser &parser,
|
|||
|
||||
static LogicalResult verify(FillOp op) {
|
||||
auto viewType = op.getOutputViewType(0);
|
||||
auto fillType = op.getValue()->getType();
|
||||
auto fillType = op.value()->getType();
|
||||
if (viewType.getElementType() != fillType)
|
||||
return op.emitOpError("expects fill type to match view elemental type");
|
||||
return success();
|
||||
|
@ -813,3 +813,30 @@ std::string mlir::linalg::generateLibraryCallName(Operation *op) {
|
|||
[&]() { ss << "_"; });
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
static ArrayAttr getIndexingMaps(Operation *op) {
|
||||
LinalgOp linalgOp = cast<LinalgOp>(op);
|
||||
SmallVector<Attribute, 4> maps;
|
||||
maps.reserve(linalgOp.getNumInputsAndOutputs());
|
||||
for (AffineMap map : loopToOperandRangesMaps(op))
|
||||
maps.push_back(AffineMapAttr::get(map));
|
||||
return ArrayAttr::get(maps, op->getContext());
|
||||
}
|
||||
ArrayAttr mlir::linalg::ConvOp::indexing_maps() {
|
||||
return getIndexingMaps(getOperation());
|
||||
}
|
||||
ArrayAttr mlir::linalg::CopyOp::indexing_maps() {
|
||||
return getIndexingMaps(getOperation());
|
||||
}
|
||||
ArrayAttr mlir::linalg::DotOp::indexing_maps() {
|
||||
return getIndexingMaps(getOperation());
|
||||
}
|
||||
ArrayAttr mlir::linalg::FillOp::indexing_maps() {
|
||||
return getIndexingMaps(getOperation());
|
||||
}
|
||||
ArrayAttr mlir::linalg::MatmulOp::indexing_maps() {
|
||||
return getIndexingMaps(getOperation());
|
||||
}
|
||||
ArrayAttr mlir::linalg::MatvecOp::indexing_maps() {
|
||||
return getIndexingMaps(getOperation());
|
||||
}
|
||||
|
|
|
@ -130,8 +130,8 @@ public:
|
|||
IndexedValueType O(fillOp.getOutput(0));
|
||||
// Emit the proper scalar assignment, whether we are dealing with a 0-D or
|
||||
// an n-D loop nest; with or without permutations.
|
||||
nPar > 0 ? O(ivs) = ValueHandle(fillOp.getValue())
|
||||
: O() = ValueHandle(fillOp.getValue());
|
||||
nPar > 0 ? O(ivs) = ValueHandle(fillOp.value())
|
||||
: O() = ValueHandle(fillOp.value());
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -211,20 +211,20 @@ mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
|
|||
auto permutationMap = inversePermutation(
|
||||
AffineMap::getPermutationMap(permutation, rewriter.getContext()));
|
||||
SmallVector<AffineMap, 4> newIndexingMap;
|
||||
auto indexingMaps =
|
||||
linOp.getAttrOfType<ArrayAttr>("indexing_maps").getValue();
|
||||
auto indexingMaps = linOp.indexing_maps().getValue();
|
||||
for (unsigned i = 0, e = linOp.getNumInputsAndOutputs(); i != e; ++i) {
|
||||
AffineMap m = indexingMaps[i].cast<AffineMapAttr>().getValue().compose(
|
||||
permutationMap);
|
||||
newIndexingMap.push_back(m);
|
||||
}
|
||||
auto itTypes = linOp.getAttrOfType<ArrayAttr>("iterator_types").getValue();
|
||||
SmallVector<StringRef, 4> itTypesVector;
|
||||
auto itTypes = linOp.iterator_types().getValue();
|
||||
SmallVector<Attribute, 4> itTypesVector;
|
||||
for (unsigned i = 0, e = itTypes.size(); i != e; ++i)
|
||||
itTypesVector.push_back(itTypes[i].cast<StringAttr>().getValue());
|
||||
itTypesVector.push_back(itTypes[i]);
|
||||
applyPermutationToVector(itTypesVector, permutation);
|
||||
op->setAttr("indexing_maps", rewriter.getAffineMapArrayAttr(newIndexingMap));
|
||||
op->setAttr("iterator_types", rewriter.getStrArrayAttr(itTypesVector));
|
||||
op->setAttr(getIndexingMapsAttrName(),
|
||||
rewriter.getAffineMapArrayAttr(newIndexingMap));
|
||||
op->setAttr(getIteratorTypesAttrName(), rewriter.getArrayAttr(itTypesVector));
|
||||
op->setAttr(LinalgTransforms::kLinalgTransformMarker,
|
||||
rewriter.getStringAttr(linalgMarker));
|
||||
linOp.clone(rewriter, linOp.getLoc(), op->getOperands());
|
||||
|
|
|
@ -305,9 +305,10 @@ func @f8(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, of
|
|||
|
||||
#id_2d = (i, j) -> (i, j)
|
||||
#pointwise_2d_trait = {
|
||||
args_in = 2,
|
||||
args_out = 1,
|
||||
indexing_maps = [#id_2d, #id_2d, #id_2d],
|
||||
iterator_types = ["parallel", "parallel"],
|
||||
n_views = [2, 1]
|
||||
iterator_types = ["parallel", "parallel"]
|
||||
}
|
||||
func @pointwise(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, offset: 0, strides: [?, ?]>, %C: memref<?x?xf32, offset: 0, strides: [?, ?]>, %D: memref<?x?xf32, offset: 0, strides: [?, ?]>) {
|
||||
%c1 = constant 1 : index
|
||||
|
|
|
@ -57,9 +57,10 @@ func @yield_parent(%arg0: memref<?xf32, (i)[off]->(off + i)>) {
|
|||
func @generic_at_least_2_operands(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected 2 or more operands}}
|
||||
linalg.generic {
|
||||
args_in = 1,
|
||||
args_out = 1,
|
||||
fun = @foo,
|
||||
indexing_maps = [ () -> (0) ],
|
||||
n_views = [1, 1],
|
||||
iterator_types = []
|
||||
} %arg0: memref<f32>
|
||||
}
|
||||
|
@ -69,9 +70,10 @@ func @generic_at_least_2_operands(%arg0: memref<f32>) {
|
|||
func @generic_exactly_2_views(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected exactly 2 view operands}}
|
||||
linalg.generic {
|
||||
args_in = 1,
|
||||
args_out = 1,
|
||||
fun = @foo,
|
||||
indexing_maps = [ () -> (0) ],
|
||||
n_views = [1, 1],
|
||||
iterator_types = []
|
||||
} %arg0, %arg0, %arg0: memref<f32>, memref<f32>, memref<f32>
|
||||
}
|
||||
|
@ -81,9 +83,10 @@ func @generic_exactly_2_views(%arg0: memref<f32>) {
|
|||
func @generic_undefined_fun(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected fun attribute to refer to a defined symbol}}
|
||||
linalg.generic {
|
||||
args_in = 1,
|
||||
args_out = 1,
|
||||
fun = @foo,
|
||||
indexing_maps = [ () -> (0) ],
|
||||
n_views = [1, 1],
|
||||
iterator_types = []
|
||||
} %arg0, %arg0: memref<f32>, memref<f32>
|
||||
}
|
||||
|
@ -95,9 +98,10 @@ func @foo() { return }
|
|||
func @generic_mismatched_num_arguments(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected fun arguments to match number of views}}
|
||||
linalg.generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
fun = @foo,
|
||||
indexing_maps = [ () -> (0) ],
|
||||
n_views = [0, 1],
|
||||
iterator_types = []
|
||||
} %arg0: memref<f32>
|
||||
}
|
||||
|
@ -109,9 +113,10 @@ func @foo(%0: i32) { return }
|
|||
func @generic_mismatched_num_returns(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected fun results to match number of output views}}
|
||||
linalg.generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
fun = @foo,
|
||||
indexing_maps = [ () -> (0) ],
|
||||
n_views = [0, 1],
|
||||
iterator_types = []
|
||||
} %arg0: memref<f32>
|
||||
}
|
||||
|
@ -123,9 +128,10 @@ func @foo(%0: i32) -> i32 { return %0: i32 }
|
|||
func @generic_symbol_in_map(%arg0: memref<i32>) {
|
||||
// expected-error @+1 {{op expected indexing_map #0 to have no symbols}}
|
||||
linalg.generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
fun = @foo,
|
||||
indexing_maps = [ ()[N] -> (0) ],
|
||||
n_views = [0, 1],
|
||||
iterator_types = ["parallel"]
|
||||
} %arg0: memref<i32>
|
||||
}
|
||||
|
@ -137,9 +143,10 @@ func @foo(%0: i32) -> i32 { return %0: i32 }
|
|||
func @generic_wrong_dim_in_map(%arg0: memref<i32>) {
|
||||
// expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match the number of loops}}
|
||||
linalg.generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
fun = @foo,
|
||||
indexing_maps = [ () -> (0) ],
|
||||
n_views = [0, 1],
|
||||
iterator_types = ["parallel"]
|
||||
} %arg0: memref<i32>
|
||||
}
|
||||
|
@ -151,9 +158,10 @@ func @foo(%0: i32) -> i32 { return %0: i32 }
|
|||
func @generic_zero_d_view(%arg0: memref<i32>) {
|
||||
// expected-error @+1 {{op expected indexing_map #0 to be 0 to match 0-D view: 'memref<i32>'}}
|
||||
linalg.generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
fun = @foo,
|
||||
indexing_maps = [ () -> (1) ],
|
||||
n_views = [0, 1],
|
||||
iterator_types = []
|
||||
} %arg0: memref<i32>
|
||||
}
|
||||
|
@ -165,9 +173,10 @@ func @foo(%0: f32) -> f32 { return %0: f32 }
|
|||
func @generic_one_d_view(%arg0: memref<?xf32, (i)[off]->(off + i)>) {
|
||||
// expected-error @+1 {{op expected indexing_map #0 results to match view rank: 'memref<?xf32, (d0)[s0] -> (d0 + s0)>'}}
|
||||
linalg.generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
fun = @foo,
|
||||
indexing_maps = [ () -> (0, 0) ],
|
||||
n_views = [0, 1],
|
||||
iterator_types = []
|
||||
} %arg0: memref<?xf32, (i)[off]->(off + i)>
|
||||
}
|
||||
|
@ -182,9 +191,10 @@ func @foo(%0: i32) -> f32 {
|
|||
func @generic_fun_arg_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)>) {
|
||||
// expected-error @+1 {{op expected fun argument 0 of the same type as elemental type 'f32' of view 0}}
|
||||
linalg.generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
fun = @foo,
|
||||
indexing_maps = [ () -> (0) ],
|
||||
n_views = [0, 1],
|
||||
iterator_types = []
|
||||
} %arg0: memref<?xf32, (i)[off]->(off + i)>
|
||||
}
|
||||
|
@ -199,9 +209,10 @@ func @foo(%0: f32) -> i4 {
|
|||
func @generic_fun_result_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)>) {
|
||||
// expected-error @+1 {{op expected fun result 0 of the same type as elemental type 'f32' of view 0}}
|
||||
linalg.generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
fun = @foo,
|
||||
indexing_maps = [ () -> (0) ],
|
||||
n_views = [0, 1],
|
||||
iterator_types = []
|
||||
} %arg0: memref<?xf32, (i)[off]->(off + i)>
|
||||
}
|
||||
|
@ -213,12 +224,13 @@ func @foo(%0: f32, %1: f32) -> f32 { return %1: f32 }
|
|||
func @generic_singular_maps(%arg0: memref<?xf32, (i)[off]->(off + i)>, %arg1: memref<?xf32, (i)[off]->(off + i)>) {
|
||||
// expected-error @+1 {{op expected the concatenation of maps in indexing_map to be invertible}}
|
||||
linalg.generic {
|
||||
args_in = 1,
|
||||
args_out = 1,
|
||||
fun = @foo,
|
||||
indexing_maps = [
|
||||
(i, j) -> (i + j) ,
|
||||
(i, j) -> (i + j)
|
||||
],
|
||||
n_views = [1, 1],
|
||||
iterator_types = ["parallel","parallel"]
|
||||
} %arg0, %arg1: memref<?xf32, (i)[off]->(off + i)>, memref<?xf32, (i)[off]->(off + i)>
|
||||
}
|
||||
|
@ -232,8 +244,9 @@ func @generic_singular_maps(%arg0: memref<?xf32, (i)[off]->(off + i)>, %arg1: me
|
|||
func @generic_empty_region(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected region with 1 block}}
|
||||
linalg.generic {
|
||||
args_in = 1,
|
||||
args_out = 1,
|
||||
indexing_maps = [ () -> (0) ],
|
||||
n_views = [1, 1],
|
||||
iterator_types = []
|
||||
} %arg0, %arg0 {
|
||||
^bb1:
|
||||
|
@ -246,8 +259,9 @@ func @generic_empty_region(%arg0: memref<f32>) {
|
|||
func @generic_mismatched_num_arguments(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected number of block arguments to match number of views}}
|
||||
linalg.generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
indexing_maps = [ () -> (0) ],
|
||||
n_views = [0, 1],
|
||||
iterator_types = []
|
||||
} %arg0 {
|
||||
^bb:
|
||||
|
@ -259,8 +273,9 @@ func @generic_mismatched_num_arguments(%arg0: memref<f32>) {
|
|||
func @generic_block_arg_type(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected block argument 0 of the same type as elemental type of output view: 'memref<f32>'}}
|
||||
linalg.generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
indexing_maps = [ () -> (0) ],
|
||||
n_views = [0, 1],
|
||||
iterator_types = []
|
||||
} %arg0 {
|
||||
^bb(%i: i1):
|
||||
|
@ -272,8 +287,9 @@ func @generic_block_arg_type(%arg0: memref<f32>) {
|
|||
func @indexed_generic_block_arg_count(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected number of block arguments to match number of views + number of loops}}
|
||||
linalg.indexed_generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
indexing_maps = [ (d0) -> (d0) ],
|
||||
n_views = [0, 1],
|
||||
iterator_types = ["parallel"]
|
||||
} %arg0 {
|
||||
^bb(%f: f32):
|
||||
|
@ -285,8 +301,9 @@ func @indexed_generic_block_arg_count(%arg0: memref<f32>) {
|
|||
func @indexed_generic_block_induction_var_arg_type(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected block argument 0 to be of IndexType}}
|
||||
linalg.indexed_generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
indexing_maps = [ (d0) -> (d0) ],
|
||||
n_views = [0, 1],
|
||||
iterator_types = ["parallel"]
|
||||
} %arg0 {
|
||||
^bb(%i: f64, %f: f32):
|
||||
|
@ -298,8 +315,9 @@ func @indexed_generic_block_induction_var_arg_type(%arg0: memref<f32>) {
|
|||
func @indexed_generic_block_arg_type(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected block argument 1 of the same type as elemental type of output view: 'memref<f32>'}}
|
||||
linalg.indexed_generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
indexing_maps = [ (d0) -> (d0) ],
|
||||
n_views = [0, 1],
|
||||
iterator_types = ["parallel"]
|
||||
} %arg0 {
|
||||
^bb(%i: index, %f: i1):
|
||||
|
@ -314,8 +332,9 @@ func @foo(%f: f32) -> (f32) {
|
|||
func @indexed_generic_fun_arg_count(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected fun arguments to match number of views + number of loops}}
|
||||
linalg.indexed_generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
indexing_maps = [ (d0) -> (d0) ],
|
||||
n_views = [0, 1],
|
||||
iterator_types = ["parallel"],
|
||||
fun = @foo
|
||||
} %arg0: memref<f32>
|
||||
|
@ -329,7 +348,8 @@ func @foo(%i: i32, %val: f32) -> (f32) {
|
|||
func @indexed_generic_fun_induction_var_arg_type(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected fun argument 0 to be of IndexType}}
|
||||
linalg.indexed_generic {
|
||||
n_views = [0, 1],
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
iterator_types = ["parallel"],
|
||||
indexing_maps = [ (i) -> (i) ],
|
||||
fun = @foo
|
||||
|
@ -344,8 +364,9 @@ func @foo(%i: index, %val: i1) -> (i1) {
|
|||
func @indexed_generic_fun_arg_type(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected fun argument 1 of the same type as elemental type 'f32' of view 0}}
|
||||
linalg.indexed_generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
indexing_maps = [ (d0) -> (d0) ],
|
||||
n_views = [0, 1],
|
||||
iterator_types = ["parallel"],
|
||||
fun = @foo
|
||||
} %arg0: memref<f32>
|
||||
|
@ -359,8 +380,9 @@ func @foo(%i: index, %val: i1) -> (i1, i1) {
|
|||
func @indexed_generic_fun_result_count(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected fun results to match number of output views}}
|
||||
linalg.indexed_generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
indexing_maps = [ (d0) -> (d0) ],
|
||||
n_views = [0, 1],
|
||||
iterator_types = ["parallel"],
|
||||
fun = @foo
|
||||
} %arg0: memref<f32>
|
||||
|
@ -375,8 +397,9 @@ func @foo(%i: index, %val: i32) -> (f32) {
|
|||
func @indexed_generic_fun_result_count(%arg0: memref<i32>) {
|
||||
// expected-error @+1 {{op expected fun result 0 of the same type as elemental type 'i32' of view 0}}
|
||||
linalg.indexed_generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
indexing_maps = [ (d0) -> (d0) ],
|
||||
n_views = [0, 1],
|
||||
iterator_types = ["parallel"],
|
||||
fun = @foo
|
||||
} %arg0: memref<i32>
|
||||
|
@ -385,10 +408,11 @@ func @indexed_generic_fun_result_count(%arg0: memref<i32>) {
|
|||
// -----
|
||||
|
||||
func @generic_fun_result_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)>) {
|
||||
// expected-error @+8 {{type of return operand 0 ('i1') doesn't match view element type ('f32')}}
|
||||
// expected-error @+9 {{type of return operand 0 ('i1') doesn't match view element type ('f32')}}
|
||||
linalg.generic {
|
||||
indexing_maps = [ (i) -> (i) ],
|
||||
n_views = [0, 1],
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
indexing_maps = [ (i) -> (i) ],
|
||||
iterator_types = ["parallel"]
|
||||
} %arg0 {
|
||||
^bb(%i: f32):
|
||||
|
@ -399,6 +423,13 @@ func @generic_fun_result_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)
|
|||
|
||||
// -----
|
||||
|
||||
func @generic_fun_result_0_element_type(%arg0: memref<?xf32>) {
|
||||
// expected-error @+1 {{'linalg.dot' op expected 3 or more operands}}
|
||||
linalg.dot(%arg0, %arg0): memref<?xf32>, memref<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{unknown Linalg type}}
|
||||
!invalid_type = type !linalg.unknown
|
||||
|
||||
|
|
|
@ -138,7 +138,8 @@ func @copy_transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %a
|
|||
(m, n, k) -> (m, n)
|
||||
]
|
||||
#matmul_trait = {
|
||||
n_views = [2, 1],
|
||||
args_in = 2,
|
||||
args_out = 1,
|
||||
iterator_types = ["parallel", "parallel", "reduction"],
|
||||
indexing_maps = #matmul_accesses,
|
||||
library_call = "external_outerproduct_matmul"
|
||||
|
@ -175,7 +176,8 @@ func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C
|
|||
|
||||
|
||||
#indexed_matmul_trait = {
|
||||
n_views = [2, 1],
|
||||
args_in = 2,
|
||||
args_out = 1,
|
||||
iterator_types = ["parallel", "parallel", "reduction"],
|
||||
indexing_maps = #matmul_accesses,
|
||||
library_call = "external_indexed_outerproduct_matmul"
|
||||
|
|
|
@ -222,7 +222,8 @@ func @foo(%0: f32, %1: f32, %2: f32) -> (f32, f32) {
|
|||
(i, j, k) -> (i, k, j)
|
||||
]
|
||||
#trait = {
|
||||
n_views = [1, 2],
|
||||
args_in = 1,
|
||||
args_out = 2,
|
||||
iterator_types = ["parallel", "parallel", "parallel"],
|
||||
indexing_maps = #accesses,
|
||||
fun = @foo,
|
||||
|
@ -247,7 +248,8 @@ func @generic_function(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1
|
|||
// CHECK: store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
|
||||
|
||||
#trait2 = {
|
||||
n_views = [1, 2],
|
||||
args_in = 1,
|
||||
args_out = 2,
|
||||
iterator_types = ["parallel", "parallel", "parallel"],
|
||||
indexing_maps = #accesses,
|
||||
library_call = "some_external_function_name_2",
|
||||
|
@ -280,7 +282,8 @@ func @indexed_foo(%i: index, %j: index, %k: index, %0: f32, %1: f32, %2: f32) ->
|
|||
return %i_float, %i_float : f32, f32
|
||||
}
|
||||
#trait3 = {
|
||||
n_views = [1, 2],
|
||||
args_in = 1,
|
||||
args_out = 2,
|
||||
iterator_types = ["parallel", "parallel", "parallel"],
|
||||
indexing_maps = #accesses,
|
||||
fun = @indexed_foo,
|
||||
|
@ -310,7 +313,8 @@ func @indexed_generic_function(
|
|||
// CHECK: store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
|
||||
|
||||
#trait4 = {
|
||||
n_views = [1, 2],
|
||||
args_in = 1,
|
||||
args_out = 2,
|
||||
iterator_types = ["parallel", "parallel", "parallel"],
|
||||
indexing_maps = #accesses,
|
||||
library_call = "some_external_function_name_2",
|
||||
|
|
|
@ -120,8 +120,9 @@ func @conv_view6(%arg0: memref<?x?x?x?x?x?xf32, offset: ?, strides: [?, ?, ?, ?,
|
|||
(i, j, k) -> (i, k, i + j)
|
||||
]
|
||||
#trait = {
|
||||
args_in = 1,
|
||||
args_out = 1,
|
||||
indexing_maps = #accesses,
|
||||
n_views = [1, 1],
|
||||
iterator_types = ["parallel", "parallel", "parallel"],
|
||||
fun = @foo,
|
||||
library_call = "some_external_function_name_1"
|
||||
|
@ -136,11 +137,12 @@ func @generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>, %ar
|
|||
}
|
||||
// CHECK-LABEL: func @foo
|
||||
// CHECK-LABEL: func @generic
|
||||
// CHECK: linalg.generic {fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1", n_views = [1, 1]} %{{.*}}, %{{.*}} {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
|
||||
// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
|
||||
|
||||
#trait2 = {
|
||||
args_in = 1,
|
||||
args_out = 1,
|
||||
indexing_maps = #accesses,
|
||||
n_views = [1, 1],
|
||||
iterator_types = ["parallel", "parallel", "parallel"],
|
||||
library_call = "some_external_function_name_2"
|
||||
}
|
||||
|
@ -152,7 +154,7 @@ func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1
|
|||
return
|
||||
}
|
||||
// CHECK-LABEL: func @generic_region
|
||||
// CHECK: linalg.generic {indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2", n_views = [1, 1]} %{{.*}}, %{{.*}} {
|
||||
// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2"} %{{.*}}, %{{.*}} {
|
||||
// CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32): // no predecessors
|
||||
// CHECK: linalg.yield %{{.*}} : f32
|
||||
// CHECK: } {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
|
||||
|
@ -166,7 +168,7 @@ func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?,
|
|||
return
|
||||
}
|
||||
// CHECK-LABEL: func @indexed_generic
|
||||
// CHECK: linalg.indexed_generic {indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2", n_views = [1, 1]} %{{.*}}, %{{.*}} {
|
||||
// CHECK: linalg.indexed_generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2"} %{{.*}}, %{{.*}} {
|
||||
// CHECK: ^{{.*}}(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: vector<3x4xi4>, %{{.*}}: f32):
|
||||
// CHECK: linalg.yield %{{.*}} : f32
|
||||
// CHECK: } {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
|
||||
|
|
|
@ -213,9 +213,10 @@ func @fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: f32) {
|
|||
|
||||
#id_2d = (i, j) -> (i, j)
|
||||
#pointwise_2d_trait = {
|
||||
args_in = 2,
|
||||
args_out = 1,
|
||||
indexing_maps = [#id_2d, #id_2d, #id_2d],
|
||||
iterator_types = ["parallel", "parallel"],
|
||||
n_views = [2, 1]
|
||||
iterator_types = ["parallel", "parallel"]
|
||||
}
|
||||
|
||||
func @pointwise(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
|
|
|
@ -4,9 +4,10 @@
|
|||
|
||||
#id_1d = (i) -> (i)
|
||||
#pointwise_1d_trait = {
|
||||
args_in = 1,
|
||||
args_out = 1,
|
||||
indexing_maps = [#id_1d, #id_1d],
|
||||
iterator_types = ["parallel"],
|
||||
n_views = [1, 1]
|
||||
iterator_types = ["parallel"]
|
||||
}
|
||||
func @indexed_generic_vector(%operand: memref<50xf32>, %result: memref<50xf32>) {
|
||||
linalg.indexed_generic #pointwise_1d_trait %operand, %result {
|
||||
|
@ -43,12 +44,13 @@ func @indexed_generic_vector(%operand: memref<50xf32>, %result: memref<50xf32>)
|
|||
// TILE-0n25: linalg.indexed_generic
|
||||
|
||||
#combined_indices_trait = {
|
||||
args_in = 1,
|
||||
args_out = 1,
|
||||
indexing_maps = [
|
||||
(i, j) -> (j, i + j),
|
||||
(i, j) -> (i, j)
|
||||
],
|
||||
iterator_types = ["parallel", "parallel"],
|
||||
n_views = [1, 1]
|
||||
iterator_types = ["parallel", "parallel"]
|
||||
}
|
||||
func @indexed_generic_matrix(%operand: memref<50x100xf32>, %result: memref<50x100xf32>) {
|
||||
linalg.indexed_generic #combined_indices_trait %operand, %result {
|
||||
|
|
|
@ -83,11 +83,12 @@ func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
|||
// CHECK : linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
|
||||
|
||||
#some_generic_trait = {
|
||||
args_in = 1,
|
||||
args_out = 1,
|
||||
indexing_maps = [
|
||||
(i, j) -> (i, j),
|
||||
(i, j) -> (i, j)
|
||||
],
|
||||
n_views = [1, 1],
|
||||
iterator_types = ["parallel", "parallel"]
|
||||
}
|
||||
func @fusion_test(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
|
@ -164,12 +165,13 @@ func @fusion_test(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
|||
// CHECK : linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
|
||||
|
||||
#matmul_trait = {
|
||||
args_in = 2,
|
||||
args_out = 1,
|
||||
indexing_maps = [
|
||||
(m, n, k) -> (m, k),
|
||||
(m, n, k) -> (k, n),
|
||||
(m, n, k) -> (m, n)
|
||||
],
|
||||
n_views = [2, 1],
|
||||
iterator_types = ["parallel", "parallel", "reduction"],
|
||||
__internal_linalg_transform__ = "_marked_matmul_"
|
||||
}
|
||||
|
@ -204,10 +206,11 @@ func @fma(%a: f32, %b: f32, %c: f32) -> f32 {
|
|||
(m, n, k) -> (m, n)
|
||||
]
|
||||
#generic_matmul_trait = {
|
||||
args_in = 2,
|
||||
args_out = 1,
|
||||
fun = @fma,
|
||||
indexing_maps = #matmul_accesses,
|
||||
library_call = "linalg_matmul",
|
||||
n_views = [2, 1],
|
||||
iterator_types = ["parallel", "parallel", "reduction"]
|
||||
}
|
||||
|
||||
|
@ -220,7 +223,7 @@ func @permute_generic(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
|||
}
|
||||
// CHECK-LABEL : func @fma
|
||||
// CHECK-LABEL : func @permute_generic
|
||||
// CHECK : linalg.generic {fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul", n_views = [2, 1]} %{{.*}}, %{{.*}}, %{{.*}} : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
|
||||
// CHECK : linalg.generic {args_in = 2, args_out = 1, fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul"} %{{.*}}, %{{.*}}, %{{.*}} : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
|
||||
|
||||
func @fma_indexed(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32) -> f32 {
|
||||
%d = mulf %a, %b: f32
|
||||
|
@ -228,10 +231,11 @@ func @fma_indexed(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32) ->
|
|||
return %e: f32
|
||||
}
|
||||
#indexed_matmul_trait = {
|
||||
args_in = 2,
|
||||
args_out = 1,
|
||||
fun = @fma_indexed,
|
||||
indexing_maps = #matmul_accesses,
|
||||
library_call = "linalg_matmul_indexed",
|
||||
n_views = [2, 1],
|
||||
iterator_types = ["parallel", "parallel", "reduction"]
|
||||
}
|
||||
func @permute_generic_indexed(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
|
@ -242,7 +246,7 @@ func @permute_generic_indexed(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
|||
}
|
||||
// CHECK-LABEL : func @fma_indexed
|
||||
// CHECK-LABEL : func @permute_generic_indexed
|
||||
// CHECK : linalg.indexed_generic {fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul_indexed", n_views = [2, 1]} %{{.*}}, %{{.*}}, %{{.*}} : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
|
||||
// CHECK : linalg.indexed_generic {args_in = 2, args_out = 1, fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul_indexed"} %{{.*}}, %{{.*}}, %{{.*}} : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
|
||||
|
||||
func @dot_perm(%x: memref<?xf32, offset: ?, strides: [1]>,
|
||||
%y: memref<?xf32, offset: ?, strides: [1]>,
|
||||
|
|
|
@ -108,7 +108,7 @@ def : Pattern<(DotOp:$op $a, $b, $c),
|
|||
//===----------------------------------------------------------------------===//
|
||||
// Linalg to vector contraction patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pattern<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7),
|
||||
def : Pattern<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
|
||||
[(LinalgOpToVectorContraction<"GenericOp"> $op)],
|
||||
[(Constraint<HasLinalgTransformMarker<"_marked_matmul_">> $op)]>;
|
||||
|
||||
|
@ -116,13 +116,14 @@ def : Pattern<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7),
|
|||
// Linalg generic permutation patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def : Pat<(GenericOp:$op $input, $imap, $itypes, $nviews, $doc, $fun, $libcall),
|
||||
def : Pat<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
|
||||
(PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op),
|
||||
[(Constraint<And<[HasNoLinalgTransformMarker,
|
||||
AffineMapDomainHasDim<3>]>> $op)]>;
|
||||
|
||||
def : Pat<(IndexedGenericOp:$op $input, $imap, $itypes, $nviews, $doc, $fun, $libcall),
|
||||
def : Pat<(IndexedGenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
|
||||
(PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op),
|
||||
[(Constraint<And<[HasNoLinalgTransformMarker,
|
||||
AffineMapDomainHasDim<3>]>> $op)]>;
|
||||
|
||||
#endif // TEST_LINALG_TRANSFORMS_PATTERNS
|
||||
|
|
Loading…
Reference in New Issue