Continue refactoring StructuredOps utilities

This CL adds more common information to StructuredOpsUtils.h
The n_view attribute is retired in favor of args_in + args_out but the CL is otherwise NFC.

PiperOrigin-RevId: 285000621
This commit is contained in:
Nicolas Vasilache 2019-12-11 09:26:51 -08:00 committed by A. Unique TensorFlower
parent c5fb4c1303
commit 508d4e672e
17 changed files with 368 additions and 250 deletions

View File

@ -26,37 +26,17 @@
include "mlir/Dialect/AffineOps/AffineOpsBase.td"
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
class LinalgParametricNativeOpTrait<string prop, string parameters> :
NativeOpTrait<"linalg::" # prop # parameters>
{}
class LinalgParametricIntNativeOpTrait<string prop, list<int> parameters> :
LinalgParametricNativeOpTrait<
prop,
!strconcat("<",
!cast<string>(!head(parameters)),
!foldl("",
!tail(parameters),
sum,
param,
sum # "," # !cast<string>(param)),
">::Impl")>
{}
// The Linalg `NInputsAndOutputs` trait provides the API for ops that are known
// to have a specified number of inputs and outputs, all passed as operands.
// The Linalg `NInputs` trait provides the API for ops that are known
// to have a specified number of inputs, all passed as operands.
// See Linalg/LinalgTraits.h for implementation details an usage.
class NInputsAndOutputs<int n_ins, int n_outs> :
LinalgParametricIntNativeOpTrait<"NInputsAndOutputs", [n_ins, n_outs]>
{}
class NInputs<int args_in> :
NativeOpTrait<"linalg::NInputs<" # !cast<string>(args_in) # ">::Impl"> {}
// The linalg `NLoopTypes` trait provides the API for ops that are known to have
// a specified number of parallel (n_par), reduction (n_red) and window (n_win)
// loops.
// The Linalg `NOutputs` trait provides the API for ops that are known
// to have a specified number of outputs, all passed as operands.
// See Linalg/LinalgTraits.h for implementation details an usage.
class NLoopTypes<int n_par, int n_red, int n_win> :
LinalgParametricIntNativeOpTrait<"NLoopTypes", [n_par, n_red, n_win]>
{}
class NOutputs<int args_out> :
NativeOpTrait<"linalg::NOutputs<" # !cast<string>(args_out) # ">::Impl"> {}
def ViewTraits : NativeOpTrait<"linalg::ViewTraits">;
@ -88,6 +68,14 @@ def LinalgLibraryInterface : OpInterface<"LinalgOp"> {
"Query the input and output operands from the current operation.",
"Operation::operand_range", "getInputsAndOutputs"
>,
InterfaceMethod<
"Query the iterator types attribute within the current operation.",
"ArrayAttr", "iterator_types"
>,
InterfaceMethod<
"Query the indexing maps attribute within the current operation.",
"ArrayAttr", "indexing_maps"
>,
InterfaceMethod<
"Query the number of parallel loops within the current operation.",
"unsigned", "getNumParallelLoops"
@ -102,10 +90,7 @@ def LinalgLibraryInterface : OpInterface<"LinalgOp"> {
>,
InterfaceMethod<
"Query the number of loops within the current operation.",
"unsigned", "getNumLoops", (ins), [{
return op.getNumParallelLoops() + op.getNumReductionLoops() +
op.getNumWindowLoops();
}]>,
"unsigned", "getNumLoops">,
InterfaceMethod<"Query the input view at the given index.",
"Value *", "getInput", (ins "unsigned":$i)
>,
@ -188,7 +173,7 @@ class LinalgLibrary_Op<string mnemonic, list<OpTrait> props>
////////////////////////////////////////////////////////////////////////////////
// Concrete Linalg ops.
////////////////////////////////////////////////////////////////////////////////
def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> {
def CopyOp : LinalgLibrary_Op<"copy", [NInputs<1>, NOutputs<1>]> {
let description = [{
Copies the data in the input view into the output view.
@ -248,61 +233,87 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> {
builder, result, input, output, AffineMapAttr(), AffineMapAttr());
}]>];
let extraClassDeclaration = libraryCallName # [{
unsigned getNumParallelLoops() {
auto *view = *(getOperands().begin());
return view->getType().cast<MemRefType>().getRank();
ArrayAttr indexing_maps();
ArrayAttr iterator_types() {
unsigned nPar = input()->getType().cast<ShapedType>().getRank();
MLIRContext *ctx = getContext();
SmallVector<Attribute, 8> iters(
nPar, StringAttr::get(getParallelIteratorTypeName(), ctx));
return ArrayAttr::get(iters, ctx);
}
unsigned getNumReductionLoops() { return 0; }
unsigned getNumWindowLoops() { return 0; }
}];
let verifier = [{ return ::verify(*this); }];
}
def FillOp : LinalgLibrary_Op<"fill", [NInputsAndOutputs<0, 1>]> {
let arguments = (ins AnyStridedMemRef,
AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>);
def FillOp : LinalgLibrary_Op<"fill", [NInputs<0>, NOutputs<1>]> {
let arguments = (ins AnyStridedMemRef:$input,
AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>:$value);
let extraClassDeclaration = libraryCallName # [{
unsigned getNumParallelLoops() {
auto *view = *(getOperands().begin());
return view->getType().cast<MemRefType>().getRank();
}
unsigned getNumReductionLoops() { return 0; }
unsigned getNumWindowLoops() { return 0; }
Value *getValue() {
return *(getOperands().begin() + getNumInputsAndOutputs());
ArrayAttr indexing_maps();
ArrayAttr iterator_types() {
unsigned nPar = input()->getType().cast<ShapedType>().getRank();
MLIRContext *ctx = getContext();
SmallVector<Attribute, 8> iters(
nPar, StringAttr::get(getParallelIteratorTypeName(), ctx));
return ArrayAttr::get(iters, ctx);
}
}];
let verifier = [{ return ::verify(*this); }];
}
def DotOp : LinalgLibrary_Op<"dot",
[NInputsAndOutputs<2, 1>,
NLoopTypes<0, 1, 0>]> {
def DotOp : LinalgLibrary_Op<"dot", [NInputs<2>, NOutputs<1>]> {
let arguments = (ins AnyStridedMemRefOfRank<1>,
AnyStridedMemRefOfRank<1>,
AnyStridedMemRefOfRank<0>);
let extraClassDeclaration = libraryCallName;
let extraClassDeclaration = libraryCallName # [{
ArrayAttr indexing_maps();
ArrayAttr iterator_types() {
MLIRContext *ctx = getContext();
return ArrayAttr::get(
StringAttr::get(getReductionIteratorTypeName(), ctx), ctx);
}
}];
}
def MatvecOp : LinalgLibrary_Op<"matvec",
[NInputsAndOutputs<2, 1>,
NLoopTypes<1, 1, 0>]> {
def MatvecOp : LinalgLibrary_Op<"matvec", [NInputs<2>, NOutputs<1>]> {
let arguments = (ins AnyStridedMemRefOfRank<2>,
AnyStridedMemRefOfRank<1>,
AnyStridedMemRefOfRank<1>);
let extraClassDeclaration = libraryCallName;
let extraClassDeclaration = libraryCallName # [{
ArrayAttr indexing_maps();
ArrayAttr iterator_types() {
MLIRContext *ctx = getContext();
Attribute iters[2]{
StringAttr::get(getParallelIteratorTypeName(), ctx),
StringAttr::get(getReductionIteratorTypeName(), ctx)};
return ArrayAttr::get(iters, ctx);
}
}];
}
def MatmulOp : LinalgLibrary_Op<"matmul",
[NInputsAndOutputs<2, 1>,
NLoopTypes<2, 1, 0>]> {
def MatmulOp : LinalgLibrary_Op<"matmul", [NInputs<2>, NOutputs<1>]> {
let arguments = (ins AnyStridedMemRefOfRank<2>,
AnyStridedMemRefOfRank<2>,
AnyStridedMemRefOfRank<2>);
let extraClassDeclaration = libraryCallName;
let extraClassDeclaration = libraryCallName # [{
ArrayAttr indexing_maps();
ArrayAttr iterator_types() {
MLIRContext *ctx = getContext();
Attribute iters[3]{
StringAttr::get(getParallelIteratorTypeName(), ctx),
StringAttr::get(getParallelIteratorTypeName(), ctx),
StringAttr::get(getReductionIteratorTypeName(), ctx)};
return ArrayAttr::get(iters, ctx);
}
}];
}
def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> {
def ConvOp : LinalgLibrary_Op<"conv", [NInputs<2>, NOutputs<1>]> {
let description = [{
Generic n-D convolution as described in the TF documentation:
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/nn/convolution
@ -333,21 +344,27 @@ def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> {
unsigned getNumInputFeatureDimensions() { return 1; }
unsigned getNumOutputFeatureDimensions() { return 1; }
// Outer parallel loops are always the number of output dimensions; i.e.
// [ b, xs, q] in the TF notation above.
unsigned getNumParallelLoops() { return getOutputViewType(0).getRank(); }
ArrayAttr indexing_maps();
// Window loops are a special kind of reduction that is neither tiled or
// parallelized across; i.e. [zs] in the TF notation above whose number
// match `xs` (i.e. 1 window loop per "image" dimension).
unsigned getNumWindowLoops() {
return getNumParallelLoops() - getNumBatchDimensions() -
getNumInputFeatureDimensions(); }
// Reduction loops are exactly the non-parallel, non-window loops (i.e. `q`)
// We distinguish between reduction loops and convolution window loops for
// now. That distinction may disappear in the future.
unsigned getNumReductionLoops() { return getNumInputFeatureDimensions(); }
ArrayAttr iterator_types() {
// Outer parallel loops are always the number of output dimensions; i.e.
// [ b, xs, q] in the TF notation above.
unsigned nPar = getOutputViewType(0).getRank();
unsigned nRed = getNumInputFeatureDimensions();
// Window loops are a special kind of reduction that is never tiled or
// parallelized across; i.e. [zs] in the TF notation above whose number
// match `xs` (i.e. 1 window loop per "image" dimension).
// This may evolve in the future.
unsigned nWin =
nPar - getNumBatchDimensions() - getNumInputFeatureDimensions();
MLIRContext *ctx = getContext();
SmallVector<Attribute, 8> iters(
nPar, StringAttr::get(getParallelIteratorTypeName(), ctx));
iters.reserve(nPar + nRed + nWin);
iters.append(nRed, StringAttr::get(getReductionIteratorTypeName(), ctx));
iters.append(nWin, StringAttr::get(getWindowIteratorTypeName(), ctx));
return ArrayAttr::get(iters, ctx);
}
int64_t getStride(unsigned i) {
assert(i < getNumWindowLoops());
@ -368,9 +385,10 @@ def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> {
class GenericOpBase<string mnemonic> : LinalgLibraryBase_Op<mnemonic, []> {
let arguments = (ins Variadic<AnyStridedMemRef>:$views,
I64Attr:$args_in,
I64Attr:$args_out,
AffineMapArrayAttr:$indexing_maps,
ArrayAttr:$iterator_types,
I64ArrayAttr:$n_views,
OptionalAttr<StrAttr>:$doc,
OptionalAttr<FlatSymbolRefAttr>:$fun,
OptionalAttr<StrAttr>:$library_call);
@ -378,57 +396,13 @@ class GenericOpBase<string mnemonic> : LinalgLibraryBase_Op<mnemonic, []> {
let extraClassDeclaration = [{
SmallVector<StringRef, 8> linalgTraitAttrNames() {
return SmallVector<StringRef, 8>{
"doc", "fun", "indexing_maps", "library_call", "iterator_types", "n_views"
getArgsInAttrName(), getArgsOutAttrName(), getDocAttrName(),
getFunAttrName(), getIndexingMapsAttrName(), getLibraryCallAttrName(),
getIteratorTypesAttrName()
};
}
unsigned getNumInputs() {
if (!getAttr("n_views") || n_views().getValue().size() != 2)
return 0;
auto val = n_views().getValue()[0].cast<IntegerAttr>().getValue();
assert(val.getSExtValue() >= 0);
return val.getZExtValue();
}
unsigned getNumOutputs() {
if (!getAttr("n_views") || n_views().getValue().size() != 2)
return 0;
auto val = n_views().getValue()[1].cast<IntegerAttr>().getValue();
assert(val.getSExtValue() >= 0);
return val.getZExtValue();
}
unsigned getNumParallelLoops() {
if (!getAttr("iterator_types") || iterator_types().getValue().size() == 0)
return 0;
unsigned nPar = 0;
for (auto ty : iterator_types()) {
if (ty.cast<StringAttr>().getValue() == "parallel")
nPar++;
}
return nPar;
}
unsigned getNumReductionLoops() {
if (!getAttr("iterator_types") || iterator_types().getValue().size() == 0)
return 0;
unsigned nRed = 0;
for (auto ty : iterator_types()) {
if (ty.cast<StringAttr>().getValue() == "reduction")
nRed++;
}
return nRed;
}
unsigned getNumWindowLoops() {
if (!getAttr("iterator_types") || iterator_types().getValue().size() == 0)
return 0;
unsigned nWin = 0;
for (auto ty : iterator_types()) {
if (ty.cast<StringAttr>().getValue() == "window")
nWin++;
}
return nWin;
}
unsigned getNumLoops() {
return getNumParallelLoops() + getNumReductionLoops() +
getNumWindowLoops();
}
unsigned getNumInputs() { return args_in().getSExtValue(); }
unsigned getNumOutputs() { return args_out().getSExtValue(); }
FuncOp getFunction() {
auto moduleOp = getParentOfType<ModuleOp>();
return fun().hasValue() ?
@ -468,10 +442,12 @@ def GenericOp : GenericOpBase<"generic"> {
```
Where #trait_attributes is an alias of a dictionary attribute containing:
- args_in: an I64Attr representing the number of input (readonly) views
- args_out: an I64Attr representing the number of output (readwrite) views
- doc [optional]: a documentation string
- fun: a FlatSymbolRefAttr that must resolve to an existing function symbol.
To support inplace updates in a generic fashion, the signature of the
function must be:
- fun: a FlatSymbolRefAttr that must resolve to an existing function
symbol. To support inplace updates in a generic fashion, the signature
of the function must be:
```
fun([input views element types], [output views element types])
-> ([output views element types])
@ -488,8 +464,6 @@ def GenericOp : GenericOpBase<"generic"> {
Each element of the list represents and iterator of one of the following
types:
parallel, reduction, window
- n_views: a pair of I64Attr representing the number of input (readonly)
and output (readwrite) views.
Example:
Defining a #matmul_trait attribute in MLIR can be done as follows:
@ -564,12 +538,14 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
```
Where #trait_attributes is an alias of a dictionary attribute containing:
- args_in: an I64Attr representing the number of input (readonly) views
- args_out: an I64Attr representing the number of output (readwrite) views
- doc [optional]: a documentation string
- fun: a FlatSymbolRefAttr that must resolve to an existing function symbol.
To support inplace updates in a generic fashion, the signature of the
function must be:
- fun: a FlatSymbolRefAttr that must resolve to an existing function
symbol. To support inplace updates in a generic fashion, the signature
of the function must be:
```
fun([index types for induction variables], [input views element types],
fun([index types of induction variables], [input views element types],
[output views element types]) -> ([output views element types])
```
- indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input
@ -580,16 +556,17 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
maps to. The external library is assumed to be dynamically linked and
no strong compile-time guarantees are provided. In the absence of such
a library call, linalg.indexed_generic will always lower to loops.
- iterator_types: an ArrayAttr they type of the enclosing loops; Each element of
the list represents and iterator of one of the following types:
parallel, reduction, window
- n_views: a pair of I64Attr representing the number of input (readonly)
and output (readwrite) views.
- iterator_types: an ArrayAttr they type of the enclosing loops; Each
element of the list represents and iterator of one of the following
types:
parallel, reduction, window
Example:
Defining a #matmul_trait attribute in MLIR can be done as follows:
```mlir
func @fma(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32) -> f32 {
func @fma(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32)
-> f32
{
%d = mulf %a, %b: f32
%e = addf %c, %d: f32
return %e: f32

View File

@ -20,6 +20,7 @@
#include "mlir/Dialect/Linalg/IR/LinalgTraits.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
@ -85,7 +86,6 @@ SmallVector<AffineMap, 4> loopToOperandRangesMaps(Operation *op);
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/IR/LinalgLibraryOps.h.inc"
} // namespace linalg
} // namespace mlir

View File

@ -19,6 +19,7 @@
#define MLIR_DIALECT_LINALG_LINALGTRAITS_H_
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/LLVM.h"
@ -28,23 +29,30 @@ namespace OpTrait {
namespace linalg {
/// This class provides the API for ops that are known to have a specified
/// number of inputs and outputs, all passed as operands. This is used as a
/// trait like this:
/// number of inputs, all passed as operands. This is used as a trait like this:
///
/// class DotOp : public Op<DotOp, OpTrait::NInputsAndOutputs<2, 1>::Impl> {
/// class DotOp : public Op<DotOp, OpTrait::NInputs<2>::Impl> {
///
template <unsigned NInputs, unsigned NOutputs> class NInputsAndOutputs {
template <unsigned N> class NInputs {
public:
template <typename ConcreteType>
class Impl
: public OpTrait::TraitBase<ConcreteType,
NInputsAndOutputs<NInputs, NOutputs>::Impl> {
class Impl : public OpTrait::TraitBase<ConcreteType, NInputs<N>::Impl> {
public:
static unsigned getNumInputs() { return NInputs; }
static unsigned getNumOutputs() { return NOutputs; }
static LogicalResult verifyTrait(Operation *op) {
return OpTrait::impl::verifyAtLeastNOperands(op, NInputs + NOutputs);
}
static unsigned getNumInputs() { return N; }
};
};
/// This class provides the API for ops that are known to have a specified
/// number of inputs, all passed as operands. This is used as a trait like this:
///
/// class DotOp : public Op<DotOp, OpTrait::NOutputs<2>::Impl> {
///
template <unsigned N> class NOutputs {
public:
template <typename ConcreteType>
class Impl : public OpTrait::TraitBase<ConcreteType, NOutputs<N>::Impl> {
public:
static unsigned getNumOutputs() { return N; }
};
};
@ -124,6 +132,25 @@ public:
auto range = this->getOperation()->getOperands();
return {range.begin(), range.begin() + getNumInputsAndOutputs()};
}
unsigned getNumParallelLoops() {
return getNumIterators(
getParallelIteratorTypeName(),
cast<ConcreteType>(this->getOperation()).iterator_types());
}
unsigned getNumReductionLoops() {
return getNumIterators(
getReductionIteratorTypeName(),
cast<ConcreteType>(this->getOperation()).iterator_types());
}
unsigned getNumWindowLoops() {
return getNumIterators(
getWindowIteratorTypeName(),
cast<ConcreteType>(this->getOperation()).iterator_types());
}
unsigned getNumLoops() {
return getNumIterators(
cast<ConcreteType>(this->getOperation()).iterator_types());
}
static LogicalResult verifyTrait(Operation *op) {
auto nViews = cast<ConcreteType>(op).getNumInputsAndOutputs();
if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nViews)))
@ -132,27 +159,6 @@ public:
}
};
/// This class provides the API for ops that are known to have a specified
/// number of parallel, reduction and window loops. This is used as a trait like
/// this:
///
/// class MatmulOp : public Op<MatmulOp, OpTrait::NLoopTypes<2, 1, 0>::Impl> {
///
template <unsigned NParallel, unsigned NReduction, unsigned NWindow = 0>
class NLoopTypes {
public:
template <typename ConcreteType>
class Impl
: public OpTrait::TraitBase<
ConcreteType, NLoopTypes<NParallel, NReduction, NWindow>::Impl> {
public:
static unsigned getNumParallelLoops() { return NParallel; }
static unsigned getNumReductionLoops() { return NReduction; }
static unsigned getNumWindowLoops() { return NWindow; }
static unsigned getNumLoops() { return NParallel + NReduction + NWindow; }
};
};
} // namespace linalg
} // namespace OpTrait
} // namespace mlir

View File

@ -40,7 +40,7 @@ class IsProducedByOpOfType<string str> :
CPred<"isProducedByOpOfType<" # str # ">($0, $1)">;
class AffineMapDomainHasDim<int n> : CPred<[{
$0.getAttrOfType<ArrayAttr>("indexing_maps").getValue()[0].
$0.getAttrOfType<ArrayAttr>(getIndexingMapsAttrName()).getValue()[0].
cast<AffineMapAttr>().getValue().getNumDims() ==}] # n # [{}]>;
//===----------------------------------------------------------------------===//
@ -67,11 +67,12 @@ class TileAndFuseLinalgOp<
// `permutation` is an optional parameter to specify the ordering of the
// tiled loops. If provided, it must be a list of integers with the same number
// of elements as `sizes`.
class TileLinalgOp<list<int> sizes, string value, list<int> permutation=[]> : NativeCodeCall<
"if (failed(tileLinalgOpAndSetMarker($_builder, $0, {" #
StrJoinInt<sizes>.result # "}, \"" # value # "\", {" #
StrJoinInt<permutation>.result # "})))" #
" return matchFailure();">;
class TileLinalgOp<list<int> sizes, string value, list<int> permutation=[]> :
NativeCodeCall<
"if (failed(tileLinalgOpAndSetMarker($_builder, $0, {" #
StrJoinInt<sizes>.result # "}, \"" # value # "\", {" #
StrJoinInt<permutation>.result # "})))" #
" return matchFailure();">;
//===----------------------------------------------------------------------===//
// Linalg to loop patterns.
@ -96,7 +97,8 @@ class LinalgOpToVectorContraction<string OpType> : NativeCodeCall<
//===----------------------------------------------------------------------===//
class PermuteGenericLinalgOp<list<int> permutation, string value> :
NativeCodeCall<
"if (failed(permuteGenericLinalgOp($_builder, $0, {" #
StrJoinInt<permutation>.result # "}, \"" # value # "\"))) " #
" return matchFailure();">;
"if (failed(permuteGenericLinalgOp($_builder, $0, {" #
StrJoinInt<permutation>.result # "}, \"" # value # "\"))) " #
" return matchFailure();">;
#endif // LINALG_TRANSFORMS

View File

@ -26,6 +26,7 @@
#ifndef MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
#define MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
#include "mlir/IR/Attributes.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"
@ -42,6 +43,32 @@ static constexpr StringLiteral getIteratorTypesAttrName() {
return StringLiteral("iterator_types");
}
/// Attribute name for the IntegerAttr which encodes the number of input buffer
/// arguments.
static constexpr StringLiteral getArgsInAttrName() {
return StringLiteral("args_in");
}
/// Attribute name for the IntegerAttr which encodes the number of input buffer
/// arguments.
static constexpr StringLiteral getArgsOutAttrName() {
return StringLiteral("args_out");
}
/// Attribute name for the StringAttr which encodes an optional documentation
/// string of the structured op.
static constexpr StringLiteral getDocAttrName() { return StringLiteral("doc"); }
/// Attribute name for the StrArrayAttr which encodes the SymbolAttr for the
/// MLIR function that implements the body of the structured op.
static constexpr StringLiteral getFunAttrName() { return StringLiteral("fun"); }
/// Attribute name for the StrArrayAttr which encodes the external library
/// function that implements the structured op.
static constexpr StringLiteral getLibraryCallAttrName() {
return StringLiteral("library_call");
}
/// Use to encode that a particular iterator type has parallel semantics.
inline static constexpr StringLiteral getParallelIteratorTypeName() {
return StringLiteral("parallel");
@ -51,6 +78,37 @@ inline static constexpr StringLiteral getParallelIteratorTypeName() {
inline static constexpr StringLiteral getReductionIteratorTypeName() {
return StringLiteral("reduction");
}
/// Use to encode that a particular iterator type has window semantics.
inline static constexpr StringLiteral getWindowIteratorTypeName() {
return StringLiteral("window");
}
/// Use to encode that a particular iterator type has window semantics.
inline static ArrayRef<StringRef> getAllIteratorTypeNames() {
static const StringRef names[3] = {getParallelIteratorTypeName(),
getReductionIteratorTypeName(),
getWindowIteratorTypeName()};
return names;
}
/// Returns the iterator of a certain type.
inline unsigned getNumIterators(StringRef name, ArrayAttr iteratorTypes) {
auto names = getAllIteratorTypeNames();
(void)names;
assert(llvm::is_contained(names, name));
return llvm::count_if(iteratorTypes, [name](Attribute a) {
return a.cast<StringAttr>().getValue() == name;
});
}
inline unsigned getNumIterators(ArrayAttr iteratorTypes) {
unsigned res = 0;
for (auto n : getAllIteratorTypeNames())
res += getNumIterators(n, iteratorTypes);
return res;
}
} // end namespace mlir
#endif // MLIR_UTILS_STRUCTUREDOPSUTILS_H

View File

@ -96,8 +96,8 @@ static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
Region &region = *result.addRegion();
SmallVector<Type, 8> operandTypes, regionTypes;
// Optional attributes may be added.
// Either Optional "fun" attribute or region must be specified.
if (!dictAttr.get("fun") &&
// Either Optional getFunAttrName() attribute or region must be specified.
if (!dictAttr.get(getFunAttrName()) &&
parser.parseOptionalRegion(region, regionOperandsInfo, regionTypes))
return failure();
if (parser.parseOptionalAttrDict(result.attributes) ||
@ -557,7 +557,7 @@ static ParseResult parseLinalgLibraryOp(OpAsmParser &parser,
static LogicalResult verify(FillOp op) {
auto viewType = op.getOutputViewType(0);
auto fillType = op.getValue()->getType();
auto fillType = op.value()->getType();
if (viewType.getElementType() != fillType)
return op.emitOpError("expects fill type to match view elemental type");
return success();
@ -813,3 +813,30 @@ std::string mlir::linalg::generateLibraryCallName(Operation *op) {
[&]() { ss << "_"; });
return ss.str();
}
static ArrayAttr getIndexingMaps(Operation *op) {
LinalgOp linalgOp = cast<LinalgOp>(op);
SmallVector<Attribute, 4> maps;
maps.reserve(linalgOp.getNumInputsAndOutputs());
for (AffineMap map : loopToOperandRangesMaps(op))
maps.push_back(AffineMapAttr::get(map));
return ArrayAttr::get(maps, op->getContext());
}
ArrayAttr mlir::linalg::ConvOp::indexing_maps() {
return getIndexingMaps(getOperation());
}
ArrayAttr mlir::linalg::CopyOp::indexing_maps() {
return getIndexingMaps(getOperation());
}
ArrayAttr mlir::linalg::DotOp::indexing_maps() {
return getIndexingMaps(getOperation());
}
ArrayAttr mlir::linalg::FillOp::indexing_maps() {
return getIndexingMaps(getOperation());
}
ArrayAttr mlir::linalg::MatmulOp::indexing_maps() {
return getIndexingMaps(getOperation());
}
ArrayAttr mlir::linalg::MatvecOp::indexing_maps() {
return getIndexingMaps(getOperation());
}

View File

@ -130,8 +130,8 @@ public:
IndexedValueType O(fillOp.getOutput(0));
// Emit the proper scalar assignment, whether we are dealing with a 0-D or
// an n-D loop nest; with or without permutations.
nPar > 0 ? O(ivs) = ValueHandle(fillOp.getValue())
: O() = ValueHandle(fillOp.getValue());
nPar > 0 ? O(ivs) = ValueHandle(fillOp.value())
: O() = ValueHandle(fillOp.value());
}
};

View File

@ -211,20 +211,20 @@ mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
auto permutationMap = inversePermutation(
AffineMap::getPermutationMap(permutation, rewriter.getContext()));
SmallVector<AffineMap, 4> newIndexingMap;
auto indexingMaps =
linOp.getAttrOfType<ArrayAttr>("indexing_maps").getValue();
auto indexingMaps = linOp.indexing_maps().getValue();
for (unsigned i = 0, e = linOp.getNumInputsAndOutputs(); i != e; ++i) {
AffineMap m = indexingMaps[i].cast<AffineMapAttr>().getValue().compose(
permutationMap);
newIndexingMap.push_back(m);
}
auto itTypes = linOp.getAttrOfType<ArrayAttr>("iterator_types").getValue();
SmallVector<StringRef, 4> itTypesVector;
auto itTypes = linOp.iterator_types().getValue();
SmallVector<Attribute, 4> itTypesVector;
for (unsigned i = 0, e = itTypes.size(); i != e; ++i)
itTypesVector.push_back(itTypes[i].cast<StringAttr>().getValue());
itTypesVector.push_back(itTypes[i]);
applyPermutationToVector(itTypesVector, permutation);
op->setAttr("indexing_maps", rewriter.getAffineMapArrayAttr(newIndexingMap));
op->setAttr("iterator_types", rewriter.getStrArrayAttr(itTypesVector));
op->setAttr(getIndexingMapsAttrName(),
rewriter.getAffineMapArrayAttr(newIndexingMap));
op->setAttr(getIteratorTypesAttrName(), rewriter.getArrayAttr(itTypesVector));
op->setAttr(LinalgTransforms::kLinalgTransformMarker,
rewriter.getStringAttr(linalgMarker));
linOp.clone(rewriter, linOp.getLoc(), op->getOperands());

View File

@ -305,9 +305,10 @@ func @f8(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, of
#id_2d = (i, j) -> (i, j)
#pointwise_2d_trait = {
args_in = 2,
args_out = 1,
indexing_maps = [#id_2d, #id_2d, #id_2d],
iterator_types = ["parallel", "parallel"],
n_views = [2, 1]
iterator_types = ["parallel", "parallel"]
}
func @pointwise(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, offset: 0, strides: [?, ?]>, %C: memref<?x?xf32, offset: 0, strides: [?, ?]>, %D: memref<?x?xf32, offset: 0, strides: [?, ?]>) {
%c1 = constant 1 : index

View File

@ -57,9 +57,10 @@ func @yield_parent(%arg0: memref<?xf32, (i)[off]->(off + i)>) {
func @generic_at_least_2_operands(%arg0: memref<f32>) {
// expected-error @+1 {{op expected 2 or more operands}}
linalg.generic {
args_in = 1,
args_out = 1,
fun = @foo,
indexing_maps = [ () -> (0) ],
n_views = [1, 1],
iterator_types = []
} %arg0: memref<f32>
}
@ -69,9 +70,10 @@ func @generic_at_least_2_operands(%arg0: memref<f32>) {
func @generic_exactly_2_views(%arg0: memref<f32>) {
// expected-error @+1 {{op expected exactly 2 view operands}}
linalg.generic {
args_in = 1,
args_out = 1,
fun = @foo,
indexing_maps = [ () -> (0) ],
n_views = [1, 1],
iterator_types = []
} %arg0, %arg0, %arg0: memref<f32>, memref<f32>, memref<f32>
}
@ -81,9 +83,10 @@ func @generic_exactly_2_views(%arg0: memref<f32>) {
func @generic_undefined_fun(%arg0: memref<f32>) {
// expected-error @+1 {{op expected fun attribute to refer to a defined symbol}}
linalg.generic {
args_in = 1,
args_out = 1,
fun = @foo,
indexing_maps = [ () -> (0) ],
n_views = [1, 1],
iterator_types = []
} %arg0, %arg0: memref<f32>, memref<f32>
}
@ -95,9 +98,10 @@ func @foo() { return }
func @generic_mismatched_num_arguments(%arg0: memref<f32>) {
// expected-error @+1 {{op expected fun arguments to match number of views}}
linalg.generic {
args_in = 0,
args_out = 1,
fun = @foo,
indexing_maps = [ () -> (0) ],
n_views = [0, 1],
iterator_types = []
} %arg0: memref<f32>
}
@ -109,9 +113,10 @@ func @foo(%0: i32) { return }
func @generic_mismatched_num_returns(%arg0: memref<f32>) {
// expected-error @+1 {{op expected fun results to match number of output views}}
linalg.generic {
args_in = 0,
args_out = 1,
fun = @foo,
indexing_maps = [ () -> (0) ],
n_views = [0, 1],
iterator_types = []
} %arg0: memref<f32>
}
@ -123,9 +128,10 @@ func @foo(%0: i32) -> i32 { return %0: i32 }
func @generic_symbol_in_map(%arg0: memref<i32>) {
// expected-error @+1 {{op expected indexing_map #0 to have no symbols}}
linalg.generic {
args_in = 0,
args_out = 1,
fun = @foo,
indexing_maps = [ ()[N] -> (0) ],
n_views = [0, 1],
iterator_types = ["parallel"]
} %arg0: memref<i32>
}
@ -137,9 +143,10 @@ func @foo(%0: i32) -> i32 { return %0: i32 }
func @generic_wrong_dim_in_map(%arg0: memref<i32>) {
// expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match the number of loops}}
linalg.generic {
args_in = 0,
args_out = 1,
fun = @foo,
indexing_maps = [ () -> (0) ],
n_views = [0, 1],
iterator_types = ["parallel"]
} %arg0: memref<i32>
}
@ -151,9 +158,10 @@ func @foo(%0: i32) -> i32 { return %0: i32 }
func @generic_zero_d_view(%arg0: memref<i32>) {
// expected-error @+1 {{op expected indexing_map #0 to be 0 to match 0-D view: 'memref<i32>'}}
linalg.generic {
args_in = 0,
args_out = 1,
fun = @foo,
indexing_maps = [ () -> (1) ],
n_views = [0, 1],
iterator_types = []
} %arg0: memref<i32>
}
@ -165,9 +173,10 @@ func @foo(%0: f32) -> f32 { return %0: f32 }
func @generic_one_d_view(%arg0: memref<?xf32, (i)[off]->(off + i)>) {
// expected-error @+1 {{op expected indexing_map #0 results to match view rank: 'memref<?xf32, (d0)[s0] -> (d0 + s0)>'}}
linalg.generic {
args_in = 0,
args_out = 1,
fun = @foo,
indexing_maps = [ () -> (0, 0) ],
n_views = [0, 1],
iterator_types = []
} %arg0: memref<?xf32, (i)[off]->(off + i)>
}
@ -182,9 +191,10 @@ func @foo(%0: i32) -> f32 {
func @generic_fun_arg_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)>) {
// expected-error @+1 {{op expected fun argument 0 of the same type as elemental type 'f32' of view 0}}
linalg.generic {
args_in = 0,
args_out = 1,
fun = @foo,
indexing_maps = [ () -> (0) ],
n_views = [0, 1],
iterator_types = []
} %arg0: memref<?xf32, (i)[off]->(off + i)>
}
@ -199,9 +209,10 @@ func @foo(%0: f32) -> i4 {
func @generic_fun_result_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)>) {
// expected-error @+1 {{op expected fun result 0 of the same type as elemental type 'f32' of view 0}}
linalg.generic {
args_in = 0,
args_out = 1,
fun = @foo,
indexing_maps = [ () -> (0) ],
n_views = [0, 1],
iterator_types = []
} %arg0: memref<?xf32, (i)[off]->(off + i)>
}
@ -213,12 +224,13 @@ func @foo(%0: f32, %1: f32) -> f32 { return %1: f32 }
func @generic_singular_maps(%arg0: memref<?xf32, (i)[off]->(off + i)>, %arg1: memref<?xf32, (i)[off]->(off + i)>) {
// expected-error @+1 {{op expected the concatenation of maps in indexing_map to be invertible}}
linalg.generic {
args_in = 1,
args_out = 1,
fun = @foo,
indexing_maps = [
(i, j) -> (i + j) ,
(i, j) -> (i + j)
],
n_views = [1, 1],
iterator_types = ["parallel","parallel"]
} %arg0, %arg1: memref<?xf32, (i)[off]->(off + i)>, memref<?xf32, (i)[off]->(off + i)>
}
@ -232,8 +244,9 @@ func @generic_singular_maps(%arg0: memref<?xf32, (i)[off]->(off + i)>, %arg1: me
func @generic_empty_region(%arg0: memref<f32>) {
// expected-error @+1 {{op expected region with 1 block}}
linalg.generic {
args_in = 1,
args_out = 1,
indexing_maps = [ () -> (0) ],
n_views = [1, 1],
iterator_types = []
} %arg0, %arg0 {
^bb1:
@ -246,8 +259,9 @@ func @generic_empty_region(%arg0: memref<f32>) {
func @generic_mismatched_num_arguments(%arg0: memref<f32>) {
// expected-error @+1 {{op expected number of block arguments to match number of views}}
linalg.generic {
args_in = 0,
args_out = 1,
indexing_maps = [ () -> (0) ],
n_views = [0, 1],
iterator_types = []
} %arg0 {
^bb:
@ -259,8 +273,9 @@ func @generic_mismatched_num_arguments(%arg0: memref<f32>) {
func @generic_block_arg_type(%arg0: memref<f32>) {
// expected-error @+1 {{op expected block argument 0 of the same type as elemental type of output view: 'memref<f32>'}}
linalg.generic {
args_in = 0,
args_out = 1,
indexing_maps = [ () -> (0) ],
n_views = [0, 1],
iterator_types = []
} %arg0 {
^bb(%i: i1):
@ -272,8 +287,9 @@ func @generic_block_arg_type(%arg0: memref<f32>) {
func @indexed_generic_block_arg_count(%arg0: memref<f32>) {
// expected-error @+1 {{op expected number of block arguments to match number of views + number of loops}}
linalg.indexed_generic {
args_in = 0,
args_out = 1,
indexing_maps = [ (d0) -> (d0) ],
n_views = [0, 1],
iterator_types = ["parallel"]
} %arg0 {
^bb(%f: f32):
@ -285,8 +301,9 @@ func @indexed_generic_block_arg_count(%arg0: memref<f32>) {
func @indexed_generic_block_induction_var_arg_type(%arg0: memref<f32>) {
// expected-error @+1 {{op expected block argument 0 to be of IndexType}}
linalg.indexed_generic {
args_in = 0,
args_out = 1,
indexing_maps = [ (d0) -> (d0) ],
n_views = [0, 1],
iterator_types = ["parallel"]
} %arg0 {
^bb(%i: f64, %f: f32):
@ -298,8 +315,9 @@ func @indexed_generic_block_induction_var_arg_type(%arg0: memref<f32>) {
func @indexed_generic_block_arg_type(%arg0: memref<f32>) {
// expected-error @+1 {{op expected block argument 1 of the same type as elemental type of output view: 'memref<f32>'}}
linalg.indexed_generic {
args_in = 0,
args_out = 1,
indexing_maps = [ (d0) -> (d0) ],
n_views = [0, 1],
iterator_types = ["parallel"]
} %arg0 {
^bb(%i: index, %f: i1):
@ -314,8 +332,9 @@ func @foo(%f: f32) -> (f32) {
func @indexed_generic_fun_arg_count(%arg0: memref<f32>) {
// expected-error @+1 {{op expected fun arguments to match number of views + number of loops}}
linalg.indexed_generic {
args_in = 0,
args_out = 1,
indexing_maps = [ (d0) -> (d0) ],
n_views = [0, 1],
iterator_types = ["parallel"],
fun = @foo
} %arg0: memref<f32>
@ -329,7 +348,8 @@ func @foo(%i: i32, %val: f32) -> (f32) {
func @indexed_generic_fun_induction_var_arg_type(%arg0: memref<f32>) {
// expected-error @+1 {{op expected fun argument 0 to be of IndexType}}
linalg.indexed_generic {
n_views = [0, 1],
args_in = 0,
args_out = 1,
iterator_types = ["parallel"],
indexing_maps = [ (i) -> (i) ],
fun = @foo
@ -344,8 +364,9 @@ func @foo(%i: index, %val: i1) -> (i1) {
func @indexed_generic_fun_arg_type(%arg0: memref<f32>) {
// expected-error @+1 {{op expected fun argument 1 of the same type as elemental type 'f32' of view 0}}
linalg.indexed_generic {
args_in = 0,
args_out = 1,
indexing_maps = [ (d0) -> (d0) ],
n_views = [0, 1],
iterator_types = ["parallel"],
fun = @foo
} %arg0: memref<f32>
@ -359,8 +380,9 @@ func @foo(%i: index, %val: i1) -> (i1, i1) {
func @indexed_generic_fun_result_count(%arg0: memref<f32>) {
// expected-error @+1 {{op expected fun results to match number of output views}}
linalg.indexed_generic {
args_in = 0,
args_out = 1,
indexing_maps = [ (d0) -> (d0) ],
n_views = [0, 1],
iterator_types = ["parallel"],
fun = @foo
} %arg0: memref<f32>
@ -375,8 +397,9 @@ func @foo(%i: index, %val: i32) -> (f32) {
func @indexed_generic_fun_result_count(%arg0: memref<i32>) {
// expected-error @+1 {{op expected fun result 0 of the same type as elemental type 'i32' of view 0}}
linalg.indexed_generic {
args_in = 0,
args_out = 1,
indexing_maps = [ (d0) -> (d0) ],
n_views = [0, 1],
iterator_types = ["parallel"],
fun = @foo
} %arg0: memref<i32>
@ -385,10 +408,11 @@ func @indexed_generic_fun_result_count(%arg0: memref<i32>) {
// -----
func @generic_fun_result_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)>) {
// expected-error @+8 {{type of return operand 0 ('i1') doesn't match view element type ('f32')}}
// expected-error @+9 {{type of return operand 0 ('i1') doesn't match view element type ('f32')}}
linalg.generic {
indexing_maps = [ (i) -> (i) ],
n_views = [0, 1],
args_in = 0,
args_out = 1,
indexing_maps = [ (i) -> (i) ],
iterator_types = ["parallel"]
} %arg0 {
^bb(%i: f32):
@ -399,6 +423,13 @@ func @generic_fun_result_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)
// -----
func @generic_fun_result_0_element_type(%arg0: memref<?xf32>) {
// expected-error @+1 {{'linalg.dot' op expected 3 or more operands}}
linalg.dot(%arg0, %arg0): memref<?xf32>, memref<?xf32>
}
// -----
// expected-error @+1 {{unknown Linalg type}}
!invalid_type = type !linalg.unknown

View File

@ -138,7 +138,8 @@ func @copy_transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %a
(m, n, k) -> (m, n)
]
#matmul_trait = {
n_views = [2, 1],
args_in = 2,
args_out = 1,
iterator_types = ["parallel", "parallel", "reduction"],
indexing_maps = #matmul_accesses,
library_call = "external_outerproduct_matmul"
@ -175,7 +176,8 @@ func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C
#indexed_matmul_trait = {
n_views = [2, 1],
args_in = 2,
args_out = 1,
iterator_types = ["parallel", "parallel", "reduction"],
indexing_maps = #matmul_accesses,
library_call = "external_indexed_outerproduct_matmul"

View File

@ -222,7 +222,8 @@ func @foo(%0: f32, %1: f32, %2: f32) -> (f32, f32) {
(i, j, k) -> (i, k, j)
]
#trait = {
n_views = [1, 2],
args_in = 1,
args_out = 2,
iterator_types = ["parallel", "parallel", "parallel"],
indexing_maps = #accesses,
fun = @foo,
@ -247,7 +248,8 @@ func @generic_function(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1
// CHECK: store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
#trait2 = {
n_views = [1, 2],
args_in = 1,
args_out = 2,
iterator_types = ["parallel", "parallel", "parallel"],
indexing_maps = #accesses,
library_call = "some_external_function_name_2",
@ -280,7 +282,8 @@ func @indexed_foo(%i: index, %j: index, %k: index, %0: f32, %1: f32, %2: f32) ->
return %i_float, %i_float : f32, f32
}
#trait3 = {
n_views = [1, 2],
args_in = 1,
args_out = 2,
iterator_types = ["parallel", "parallel", "parallel"],
indexing_maps = #accesses,
fun = @indexed_foo,
@ -310,7 +313,8 @@ func @indexed_generic_function(
// CHECK: store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
#trait4 = {
n_views = [1, 2],
args_in = 1,
args_out = 2,
iterator_types = ["parallel", "parallel", "parallel"],
indexing_maps = #accesses,
library_call = "some_external_function_name_2",

View File

@ -120,8 +120,9 @@ func @conv_view6(%arg0: memref<?x?x?x?x?x?xf32, offset: ?, strides: [?, ?, ?, ?,
(i, j, k) -> (i, k, i + j)
]
#trait = {
args_in = 1,
args_out = 1,
indexing_maps = #accesses,
n_views = [1, 1],
iterator_types = ["parallel", "parallel", "parallel"],
fun = @foo,
library_call = "some_external_function_name_1"
@ -136,11 +137,12 @@ func @generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>, %ar
}
// CHECK-LABEL: func @foo
// CHECK-LABEL: func @generic
// CHECK: linalg.generic {fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1", n_views = [1, 1]} %{{.*}}, %{{.*}} {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
#trait2 = {
args_in = 1,
args_out = 1,
indexing_maps = #accesses,
n_views = [1, 1],
iterator_types = ["parallel", "parallel", "parallel"],
library_call = "some_external_function_name_2"
}
@ -152,7 +154,7 @@ func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1
return
}
// CHECK-LABEL: func @generic_region
// CHECK: linalg.generic {indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2", n_views = [1, 1]} %{{.*}}, %{{.*}} {
// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2"} %{{.*}}, %{{.*}} {
// CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32): // no predecessors
// CHECK: linalg.yield %{{.*}} : f32
// CHECK: } {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
@ -166,7 +168,7 @@ func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?,
return
}
// CHECK-LABEL: func @indexed_generic
// CHECK: linalg.indexed_generic {indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2", n_views = [1, 1]} %{{.*}}, %{{.*}} {
// CHECK: linalg.indexed_generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2"} %{{.*}}, %{{.*}} {
// CHECK: ^{{.*}}(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: vector<3x4xi4>, %{{.*}}: f32):
// CHECK: linalg.yield %{{.*}} : f32
// CHECK: } {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>

View File

@ -213,9 +213,10 @@ func @fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: f32) {
#id_2d = (i, j) -> (i, j)
#pointwise_2d_trait = {
args_in = 2,
args_out = 1,
indexing_maps = [#id_2d, #id_2d, #id_2d],
iterator_types = ["parallel", "parallel"],
n_views = [2, 1]
iterator_types = ["parallel", "parallel"]
}
func @pointwise(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,

View File

@ -4,9 +4,10 @@
#id_1d = (i) -> (i)
#pointwise_1d_trait = {
args_in = 1,
args_out = 1,
indexing_maps = [#id_1d, #id_1d],
iterator_types = ["parallel"],
n_views = [1, 1]
iterator_types = ["parallel"]
}
func @indexed_generic_vector(%operand: memref<50xf32>, %result: memref<50xf32>) {
linalg.indexed_generic #pointwise_1d_trait %operand, %result {
@ -43,12 +44,13 @@ func @indexed_generic_vector(%operand: memref<50xf32>, %result: memref<50xf32>)
// TILE-0n25: linalg.indexed_generic
#combined_indices_trait = {
args_in = 1,
args_out = 1,
indexing_maps = [
(i, j) -> (j, i + j),
(i, j) -> (i, j)
],
iterator_types = ["parallel", "parallel"],
n_views = [1, 1]
iterator_types = ["parallel", "parallel"]
}
func @indexed_generic_matrix(%operand: memref<50x100xf32>, %result: memref<50x100xf32>) {
linalg.indexed_generic #combined_indices_trait %operand, %result {

View File

@ -83,11 +83,12 @@ func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// CHECK : linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
#some_generic_trait = {
args_in = 1,
args_out = 1,
indexing_maps = [
(i, j) -> (i, j),
(i, j) -> (i, j)
],
n_views = [1, 1],
iterator_types = ["parallel", "parallel"]
}
func @fusion_test(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
@ -164,12 +165,13 @@ func @fusion_test(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// CHECK : linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
#matmul_trait = {
args_in = 2,
args_out = 1,
indexing_maps = [
(m, n, k) -> (m, k),
(m, n, k) -> (k, n),
(m, n, k) -> (m, n)
],
n_views = [2, 1],
iterator_types = ["parallel", "parallel", "reduction"],
__internal_linalg_transform__ = "_marked_matmul_"
}
@ -204,10 +206,11 @@ func @fma(%a: f32, %b: f32, %c: f32) -> f32 {
(m, n, k) -> (m, n)
]
#generic_matmul_trait = {
args_in = 2,
args_out = 1,
fun = @fma,
indexing_maps = #matmul_accesses,
library_call = "linalg_matmul",
n_views = [2, 1],
iterator_types = ["parallel", "parallel", "reduction"]
}
@ -220,7 +223,7 @@ func @permute_generic(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
}
// CHECK-LABEL : func @fma
// CHECK-LABEL : func @permute_generic
// CHECK : linalg.generic {fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul", n_views = [2, 1]} %{{.*}}, %{{.*}}, %{{.*}} : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
// CHECK : linalg.generic {args_in = 2, args_out = 1, fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul"} %{{.*}}, %{{.*}}, %{{.*}} : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
func @fma_indexed(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32) -> f32 {
%d = mulf %a, %b: f32
@ -228,10 +231,11 @@ func @fma_indexed(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32) ->
return %e: f32
}
#indexed_matmul_trait = {
args_in = 2,
args_out = 1,
fun = @fma_indexed,
indexing_maps = #matmul_accesses,
library_call = "linalg_matmul_indexed",
n_views = [2, 1],
iterator_types = ["parallel", "parallel", "reduction"]
}
func @permute_generic_indexed(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
@ -242,7 +246,7 @@ func @permute_generic_indexed(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
}
// CHECK-LABEL : func @fma_indexed
// CHECK-LABEL : func @permute_generic_indexed
// CHECK : linalg.indexed_generic {fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul_indexed", n_views = [2, 1]} %{{.*}}, %{{.*}}, %{{.*}} : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
// CHECK : linalg.indexed_generic {args_in = 2, args_out = 1, fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul_indexed"} %{{.*}}, %{{.*}}, %{{.*}} : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
func @dot_perm(%x: memref<?xf32, offset: ?, strides: [1]>,
%y: memref<?xf32, offset: ?, strides: [1]>,

View File

@ -108,7 +108,7 @@ def : Pattern<(DotOp:$op $a, $b, $c),
//===----------------------------------------------------------------------===//
// Linalg to vector contraction patterns.
//===----------------------------------------------------------------------===//
def : Pattern<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7),
def : Pattern<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
[(LinalgOpToVectorContraction<"GenericOp"> $op)],
[(Constraint<HasLinalgTransformMarker<"_marked_matmul_">> $op)]>;
@ -116,13 +116,14 @@ def : Pattern<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7),
// Linalg generic permutation patterns.
//===----------------------------------------------------------------------===//
def : Pat<(GenericOp:$op $input, $imap, $itypes, $nviews, $doc, $fun, $libcall),
def : Pat<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
(PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op),
[(Constraint<And<[HasNoLinalgTransformMarker,
AffineMapDomainHasDim<3>]>> $op)]>;
def : Pat<(IndexedGenericOp:$op $input, $imap, $itypes, $nviews, $doc, $fun, $libcall),
def : Pat<(IndexedGenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
(PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op),
[(Constraint<And<[HasNoLinalgTransformMarker,
AffineMapDomainHasDim<3>]>> $op)]>;
#endif // TEST_LINALG_TRANSFORMS_PATTERNS