forked from OSchip/llvm-project
Linalg portion of the tutorial - part 3
This CL starts the third part of the Linalg tutorial by adding support for ops to declare how they lower themselves to other ops. Tests are added that demonstrate matmul lowering to a loop over matvec and matvec lowering to a loop over dot. This is part of a list of CLs that add new Transforms and Analyses to Linalg3: it iseasier to integrate in small chunks. As part of working with the TensorContractionBase template class and in an effort to add pieces incrementally without copying code, it is easiest to define operations ahead of time in Linalg2/TensorOps.h and gradually implement them as needed. This CL performs the necessary refactoring for this to happen. -- PiperOrigin-RevId: 241605869
This commit is contained in:
parent
7fa2864954
commit
72ccfcee1e
|
@ -166,6 +166,7 @@ using constant_float = ValueBuilder<ConstantFloatOp>;
|
|||
using constant_index = ValueBuilder<ConstantIndexOp>;
|
||||
using constant_int = ValueBuilder<ConstantIntOp>;
|
||||
using dealloc = OperationBuilder<DeallocOp>;
|
||||
using dim = ValueBuilder<DimOp>;
|
||||
using load = ValueBuilder<LoadOp>;
|
||||
using ret = OperationBuilder<ReturnOp>;
|
||||
using select = ValueBuilder<SelectOp>;
|
||||
|
|
|
@ -44,12 +44,6 @@ mlir::Value *getViewSupportingMemRef(mlir::Value *view);
|
|||
std::pair<mlir::Value *, unsigned> getViewRootIndexing(mlir::Value *view,
|
||||
unsigned dim);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Helper functions to avoid dispatching at all client sites.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Asserts `view` is of ViewType and returns its rank.
|
||||
unsigned getViewRank(mlir::Value *view);
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG1_ANALYSIS_H_
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
//===- Utils.h - Linalg dialect utility functions definitions -------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG1_UTILS_H_
|
||||
#define LINALG1_UTILS_H_
|
||||
|
||||
namespace mlir {
|
||||
class Value;
|
||||
} // namespace mlir
|
||||
|
||||
namespace linalg {
|
||||
|
||||
/// Asserts `view` is of ViewType and returns its rank.
|
||||
unsigned getViewRank(mlir::Value *view);
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG1_UTILS_H_
|
|
@ -73,13 +73,3 @@ std::pair<mlir::Value *, unsigned> linalg::getViewRootIndexing(Value *view,
|
|||
unsigned parentDim = dim > sliceDim ? dim + 1 : dim;
|
||||
return getViewRootIndexing(parentView, parentDim);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Helper functions to avoid dispatch at all client sites.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
unsigned linalg::getViewRank(Value *view) {
|
||||
assert(view->getType().isa<ViewType>() && "expected a ViewType");
|
||||
if (auto viewOp = view->getDefiningOp()->dyn_cast<ViewOp>())
|
||||
return viewOp.getRank();
|
||||
return view->getDefiningOp()->dyn_cast<SliceOp>().getRank();
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "linalg1/Analysis.h"
|
||||
#include "linalg1/Ops.h"
|
||||
#include "linalg1/Types.h"
|
||||
#include "linalg1/Utils.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
//===- Utils.cpp - Implementation of utiliy functions for Linalg ----------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements utility functions for the linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg1/Utils.h"
|
||||
#include "linalg1/Ops.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace linalg;
|
||||
|
||||
unsigned linalg::getViewRank(Value *view) {
|
||||
assert(view->getType().isa<ViewType>() && "expected a ViewType");
|
||||
if (auto viewOp = view->getDefiningOp()->dyn_cast<ViewOp>())
|
||||
return viewOp.getRank();
|
||||
return view->getDefiningOp()->dyn_cast<SliceOp>().getRank();
|
||||
}
|
|
@ -0,0 +1,108 @@
|
|||
//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
|
||||
/// TensorOps by adding implementations as they are needed in the appropriate
|
||||
/// step in the tutorial.
|
||||
#ifndef LINALG2_TENSOROPS_INL_H_
|
||||
#define LINALG2_TENSOROPS_INL_H_
|
||||
|
||||
#include "linalg2/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
namespace linalg {
|
||||
|
||||
template <class ConcreteOp>
|
||||
mlir::Operation::operand_range
|
||||
linalg::TensorContractionBase<ConcreteOp>::getInputs() {
|
||||
auto *op = static_cast<ConcreteOp *>(this)->getOperation();
|
||||
return {op->operand_begin(), op->operand_begin() + getNumInputs()};
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
mlir::Operation::operand_range
|
||||
linalg::TensorContractionBase<ConcreteOp>::getOutputs() {
|
||||
auto *op = static_cast<ConcreteOp *>(this)->getOperation();
|
||||
return {op->operand_begin() + getNumInputs(),
|
||||
op->operand_begin() + getNumInputs() + getNumOutputs()};
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
mlir::LogicalResult linalg::TensorContractionBase<ConcreteOp>::verify() {
|
||||
auto *concreteOp = static_cast<ConcreteOp *>(this)->getOperation();
|
||||
if (getNumInputs() <= 0)
|
||||
concreteOp->emitOpError("expected at least one input");
|
||||
if (getNumOutputs() <= 0)
|
||||
concreteOp->emitOpError("expected at least one output");
|
||||
if (concreteOp->getNumOperands() != getNumInputs() + getNumOutputs()) {
|
||||
concreteOp->emitOpError("expected " +
|
||||
llvm::Twine(getNumInputs() + getNumOutputs()) +
|
||||
" operands");
|
||||
}
|
||||
for (unsigned i = 0, e = getNumInputs(); i < e; ++i) {
|
||||
if (!concreteOp->getOperand(i)->getType().template isa<ViewType>())
|
||||
return concreteOp->emitOpError("operand " + llvm::Twine(i) +
|
||||
" not a ViewType");
|
||||
}
|
||||
for (unsigned i = getNumInputs(), e = getNumInputs() + getNumOutputs(); i < e;
|
||||
++i) {
|
||||
auto viewType =
|
||||
concreteOp->getOperand(i)->getType().template dyn_cast<ViewType>();
|
||||
if (!viewType)
|
||||
return concreteOp->emitOpError("operand " + llvm::Twine(i) +
|
||||
" not a ViewType");
|
||||
if (viewType.getRank() != getNumParallelDims())
|
||||
return concreteOp->emitOpError("operand " + llvm::Twine(i) +
|
||||
" must be of rank " +
|
||||
llvm::Twine(getNumParallelDims()));
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
bool linalg::TensorContractionBase<ConcreteOp>::parse(
|
||||
mlir::OpAsmParser *parser, mlir::OperationState *result) {
|
||||
llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
|
||||
}
|
||||
|
||||
// A TensorContraction prints as:
|
||||
//
|
||||
// ```{.mlir}
|
||||
// concrete_op_name {%0, %1} -> {%2}
|
||||
// ```
|
||||
//
|
||||
// Where %0, %1 is an ssa-value holding a View, %2 is an ssa-value holding a
|
||||
// view.
|
||||
template <class ConcreteOp>
|
||||
void linalg::TensorContractionBase<ConcreteOp>::print(mlir::OpAsmPrinter *p) {
|
||||
*p << static_cast<ConcreteOp *>(this)->getOperationName() << " {";
|
||||
auto *lastInput = *std::prev(getInputs().end());
|
||||
for (auto *i : getInputs()) {
|
||||
*p << *i << ((i == lastInput) ? "} -> {" : ", ");
|
||||
}
|
||||
auto *lastOutput = *std::prev(getOutputs().end());
|
||||
for (auto *o : getOutputs()) {
|
||||
*p << *o << ((o == lastOutput) ? "}" : ",");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG2_TENSOROPS_INL_H_
|
|
@ -15,12 +15,16 @@
|
|||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG2_MATMULOP_H_
|
||||
#define LINALG2_MATMULOP_H_
|
||||
#ifndef LINALG2_TENSOROPS_H_
|
||||
#define LINALG2_TENSOROPS_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace mlir {
|
||||
class AffineForOp;
|
||||
} // namespace mlir
|
||||
|
||||
namespace linalg {
|
||||
|
||||
/// A generic TensorContraction base class which captures the generic behavior
|
||||
|
@ -41,13 +45,6 @@ protected:
|
|||
// Op-specific functionality.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
TensorContractionBase() = default;
|
||||
|
||||
mlir::Type getInputElementType(unsigned i);
|
||||
mlir::Type getOutputElementType(unsigned i);
|
||||
mlir::Value *getInputView(unsigned i);
|
||||
mlir::Value *getOutputView(unsigned i);
|
||||
mlir::Value *getInputMemRef(unsigned i);
|
||||
mlir::Value *getOutputMemRef(unsigned i);
|
||||
mlir::Operation::operand_range getInputs();
|
||||
mlir::Operation::operand_range getOutputs();
|
||||
|
||||
|
@ -69,6 +66,20 @@ public:
|
|||
unsigned getNumReductionDims() {
|
||||
return static_cast<ConcreteOp *>(this)->numReductionDims;
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Used in Linalg3 and later.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
mlir::Value *getInputView(unsigned i);
|
||||
mlir::Value *getOutputView(unsigned i);
|
||||
/// Computes a mapping from all the ranges of the operands to the enclosing
|
||||
/// loops. In order to support "broadcast"-style semantics, we need to
|
||||
/// consider all the operands (i.e. input operands are not sufficient).
|
||||
/// The operands and their ranges are in the order defined by the particular
|
||||
/// ConcreteOp implementation, the resulting map must match those.
|
||||
/// This is currently computed but can also be specified explicitly in each
|
||||
/// operator to generalize to cases where an analysis is not available.
|
||||
mlir::AffineMap operandRangesToLoopsMap();
|
||||
};
|
||||
|
||||
/// Implements c = A * B where c is a scalar and A and B are 1-D vectors.
|
||||
|
@ -101,6 +112,13 @@ public:
|
|||
static constexpr unsigned numOutputs = 1;
|
||||
static constexpr unsigned numParallelDims = 0;
|
||||
static constexpr unsigned numReductionDims = 1;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Used in Linalg3 and later.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
|
||||
/// loop over matvec). Does nothing by default.
|
||||
void writeAsFinerGrainTensorContraction();
|
||||
};
|
||||
|
||||
/// Implements C = A * B where A is a 2-D matrix and X and Y are 1-D vectors.
|
||||
|
@ -133,6 +151,13 @@ public:
|
|||
static constexpr unsigned numOutputs = 1;
|
||||
static constexpr unsigned numParallelDims = 1;
|
||||
static constexpr unsigned numReductionDims = 1;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Used in Linalg3 and later.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
|
||||
/// loop over matvec). Does nothing by default.
|
||||
void writeAsFinerGrainTensorContraction();
|
||||
};
|
||||
|
||||
/// Implements C = A * B on 2-D matrices.
|
||||
|
@ -165,7 +190,20 @@ public:
|
|||
static constexpr unsigned numOutputs = 1;
|
||||
static constexpr unsigned numParallelDims = 2;
|
||||
static constexpr unsigned numReductionDims = 1;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Used in Linalg3 and later.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
|
||||
/// loop over matvec). Does nothing by default.
|
||||
void writeAsFinerGrainTensorContraction();
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG2_MATMULOP_H_
|
||||
/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
|
||||
/// TensorOps by adding implementations as they are needed in the appropriate
|
||||
/// step in the tutorial.
|
||||
#include "linalg2/TensorOps-inl.h"
|
||||
|
||||
#endif // LINALG2_TENSOROPS_H_
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg2/Analysis.h"
|
||||
#include "linalg1/Utils.h"
|
||||
#include "linalg2/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
@ -34,116 +34,6 @@ using llvm::Twine;
|
|||
using namespace mlir;
|
||||
using namespace linalg;
|
||||
|
||||
template <class ConcreteOp>
|
||||
Type linalg::TensorContractionBase<ConcreteOp>::getInputElementType(
|
||||
unsigned idx) {
|
||||
return getInputView(idx)
|
||||
->getType()
|
||||
.template cast<ViewType>()
|
||||
.getElementType();
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
Type linalg::TensorContractionBase<ConcreteOp>::getOutputElementType(
|
||||
unsigned idx) {
|
||||
return getOutputView(idx)
|
||||
->getType()
|
||||
.template cast<ViewType>()
|
||||
.getElementType();
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
Value *linalg::TensorContractionBase<ConcreteOp>::getInputView(unsigned idx) {
|
||||
return *(getInputs().begin() + idx);
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
Value *linalg::TensorContractionBase<ConcreteOp>::getOutputView(unsigned idx) {
|
||||
return *(getOutputs().begin() + idx);
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
Value *linalg::TensorContractionBase<ConcreteOp>::getInputMemRef(unsigned idx) {
|
||||
return getViewSupportingMemRef(*(getInputs().begin() + idx));
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
Value *
|
||||
linalg::TensorContractionBase<ConcreteOp>::getOutputMemRef(unsigned idx) {
|
||||
return getViewSupportingMemRef(*(getOutputs().begin() + idx));
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
mlir::Operation::operand_range
|
||||
linalg::TensorContractionBase<ConcreteOp>::getInputs() {
|
||||
auto *op = static_cast<ConcreteOp *>(this)->getOperation();
|
||||
return {op->operand_begin(), op->operand_begin() + getNumInputs()};
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
mlir::Operation::operand_range
|
||||
linalg::TensorContractionBase<ConcreteOp>::getOutputs() {
|
||||
auto *op = static_cast<ConcreteOp *>(this)->getOperation();
|
||||
return {op->operand_begin() + getNumInputs(),
|
||||
op->operand_begin() + getNumInputs() + getNumOutputs()};
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
LogicalResult linalg::TensorContractionBase<ConcreteOp>::verify() {
|
||||
auto *concreteOp = static_cast<ConcreteOp *>(this)->getOperation();
|
||||
if (getNumInputs() <= 0)
|
||||
concreteOp->emitOpError("expected at least one input");
|
||||
if (getNumOutputs() <= 0)
|
||||
concreteOp->emitOpError("expected at least one output");
|
||||
if (concreteOp->getNumOperands() != getNumInputs() + getNumOutputs()) {
|
||||
concreteOp->emitOpError(
|
||||
"expected " + Twine(getNumInputs() + getNumOutputs()) + " operands");
|
||||
}
|
||||
for (unsigned i = 0, e = getNumInputs(); i < e; ++i) {
|
||||
if (!concreteOp->getOperand(i)->getType().template isa<ViewType>())
|
||||
return concreteOp->emitOpError("operand " + Twine(i) + " not a ViewType");
|
||||
}
|
||||
for (unsigned i = getNumInputs(), e = getNumInputs() + getNumOutputs(); i < e;
|
||||
++i) {
|
||||
auto viewType =
|
||||
concreteOp->getOperand(i)->getType().template dyn_cast<ViewType>();
|
||||
if (!viewType)
|
||||
return concreteOp->emitOpError("operand " + Twine(i) + " not a ViewType");
|
||||
if (viewType.getRank() != getNumParallelDims())
|
||||
return concreteOp->emitOpError("operand " + Twine(i) +
|
||||
" must be of rank " +
|
||||
Twine(getNumParallelDims()));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
bool linalg::TensorContractionBase<ConcreteOp>::parse(OpAsmParser *parser,
|
||||
OperationState *result) {
|
||||
llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
|
||||
}
|
||||
|
||||
// A TensorContraction prints as:
|
||||
//
|
||||
// ```{.mlir}
|
||||
// concrete_op_name {%0, %1} -> {%2}
|
||||
// ```
|
||||
//
|
||||
// Where %0, %1 is an ssa-value holding a View, %2 is an ssa-value holding a
|
||||
// view.
|
||||
template <class ConcreteOp>
|
||||
void linalg::TensorContractionBase<ConcreteOp>::print(OpAsmPrinter *p) {
|
||||
*p << static_cast<ConcreteOp *>(this)->getOperationName() << " {";
|
||||
auto *lastInput = *std::prev(getInputs().end());
|
||||
for (auto *i : getInputs()) {
|
||||
*p << *i << ((i == lastInput) ? "} -> {" : ", ");
|
||||
}
|
||||
auto *lastOutput = *std::prev(getOutputs().end());
|
||||
for (auto *o : getOutputs()) {
|
||||
*p << *o << ((o == lastOutput) ? "}" : ",");
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Op-specific Dot.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
//===- Example.cpp - Our running example ----------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
// RUN: %p/test | FileCheck %s
|
||||
|
||||
#include "TestHarness.h"
|
||||
#include "linalg1/Common.h"
|
||||
#include "linalg2/Intrinsics.h"
|
||||
#include "linalg3/Ops.h"
|
||||
#include "linalg3/Transforms.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
|
||||
using llvm::StringRef;
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
using namespace linalg;
|
||||
using namespace linalg::common;
|
||||
using namespace linalg::intrinsics;
|
||||
|
||||
Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
|
||||
MLIRContext *context = module.getContext();
|
||||
auto dynamic2DMemRefType = floatMemRefType<2>(context);
|
||||
mlir::Function *f = linalg::common::makeFunction(
|
||||
module, name,
|
||||
{dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
|
||||
|
||||
ScopedContext scope(f);
|
||||
// clang-format off
|
||||
ValueHandle
|
||||
M = dim(f->getArgument(0), 0),
|
||||
N = dim(f->getArgument(2), 1),
|
||||
K = dim(f->getArgument(0), 1),
|
||||
rM = range(constant_index(0), M, constant_index(1)),
|
||||
rN = range(constant_index(0), N, constant_index(1)),
|
||||
rK = range(constant_index(0), K, constant_index(1)),
|
||||
vA = view(f->getArgument(0), {rM, rK}),
|
||||
vB = view(f->getArgument(1), {rK, rN}),
|
||||
vC = view(f->getArgument(2), {rM, rN});
|
||||
matmul(vA, vB, vC);
|
||||
ret();
|
||||
// clang-format on
|
||||
|
||||
return f;
|
||||
}
|
||||
|
||||
TEST_FUNC(matmul_as_matvec) {
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec");
|
||||
lowerToFinerGrainedTensorContraction(f);
|
||||
// clang-format off
|
||||
// CHECK-LABEL: func @matmul_as_matvec(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
|
||||
// CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
|
||||
// CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) {
|
||||
// CHECK-NEXT: %[[vB:.*]] = linalg.slice %{{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
|
||||
// CHECK-NEXT: %[[vC:.*]] = linalg.slice %{{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
|
||||
// CHECK-NEXT: linalg.matvec {%{{.*}}, %[[vB]]} -> {%[[vC]]}
|
||||
// clang-format on
|
||||
cleanupAndPrintFunction(f);
|
||||
}
|
||||
|
||||
TEST_FUNC(matmul_as_dot) {
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_dot");
|
||||
lowerToFinerGrainedTensorContraction(f);
|
||||
lowerToFinerGrainedTensorContraction(f);
|
||||
// clang-format off
|
||||
// CHECK-LABEL: func @matmul_as_dot(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
|
||||
// CHECK: %[[M:.*]] = dim %arg0, 0 : memref<?x?xf32>
|
||||
// CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
|
||||
// CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) {
|
||||
// CHECK-NEXT: %[[vB:.*]] = linalg.slice {{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
|
||||
// CHECK-NEXT: %[[sC:.*]] = linalg.slice {{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
|
||||
// CHECK-NEXT: affine.for %i1 = 0 to (d0) -> (d0)(%[[M]]) {
|
||||
// CHECK-NEXT: %[[vA:.*]] = linalg.slice {{.*}}[%i1, *] { dim : 0 } : !linalg<"view<f32>">
|
||||
// CHECK-NEXT: %[[vC:.*]] = linalg.slice %[[sC]][%i1] { dim : 0 } : !linalg<"view<0xf32>">
|
||||
// CHECK-NEXT: linalg.dot {%[[vA]], %[[vB]]} -> {%[[vC]]}
|
||||
// clang-format on
|
||||
cleanupAndPrintFunction(f);
|
||||
}
|
||||
|
||||
int main() {
|
||||
RUN_TESTS();
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
//===- Ops.h - Linalg Ops single entry point ------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG3_OPS_H_
|
||||
#define LINALG3_OPS_H_
|
||||
|
||||
#include "linalg2/Ops.h"
|
||||
#include "linalg3/TensorOps.h"
|
||||
|
||||
#endif // LINALG3_OPS_H_
|
|
@ -0,0 +1,43 @@
|
|||
//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
|
||||
/// TensorOps by adding implementations as they are needed in the appropriate
|
||||
/// step in the tutorial.
|
||||
#ifndef LINALG3_TENSOROPS_INL_H_
|
||||
#define LINALG3_TENSOROPS_INL_H_
|
||||
|
||||
#include "linalg1/Common.h"
|
||||
#include "linalg2/TensorOps.h"
|
||||
|
||||
namespace linalg {
|
||||
|
||||
template <class ConcreteOp>
|
||||
mlir::Value *
|
||||
linalg::TensorContractionBase<ConcreteOp>::getInputView(unsigned i) {
|
||||
return *(getInputs().begin() + i);
|
||||
}
|
||||
|
||||
template <class ConcreteOp>
|
||||
mlir::Value *
|
||||
linalg::TensorContractionBase<ConcreteOp>::getOutputView(unsigned i) {
|
||||
return *(getOutputs().begin() + i);
|
||||
}
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG3_TENSOROPS-INL_H_
|
|
@ -0,0 +1,28 @@
|
|||
//===- TensorOps.h - Linalg dialect TensorOps operation definition --------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG3_TENSOROPS_H_
|
||||
#define LINALG3_TENSOROPS_H_
|
||||
|
||||
#include "linalg2/TensorOps.h"
|
||||
|
||||
/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
|
||||
/// TensorOps by adding implementations as they are needed in the appropriate
|
||||
/// step in the tutorial.
|
||||
#include "linalg3/TensorOps-inl.h"
|
||||
|
||||
#endif // LINALG3_TENSOROPS_H_
|
|
@ -0,0 +1,39 @@
|
|||
//===- Transforms.h - Linalg dialect Transformations definition -----------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef LINALG3_TRANSFORMS_H_
|
||||
#define LINALG3_TRANSFORMS_H_
|
||||
|
||||
#include "linalg2/Transforms.h"
|
||||
|
||||
namespace mlir {
|
||||
class Function;
|
||||
} // namespace mlir
|
||||
|
||||
namespace linalg {
|
||||
|
||||
/// Traverses `f` and rewrites linalg.slice, and the operations it depends on,
|
||||
/// to only use linalg.view operations.
|
||||
void composeSliceOps(mlir::Function *f);
|
||||
|
||||
/// Traverses `f` and rewrites linalg.matmul (resp. linalg.matvec, linalg.dot)
|
||||
/// as linalg.matvec (resp. linalg.dot, loop form).
|
||||
void lowerToFinerGrainedTensorContraction(mlir::Function *f);
|
||||
|
||||
} // namespace linalg
|
||||
|
||||
#endif // LINALG3_TRANSFORMS_H_
|
|
@ -0,0 +1,70 @@
|
|||
//===- TensorOps.cpp - Implementation of the linalg TensorOps operation ---===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a simple IR operation to create new tensor computation
|
||||
// operations in the linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg1/Analysis.h"
|
||||
#include "linalg1/Common.h"
|
||||
#include "linalg2/Intrinsics.h"
|
||||
#include "linalg3/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
using namespace linalg;
|
||||
using namespace linalg::intrinsics;
|
||||
|
||||
// The body expression for matvec is: C(i) = scalarC + A(i, r_j) * B(r_j)
|
||||
// The body expression for dot is: C() = A(r_i) * B(r_i);
|
||||
// So we must drop the `i` loop from the matvec.
|
||||
void linalg::MatvecOp::writeAsFinerGrainTensorContraction() {
|
||||
auto *op = getOperation();
|
||||
ScopedContext scope(FuncBuilder(op), op->getLoc());
|
||||
IndexHandle i;
|
||||
auto *vA(getInputView(0)), *vB(getInputView(1)), *vC(getOutputView(0));
|
||||
auto indexingPosPair = getViewRootIndexing(vA, 0);
|
||||
assert(indexingPosPair.first->getDefiningOp() &&
|
||||
indexingPosPair.first->getDefiningOp()->isa<RangeOp>());
|
||||
linalg::common::LoopNestRangeBuilder(&i, ValueHandle(indexingPosPair.first))({
|
||||
dot(slice(vA, i, 0), vB, slice(vC, i, 0)),
|
||||
});
|
||||
}
|
||||
|
||||
// The body expression for matmul is: C(i, j) = scalarC + A(i, r_k) * B(r_k, j)
|
||||
// The body expression for matvec is: C(i) = scalarC + A(i, r_j) * B(r_j)
|
||||
// So we must drop the `j` loop from the matmul.
|
||||
// This is fine because parallel dimensions permute: we can just do it
|
||||
// declaratively.
|
||||
void linalg::MatmulOp::writeAsFinerGrainTensorContraction() {
|
||||
auto *op = getOperation();
|
||||
ScopedContext scope(FuncBuilder(op), op->getLoc());
|
||||
IndexHandle j;
|
||||
auto *vA(getInputView(0)), *vB(getInputView(1)), *vC(getOutputView(0));
|
||||
auto indexingPosPair = getViewRootIndexing(vB, 1);
|
||||
assert(indexingPosPair.first->getDefiningOp() &&
|
||||
indexingPosPair.first->getDefiningOp()->isa<RangeOp>());
|
||||
linalg::common::LoopNestRangeBuilder(&j, ValueHandle(indexingPosPair.first))({
|
||||
matvec(vA, slice(vB, j, 1), slice(vC, j, 1)),
|
||||
});
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
//===- Transforms.cpp - Implementation of the linalg Transformations ------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements analyses and transformations for the linalg dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "linalg3/Transforms.h"
|
||||
#include "linalg2/Intrinsics.h"
|
||||
#include "linalg3/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
using namespace linalg;
|
||||
using namespace linalg::intrinsics;
|
||||
|
||||
void linalg::composeSliceOps(mlir::Function *f) {
|
||||
f->walkPostOrder<SliceOp>([](SliceOp sliceOp) {
|
||||
auto *sliceResult = sliceOp.getResult();
|
||||
auto viewOp = createFullyComposedView(sliceResult);
|
||||
sliceResult->replaceAllUsesWith(viewOp.getResult());
|
||||
sliceOp.erase();
|
||||
});
|
||||
}
|
||||
|
||||
void linalg::lowerToFinerGrainedTensorContraction(mlir::Function *f) {
|
||||
f->walkPostOrder([](Operation *op) {
|
||||
if (auto matmulOp = op->dyn_cast<linalg::MatmulOp>()) {
|
||||
matmulOp.writeAsFinerGrainTensorContraction();
|
||||
} else if (auto matvecOp = op->dyn_cast<linalg::MatvecOp>()) {
|
||||
matvecOp.writeAsFinerGrainTensorContraction();
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
op->erase();
|
||||
});
|
||||
}
|
Loading…
Reference in New Issue