From fa01679e7c19b032856f8433f0c7af57dcfa4943 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Tue, 14 May 2019 19:37:48 -0700 Subject: [PATCH] Add support for a Linalg base op class This CL uses a pattern proposed by aminim@ to add a base Linalg op that further dispatches to the proper op implementation. This CL adds a LinalgOp which implements isclassof for only a subset of the linalg ops: the ops that behave like a library call for the purpose of transformations like tiling. This uses a static dispatch mechanism based on the LinalgLibraryOps.td ops declarations to avoid switch or visitor patterns. This may later be replaced by Tablegen'd dispatch when it is available. As a consequence, the list of library like operations in Linalg may now grow without having to modify any of the dispatch or transformation support. More details in the concept-based dispatch, as explained by aminim@ ``` This is inspired by Sean Parent's: https://sean-parent.stlab.cc/papers-and-presentations/#value-semantics-and-concept-based-polymorphism A key difference is that the set of classes erased is statically known, which avoids to use dynamic memory allocation. We use a zero-sized templated class to emit the virtual table and generate a singleton object for each instantiation of this class. We pay the cost of initialization once on construction (find which class to dispatch to) and then a virtual dispatch on every call. ``` -- PiperOrigin-RevId: 248258921 --- mlir/include/mlir/Linalg/IR/LinalgBase.td | 42 +++++++ .../mlir/Linalg/IR/LinalgLibraryOps.td | 91 ++++++++++++++ mlir/include/mlir/Linalg/IR/LinalgOps.h | 114 ++++++++++++++++++ mlir/include/mlir/Linalg/IR/LinalgOps.td | 80 +----------- mlir/lib/Linalg/IR/LinalgOps.cpp | 3 + mlir/lib/Linalg/IR/LinalgTypes.cpp | 4 + mlir/lib/Linalg/Transforms/Tiling.cpp | 13 +- 7 files changed, 261 insertions(+), 86 deletions(-) create mode 100644 mlir/include/mlir/Linalg/IR/LinalgBase.td create mode 100644 mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td diff --git a/mlir/include/mlir/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Linalg/IR/LinalgBase.td new file mode 100644 index 000000000000..42e5bcd5bc00 --- /dev/null +++ b/mlir/include/mlir/Linalg/IR/LinalgBase.td @@ -0,0 +1,42 @@ +//===- LinalgBase.td - Linalg dialect base support ---------*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This is the definition file for base linear algebra support. +// +//===----------------------------------------------------------------------===// + +#ifdef LINALG_OPS +#else + +#ifdef OP_BASE +#else +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +def Linalg_Dialect : Dialect { + let name = "linalg"; +} + +// Whether a type is a BufferType. +def LinalgIsBufferTypePred : CPred<"$_self.isa()">; +def Buffer : Type; + +// Whether a type is a ViewType. +def LinalgIsViewTypePred : CPred<"$_self.isa()">; +def View : Type; + +#endif // LINALG_OPS \ No newline at end of file diff --git a/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td b/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td new file mode 100644 index 000000000000..15b9fabfd5f5 --- /dev/null +++ b/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td @@ -0,0 +1,91 @@ +//===- LinalgLibraryOps.td - Linalg dialect library ops -*- tablegen ----*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This is the operation definition file for linear algebra operations that +// correspond to underlying library calls (e.g. BLAS). +// +//===----------------------------------------------------------------------===// + +#ifdef LINALG_OPS +#else + +#ifdef LINALG_BASE +#else +include "mlir/Linalg/IR/LinalgBase.td" +#endif // LINALG_BASE + +class LinalgParametricNativeOpTrait : + NativeOpTrait<"linalg::" # prop # parameters> +{} + +class LinalgParametricIntNativeOpTrait parameters> : + LinalgParametricNativeOpTrait< + prop, + !strconcat("<", + !cast(!head(parameters)), + !foldl("", + !tail(parameters), + sum, + param, + sum # "," # !cast(param)), + ">::Impl")> +{} + +// The Linalg `NInputsAndOutputs` trait provides the API for ops that are known +// to have a specified number of inputs and outputs, all passed as operands. +// See Linalg/LinalgTraits.h for implementation details an usage. +class NInputsAndOutputs : + LinalgParametricIntNativeOpTrait<"NInputsAndOutputs", [n_ins, n_outs]> +{} + +// The linalg `NLoopTypes` trait provides the API for ops that are known to have +// a specified number of parallel (n_par), reduction (n_red) and window (n_win) +// loops. +// See Linalg/LinalgTraits.h for implementation details an usage. +class NLoopTypes : +LinalgParametricIntNativeOpTrait<"NLoopTypes", [n_par, n_red, n_win]> +{} + +// The linalg `ViewRanks` trait the API for ops that are known to have a +// specified list of view ranks. +// See Linalg/LinalgTraits.h for implementation details an usage. +class ViewRanks ranks> : +LinalgParametricIntNativeOpTrait<"ViewRanks", ranks> +{} + +// Base Tablegen class for Linalg ops. +class LinalgOp props> : +Op { + let arguments = (ins Variadic); // default variadic builder + let parser = [{ return parseLinalgLibraryOp(parser, result); }]; + let printer = [{ printLinalgLibraryOp(p, *this); }]; +} + +//////////////////////////////////////////////////////////////////////////////// +// Concrete Linalg ops. +//////////////////////////////////////////////////////////////////////////////// +def DotOp : LinalgOp<"dot", [NInputsAndOutputs<2, 1>, + NLoopTypes<0, 1, 0>, + ViewRanks<[1, 1, 0]>]> {} +def MatvecOp : LinalgOp<"matvec", [NInputsAndOutputs<2, 1>, + NLoopTypes<1, 1, 0>, + ViewRanks<[2, 1, 1]>]> {} +def MatmulOp : LinalgOp<"matmul", [NInputsAndOutputs<2, 1>, + NLoopTypes<2, 1, 0>, + ViewRanks<[2, 2, 2]>]> {} + +#endif // LINALG_OPS \ No newline at end of file diff --git a/mlir/include/mlir/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Linalg/IR/LinalgOps.h index 6a6c953c3e0c..0c48cb05d23c 100644 --- a/mlir/include/mlir/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Linalg/IR/LinalgOps.h @@ -18,6 +18,7 @@ #ifndef MLIR_LINALG_LINALGOPS_H_ #define MLIR_LINALG_LINALGOPS_H_ +#include "mlir/IR/Builders.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Linalg/IR/LinalgTraits.h" #include "mlir/Linalg/IR/LinalgTypes.h" @@ -274,6 +275,9 @@ public: #define GET_OP_CLASSES #include "mlir/Linalg/IR/LinalgOps.h.inc" +#define GET_OP_CLASSES +#include "mlir/Linalg/IR/LinalgLibraryOps.h.inc" + /// Returns the list of maps that map loops to operands of a Linalg op. /// The i-th affine map identifies loop indices to subscripts that are used when /// accessing the i-th operand. @@ -292,6 +296,116 @@ public: /// Only permutation maps are currently supported. SmallVector loopToOperandRangesMaps(Operation *op); +/// A LinalgOp behaves like a base class for the Linalg operations that are +/// defined in LinalgLibraryOps.td. The implementation does not use inheritance +/// directly. Instead, a LinalgOp directly derives from Op, hides the `classof` +/// method and dispatches to the appropriate LinalgLibraryOp. +/// This allows writing generic passes, like tiling, for all current and future +/// LinalgOps without requiring templating and dispatch in multiple places. +class LinalgOp : public Op { +public: + using Op::Op; + + LinalgOp(Operation *op) : Op(op) { + impl = ModelDispatch< +#define GET_OP_LIST +#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc" + >::dispatch(op); + } + + static bool classof(Operation *op) { + return ModelDispatch< +#define GET_OP_LIST +#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc" + >::classof(op); + } + + unsigned getNumParallelLoops() { + return impl->getNumParallelLoops(getOperation()); + } + unsigned getNumReductionLoops() { + return impl->getNumReductionLoops(getOperation()); + } + unsigned getNumWindowLoops() { + return impl->getNumWindowLoops(getOperation()); + } + unsigned getNumInputsAndOutputs() { + return impl->getNumInputsAndOutputs(getOperation()); + } + Operation *create(FuncBuilder &builder, Location loc, + ArrayRef operands) { + return impl->create(builder, loc, operands); + } + +private: + struct Concept { + virtual ~Concept() = default; + virtual unsigned getNumInputsAndOutputs(Operation *op) = 0; + virtual unsigned getNumParallelLoops(Operation *op) = 0; + virtual unsigned getNumReductionLoops(Operation *op) = 0; + virtual unsigned getNumWindowLoops(Operation *op) = 0; + virtual unsigned getNumLoops(Operation *op) = 0; + virtual Operation *create(FuncBuilder &builder, Location loc, + ArrayRef operands) = 0; + }; + + /// The implementation is inspired from Sean Parent's concept-based + /// polymorphism. A key difference is that the set of classes erased is + /// statically known, which alleviates the need for using dynamic memory + /// allocation. + /// We use a zero-sized templated class `Model` to emit the + /// virtual table and generate a singleton object for each instantiation of + /// this class. + /// We pay the cost of initialization once on construction (find which class + /// to dispatch to) and then a virtual dispatch on every call. + template struct Model : public Concept { + static Model &instance() { + static Model singleton; + return singleton; + } + unsigned getNumInputsAndOutputs(Operation *op) override { + return cast(op).getNumInputsAndOutputs(); + } + unsigned getNumParallelLoops(Operation *op) override { + return cast(op).getNumParallelLoops(); + } + unsigned getNumReductionLoops(Operation *op) override { + return cast(op).getNumReductionLoops(); + } + unsigned getNumWindowLoops(Operation *op) override { + return cast(op).getNumWindowLoops(); + } + unsigned getNumLoops(Operation *op) override { + return cast(op).getNumLoops(); + } + Operation *create(FuncBuilder &builder, Location loc, + ArrayRef operands) override { + return builder.create(loc, operands); + } + }; + Concept *impl; + + template struct ModelDispatch; + + template + struct ModelDispatch { + static bool classof(Operation *op) { + return isa(op) || ModelDispatch::classof(op); + } + static Concept *dispatch(Operation *op) { + return isa(op) ? &Model::instance() + : ModelDispatch::dispatch(op); + } + }; + + template struct ModelDispatch { + static bool classof(Operation *op) { return false; } + static Concept *dispatch(Operation *op) { + llvm_unreachable("Invalid LinalgOp"); + } + }; +}; + } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Linalg/IR/LinalgOps.td index 58eb3f00401f..ecdb111bb048 100644 --- a/mlir/include/mlir/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Linalg/IR/LinalgOps.td @@ -1,4 +1,4 @@ -//===- LinalgOps.td - Linear algebra dialect ops -----------*- tablegen -*-===// +//===- LinalgOps.td - Linalg dialect ops -------------------*- tablegen -*-===// // // Copyright 2019 The MLIR Authors. // @@ -22,69 +22,10 @@ #ifdef LINALG_OPS #else -#ifdef OP_BASE +#ifdef LINALG_BASE #else -include "mlir/IR/OpBase.td" -#endif // OP_BASE - -def Linalg_Dialect : Dialect { - let name = "linalg"; -} - -// Whether a type is a BufferType. -def LinalgIsBufferTypePred : CPred<"$_self.isa()">; -def Buffer : Type; - -// Whether a type is a ViewType. -def LinalgIsViewTypePred : CPred<"$_self.isa()">; -def View : Type; - -class LinalgParametricNativeOpTrait : - NativeOpTrait<"linalg::" # prop # parameters> -{} - -class LinalgParametricIntNativeOpTrait parameters> : - LinalgParametricNativeOpTrait< - prop, - !strconcat("<", - !cast(!head(parameters)), - !foldl("", - !tail(parameters), - sum, - param, - sum # "," # !cast(param)), - ">::Impl")> -{} - -// The Linalg `NInputsAndOutputs` trait provides the API for ops that are known -// to have a specified number of inputs and outputs, all passed as operands. -// See Linalg/LinalgTraits.h for implementation details an usage. -class NInputsAndOutputs : - LinalgParametricIntNativeOpTrait<"NInputsAndOutputs", [n_ins, n_outs]> -{} - -// The linalg `NLoopTypes` trait provides the API for ops that are known to have -// a specified number of parallel (n_par), reduction (n_red) and window (n_win) -// loops. -// See Linalg/LinalgTraits.h for implementation details an usage. -class NLoopTypes : -LinalgParametricIntNativeOpTrait<"NLoopTypes", [n_par, n_red, n_win]> -{} - -// The linalg `ViewRanks` trait the API for ops that are known to have a -// specified list of view ranks. -// See Linalg/LinalgTraits.h for implementation details an usage. -class ViewRanks ranks> : -LinalgParametricIntNativeOpTrait<"ViewRanks", ranks> -{} - -// Base Tablegen class for Linalg ops. -class LinalgOp props> : -Op { - let arguments = (ins Variadic); // default variadic builder - let parser = [{ return parseLinalgLibraryOp(parser, result); }]; - let printer = [{ printLinalgLibraryOp(p, *this); }]; -} +include "mlir/Linalg/IR/LinalgBase.td" +#endif // LINALG_BASE def BufferSizeOp : Op, @@ -127,17 +68,4 @@ def DimOp : Op, }]; } -//////////////////////////////////////////////////////////////////////////////// -// Concrete Linalg ops. -//////////////////////////////////////////////////////////////////////////////// -def DotOp : LinalgOp<"dot", [NInputsAndOutputs<2, 1>, - NLoopTypes<0, 1, 0>, - ViewRanks<[1, 1, 0]>]> {} -def MatvecOp : LinalgOp<"matvec", [NInputsAndOutputs<2, 1>, - NLoopTypes<1, 1, 0>, - ViewRanks<[2, 1, 1]>]> {} -def MatmulOp : LinalgOp<"matmul", [NInputsAndOutputs<2, 1>, - NLoopTypes<2, 1, 0>, - ViewRanks<[2, 2, 2]>]> {} - #endif // LINALG_OPS \ No newline at end of file diff --git a/mlir/lib/Linalg/IR/LinalgOps.cpp b/mlir/lib/Linalg/IR/LinalgOps.cpp index e6e18bb197e3..d077927ec240 100644 --- a/mlir/lib/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Linalg/IR/LinalgOps.cpp @@ -593,6 +593,9 @@ namespace mlir { #define GET_OP_CLASSES #include "mlir/Linalg/IR/LinalgOps.cpp.inc" +#define GET_OP_CLASSES +#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc" + } // namespace mlir // Ideally this should all be Tablegen'd but there is no good story for diff --git a/mlir/lib/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Linalg/IR/LinalgTypes.cpp index 19105e8676cf..0e20eb85cf81 100644 --- a/mlir/lib/Linalg/IR/LinalgTypes.cpp +++ b/mlir/lib/Linalg/IR/LinalgTypes.cpp @@ -37,6 +37,10 @@ mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context) #define GET_OP_LIST #include "mlir/Linalg/IR/LinalgOps.cpp.inc" >(); + addOperations< +#define GET_OP_LIST +#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc" + >(); } struct mlir::linalg::BufferTypeStorage : public TypeStorage { diff --git a/mlir/lib/Linalg/Transforms/Tiling.cpp b/mlir/lib/Linalg/Transforms/Tiling.cpp index ba1fdbe27155..ff6f02f17570 100644 --- a/mlir/lib/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Linalg/Transforms/Tiling.cpp @@ -248,7 +248,6 @@ static SmallVector makeTiledViews(FuncBuilder *b, Location loc, return res; } -template static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef tileSizes, PerFunctionState &state) { // Enforce the convention that "tiling by zero" skips tiling a particular @@ -278,7 +277,7 @@ static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef tileSizes, assert(op.getNumInputsAndOutputs() == op.getOperation()->getNumOperands()); auto views = makeTiledViews(b, loc, op.getOperation(), ivValues, tileSizes, state); - b->create(loc, views); + op.create(*b, loc, views); /// NestedBuilders expect handles, we thus return an IndexHandle. return IndexHandle(); }()}); @@ -286,7 +285,6 @@ static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef tileSizes, return success(); } -template static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef tileSizes, PerFunctionState &state) { if (tileSizes.empty()) @@ -319,13 +317,8 @@ static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef tileSizes, // TODO(ntv) expose as a primitive for other passes. static LogicalResult tileLinalgOp(Operation *op, ArrayRef tileSizes, PerFunctionState &state) { - if (auto matmulOp = dyn_cast(op)) { - return tileLinalgOp(matmulOp, tileSizes, state); - } else if (auto matvecOp = dyn_cast(op)) { - return tileLinalgOp(matvecOp, tileSizes, state); - } else if (auto dotOp = dyn_cast(op)) { - return tileLinalgOp(dotOp, tileSizes, state); - } + if (auto linalgOp = dyn_cast(op)) + return tileLinalgOp(linalgOp, tileSizes, state); return failure(); }