Automated rollback of changelist 242546977.

PiperOrigin-RevId: 242604949
This commit is contained in:
Mehdi Amini 2019-04-08 22:37:37 -07:00 committed by Mehdi Amini
parent 6b18e34de4
commit a43f216fd5
9 changed files with 374 additions and 279 deletions

View File

@ -0,0 +1,120 @@
//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
/// TensorOps by adding implementations as they are needed in the appropriate
/// step in the tutorial.
#ifndef LINALG2_TENSOROPS_INL_H_
#define LINALG2_TENSOROPS_INL_H_
#include "linalg2/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
namespace linalg {
template <class ConcreteOp>
mlir::Operation::operand_range
linalg::TensorContractionBase<ConcreteOp>::getInputs() {
auto *op = static_cast<ConcreteOp *>(this)->getOperation();
return {op->operand_begin(), op->operand_begin() + getNumInputs()};
}
template <class ConcreteOp>
mlir::Operation::operand_range
linalg::TensorContractionBase<ConcreteOp>::getOutputs() {
auto *op = static_cast<ConcreteOp *>(this)->getOperation();
return {op->operand_begin() + getNumInputs(),
op->operand_begin() + getNumInputs() + getNumOutputs()};
}
template <class ConcreteOp>
mlir::Operation::operand_range
linalg::TensorContractionBase<ConcreteOp>::getInputsAndOutputs() {
return {getInputs().begin(), getOutputs().end()};
}
template <class ConcreteOp>
mlir::LogicalResult linalg::TensorContractionBase<ConcreteOp>::verify() {
auto *concreteOp = static_cast<ConcreteOp *>(this)->getOperation();
if (getNumInputs() <= 0)
concreteOp->emitOpError("expected at least one input");
if (getNumOutputs() <= 0)
concreteOp->emitOpError("expected at least one output");
if (concreteOp->getNumOperands() != getNumInputs() + getNumOutputs()) {
concreteOp->emitOpError("expected " +
llvm::Twine(getNumInputs() + getNumOutputs()) +
" operands");
}
for (unsigned i = 0, e = getNumInputs(); i < e; ++i) {
if (!concreteOp->getOperand(i)->getType().template isa<ViewType>())
return concreteOp->emitOpError("operand " + llvm::Twine(i) +
" not a ViewType");
}
for (unsigned i = getNumInputs(), e = getNumInputs() + getNumOutputs(); i < e;
++i) {
auto viewType =
concreteOp->getOperand(i)->getType().template dyn_cast<ViewType>();
if (!viewType)
return concreteOp->emitOpError("operand " + llvm::Twine(i) +
" not a ViewType");
if (viewType.getRank() != getNumParallelDims())
return concreteOp->emitOpError("operand " + llvm::Twine(i) +
" must be of rank " +
llvm::Twine(getNumParallelDims()));
}
return mlir::success();
}
template <class ConcreteOp>
bool linalg::TensorContractionBase<ConcreteOp>::parse(
mlir::OpAsmParser *parser, mlir::OperationState *result) {
llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
}
// A TensorContraction prints as:
//
// ```{.mlir}
// concrete_op_name (ssa-inputs, ssa-outputs) : output-view-types
// ```
//
// for example:
//
// ```
// linalg.matmul(%0, %1, %2) : view<?x?xf32>
// ```
//
// Where %0, %1 and %2 are ssa-values of type ViewType.
template <class ConcreteOp>
void linalg::TensorContractionBase<ConcreteOp>::print(mlir::OpAsmPrinter *p) {
*p << static_cast<ConcreteOp *>(this)->getOperationName() << "(";
auto *last = *std::prev(getInputsAndOutputs().end());
for (auto *i : getInputsAndOutputs()) {
*p << *i << ((i == last) ? "" : ", ");
}
*p << ") : ";
auto *lastOutput = *std::prev(getOutputs().end());
for (auto *o : getOutputs()) {
*p << o->getType() << ((o == lastOutput) ? "" : ",");
}
}
} // namespace linalg
#endif // LINALG2_TENSOROPS_INL_H_

View File

