forked from OSchip/llvm-project
[mlir][OpDSL] Add type function attributes.
Previously, OpDSL operation used hardcoded type conversion operations (cast or cast_unsigned). Supporting signed and unsigned casts thus meant implementing two different operations. Type function attributes allow us to define a single operation that has a cast type function attribute which at operation instantiation time may be set to cast or cast_unsigned. We may for example, defina a matmul operation with a cast argument: ``` @linalg_structured_op def matmul(A=TensorDef(T1, S.M, S.K), B=TensorDef(T2, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True), cast=TypeFnAttrDef(default=TypeFn.cast)): C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) ``` When instantiating the operation the attribute may be set to the desired cast function: ``` linalg.matmul(lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned) ``` The revsion introduces a enum in the Linalg dialect that maps one-by-one to the type functions defined by OpDSL. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D119718
This commit is contained in:
parent
3fe6f9388f
commit
51fdd802c7
|
@ -44,6 +44,18 @@ add_dependencies(mlir-headers LinalgOdsGen)
|
|||
|
||||
add_mlir_dialect(LinalgOps linalg)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS LinalgOps.td)
|
||||
mlir_tablegen(LinalgOpsEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(LinalgOpsEnums.cpp.inc -gen-enum-defs)
|
||||
add_public_tablegen_target(MLIRLinalgOpsEnumsIncGen)
|
||||
add_dependencies(mlir-headers MLIRLinalgOpsEnumsIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS LinalgOps.td)
|
||||
mlir_tablegen(LinalgOpsAttrDefs.h.inc -gen-attrdef-decls)
|
||||
mlir_tablegen(LinalgOpsAttrDefs.cpp.inc -gen-attrdef-defs)
|
||||
add_public_tablegen_target(MLIRLinalgOpsAttributesIncGen)
|
||||
add_dependencies(mlir-headers MLIRLinalgOpsAttributesIncGen)
|
||||
|
||||
add_mlir_doc(LinalgDoc LinalgOps Dialects/ -gen-op-doc)
|
||||
add_dependencies(LinalgOpsDocGen LinalgOdsGen)
|
||||
|
||||
|
|
|
@ -104,6 +104,19 @@ LogicalResult verifyStructuredOpInterface(Operation *op);
|
|||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg Enums
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOpsEnums.h.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg Attributes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.h.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#ifndef LINALG_BASE
|
||||
#define LINALG_BASE
|
||||
|
||||
include "mlir/IR/EnumAttr.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def Linalg_Dialect : Dialect {
|
||||
|
@ -57,4 +58,17 @@ def Linalg_Dialect : Dialect {
|
|||
}];
|
||||
}
|
||||
|
||||
// Define a TypeFn enum matching the OpDSL TypeFn class.
|
||||
def TypeFn : I32EnumAttr<"TypeFn", "", [
|
||||
I32EnumAttrCase<"cast", 0>,
|
||||
I32EnumAttrCase<"cast_unsigned", 1>
|
||||
]> {
|
||||
let genSpecializedAttr = 0;
|
||||
let cppNamespace = "::mlir::linalg";
|
||||
}
|
||||
|
||||
def TypeFnAttr : EnumAttr<Linalg_Dialect, TypeFn, "type_fn"> {
|
||||
let assemblyFormat = "`<` $value `>`";
|
||||
}
|
||||
|
||||
#endif // LINALG_BASE
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -37,7 +37,7 @@ class LinalgStructuredBase_Op<string mnemonic, list<Trait> props>
|
|||
|
||||
LogicalResult reifyResultShapes(OpBuilder &b,
|
||||
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
|
||||
return cast<LinalgOp>(getOperation()).reifyResultShapes(b,
|
||||
return llvm::cast<LinalgOp>(getOperation()).reifyResultShapes(b,
|
||||
reifiedReturnShapes);
|
||||
}
|
||||
}];
|
||||
|
|
|
@ -8,6 +8,8 @@ add_mlir_dialect_library(MLIRLinalg
|
|||
|
||||
DEPENDS
|
||||
MLIRLinalgInterfacesIncGen
|
||||
MLIRLinalgOpsAttributesIncGen
|
||||
MLIRLinalgOpsEnumsIncGen
|
||||
MLIRLinalgOpsIncGen
|
||||
MLIRLinalgStructuredOpsIncGen
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "mlir/Transforms/InliningUtils.h"
|
||||
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -95,6 +96,10 @@ void addNamedOpBuilders(
|
|||
}
|
||||
|
||||
void mlir::linalg::LinalgDialect::initialize() {
|
||||
addAttributes<
|
||||
#define GET_ATTRDEF_LIST
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.cpp.inc"
|
||||
>();
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
|
||||
|
@ -144,3 +149,10 @@ LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op,
|
|||
return op->emitError() << "attribute '" << attr.getName()
|
||||
<< "' not supported by the linalg dialect";
|
||||
}
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOpsEnums.cpp.inc"
|
||||
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.cpp.inc"
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.cpp.inc"
|
||||
|
|
|
@ -36,8 +36,6 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.cpp.inc"
|
||||
|
||||
/// Forward declarations.
|
||||
|
||||
/// Generic entry point to create the block for the region of a LinalgOp.
|
||||
|
@ -232,14 +230,14 @@ public:
|
|||
return operand;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||
Value typefn__cast(Type toType, Value operand) {
|
||||
return cast(toType, operand, false);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||
Value typefn__cast_unsigned(Type toType, Value operand) {
|
||||
return cast(toType, operand, true);
|
||||
Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
|
||||
switch (typeFn) {
|
||||
case TypeFn::cast:
|
||||
return cast(toType, operand, false);
|
||||
case TypeFn::cast_unsigned:
|
||||
return cast(toType, operand, true);
|
||||
}
|
||||
llvm_unreachable("unsupported type conversion function");
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||
|
|
|
@ -111,7 +111,7 @@ class TensorUse(TensorExpression):
|
|||
@property
|
||||
def tensor_name(self) -> str:
|
||||
name = self.operand_def.name
|
||||
assert name is not None, "TensorDef not attached"
|
||||
assert name is not None, "TensorDef not registered with an op"
|
||||
return name
|
||||
|
||||
def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]:
|
||||
|
@ -129,7 +129,8 @@ class TensorUse(TensorExpression):
|
|||
return ReduceFnUse(ArithFn.add, *self._compute_reduce_dims(rhs))(rhs)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.tensor_name}[{', '.join([repr(i) for i in self.indices])}]"
|
||||
return (f"{self.operand_def.name}"
|
||||
f"[{', '.join([repr(i) for i in self.indices])}]")
|
||||
|
||||
|
||||
class TensorArithFn(TensorExpression):
|
||||
|
@ -156,14 +157,22 @@ class TensorArithFn(TensorExpression):
|
|||
class TensorTypeFn(TensorExpression):
|
||||
"""Application of a type conversion function."""
|
||||
|
||||
def __init__(self, type_fn: "TypeFn", type_var: TypeVar,
|
||||
def __init__(self, type_fn: Optional["TypeFn"],
|
||||
operand_def: Optional["OperandDef"], type_var: TypeVar,
|
||||
arg: TensorExpression):
|
||||
if bool(type_fn) + bool(operand_def) != 1:
|
||||
raise ValueError("Either 'type_fn' or 'operand_def' must be specified")
|
||||
self.type_fn = type_fn
|
||||
self.operand_def = operand_def
|
||||
self.type_var = type_var
|
||||
self.arg = arg
|
||||
|
||||
def to_scalar_expression(self) -> ScalarExpression:
|
||||
return ScalarTypeFn(self.type_fn.fn_name, self.type_var,
|
||||
if self.operand_def:
|
||||
assert self.operand_def.name, "TypeFnAttr not registered with an op"
|
||||
fn_name = self.type_fn.fn_name if self.type_fn else None
|
||||
attr_name = self.operand_def.name if self.operand_def else None
|
||||
return ScalarTypeFn(fn_name, attr_name, self.type_var,
|
||||
self.arg.to_scalar_expression()).expr()
|
||||
|
||||
def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
|
||||
|
@ -171,7 +180,8 @@ class TensorTypeFn(TensorExpression):
|
|||
self.arg.visit_tensor_exprs(callback)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{repr(self.type_fn)}({self.type_var}, {self.arg})"
|
||||
return (f"{repr(self.type_fn)}[{repr(self.operand_def)}]"
|
||||
f"({self.type_var}, {self.arg})")
|
||||
|
||||
|
||||
class TensorReduceFn(TensorExpression):
|
||||
|
@ -260,7 +270,7 @@ class TypeFnType:
|
|||
self.fn_name = fn_name
|
||||
|
||||
def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TypeFnType":
|
||||
return TensorTypeFn(self, type_var, arg)
|
||||
return TensorTypeFn(self, None, type_var, arg)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.fn_name}"
|
||||
|
@ -370,10 +380,11 @@ class ReduceFn:
|
|||
|
||||
|
||||
class OperandKind(Enum):
|
||||
InputTensor = 0
|
||||
Scalar = 1
|
||||
OutputTensor = 2
|
||||
IndexAttr = 3
|
||||
INPUT_TENSOR = 0
|
||||
SCALAR = 1
|
||||
OUTPUT_TENSOR = 2
|
||||
INDEX_ATTR = 3
|
||||
TYPE_FN_ATTR = 4
|
||||
|
||||
|
||||
class OperandDef:
|
||||
|
@ -388,7 +399,8 @@ class OperandDef:
|
|||
type_var: Optional[TypeVar] = None,
|
||||
size_exprs: Optional[Sequence[AffineExprDef]] = None,
|
||||
index_dims: Optional[Sequence[DimDef]] = None,
|
||||
default_vals: Optional[Sequence[int]] = None):
|
||||
default_indices: Optional[Sequence[int]] = None,
|
||||
default_fn: Optional[str] = None):
|
||||
if type_var and not isinstance(type_var, TypeVar):
|
||||
raise ValueError(
|
||||
f"OperandDef requires a TypeVar but got {repr(type_var)}")
|
||||
|
@ -396,25 +408,40 @@ class OperandDef:
|
|||
self.type_var = type_var
|
||||
self.size_exprs = size_exprs
|
||||
self.index_dims = index_dims
|
||||
self.default_vals = default_vals
|
||||
self.default_indices = default_indices
|
||||
self.default_fn = default_fn
|
||||
self.kind = kind
|
||||
self.name = None # type: Optional[str]
|
||||
self.registered_index = -1 # type: int
|
||||
|
||||
def attach(self, index: int, name: str, owner: "LinalgOpDef"):
|
||||
if self.owner:
|
||||
raise ValueError(f"OperandDef already registered with op: {self}")
|
||||
raise ValueError(f"OperandDef already registered with an op: {self}")
|
||||
self.registered_index = index
|
||||
self.name = name
|
||||
self.owner = owner
|
||||
|
||||
def is_input(self) -> bool:
|
||||
return (self.kind == OperandKind.SCALAR or
|
||||
self.kind == OperandKind.INPUT_TENSOR)
|
||||
|
||||
def is_tensor(self) -> bool:
|
||||
return (self.kind == OperandKind.INPUT_TENSOR or
|
||||
self.kind == OperandKind.OUTPUT_TENSOR)
|
||||
|
||||
def is_attribute(self) -> bool:
|
||||
return (self.kind == OperandKind.INDEX_ATTR or
|
||||
self.kind == OperandKind.TYPE_FN_ATTR)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(id(self))
|
||||
|
||||
def __repr__(self):
|
||||
return (f"{self.name}:OperandDef(kind={self.kind.name}, "
|
||||
f"type={repr(self.type_var)}, size_exprs={self.size_exprs}), "
|
||||
f"index_dims={self.index_dims}, default_vals={self.default_vals})")
|
||||
f"type={repr(self.type_var)}, size_exprs={self.size_exprs}, "
|
||||
f"index_dims={self.index_dims}, "
|
||||
f"default_indices={self.default_indices}, "
|
||||
f"default_fn={self.default_fn})")
|
||||
|
||||
|
||||
class TensorDef:
|
||||
|
@ -440,12 +467,12 @@ class TensorDef:
|
|||
if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims):
|
||||
raise ValueError(f"TensorDef requires index dims of type DimDef but "
|
||||
f"got {index_dims}")
|
||||
kind = OperandKind.OutputTensor if output else OperandKind.InputTensor
|
||||
kind = OperandKind.OUTPUT_TENSOR if output else OperandKind.INPUT_TENSOR
|
||||
self.operand_def = OperandDef(
|
||||
kind, type_var=type_var, size_exprs=shape, index_dims=index_dims)
|
||||
|
||||
def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse:
|
||||
assert self.operand_def.owner, "TensorDef is not attached to an op"
|
||||
assert self.operand_def.owner, "TensorDef is not registered with an op"
|
||||
state = AffineBuildState(
|
||||
global_state=self.operand_def.owner._affine_state,
|
||||
allow_new_symbols=False)
|
||||
|
@ -486,12 +513,12 @@ class ScalarDef(TensorExpression):
|
|||
"""
|
||||
|
||||
def __init__(self, type_var: TypeVar):
|
||||
self.operand_def = OperandDef(OperandKind.Scalar, type_var=type_var)
|
||||
self.operand_def = OperandDef(OperandKind.SCALAR, type_var=type_var)
|
||||
|
||||
@property
|
||||
def scalar_name(self) -> str:
|
||||
name = self.operand_def.name
|
||||
assert name is not None, "ScalarDef not attached"
|
||||
assert name is not None, "ScalarDef not registered with an op"
|
||||
return name
|
||||
|
||||
def to_scalar_expression(self) -> ScalarExpression:
|
||||
|
@ -517,7 +544,26 @@ class IndexAttrDef:
|
|||
raise ValueError(f"IndexAttrDef expects {len(sizes)} default values "
|
||||
f"but got {len(default)}")
|
||||
self.operand_def = OperandDef(
|
||||
OperandKind.IndexAttr, size_exprs=sizes, default_vals=default)
|
||||
OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default)
|
||||
|
||||
|
||||
class TypeFnAttrDef:
|
||||
"""Type conversion function attribute definition.
|
||||
|
||||
Type conversion function attributes provide a way to make type conversions
|
||||
parameterizable. Every attribute specifies a default type conversion function
|
||||
that may be overwritten at operation instantiation time.
|
||||
"""
|
||||
|
||||
def __init__(self, default: "TypeFnType"):
|
||||
if not isinstance(default, TypeFnType):
|
||||
raise ValueError(f"TypeFnAttrDef requires default of type TypeFnType "
|
||||
f"but got {default}")
|
||||
self.operand_def = OperandDef(
|
||||
OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name)
|
||||
|
||||
def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorTypeFn:
|
||||
return TensorTypeFn(None, self.operand_def, type_var, arg)
|
||||
|
||||
|
||||
###############################################################################
|
||||
|
@ -615,17 +661,21 @@ class LinalgOpDef:
|
|||
if name in self.registered_operands:
|
||||
raise ValueError(f"The operand {name} is already registered "
|
||||
f"to {self.registered_operands['name']}")
|
||||
structured_op_methods = [
|
||||
"inputs", "outputs", "result_tensors", "region", "iterator_types",
|
||||
"indexing_maps", "getRegionBuilder", "getLibraryCallName"
|
||||
]
|
||||
if operand.is_attribute() and name in structured_op_methods:
|
||||
raise ValueError(f"The attribute name {name} conflicts with a structured "
|
||||
f"op method name")
|
||||
# Ensure output tensors are registered after input tensors and scalars and
|
||||
# attributes are registered after all other operand types.
|
||||
registered_kinds = [
|
||||
operand.kind.value for operand in self.registered_operands.values()
|
||||
]
|
||||
if registered_kinds:
|
||||
maximum = max(registered_kinds)
|
||||
if maximum > operand.kind.value and maximum > OperandKind.Scalar.value:
|
||||
raise ValueError(
|
||||
f"The operand {name} of kind {operand.kind.name} is registered "
|
||||
f"after an operand of kind {OperandKind(maximum).name}")
|
||||
if operand.is_input() and any(
|
||||
not op_def.is_input() for op_def in self.registered_operands.values()):
|
||||
raise ValueError(f"Input {name} registered after an output or attribute")
|
||||
if operand.kind == OperandKind.OUTPUT_TENSOR and any(
|
||||
op_def.is_attribute() for op_def in self.registered_operands.values()):
|
||||
raise ValueError(f"Output {name} registered after an attribute")
|
||||
operand.attach(len(self.registered_operands), name, self)
|
||||
self.registered_operands[name] = operand
|
||||
|
||||
|
|
|
@ -55,28 +55,26 @@ class OperandDefConfig(YAMLObject):
|
|||
def name(self) -> str:
|
||||
return self.operand_def.name
|
||||
|
||||
@property
|
||||
def kind(self) -> OperandKind:
|
||||
return self.operand_def.kind
|
||||
|
||||
@property
|
||||
def type_var(self) -> TypeVar:
|
||||
return self.operand_def.type_var
|
||||
|
||||
@property
|
||||
def usage(self) -> str:
|
||||
if self.operand_def.kind == OperandKind.IndexAttr:
|
||||
return "IndexAttr"
|
||||
if self.operand_def.kind == OperandKind.OutputTensor:
|
||||
return "Output"
|
||||
return "Input"
|
||||
|
||||
def to_yaml_custom_dict(self):
|
||||
self_dict = dict(name=self.name, usage=self.usage)
|
||||
self_dict = dict(name=self.name, kind=self.operand_def.kind.name.lower())
|
||||
if self.type_var:
|
||||
self_dict["type_var"] = self.type_var.name
|
||||
if self.shape_map:
|
||||
self_dict["shape_map"] = _serialize_affine_map(self.shape_map)
|
||||
if self.index_attr_map:
|
||||
self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map)
|
||||
if self.operand_def.default_vals:
|
||||
self_dict["default_vals"] = self.operand_def.default_vals
|
||||
if self.operand_def.default_indices:
|
||||
self_dict["default_indices"] = self.operand_def.default_indices
|
||||
if self.operand_def.default_fn:
|
||||
self_dict["default_fn"] = self.operand_def.default_fn
|
||||
return self_dict
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -166,7 +164,7 @@ class LinalgStructuredOpConfig(YAMLObject):
|
|||
# Collect all attribute definitions.
|
||||
collected_attr_defs = list()
|
||||
for operand in registered_operands:
|
||||
if operand.kind == OperandKind.IndexAttr:
|
||||
if operand.is_attribute():
|
||||
collected_attr_defs.append(operand)
|
||||
|
||||
# Collect all tensors with manual indexing annotation.
|
||||
|
@ -244,12 +242,12 @@ class LinalgStructuredOpConfig(YAMLObject):
|
|||
|
||||
# Set the indexing map of all scalar uses to the empty map.
|
||||
for operand_config in self.operands.values():
|
||||
if operand_config.operand_def.kind == OperandKind.Scalar:
|
||||
if operand_config.operand_def.kind == OperandKind.SCALAR:
|
||||
operand_config.indexing_map = self._get_scalar_map()
|
||||
|
||||
# Check all registered tensor and scalar operands have an indexing map.
|
||||
for operand in registered_operands:
|
||||
if operand.kind == OperandKind.IndexAttr:
|
||||
if operand.is_attribute():
|
||||
continue
|
||||
if not (operand in self.operands and self.operands[operand].indexing_map):
|
||||
raise ValueError(f"Failed to compute an indexing map for operand "
|
||||
|
@ -311,7 +309,8 @@ class LinalgStructuredOpConfig(YAMLObject):
|
|||
def add_operand(self, operand_def: OperandDef):
|
||||
if operand_def in self.operands:
|
||||
return
|
||||
if operand_def.kind == OperandKind.Scalar:
|
||||
if (operand_def.kind == OperandKind.SCALAR or
|
||||
operand_def.kind == OperandKind.TYPE_FN_ATTR):
|
||||
self.operands[operand_def] = OperandDefConfig(operand_def)
|
||||
return
|
||||
with self.context:
|
||||
|
@ -323,7 +322,7 @@ class LinalgStructuredOpConfig(YAMLObject):
|
|||
assert local_state.local_dim_count == 0
|
||||
affine_map = _ir.AffineMap.get(
|
||||
dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs)
|
||||
if operand_def.kind == OperandKind.IndexAttr:
|
||||
if operand_def.kind == OperandKind.INDEX_ATTR:
|
||||
self.operands[operand_def] = OperandDefConfig(
|
||||
operand_def, index_attr_map=affine_map)
|
||||
else:
|
||||
|
@ -429,8 +428,7 @@ class LinalgOpConfig(YAMLObject):
|
|||
context: Optional[_ir.Context] = None) -> Sequence["LinalgOpConfig"]:
|
||||
"""Expands a LinalgOpDef into corresponding Linalg configured ops."""
|
||||
# TODO: Many LinalgOpDef patterns need to expand to multiple generics.
|
||||
assert len(
|
||||
op_def.comprehensions) == 1, "Only one comprehension supported"
|
||||
assert len(op_def.comprehensions) == 1, "Only one comprehension supported"
|
||||
return [
|
||||
LinalgOpConfig(
|
||||
op_def.metadata,
|
||||
|
|
|
@ -129,7 +129,8 @@ def linalg_structured_op(dsl_func=None,
|
|||
sig = inspect.signature(dsl_func)
|
||||
for param_name, param in sig.parameters.items():
|
||||
param_default = param.default
|
||||
if isinstance(param_default, (TensorDef, ScalarDef, IndexAttrDef)):
|
||||
if isinstance(param_default,
|
||||
(TensorDef, ScalarDef, IndexAttrDef, TypeFnAttrDef)):
|
||||
op_def.add_operand(param_name, param_default.operand_def)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
|
|
@ -37,11 +37,21 @@ def isa(cls: Type, ty: Type):
|
|||
|
||||
def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
|
||||
*ins: Value, outs: ValueList,
|
||||
**attrs: Sequence[int]):
|
||||
**attrs: Union[Sequence[int], TypeFnType]):
|
||||
all_arg_defs = op_config.ordered_operands
|
||||
in_arg_defs = [d for d in all_arg_defs if d.usage == "Input"]
|
||||
out_arg_defs = [d for d in all_arg_defs if d.usage == "Output"]
|
||||
index_attr_arg_defs = [d for d in all_arg_defs if d.usage == "IndexAttr"]
|
||||
in_arg_defs = [
|
||||
d for d in all_arg_defs
|
||||
if d.kind == OperandKind.SCALAR or d.kind == OperandKind.INPUT_TENSOR
|
||||
]
|
||||
out_arg_defs = [
|
||||
d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR
|
||||
]
|
||||
index_attr_arg_defs = [
|
||||
d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR
|
||||
]
|
||||
type_fn_attr_arg_defs = [
|
||||
d for d in all_arg_defs if d.kind == OperandKind.TYPE_FN_ATTR
|
||||
]
|
||||
|
||||
# Verify outs is a sequence or a list of results.
|
||||
if not isinstance(outs, (Sequence, OpResultList)):
|
||||
|
@ -56,11 +66,11 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
|
|||
raise ValueError(f"Expected {len(out_arg_defs)} outputs but got "
|
||||
f"{len(outs)} for {op_config}")
|
||||
|
||||
# Compute a replacement list for all attribute symbols.
|
||||
# Compute a replacement list for all index attribute symbols.
|
||||
expressions = [] # type: Sequence[AffineExpr]
|
||||
replacements = [] # type: Sequence[AffineExpr]
|
||||
for index_attr in index_attr_arg_defs:
|
||||
index_attr_vals = index_attr.operand_def.default_vals
|
||||
index_attr_vals = index_attr.operand_def.default_indices
|
||||
if index_attr.name in attrs:
|
||||
index_attr_vals = attrs.get(index_attr.name)
|
||||
assert index_attr_vals, "Index attribute has no value"
|
||||
|
@ -125,15 +135,29 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
|
|||
array = np.array(index_attr_vals, dtype=np.int64)
|
||||
index_attrs[index_attr.name] = DenseElementsAttr.get(array)
|
||||
|
||||
# Compute the type function attribute mapping.
|
||||
type_fn_attr_mapping = {}
|
||||
for type_fn_attr in type_fn_attr_arg_defs:
|
||||
attr_val = type_fn_attr.operand_def.default_fn
|
||||
if type_fn_attr.name in attrs:
|
||||
type_fn = attrs.get(type_fn_attr.name)
|
||||
if not isinstance(type_fn, TypeFnType):
|
||||
raise ValueError(f"Attribute {type_fn_attr.name} needs to be of type "
|
||||
f"TypeFnType but got {type(attr_val)}")
|
||||
attr_val = type_fn.fn_name
|
||||
assert attr_val, "Type function attribute has no value"
|
||||
type_fn_attr_mapping[type_fn_attr.name] = attr_val
|
||||
|
||||
return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types,
|
||||
type_mapping, indexing_maps_attr, iterator_types_attr,
|
||||
index_attrs, block_arg_types)
|
||||
type_mapping, indexing_maps_attr, iterator_types_attr, index_attrs,
|
||||
type_fn_attr_mapping, block_arg_types)
|
||||
|
||||
|
||||
def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
|
||||
outs: ValueList, **attrs: Sequence[int]):
|
||||
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
|
||||
indexing_maps_attr, iterator_types_attr, index_attrs, block_arg_types = \
|
||||
indexing_maps_attr, iterator_types_attr, index_attrs, type_fn_attr_mapping, \
|
||||
block_arg_types = \
|
||||
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
|
||||
|
||||
# An operation that accesses only scalars and scalar/rank zero tensors is
|
||||
|
@ -147,10 +171,9 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
|
|||
tensor_map = AffineMap.get_identity(rank)
|
||||
indexing_maps = []
|
||||
for arg_def in all_arg_defs:
|
||||
if arg_def.operand_def.kind == OperandKind.Scalar:
|
||||
if arg_def.operand_def.kind == OperandKind.SCALAR:
|
||||
indexing_maps.append(scalar_map)
|
||||
if (arg_def.operand_def.kind == OperandKind.InputTensor or
|
||||
arg_def.operand_def.kind == OperandKind.OutputTensor):
|
||||
if arg_def.operand_def.is_tensor():
|
||||
indexing_maps.append(tensor_map)
|
||||
indexing_maps_attr = ArrayAttr.get(
|
||||
[AffineMapAttr.get(am) for am in indexing_maps])
|
||||
|
@ -169,7 +192,8 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
|
|||
block = generic_op.regions[0].blocks.append(*block_arg_types)
|
||||
block_arg_mapping = dict(zip(block_arg_names, block.arguments))
|
||||
with InsertionPoint(block):
|
||||
body_builder = _BodyBuilder(type_mapping, block_arg_mapping)
|
||||
body_builder = _BodyBuilder(type_mapping, block_arg_mapping,
|
||||
type_fn_attr_mapping)
|
||||
for assignment in op_config.assignments:
|
||||
body_builder.assign(assignment)
|
||||
body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs))
|
||||
|
@ -184,7 +208,8 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
|
|||
op_class_name: str, *ins: Value, outs: ValueList,
|
||||
**attrs: Sequence[int]):
|
||||
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
|
||||
indexing_maps_attr, iterator_types_attr, index_attrs, block_arg_types = \
|
||||
indexing_maps_attr, iterator_types_attr, index_attrs, type_fn_attr_mapping, \
|
||||
block_arg_types = \
|
||||
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
|
||||
|
||||
# If we get here, there must exist a builtin class `op_class_name`.
|
||||
|
@ -200,6 +225,11 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
|
|||
for name, value in index_attrs.items():
|
||||
named_op.operation.attributes[name] = value
|
||||
|
||||
# Set the type function attributes.
|
||||
for name, value in type_fn_attr_mapping.items():
|
||||
named_op.operation.attributes[name] = Attribute.parse(
|
||||
f"#linalg.type_fn<{value}>")
|
||||
|
||||
linalg.fill_builtin_region(named_op.operation)
|
||||
|
||||
if len(result_types) == 1:
|
||||
|
@ -212,9 +242,11 @@ class _BodyBuilder:
|
|||
"""Constructs a structured op body by evaluating assignments."""
|
||||
|
||||
def __init__(self, type_mapping: Dict[str, Type],
|
||||
block_arg_mapping: Dict[str, Value]):
|
||||
block_arg_mapping: Dict[str, Value],
|
||||
type_fn_attr_mapping: Dict[str, str]):
|
||||
self.type_mapping = type_mapping
|
||||
self.block_arg_mapping = block_arg_mapping
|
||||
self.type_fn_attr_mapping = type_fn_attr_mapping
|
||||
self.yield_mapping = dict() # type: Dict[str, Value]
|
||||
|
||||
def assign(self, assignment: ScalarAssign):
|
||||
|
@ -245,7 +277,10 @@ class _BodyBuilder:
|
|||
]
|
||||
return fn(*operand_values)
|
||||
elif expr.type_fn:
|
||||
fn = self._get_function(f"_typefn_{expr.type_fn.fn_name}")
|
||||
fn_name = expr.type_fn.fn_name
|
||||
if expr.type_fn.attr_name:
|
||||
fn_name = self.type_fn_attr_mapping[expr.type_fn.attr_name]
|
||||
fn = self._get_function(f"_typefn_{fn_name}")
|
||||
operand = self.expression(expr.type_fn.operand)
|
||||
return fn(expr.type_fn.type_var.name, operand)
|
||||
raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
|
||||
|
|
|
@ -46,9 +46,10 @@ class ScalarArithFn:
|
|||
class ScalarTypeFn:
|
||||
"""A type of ScalarExpression that applies a type conversion function."""
|
||||
|
||||
def __init__(self, fn_name: str, type_var: TypeVar,
|
||||
operand: "ScalarExpression"):
|
||||
def __init__(self, fn_name: Optional[str], attr_name: Optional[str],
|
||||
type_var: TypeVar, operand: "ScalarExpression"):
|
||||
self.fn_name = fn_name
|
||||
self.attr_name = attr_name
|
||||
self.type_var = type_var
|
||||
self.operand = operand
|
||||
|
||||
|
@ -56,7 +57,8 @@ class ScalarTypeFn:
|
|||
return ScalarExpression(type_fn=self)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ScalarTypeFn<{self.fn_name}>({self.type_var}, {self.operand})"
|
||||
return (f"ScalarTypeFn<{self.fn_name}[{self.attr_name}]>"
|
||||
f"({self.type_var}, {self.operand})")
|
||||
|
||||
|
||||
class ScalarArg:
|
||||
|
@ -138,12 +140,15 @@ class ScalarExpression(YAMLObject):
|
|||
# Note that even though operands must be arity 1, we write it the
|
||||
# same way as for apply because it allows handling code to be more
|
||||
# generic vs having a special form.
|
||||
return dict(
|
||||
type_fn=dict(
|
||||
fn_name=self.type_fn.fn_name,
|
||||
type_var=self.type_fn.type_var.name,
|
||||
operands=[self.type_fn.operand],
|
||||
))
|
||||
type_fn_dict = dict(
|
||||
type_var=self.type_fn.type_var.name,
|
||||
operands=[self.type_fn.operand],
|
||||
)
|
||||
if self.type_fn.fn_name:
|
||||
type_fn_dict["fn_name"] = self.type_fn.fn_name
|
||||
if self.type_fn.attr_name:
|
||||
type_fn_dict["attr_name"] = self.type_fn.attr_name
|
||||
return dict(type_fn=type_fn_dict)
|
||||
elif self.scalar_arg:
|
||||
return dict(scalar_arg=self.scalar_arg.arg)
|
||||
elif self.scalar_const:
|
||||
|
|
|
@ -10,7 +10,8 @@ Batch = S.Batch
|
|||
def matmul(
|
||||
A=TensorDef(T1, S.M, S.K),
|
||||
B=TensorDef(T2, S.K, S.N),
|
||||
C=TensorDef(U, S.M, S.N, output=True)):
|
||||
C=TensorDef(U, S.M, S.N, output=True),
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast)):
|
||||
"""Performs a matrix multiplication of two 2D inputs.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
|
@ -18,7 +19,7 @@ def matmul(
|
|||
"""
|
||||
domain(D.m, D.n, D.k)
|
||||
implements(ContractionOpInterface)
|
||||
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
|
||||
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
|
|
|
@ -36,6 +36,20 @@ func @generalize_matmul_tensor_i16i64i32(%A : tensor<16x8xi16>, %B: tensor<8x32x
|
|||
// CHECK-NEXT: linalg.yield %[[ADD]] : i32
|
||||
// CHECK-NEXT: -> tensor<16x32xi32>
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Verifies that cast attributes control the cast operations used.
|
||||
func @generalize_matmul_tensor_i16i64i32_unsigned(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
|
||||
%0 = linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
|
||||
ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>)
|
||||
outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
|
||||
return %0: tensor<16x32xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @generalize_matmul_tensor_i16i64i32_unsigned
|
||||
// CHECK: = arith.extui
|
||||
|
||||
// -----
|
||||
|
||||
func @generalize_matmul_tensor_i16i64f32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
|
||||
|
|
|
@ -1258,7 +1258,7 @@ class IndexExpr(abc.ABC):
|
|||
value = self._emit_expression(expr_to_input_opnd, expr_to_info)
|
||||
# Emit the structured op representation for the destination tensor.
|
||||
dst_opnd = _emit_operand(op_def, op_info.dst_indices, op_info.dst_name,
|
||||
lang.OperandKind.OutputTensor)
|
||||
lang.OperandKind.OUTPUT_TENSOR)
|
||||
dst_dim_syms = _mlir_dimensions_from_index_vars(op_info.dst_indices)
|
||||
dst_use = lang.TensorUse(dst_opnd, dst_dim_syms)
|
||||
|
||||
|
@ -1893,6 +1893,6 @@ def _emit_structured_op_input(
|
|||
name = expr.tensor.name
|
||||
|
||||
dim_sym = _mlir_symbols_from_index_vars(indices)
|
||||
opnd = lang.OperandDef(lang.OperandKind.InputTensor, lang.T, dim_sym)
|
||||
opnd = lang.OperandDef(lang.OperandKind.INPUT_TENSOR, lang.T, dim_sym)
|
||||
op_def.add_operand(name, opnd)
|
||||
return opnd
|
||||
|
|
|
@ -2,7 +2,8 @@
|
|||
# RUN: mlir-linalg-ods-yaml-gen %s --o-impl=- | FileCheck %s --check-prefix=IMPL
|
||||
|
||||
# @linalg_structured_op
|
||||
# def test1(O=TensorDef(T, S.M, S.N, output=True)):
|
||||
# def test1(O=TensorDef(T, S.M, S.N, output=True),
|
||||
# cast=TypeFnAttrDef(default=TypeFn.cast)):
|
||||
# """Title.
|
||||
|
||||
# Detailed description.
|
||||
|
@ -21,9 +22,13 @@ structured_op: !LinalgStructuredOpConfig
|
|||
args:
|
||||
- !LinalgOperandDefConfig
|
||||
name: O
|
||||
usage: Output
|
||||
kind: output_tensor
|
||||
type_var: T
|
||||
shape_map: affine_map<()[s0, s1] -> (s0, s1)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: cast
|
||||
kind: type_fn_attr
|
||||
default_fn: cast
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
|
||||
|
@ -39,18 +44,18 @@ structured_op: !LinalgStructuredOpConfig
|
|||
operands:
|
||||
- !ScalarExpression
|
||||
type_fn:
|
||||
fn_name: cast
|
||||
type_var: T
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_const: '42 : i64'
|
||||
attr_name: cast
|
||||
- !ScalarExpression
|
||||
type_fn:
|
||||
fn_name: cast_unsigned
|
||||
type_var: T
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_index: 1
|
||||
attr_name: cast
|
||||
|
||||
# ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1"
|
||||
|
||||
|
@ -61,16 +66,22 @@ structured_op: !LinalgStructuredOpConfig
|
|||
|
||||
# ODS: let arguments =
|
||||
# ODS-NEXT: Variadic<AnyType>:$inputs,
|
||||
# ODS-NEXT: Variadic<AnyShaped>:$outputs
|
||||
# ODS-NEXT: Variadic<AnyShaped>:$outputs,
|
||||
# ODS-NEXT: DefaultValuedAttr<TypeFnAttr, "TypeFn::cast">:$cast
|
||||
|
||||
# ODS: let builders =
|
||||
# ODS: (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
||||
# ODS-NEXT: "ValueRange":$outputs,
|
||||
# ODS-NEXT: CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
|
||||
|
||||
# ODS: (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
||||
# ODS-NEXT: "ValueRange":$outputs, "Attribute":$cast,
|
||||
# ODS-NEXT: CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
|
||||
|
||||
# ODS: $_state.addOperands(inputs);
|
||||
# ODS-NEXT: $_state.addOperands(outputs);
|
||||
# ODS-NEXT: $_state.addTypes(resultTensorTypes);
|
||||
# ODS-NEXT: $_state.addAttribute("cast", cast)
|
||||
# ODS-NEXT: $_state.addAttributes(attributes);
|
||||
# ODS-NEXT: $_state.addAttribute(
|
||||
# ODS-NEXT: "operand_segment_sizes",
|
||||
|
@ -85,10 +96,18 @@ structured_op: !LinalgStructuredOpConfig
|
|||
|
||||
# IMPL-LABEL: void Test1Op::regionBuilder(ImplicitLocOpBuilder &b,
|
||||
# IMPL-NEXT: Block &block, ArrayRef<NamedAttribute> attrs)
|
||||
# IMPL: TypeFn castVal = TypeFn::cast;
|
||||
# IMPL-NEXT: auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
|
||||
# IMPL-NEXT: return attr.getName() == "cast"; });
|
||||
# IMPL-NEXT: if (castIter != attrs.end()) {
|
||||
# IMPL-NEXT: if (auto attr = castIter->getValue().dyn_cast<TypeFnAttr>())
|
||||
# IMPL-NEXT: castVal = attr.getValue();
|
||||
# IMPL-NEXT: }
|
||||
|
||||
# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64");
|
||||
# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.typefn__cast(block.getArgument(0).getType(), [[VAL0]]);
|
||||
# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL0]]);
|
||||
# IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1);
|
||||
# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.typefn__cast_unsigned(block.getArgument(0).getType(), [[VAL2]]);
|
||||
# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL2]]);
|
||||
# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.arithfn__add([[VAL1]], [[VAL3]]);
|
||||
|
||||
|
||||
|
@ -114,19 +133,19 @@ structured_op: !LinalgStructuredOpConfig
|
|||
args:
|
||||
- !LinalgOperandDefConfig
|
||||
name: I
|
||||
usage: Input
|
||||
kind: input_tensor
|
||||
type_var: T
|
||||
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: O
|
||||
usage: Output
|
||||
kind: output_tensor
|
||||
type_var: T
|
||||
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: strides
|
||||
usage: IndexAttr
|
||||
kind: index_attr
|
||||
index_attr_map: affine_map<()[s0, s1, s2, s3] -> (s2, s3)>
|
||||
default_vals:
|
||||
default_indices:
|
||||
- 1
|
||||
- 2
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
|
@ -201,11 +220,11 @@ structured_op: !LinalgStructuredOpConfig
|
|||
args:
|
||||
- !LinalgOperandDefConfig
|
||||
name: value
|
||||
usage: Input
|
||||
kind: scalar
|
||||
type_var: T1
|
||||
- !LinalgOperandDefConfig
|
||||
name: O
|
||||
usage: Output
|
||||
kind: output_tensor
|
||||
type_var: U
|
||||
shape_map: affine_map<() -> ()>
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
|
|
|
@ -7,30 +7,34 @@ from mlir.dialects.linalg.opdsl.lang import *
|
|||
# CHECK-LABEL: matmul
|
||||
# CHECK: args:
|
||||
# CHECK: name: A
|
||||
# CHECK: usage: Input
|
||||
# CHECK: kind: input_tensor
|
||||
# CHECK: type_var: T
|
||||
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
|
||||
# CHECK: name: B
|
||||
# CHECK: usage: Input
|
||||
# CHECK: kind: input_tensor
|
||||
# CHECK: type_var: T
|
||||
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
|
||||
# CHECK: name: C
|
||||
# CHECK: usage: Output
|
||||
# CHECK: kind: output_tensor
|
||||
# CHECK: type_var: U
|
||||
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
|
||||
# CHECK: name: cast
|
||||
# CHECK: kind: type_fn_attr
|
||||
# CHECK: default_fn: cast
|
||||
@linalg_structured_op
|
||||
def matmul(
|
||||
A=TensorDef(T, S.M, S.K),
|
||||
B=TensorDef(T, S.K, S.N),
|
||||
C=TensorDef(U, S.M, S.N, output=True)):
|
||||
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
|
||||
C=TensorDef(U, S.M, S.N, output=True),
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast)):
|
||||
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
|
||||
|
||||
|
||||
# CHECK: ---
|
||||
# CHECK-LABEL: fill
|
||||
# CHECK: args:
|
||||
# CHECK: name: value
|
||||
# CHECK: usage: Input
|
||||
# CHECK: kind: scalar
|
||||
# CHECK-NOT: shape_map:
|
||||
# CHECK: type_var: T
|
||||
@linalg_structured_op
|
||||
|
@ -42,17 +46,17 @@ def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
|
|||
# CHECK-LABEL: strided_copy
|
||||
# CHECK: args:
|
||||
# CHECK: name: I
|
||||
# CHECK: usage: Input
|
||||
# CHECK: kind: input_tensor
|
||||
# CHECK: type_var: T
|
||||
# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1)>
|
||||
# CHECK: name: O
|
||||
# CHECK: usage: Output
|
||||
# CHECK: kind: output_tensor
|
||||
# CHECK: type_var: T
|
||||
# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2, s3)>
|
||||
# CHECK: name: strides
|
||||
# CHECK: usage: IndexAttr
|
||||
# CHECK: kind: index_attr
|
||||
# CHECK: index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s5)>
|
||||
# CHECK: default_vals:
|
||||
# CHECK: default_indices:
|
||||
# CHECK: - 1
|
||||
# CHECK: - 2
|
||||
@linalg_structured_op
|
||||
|
|
|
@ -19,16 +19,19 @@ from mlir.dialects.linalg.opdsl.lang import *
|
|||
# CHECK: type_var: U
|
||||
# CHECK: operands:
|
||||
# CHECK: scalar_arg: A
|
||||
# CHECK: attr_name: cast
|
||||
# CHECK: type_fn:
|
||||
# CHECK: type_var: U
|
||||
# CHECK: operands:
|
||||
# CHECK: scalar_arg: B
|
||||
# CHECK: attr_name: cast
|
||||
@linalg_structured_op
|
||||
def matmul(
|
||||
A=TensorDef(T, S.M, S.K),
|
||||
B=TensorDef(T, S.K, S.N),
|
||||
C=TensorDef(U, S.M, S.N, output=True)):
|
||||
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
|
||||
C=TensorDef(U, S.M, S.N, output=True),
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast)):
|
||||
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
|
||||
|
||||
|
||||
# CHECK: ---
|
||||
|
|
|
@ -24,19 +24,10 @@ def matmul_mono(
|
|||
def matmul_poly(
|
||||
A=TensorDef(T1, S.M, S.K),
|
||||
B=TensorDef(T2, S.K, S.N),
|
||||
C=TensorDef(U, S.M, S.N, output=True)):
|
||||
C=TensorDef(U, S.M, S.N, output=True),
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast)):
|
||||
domain(D.m, D.n, D.k)
|
||||
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def matmul_unsigned_poly(
|
||||
A=TensorDef(T1, S.M, S.K),
|
||||
B=TensorDef(T2, S.K, S.N),
|
||||
C=TensorDef(U, S.M, S.N, output=True)):
|
||||
domain(D.m, D.n, D.k)
|
||||
C[D.m, D.n] += TypeFn.cast_unsigned(U, A[D.m, D.k]) * TypeFn.cast_unsigned(
|
||||
U, B[D.k, D.n])
|
||||
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
|
||||
|
||||
|
||||
with Context() as ctx, Location.unknown():
|
||||
|
@ -92,7 +83,8 @@ with Context() as ctx, Location.unknown():
|
|||
RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
|
||||
RankedTensorType.get((4, 8), i32))
|
||||
def test_i8i8i32_matmul_unsigned(lhs, rhs, init_result):
|
||||
return matmul_unsigned_poly(lhs, rhs, outs=[init_result])
|
||||
return matmul_poly(
|
||||
lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned)
|
||||
|
||||
# CHECK-LABEL: @test_i8i16i32_matmul
|
||||
# CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32)
|
||||
|
@ -143,7 +135,8 @@ with Context() as ctx, Location.unknown():
|
|||
RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
|
||||
RankedTensorType.get((4, 8), f32))
|
||||
def test_i8i8f32_matmul_unsigned(lhs, rhs, init_result):
|
||||
return matmul_unsigned_poly(lhs, rhs, outs=[init_result])
|
||||
return matmul_poly(
|
||||
lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned)
|
||||
|
||||
# CHECK-LABEL: @test_f16f16f32_matmul
|
||||
# CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
|
||||
|
|
|
@ -6,6 +6,8 @@ from mlir.dialects import linalg
|
|||
from mlir.dialects import std
|
||||
from mlir.dialects import arith
|
||||
|
||||
from mlir.dialects.linalg.opdsl.lang import *
|
||||
|
||||
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
|
@ -98,12 +100,14 @@ def testNamedStructuredOpCustomForm():
|
|||
init_result = linalg.InitTensorOp([4, 8], f32)
|
||||
# First check the named form with custom format
|
||||
# CHECK: linalg.matmul
|
||||
# CHECK: cast = #linalg.type_fn<cast_unsigned>
|
||||
# CHECK-NOT: linalg.memoized_indexing_maps
|
||||
# CHECK-SAME: ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>)
|
||||
# CHECK-SAME: outs(%{{.*}} : tensor<4x8xf32>)
|
||||
# CHECK-SAME: -> tensor<4x8xf32>
|
||||
# CHECK-NEXT: return
|
||||
return linalg.matmul(lhs, rhs, outs=[init_result.result])
|
||||
return linalg.matmul(
|
||||
lhs, rhs, outs=[init_result.result], cast=TypeFn.cast_unsigned)
|
||||
|
||||
print(module)
|
||||
|
||||
|
|
|
@ -9,6 +9,8 @@ from mlir.dialects import std
|
|||
from mlir.passmanager import *
|
||||
from mlir.execution_engine import *
|
||||
|
||||
from mlir.dialects.linalg.opdsl.lang import *
|
||||
|
||||
|
||||
# Log everything to stderr and flush so that we have a unified stream to match
|
||||
# errors/info emitted by MLIR to stderr.
|
||||
|
@ -20,21 +22,28 @@ def log(*args):
|
|||
matmul_boiler = """
|
||||
func @main() -> f32 attributes {llvm.emit_c_interface} {
|
||||
%v0 = arith.constant 0.0 : f32
|
||||
%v1 = arith.constant 1.0 : f32
|
||||
%v1 = arith.constant -1 : i8
|
||||
%v2 = arith.constant 2.0 : f32
|
||||
|
||||
%A = memref.alloc() : memref<4x16xf32>
|
||||
%A = memref.alloc() : memref<4x16xi8>
|
||||
%B = memref.alloc() : memref<16x8xf32>
|
||||
%C = memref.alloc() : memref<4x8xf32>
|
||||
linalg.fill(%v1, %A) : f32, memref<4x16xf32>
|
||||
%C0 = memref.alloc() : memref<4x8xf32>
|
||||
%C1 = memref.alloc() : memref<4x8xf32>
|
||||
linalg.fill(%v1, %A) : i8, memref<4x16xi8>
|
||||
linalg.fill(%v2, %B) : f32, memref<16x8xf32>
|
||||
linalg.fill(%v0, %C) : f32, memref<4x8xf32>
|
||||
linalg.fill(%v0, %C0) : f32, memref<4x8xf32>
|
||||
linalg.fill(%v0, %C1) : f32, memref<4x8xf32>
|
||||
|
||||
call @matmul_on_buffers(%A, %B, %C) :
|
||||
(memref<4x16xf32>, memref<16x8xf32>, memref<4x8xf32>) -> ()
|
||||
call @matmul_signed_on_buffers(%A, %B, %C0) :
|
||||
(memref<4x16xi8>, memref<16x8xf32>, memref<4x8xf32>) -> ()
|
||||
call @matmul_unsigned_on_buffers(%A, %B, %C1) :
|
||||
(memref<4x16xi8>, memref<16x8xf32>, memref<4x8xf32>) -> ()
|
||||
|
||||
%c0 = arith.constant 0 : index
|
||||
%0 = memref.load %C[%c0, %c0] : memref<4x8xf32>
|
||||
%res0 = memref.load %C0[%c0, %c0] : memref<4x8xf32>
|
||||
%res1 = memref.load %C1[%c0, %c0] : memref<4x8xf32>
|
||||
|
||||
%0 = arith.addf %res0, %res1 : f32
|
||||
|
||||
// TODO: FFI-based solution to allow testing and printing with python code.
|
||||
return %0 : f32
|
||||
|
@ -157,8 +166,8 @@ def transform(module, boilerplate):
|
|||
|
||||
pm = PassManager.parse(
|
||||
"builtin.func(convert-linalg-to-loops, lower-affine, " +
|
||||
"convert-scf-to-cf, arith-expand, memref-expand), convert-vector-to-llvm," +
|
||||
"convert-memref-to-llvm, convert-std-to-llvm," +
|
||||
"convert-scf-to-cf, arith-expand, memref-expand), convert-vector-to-llvm,"
|
||||
+ "convert-memref-to-llvm, convert-std-to-llvm," +
|
||||
"reconcile-unrealized-casts")
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
@ -168,14 +177,21 @@ def test_matmul_builtin():
|
|||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
f32 = F32Type.get()
|
||||
i8 = IntegerType.get_signless(8)
|
||||
with InsertionPoint(module.body):
|
||||
|
||||
@builtin.FuncOp.from_py_func(
|
||||
MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32),
|
||||
MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32),
|
||||
MemRefType.get((4, 8), f32))
|
||||
def matmul_on_buffers(lhs, rhs, out):
|
||||
def matmul_signed_on_buffers(lhs, rhs, out):
|
||||
linalg.matmul(lhs, rhs, outs=[out])
|
||||
|
||||
@builtin.FuncOp.from_py_func(
|
||||
MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32),
|
||||
MemRefType.get((4, 8), f32))
|
||||
def matmul_unsigned_on_buffers(lhs, rhs, out):
|
||||
linalg.matmul(lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned)
|
||||
|
||||
execution_engine = ExecutionEngine(transform(module, matmul_boiler))
|
||||
|
||||
# TODO: FFI-based solution to allow testing and printing with python code.
|
||||
|
@ -186,7 +202,9 @@ def test_matmul_builtin():
|
|||
execution_engine.invoke("main", res)
|
||||
|
||||
log("RESULT: ", res[0])
|
||||
# CHECK: RESULT: 32.0
|
||||
# matmul_signed_on_buffers: -1 * 2.0 * 16 = -32
|
||||
# matmul_unsigned_on_buffers: (2^8-1) * 2.0 * 16 = 8160
|
||||
# CHECK: RESULT: 8128
|
||||
|
||||
|
||||
test_matmul_builtin()
|
||||
|
@ -196,14 +214,22 @@ def test_matmul_generic():
|
|||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
f32 = F32Type.get()
|
||||
i8 = IntegerType.get_signless(8)
|
||||
with InsertionPoint(module.body):
|
||||
|
||||
@builtin.FuncOp.from_py_func(
|
||||
MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32),
|
||||
MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32),
|
||||
MemRefType.get((4, 8), f32))
|
||||
def matmul_on_buffers(lhs, rhs, out):
|
||||
def matmul_signed_on_buffers(lhs, rhs, out):
|
||||
linalg.matmul(lhs, rhs, outs=[out], emit_generic=True)
|
||||
|
||||
@builtin.FuncOp.from_py_func(
|
||||
MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32),
|
||||
MemRefType.get((4, 8), f32))
|
||||
def matmul_unsigned_on_buffers(lhs, rhs, out):
|
||||
linalg.matmul(
|
||||
lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned, emit_generic=True)
|
||||
|
||||
execution_engine = ExecutionEngine(transform(module, matmul_boiler))
|
||||
|
||||
# TODO: FFI-based solution to allow testing and printing with python code.
|
||||
|
@ -214,7 +240,9 @@ def test_matmul_generic():
|
|||
execution_engine.invoke("main", res)
|
||||
|
||||
log("RESULT: ", res[0])
|
||||
# CHECK: RESULT: 32.0
|
||||
# matmul_signed_on_buffers = -1 * 2.0 * 16 = -32
|
||||
# matmul_unsigned_on_buffers = (2^8-1) * 2.0 * 16 = 8160
|
||||
# CHECK: RESULT: 8128
|
||||
|
||||
|
||||
test_matmul_generic()
|
||||
|
@ -423,11 +451,7 @@ def test_min_pooling_builtin():
|
|||
MemRefType.get((1, 2, 4, 1), i32))
|
||||
# Set the strides and use the default dilations.
|
||||
def pooling_on_buffers(input, shape, output):
|
||||
linalg.pooling_nhwc_min(
|
||||
input,
|
||||
shape,
|
||||
outs=[output],
|
||||
strides=[2, 4])
|
||||
linalg.pooling_nhwc_min(input, shape, outs=[output], strides=[2, 4])
|
||||
|
||||
execution_engine = ExecutionEngine(transform(module, pooling_boiler))
|
||||
|
||||
|
@ -458,11 +482,7 @@ def test_min_pooling_generic():
|
|||
# Set the strides and use the default dilations.
|
||||
def pooling_on_buffers(input, shape, output):
|
||||
linalg.pooling_nhwc_min(
|
||||
input,
|
||||
shape,
|
||||
outs=[output],
|
||||
strides=[2, 4],
|
||||
emit_generic=True)
|
||||
input, shape, outs=[output], strides=[2, 4], emit_generic=True)
|
||||
|
||||
execution_engine = ExecutionEngine(transform(module, pooling_boiler))
|
||||
|
||||
|
|
|
@ -61,15 +61,22 @@ struct SerializedAffineMap {
|
|||
AffineMap affineMap() { return affineMapAttr.getValue(); }
|
||||
};
|
||||
|
||||
enum class LinalgOperandDefUsage { Input, Output, IndexAttr };
|
||||
enum class LinalgOperandDefKind {
|
||||
InputTensor,
|
||||
Scalar,
|
||||
OutputTensor,
|
||||
IndexAttr,
|
||||
TypeFnAttr
|
||||
};
|
||||
|
||||
struct LinalgOperandDef {
|
||||
std::string name;
|
||||
LinalgOperandDefUsage usage;
|
||||
LinalgOperandDefKind kind;
|
||||
Optional<std::string> typeVar;
|
||||
Optional<SerializedAffineMap> shapeMap;
|
||||
Optional<SerializedAffineMap> indexAttrMap;
|
||||
Optional<SmallVector<int64_t>> defaultVals;
|
||||
Optional<SmallVector<int64_t>> defaultIndices;
|
||||
Optional<std::string> defaultFn;
|
||||
};
|
||||
|
||||
enum class LinalgIteratorTypeDef {
|
||||
|
@ -91,11 +98,12 @@ struct ScalarArithFn {
|
|||
};
|
||||
|
||||
struct ScalarTypeFn {
|
||||
std::string fnName;
|
||||
std::string typeVar;
|
||||
// NOTE: This must be of arity 1, but to break the self-referential cycle,
|
||||
// we use a heap allocated vector.
|
||||
std::vector<ScalarExpression> operands;
|
||||
Optional<std::string> fnName;
|
||||
Optional<std::string> attrName;
|
||||
};
|
||||
|
||||
struct ScalarExpression {
|
||||
|
@ -180,27 +188,32 @@ struct MappingTraits<LinalgStructuredOpConfig> {
|
|||
/// index attribute symbols. During op creation these symbols are replaced
|
||||
/// by the corresponding `name` index attribue values. Only index attribute
|
||||
/// arguments have an `index_attr_map`.
|
||||
/// - `default_vals`: An optional default initialization for index attribute
|
||||
/// - `default_indices`: An optional default initialization for index
|
||||
/// attribute arguments.
|
||||
/// - `default_fn`: An optional default initialization for function attribute
|
||||
/// arguments.
|
||||
template <>
|
||||
struct MappingTraits<LinalgOperandDef> {
|
||||
static void mapping(IO &io, LinalgOperandDef &info) {
|
||||
io.mapRequired("name", info.name);
|
||||
io.mapRequired("usage", info.usage);
|
||||
io.mapRequired("kind", info.kind);
|
||||
io.mapOptional("type_var", info.typeVar);
|
||||
io.mapOptional("shape_map", info.shapeMap);
|
||||
io.mapOptional("index_attr_map", info.indexAttrMap);
|
||||
io.mapOptional("default_vals", info.defaultVals);
|
||||
io.mapOptional("default_indices", info.defaultIndices);
|
||||
io.mapOptional("default_fn", info.defaultFn);
|
||||
}
|
||||
};
|
||||
|
||||
/// Usage enum for a named argument.
|
||||
template <>
|
||||
struct ScalarEnumerationTraits<LinalgOperandDefUsage> {
|
||||
static void enumeration(IO &io, LinalgOperandDefUsage &value) {
|
||||
io.enumCase(value, "Input", LinalgOperandDefUsage::Input);
|
||||
io.enumCase(value, "Output", LinalgOperandDefUsage::Output);
|
||||
io.enumCase(value, "IndexAttr", LinalgOperandDefUsage::IndexAttr);
|
||||
struct ScalarEnumerationTraits<LinalgOperandDefKind> {
|
||||
static void enumeration(IO &io, LinalgOperandDefKind &value) {
|
||||
io.enumCase(value, "input_tensor", LinalgOperandDefKind::InputTensor);
|
||||
io.enumCase(value, "scalar", LinalgOperandDefKind::Scalar);
|
||||
io.enumCase(value, "output_tensor", LinalgOperandDefKind::OutputTensor);
|
||||
io.enumCase(value, "index_attr", LinalgOperandDefKind::IndexAttr);
|
||||
io.enumCase(value, "type_fn_attr", LinalgOperandDefKind::TypeFnAttr);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -281,9 +294,10 @@ struct MappingTraits<ScalarArithFn> {
|
|||
template <>
|
||||
struct MappingTraits<ScalarTypeFn> {
|
||||
static void mapping(IO &io, ScalarTypeFn &info) {
|
||||
io.mapRequired("fn_name", info.fnName);
|
||||
io.mapRequired("type_var", info.typeVar);
|
||||
io.mapRequired("operands", info.operands);
|
||||
io.mapOptional("fn_name", info.fnName);
|
||||
io.mapOptional("attr_name", info.attrName);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -399,8 +413,9 @@ findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgOperandDef> &args) {
|
|||
|
||||
// Search all argument types.
|
||||
for (const auto &it : llvm::enumerate(args)) {
|
||||
if (it.value().usage != LinalgOperandDefUsage::Input &&
|
||||
it.value().usage != LinalgOperandDefUsage::Output)
|
||||
if (it.value().kind != LinalgOperandDefKind::InputTensor &&
|
||||
it.value().kind != LinalgOperandDefKind::Scalar &&
|
||||
it.value().kind != LinalgOperandDefKind::OutputTensor)
|
||||
continue;
|
||||
if (it.value().typeVar.getValue() == typeVar)
|
||||
return llvm::formatv("block.getArgument({0}).getType()", it.index())
|
||||
|
@ -552,6 +567,8 @@ static const char structuredOpBuilderFormat[] = R"FMT(
|
|||
$_state.addOperands(inputs);
|
||||
$_state.addOperands(outputs);
|
||||
$_state.addTypes(resultTensorTypes);
|
||||
{2}
|
||||
$_state.addAttributes(attributes);
|
||||
$_state.addAttribute(
|
||||
"operand_segment_sizes",
|
||||
$_builder.getI32VectorAttr({{
|
||||
|
@ -562,8 +579,6 @@ static const char structuredOpBuilderFormat[] = R"FMT(
|
|||
$_state,
|
||||
TypeRange(inputs),
|
||||
TypeRange(outputs));
|
||||
{2}
|
||||
$_state.addAttributes(attributes);
|
||||
}]>
|
||||
)FMT";
|
||||
|
||||
|
@ -681,42 +696,56 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
|
|||
|
||||
interfaceNameList = interleaveToString(opConfig.metadata->implements, ", ");
|
||||
|
||||
// Assemble the attribute specific logic required for the op definition.
|
||||
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
|
||||
return arg.usage == LinalgOperandDefUsage::IndexAttr;
|
||||
return arg.kind == LinalgOperandDefKind::IndexAttr ||
|
||||
arg.kind == LinalgOperandDefKind::TypeFnAttr;
|
||||
})) {
|
||||
SmallVector<std::string> attrDefs;
|
||||
SmallVector<std::string> attrParams;
|
||||
SmallVector<std::string> attrStmts;
|
||||
for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
|
||||
if (arg.usage != LinalgOperandDefUsage::IndexAttr)
|
||||
continue;
|
||||
assert(arg.indexAttrMap.hasValue());
|
||||
assert(arg.defaultVals.hasValue());
|
||||
size_t size = arg.indexAttrMap->affineMap().getNumResults();
|
||||
assert(arg.defaultVals.getValue().size() == size);
|
||||
static const char typeFmt[] = "RankedI64ElementsAttr<[{0}]>";
|
||||
static const char defFmt[] = "DefaultValuedAttr<{0}, \"{1}\">:${2}";
|
||||
static const char paramFmt[] = "\"Attribute\":${0}";
|
||||
static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});";
|
||||
std::string defaultVals;
|
||||
llvm::raw_string_ostream ss(defaultVals);
|
||||
ss << "{ ";
|
||||
llvm::interleave(
|
||||
arg.defaultVals.getValue(), ss,
|
||||
[&](int64_t val) { ss << "static_cast<int64_t>(" << val << ")"; },
|
||||
", ");
|
||||
ss << " }";
|
||||
attrDefs.push_back(llvm::formatv(defFmt, llvm::formatv(typeFmt, size),
|
||||
ss.str(), arg.name));
|
||||
attrParams.push_back(llvm::formatv(paramFmt, arg.name));
|
||||
attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
|
||||
// Add the type conversion attributes to the op definition and builders.
|
||||
if (arg.kind == LinalgOperandDefKind::TypeFnAttr) {
|
||||
assert(arg.defaultFn.hasValue());
|
||||
static const char typeFmt[] = "TypeFn::{0}";
|
||||
static const char defFmt[] = "DefaultValuedAttr<{0}, \"{1}\">:${2}";
|
||||
attrDefs.push_back(llvm::formatv(defFmt, "TypeFnAttr",
|
||||
llvm::formatv(typeFmt, arg.defaultFn),
|
||||
arg.name));
|
||||
attrParams.push_back(llvm::formatv(paramFmt, arg.name));
|
||||
attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
|
||||
}
|
||||
// Add the index attributes to the op definition and builders.
|
||||
if (arg.kind == LinalgOperandDefKind::IndexAttr) {
|
||||
assert(arg.indexAttrMap.hasValue());
|
||||
assert(arg.defaultIndices.hasValue());
|
||||
size_t size = arg.indexAttrMap->affineMap().getNumResults();
|
||||
assert(arg.defaultIndices.getValue().size() == size);
|
||||
static const char typeFmt[] = "RankedI64ElementsAttr<[{0}]>";
|
||||
static const char defFmt[] = "DefaultValuedAttr<{0}, \"{ {1} }\">:${2}";
|
||||
std::string defaultVals;
|
||||
llvm::raw_string_ostream ss(defaultVals);
|
||||
llvm::interleave(
|
||||
arg.defaultIndices.getValue(), ss,
|
||||
[&](int64_t val) { ss << "static_cast<int64_t>(" << val << ")"; },
|
||||
", ");
|
||||
attrDefs.push_back(llvm::formatv(defFmt, llvm::formatv(typeFmt, size),
|
||||
ss.str(), arg.name));
|
||||
attrParams.push_back(llvm::formatv(paramFmt, arg.name));
|
||||
attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
|
||||
}
|
||||
}
|
||||
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
|
||||
return arg.kind == LinalgOperandDefKind::IndexAttr;
|
||||
})) {
|
||||
attrMethods = R"(
|
||||
bool hasDynamicIndexingMaps();
|
||||
LogicalResult verifyIndexingMapRequiredAttributes();
|
||||
)";
|
||||
}
|
||||
attrList = ",\n" + llvm::join(attrDefs, ",\n");
|
||||
attrMethods = R"(
|
||||
bool hasDynamicIndexingMaps();
|
||||
LogicalResult verifyIndexingMapRequiredAttributes();
|
||||
)";
|
||||
attrBuilder = llvm::formatv(
|
||||
structuredOpBuilderFormat, opConfig.metadata->cppClassName,
|
||||
llvm::join(attrParams, ", "), llvm::join(attrStmts, "\n"));
|
||||
|
@ -746,7 +775,9 @@ generateNamedGenericOpDefns(LinalgOpConfig &opConfig,
|
|||
// Compute the number of scalar and tensor arguments.
|
||||
int64_t numOfArgs =
|
||||
llvm::count_if(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
|
||||
return arg.usage != LinalgOperandDefUsage::IndexAttr;
|
||||
return arg.kind == LinalgOperandDefKind::InputTensor ||
|
||||
arg.kind == LinalgOperandDefKind::Scalar ||
|
||||
arg.kind == LinalgOperandDefKind::OutputTensor;
|
||||
});
|
||||
|
||||
// An operation that accesses only scalars and scalar/rank zero tensors is
|
||||
|
@ -817,7 +848,7 @@ exprs.push_back(getAffineConstantExpr(cst{1}, context));
|
|||
)FMT";
|
||||
// Update all symbol bindings mapped to an attribute.
|
||||
for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
|
||||
if (arg.usage != LinalgOperandDefUsage::IndexAttr)
|
||||
if (arg.kind != LinalgOperandDefKind::IndexAttr)
|
||||
continue;
|
||||
assert(arg.indexAttrMap.hasValue());
|
||||
for (auto &en :
|
||||
|
@ -910,11 +941,11 @@ std::string {0}::getLibraryCallName() {{
|
|||
|
||||
// hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes()
|
||||
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
|
||||
return arg.usage == LinalgOperandDefUsage::IndexAttr;
|
||||
return arg.kind == LinalgOperandDefKind::IndexAttr;
|
||||
})) {
|
||||
std::vector<std::string> attrVerifications;
|
||||
for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
|
||||
if (arg.usage != LinalgOperandDefUsage::IndexAttr)
|
||||
if (arg.kind != LinalgOperandDefKind::IndexAttr)
|
||||
continue;
|
||||
assert(arg.indexAttrMap.hasValue());
|
||||
// Verify index attribute. Paramters:
|
||||
|
@ -952,7 +983,8 @@ LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{
|
|||
// Generates a regionBuilder method. Parameters.
|
||||
// {0}: Class name
|
||||
// {1}: Number of args
|
||||
// {2}: Statements
|
||||
// {2}: Attributes
|
||||
// {3}: Statements
|
||||
static const char structuredOpRegionBuilderFormat[] = R"FMT(
|
||||
void {0}::regionBuilder(ImplicitLocOpBuilder &b,
|
||||
Block &block, ArrayRef<NamedAttribute> attrs) {{
|
||||
|
@ -961,6 +993,7 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b,
|
|||
RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
|
||||
SmallVector<Value> yields;
|
||||
{2}
|
||||
{3}
|
||||
helper.yieldOutputs(yields);
|
||||
}
|
||||
)FMT";
|
||||
|
@ -968,9 +1001,27 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b,
|
|||
auto &assignments = opConfig.structuredOp->assignments;
|
||||
size_t generatedAssignmentCount = 0;
|
||||
int localCounter = 0;
|
||||
SmallVector<std::string> attrs;
|
||||
SmallVector<std::string> stmts;
|
||||
for (LinalgOperandDef &arg : args) {
|
||||
if (arg.usage != LinalgOperandDefUsage::Output)
|
||||
if (arg.kind != LinalgOperandDefKind::TypeFnAttr)
|
||||
continue;
|
||||
// Obtain the type function attribute values. Parameters.
|
||||
// {0}: attribute name
|
||||
// {1}: default type function name
|
||||
static const char attrDef[] = R"FMT(
|
||||
TypeFn {0}Val = TypeFn::{1};
|
||||
auto {0}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
|
||||
return attr.getName() == "{0}"; });
|
||||
if ({0}Iter != attrs.end()) {{
|
||||
if (auto attr = {0}Iter->getValue().dyn_cast<TypeFnAttr>())
|
||||
{0}Val = attr.getValue();
|
||||
}
|
||||
)FMT";
|
||||
attrs.push_back(llvm::formatv(attrDef, arg.name, arg.defaultFn));
|
||||
}
|
||||
for (LinalgOperandDef &arg : args) {
|
||||
if (arg.kind != LinalgOperandDefKind::OutputTensor)
|
||||
continue;
|
||||
|
||||
// Find the assignment that correlates with the argument.
|
||||
|
@ -1048,11 +1099,25 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b,
|
|||
<< "an argument type but it does not";
|
||||
return None;
|
||||
}
|
||||
|
||||
// Use the function name or the attribute to build the type function.
|
||||
std::string typeFunc = llvm::formatv(
|
||||
"TypeFn::{0}", expression.typeFn->fnName.getValueOr(""));
|
||||
if (expression.typeFn->attrName) {
|
||||
if (llvm::none_of(args, [&](LinalgOperandDef &arg) {
|
||||
return arg.kind == LinalgOperandDefKind::TypeFnAttr &&
|
||||
arg.name == expression.typeFn->attrName.getValue();
|
||||
})) {
|
||||
emitError(genContext.getLoc())
|
||||
<< "missing type function attribute "
|
||||
<< expression.typeFn->attrName.getValue();
|
||||
}
|
||||
typeFunc = llvm::formatv("{0}Val", *expression.typeFn->attrName);
|
||||
}
|
||||
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
|
||||
stmts.push_back(
|
||||
llvm::formatv("Value {0} = helper.typefn__{1}({2}, {3});",
|
||||
cppIdent, expression.typeFn->fnName,
|
||||
typeCppValue.getValue(), *operandCppValue));
|
||||
stmts.push_back(llvm::formatv(
|
||||
"Value {0} = helper.buildTypeFn({1}, {2}, {3});", cppIdent,
|
||||
typeFunc, typeCppValue.getValue(), *operandCppValue));
|
||||
return cppIdent;
|
||||
}
|
||||
emitError(genContext.getLoc()) << "unknown ScalarExpression type";
|
||||
|
@ -1069,6 +1134,7 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b,
|
|||
<< "mismatched number of assignments vs output arguments";
|
||||
|
||||
os << llvm::formatv(structuredOpRegionBuilderFormat, className, numOfArgs,
|
||||
interleaveToString(attrs, "\n "),
|
||||
interleaveToString(stmts, "\n "));
|
||||
}
|
||||
|
||||
|
|
|
@ -6720,6 +6720,22 @@ gentbl_cc_library(
|
|||
],
|
||||
"include/mlir/Dialect/Linalg/IR/LinalgOpsDialect.cpp.inc",
|
||||
),
|
||||
(
|
||||
["-gen-enum-decls"],
|
||||
"include/mlir/Dialect/Linalg/IR/LinalgOpsEnums.h.inc",
|
||||
),
|
||||
(
|
||||
["-gen-enum-defs"],
|
||||
"include/mlir/Dialect/Linalg/IR/LinalgOpsEnums.cpp.inc",
|
||||
),
|
||||
(
|
||||
["-gen-attrdef-decls"],
|
||||
"include/mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.h.inc",
|
||||
),
|
||||
(
|
||||
["-gen-attrdef-defs"],
|
||||
"include/mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.cpp.inc",
|
||||
),
|
||||
],
|
||||
tblgen = ":mlir-tblgen",
|
||||
td_file = "include/mlir/Dialect/Linalg/IR/LinalgOps.td",
|
||||
|
|
Loading…
Reference in New Issue