[mlir][Linalg][Python] Create the body of builtin named Linalg ops

This revision adds support to properly add the body of registered
builtin named linalg ops.
At this time, indexing_map and iterator_type support is still
missing so the op is not executable yet.

Differential Revision: https://reviews.llvm.org/D99578
This commit is contained in:
Nicolas Vasilache 2021-03-30 11:41:41 +00:00
parent 465b9a4a33
commit 43b9fa3ce0
12 changed files with 196 additions and 14 deletions

View File

@ -17,6 +17,11 @@
extern "C" {
#endif
/// Apply the special region builder for the builtin named Linalg op.
/// Assert that `op` is a builtin named Linalg op.
MLIR_CAPI_EXPORTED void
mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op);
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
#ifdef __cplusplus

View File

@ -37,6 +37,14 @@ def Linalg_Dialect : Dialect {
let dependentDialects = [
"AffineDialect", "StandardOpsDialect", "tensor::TensorDialect"
];
let extraClassDeclaration = [{
using RegionBuilderFunType = llvm::function_ref<void(Block &, ValueRange)>;
RegionBuilderFunType getRegionBuilder(StringRef name) {
return namedStructuredOpRegionBuilders.lookup(name);
}
private:
llvm::StringMap<RegionBuilderFunType> namedStructuredOpRegionBuilders;
}];
}
// Whether a type is a RangeType.

View File

@ -14,6 +14,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/StringMap.h"
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"

View File

@ -69,6 +69,7 @@ add_mlir_python_extension(MLIRCoreBindingsPythonExtension _mlir
INSTALL_DIR
python
SOURCES
DialectLinalg.cpp
MainModule.cpp
IRAffine.cpp
IRAttributes.cpp

View File

@ -0,0 +1,34 @@
//===- DialectLinalg.cpp - Pybind module for Linalg dialect API support --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "IRModule.h"
#include "mlir-c/Dialect/Linalg.h"
#include "mlir-c/IR.h"
#include <pybind11/pybind11.h>
namespace py = pybind11;
using namespace mlir;
using namespace mlir::python;
namespace mlir {
namespace python {
void populateDialectLinalgSubmodule(py::module &m) {
m.def(
"fill_builtin_region",
[](PyDialectDescriptor &dialect, PyOperation &op) {
return mlirLinalgFillBuiltinNamedOpRegion(dialect.get(), op.get());
},
py::arg("dialect"), py::arg("op"),
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
"op.");
}
} // namespace python
} // namespace mlir

View File

@ -0,0 +1,22 @@
//===- DialectLinalg.h - Linalg dialect submodule of pybind module --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_BINDINGS_PYTHON_DIALECTLINALG_H
#define MLIR_BINDINGS_PYTHON_DIALECTLINALG_H
#include "PybindUtils.h"
namespace mlir {
namespace python {
void populateDialectLinalgSubmodule(pybind11::module &m);
} // namespace python
} // namespace mlir
#endif // MLIR_BINDINGS_PYTHON_DIALECTLINALG_H

View File

@ -10,6 +10,7 @@
#include "PybindUtils.h"
#include "DialectLinalg.h"
#include "ExecutionEngine.h"
#include "Globals.h"
#include "IRModule.h"
@ -225,4 +226,9 @@ PYBIND11_MODULE(_mlir, m) {
auto executionEngineModule =
m.def_submodule("execution_engine", "MLIR JIT Execution Engine");
populateExecutionEngineSubmodule(executionEngineModule);
// Define and populate Linalg submodule.
auto dialectsModule = m.def_submodule("dialects");
auto linalgModule = dialectsModule.def_submodule("linalg");
populateDialectLinalgSubmodule(linalgModule);
}

View File

@ -61,11 +61,10 @@ class DefinedOpCallable:
raise NotImplementedError(
f"Emission of composite linalg ops not supported: {op_configs}")
# TODO: this file should probably not be called dsl.py but rather is a client
# of the dsl.py.
from .... import linalg as linalg_ops
emit_generic = (emit_generic or
(not self.model.metadata.cpp_class_name in linalg_ops.__dict__.keys()))
ctx = ir.Context.current
linalgDialect = ctx.get_dialect_descriptor("linalg")
fully_qualified_name = 'linalg.' + self.op_name
emit_generic = (emit_generic or not ctx.is_registered_operation(fully_qualified_name))
op_config = op_configs[0]
if op_config.structured_op:

View File

@ -7,6 +7,9 @@ from typing import Dict, Sequence
from mlir.ir import *
from mlir.dialects import linalg
from mlir.dialects import std
# TODO: resolve name collision for Linalg functionality that is injected inside
# the _mlir.dialects.linalg directly via pybind.
from _mlir.dialects.linalg import fill_builtin_region
from .scalar_expr import *
from .config import *
@ -16,7 +19,6 @@ __all__ = [
"emit_named_structured_op",
]
def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
*ins: Value,
outs: Value):
@ -97,11 +99,18 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
type_mapping, indexing_maps_attr, iterator_types_attr = \
prepare_common_structured_op(op_config, *ins, outs = outs)
if not op_class_name in linalg.__dict__.keys():
# If we get here, there must exist a builtin class `op_class_name`.
ctx = Context.current
fully_qualified_name = 'linalg.' + op_name
if (not ctx.is_registered_operation(fully_qualified_name) or
not op_class_name in linalg.__dict__.keys()):
raise NotImplementedError(
f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}")
named_op = getattr(linalg, op_class_name)(ins, outs, out_types)
linalgDialect = ctx.get_dialect_descriptor("linalg")
fill_builtin_region(linalgDialect, named_op.operation)
if len(out_arg_defs) == 1:
return named_op.result
else:

