[mlir][OpDSL] Add type function attributes.

Previously, OpDSL operation used hardcoded type conversion operations (cast or cast_unsigned). Supporting signed and unsigned casts thus meant implementing two different operations. Type function attributes allow us to define a single operation that has a cast type function attribute which at operation instantiation time may be set to cast or cast_unsigned. We may for example, defina a matmul operation with a cast argument:

```
@linalg_structured_op
def matmul(A=TensorDef(T1, S.M, S.K), B=TensorDef(T2, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True),
    cast=TypeFnAttrDef(default=TypeFn.cast)):
  C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
```

When instantiating the operation the attribute may be set to the desired cast function:

```
linalg.matmul(lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned)
```

The revsion introduces a enum in the Linalg dialect that maps one-by-one to the type functions defined by OpDSL.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D119718
This commit is contained in:
gysit 2022-02-25 08:12:34 +00:00
parent 3fe6f9388f
commit 51fdd802c7
24 changed files with 759 additions and 475 deletions

View File

@ -44,6 +44,18 @@ add_dependencies(mlir-headers LinalgOdsGen)
add_mlir_dialect(LinalgOps linalg)
set(LLVM_TARGET_DEFINITIONS LinalgOps.td)
mlir_tablegen(LinalgOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(LinalgOpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRLinalgOpsEnumsIncGen)
add_dependencies(mlir-headers MLIRLinalgOpsEnumsIncGen)
set(LLVM_TARGET_DEFINITIONS LinalgOps.td)
mlir_tablegen(LinalgOpsAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(LinalgOpsAttrDefs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(MLIRLinalgOpsAttributesIncGen)
add_dependencies(mlir-headers MLIRLinalgOpsAttributesIncGen)
add_mlir_doc(LinalgDoc LinalgOps Dialects/ -gen-op-doc)
add_dependencies(LinalgOpsDocGen LinalgOdsGen)

View File

@ -104,6 +104,19 @@ LogicalResult verifyStructuredOpInterface(Operation *op);
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"
//===----------------------------------------------------------------------===//
// Linalg Enums
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/IR/LinalgOpsEnums.h.inc"
//===----------------------------------------------------------------------===//
// Linalg Attributes
//===----------------------------------------------------------------------===//
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.h.inc"
//===----------------------------------------------------------------------===//
// Linalg Interfaces
//===----------------------------------------------------------------------===//

View File

@ -13,6 +13,7 @@
#ifndef LINALG_BASE
#define LINALG_BASE
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
def Linalg_Dialect : Dialect {
@ -57,4 +58,17 @@ def Linalg_Dialect : Dialect {
}];
}
// Define a TypeFn enum matching the OpDSL TypeFn class.
def TypeFn : I32EnumAttr<"TypeFn", "", [
I32EnumAttrCase<"cast", 0>,
I32EnumAttrCase<"cast_unsigned", 1>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}
def TypeFnAttr : EnumAttr<Linalg_Dialect, TypeFn, "type_fn"> {
let assemblyFormat = "`<` $value `>`";
}
#endif // LINALG_BASE

View File

@ -37,7 +37,7 @@ class LinalgStructuredBase_Op<string mnemonic, list<Trait> props>
LogicalResult reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgOp>(getOperation()).reifyResultShapes(b,
return llvm::cast<LinalgOp>(getOperation()).reifyResultShapes(b,
reifiedReturnShapes);
}
}];

View File

