[mlir][linalg] Purge linalg.indexed_generic.

Differential Revision: https://reviews.llvm.org/D104449
This commit is contained in:
Alexander Belyaev 2021-06-17 14:35:26 +02:00
parent aa6e8e9572
commit 5b3cb31edb
13 changed files with 12 additions and 494 deletions

View File

@ -32,7 +32,7 @@ class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
// always be 0 for index-free linalg ops. For IndexedGeneric, this must be
// equal to numLoops.
unsigned getNumPayloadInductionVariables() {
return isa<IndexedGenericOp>(this->getOperation()) ? getNumLoops() : 0;
return 0;
}
// Return whether the op accesses the iteration indices.
@ -671,140 +671,6 @@ def GenericOp : GenericOpBase<"generic"> {
let hasFolder = 1;
}
/// GenericOp with Indexing (i.e. multi-for style in which the region is passed
/// the enclosing loop induction variables)
def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
let description = [{
Indexed Generic Linalg op form where the key properties of the computation
are specified as attributes. In pretty form, a `linalg.indexed_generic` op
is written as:
```mlir
linalg.indexed_generic #trait_attribute
ins(%A, %B : memref<?x?xf32, stride_specification>,
memref<?x?xf32, stride_specification>)
outs(%C : memref<?x?xf32, stride_specification>)
attrs = {other-optional-attributes}
{region}
```
Where #trait_attributes is an alias of a dictionary attribute containing:
- doc [optional]: a documentation string
- indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input
and output view. Such AffineMapAttr specifies the mapping between the
loops and the indexing within each view.
- library_call [optional]: a StringAttr containing the name of an
external library function that the linalg.indexed_generic operation
maps to. The external library is assumed to be dynamically linked and
no strong compile-time guarantees are provided. In the absence of such
a library call, linalg.indexed_generic will always lower to loops.
- iterator_types: an ArrayAttr they type of the enclosing loops; Each
element of the list represents and iterator of one of the following
types:
parallel, reduction, window
Example:
Defining a #matmul_trait attribute in MLIR can be done as follows:
```mlir
#matmul_accesses = [
(m, n, k) -> (m, k),
(m, n, k) -> (k, n),
(m, n, k) -> (m, n)
]
#matmul_trait = {
doc = "C(m, n) += A(m, k) * B(k, n)",
indexing_maps = #matmul_accesses,
library_call = "linalg_matmul",
iterator_types = ["parallel", "parallel", "reduction"]
}
```
And can be reused in multiple places as:
```mlir
linalg.indexed_generic #matmul_trait
ins(%A, %B : memref<?x?xf32, stride_specification>,
memref<?x?xf32, stride_specification>)
outs(%C : memref<?x?xf32, stride_specification>) {
(%offset_m: index, %offset_n: index, %offset_k: index,
%a: f32, %b: f32, %c: f32) :
"some_optional_computation"(%offset_m, %offset_n, %offset_k)
%d = mulf %a, %b: f32
%e = addf %c, %d: f32
linalg_yield %e : f32
}
```
This may lower to either:
```mlir
call @linalg_matmul(%offset_m, %offset_n, %offset_k, %A, %B, %C) :
(index, index, index,
memref<?x?xf32, stride_specification>,
memref<?x?xf32, stride_specification>,
memref<?x?xf32, stride_specification>)
-> ()
```
or IR resembling:
```mlir
scf.for %m = %c0 to %M step %c1 {
scf.for %n = %c0 to %N step %c1 {
scf.for %k = %c0 to %K step %c1 {
%a = load %A[%m, %k] : memref<?x?xf32, stride_specification>
%b = load %B[%k, %n] : memref<?x?xf32, stride_specification>
%c = load %C[%m, %n] : memref<?x?xf32, stride_specification>
"some_optional_computation"(%m, %n, %k)
%d = mulf %a, %b: f32
%e = addf %c, %d: f32
store %d, %C[%m, %n] : memref<?x?x?xf32, stride_specification>
}
}
}
```
To allow progressive lowering from the value world (a.k.a tensor values) to
the buffer world (a.k.a memref values), a `linalg.indexed_generic` op
allows mixing tensors and buffers operands and tensor results.
```mlir
%C = linalg.indexed_generic #trait_attribute
ins(%A, %B : tensor<?x?xf32>, memref<?x?xf32, stride_specification>)
outs(%C : tensor<?x?xf32>)
{other-optional-attributes}
{region_with_index_arguments}
-> (tensor<?x?xf32>)
```
}];
let builders = [
OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
"ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
"ArrayRef<StringRef>":$iteratorTypes, "StringRef":$doc,
"StringRef":$libraryCall,
CArg<"function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>",
"nullptr">)>,
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
"ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
"StringRef":$doc, "StringRef":$libraryCall,
CArg<"function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>",
"nullptr">)>,
OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
"ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
"ArrayRef<StringRef>":$iteratorTypes,
CArg<"function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>",
"nullptr">)>,
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
"ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
CArg<"function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>",
"nullptr">)>
];
let verifier = [{ return ::verify(*this); }];
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
// Named Linalg ops, implemented as a declarative configurations of generic ops.

