forked from OSchip/llvm-project
[mlir] Added support for symbols inside linalg.generic and map concatenation
This commit adds functionality needed for implementation of convolutions with linalg.generic op. Since linalg.generic right now expects indexing maps to be just permutations, offset indexing needed in convolutions is not possible. Therefore in this commit we address the issue by adding support for symbols inside indexing maps which enables more advanced indexing. The upcoming commit will solve the problem of computing loop bounds from such maps. Differential Revision: https://reviews.llvm.org/D83158
This commit is contained in:
parent
55fa315b03
commit
f9c8febc52
|
@ -485,7 +485,9 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic,
|
|||
AffineMapArrayAttr:$indexing_maps,
|
||||
ArrayAttr:$iterator_types,
|
||||
OptionalAttr<StrAttr>:$doc,
|
||||
OptionalAttr<StrAttr>:$library_call);
|
||||
OptionalAttr<StrAttr>:$library_call,
|
||||
Confined<OptionalAttr<I64Attr>,
|
||||
[IntMinValue<0>]>:$symbol_source);
|
||||
let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
|
||||
let regions = (region AnyRegion:$region);
|
||||
let extraClassDeclaration = [{
|
||||
|
@ -493,7 +495,7 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic,
|
|||
return SmallVector<StringRef, 8>{
|
||||
getArgsInAttrName(), getArgsOutAttrName(), getDocAttrName(),
|
||||
getIndexingMapsAttrName(), getLibraryCallAttrName(),
|
||||
getIteratorTypesAttrName()
|
||||
getIteratorTypesAttrName(), getSymbolSourceAttrName()
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -508,12 +510,18 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic,
|
|||
llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
|
||||
llvm_unreachable(
|
||||
"No such thing as reference iterator types for a generic op.");
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
|
||||
llvm_unreachable(
|
||||
"No such thing as reference indexing maps for a generic op.");
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Optional<unsigned> getSymbolSource() {
|
||||
auto ss = symbol_source();
|
||||
return ss.hasValue() ?
|
||||
llvm::Optional<unsigned>(ss.getValue().getLimitedValue()) : llvm::None;
|
||||
}
|
||||
}];
|
||||
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
|
@ -549,6 +557,10 @@ def GenericOp : GenericOpBase<"generic"> {
|
|||
Each element of the list represents and iterator of one of the following
|
||||
types:
|
||||
parallel, reduction, window
|
||||
- symbol_source: index of the operand whose dimensions will be propagated
|
||||
as symbols to the indexing maps. When specified the number of symbols
|
||||
in each of the indexing maps has to be either 0 or the rank of the
|
||||
specified operand.
|
||||
|
||||
Example:
|
||||
Defining a #matmul_trait attribute in MLIR can be done as follows:
|
||||
|
@ -630,6 +642,35 @@ def GenericOp : GenericOpBase<"generic"> {
|
|||
escape naturally. Still, transformations and rewrites that take advantage of
|
||||
tensor SSA values are expected to be useful and will be added in the near
|
||||
future.
|
||||
|
||||
Example of 1D convolution with symbols:
|
||||
```mlir
|
||||
#conv_1d_accesses = [
|
||||
affine_map<(m, n)[dimN] -> (m + n - dimN floordiv 2)>, // in
|
||||
affine_map<(m, n)[dimN] -> (n)>, // filter
|
||||
affine_map<(m, n)[dimN] -> (m)> // out
|
||||
]
|
||||
|
||||
#conv_1d_trait = {
|
||||
doc = "O(m) += I(m + n - size(n) floordiv 2) * K(n)",
|
||||
indexing_maps = #conv_1d_accesses,
|
||||
library_call = "linalg_conv_1d",
|
||||
iterator_types = ["parallel", "parallel"],
|
||||
symbol_source = 1
|
||||
}
|
||||
|
||||
linalg.generic #conv_1d_trait %in, %filter, %out {
|
||||
^bb0(%a: f32, %b: f32, %c: f32) :
|
||||
%d = mulf %a, %b : f32
|
||||
%e = addf %c, %d : f32
|
||||
linalg.yield %e : f32
|
||||
} : memref<?xf32>,
|
||||
memref<?xf32>,
|
||||
memref<?xf32>
|
||||
```
|
||||
where symbol s0 will be substituted with `dim %filter, %c0` i.e. the first
|
||||
and only dimension of the second operand as specified by the symbol_source
|
||||
attribute.
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
|
|
|
@ -101,11 +101,28 @@ template <typename ConcreteOp>
|
|||
SmallVector<Value, 8> getViewSizes(OpBuilder &builder, ConcreteOp linalgOp) {
|
||||
auto loc = linalgOp.getLoc();
|
||||
SmallVector<Value, 8> res;
|
||||
SmallVector<unsigned, 4> ranks;
|
||||
for (auto v : linalgOp.getInputsAndOutputBuffers()) {
|
||||
MemRefType t = v.getType().template cast<MemRefType>();
|
||||
ranks.push_back(t.getRank());
|
||||
for (unsigned i = 0; i < t.getRank(); ++i)
|
||||
res.push_back(builder.create<DimOp>(loc, v, i));
|
||||
}
|
||||
|
||||
auto attr = linalgOp.template getAttrOfType<IntegerAttr>("symbol_source");
|
||||
if (attr) {
|
||||
// Find the correct position for inserting values for symbols.
|
||||
unsigned numSymb = ranks[attr.getInt()], symbolsPos = 0;
|
||||
for (unsigned idx = 0; idx < attr.getInt(); idx++)
|
||||
symbolsPos += ranks[idx];
|
||||
|
||||
// Append or rewrite the end of the value list that corresponds to the
|
||||
// values mapping to symbols. Since inside concatinated map symbols are
|
||||
// repeated we have to repeat the sizes as well.
|
||||
for (unsigned idx = 0, s = ranks.size(); idx < s; ++idx)
|
||||
for (unsigned idx2 = 0; idx2 < numSymb; ++idx2)
|
||||
res.push_back(res[symbolsPos + idx2]);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
|
|
|
@ -46,6 +46,10 @@ inline bool isColumnMajorMatmul(ArrayAttr indexingMaps) {
|
|||
return indexingMaps == maps;
|
||||
}
|
||||
|
||||
/// Attribute name for the IntegerAttr which encodes the index of operand
|
||||
/// whose dimensions will be propagated as symbols to the indexing maps
|
||||
constexpr StringRef getSymbolSourceAttrName() { return "symbol_source"; }
|
||||
|
||||
/// Attribute name for the AffineArrayAttr which encodes the relationship
|
||||
/// between a structured op iterators' and its operands.
|
||||
constexpr StringRef getIndexingMapsAttrName() { return "indexing_maps"; }
|
||||
|
|
|
@ -118,6 +118,10 @@ public:
|
|||
AffineExpr replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
|
||||
ArrayRef<AffineExpr> symReplacements) const;
|
||||
|
||||
/// Replace symbols[0 .. numDims - 1] by
|
||||
/// symbols[shift .. shift + numDims - 1].
|
||||
AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift) const;
|
||||
|
||||
AffineExpr operator+(int64_t v) const;
|
||||
AffineExpr operator+(AffineExpr other) const;
|
||||
AffineExpr operator-() const;
|
||||
|
|
|
@ -69,7 +69,8 @@ Operation *mlir::edsc::makeGenericLinalgOp(
|
|||
builder.getAffineMapArrayAttr(maps),
|
||||
builder.getStrArrayAttr(iteratorStrTypes),
|
||||
StringAttr() /*doc*/,
|
||||
StringAttr() /*library_call*/
|
||||
StringAttr() /*library_call*/,
|
||||
IntegerAttr() /*symbol_source*/
|
||||
/* TODO: other attributes in op */
|
||||
)
|
||||
.getOperation();
|
||||
|
|
|
@ -79,7 +79,8 @@ void GenericOp::build(
|
|||
builder.getI64IntegerAttr(argsOut),
|
||||
builder.getAffineMapArrayAttr(indexingMaps),
|
||||
builder.getStrArrayAttr(iteratorTypes),
|
||||
/*doc=*/nullptr, /*library_call=*/nullptr);
|
||||
/*doc=*/nullptr, /*library_call=*/nullptr,
|
||||
/*symbol_source=*/nullptr);
|
||||
if (!bodyBuild)
|
||||
return;
|
||||
|
||||
|
@ -103,7 +104,8 @@ void IndexedGenericOp::build(
|
|||
builder.getI64IntegerAttr(argsOut),
|
||||
builder.getAffineMapArrayAttr(indexingMaps),
|
||||
builder.getStrArrayAttr(iteratorTypes),
|
||||
/*doc=*/nullptr, /*library_call=*/nullptr);
|
||||
/*doc=*/nullptr, /*library_call=*/nullptr,
|
||||
/*symbol_source=*/nullptr);
|
||||
if (!bodyBuild)
|
||||
return;
|
||||
|
||||
|
@ -257,6 +259,15 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
|
|||
if (failed(BlockArgsVerifier<GenericOpType>::verify(op, region.front())))
|
||||
return failure();
|
||||
|
||||
auto attr = op.template getAttrOfType<IntegerAttr>("symbol_source");
|
||||
int64_t targetRank = 0;
|
||||
if (attr) {
|
||||
unsigned index = attr.getInt();
|
||||
if (index >= op.getNumOperands())
|
||||
return op.emitOpError("symbol_source index out of range");
|
||||
targetRank = op.getShapedType(index).getRank();
|
||||
}
|
||||
|
||||
SmallVector<AffineMap, 4> indexingMaps;
|
||||
indexingMaps.reserve(op.indexing_maps().size());
|
||||
for (auto en : llvm::enumerate(op.indexing_maps())) {
|
||||
|
@ -266,9 +277,9 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
|
|||
auto view = (idx < nInputViews) ? op.getInputShapedType(idx)
|
||||
: op.getOutputShapedType(idx - nInputViews);
|
||||
|
||||
if (m.getNumSymbols() != 0)
|
||||
return op.emitOpError("expected indexing_map #")
|
||||
<< idx << " to have no symbols";
|
||||
if (m.getNumSymbols() != targetRank)
|
||||
return op.emitOpError("expected the number of symbols in indexing_map #")
|
||||
<< idx << " to match target rank";
|
||||
|
||||
if (m.getNumDims() != nLoops)
|
||||
return op.emitOpError("expected indexing_map #")
|
||||
|
@ -281,8 +292,8 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
|
|||
}
|
||||
|
||||
auto concatMap = concatAffineMaps(indexingMaps);
|
||||
auto aggregateMap = inversePermutation(concatMap);
|
||||
if (!aggregateMap)
|
||||
// TODO: Bound inference for maps with symbols
|
||||
if (!concatMap.getNumSymbols() && !inversePermutation(concatMap))
|
||||
return op.emitOpError("expected the concatenation of maps in indexing_map "
|
||||
"to be invertible");
|
||||
|
||||
|
|
|
@ -319,7 +319,8 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
|
|||
genericOp.args_out(), rewriter.getAffineMapArrayAttr(newIndexingMaps),
|
||||
genericOp.iterator_types(),
|
||||
/*doc = */ nullptr,
|
||||
/*library_call = */ nullptr);
|
||||
/*library_call = */ nullptr,
|
||||
/*symbol_source = */ nullptr);
|
||||
rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(),
|
||||
replacementOp.region().begin());
|
||||
|
||||
|
|
|
@ -510,7 +510,8 @@ struct FuseGenericOpsOnTensors {
|
|||
rewriter.getArrayAttr(fusedIndexMaps),
|
||||
consumer.iterator_types(),
|
||||
/*doc=*/nullptr,
|
||||
/*library_call=*/nullptr)
|
||||
/*library_call=*/nullptr,
|
||||
/*symbol_source=*/nullptr)
|
||||
.getOperation();
|
||||
} else {
|
||||
fusedOp =
|
||||
|
@ -524,7 +525,8 @@ struct FuseGenericOpsOnTensors {
|
|||
rewriter.getArrayAttr(fusedIndexMaps),
|
||||
consumer.iterator_types(),
|
||||
/*doc=*/nullptr,
|
||||
/*library_call=*/nullptr)
|
||||
/*library_call=*/nullptr,
|
||||
/*symbol_source=*/nullptr)
|
||||
.getOperation();
|
||||
}
|
||||
|
||||
|
@ -787,7 +789,8 @@ template <typename LinalgOpTy> struct FuseTensorReshapeOpAsProducer {
|
|||
rewriter.getI64IntegerAttr(consumer.getNumResults()),
|
||||
rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(),
|
||||
/*doc=*/nullptr,
|
||||
/*library_call=*/nullptr);
|
||||
/*library_call=*/nullptr,
|
||||
/*symbol_source=*/nullptr);
|
||||
auto &fusedRegion = fusedOp.region();
|
||||
rewriter.cloneRegionBefore(consumer.region(), fusedRegion,
|
||||
fusedRegion.begin());
|
||||
|
@ -843,7 +846,8 @@ template <typename LinalgOpTy> struct FuseTensorReshapeOpAsConsumer {
|
|||
rewriter.getI64IntegerAttr(1), rewriter.getArrayAttr(indexMapAttrs),
|
||||
producer.iterator_types(),
|
||||
/*doc=*/nullptr,
|
||||
/*library_call=*/nullptr);
|
||||
/*library_call=*/nullptr,
|
||||
/*symbol_source=*/nullptr);
|
||||
auto &fusedRegion = fusedOp.region();
|
||||
rewriter.cloneRegionBefore(producer.region(), fusedRegion,
|
||||
fusedRegion.begin());
|
||||
|
@ -893,7 +897,8 @@ template <typename LinalgOpTy> struct FuseConstantOpAsProducer {
|
|||
rewriter.getAffineMapArrayAttr(fusedIndexMaps),
|
||||
consumer.iterator_types(),
|
||||
/*doc=*/nullptr,
|
||||
/*library_call=*/nullptr);
|
||||
/*library_call=*/nullptr,
|
||||
/*symbol_source=*/nullptr);
|
||||
|
||||
// Map the block argument corresponding to the replaced argument with the
|
||||
// scalar constant.
|
||||
|
|
|
@ -36,13 +36,13 @@ static SmallVector<Value, 8> makeCanonicalAffineApplies(OpBuilder &b,
|
|||
ArrayRef<Value> vals) {
|
||||
if (map.isEmpty())
|
||||
return {};
|
||||
assert(map.getNumSymbols() == 0);
|
||||
|
||||
assert(map.getNumInputs() == vals.size());
|
||||
SmallVector<Value, 8> res;
|
||||
res.reserve(map.getNumResults());
|
||||
auto dims = map.getNumDims();
|
||||
for (auto e : map.getResults()) {
|
||||
auto exprMap = AffineMap::get(dims, 0, e);
|
||||
auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e);
|
||||
SmallVector<Value, 4> operands(vals.begin(), vals.end());
|
||||
canonicalizeMapAndOperands(&exprMap, &operands);
|
||||
res.push_back(affine_apply(exprMap, operands));
|
||||
|
@ -165,19 +165,29 @@ void emitScalarImplementation(ArrayRef<Value> allIvs,
|
|||
SmallVector<Value, 4> indexedValues;
|
||||
indexedValues.reserve(nInputs + nOutputs);
|
||||
|
||||
auto attr = linalgOp.template getAttrOfType<IntegerAttr>("symbol_source");
|
||||
auto allIvsPlusDims = SmallVector<Value, 4>(allIvs.begin(), allIvs.end());
|
||||
if (attr) {
|
||||
auto operand = linalgOp.getOperand(attr.getInt());
|
||||
auto shapedType = operand.getType().template cast<ShapedType>();
|
||||
allIvsPlusDims.reserve(allIvs.size() + shapedType.getRank());
|
||||
for (unsigned idx = 0, e = shapedType.getRank(); idx < e; ++idx)
|
||||
allIvsPlusDims.push_back(b.create<DimOp>(loc, operand, idx));
|
||||
}
|
||||
|
||||
// TODO: Avoid the loads if the corresponding argument of the
|
||||
// region has no uses.
|
||||
// 1.a. Emit load from input views.
|
||||
for (unsigned i = 0; i < nInputs; ++i) {
|
||||
auto indexing = makeCanonicalAffineApplies(
|
||||
b, loc, linalgOp.getInputIndexingMap(i), allIvs);
|
||||
b, loc, linalgOp.getInputIndexingMap(i), allIvsPlusDims);
|
||||
// Passing through IndexedValueType emits the proper load operation.
|
||||
indexedValues.push_back(IndexedValueType(linalgOp.getInput(i))(indexing));
|
||||
}
|
||||
// 1.b. Emit load from output views.
|
||||
for (unsigned i = 0; i < nOutputs; ++i) {
|
||||
auto indexing = makeCanonicalAffineApplies(
|
||||
b, loc, linalgOp.getOutputIndexingMap(i), allIvs);
|
||||
b, loc, linalgOp.getOutputIndexingMap(i), allIvsPlusDims);
|
||||
// Passing through IndexedValueType emits the proper load operation.
|
||||
indexedValues.push_back(
|
||||
IndexedValueType(linalgOp.getOutputBuffer(i))(indexing));
|
||||
|
@ -190,7 +200,7 @@ void emitScalarImplementation(ArrayRef<Value> allIvs,
|
|||
SmallVector<Value, 8> outputBuffers;
|
||||
for (unsigned i = 0; i < nOutputs; ++i) {
|
||||
indexing.push_back(makeCanonicalAffineApplies(
|
||||
b, loc, linalgOp.getOutputIndexingMap(i), allIvs));
|
||||
b, loc, linalgOp.getOutputIndexingMap(i), allIvsPlusDims));
|
||||
outputBuffers.push_back(linalgOp.getOutputBuffer(i));
|
||||
}
|
||||
inlineRegionAndEmitStore<IndexedValueType>(linalgOp, indexedValues, indexing,
|
||||
|
@ -457,7 +467,24 @@ Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
|
|||
linalgOp.indexing_maps().template getAsRange<AffineMapAttr>();
|
||||
auto maps = llvm::to_vector<8>(
|
||||
llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
|
||||
AffineMap invertedMap = inversePermutation(concatAffineMaps(maps));
|
||||
SmallVector<Value, 8> sizes = getViewSizes(builder, linalgOp);
|
||||
AffineMap map = concatAffineMaps(maps);
|
||||
if (map.getNumSymbols()) {
|
||||
// Ignore symbols for now as they are not supported by inversePermutation.
|
||||
unsigned dims = map.getNumDims();
|
||||
SmallVector<AffineExpr, 8> zeros(
|
||||
map.getNumSymbols(), getAffineConstantExpr(0, map.getContext()));
|
||||
SmallVector<AffineExpr, 8> res;
|
||||
for (auto result : map.getResults())
|
||||
res.push_back(result.replaceDimsAndSymbols({}, zeros));
|
||||
|
||||
map = AffineMap::get(dims, 0, res, map.getContext());
|
||||
|
||||
// Cut off values that would have been applied to symbols
|
||||
sizes.resize(res.size());
|
||||
}
|
||||
|
||||
AffineMap invertedMap = inversePermutation(map);
|
||||
if (!invertedMap)
|
||||
return {};
|
||||
if (invertedMap.isEmpty()) {
|
||||
|
@ -466,9 +493,8 @@ Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
|
|||
}
|
||||
|
||||
SmallVector<Value, 4> allIvs;
|
||||
auto loopRanges =
|
||||
emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), invertedMap,
|
||||
getViewSizes(builder, linalgOp));
|
||||
auto loopRanges = emitLoopRanges(scope.getBuilderRef(), scope.getLocation(),
|
||||
invertedMap, sizes);
|
||||
GenerateLoopNest<LoopTy>::doit(
|
||||
loopRanges, linalgOp.iterator_types().getValue(), [&](ValueRange ivs) {
|
||||
allIvs.append(ivs.begin(), ivs.end());
|
||||
|
|
|
@ -65,7 +65,8 @@ public:
|
|||
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
||||
loc, llvm::None, newArgs, rewriter.getI64IntegerAttr(operands.size()),
|
||||
rewriter.getI64IntegerAttr(results.size()), op.indexing_maps(),
|
||||
op.iterator_types(), op.docAttr(), op.library_callAttr());
|
||||
op.iterator_types(), op.docAttr(), op.library_callAttr(),
|
||||
op.symbol_sourceAttr());
|
||||
|
||||
// Create a new block in the region of the new Generic Op.
|
||||
Block &oldBlock = op.getRegion().front();
|
||||
|
|
|
@ -93,6 +93,14 @@ AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
|
|||
llvm_unreachable("Unknown AffineExpr");
|
||||
}
|
||||
|
||||
/// Replace symbols[0 .. numDims - 1] by symbols[shift .. shift + numDims - 1].
|
||||
AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift) const {
|
||||
SmallVector<AffineExpr, 4> symbols;
|
||||
for (unsigned idx = 0; idx < numSymbols; ++idx)
|
||||
symbols.push_back(getAffineSymbolExpr(idx + shift, getContext()));
|
||||
return replaceDimsAndSymbols({}, symbols);
|
||||
}
|
||||
|
||||
/// Returns true if this expression is made out of only symbols and
|
||||
/// constants (no dimensional identifiers).
|
||||
bool AffineExpr::isSymbolicOrConstant() const {
|
||||
|
|
|
@ -434,18 +434,19 @@ AffineMap mlir::inversePermutation(AffineMap map) {
|
|||
}
|
||||
|
||||
AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
|
||||
unsigned numResults = 0;
|
||||
unsigned numResults = 0, numDims = 0, numSymbols = 0;
|
||||
for (auto m : maps)
|
||||
numResults += m.getNumResults();
|
||||
unsigned numDims = 0;
|
||||
SmallVector<AffineExpr, 8> results;
|
||||
results.reserve(numResults);
|
||||
for (auto m : maps) {
|
||||
assert(m.getNumSymbols() == 0 && "expected map without symbols");
|
||||
results.append(m.getResults().begin(), m.getResults().end());
|
||||
for (auto res : m.getResults())
|
||||
results.push_back(res.shiftSymbols(m.getNumSymbols(), numSymbols));
|
||||
|
||||
numSymbols += m.getNumSymbols();
|
||||
numDims = std::max(m.getNumDims(), numDims);
|
||||
}
|
||||
return AffineMap::get(numDims, /*numSymbols=*/0, results,
|
||||
return AffineMap::get(numDims, numSymbols, results,
|
||||
maps.front().getContext());
|
||||
}
|
||||
|
||||
|
|
|
@ -106,7 +106,7 @@ func @generic_mismatched_num_returns(%arg0: memref<f32>) {
|
|||
// -----
|
||||
|
||||
func @generic_symbol_in_map(%arg0: memref<i32>) {
|
||||
// expected-error @+1 {{op expected indexing_map #0 to have no symbols}}
|
||||
// expected-error @+1 {{expected the number of symbols in indexing_map #0 to match target rank}}
|
||||
linalg.generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
|
@ -120,6 +120,22 @@ func @generic_symbol_in_map(%arg0: memref<i32>) {
|
|||
|
||||
// -----
|
||||
|
||||
func @generic_symbol_source_out_of_range(%arg0: memref<i32>) {
|
||||
// expected-error @+1 {{symbol_source index out of range}}
|
||||
linalg.generic {
|
||||
args_in = 0,
|
||||
args_out = 1,
|
||||
indexing_maps = [ affine_map<()[N] -> (0)> ],
|
||||
iterator_types = ["parallel"],
|
||||
symbol_source = 1
|
||||
} %arg0 {
|
||||
^bb(%i : i32):
|
||||
linalg.yield %i : i32
|
||||
}: memref<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) {
|
||||
// expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match the number of loops}}
|
||||
linalg.generic {
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
// CHECKLOOP-DAG: #[[$stride2Dilation1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
|
||||
// CHECKLOOP-DAG: #[[$stride2Dilation4:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)>
|
||||
// CHECKLOOP-DAG: #[[$stride3Dilation5:.*]] = affine_map<(d0, d1) -> (d0 * 3 + d1 * 5)>
|
||||
// CHECKLOOP-DAG: #[[$convMap:.*]] = affine_map<(d0, d1)[s0] -> (d0 + d1 - s0 floordiv 2)>
|
||||
|
||||
// CHECKPARALLEL-DAG: #[[$strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
|
||||
// CHECKPARALLEL-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
|
||||
|
@ -25,6 +26,7 @@
|
|||
// CHECKPARALLEL-DAG: #[[$stride2Dilation1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
|
||||
// CHECKPARALLEL-DAG: #[[$stride2Dilation4:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)>
|
||||
// CHECKPARALLEL-DAG: #[[$stride3Dilation5:.*]] = affine_map<(d0, d1) -> (d0 * 3 + d1 * 5)>
|
||||
// CHECKPARALLEL-DAG: #[[$convMap:.*]] = affine_map<(d0, d1)[s0] -> (d0 + d1 - s0 floordiv 2)>
|
||||
|
||||
|
||||
func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
|
||||
|
@ -910,3 +912,331 @@ func @named_batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memre
|
|||
// CHECKPARALLEL: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
|
||||
// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
|
||||
// CHECKPARALLEL: store %[[res]], %[[mC]][%[[b]], %[[m]], %[[n]]] : memref<?x?x?xf32>
|
||||
|
||||
#conv_1d_accesses = [
|
||||
affine_map<(m, n)[s0] -> (m + n - s0 floordiv 2)>, // in
|
||||
affine_map<(m, n)[s0] -> (n)>, // filter
|
||||
affine_map<(m, n)[s0] -> (m)> // out
|
||||
]
|
||||
|
||||
#conv_1d_trait = {
|
||||
args_in = 2,
|
||||
args_out = 1,
|
||||
doc = "C(m) += A(m) * B(n)",
|
||||
indexing_maps = #conv_1d_accesses,
|
||||
library_call = "linalg_conv_1d",
|
||||
n_views = [2, 1],
|
||||
iterator_types = ["parallel", "parallel"],
|
||||
symbol_source = 1
|
||||
}
|
||||
|
||||
func @conv1d(%in : memref<?xf32>, %filter : memref<?xf32>, %out : memref<?xf32>) -> () {
|
||||
linalg.generic #conv_1d_trait %in, %filter, %out {
|
||||
^bb0(%a: f32, %b: f32, %c: f32) :
|
||||
%d = mulf %a, %b : f32
|
||||
%e = addf %c, %d : f32
|
||||
linalg.yield %e : f32
|
||||
} : memref<?xf32>,
|
||||
memref<?xf32>,
|
||||
memref<?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECKLOOP-LABEL: @conv1d
|
||||
// CHECKLOOP-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?xf32>
|
||||
// CHECKLOOP-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?xf32>
|
||||
// CHECKLOOP-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?xf32>
|
||||
// CHECKLOOP: %[[c0:.*]] = constant 0 : index
|
||||
// CHECKLOOP: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?xf32>
|
||||
// CHECKLOOP: %[[dim1:.*]] = dim %[[arg2]], %[[c0]] : memref<?xf32>
|
||||
// CHECKLOOP: scf.for %[[b:.*]] = %{{.*}} to %[[dim1]] step %{{.*}} {
|
||||
// CHECKLOOP: scf.for %[[m:.*]] = %{{.*}} to %[[dim0]] step %{{.*}} {
|
||||
// CHECKLOOP: %[[dim2:.*]] = dim %[[arg1]], %[[c0]] : memref<?xf32>
|
||||
// CHECKLOOP: %[[aff:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim2]]]
|
||||
// CHECKLOOP: %[[va:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32>
|
||||
// CHECKLOOP: %[[vb:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
|
||||
// CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref<?xf32>
|
||||
// CHECKLOOP: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
|
||||
// CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
|
||||
// CHECKLOOP: store %[[res]], %[[arg2]][%[[b]]] : memref<?xf32>
|
||||
|
||||
// CHECKPARALLEL-LABEL: @conv1d
|
||||
// CHECKPARALLEL-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?xf32>
|
||||
// CHECKPARALLEL-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?xf32>
|
||||
// CHECKPARALLEL-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?xf32>
|
||||
// CHECKPARALLEL: %[[c0:.*]] = constant 0 : index
|
||||
// CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?xf32>
|
||||
// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg2]], %[[c0]] : memref<?xf32>
|
||||
// CHECKPARALLEL: scf.parallel (%[[b:.*]], %[[m:.*]]) = (%{{.*}}, %{{.*}}) to (%[[dim1]], %[[dim0]]) step ({{.*}}) {
|
||||
// CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg1]], %[[c0]] : memref<?xf32>
|
||||
// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim2]]]
|
||||
// CHECKPARALLEL: %[[va:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32>
|
||||
// CHECKPARALLEL: %[[vb:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
|
||||
// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref<?xf32>
|
||||
// CHECKPARALLEL: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
|
||||
// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
|
||||
// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[b]]] : memref<?xf32>
|
||||
|
||||
#conv_2d_accesses = [
|
||||
affine_map<(m, n, m1, n1)[s0, s1] -> (m + m1 - s0 floordiv 2, n + n1 - s1 floordiv 2)>, // in
|
||||
affine_map<(m, n, m1, n1)[s0, s1] -> (m1, n1)>, // filter
|
||||
affine_map<(m, n, m1, n1)[s0, s1] -> (m, n)> // out
|
||||
]
|
||||
|
||||
#conv_2d_trait = {
|
||||
args_in = 2,
|
||||
args_out = 1,
|
||||
doc = "C(m,n) += A(m,n) * B(m1,n1)",
|
||||
indexing_maps = #conv_2d_accesses,
|
||||
library_call = "linalg_conv_2d",
|
||||
n_views = [2, 1],
|
||||
iterator_types = ["parallel", "parallel", "parallel", "parallel"],
|
||||
symbol_source = 1
|
||||
}
|
||||
|
||||
func @conv2d(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out : memref<?x?xf32>) -> () {
|
||||
linalg.generic #conv_2d_trait %in, %filter, %out {
|
||||
^bb0(%a: f32, %b: f32, %c: f32) :
|
||||
%d = mulf %a, %b : f32
|
||||
%e = addf %c, %d : f32
|
||||
linalg.yield %e : f32
|
||||
} : memref<?x?xf32>,
|
||||
memref<?x?xf32>,
|
||||
memref<?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECKLOOP-LABEL: @conv2d
|
||||
// CHECKLOOP-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?xf32>
|
||||
// CHECKLOOP-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?xf32>
|
||||
// CHECKLOOP-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?xf32>
|
||||
// CHECKLOOP: %[[c0:.*]] = constant 0 : index
|
||||
// CHECKLOOP: %[[c1:.*]] = constant 1 : index
|
||||
// CHECKLOOP: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?xf32>
|
||||
// CHECKLOOP: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?xf32>
|
||||
// CHECKLOOP: %[[dim2:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?xf32>
|
||||
// CHECKLOOP: %[[dim3:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?xf32>
|
||||
// CHECKLOOP: scf.for %[[i0:.*]] = %{{.*}} to %[[dim2]] step %{{.*}} {
|
||||
// CHECKLOOP: scf.for %[[i1:.*]] = %{{.*}} to %[[dim3]] step %{{.*}} {
|
||||
// CHECKLOOP: scf.for %[[i2:.*]] = %{{.*}} to %[[dim0]] step %{{.*}} {
|
||||
// CHECKLOOP: scf.for %[[i3:.*]] = %{{.*}} to %[[dim1]] step %{{.*}} {
|
||||
// CHECKLOOP: %[[dim4:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?xf32>
|
||||
// CHECKLOOP: %[[dim5:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?xf32>
|
||||
// CHECKLOOP: %[[aff1:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim4]]]
|
||||
// CHECKLOOP: %[[aff2:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim5]]]
|
||||
// CHECKLOOP: %[[va:.*]] = load %[[arg0]][%[[aff1]], %[[aff2]]] : memref<?x?xf32>
|
||||
// CHECKLOOP: %[[vb:.*]] = load %[[arg1]][%[[i2]], %[[i3]]] : memref<?x?xf32>
|
||||
// CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[i0]], %[[i1]]] : memref<?x?xf32>
|
||||
// CHECKLOOP: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
|
||||
// CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
|
||||
// CHECKLOOP: store %[[res]], %[[arg2]][%[[i0]], %[[i1]]] : memref<?x?xf32>
|
||||
|
||||
// CHECKPARALLEL-LABEL: @conv2d
|
||||
// CHECKPARALLEL-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?xf32>
|
||||
// CHECKPARALLEL-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?xf32>
|
||||
// CHECKPARALLEL-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?xf32>
|
||||
// CHECKPARALLEL: %[[c0:.*]] = constant 0 : index
|
||||
// CHECKPARALLEL: %[[c1:.*]] = constant 1 : index
|
||||
// CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?xf32>
|
||||
// CHECKPARALLEL: scf.parallel (%[[i0:.*]], %[[i1:.*]], %[[i2:.*]], %[[i3:.*]]) = (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) to (%[[dim2]], %[[dim3]], %[[dim0]], %[[dim1]]) step ({{.*}}) {
|
||||
// CHECKPARALLEL: %[[dim4:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim5:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?xf32>
|
||||
// CHECKPARALLEL: %[[aff1:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim4]]]
|
||||
// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim5]]]
|
||||
// CHECKPARALLEL: %[[va:.*]] = load %[[arg0]][%[[aff1]], %[[aff2]]] : memref<?x?xf32>
|
||||
// CHECKPARALLEL: %[[vb:.*]] = load %[[arg1]][%[[i2]], %[[i3]]] : memref<?x?xf32>
|
||||
// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[i0]], %[[i1]]] : memref<?x?xf32>
|
||||
// CHECKPARALLEL: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
|
||||
// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
|
||||
// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[i0]], %[[i1]]] : memref<?x?xf32>
|
||||
|
||||
#conv_3d_accesses = [
|
||||
affine_map<(m, n, k, m1, n1, k1)[s0, s1, s2] -> (m + m1 - s0 floordiv 2, n + n1 - s1 floordiv 2, k + k1 - s2 floordiv 2)>, // in
|
||||
affine_map<(m, n, k, m1, n1, k1)[s0, s1, s2] -> (m1, n1, k1)>, // filter
|
||||
affine_map<(m, n, k, m1, n1, k1)[s0, s1, s2] -> (m, n, k)> // out
|
||||
]
|
||||
|
||||
#conv_3d_trait = {
|
||||
args_in = 2,
|
||||
args_out = 1,
|
||||
doc = "C(m,n,k) += A(m,n,k) * B(m1,n1,k1)",
|
||||
indexing_maps = #conv_3d_accesses,
|
||||
library_call = "linalg_conv_3d",
|
||||
n_views = [2, 1],
|
||||
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"],
|
||||
symbol_source = 1
|
||||
}
|
||||
|
||||
func @conv3d(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %out : memref<?x?x?xf32>) -> () {
|
||||
linalg.generic #conv_3d_trait %in, %filter, %out {
|
||||
^bb0(%a: f32, %b: f32, %c: f32) :
|
||||
%d = mulf %a, %b : f32
|
||||
%e = addf %c, %d : f32
|
||||
linalg.yield %e : f32
|
||||
} : memref<?x?x?xf32>,
|
||||
memref<?x?x?xf32>,
|
||||
memref<?x?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECKLOOP-LABEL: @conv3d
|
||||
// CHECKLOOP-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
|
||||
// CHECKLOOP-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
|
||||
// CHECKLOOP-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
|
||||
// CHECKLOOP: %[[c0:.*]] = constant 0 : index
|
||||
// CHECKLOOP: %[[c1:.*]] = constant 1 : index
|
||||
// CHECKLOOP: %[[c2:.*]] = constant 2 : index
|
||||
// CHECKLOOP: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?xf32>
|
||||
// CHECKLOOP: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?xf32>
|
||||
// CHECKLOOP: %[[dim2:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
|
||||
// CHECKLOOP: %[[dim3:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?x?xf32>
|
||||
// CHECKLOOP: %[[dim4:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?x?xf32>
|
||||
// CHECKLOOP: %[[dim5:.*]] = dim %[[arg2]], %[[c2]] : memref<?x?x?xf32>
|
||||
// CHECKLOOP: scf.for %[[i0:.*]] = %{{.*}} to %[[dim3]] step %{{.*}} {
|
||||
// CHECKLOOP: scf.for %[[i1:.*]] = %{{.*}} to %[[dim4]] step %{{.*}} {
|
||||
// CHECKLOOP: scf.for %[[i2:.*]] = %{{.*}} to %[[dim5]] step %{{.*}} {
|
||||
// CHECKLOOP: scf.for %[[i3:.*]] = %{{.*}} to %[[dim0]] step %{{.*}} {
|
||||
// CHECKLOOP: scf.for %[[i4:.*]] = %{{.*}} to %[[dim1]] step %{{.*}} {
|
||||
// CHECKLOOP: scf.for %[[i5:.*]] = %{{.*}} to %[[dim2]] step %{{.*}} {
|
||||
// CHECKLOOP: %[[dim6:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?xf32>
|
||||
// CHECKLOOP: %[[dim7:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?xf32>
|
||||
// CHECKLOOP: %[[dim8:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
|
||||
// CHECKLOOP: %[[aff1:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim6]]]
|
||||
// CHECKLOOP: %[[aff2:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim7]]]
|
||||
// CHECKLOOP: %[[aff3:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim8]]]
|
||||
// CHECKLOOP: %[[va:.*]] = load %[[arg0]][%[[aff1]], %[[aff2]], %[[aff3]]] : memref<?x?x?xf32>
|
||||
// CHECKLOOP: %[[vb:.*]] = load %[[arg1]][%[[i3]], %[[i4]], %[[i5]]] : memref<?x?x?xf32>
|
||||
// CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[i0]], %[[i1]], %[[i2]]] : memref<?x?x?xf32>
|
||||
// CHECKLOOP: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
|
||||
// CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
|
||||
// CHECKLOOP: store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]]] : memref<?x?x?xf32>
|
||||
|
||||
// CHECKPARALLEL-LABEL: @conv3d
|
||||
// CHECKPARALLEL-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
|
||||
// CHECKPARALLEL-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
|
||||
// CHECKPARALLEL-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[c0:.*]] = constant 0 : index
|
||||
// CHECKPARALLEL: %[[c1:.*]] = constant 1 : index
|
||||
// CHECKPARALLEL: %[[c2:.*]] = constant 2 : index
|
||||
// CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim4:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim5:.*]] = dim %[[arg2]], %[[c2]] : memref<?x?x?xf32>
|
||||
// CHECKPARALLEL: scf.parallel (%[[i0:.*]], %[[i1:.*]], %[[i2:.*]], %[[i3:.*]], %[[i4:.*]], %[[i5:.*]]) = (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) to (%[[dim3]], %[[dim4]], %[[dim5]], %[[dim0]], %[[dim1]], %[[dim2]]) step ({{.*}}) {
|
||||
// CHECKPARALLEL: %[[dim6:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim7:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim8:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[aff1:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim6]]]
|
||||
// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim7]]]
|
||||
// CHECKPARALLEL: %[[aff3:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim8]]]
|
||||
// CHECKPARALLEL: %[[va:.*]] = load %[[arg0]][%[[aff1]], %[[aff2]], %[[aff3]]] : memref<?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[vb:.*]] = load %[[arg1]][%[[i3]], %[[i4]], %[[i5]]] : memref<?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[i0]], %[[i1]], %[[i2]]] : memref<?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
|
||||
// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
|
||||
// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]]] : memref<?x?x?xf32>
|
||||
|
||||
#conv_4d_accesses = [
|
||||
affine_map<(m, n, k, l, m1, n1, k1, l1)[s0, s1, s2, s3] -> (m + m1 - s0 floordiv 2, n + n1 - s1 floordiv 2, k + k1 - s2 floordiv 2, l + l1 - s3 floordiv 2)>, // in
|
||||
affine_map<(m, n, k, l, m1, n1, k1, l1)[s0, s1, s2, s3] -> (m1, n1, k1, l1)>, // filter
|
||||
affine_map<(m, n, k, l, m1, n1, k1, l1)[s0, s1, s2, s3] -> (m, n, k, l)> // out
|
||||
]
|
||||
|
||||
#conv_4d_trait = {
|
||||
args_in = 2,
|
||||
args_out = 1,
|
||||
doc = "C(m,n,k,l) += A(m,n,k,l) * B(m1,n1,k1,l1)",
|
||||
indexing_maps = #conv_4d_accesses,
|
||||
library_call = "linalg_conv_4d",
|
||||
n_views = [2, 1],
|
||||
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"],
|
||||
symbol_source = 1
|
||||
}
|
||||
|
||||
func @conv4d(%in : memref<?x?x?x?xf32>, %filter : memref<?x?x?x?xf32>, %out : memref<?x?x?x?xf32>) -> () {
|
||||
linalg.generic #conv_4d_trait %in, %filter, %out {
|
||||
^bb0(%a: f32, %b: f32, %c: f32) :
|
||||
%d = mulf %a, %b : f32
|
||||
%e = addf %c, %d : f32
|
||||
linalg.yield %e : f32
|
||||
} : memref<?x?x?x?xf32>,
|
||||
memref<?x?x?x?xf32>,
|
||||
memref<?x?x?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECKLOOP-LABEL: @conv4d
|
||||
// CHECKLOOP-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
|
||||
// CHECKLOOP-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
|
||||
// CHECKLOOP-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
|
||||
// CHECKLOOP: %[[c0:.*]] = constant 0 : index
|
||||
// CHECKLOOP: %[[c1:.*]] = constant 1 : index
|
||||
// CHECKLOOP: %[[c2:.*]] = constant 2 : index
|
||||
// CHECKLOOP: %[[c3:.*]] = constant 3 : index
|
||||
// CHECKLOOP: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?x?xf32>
|
||||
// CHECKLOOP: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?x?xf32>
|
||||
// CHECKLOOP: %[[dim2:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?x?xf32>
|
||||
// CHECKLOOP: %[[dim3:.*]] = dim %[[arg1]], %[[c3]] : memref<?x?x?x?xf32>
|
||||
// CHECKLOOP: %[[dim4:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?x?x?xf32>
|
||||
// CHECKLOOP: %[[dim5:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?x?x?xf32>
|
||||
// CHECKLOOP: %[[dim6:.*]] = dim %[[arg2]], %[[c2]] : memref<?x?x?x?xf32>
|
||||
// CHECKLOOP: %[[dim7:.*]] = dim %[[arg2]], %[[c3]] : memref<?x?x?x?xf32>
|
||||
// CHECKLOOP: scf.for %[[i0:.*]] = %{{.*}} to %[[dim4]] step %{{.*}} {
|
||||
// CHECKLOOP: scf.for %[[i1:.*]] = %{{.*}} to %[[dim5]] step %{{.*}} {
|
||||
// CHECKLOOP: scf.for %[[i2:.*]] = %{{.*}} to %[[dim6]] step %{{.*}} {
|
||||
// CHECKLOOP: scf.for %[[i3:.*]] = %{{.*}} to %[[dim7]] step %{{.*}} {
|
||||
// CHECKLOOP: scf.for %[[i4:.*]] = %{{.*}} to %[[dim0]] step %{{.*}} {
|
||||
// CHECKLOOP: scf.for %[[i5:.*]] = %{{.*}} to %[[dim1]] step %{{.*}} {
|
||||
// CHECKLOOP: scf.for %[[i6:.*]] = %{{.*}} to %[[dim2]] step %{{.*}} {
|
||||
// CHECKLOOP: scf.for %[[i7:.*]] = %{{.*}} to %[[dim3]] step %{{.*}} {
|
||||
// CHECKLOOP: %[[dim8:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?x?xf32>
|
||||
// CHECKLOOP: %[[dim9:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?x?xf32>
|
||||
// CHECKLOOP: %[[dim10:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?x?xf32>
|
||||
// CHECKLOOP: %[[dim11:.*]] = dim %[[arg1]], %[[c3]] : memref<?x?x?x?xf32>
|
||||
// CHECKLOOP: %[[aff1:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim8]]]
|
||||
// CHECKLOOP: %[[aff2:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim9]]]
|
||||
// CHECKLOOP: %[[aff3:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim10]]]
|
||||
// CHECKLOOP: %[[aff4:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim11]]]
|
||||
// CHECKLOOP: %[[va:.*]] = load %[[arg0]][%[[aff1]], %[[aff2]], %[[aff3]], %[[aff4]]] : memref<?x?x?x?xf32>
|
||||
// CHECKLOOP: %[[vb:.*]] = load %[[arg1]][%[[i4]], %[[i5]], %[[i6]], %[[i7]]] : memref<?x?x?x?xf32>
|
||||
// CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref<?x?x?x?xf32>
|
||||
// CHECKLOOP: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
|
||||
// CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
|
||||
// CHECKLOOP: store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref<?x?x?x?xf32>
|
||||
|
||||
// CHECKPARALLEL-LABEL: @conv4d
|
||||
// CHECKPARALLEL-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
|
||||
// CHECKPARALLEL-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
|
||||
// CHECKPARALLEL-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[c0:.*]] = constant 0 : index
|
||||
// CHECKPARALLEL: %[[c1:.*]] = constant 1 : index
|
||||
// CHECKPARALLEL: %[[c2:.*]] = constant 2 : index
|
||||
// CHECKPARALLEL: %[[c3:.*]] = constant 3 : index
|
||||
// CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg1]], %[[c3]] : memref<?x?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim4:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim5:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim6:.*]] = dim %[[arg2]], %[[c2]] : memref<?x?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim7:.*]] = dim %[[arg2]], %[[c3]] : memref<?x?x?x?xf32>
|
||||
// CHECKPARALLEL: scf.parallel (%[[i0:.*]], %[[i1:.*]], %[[i2:.*]], %[[i3:.*]], %[[i4:.*]], %[[i5:.*]], %[[i6:.*]], %[[i7:.*]]) = (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) to (%[[dim4]], %[[dim5]], %[[dim6]], %[[dim7]], %[[dim0]], %[[dim1]], %[[dim2]], %[[dim3]]) step ({{.*}}) {
|
||||
// CHECKPARALLEL: %[[dim8:.*]] = dim %[[arg1]], %[[c0]] : memref<?x?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim9:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim10:.*]] = dim %[[arg1]], %[[c2]] : memref<?x?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim11:.*]] = dim %[[arg1]], %[[c3]] : memref<?x?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[aff1:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim8]]]
|
||||
// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim9]]]
|
||||
// CHECKPARALLEL: %[[aff3:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim10]]]
|
||||
// CHECKPARALLEL: %[[aff4:.*]] = affine.apply #[[$convMap]](%{{.*}}, %{{.*}})[%[[dim11]]]
|
||||
// CHECKPARALLEL: %[[va:.*]] = load %[[arg0]][%[[aff1]], %[[aff2]], %[[aff3]], %[[aff4]]] : memref<?x?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[vb:.*]] = load %[[arg1]][%[[i4]], %[[i5]], %[[i6]], %[[i7]]] : memref<?x?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref<?x?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
|
||||
// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
|
||||
// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref<?x?x?x?xf32>
|
||||
|
|
|
@ -77,7 +77,8 @@ struct TestBufferPlacementPreparationPass
|
|||
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
||||
loc, llvm::None, newArgs, rewriter.getI64IntegerAttr(operands.size()),
|
||||
rewriter.getI64IntegerAttr(results.size()), op.indexing_maps(),
|
||||
op.iterator_types(), op.docAttr(), op.library_callAttr());
|
||||
op.iterator_types(), op.docAttr(), op.library_callAttr(),
|
||||
op.symbol_sourceAttr());
|
||||
|
||||
// Create a new block in the region of the new Generic Op.
|
||||
Block &oldBlock = op.getRegion().front();
|
||||
|
|
Loading…
Reference in New Issue