forked from OSchip/llvm-project
[mlir][Linalg] Drop function attribute from generic ops.
The function attribute in generic ops is not paying for itself. A region is the more standardized way of specifying a custom computation. If needed this region can call a function directly. This is deemed more natural than managing a dedicated function attribute. This also simplifies named ops generation by trimming unnecessary complexity. Differential Revision: https://reviews.llvm.org/D78266
This commit is contained in:
parent
2ec5520a54
commit
f54312277c
|
@ -523,7 +523,6 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, []> {
|
|||
AffineMapArrayAttr:$indexing_maps,
|
||||
ArrayAttr:$iterator_types,
|
||||
OptionalAttr<StrAttr>:$doc,
|
||||
OptionalAttr<FlatSymbolRefAttr>:$fun,
|
||||
OptionalAttr<StrAttr>:$library_call);
|
||||
let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
|
||||
let regions = (region AnyRegion:$region);
|
||||
|
@ -531,7 +530,7 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, []> {
|
|||
SmallVector<StringRef, 8> linalgTraitAttrNames() {
|
||||
return SmallVector<StringRef, 8>{
|
||||
getArgsInAttrName(), getArgsOutAttrName(), getDocAttrName(),
|
||||
getFunAttrName(), getIndexingMapsAttrName(), getLibraryCallAttrName(),
|
||||
getIndexingMapsAttrName(), getLibraryCallAttrName(),
|
||||
getIteratorTypesAttrName()
|
||||
};
|
||||
}
|
||||
|
@ -540,12 +539,6 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, []> {
|
|||
|
||||
unsigned getNumOutputs() { return args_out().getSExtValue(); }
|
||||
|
||||
FuncOp getFunction() {
|
||||
auto moduleOp = getParentOfType<ModuleOp>();
|
||||
return fun().hasValue() ?
|
||||
moduleOp.lookupSymbol<FuncOp>(fun().getValue()) : FuncOp();
|
||||
}
|
||||
|
||||
StringRef getLibraryCallName() {
|
||||
return library_call().hasValue() ? library_call().getValue() : "";
|
||||
}
|
||||
|
@ -581,13 +574,6 @@ def GenericOp : GenericOpBase<"generic"> {
|
|||
- args_in: an I64Attr representing the number of input (readonly) views
|
||||
- args_out: an I64Attr representing the number of output (readwrite) views
|
||||
- doc [optional]: a documentation string
|
||||
- fun: a FlatSymbolRefAttr that must resolve to an existing function
|
||||
symbol. To support inplace updates in a generic fashion, the signature
|
||||
of the function must be:
|
||||
```
|
||||
fun([input views element types], [output views element types])
|
||||
-> ([output views element types])
|
||||
```
|
||||
- indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input
|
||||
and output view. Such AffineMapAttr specifies the mapping between the
|
||||
loops and the indexing within each view.
|
||||
|
@ -604,11 +590,6 @@ def GenericOp : GenericOpBase<"generic"> {
|
|||
Example:
|
||||
Defining a #matmul_trait attribute in MLIR can be done as follows:
|
||||
```mlir
|
||||
func @fma(%a: f32, %b: f32, %c: f32) -> f32 {
|
||||
%d = mulf %a, %b: f32
|
||||
%e = addf %c, %d: f32
|
||||
return %e: f32
|
||||
}
|
||||
#matmul_accesses = [
|
||||
(m, n, k) -> (m, k),
|
||||
(m, n, k) -> (k, n),
|
||||
|
@ -616,7 +597,6 @@ def GenericOp : GenericOpBase<"generic"> {
|
|||
]
|
||||
#matmul_trait = {
|
||||
doc = "C(m, n) += A(m, k) * B(k, n)",
|
||||
fun = @fma,
|
||||
indexing_maps = #matmul_accesses,
|
||||
library_call = "linalg_matmul",
|
||||
n_views = [2, 1],
|
||||
|
@ -626,10 +606,14 @@ def GenericOp : GenericOpBase<"generic"> {
|
|||
|
||||
And can be reused in multiple places as:
|
||||
```mlir
|
||||
linalg.generic #matmul_trait %A, %B, %C [other-attributes] :
|
||||
memref<?x?xf32, stride_specification>,
|
||||
memref<?x?xf32, stride_specification>,
|
||||
memref<?x?xf32, stride_specification>
|
||||
linalg.generic #matmul_trait %A, %B, %C [other-attributes] {
|
||||
(%a: f32, %b: f32, %c: f32) :
|
||||
%d = mulf %a, %b: f32
|
||||
%e = addf %c, %d: f32
|
||||
linalg_yield %e : f32
|
||||
} : memref<?x?xf32, stride_specification>,
|
||||
memref<?x?xf32, stride_specification>,
|
||||
memref<?x?xf32, stride_specification>
|
||||
```
|
||||
|
||||
This may lower to either:
|
||||
|
@ -649,9 +633,9 @@ def GenericOp : GenericOpBase<"generic"> {
|
|||
%a = load %A[%m, %k] : memref<?x?xf32, stride_specification>
|
||||
%b = load %B[%k, %n] : memref<?x?xf32, stride_specification>
|
||||
%c = load %C[%m, %n] : memref<?x?xf32, stride_specification>
|
||||
%d = call @func_of_elements(%a, %b, %c)
|
||||
: (f32, f32, f32) -> (f32)
|
||||
store %d, %C[%m, %n] : memref<?x?x?xf32, stride_specification>
|
||||
%d = mulf %a, %b: f32
|
||||
%e = addf %c, %d: f32
|
||||
store %e, %C[%m, %n] : memref<?x?x?xf32, stride_specification>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -662,7 +646,7 @@ def GenericOp : GenericOpBase<"generic"> {
|
|||
mixing input and output ranked tensor values with input and output memrefs.
|
||||
|
||||
```mlir
|
||||
%C = linalg.generic #trait_attribute %A, %B {other-attributes} :
|
||||
%C = linalg.generic #trait_attribute %A, %B {other-attributes} {region} :
|
||||
tensor<?x?xf32>,
|
||||
memref<?x?xf32, stride_specification>
|
||||
-> (tensor<?x?xf32>)
|
||||
|
@ -708,13 +692,6 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
|
|||
- args_in: an I64Attr representing the number of input (readonly) views
|
||||
- args_out: an I64Attr representing the number of output (readwrite) views
|
||||
- doc [optional]: a documentation string
|
||||
- fun: a FlatSymbolRefAttr that must resolve to an existing function
|
||||
symbol. To support inplace updates in a generic fashion, the signature
|
||||
of the function must be:
|
||||
```
|
||||
fun([index types of induction variables], [input views element types],
|
||||
[output views element types]) -> ([output views element types])
|
||||
```
|
||||
- indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input
|
||||
and output view. Such AffineMapAttr specifies the mapping between the
|
||||
loops and the indexing within each view.
|
||||
|
@ -732,15 +709,6 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
|
|||
Defining a #matmul_trait attribute in MLIR can be done as follows:
|
||||
|
||||
```mlir
|
||||
func @fma(%offset_m: index, %offset_n: index, %offset_k: index,
|
||||
%a: f32, %b: f32, %c: f32)
|
||||
-> f32
|
||||
{
|
||||
"some_optional_condition"(%offset_m, %offset_n, %offset_k)
|
||||
%d = mulf %a, %b: f32
|
||||
%e = addf %c, %d: f32
|
||||
return %e: f32
|
||||
}
|
||||
#matmul_accesses = [
|
||||
(m, n, k) -> (m, k),
|
||||
(m, n, k) -> (k, n),
|
||||
|
@ -748,7 +716,6 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
|
|||
]
|
||||
#matmul_trait = {
|
||||
doc = "C(m, n) += A(m, k) * B(k, n)",
|
||||
fun = @fma,
|
||||
indexing_maps = #matmul_accesses,
|
||||
library_call = "linalg_matmul",
|
||||
n_views = [2, 1],
|
||||
|
@ -759,10 +726,16 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
|
|||
And can be reused in multiple places as:
|
||||
|
||||
```mlir
|
||||
linalg.indexed_generic #matmul_trait %A, %B, %C [other-attributes] :
|
||||
memref<?x?xf32, stride_specification>,
|
||||
memref<?x?xf32, stride_specification>,
|
||||
memref<?x?xf32, stride_specification>
|
||||
linalg.indexed_generic #matmul_trait %A, %B, %C [other-attributes] {
|
||||
(%offset_m: index, %offset_n: index, %offset_k: index,
|
||||
%a: f32, %b: f32, %c: f32) :
|
||||
"some_optional_computation"(%offset_m, %offset_n, %offset_k)
|
||||
%d = mulf %a, %b: f32
|
||||
%e = addf %c, %d: f32
|
||||
linalg_yield %e : f32
|
||||
} : memref<?x?xf32, stride_specification>,
|
||||
memref<?x?xf32, stride_specification>,
|
||||
memref<?x?xf32, stride_specification>
|
||||
```
|
||||
|
||||
This may lower to either:
|
||||
|
@ -784,8 +757,9 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
|
|||
%a = load %A[%m, %k] : memref<?x?xf32, stride_specification>
|
||||
%b = load %B[%k, %n] : memref<?x?xf32, stride_specification>
|
||||
%c = load %C[%m, %n] : memref<?x?xf32, stride_specification>
|
||||
%d = call @func_of_elements_and_indices(%m, %n, %k, %a, %b, %c)
|
||||
: (index, index, index, f32, f32, f32) -> (f32)
|
||||
"some_optional_computation"(%m, %n, %k)
|
||||
%d = mulf %a, %b: f32
|
||||
%e = addf %c, %d: f32
|
||||
store %d, %C[%m, %n] : memref<?x?x?xf32, stride_specification>
|
||||
}
|
||||
}
|
||||
|
|
|
@ -66,10 +66,6 @@ constexpr StringRef getArgsOutAttrName() { return "args_out"; }
|
|||
/// string of the structured op.
|
||||
constexpr StringRef getDocAttrName() { return "doc"; }
|
||||
|
||||
/// Attribute name for the StrArrayAttr which encodes the SymbolAttr for the
|
||||
/// MLIR function that implements the body of the structured op.
|
||||
constexpr StringRef getFunAttrName() { return "fun"; }
|
||||
|
||||
/// Attribute name for the StrArrayAttr which encodes the external library
|
||||
/// function that implements the structured op.
|
||||
constexpr StringRef getLibraryCallAttrName() { return "library_call"; }
|
||||
|
|
|
@ -177,7 +177,6 @@ Operation *mlir::edsc::makeGenericLinalgOp(
|
|||
builder.getAffineMapArrayAttr(maps),
|
||||
builder.getStrArrayAttr(iteratorStrTypes),
|
||||
StringAttr() /*doc*/,
|
||||
FlatSymbolRefAttr() /*fun*/,
|
||||
StringAttr() /*library_call*/
|
||||
/* TODO: other attributes in op */
|
||||
)
|
||||
|
|
|
@ -133,10 +133,11 @@ static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
|
|||
attrs.push_back(attr);
|
||||
|
||||
auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
|
||||
p << op.getOperationName() << " " << dictAttr << " " << op.getOperands();
|
||||
p << op.getOperationName() << " " << dictAttr;
|
||||
p.printOptionalAttrDict(op.getAttrs(), attrNames);
|
||||
p << " " << op.getOperands();
|
||||
if (!op.region().empty())
|
||||
p.printRegion(op.region());
|
||||
p.printOptionalAttrDict(op.getAttrs(), attrNames);
|
||||
p << ": " << op.getOperandTypes();
|
||||
auto outputTensorTypes = op.getResultTypes();
|
||||
if (!outputTensorTypes.empty())
|
||||
|
@ -156,21 +157,21 @@ static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
|
|||
// The name is unimportant as we will overwrite result.attributes.
|
||||
// The core linalg traits must contain the information necessary to pass the
|
||||
// verifier.
|
||||
if (parser.parseAttribute(dictAttr, "_", result.attributes) ||
|
||||
parser.parseOperandList(operandsInfo))
|
||||
if (parser.parseAttribute(dictAttr, "_", result.attributes))
|
||||
return failure();
|
||||
result.attributes.assign(dictAttr.getValue().begin(),
|
||||
dictAttr.getValue().end());
|
||||
|
||||
// Optional attributes may be added.
|
||||
if (parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseOperandList(operandsInfo))
|
||||
return failure();
|
||||
|
||||
Region ®ion = *result.addRegion();
|
||||
SmallVector<Type, 8> operandTypes, regionTypes;
|
||||
// Optional attributes may be added.
|
||||
// Either Optional getFunAttrName() attribute or region must be specified.
|
||||
if (!dictAttr.get(getFunAttrName()) &&
|
||||
parser.parseOptionalRegion(region, regionOperandsInfo, regionTypes))
|
||||
if (parser.parseRegion(region, regionOperandsInfo, regionTypes))
|
||||
return failure();
|
||||
if (parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonTypeList(operandTypes))
|
||||
if (parser.parseColonTypeList(operandTypes))
|
||||
return failure();
|
||||
// Generic ops may specify that a subset of its outputs are tensors. Such
|
||||
// outputs are specified in the result type.
|
||||
|
@ -183,10 +184,7 @@ static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
|
|||
parser.getCurrentLocation(), result.operands);
|
||||
}
|
||||
|
||||
template <typename GenericOpType>
|
||||
static LogicalResult verifyBlockArgs(GenericOpType op, Block &block);
|
||||
|
||||
template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) {
|
||||
LogicalResult verifyBlockArgs(GenericOp op, Block &block) {
|
||||
auto nOperands = op.getNumOperands();
|
||||
if (block.getNumArguments() != nOperands)
|
||||
return op.emitOpError("expected number of block arguments to match number "
|
||||
|
@ -205,7 +203,7 @@ template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) {
|
|||
return success();
|
||||
}
|
||||
|
||||
template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) {
|
||||
LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) {
|
||||
auto nInputViews = op.getNumInputs();
|
||||
auto nLoops = op.getNumLoops();
|
||||
auto nOperands = op.getNumOperands();
|
||||
|
@ -234,81 +232,6 @@ template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) {
|
|||
return success();
|
||||
}
|
||||
|
||||
template <typename GenericOpType>
|
||||
static LogicalResult verifyFuncArgs(GenericOpType op, FunctionType funType);
|
||||
|
||||
template <typename GenericOpType>
|
||||
static LogicalResult verifyFuncArgsGeneric(GenericOpType op,
|
||||
FunctionType funType) {
|
||||
auto res = verifyFuncArgs(op, funType);
|
||||
if (failed(res))
|
||||
return res;
|
||||
|
||||
auto nInputs = op.getNumInputs();
|
||||
auto nOutputs = op.getNumOutputs();
|
||||
// linalg.generic output element types are exactly the function results.
|
||||
for (unsigned idx = 0; idx < nOutputs; ++idx) {
|
||||
ShapedType shapedType = op.getShapedType(nInputs + idx);
|
||||
if (funType.getResult(idx) != shapedType.getElementType())
|
||||
return op.emitOpError("expected function result ")
|
||||
<< (idx + 1) << " of the same type as elemental type "
|
||||
<< shapedType.getElementType() << " of output " << (idx + 1);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
template <> LogicalResult verifyFuncArgs(GenericOp op, FunctionType funType) {
|
||||
auto nOperands = op.getNumOperands();
|
||||
if (funType.getNumInputs() != nOperands)
|
||||
return op.emitOpError(
|
||||
"expected function arguments to match number of operands");
|
||||
if (funType.getNumResults() != op.getNumOutputs())
|
||||
return op.emitOpError("expected function results(")
|
||||
<< funType.getNumResults() << ") to match number of outputs("
|
||||
<< op.getNumOutputs() << ")";
|
||||
|
||||
// linalg.generic operands element types are exactly the first function
|
||||
// arguments.
|
||||
for (unsigned idx = 0; idx < nOperands; ++idx) {
|
||||
ShapedType shapedType = op.getShapedType(idx);
|
||||
if (funType.getInput(idx) != shapedType.getElementType())
|
||||
return op.emitOpError("expected function argument ")
|
||||
<< (idx + 1) << " of the same type as elemental type "
|
||||
<< shapedType.getElementType() << " of operand " << (idx + 1);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult verifyFuncArgs(IndexedGenericOp op, FunctionType funType) {
|
||||
auto nLoops = op.getNumLoops();
|
||||
auto nOutputs = op.getNumOutputs();
|
||||
auto nOperands = op.getNumOperands();
|
||||
if (funType.getNumInputs() != nOperands + nLoops)
|
||||
return op.emitOpError("expected function arguments to match number of "
|
||||
"loops + number of operands");
|
||||
if (funType.getNumResults() != nOutputs)
|
||||
return op.emitOpError(
|
||||
"expected function results to match number of outputs");
|
||||
for (unsigned i = 0; i < nLoops; ++i)
|
||||
if (!funType.getInput(i).isIndex())
|
||||
return op.emitOpError("expected function argument ")
|
||||
<< (i + 1) << " to be an index";
|
||||
|
||||
// linalg.generic operands element types are exactly the first function
|
||||
// arguments.
|
||||
for (unsigned idx = 0; idx < nOperands; ++idx) {
|
||||
ShapedType shapedType = op.getShapedType(idx);
|
||||
if (funType.getInput(idx + nLoops) != shapedType.getElementType())
|
||||
return op.emitOpError("expected function argument ")
|
||||
<< (idx + nLoops + 1) << " of the same type as elemental type "
|
||||
<< shapedType.getElementType() << " of input " << (idx + 1);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename GenericOpType>
|
||||
static LogicalResult verifyGenericOp(GenericOpType op) {
|
||||
auto nInputViews = op.getNumInputs();
|
||||
|
@ -320,20 +243,10 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
|
|||
<< " inputs (tensor or buffer) and output buffer operands";
|
||||
|
||||
auto ®ion = op.region();
|
||||
auto funOp = op.getFunction();
|
||||
auto funType = funOp ? funOp.getType() : FunctionType();
|
||||
if (!region.empty()) {
|
||||
if (region.getBlocks().size() != 1)
|
||||
return op.emitOpError("expected region with 1 block");
|
||||
if (failed(verifyBlockArgs(op, region.getBlocks().front())))
|
||||
return failure();
|
||||
} else {
|
||||
if (!funOp || !funOp.getType())
|
||||
return op.emitOpError(
|
||||
"expected function attribute to refer to a defined symbol");
|
||||
if (failed(verifyFuncArgsGeneric(op, funType)))
|
||||
return failure();
|
||||
}
|
||||
if (region.getBlocks().size() != 1)
|
||||
return op.emitOpError("expected region with 1 block");
|
||||
if (failed(verifyBlockArgs(op, region.getBlocks().front())))
|
||||
return failure();
|
||||
|
||||
SmallVector<AffineMap, 4> indexingMaps;
|
||||
indexingMaps.reserve(op.indexing_maps().size());
|
||||
|
|
|
@ -382,8 +382,7 @@ static bool areTensorOpsFusible(LinalgOp producer, LinalgOp consumer,
|
|||
// - only handle ops that use regions for specifying the scalar operations.
|
||||
if (!producerOp || !consumerOp || producerOp.getNumOutputs() != 1 ||
|
||||
producerOp.getResult(0) != consumerOp.getOperand(consumerIdx) ||
|
||||
producerOp.getNumParallelLoops() != producerOp.getNumLoops() ||
|
||||
producerOp.fun() || consumerOp.fun())
|
||||
producerOp.getNumParallelLoops() != producerOp.getNumLoops())
|
||||
return false;
|
||||
|
||||
// Get the consumer index map. The number of results of the consumer index map
|
||||
|
@ -472,7 +471,6 @@ Optional<LinalgOp> mlir::linalg::fuseTensorOps(OpBuilder &b, LinalgOp producer,
|
|||
b.getI64IntegerAttr(fusedArgsIn), b.getI64IntegerAttr(fusedArgsOut),
|
||||
b.getArrayAttr(fusedIndexingMapAttrs), consumerOp.iterator_types(),
|
||||
/*doc=*/nullptr,
|
||||
/*fun=*/nullptr,
|
||||
/*library_call=*/nullptr);
|
||||
|
||||
// Build the region of the fused op.
|
||||
|
|
|
@ -400,21 +400,6 @@ public:
|
|||
indexedValues[nInputs + i] = std_load(output, indexing);
|
||||
}
|
||||
|
||||
auto funcOp = genericOp.getFunction();
|
||||
if (funcOp) {
|
||||
// 2. Emit call.
|
||||
Operation *callOp = std_call(funcOp, indexedValues);
|
||||
assert(callOp->getNumResults() == genericOp.getNumOutputs());
|
||||
|
||||
// 3. Emit std_store.
|
||||
for (unsigned i = 0; i < nOutputs; ++i) {
|
||||
Value output = genericOp.getOutputBuffer(i);
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
|
||||
std_store(callOp->getResult(i), output, indexing);
|
||||
}
|
||||
return;
|
||||
}
|
||||
// TODO(ntv): When a region inliner exists, use it.
|
||||
// 2. Inline region, currently only works for a single basic block.
|
||||
// 3. Emit std_store.
|
||||
|
@ -495,20 +480,6 @@ public:
|
|||
indexedValues[nLoops + nInputs + i] = std_load(output, indexing);
|
||||
}
|
||||
|
||||
if (auto funcOp = indexedGenericOp.getFunction()) {
|
||||
// 2. Emit call.
|
||||
Operation *callOp = std_call(funcOp, indexedValues);
|
||||
assert(callOp->getNumResults() == indexedGenericOp.getNumOutputs());
|
||||
|
||||
// 3. Emit std_store.
|
||||
for (unsigned i = 0; i < nOutputs; ++i) {
|
||||
Value output = indexedGenericOp.getOutputBuffer(i);
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
|
||||
std_store(callOp->getResult(i), output, indexing);
|
||||
}
|
||||
return;
|
||||
}
|
||||
// TODO(ntv): When a region inliner exists, use it.
|
||||
// 2. Inline region, currently only works for a single basic block.
|
||||
// 3. Emit std_store.
|
||||
|
|
|
@ -54,15 +54,26 @@ func @yield_parent(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
|
|||
|
||||
// -----
|
||||
|
||||
func @generic_no_region(%arg0: memref<f32>) {
|
||||
// expected-error @+6 {{expected '{' to begin a region}}
|
||||
linalg.generic {
|
||||
args_in = 1,
|
||||
args_out = 1,
|
||||
indexing_maps = [ affine_map<() -> (0)> ],
|
||||
iterator_types = []
|
||||
} %arg0 : memref<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @generic_at_least_2_operands(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected 2 or more operands}}
|
||||
linalg.generic {
|
||||
args_in = 1,
|
||||
args_out = 1,
|
||||
fun = @foo,
|
||||
indexing_maps = [ affine_map<() -> (0)> ],
|
||||
iterator_types = []
|
||||
} %arg0: memref<f32>
|
||||
} %arg0 {} : memref<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -72,182 +83,102 @@ func @generic_exactly_2_views(%arg0: memref<f32>) {
|
|||
linalg.generic {
|
||||
args_in = 1,
|
||||
args_out = 1,
|
||||
fun = @foo,
|
||||
indexing_maps = [ affine_map<() -> (0)> ],
|
||||
iterator_types = []
|
||||
} %arg0, %arg0, %arg0: memref<f32>, memref<f32>, memref<f32>
|
||||
} %arg0, %arg0, %arg0 {}: memref<f32>, memref<f32>, memref<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @generic_undefined_fun(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected function attribute to refer to a defined symbol}}
|
||||
linalg.generic {
|
||||
args_in = 1,
|
||||
args_out = 1,
|
||||
fun = @foo,
|
||||
indexing_maps = [ affine_map<() -> (0)> ],
|
||||
iterator_types = []
|
||||
} %arg0, %arg0: memref<f32>, memref<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @foo() { return }
|
||||
|
||||
func @generic_mismatched_num_arguments(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected function arguments to match number of operands}}
|
||||
linalg.generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
fun = @foo,
|
||||
indexing_maps = [ affine_map<() -> (0)> ],
|
||||
iterator_types = []
|
||||
} %arg0: memref<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @foo(%0: i32) { return }
|
||||
|
||||
func @generic_mismatched_num_returns(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected function results(0) to match number of outputs(1)}}
|
||||
// expected-error @+8 {{op expected number of yield values (1) to match the number of operands of the enclosing linalg.generic op (0)}}
|
||||
linalg.generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
fun = @foo,
|
||||
indexing_maps = [ affine_map<() -> (0)> ],
|
||||
indexing_maps = [ affine_map<() -> ()> ],
|
||||
iterator_types = []
|
||||
} %arg0: memref<f32>
|
||||
} %arg0 {
|
||||
^bb(%0: f32):
|
||||
linalg.yield
|
||||
}: memref<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @foo(%0: i32, %1: i32, %2: i32) { return }
|
||||
|
||||
func @generic_mismatched_num_returns(%0: memref<i32>, %1: memref<f32>) {
|
||||
// expected-error @+1 {{op expected function argument 2 of the same type as elemental type 'f32' of operand 2}}
|
||||
linalg.generic {
|
||||
args_in = 3,
|
||||
args_out = 0,
|
||||
fun = @foo,
|
||||
indexing_maps = [ affine_map<() -> (0)> ],
|
||||
iterator_types = []
|
||||
} %0, %1, %1: memref<i32>, memref<f32>, memref<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @foo(%0: i32, %1: i32, %2: f32) -> i32 { return %1: i32}
|
||||
|
||||
func @generic_mismatched_num_returns(%0: memref<i32>, %1: memref<f32>) {
|
||||
// expected-error @+1 {{op expected function result 1 of the same type as elemental type 'f32' of output 1}}
|
||||
linalg.generic {
|
||||
args_in = 2,
|
||||
args_out = 1,
|
||||
fun = @foo,
|
||||
indexing_maps = [ affine_map<() -> (0)> ],
|
||||
iterator_types = []
|
||||
} %0, %0, %1: memref<i32>, memref<i32>, memref<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @foo(%0: i32) -> i32 { return %0: i32 }
|
||||
|
||||
func @generic_symbol_in_map(%arg0: memref<i32>) {
|
||||
// expected-error @+1 {{op expected indexing_map #0 to have no symbols}}
|
||||
linalg.generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
fun = @foo,
|
||||
indexing_maps = [ affine_map<()[N] -> (0)> ],
|
||||
iterator_types = ["parallel"]
|
||||
} %arg0: memref<i32>
|
||||
} %arg0 {
|
||||
^bb(%i : i32):
|
||||
}: memref<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @foo(%0: i32) -> i32 { return %0: i32 }
|
||||
|
||||
func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) {
|
||||
// expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match the number of loops}}
|
||||
linalg.generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
fun = @foo,
|
||||
indexing_maps = [ affine_map<() -> (0)> ],
|
||||
iterator_types = ["parallel"]
|
||||
} %arg0: memref<1xi32>
|
||||
} %arg0 {
|
||||
^bb(%i : i32):
|
||||
}: memref<1xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @foo(%0: f32) -> f32 { return %0: f32 }
|
||||
|
||||
func @generic_one_d_view(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
|
||||
// expected-error @+1 {{op expected indexing_map #0 results to match view rank: 'memref<?xf32, affine_map<(d0)[s0] -> (d0 + s0)>>'}}
|
||||
linalg.generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
fun = @foo,
|
||||
indexing_maps = [ affine_map<() -> (0, 0)> ],
|
||||
iterator_types = []
|
||||
} %arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>
|
||||
} %arg0 {
|
||||
^bb(%f : f32):
|
||||
linalg.yield %f: f32
|
||||
}: memref<?xf32, affine_map<(i)[off]->(off + i)>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @foo(%0: i32) -> f32 {
|
||||
%1 = constant 0.0: f32
|
||||
return %1: f32
|
||||
}
|
||||
|
||||
func @generic_fun_arg_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
|
||||
// expected-error @+1 {{op expected function argument 1 of the same type as elemental type 'f32' of operand 1}}
|
||||
func @generic_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
|
||||
// expected-error @+9 {{'linalg.yield' op type of yield operand 1 ('i4') doesn't match the element type of the enclosing linalg.generic op ('f32')}}
|
||||
linalg.generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
fun = @foo,
|
||||
indexing_maps = [ affine_map<() -> (0)> ],
|
||||
iterator_types = []
|
||||
} %arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>
|
||||
indexing_maps = [ affine_map<(i) -> (i)> ],
|
||||
iterator_types = ["parallel"]
|
||||
} %arg0 {
|
||||
^bb(%0: f32):
|
||||
%1 = constant 1: i4
|
||||
linalg.yield %1: i4
|
||||
}: memref<?xf32, affine_map<(i)[off]->(off + i)>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @foo(%0: f32) -> i4 {
|
||||
%1 = constant 1: i4
|
||||
return %1: i4
|
||||
}
|
||||
|
||||
func @generic_fun_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
|
||||
// expected-error @+1 {{op expected function result 1 of the same type as elemental type 'f32' of output 1}}
|
||||
linalg.generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
fun = @foo,
|
||||
indexing_maps = [ affine_map<() -> (0)> ],
|
||||
iterator_types = []
|
||||
} %arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @foo(%0: f32, %1: f32) -> f32 { return %1: f32 }
|
||||
|
||||
func @generic_singular_maps(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>, %arg1: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
|
||||
// expected-error @+1 {{op expected the concatenation of maps in indexing_map to be invertible}}
|
||||
linalg.generic {
|
||||
args_in = 1,
|
||||
args_out = 1,
|
||||
fun = @foo,
|
||||
indexing_maps = [
|
||||
affine_map<(i, j) -> (i + j)>,
|
||||
affine_map<(i, j) -> (i + j)>
|
||||
],
|
||||
iterator_types = ["parallel","parallel"]
|
||||
} %arg0, %arg1: memref<?xf32, affine_map<(i)[off]->(off + i)>>, memref<?xf32, affine_map<(i)[off]->(off + i)>>
|
||||
} %arg0, %arg1 {
|
||||
^bb(%0: f32, %1: f32):
|
||||
linalg.yield %1: f32
|
||||
}: memref<?xf32, affine_map<(i)[off]->(off + i)>>,
|
||||
memref<?xf32, affine_map<(i)[off]->(off + i)>>
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -341,88 +272,53 @@ func @indexed_generic_block_arg_type(%arg0: memref<f32>) {
|
|||
|
||||
// -----
|
||||
|
||||
func @foo(%f: f32) -> (f32) {
|
||||
return %f : f32
|
||||
func @indexed_generic_arg_count(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected number of block arguments to match number of operands + number of loops}}
|
||||
linalg.indexed_generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
indexing_maps = [ affine_map<()[] -> ()> ],
|
||||
iterator_types = []
|
||||
} %arg0 {
|
||||
^bb(%0: index, %1: f32):
|
||||
linalg.yield %1: f32
|
||||
} : memref<f32>
|
||||
return
|
||||
}
|
||||
func @indexed_generic_fun_arg_count(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected function arguments to match number of loops + number of operands}}
|
||||
|
||||
// -----
|
||||
|
||||
func @indexed_generic_induction_var_arg_type(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected block argument 1 to be an index}}
|
||||
linalg.indexed_generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
iterator_types = ["parallel"],
|
||||
indexing_maps = [ affine_map<(i) -> (i)> ]
|
||||
} %arg0 {
|
||||
^bb(%0: i32, %1: f32):
|
||||
linalg.yield %1: f32
|
||||
} : memref<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @indexed_generic_result_count(%arg0: memref<?xf32>) {
|
||||
// expected-error @+8 {{op expected number of yield values (1) to match the number of operands of the enclosing linalg.generic op (2)}}
|
||||
linalg.indexed_generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
indexing_maps = [ affine_map<(d0) -> (d0)> ],
|
||||
iterator_types = ["parallel"],
|
||||
fun = @foo
|
||||
} %arg0: memref<f32>
|
||||
iterator_types = ["parallel"]
|
||||
} %arg0 {
|
||||
^bb(%i: index, %val: f32):
|
||||
linalg.yield %val, %val: f32, f32
|
||||
}: memref<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @foo(%i: i32, %val: f32) -> (f32) {
|
||||
return %val : f32
|
||||
}
|
||||
func @indexed_generic_fun_induction_var_arg_type(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected function argument 1 to be an index}}
|
||||
linalg.indexed_generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
iterator_types = ["parallel"],
|
||||
indexing_maps = [ affine_map<(i) -> (i)> ],
|
||||
fun = @foo
|
||||
} %arg0 : memref<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @foo(%i: index, %val: i1) -> (i1) {
|
||||
return %val : i1
|
||||
}
|
||||
func @indexed_generic_fun_arg_type(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected function argument 2 of the same type as elemental type 'f32' of input 1}}
|
||||
linalg.indexed_generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
indexing_maps = [ affine_map<(d0) -> (d0)> ],
|
||||
iterator_types = ["parallel"],
|
||||
fun = @foo
|
||||
} %arg0: memref<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @foo(%i: index, %val: i1) -> (i1, i1) {
|
||||
return %val, %val : i1, i1
|
||||
}
|
||||
func @indexed_generic_fun_result_count(%arg0: memref<f32>) {
|
||||
// expected-error @+1 {{op expected function results to match number of outputs}}
|
||||
linalg.indexed_generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
indexing_maps = [ affine_map<(d0) -> (d0)> ],
|
||||
iterator_types = ["parallel"],
|
||||
fun = @foo
|
||||
} %arg0: memref<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @foo(%i: index, %val: i32) -> (f32) {
|
||||
%val_float = sitofp %val : i32 to f32
|
||||
return %val_float : f32
|
||||
}
|
||||
func @indexed_generic_fun_result_count(%arg0: memref<i32>) {
|
||||
// expected-error @+1 {{op expected function result 1 of the same type as elemental type 'i32' of output 1}}
|
||||
linalg.indexed_generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
indexing_maps = [ affine_map<(d0) -> (d0)> ],
|
||||
iterator_types = ["parallel"],
|
||||
fun = @foo
|
||||
} %arg0: memref<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @generic_fun_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
|
||||
func @generic_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
|
||||
// expected-error @+9 {{type of yield operand 1 ('i1') doesn't match the element type of the enclosing linalg.generic op ('f32')}}
|
||||
linalg.generic {
|
||||
args_in = 0,
|
||||
|
@ -453,7 +349,7 @@ func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off
|
|||
|
||||
// -----
|
||||
|
||||
func @generic_fun_result_0_element_type(%arg0: memref<?xf32>) {
|
||||
func @generic_result_0_element_type(%arg0: memref<?xf32>) {
|
||||
// expected-error @+1 {{'linalg.dot' op expected 3 operands, but found 2}}
|
||||
linalg.dot(%arg0, %arg0): memref<?xf32>, memref<?xf32>
|
||||
}
|
||||
|
|
|
@ -533,51 +533,11 @@ func @pooling_sum(%arg0: memref<?x?xf32>,
|
|||
// CHECKPARALLEL: %[[RES:.*]] = addf %[[LHS]], %[[RHS]] : f32
|
||||
// CHECKPARALLEL: store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
|
||||
|
||||
func @foo(%0: f32, %1: f32, %2: f32) -> (f32, f32) {
|
||||
%f0 = constant 0.0 : f32
|
||||
return %f0, %f0 : f32, f32
|
||||
}
|
||||
#accesses = [
|
||||
affine_map<(i, j, k) -> (i, j)>,
|
||||
affine_map<(i, j, k) -> (i, j, k)>,
|
||||
affine_map<(i, j, k) -> (i, k, j)>
|
||||
]
|
||||
#trait = {
|
||||
args_in = 1,
|
||||
args_out = 2,
|
||||
iterator_types = ["parallel", "parallel", "parallel"],
|
||||
indexing_maps = #accesses,
|
||||
fun = @foo,
|
||||
library_call = "some_external_function_name_1",
|
||||
doc = "B(i,j,k), C(i,k,j) = foo(A(i, j), B(i,j,k), C(i,k,j))"
|
||||
}
|
||||
func @generic_function(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg2: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
|
||||
linalg.generic #trait %arg0, %arg1, %arg2:
|
||||
memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
|
||||
return
|
||||
}
|
||||
// CHECKLOOP-LABEL: @foo
|
||||
// CHECKLOOP-LABEL: @generic_function
|
||||
// CHECKLOOP: loop.for %[[i:.*]] = {{.*}}
|
||||
// CHECKLOOP: loop.for %[[j:.*]] = {{.*}}
|
||||
// CHECKLOOP: loop.for %[[k:.*]] = {{.*}}
|
||||
// CHECKLOOP: %[[a:.*]] = load %{{.*}}[%[[i]], %[[j]]] : memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECKLOOP: %[[b:.*]] = load %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref<?x?x?xf32, #[[strided3D]]>
|
||||
// CHECKLOOP: %[[c:.*]] = load %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
|
||||
// CHECKLOOP: %[[res:.*]]:2 = call @foo(%[[a]], %[[b]], %[[c]]) : (f32, f32, f32) -> (f32, f32)
|
||||
// CHECKLOOP: store %[[res]]#0, %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref<?x?x?xf32, #[[strided3D]]>
|
||||
// CHECKLOOP: store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
|
||||
|
||||
// CHECKPARALLEL-LABEL: @foo
|
||||
// CHECKPARALLEL-LABEL: @generic_function
|
||||
// CHECKPARALLEL: loop.parallel (%[[i:[a-zA-Z0-9_]*]], %[[j:[a-zA-Z0-9_]*]], %[[k:[a-zA-Z0-9_]*]])
|
||||
// CHECKPARALLEL: %[[a:.*]] = load %{{.*}}[%[[i]], %[[j]]] : memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECKPARALLEL: %[[b:.*]] = load %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref<?x?x?xf32, #[[strided3D]]>
|
||||
// CHECKPARALLEL: %[[c:.*]] = load %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
|
||||
// CHECKPARALLEL: %[[res:.*]]:2 = call @foo(%[[a]], %[[b]], %[[c]]) : (f32, f32, f32) -> (f32, f32)
|
||||
// CHECKPARALLEL: store %[[res]]#0, %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref<?x?x?xf32, #[[strided3D]]>
|
||||
// CHECKPARALLEL: store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
|
||||
|
||||
#trait2 = {
|
||||
args_in = 1,
|
||||
args_out = 2,
|
||||
|
@ -617,52 +577,6 @@ func @generic_region(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1:
|
|||
// CHECKPARALLEL: store %[[d]], %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref<?x?x?xf32, #[[strided3D]]>
|
||||
// CHECKPARALLEL: store %[[e]], %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
|
||||
|
||||
func @indexed_foo(%i: index, %j: index, %k: index, %0: f32, %1: f32, %2: f32) -> (f32, f32) {
|
||||
%i_int = index_cast %i: index to i32
|
||||
%i_float = sitofp %i_int : i32 to f32
|
||||
return %i_float, %i_float : f32, f32
|
||||
}
|
||||
#trait3 = {
|
||||
args_in = 1,
|
||||
args_out = 2,
|
||||
iterator_types = ["parallel", "parallel", "parallel"],
|
||||
indexing_maps = #accesses,
|
||||
fun = @indexed_foo,
|
||||
library_call = "some_external_function_name_1",
|
||||
doc = "b(i,j,k), c(i,k,j) = foo(a(i, j), b(i,j,k), c(i,k,j))"
|
||||
}
|
||||
func @indexed_generic_function(
|
||||
%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
%arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
|
||||
%arg2: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
|
||||
linalg.indexed_generic #trait3 %arg0, %arg1, %arg2:
|
||||
memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
|
||||
memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
|
||||
return
|
||||
}
|
||||
// CHECKLOOP-LABEL: @indexed_foo
|
||||
// CHECKLOOP-LABEL: @indexed_generic_function
|
||||
// CHECKLOOP: loop.for %[[i:.*]] = {{.*}}
|
||||
// CHECKLOOP: loop.for %[[j:.*]] = {{.*}}
|
||||
// CHECKLOOP: loop.for %[[k:.*]] = {{.*}}
|
||||
// CHECKLOOP: %[[a:.*]] = load %{{.*}}[%[[i]], %[[j]]] : memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECKLOOP: %[[b:.*]] = load %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref<?x?x?xf32, #[[strided3D]]>
|
||||
// CHECKLOOP: %[[c:.*]] = load %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
|
||||
// CHECKLOOP: %[[res:.*]]:2 = call @indexed_foo(%[[i]], %[[j]], %[[k]], %[[a]], %[[b]], %[[c]]) : (index, index, index, f32, f32, f32) -> (f32, f32)
|
||||
// CHECKLOOP: store %[[res]]#0, %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref<?x?x?xf32, #[[strided3D]]>
|
||||
// CHECKLOOP: store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
|
||||
|
||||
// CHECKPARALLEL-LABEL: @indexed_foo
|
||||
// CHECKPARALLEL-LABEL: @indexed_generic_function
|
||||
// CHECKPARALLEL: loop.parallel (%[[i:[a-zA-Z0-9_]*]], %[[j:[a-zA-Z0-9_]*]], %[[k:[a-zA-Z0-9_]*]])
|
||||
// CHECKPARALLEL: %[[a:.*]] = load %{{.*}}[%[[i]], %[[j]]] : memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECKPARALLEL: %[[b:.*]] = load %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref<?x?x?xf32, #[[strided3D]]>
|
||||
// CHECKPARALLEL: %[[c:.*]] = load %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
|
||||
// CHECKPARALLEL: %[[res:.*]]:2 = call @indexed_foo(%[[i]], %[[j]], %[[k]], %[[a]], %[[b]], %[[c]]) : (index, index, index, f32, f32, f32) -> (f32, f32)
|
||||
// CHECKPARALLEL: store %[[res]]#0, %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref<?x?x?xf32, #[[strided3D]]>
|
||||
// CHECKPARALLEL: store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
|
||||
|
||||
#trait4 = {
|
||||
args_in = 1,
|
||||
args_out = 2,
|
||||
|
|
|
@ -289,11 +289,6 @@ func @pooling_sum(%arg0: memref<?x?x?xf32>,
|
|||
// CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
|
||||
// CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
|
||||
|
||||
func @foo(%0: vector<3x4xi4>, %1: f32) -> f32 {
|
||||
%f0 = constant 0.0 : f32
|
||||
return %f0 : f32
|
||||
}
|
||||
|
||||
#accesses = [
|
||||
affine_map<(i, j, k) -> (j, i)>,
|
||||
affine_map<(i, j, k) -> (i, k, i + j)>
|
||||
|
@ -304,46 +299,45 @@ func @foo(%0: vector<3x4xi4>, %1: f32) -> f32 {
|
|||
args_out = 1,
|
||||
indexing_maps = #accesses,
|
||||
iterator_types = ["parallel", "parallel", "parallel"],
|
||||
fun = @foo,
|
||||
library_call = "some_external_function_name_1"
|
||||
}
|
||||
|
||||
func @generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
|
||||
%arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
|
||||
linalg.generic #trait %arg0, %arg1 {foo = 1} :
|
||||
memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
|
||||
memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
|
||||
linalg.generic #trait {foo = 1} %arg0, %arg1 {
|
||||
^bb(%0: vector<3x4xi4>, %1: f32) :
|
||||
%f0 = constant 0.0 : f32
|
||||
linalg.yield %f0 : f32
|
||||
} : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
|
||||
memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @foo
|
||||
// CHECK-LABEL: func @generic
|
||||
// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo,
|
||||
// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64,
|
||||
// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
|
||||
// CHECK-SAME: library_call = "some_external_function_name_1"
|
||||
// CHECK-SAME: {foo = 1 : i64}:
|
||||
// CHECK-SAME: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
|
||||
// CHECK-SAME: {foo = 1 : i64}
|
||||
// CHECK: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
|
||||
|
||||
func @generic_with_tensor_input(%arg0: tensor<?x?xvector<3x4xi4>>,
|
||||
%arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
|
||||
linalg.generic #trait %arg0, %arg1 {foo = 1} :
|
||||
tensor<?x?xvector<3x4xi4>>,
|
||||
memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
|
||||
linalg.generic #trait {foo = 1} %arg0, %arg1 {
|
||||
^bb(%0: vector<3x4xi4>, %1: f32) :
|
||||
%f0 = constant 0.0 : f32
|
||||
linalg.yield %f0 : f32
|
||||
} : tensor<?x?xvector<3x4xi4>>,
|
||||
memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @generic_with_tensor_input
|
||||
// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo,
|
||||
// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64,
|
||||
// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
|
||||
// CHECK-SAME: library_call = "some_external_function_name_1"}
|
||||
// CHECK-SAME: {foo = 1 : i64}:
|
||||
// CHECK-SAME: tensor<?x?xvector<3x4xi4>>, memref<?x?x?xf32, #[[strided3D]]>
|
||||
// CHECK-SAME: {foo = 1 : i64}
|
||||
// CHECK: tensor<?x?xvector<3x4xi4>>, memref<?x?x?xf32, #[[strided3D]]>
|
||||
|
||||
// -----
|
||||
|
||||
func @foo(%0: vector<3x4xi4>, %1: f32) -> f32 {
|
||||
%f0 = constant 0.0 : f32
|
||||
return %f0 : f32
|
||||
}
|
||||
|
||||
#accesses = [
|
||||
affine_map<(i, j, k) -> (j, i)>,
|
||||
affine_map<(i, j, k) -> (i, k, i + j)>
|
||||
|
@ -354,31 +348,30 @@ func @foo(%0: vector<3x4xi4>, %1: f32) -> f32 {
|
|||
args_out = 1,
|
||||
indexing_maps = #accesses,
|
||||
iterator_types = ["parallel", "parallel", "parallel"],
|
||||
fun = @foo,
|
||||
library_call = "some_external_function_name_1"
|
||||
}
|
||||
|
||||
func @generic_with_tensor_input_and_output(
|
||||
%arg0: tensor<?x?xvector<3x4xi4>>, %arg1: tensor<?x?x?xf32>)
|
||||
-> (tensor<?x?x?xf32>) {
|
||||
%0 = linalg.generic #trait2 %arg0, %arg1 {foo = 1} :
|
||||
tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
|
||||
%0 = linalg.generic #trait2 {foo = 1} %arg0, %arg1 {
|
||||
^bb(%0: vector<3x4xi4>, %1: f32) :
|
||||
%f0 = constant 0.0 : f32
|
||||
linalg.yield %f0 : f32
|
||||
} : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
|
||||
return %0 : tensor<?x?x?xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @generic_with_tensor_input_and_output
|
||||
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, fun = @foo,
|
||||
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
|
||||
// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
|
||||
// CHECK-SAME: library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}:
|
||||
// CHECK-SAME: tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
|
||||
// CHECK-SAME: library_call = "some_external_function_name_1"}
|
||||
// CHECK-SAME: {foo = 1 : i64}
|
||||
// CHECK-SAME: %{{.*}}, %{{.*}}
|
||||
// CHECK: tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
|
||||
// CHECK: return {{.*}} : tensor<?x?x?xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func @foo(%i: index, %j: index, %k: index, %0: vector<3x4xi4>, %1: f32) -> f32 {
|
||||
%f0 = constant 0.0 : f32
|
||||
return %f0 : f32
|
||||
}
|
||||
|
||||
#accesses = [
|
||||
affine_map<(i, j, k) -> (j, i)>,
|
||||
affine_map<(i, j, k) -> (i, k, i + j)>
|
||||
|
@ -389,22 +382,26 @@ func @foo(%i: index, %j: index, %k: index, %0: vector<3x4xi4>, %1: f32) -> f32 {
|
|||
args_out = 1,
|
||||
indexing_maps = #accesses,
|
||||
iterator_types = ["parallel", "parallel", "parallel"],
|
||||
fun = @foo,
|
||||
library_call = "some_external_function_name_1"
|
||||
}
|
||||
|
||||
func @indexed_generic_with_tensor_input_and_output(
|
||||
%arg0: tensor<?x?xvector<3x4xi4>>, %arg1: tensor<?x?x?xf32>)
|
||||
-> (tensor<?x?x?xf32>) {
|
||||
%0 = linalg.indexed_generic #trait2 %arg0, %arg1 {foo = 1} :
|
||||
tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
|
||||
%0 = linalg.indexed_generic #trait2 {foo = 1} %arg0, %arg1 {
|
||||
^bb(%i: index, %j: index, %k: index, %0: vector<3x4xi4>, %1: f32) :
|
||||
%f0 = constant 0.0 : f32
|
||||
linalg.yield %f0 : f32
|
||||
} : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
|
||||
return %0 : tensor<?x?x?xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @indexed_generic_with_tensor_input_and_output
|
||||
// CHECK: linalg.indexed_generic {args_in = 2 : i64, args_out = 1 : i64, fun = @foo,
|
||||
// CHECK: linalg.indexed_generic {args_in = 2 : i64, args_out = 1 : i64,
|
||||
// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
|
||||
// CHECK-SAME: library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}:
|
||||
// CHECK-SAME: tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
|
||||
// CHECK-SAME: library_call = "some_external_function_name_1"}
|
||||
// CHECK-SAME: {foo = 1 : i64}
|
||||
// CHECK-SAME: %{{.*}}, %{{.*}}
|
||||
// CHECK: tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
|
||||
// CHECK: return {{.*}} : tensor<?x?x?xf32>
|
||||
|
||||
// -----
|
||||
|
@ -460,10 +457,10 @@ func @indexed_generic_op_zero_rank(%arg0: tensor<f32>) -> (tensor<3x4xf32>)
|
|||
|
||||
func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
|
||||
%arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
|
||||
linalg.generic #trait3 %arg0, %arg1 {
|
||||
linalg.generic #trait3 {foo = 1} %arg0, %arg1 {
|
||||
^bb(%a: vector<3x4xi4>, %b: f32) :
|
||||
linalg.yield %b : f32
|
||||
} {foo = 1}: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
|
||||
} : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
|
||||
memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
|
||||
return
|
||||
}
|
||||
|
@ -471,17 +468,18 @@ func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1
|
|||
// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64,
|
||||
// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
|
||||
// CHECK-SAME: library_call = "some_external_function_name_2"
|
||||
// CHECK-SAME: {foo = 1 : i64}
|
||||
// CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32):
|
||||
// CHECK: linalg.yield %{{.*}} : f32
|
||||
// CHECK: } {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>,
|
||||
// CHECK-SAME: memref<?x?x?xf32, #[[strided3D]]>
|
||||
// CHECK: memref<?x?xvector<3x4xi4>, #[[strided2D]]>,
|
||||
// CHECK-SAME: memref<?x?x?xf32, #[[strided3D]]>
|
||||
|
||||
func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
|
||||
%arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
|
||||
linalg.indexed_generic #trait3 %arg0, %arg1 {
|
||||
linalg.indexed_generic #trait3 {foo = 1} %arg0, %arg1 {
|
||||
^bb(%i: index, %j: index, %k: index, %a: vector<3x4xi4>, %b: f32) :
|
||||
linalg.yield %b : f32
|
||||
} {foo = 1}: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
|
||||
}: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
|
||||
memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
|
||||
return
|
||||
}
|
||||
|
@ -489,9 +487,10 @@ func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?,
|
|||
// CHECK: linalg.indexed_generic {args_in = 1 : i64, args_out = 1 : i64,
|
||||
// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
|
||||
// CHECK-SAME: library_call = "some_external_function_name_2"
|
||||
// CHECK-SAME: {foo = 1 : i64}
|
||||
// CHECK: ^{{.*}}(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: vector<3x4xi4>, %{{.*}}: f32):
|
||||
// CHECK: linalg.yield %{{.*}} : f32
|
||||
// CHECK: } {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>,
|
||||
// CHECK: }: memref<?x?xvector<3x4xi4>, #[[strided2D]]>,
|
||||
// CHECK-SAME: memref<?x?x?xf32, #[[strided3D]]>
|
||||
|
||||
// -----
|
||||
|
|
|
@ -212,57 +212,71 @@ func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
|
|||
// CHECK-LABEL: func @test_vectorize_fill
|
||||
// CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32>
|
||||
|
||||
func @fma(%a: f32, %b: f32, %c: f32) -> f32 {
|
||||
%d = mulf %a, %b: f32
|
||||
%e = addf %c, %d: f32
|
||||
return %e: f32
|
||||
}
|
||||
#matmul_accesses = [
|
||||
affine_map<(m, n, k) -> (m, k)>,
|
||||
affine_map<(m, n, k) -> (k, n)>,
|
||||
affine_map<(m, n, k) -> (m, n)>
|
||||
affine_map<(m, n, k) -> (m, k)>,
|
||||
affine_map<(m, n, k) -> (k, n)>,
|
||||
affine_map<(m, n, k) -> (m, n)>
|
||||
]
|
||||
#generic_matmul_trait = {
|
||||
args_in = 2,
|
||||
args_out = 1,
|
||||
fun = @fma,
|
||||
indexing_maps = #matmul_accesses,
|
||||
library_call = "linalg_matmul",
|
||||
iterator_types = ["parallel", "parallel", "reduction"]
|
||||
}
|
||||
args_in = 2,
|
||||
args_out = 1,
|
||||
indexing_maps = #matmul_accesses,
|
||||
library_call = "linalg_matmul",
|
||||
iterator_types = ["parallel", "parallel", "reduction"]
|
||||
}
|
||||
func @permute_generic(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
|
||||
linalg.generic #generic_matmul_trait %A, %B, %C : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
|
||||
linalg.generic #generic_matmul_trait %A, %B, %C {
|
||||
^bb(%a: f32, %b: f32, %c: f32):
|
||||
%d = mulf %a, %b: f32
|
||||
%e = addf %c, %d: f32
|
||||
linalg.yield %e: f32
|
||||
}: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL : func @fma
|
||||
// CHECK-LABEL : func @permute_generic
|
||||
// CHECK : linalg.generic {args_in = 2, args_out = 1, fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul"} %{{.*}}, %{{.*}}, %{{.*}} : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
|
||||
// CHECK : linalg.generic {args_in = 2, args_out = 1,
|
||||
// CHECK-SAME : indexing_maps = [#[[kn]], #[[nm]], #[[km]]],
|
||||
// CHECK-SAME : iterator_types = ["parallel", "reduction", "parallel"],
|
||||
// CHECK-SAME : library_call = "linalg_matmul"} %{{.*}}, %{{.*}}, %{{.*}}
|
||||
// CHECK : memref<?x?xf32, #[[STRIDED_2D]]>,
|
||||
// CHECK-SAME : memref<?x?xf32, #[[STRIDED_2D]]>,
|
||||
// CHECK-SAME : memref<?x?xf32, #[[STRIDED_2D]]>
|
||||
|
||||
func @fma_indexed(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32) -> f32 {
|
||||
%d = mulf %a, %b: f32
|
||||
%e = addf %c, %d: f32
|
||||
return %e: f32
|
||||
}
|
||||
#indexed_matmul_trait = {
|
||||
args_in = 2,
|
||||
args_out = 1,
|
||||
fun = @fma_indexed,
|
||||
indexing_maps = #matmul_accesses,
|
||||
library_call = "linalg_matmul_indexed",
|
||||
iterator_types = ["parallel", "parallel", "reduction"]
|
||||
args_in = 2,
|
||||
args_out = 1,
|
||||
indexing_maps = #matmul_accesses,
|
||||
library_call = "linalg_matmul_indexed",
|
||||
iterator_types = ["parallel", "parallel", "reduction"]
|
||||
}
|
||||
func @permute_generic_indexed(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
|
||||
linalg.indexed_generic #indexed_matmul_trait %A, %B, %C : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
func @permute_generic_indexed(
|
||||
%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
|
||||
linalg.indexed_generic #indexed_matmul_trait %A, %B, %C {
|
||||
^bb(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32):
|
||||
%d = mulf %a, %b: f32
|
||||
%e = addf %c, %d: f32
|
||||
linalg.yield %e: f32
|
||||
} : memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL : func @fma_indexed
|
||||
// CHECK-LABEL : func @permute_generic_indexed
|
||||
// CHECK : linalg.indexed_generic {args_in = 2, args_out = 1, fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul_indexed"} %{{.*}}, %{{.*}}, %{{.*}} : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
|
||||
// CHECK : linalg.indexed_generic {args_in = 2, args_out = 1,
|
||||
// CHECK-SAME : indexing_maps = [#[[kn]], #[[nm]], #[[km]]],
|
||||
// CHECK-SAME : iterator_types = ["parallel", "reduction", "parallel"],
|
||||
// CHECK-SAME : library_call = "linalg_matmul_indexed"} %{{.*}}, %{{.*}}, %{{.*}} :
|
||||
// CHECK : memref<?x?xf32, #[[STRIDED_2D]]>,
|
||||
// CHECK-SAME : memref<?x?xf32, #[[STRIDED_2D]]>,
|
||||
// CHECK-SAME : memref<?x?xf32, #[[STRIDED_2D]]>
|
||||
|
||||
func @dot_perm(%x: memref<?xf32, offset: ?, strides: [1]>,
|
||||
%y: memref<?xf32, offset: ?, strides: [1]>,
|
||||
|
|
|
@ -111,7 +111,7 @@ def : Pattern<(FillOp:$op $_, $_),
|
|||
HasLinalgTransformMarker<"VECTORIZE">,
|
||||
PreconditionVectorizeLinalgOp
|
||||
]>>)]>;
|
||||
def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
|
||||
def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_),
|
||||
[(VectorizeLinalgOp)],
|
||||
[(Constraint<And<[
|
||||
HasLinalgTransformMarker<"VECTORIZE">,
|
||||
|
@ -122,7 +122,7 @@ def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
|
|||
//===----------------------------------------------------------------------===//
|
||||
// Linalg generic permutation patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pat<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
|
||||
def : Pat<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_),
|
||||
(PermuteGenericLinalgOp<[1, 2, 0], "PERMUTE"> $op),
|
||||
[(Constraint<And<[
|
||||
HasNoLinalgTransformMarker,
|
||||
|
@ -130,7 +130,7 @@ def : Pat<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
|
|||
PreconditionPermuteGenericLinalgOp<[1, 2, 0]>
|
||||
]>>)]>;
|
||||
|
||||
def : Pat<(IndexedGenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
|
||||
def : Pat<(IndexedGenericOp:$op $_, $_, $_, $_, $_, $_, $_),
|
||||
(PermuteGenericLinalgOp<[1, 2, 0], "PERMUTE"> $op),
|
||||
[(Constraint<And<[
|
||||
HasNoLinalgTransformMarker,
|
||||
|
|
Loading…
Reference in New Issue