forked from OSchip/llvm-project
Add a region to linalg.generic
This CL extends the Linalg GenericOp with an alternative way of specifying the body of the computation based on a single block region. The "fun" attribute becomes optional. Either a SymbolRef "fun" attribute or a single block region must be specified to describe the side-effect-free computation. Upon lowering to loops, the new region body is inlined in the innermost loop. The parser, verifier and pretty printer are extended. Appropriate roundtrip, negative and lowering to loop tests are added. PiperOrigin-RevId: 261895568
This commit is contained in:
parent
24647750d4
commit
4b422a51ed
|
@ -335,12 +335,13 @@ def GenericOp : LinalgLibraryBase_Op<"generic", []> {
|
|||
```
|
||||
}];
|
||||
let arguments = (ins Variadic<View>:$views,
|
||||
SymbolRefAttr:$fun,
|
||||
AffineMapArrayAttr:$indexing_maps,
|
||||
I64ArrayAttr:$n_loop_types,
|
||||
I64ArrayAttr:$n_views,
|
||||
OptionalAttr<StrAttr>:$doc,
|
||||
OptionalAttr<SymbolRefAttr>:$fun,
|
||||
OptionalAttr<StrAttr>:$library_call);
|
||||
let regions = (region AnyRegion:$region);
|
||||
let extraClassDeclaration = [{
|
||||
SmallVector<StringRef, 8> linalgTraitAttrNames() {
|
||||
return SmallVector<StringRef, 8>{
|
||||
|
@ -386,8 +387,10 @@ def GenericOp : LinalgLibraryBase_Op<"generic", []> {
|
|||
return getNumParallelLoops() + getNumReductionLoops() +
|
||||
getNumWindowLoops();
|
||||
}
|
||||
StringRef getFunName() {
|
||||
return fun();
|
||||
FuncOp getFunction() {
|
||||
auto moduleOp = getParentOfType<ModuleOp>();
|
||||
return fun().hasValue() ?
|
||||
moduleOp.lookupSymbol<FuncOp>(fun().getValue()) : FuncOp();
|
||||
}
|
||||
StringRef getLibraryCallName() {
|
||||
return library_call().hasValue() ? library_call().getValue() : "";
|
||||
|
|
|
@ -20,6 +20,8 @@
|
|||
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Linalg/IR/LinalgTraits.h"
|
||||
#include "mlir/Linalg/IR/LinalgTypes.h"
|
||||
|
|
|
@ -213,4 +213,15 @@ def SubViewOp : Linalg_Op<"subview", [NoSideEffect]>,
|
|||
}];
|
||||
}
|
||||
|
||||
def YieldOp : Linalg_Op<"yield", [NativeOpTrait<"IsTerminator">]>,
|
||||
Arguments<(ins Variadic<AnyType>:$values)> {
|
||||
let summary = "Linalg yield operation";
|
||||
let description = [{
|
||||
"linalg.yield" is a special terminator operation for blocks inside regions
|
||||
in linalg ops. It returns values to the immediately enclosing linalg op.
|
||||
|
||||
linalg.yield %f0, %f1 : f32, f32
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // LINALG_OPS
|
||||
|
|
|
@ -113,6 +113,11 @@ public:
|
|||
}
|
||||
/// Return the number of input and output views.
|
||||
unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); }
|
||||
/// Return the `i`-th view type.
|
||||
mlir::linalg::ViewType getViewType(unsigned i) {
|
||||
return (i < nInputs()) ? getInputViewType(i)
|
||||
: getOutputViewType(i - nInputs());
|
||||
}
|
||||
/// Return the range over input and output views.
|
||||
Operation::operand_range getInputsAndOutputs() {
|
||||
auto range = this->getOperation()->getOperands();
|
||||
|
|
|
@ -538,13 +538,15 @@ static void print(OpAsmPrinter *p, GenericOp op) {
|
|||
auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
|
||||
*p << op.getOperationName() << " " << dictAttr << " ";
|
||||
p->printOperands(op.getOperands());
|
||||
if (!op.region().empty())
|
||||
p->printRegion(op.region());
|
||||
p->printOptionalAttrDict(op.getAttrs(), attrNames);
|
||||
*p << ": ";
|
||||
interleaveComma(op.getOperandTypes(), *p);
|
||||
}
|
||||
|
||||
static ParseResult parseGenericOp(OpAsmParser *parser, OperationState *result) {
|
||||
SmallVector<OpAsmParser::OperandType, 8> operandsInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 8> operandsInfo, regionOperandsInfo;
|
||||
DictionaryAttr dictAttr;
|
||||
// Parse the core linalg traits that must check into a dictAttr.
|
||||
// The name is unimportant as we will overwrite result->attributes.
|
||||
|
@ -556,8 +558,13 @@ static ParseResult parseGenericOp(OpAsmParser *parser, OperationState *result) {
|
|||
result->attributes.assign(dictAttr.getValue().begin(),
|
||||
dictAttr.getValue().end());
|
||||
|
||||
Region ®ion = *result->addRegion();
|
||||
SmallVector<Type, 8> operandTypes, regionTypes;
|
||||
// Optional attributes may be added.
|
||||
SmallVector<Type, 8> operandTypes;
|
||||
// Either Optional "fun" attribute or region must be specified.
|
||||
if (!dictAttr.get("fun") &&
|
||||
parser->parseOptionalRegion(region, regionOperandsInfo, regionTypes))
|
||||
return failure();
|
||||
if (parser->parseOptionalAttributeDict(result->attributes) ||
|
||||
parser->parseColonTypeList(operandTypes))
|
||||
return failure();
|
||||
|
@ -572,18 +579,36 @@ static LogicalResult verify(GenericOp op) {
|
|||
if (nViews != llvm::size(op.views()))
|
||||
return op.emitError("op expected exactly ") << nViews << " view operands";
|
||||
|
||||
auto m = op.getParentOfType<ModuleOp>();
|
||||
auto fun = m.lookupSymbol<FuncOp>(op.fun());
|
||||
if (!fun || !fun.getType())
|
||||
return op.emitError(
|
||||
"op expected fun attribute to refer to a defined symbol");
|
||||
auto ®ion = op.region();
|
||||
auto funOp = op.getFunction();
|
||||
auto funType = funOp ? funOp.getType() : FunctionType();
|
||||
if (!region.empty()) {
|
||||
if (region.getBlocks().size() != 1)
|
||||
return op.emitError("op expected region with 1 block");
|
||||
|
||||
auto funType = fun.getType();
|
||||
if (funType.getNumInputs() != nViews)
|
||||
return op.emitError("op expected fun arguments to match number of views");
|
||||
if (funType.getNumResults() != op.getNumOutputs())
|
||||
return op.emitError(
|
||||
"op expected fun results to match number of output views");
|
||||
auto &block = region.getBlocks().front();
|
||||
if (block.getNumArguments() != nViews)
|
||||
return op.emitError(
|
||||
"op expected number of block arguments to match number of views");
|
||||
|
||||
for (unsigned i = 0; i < nViews; ++i) {
|
||||
auto viewType = op.getViewType(i);
|
||||
if (viewType.getElementType() != block.getArgument(i)->getType())
|
||||
return op.emitError("op expected block argument ")
|
||||
<< i << " of the same type as elemental type of "
|
||||
<< ((i < nInputViews) ? "input " : "output ")
|
||||
<< "view: " << viewType;
|
||||
}
|
||||
} else {
|
||||
if (!funOp || !funOp.getType())
|
||||
return op.emitError(
|
||||
"op expected fun attribute to refer to a defined symbol");
|
||||
if (funType.getNumInputs() != nViews)
|
||||
return op.emitError("op expected fun arguments to match number of views");
|
||||
if (funType.getNumResults() != op.getNumOutputs())
|
||||
return op.emitError(
|
||||
"op expected fun results to match number of output views");
|
||||
}
|
||||
|
||||
auto nLoops = op.getNumLoops();
|
||||
SmallVector<AffineMap, 4> indexingMaps;
|
||||
|
@ -615,15 +640,18 @@ static LogicalResult verify(GenericOp op) {
|
|||
return op.emitError("op expected indexing_map #")
|
||||
<< idx << " results to match view rank: " << view;
|
||||
|
||||
if (funType.getInput(idx) != view.getElementType())
|
||||
return op.emitError("op expected fun argument ")
|
||||
<< idx << " to match view element type: " << view.getElementType();
|
||||
if (funType) {
|
||||
if (funType.getInput(idx) != view.getElementType())
|
||||
return op.emitError("op expected fun argument ")
|
||||
<< idx
|
||||
<< " to match view element type: " << view.getElementType();
|
||||
|
||||
if (idx >= nInputViews)
|
||||
if (funType.getResult(idx - nInputViews) != view.getElementType())
|
||||
return op.emitError("op expected fun result ")
|
||||
<< idx << " to match output view element type: "
|
||||
<< view.getElementType();
|
||||
if (idx >= nInputViews)
|
||||
if (funType.getResult(idx - nInputViews) != view.getElementType())
|
||||
return op.emitError("op expected fun result ")
|
||||
<< idx << " to match output view element type: "
|
||||
<< view.getElementType();
|
||||
}
|
||||
}
|
||||
|
||||
auto concatMap = concatAffineMaps(indexingMaps);
|
||||
|
@ -635,6 +663,56 @@ static LogicalResult verify(GenericOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// YieldOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseYieldOp(OpAsmParser *parser, OperationState *result) {
|
||||
SmallVector<OpAsmParser::OperandType, 2> opInfo;
|
||||
SmallVector<Type, 2> types;
|
||||
llvm::SMLoc loc = parser->getCurrentLocation();
|
||||
return failure(parser->parseOperandList(opInfo) ||
|
||||
(!opInfo.empty() && parser->parseColonTypeList(types)) ||
|
||||
parser->resolveOperands(opInfo, types, loc, result->operands));
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter *p, YieldOp op) {
|
||||
*p << op.getOperationName();
|
||||
if (op.getNumOperands() > 0) {
|
||||
*p << ' ';
|
||||
p->printOperands(op.operand_begin(), op.operand_end());
|
||||
*p << " : ";
|
||||
interleaveComma(op.getOperands(), *p,
|
||||
[&](Value *e) { p->printType(e->getType()); });
|
||||
}
|
||||
}
|
||||
|
||||
static LogicalResult verify(YieldOp op) {
|
||||
auto *parentOp = op.getParentOp();
|
||||
if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
|
||||
return op.emitOpError("op expected single non-empty parent region");
|
||||
|
||||
auto genericOp = dyn_cast<GenericOp>(parentOp);
|
||||
if (!genericOp)
|
||||
return op.emitOpError("op expected '")
|
||||
<< GenericOp::getOperationName() << "' parent op";
|
||||
|
||||
// The operand number and types must match the view element types.
|
||||
auto nOutputViews = genericOp.getNumOutputs();
|
||||
if (op.getNumOperands() != nOutputViews)
|
||||
return op.emitOpError("op expected ")
|
||||
<< nOutputViews << " operand to match enclosing linalg.generic op";
|
||||
|
||||
for (unsigned i = 0; i != nOutputViews; ++i) {
|
||||
auto elementType = genericOp.getOutputViewType(i).getElementType();
|
||||
if (op.getOperand(i)->getType() != elementType)
|
||||
return op.emitError("type of return operand ")
|
||||
<< i << " (" << op.getOperand(i)->getType()
|
||||
<< ") doesn't match view element type (" << elementType << ")";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter *p, SubViewOp op) {
|
||||
*p << op.getOperationName() << " " << *op.getOperand(0) << "[";
|
||||
auto ranges = op.getRanges();
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "mlir/EDSC/Helpers.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Linalg/IR/LinalgTypes.h"
|
||||
|
@ -241,17 +242,44 @@ public:
|
|||
linalg_load(genericOp.getOutput(i), indexing);
|
||||
}
|
||||
|
||||
// 2. Emit call.
|
||||
auto m = genericOp.getParentOfType<ModuleOp>();
|
||||
auto fun = m.lookupSymbol<FuncOp>(genericOp.fun());
|
||||
Operation *callOp = call(fun, indexedValues);
|
||||
assert(callOp->getNumResults() == genericOp.getNumOutputs());
|
||||
auto funcOp = genericOp.getFunction();
|
||||
if (funcOp) {
|
||||
// 2. Emit call.
|
||||
Operation *callOp = call(funcOp, indexedValues);
|
||||
assert(callOp->getNumResults() == genericOp.getNumOutputs());
|
||||
|
||||
// 3. Emit linalg_store.
|
||||
for (unsigned i = 0, e = nOutputs; i < e; ++i) {
|
||||
ValueHandleArray indexing(foldedAffineApplies(
|
||||
b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
|
||||
linalg_store(callOp->getResult(i), genericOp.getOutput(i), indexing);
|
||||
// 3. Emit linalg_store.
|
||||
for (unsigned i = 0, e = nOutputs; i < e; ++i) {
|
||||
ValueHandleArray indexing(foldedAffineApplies(
|
||||
b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
|
||||
linalg_store(callOp->getResult(i), genericOp.getOutput(i), indexing);
|
||||
}
|
||||
} else {
|
||||
// TODO(ntv): When a region inliner exists, use it.
|
||||
// 2. Inline region, currently only works for a single basic block.
|
||||
BlockAndValueMapping map;
|
||||
auto &block = genericOp.region().front();
|
||||
for (auto it : llvm::zip(block.getArguments(), indexedValues))
|
||||
map.map(std::get<0>(it), std::get<1>(it));
|
||||
for (auto &op : block) {
|
||||
// Skip terminator.
|
||||
if (&op == &block.back())
|
||||
continue;
|
||||
assert(op.getNumRegions() == 0);
|
||||
auto *newOp = b.clone(op, map);
|
||||
for (auto it : llvm::zip(op.getResults(), newOp->getResults()))
|
||||
map.map(std::get<0>(it), std::get<1>(it));
|
||||
}
|
||||
|
||||
// 3. Emit linalg_store.
|
||||
auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
|
||||
assert(yieldOp->getNumOperands() == nOutputs);
|
||||
for (unsigned i = 0, e = nOutputs; i < e; ++i) {
|
||||
ValueHandleArray indexing(foldedAffineApplies(
|
||||
b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
|
||||
linalg_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i),
|
||||
indexing);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////// Function Attribute tests /////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: at_least_2_operands
|
||||
|
@ -194,3 +198,78 @@ func @singular_maps(%arg0: !linalg.view<?xf32>, %arg1: !linalg.view<?xf32>) {
|
|||
} %arg0, %arg1: !linalg.view<?xf32>, !linalg.view<?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////// Region tests /////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: empty_region
|
||||
func @empty_region(%arg0: !linalg.view<f32>) {
|
||||
// expected-error @+1 {{op expected region with 1 block}}
|
||||
linalg.generic {
|
||||
indexing_maps = [ () -> (0) ],
|
||||
n_views = [1, 1],
|
||||
n_loop_types = [0, 0, 0]
|
||||
} %arg0, %arg0 {
|
||||
^bb1:
|
||||
^bb2:
|
||||
}: !linalg.view<f32>, !linalg.view<f32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: mismatched_num_arguments
|
||||
func @mismatched_num_arguments(%arg0: !linalg.view<f32>) {
|
||||
// expected-error @+1 {{op expected number of block arguments to match number of views}}
|
||||
linalg.generic {
|
||||
indexing_maps = [ () -> (0) ],
|
||||
n_views = [0, 1],
|
||||
n_loop_types = [0, 0, 0]
|
||||
} %arg0 {
|
||||
^bb:
|
||||
}: !linalg.view<f32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: block_arg_type
|
||||
func @block_arg_type(%arg0: !linalg.view<f32>) {
|
||||
// expected-error @+1 {{op expected block argument 0 of the same type as elemental type of output view: '!linalg.view<f32>'}}
|
||||
linalg.generic {
|
||||
indexing_maps = [ () -> (0) ],
|
||||
n_views = [0, 1],
|
||||
n_loop_types = [0, 0, 0]
|
||||
} %arg0 {
|
||||
^bb(%i: i1):
|
||||
}: !linalg.view<f32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: fun_result_0_element_type
|
||||
func @fun_result_0_element_type(%arg0: !linalg.view<?xf32>) {
|
||||
// expected-error @+8 {{type of return operand 0 ('i1') doesn't match view element type ('f32')}}
|
||||
linalg.generic {
|
||||
indexing_maps = [ (i) -> (i) ],
|
||||
n_views = [0, 1],
|
||||
n_loop_types = [1, 0, 0]
|
||||
} %arg0 {
|
||||
^bb(%i: f32):
|
||||
%0 = constant 0: i1
|
||||
linalg.yield %0: i1
|
||||
}: !linalg.view<?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: wrong_yield_parent
|
||||
func @fun_result_0_element_type(%arg0: !linalg.view<?xf32>) {
|
||||
// expected-error @+1 {{op expected 'linalg.generic' parent op}}
|
||||
linalg.yield %arg0: !linalg.view<?xf32>
|
||||
}
|
||||
|
|
|
@ -219,13 +219,13 @@ func @foo(%0: f32, %1: f32, %2: f32) -> (f32, f32) {
|
|||
library_call = "external_function_name",
|
||||
doc = "B(i,j,k), C(i,k,j) = foo(A(i, j), B(i,j,k), C(i,k,j))"
|
||||
}
|
||||
func @generic(%arg0: !linalg.view<?x?xf32>, %arg1: !linalg.view<?x?x?xf32>, %arg2: !linalg.view<?x?x?xf32>) {
|
||||
func @generic_function(%arg0: !linalg.view<?x?xf32>, %arg1: !linalg.view<?x?x?xf32>, %arg2: !linalg.view<?x?x?xf32>) {
|
||||
linalg.generic #trait %arg0, %arg1, %arg2:
|
||||
!linalg.view<?x?xf32>, !linalg.view<?x?x?xf32>, !linalg.view<?x?x?xf32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: @foo
|
||||
// CHECK-LABEL: @generic
|
||||
// CHECK-LABEL: @generic_function
|
||||
// CHECK: loop.for %[[i:.*]] = {{.*}}
|
||||
// CHECK: loop.for %[[j:.*]] = {{.*}}
|
||||
// CHECK: loop.for %[[k:.*]] = {{.*}}
|
||||
|
@ -235,3 +235,31 @@ func @generic(%arg0: !linalg.view<?x?xf32>, %arg1: !linalg.view<?x?x?xf32>, %arg
|
|||
// CHECK: %[[res:.*]]:2 = call @foo(%[[a]], %[[b]], %[[c]]) : (f32, f32, f32) -> (f32, f32)
|
||||
// CHECK: linalg.store %[[res]]#0, %{{.*}}[%[[i]], %[[j]], %[[k]]] : !linalg.view<?x?x?xf32>
|
||||
// CHECK: linalg.store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : !linalg.view<?x?x?xf32>
|
||||
|
||||
#trait2 = {
|
||||
n_views = [1, 2],
|
||||
n_loop_types = [3, 0, 0],
|
||||
indexing_maps = #accesses,
|
||||
library_call = "external_function_name",
|
||||
doc = "B(i,j,k), C(i,k,j) = foo(A(i, j), B(i,j,k), C(i,k,j))"
|
||||
}
|
||||
func @generic_region(%arg0: !linalg.view<?x?xf32>, %arg1: !linalg.view<?x?x?xf32>, %arg2: !linalg.view<?x?x?xf32>) {
|
||||
linalg.generic #trait2 %arg0, %arg1, %arg2 {
|
||||
^bb0(%a: f32, %b: f32, %c: f32):
|
||||
%d = mulf %a, %b : f32
|
||||
%e = addf %c, %d : f32
|
||||
linalg.yield %d, %e : f32, f32
|
||||
}: !linalg.view<?x?xf32>, !linalg.view<?x?x?xf32>, !linalg.view<?x?x?xf32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: @generic_region
|
||||
// CHECK: loop.for %[[i:.*]] = {{.*}}
|
||||
// CHECK: loop.for %[[j:.*]] = {{.*}}
|
||||
// CHECK: loop.for %[[k:.*]] = {{.*}}
|
||||
// CHECK: %[[a:.*]] = linalg.load %{{.*}}[%[[i]], %[[j]]] : !linalg.view<?x?xf32>
|
||||
// CHECK: %[[b:.*]] = linalg.load %{{.*}}[%[[i]], %[[j]], %[[k]]] : !linalg.view<?x?x?xf32>
|
||||
// CHECK: %[[c:.*]] = linalg.load %{{.*}}[%[[i]], %[[k]], %[[j]]] : !linalg.view<?x?x?xf32>
|
||||
// CHECK: %[[d:.*]] = mulf %[[a]], %[[b]] : f32
|
||||
// CHECK: %[[e:.*]] = addf %[[c]], %[[d]] : f32
|
||||
// CHECK: linalg.store %[[d]], %{{.*}}[%[[i]], %[[j]], %[[k]]] : !linalg.view<?x?x?xf32>
|
||||
// CHECK: linalg.store %[[e]], %{{.*}}[%[[i]], %[[k]], %[[j]]] : !linalg.view<?x?x?xf32>
|
||||
|
|
|
@ -179,3 +179,22 @@ func @generic(%arg0: !linalg.view<?x?xvector<3x4xi4>>, %arg1: !linalg.view<?x?x?
|
|||
// CHECK-LABEL: func @foo
|
||||
// CHECK-LABEL: func @generic
|
||||
// CHECK: linalg.generic {fun = @foo, indexing_maps = [#map2, #map3], library_call = "external_function_name", n_loop_types = [3, 0, 0], n_views = [1, 1]} %{{.*}}, %{{.*}} {foo = 1 : i64}: !linalg.view<?x?xvector<3x4xi4>>, !linalg.view<?x?x?xf32>
|
||||
|
||||
#trait2 = {
|
||||
indexing_maps = #accesses,
|
||||
n_views = [1, 1],
|
||||
n_loop_types = [3, 0, 0],
|
||||
library_call = "external_function_name"
|
||||
}
|
||||
func @generic_region(%arg0: !linalg.view<?x?xvector<3x4xi4>>, %arg1: !linalg.view<?x?x?xf32>) {
|
||||
linalg.generic #trait2 %arg0, %arg1 {
|
||||
^bb(%a: vector<3x4xi4>, %b: f32) :
|
||||
linalg.yield %b : f32
|
||||
} {foo = 1}: !linalg.view<?x?xvector<3x4xi4>>, !linalg.view<?x?x?xf32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @generic_region
|
||||
// CHECK: linalg.generic {indexing_maps = [#map2, #map3], library_call = "external_function_name", n_loop_types = [3, 0, 0], n_views = [1, 1]} %{{.*}}, %{{.*}} {
|
||||
// CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32): // no predecessors
|
||||
// CHECK: linalg.yield %{{.*}} : f32
|
||||
// CHECK: } {foo = 1 : i64}: !linalg.view<?x?xvector<3x4xi4>>, !linalg.view<?x?x?xf32>
|
Loading…
Reference in New Issue