View File

@ -100,10 +100,6 @@ LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
if (isa<CopyOp>(op))
return failure();
// Canonicalize indexed generic operations before library call conversion.
if (isa<IndexedGenericOp>(op))
return failure();
auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
if (!libraryCallName)
return failure();

View File

@ -525,69 +525,8 @@ void GenericOp::build(
/*doc=*/"",
/*libraryCall=*/"", bodyBuild);
}
void IndexedGenericOp::build(
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
bodyBuild) {
build(builder, result, resultTensorTypes, inputs, outputs,
builder.getAffineMapArrayAttr(indexingMaps),
builder.getStrArrayAttr(iteratorTypes),
doc.empty() ? StringAttr() : builder.getStringAttr(doc),
libraryCall.empty() ? StringAttr()
: builder.getStringAttr(libraryCall));
if (!bodyBuild)
return;
unsigned nLoops = iteratorTypes.size();
SmallVector<Type, 4> blockArgTypes(nLoops, builder.getIndexType());
for (ValueRange container : {inputs, outputs})
for (Value v : container)
blockArgTypes.push_back(getElementTypeOrSelf(v));
OpBuilder::InsertionGuard guard(builder);
auto &region = *result.regions.front();
Block *bodyBlock = builder.createBlock(&region, region.end(), blockArgTypes);
bodyBuild(builder, result.location,
bodyBlock->getArguments().take_front(nLoops),
bodyBlock->getArguments().drop_front(nLoops));
}
void IndexedGenericOp::build(
OpBuilder &builder, OperationState &result, ValueRange inputs,
ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
bodyBuild) {
build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
iteratorTypes, doc, libraryCall, bodyBuild);
}
void IndexedGenericOp::build(
OpBuilder &builder, OperationState &result, ValueRange inputs,
ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
ArrayRef<StringRef> iteratorTypes,
function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
bodyBuild) {
build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
/*doc=*/"", /*libraryCall=*/"", bodyBuild);
}
void IndexedGenericOp::build(
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
ArrayRef<StringRef> iteratorTypes,
function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
bodyBuild) {
build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
iteratorTypes,
/*doc=*/"",
/*libraryCall=*/"", bodyBuild);
}
template <typename GenericOpType>
static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
static void print(OpAsmPrinter &p, GenericOp op) {
p << op.getOperationName() << " ";
// Print extra attributes.
@ -628,12 +567,6 @@ static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
printNamedStructuredOpResults(p, op.result_tensors().getTypes());
}
static void print(OpAsmPrinter &p, GenericOp op) { printGenericOp(p, op); }
static void print(OpAsmPrinter &p, IndexedGenericOp op) {
printGenericOp(p, op);
}
static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
DictionaryAttr dictAttr;
// Parse the core linalg traits that must check into a dictAttr.
@ -704,15 +637,6 @@ void GenericOp::getEffects(
outputBuffers);
}
void IndexedGenericOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
SmallVector<Value> inputBuffers = getInputBufferOperands();
SmallVector<Value> outputBuffers = getOutputBufferOperands();
getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
outputBuffers);
}
template <typename GenericOpType>
static LogicalResult verifyGenericOp(GenericOpType op) {
return success();
@ -720,52 +644,6 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); }
static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
namespace {
/// Replace indexed_generic ops by generic ops that access the iteration indices
/// using index operation calls.
struct ConvertIndexedToGenericOp : OpRewritePattern<IndexedGenericOp> {
using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(IndexedGenericOp indexedOp,
PatternRewriter &rewriter) const override {
// Replace all uses of the index block arguments.
BlockAndValueMapping bvm;
if (Block *body = indexedOp.getBody()) {
rewriter.setInsertionPointToStart(body);
for (const auto &en : llvm::enumerate(
body->getArguments().take_front(indexedOp.getNumLoops()))) {
Value index = rewriter.create<IndexOp>(indexedOp.getLoc(), en.index());
bvm.map(en.value(), index);
}
}
// Create a generic replacement operation and clone the body.
rewriter.setInsertionPointAfter(indexedOp);
SmallVector<Value> inputOperands = indexedOp.getInputOperands();
SmallVector<Value> outputOperands = indexedOp.getOutputOperands();
SmallVector<StringRef> iterators = llvm::to_vector<4>(
indexedOp.iterator_types().getAsValueRange<StringAttr>());
GenericOp genericOp = rewriter.create<GenericOp>(
indexedOp.getLoc(), indexedOp->getResultTypes(), inputOperands,
outputOperands, indexedOp.getIndexingMaps(), iterators);
Region &genericRegion = genericOp.region();
Region &indexedRegion = indexedOp.region();
rewriter.cloneRegionBefore(indexedRegion, genericRegion,
genericRegion.begin(), bvm);
rewriter.replaceOp(indexedOp, genericOp->getResults());
return success();
}
};
} // namespace
void IndexedGenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ConvertIndexedToGenericOp>(context);
}
//===----------------------------------------------------------------------===//
// InitTensorOp
//===----------------------------------------------------------------------===//
@ -3230,7 +3108,7 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
PatternRewriter &rewriter) const override {
// This pattern reduces the number of arguments of an op, which breaks
// the invariants of semantically charged named ops.
if (!isa<GenericOp, IndexedGenericOp>(op))
if (!isa<GenericOp>(op))
return failure();
// Associate each input to an equivalent "canonical" input that has the same
@ -3290,10 +3168,6 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
// the value from the original op.
newLinalgOp.setNumInputs(canonicalInput.size());
// linalg.indexed_generic payloads have additional arguments prepended to
// the block arg list.
int bbArgBaseOffset = newLinalgOp.getNumPayloadInductionVariables();
// Repair the payload entry block by RAUW'ing redundant arguments and
// erasing them.
Block &payload = newOp->getRegion(0).front();
@ -3305,10 +3179,10 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
unsigned operandNumber = opOperand->getOperandNumber();
if (canonicalInputIndices[operandNumber] == operandNumber)
continue;
payload.getArgument(bbArgBaseOffset + operandNumber)
.replaceAllUsesWith(payload.getArgument(
bbArgBaseOffset + canonicalInputIndices[operandNumber]));
payload.eraseArgument(bbArgBaseOffset + operandNumber);
payload.getArgument(operandNumber)
.replaceAllUsesWith(
payload.getArgument(canonicalInputIndices[operandNumber]));
payload.eraseArgument(operandNumber);
}
rewriter.replaceOp(op, newOp->getResults());
@ -3316,7 +3190,7 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
}
};
/// Remove generic/indexed_generic operations (on tensors) that are just copying
/// Remove generic operations (on tensors) that are just copying
/// the values from inputs to the results. Requirements are
/// 1) All iterator types are parallel
/// 2) The body contains just a yield operation with the yielded values being
@ -3335,7 +3209,7 @@ struct RemoveIdentityLinalgOps : public OpInterfaceRewritePattern<LinalgOp> {
}
}
if (!isa<GenericOp, IndexedGenericOp>(op))
if (!isa<GenericOp>(op))
return failure();
if (!op.hasTensorSemantics())
return failure();

