Add support for a Linalg base op class

This CL uses a pattern proposed by aminim@ to add a base Linalg op that further dispatches to the proper op implementation.
    This CL adds a LinalgOp which implements isclassof for only a subset of the linalg ops: the ops that behave like a library call for the purpose of transformations like tiling.
    This uses a static dispatch mechanism based on the LinalgLibraryOps.td ops declarations to avoid switch or visitor patterns. This may later be replaced by Tablegen'd dispatch when it is available.

    As a consequence, the list of library like operations in Linalg may now grow without having to modify any of the dispatch or transformation support.

    More details in the concept-based dispatch, as explained by aminim@
    ```
    This is inspired by Sean Parent's: https://sean-parent.stlab.cc/papers-and-presentations/#value-semantics-and-concept-based-polymorphism

    A key difference is that the set of classes erased is statically known, which avoids to use dynamic memory allocation.
    We use a zero-sized templated class to emit the virtual table and generate a singleton object for each instantiation of this class. We pay the cost of initialization once on construction (find which class to dispatch to) and then a virtual dispatch on every call.
    ```

--

PiperOrigin-RevId: 248258921
This commit is contained in:
Nicolas Vasilache 2019-05-14 19:37:48 -07:00 committed by Mehdi Amini
parent cde4d5a6d9
commit fa01679e7c
7 changed files with 261 additions and 86 deletions

View File

@ -0,0 +1,42 @@
//===- LinalgBase.td - Linalg dialect base support ---------*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This is the definition file for base linear algebra support.
//
//===----------------------------------------------------------------------===//
#ifdef LINALG_OPS
#else
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
def Linalg_Dialect : Dialect {
let name = "linalg";
}
// Whether a type is a BufferType.
def LinalgIsBufferTypePred : CPred<"$_self.isa<BufferType>()">;
def Buffer : Type<LinalgIsBufferTypePred, "buffer">;
// Whether a type is a ViewType.
def LinalgIsViewTypePred : CPred<"$_self.isa<ViewType>()">;
def View : Type<LinalgIsViewTypePred, "view">;
#endif // LINALG_OPS

View File

