diff --git a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt index 09db72806565..14ee4ea6968a 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt @@ -45,8 +45,8 @@ add_public_tablegen_target(MLIRLinalgStructuredOpsIncGen) add_dependencies(MLIRLinalgStructuredOpsIncGen LinalgOdsGen) add_dependencies(mlir-headers MLIRLinalgStructuredOpsIncGen) -set(LLVM_TARGET_DEFINITIONS LinalgStructuredOpsInterface.td) -mlir_tablegen(LinalgStructuredOpsInterfaces.h.inc -gen-op-interface-decls) -mlir_tablegen(LinalgStructuredOpsInterfaces.cpp.inc -gen-op-interface-defs) -add_public_tablegen_target(MLIRLinalgStructuredOpsInterfaceIncGen) -add_dependencies(mlir-headers MLIRLinalgStructuredOpsInterfaceIncGen) +set(LLVM_TARGET_DEFINITIONS LinalgInterfaces.td) +mlir_tablegen(LinalgInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(LinalgInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRLinalgInterfacesIncGen) +add_dependencies(mlir-headers MLIRLinalgInterfacesIncGen) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h new file mode 100644 index 000000000000..e4fddd594580 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -0,0 +1,44 @@ +//===- LinalgInterface.h - Linalg operations interfaces -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the operation interfaces for Linalg operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_IR_LINALGINTERFACES_H_ +#define MLIR_DIALECT_LINALG_IR_LINALGINTERFACES_H_ + +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/ViewLikeInterface.h" + +namespace mlir { +namespace linalg { + +/// Returns the values obtained by applying `map` to the list of values. +SmallVector applyMapToValues(OpBuilder &b, Location loc, + AffineMap map, ValueRange values); + +namespace detail { + +/// Verify that `op` conforms to the invariants of StructuredOpInterface +LogicalResult verifyStructuredOpInterface(Operation *op); + +} // namespace detail +} // namespace linalg +} // namespace mlir + +#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.h.inc" + +/// Include the generated interface declarations. +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h.inc" + +#endif // MLIR_DIALECT_LINALG_IR_LINALGINTERFACES_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td similarity index 98% rename from mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td rename to mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 7f3839a02b2f..a38b04ca16b2 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -1,4 +1,4 @@ -//===- LinalgStructuredInterface.td- Linalg StructuredIfce -*- tablegen -*-===// +//===- LinalgInterfaces.td - Linalg Interfaces Declaration -*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,14 +6,14 @@ // //===----------------------------------------------------------------------===// // -// This is the definition file for the structured interface for Linalg ops. +// This is the definition file for the structured interface sfor Linalg ops. // //===----------------------------------------------------------------------===// -#ifndef LINALG_IR_STRUCTURED_OPS_INTERFACE -#define LINALG_IR_STRUCTURED_OPS_INTERFACE +#ifndef LINALG_IR_LINALGINTERFACES +#define LINALG_IR_LINALGINTERFACES -include "mlir/Dialect/Linalg/IR/LinalgBase.td" +include "mlir/IR/OpBase.td" // The linalg 'LinalgStructuredInterface' provides access to the 'LinalgOp' // interface. @@ -33,10 +33,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /*methodName=*/"getNumPayloadInductionVariables", /*args=*/(ins), /*methodBody=*/"", - /*defaultImplementation=*/[{ - return isa(this->getOperation()) ? - $_op.getNumLoops() : 0; - }] + /*defaultImplementation=*/"" >, //===------------------------------------------------------------------===// // Loop types handling. @@ -570,7 +567,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /*methodBody=*/"", /*defaultImplementation=*/[{ unsigned bbArgNumber = - getNumPayloadInductionVariables() + opOperand->getOperandNumber(); + $_op.getNumPayloadInductionVariables() + opOperand->getOperandNumber(); // Safeguard against the named linalg ops that are manually defined and // that only support buffer semantics: we should not be there. // Such ops have an empty regionBuilder and are not constructed with a @@ -1117,4 +1114,4 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { let verify = [{ return detail::verifyStructuredOpInterface($_op); }]; } -#endif // LINALG_IR_STRUCTURED_OPS_INTERFACE +#endif // LINALG_IR_LINALGINTERFACES diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h index 4075ddd12117..f75e3010d3c5 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -42,10 +42,6 @@ class PoolingSumOp; using LoopRangeBuilder = std::function(OpBuilder &, Location)>; -/// Returns the values obtained by applying `map` to the list of values. -SmallVector applyMapToValues(OpBuilder &b, Location loc, - AffineMap map, ValueRange values); - /// Provide a very simple inference procedure to build the loop ranges from the /// op and its operands. This only works with permutation affine maps and /// patterns of the form `(m, n)[s] -> (m + n - s floordiv 2)`. @@ -122,7 +118,7 @@ namespace linalg { class IndexedGenericOp; } // namespace linalg } // namespace mlir -#include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.h.inc" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #define GET_OP_CLASSES #include "mlir/Dialect/Linalg/IR/LinalgOps.h.inc" diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index c88e1201f84b..8988a3a11efd 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -15,7 +15,7 @@ #define LINALG_STRUCTURED_OPS include "mlir/Dialect/Linalg/IR/LinalgBase.td" -include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td" +include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -25,13 +25,22 @@ include "mlir/Interfaces/SideEffectInterfaces.td" // depending on the specific Linalg op. class LinalgStructuredBase_Op props> : Op {} + LinalgStructuredInterface])> { + code structuredOpsBaseDecls = [{ + // Return the number of induction variables in the basic block. This should + // always be 0 for index-free linalg ops. For IndexedGeneric, this must be + // equal to numLoops. + unsigned getNumPayloadInductionVariables() { + return isa(this->getOperation()) ? getNumLoops() : 0; + } + }]; +} class LinalgStructured_Op props> : LinalgStructuredBase_Op])> { - code libraryCallName = [{ + code structuredOpsDecls = structuredOpsBaseDecls # [{ std::string getLibraryCallName() { return generateLibraryCallName(getOperation()); } @@ -110,7 +119,7 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> { $_builder, $_state, input, output, AffineMapAttr(), AffineMapAttr()); }]>]; - let extraClassDeclaration = libraryCallName # [{ + let extraClassDeclaration = structuredOpsDecls # [{ ValueRange inputs() { return getOperands().take_front(); } ValueRange outputs() { return getOperands().take_back(); } @@ -155,7 +164,7 @@ def FillOp : LinalgStructured_Op<"fill", []> { let arguments = (ins AnyShaped:$output, AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value); let results = (outs Optional:$result); - let extraClassDeclaration = libraryCallName # [{ + let extraClassDeclaration = structuredOpsDecls # [{ ValueRange inputs() { return {}; } ValueRange outputs() { return getOperands().take_front(); } @@ -232,7 +241,7 @@ class PoolingBase_Op props> for both low and high in each of the dimensions, if not specified. }]; - code commonUtils = libraryCallName # [{ + code commonUtils = structuredOpsDecls # [{ int64_t getStride(unsigned i) { assert(i < getNumWindowLoops()); if (!strides().hasValue()) return 1; @@ -497,7 +506,7 @@ class GenericOpBase : LinalgStructuredBase_Op:$sparse); let results = (outs Variadic:$result_tensors); let regions = (region AnyRegion:$region); - let extraClassDeclaration = [{ + let extraClassDeclaration = structuredOpsBaseDecls # [{ SmallVector linalgTraitAttrNames() { return SmallVector{ getDocAttrName(), diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt index 3ed79a554b31..8522919bacb3 100644 --- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRLinalg + LinalgInterfaces.cpp LinalgOps.cpp LinalgTypes.cpp @@ -6,9 +7,9 @@ add_mlir_dialect_library(MLIRLinalg ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg DEPENDS + MLIRLinalgInterfacesIncGen MLIRLinalgOpsIncGen MLIRLinalgStructuredOpsIncGen - MLIRLinalgStructuredOpsInterfaceIncGen LINK_LIBS PUBLIC MLIRAffine diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp new file mode 100644 index 000000000000..f9b17dd38fe0 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -0,0 +1,294 @@ +//===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/IR/AffineExprVisitor.h" +#include "mlir/IR/AffineMap.h" +#include "llvm/ADT/SmallSet.h" + +using namespace mlir; +using namespace mlir::linalg; + +/// Include the definitions of the copy operation interface. +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc" + +/// Fully compose map with operands and canonicalize the result. +/// Return the `createOrFold`'ed AffineApply op. +static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc, + AffineMap map, + ValueRange operandsRef) { + SmallVector operands(operandsRef.begin(), operandsRef.end()); + fullyComposeAffineMapAndOperands(&map, &operands); + canonicalizeMapAndOperands(&map, &operands); + return b.createOrFold(loc, map, operands); +} + +SmallVector mlir::linalg::applyMapToValues(OpBuilder &b, Location loc, + AffineMap map, + ValueRange values) { + SmallVector res; + res.reserve(map.getNumResults()); + unsigned numDims = map.getNumDims(), numSym = map.getNumSymbols(); + // For each `expr` in `map`, applies the `expr` to the values extracted from + // ranges. If the resulting application can be folded into a Value, the + // folding occurs eagerly. + for (auto expr : map.getResults()) { + AffineMap map = AffineMap::get(numDims, numSym, expr); + res.push_back(createFoldedComposedAffineApply(b, loc, map, values)); + } + return res; +} + +SmallVector LinalgOp::createFlatListOfOperandDims(OpBuilder &b, + Location loc) { + SmallVector res; + for (Value v : getShapedOperands()) { + ShapedType t = v.getType().template cast(); + for (unsigned i = 0, e = t.getRank(); i < e; ++i) + res.push_back(b.create(loc, v, i)); + } + return res; +} + +SmallVector LinalgOp::createLoopRanges(OpBuilder &b, Location loc) { + AffineMap map = getLoopsToShapesMap(); + unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); + auto viewSizes = createFlatListOfOperandDims(b, loc); + SmallVector res(numDims); + Value zeroVal = b.create(loc, 0); + Value oneVal = b.create(loc, 1); + for (unsigned idx = 0; idx < numRes; ++idx) { + auto result = map.getResult(idx); + if (auto d = result.dyn_cast()) { + if (res[d.getPosition()].offset) + continue; + res[d.getPosition()] = Range{zeroVal, viewSizes[idx], oneVal}; + } + } + return res; +} + +/// Visitor to check if any of the given set of positions from AffineDimExprs +/// are used within an AffineExpr. +struct HasAffineDimExprVisitor + : public AffineExprVisitor { + HasAffineDimExprVisitor(llvm::SmallSet &positions) + : positions(positions) {} + + bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) { + return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS()); + } + + bool visitDimExpr(AffineDimExpr dimExpr) { + return positions.count(dimExpr.getPosition()); + } + + bool visitConstantExpr(AffineConstantExpr constExpr) { return false; } + + bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; } + +private: + llvm::SmallSet positions; +}; + +Optional LinalgOp::inferResultDimFromInputShapes(OpBuilder &b, + Location loc, + unsigned resultIdx, + unsigned dim) { + // An example that helps understand the logic below. + // Consider the following expression O(i+j, j) += A(i,k) * B(k, j) + // We want to express the shape of dim 0 of O in terms of shape of the inputs. + // This is achieved as follows. + // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1) + // subMapOfResultDim = (d0, d1, d2) -> (d0 + d1) + // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2) + // resultFromFromInputDim = subMapOfResultDim.compose(shapesToLoopMap) + // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1) + AffineMap loopsToShapesMap = getLoopsToShapesMap(); + + // Find the position in the above map that represents the shape of the + // result:dim being inferred. + Optional resultDimSubMapPos = + getResultValueDimPositionInLoopsToShapeMap(resultIdx, dim); + if (!resultDimSubMapPos) + return {}; + + /// From loopsToShapesMap extract the submap that represents the shape of the + /// (resultIdx, dim) needed + AffineMap loopToResultDimShapeMap = + loopsToShapesMap.getSubMap(*resultDimSubMapPos); + AffineMap operandShapesToResultDimMap = + loopToResultDimShapeMap.compose(getShapesToLoopsMap()); + + // Check that the result dim map does not contain the positions corresponding + // to the outputs. + llvm::SmallSet outputDims; + unsigned outputDimPosStart = + getResultValueDimPositionInLoopsToShapeMap(0, 0).getValue(); + unsigned outputDimPosEnd = + getResultValueDimPositionInLoopsToShapeMap(getNumOutputs() - 1, + getOutputOpOperands() + .back() + .get() + .getType() + .cast() + .getRank() - + 1) + .getValue(); + llvm::for_each(llvm::seq(outputDimPosStart, outputDimPosEnd), + [&outputDims](unsigned dim) { outputDims.insert(dim); }); + HasAffineDimExprVisitor checkDimExpr(outputDims); + if (checkDimExpr.visit(operandShapesToResultDimMap.getResult(0))) + return llvm::None; + return applyMapToValues(b, loc, operandShapesToResultDimMap, + createFlatListOfOperandDims(b, loc))[0]; +} + +LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { + LinalgOp linalgOp = cast(op); + // Expect at least one shaped operand. + // This means an op that constructs a tensor out of indices cannot be a + // LinalgOp at the moment. For now this will have to be a special op until we + // have output shape operands that are not tensors. + auto nShapedOperands = linalgOp.getNumShapedOperands(); + if (nShapedOperands == 0) + return linalgOp.emitOpError("expected at least 1 Shaped operand"); + if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nShapedOperands))) + return failure(); + // Should have at least one output tensor per result tensor. + // Can also have outbut buffers that do not correspond to results. + if (op->getNumResults() > linalgOp.getNumOutputTensors()) + return op->emitError("unexpected #results > #outputs"); + + // All shaped operands must be indexed. + if (linalgOp.indexing_maps().size() != linalgOp.getNumShapedOperands()) + return linalgOp.emitOpError("expected the number of indexing_map (") + << linalgOp.indexing_maps().size() + << ") to be equal to the number of shaped operands (" + << linalgOp.getNumShapedOperands() << ")"; + + SmallVector indexingMaps; + indexingMaps.reserve(linalgOp.indexing_maps().size()); + for (auto en : llvm::enumerate(linalgOp.indexing_maps())) { + auto idx = en.index(); + auto m = en.value().template cast().getValue(); + indexingMaps.push_back(m); // Save reference to map for further checks. + auto shapedValue = linalgOp.getShapedType(idx); + + // Symbols disallowed. + if (m.getNumSymbols() != 0) + return linalgOp.emitOpError("unexpected symbols in indexing_map #") + << idx; + + // Domain must be consistent. + auto nLoops = linalgOp.getNumLoops(); + if (m.getNumDims() != nLoops) + return linalgOp.emitOpError("expected indexing_map #") + << idx << " to have " << nLoops + << " dim(s) to match the number of loops"; + + if (m.getNumResults() != shapedValue.getRank()) + return linalgOp.emitOpError("expected shaped value rank (") + << shapedValue.getRank() + << ") to match the result rank of indexing_map #" << idx << " (" + << m.getNumResults() << ")"; + } + + SmallVector redDims; + linalgOp.getReductionDims(redDims); + + // Simplifying assumption: either full tensor or full buffer mode. + // This allows simpler verification of output operands vs result types + // without premature tracking of which operand is what in mixed-mode. + // TODO: relax when mixed-mode needs to pass verification. + if (linalgOp.getNumOutputBuffers() > 0 && linalgOp.getNumOutputTensors() > 0) + return op->emitError("expected output operands to all have tensor type or " + "all have buffer type"); + + for (auto it : + llvm::zip(linalgOp.getOutputOpOperands(), op->getResultTypes())) { + if (!std::get<0>(it).get().getType().isa()) + continue; + if (std::get<0>(it).get().getType() != std::get<1>(it)) + return op->emitError("expected type of operand #") + << std::get<0>(it).getOperandNumber() << " (" + << std::get<0>(it).get().getType() << ")" + << " to match type of corresponding result (" << std::get<1>(it) + << ")"; + } + + // Output tensor indexing map may not depend on reduction indices. + for (OpOperand &opOperand : linalgOp.getOutputOpOperands()) { + AffineMap outputMap = linalgOp.getIndexingMap(opOperand.getOperandNumber()); + for (auto expr : outputMap.getResults()) { + for (auto dim : redDims) { + unsigned pos = dim.cast().getPosition(); + if (expr.isFunctionOfDim(pos)) { + std::string exprStr; + { + llvm::raw_string_ostream os(exprStr); + os << expr; + } + return op->emitError( + "unexpected output tensor expression in indexing map #") + << (opOperand.getOperandNumber() - linalgOp.getNumInputs()) + << " a.k.a '" << exprStr + << "' is function of reduction iterator 'd" << pos << "'"; + } + } + } + } + + // Named ops that are defined manually have a region builder but no region at + // this time. Assume the region is well-formed by specification. + // TODO: use linalg-ods-gen for all ops when we have enough expressive power. + if (linalgOp->getNumRegions() == 0) { + assert(!linalgOp.getRegionBuilder() && "regionBuilder but no region"); + return success(); + } + + auto ®ion = linalgOp->getRegion(0); + if (linalgOp->getNumRegions() > 1 || !llvm::hasSingleElement(region)) + return op->emitOpError("expected 1 region with 1 block"); + + if (!linalgOp.getShapesToLoopsMap()) + return op->emitOpError("expected the shape-to-loops map to be non-null"); + + // Simplifying assumption: bbargs match 1-1 with shape operands elemental + // types. + // TODO: once ranked shape types are plugged in, we may want to drop the + // corresponding bbargs, that can never be read from. This will be subject to + // consistency discussions (i.e. what to do with output tensors whose bbarg is + // not used). + Block &block = linalgOp->getRegion(0).front(); + unsigned numBBIvs = linalgOp.getNumPayloadInductionVariables(); + + if (linalgOp.getNumShapedOperands() + numBBIvs != block.getNumArguments()) + return op->emitError("expected as many non-induction variable region " + "arguments as the number of shaped operands"); + + // Note: the number and type of yield values are checked in the YieldOp. + for (unsigned i = 0; i < numBBIvs; ++i) + if (!block.getArgument(i).getType().isIndex()) + return op->emitOpError("expected index block argument #") << i; + + unsigned idx = 0; + for (auto it : llvm::zip(linalgOp.getShapedOperandTypes(), + block.getArguments().drop_front(numBBIvs))) { + if (std::get<0>(it).getElementType() != std::get<1>(it).getType()) + return op->emitError("expected type of bb argument #") + << (idx + numBBIvs) << " (" << std::get<1>(it).getType() << ")" + << " to match element type of corresponding shaped operand (" + << std::get<0>(it).getElementType() << ")"; + ++idx; + } + + return success(); +} diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 7d2685f8166a..7a720d3e68bc 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -32,138 +32,6 @@ using namespace mlir; using namespace mlir::linalg; -/// Fully compose map with operands and canonicalize the result. -/// Return the `createOrFold`'ed AffineApply op. -static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc, - AffineMap map, - ValueRange operandsRef) { - SmallVector operands(operandsRef.begin(), operandsRef.end()); - fullyComposeAffineMapAndOperands(&map, &operands); - canonicalizeMapAndOperands(&map, &operands); - return b.createOrFold(loc, map, operands); -} - -SmallVector mlir::linalg::applyMapToValues(OpBuilder &b, Location loc, - AffineMap map, - ValueRange values) { - SmallVector res; - res.reserve(map.getNumResults()); - unsigned numDims = map.getNumDims(), numSym = map.getNumSymbols(); - // For each `expr` in `map`, applies the `expr` to the values extracted from - // ranges. If the resulting application can be folded into a Value, the - // folding occurs eagerly. - for (auto expr : map.getResults()) { - AffineMap map = AffineMap::get(numDims, numSym, expr); - res.push_back(createFoldedComposedAffineApply(b, loc, map, values)); - } - return res; -} - -SmallVector LinalgOp::createFlatListOfOperandDims(OpBuilder &b, - Location loc) { - SmallVector res; - for (Value v : getShapedOperands()) { - ShapedType t = v.getType().template cast(); - for (unsigned i = 0, e = t.getRank(); i < e; ++i) - res.push_back(b.create(loc, v, i)); - } - return res; -} - -SmallVector LinalgOp::createLoopRanges(OpBuilder &b, Location loc) { - AffineMap map = getLoopsToShapesMap(); - unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); - auto viewSizes = createFlatListOfOperandDims(b, loc); - SmallVector res(numDims); - Value zeroVal = b.create(loc, 0); - Value oneVal = b.create(loc, 1); - for (unsigned idx = 0; idx < numRes; ++idx) { - auto result = map.getResult(idx); - if (auto d = result.dyn_cast()) { - if (res[d.getPosition()].offset) - continue; - res[d.getPosition()] = Range{zeroVal, viewSizes[idx], oneVal}; - } - } - return res; -} - -/// Visitor to check if any of the given set of positions from AffineDimExprs -/// are used within an AffineExpr. -struct HasAffineDimExprVisitor - : public AffineExprVisitor { - HasAffineDimExprVisitor(llvm::SmallSet &positions) - : positions(positions) {} - - bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) { - return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS()); - } - - bool visitDimExpr(AffineDimExpr dimExpr) { - return positions.count(dimExpr.getPosition()); - } - - bool visitConstantExpr(AffineConstantExpr constExpr) { return false; } - - bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; } - -private: - llvm::SmallSet positions; -}; - -Optional LinalgOp::inferResultDimFromInputShapes(OpBuilder &b, - Location loc, - unsigned resultIdx, - unsigned dim) { - // An example that helps understand the logic below. - // Consider the following expression O(i+j, j) += A(i,k) * B(k, j) - // We want to express the shape of dim 0 of O in terms of shape of the inputs. - // This is achieved as follows. - // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1) - // subMapOfResultDim = (d0, d1, d2) -> (d0 + d1) - // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2) - // resultFromFromInputDim = subMapOfResultDim.compose(shapesToLoopMap) - // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1) - AffineMap loopsToShapesMap = getLoopsToShapesMap(); - - // Find the position in the above map that represents the shape of the - // result:dim being inferred. - Optional resultDimSubMapPos = - getResultValueDimPositionInLoopsToShapeMap(resultIdx, dim); - if (!resultDimSubMapPos) - return {}; - - /// From loopsToShapesMap extract the submap that represents the shape of the - /// (resultIdx, dim) needed - AffineMap loopToResultDimShapeMap = - loopsToShapesMap.getSubMap(*resultDimSubMapPos); - AffineMap operandShapesToResultDimMap = - loopToResultDimShapeMap.compose(getShapesToLoopsMap()); - - // Check that the result dim map does not contain the positions corresponding - // to the outputs. - llvm::SmallSet outputDims; - unsigned outputDimPosStart = - getResultValueDimPositionInLoopsToShapeMap(0, 0).getValue(); - unsigned outputDimPosEnd = - getResultValueDimPositionInLoopsToShapeMap(getNumOutputs() - 1, - getOutputOpOperands() - .back() - .get() - .getType() - .cast() - .getRank() - - 1) - .getValue(); - llvm::for_each(llvm::seq(outputDimPosStart, outputDimPosEnd), - [&outputDims](unsigned dim) { outputDims.insert(dim); }); - HasAffineDimExprVisitor checkDimExpr(outputDims); - if (checkDimExpr.visit(operandShapesToResultDimMap.getResult(0))) - return llvm::None; - return applyMapToValues(b, loc, operandShapesToResultDimMap, - createFlatListOfOperandDims(b, loc))[0]; -} - /// Forward declarations. template static void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder, @@ -215,11 +83,6 @@ static LogicalResult foldMemRefCast(Operation *op) { return success(folded); } -///////////////////// Operations defined with Tablegen ///////////////////////// -// For such operations that do not correspond to library calls (i.e. defined in -// LinalgOps.td), we define an overloaded `print` function and a -// parse`className` function. - //===----------------------------------------------------------------------===// // FillOp //===----------------------------------------------------------------------===// @@ -471,148 +334,6 @@ void IndexedGenericOp::getEffects( getInputBuffers(), getOutputBuffers()); } -LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { - LinalgOp linalgOp = cast(op); - // Expect at least one shaped operand. - // This means an op that constructs a tensor out of indices cannot be a - // LinalgOp at the moment. For now this will have to be a special op until we - // have output shape operands that are not tensors. - auto nShapedOperands = linalgOp.getNumShapedOperands(); - if (nShapedOperands == 0) - return linalgOp.emitOpError("expected at least 1 Shaped operand"); - if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nShapedOperands))) - return failure(); - // Should have at least one output tensor per result tensor. - // Can also have outbut buffers that do not correspond to results. - if (op->getNumResults() > linalgOp.getNumOutputTensors()) - return op->emitError("unexpected #results > #outputs"); - - // All shaped operands must be indexed. - if (linalgOp.indexing_maps().size() != linalgOp.getNumShapedOperands()) - return linalgOp.emitOpError("expected the number of indexing_map (") - << linalgOp.indexing_maps().size() - << ") to be equal to the number of shaped operands (" - << linalgOp.getNumShapedOperands() << ")"; - - SmallVector indexingMaps; - indexingMaps.reserve(linalgOp.indexing_maps().size()); - for (auto en : llvm::enumerate(linalgOp.indexing_maps())) { - auto idx = en.index(); - auto m = en.value().template cast().getValue(); - indexingMaps.push_back(m); // Save reference to map for further checks. - auto shapedValue = linalgOp.getShapedType(idx); - - // Symbols disallowed. - if (m.getNumSymbols() != 0) - return linalgOp.emitOpError("unexpected symbols in indexing_map #") - << idx; - - // Domain must be consistent. - auto nLoops = linalgOp.getNumLoops(); - if (m.getNumDims() != nLoops) - return linalgOp.emitOpError("expected indexing_map #") - << idx << " to have " << nLoops - << " dim(s) to match the number of loops"; - - if (m.getNumResults() != shapedValue.getRank()) - return linalgOp.emitOpError("expected shaped value rank (") - << shapedValue.getRank() - << ") to match the result rank of indexing_map #" << idx << " (" - << m.getNumResults() << ")"; - } - - SmallVector redDims; - linalgOp.getReductionDims(redDims); - - // Simplifying assumption: either full tensor or full buffer mode. - // This allows simpler verification of output operands vs result types - // without premature tracking of which operand is what in mixed-mode. - // TODO: relax when mixed-mode needs to pass verification. - if (linalgOp.getNumOutputBuffers() > 0 && linalgOp.getNumOutputTensors() > 0) - return op->emitError("expected output operands to all have tensor type or " - "all have buffer type"); - - for (auto it : - llvm::zip(linalgOp.getOutputOpOperands(), op->getResultTypes())) { - if (!std::get<0>(it).get().getType().isa()) - continue; - if (std::get<0>(it).get().getType() != std::get<1>(it)) - return op->emitError("expected type of operand #") - << std::get<0>(it).getOperandNumber() << " (" - << std::get<0>(it).get().getType() << ")" - << " to match type of corresponding result (" << std::get<1>(it) - << ")"; - } - - // Output tensor indexing map may not depend on reduction indices. - for (OpOperand &opOperand : linalgOp.getOutputOpOperands()) { - AffineMap outputMap = linalgOp.getIndexingMap(opOperand.getOperandNumber()); - for (auto expr : outputMap.getResults()) { - for (auto dim : redDims) { - unsigned pos = dim.cast().getPosition(); - if (expr.isFunctionOfDim(pos)) { - std::string exprStr; - { - llvm::raw_string_ostream os(exprStr); - os << expr; - } - return op->emitError( - "unexpected output tensor expression in indexing map #") - << (opOperand.getOperandNumber() - linalgOp.getNumInputs()) - << " a.k.a '" << exprStr - << "' is function of reduction iterator 'd" << pos << "'"; - } - } - } - } - - // Named ops that are defined manually have a region builder but no region at - // this time. Assume the region is well-formed by specification. - // TODO: use linalg-ods-gen for all ops when we have enough expressive power. - if (linalgOp->getNumRegions() == 0) { - assert(!linalgOp.getRegionBuilder() && "regionBuilder but no region"); - return success(); - } - - auto ®ion = linalgOp->getRegion(0); - if (linalgOp->getNumRegions() > 1 || !llvm::hasSingleElement(region)) - return op->emitOpError("expected 1 region with 1 block"); - - if (!linalgOp.getShapesToLoopsMap()) - return op->emitOpError("expected the shape-to-loops map to be non-null"); - - // Simplifying assumption: bbargs match 1-1 with shape operands elemental - // types. - // TODO: once ranked shape types are plugged in, we may want to drop the - // corresponding bbargs, that can never be read from. This will be subject to - // consistency discussions (i.e. what to do with output tensors whose bbarg is - // not used). - Block &block = linalgOp->getRegion(0).front(); - unsigned numBBIvs = linalgOp.getNumPayloadInductionVariables(); - - if (linalgOp.getNumShapedOperands() + numBBIvs != block.getNumArguments()) - return op->emitError("expected as many non-induction variable region " - "arguments as the number of shaped operands"); - - // Note: the number and type of yield values are checked in the YieldOp. - for (unsigned i = 0; i < numBBIvs; ++i) - if (!block.getArgument(i).getType().isIndex()) - return op->emitOpError("expected index block argument #") << i; - - unsigned idx = 0; - for (auto it : llvm::zip(linalgOp.getShapedOperandTypes(), - block.getArguments().drop_front(numBBIvs))) { - if (std::get<0>(it).getElementType() != std::get<1>(it).getType()) - return op->emitError("expected type of bb argument #") - << (idx + numBBIvs) << " (" << std::get<1>(it).getType() << ")" - << " to match element type of corresponding shaped operand (" - << std::get<0>(it).getElementType() << ")"; - ++idx; - } - - return success(); -} - namespace { template @@ -1901,8 +1622,6 @@ struct EraseDeadLinalgOp; struct FoldTensorCastOp; } // namespace -#include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.cpp.inc" - #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc" #define GET_OP_CLASSES diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp index 47841c840fe5..9bf763079470 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1863,12 +1863,14 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, let hasFolder = 1; let hasCanonicalizer = 1; - let extraClassDeclaration = [{{ + let extraClassDeclaration = structuredOpsBaseDecls # [{{ // Auto-generated. ArrayAttr iterator_types(); ArrayAttr indexing_maps(); static void regionBuilder(Block &block); - static std::function getRegionBuilder() {{ return regionBuilder; } + static std::function getRegionBuilder() {{ + return regionBuilder; + } // Generic methods. static unsigned getNumRegionArgs() {{ return {4}; }