View File

@ -202,10 +202,6 @@ public:
LogicalResult
matchAndRewrite(LinalgOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Canonicalize indexed generic operations before bufferization.
if (isa<IndexedGenericOp>(op))
return failure();
// GenericOpAdaptor below expects an `operand_segment_sizes` attribute.
if (!op->hasAttr("operand_segment_sizes"))
return failure();

View File

@ -230,7 +230,6 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
// When the producer has index semantics, we have to transform the indices of
// the producer according to the tiling of the consumer, i.e. offset them by
// the values computed in `loopRanges`.
assert(!isa<IndexedGenericOp>(producer) && "unexpected op");
if (producer.hasIndexSemantics()) {
assert(clonedOp->getNumRegions() == 1 &&
clonedOp->getRegion(0).getBlocks().size() == 1 &&
@ -426,10 +425,6 @@ mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand,
if (!fusableDependence)
return llvm::None;
// Canonicalize indexed generic ops before fusion.
if (isa<IndexedGenericOp>(fusableDependence->getDependentOp()))
return llvm::None;
LinalgOp producerOp = dyn_cast<LinalgOp>(fusableDependence->getDependentOp());
if (!producerOp)
return llvm::None;
@ -507,10 +502,6 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
Optional<FusionInfo>
mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
OpOperand &consumerOpOperand) {
// Canonicalize indexed generic ops before fusion.
if (isa<IndexedGenericOp>(producerOpResult.getOwner()))
return llvm::None;
auto producerOp = dyn_cast<LinalgOp>(producerOpResult.getOwner());
if (!producerOp)
return llvm::None;
@ -766,9 +757,6 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
fusableDependence = findFusableProducer(*opOperand, dependenceGraph);
if (!fusableDependence)
continue;
// Canonicalize indexed generic ops before fusion.
if (isa<IndexedGenericOp>(fusableDependence->getDependentOp()))
continue;
LinalgOp producerOp =
dyn_cast<LinalgOp>(fusableDependence->getDependentOp());
if (!producerOp)

View File

@ -1402,7 +1402,6 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
options.controlFoldingReshapesFn);
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
GenericOp::getCanonicalizationPatterns(patterns, context);
IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context);
TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context);
context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(

View File

@ -130,8 +130,8 @@ struct LinalgNamedOpGeneralizationPattern : RewritePattern {
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
return failure();
// No nothing to do for linalg.generic and linalg.indexed_generic.
if (isa<GenericOp, IndexedGenericOp>(rootOp))
// No nothing to do for linalg.generic.
if (isa<GenericOp>(rootOp))
return failure();
GenericOp genericOp = createGenericOpFromNamedOp(linalgOp, rewriter);

View File

@ -418,10 +418,6 @@ static Optional<LinalgLoops> linalgOpToLoopsImpl(PatternRewriter &rewriter,
typename std::conditional<std::is_same<LoopTy, AffineForOp>::value,
AffineStoreOp, memref::StoreOp>::type;
// Canonicalize indexed_generic operations before lowering them to loops.
if (isa<IndexedGenericOp>(linalgOp))
return llvm::None;
// The flattened loopToOperandRangesMaps is expected to be an invertible
// permutation map (which is asserted in the inverse calculation).
assert(linalgOp.hasBufferSemantics() &&

View File

@ -163,10 +163,6 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
if (llvm::all_of(tileSizes, isZero))
return llvm::None;
// Canonicalize indexed generic operations before tiling.
if (isa<IndexedGenericOp>(op))
return llvm::None;
if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
// For conv op only support tiling along batch dimension (which is the first
// loop).

View File

@ -89,26 +89,3 @@ func @multiple_different_redundant_args(%arg0: tensor<?xf32>, %arg1: tensor<?xf3
} -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// Test case: linalg.indexed_generic.
// Other than the payload argument handling, everything else is the same.
#map = affine_map<(d0) -> (d0)>
// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: @indexed_generic
func @indexed_generic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: linalg.generic
// CHECK: ^bb0(%[[BBARG:.*]]: f32, %{{[a-zA-Z0-9]+}}: f32):
// CHECK: addf %[[BBARG]], %[[BBARG]]
%0 = linalg.indexed_generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]}
ins(%arg0, %arg0 : tensor<?xf32>, tensor<?xf32>)
outs(%arg0 : tensor<?xf32>) {
^bb0(%index: index, %arg1: f32, %arg2: f32, %arg3: f32):
%1 = addf %arg1, %arg2 : f32
linalg.yield %1 : f32
} -> tensor<?xf32>
return %0 : tensor<?xf32>
}

View File

@ -842,39 +842,6 @@ func @fold_tiled_loop_inputs(%A: memref<192xf32>, %A_tensor: tensor<192xf32>,
// -----
#map = affine_map<(d0, d1) -> (d0, d1)>
func @indexed_generic(%arg0: memref<?x?xindex>, %arg1: memref<?x?xindex>) {
linalg.indexed_generic {
indexing_maps = [#map, #map],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : memref<?x?xindex>)
outs(%arg1 : memref<?x?xindex>) {
^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
%0 = addi %arg4, %arg5 : index
%1 = addi %0, %arg6 : index
%2 = addi %1, %arg7 : index
linalg.yield %2 : index
}
return
}
// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: func @indexed_generic
// CHECK-NEXT: linalg.generic {
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins(%[[ARG0:[A-Za-z0-9_]+]] : memref<?x?xindex>)
// CHECK-SAME: outs(%[[ARG1:[A-Za-z0-9_]+]] : memref<?x?xindex>)
// CHECK: ^bb0(%[[ARG2:[A-Za-z0-9_]+]]: index, %[[ARG3:[A-Za-z0-9_]+]]: index):
// CHECK-NEXT: %[[IDX0:.+]] = linalg.index 0 : index
// CHECK-NEXT: %[[IDX1:.+]] = linalg.index 1 : index
// CHECK-NEXT: %[[SUM0:.+]] = addi %[[IDX0]], %[[IDX1]] : index
// CHECK-NEXT: %[[SUM1:.+]] = addi %[[SUM0]], %[[ARG2]] : index
// CHECK-NEXT: %[[SUM2:.+]] = addi %[[SUM1]], %[[ARG3]] : index
// CHECK-NEXT: linalg.yield %[[SUM2]] : index
// -----
func @tensor_pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
%c0 = constant 0 : index
%cst = constant 0.0 : f32

View File

@ -227,72 +227,6 @@ func @generic_scalar_operand_block_arg_type(%arg0: f32) {
// -----
func @indexed_generic_block_arg_count(%arg0: memref<?xf32>) {
// expected-error @+1 {{expected as many non-induction variable region arguments as the number of input/output operands}}
linalg.indexed_generic {
indexing_maps = [ affine_map<(i) -> (i)> ],
iterator_types = ["parallel"]}
outs(%arg0 : memref<?xf32>) {
^bb(%f: f32):
linalg.yield %f : f32
}
}
// -----
func @indexed_generic_block_induction_var_arg_type(%arg0: memref<?xf32>) {
// expected-error @+1 {{op expected index block argument #0}}
linalg.indexed_generic {
indexing_maps = [ affine_map<(d0) -> (d0)> ],
iterator_types = ["parallel"]}
outs(%arg0 : memref<?xf32>) {
^bb(%i: f64, %f: f32):
linalg.yield %f: f32
}
}
// -----
func @indexed_generic_block_arg_type(%arg0: memref<?xf32>) {
// expected-error @+1 {{expected type of bb argument #1 ('i1') to match element or self type of the corresponding operand ('f32')}}
linalg.indexed_generic {
indexing_maps = [ affine_map<(d0) -> (d0)> ],
iterator_types = ["parallel"]}
outs(%arg0 : memref<?xf32>) {
^bb(%i: index, %f: i1):
linalg.yield %i: index
}
}
// -----
func @indexed_generic_arg_count(%arg0: memref<f32>) {
// expected-error @+1 {{expected as many non-induction variable region arguments as the number of input/output operands}}
linalg.indexed_generic {
indexing_maps = [ affine_map<()[] -> ()> ],
iterator_types = []}
outs(%arg0 : memref<f32>) {
^bb(%0: index, %1: f32):
linalg.yield %1: f32
}
return
}
// -----
func @indexed_generic_result_count(%arg0: memref<?xf32>) {
// expected-error @+6 {{op expected number of yield values (1) to match the number of operands of the enclosing LinalgOp (2)}}
linalg.indexed_generic {
indexing_maps = [ affine_map<(d0) -> (d0)> ],
iterator_types = ["parallel"]}
outs(%arg0 : memref<?xf32>) {
^bb(%i: index, %val: f32):
linalg.yield %val, %val: f32, f32
}
}
// -----
func @generic_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
// expected-error @+7 {{type of yield operand 1 ('i1') doesn't match the element type of the enclosing linalg.generic op ('f32')}}
linalg.generic {

View File

@ -1,7 +1,7 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
// TODO: Re-enable LLVM lowering test after IndexedGenericOp is lowered.
// TODO: Re-enable LLVM lowering test.
//
// Test that we can lower all the way to LLVM without crashing, don't check results here.
// DISABLED: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1
@ -457,43 +457,6 @@ func @generic_with_multiple_tensor_outputs(
// -----
#accesses_2 = [
affine_map<(i, j, k) -> (j, i)>,
affine_map<(i, j, k) -> (i, k, i + j)>,
affine_map<(i, j, k) -> (i, k, i + j)>
]
#trait_2 = {
indexing_maps = #accesses_2,
iterator_types = ["parallel", "parallel", "parallel"],
library_call = "some_external_function_name_1"
}
func @indexed_generic_with_tensor_input_and_output(
%arg0: tensor<?x?xvector<3x4xi4>>, %arg1: tensor<?x?x?xf32>)
-> (tensor<?x?x?xf32>) {
%0 = linalg.indexed_generic #trait_2
ins(%arg0, %arg1 : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
outs(%arg1 : tensor<?x?x?xf32>)
attrs = {foo = 1} {
^bb(%i: index, %j: index, %k: index, %0: vector<3x4xi4>, %1: f32, %2: f32) :
%f0 = constant 0.0 : f32
linalg.yield %f0 : f32
} -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
}
// CHECK-LABEL: func @indexed_generic_with_tensor_input_and_output
// CHECK: linalg.indexed_generic {
// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
// CHECK-SAME: library_call = "some_external_function_name_1"}
// CHECK-SAME: ins({{.*}} : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
// CHECK-SAME: outs({{.*}} : tensor<?x?x?xf32>)
// CHECK-SAME: {foo = 1 : i64}
// CHECK: -> tensor<?x?x?xf32>
// CHECK: return {{.*}} : tensor<?x?x?xf32>
// -----
#broadcast_access = [
affine_map<(i, j) -> ()>,
affine_map<(i, j) -> (i, j)>
@ -516,17 +479,6 @@ func @generic_op_zero_rank(%arg0: tensor<f32>, %arg1 : tensor<3x4xf32>) -> (tens
return %0 : tensor<3x4xf32>
}
func @indexed_generic_op_zero_rank(%arg0: tensor<f32>, %arg1 : tensor<3x4xf32>) -> (tensor<3x4xf32>)
{
%0 = linalg.indexed_generic #trait_broadcast
ins(%arg0 : tensor<f32>)
outs(%arg1 : tensor<3x4xf32>) {
^bb(%i: index, %j: index, %a: f32, %b: f32) :
linalg.yield %a : f32
} -> tensor<3x4xf32>
return %0 : tensor<3x4xf32>
}
// -----
@ -569,29 +521,6 @@ func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1
// CHECK: %{{.*}} = linalg.index 2 : index
// CHECK: linalg.yield %{{.*}} : f32
func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
%arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
linalg.indexed_generic #trait_3
ins(%arg0 : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>)
outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
attrs = {foo = 1} {
^bb(%i: index, %j: index, %k: index, %a: vector<3x4xi4>, %b: f32) :
linalg.yield %b : f32
}
return
}
// CHECK-LABEL: func @indexed_generic
// CHECK: linalg.indexed_generic {
// CHECK-SAME: indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"],
// CHECK-SAME: library_call = "some_external_function_name_2"
// CHECK-SAME: ins({{.*}} : memref<?x?xvector<3x4xi4>, #[[$strided2D]]>)
// CHECK-SAME: outs({{.*}} : memref<?x?x?xf32, #[[$strided3D]]>)
// CHECK-SAME: {foo = 1 : i64}
// CHECK: ^{{.*}}(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: vector<3x4xi4>, %{{.*}}: f32):
// CHECK: linalg.yield %{{.*}} : f32
// CHECK: }
// -----
func @reshape_static(%arg0: memref<3x4x5xf32>, %arg1: tensor<3x4x5xf32>,