@ -0,0 +1,91 @@
//===- LinalgLibraryOps.td - Linalg dialect library ops -*- tablegen ----*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This is the operation definition file for linear algebra operations that
// correspond to underlying library calls (e.g. BLAS).
//
//===----------------------------------------------------------------------===//
#ifdef LINALG_OPS
#else
#ifdef LINALG_BASE
#else
include "mlir/Linalg/IR/LinalgBase.td"
#endif // LINALG_BASE
class LinalgParametricNativeOpTrait<string prop, string parameters> :
NativeOpTrait<"linalg::" # prop # parameters>
{}
class LinalgParametricIntNativeOpTrait<string prop, list<int> parameters> :
LinalgParametricNativeOpTrait<
prop,
!strconcat("<",
!cast<string>(!head(parameters)),
!foldl("",
!tail(parameters),
sum,
param,
sum # "," # !cast<string>(param)),
">::Impl")>
{}
// The Linalg `NInputsAndOutputs` trait provides the API for ops that are known
// to have a specified number of inputs and outputs, all passed as operands.
// See Linalg/LinalgTraits.h for implementation details an usage.
class NInputsAndOutputs<int n_ins, int n_outs> :
LinalgParametricIntNativeOpTrait<"NInputsAndOutputs", [n_ins, n_outs]>
{}
// The linalg `NLoopTypes` trait provides the API for ops that are known to have
// a specified number of parallel (n_par), reduction (n_red) and window (n_win)
// loops.
// See Linalg/LinalgTraits.h for implementation details an usage.
class NLoopTypes<int n_par, int n_red, int n_win> :
LinalgParametricIntNativeOpTrait<"NLoopTypes", [n_par, n_red, n_win]>
{}
// The linalg `ViewRanks` trait the API for ops that are known to have a
// specified list of view ranks.
// See Linalg/LinalgTraits.h for implementation details an usage.
class ViewRanks<list<int> ranks> :
LinalgParametricIntNativeOpTrait<"ViewRanks", ranks>
{}
// Base Tablegen class for Linalg ops.
class LinalgOp<string mnemonic, list<OpTrait> props> :
Op<Linalg_Dialect, mnemonic, props> {
let arguments = (ins Variadic<View>); // default variadic builder
let parser = [{ return parseLinalgLibraryOp(parser, result); }];
let printer = [{ printLinalgLibraryOp(p, *this); }];
}
////////////////////////////////////////////////////////////////////////////////
// Concrete Linalg ops.
////////////////////////////////////////////////////////////////////////////////
def DotOp : LinalgOp<"dot", [NInputsAndOutputs<2, 1>,
NLoopTypes<0, 1, 0>,
ViewRanks<[1, 1, 0]>]> {}
def MatvecOp : LinalgOp<"matvec", [NInputsAndOutputs<2, 1>,
NLoopTypes<1, 1, 0>,
ViewRanks<[2, 1, 1]>]> {}
def MatmulOp : LinalgOp<"matmul", [NInputsAndOutputs<2, 1>,
NLoopTypes<2, 1, 0>,
ViewRanks<[2, 2, 2]>]> {}
#endif // LINALG_OPS

View File

@ -18,6 +18,7 @@
#ifndef MLIR_LINALG_LINALGOPS_H_
#define MLIR_LINALG_LINALGOPS_H_
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Linalg/IR/LinalgTraits.h"
#include "mlir/Linalg/IR/LinalgTypes.h"
@ -274,6 +275,9 @@ public:
#define GET_OP_CLASSES
#include "mlir/Linalg/IR/LinalgOps.h.inc"
#define GET_OP_CLASSES
#include "mlir/Linalg/IR/LinalgLibraryOps.h.inc"
/// Returns the list of maps that map loops to operands of a Linalg op.
/// The i-th affine map identifies loop indices to subscripts that are used when
/// accessing the i-th operand.
@ -292,6 +296,116 @@ public:
/// Only permutation maps are currently supported.
SmallVector<AffineMap, 4> loopToOperandRangesMaps(Operation *op);
/// A LinalgOp behaves like a base class for the Linalg operations that are
/// defined in LinalgLibraryOps.td. The implementation does not use inheritance
/// directly. Instead, a LinalgOp directly derives from Op, hides the `classof`
/// method and dispatches to the appropriate LinalgLibraryOp.
/// This allows writing generic passes, like tiling, for all current and future
/// LinalgOps without requiring templating and dispatch in multiple places.
class LinalgOp : public Op<LinalgOp> {
public:
using Op::Op;
LinalgOp(Operation *op) : Op<LinalgOp>(op) {
impl = ModelDispatch<
#define GET_OP_LIST
#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
>::dispatch(op);
}
static bool classof(Operation *op) {
return ModelDispatch<
#define GET_OP_LIST
#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
>::classof(op);
}
unsigned getNumParallelLoops() {
return impl->getNumParallelLoops(getOperation());
}
unsigned getNumReductionLoops() {
return impl->getNumReductionLoops(getOperation());
}
unsigned getNumWindowLoops() {
return impl->getNumWindowLoops(getOperation());
}
unsigned getNumInputsAndOutputs() {
return impl->getNumInputsAndOutputs(getOperation());
}
Operation *create(FuncBuilder &builder, Location loc,
ArrayRef<Value *> operands) {
return impl->create(builder, loc, operands);
}
private:
struct Concept {
virtual ~Concept() = default;
virtual unsigned getNumInputsAndOutputs(Operation *op) = 0;
virtual unsigned getNumParallelLoops(Operation *op) = 0;
virtual unsigned getNumReductionLoops(Operation *op) = 0;
virtual unsigned getNumWindowLoops(Operation *op) = 0;
virtual unsigned getNumLoops(Operation *op) = 0;
virtual Operation *create(FuncBuilder &builder, Location loc,
ArrayRef<Value *> operands) = 0;
};
/// The implementation is inspired from Sean Parent's concept-based
/// polymorphism. A key difference is that the set of classes erased is
/// statically known, which alleviates the need for using dynamic memory
/// allocation.
/// We use a zero-sized templated class `Model<ConcreteOp>` to emit the
/// virtual table and generate a singleton object for each instantiation of
/// this class.
/// We pay the cost of initialization once on construction (find which class
/// to dispatch to) and then a virtual dispatch on every call.
template <typename ConcreteOp> struct Model : public Concept {
static Model<ConcreteOp> &instance() {
static Model<ConcreteOp> singleton;
return singleton;
}
unsigned getNumInputsAndOutputs(Operation *op) override {
return cast<ConcreteOp>(op).getNumInputsAndOutputs();
}
unsigned getNumParallelLoops(Operation *op) override {
return cast<ConcreteOp>(op).getNumParallelLoops();
}
unsigned getNumReductionLoops(Operation *op) override {
return cast<ConcreteOp>(op).getNumReductionLoops();
}
unsigned getNumWindowLoops(Operation *op) override {
return cast<ConcreteOp>(op).getNumWindowLoops();
}
unsigned getNumLoops(Operation *op) override {
return cast<ConcreteOp>(op).getNumLoops();
}
Operation *create(FuncBuilder &builder, Location loc,
ArrayRef<Value *> operands) override {
return builder.create<ConcreteOp>(loc, operands);
}
};
Concept *impl;
template <typename... Types> struct ModelDispatch;
template <typename First, typename... Rest>
struct ModelDispatch<First, Rest...> {
static bool classof(Operation *op) {
return isa<First>(op) || ModelDispatch<Rest...>::classof(op);
}
static Concept *dispatch(Operation *op) {
return isa<First>(op) ? &Model<First>::instance()
: ModelDispatch<Rest...>::dispatch(op);
}
};
template <typename...> struct ModelDispatch {
static bool classof(Operation *op) { return false; }
static Concept *dispatch(Operation *op) {
llvm_unreachable("Invalid LinalgOp");
}
};
};
} // namespace linalg
} // namespace mlir