View File

@ -10,5 +10,30 @@
#include "mlir/CAPI/Registration.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg,
mlir::linalg::LinalgDialect)
using namespace mlir;
using namespace mlir::linalg;
/// Apply the special region builder for the builtin named Linalg op.
/// Assert that `op` is a builtin named Linalg op.
void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect,
MlirOperation mlirOp) {
Operation *op = unwrap(mlirOp);
LinalgDialect::RegionBuilderFunType fun =
static_cast<LinalgDialect *>(unwrap(linalgDialect))
->getRegionBuilder(op->getName().getStringRef());
assert(fun && "Expected a builtin named Linalg op.");
assert(op->getNumRegions() == 1 && "Expected Linalg op with 1 region");
assert(op->getRegion(0).getBlocks().empty() &&
"Expected Linalg op with 0 blocks");
SmallVector<Type, 8> argTypes;
auto linalgOp = cast<LinalgOp>(op);
for (auto t : linalgOp.getShapedOperandTypes())
argTypes.push_back(getElementTypeOrSelf(t));
OpBuilder b(op->getContext());
Region &region = op->getRegion(0);
Block *body = b.createBlock(&region, /*insertPt=*/{}, argTypes);
// TODO: allow captures.
fun(*body, ValueRange{});
}
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)

View File