@ -29,29 +29,44 @@ namespace linalg {
/// A generic TensorContraction base class which captures the generic behavior
/// of tensor contraction operations (with broadcast).
class TensorContractionBase {
public:
virtual ~TensorContractionBase() {}
template <class ConcreteOp> class TensorContractionBase {
protected:
using TensorContractionBaseType = TensorContractionBase<ConcreteOp>;
//////////////////////////////////////////////////////////////////////////////
// Hooks to customize the behavior of this op.
//////////////////////////////////////////////////////////////////////////////
/// Generic implementation of hooks that should be called from `ConcreteType`s
mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p);
public:
//////////////////////////////////////////////////////////////////////////////
// Op-specific functionality.
//////////////////////////////////////////////////////////////////////////////
virtual llvm::StringRef getTensorContractionName() = 0;
TensorContractionBase() = default;
mlir::Operation::operand_range getInputs();
mlir::Operation::operand_range getOutputs();
mlir::Operation::operand_range getInputsAndOutputs() {
return {getInputs().begin(), getOutputs().end()};
}
mlir::Operation::operand_range getInputsAndOutputs();
/// These are better as methods calling into the ConcreteOp instead of
/// template parameters because methods allow more generic behavior and avoid
/// specializing for number of arguments. All derived classes have
/// `VariadicOperands` and a build method from both an ArrayRef<mlirValue*>
/// and the proper number of mlir::Value*.
virtual unsigned getNumInputs() = 0;
virtual unsigned getNumOutputs() = 0;
virtual unsigned getNumParallelDims() = 0;
virtual unsigned getNumReductionDims() = 0;
unsigned getNumInputs() {
return static_cast<ConcreteOp *>(this)->numInputs;
};
unsigned getNumOutputs() {
return static_cast<ConcreteOp *>(this)->numOutputs;
};
unsigned getNumParallelDims() {
return static_cast<ConcreteOp *>(this)->numParallelDims;
};
unsigned getNumReductionDims() {
return static_cast<ConcreteOp *>(this)->numReductionDims;
};
//////////////////////////////////////////////////////////////////////////////
// Used in Linalg3 and later.
@ -64,18 +79,13 @@ public:
: getOutputView(viewIndex - getNumInputs());
}
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
/// loop over matvec). Does nothing by default.
virtual void writeAsFinerGrainTensorContraction() {}
/// Each op is responsible for declaring how it lowers itself to scalar form,
/// given the enclosing parallel and reduction induction variables.
/// `emitScalarImplementation` emits the scalar IR for the op in the nesting
/// context of the innermost enclosing loop(i.e. `reductionIvs.back()` or
/// `parallel.back()`).
virtual void
emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
llvm::ArrayRef<mlir::Value *> reductionIvs) {}
void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
llvm::ArrayRef<mlir::Value *> reductionIvs);
/// Represents a mapping from the loops to all the ranges of the operands.
/// The operands and their ranges are in the order defined by the particular
@ -84,17 +94,17 @@ public:
/// it explicitly is not expensive and generalizes to cases where an analysis
/// is not available. For details, see the description of
/// loopsToOperandRangeMaps in each ConcreteOp.
virtual llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps() {
return llvm::SmallVector<mlir::AffineMap, 8>();
}
llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
};
/// Implements c = A * B where c is a scalar and A and B are 1-D vectors.
class DotOp : public TensorContractionBase,
class DotOp : public TensorContractionBase<DotOp>,
public mlir::Op<DotOp, mlir::OpTrait::VariadicOperands,
mlir::OpTrait::ZeroResult> {
public:
using Op::Op;
using TensorContractionBaseType =
TensorContractionBase::TensorContractionBaseType;
//////////////////////////////////////////////////////////////////////////////
// Hooks to customize the behavior of this op.
@ -113,28 +123,24 @@ public:
//////////////////////////////////////////////////////////////////////////////
// Op-specific functionality.
//////////////////////////////////////////////////////////////////////////////
llvm::StringRef getTensorContractionName() override {
return getOperationName();
}
unsigned getNumInputs() override { return 2; }
unsigned getNumOutputs() override { return 1; }
unsigned getNumParallelDims() override { return 0; }
unsigned getNumReductionDims() override { return 1; }
static constexpr unsigned numInputs = 2;
static constexpr unsigned numOutputs = 1;
static constexpr unsigned numParallelDims = 0;
static constexpr unsigned numReductionDims = 1;
#if LINALG_STEP > 2
//////////////////////////////////////////////////////////////////////////////
// Used in Linalg3 and later.
//////////////////////////////////////////////////////////////////////////////
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
/// loop over matvec). Does nothing by default.
void writeAsFinerGrainTensorContraction() override;
void writeAsFinerGrainTensorContraction();
/// Inputs to this map will be (%k) coming from enclosing loops.
/// Therefore, the mapping to get back to A(K), B(K), C() is:
/// (d0) -> (d0, d0)(%k)
/// And the operands ranges are:
/// (%k, %k)
llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps() override;
llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
/// Given an enclosing reduction loop with iv `r_i`, emits MLIR corresponding
/// to:
@ -147,18 +153,18 @@ public:
/// cond = (r_i == zero)
/// scalarC = select(cond, zerof, C[]);
/// C[] = scalarC + A[r_i] * B[r_i];
void
emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
llvm::ArrayRef<mlir::Value *> reductionIvs) override;
#endif // LINALG_STEP
void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
llvm::ArrayRef<mlir::Value *> reductionIvs);
};
/// Implements C = A * B where A is a 2-D matrix and X and Y are 1-D vectors.
class MatvecOp : public TensorContractionBase,
class MatvecOp : public TensorContractionBase<MatvecOp>,
public mlir::Op<MatvecOp, mlir::OpTrait::VariadicOperands,
mlir::OpTrait::ZeroResult> {
public:
using Op::Op;
using TensorContractionBaseType =
TensorContractionBase::TensorContractionBaseType;
//////////////////////////////////////////////////////////////////////////////
// Hooks to customize the behavior of this op.
@ -177,28 +183,24 @@ public:
//////////////////////////////////////////////////////////////////////////////
// Op-specific functionality.
//////////////////////////////////////////////////////////////////////////////
llvm::StringRef getTensorContractionName() override {
return getOperationName();
}
unsigned getNumInputs() override { return 2; }
unsigned getNumOutputs() override { return 1; }
unsigned getNumParallelDims() override { return 1; }
unsigned getNumReductionDims() override { return 1; }
static constexpr unsigned numInputs = 2;
static constexpr unsigned numOutputs = 1;
static constexpr unsigned numParallelDims = 1;
static constexpr unsigned numReductionDims = 1;
#if LINALG_STEP > 2
//////////////////////////////////////////////////////////////////////////////
// Used in Linalg3 and later.
//////////////////////////////////////////////////////////////////////////////
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
/// loop over matvec). Does nothing by default.
void writeAsFinerGrainTensorContraction() override;
void writeAsFinerGrainTensorContraction();
/// Inputs to this map will be (%m, %k) coming from enclosing loops.
/// Therefore, the mapping to get back to A(M, K), B(K), C(M) is:
/// (d0, d1) -> (d0, d1, d1, d0)(%m, %k)
/// And the operands ranges are:
/// (%m, %k, %k, %m)
llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps() override;
llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
/// Given an enclosing parallel loop with iv `i` and an enclosing parallel
/// loop with iv `r_j`, emits MLIR corresponding to:
@ -211,18 +213,18 @@ public:
/// cond = (r_j == zero)
/// scalarC = select(cond, zerof, C(i));
/// C(i) = scalarC + A(i, r_j) * B(r_j);
void
emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
llvm::ArrayRef<mlir::Value *> reductionIvs) override;
#endif // LINALG_STEP
void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
llvm::ArrayRef<mlir::Value *> reductionIvs);
};
/// Implements C = A * B on 2-D matrices.
class MatmulOp : public TensorContractionBase,
class MatmulOp : public TensorContractionBase<MatmulOp>,
public mlir::Op<MatmulOp, mlir::OpTrait::VariadicOperands,
mlir::OpTrait::ZeroResult> {
public:
using Op::Op;
using TensorContractionBaseType =
TensorContractionBase::TensorContractionBaseType;
//////////////////////////////////////////////////////////////////////////////
// Hooks to customize the behavior of this op.
@ -241,28 +243,24 @@ public:
//////////////////////////////////////////////////////////////////////////////
// Op-specific functionality.
//////////////////////////////////////////////////////////////////////////////
llvm::StringRef getTensorContractionName() override {
return getOperationName();
}
unsigned getNumInputs() override { return 2; }
unsigned getNumOutputs() override { return 1; }
unsigned getNumParallelDims() override { return 2; }
unsigned getNumReductionDims() override { return 1; }
static constexpr unsigned numInputs = 2;
static constexpr unsigned numOutputs = 1;
static constexpr unsigned numParallelDims = 2;
static constexpr unsigned numReductionDims = 1;
#if LINALG_STEP > 2
//////////////////////////////////////////////////////////////////////////////
// Used in Linalg3 and later.
//////////////////////////////////////////////////////////////////////////////
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
/// loop over matvec). Does nothing by default.
void writeAsFinerGrainTensorContraction() override;
void writeAsFinerGrainTensorContraction();
/// Inputs to this map will be (%m, %n, %k) coming from enclosing loops.
/// Therefore, the mapping to get back to A(M, K), B(K, N), C(M, N) is:
/// (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k)
/// And the operands ranges are:
/// (%m, %k, %k, %n, %m, %n)
llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps() override;
llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
/// Given a enclosing parallel loops with ivs `i` and `j`, and an enclosing
/// reduction loop with iv `r_k`, emits MLIR corresponding to:
@ -275,12 +273,15 @@ public:
/// cond = (r_k == zero)
/// scalarC = select(cond, zerof, C[i, j]);
/// C[i, j] = scalarC + A[i, r_k] * B[r_k, j];
void
emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
llvm::ArrayRef<mlir::Value *> reductionIvs) override;
#endif // LINALG_STEP
void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
llvm::ArrayRef<mlir::Value *> reductionIvs);
};
} // namespace linalg
/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
/// TensorOps by adding implementations as they are needed in the appropriate
/// step in the tutorial.
#include "linalg2/TensorOps-inl.h"
#endif // LINALG2_TENSOROPS_H_