View File

@ -1,4 +1,4 @@
//===- LinalgOps.td - Linear algebra dialect ops -----------*- tablegen -*-===//
//===- LinalgOps.td - Linalg dialect ops -------------------*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
@ -22,69 +22,10 @@
#ifdef LINALG_OPS
#else
#ifdef OP_BASE
#ifdef LINALG_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
def Linalg_Dialect : Dialect {
let name = "linalg";
}
// Whether a type is a BufferType.
def LinalgIsBufferTypePred : CPred<"$_self.isa<BufferType>()">;
def Buffer : Type<LinalgIsBufferTypePred, "buffer">;
// Whether a type is a ViewType.
def LinalgIsViewTypePred : CPred<"$_self.isa<ViewType>()">;
def View : Type<LinalgIsViewTypePred, "view">;
class LinalgParametricNativeOpTrait<string prop, string parameters> :
NativeOpTrait<"linalg::" # prop # parameters>
{}
class LinalgParametricIntNativeOpTrait<string prop, list<int> parameters> :
LinalgParametricNativeOpTrait<
prop,
!strconcat("<",
!cast<string>(!head(parameters)),
!foldl("",
!tail(parameters),
sum,
param,
sum # "," # !cast<string>(param)),
">::Impl")>
{}
// The Linalg `NInputsAndOutputs` trait provides the API for ops that are known
// to have a specified number of inputs and outputs, all passed as operands.
// See Linalg/LinalgTraits.h for implementation details an usage.
class NInputsAndOutputs<int n_ins, int n_outs> :
LinalgParametricIntNativeOpTrait<"NInputsAndOutputs", [n_ins, n_outs]>
{}
// The linalg `NLoopTypes` trait provides the API for ops that are known to have
// a specified number of parallel (n_par), reduction (n_red) and window (n_win)
// loops.
// See Linalg/LinalgTraits.h for implementation details an usage.
class NLoopTypes<int n_par, int n_red, int n_win> :
LinalgParametricIntNativeOpTrait<"NLoopTypes", [n_par, n_red, n_win]>
{}
// The linalg `ViewRanks` trait the API for ops that are known to have a
// specified list of view ranks.
// See Linalg/LinalgTraits.h for implementation details an usage.
class ViewRanks<list<int> ranks> :
LinalgParametricIntNativeOpTrait<"ViewRanks", ranks>
{}
// Base Tablegen class for Linalg ops.
class LinalgOp<string mnemonic, list<OpTrait> props> :
Op<Linalg_Dialect, mnemonic, props> {
let arguments = (ins Variadic<View>); // default variadic builder
let parser = [{ return parseLinalgLibraryOp(parser, result); }];
let printer = [{ printLinalgLibraryOp(p, *this); }];
}
include "mlir/Linalg/IR/LinalgBase.td"
#endif // LINALG_BASE
def BufferSizeOp :
Op<Linalg_Dialect, "buffer_size", [NoSideEffect]>,
@ -127,17 +68,4 @@ def DimOp : Op<Linalg_Dialect, "dim", [NoSideEffect]>,
}];
}
////////////////////////////////////////////////////////////////////////////////
// Concrete Linalg ops.
////////////////////////////////////////////////////////////////////////////////
def DotOp : LinalgOp<"dot", [NInputsAndOutputs<2, 1>,
NLoopTypes<0, 1, 0>,
ViewRanks<[1, 1, 0]>]> {}
def MatvecOp : LinalgOp<"matvec", [NInputsAndOutputs<2, 1>,
NLoopTypes<1, 1, 0>,
ViewRanks<[2, 1, 1]>]> {}
def MatmulOp : LinalgOp<"matmul", [NInputsAndOutputs<2, 1>,
NLoopTypes<2, 1, 0>,
ViewRanks<[2, 2, 2]>]> {}
#endif // LINALG_OPS