@ -8,6 +8,8 @@ add_mlir_dialect_library(MLIRLinalg
DEPENDS
MLIRLinalgInterfacesIncGen
MLIRLinalgOpsAttributesIncGen
MLIRLinalgOpsEnumsIncGen
MLIRLinalgOpsIncGen
MLIRLinalgStructuredOpsIncGen

View File

@ -21,6 +21,7 @@
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
@ -95,6 +96,10 @@ void addNamedOpBuilders(
}
void mlir::linalg::LinalgDialect::initialize() {
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
@ -144,3 +149,10 @@ LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op,
return op->emitError() << "attribute '" << attr.getName()
<< "' not supported by the linalg dialect";
}
#include "mlir/Dialect/Linalg/IR/LinalgOpsEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.cpp.inc"
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.cpp.inc"

View File

@ -36,8 +36,6 @@
using namespace mlir;
using namespace mlir::linalg;
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.cpp.inc"
/// Forward declarations.
/// Generic entry point to create the block for the region of a LinalgOp.
@ -232,14 +230,14 @@ public:
return operand;
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value typefn__cast(Type toType, Value operand) {
return cast(toType, operand, false);
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value typefn__cast_unsigned(Type toType, Value operand) {
return cast(toType, operand, true);
Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
switch (typeFn) {
case TypeFn::cast:
return cast(toType, operand, false);
case TypeFn::cast_unsigned:
return cast(toType, operand, true);
}
llvm_unreachable("unsupported type conversion function");
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.

View File

@ -111,7 +111,7 @@ class TensorUse(TensorExpression):
@property
def tensor_name(self) -> str:
name = self.operand_def.name
assert name is not None, "TensorDef not attached"
assert name is not None, "TensorDef not registered with an op"
return name
def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]:
@ -129,7 +129,8 @@ class TensorUse(TensorExpression):
return ReduceFnUse(ArithFn.add, *self._compute_reduce_dims(rhs))(rhs)
def __repr__(self):
return f"{self.tensor_name}[{', '.join([repr(i) for i in self.indices])}]"
return (f"{self.operand_def.name}"
f"[{', '.join([repr(i) for i in self.indices])}]")
class TensorArithFn(TensorExpression):
@ -156,14 +157,22 @@ class TensorArithFn(TensorExpression):
class TensorTypeFn(TensorExpression):
"""Application of a type conversion function."""
def __init__(self, type_fn: "TypeFn", type_var: TypeVar,
def __init__(self, type_fn: Optional["TypeFn"],
operand_def: Optional["OperandDef"], type_var: TypeVar,
arg: TensorExpression):
if bool(type_fn) + bool(operand_def) != 1:
raise ValueError("Either 'type_fn' or 'operand_def' must be specified")
self.type_fn = type_fn
self.operand_def = operand_def
self.type_var = type_var
self.arg = arg
def to_scalar_expression(self) -> ScalarExpression:
return ScalarTypeFn(self.type_fn.fn_name, self.type_var,
if self.operand_def:
assert self.operand_def.name, "TypeFnAttr not registered with an op"
fn_name = self.type_fn.fn_name if self.type_fn else None
attr_name = self.operand_def.name if self.operand_def else None
return ScalarTypeFn(fn_name, attr_name, self.type_var,
self.arg.to_scalar_expression()).expr()
def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
@ -171,7 +180,8 @@ class TensorTypeFn(TensorExpression):
self.arg.visit_tensor_exprs(callback)
def __repr__(self):
return f"{repr(self.type_fn)}({self.type_var}, {self.arg})"
return (f"{repr(self.type_fn)}[{repr(self.operand_def)}]"
f"({self.type_var}, {self.arg})")
class TensorReduceFn(TensorExpression):
@ -260,7 +270,7 @@ class TypeFnType:
self.fn_name = fn_name
def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TypeFnType":
return TensorTypeFn(self, type_var, arg)
return TensorTypeFn(self, None, type_var, arg)
def __repr__(self):
return f"{self.fn_name}"
@ -370,10 +380,11 @@ class ReduceFn:
class OperandKind(Enum):
InputTensor = 0
Scalar = 1
OutputTensor = 2
IndexAttr = 3
INPUT_TENSOR = 0
SCALAR = 1
OUTPUT_TENSOR = 2
INDEX_ATTR = 3
TYPE_FN_ATTR = 4
class OperandDef:
@ -388,7 +399,8 @@ class OperandDef:
type_var: Optional[TypeVar] = None,
size_exprs: Optional[Sequence[AffineExprDef]] = None,
index_dims: Optional[Sequence[DimDef]] = None,
default_vals: Optional[Sequence[int]] = None):
default_indices: Optional[Sequence[int]] = None,
default_fn: Optional[str] = None):
if type_var and not isinstance(type_var, TypeVar):
raise ValueError(
f"OperandDef requires a TypeVar but got {repr(type_var)}")
@ -396,25 +408,40 @@ class OperandDef:
self.type_var = type_var
self.size_exprs = size_exprs
self.index_dims = index_dims
self.default_vals = default_vals
self.default_indices = default_indices
self.default_fn = default_fn
self.kind = kind
self.name = None # type: Optional[str]
self.registered_index = -1 # type: int
def attach(self, index: int, name: str, owner: "LinalgOpDef"):
if self.owner:
raise ValueError(f"OperandDef already registered with op: {self}")
raise ValueError(f"OperandDef already registered with an op: {self}")
self.registered_index = index
self.name = name
self.owner = owner
def is_input(self) -> bool:
return (self.kind == OperandKind.SCALAR or
self.kind == OperandKind.INPUT_TENSOR)
def is_tensor(self) -> bool:
return (self.kind == OperandKind.INPUT_TENSOR or
self.kind == OperandKind.OUTPUT_TENSOR)
def is_attribute(self) -> bool:
return (self.kind == OperandKind.INDEX_ATTR or
self.kind == OperandKind.TYPE_FN_ATTR)
def __hash__(self):
return hash(id(self))
def __repr__(self):
return (f"{self.name}:OperandDef(kind={self.kind.name}, "
f"type={repr(self.type_var)}, size_exprs={self.size_exprs}), "
f"index_dims={self.index_dims}, default_vals={self.default_vals})")
f"type={repr(self.type_var)}, size_exprs={self.size_exprs}, "
f"index_dims={self.index_dims}, "
f"default_indices={self.default_indices}, "
f"default_fn={self.default_fn})")
class TensorDef:
@ -440,12 +467,12 @@ class TensorDef:
if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims):
raise ValueError(f"TensorDef requires index dims of type DimDef but "
f"got {index_dims}")
kind = OperandKind.OutputTensor if output else OperandKind.InputTensor
kind = OperandKind.OUTPUT_TENSOR if output else OperandKind.INPUT_TENSOR
self.operand_def = OperandDef(
kind, type_var=type_var, size_exprs=shape, index_dims=index_dims)
def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse:
assert self.operand_def.owner, "TensorDef is not attached to an op"
assert self.operand_def.owner, "TensorDef is not registered with an op"
state = AffineBuildState(
global_state=self.operand_def.owner._affine_state,
allow_new_symbols=False)
@ -486,12 +513,12 @@ class ScalarDef(TensorExpression):
"""
def __init__(self, type_var: TypeVar):
self.operand_def = OperandDef(OperandKind.Scalar, type_var=type_var)
self.operand_def = OperandDef(OperandKind.SCALAR, type_var=type_var)
@property
def scalar_name(self) -> str:
name = self.operand_def.name
assert name is not None, "ScalarDef not attached"
assert name is not None, "ScalarDef not registered with an op"
return name
def to_scalar_expression(self) -> ScalarExpression:
@ -517,7 +544,26 @@ class IndexAttrDef:
raise ValueError(f"IndexAttrDef expects {len(sizes)} default values "
f"but got {len(default)}")
self.operand_def = OperandDef(
OperandKind.IndexAttr, size_exprs=sizes, default_vals=default)
OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default)
class TypeFnAttrDef:
"""Type conversion function attribute definition.
Type conversion function attributes provide a way to make type conversions
parameterizable. Every attribute specifies a default type conversion function
that may be overwritten at operation instantiation time.
"""
def __init__(self, default: "TypeFnType"):
if not isinstance(default, TypeFnType):
raise ValueError(f"TypeFnAttrDef requires default of type TypeFnType "
f"but got {default}")
self.operand_def = OperandDef(
OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name)
def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorTypeFn:
return TensorTypeFn(None, self.operand_def, type_var, arg)
###############################################################################
@ -615,17 +661,21 @@ class LinalgOpDef:
if name in self.registered_operands:
raise ValueError(f"The operand {name} is already registered "
f"to {self.registered_operands['name']}")
structured_op_methods = [
"inputs", "outputs", "result_tensors", "region", "iterator_types",
"indexing_maps", "getRegionBuilder", "getLibraryCallName"
]
if operand.is_attribute() and name in structured_op_methods:
raise ValueError(f"The attribute name {name} conflicts with a structured "
f"op method name")
# Ensure output tensors are registered after input tensors and scalars and
# attributes are registered after all other operand types.
registered_kinds = [
operand.kind.value for operand in self.registered_operands.values()
]
if registered_kinds:
maximum = max(registered_kinds)
if maximum > operand.kind.value and maximum > OperandKind.Scalar.value:
raise ValueError(
f"The operand {name} of kind {operand.kind.name} is registered "
f"after an operand of kind {OperandKind(maximum).name}")
if operand.is_input() and any(
not op_def.is_input() for op_def in self.registered_operands.values()):
raise ValueError(f"Input {name} registered after an output or attribute")
if operand.kind == OperandKind.OUTPUT_TENSOR and any(
op_def.is_attribute() for op_def in self.registered_operands.values()):
raise ValueError(f"Output {name} registered after an attribute")
operand.attach(len(self.registered_operands), name, self)
self.registered_operands[name] = operand

View File

@ -55,28 +55,26 @@ class OperandDefConfig(YAMLObject):
def name(self) -> str:
return self.operand_def.name
@property
def kind(self) -> OperandKind:
return self.operand_def.kind
@property
def type_var(self) -> TypeVar:
return self.operand_def.type_var
@property
def usage(self) -> str:
if self.operand_def.kind == OperandKind.IndexAttr:
return "IndexAttr"
if self.operand_def.kind == OperandKind.OutputTensor:
return "Output"
return "Input"
def to_yaml_custom_dict(self):
self_dict = dict(name=self.name, usage=self.usage)
self_dict = dict(name=self.name, kind=self.operand_def.kind.name.lower())
if self.type_var:
self_dict["type_var"] = self.type_var.name
if self.shape_map:
self_dict["shape_map"] = _serialize_affine_map(self.shape_map)
if self.index_attr_map:
self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map)
if self.operand_def.default_vals:
self_dict["default_vals"] = self.operand_def.default_vals
if self.operand_def.default_indices:
self_dict["default_indices"] = self.operand_def.default_indices
if self.operand_def.default_fn:
self_dict["default_fn"] = self.operand_def.default_fn
return self_dict
def __repr__(self):
@ -166,7 +164,7 @@ class LinalgStructuredOpConfig(YAMLObject):
# Collect all attribute definitions.
collected_attr_defs = list()
for operand in registered_operands:
if operand.kind == OperandKind.IndexAttr:
if operand.is_attribute():
collected_attr_defs.append(operand)
# Collect all tensors with manual indexing annotation.
@ -244,12 +242,12 @@ class LinalgStructuredOpConfig(YAMLObject):
# Set the indexing map of all scalar uses to the empty map.
for operand_config in self.operands.values():
if operand_config.operand_def.kind == OperandKind.Scalar:
if operand_config.operand_def.kind == OperandKind.SCALAR:
operand_config.indexing_map = self._get_scalar_map()
# Check all registered tensor and scalar operands have an indexing map.
for operand in registered_operands:
if operand.kind == OperandKind.IndexAttr:
if operand.is_attribute():
continue
if not (operand in self.operands and self.operands[operand].indexing_map):
raise ValueError(f"Failed to compute an indexing map for operand "
@ -311,7 +309,8 @@ class LinalgStructuredOpConfig(YAMLObject):
def add_operand(self, operand_def: OperandDef):
if operand_def in self.operands:
return
if operand_def.kind == OperandKind.Scalar:
if (operand_def.kind == OperandKind.SCALAR or
operand_def.kind == OperandKind.TYPE_FN_ATTR):
self.operands[operand_def] = OperandDefConfig(operand_def)
return
with self.context:
@ -323,7 +322,7 @@ class LinalgStructuredOpConfig(YAMLObject):
assert local_state.local_dim_count == 0
affine_map = _ir.AffineMap.get(
dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs)
if operand_def.kind == OperandKind.IndexAttr:
if operand_def.kind == OperandKind.INDEX_ATTR:
self.operands[operand_def] = OperandDefConfig(
operand_def, index_attr_map=affine_map)
else:
@ -429,8 +428,7 @@ class LinalgOpConfig(YAMLObject):
context: Optional[_ir.Context] = None) -> Sequence["LinalgOpConfig"]:
"""Expands a LinalgOpDef into corresponding Linalg configured ops."""
# TODO: Many LinalgOpDef patterns need to expand to multiple generics.
assert len(
op_def.comprehensions) == 1, "Only one comprehension supported"
assert len(op_def.comprehensions) == 1, "Only one comprehension supported"
return [
LinalgOpConfig(
op_def.metadata,

View File

@ -129,7 +129,8 @@ def linalg_structured_op(dsl_func=None,
sig = inspect.signature(dsl_func)
for param_name, param in sig.parameters.items():
param_default = param.default
if isinstance(param_default, (TensorDef, ScalarDef, IndexAttrDef)):
if isinstance(param_default,
(TensorDef, ScalarDef, IndexAttrDef, TypeFnAttrDef)):
op_def.add_operand(param_name, param_default.operand_def)
else:
raise ValueError(

View File

@ -37,11 +37,21 @@ def isa(cls: Type, ty: Type):
def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
*ins: Value, outs: ValueList,
**attrs: Sequence[int]):
**attrs: Union[Sequence[int], TypeFnType]):
all_arg_defs = op_config.ordered_operands
in_arg_defs = [d for d in all_arg_defs if d.usage == "Input"]
out_arg_defs = [d for d in all_arg_defs if d.usage == "Output"]
index_attr_arg_defs = [d for d in all_arg_defs if d.usage == "IndexAttr"]
in_arg_defs = [
d for d in all_arg_defs
if d.kind == OperandKind.SCALAR or d.kind == OperandKind.INPUT_TENSOR
]
out_arg_defs = [
d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR
]
index_attr_arg_defs = [
d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR
]
type_fn_attr_arg_defs = [
d for d in all_arg_defs if d.kind == OperandKind.TYPE_FN_ATTR
]
# Verify outs is a sequence or a list of results.
if not isinstance(outs, (Sequence, OpResultList)):
@ -56,11 +66,11 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
raise ValueError(f"Expected {len(out_arg_defs)} outputs but got "
f"{len(outs)} for {op_config}")
# Compute a replacement list for all attribute symbols.
# Compute a replacement list for all index attribute symbols.
expressions = [] # type: Sequence[AffineExpr]
replacements = [] # type: Sequence[AffineExpr]
for index_attr in index_attr_arg_defs:
index_attr_vals = index_attr.operand_def.default_vals
index_attr_vals = index_attr.operand_def.default_indices
if index_attr.name in attrs:
index_attr_vals = attrs.get(index_attr.name)
assert index_attr_vals, "Index attribute has no value"
@ -125,15 +135,29 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
array = np.array(index_attr_vals, dtype=np.int64)
index_attrs[index_attr.name] = DenseElementsAttr.get(array)
# Compute the type function attribute mapping.
type_fn_attr_mapping = {}
for type_fn_attr in type_fn_attr_arg_defs:
attr_val = type_fn_attr.operand_def.default_fn
if type_fn_attr.name in attrs:
type_fn = attrs.get(type_fn_attr.name)
if not isinstance(type_fn, TypeFnType):
raise ValueError(f"Attribute {type_fn_attr.name} needs to be of type "
f"TypeFnType but got {type(attr_val)}")
attr_val = type_fn.fn_name
assert attr_val, "Type function attribute has no value"
type_fn_attr_mapping[type_fn_attr.name] = attr_val
return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types,
type_mapping, indexing_maps_attr, iterator_types_attr,
index_attrs, block_arg_types)
type_mapping, indexing_maps_attr, iterator_types_attr, index_attrs,
type_fn_attr_mapping, block_arg_types)
def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
outs: ValueList, **attrs: Sequence[int]):
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
indexing_maps_attr, iterator_types_attr, index_attrs, block_arg_types = \
indexing_maps_attr, iterator_types_attr, index_attrs, type_fn_attr_mapping, \
block_arg_types = \
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
# An operation that accesses only scalars and scalar/rank zero tensors is
@ -147,10 +171,9 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
tensor_map = AffineMap.get_identity(rank)
indexing_maps = []
for arg_def in all_arg_defs:
if arg_def.operand_def.kind == OperandKind.Scalar:
if arg_def.operand_def.kind == OperandKind.SCALAR:
indexing_maps.append(scalar_map)
if (arg_def.operand_def.kind == OperandKind.InputTensor or
arg_def.operand_def.kind == OperandKind.OutputTensor):
if arg_def.operand_def.is_tensor():
indexing_maps.append(tensor_map)
indexing_maps_attr = ArrayAttr.get(
[AffineMapAttr.get(am) for am in indexing_maps])
@ -169,7 +192,8 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
block = generic_op.regions[0].blocks.append(*block_arg_types)
block_arg_mapping = dict(zip(block_arg_names, block.arguments))
with InsertionPoint(block):
body_builder = _BodyBuilder(type_mapping, block_arg_mapping)
body_builder = _BodyBuilder(type_mapping, block_arg_mapping,
type_fn_attr_mapping)
for assignment in op_config.assignments:
body_builder.assign(assignment)
body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs))
@ -184,7 +208,8 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
op_class_name: str, *ins: Value, outs: ValueList,
**attrs: Sequence[int]):
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
indexing_maps_attr, iterator_types_attr, index_attrs, block_arg_types = \
indexing_maps_attr, iterator_types_attr, index_attrs, type_fn_attr_mapping, \
block_arg_types = \
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
# If we get here, there must exist a builtin class `op_class_name`.
@ -200,6 +225,11 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
for name, value in index_attrs.items():
named_op.operation.attributes[name] = value
# Set the type function attributes.
for name, value in type_fn_attr_mapping.items():
named_op.operation.attributes[name] = Attribute.parse(
f"#linalg.type_fn<{value}>")
linalg.fill_builtin_region(named_op.operation)
if len(result_types) == 1:
@ -212,9 +242,11 @@ class _BodyBuilder:
"""Constructs a structured op body by evaluating assignments."""
def __init__(self, type_mapping: Dict[str, Type],
block_arg_mapping: Dict[str, Value]):
block_arg_mapping: Dict[str, Value],
type_fn_attr_mapping: Dict[str, str]):
self.type_mapping = type_mapping
self.block_arg_mapping = block_arg_mapping
self.type_fn_attr_mapping = type_fn_attr_mapping
self.yield_mapping = dict() # type: Dict[str, Value]
def assign(self, assignment: ScalarAssign):
@ -245,7 +277,10 @@ class _BodyBuilder:
]
return fn(*operand_values)
elif expr.type_fn:
fn = self._get_function(f"_typefn_{expr.type_fn.fn_name}")
fn_name = expr.type_fn.fn_name
if expr.type_fn.attr_name:
fn_name = self.type_fn_attr_mapping[expr.type_fn.attr_name]
fn = self._get_function(f"_typefn_{fn_name}")
operand = self.expression(expr.type_fn.operand)
return fn(expr.type_fn.type_var.name, operand)
raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")

View File

@ -46,9 +46,10 @@ class ScalarArithFn:
class ScalarTypeFn:
"""A type of ScalarExpression that applies a type conversion function."""
def __init__(self, fn_name: str, type_var: TypeVar,
operand: "ScalarExpression"):
def __init__(self, fn_name: Optional[str], attr_name: Optional[str],
type_var: TypeVar, operand: "ScalarExpression"):
self.fn_name = fn_name
self.attr_name = attr_name
self.type_var = type_var
self.operand = operand
@ -56,7 +57,8 @@ class ScalarTypeFn:
return ScalarExpression(type_fn=self)
def __repr__(self):
return f"ScalarTypeFn<{self.fn_name}>({self.type_var}, {self.operand})"
return (f"ScalarTypeFn<{self.fn_name}[{self.attr_name}]>"
f"({self.type_var}, {self.operand})")
class ScalarArg:
@ -138,12 +140,15 @@ class ScalarExpression(YAMLObject):
# Note that even though operands must be arity 1, we write it the
# same way as for apply because it allows handling code to be more
# generic vs having a special form.
return dict(
type_fn=dict(
fn_name=self.type_fn.fn_name,
type_var=self.type_fn.type_var.name,
operands=[self.type_fn.operand],
))
type_fn_dict = dict(
type_var=self.type_fn.type_var.name,
operands=[self.type_fn.operand],
)
if self.type_fn.fn_name:
type_fn_dict["fn_name"] = self.type_fn.fn_name
if self.type_fn.attr_name:
type_fn_dict["attr_name"] = self.type_fn.attr_name
return dict(type_fn=type_fn_dict)
elif self.scalar_arg:
return dict(scalar_arg=self.scalar_arg.arg)
elif self.scalar_const:

View File

@ -10,7 +10,8 @@ Batch = S.Batch
def matmul(
A=TensorDef(T1, S.M, S.K),
B=TensorDef(T2, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True)):
C=TensorDef(U, S.M, S.N, output=True),
cast=TypeFnAttrDef(default=TypeFn.cast)):
"""Performs a matrix multiplication of two 2D inputs.
Numeric casting is performed on the operands to the inner multiply, promoting
@ -18,7 +19,7 @@ def matmul(
"""
domain(D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
@linalg_structured_op

View File

@ -36,6 +36,20 @@ func @generalize_matmul_tensor_i16i64i32(%A : tensor<16x8xi16>, %B: tensor<8x32x
// CHECK-NEXT: linalg.yield %[[ADD]] : i32
// CHECK-NEXT: -> tensor<16x32xi32>
// -----
// Verifies that cast attributes control the cast operations used.
func @generalize_matmul_tensor_i16i64i32_unsigned(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
%0 = linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>)
outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
return %0: tensor<16x32xi32>
}
// CHECK-LABEL: @generalize_matmul_tensor_i16i64i32_unsigned
// CHECK: = arith.extui
// -----
func @generalize_matmul_tensor_i16i64f32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {

View File

@ -1258,7 +1258,7 @@ class IndexExpr(abc.ABC):
value = self._emit_expression(expr_to_input_opnd, expr_to_info)
# Emit the structured op representation for the destination tensor.
dst_opnd = _emit_operand(op_def, op_info.dst_indices, op_info.dst_name,
lang.OperandKind.OutputTensor)
lang.OperandKind.OUTPUT_TENSOR)
dst_dim_syms = _mlir_dimensions_from_index_vars(op_info.dst_indices)
dst_use = lang.TensorUse(dst_opnd, dst_dim_syms)
@ -1893,6 +1893,6 @@ def _emit_structured_op_input(
name = expr.tensor.name
dim_sym = _mlir_symbols_from_index_vars(indices)
opnd = lang.OperandDef(lang.OperandKind.InputTensor, lang.T, dim_sym)
opnd = lang.OperandDef(lang.OperandKind.INPUT_TENSOR, lang.T, dim_sym)
op_def.add_operand(name, opnd)
return opnd

View File

@ -2,7 +2,8 @@
# RUN: mlir-linalg-ods-yaml-gen %s --o-impl=- | FileCheck %s --check-prefix=IMPL
# @linalg_structured_op
# def test1(O=TensorDef(T, S.M, S.N, output=True)):
# def test1(O=TensorDef(T, S.M, S.N, output=True),
# cast=TypeFnAttrDef(default=TypeFn.cast)):
# """Title.
# Detailed description.
@ -21,9 +22,13 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: O
usage: Output
kind: output_tensor
type_var: T
shape_map: affine_map<()[s0, s1] -> (s0, s1)>
- !LinalgOperandDefConfig
name: cast
kind: type_fn_attr
default_fn: cast
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
@ -39,18 +44,18 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
type_fn:
fn_name: cast
type_var: T
operands:
- !ScalarExpression
scalar_const: '42 : i64'
attr_name: cast
- !ScalarExpression
type_fn:
fn_name: cast_unsigned
type_var: T
operands:
- !ScalarExpression
scalar_index: 1
attr_name: cast
# ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1"
@ -61,16 +66,22 @@ structured_op: !LinalgStructuredOpConfig
# ODS: let arguments =
# ODS-NEXT: Variadic<AnyType>:$inputs,
# ODS-NEXT: Variadic<AnyShaped>:$outputs
# ODS-NEXT: Variadic<AnyShaped>:$outputs,
# ODS-NEXT: DefaultValuedAttr<TypeFnAttr, "TypeFn::cast">:$cast
# ODS: let builders =
# ODS: (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
# ODS-NEXT: "ValueRange":$outputs,
# ODS-NEXT: CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
# ODS: (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
# ODS-NEXT: "ValueRange":$outputs, "Attribute":$cast,
# ODS-NEXT: CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
# ODS: $_state.addOperands(inputs);
# ODS-NEXT: $_state.addOperands(outputs);
# ODS-NEXT: $_state.addTypes(resultTensorTypes);
# ODS-NEXT: $_state.addAttribute("cast", cast)
# ODS-NEXT: $_state.addAttributes(attributes);
# ODS-NEXT: $_state.addAttribute(
# ODS-NEXT: "operand_segment_sizes",
@ -85,10 +96,18 @@ structured_op: !LinalgStructuredOpConfig
# IMPL-LABEL: void Test1Op::regionBuilder(ImplicitLocOpBuilder &b,
# IMPL-NEXT: Block &block, ArrayRef<NamedAttribute> attrs)
# IMPL: TypeFn castVal = TypeFn::cast;
# IMPL-NEXT: auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
# IMPL-NEXT: return attr.getName() == "cast"; });
# IMPL-NEXT: if (castIter != attrs.end()) {
# IMPL-NEXT: if (auto attr = castIter->getValue().dyn_cast<TypeFnAttr>())
# IMPL-NEXT: castVal = attr.getValue();
# IMPL-NEXT: }
# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64");
# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.typefn__cast(block.getArgument(0).getType(), [[VAL0]]);
# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL0]]);
# IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1);
# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.typefn__cast_unsigned(block.getArgument(0).getType(), [[VAL2]]);
# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL2]]);
# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.arithfn__add([[VAL1]], [[VAL3]]);
@ -114,19 +133,19 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
usage: Input
kind: input_tensor
type_var: T
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)>
- !LinalgOperandDefConfig
name: O
usage: Output
kind: output_tensor
type_var: T
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)>
- !LinalgOperandDefConfig
name: strides
usage: IndexAttr
kind: index_attr
index_attr_map: affine_map<()[s0, s1, s2, s3] -> (s2, s3)>
default_vals:
default_indices:
- 1
- 2
indexing_maps: !LinalgIndexingMapsConfig
@ -201,11 +220,11 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: value
usage: Input
kind: scalar
type_var: T1
- !LinalgOperandDefConfig
name: O
usage: Output
kind: output_tensor
type_var: U
shape_map: affine_map<() -> ()>
indexing_maps: !LinalgIndexingMapsConfig

