forked from OSchip/llvm-project
[mlir][Linalg] NFC - Extract a standalone LinalgInterfaces
This separation improves the layering and paves the way for more interfaces coming up in the future. Differential revision: https://reviews.llvm.org/D95941
This commit is contained in:
parent
a2fdf9d4d7
commit
1029c82c1e
|
@ -45,8 +45,8 @@ add_public_tablegen_target(MLIRLinalgStructuredOpsIncGen)
|
||||||
add_dependencies(MLIRLinalgStructuredOpsIncGen LinalgOdsGen)
|
add_dependencies(MLIRLinalgStructuredOpsIncGen LinalgOdsGen)
|
||||||
add_dependencies(mlir-headers MLIRLinalgStructuredOpsIncGen)
|
add_dependencies(mlir-headers MLIRLinalgStructuredOpsIncGen)
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS LinalgStructuredOpsInterface.td)
|
set(LLVM_TARGET_DEFINITIONS LinalgInterfaces.td)
|
||||||
mlir_tablegen(LinalgStructuredOpsInterfaces.h.inc -gen-op-interface-decls)
|
mlir_tablegen(LinalgInterfaces.h.inc -gen-op-interface-decls)
|
||||||
mlir_tablegen(LinalgStructuredOpsInterfaces.cpp.inc -gen-op-interface-defs)
|
mlir_tablegen(LinalgInterfaces.cpp.inc -gen-op-interface-defs)
|
||||||
add_public_tablegen_target(MLIRLinalgStructuredOpsInterfaceIncGen)
|
add_public_tablegen_target(MLIRLinalgInterfacesIncGen)
|
||||||
add_dependencies(mlir-headers MLIRLinalgStructuredOpsInterfaceIncGen)
|
add_dependencies(mlir-headers MLIRLinalgInterfacesIncGen)
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
//===- LinalgInterface.h - Linalg operations interfaces -------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file implements the operation interfaces for Linalg operations.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_DIALECT_LINALG_IR_LINALGINTERFACES_H_
|
||||||
|
#define MLIR_DIALECT_LINALG_IR_LINALGINTERFACES_H_
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||||
|
#include "mlir/IR/AffineMap.h"
|
||||||
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace linalg {
|
||||||
|
|
||||||
|
/// Returns the values obtained by applying `map` to the list of values.
|
||||||
|
SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
|
||||||
|
AffineMap map, ValueRange values);
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
/// Verify that `op` conforms to the invariants of StructuredOpInterface
|
||||||
|
LogicalResult verifyStructuredOpInterface(Operation *op);
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
} // namespace linalg
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.h.inc"
|
||||||
|
|
||||||
|
/// Include the generated interface declarations.
|
||||||
|
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h.inc"
|
||||||
|
|
||||||
|
#endif // MLIR_DIALECT_LINALG_IR_LINALGINTERFACES_H_
|
|
@ -1,4 +1,4 @@
|
||||||
//===- LinalgStructuredInterface.td- Linalg StructuredIfce -*- tablegen -*-===//
|
//===- LinalgInterfaces.td - Linalg Interfaces Declaration -*- tablegen -*-===//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
@ -6,14 +6,14 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
//
|
//
|
||||||
// This is the definition file for the structured interface for Linalg ops.
|
// This is the definition file for the structured interface sfor Linalg ops.
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef LINALG_IR_STRUCTURED_OPS_INTERFACE
|
#ifndef LINALG_IR_LINALGINTERFACES
|
||||||
#define LINALG_IR_STRUCTURED_OPS_INTERFACE
|
#define LINALG_IR_LINALGINTERFACES
|
||||||
|
|
||||||
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
|
|
||||||
// The linalg 'LinalgStructuredInterface' provides access to the 'LinalgOp'
|
// The linalg 'LinalgStructuredInterface' provides access to the 'LinalgOp'
|
||||||
// interface.
|
// interface.
|
||||||
|
@ -33,10 +33,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
||||||
/*methodName=*/"getNumPayloadInductionVariables",
|
/*methodName=*/"getNumPayloadInductionVariables",
|
||||||
/*args=*/(ins),
|
/*args=*/(ins),
|
||||||
/*methodBody=*/"",
|
/*methodBody=*/"",
|
||||||
/*defaultImplementation=*/[{
|
/*defaultImplementation=*/""
|
||||||
return isa<IndexedGenericOp>(this->getOperation()) ?
|
|
||||||
$_op.getNumLoops() : 0;
|
|
||||||
}]
|
|
||||||
>,
|
>,
|
||||||
//===------------------------------------------------------------------===//
|
//===------------------------------------------------------------------===//
|
||||||
// Loop types handling.
|
// Loop types handling.
|
||||||
|
@ -570,7 +567,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
||||||
/*methodBody=*/"",
|
/*methodBody=*/"",
|
||||||
/*defaultImplementation=*/[{
|
/*defaultImplementation=*/[{
|
||||||
unsigned bbArgNumber =
|
unsigned bbArgNumber =
|
||||||
getNumPayloadInductionVariables() + opOperand->getOperandNumber();
|
$_op.getNumPayloadInductionVariables() + opOperand->getOperandNumber();
|
||||||
// Safeguard against the named linalg ops that are manually defined and
|
// Safeguard against the named linalg ops that are manually defined and
|
||||||
// that only support buffer semantics: we should not be there.
|
// that only support buffer semantics: we should not be there.
|
||||||
// Such ops have an empty regionBuilder and are not constructed with a
|
// Such ops have an empty regionBuilder and are not constructed with a
|
||||||
|
@ -1117,4 +1114,4 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
||||||
let verify = [{ return detail::verifyStructuredOpInterface($_op); }];
|
let verify = [{ return detail::verifyStructuredOpInterface($_op); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // LINALG_IR_STRUCTURED_OPS_INTERFACE
|
#endif // LINALG_IR_LINALGINTERFACES
|
|
@ -42,10 +42,6 @@ class PoolingSumOp;
|
||||||
using LoopRangeBuilder =
|
using LoopRangeBuilder =
|
||||||
std::function<SmallVector<Range, 4>(OpBuilder &, Location)>;
|
std::function<SmallVector<Range, 4>(OpBuilder &, Location)>;
|
||||||
|
|
||||||
/// Returns the values obtained by applying `map` to the list of values.
|
|
||||||
SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
|
|
||||||
AffineMap map, ValueRange values);
|
|
||||||
|
|
||||||
/// Provide a very simple inference procedure to build the loop ranges from the
|
/// Provide a very simple inference procedure to build the loop ranges from the
|
||||||
/// op and its operands. This only works with permutation affine maps and
|
/// op and its operands. This only works with permutation affine maps and
|
||||||
/// patterns of the form `(m, n)[s] -> (m + n - s floordiv 2)`.
|
/// patterns of the form `(m, n)[s] -> (m + n - s floordiv 2)`.
|
||||||
|
@ -122,7 +118,7 @@ namespace linalg {
|
||||||
class IndexedGenericOp;
|
class IndexedGenericOp;
|
||||||
} // namespace linalg
|
} // namespace linalg
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.h.inc"
|
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h.inc"
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h.inc"
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
#define LINALG_STRUCTURED_OPS
|
#define LINALG_STRUCTURED_OPS
|
||||||
|
|
||||||
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
|
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
|
||||||
include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td"
|
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
|
||||||
include "mlir/Interfaces/CopyOpInterface.td"
|
include "mlir/Interfaces/CopyOpInterface.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
|
|
||||||
|
@ -25,13 +25,22 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
// depending on the specific Linalg op.
|
// depending on the specific Linalg op.
|
||||||
class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
|
class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
|
||||||
: Op<Linalg_Dialect, mnemonic, !listconcat(props, [
|
: Op<Linalg_Dialect, mnemonic, !listconcat(props, [
|
||||||
LinalgStructuredInterface])> {}
|
LinalgStructuredInterface])> {
|
||||||
|
code structuredOpsBaseDecls = [{
|
||||||
|
// Return the number of induction variables in the basic block. This should
|
||||||
|
// always be 0 for index-free linalg ops. For IndexedGeneric, this must be
|
||||||
|
// equal to numLoops.
|
||||||
|
unsigned getNumPayloadInductionVariables() {
|
||||||
|
return isa<IndexedGenericOp>(this->getOperation()) ? getNumLoops() : 0;
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
class LinalgStructured_Op<string mnemonic, list<OpTrait> props>
|
class LinalgStructured_Op<string mnemonic, list<OpTrait> props>
|
||||||
: LinalgStructuredBase_Op<mnemonic,
|
: LinalgStructuredBase_Op<mnemonic,
|
||||||
!listconcat(props, [
|
!listconcat(props, [
|
||||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>])> {
|
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>])> {
|
||||||
code libraryCallName = [{
|
code structuredOpsDecls = structuredOpsBaseDecls # [{
|
||||||
std::string getLibraryCallName() {
|
std::string getLibraryCallName() {
|
||||||
return generateLibraryCallName(getOperation());
|
return generateLibraryCallName(getOperation());
|
||||||
}
|
}
|
||||||
|
@ -110,7 +119,7 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
|
||||||
$_builder, $_state, input, output, AffineMapAttr(), AffineMapAttr());
|
$_builder, $_state, input, output, AffineMapAttr(), AffineMapAttr());
|
||||||
}]>];
|
}]>];
|
||||||
|
|
||||||
let extraClassDeclaration = libraryCallName # [{
|
let extraClassDeclaration = structuredOpsDecls # [{
|
||||||
ValueRange inputs() { return getOperands().take_front(); }
|
ValueRange inputs() { return getOperands().take_front(); }
|
||||||
ValueRange outputs() { return getOperands().take_back(); }
|
ValueRange outputs() { return getOperands().take_back(); }
|
||||||
|
|
||||||
|
@ -155,7 +164,7 @@ def FillOp : LinalgStructured_Op<"fill", []> {
|
||||||
let arguments = (ins AnyShaped:$output,
|
let arguments = (ins AnyShaped:$output,
|
||||||
AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value);
|
AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value);
|
||||||
let results = (outs Optional<AnyRankedTensor>:$result);
|
let results = (outs Optional<AnyRankedTensor>:$result);
|
||||||
let extraClassDeclaration = libraryCallName # [{
|
let extraClassDeclaration = structuredOpsDecls # [{
|
||||||
ValueRange inputs() { return {}; }
|
ValueRange inputs() { return {}; }
|
||||||
ValueRange outputs() { return getOperands().take_front(); }
|
ValueRange outputs() { return getOperands().take_front(); }
|
||||||
|
|
||||||
|
@ -232,7 +241,7 @@ class PoolingBase_Op<string mnemonic, list<OpTrait> props>
|
||||||
for both low and high in each of the dimensions, if not specified.
|
for both low and high in each of the dimensions, if not specified.
|
||||||
}];
|
}];
|
||||||
|
|
||||||
code commonUtils = libraryCallName # [{
|
code commonUtils = structuredOpsDecls # [{
|
||||||
int64_t getStride(unsigned i) {
|
int64_t getStride(unsigned i) {
|
||||||
assert(i < getNumWindowLoops());
|
assert(i < getNumWindowLoops());
|
||||||
if (!strides().hasValue()) return 1;
|
if (!strides().hasValue()) return 1;
|
||||||
|
@ -497,7 +506,7 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, [
|
||||||
OptionalAttr<ArrayAttr>:$sparse);
|
OptionalAttr<ArrayAttr>:$sparse);
|
||||||
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
|
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
|
||||||
let regions = (region AnyRegion:$region);
|
let regions = (region AnyRegion:$region);
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = structuredOpsBaseDecls # [{
|
||||||
SmallVector<StringRef, 8> linalgTraitAttrNames() {
|
SmallVector<StringRef, 8> linalgTraitAttrNames() {
|
||||||
return SmallVector<StringRef, 8>{
|
return SmallVector<StringRef, 8>{
|
||||||
getDocAttrName(),
|
getDocAttrName(),
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
add_mlir_dialect_library(MLIRLinalg
|
add_mlir_dialect_library(MLIRLinalg
|
||||||
|
LinalgInterfaces.cpp
|
||||||
LinalgOps.cpp
|
LinalgOps.cpp
|
||||||
LinalgTypes.cpp
|
LinalgTypes.cpp
|
||||||
|
|
||||||
|
@ -6,9 +7,9 @@ add_mlir_dialect_library(MLIRLinalg
|
||||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
|
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
|
MLIRLinalgInterfacesIncGen
|
||||||
MLIRLinalgOpsIncGen
|
MLIRLinalgOpsIncGen
|
||||||
MLIRLinalgStructuredOpsIncGen
|
MLIRLinalgStructuredOpsIncGen
|
||||||
MLIRLinalgStructuredOpsInterfaceIncGen
|
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRAffine
|
MLIRAffine
|
||||||
|
|
|
@ -0,0 +1,294 @@
|
||||||
|
//===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
|
#include "mlir/IR/AffineExprVisitor.h"
|
||||||
|
#include "mlir/IR/AffineMap.h"
|
||||||
|
#include "llvm/ADT/SmallSet.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::linalg;
|
||||||
|
|
||||||
|
/// Include the definitions of the copy operation interface.
|
||||||
|
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
|
||||||
|
|
||||||
|
/// Fully compose map with operands and canonicalize the result.
|
||||||
|
/// Return the `createOrFold`'ed AffineApply op.
|
||||||
|
static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc,
|
||||||
|
AffineMap map,
|
||||||
|
ValueRange operandsRef) {
|
||||||
|
SmallVector<Value, 4> operands(operandsRef.begin(), operandsRef.end());
|
||||||
|
fullyComposeAffineMapAndOperands(&map, &operands);
|
||||||
|
canonicalizeMapAndOperands(&map, &operands);
|
||||||
|
return b.createOrFold<AffineApplyOp>(loc, map, operands);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 4> mlir::linalg::applyMapToValues(OpBuilder &b, Location loc,
|
||||||
|
AffineMap map,
|
||||||
|
ValueRange values) {
|
||||||
|
SmallVector<Value, 4> res;
|
||||||
|
res.reserve(map.getNumResults());
|
||||||
|
unsigned numDims = map.getNumDims(), numSym = map.getNumSymbols();
|
||||||
|
// For each `expr` in `map`, applies the `expr` to the values extracted from
|
||||||
|
// ranges. If the resulting application can be folded into a Value, the
|
||||||
|
// folding occurs eagerly.
|
||||||
|
for (auto expr : map.getResults()) {
|
||||||
|
AffineMap map = AffineMap::get(numDims, numSym, expr);
|
||||||
|
res.push_back(createFoldedComposedAffineApply(b, loc, map, values));
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 4> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
|
||||||
|
Location loc) {
|
||||||
|
SmallVector<Value, 4> res;
|
||||||
|
for (Value v : getShapedOperands()) {
|
||||||
|
ShapedType t = v.getType().template cast<ShapedType>();
|
||||||
|
for (unsigned i = 0, e = t.getRank(); i < e; ++i)
|
||||||
|
res.push_back(b.create<DimOp>(loc, v, i));
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
|
||||||
|
AffineMap map = getLoopsToShapesMap();
|
||||||
|
unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
|
||||||
|
auto viewSizes = createFlatListOfOperandDims(b, loc);
|
||||||
|
SmallVector<Range, 4> res(numDims);
|
||||||
|
Value zeroVal = b.create<ConstantIndexOp>(loc, 0);
|
||||||
|
Value oneVal = b.create<ConstantIndexOp>(loc, 1);
|
||||||
|
for (unsigned idx = 0; idx < numRes; ++idx) {
|
||||||
|
auto result = map.getResult(idx);
|
||||||
|
if (auto d = result.dyn_cast<AffineDimExpr>()) {
|
||||||
|
if (res[d.getPosition()].offset)
|
||||||
|
continue;
|
||||||
|
res[d.getPosition()] = Range{zeroVal, viewSizes[idx], oneVal};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Visitor to check if any of the given set of positions from AffineDimExprs
|
||||||
|
/// are used within an AffineExpr.
|
||||||
|
struct HasAffineDimExprVisitor
|
||||||
|
: public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
|
||||||
|
HasAffineDimExprVisitor(llvm::SmallSet<unsigned, 4> &positions)
|
||||||
|
: positions(positions) {}
|
||||||
|
|
||||||
|
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) {
|
||||||
|
return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
|
||||||
|
}
|
||||||
|
|
||||||
|
bool visitDimExpr(AffineDimExpr dimExpr) {
|
||||||
|
return positions.count(dimExpr.getPosition());
|
||||||
|
}
|
||||||
|
|
||||||
|
bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
|
||||||
|
|
||||||
|
bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
llvm::SmallSet<unsigned, 4> positions;
|
||||||
|
};
|
||||||
|
|
||||||
|
Optional<Value> LinalgOp::inferResultDimFromInputShapes(OpBuilder &b,
|
||||||
|
Location loc,
|
||||||
|
unsigned resultIdx,
|
||||||
|
unsigned dim) {
|
||||||
|
// An example that helps understand the logic below.
|
||||||
|
// Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
|
||||||
|
// We want to express the shape of dim 0 of O in terms of shape of the inputs.
|
||||||
|
// This is achieved as follows.
|
||||||
|
// loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
|
||||||
|
// subMapOfResultDim = (d0, d1, d2) -> (d0 + d1)
|
||||||
|
// shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
|
||||||
|
// resultFromFromInputDim = subMapOfResultDim.compose(shapesToLoopMap)
|
||||||
|
// = (d0, d1, d2, d3, d4, d5) -> (d0 + d1)
|
||||||
|
AffineMap loopsToShapesMap = getLoopsToShapesMap();
|
||||||
|
|
||||||
|
// Find the position in the above map that represents the shape of the
|
||||||
|
// result:dim being inferred.
|
||||||
|
Optional<unsigned> resultDimSubMapPos =
|
||||||
|
getResultValueDimPositionInLoopsToShapeMap(resultIdx, dim);
|
||||||
|
if (!resultDimSubMapPos)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
/// From loopsToShapesMap extract the submap that represents the shape of the
|
||||||
|
/// (resultIdx, dim) needed
|
||||||
|
AffineMap loopToResultDimShapeMap =
|
||||||
|
loopsToShapesMap.getSubMap(*resultDimSubMapPos);
|
||||||
|
AffineMap operandShapesToResultDimMap =
|
||||||
|
loopToResultDimShapeMap.compose(getShapesToLoopsMap());
|
||||||
|
|
||||||
|
// Check that the result dim map does not contain the positions corresponding
|
||||||
|
// to the outputs.
|
||||||
|
llvm::SmallSet<unsigned, 4> outputDims;
|
||||||
|
unsigned outputDimPosStart =
|
||||||
|
getResultValueDimPositionInLoopsToShapeMap(0, 0).getValue();
|
||||||
|
unsigned outputDimPosEnd =
|
||||||
|
getResultValueDimPositionInLoopsToShapeMap(getNumOutputs() - 1,
|
||||||
|
getOutputOpOperands()
|
||||||
|
.back()
|
||||||
|
.get()
|
||||||
|
.getType()
|
||||||
|
.cast<ShapedType>()
|
||||||
|
.getRank() -
|
||||||
|
1)
|
||||||
|
.getValue();
|
||||||
|
llvm::for_each(llvm::seq<unsigned>(outputDimPosStart, outputDimPosEnd),
|
||||||
|
[&outputDims](unsigned dim) { outputDims.insert(dim); });
|
||||||
|
HasAffineDimExprVisitor checkDimExpr(outputDims);
|
||||||
|
if (checkDimExpr.visit(operandShapesToResultDimMap.getResult(0)))
|
||||||
|
return llvm::None;
|
||||||
|
return applyMapToValues(b, loc, operandShapesToResultDimMap,
|
||||||
|
createFlatListOfOperandDims(b, loc))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
|
||||||
|
LinalgOp linalgOp = cast<LinalgOp>(op);
|
||||||
|
// Expect at least one shaped operand.
|
||||||
|
// This means an op that constructs a tensor out of indices cannot be a
|
||||||
|
// LinalgOp at the moment. For now this will have to be a special op until we
|
||||||
|
// have output shape operands that are not tensors.
|
||||||
|
auto nShapedOperands = linalgOp.getNumShapedOperands();
|
||||||
|
if (nShapedOperands == 0)
|
||||||
|
return linalgOp.emitOpError("expected at least 1 Shaped operand");
|
||||||
|
if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nShapedOperands)))
|
||||||
|
return failure();
|
||||||
|
// Should have at least one output tensor per result tensor.
|
||||||
|
// Can also have outbut buffers that do not correspond to results.
|
||||||
|
if (op->getNumResults() > linalgOp.getNumOutputTensors())
|
||||||
|
return op->emitError("unexpected #results > #outputs");
|
||||||
|
|
||||||
|
// All shaped operands must be indexed.
|
||||||
|
if (linalgOp.indexing_maps().size() != linalgOp.getNumShapedOperands())
|
||||||
|
return linalgOp.emitOpError("expected the number of indexing_map (")
|
||||||
|
<< linalgOp.indexing_maps().size()
|
||||||
|
<< ") to be equal to the number of shaped operands ("
|
||||||
|
<< linalgOp.getNumShapedOperands() << ")";
|
||||||
|
|
||||||
|
SmallVector<AffineMap, 4> indexingMaps;
|
||||||
|
indexingMaps.reserve(linalgOp.indexing_maps().size());
|
||||||
|
for (auto en : llvm::enumerate(linalgOp.indexing_maps())) {
|
||||||
|
auto idx = en.index();
|
||||||
|
auto m = en.value().template cast<AffineMapAttr>().getValue();
|
||||||
|
indexingMaps.push_back(m); // Save reference to map for further checks.
|
||||||
|
auto shapedValue = linalgOp.getShapedType(idx);
|
||||||
|
|
||||||
|
// Symbols disallowed.
|
||||||
|
if (m.getNumSymbols() != 0)
|
||||||
|
return linalgOp.emitOpError("unexpected symbols in indexing_map #")
|
||||||
|
<< idx;
|
||||||
|
|
||||||
|
// Domain must be consistent.
|
||||||
|
auto nLoops = linalgOp.getNumLoops();
|
||||||
|
if (m.getNumDims() != nLoops)
|
||||||
|
return linalgOp.emitOpError("expected indexing_map #")
|
||||||
|
<< idx << " to have " << nLoops
|
||||||
|
<< " dim(s) to match the number of loops";
|
||||||
|
|
||||||
|
if (m.getNumResults() != shapedValue.getRank())
|
||||||
|
return linalgOp.emitOpError("expected shaped value rank (")
|
||||||
|
<< shapedValue.getRank()
|
||||||
|
<< ") to match the result rank of indexing_map #" << idx << " ("
|
||||||
|
<< m.getNumResults() << ")";
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<AffineExpr, 4> redDims;
|
||||||
|
linalgOp.getReductionDims(redDims);
|
||||||
|
|
||||||
|
// Simplifying assumption: either full tensor or full buffer mode.
|
||||||
|
// This allows simpler verification of output operands vs result types
|
||||||
|
// without premature tracking of which operand is what in mixed-mode.
|
||||||
|
// TODO: relax when mixed-mode needs to pass verification.
|
||||||
|
if (linalgOp.getNumOutputBuffers() > 0 && linalgOp.getNumOutputTensors() > 0)
|
||||||
|
return op->emitError("expected output operands to all have tensor type or "
|
||||||
|
"all have buffer type");
|
||||||
|
|
||||||
|
for (auto it :
|
||||||
|
llvm::zip(linalgOp.getOutputOpOperands(), op->getResultTypes())) {
|
||||||
|
if (!std::get<0>(it).get().getType().isa<RankedTensorType>())
|
||||||
|
continue;
|
||||||
|
if (std::get<0>(it).get().getType() != std::get<1>(it))
|
||||||
|
return op->emitError("expected type of operand #")
|
||||||
|
<< std::get<0>(it).getOperandNumber() << " ("
|
||||||
|
<< std::get<0>(it).get().getType() << ")"
|
||||||
|
<< " to match type of corresponding result (" << std::get<1>(it)
|
||||||
|
<< ")";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output tensor indexing map may not depend on reduction indices.
|
||||||
|
for (OpOperand &opOperand : linalgOp.getOutputOpOperands()) {
|
||||||
|
AffineMap outputMap = linalgOp.getIndexingMap(opOperand.getOperandNumber());
|
||||||
|
for (auto expr : outputMap.getResults()) {
|
||||||
|
for (auto dim : redDims) {
|
||||||
|
unsigned pos = dim.cast<AffineDimExpr>().getPosition();
|
||||||
|
if (expr.isFunctionOfDim(pos)) {
|
||||||
|
std::string exprStr;
|
||||||
|
{
|
||||||
|
llvm::raw_string_ostream os(exprStr);
|
||||||
|
os << expr;
|
||||||
|
}
|
||||||
|
return op->emitError(
|
||||||
|
"unexpected output tensor expression in indexing map #")
|
||||||
|
<< (opOperand.getOperandNumber() - linalgOp.getNumInputs())
|
||||||
|
<< " a.k.a '" << exprStr
|
||||||
|
<< "' is function of reduction iterator 'd" << pos << "'";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Named ops that are defined manually have a region builder but no region at
|
||||||
|
// this time. Assume the region is well-formed by specification.
|
||||||
|
// TODO: use linalg-ods-gen for all ops when we have enough expressive power.
|
||||||
|
if (linalgOp->getNumRegions() == 0) {
|
||||||
|
assert(!linalgOp.getRegionBuilder() && "regionBuilder but no region");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ®ion = linalgOp->getRegion(0);
|
||||||
|
if (linalgOp->getNumRegions() > 1 || !llvm::hasSingleElement(region))
|
||||||
|
return op->emitOpError("expected 1 region with 1 block");
|
||||||
|
|
||||||
|
if (!linalgOp.getShapesToLoopsMap())
|
||||||
|
return op->emitOpError("expected the shape-to-loops map to be non-null");
|
||||||
|
|
||||||
|
// Simplifying assumption: bbargs match 1-1 with shape operands elemental
|
||||||
|
// types.
|
||||||
|
// TODO: once ranked shape types are plugged in, we may want to drop the
|
||||||
|
// corresponding bbargs, that can never be read from. This will be subject to
|
||||||
|
// consistency discussions (i.e. what to do with output tensors whose bbarg is
|
||||||
|
// not used).
|
||||||
|
Block &block = linalgOp->getRegion(0).front();
|
||||||
|
unsigned numBBIvs = linalgOp.getNumPayloadInductionVariables();
|
||||||
|
|
||||||
|
if (linalgOp.getNumShapedOperands() + numBBIvs != block.getNumArguments())
|
||||||
|
return op->emitError("expected as many non-induction variable region "
|
||||||
|
"arguments as the number of shaped operands");
|
||||||
|
|
||||||
|
// Note: the number and type of yield values are checked in the YieldOp.
|
||||||
|
for (unsigned i = 0; i < numBBIvs; ++i)
|
||||||
|
if (!block.getArgument(i).getType().isIndex())
|
||||||
|
return op->emitOpError("expected index block argument #") << i;
|
||||||
|
|
||||||
|
unsigned idx = 0;
|
||||||
|
for (auto it : llvm::zip(linalgOp.getShapedOperandTypes(),
|
||||||
|
block.getArguments().drop_front(numBBIvs))) {
|
||||||
|
if (std::get<0>(it).getElementType() != std::get<1>(it).getType())
|
||||||
|
return op->emitError("expected type of bb argument #")
|
||||||
|
<< (idx + numBBIvs) << " (" << std::get<1>(it).getType() << ")"
|
||||||
|
<< " to match element type of corresponding shaped operand ("
|
||||||
|
<< std::get<0>(it).getElementType() << ")";
|
||||||
|
++idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
|
@ -32,138 +32,6 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::linalg;
|
using namespace mlir::linalg;
|
||||||
|
|
||||||
/// Fully compose map with operands and canonicalize the result.
|
|
||||||
/// Return the `createOrFold`'ed AffineApply op.
|
|
||||||
static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc,
|
|
||||||
AffineMap map,
|
|
||||||
ValueRange operandsRef) {
|
|
||||||
SmallVector<Value, 4> operands(operandsRef.begin(), operandsRef.end());
|
|
||||||
fullyComposeAffineMapAndOperands(&map, &operands);
|
|
||||||
canonicalizeMapAndOperands(&map, &operands);
|
|
||||||
return b.createOrFold<AffineApplyOp>(loc, map, operands);
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value, 4> mlir::linalg::applyMapToValues(OpBuilder &b, Location loc,
|
|
||||||
AffineMap map,
|
|
||||||
ValueRange values) {
|
|
||||||
SmallVector<Value, 4> res;
|
|
||||||
res.reserve(map.getNumResults());
|
|
||||||
unsigned numDims = map.getNumDims(), numSym = map.getNumSymbols();
|
|
||||||
// For each `expr` in `map`, applies the `expr` to the values extracted from
|
|
||||||
// ranges. If the resulting application can be folded into a Value, the
|
|
||||||
// folding occurs eagerly.
|
|
||||||
for (auto expr : map.getResults()) {
|
|
||||||
AffineMap map = AffineMap::get(numDims, numSym, expr);
|
|
||||||
res.push_back(createFoldedComposedAffineApply(b, loc, map, values));
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value, 4> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
|
|
||||||
Location loc) {
|
|
||||||
SmallVector<Value, 4> res;
|
|
||||||
for (Value v : getShapedOperands()) {
|
|
||||||
ShapedType t = v.getType().template cast<ShapedType>();
|
|
||||||
for (unsigned i = 0, e = t.getRank(); i < e; ++i)
|
|
||||||
res.push_back(b.create<DimOp>(loc, v, i));
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
|
|
||||||
AffineMap map = getLoopsToShapesMap();
|
|
||||||
unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
|
|
||||||
auto viewSizes = createFlatListOfOperandDims(b, loc);
|
|
||||||
SmallVector<Range, 4> res(numDims);
|
|
||||||
Value zeroVal = b.create<ConstantIndexOp>(loc, 0);
|
|
||||||
Value oneVal = b.create<ConstantIndexOp>(loc, 1);
|
|
||||||
for (unsigned idx = 0; idx < numRes; ++idx) {
|
|
||||||
auto result = map.getResult(idx);
|
|
||||||
if (auto d = result.dyn_cast<AffineDimExpr>()) {
|
|
||||||
if (res[d.getPosition()].offset)
|
|
||||||
continue;
|
|
||||||
res[d.getPosition()] = Range{zeroVal, viewSizes[idx], oneVal};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Visitor to check if any of the given set of positions from AffineDimExprs
|
|
||||||
/// are used within an AffineExpr.
|
|
||||||
struct HasAffineDimExprVisitor
|
|
||||||
: public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
|
|
||||||
HasAffineDimExprVisitor(llvm::SmallSet<unsigned, 4> &positions)
|
|
||||||
: positions(positions) {}
|
|
||||||
|
|
||||||
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) {
|
|
||||||
return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
|
|
||||||
}
|
|
||||||
|
|
||||||
bool visitDimExpr(AffineDimExpr dimExpr) {
|
|
||||||
return positions.count(dimExpr.getPosition());
|
|
||||||
}
|
|
||||||
|
|
||||||
bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
|
|
||||||
|
|
||||||
bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
llvm::SmallSet<unsigned, 4> positions;
|
|
||||||
};
|
|
||||||
|
|
||||||
Optional<Value> LinalgOp::inferResultDimFromInputShapes(OpBuilder &b,
|
|
||||||
Location loc,
|
|
||||||
unsigned resultIdx,
|
|
||||||
unsigned dim) {
|
|
||||||
// An example that helps understand the logic below.
|
|
||||||
// Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
|
|
||||||
// We want to express the shape of dim 0 of O in terms of shape of the inputs.
|
|
||||||
// This is achieved as follows.
|
|
||||||
// loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
|
|
||||||
// subMapOfResultDim = (d0, d1, d2) -> (d0 + d1)
|
|
||||||
// shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
|
|
||||||
// resultFromFromInputDim = subMapOfResultDim.compose(shapesToLoopMap)
|
|
||||||
// = (d0, d1, d2, d3, d4, d5) -> (d0 + d1)
|
|
||||||
AffineMap loopsToShapesMap = getLoopsToShapesMap();
|
|
||||||
|
|
||||||
// Find the position in the above map that represents the shape of the
|
|
||||||
// result:dim being inferred.
|
|
||||||
Optional<unsigned> resultDimSubMapPos =
|
|
||||||
getResultValueDimPositionInLoopsToShapeMap(resultIdx, dim);
|
|
||||||
if (!resultDimSubMapPos)
|
|
||||||
return {};
|
|
||||||
|
|
||||||
/// From loopsToShapesMap extract the submap that represents the shape of the
|
|
||||||
/// (resultIdx, dim) needed
|
|
||||||
AffineMap loopToResultDimShapeMap =
|
|
||||||
loopsToShapesMap.getSubMap(*resultDimSubMapPos);
|
|
||||||
AffineMap operandShapesToResultDimMap =
|
|
||||||
loopToResultDimShapeMap.compose(getShapesToLoopsMap());
|
|
||||||
|
|
||||||
// Check that the result dim map does not contain the positions corresponding
|
|
||||||
// to the outputs.
|
|
||||||
llvm::SmallSet<unsigned, 4> outputDims;
|
|
||||||
unsigned outputDimPosStart =
|
|
||||||
getResultValueDimPositionInLoopsToShapeMap(0, 0).getValue();
|
|
||||||
unsigned outputDimPosEnd =
|
|
||||||
getResultValueDimPositionInLoopsToShapeMap(getNumOutputs() - 1,
|
|
||||||
getOutputOpOperands()
|
|
||||||
.back()
|
|
||||||
.get()
|
|
||||||
.getType()
|
|
||||||
.cast<ShapedType>()
|
|
||||||
.getRank() -
|
|
||||||
1)
|
|
||||||
.getValue();
|
|
||||||
llvm::for_each(llvm::seq<unsigned>(outputDimPosStart, outputDimPosEnd),
|
|
||||||
[&outputDims](unsigned dim) { outputDims.insert(dim); });
|
|
||||||
HasAffineDimExprVisitor checkDimExpr(outputDims);
|
|
||||||
if (checkDimExpr.visit(operandShapesToResultDimMap.getResult(0)))
|
|
||||||
return llvm::None;
|
|
||||||
return applyMapToValues(b, loc, operandShapesToResultDimMap,
|
|
||||||
createFlatListOfOperandDims(b, loc))[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Forward declarations.
|
/// Forward declarations.
|
||||||
template <typename NamedStructuredOpType>
|
template <typename NamedStructuredOpType>
|
||||||
static void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
|
static void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
|
||||||
|
@ -215,11 +83,6 @@ static LogicalResult foldMemRefCast(Operation *op) {
|
||||||
return success(folded);
|
return success(folded);
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////// Operations defined with Tablegen /////////////////////////
|
|
||||||
// For such operations that do not correspond to library calls (i.e. defined in
|
|
||||||
// LinalgOps.td), we define an overloaded `print` function and a
|
|
||||||
// parse`className` function.
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// FillOp
|
// FillOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -471,148 +334,6 @@ void IndexedGenericOp::getEffects(
|
||||||
getInputBuffers(), getOutputBuffers());
|
getInputBuffers(), getOutputBuffers());
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
|
|
||||||
LinalgOp linalgOp = cast<LinalgOp>(op);
|
|
||||||
// Expect at least one shaped operand.
|
|
||||||
// This means an op that constructs a tensor out of indices cannot be a
|
|
||||||
// LinalgOp at the moment. For now this will have to be a special op until we
|
|
||||||
// have output shape operands that are not tensors.
|
|
||||||
auto nShapedOperands = linalgOp.getNumShapedOperands();
|
|
||||||
if (nShapedOperands == 0)
|
|
||||||
return linalgOp.emitOpError("expected at least 1 Shaped operand");
|
|
||||||
if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nShapedOperands)))
|
|
||||||
return failure();
|
|
||||||
// Should have at least one output tensor per result tensor.
|
|
||||||
// Can also have outbut buffers that do not correspond to results.
|
|
||||||
if (op->getNumResults() > linalgOp.getNumOutputTensors())
|
|
||||||
return op->emitError("unexpected #results > #outputs");
|
|
||||||
|
|
||||||
// All shaped operands must be indexed.
|
|
||||||
if (linalgOp.indexing_maps().size() != linalgOp.getNumShapedOperands())
|
|
||||||
return linalgOp.emitOpError("expected the number of indexing_map (")
|
|
||||||
<< linalgOp.indexing_maps().size()
|
|
||||||
<< ") to be equal to the number of shaped operands ("
|
|
||||||
<< linalgOp.getNumShapedOperands() << ")";
|
|
||||||
|
|
||||||
SmallVector<AffineMap, 4> indexingMaps;
|
|
||||||
indexingMaps.reserve(linalgOp.indexing_maps().size());
|
|
||||||
for (auto en : llvm::enumerate(linalgOp.indexing_maps())) {
|
|
||||||
auto idx = en.index();
|
|
||||||
auto m = en.value().template cast<AffineMapAttr>().getValue();
|
|
||||||
indexingMaps.push_back(m); // Save reference to map for further checks.
|
|
||||||
auto shapedValue = linalgOp.getShapedType(idx);
|
|
||||||
|
|
||||||
// Symbols disallowed.
|
|
||||||
if (m.getNumSymbols() != 0)
|
|
||||||
return linalgOp.emitOpError("unexpected symbols in indexing_map #")
|
|
||||||
<< idx;
|
|
||||||
|
|
||||||
// Domain must be consistent.
|
|
||||||
auto nLoops = linalgOp.getNumLoops();
|
|
||||||
if (m.getNumDims() != nLoops)
|
|
||||||
return linalgOp.emitOpError("expected indexing_map #")
|
|
||||||
<< idx << " to have " << nLoops
|
|
||||||
<< " dim(s) to match the number of loops";
|
|
||||||
|
|
||||||
if (m.getNumResults() != shapedValue.getRank())
|
|
||||||
return linalgOp.emitOpError("expected shaped value rank (")
|
|
||||||
<< shapedValue.getRank()
|
|
||||||
<< ") to match the result rank of indexing_map #" << idx << " ("
|
|
||||||
<< m.getNumResults() << ")";
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<AffineExpr, 4> redDims;
|
|
||||||
linalgOp.getReductionDims(redDims);
|
|
||||||
|
|
||||||
// Simplifying assumption: either full tensor or full buffer mode.
|
|
||||||
// This allows simpler verification of output operands vs result types
|
|
||||||
// without premature tracking of which operand is what in mixed-mode.
|
|
||||||
// TODO: relax when mixed-mode needs to pass verification.
|
|
||||||
if (linalgOp.getNumOutputBuffers() > 0 && linalgOp.getNumOutputTensors() > 0)
|
|
||||||
return op->emitError("expected output operands to all have tensor type or "
|
|
||||||
"all have buffer type");
|
|
||||||
|
|
||||||
for (auto it :
|
|
||||||
llvm::zip(linalgOp.getOutputOpOperands(), op->getResultTypes())) {
|
|
||||||
if (!std::get<0>(it).get().getType().isa<RankedTensorType>())
|
|
||||||
continue;
|
|
||||||
if (std::get<0>(it).get().getType() != std::get<1>(it))
|
|
||||||
return op->emitError("expected type of operand #")
|
|
||||||
<< std::get<0>(it).getOperandNumber() << " ("
|
|
||||||
<< std::get<0>(it).get().getType() << ")"
|
|
||||||
<< " to match type of corresponding result (" << std::get<1>(it)
|
|
||||||
<< ")";
|
|
||||||
}
|
|
||||||
|
|
||||||
// Output tensor indexing map may not depend on reduction indices.
|
|
||||||
for (OpOperand &opOperand : linalgOp.getOutputOpOperands()) {
|
|
||||||
AffineMap outputMap = linalgOp.getIndexingMap(opOperand.getOperandNumber());
|
|
||||||
for (auto expr : outputMap.getResults()) {
|
|
||||||
for (auto dim : redDims) {
|
|
||||||
unsigned pos = dim.cast<AffineDimExpr>().getPosition();
|
|
||||||
if (expr.isFunctionOfDim(pos)) {
|
|
||||||
std::string exprStr;
|
|
||||||
{
|
|
||||||
llvm::raw_string_ostream os(exprStr);
|
|
||||||
os << expr;
|
|
||||||
}
|
|
||||||
return op->emitError(
|
|
||||||
"unexpected output tensor expression in indexing map #")
|
|
||||||
<< (opOperand.getOperandNumber() - linalgOp.getNumInputs())
|
|
||||||
<< " a.k.a '" << exprStr
|
|
||||||
<< "' is function of reduction iterator 'd" << pos << "'";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Named ops that are defined manually have a region builder but no region at
|
|
||||||
// this time. Assume the region is well-formed by specification.
|
|
||||||
// TODO: use linalg-ods-gen for all ops when we have enough expressive power.
|
|
||||||
if (linalgOp->getNumRegions() == 0) {
|
|
||||||
assert(!linalgOp.getRegionBuilder() && "regionBuilder but no region");
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto ®ion = linalgOp->getRegion(0);
|
|
||||||
if (linalgOp->getNumRegions() > 1 || !llvm::hasSingleElement(region))
|
|
||||||
return op->emitOpError("expected 1 region with 1 block");
|
|
||||||
|
|
||||||
if (!linalgOp.getShapesToLoopsMap())
|
|
||||||
return op->emitOpError("expected the shape-to-loops map to be non-null");
|
|
||||||
|
|
||||||
// Simplifying assumption: bbargs match 1-1 with shape operands elemental
|
|
||||||
// types.
|
|
||||||
// TODO: once ranked shape types are plugged in, we may want to drop the
|
|
||||||
// corresponding bbargs, that can never be read from. This will be subject to
|
|
||||||
// consistency discussions (i.e. what to do with output tensors whose bbarg is
|
|
||||||
// not used).
|
|
||||||
Block &block = linalgOp->getRegion(0).front();
|
|
||||||
unsigned numBBIvs = linalgOp.getNumPayloadInductionVariables();
|
|
||||||
|
|
||||||
if (linalgOp.getNumShapedOperands() + numBBIvs != block.getNumArguments())
|
|
||||||
return op->emitError("expected as many non-induction variable region "
|
|
||||||
"arguments as the number of shaped operands");
|
|
||||||
|
|
||||||
// Note: the number and type of yield values are checked in the YieldOp.
|
|
||||||
for (unsigned i = 0; i < numBBIvs; ++i)
|
|
||||||
if (!block.getArgument(i).getType().isIndex())
|
|
||||||
return op->emitOpError("expected index block argument #") << i;
|
|
||||||
|
|
||||||
unsigned idx = 0;
|
|
||||||
for (auto it : llvm::zip(linalgOp.getShapedOperandTypes(),
|
|
||||||
block.getArguments().drop_front(numBBIvs))) {
|
|
||||||
if (std::get<0>(it).getElementType() != std::get<1>(it).getType())
|
|
||||||
return op->emitError("expected type of bb argument #")
|
|
||||||
<< (idx + numBBIvs) << " (" << std::get<1>(it).getType() << ")"
|
|
||||||
<< " to match element type of corresponding shaped operand ("
|
|
||||||
<< std::get<0>(it).getElementType() << ")";
|
|
||||||
++idx;
|
|
||||||
}
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename GenericOpType>
|
template <typename GenericOpType>
|
||||||
|
@ -1901,8 +1622,6 @@ struct EraseDeadLinalgOp;
|
||||||
struct FoldTensorCastOp;
|
struct FoldTensorCastOp;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.cpp.inc"
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc"
|
#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc"
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
|
|
|
@ -1863,12 +1863,14 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
|
|
||||||
let extraClassDeclaration = [{{
|
let extraClassDeclaration = structuredOpsBaseDecls # [{{
|
||||||
// Auto-generated.
|
// Auto-generated.
|
||||||
ArrayAttr iterator_types();
|
ArrayAttr iterator_types();
|
||||||
ArrayAttr indexing_maps();
|
ArrayAttr indexing_maps();
|
||||||
static void regionBuilder(Block &block);
|
static void regionBuilder(Block &block);
|
||||||
static std::function<void(Block &)> getRegionBuilder() {{ return regionBuilder; }
|
static std::function<void(Block &)> getRegionBuilder() {{
|
||||||
|
return regionBuilder;
|
||||||
|
}
|
||||||
|
|
||||||
// Generic methods.
|
// Generic methods.
|
||||||
static unsigned getNumRegionArgs() {{ return {4}; }
|
static unsigned getNumRegionArgs() {{ return {4}; }
|
||||||
|
|
Loading…
Reference in New Issue