@ -57,6 +57,38 @@ struct LinalgInlinerInterface : public DialectInlinerInterface {
// LinalgDialect
//===----------------------------------------------------------------------===//
/// Trait to check if T provides a `regionBuilder` method.
template <typename T, typename... Args>
using has_region_builder = decltype(T::regionBuilder);
template <typename T>
using detect_has_region_builder = llvm::is_detected<has_region_builder, T>;
/// SFINAE helper for single C++ class without a `regionBuilder` method (e.g.
/// an OpInterface).
template <typename OpType, typename = std::enable_if_t<
!detect_has_region_builder<OpType>::value>>
void addNamedOpBuilderImpl(
llvm::StringMap<LinalgDialect::RegionBuilderFunType> &map) {
// Do nothing.
}
template <typename OpType,
typename = std::enable_if_t<detect_has_region_builder<OpType>::value>,
typename = void>
void addNamedOpBuilderImpl(
llvm::StringMap<LinalgDialect::RegionBuilderFunType> &map) {
map.insert(std::make_pair(
OpType::getOperationName(),
static_cast<LinalgDialect::RegionBuilderFunType>(OpType::regionBuilder)));
}
template <typename... OpTypes>
void addNamedOpBuilders(
llvm::StringMap<LinalgDialect::RegionBuilderFunType> &map) {
(void)std::initializer_list<int>{0,
(addNamedOpBuilderImpl<OpTypes>(map), 0)...};
}
void mlir::linalg::LinalgDialect::initialize() {
addTypes<RangeType>();
addOperations<
@ -72,6 +104,12 @@ void mlir::linalg::LinalgDialect::initialize() {
#include "mlir/Dialect/Linalg/IR/LinalgSparseOps.cpp.inc"
>();
// Fill the Linalg-specific OpName to RegionBuilder map.
addNamedOpBuilders<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>(namedStructuredOpRegionBuilders);
addInterfaces<LinalgInlinerInterface>();
}

View File

@ -5,7 +5,6 @@ from mlir.dialects import builtin
from mlir.dialects import linalg
from mlir.dialects import std
def run(f):
print("\nTEST:", f.__name__)
f()
@ -82,9 +81,9 @@ def testStructuredOpOnBuffers():
# CHECK: linalg.matmul ins(%arg0, %arg1 : memref<2x3x4xf32>, memref<2x3x4xf32>) outs(%arg2 : memref<2x3x4xf32>)
print(module)
# CHECK-LABEL: TEST: testNamedStructuredOp
# CHECK-LABEL: TEST: testNamedStructuredOpCustomForm
@run
def testNamedStructuredOp():
def testNamedStructuredOpCustomForm():
with Context() as ctx, Location.unknown():
module = Module.create()
f32 = F32Type.get()
@ -93,10 +92,45 @@ def testNamedStructuredOp():
RankedTensorType.get((16, 8), f32))
def named_form(lhs, rhs):
init_result = linalg.InitTensorOp([4, 8], f32)
# CHECK: linalg.matmul
# TODO: prperly hook up the region.
# First check the named form with custom format
# CHECK: linalg.matmul
# CHECK-SAME: ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>)
# CHECK-SAME: outs(%{{.*}} : tensor<4x8xf32>)
# CHECK-SAME: -> tensor<4x8xf32>
# CHECK-NEXT: return
return linalg.matmul(lhs, rhs, outs=[init_result.result])
print(module)
# CHECK-LABEL: TEST: testNamedStructuredOpGenericForm
@run
def testNamedStructuredOpGenericForm():
with Context() as ctx, Location.unknown():
module = Module.create()
f32 = F32Type.get()
with InsertionPoint(module.body):
@builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
RankedTensorType.get((16, 8), f32))
def named_form(lhs, rhs):
init_result = linalg.InitTensorOp([4, 8], f32)
# CHECK: "linalg.matmul"(%{{.*}})
# CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32):
# CHECK-NEXT: std.mulf{{.*}} (f32, f32) -> f32
# CHECK-NEXT: std.addf{{.*}} (f32, f32) -> f32
# CHECK-NEXT: linalg.yield{{.*}} (f32) -> ()
# CHECK-NEXT: {operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} :
# CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
return linalg.matmul(lhs, rhs, outs=[init_result.result])
module.operation.print(print_generic_op_form=True)
# CHECK-LABEL: TEST: testNamedStructuredAsGenericOp
@run
def testNamedStructuredAsGenericOp():
with Context() as ctx, Location.unknown():
module = Module.create()
f32 = F32Type.get()
with InsertionPoint(module.body):
@builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
RankedTensorType.get((16, 8), f32))
def generic_form(lhs, rhs):