View File

@ -7,30 +7,34 @@ from mlir.dialects.linalg.opdsl.lang import *
# CHECK-LABEL: matmul
# CHECK: args:
# CHECK: name: A
# CHECK: usage: Input
# CHECK: kind: input_tensor
# CHECK: type_var: T
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
# CHECK: name: B
# CHECK: usage: Input
# CHECK: kind: input_tensor
# CHECK: type_var: T
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
# CHECK: name: C
# CHECK: usage: Output
# CHECK: kind: output_tensor
# CHECK: type_var: U
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
# CHECK: name: cast
# CHECK: kind: type_fn_attr
# CHECK: default_fn: cast
@linalg_structured_op
def matmul(
A=TensorDef(T, S.M, S.K),
B=TensorDef(T, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True)):
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
C=TensorDef(U, S.M, S.N, output=True),
cast=TypeFnAttrDef(default=TypeFn.cast)):
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
# CHECK: ---
# CHECK-LABEL: fill
# CHECK: args:
# CHECK: name: value
# CHECK: usage: Input
# CHECK: kind: scalar
# CHECK-NOT: shape_map:
# CHECK: type_var: T
@linalg_structured_op
@ -42,17 +46,17 @@ def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
# CHECK-LABEL: strided_copy
# CHECK: args:
# CHECK: name: I
# CHECK: usage: Input
# CHECK: kind: input_tensor
# CHECK: type_var: T
# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1)>
# CHECK: name: O
# CHECK: usage: Output
# CHECK: kind: output_tensor
# CHECK: type_var: T
# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2, s3)>
# CHECK: name: strides
# CHECK: usage: IndexAttr
# CHECK: kind: index_attr
# CHECK: index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s5)>
# CHECK: default_vals:
# CHECK: default_indices:
# CHECK: - 1
# CHECK: - 2
@linalg_structured_op

