forked from OSchip/llvm-project
Add support for a Linalg base op class
This CL uses a pattern proposed by aminim@ to add a base Linalg op that further dispatches to the proper op implementation. This CL adds a LinalgOp which implements isclassof for only a subset of the linalg ops: the ops that behave like a library call for the purpose of transformations like tiling. This uses a static dispatch mechanism based on the LinalgLibraryOps.td ops declarations to avoid switch or visitor patterns. This may later be replaced by Tablegen'd dispatch when it is available. As a consequence, the list of library like operations in Linalg may now grow without having to modify any of the dispatch or transformation support. More details in the concept-based dispatch, as explained by aminim@ ``` This is inspired by Sean Parent's: https://sean-parent.stlab.cc/papers-and-presentations/#value-semantics-and-concept-based-polymorphism A key difference is that the set of classes erased is statically known, which avoids to use dynamic memory allocation. We use a zero-sized templated class to emit the virtual table and generate a singleton object for each instantiation of this class. We pay the cost of initialization once on construction (find which class to dispatch to) and then a virtual dispatch on every call. ``` -- PiperOrigin-RevId: 248258921
This commit is contained in:
parent
cde4d5a6d9
commit
fa01679e7c
|
@ -0,0 +1,42 @@
|
|||
//===- LinalgBase.td - Linalg dialect base support ---------*- tablegen -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This is the definition file for base linear algebra support.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifdef LINALG_OPS
|
||||
#else
|
||||
|
||||
#ifdef OP_BASE
|
||||
#else
|
||||
include "mlir/IR/OpBase.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def Linalg_Dialect : Dialect {
|
||||
let name = "linalg";
|
||||
}
|
||||
|
||||
// Whether a type is a BufferType.
|
||||
def LinalgIsBufferTypePred : CPred<"$_self.isa<BufferType>()">;
|
||||
def Buffer : Type<LinalgIsBufferTypePred, "buffer">;
|
||||
|
||||
// Whether a type is a ViewType.
|
||||
def LinalgIsViewTypePred : CPred<"$_self.isa<ViewType>()">;
|
||||
def View : Type<LinalgIsViewTypePred, "view">;
|
||||
|
||||
#endif // LINALG_OPS
|
|
@ -0,0 +1,91 @@
|
|||
//===- LinalgLibraryOps.td - Linalg dialect library ops -*- tablegen ----*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This is the operation definition file for linear algebra operations that
|
||||
// correspond to underlying library calls (e.g. BLAS).
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifdef LINALG_OPS
|
||||
#else
|
||||
|
||||
#ifdef LINALG_BASE
|
||||
#else
|
||||
include "mlir/Linalg/IR/LinalgBase.td"
|
||||
#endif // LINALG_BASE
|
||||
|
||||
class LinalgParametricNativeOpTrait<string prop, string parameters> :
|
||||
NativeOpTrait<"linalg::" # prop # parameters>
|
||||
{}
|
||||
|
||||
class LinalgParametricIntNativeOpTrait<string prop, list<int> parameters> :
|
||||
LinalgParametricNativeOpTrait<
|
||||
prop,
|
||||
!strconcat("<",
|
||||
!cast<string>(!head(parameters)),
|
||||
!foldl("",
|
||||
!tail(parameters),
|
||||
sum,
|
||||
param,
|
||||
sum # "," # !cast<string>(param)),
|
||||
">::Impl")>
|
||||
{}
|
||||
|
||||
// The Linalg `NInputsAndOutputs` trait provides the API for ops that are known
|
||||
// to have a specified number of inputs and outputs, all passed as operands.
|
||||
// See Linalg/LinalgTraits.h for implementation details an usage.
|
||||
class NInputsAndOutputs<int n_ins, int n_outs> :
|
||||
LinalgParametricIntNativeOpTrait<"NInputsAndOutputs", [n_ins, n_outs]>
|
||||
{}
|
||||
|
||||
// The linalg `NLoopTypes` trait provides the API for ops that are known to have
|
||||
// a specified number of parallel (n_par), reduction (n_red) and window (n_win)
|
||||
// loops.
|
||||
// See Linalg/LinalgTraits.h for implementation details an usage.
|
||||
class NLoopTypes<int n_par, int n_red, int n_win> :
|
||||
LinalgParametricIntNativeOpTrait<"NLoopTypes", [n_par, n_red, n_win]>
|
||||
{}
|
||||
|
||||
// The linalg `ViewRanks` trait the API for ops that are known to have a
|
||||
// specified list of view ranks.
|
||||
// See Linalg/LinalgTraits.h for implementation details an usage.
|
||||
class ViewRanks<list<int> ranks> :
|
||||
LinalgParametricIntNativeOpTrait<"ViewRanks", ranks>
|
||||
{}
|
||||
|
||||
// Base Tablegen class for Linalg ops.
|
||||
class LinalgOp<string mnemonic, list<OpTrait> props> :
|
||||
Op<Linalg_Dialect, mnemonic, props> {
|
||||
let arguments = (ins Variadic<View>); // default variadic builder
|
||||
let parser = [{ return parseLinalgLibraryOp(parser, result); }];
|
||||
let printer = [{ printLinalgLibraryOp(p, *this); }];
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Concrete Linalg ops.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
def DotOp : LinalgOp<"dot", [NInputsAndOutputs<2, 1>,
|
||||
NLoopTypes<0, 1, 0>,
|
||||
ViewRanks<[1, 1, 0]>]> {}
|
||||
def MatvecOp : LinalgOp<"matvec", [NInputsAndOutputs<2, 1>,
|
||||
NLoopTypes<1, 1, 0>,
|
||||
ViewRanks<[2, 1, 1]>]> {}
|
||||
def MatmulOp : LinalgOp<"matmul", [NInputsAndOutputs<2, 1>,
|
||||
NLoopTypes<2, 1, 0>,
|
||||
ViewRanks<[2, 2, 2]>]> {}
|
||||
|
||||
#endif // LINALG_OPS
|
|
@ -18,6 +18,7 @@
|
|||
#ifndef MLIR_LINALG_LINALGOPS_H_
|
||||
#define MLIR_LINALG_LINALGOPS_H_
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Linalg/IR/LinalgTraits.h"
|
||||
#include "mlir/Linalg/IR/LinalgTypes.h"
|
||||
|
@ -274,6 +275,9 @@ public:
|
|||
#define GET_OP_CLASSES
|
||||
#include "mlir/Linalg/IR/LinalgOps.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Linalg/IR/LinalgLibraryOps.h.inc"
|
||||
|
||||
/// Returns the list of maps that map loops to operands of a Linalg op.
|
||||
/// The i-th affine map identifies loop indices to subscripts that are used when
|
||||
/// accessing the i-th operand.
|
||||
|
@ -292,6 +296,116 @@ public:
|
|||
/// Only permutation maps are currently supported.
|
||||
SmallVector<AffineMap, 4> loopToOperandRangesMaps(Operation *op);
|
||||
|
||||
/// A LinalgOp behaves like a base class for the Linalg operations that are
|
||||
/// defined in LinalgLibraryOps.td. The implementation does not use inheritance
|
||||
/// directly. Instead, a LinalgOp directly derives from Op, hides the `classof`
|
||||
/// method and dispatches to the appropriate LinalgLibraryOp.
|
||||
/// This allows writing generic passes, like tiling, for all current and future
|
||||
/// LinalgOps without requiring templating and dispatch in multiple places.
|
||||
class LinalgOp : public Op<LinalgOp> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
LinalgOp(Operation *op) : Op<LinalgOp>(op) {
|
||||
impl = ModelDispatch<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
|
||||
>::dispatch(op);
|
||||
}
|
||||
|
||||
static bool classof(Operation *op) {
|
||||
return ModelDispatch<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
|
||||
>::classof(op);
|
||||
}
|
||||
|
||||
unsigned getNumParallelLoops() {
|
||||
return impl->getNumParallelLoops(getOperation());
|
||||
}
|
||||
unsigned getNumReductionLoops() {
|
||||
return impl->getNumReductionLoops(getOperation());
|
||||
}
|
||||
unsigned getNumWindowLoops() {
|
||||
return impl->getNumWindowLoops(getOperation());
|
||||
}
|
||||
unsigned getNumInputsAndOutputs() {
|
||||
return impl->getNumInputsAndOutputs(getOperation());
|
||||
}
|
||||
Operation *create(FuncBuilder &builder, Location loc,
|
||||
ArrayRef<Value *> operands) {
|
||||
return impl->create(builder, loc, operands);
|
||||
}
|
||||
|
||||
private:
|
||||
struct Concept {
|
||||
virtual ~Concept() = default;
|
||||
virtual unsigned getNumInputsAndOutputs(Operation *op) = 0;
|
||||
virtual unsigned getNumParallelLoops(Operation *op) = 0;
|
||||
virtual unsigned getNumReductionLoops(Operation *op) = 0;
|
||||
virtual unsigned getNumWindowLoops(Operation *op) = 0;
|
||||
virtual unsigned getNumLoops(Operation *op) = 0;
|
||||
virtual Operation *create(FuncBuilder &builder, Location loc,
|
||||
ArrayRef<Value *> operands) = 0;
|
||||
};
|
||||
|
||||
/// The implementation is inspired from Sean Parent's concept-based
|
||||
/// polymorphism. A key difference is that the set of classes erased is
|
||||
/// statically known, which alleviates the need for using dynamic memory
|
||||
/// allocation.
|
||||
/// We use a zero-sized templated class `Model<ConcreteOp>` to emit the
|
||||
/// virtual table and generate a singleton object for each instantiation of
|
||||
/// this class.
|
||||
/// We pay the cost of initialization once on construction (find which class
|
||||
/// to dispatch to) and then a virtual dispatch on every call.
|
||||
template <typename ConcreteOp> struct Model : public Concept {
|
||||
static Model<ConcreteOp> &instance() {
|
||||
static Model<ConcreteOp> singleton;
|
||||
return singleton;
|
||||
}
|
||||
unsigned getNumInputsAndOutputs(Operation *op) override {
|
||||
return cast<ConcreteOp>(op).getNumInputsAndOutputs();
|
||||
}
|
||||
unsigned getNumParallelLoops(Operation *op) override {
|
||||
return cast<ConcreteOp>(op).getNumParallelLoops();
|
||||
}
|
||||
unsigned getNumReductionLoops(Operation *op) override {
|
||||
return cast<ConcreteOp>(op).getNumReductionLoops();
|
||||
}
|
||||
unsigned getNumWindowLoops(Operation *op) override {
|
||||
return cast<ConcreteOp>(op).getNumWindowLoops();
|
||||
}
|
||||
unsigned getNumLoops(Operation *op) override {
|
||||
return cast<ConcreteOp>(op).getNumLoops();
|
||||
}
|
||||
Operation *create(FuncBuilder &builder, Location loc,
|
||||
ArrayRef<Value *> operands) override {
|
||||
return builder.create<ConcreteOp>(loc, operands);
|
||||
}
|
||||
};
|
||||
Concept *impl;
|
||||
|
||||
template <typename... Types> struct ModelDispatch;
|
||||
|
||||
template <typename First, typename... Rest>
|
||||
struct ModelDispatch<First, Rest...> {
|
||||
static bool classof(Operation *op) {
|
||||
return isa<First>(op) || ModelDispatch<Rest...>::classof(op);
|
||||
}
|
||||
static Concept *dispatch(Operation *op) {
|
||||
return isa<First>(op) ? &Model<First>::instance()
|
||||
: ModelDispatch<Rest...>::dispatch(op);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename...> struct ModelDispatch {
|
||||
static bool classof(Operation *op) { return false; }
|
||||
static Concept *dispatch(Operation *op) {
|
||||
llvm_unreachable("Invalid LinalgOp");
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- LinalgOps.td - Linear algebra dialect ops -----------*- tablegen -*-===//
|
||||
//===- LinalgOps.td - Linalg dialect ops -------------------*- tablegen -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
|
@ -22,69 +22,10 @@
|
|||
#ifdef LINALG_OPS
|
||||
#else
|
||||
|
||||
#ifdef OP_BASE
|
||||
#ifdef LINALG_BASE
|
||||
#else
|
||||
include "mlir/IR/OpBase.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def Linalg_Dialect : Dialect {
|
||||
let name = "linalg";
|
||||
}
|
||||
|
||||
// Whether a type is a BufferType.
|
||||
def LinalgIsBufferTypePred : CPred<"$_self.isa<BufferType>()">;
|
||||
def Buffer : Type<LinalgIsBufferTypePred, "buffer">;
|
||||
|
||||
// Whether a type is a ViewType.
|
||||
def LinalgIsViewTypePred : CPred<"$_self.isa<ViewType>()">;
|
||||
def View : Type<LinalgIsViewTypePred, "view">;
|
||||
|
||||
class LinalgParametricNativeOpTrait<string prop, string parameters> :
|
||||
NativeOpTrait<"linalg::" # prop # parameters>
|
||||
{}
|
||||
|
||||
class LinalgParametricIntNativeOpTrait<string prop, list<int> parameters> :
|
||||
LinalgParametricNativeOpTrait<
|
||||
prop,
|
||||
!strconcat("<",
|
||||
!cast<string>(!head(parameters)),
|
||||
!foldl("",
|
||||
!tail(parameters),
|
||||
sum,
|
||||
param,
|
||||
sum # "," # !cast<string>(param)),
|
||||
">::Impl")>
|
||||
{}
|
||||
|
||||
// The Linalg `NInputsAndOutputs` trait provides the API for ops that are known
|
||||
// to have a specified number of inputs and outputs, all passed as operands.
|
||||
// See Linalg/LinalgTraits.h for implementation details an usage.
|
||||
class NInputsAndOutputs<int n_ins, int n_outs> :
|
||||
LinalgParametricIntNativeOpTrait<"NInputsAndOutputs", [n_ins, n_outs]>
|
||||
{}
|
||||
|
||||
// The linalg `NLoopTypes` trait provides the API for ops that are known to have
|
||||
// a specified number of parallel (n_par), reduction (n_red) and window (n_win)
|
||||
// loops.
|
||||
// See Linalg/LinalgTraits.h for implementation details an usage.
|
||||
class NLoopTypes<int n_par, int n_red, int n_win> :
|
||||
LinalgParametricIntNativeOpTrait<"NLoopTypes", [n_par, n_red, n_win]>
|
||||
{}
|
||||
|
||||
// The linalg `ViewRanks` trait the API for ops that are known to have a
|
||||
// specified list of view ranks.
|
||||
// See Linalg/LinalgTraits.h for implementation details an usage.
|
||||
class ViewRanks<list<int> ranks> :
|
||||
LinalgParametricIntNativeOpTrait<"ViewRanks", ranks>
|
||||
{}
|
||||
|
||||
// Base Tablegen class for Linalg ops.
|
||||
class LinalgOp<string mnemonic, list<OpTrait> props> :
|
||||
Op<Linalg_Dialect, mnemonic, props> {
|
||||
let arguments = (ins Variadic<View>); // default variadic builder
|
||||
let parser = [{ return parseLinalgLibraryOp(parser, result); }];
|
||||
let printer = [{ printLinalgLibraryOp(p, *this); }];
|
||||
}
|
||||
include "mlir/Linalg/IR/LinalgBase.td"
|
||||
#endif // LINALG_BASE
|
||||
|
||||
def BufferSizeOp :
|
||||
Op<Linalg_Dialect, "buffer_size", [NoSideEffect]>,
|
||||
|
@ -127,17 +68,4 @@ def DimOp : Op<Linalg_Dialect, "dim", [NoSideEffect]>,
|
|||
}];
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Concrete Linalg ops.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
def DotOp : LinalgOp<"dot", [NInputsAndOutputs<2, 1>,
|
||||
NLoopTypes<0, 1, 0>,
|
||||
ViewRanks<[1, 1, 0]>]> {}
|
||||
def MatvecOp : LinalgOp<"matvec", [NInputsAndOutputs<2, 1>,
|
||||
NLoopTypes<1, 1, 0>,
|
||||
ViewRanks<[2, 1, 1]>]> {}
|
||||
def MatmulOp : LinalgOp<"matmul", [NInputsAndOutputs<2, 1>,
|
||||
NLoopTypes<2, 1, 0>,
|
||||
ViewRanks<[2, 2, 2]>]> {}
|
||||
|
||||
#endif // LINALG_OPS
|
|
@ -593,6 +593,9 @@ namespace mlir {
|
|||
#define GET_OP_CLASSES
|
||||
#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
// Ideally this should all be Tablegen'd but there is no good story for
|
||||
|
|
|
@ -37,6 +37,10 @@ mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context)
|
|||
#define GET_OP_LIST
|
||||
#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
|
||||
>();
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
struct mlir::linalg::BufferTypeStorage : public TypeStorage {
|
||||
|
|
|
@ -248,7 +248,6 @@ static SmallVector<Value *, 4> makeTiledViews(FuncBuilder *b, Location loc,
|
|||
return res;
|
||||
}
|
||||
|
||||
template <class LinalgOp>
|
||||
static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<Value *> tileSizes,
|
||||
PerFunctionState &state) {
|
||||
// Enforce the convention that "tiling by zero" skips tiling a particular
|
||||
|
@ -278,7 +277,7 @@ static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<Value *> tileSizes,
|
|||
assert(op.getNumInputsAndOutputs() == op.getOperation()->getNumOperands());
|
||||
auto views =
|
||||
makeTiledViews(b, loc, op.getOperation(), ivValues, tileSizes, state);
|
||||
b->create<LinalgOp>(loc, views);
|
||||
op.create(*b, loc, views);
|
||||
/// NestedBuilders expect handles, we thus return an IndexHandle.
|
||||
return IndexHandle();
|
||||
}()});
|
||||
|
@ -286,7 +285,6 @@ static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<Value *> tileSizes,
|
|||
return success();
|
||||
}
|
||||
|
||||
template <class LinalgOp>
|
||||
static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<int64_t> tileSizes,
|
||||
PerFunctionState &state) {
|
||||
if (tileSizes.empty())
|
||||
|
@ -319,13 +317,8 @@ static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<int64_t> tileSizes,
|
|||
// TODO(ntv) expose as a primitive for other passes.
|
||||
static LogicalResult tileLinalgOp(Operation *op, ArrayRef<int64_t> tileSizes,
|
||||
PerFunctionState &state) {
|
||||
if (auto matmulOp = dyn_cast<MatmulOp>(op)) {
|
||||
return tileLinalgOp(matmulOp, tileSizes, state);
|
||||
} else if (auto matvecOp = dyn_cast<MatvecOp>(op)) {
|
||||
return tileLinalgOp(matvecOp, tileSizes, state);
|
||||
} else if (auto dotOp = dyn_cast<DotOp>(op)) {
|
||||
return tileLinalgOp(dotOp, tileSizes, state);
|
||||
}
|
||||
if (auto linalgOp = dyn_cast<LinalgOp>(op))
|
||||
return tileLinalgOp(linalgOp, tileSizes, state);
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue