forked from OSchip/llvm-project
[mlir][Python][Linalg] Add missing attributes to linalg ops
This revision tightens up the handling of attributes for both named and generic linalg ops. To demonstrate the IR validity, a working e2e Linalg example is added. Differential Revision: https://reviews.llvm.org/D99430
This commit is contained in:
parent
e1d4fb1ebf
commit
335d2df533
|
@ -169,6 +169,17 @@ mlirAffineMapGetMajorSubMap(MlirAffineMap affineMap, intptr_t numResults);
|
|||
MLIR_CAPI_EXPORTED MlirAffineMap
|
||||
mlirAffineMapGetMinorSubMap(MlirAffineMap affineMap, intptr_t numResults);
|
||||
|
||||
/// Returns the simplified affine map resulting from dropping the symbols that
|
||||
/// do not appear in any of the individual maps in `affineMaps`.
|
||||
/// Asserts that all maps in `affineMaps` are normalized to the same number of
|
||||
/// dims and symbols.
|
||||
/// Takes a callback `populateResult` to fill the `res` container with value
|
||||
/// `m` at entry `idx`. This allows returning without worrying about ownership
|
||||
/// considerations.
|
||||
MLIR_CAPI_EXPORTED void mlirAffineMapCompressUnusedSymbols(
|
||||
MlirAffineMap *affineMaps, intptr_t size, void *result,
|
||||
void (*populateResult)(void *res, intptr_t idx, MlirAffineMap m));
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -340,6 +340,11 @@ AffineMap simplifyAffineMap(AffineMap map);
|
|||
/// Drop the dims that are not used.
|
||||
AffineMap compressUnusedDims(AffineMap map);
|
||||
|
||||
/// Drop the dims that are not used by any of the individual maps in `maps`.
|
||||
/// Asserts that all maps in `maps` are normalized to the same number of
|
||||
/// dims and symbols.
|
||||
SmallVector<AffineMap> compressUnusedDims(ArrayRef<AffineMap> maps);
|
||||
|
||||
/// Drop the dims that are not listed in `unusedDims`.
|
||||
AffineMap compressDims(AffineMap map,
|
||||
const llvm::SmallDenseSet<unsigned> &unusedDims);
|
||||
|
@ -347,6 +352,11 @@ AffineMap compressDims(AffineMap map,
|
|||
/// Drop the symbols that are not used.
|
||||
AffineMap compressUnusedSymbols(AffineMap map);
|
||||
|
||||
/// Drop the symbols that are not used by any of the individual maps in `maps`.
|
||||
/// Asserts that all maps in `maps` are normalized to the same number of
|
||||
/// dims and symbols.
|
||||
SmallVector<AffineMap> compressUnusedSymbols(ArrayRef<AffineMap> maps);
|
||||
|
||||
/// Drop the symbols that are not listed in `unusedSymbols`.
|
||||
AffineMap compressSymbols(AffineMap map,
|
||||
const llvm::SmallDenseSet<unsigned> &unusedSymbols);
|
||||
|
|
|
@ -538,6 +538,23 @@ void mlir::python::populateIRAffine(py::module &m) {
|
|||
printAccum.parts.append(")");
|
||||
return printAccum.join();
|
||||
})
|
||||
.def_static("compress_unused_symbols",
|
||||
[](py::list affineMaps, DefaultingPyMlirContext context) {
|
||||
SmallVector<MlirAffineMap> maps;
|
||||
pyListToVector<PyAffineMap, MlirAffineMap>(
|
||||
affineMaps, maps, "attempting to create an AffineMap");
|
||||
std::vector<MlirAffineMap> compressed(affineMaps.size());
|
||||
auto populate = [](void *result, intptr_t idx,
|
||||
MlirAffineMap m) {
|
||||
static_cast<MlirAffineMap *>(result)[idx] = (m);
|
||||
};
|
||||
mlirAffineMapCompressUnusedSymbols(
|
||||
maps.data(), maps.size(), compressed.data(), populate);
|
||||
std::vector<PyAffineMap> res;
|
||||
for (auto m : compressed)
|
||||
res.push_back(PyAffineMap(context->getRef(), m));
|
||||
return res;
|
||||
})
|
||||
.def_property_readonly(
|
||||
"context",
|
||||
[](PyAffineMap &self) { return self.getContext().getObject(); },
|
||||
|
|
|
@ -19,6 +19,13 @@ __all__ = [
|
|||
"emit_named_structured_op",
|
||||
]
|
||||
|
||||
def isa(cls : Type, ty : Type):
|
||||
try:
|
||||
cls(ty)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
|
||||
*ins: Value,
|
||||
outs: Value):
|
||||
|
@ -37,6 +44,8 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
|
|||
outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins,
|
||||
out_arg_defs, outs)
|
||||
|
||||
result_types = [t for t in out_types if isa(RankedTensorType, t)]
|
||||
|
||||
# Extract type vars for input/output based types.
|
||||
type_mapping = dict() # type: Dict[str, Type]
|
||||
for arg_def, arg_element_type in zip(
|
||||
|
@ -48,30 +57,37 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
|
|||
# Emit the generic op.
|
||||
# TODO: Support emission of pure memref form.
|
||||
indexing_maps_attr = ArrayAttr.get(
|
||||
[AffineMapAttr.get(am) for am in op_config.indexing_maps])
|
||||
[AffineMapAttr.get(am)
|
||||
# TODO: linalg verification does not currently allow symbols.
|
||||
# Compress them for now.
|
||||
for am in AffineMap.compress_unused_symbols(op_config.indexing_maps, Context.current)])
|
||||
iterator_types_attr = ArrayAttr.get(
|
||||
[StringAttr.get(s) for s in op_config.iterator_types])
|
||||
sparse_attr = ArrayAttr.get(
|
||||
[BoolAttr.get(False) for s in list(ins) + list(outs) if isa(RankedTensorType, s.type)])
|
||||
if len(sparse_attr) == 0:
|
||||
sparse_attr = None
|
||||
|
||||
return (all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types,
|
||||
type_mapping, indexing_maps_attr, iterator_types_attr)
|
||||
return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types,
|
||||
type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr)
|
||||
|
||||
|
||||
def emit_generic_structured_op(op_config: LinalgStructuredOpConfig,
|
||||
*ins: Value,
|
||||
outs: Value = ()):
|
||||
all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, \
|
||||
type_mapping, indexing_maps_attr, iterator_types_attr = \
|
||||
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, \
|
||||
type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr = \
|
||||
prepare_common_structured_op(op_config, *ins, outs = outs)
|
||||
|
||||
generic_op = linalg.GenericOp(
|
||||
result_tensors=out_types,
|
||||
result_tensors=result_types,
|
||||
inputs=ins,
|
||||
outputs=outs,
|
||||
indexing_maps=indexing_maps_attr,
|
||||
iterator_types=iterator_types_attr,
|
||||
doc=None, # TODO: Make optional.
|
||||
library_call=None, # TODO: Make optional.
|
||||
sparse=BoolAttr.get(False)) # TODO: Make optional.
|
||||
sparse=sparse_attr) # TODO: Make optional.
|
||||
|
||||
# Construct the body.
|
||||
block_arg_names = _get_tensor_def_names(*in_arg_defs, *out_arg_defs)
|
||||
|
@ -84,7 +100,7 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig,
|
|||
body_builder.assign(assignment)
|
||||
body_builder.yield_outputs(*_get_tensor_def_names(*out_arg_defs))
|
||||
|
||||
if len(out_arg_defs) == 1:
|
||||
if len(result_types) == 1:
|
||||
return generic_op.result
|
||||
else:
|
||||
return generic_op.results
|
||||
|
@ -95,8 +111,8 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
|
|||
op_class_name: str,
|
||||
*ins: Value,
|
||||
outs: Value = ()):
|
||||
all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, \
|
||||
type_mapping, indexing_maps_attr, iterator_types_attr = \
|
||||
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, \
|
||||
type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr = \
|
||||
prepare_common_structured_op(op_config, *ins, outs = outs)
|
||||
|
||||
# If we get here, there must exist a builtin class `op_class_name`.
|
||||
|
@ -107,11 +123,16 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
|
|||
raise NotImplementedError(
|
||||
f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}")
|
||||
|
||||
named_op = getattr(linalg, op_class_name)(ins, outs, out_types)
|
||||
named_op = getattr(linalg, op_class_name)(ins, outs, result_types)
|
||||
linalgDialect = ctx.get_dialect_descriptor("linalg")
|
||||
fill_builtin_region(linalgDialect, named_op.operation)
|
||||
# Note: mlir-linalg-ods-yaml-gen.cpp uses a special linalg.memoized_indexing_maps
|
||||
# attribute that the non-yaml path does not. The non-yaml path hardcodes the
|
||||
# indexing_maps in C++ directly.
|
||||
named_op.operation.attributes["linalg.memoized_indexing_maps"] = indexing_maps_attr
|
||||
# iterator_types are hardcoded in C++ both in the yaml and non-yaml path.
|
||||
|
||||
if len(out_arg_defs) == 1:
|
||||
if len(result_types) == 1:
|
||||
return named_op.result
|
||||
else:
|
||||
return named_op.results
|
||||
|
|
|
@ -137,3 +137,14 @@ MlirAffineMap mlirAffineMapGetMinorSubMap(MlirAffineMap affineMap,
|
|||
intptr_t numResults) {
|
||||
return wrap(unwrap(affineMap).getMinorSubMap(numResults));
|
||||
}
|
||||
|
||||
void mlirAffineMapCompressUnusedSymbols(
|
||||
MlirAffineMap *affineMaps, intptr_t size, void *result,
|
||||
void (*populateResult)(void *res, intptr_t idx, MlirAffineMap m)) {
|
||||
SmallVector<AffineMap> maps;
|
||||
for (intptr_t idx = 0; idx < size; ++idx)
|
||||
maps.push_back(unwrap(affineMaps[idx]));
|
||||
intptr_t idx = 0;
|
||||
for (auto m : mlir::compressUnusedSymbols(maps))
|
||||
populateResult(result, idx++, wrap(m));
|
||||
}
|
||||
|
|
|
@ -543,6 +543,41 @@ AffineMap mlir::compressUnusedDims(AffineMap map) {
|
|||
return compressDims(map, unusedDims);
|
||||
}
|
||||
|
||||
static SmallVector<AffineMap>
|
||||
compressUnusedImpl(ArrayRef<AffineMap> maps,
|
||||
llvm::function_ref<AffineMap(AffineMap)> compressionFun) {
|
||||
if (maps.empty())
|
||||
return SmallVector<AffineMap>();
|
||||
SmallVector<AffineExpr> allExprs;
|
||||
allExprs.reserve(maps.size() * maps.front().getNumResults());
|
||||
unsigned numDims = maps.front().getNumDims(),
|
||||
numSymbols = maps.front().getNumSymbols();
|
||||
for (auto m : maps) {
|
||||
assert(numDims == m.getNumDims() && numSymbols == m.getNumSymbols() &&
|
||||
"expected maps with same num dims and symbols");
|
||||
llvm::append_range(allExprs, m.getResults());
|
||||
}
|
||||
AffineMap unifiedMap = compressionFun(
|
||||
AffineMap::get(numDims, numSymbols, allExprs, maps.front().getContext()));
|
||||
unsigned unifiedNumDims = unifiedMap.getNumDims(),
|
||||
unifiedNumSymbols = unifiedMap.getNumSymbols();
|
||||
ArrayRef<AffineExpr> unifiedResults = unifiedMap.getResults();
|
||||
SmallVector<AffineMap> res;
|
||||
res.reserve(maps.size());
|
||||
for (auto m : maps) {
|
||||
res.push_back(AffineMap::get(unifiedNumDims, unifiedNumSymbols,
|
||||
unifiedResults.take_front(m.getNumResults()),
|
||||
m.getContext()));
|
||||
unifiedResults = unifiedResults.drop_front(m.getNumResults());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
SmallVector<AffineMap> mlir::compressUnusedDims(ArrayRef<AffineMap> maps) {
|
||||
return compressUnusedImpl(maps,
|
||||
[](AffineMap m) { return compressUnusedDims(m); });
|
||||
}
|
||||
|
||||
AffineMap
|
||||
mlir::compressSymbols(AffineMap map,
|
||||
const llvm::SmallDenseSet<unsigned> &unusedSymbols) {
|
||||
|
@ -576,6 +611,11 @@ AffineMap mlir::compressUnusedSymbols(AffineMap map) {
|
|||
return compressSymbols(map, unusedSymbols);
|
||||
}
|
||||
|
||||
SmallVector<AffineMap> mlir::compressUnusedSymbols(ArrayRef<AffineMap> maps) {
|
||||
return compressUnusedImpl(
|
||||
maps, [](AffineMap m) { return compressUnusedSymbols(m); });
|
||||
}
|
||||
|
||||
AffineMap mlir::simplifyAffineMap(AffineMap map) {
|
||||
SmallVector<AffineExpr, 8> exprs;
|
||||
for (auto e : map.getResults()) {
|
||||
|
|
|
@ -37,9 +37,9 @@ with Context() as ctx, Location.unknown():
|
|||
# Note that these all have the same indexing maps. We verify the first and
|
||||
# then do more permutation tests on casting and body generation
|
||||
# behavior.
|
||||
# CHECK: #[[$MAPA:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
|
||||
# CHECK: #[[$MAPB:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
|
||||
# CHECK: #[[$MAPC:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
|
||||
# CHECK: #[[$MAPA:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
# CHECK: #[[$MAPB:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
|
||||
# CHECK: #[[$MAPC:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
# CHECK-LABEL: func @test_matmul_mono
|
||||
# CHECK-SAME: %[[A:.+]]: tensor<4x16xf32>
|
||||
|
|
|
@ -94,6 +94,7 @@ def testNamedStructuredOpCustomForm():
|
|||
init_result = linalg.InitTensorOp([4, 8], f32)
|
||||
# First check the named form with custom format
|
||||
# CHECK: linalg.matmul
|
||||
# CHECK-NOT: linalg.memoized_indexing_maps
|
||||
# CHECK-SAME: ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>)
|
||||
# CHECK-SAME: outs(%{{.*}} : tensor<4x8xf32>)
|
||||
# CHECK-SAME: -> tensor<4x8xf32>
|
||||
|
@ -118,7 +119,7 @@ def testNamedStructuredOpGenericForm():
|
|||
# CHECK-NEXT: std.mulf{{.*}} (f32, f32) -> f32
|
||||
# CHECK-NEXT: std.addf{{.*}} (f32, f32) -> f32
|
||||
# CHECK-NEXT: linalg.yield{{.*}} (f32) -> ()
|
||||
# CHECK-NEXT: {operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} :
|
||||
# CHECK-NEXT: {linalg.memoized_indexing_maps{{.*}}operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} :
|
||||
# CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
|
||||
return linalg.matmul(lhs, rhs, outs=[init_result.result])
|
||||
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
# RUN: %PYTHON %s 2>&1 | FileCheck %s
|
||||
|
||||
import sys
|
||||
from mlir.ir import *
|
||||
from mlir.dialects import builtin
|
||||
from mlir.dialects import linalg
|
||||
from mlir.dialects import std
|
||||
from mlir.passmanager import *
|
||||
from mlir.execution_engine import *
|
||||
|
||||
# Log everything to stderr and flush so that we have a unified stream to match
|
||||
# errors/info emitted by MLIR to stderr.
|
||||
def log(*args):
|
||||
print(*args, file=sys.stderr)
|
||||
sys.stderr.flush()
|
||||
|
||||
boilerplate = """
|
||||
func @main() -> f32 attributes {llvm.emit_c_interface} {
|
||||
%v0 = constant 0.0 : f32
|
||||
%v1 = constant 1.0 : f32
|
||||
%v2 = constant 2.0 : f32
|
||||
|
||||
%A = memref.alloc() : memref<4x16xf32>
|
||||
%B = memref.alloc() : memref<16x8xf32>
|
||||
%C = memref.alloc() : memref<4x8xf32>
|
||||
linalg.fill(%A, %v1) : memref<4x16xf32>, f32
|
||||
linalg.fill(%B, %v2) : memref<16x8xf32>, f32
|
||||
linalg.fill(%C, %v0) : memref<4x8xf32>, f32
|
||||
|
||||
call @matmul_on_buffers(%A, %B, %C) :
|
||||
(memref<4x16xf32>, memref<16x8xf32>, memref<4x8xf32>) -> ()
|
||||
|
||||
%c0 = constant 0 : index
|
||||
%0 = memref.load %C[%c0, %c0] : memref<4x8xf32>
|
||||
|
||||
// TODO: FFI-based solution to allow testing and printing with python code.
|
||||
return %0 : f32
|
||||
}
|
||||
"""
|
||||
|
||||
def transform(module):
|
||||
import mlir.conversions
|
||||
import mlir.dialects.linalg.passes
|
||||
import mlir.transforms
|
||||
|
||||
# TODO: Allow cloning functions from one module to another.
|
||||
# Atm we have to resort to string concatenation.
|
||||
mod = Module.parse(
|
||||
str(module.operation.regions[0].blocks[0].operations[0].operation) +
|
||||
boilerplate)
|
||||
pm = PassManager.parse("func(convert-linalg-to-loops, convert-scf-to-std)," +
|
||||
"convert-vector-to-llvm," +
|
||||
"convert-std-to-llvm")
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
def test_builtin():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
f32 = F32Type.get()
|
||||
with InsertionPoint(module.body):
|
||||
@builtin.FuncOp.from_py_func(MemRefType.get((4, 16), f32),
|
||||
MemRefType.get((16, 8), f32),
|
||||
MemRefType.get((4, 8), f32))
|
||||
def matmul_on_buffers(lhs, rhs, out):
|
||||
linalg.matmul(lhs, rhs, outs=[out])
|
||||
|
||||
execution_engine = ExecutionEngine(transform(module))
|
||||
|
||||
# TODO: FFI-based solution to allow testing and printing with python code.
|
||||
# Prepare arguments: one result f32.
|
||||
# Arguments must be passed as pointers.
|
||||
c_float_p = ctypes.c_float * 1
|
||||
res = c_float_p(-1.)
|
||||
execution_engine.invoke("main", res)
|
||||
|
||||
log('RESULT: ', res[0])
|
||||
# CHECK: RESULT: 32.0
|
||||
|
||||
test_builtin()
|
||||
|
||||
def test_generic():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
f32 = F32Type.get()
|
||||
with InsertionPoint(module.body):
|
||||
@builtin.FuncOp.from_py_func(MemRefType.get((4, 16), f32),
|
||||
MemRefType.get((16, 8), f32),
|
||||
MemRefType.get((4, 8), f32))
|
||||
def matmul_on_buffers(lhs, rhs, out):
|
||||
linalg.matmul(lhs, rhs, outs=[out], emit_generic=True)
|
||||
|
||||
execution_engine = ExecutionEngine(transform(module))
|
||||
|
||||
# TODO: FFI-based solution to allow testing and printing with python code.
|
||||
# Prepare arguments: one result f32.
|
||||
# Arguments must be passed as pointers.
|
||||
c_float_p = ctypes.c_float * 1
|
||||
res = c_float_p(-1.)
|
||||
execution_engine.invoke("main", res)
|
||||
|
||||
log('RESULT: ', res[0])
|
||||
# CHECK: RESULT: 32.0
|
||||
|
||||
test_generic()
|
Loading…
Reference in New Issue