View File

@ -19,16 +19,19 @@ from mlir.dialects.linalg.opdsl.lang import *
# CHECK: type_var: U
# CHECK: operands:
# CHECK: scalar_arg: A
# CHECK: attr_name: cast
# CHECK: type_fn:
# CHECK: type_var: U
# CHECK: operands:
# CHECK: scalar_arg: B
# CHECK: attr_name: cast
@linalg_structured_op
def matmul(
A=TensorDef(T, S.M, S.K),
B=TensorDef(T, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True)):
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
C=TensorDef(U, S.M, S.N, output=True),
cast=TypeFnAttrDef(default=TypeFn.cast)):
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
# CHECK: ---

View File

@ -24,19 +24,10 @@ def matmul_mono(
def matmul_poly(
A=TensorDef(T1, S.M, S.K),
B=TensorDef(T2, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True)):
C=TensorDef(U, S.M, S.N, output=True),
cast=TypeFnAttrDef(default=TypeFn.cast)):
domain(D.m, D.n, D.k)
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
@linalg_structured_op
def matmul_unsigned_poly(
A=TensorDef(T1, S.M, S.K),
B=TensorDef(T2, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True)):
domain(D.m, D.n, D.k)
C[D.m, D.n] += TypeFn.cast_unsigned(U, A[D.m, D.k]) * TypeFn.cast_unsigned(
U, B[D.k, D.n])
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
with Context() as ctx, Location.unknown():
@ -92,7 +83,8 @@ with Context() as ctx, Location.unknown():
RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
RankedTensorType.get((4, 8), i32))
def test_i8i8i32_matmul_unsigned(lhs, rhs, init_result):
return matmul_unsigned_poly(lhs, rhs, outs=[init_result])
return matmul_poly(
lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned)
# CHECK-LABEL: @test_i8i16i32_matmul
# CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32)
@ -143,7 +135,8 @@ with Context() as ctx, Location.unknown():
RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
RankedTensorType.get((4, 8), f32))
def test_i8i8f32_matmul_unsigned(lhs, rhs, init_result):
return matmul_unsigned_poly(lhs, rhs, outs=[init_result])
return matmul_poly(
lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned)
# CHECK-LABEL: @test_f16f16f32_matmul
# CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)

