forked from OSchip/llvm-project
Automated rollback of changelist 242546977.
PiperOrigin-RevId: 242604949
This commit is contained in:
parent
6b18e34de4
commit
a43f216fd5
|
@ -0,0 +1,120 @@
|
|||
//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
|
||||
/// TensorOps by adding implementations as they are needed in the appropriate
|
||||
/// step in the tutorial.
|
||||
#ifndef LINALG2_TENSOROPS_INL_H_
|
||||
#define LINALG2_TENSOROPS_INL_H_
|
||||
|
||||
#include "linalg2/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
namespace linalg {
|
||||
|
||||
template <class ConcreteOp>
|
||||
mlir::Operation::operand_range
|
||||
linalg::TensorContractionBase<ConcreteOp>::getInputs() {
|
||||
auto *op = static_cast<ConcreteOp *>(this)->getOperation();
|
||||
return {op->operand_begin(), op->operand_begin() + getNumInputs()};
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
mlir::Operation::operand_range
|
||||
linalg::TensorContractionBase<ConcreteOp>::getOutputs() {
|
||||
auto *op = static_cast<ConcreteOp *>(this)->getOperation();
|
||||
return {op->operand_begin() + getNumInputs(),
|
||||
op->operand_begin() + getNumInputs() + getNumOutputs()};
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
mlir::Operation::operand_range
|
||||
linalg::TensorContractionBase<ConcreteOp>::getInputsAndOutputs() {
|
||||
return {getInputs().begin(), getOutputs().end()};
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
mlir::LogicalResult linalg::TensorContractionBase<ConcreteOp>::verify() {
|
||||
auto *concreteOp = static_cast<ConcreteOp *>(this)->getOperation();
|
||||
if (getNumInputs() <= 0)
|
||||
concreteOp->emitOpError("expected at least one input");
|
||||
if (getNumOutputs() <= 0)
|
||||
concreteOp->emitOpError("expected at least one output");
|
||||
if (concreteOp->getNumOperands() != getNumInputs() + getNumOutputs()) {
|
||||
concreteOp->emitOpError("expected " +
|
||||
llvm::Twine(getNumInputs() + getNumOutputs()) +
|
||||
" operands");
|
||||
}
|
||||
for (unsigned i = 0, e = getNumInputs(); i < e; ++i) {
|
||||
if (!concreteOp->getOperand(i)->getType().template isa<ViewType>())
|
||||
return concreteOp->emitOpError("operand " + llvm::Twine(i) +
|
||||
" not a ViewType");
|
||||
}
|
||||
for (unsigned i = getNumInputs(), e = getNumInputs() + getNumOutputs(); i < e;
|
||||
++i) {
|
||||
auto viewType =
|
||||
concreteOp->getOperand(i)->getType().template dyn_cast<ViewType>();
|
||||
if (!viewType)
|
||||
return concreteOp->emitOpError("operand " + llvm::Twine(i) +
|
||||
" not a ViewType");
|
||||
if (viewType.getRank() != getNumParallelDims())
|
||||
return concreteOp->emitOpError("operand " + llvm::Twine(i) +
|
||||
" must be of rank " +
|
||||
llvm::Twine(getNumParallelDims()));
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
bool linalg::TensorContractionBase<ConcreteOp>::parse(
|
||||
mlir::OpAsmParser *parser, mlir::OperationState *result) {
|
||||
llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
|
||||
}
|
||||
|
||||
// A TensorContraction prints as:
|
||||
//
|
||||
// ```{.mlir}
|
||||
// concrete_op_name (ssa-inputs, ssa-outputs) : output-view-types
|
||||
// ```
|
||||
//
|
||||
// for example:
|
||||
//
|
||||
// ```
|
||||
// linalg.matmul(%0, %1, %2) : view<?x?xf32>
|
||||
// ```
|
||||
//
|
||||
// Where %0, %1 and %2 are ssa-values of type ViewType.
|
||||
template <class ConcreteOp>
|
||||
void linalg::TensorContractionBase<ConcreteOp>::print(mlir::OpAsmPrinter *p) {
|
||||
*p << static_cast<ConcreteOp *>(this)->getOperationName() << "(";
|
||||
auto *last = *std::prev(getInputsAndOutputs().end());
|
||||
for (auto *i : getInputsAndOutputs()) {
|
||||
*p << *i << ((i == last) ? "" : ", ");
|
||||
}
|
||||
*p << ") : ";
|
||||
auto *lastOutput = *std::prev(getOutputs().end());
|
||||
for (auto *o : getOutputs()) {
|
||||
*p << o->getType() << ((o == lastOutput) ? "" : ",");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG2_TENSOROPS_INL_H_
|
|
@ -29,29 +29,44 @@ namespace linalg {
|
|||
|
||||
/// A generic TensorContraction base class which captures the generic behavior
|
||||
/// of tensor contraction operations (with broadcast).
|
||||
class TensorContractionBase {
|
||||
public:
|
||||
virtual ~TensorContractionBase() {}
|
||||
template <class ConcreteOp> class TensorContractionBase {
|
||||
protected:
|
||||
using TensorContractionBaseType = TensorContractionBase<ConcreteOp>;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Hooks to customize the behavior of this op.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// Generic implementation of hooks that should be called from `ConcreteType`s
|
||||
mlir::LogicalResult verify();
|
||||
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
|
||||
void print(mlir::OpAsmPrinter *p);
|
||||
|
||||
public:
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific functionality.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
virtual llvm::StringRef getTensorContractionName() = 0;
|
||||
TensorContractionBase() = default;
|
||||
mlir::Operation::operand_range getInputs();
|
||||
mlir::Operation::operand_range getOutputs();
|
||||
mlir::Operation::operand_range getInputsAndOutputs() {
|
||||
return {getInputs().begin(), getOutputs().end()};
|
||||
}
|
||||
mlir::Operation::operand_range getInputsAndOutputs();
|
||||
|
||||
/// These are better as methods calling into the ConcreteOp instead of
|
||||
/// template parameters because methods allow more generic behavior and avoid
|
||||
/// specializing for number of arguments. All derived classes have
|
||||
/// `VariadicOperands` and a build method from both an ArrayRef<mlirValue*>
|
||||
/// and the proper number of mlir::Value*.
|
||||
virtual unsigned getNumInputs() = 0;
|
||||
virtual unsigned getNumOutputs() = 0;
|
||||
virtual unsigned getNumParallelDims() = 0;
|
||||
virtual unsigned getNumReductionDims() = 0;
|
||||
unsigned getNumInputs() {
|
||||
return static_cast<ConcreteOp *>(this)->numInputs;
|
||||
};
|
||||
unsigned getNumOutputs() {
|
||||
return static_cast<ConcreteOp *>(this)->numOutputs;
|
||||
};
|
||||
unsigned getNumParallelDims() {
|
||||
return static_cast<ConcreteOp *>(this)->numParallelDims;
|
||||
};
|
||||
unsigned getNumReductionDims() {
|
||||
return static_cast<ConcreteOp *>(this)->numReductionDims;
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Used in Linalg3 and later.
|
||||
|
@ -64,18 +79,13 @@ public:
|
|||
: getOutputView(viewIndex - getNumInputs());
|
||||
}
|
||||
|
||||
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
|
||||
/// loop over matvec). Does nothing by default.
|
||||
virtual void writeAsFinerGrainTensorContraction() {}
|
||||
|
||||
/// Each op is responsible for declaring how it lowers itself to scalar form,
|
||||
/// given the enclosing parallel and reduction induction variables.
|
||||
/// `emitScalarImplementation` emits the scalar IR for the op in the nesting
|
||||
/// context of the innermost enclosing loop(i.e. `reductionIvs.back()` or
|
||||
/// `parallel.back()`).
|
||||
virtual void
|
||||
emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
|
||||
llvm::ArrayRef<mlir::Value *> reductionIvs) {}
|
||||
void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
|
||||
llvm::ArrayRef<mlir::Value *> reductionIvs);
|
||||
|
||||
/// Represents a mapping from the loops to all the ranges of the operands.
|
||||
/// The operands and their ranges are in the order defined by the particular
|
||||
|
@ -84,17 +94,17 @@ public:
|
|||
/// it explicitly is not expensive and generalizes to cases where an analysis
|
||||
/// is not available. For details, see the description of
|
||||
/// loopsToOperandRangeMaps in each ConcreteOp.
|
||||
virtual llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps() {
|
||||
return llvm::SmallVector<mlir::AffineMap, 8>();
|
||||
}
|
||||
llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
|
||||
};
|
||||
|
||||
/// Implements c = A * B where c is a scalar and A and B are 1-D vectors.
|
||||
class DotOp : public TensorContractionBase,
|
||||
class DotOp : public TensorContractionBase<DotOp>,
|
||||
public mlir::Op<DotOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::ZeroResult> {
|
||||
public:
|
||||
using Op::Op;
|
||||
using TensorContractionBaseType =
|
||||
TensorContractionBase::TensorContractionBaseType;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Hooks to customize the behavior of this op.
|
||||
|
@ -113,28 +123,24 @@ public:
|
|||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific functionality.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
llvm::StringRef getTensorContractionName() override {
|
||||
return getOperationName();
|
||||
}
|
||||
unsigned getNumInputs() override { return 2; }
|
||||
unsigned getNumOutputs() override { return 1; }
|
||||
unsigned getNumParallelDims() override { return 0; }
|
||||
unsigned getNumReductionDims() override { return 1; }
|
||||
static constexpr unsigned numInputs = 2;
|
||||
static constexpr unsigned numOutputs = 1;
|
||||
static constexpr unsigned numParallelDims = 0;
|
||||
static constexpr unsigned numReductionDims = 1;
|
||||
|
||||
#if LINALG_STEP > 2
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Used in Linalg3 and later.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
|
||||
/// loop over matvec). Does nothing by default.
|
||||
void writeAsFinerGrainTensorContraction() override;
|
||||
void writeAsFinerGrainTensorContraction();
|
||||
|
||||
/// Inputs to this map will be (%k) coming from enclosing loops.
|
||||
/// Therefore, the mapping to get back to A(K), B(K), C() is:
|
||||
/// (d0) -> (d0, d0)(%k)
|
||||
/// And the operands ranges are:
|
||||
/// (%k, %k)
|
||||
llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps() override;
|
||||
llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
|
||||
|
||||
/// Given an enclosing reduction loop with iv `r_i`, emits MLIR corresponding
|
||||
/// to:
|
||||
|
@ -147,18 +153,18 @@ public:
|
|||
/// cond = (r_i == zero)
|
||||
/// scalarC = select(cond, zerof, C[]);
|
||||
/// C[] = scalarC + A[r_i] * B[r_i];
|
||||
void
|
||||
emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
|
||||
llvm::ArrayRef<mlir::Value *> reductionIvs) override;
|
||||
#endif // LINALG_STEP
|
||||
void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
|
||||
llvm::ArrayRef<mlir::Value *> reductionIvs);
|
||||
};
|
||||
|
||||
/// Implements C = A * B where A is a 2-D matrix and X and Y are 1-D vectors.
|
||||
class MatvecOp : public TensorContractionBase,
|
||||
class MatvecOp : public TensorContractionBase<MatvecOp>,
|
||||
public mlir::Op<MatvecOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::ZeroResult> {
|
||||
public:
|
||||
using Op::Op;
|
||||
using TensorContractionBaseType =
|
||||
TensorContractionBase::TensorContractionBaseType;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Hooks to customize the behavior of this op.
|
||||
|
@ -177,28 +183,24 @@ public:
|
|||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific functionality.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
llvm::StringRef getTensorContractionName() override {
|
||||
return getOperationName();
|
||||
}
|
||||
unsigned getNumInputs() override { return 2; }
|
||||
unsigned getNumOutputs() override { return 1; }
|
||||
unsigned getNumParallelDims() override { return 1; }
|
||||
unsigned getNumReductionDims() override { return 1; }
|
||||
static constexpr unsigned numInputs = 2;
|
||||
static constexpr unsigned numOutputs = 1;
|
||||
static constexpr unsigned numParallelDims = 1;
|
||||
static constexpr unsigned numReductionDims = 1;
|
||||
|
||||
#if LINALG_STEP > 2
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Used in Linalg3 and later.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
|
||||
/// loop over matvec). Does nothing by default.
|
||||
void writeAsFinerGrainTensorContraction() override;
|
||||
void writeAsFinerGrainTensorContraction();
|
||||
|
||||
/// Inputs to this map will be (%m, %k) coming from enclosing loops.
|
||||
/// Therefore, the mapping to get back to A(M, K), B(K), C(M) is:
|
||||
/// (d0, d1) -> (d0, d1, d1, d0)(%m, %k)
|
||||
/// And the operands ranges are:
|
||||
/// (%m, %k, %k, %m)
|
||||
llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps() override;
|
||||
llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
|
||||
|
||||
/// Given an enclosing parallel loop with iv `i` and an enclosing parallel
|
||||
/// loop with iv `r_j`, emits MLIR corresponding to:
|
||||
|
@ -211,18 +213,18 @@ public:
|
|||
/// cond = (r_j == zero)
|
||||
/// scalarC = select(cond, zerof, C(i));
|
||||
/// C(i) = scalarC + A(i, r_j) * B(r_j);
|
||||
void
|
||||
emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
|
||||
llvm::ArrayRef<mlir::Value *> reductionIvs) override;
|
||||
#endif // LINALG_STEP
|
||||
void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
|
||||
llvm::ArrayRef<mlir::Value *> reductionIvs);
|
||||
};
|
||||
|
||||
/// Implements C = A * B on 2-D matrices.
|
||||
class MatmulOp : public TensorContractionBase,
|
||||
class MatmulOp : public TensorContractionBase<MatmulOp>,
|
||||
public mlir::Op<MatmulOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::ZeroResult> {
|
||||
public:
|
||||
using Op::Op;
|
||||
using TensorContractionBaseType =
|
||||
TensorContractionBase::TensorContractionBaseType;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Hooks to customize the behavior of this op.
|
||||
|
@ -241,28 +243,24 @@ public:
|
|||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific functionality.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
llvm::StringRef getTensorContractionName() override {
|
||||
return getOperationName();
|
||||
}
|
||||
unsigned getNumInputs() override { return 2; }
|
||||
unsigned getNumOutputs() override { return 1; }
|
||||
unsigned getNumParallelDims() override { return 2; }
|
||||
unsigned getNumReductionDims() override { return 1; }
|
||||
static constexpr unsigned numInputs = 2;
|
||||
static constexpr unsigned numOutputs = 1;
|
||||
static constexpr unsigned numParallelDims = 2;
|
||||
static constexpr unsigned numReductionDims = 1;
|
||||
|
||||
#if LINALG_STEP > 2
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Used in Linalg3 and later.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
|
||||
/// loop over matvec). Does nothing by default.
|
||||
void writeAsFinerGrainTensorContraction() override;
|
||||
void writeAsFinerGrainTensorContraction();
|
||||
|
||||
/// Inputs to this map will be (%m, %n, %k) coming from enclosing loops.
|
||||
/// Therefore, the mapping to get back to A(M, K), B(K, N), C(M, N) is:
|
||||
/// (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k)
|
||||
/// And the operands ranges are:
|
||||
/// (%m, %k, %k, %n, %m, %n)
|
||||
llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps() override;
|
||||
llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
|
||||
|
||||
/// Given a enclosing parallel loops with ivs `i` and `j`, and an enclosing
|
||||
/// reduction loop with iv `r_k`, emits MLIR corresponding to:
|
||||
|
@ -275,12 +273,15 @@ public:
|
|||
/// cond = (r_k == zero)
|
||||
/// scalarC = select(cond, zerof, C[i, j]);
|
||||
/// C[i, j] = scalarC + A[i, r_k] * B[r_k, j];
|
||||
void
|
||||
emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
|
||||
llvm::ArrayRef<mlir::Value *> reductionIvs) override;
|
||||
#endif // LINALG_STEP
|
||||
void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
|
||||
llvm::ArrayRef<mlir::Value *> reductionIvs);
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
|
||||
/// TensorOps by adding implementations as they are needed in the appropriate
|
||||
/// step in the tutorial.
|
||||
#include "linalg2/TensorOps-inl.h"
|
||||
|
||||
#endif // LINALG2_TENSOROPS_H_
|
||||
|
|
|
@ -26,100 +26,14 @@
|
|||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
using llvm::ArrayRef;
|
||||
using llvm::Twine;
|
||||
|
||||
using namespace mlir;
|
||||
using namespace linalg;
|
||||
|
||||
#define TENSOR_CONTRACTION_DISPATCH(FUNCTION_NAME) \
|
||||
if (getTensorContractionName() == MatmulOp::getOperationName()) { \
|
||||
return FUNCTION_NAME(static_cast<MatmulOp &>(*this)); \
|
||||
} \
|
||||
if (getTensorContractionName() == MatvecOp::getOperationName()) { \
|
||||
return FUNCTION_NAME(static_cast<MatvecOp &>(*this)); \
|
||||
} \
|
||||
if (getTensorContractionName() == DotOp::getOperationName()) { \
|
||||
return FUNCTION_NAME(static_cast<DotOp &>(*this)); \
|
||||
} \
|
||||
llvm_unreachable("Missing linalg op");
|
||||
|
||||
template <typename ConcreteOp>
|
||||
static mlir::Operation::operand_range getInputs(ConcreteOp &concreteOp) {
|
||||
return {concreteOp.operand_begin(),
|
||||
concreteOp.operand_begin() + concreteOp.getNumInputs()};
|
||||
}
|
||||
|
||||
mlir::Operation::operand_range linalg::TensorContractionBase::getInputs() {
|
||||
TENSOR_CONTRACTION_DISPATCH(::getInputs);
|
||||
}
|
||||
|
||||
template <typename ConcreteOp>
|
||||
static mlir::Operation::operand_range getOutputs(ConcreteOp &concreteOp) {
|
||||
return {concreteOp.operand_begin() + concreteOp.getNumInputs(),
|
||||
concreteOp.operand_begin() + concreteOp.getNumInputs() +
|
||||
concreteOp.getNumOutputs()};
|
||||
}
|
||||
|
||||
mlir::Operation::operand_range linalg::TensorContractionBase::getOutputs() {
|
||||
TENSOR_CONTRACTION_DISPATCH(::getOutputs);
|
||||
}
|
||||
|
||||
template <typename LinalgOp>
|
||||
static mlir::LogicalResult verifyLinalgOp(LinalgOp op) {
|
||||
if (op.getNumInputs() <= 0)
|
||||
op.emitOpError("expected at least one input");
|
||||
if (op.getNumOutputs() <= 0)
|
||||
op.emitOpError("expected at least one output");
|
||||
if (op.getNumOperands() != op.getNumInputs() + op.getNumOutputs()) {
|
||||
op.emitOpError("expected " +
|
||||
llvm::Twine(op.getNumInputs() + op.getNumOutputs()) +
|
||||
" operands");
|
||||
}
|
||||
for (unsigned i = 0, e = op.getNumInputs(); i < e; ++i) {
|
||||
if (!op.getOperand(i)->getType().template isa<ViewType>())
|
||||
return op.emitOpError("operand " + llvm::Twine(i) + " not a ViewType");
|
||||
}
|
||||
for (unsigned i = op.getNumInputs(),
|
||||
e = op.getNumInputs() + op.getNumOutputs();
|
||||
i < e; ++i) {
|
||||
auto viewType = op.getOperand(i)->getType().template dyn_cast<ViewType>();
|
||||
if (!viewType)
|
||||
return op.emitOpError("operand " + llvm::Twine(i) + " not a ViewType");
|
||||
if (viewType.getRank() != op.getNumParallelDims())
|
||||
return op.emitOpError("operand " + llvm::Twine(i) + " must be of rank " +
|
||||
llvm::Twine(op.getNumParallelDims()));
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// A TensorContraction prints as:
|
||||
//
|
||||
// ```{.mlir}
|
||||
// concrete_op_name (ssa-inputs, ssa-outputs) : output-view-types
|
||||
// ```
|
||||
//
|
||||
// for example:
|
||||
//
|
||||
// ```
|
||||
// linalg.matmul(%0, %1, %2) : view<?x?xf32>
|
||||
// ```
|
||||
//
|
||||
// Where %0, %1 and %2 are ssa-values of type ViewType.
|
||||
template <typename LinalgOp>
|
||||
static void printLinalgOp(mlir::OpAsmPrinter *p, LinalgOp op) {
|
||||
*p << op.getOperationName() << "(";
|
||||
auto *last = *std::prev(op.getInputsAndOutputs().end());
|
||||
for (auto *i : op.getInputsAndOutputs()) {
|
||||
*p << *i << ((i == last) ? "" : ", ");
|
||||
}
|
||||
*p << ") : ";
|
||||
auto *lastOutput = *std::prev(op.getOutputs().end());
|
||||
for (auto *o : op.getOutputs()) {
|
||||
*p << o->getType() << ((o == lastOutput) ? "" : ",");
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific Dot.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -129,7 +43,7 @@ void linalg::DotOp::build(Builder *b, OperationState *result,
|
|||
}
|
||||
|
||||
LogicalResult linalg::DotOp::verify() {
|
||||
if (failed(verifyLinalgOp(*this)))
|
||||
if (failed(TensorContractionBaseType::verify()))
|
||||
return failure();
|
||||
auto *A = getOperand(0), *B = getOperand(1), *C = getOperand(2);
|
||||
unsigned index = 0;
|
||||
|
@ -146,10 +60,12 @@ LogicalResult linalg::DotOp::verify() {
|
|||
// Parsing of the linalg dialect is not supported in this tutorial.
|
||||
bool linalg::DotOp::parse(mlir::OpAsmParser *parser,
|
||||
mlir::OperationState *result) {
|
||||
llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
|
||||
return TensorContractionBaseType::parse(parser, result);
|
||||
}
|
||||
|
||||
void linalg::DotOp::print(mlir::OpAsmPrinter *p) { printLinalgOp(p, *this); }
|
||||
void linalg::DotOp::print(mlir::OpAsmPrinter *p) {
|
||||
TensorContractionBaseType::print(p);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific Matvec.
|
||||
|
@ -160,7 +76,7 @@ void linalg::MatvecOp::build(Builder *b, OperationState *result,
|
|||
}
|
||||
|
||||
LogicalResult linalg::MatvecOp::verify() {
|
||||
if (failed(verifyLinalgOp(*this)))
|
||||
if (failed(TensorContractionBaseType::verify()))
|
||||
return failure();
|
||||
auto *A = getOperand(0), *B = getOperand(1), *C = getOperand(2);
|
||||
if (getViewRank(A) != 2)
|
||||
|
@ -178,10 +94,12 @@ LogicalResult linalg::MatvecOp::verify() {
|
|||
// Parsing of the linalg dialect is not supported in this tutorial.
|
||||
bool linalg::MatvecOp::parse(mlir::OpAsmParser *parser,
|
||||
mlir::OperationState *result) {
|
||||
llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
|
||||
return TensorContractionBaseType::parse(parser, result);
|
||||
}
|
||||
|
||||
void linalg::MatvecOp::print(mlir::OpAsmPrinter *p) { printLinalgOp(p, *this); }
|
||||
void linalg::MatvecOp::print(mlir::OpAsmPrinter *p) {
|
||||
TensorContractionBaseType::print(p);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific Matmul.
|
||||
|
@ -192,7 +110,7 @@ void linalg::MatmulOp::build(Builder *b, OperationState *result,
|
|||
}
|
||||
|
||||
LogicalResult linalg::MatmulOp::verify() {
|
||||
if (failed(verifyLinalgOp(*this)))
|
||||
if (failed(TensorContractionBaseType::verify()))
|
||||
return failure();
|
||||
auto *A = getOperand(0), *B = getOperand(1), *C = getOperand(2);
|
||||
unsigned index = 0;
|
||||
|
@ -207,7 +125,9 @@ LogicalResult linalg::MatmulOp::verify() {
|
|||
// Parsing of the linalg dialect is not supported in this tutorial.
|
||||
bool linalg::MatmulOp::parse(mlir::OpAsmParser *parser,
|
||||
mlir::OperationState *result) {
|
||||
llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
|
||||
return TensorContractionBaseType::parse(parser, result);
|
||||
}
|
||||
|
||||
void linalg::MatmulOp::print(mlir::OpAsmPrinter *p) { printLinalgOp(p, *this); }
|
||||
void linalg::MatmulOp::print(mlir::OpAsmPrinter *p) {
|
||||
TensorContractionBaseType::print(p);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,145 @@
|
|||
//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
|
||||
/// TensorOps by adding implementations as they are needed in the appropriate
|
||||
/// step in the tutorial.
|
||||
#ifndef LINALG3_TENSOROPS_INL_H_
|
||||
#define LINALG3_TENSOROPS_INL_H_
|
||||
|
||||
#include "linalg1/Common.h"
|
||||
#include "linalg1/Utils.h"
|
||||
#include "linalg2/TensorOps.h"
|
||||
#include "linalg3/Analysis.h"
|
||||
#include "linalg3/Ops.h"
|
||||
|
||||
template <class ConcreteOp>
|
||||
mlir::Value *
|
||||
linalg::TensorContractionBase<ConcreteOp>::getInputView(unsigned viewIndex) {
|
||||
return *(getInputs().begin() + viewIndex);
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
mlir::Value *
|
||||
linalg::TensorContractionBase<ConcreteOp>::getOutputView(unsigned viewIndex) {
|
||||
return *(getOutputs().begin() + viewIndex);
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
llvm::SmallVector<mlir::AffineMap, 8>
|
||||
linalg::TensorContractionBase<ConcreteOp>::loopsToOperandRangeMaps() {
|
||||
return static_cast<ConcreteOp *>(this)->loopsToOperandRangeMaps();
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
void linalg::TensorContractionBase<ConcreteOp>::emitScalarImplementation(
|
||||
llvm::ArrayRef<mlir::Value *> parallelIvs,
|
||||
llvm::ArrayRef<mlir::Value *> reductionIvs) {
|
||||
static_cast<ConcreteOp *>(this)->emitScalarImplementation(parallelIvs,
|
||||
reductionIvs);
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
mlir::AffineMap linalg::operandRangesToLoopsMap(
|
||||
linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
|
||||
mlir::AffineMap current;
|
||||
// Individual submaps may not be invertible but their union must be invertible
|
||||
// by construction.
|
||||
for (auto m : tensorContraction.loopsToOperandRangeMaps()) {
|
||||
if (!m)
|
||||
continue;
|
||||
if (!current) {
|
||||
current = m;
|
||||
continue;
|
||||
}
|
||||
llvm::SmallVector<mlir::AffineExpr, 8> results(current.getResults().begin(),
|
||||
current.getResults().end());
|
||||
results.append(m.getResults().begin(), m.getResults().end());
|
||||
current = mlir::AffineMap::get(
|
||||
std::max(current.getNumDims(), m.getNumDims()),
|
||||
current.getNumSymbols() + m.getNumSymbols(), results, {});
|
||||
}
|
||||
return inverseSubMap(current);
|
||||
}
|
||||
|
||||
// Extract the ranges from a given ViewOp or SliceOp.
|
||||
//
|
||||
// In the case of a ViewOp, things are simple: just traverse the indexings and
|
||||
// get all the ranges (i.e. drop the indices).
|
||||
//
|
||||
// In the case of a SliceOp, things are trickier because we need to handle a
|
||||
// potential rank-reduction:
|
||||
// 1. Examine the indexing to determine if it is rank-reducing.
|
||||
// 2. If it is rank-reducing, an offset of 1 is added to the dimensions such
|
||||
// that `d >= slicingDim`. This is to account for the rank reduction.
|
||||
// `getRootIndex` is then called on the **parent** view
|
||||
static llvm::SmallVector<mlir::Value *, 8>
|
||||
extractRangesFromViewOrSliceOp(mlir::Value *view) {
|
||||
// This expects a viewType which must come from either ViewOp or SliceOp.
|
||||
assert(view->getType().isa<linalg::ViewType>() && "expected ViewType");
|
||||
if (auto viewOp = view->getDefiningOp()->dyn_cast<linalg::ViewOp>())
|
||||
return viewOp.getRanges();
|
||||
|
||||
auto sliceOp = view->getDefiningOp()->cast<linalg::SliceOp>();
|
||||
unsigned slicingDim = sliceOp.getSlicingDim();
|
||||
auto *indexing = *(sliceOp.getIndexings().begin());
|
||||
bool isRankReducing = indexing->getType().isa<mlir::IndexType>();
|
||||
unsigned offset = 0;
|
||||
llvm::SmallVector<mlir::Value *, 8> res;
|
||||
res.reserve(sliceOp.getRank());
|
||||
for (unsigned d = 0, e = sliceOp.getRank(); d < e; ++d) {
|
||||
if (d == slicingDim && isRankReducing)
|
||||
offset = 1;
|
||||
auto *parentView = sliceOp.getParentView();
|
||||
auto indexingPosPair = linalg::getViewRootIndexing(parentView, d + offset);
|
||||
res.push_back(indexingPosPair.first);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
static llvm::SmallVector<mlir::Value *, 8>
|
||||
getInputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
|
||||
llvm::SmallVector<mlir::Value *, 8> res;
|
||||
for (auto *in : tensorContraction.getInputs()) {
|
||||
auto subres = extractRangesFromViewOrSliceOp(in);
|
||||
res.append(subres.begin(), subres.end());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
static llvm::SmallVector<mlir::Value *, 8>
|
||||
getOutputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
|
||||
llvm::SmallVector<mlir::Value *, 8> res;
|
||||
for (auto *out : tensorContraction.getOutputs()) {
|
||||
auto subres = extractRangesFromViewOrSliceOp(out);
|
||||
res.append(subres.begin(), subres.end());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
llvm::SmallVector<mlir::Value *, 8> linalg::getRanges(
|
||||
linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
|
||||
llvm::SmallVector<mlir::Value *, 8> res = getInputRanges(tensorContraction);
|
||||
llvm::SmallVector<mlir::Value *, 8> tmp = getOutputRanges(tensorContraction);
|
||||
res.append(tmp.begin(), tmp.end());
|
||||
return res;
|
||||
}
|
||||
|
||||
#endif // LINALG3_TENSOROPS_INL_H_
|
|
@ -29,8 +29,9 @@ namespace linalg {
|
|||
|
||||
/// Takes a `tensorContraction` and a returns an AffineMap that can be used to
|
||||
/// map ranges to enclosing loops for all the operands' ranges.
|
||||
mlir::AffineMap
|
||||
operandRangesToLoopsMap(linalg::TensorContractionBase &tensorContraction);
|
||||
template <class ConcreteOp>
|
||||
mlir::AffineMap operandRangesToLoopsMap(
|
||||
linalg::TensorContractionBase<ConcreteOp> &tensorContraction);
|
||||
|
||||
/// Takes a `tensorContraction` and returns the ranges of all its operands.
|
||||
/// When an operand comes from a ViewOp, things are simple:
|
||||
|
@ -39,9 +40,15 @@ operandRangesToLoopsMap(linalg::TensorContractionBase &tensorContraction);
|
|||
/// In the case of a SliceOp, things are more involved because we need to handle
|
||||
/// potential rank-reductions.
|
||||
/// This function abstracts this complexity away and returns all the ranges.
|
||||
template <class ConcreteOp>
|
||||
llvm::SmallVector<mlir::Value *, 8>
|
||||
getRanges(linalg::TensorContractionBase &tensorContraction);
|
||||
getRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction);
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
|
||||
/// TensorOps by adding implementations as they are needed in the appropriate
|
||||
/// step in the tutorial.
|
||||
#include "linalg3/TensorOps-inl.h"
|
||||
|
||||
#endif // LINALG3_TENSOROPS_H_
|
||||
|
|
|
@ -27,9 +27,7 @@
|
|||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/LLVMIR/Transforms.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "linalg1/ConvertToLLVMDialect.h"
|
||||
#include "linalg1/LLVMIntrinsics.h"
|
||||
|
|
|
@ -20,8 +20,8 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg1/Analysis.h"
|
||||
#include "linalg1/Common.h"
|
||||
#include "linalg3/Analysis.h"
|
||||
#include "linalg3/Intrinsics.h"
|
||||
#include "linalg3/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
@ -36,99 +36,6 @@ using namespace mlir::edsc::intrinsics;
|
|||
using namespace linalg;
|
||||
using namespace linalg::intrinsics;
|
||||
|
||||
mlir::Value *linalg::TensorContractionBase::getInputView(unsigned viewIndex) {
|
||||
return *(getInputs().begin() + viewIndex);
|
||||
}
|
||||
|
||||
mlir::Value *linalg::TensorContractionBase::getOutputView(unsigned viewIndex) {
|
||||
return *(getOutputs().begin() + viewIndex);
|
||||
}
|
||||
|
||||
mlir::AffineMap linalg::operandRangesToLoopsMap(
|
||||
linalg::TensorContractionBase &tensorContraction) {
|
||||
mlir::AffineMap current;
|
||||
// Individual submaps may not be invertible but their union must be invertible
|
||||
// by construction.
|
||||
for (auto m : tensorContraction.loopsToOperandRangeMaps()) {
|
||||
if (!m)
|
||||
continue;
|
||||
if (!current) {
|
||||
current = m;
|
||||
continue;
|
||||
}
|
||||
llvm::SmallVector<mlir::AffineExpr, 8> results(current.getResults().begin(),
|
||||
current.getResults().end());
|
||||
results.append(m.getResults().begin(), m.getResults().end());
|
||||
current = mlir::AffineMap::get(
|
||||
std::max(current.getNumDims(), m.getNumDims()),
|
||||
current.getNumSymbols() + m.getNumSymbols(), results, {});
|
||||
}
|
||||
return inverseSubMap(current);
|
||||
}
|
||||
|
||||
// Extract the ranges from a given ViewOp or SliceOp.
|
||||
//
|
||||
// In the case of a ViewOp, things are simple: just traverse the indexings and
|
||||
// get all the ranges (i.e. drop the indices).
|
||||
//
|
||||
// In the case of a SliceOp, things are trickier because we need to handle a
|
||||
// potential rank-reduction:
|
||||
// 1. Examine the indexing to determine if it is rank-reducing.
|
||||
// 2. If it is rank-reducing, an offset of 1 is added to the dimensions such
|
||||
// that `d >= slicingDim`. This is to account for the rank reduction.
|
||||
// `getRootIndex` is then called on the **parent** view
|
||||
static llvm::SmallVector<mlir::Value *, 8>
|
||||
extractRangesFromViewOrSliceOp(mlir::Value *view) {
|
||||
// This expects a viewType which must come from either ViewOp or SliceOp.
|
||||
assert(view->getType().isa<linalg::ViewType>() && "expected ViewType");
|
||||
if (auto viewOp = view->getDefiningOp()->dyn_cast<linalg::ViewOp>())
|
||||
return viewOp.getRanges();
|
||||
|
||||
auto sliceOp = view->getDefiningOp()->cast<linalg::SliceOp>();
|
||||
unsigned slicingDim = sliceOp.getSlicingDim();
|
||||
auto *indexing = *(sliceOp.getIndexings().begin());
|
||||
bool isRankReducing = indexing->getType().isa<mlir::IndexType>();
|
||||
unsigned offset = 0;
|
||||
llvm::SmallVector<mlir::Value *, 8> res;
|
||||
res.reserve(sliceOp.getRank());
|
||||
for (unsigned d = 0, e = sliceOp.getRank(); d < e; ++d) {
|
||||
if (d == slicingDim && isRankReducing)
|
||||
offset = 1;
|
||||
auto *parentView = sliceOp.getParentView();
|
||||
auto indexingPosPair = linalg::getViewRootIndexing(parentView, d + offset);
|
||||
res.push_back(indexingPosPair.first);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
static llvm::SmallVector<mlir::Value *, 8>
|
||||
getInputRanges(linalg::TensorContractionBase &tensorContraction) {
|
||||
llvm::SmallVector<mlir::Value *, 8> res;
|
||||
for (auto *in : tensorContraction.getInputs()) {
|
||||
auto subres = extractRangesFromViewOrSliceOp(in);
|
||||
res.append(subres.begin(), subres.end());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
static llvm::SmallVector<mlir::Value *, 8>
|
||||
getOutputRanges(linalg::TensorContractionBase &tensorContraction) {
|
||||
llvm::SmallVector<mlir::Value *, 8> res;
|
||||
for (auto *out : tensorContraction.getOutputs()) {
|
||||
auto subres = extractRangesFromViewOrSliceOp(out);
|
||||
res.append(subres.begin(), subres.end());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::Value *, 8>
|
||||
linalg::getRanges(linalg::TensorContractionBase &tensorContraction) {
|
||||
llvm::SmallVector<mlir::Value *, 8> res = getInputRanges(tensorContraction);
|
||||
llvm::SmallVector<mlir::Value *, 8> tmp = getOutputRanges(tensorContraction);
|
||||
res.append(tmp.begin(), tmp.end());
|
||||
return res;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Implementation of DotOp.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -20,11 +20,8 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg3/Transforms.h"
|
||||
#include "linalg1/Common.h"
|
||||
#include "linalg2/Intrinsics.h"
|
||||
#include "linalg3/Ops.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
|
|
|
@ -20,14 +20,11 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg4/Transforms.h"
|
||||
#include "linalg1/Common.h"
|
||||
#include "linalg1/Utils.h"
|
||||
#include "linalg3/Intrinsics.h"
|
||||
#include "linalg3/TensorOps.h"
|
||||
|
||||
#include "mlir/AffineOps/AffineOps.h"
|
||||
#include "mlir/EDSC/Helpers.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
|
||||
|
@ -59,25 +56,27 @@ static bool isZeroIndex(Value *v) {
|
|||
v->getDefiningOp()->dyn_cast<ConstantIndexOp>().getValue() == 0;
|
||||
}
|
||||
|
||||
template <typename ConcreteOp>
|
||||
static llvm::SmallVector<Value *, 4>
|
||||
makeTiledRanges(TensorContractionBase &contraction, ArrayRef<Value *> allRanges,
|
||||
llvm::ArrayRef<Value *> ivs,
|
||||
makeTiledRanges(TensorContractionBase<ConcreteOp> &contraction,
|
||||
ArrayRef<Value *> allRanges, llvm::ArrayRef<Value *> ivs,
|
||||
llvm::ArrayRef<Value *> tileSizes) {
|
||||
assert(ivs.size() == tileSizes.size());
|
||||
if (ivs.empty())
|
||||
return RangeParts(allRanges).makeRanges();
|
||||
|
||||
auto *op = static_cast<ConcreteOp *>(&contraction);
|
||||
RangeParts result(allRanges.size());
|
||||
RangeParts rangeParts(allRanges);
|
||||
|
||||
for (auto map : contraction.loopsToOperandRangeMaps()) {
|
||||
for (auto map : op->loopsToOperandRangeMaps()) {
|
||||
// 1. Take the first ivs results of the map, the other ones are not composed
|
||||
// but merely copied over.
|
||||
assert(map.getNumSymbols() == 0);
|
||||
assert(map.getRangeSizes().empty());
|
||||
MLIRContext *context = ScopedContext::getContext();
|
||||
unsigned numParallel = contraction.getNumParallelDims();
|
||||
unsigned numReduction = contraction.getNumReductionDims();
|
||||
unsigned numParallel = op->getNumParallelDims();
|
||||
unsigned numReduction = op->getNumReductionDims();
|
||||
if (ivs.size() < numParallel + numReduction) {
|
||||
// Inject zeros in positions that are not tiled.
|
||||
SmallVector<AffineExpr, 4> dimReplacements(numParallel + numReduction);
|
||||
|
@ -121,8 +120,9 @@ makeTiledRanges(TensorContractionBase &contraction, ArrayRef<Value *> allRanges,
|
|||
return result.makeRanges();
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
static SmallVector<Value *, 4>
|
||||
makeTiledViews(linalg::TensorContractionBase &contraction,
|
||||
makeTiledViews(linalg::TensorContractionBase<ConcreteOp> &contraction,
|
||||
ArrayRef<Value *> ivs, ArrayRef<Value *> tileSizes) {
|
||||
auto tiledRanges =
|
||||
makeTiledRanges(contraction, getRanges(contraction), ivs, tileSizes);
|
||||
|
@ -141,15 +141,15 @@ makeTiledViews(linalg::TensorContractionBase &contraction,
|
|||
return res;
|
||||
}
|
||||
|
||||
template <typename ConcreteOp>
|
||||
template <class ConcreteOp>
|
||||
static SmallVector<mlir::AffineForOp, 8>
|
||||
writeContractionAsTiledViews(ConcreteOp &contraction,
|
||||
writeContractionAsTiledViews(TensorContractionBase<ConcreteOp> &contraction,
|
||||
ArrayRef<Value *> tileSizes) {
|
||||
assert(tileSizes.size() <=
|
||||
contraction.getNumParallelDims() + contraction.getNumReductionDims());
|
||||
|
||||
ScopedContext scope(mlir::FuncBuilder(contraction.getOperation()),
|
||||
contraction.getLoc());
|
||||
auto *op = static_cast<ConcreteOp *>(&contraction);
|
||||
ScopedContext scope(mlir::FuncBuilder(op->getOperation()), op->getLoc());
|
||||
SmallVector<IndexHandle, 4> ivs(tileSizes.size());
|
||||
auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
|
||||
|
||||
|
|
Loading…
Reference in New Issue