View File

@ -593,6 +593,9 @@ namespace mlir {
#define GET_OP_CLASSES
#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
#define GET_OP_CLASSES
#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
} // namespace mlir
// Ideally this should all be Tablegen'd but there is no good story for

View File

@ -37,6 +37,10 @@ mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context)
#define GET_OP_LIST
#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
>();
}
struct mlir::linalg::BufferTypeStorage : public TypeStorage {

View File

@ -248,7 +248,6 @@ static SmallVector<Value *, 4> makeTiledViews(FuncBuilder *b, Location loc,
return res;
}
template <class LinalgOp>
static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<Value *> tileSizes,
PerFunctionState &state) {
// Enforce the convention that "tiling by zero" skips tiling a particular
@ -278,7 +277,7 @@ static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<Value *> tileSizes,
assert(op.getNumInputsAndOutputs() == op.getOperation()->getNumOperands());
auto views =
makeTiledViews(b, loc, op.getOperation(), ivValues, tileSizes, state);
b->create<LinalgOp>(loc, views);
op.create(*b, loc, views);
/// NestedBuilders expect handles, we thus return an IndexHandle.
return IndexHandle();
}()});
@ -286,7 +285,6 @@ static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<Value *> tileSizes,
return success();
}
template <class LinalgOp>
static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<int64_t> tileSizes,
PerFunctionState &state) {
if (tileSizes.empty())
@ -319,13 +317,8 @@ static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<int64_t> tileSizes,
// TODO(ntv) expose as a primitive for other passes.
static LogicalResult tileLinalgOp(Operation *op, ArrayRef<int64_t> tileSizes,
PerFunctionState &state) {
if (auto matmulOp = dyn_cast<MatmulOp>(op)) {
return tileLinalgOp(matmulOp, tileSizes, state);
} else if (auto matvecOp = dyn_cast<MatvecOp>(op)) {
return tileLinalgOp(matvecOp, tileSizes, state);
} else if (auto dotOp = dyn_cast<DotOp>(op)) {
return tileLinalgOp(dotOp, tileSizes, state);
}
if (auto linalgOp = dyn_cast<LinalgOp>(op))
return tileLinalgOp(linalgOp, tileSizes, state);
return failure();
}