View File

@ -26,100 +26,14 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
using llvm::ArrayRef;
using llvm::Twine;
using namespace mlir;
using namespace linalg;
#define TENSOR_CONTRACTION_DISPATCH(FUNCTION_NAME) \
if (getTensorContractionName() == MatmulOp::getOperationName()) { \
return FUNCTION_NAME(static_cast<MatmulOp &>(*this)); \
} \
if (getTensorContractionName() == MatvecOp::getOperationName()) { \
return FUNCTION_NAME(static_cast<MatvecOp &>(*this)); \
} \
if (getTensorContractionName() == DotOp::getOperationName()) { \
return FUNCTION_NAME(static_cast<DotOp &>(*this)); \
} \
llvm_unreachable("Missing linalg op");
template <typename ConcreteOp>
static mlir::Operation::operand_range getInputs(ConcreteOp &concreteOp) {
return {concreteOp.operand_begin(),
concreteOp.operand_begin() + concreteOp.getNumInputs()};
}
mlir::Operation::operand_range linalg::TensorContractionBase::getInputs() {
TENSOR_CONTRACTION_DISPATCH(::getInputs);
}
template <typename ConcreteOp>
static mlir::Operation::operand_range getOutputs(ConcreteOp &concreteOp) {
return {concreteOp.operand_begin() + concreteOp.getNumInputs(),
concreteOp.operand_begin() + concreteOp.getNumInputs() +
concreteOp.getNumOutputs()};
}
mlir::Operation::operand_range linalg::TensorContractionBase::getOutputs() {
TENSOR_CONTRACTION_DISPATCH(::getOutputs);
}
template <typename LinalgOp>
static mlir::LogicalResult verifyLinalgOp(LinalgOp op) {
if (op.getNumInputs() <= 0)
op.emitOpError("expected at least one input");
if (op.getNumOutputs() <= 0)
op.emitOpError("expected at least one output");
if (op.getNumOperands() != op.getNumInputs() + op.getNumOutputs()) {
op.emitOpError("expected " +
llvm::Twine(op.getNumInputs() + op.getNumOutputs()) +
" operands");
}
for (unsigned i = 0, e = op.getNumInputs(); i < e; ++i) {
if (!op.getOperand(i)->getType().template isa<ViewType>())
return op.emitOpError("operand " + llvm::Twine(i) + " not a ViewType");
}
for (unsigned i = op.getNumInputs(),
e = op.getNumInputs() + op.getNumOutputs();
i < e; ++i) {
auto viewType = op.getOperand(i)->getType().template dyn_cast<ViewType>();
if (!viewType)
return op.emitOpError("operand " + llvm::Twine(i) + " not a ViewType");
if (viewType.getRank() != op.getNumParallelDims())
return op.emitOpError("operand " + llvm::Twine(i) + " must be of rank " +
llvm::Twine(op.getNumParallelDims()));
}
return mlir::success();
}
// A TensorContraction prints as:
//
// ```{.mlir}
// concrete_op_name (ssa-inputs, ssa-outputs) : output-view-types
// ```
//
// for example:
//
// ```
// linalg.matmul(%0, %1, %2) : view<?x?xf32>
// ```
//
// Where %0, %1 and %2 are ssa-values of type ViewType.
template <typename LinalgOp>
static void printLinalgOp(mlir::OpAsmPrinter *p, LinalgOp op) {
*p << op.getOperationName() << "(";
auto *last = *std::prev(op.getInputsAndOutputs().end());
for (auto *i : op.getInputsAndOutputs()) {
*p << *i << ((i == last) ? "" : ", ");
}
*p << ") : ";
auto *lastOutput = *std::prev(op.getOutputs().end());
for (auto *o : op.getOutputs()) {
*p << o->getType() << ((o == lastOutput) ? "" : ",");
}
}
//////////////////////////////////////////////////////////////////////////////
// Op-specific Dot.
//////////////////////////////////////////////////////////////////////////////
@ -129,7 +43,7 @@ void linalg::DotOp::build(Builder *b, OperationState *result,
}
LogicalResult linalg::DotOp::verify() {
if (failed(verifyLinalgOp(*this)))
if (failed(TensorContractionBaseType::verify()))
return failure();
auto *A = getOperand(0), *B = getOperand(1), *C = getOperand(2);
unsigned index = 0;
@ -146,10 +60,12 @@ LogicalResult linalg::DotOp::verify() {
// Parsing of the linalg dialect is not supported in this tutorial.
bool linalg::DotOp::parse(mlir::OpAsmParser *parser,
mlir::OperationState *result) {
llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
return TensorContractionBaseType::parse(parser, result);
}
void linalg::DotOp::print(mlir::OpAsmPrinter *p) { printLinalgOp(p, *this); }
void linalg::DotOp::print(mlir::OpAsmPrinter *p) {
TensorContractionBaseType::print(p);
}
//////////////////////////////////////////////////////////////////////////////
// Op-specific Matvec.
@ -160,7 +76,7 @@ void linalg::MatvecOp::build(Builder *b, OperationState *result,
}
LogicalResult linalg::MatvecOp::verify() {
if (failed(verifyLinalgOp(*this)))
if (failed(TensorContractionBaseType::verify()))
return failure();
auto *A = getOperand(0), *B = getOperand(1), *C = getOperand(2);
if (getViewRank(A) != 2)
@ -178,10 +94,12 @@ LogicalResult linalg::MatvecOp::verify() {
// Parsing of the linalg dialect is not supported in this tutorial.
bool linalg::MatvecOp::parse(mlir::OpAsmParser *parser,
mlir::OperationState *result) {
llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
return TensorContractionBaseType::parse(parser, result);
}
void linalg::MatvecOp::print(mlir::OpAsmPrinter *p) { printLinalgOp(p, *this); }
void linalg::MatvecOp::print(mlir::OpAsmPrinter *p) {
TensorContractionBaseType::print(p);
}
//////////////////////////////////////////////////////////////////////////////
// Op-specific Matmul.
@ -192,7 +110,7 @@ void linalg::MatmulOp::build(Builder *b, OperationState *result,
}
LogicalResult linalg::MatmulOp::verify() {
if (failed(verifyLinalgOp(*this)))
if (failed(TensorContractionBaseType::verify()))
return failure();
auto *A = getOperand(0), *B = getOperand(1), *C = getOperand(2);
unsigned index = 0;
@ -207,7 +125,9 @@ LogicalResult linalg::MatmulOp::verify() {
// Parsing of the linalg dialect is not supported in this tutorial.
bool linalg::MatmulOp::parse(mlir::OpAsmParser *parser,
mlir::OperationState *result) {
llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
return TensorContractionBaseType::parse(parser, result);
}
void linalg::MatmulOp::print(mlir::OpAsmPrinter *p) { printLinalgOp(p, *this); }
void linalg::MatmulOp::print(mlir::OpAsmPrinter *p) {
TensorContractionBaseType::print(p);
}

View File

@ -0,0 +1,145 @@
//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
/// TensorOps by adding implementations as they are needed in the appropriate
/// step in the tutorial.
#ifndef LINALG3_TENSOROPS_INL_H_
#define LINALG3_TENSOROPS_INL_H_
#include "linalg1/Common.h"
#include "linalg1/Utils.h"
#include "linalg2/TensorOps.h"
#include "linalg3/Analysis.h"
#include "linalg3/Ops.h"
template <class ConcreteOp>
mlir::Value *
linalg::TensorContractionBase<ConcreteOp>::getInputView(unsigned viewIndex) {
return *(getInputs().begin() + viewIndex);
}
template <class ConcreteOp>
mlir::Value *
linalg::TensorContractionBase<ConcreteOp>::getOutputView(unsigned viewIndex) {
return *(getOutputs().begin() + viewIndex);
}
template <class ConcreteOp>
llvm::SmallVector<mlir::AffineMap, 8>
linalg::TensorContractionBase<ConcreteOp>::loopsToOperandRangeMaps() {
return static_cast<ConcreteOp *>(this)->loopsToOperandRangeMaps();
}
template <class ConcreteOp>
void linalg::TensorContractionBase<ConcreteOp>::emitScalarImplementation(
llvm::ArrayRef<mlir::Value *> parallelIvs,
llvm::ArrayRef<mlir::Value *> reductionIvs) {
static_cast<ConcreteOp *>(this)->emitScalarImplementation(parallelIvs,
reductionIvs);
}
template <class ConcreteOp>
mlir::AffineMap linalg::operandRangesToLoopsMap(
linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
mlir::AffineMap current;
// Individual submaps may not be invertible but their union must be invertible
// by construction.
for (auto m : tensorContraction.loopsToOperandRangeMaps()) {
if (!m)
continue;
if (!current) {
current = m;
continue;
}
llvm::SmallVector<mlir::AffineExpr, 8> results(current.getResults().begin(),
current.getResults().end());
results.append(m.getResults().begin(), m.getResults().end());
current = mlir::AffineMap::get(
std::max(current.getNumDims(), m.getNumDims()),
current.getNumSymbols() + m.getNumSymbols(), results, {});
}
return inverseSubMap(current);
}
// Extract the ranges from a given ViewOp or SliceOp.
//
// In the case of a ViewOp, things are simple: just traverse the indexings and
// get all the ranges (i.e. drop the indices).
//
// In the case of a SliceOp, things are trickier because we need to handle a
// potential rank-reduction:
// 1. Examine the indexing to determine if it is rank-reducing.
// 2. If it is rank-reducing, an offset of 1 is added to the dimensions such
// that `d >= slicingDim`. This is to account for the rank reduction.
// `getRootIndex` is then called on the **parent** view
static llvm::SmallVector<mlir::Value *, 8>
extractRangesFromViewOrSliceOp(mlir::Value *view) {
// This expects a viewType which must come from either ViewOp or SliceOp.
assert(view->getType().isa<linalg::ViewType>() && "expected ViewType");
if (auto viewOp = view->getDefiningOp()->dyn_cast<linalg::ViewOp>())
return viewOp.getRanges();
auto sliceOp = view->getDefiningOp()->cast<linalg::SliceOp>();
unsigned slicingDim = sliceOp.getSlicingDim();
auto *indexing = *(sliceOp.getIndexings().begin());
bool isRankReducing = indexing->getType().isa<mlir::IndexType>();
unsigned offset = 0;
llvm::SmallVector<mlir::Value *, 8> res;
res.reserve(sliceOp.getRank());
for (unsigned d = 0, e = sliceOp.getRank(); d < e; ++d) {
if (d == slicingDim && isRankReducing)
offset = 1;
auto *parentView = sliceOp.getParentView();
auto indexingPosPair = linalg::getViewRootIndexing(parentView, d + offset);
res.push_back(indexingPosPair.first);
}
return res;
}
template <class ConcreteOp>
static llvm::SmallVector<mlir::Value *, 8>
getInputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
llvm::SmallVector<mlir::Value *, 8> res;
for (auto *in : tensorContraction.getInputs()) {
auto subres = extractRangesFromViewOrSliceOp(in);
res.append(subres.begin(), subres.end());
}
return res;
}
template <class ConcreteOp>
static llvm::SmallVector<mlir::Value *, 8>
getOutputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
llvm::SmallVector<mlir::Value *, 8> res;
for (auto *out : tensorContraction.getOutputs()) {
auto subres = extractRangesFromViewOrSliceOp(out);
res.append(subres.begin(), subres.end());
}
return res;
}
template <class ConcreteOp>
llvm::SmallVector<mlir::Value *, 8> linalg::getRanges(
linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
llvm::SmallVector<mlir::Value *, 8> res = getInputRanges(tensorContraction);
llvm::SmallVector<mlir::Value *, 8> tmp = getOutputRanges(tensorContraction);
res.append(tmp.begin(), tmp.end());
return res;
}
#endif // LINALG3_TENSOROPS_INL_H_

View File

@ -29,8 +29,9 @@ namespace linalg {
/// Takes a `tensorContraction` and a returns an AffineMap that can be used to
/// map ranges to enclosing loops for all the operands' ranges.
mlir::AffineMap
operandRangesToLoopsMap(linalg::TensorContractionBase &tensorContraction);
template <class ConcreteOp>
mlir::AffineMap operandRangesToLoopsMap(
linalg::TensorContractionBase<ConcreteOp> &tensorContraction);
/// Takes a `tensorContraction` and returns the ranges of all its operands.
/// When an operand comes from a ViewOp, things are simple:
@ -39,9 +40,15 @@ operandRangesToLoopsMap(linalg::TensorContractionBase &tensorContraction);
/// In the case of a SliceOp, things are more involved because we need to handle
/// potential rank-reductions.
/// This function abstracts this complexity away and returns all the ranges.
template <class ConcreteOp>
llvm::SmallVector<mlir::Value *, 8>
getRanges(linalg::TensorContractionBase &tensorContraction);
getRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction);
} // namespace linalg
/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
/// TensorOps by adding implementations as they are needed in the appropriate
/// step in the tutorial.
#include "linalg3/TensorOps-inl.h"
#endif // LINALG3_TENSOROPS_H_

View File

@ -27,9 +27,7 @@
#include "mlir/IR/Types.h"
#include "mlir/LLVMIR/LLVMDialect.h"
#include "mlir/LLVMIR/Transforms.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "linalg1/ConvertToLLVMDialect.h"
#include "linalg1/LLVMIntrinsics.h"

View File

@ -20,8 +20,8 @@
//
//===----------------------------------------------------------------------===//
#include "linalg1/Analysis.h"
#include "linalg1/Common.h"
#include "linalg3/Analysis.h"
#include "linalg3/Intrinsics.h"
#include "linalg3/Ops.h"
#include "mlir/IR/Builders.h"
@ -36,99 +36,6 @@ using namespace mlir::edsc::intrinsics;
using namespace linalg;
using namespace linalg::intrinsics;
mlir::Value *linalg::TensorContractionBase::getInputView(unsigned viewIndex) {
return *(getInputs().begin() + viewIndex);
}
mlir::Value *linalg::TensorContractionBase::getOutputView(unsigned viewIndex) {
return *(getOutputs().begin() + viewIndex);
}
mlir::AffineMap linalg::operandRangesToLoopsMap(
linalg::TensorContractionBase &tensorContraction) {
mlir::AffineMap current;
// Individual submaps may not be invertible but their union must be invertible
// by construction.
for (auto m : tensorContraction.loopsToOperandRangeMaps()) {
if (!m)
continue;
if (!current) {
current = m;
continue;
}
llvm::SmallVector<mlir::AffineExpr, 8> results(current.getResults().begin(),
current.getResults().end());
results.append(m.getResults().begin(), m.getResults().end());
current = mlir::AffineMap::get(
std::max(current.getNumDims(), m.getNumDims()),
current.getNumSymbols() + m.getNumSymbols(), results, {});
}
return inverseSubMap(current);
}
// Extract the ranges from a given ViewOp or SliceOp.
//
// In the case of a ViewOp, things are simple: just traverse the indexings and
// get all the ranges (i.e. drop the indices).
//
// In the case of a SliceOp, things are trickier because we need to handle a
// potential rank-reduction:
// 1. Examine the indexing to determine if it is rank-reducing.
// 2. If it is rank-reducing, an offset of 1 is added to the dimensions such
// that `d >= slicingDim`. This is to account for the rank reduction.
// `getRootIndex` is then called on the **parent** view
static llvm::SmallVector<mlir::Value *, 8>
extractRangesFromViewOrSliceOp(mlir::Value *view) {
// This expects a viewType which must come from either ViewOp or SliceOp.
assert(view->getType().isa<linalg::ViewType>() && "expected ViewType");
if (auto viewOp = view->getDefiningOp()->dyn_cast<linalg::ViewOp>())
return viewOp.getRanges();
auto sliceOp = view->getDefiningOp()->cast<linalg::SliceOp>();
unsigned slicingDim = sliceOp.getSlicingDim();
auto *indexing = *(sliceOp.getIndexings().begin());
bool isRankReducing = indexing->getType().isa<mlir::IndexType>();
unsigned offset = 0;
llvm::SmallVector<mlir::Value *, 8> res;
res.reserve(sliceOp.getRank());
for (unsigned d = 0, e = sliceOp.getRank(); d < e; ++d) {
if (d == slicingDim && isRankReducing)
offset = 1;
auto *parentView = sliceOp.getParentView();
auto indexingPosPair = linalg::getViewRootIndexing(parentView, d + offset);
res.push_back(indexingPosPair.first);
}
return res;
}
static llvm::SmallVector<mlir::Value *, 8>
getInputRanges(linalg::TensorContractionBase &tensorContraction) {
llvm::SmallVector<mlir::Value *, 8> res;
for (auto *in : tensorContraction.getInputs()) {
auto subres = extractRangesFromViewOrSliceOp(in);
res.append(subres.begin(), subres.end());
}
return res;
}
static llvm::SmallVector<mlir::Value *, 8>
getOutputRanges(linalg::TensorContractionBase &tensorContraction) {
llvm::SmallVector<mlir::Value *, 8> res;
for (auto *out : tensorContraction.getOutputs()) {
auto subres = extractRangesFromViewOrSliceOp(out);
res.append(subres.begin(), subres.end());
}
return res;
}
llvm::SmallVector<mlir::Value *, 8>
linalg::getRanges(linalg::TensorContractionBase &tensorContraction) {
llvm::SmallVector<mlir::Value *, 8> res = getInputRanges(tensorContraction);
llvm::SmallVector<mlir::Value *, 8> tmp = getOutputRanges(tensorContraction);
res.append(tmp.begin(), tmp.end());
return res;
}
//////////////////////////////////////////////////////////////////////////////
// Implementation of DotOp.
//////////////////////////////////////////////////////////////////////////////

View File

@ -20,11 +20,8 @@
//===----------------------------------------------------------------------===//
#include "linalg3/Transforms.h"
#include "linalg1/Common.h"
#include "linalg2/Intrinsics.h"
#include "linalg3/Ops.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"

View File

@ -20,14 +20,11 @@
//===----------------------------------------------------------------------===//
#include "linalg4/Transforms.h"
#include "linalg1/Common.h"
#include "linalg1/Utils.h"
#include "linalg3/Intrinsics.h"
#include "linalg3/TensorOps.h"
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/EDSC/Helpers.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Transforms/LoopUtils.h"
@ -59,25 +56,27 @@ static bool isZeroIndex(Value *v) {
v->getDefiningOp()->dyn_cast<ConstantIndexOp>().getValue() == 0;
}
template <typename ConcreteOp>
static llvm::SmallVector<Value *, 4>
makeTiledRanges(TensorContractionBase &contraction, ArrayRef<Value *> allRanges,
llvm::ArrayRef<Value *> ivs,
makeTiledRanges(TensorContractionBase<ConcreteOp> &contraction,
ArrayRef<Value *> allRanges, llvm::ArrayRef<Value *> ivs,
llvm::ArrayRef<Value *> tileSizes) {
assert(ivs.size() == tileSizes.size());
if (ivs.empty())
return RangeParts(allRanges).makeRanges();
auto *op = static_cast<ConcreteOp *>(&contraction);
RangeParts result(allRanges.size());
RangeParts rangeParts(allRanges);
for (auto map : contraction.loopsToOperandRangeMaps()) {
for (auto map : op->loopsToOperandRangeMaps()) {
// 1. Take the first ivs results of the map, the other ones are not composed
// but merely copied over.
assert(map.getNumSymbols() == 0);
assert(map.getRangeSizes().empty());
MLIRContext *context = ScopedContext::getContext();
unsigned numParallel = contraction.getNumParallelDims();
unsigned numReduction = contraction.getNumReductionDims();
unsigned numParallel = op->getNumParallelDims();
unsigned numReduction = op->getNumReductionDims();
if (ivs.size() < numParallel + numReduction) {
// Inject zeros in positions that are not tiled.
SmallVector<AffineExpr, 4> dimReplacements(numParallel + numReduction);
@ -121,8 +120,9 @@ makeTiledRanges(TensorContractionBase &contraction, ArrayRef<Value *> allRanges,
return result.makeRanges();
}
template <class ConcreteOp>
static SmallVector<Value *, 4>
makeTiledViews(linalg::TensorContractionBase &contraction,
makeTiledViews(linalg::TensorContractionBase<ConcreteOp> &contraction,
ArrayRef<Value *> ivs, ArrayRef<Value *> tileSizes) {
auto tiledRanges =
makeTiledRanges(contraction, getRanges(contraction), ivs, tileSizes);
@ -141,15 +141,15 @@ makeTiledViews(linalg::TensorContractionBase &contraction,
return res;
}
template <typename ConcreteOp>
template <class ConcreteOp>
static SmallVector<mlir::AffineForOp, 8>
writeContractionAsTiledViews(ConcreteOp &contraction,
writeContractionAsTiledViews(TensorContractionBase<ConcreteOp> &contraction,
ArrayRef<Value *> tileSizes) {
assert(tileSizes.size() <=
contraction.getNumParallelDims() + contraction.getNumReductionDims());
ScopedContext scope(mlir::FuncBuilder(contraction.getOperation()),
contraction.getLoc());
auto *op = static_cast<ConcreteOp *>(&contraction);
ScopedContext scope(mlir::FuncBuilder(op->getOperation()), op->getLoc());
SmallVector<IndexHandle, 4> ivs(tileSizes.size());
auto pivs = IndexHandle::makeIndexHandlePointers(ivs);