View File

@ -6,6 +6,8 @@ from mlir.dialects import linalg
from mlir.dialects import std
from mlir.dialects import arith
from mlir.dialects.linalg.opdsl.lang import *
def run(f):
print("\nTEST:", f.__name__)
@ -98,12 +100,14 @@ def testNamedStructuredOpCustomForm():
init_result = linalg.InitTensorOp([4, 8], f32)
# First check the named form with custom format
# CHECK: linalg.matmul
# CHECK: cast = #linalg.type_fn<cast_unsigned>
# CHECK-NOT: linalg.memoized_indexing_maps
# CHECK-SAME: ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>)
# CHECK-SAME: outs(%{{.*}} : tensor<4x8xf32>)
# CHECK-SAME: -> tensor<4x8xf32>
# CHECK-NEXT: return
return linalg.matmul(lhs, rhs, outs=[init_result.result])
return linalg.matmul(
lhs, rhs, outs=[init_result.result], cast=TypeFn.cast_unsigned)
print(module)

View File

@ -9,6 +9,8 @@ from mlir.dialects import std
from mlir.passmanager import *
from mlir.execution_engine import *
from mlir.dialects.linalg.opdsl.lang import *
# Log everything to stderr and flush so that we have a unified stream to match
# errors/info emitted by MLIR to stderr.
@ -20,21 +22,28 @@ def log(*args):
matmul_boiler = """
func @main() -> f32 attributes {llvm.emit_c_interface} {
%v0 = arith.constant 0.0 : f32
%v1 = arith.constant 1.0 : f32
%v1 = arith.constant -1 : i8
%v2 = arith.constant 2.0 : f32
%A = memref.alloc() : memref<4x16xf32>
%A = memref.alloc() : memref<4x16xi8>
%B = memref.alloc() : memref<16x8xf32>
%C = memref.alloc() : memref<4x8xf32>
linalg.fill(%v1, %A) : f32, memref<4x16xf32>
%C0 = memref.alloc() : memref<4x8xf32>
%C1 = memref.alloc() : memref<4x8xf32>
linalg.fill(%v1, %A) : i8, memref<4x16xi8>
linalg.fill(%v2, %B) : f32, memref<16x8xf32>
linalg.fill(%v0, %C) : f32, memref<4x8xf32>
linalg.fill(%v0, %C0) : f32, memref<4x8xf32>
linalg.fill(%v0, %C1) : f32, memref<4x8xf32>
call @matmul_on_buffers(%A, %B, %C) :
(memref<4x16xf32>, memref<16x8xf32>, memref<4x8xf32>) -> ()
call @matmul_signed_on_buffers(%A, %B, %C0) :
(memref<4x16xi8>, memref<16x8xf32>, memref<4x8xf32>) -> ()
call @matmul_unsigned_on_buffers(%A, %B, %C1) :
(memref<4x16xi8>, memref<16x8xf32>, memref<4x8xf32>) -> ()
%c0 = arith.constant 0 : index
%0 = memref.load %C[%c0, %c0] : memref<4x8xf32>
%res0 = memref.load %C0[%c0, %c0] : memref<4x8xf32>
%res1 = memref.load %C1[%c0, %c0] : memref<4x8xf32>
%0 = arith.addf %res0, %res1 : f32
// TODO: FFI-based solution to allow testing and printing with python code.
return %0 : f32
@ -157,8 +166,8 @@ def transform(module, boilerplate):
pm = PassManager.parse(
"builtin.func(convert-linalg-to-loops, lower-affine, " +
"convert-scf-to-cf, arith-expand, memref-expand), convert-vector-to-llvm," +
"convert-memref-to-llvm, convert-std-to-llvm," +
"convert-scf-to-cf, arith-expand, memref-expand), convert-vector-to-llvm,"
+ "convert-memref-to-llvm, convert-std-to-llvm," +
"reconcile-unrealized-casts")
pm.run(mod)
return mod
@ -168,14 +177,21 @@ def test_matmul_builtin():
with Context() as ctx, Location.unknown():
module = Module.create()
f32 = F32Type.get()
i8 = IntegerType.get_signless(8)
with InsertionPoint(module.body):
@builtin.FuncOp.from_py_func(
MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32),
MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32),
MemRefType.get((4, 8), f32))
def matmul_on_buffers(lhs, rhs, out):
def matmul_signed_on_buffers(lhs, rhs, out):
linalg.matmul(lhs, rhs, outs=[out])
@builtin.FuncOp.from_py_func(
MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32),
MemRefType.get((4, 8), f32))
def matmul_unsigned_on_buffers(lhs, rhs, out):
linalg.matmul(lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned)
execution_engine = ExecutionEngine(transform(module, matmul_boiler))
# TODO: FFI-based solution to allow testing and printing with python code.
@ -186,7 +202,9 @@ def test_matmul_builtin():
execution_engine.invoke("main", res)
log("RESULT: ", res[0])
# CHECK: RESULT: 32.0
# matmul_signed_on_buffers: -1 * 2.0 * 16 = -32
# matmul_unsigned_on_buffers: (2^8-1) * 2.0 * 16 = 8160
# CHECK: RESULT: 8128
test_matmul_builtin()
@ -196,14 +214,22 @@ def test_matmul_generic():
with Context() as ctx, Location.unknown():
module = Module.create()
f32 = F32Type.get()
i8 = IntegerType.get_signless(8)
with InsertionPoint(module.body):
@builtin.FuncOp.from_py_func(
MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32),
MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32),
MemRefType.get((4, 8), f32))
def matmul_on_buffers(lhs, rhs, out):
def matmul_signed_on_buffers(lhs, rhs, out):
linalg.matmul(lhs, rhs, outs=[out], emit_generic=True)
@builtin.FuncOp.from_py_func(
MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32),
MemRefType.get((4, 8), f32))
def matmul_unsigned_on_buffers(lhs, rhs, out):
linalg.matmul(
lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned, emit_generic=True)
execution_engine = ExecutionEngine(transform(module, matmul_boiler))
# TODO: FFI-based solution to allow testing and printing with python code.
@ -214,7 +240,9 @@ def test_matmul_generic():
execution_engine.invoke("main", res)
log("RESULT: ", res[0])
# CHECK: RESULT: 32.0
# matmul_signed_on_buffers = -1 * 2.0 * 16 = -32
# matmul_unsigned_on_buffers = (2^8-1) * 2.0 * 16 = 8160
# CHECK: RESULT: 8128
test_matmul_generic()
@ -423,11 +451,7 @@ def test_min_pooling_builtin():
MemRefType.get((1, 2, 4, 1), i32))
# Set the strides and use the default dilations.
def pooling_on_buffers(input, shape, output):
linalg.pooling_nhwc_min(
input,
shape,
outs=[output],
strides=[2, 4])
linalg.pooling_nhwc_min(input, shape, outs=[output], strides=[2, 4])
execution_engine = ExecutionEngine(transform(module, pooling_boiler))
@ -458,11 +482,7 @@ def test_min_pooling_generic():
# Set the strides and use the default dilations.
def pooling_on_buffers(input, shape, output):
linalg.pooling_nhwc_min(
input,
shape,
outs=[output],
strides=[2, 4],
emit_generic=True)
input, shape, outs=[output], strides=[2, 4], emit_generic=True)
execution_engine = ExecutionEngine(transform(module, pooling_boiler))

View File

@ -61,15 +61,22 @@ struct SerializedAffineMap {
AffineMap affineMap() { return affineMapAttr.getValue(); }
};
enum class LinalgOperandDefUsage { Input, Output, IndexAttr };
enum class LinalgOperandDefKind {
InputTensor,
Scalar,
OutputTensor,
IndexAttr,
TypeFnAttr
};
struct LinalgOperandDef {
std::string name;
LinalgOperandDefUsage usage;
LinalgOperandDefKind kind;
Optional<std::string> typeVar;
Optional<SerializedAffineMap> shapeMap;
Optional<SerializedAffineMap> indexAttrMap;
Optional<SmallVector<int64_t>> defaultVals;
Optional<SmallVector<int64_t>> defaultIndices;
Optional<std::string> defaultFn;
};
enum class LinalgIteratorTypeDef {
@ -91,11 +98,12 @@ struct ScalarArithFn {
};
struct ScalarTypeFn {
std::string fnName;
std::string typeVar;
// NOTE: This must be of arity 1, but to break the self-referential cycle,
// we use a heap allocated vector.
std::vector<ScalarExpression> operands;
Optional<std::string> fnName;
Optional<std::string> attrName;
};
struct ScalarExpression {
@ -180,27 +188,32 @@ struct MappingTraits<LinalgStructuredOpConfig> {
/// index attribute symbols. During op creation these symbols are replaced
/// by the corresponding `name` index attribue values. Only index attribute
/// arguments have an `index_attr_map`.
/// - `default_vals`: An optional default initialization for index attribute
/// - `default_indices`: An optional default initialization for index
/// attribute arguments.
/// - `default_fn`: An optional default initialization for function attribute
/// arguments.
template <>
struct MappingTraits<LinalgOperandDef> {
static void mapping(IO &io, LinalgOperandDef &info) {
io.mapRequired("name", info.name);
io.mapRequired("usage", info.usage);
io.mapRequired("kind", info.kind);
io.mapOptional("type_var", info.typeVar);
io.mapOptional("shape_map", info.shapeMap);
io.mapOptional("index_attr_map", info.indexAttrMap);
io.mapOptional("default_vals", info.defaultVals);
io.mapOptional("default_indices", info.defaultIndices);
io.mapOptional("default_fn", info.defaultFn);
}
};
/// Usage enum for a named argument.
template <>
struct ScalarEnumerationTraits<LinalgOperandDefUsage> {
static void enumeration(IO &io, LinalgOperandDefUsage &value) {
io.enumCase(value, "Input", LinalgOperandDefUsage::Input);
io.enumCase(value, "Output", LinalgOperandDefUsage::Output);
io.enumCase(value, "IndexAttr", LinalgOperandDefUsage::IndexAttr);
struct ScalarEnumerationTraits<LinalgOperandDefKind> {
static void enumeration(IO &io, LinalgOperandDefKind &value) {
io.enumCase(value, "input_tensor", LinalgOperandDefKind::InputTensor);
io.enumCase(value, "scalar", LinalgOperandDefKind::Scalar);
io.enumCase(value, "output_tensor", LinalgOperandDefKind::OutputTensor);
io.enumCase(value, "index_attr", LinalgOperandDefKind::IndexAttr);
io.enumCase(value, "type_fn_attr", LinalgOperandDefKind::TypeFnAttr);
}
};
@ -281,9 +294,10 @@ struct MappingTraits<ScalarArithFn> {
template <>
struct MappingTraits<ScalarTypeFn> {
static void mapping(IO &io, ScalarTypeFn &info) {
io.mapRequired("fn_name", info.fnName);
io.mapRequired("type_var", info.typeVar);
io.mapRequired("operands", info.operands);
io.mapOptional("fn_name", info.fnName);
io.mapOptional("attr_name", info.attrName);
}
};
@ -399,8 +413,9 @@ findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgOperandDef> &args) {
// Search all argument types.
for (const auto &it : llvm::enumerate(args)) {
if (it.value().usage != LinalgOperandDefUsage::Input &&
it.value().usage != LinalgOperandDefUsage::Output)
if (it.value().kind != LinalgOperandDefKind::InputTensor &&
it.value().kind != LinalgOperandDefKind::Scalar &&
it.value().kind != LinalgOperandDefKind::OutputTensor)
continue;
if (it.value().typeVar.getValue() == typeVar)
return llvm::formatv("block.getArgument({0}).getType()", it.index())
@ -552,6 +567,8 @@ static const char structuredOpBuilderFormat[] = R"FMT(
$_state.addOperands(inputs);
$_state.addOperands(outputs);
$_state.addTypes(resultTensorTypes);
{2}
$_state.addAttributes(attributes);
$_state.addAttribute(
"operand_segment_sizes",
$_builder.getI32VectorAttr({{
@ -562,8 +579,6 @@ static const char structuredOpBuilderFormat[] = R"FMT(
$_state,
TypeRange(inputs),
TypeRange(outputs));
{2}
$_state.addAttributes(attributes);
}]>
)FMT";
@ -681,42 +696,56 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
interfaceNameList = interleaveToString(opConfig.metadata->implements, ", ");
// Assemble the attribute specific logic required for the op definition.
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
return arg.usage == LinalgOperandDefUsage::IndexAttr;
return arg.kind == LinalgOperandDefKind::IndexAttr ||
arg.kind == LinalgOperandDefKind::TypeFnAttr;
})) {
SmallVector<std::string> attrDefs;
SmallVector<std::string> attrParams;
SmallVector<std::string> attrStmts;
for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
if (arg.usage != LinalgOperandDefUsage::IndexAttr)
continue;
assert(arg.indexAttrMap.hasValue());
assert(arg.defaultVals.hasValue());
size_t size = arg.indexAttrMap->affineMap().getNumResults();
assert(arg.defaultVals.getValue().size() == size);
static const char typeFmt[] = "RankedI64ElementsAttr<[{0}]>";
static const char defFmt[] = "DefaultValuedAttr<{0}, \"{1}\">:${2}";
static const char paramFmt[] = "\"Attribute\":${0}";
static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});";
std::string defaultVals;
llvm::raw_string_ostream ss(defaultVals);
ss << "{ ";
llvm::interleave(
arg.defaultVals.getValue(), ss,
[&](int64_t val) { ss << "static_cast<int64_t>(" << val << ")"; },
", ");
ss << " }";
attrDefs.push_back(llvm::formatv(defFmt, llvm::formatv(typeFmt, size),
ss.str(), arg.name));
attrParams.push_back(llvm::formatv(paramFmt, arg.name));
attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
// Add the type conversion attributes to the op definition and builders.
if (arg.kind == LinalgOperandDefKind::TypeFnAttr) {
assert(arg.defaultFn.hasValue());
static const char typeFmt[] = "TypeFn::{0}";
static const char defFmt[] = "DefaultValuedAttr<{0}, \"{1}\">:${2}";
attrDefs.push_back(llvm::formatv(defFmt, "TypeFnAttr",
llvm::formatv(typeFmt, arg.defaultFn),
arg.name));
attrParams.push_back(llvm::formatv(paramFmt, arg.name));
attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
}
// Add the index attributes to the op definition and builders.
if (arg.kind == LinalgOperandDefKind::IndexAttr) {
assert(arg.indexAttrMap.hasValue());
assert(arg.defaultIndices.hasValue());
size_t size = arg.indexAttrMap->affineMap().getNumResults();
assert(arg.defaultIndices.getValue().size() == size);
static const char typeFmt[] = "RankedI64ElementsAttr<[{0}]>";
static const char defFmt[] = "DefaultValuedAttr<{0}, \"{ {1} }\">:${2}";
std::string defaultVals;
llvm::raw_string_ostream ss(defaultVals);
llvm::interleave(
arg.defaultIndices.getValue(), ss,
[&](int64_t val) { ss << "static_cast<int64_t>(" << val << ")"; },
", ");
attrDefs.push_back(llvm::formatv(defFmt, llvm::formatv(typeFmt, size),
ss.str(), arg.name));
attrParams.push_back(llvm::formatv(paramFmt, arg.name));
attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
}
}
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
return arg.kind == LinalgOperandDefKind::IndexAttr;
})) {
attrMethods = R"(
bool hasDynamicIndexingMaps();
LogicalResult verifyIndexingMapRequiredAttributes();
)";
}
attrList = ",\n" + llvm::join(attrDefs, ",\n");
attrMethods = R"(
bool hasDynamicIndexingMaps();
LogicalResult verifyIndexingMapRequiredAttributes();
)";
attrBuilder = llvm::formatv(
structuredOpBuilderFormat, opConfig.metadata->cppClassName,
llvm::join(attrParams, ", "), llvm::join(attrStmts, "\n"));
@ -746,7 +775,9 @@ generateNamedGenericOpDefns(LinalgOpConfig &opConfig,
// Compute the number of scalar and tensor arguments.
int64_t numOfArgs =
llvm::count_if(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
return arg.usage != LinalgOperandDefUsage::IndexAttr;
return arg.kind == LinalgOperandDefKind::InputTensor ||
arg.kind == LinalgOperandDefKind::Scalar ||
arg.kind == LinalgOperandDefKind::OutputTensor;
});
// An operation that accesses only scalars and scalar/rank zero tensors is
@ -817,7 +848,7 @@ exprs.push_back(getAffineConstantExpr(cst{1}, context));
)FMT";
// Update all symbol bindings mapped to an attribute.
for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
if (arg.usage != LinalgOperandDefUsage::IndexAttr)
if (arg.kind != LinalgOperandDefKind::IndexAttr)
continue;
assert(arg.indexAttrMap.hasValue());
for (auto &en :
@ -910,11 +941,11 @@ std::string {0}::getLibraryCallName() {{
// hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes()
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
return arg.usage == LinalgOperandDefUsage::IndexAttr;
return arg.kind == LinalgOperandDefKind::IndexAttr;
})) {
std::vector<std::string> attrVerifications;
for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
if (arg.usage != LinalgOperandDefUsage::IndexAttr)
if (arg.kind != LinalgOperandDefKind::IndexAttr)
continue;
assert(arg.indexAttrMap.hasValue());
// Verify index attribute. Paramters:
@ -952,7 +983,8 @@ LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{
// Generates a regionBuilder method. Parameters.
// {0}: Class name
// {1}: Number of args
// {2}: Statements
// {2}: Attributes
// {3}: Statements
static const char structuredOpRegionBuilderFormat[] = R"FMT(
void {0}::regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs) {{
@ -961,6 +993,7 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b,
RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
SmallVector<Value> yields;
{2}
{3}
helper.yieldOutputs(yields);
}
)FMT";
@ -968,9 +1001,27 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b,
auto &assignments = opConfig.structuredOp->assignments;
size_t generatedAssignmentCount = 0;
int localCounter = 0;
SmallVector<std::string> attrs;
SmallVector<std::string> stmts;
for (LinalgOperandDef &arg : args) {
if (arg.usage != LinalgOperandDefUsage::Output)
if (arg.kind != LinalgOperandDefKind::TypeFnAttr)
continue;
// Obtain the type function attribute values. Parameters.
// {0}: attribute name
// {1}: default type function name
static const char attrDef[] = R"FMT(
TypeFn {0}Val = TypeFn::{1};
auto {0}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
return attr.getName() == "{0}"; });
if ({0}Iter != attrs.end()) {{
if (auto attr = {0}Iter->getValue().dyn_cast<TypeFnAttr>())
{0}Val = attr.getValue();
}
)FMT";
attrs.push_back(llvm::formatv(attrDef, arg.name, arg.defaultFn));
}
for (LinalgOperandDef &arg : args) {
if (arg.kind != LinalgOperandDefKind::OutputTensor)
continue;
// Find the assignment that correlates with the argument.
@ -1048,11 +1099,25 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b,
<< "an argument type but it does not";
return None;
}
// Use the function name or the attribute to build the type function.
std::string typeFunc = llvm::formatv(
"TypeFn::{0}", expression.typeFn->fnName.getValueOr(""));
if (expression.typeFn->attrName) {
if (llvm::none_of(args, [&](LinalgOperandDef &arg) {
return arg.kind == LinalgOperandDefKind::TypeFnAttr &&
arg.name == expression.typeFn->attrName.getValue();
})) {
emitError(genContext.getLoc())
<< "missing type function attribute "
<< expression.typeFn->attrName.getValue();
}
typeFunc = llvm::formatv("{0}Val", *expression.typeFn->attrName);
}
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
stmts.push_back(
llvm::formatv("Value {0} = helper.typefn__{1}({2}, {3});",
cppIdent, expression.typeFn->fnName,
typeCppValue.getValue(), *operandCppValue));
stmts.push_back(llvm::formatv(
"Value {0} = helper.buildTypeFn({1}, {2}, {3});", cppIdent,
typeFunc, typeCppValue.getValue(), *operandCppValue));
return cppIdent;
}
emitError(genContext.getLoc()) << "unknown ScalarExpression type";
@ -1069,6 +1134,7 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b,
<< "mismatched number of assignments vs output arguments";
os << llvm::formatv(structuredOpRegionBuilderFormat, className, numOfArgs,
interleaveToString(attrs, "\n "),
interleaveToString(stmts, "\n "));
}

View File

@ -6720,6 +6720,22 @@ gentbl_cc_library(
],
"include/mlir/Dialect/Linalg/IR/LinalgOpsDialect.cpp.inc",
),
(
["-gen-enum-decls"],
"include/mlir/Dialect/Linalg/IR/LinalgOpsEnums.h.inc",
),
(
["-gen-enum-defs"],
"include/mlir/Dialect/Linalg/IR/LinalgOpsEnums.cpp.inc",
),
(
["-gen-attrdef-decls"],
"include/mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.h.inc",
),
(
["-gen-attrdef-defs"],
"include/mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "include/mlir/Dialect/Linalg/IR/LinalgOps.td",