forked from OSchip/llvm-project
[mlir][Linalg][Python] Create the body of builtin named Linalg ops
This revision adds support to properly add the body of registered builtin named linalg ops. At this time, indexing_map and iterator_type support is still missing so the op is not executable yet. Differential Revision: https://reviews.llvm.org/D99578
This commit is contained in:
parent
465b9a4a33
commit
43b9fa3ce0
|
@ -17,6 +17,11 @@
|
|||
extern "C" {
|
||||
#endif
|
||||
|
||||
/// Apply the special region builder for the builtin named Linalg op.
|
||||
/// Assert that `op` is a builtin named Linalg op.
|
||||
MLIR_CAPI_EXPORTED void
|
||||
mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op);
|
||||
|
||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -37,6 +37,14 @@ def Linalg_Dialect : Dialect {
|
|||
let dependentDialects = [
|
||||
"AffineDialect", "StandardOpsDialect", "tensor::TensorDialect"
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
using RegionBuilderFunType = llvm::function_ref<void(Block &, ValueRange)>;
|
||||
RegionBuilderFunType getRegionBuilder(StringRef name) {
|
||||
return namedStructuredOpRegionBuilders.lookup(name);
|
||||
}
|
||||
private:
|
||||
llvm::StringMap<RegionBuilderFunType> namedStructuredOpRegionBuilders;
|
||||
}];
|
||||
}
|
||||
|
||||
// Whether a type is a RangeType.
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"
|
||||
|
||||
|
|
|
@ -69,6 +69,7 @@ add_mlir_python_extension(MLIRCoreBindingsPythonExtension _mlir
|
|||
INSTALL_DIR
|
||||
python
|
||||
SOURCES
|
||||
DialectLinalg.cpp
|
||||
MainModule.cpp
|
||||
IRAffine.cpp
|
||||
IRAttributes.cpp
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
//===- DialectLinalg.cpp - Pybind module for Linalg dialect API support --===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "IRModule.h"
|
||||
#include "mlir-c/Dialect/Linalg.h"
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace mlir;
|
||||
using namespace mlir::python;
|
||||
|
||||
namespace mlir {
|
||||
namespace python {
|
||||
|
||||
void populateDialectLinalgSubmodule(py::module &m) {
|
||||
m.def(
|
||||
"fill_builtin_region",
|
||||
[](PyDialectDescriptor &dialect, PyOperation &op) {
|
||||
return mlirLinalgFillBuiltinNamedOpRegion(dialect.get(), op.get());
|
||||
},
|
||||
py::arg("dialect"), py::arg("op"),
|
||||
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
|
||||
"op.");
|
||||
}
|
||||
|
||||
} // namespace python
|
||||
} // namespace mlir
|
|
@ -0,0 +1,22 @@
|
|||
//===- DialectLinalg.h - Linalg dialect submodule of pybind module --------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_BINDINGS_PYTHON_DIALECTLINALG_H
|
||||
#define MLIR_BINDINGS_PYTHON_DIALECTLINALG_H
|
||||
|
||||
#include "PybindUtils.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace python {
|
||||
|
||||
void populateDialectLinalgSubmodule(pybind11::module &m);
|
||||
|
||||
} // namespace python
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_BINDINGS_PYTHON_DIALECTLINALG_H
|
|
@ -10,6 +10,7 @@
|
|||
|
||||
#include "PybindUtils.h"
|
||||
|
||||
#include "DialectLinalg.h"
|
||||
#include "ExecutionEngine.h"
|
||||
#include "Globals.h"
|
||||
#include "IRModule.h"
|
||||
|
@ -225,4 +226,9 @@ PYBIND11_MODULE(_mlir, m) {
|
|||
auto executionEngineModule =
|
||||
m.def_submodule("execution_engine", "MLIR JIT Execution Engine");
|
||||
populateExecutionEngineSubmodule(executionEngineModule);
|
||||
|
||||
// Define and populate Linalg submodule.
|
||||
auto dialectsModule = m.def_submodule("dialects");
|
||||
auto linalgModule = dialectsModule.def_submodule("linalg");
|
||||
populateDialectLinalgSubmodule(linalgModule);
|
||||
}
|
||||
|
|
|
@ -61,11 +61,10 @@ class DefinedOpCallable:
|
|||
raise NotImplementedError(
|
||||
f"Emission of composite linalg ops not supported: {op_configs}")
|
||||
|
||||
# TODO: this file should probably not be called dsl.py but rather is a client
|
||||
# of the dsl.py.
|
||||
from .... import linalg as linalg_ops
|
||||
emit_generic = (emit_generic or
|
||||
(not self.model.metadata.cpp_class_name in linalg_ops.__dict__.keys()))
|
||||
ctx = ir.Context.current
|
||||
linalgDialect = ctx.get_dialect_descriptor("linalg")
|
||||
fully_qualified_name = 'linalg.' + self.op_name
|
||||
emit_generic = (emit_generic or not ctx.is_registered_operation(fully_qualified_name))
|
||||
|
||||
op_config = op_configs[0]
|
||||
if op_config.structured_op:
|
||||
|
|
|
@ -7,6 +7,9 @@ from typing import Dict, Sequence
|
|||
from mlir.ir import *
|
||||
from mlir.dialects import linalg
|
||||
from mlir.dialects import std
|
||||
# TODO: resolve name collision for Linalg functionality that is injected inside
|
||||
# the _mlir.dialects.linalg directly via pybind.
|
||||
from _mlir.dialects.linalg import fill_builtin_region
|
||||
|
||||
from .scalar_expr import *
|
||||
from .config import *
|
||||
|
@ -16,7 +19,6 @@ __all__ = [
|
|||
"emit_named_structured_op",
|
||||
]
|
||||
|
||||
|
||||
def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
|
||||
*ins: Value,
|
||||
outs: Value):
|
||||
|
@ -97,11 +99,18 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
|
|||
type_mapping, indexing_maps_attr, iterator_types_attr = \
|
||||
prepare_common_structured_op(op_config, *ins, outs = outs)
|
||||
|
||||
if not op_class_name in linalg.__dict__.keys():
|
||||
# If we get here, there must exist a builtin class `op_class_name`.
|
||||
ctx = Context.current
|
||||
fully_qualified_name = 'linalg.' + op_name
|
||||
if (not ctx.is_registered_operation(fully_qualified_name) or
|
||||
not op_class_name in linalg.__dict__.keys()):
|
||||
raise NotImplementedError(
|
||||
f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}")
|
||||
|
||||
named_op = getattr(linalg, op_class_name)(ins, outs, out_types)
|
||||
linalgDialect = ctx.get_dialect_descriptor("linalg")
|
||||
fill_builtin_region(linalgDialect, named_op.operation)
|
||||
|
||||
if len(out_arg_defs) == 1:
|
||||
return named_op.result
|
||||
else:
|
||||
|
|
|
@ -10,5 +10,30 @@
|
|||
#include "mlir/CAPI/Registration.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
|
||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg,
|
||||
mlir::linalg::LinalgDialect)
|
||||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
/// Apply the special region builder for the builtin named Linalg op.
|
||||
/// Assert that `op` is a builtin named Linalg op.
|
||||
void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect,
|
||||
MlirOperation mlirOp) {
|
||||
Operation *op = unwrap(mlirOp);
|
||||
LinalgDialect::RegionBuilderFunType fun =
|
||||
static_cast<LinalgDialect *>(unwrap(linalgDialect))
|
||||
->getRegionBuilder(op->getName().getStringRef());
|
||||
assert(fun && "Expected a builtin named Linalg op.");
|
||||
assert(op->getNumRegions() == 1 && "Expected Linalg op with 1 region");
|
||||
assert(op->getRegion(0).getBlocks().empty() &&
|
||||
"Expected Linalg op with 0 blocks");
|
||||
SmallVector<Type, 8> argTypes;
|
||||
auto linalgOp = cast<LinalgOp>(op);
|
||||
for (auto t : linalgOp.getShapedOperandTypes())
|
||||
argTypes.push_back(getElementTypeOrSelf(t));
|
||||
OpBuilder b(op->getContext());
|
||||
Region ®ion = op->getRegion(0);
|
||||
Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes);
|
||||
// TODO: allow captures.
|
||||
fun(*body, ValueRange{});
|
||||
}
|
||||
|
||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
|
||||
|
|
|
@ -57,6 +57,38 @@ struct LinalgInlinerInterface : public DialectInlinerInterface {
|
|||
// LinalgDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Trait to check if T provides a `regionBuilder` method.
|
||||
template <typename T, typename... Args>
|
||||
using has_region_builder = decltype(T::regionBuilder);
|
||||
template <typename T>
|
||||
using detect_has_region_builder = llvm::is_detected<has_region_builder, T>;
|
||||
|
||||
/// SFINAE helper for single C++ class without a `regionBuilder` method (e.g.
|
||||
/// an OpInterface).
|
||||
template <typename OpType, typename = std::enable_if_t<
|
||||
!detect_has_region_builder<OpType>::value>>
|
||||
void addNamedOpBuilderImpl(
|
||||
llvm::StringMap<LinalgDialect::RegionBuilderFunType> &map) {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
template <typename OpType,
|
||||
typename = std::enable_if_t<detect_has_region_builder<OpType>::value>,
|
||||
typename = void>
|
||||
void addNamedOpBuilderImpl(
|
||||
llvm::StringMap<LinalgDialect::RegionBuilderFunType> &map) {
|
||||
map.insert(std::make_pair(
|
||||
OpType::getOperationName(),
|
||||
static_cast<LinalgDialect::RegionBuilderFunType>(OpType::regionBuilder)));
|
||||
}
|
||||
|
||||
template <typename... OpTypes>
|
||||
void addNamedOpBuilders(
|
||||
llvm::StringMap<LinalgDialect::RegionBuilderFunType> &map) {
|
||||
(void)std::initializer_list<int>{0,
|
||||
(addNamedOpBuilderImpl<OpTypes>(map), 0)...};
|
||||
}
|
||||
|
||||
void mlir::linalg::LinalgDialect::initialize() {
|
||||
addTypes<RangeType>();
|
||||
addOperations<
|
||||
|
@ -72,6 +104,12 @@ void mlir::linalg::LinalgDialect::initialize() {
|
|||
#include "mlir/Dialect/Linalg/IR/LinalgSparseOps.cpp.inc"
|
||||
>();
|
||||
|
||||
// Fill the Linalg-specific OpName to RegionBuilder map.
|
||||
addNamedOpBuilders<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
|
||||
>(namedStructuredOpRegionBuilders);
|
||||
|
||||
addInterfaces<LinalgInlinerInterface>();
|
||||
}
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@ from mlir.dialects import builtin
|
|||
from mlir.dialects import linalg
|
||||
from mlir.dialects import std
|
||||
|
||||
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
|
@ -82,9 +81,9 @@ def testStructuredOpOnBuffers():
|
|||
# CHECK: linalg.matmul ins(%arg0, %arg1 : memref<2x3x4xf32>, memref<2x3x4xf32>) outs(%arg2 : memref<2x3x4xf32>)
|
||||
print(module)
|
||||
|
||||
# CHECK-LABEL: TEST: testNamedStructuredOp
|
||||
# CHECK-LABEL: TEST: testNamedStructuredOpCustomForm
|
||||
@run
|
||||
def testNamedStructuredOp():
|
||||
def testNamedStructuredOpCustomForm():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
f32 = F32Type.get()
|
||||
|
@ -93,10 +92,45 @@ def testNamedStructuredOp():
|
|||
RankedTensorType.get((16, 8), f32))
|
||||
def named_form(lhs, rhs):
|
||||
init_result = linalg.InitTensorOp([4, 8], f32)
|
||||
# CHECK: linalg.matmul
|
||||
# TODO: prperly hook up the region.
|
||||
# First check the named form with custom format
|
||||
# CHECK: linalg.matmul
|
||||
# CHECK-SAME: ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>)
|
||||
# CHECK-SAME: outs(%{{.*}} : tensor<4x8xf32>)
|
||||
# CHECK-SAME: -> tensor<4x8xf32>
|
||||
# CHECK-NEXT: return
|
||||
return linalg.matmul(lhs, rhs, outs=[init_result.result])
|
||||
|
||||
print(module)
|
||||
|
||||
# CHECK-LABEL: TEST: testNamedStructuredOpGenericForm
|
||||
@run
|
||||
def testNamedStructuredOpGenericForm():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
f32 = F32Type.get()
|
||||
with InsertionPoint(module.body):
|
||||
@builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
|
||||
RankedTensorType.get((16, 8), f32))
|
||||
def named_form(lhs, rhs):
|
||||
init_result = linalg.InitTensorOp([4, 8], f32)
|
||||
# CHECK: "linalg.matmul"(%{{.*}})
|
||||
# CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32):
|
||||
# CHECK-NEXT: std.mulf{{.*}} (f32, f32) -> f32
|
||||
# CHECK-NEXT: std.addf{{.*}} (f32, f32) -> f32
|
||||
# CHECK-NEXT: linalg.yield{{.*}} (f32) -> ()
|
||||
# CHECK-NEXT: {operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} :
|
||||
# CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
|
||||
return linalg.matmul(lhs, rhs, outs=[init_result.result])
|
||||
|
||||
module.operation.print(print_generic_op_form=True)
|
||||
|
||||
# CHECK-LABEL: TEST: testNamedStructuredAsGenericOp
|
||||
@run
|
||||
def testNamedStructuredAsGenericOp():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
f32 = F32Type.get()
|
||||
with InsertionPoint(module.body):
|
||||
@builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
|
||||
RankedTensorType.get((16, 8), f32))
|
||||
def generic_form(lhs, rhs):
|
||||
|
|
Loading…
Reference in New Issue