From a43f216fd5764a0ebc8a5a92be2ad07bc9bcdf37 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Mon, 8 Apr 2019 22:37:37 -0700 Subject: [PATCH] Automated rollback of changelist 242546977. PiperOrigin-RevId: 242604949 --- .../Linalg2/include/linalg2/TensorOps-inl.h | 120 +++++++++++++++ .../Linalg2/include/linalg2/TensorOps.h | 133 ++++++++-------- .../examples/Linalg/Linalg2/lib/TensorOps.cpp | 116 +++----------- .../Linalg3/include/linalg3/TensorOps-inl.h | 145 ++++++++++++++++++ .../Linalg3/include/linalg3/TensorOps.h | 13 +- .../Linalg3/lib/ConvertToLLVMDialect.cpp | 2 - .../examples/Linalg/Linalg3/lib/TensorOps.cpp | 95 +----------- .../Linalg/Linalg3/lib/Transforms.cpp | 3 - .../Linalg/Linalg4/lib/Transforms.cpp | 26 ++-- 9 files changed, 374 insertions(+), 279 deletions(-) create mode 100644 mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps-inl.h create mode 100644 mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h diff --git a/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps-inl.h b/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps-inl.h new file mode 100644 index 000000000000..940f8d7d312c --- /dev/null +++ b/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps-inl.h @@ -0,0 +1,120 @@ +//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of +/// TensorOps by adding implementations as they are needed in the appropriate +/// step in the tutorial. +#ifndef LINALG2_TENSOROPS_INL_H_ +#define LINALG2_TENSOROPS_INL_H_ + +#include "linalg2/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" + +namespace linalg { + +template +mlir::Operation::operand_range +linalg::TensorContractionBase::getInputs() { + auto *op = static_cast(this)->getOperation(); + return {op->operand_begin(), op->operand_begin() + getNumInputs()}; +} + +template +mlir::Operation::operand_range +linalg::TensorContractionBase::getOutputs() { + auto *op = static_cast(this)->getOperation(); + return {op->operand_begin() + getNumInputs(), + op->operand_begin() + getNumInputs() + getNumOutputs()}; +} + +template +mlir::Operation::operand_range +linalg::TensorContractionBase::getInputsAndOutputs() { + return {getInputs().begin(), getOutputs().end()}; +} + +template +mlir::LogicalResult linalg::TensorContractionBase::verify() { + auto *concreteOp = static_cast(this)->getOperation(); + if (getNumInputs() <= 0) + concreteOp->emitOpError("expected at least one input"); + if (getNumOutputs() <= 0) + concreteOp->emitOpError("expected at least one output"); + if (concreteOp->getNumOperands() != getNumInputs() + getNumOutputs()) { + concreteOp->emitOpError("expected " + + llvm::Twine(getNumInputs() + getNumOutputs()) + + " operands"); + } + for (unsigned i = 0, e = getNumInputs(); i < e; ++i) { + if (!concreteOp->getOperand(i)->getType().template isa()) + return concreteOp->emitOpError("operand " + llvm::Twine(i) + + " not a ViewType"); + } + for (unsigned i = getNumInputs(), e = getNumInputs() + getNumOutputs(); i < e; + ++i) { + auto viewType = + concreteOp->getOperand(i)->getType().template dyn_cast(); + if (!viewType) + return concreteOp->emitOpError("operand " + llvm::Twine(i) + + " not a ViewType"); + if (viewType.getRank() != getNumParallelDims()) + return concreteOp->emitOpError("operand " + llvm::Twine(i) + + " must be of rank " + + llvm::Twine(getNumParallelDims())); + } + return mlir::success(); +} + +template +bool linalg::TensorContractionBase::parse( + mlir::OpAsmParser *parser, mlir::OperationState *result) { + llvm_unreachable("Parsing linalg dialect is not supported in this tutorial"); +} + +// A TensorContraction prints as: +// +// ```{.mlir} +// concrete_op_name (ssa-inputs, ssa-outputs) : output-view-types +// ``` +// +// for example: +// +// ``` +// linalg.matmul(%0, %1, %2) : view +// ``` +// +// Where %0, %1 and %2 are ssa-values of type ViewType. +template +void linalg::TensorContractionBase::print(mlir::OpAsmPrinter *p) { + *p << static_cast(this)->getOperationName() << "("; + auto *last = *std::prev(getInputsAndOutputs().end()); + for (auto *i : getInputsAndOutputs()) { + *p << *i << ((i == last) ? "" : ", "); + } + *p << ") : "; + auto *lastOutput = *std::prev(getOutputs().end()); + for (auto *o : getOutputs()) { + *p << o->getType() << ((o == lastOutput) ? "" : ","); + } +} + +} // namespace linalg + +#endif // LINALG2_TENSOROPS_INL_H_ diff --git a/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps.h b/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps.h index cac813e6824a..39e51f057d3a 100644 --- a/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps.h +++ b/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps.h @@ -29,29 +29,44 @@ namespace linalg { /// A generic TensorContraction base class which captures the generic behavior /// of tensor contraction operations (with broadcast). -class TensorContractionBase { -public: - virtual ~TensorContractionBase() {} +template class TensorContractionBase { +protected: + using TensorContractionBaseType = TensorContractionBase; + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + /// Generic implementation of hooks that should be called from `ConcreteType`s + mlir::LogicalResult verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + +public: ////////////////////////////////////////////////////////////////////////////// // Op-specific functionality. ////////////////////////////////////////////////////////////////////////////// - virtual llvm::StringRef getTensorContractionName() = 0; + TensorContractionBase() = default; mlir::Operation::operand_range getInputs(); mlir::Operation::operand_range getOutputs(); - mlir::Operation::operand_range getInputsAndOutputs() { - return {getInputs().begin(), getOutputs().end()}; - } + mlir::Operation::operand_range getInputsAndOutputs(); /// These are better as methods calling into the ConcreteOp instead of /// template parameters because methods allow more generic behavior and avoid /// specializing for number of arguments. All derived classes have /// `VariadicOperands` and a build method from both an ArrayRef /// and the proper number of mlir::Value*. - virtual unsigned getNumInputs() = 0; - virtual unsigned getNumOutputs() = 0; - virtual unsigned getNumParallelDims() = 0; - virtual unsigned getNumReductionDims() = 0; + unsigned getNumInputs() { + return static_cast(this)->numInputs; + }; + unsigned getNumOutputs() { + return static_cast(this)->numOutputs; + }; + unsigned getNumParallelDims() { + return static_cast(this)->numParallelDims; + }; + unsigned getNumReductionDims() { + return static_cast(this)->numReductionDims; + }; ////////////////////////////////////////////////////////////////////////////// // Used in Linalg3 and later. @@ -64,18 +79,13 @@ public: : getOutputView(viewIndex - getNumInputs()); } - /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a - /// loop over matvec). Does nothing by default. - virtual void writeAsFinerGrainTensorContraction() {} - /// Each op is responsible for declaring how it lowers itself to scalar form, /// given the enclosing parallel and reduction induction variables. /// `emitScalarImplementation` emits the scalar IR for the op in the nesting /// context of the innermost enclosing loop(i.e. `reductionIvs.back()` or /// `parallel.back()`). - virtual void - emitScalarImplementation(llvm::ArrayRef parallelIvs, - llvm::ArrayRef reductionIvs) {} + void emitScalarImplementation(llvm::ArrayRef parallelIvs, + llvm::ArrayRef reductionIvs); /// Represents a mapping from the loops to all the ranges of the operands. /// The operands and their ranges are in the order defined by the particular @@ -84,17 +94,17 @@ public: /// it explicitly is not expensive and generalizes to cases where an analysis /// is not available. For details, see the description of /// loopsToOperandRangeMaps in each ConcreteOp. - virtual llvm::SmallVector loopsToOperandRangeMaps() { - return llvm::SmallVector(); - } + llvm::SmallVector loopsToOperandRangeMaps(); }; /// Implements c = A * B where c is a scalar and A and B are 1-D vectors. -class DotOp : public TensorContractionBase, +class DotOp : public TensorContractionBase, public mlir::Op { public: using Op::Op; + using TensorContractionBaseType = + TensorContractionBase::TensorContractionBaseType; ////////////////////////////////////////////////////////////////////////////// // Hooks to customize the behavior of this op. @@ -113,28 +123,24 @@ public: ////////////////////////////////////////////////////////////////////////////// // Op-specific functionality. ////////////////////////////////////////////////////////////////////////////// - llvm::StringRef getTensorContractionName() override { - return getOperationName(); - } - unsigned getNumInputs() override { return 2; } - unsigned getNumOutputs() override { return 1; } - unsigned getNumParallelDims() override { return 0; } - unsigned getNumReductionDims() override { return 1; } + static constexpr unsigned numInputs = 2; + static constexpr unsigned numOutputs = 1; + static constexpr unsigned numParallelDims = 0; + static constexpr unsigned numReductionDims = 1; -#if LINALG_STEP > 2 ////////////////////////////////////////////////////////////////////////////// // Used in Linalg3 and later. ////////////////////////////////////////////////////////////////////////////// /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a /// loop over matvec). Does nothing by default. - void writeAsFinerGrainTensorContraction() override; + void writeAsFinerGrainTensorContraction(); /// Inputs to this map will be (%k) coming from enclosing loops. /// Therefore, the mapping to get back to A(K), B(K), C() is: /// (d0) -> (d0, d0)(%k) /// And the operands ranges are: /// (%k, %k) - llvm::SmallVector loopsToOperandRangeMaps() override; + llvm::SmallVector loopsToOperandRangeMaps(); /// Given an enclosing reduction loop with iv `r_i`, emits MLIR corresponding /// to: @@ -147,18 +153,18 @@ public: /// cond = (r_i == zero) /// scalarC = select(cond, zerof, C[]); /// C[] = scalarC + A[r_i] * B[r_i]; - void - emitScalarImplementation(llvm::ArrayRef parallelIvs, - llvm::ArrayRef reductionIvs) override; -#endif // LINALG_STEP + void emitScalarImplementation(llvm::ArrayRef parallelIvs, + llvm::ArrayRef reductionIvs); }; /// Implements C = A * B where A is a 2-D matrix and X and Y are 1-D vectors. -class MatvecOp : public TensorContractionBase, +class MatvecOp : public TensorContractionBase, public mlir::Op { public: using Op::Op; + using TensorContractionBaseType = + TensorContractionBase::TensorContractionBaseType; ////////////////////////////////////////////////////////////////////////////// // Hooks to customize the behavior of this op. @@ -177,28 +183,24 @@ public: ////////////////////////////////////////////////////////////////////////////// // Op-specific functionality. ////////////////////////////////////////////////////////////////////////////// - llvm::StringRef getTensorContractionName() override { - return getOperationName(); - } - unsigned getNumInputs() override { return 2; } - unsigned getNumOutputs() override { return 1; } - unsigned getNumParallelDims() override { return 1; } - unsigned getNumReductionDims() override { return 1; } + static constexpr unsigned numInputs = 2; + static constexpr unsigned numOutputs = 1; + static constexpr unsigned numParallelDims = 1; + static constexpr unsigned numReductionDims = 1; -#if LINALG_STEP > 2 ////////////////////////////////////////////////////////////////////////////// // Used in Linalg3 and later. ////////////////////////////////////////////////////////////////////////////// /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a /// loop over matvec). Does nothing by default. - void writeAsFinerGrainTensorContraction() override; + void writeAsFinerGrainTensorContraction(); /// Inputs to this map will be (%m, %k) coming from enclosing loops. /// Therefore, the mapping to get back to A(M, K), B(K), C(M) is: /// (d0, d1) -> (d0, d1, d1, d0)(%m, %k) /// And the operands ranges are: /// (%m, %k, %k, %m) - llvm::SmallVector loopsToOperandRangeMaps() override; + llvm::SmallVector loopsToOperandRangeMaps(); /// Given an enclosing parallel loop with iv `i` and an enclosing parallel /// loop with iv `r_j`, emits MLIR corresponding to: @@ -211,18 +213,18 @@ public: /// cond = (r_j == zero) /// scalarC = select(cond, zerof, C(i)); /// C(i) = scalarC + A(i, r_j) * B(r_j); - void - emitScalarImplementation(llvm::ArrayRef parallelIvs, - llvm::ArrayRef reductionIvs) override; -#endif // LINALG_STEP + void emitScalarImplementation(llvm::ArrayRef parallelIvs, + llvm::ArrayRef reductionIvs); }; /// Implements C = A * B on 2-D matrices. -class MatmulOp : public TensorContractionBase, +class MatmulOp : public TensorContractionBase, public mlir::Op { public: using Op::Op; + using TensorContractionBaseType = + TensorContractionBase::TensorContractionBaseType; ////////////////////////////////////////////////////////////////////////////// // Hooks to customize the behavior of this op. @@ -241,28 +243,24 @@ public: ////////////////////////////////////////////////////////////////////////////// // Op-specific functionality. ////////////////////////////////////////////////////////////////////////////// - llvm::StringRef getTensorContractionName() override { - return getOperationName(); - } - unsigned getNumInputs() override { return 2; } - unsigned getNumOutputs() override { return 1; } - unsigned getNumParallelDims() override { return 2; } - unsigned getNumReductionDims() override { return 1; } + static constexpr unsigned numInputs = 2; + static constexpr unsigned numOutputs = 1; + static constexpr unsigned numParallelDims = 2; + static constexpr unsigned numReductionDims = 1; -#if LINALG_STEP > 2 ////////////////////////////////////////////////////////////////////////////// // Used in Linalg3 and later. ////////////////////////////////////////////////////////////////////////////// /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a /// loop over matvec). Does nothing by default. - void writeAsFinerGrainTensorContraction() override; + void writeAsFinerGrainTensorContraction(); /// Inputs to this map will be (%m, %n, %k) coming from enclosing loops. /// Therefore, the mapping to get back to A(M, K), B(K, N), C(M, N) is: /// (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k) /// And the operands ranges are: /// (%m, %k, %k, %n, %m, %n) - llvm::SmallVector loopsToOperandRangeMaps() override; + llvm::SmallVector loopsToOperandRangeMaps(); /// Given a enclosing parallel loops with ivs `i` and `j`, and an enclosing /// reduction loop with iv `r_k`, emits MLIR corresponding to: @@ -275,12 +273,15 @@ public: /// cond = (r_k == zero) /// scalarC = select(cond, zerof, C[i, j]); /// C[i, j] = scalarC + A[i, r_k] * B[r_k, j]; - void - emitScalarImplementation(llvm::ArrayRef parallelIvs, - llvm::ArrayRef reductionIvs) override; -#endif // LINALG_STEP + void emitScalarImplementation(llvm::ArrayRef parallelIvs, + llvm::ArrayRef reductionIvs); }; } // namespace linalg +/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of +/// TensorOps by adding implementations as they are needed in the appropriate +/// step in the tutorial. +#include "linalg2/TensorOps-inl.h" + #endif // LINALG2_TENSOROPS_H_ diff --git a/mlir/examples/Linalg/Linalg2/lib/TensorOps.cpp b/mlir/examples/Linalg/Linalg2/lib/TensorOps.cpp index 6aeefc8d6a11..8a47e5d70eab 100644 --- a/mlir/examples/Linalg/Linalg2/lib/TensorOps.cpp +++ b/mlir/examples/Linalg/Linalg2/lib/TensorOps.cpp @@ -26,100 +26,14 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/StandardTypes.h" -#include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" +using llvm::ArrayRef; +using llvm::Twine; + using namespace mlir; using namespace linalg; -#define TENSOR_CONTRACTION_DISPATCH(FUNCTION_NAME) \ - if (getTensorContractionName() == MatmulOp::getOperationName()) { \ - return FUNCTION_NAME(static_cast(*this)); \ - } \ - if (getTensorContractionName() == MatvecOp::getOperationName()) { \ - return FUNCTION_NAME(static_cast(*this)); \ - } \ - if (getTensorContractionName() == DotOp::getOperationName()) { \ - return FUNCTION_NAME(static_cast(*this)); \ - } \ - llvm_unreachable("Missing linalg op"); - -template -static mlir::Operation::operand_range getInputs(ConcreteOp &concreteOp) { - return {concreteOp.operand_begin(), - concreteOp.operand_begin() + concreteOp.getNumInputs()}; -} - -mlir::Operation::operand_range linalg::TensorContractionBase::getInputs() { - TENSOR_CONTRACTION_DISPATCH(::getInputs); -} - -template -static mlir::Operation::operand_range getOutputs(ConcreteOp &concreteOp) { - return {concreteOp.operand_begin() + concreteOp.getNumInputs(), - concreteOp.operand_begin() + concreteOp.getNumInputs() + - concreteOp.getNumOutputs()}; -} - -mlir::Operation::operand_range linalg::TensorContractionBase::getOutputs() { - TENSOR_CONTRACTION_DISPATCH(::getOutputs); -} - -template -static mlir::LogicalResult verifyLinalgOp(LinalgOp op) { - if (op.getNumInputs() <= 0) - op.emitOpError("expected at least one input"); - if (op.getNumOutputs() <= 0) - op.emitOpError("expected at least one output"); - if (op.getNumOperands() != op.getNumInputs() + op.getNumOutputs()) { - op.emitOpError("expected " + - llvm::Twine(op.getNumInputs() + op.getNumOutputs()) + - " operands"); - } - for (unsigned i = 0, e = op.getNumInputs(); i < e; ++i) { - if (!op.getOperand(i)->getType().template isa()) - return op.emitOpError("operand " + llvm::Twine(i) + " not a ViewType"); - } - for (unsigned i = op.getNumInputs(), - e = op.getNumInputs() + op.getNumOutputs(); - i < e; ++i) { - auto viewType = op.getOperand(i)->getType().template dyn_cast(); - if (!viewType) - return op.emitOpError("operand " + llvm::Twine(i) + " not a ViewType"); - if (viewType.getRank() != op.getNumParallelDims()) - return op.emitOpError("operand " + llvm::Twine(i) + " must be of rank " + - llvm::Twine(op.getNumParallelDims())); - } - return mlir::success(); -} - -// A TensorContraction prints as: -// -// ```{.mlir} -// concrete_op_name (ssa-inputs, ssa-outputs) : output-view-types -// ``` -// -// for example: -// -// ``` -// linalg.matmul(%0, %1, %2) : view -// ``` -// -// Where %0, %1 and %2 are ssa-values of type ViewType. -template -static void printLinalgOp(mlir::OpAsmPrinter *p, LinalgOp op) { - *p << op.getOperationName() << "("; - auto *last = *std::prev(op.getInputsAndOutputs().end()); - for (auto *i : op.getInputsAndOutputs()) { - *p << *i << ((i == last) ? "" : ", "); - } - *p << ") : "; - auto *lastOutput = *std::prev(op.getOutputs().end()); - for (auto *o : op.getOutputs()) { - *p << o->getType() << ((o == lastOutput) ? "" : ","); - } -} - ////////////////////////////////////////////////////////////////////////////// // Op-specific Dot. ////////////////////////////////////////////////////////////////////////////// @@ -129,7 +43,7 @@ void linalg::DotOp::build(Builder *b, OperationState *result, } LogicalResult linalg::DotOp::verify() { - if (failed(verifyLinalgOp(*this))) + if (failed(TensorContractionBaseType::verify())) return failure(); auto *A = getOperand(0), *B = getOperand(1), *C = getOperand(2); unsigned index = 0; @@ -146,10 +60,12 @@ LogicalResult linalg::DotOp::verify() { // Parsing of the linalg dialect is not supported in this tutorial. bool linalg::DotOp::parse(mlir::OpAsmParser *parser, mlir::OperationState *result) { - llvm_unreachable("Parsing linalg dialect is not supported in this tutorial"); + return TensorContractionBaseType::parse(parser, result); } -void linalg::DotOp::print(mlir::OpAsmPrinter *p) { printLinalgOp(p, *this); } +void linalg::DotOp::print(mlir::OpAsmPrinter *p) { + TensorContractionBaseType::print(p); +} ////////////////////////////////////////////////////////////////////////////// // Op-specific Matvec. @@ -160,7 +76,7 @@ void linalg::MatvecOp::build(Builder *b, OperationState *result, } LogicalResult linalg::MatvecOp::verify() { - if (failed(verifyLinalgOp(*this))) + if (failed(TensorContractionBaseType::verify())) return failure(); auto *A = getOperand(0), *B = getOperand(1), *C = getOperand(2); if (getViewRank(A) != 2) @@ -178,10 +94,12 @@ LogicalResult linalg::MatvecOp::verify() { // Parsing of the linalg dialect is not supported in this tutorial. bool linalg::MatvecOp::parse(mlir::OpAsmParser *parser, mlir::OperationState *result) { - llvm_unreachable("Parsing linalg dialect is not supported in this tutorial"); + return TensorContractionBaseType::parse(parser, result); } -void linalg::MatvecOp::print(mlir::OpAsmPrinter *p) { printLinalgOp(p, *this); } +void linalg::MatvecOp::print(mlir::OpAsmPrinter *p) { + TensorContractionBaseType::print(p); +} ////////////////////////////////////////////////////////////////////////////// // Op-specific Matmul. @@ -192,7 +110,7 @@ void linalg::MatmulOp::build(Builder *b, OperationState *result, } LogicalResult linalg::MatmulOp::verify() { - if (failed(verifyLinalgOp(*this))) + if (failed(TensorContractionBaseType::verify())) return failure(); auto *A = getOperand(0), *B = getOperand(1), *C = getOperand(2); unsigned index = 0; @@ -207,7 +125,9 @@ LogicalResult linalg::MatmulOp::verify() { // Parsing of the linalg dialect is not supported in this tutorial. bool linalg::MatmulOp::parse(mlir::OpAsmParser *parser, mlir::OperationState *result) { - llvm_unreachable("Parsing linalg dialect is not supported in this tutorial"); + return TensorContractionBaseType::parse(parser, result); } -void linalg::MatmulOp::print(mlir::OpAsmPrinter *p) { printLinalgOp(p, *this); } +void linalg::MatmulOp::print(mlir::OpAsmPrinter *p) { + TensorContractionBaseType::print(p); +} diff --git a/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h b/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h new file mode 100644 index 000000000000..b65105344b1f --- /dev/null +++ b/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h @@ -0,0 +1,145 @@ +//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of +/// TensorOps by adding implementations as they are needed in the appropriate +/// step in the tutorial. +#ifndef LINALG3_TENSOROPS_INL_H_ +#define LINALG3_TENSOROPS_INL_H_ + +#include "linalg1/Common.h" +#include "linalg1/Utils.h" +#include "linalg2/TensorOps.h" +#include "linalg3/Analysis.h" +#include "linalg3/Ops.h" + +template +mlir::Value * +linalg::TensorContractionBase::getInputView(unsigned viewIndex) { + return *(getInputs().begin() + viewIndex); +} + +template +mlir::Value * +linalg::TensorContractionBase::getOutputView(unsigned viewIndex) { + return *(getOutputs().begin() + viewIndex); +} + +template +llvm::SmallVector +linalg::TensorContractionBase::loopsToOperandRangeMaps() { + return static_cast(this)->loopsToOperandRangeMaps(); +} + +template +void linalg::TensorContractionBase::emitScalarImplementation( + llvm::ArrayRef parallelIvs, + llvm::ArrayRef reductionIvs) { + static_cast(this)->emitScalarImplementation(parallelIvs, + reductionIvs); +} + +template +mlir::AffineMap linalg::operandRangesToLoopsMap( + linalg::TensorContractionBase &tensorContraction) { + mlir::AffineMap current; + // Individual submaps may not be invertible but their union must be invertible + // by construction. + for (auto m : tensorContraction.loopsToOperandRangeMaps()) { + if (!m) + continue; + if (!current) { + current = m; + continue; + } + llvm::SmallVector results(current.getResults().begin(), + current.getResults().end()); + results.append(m.getResults().begin(), m.getResults().end()); + current = mlir::AffineMap::get( + std::max(current.getNumDims(), m.getNumDims()), + current.getNumSymbols() + m.getNumSymbols(), results, {}); + } + return inverseSubMap(current); +} + +// Extract the ranges from a given ViewOp or SliceOp. +// +// In the case of a ViewOp, things are simple: just traverse the indexings and +// get all the ranges (i.e. drop the indices). +// +// In the case of a SliceOp, things are trickier because we need to handle a +// potential rank-reduction: +// 1. Examine the indexing to determine if it is rank-reducing. +// 2. If it is rank-reducing, an offset of 1 is added to the dimensions such +// that `d >= slicingDim`. This is to account for the rank reduction. +// `getRootIndex` is then called on the **parent** view +static llvm::SmallVector +extractRangesFromViewOrSliceOp(mlir::Value *view) { + // This expects a viewType which must come from either ViewOp or SliceOp. + assert(view->getType().isa() && "expected ViewType"); + if (auto viewOp = view->getDefiningOp()->dyn_cast()) + return viewOp.getRanges(); + + auto sliceOp = view->getDefiningOp()->cast(); + unsigned slicingDim = sliceOp.getSlicingDim(); + auto *indexing = *(sliceOp.getIndexings().begin()); + bool isRankReducing = indexing->getType().isa(); + unsigned offset = 0; + llvm::SmallVector res; + res.reserve(sliceOp.getRank()); + for (unsigned d = 0, e = sliceOp.getRank(); d < e; ++d) { + if (d == slicingDim && isRankReducing) + offset = 1; + auto *parentView = sliceOp.getParentView(); + auto indexingPosPair = linalg::getViewRootIndexing(parentView, d + offset); + res.push_back(indexingPosPair.first); + } + return res; +} + +template +static llvm::SmallVector +getInputRanges(linalg::TensorContractionBase &tensorContraction) { + llvm::SmallVector res; + for (auto *in : tensorContraction.getInputs()) { + auto subres = extractRangesFromViewOrSliceOp(in); + res.append(subres.begin(), subres.end()); + } + return res; +} + +template +static llvm::SmallVector +getOutputRanges(linalg::TensorContractionBase &tensorContraction) { + llvm::SmallVector res; + for (auto *out : tensorContraction.getOutputs()) { + auto subres = extractRangesFromViewOrSliceOp(out); + res.append(subres.begin(), subres.end()); + } + return res; +} + +template +llvm::SmallVector linalg::getRanges( + linalg::TensorContractionBase &tensorContraction) { + llvm::SmallVector res = getInputRanges(tensorContraction); + llvm::SmallVector tmp = getOutputRanges(tensorContraction); + res.append(tmp.begin(), tmp.end()); + return res; +} + +#endif // LINALG3_TENSOROPS_INL_H_ diff --git a/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps.h b/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps.h index cbb247cc6126..bf5a377d7ff5 100644 --- a/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps.h +++ b/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps.h @@ -29,8 +29,9 @@ namespace linalg { /// Takes a `tensorContraction` and a returns an AffineMap that can be used to /// map ranges to enclosing loops for all the operands' ranges. -mlir::AffineMap -operandRangesToLoopsMap(linalg::TensorContractionBase &tensorContraction); +template +mlir::AffineMap operandRangesToLoopsMap( + linalg::TensorContractionBase &tensorContraction); /// Takes a `tensorContraction` and returns the ranges of all its operands. /// When an operand comes from a ViewOp, things are simple: @@ -39,9 +40,15 @@ operandRangesToLoopsMap(linalg::TensorContractionBase &tensorContraction); /// In the case of a SliceOp, things are more involved because we need to handle /// potential rank-reductions. /// This function abstracts this complexity away and returns all the ranges. +template llvm::SmallVector -getRanges(linalg::TensorContractionBase &tensorContraction); +getRanges(linalg::TensorContractionBase &tensorContraction); } // namespace linalg +/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of +/// TensorOps by adding implementations as they are needed in the appropriate +/// step in the tutorial. +#include "linalg3/TensorOps-inl.h" + #endif // LINALG3_TENSOROPS_H_ diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp index 1acdf7aab46f..fd2afd90bb19 100644 --- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp @@ -27,9 +27,7 @@ #include "mlir/IR/Types.h" #include "mlir/LLVMIR/LLVMDialect.h" #include "mlir/LLVMIR/Transforms.h" -#include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/Passes.h" #include "linalg1/ConvertToLLVMDialect.h" #include "linalg1/LLVMIntrinsics.h" diff --git a/mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp b/mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp index 673c6863b2cf..a5b094c777e4 100644 --- a/mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp @@ -20,8 +20,8 @@ // //===----------------------------------------------------------------------===// +#include "linalg1/Analysis.h" #include "linalg1/Common.h" -#include "linalg3/Analysis.h" #include "linalg3/Intrinsics.h" #include "linalg3/Ops.h" #include "mlir/IR/Builders.h" @@ -36,99 +36,6 @@ using namespace mlir::edsc::intrinsics; using namespace linalg; using namespace linalg::intrinsics; -mlir::Value *linalg::TensorContractionBase::getInputView(unsigned viewIndex) { - return *(getInputs().begin() + viewIndex); -} - -mlir::Value *linalg::TensorContractionBase::getOutputView(unsigned viewIndex) { - return *(getOutputs().begin() + viewIndex); -} - -mlir::AffineMap linalg::operandRangesToLoopsMap( - linalg::TensorContractionBase &tensorContraction) { - mlir::AffineMap current; - // Individual submaps may not be invertible but their union must be invertible - // by construction. - for (auto m : tensorContraction.loopsToOperandRangeMaps()) { - if (!m) - continue; - if (!current) { - current = m; - continue; - } - llvm::SmallVector results(current.getResults().begin(), - current.getResults().end()); - results.append(m.getResults().begin(), m.getResults().end()); - current = mlir::AffineMap::get( - std::max(current.getNumDims(), m.getNumDims()), - current.getNumSymbols() + m.getNumSymbols(), results, {}); - } - return inverseSubMap(current); -} - -// Extract the ranges from a given ViewOp or SliceOp. -// -// In the case of a ViewOp, things are simple: just traverse the indexings and -// get all the ranges (i.e. drop the indices). -// -// In the case of a SliceOp, things are trickier because we need to handle a -// potential rank-reduction: -// 1. Examine the indexing to determine if it is rank-reducing. -// 2. If it is rank-reducing, an offset of 1 is added to the dimensions such -// that `d >= slicingDim`. This is to account for the rank reduction. -// `getRootIndex` is then called on the **parent** view -static llvm::SmallVector -extractRangesFromViewOrSliceOp(mlir::Value *view) { - // This expects a viewType which must come from either ViewOp or SliceOp. - assert(view->getType().isa() && "expected ViewType"); - if (auto viewOp = view->getDefiningOp()->dyn_cast()) - return viewOp.getRanges(); - - auto sliceOp = view->getDefiningOp()->cast(); - unsigned slicingDim = sliceOp.getSlicingDim(); - auto *indexing = *(sliceOp.getIndexings().begin()); - bool isRankReducing = indexing->getType().isa(); - unsigned offset = 0; - llvm::SmallVector res; - res.reserve(sliceOp.getRank()); - for (unsigned d = 0, e = sliceOp.getRank(); d < e; ++d) { - if (d == slicingDim && isRankReducing) - offset = 1; - auto *parentView = sliceOp.getParentView(); - auto indexingPosPair = linalg::getViewRootIndexing(parentView, d + offset); - res.push_back(indexingPosPair.first); - } - return res; -} - -static llvm::SmallVector -getInputRanges(linalg::TensorContractionBase &tensorContraction) { - llvm::SmallVector res; - for (auto *in : tensorContraction.getInputs()) { - auto subres = extractRangesFromViewOrSliceOp(in); - res.append(subres.begin(), subres.end()); - } - return res; -} - -static llvm::SmallVector -getOutputRanges(linalg::TensorContractionBase &tensorContraction) { - llvm::SmallVector res; - for (auto *out : tensorContraction.getOutputs()) { - auto subres = extractRangesFromViewOrSliceOp(out); - res.append(subres.begin(), subres.end()); - } - return res; -} - -llvm::SmallVector -linalg::getRanges(linalg::TensorContractionBase &tensorContraction) { - llvm::SmallVector res = getInputRanges(tensorContraction); - llvm::SmallVector tmp = getOutputRanges(tensorContraction); - res.append(tmp.begin(), tmp.end()); - return res; -} - ////////////////////////////////////////////////////////////////////////////// // Implementation of DotOp. ////////////////////////////////////////////////////////////////////////////// diff --git a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp index 3f1c36d49eb3..d9a56c6f4e1c 100644 --- a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp @@ -20,11 +20,8 @@ //===----------------------------------------------------------------------===// #include "linalg3/Transforms.h" -#include "linalg1/Common.h" #include "linalg2/Intrinsics.h" #include "linalg3/Ops.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp index 7a160890dd85..05865e9e53c7 100644 --- a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp @@ -20,14 +20,11 @@ //===----------------------------------------------------------------------===// #include "linalg4/Transforms.h" -#include "linalg1/Common.h" -#include "linalg1/Utils.h" #include "linalg3/Intrinsics.h" #include "linalg3/TensorOps.h" #include "mlir/AffineOps/AffineOps.h" #include "mlir/EDSC/Helpers.h" -#include "mlir/IR/AffineExpr.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Transforms/LoopUtils.h" @@ -59,25 +56,27 @@ static bool isZeroIndex(Value *v) { v->getDefiningOp()->dyn_cast().getValue() == 0; } +template static llvm::SmallVector -makeTiledRanges(TensorContractionBase &contraction, ArrayRef allRanges, - llvm::ArrayRef ivs, +makeTiledRanges(TensorContractionBase &contraction, + ArrayRef allRanges, llvm::ArrayRef ivs, llvm::ArrayRef tileSizes) { assert(ivs.size() == tileSizes.size()); if (ivs.empty()) return RangeParts(allRanges).makeRanges(); + auto *op = static_cast(&contraction); RangeParts result(allRanges.size()); RangeParts rangeParts(allRanges); - for (auto map : contraction.loopsToOperandRangeMaps()) { + for (auto map : op->loopsToOperandRangeMaps()) { // 1. Take the first ivs results of the map, the other ones are not composed // but merely copied over. assert(map.getNumSymbols() == 0); assert(map.getRangeSizes().empty()); MLIRContext *context = ScopedContext::getContext(); - unsigned numParallel = contraction.getNumParallelDims(); - unsigned numReduction = contraction.getNumReductionDims(); + unsigned numParallel = op->getNumParallelDims(); + unsigned numReduction = op->getNumReductionDims(); if (ivs.size() < numParallel + numReduction) { // Inject zeros in positions that are not tiled. SmallVector dimReplacements(numParallel + numReduction); @@ -121,8 +120,9 @@ makeTiledRanges(TensorContractionBase &contraction, ArrayRef allRanges, return result.makeRanges(); } +template static SmallVector -makeTiledViews(linalg::TensorContractionBase &contraction, +makeTiledViews(linalg::TensorContractionBase &contraction, ArrayRef ivs, ArrayRef tileSizes) { auto tiledRanges = makeTiledRanges(contraction, getRanges(contraction), ivs, tileSizes); @@ -141,15 +141,15 @@ makeTiledViews(linalg::TensorContractionBase &contraction, return res; } -template +template static SmallVector -writeContractionAsTiledViews(ConcreteOp &contraction, +writeContractionAsTiledViews(TensorContractionBase &contraction, ArrayRef tileSizes) { assert(tileSizes.size() <= contraction.getNumParallelDims() + contraction.getNumReductionDims()); - ScopedContext scope(mlir::FuncBuilder(contraction.getOperation()), - contraction.getLoc()); + auto *op = static_cast(&contraction); + ScopedContext scope(mlir::FuncBuilder(op->getOperation()), op->getLoc()); SmallVector ivs(tileSizes.size()); auto pivs = IndexHandle::makeIndexHandlePointers(ivs);