forked from OSchip/llvm-project
[mlir][OpDSL] Add type function attributes.
Previously, OpDSL operation used hardcoded type conversion operations (cast or cast_unsigned). Supporting signed and unsigned casts thus meant implementing two different operations. Type function attributes allow us to define a single operation that has a cast type function attribute which at operation instantiation time may be set to cast or cast_unsigned. We may for example, defina a matmul operation with a cast argument: ``` @linalg_structured_op def matmul(A=TensorDef(T1, S.M, S.K), B=TensorDef(T2, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True), cast=TypeFnAttrDef(default=TypeFn.cast)): C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) ``` When instantiating the operation the attribute may be set to the desired cast function: ``` linalg.matmul(lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned) ``` The revsion introduces a enum in the Linalg dialect that maps one-by-one to the type functions defined by OpDSL. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D119718
This commit is contained in:
parent
3fe6f9388f
commit
51fdd802c7
|
@ -44,6 +44,18 @@ add_dependencies(mlir-headers LinalgOdsGen)
|
||||||
|
|
||||||
add_mlir_dialect(LinalgOps linalg)
|
add_mlir_dialect(LinalgOps linalg)
|
||||||
|
|
||||||
|
set(LLVM_TARGET_DEFINITIONS LinalgOps.td)
|
||||||
|
mlir_tablegen(LinalgOpsEnums.h.inc -gen-enum-decls)
|
||||||
|
mlir_tablegen(LinalgOpsEnums.cpp.inc -gen-enum-defs)
|
||||||
|
add_public_tablegen_target(MLIRLinalgOpsEnumsIncGen)
|
||||||
|
add_dependencies(mlir-headers MLIRLinalgOpsEnumsIncGen)
|
||||||
|
|
||||||
|
set(LLVM_TARGET_DEFINITIONS LinalgOps.td)
|
||||||
|
mlir_tablegen(LinalgOpsAttrDefs.h.inc -gen-attrdef-decls)
|
||||||
|
mlir_tablegen(LinalgOpsAttrDefs.cpp.inc -gen-attrdef-defs)
|
||||||
|
add_public_tablegen_target(MLIRLinalgOpsAttributesIncGen)
|
||||||
|
add_dependencies(mlir-headers MLIRLinalgOpsAttributesIncGen)
|
||||||
|
|
||||||
add_mlir_doc(LinalgDoc LinalgOps Dialects/ -gen-op-doc)
|
add_mlir_doc(LinalgDoc LinalgOps Dialects/ -gen-op-doc)
|
||||||
add_dependencies(LinalgOpsDocGen LinalgOdsGen)
|
add_dependencies(LinalgOpsDocGen LinalgOdsGen)
|
||||||
|
|
||||||
|
|
|
@ -104,6 +104,19 @@ LogicalResult verifyStructuredOpInterface(Operation *op);
|
||||||
|
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"
|
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Linalg Enums
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Linalg/IR/LinalgOpsEnums.h.inc"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Linalg Attributes
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#define GET_ATTRDEF_CLASSES
|
||||||
|
#include "mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.h.inc"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Linalg Interfaces
|
// Linalg Interfaces
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
#ifndef LINALG_BASE
|
#ifndef LINALG_BASE
|
||||||
#define LINALG_BASE
|
#define LINALG_BASE
|
||||||
|
|
||||||
|
include "mlir/IR/EnumAttr.td"
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
|
|
||||||
def Linalg_Dialect : Dialect {
|
def Linalg_Dialect : Dialect {
|
||||||
|
@ -57,4 +58,17 @@ def Linalg_Dialect : Dialect {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Define a TypeFn enum matching the OpDSL TypeFn class.
|
||||||
|
def TypeFn : I32EnumAttr<"TypeFn", "", [
|
||||||
|
I32EnumAttrCase<"cast", 0>,
|
||||||
|
I32EnumAttrCase<"cast_unsigned", 1>
|
||||||
|
]> {
|
||||||
|
let genSpecializedAttr = 0;
|
||||||
|
let cppNamespace = "::mlir::linalg";
|
||||||
|
}
|
||||||
|
|
||||||
|
def TypeFnAttr : EnumAttr<Linalg_Dialect, TypeFn, "type_fn"> {
|
||||||
|
let assemblyFormat = "`<` $value `>`";
|
||||||
|
}
|
||||||
|
|
||||||
#endif // LINALG_BASE
|
#endif // LINALG_BASE
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -37,7 +37,7 @@ class LinalgStructuredBase_Op<string mnemonic, list<Trait> props>
|
||||||
|
|
||||||
LogicalResult reifyResultShapes(OpBuilder &b,
|
LogicalResult reifyResultShapes(OpBuilder &b,
|
||||||
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
|
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
|
||||||
return cast<LinalgOp>(getOperation()).reifyResultShapes(b,
|
return llvm::cast<LinalgOp>(getOperation()).reifyResultShapes(b,
|
||||||
reifiedReturnShapes);
|
reifiedReturnShapes);
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
|
|
@ -8,6 +8,8 @@ add_mlir_dialect_library(MLIRLinalg
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
MLIRLinalgInterfacesIncGen
|
MLIRLinalgInterfacesIncGen
|
||||||
|
MLIRLinalgOpsAttributesIncGen
|
||||||
|
MLIRLinalgOpsEnumsIncGen
|
||||||
MLIRLinalgOpsIncGen
|
MLIRLinalgOpsIncGen
|
||||||
MLIRLinalgStructuredOpsIncGen
|
MLIRLinalgStructuredOpsIncGen
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include "mlir/Transforms/InliningUtils.h"
|
#include "mlir/Transforms/InliningUtils.h"
|
||||||
|
|
||||||
#include "llvm/ADT/StringExtras.h"
|
#include "llvm/ADT/StringExtras.h"
|
||||||
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
@ -95,6 +96,10 @@ void addNamedOpBuilders(
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::linalg::LinalgDialect::initialize() {
|
void mlir::linalg::LinalgDialect::initialize() {
|
||||||
|
addAttributes<
|
||||||
|
#define GET_ATTRDEF_LIST
|
||||||
|
#include "mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.cpp.inc"
|
||||||
|
>();
|
||||||
addOperations<
|
addOperations<
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
|
||||||
|
@ -144,3 +149,10 @@ LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op,
|
||||||
return op->emitError() << "attribute '" << attr.getName()
|
return op->emitError() << "attribute '" << attr.getName()
|
||||||
<< "' not supported by the linalg dialect";
|
<< "' not supported by the linalg dialect";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Linalg/IR/LinalgOpsEnums.cpp.inc"
|
||||||
|
|
||||||
|
#define GET_ATTRDEF_CLASSES
|
||||||
|
#include "mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.cpp.inc"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.cpp.inc"
|
||||||
|
|
|
@ -36,8 +36,6 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::linalg;
|
using namespace mlir::linalg;
|
||||||
|
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.cpp.inc"
|
|
||||||
|
|
||||||
/// Forward declarations.
|
/// Forward declarations.
|
||||||
|
|
||||||
/// Generic entry point to create the block for the region of a LinalgOp.
|
/// Generic entry point to create the block for the region of a LinalgOp.
|
||||||
|
@ -232,14 +230,14 @@ public:
|
||||||
return operand;
|
return operand;
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
|
||||||
Value typefn__cast(Type toType, Value operand) {
|
switch (typeFn) {
|
||||||
return cast(toType, operand, false);
|
case TypeFn::cast:
|
||||||
}
|
return cast(toType, operand, false);
|
||||||
|
case TypeFn::cast_unsigned:
|
||||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
return cast(toType, operand, true);
|
||||||
Value typefn__cast_unsigned(Type toType, Value operand) {
|
}
|
||||||
return cast(toType, operand, true);
|
llvm_unreachable("unsupported type conversion function");
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||||
|
|
|
@ -111,7 +111,7 @@ class TensorUse(TensorExpression):
|
||||||
@property
|
@property
|
||||||
def tensor_name(self) -> str:
|
def tensor_name(self) -> str:
|
||||||
name = self.operand_def.name
|
name = self.operand_def.name
|
||||||
assert name is not None, "TensorDef not attached"
|
assert name is not None, "TensorDef not registered with an op"
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]:
|
def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]:
|
||||||
|
@ -129,7 +129,8 @@ class TensorUse(TensorExpression):
|
||||||
return ReduceFnUse(ArithFn.add, *self._compute_reduce_dims(rhs))(rhs)
|
return ReduceFnUse(ArithFn.add, *self._compute_reduce_dims(rhs))(rhs)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"{self.tensor_name}[{', '.join([repr(i) for i in self.indices])}]"
|
return (f"{self.operand_def.name}"
|
||||||
|
f"[{', '.join([repr(i) for i in self.indices])}]")
|
||||||
|
|
||||||
|
|
||||||
class TensorArithFn(TensorExpression):
|
class TensorArithFn(TensorExpression):
|
||||||
|
@ -156,14 +157,22 @@ class TensorArithFn(TensorExpression):
|
||||||
class TensorTypeFn(TensorExpression):
|
class TensorTypeFn(TensorExpression):
|
||||||
"""Application of a type conversion function."""
|
"""Application of a type conversion function."""
|
||||||
|
|
||||||
def __init__(self, type_fn: "TypeFn", type_var: TypeVar,
|
def __init__(self, type_fn: Optional["TypeFn"],
|
||||||
|
operand_def: Optional["OperandDef"], type_var: TypeVar,
|
||||||
arg: TensorExpression):
|
arg: TensorExpression):
|
||||||
|
if bool(type_fn) + bool(operand_def) != 1:
|
||||||
|
raise ValueError("Either 'type_fn' or 'operand_def' must be specified")
|
||||||
self.type_fn = type_fn
|
self.type_fn = type_fn
|
||||||
|
self.operand_def = operand_def
|
||||||
self.type_var = type_var
|
self.type_var = type_var
|
||||||
self.arg = arg
|
self.arg = arg
|
||||||
|
|
||||||
def to_scalar_expression(self) -> ScalarExpression:
|
def to_scalar_expression(self) -> ScalarExpression:
|
||||||
return ScalarTypeFn(self.type_fn.fn_name, self.type_var,
|
if self.operand_def:
|
||||||
|
assert self.operand_def.name, "TypeFnAttr not registered with an op"
|
||||||
|
fn_name = self.type_fn.fn_name if self.type_fn else None
|
||||||
|
attr_name = self.operand_def.name if self.operand_def else None
|
||||||
|
return ScalarTypeFn(fn_name, attr_name, self.type_var,
|
||||||
self.arg.to_scalar_expression()).expr()
|
self.arg.to_scalar_expression()).expr()
|
||||||
|
|
||||||
def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
|
def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
|
||||||
|
@ -171,7 +180,8 @@ class TensorTypeFn(TensorExpression):
|
||||||
self.arg.visit_tensor_exprs(callback)
|
self.arg.visit_tensor_exprs(callback)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"{repr(self.type_fn)}({self.type_var}, {self.arg})"
|
return (f"{repr(self.type_fn)}[{repr(self.operand_def)}]"
|
||||||
|
f"({self.type_var}, {self.arg})")
|
||||||
|
|
||||||
|
|
||||||
class TensorReduceFn(TensorExpression):
|
class TensorReduceFn(TensorExpression):
|
||||||
|
@ -260,7 +270,7 @@ class TypeFnType:
|
||||||
self.fn_name = fn_name
|
self.fn_name = fn_name
|
||||||
|
|
||||||
def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TypeFnType":
|
def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TypeFnType":
|
||||||
return TensorTypeFn(self, type_var, arg)
|
return TensorTypeFn(self, None, type_var, arg)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"{self.fn_name}"
|
return f"{self.fn_name}"
|
||||||
|
@ -370,10 +380,11 @@ class ReduceFn:
|
||||||
|
|
||||||
|
|
||||||
class OperandKind(Enum):
|
class OperandKind(Enum):
|
||||||
InputTensor = 0
|
INPUT_TENSOR = 0
|
||||||
Scalar = 1
|
SCALAR = 1
|
||||||
OutputTensor = 2
|
OUTPUT_TENSOR = 2
|
||||||
IndexAttr = 3
|
INDEX_ATTR = 3
|
||||||
|
TYPE_FN_ATTR = 4
|
||||||
|
|
||||||
|
|
||||||
class OperandDef:
|
class OperandDef:
|
||||||
|
@ -388,7 +399,8 @@ class OperandDef:
|
||||||
type_var: Optional[TypeVar] = None,
|
type_var: Optional[TypeVar] = None,
|
||||||
size_exprs: Optional[Sequence[AffineExprDef]] = None,
|
size_exprs: Optional[Sequence[AffineExprDef]] = None,
|
||||||
index_dims: Optional[Sequence[DimDef]] = None,
|
index_dims: Optional[Sequence[DimDef]] = None,
|
||||||
default_vals: Optional[Sequence[int]] = None):
|
default_indices: Optional[Sequence[int]] = None,
|
||||||
|
default_fn: Optional[str] = None):
|
||||||
if type_var and not isinstance(type_var, TypeVar):
|
if type_var and not isinstance(type_var, TypeVar):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"OperandDef requires a TypeVar but got {repr(type_var)}")
|
f"OperandDef requires a TypeVar but got {repr(type_var)}")
|
||||||
|
@ -396,25 +408,40 @@ class OperandDef:
|
||||||
self.type_var = type_var
|
self.type_var = type_var
|
||||||
self.size_exprs = size_exprs
|
self.size_exprs = size_exprs
|
||||||
self.index_dims = index_dims
|
self.index_dims = index_dims
|
||||||
self.default_vals = default_vals
|
self.default_indices = default_indices
|
||||||
|
self.default_fn = default_fn
|
||||||
self.kind = kind
|
self.kind = kind
|
||||||
self.name = None # type: Optional[str]
|
self.name = None # type: Optional[str]
|
||||||
self.registered_index = -1 # type: int
|
self.registered_index = -1 # type: int
|
||||||
|
|
||||||
def attach(self, index: int, name: str, owner: "LinalgOpDef"):
|
def attach(self, index: int, name: str, owner: "LinalgOpDef"):
|
||||||
if self.owner:
|
if self.owner:
|
||||||
raise ValueError(f"OperandDef already registered with op: {self}")
|
raise ValueError(f"OperandDef already registered with an op: {self}")
|
||||||
self.registered_index = index
|
self.registered_index = index
|
||||||
self.name = name
|
self.name = name
|
||||||
self.owner = owner
|
self.owner = owner
|
||||||
|
|
||||||
|
def is_input(self) -> bool:
|
||||||
|
return (self.kind == OperandKind.SCALAR or
|
||||||
|
self.kind == OperandKind.INPUT_TENSOR)
|
||||||
|
|
||||||
|
def is_tensor(self) -> bool:
|
||||||
|
return (self.kind == OperandKind.INPUT_TENSOR or
|
||||||
|
self.kind == OperandKind.OUTPUT_TENSOR)
|
||||||
|
|
||||||
|
def is_attribute(self) -> bool:
|
||||||
|
return (self.kind == OperandKind.INDEX_ATTR or
|
||||||
|
self.kind == OperandKind.TYPE_FN_ATTR)
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash(id(self))
|
return hash(id(self))
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (f"{self.name}:OperandDef(kind={self.kind.name}, "
|
return (f"{self.name}:OperandDef(kind={self.kind.name}, "
|
||||||
f"type={repr(self.type_var)}, size_exprs={self.size_exprs}), "
|
f"type={repr(self.type_var)}, size_exprs={self.size_exprs}, "
|
||||||
f"index_dims={self.index_dims}, default_vals={self.default_vals})")
|
f"index_dims={self.index_dims}, "
|
||||||
|
f"default_indices={self.default_indices}, "
|
||||||
|
f"default_fn={self.default_fn})")
|
||||||
|
|
||||||
|
|
||||||
class TensorDef:
|
class TensorDef:
|
||||||
|
@ -440,12 +467,12 @@ class TensorDef:
|
||||||
if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims):
|
if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims):
|
||||||
raise ValueError(f"TensorDef requires index dims of type DimDef but "
|
raise ValueError(f"TensorDef requires index dims of type DimDef but "
|
||||||
f"got {index_dims}")
|
f"got {index_dims}")
|
||||||
kind = OperandKind.OutputTensor if output else OperandKind.InputTensor
|
kind = OperandKind.OUTPUT_TENSOR if output else OperandKind.INPUT_TENSOR
|
||||||
self.operand_def = OperandDef(
|
self.operand_def = OperandDef(
|
||||||
kind, type_var=type_var, size_exprs=shape, index_dims=index_dims)
|
kind, type_var=type_var, size_exprs=shape, index_dims=index_dims)
|
||||||
|
|
||||||
def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse:
|
def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse:
|
||||||
assert self.operand_def.owner, "TensorDef is not attached to an op"
|
assert self.operand_def.owner, "TensorDef is not registered with an op"
|
||||||
state = AffineBuildState(
|
state = AffineBuildState(
|
||||||
global_state=self.operand_def.owner._affine_state,
|
global_state=self.operand_def.owner._affine_state,
|
||||||
allow_new_symbols=False)
|
allow_new_symbols=False)
|
||||||
|
@ -486,12 +513,12 @@ class ScalarDef(TensorExpression):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, type_var: TypeVar):
|
def __init__(self, type_var: TypeVar):
|
||||||
self.operand_def = OperandDef(OperandKind.Scalar, type_var=type_var)
|
self.operand_def = OperandDef(OperandKind.SCALAR, type_var=type_var)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scalar_name(self) -> str:
|
def scalar_name(self) -> str:
|
||||||
name = self.operand_def.name
|
name = self.operand_def.name
|
||||||
assert name is not None, "ScalarDef not attached"
|
assert name is not None, "ScalarDef not registered with an op"
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def to_scalar_expression(self) -> ScalarExpression:
|
def to_scalar_expression(self) -> ScalarExpression:
|
||||||
|
@ -517,7 +544,26 @@ class IndexAttrDef:
|
||||||
raise ValueError(f"IndexAttrDef expects {len(sizes)} default values "
|
raise ValueError(f"IndexAttrDef expects {len(sizes)} default values "
|
||||||
f"but got {len(default)}")
|
f"but got {len(default)}")
|
||||||
self.operand_def = OperandDef(
|
self.operand_def = OperandDef(
|
||||||
OperandKind.IndexAttr, size_exprs=sizes, default_vals=default)
|
OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default)
|
||||||
|
|
||||||
|
|
||||||
|
class TypeFnAttrDef:
|
||||||
|
"""Type conversion function attribute definition.
|
||||||
|
|
||||||
|
Type conversion function attributes provide a way to make type conversions
|
||||||
|
parameterizable. Every attribute specifies a default type conversion function
|
||||||
|
that may be overwritten at operation instantiation time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, default: "TypeFnType"):
|
||||||
|
if not isinstance(default, TypeFnType):
|
||||||
|
raise ValueError(f"TypeFnAttrDef requires default of type TypeFnType "
|
||||||
|
f"but got {default}")
|
||||||
|
self.operand_def = OperandDef(
|
||||||
|
OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name)
|
||||||
|
|
||||||
|
def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorTypeFn:
|
||||||
|
return TensorTypeFn(None, self.operand_def, type_var, arg)
|
||||||
|
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
@ -615,17 +661,21 @@ class LinalgOpDef:
|
||||||
if name in self.registered_operands:
|
if name in self.registered_operands:
|
||||||
raise ValueError(f"The operand {name} is already registered "
|
raise ValueError(f"The operand {name} is already registered "
|
||||||
f"to {self.registered_operands['name']}")
|
f"to {self.registered_operands['name']}")
|
||||||
|
structured_op_methods = [
|
||||||
|
"inputs", "outputs", "result_tensors", "region", "iterator_types",
|
||||||
|
"indexing_maps", "getRegionBuilder", "getLibraryCallName"
|
||||||
|
]
|
||||||
|
if operand.is_attribute() and name in structured_op_methods:
|
||||||
|
raise ValueError(f"The attribute name {name} conflicts with a structured "
|
||||||
|
f"op method name")
|
||||||
# Ensure output tensors are registered after input tensors and scalars and
|
# Ensure output tensors are registered after input tensors and scalars and
|
||||||
# attributes are registered after all other operand types.
|
# attributes are registered after all other operand types.
|
||||||
registered_kinds = [
|
if operand.is_input() and any(
|
||||||
operand.kind.value for operand in self.registered_operands.values()
|
not op_def.is_input() for op_def in self.registered_operands.values()):
|
||||||
]
|
raise ValueError(f"Input {name} registered after an output or attribute")
|
||||||
if registered_kinds:
|
if operand.kind == OperandKind.OUTPUT_TENSOR and any(
|
||||||
maximum = max(registered_kinds)
|
op_def.is_attribute() for op_def in self.registered_operands.values()):
|
||||||
if maximum > operand.kind.value and maximum > OperandKind.Scalar.value:
|
raise ValueError(f"Output {name} registered after an attribute")
|
||||||
raise ValueError(
|
|
||||||
f"The operand {name} of kind {operand.kind.name} is registered "
|
|
||||||
f"after an operand of kind {OperandKind(maximum).name}")
|
|
||||||
operand.attach(len(self.registered_operands), name, self)
|
operand.attach(len(self.registered_operands), name, self)
|
||||||
self.registered_operands[name] = operand
|
self.registered_operands[name] = operand
|
||||||
|
|
||||||
|
|
|
@ -55,28 +55,26 @@ class OperandDefConfig(YAMLObject):
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return self.operand_def.name
|
return self.operand_def.name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def kind(self) -> OperandKind:
|
||||||
|
return self.operand_def.kind
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type_var(self) -> TypeVar:
|
def type_var(self) -> TypeVar:
|
||||||
return self.operand_def.type_var
|
return self.operand_def.type_var
|
||||||
|
|
||||||
@property
|
|
||||||
def usage(self) -> str:
|
|
||||||
if self.operand_def.kind == OperandKind.IndexAttr:
|
|
||||||
return "IndexAttr"
|
|
||||||
if self.operand_def.kind == OperandKind.OutputTensor:
|
|
||||||
return "Output"
|
|
||||||
return "Input"
|
|
||||||
|
|
||||||
def to_yaml_custom_dict(self):
|
def to_yaml_custom_dict(self):
|
||||||
self_dict = dict(name=self.name, usage=self.usage)
|
self_dict = dict(name=self.name, kind=self.operand_def.kind.name.lower())
|
||||||
if self.type_var:
|
if self.type_var:
|
||||||
self_dict["type_var"] = self.type_var.name
|
self_dict["type_var"] = self.type_var.name
|
||||||
if self.shape_map:
|
if self.shape_map:
|
||||||
self_dict["shape_map"] = _serialize_affine_map(self.shape_map)
|
self_dict["shape_map"] = _serialize_affine_map(self.shape_map)
|
||||||
if self.index_attr_map:
|
if self.index_attr_map:
|
||||||
self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map)
|
self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map)
|
||||||
if self.operand_def.default_vals:
|
if self.operand_def.default_indices:
|
||||||
self_dict["default_vals"] = self.operand_def.default_vals
|
self_dict["default_indices"] = self.operand_def.default_indices
|
||||||
|
if self.operand_def.default_fn:
|
||||||
|
self_dict["default_fn"] = self.operand_def.default_fn
|
||||||
return self_dict
|
return self_dict
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
@ -166,7 +164,7 @@ class LinalgStructuredOpConfig(YAMLObject):
|
||||||
# Collect all attribute definitions.
|
# Collect all attribute definitions.
|
||||||
collected_attr_defs = list()
|
collected_attr_defs = list()
|
||||||
for operand in registered_operands:
|
for operand in registered_operands:
|
||||||
if operand.kind == OperandKind.IndexAttr:
|
if operand.is_attribute():
|
||||||
collected_attr_defs.append(operand)
|
collected_attr_defs.append(operand)
|
||||||
|
|
||||||
# Collect all tensors with manual indexing annotation.
|
# Collect all tensors with manual indexing annotation.
|
||||||
|
@ -244,12 +242,12 @@ class LinalgStructuredOpConfig(YAMLObject):
|
||||||
|
|
||||||
# Set the indexing map of all scalar uses to the empty map.
|
# Set the indexing map of all scalar uses to the empty map.
|
||||||
for operand_config in self.operands.values():
|
for operand_config in self.operands.values():
|
||||||
if operand_config.operand_def.kind == OperandKind.Scalar:
|
if operand_config.operand_def.kind == OperandKind.SCALAR:
|
||||||
operand_config.indexing_map = self._get_scalar_map()
|
operand_config.indexing_map = self._get_scalar_map()
|
||||||
|
|
||||||
# Check all registered tensor and scalar operands have an indexing map.
|
# Check all registered tensor and scalar operands have an indexing map.
|
||||||
for operand in registered_operands:
|
for operand in registered_operands:
|
||||||
if operand.kind == OperandKind.IndexAttr:
|
if operand.is_attribute():
|
||||||
continue
|
continue
|
||||||
if not (operand in self.operands and self.operands[operand].indexing_map):
|
if not (operand in self.operands and self.operands[operand].indexing_map):
|
||||||
raise ValueError(f"Failed to compute an indexing map for operand "
|
raise ValueError(f"Failed to compute an indexing map for operand "
|
||||||
|
@ -311,7 +309,8 @@ class LinalgStructuredOpConfig(YAMLObject):
|
||||||
def add_operand(self, operand_def: OperandDef):
|
def add_operand(self, operand_def: OperandDef):
|
||||||
if operand_def in self.operands:
|
if operand_def in self.operands:
|
||||||
return
|
return
|
||||||
if operand_def.kind == OperandKind.Scalar:
|
if (operand_def.kind == OperandKind.SCALAR or
|
||||||
|
operand_def.kind == OperandKind.TYPE_FN_ATTR):
|
||||||
self.operands[operand_def] = OperandDefConfig(operand_def)
|
self.operands[operand_def] = OperandDefConfig(operand_def)
|
||||||
return
|
return
|
||||||
with self.context:
|
with self.context:
|
||||||
|
@ -323,7 +322,7 @@ class LinalgStructuredOpConfig(YAMLObject):
|
||||||
assert local_state.local_dim_count == 0
|
assert local_state.local_dim_count == 0
|
||||||
affine_map = _ir.AffineMap.get(
|
affine_map = _ir.AffineMap.get(
|
||||||
dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs)
|
dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs)
|
||||||
if operand_def.kind == OperandKind.IndexAttr:
|
if operand_def.kind == OperandKind.INDEX_ATTR:
|
||||||
self.operands[operand_def] = OperandDefConfig(
|
self.operands[operand_def] = OperandDefConfig(
|
||||||
operand_def, index_attr_map=affine_map)
|
operand_def, index_attr_map=affine_map)
|
||||||
else:
|
else:
|
||||||
|
@ -429,8 +428,7 @@ class LinalgOpConfig(YAMLObject):
|
||||||
context: Optional[_ir.Context] = None) -> Sequence["LinalgOpConfig"]:
|
context: Optional[_ir.Context] = None) -> Sequence["LinalgOpConfig"]:
|
||||||
"""Expands a LinalgOpDef into corresponding Linalg configured ops."""
|
"""Expands a LinalgOpDef into corresponding Linalg configured ops."""
|
||||||
# TODO: Many LinalgOpDef patterns need to expand to multiple generics.
|
# TODO: Many LinalgOpDef patterns need to expand to multiple generics.
|
||||||
assert len(
|
assert len(op_def.comprehensions) == 1, "Only one comprehension supported"
|
||||||
op_def.comprehensions) == 1, "Only one comprehension supported"
|
|
||||||
return [
|
return [
|
||||||
LinalgOpConfig(
|
LinalgOpConfig(
|
||||||
op_def.metadata,
|
op_def.metadata,
|
||||||
|
|
|
@ -129,7 +129,8 @@ def linalg_structured_op(dsl_func=None,
|
||||||
sig = inspect.signature(dsl_func)
|
sig = inspect.signature(dsl_func)
|
||||||
for param_name, param in sig.parameters.items():
|
for param_name, param in sig.parameters.items():
|
||||||
param_default = param.default
|
param_default = param.default
|
||||||
if isinstance(param_default, (TensorDef, ScalarDef, IndexAttrDef)):
|
if isinstance(param_default,
|
||||||
|
(TensorDef, ScalarDef, IndexAttrDef, TypeFnAttrDef)):
|
||||||
op_def.add_operand(param_name, param_default.operand_def)
|
op_def.add_operand(param_name, param_default.operand_def)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -37,11 +37,21 @@ def isa(cls: Type, ty: Type):
|
||||||
|
|
||||||
def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
|
def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
|
||||||
*ins: Value, outs: ValueList,
|
*ins: Value, outs: ValueList,
|
||||||
**attrs: Sequence[int]):
|
**attrs: Union[Sequence[int], TypeFnType]):
|
||||||
all_arg_defs = op_config.ordered_operands
|
all_arg_defs = op_config.ordered_operands
|
||||||
in_arg_defs = [d for d in all_arg_defs if d.usage == "Input"]
|
in_arg_defs = [
|
||||||
out_arg_defs = [d for d in all_arg_defs if d.usage == "Output"]
|
d for d in all_arg_defs
|
||||||
index_attr_arg_defs = [d for d in all_arg_defs if d.usage == "IndexAttr"]
|
if d.kind == OperandKind.SCALAR or d.kind == OperandKind.INPUT_TENSOR
|
||||||
|
]
|
||||||
|
out_arg_defs = [
|
||||||
|
d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR
|
||||||
|
]
|
||||||
|
index_attr_arg_defs = [
|
||||||
|
d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR
|
||||||
|
]
|
||||||
|
type_fn_attr_arg_defs = [
|
||||||
|
d for d in all_arg_defs if d.kind == OperandKind.TYPE_FN_ATTR
|
||||||
|
]
|
||||||
|
|
||||||
# Verify outs is a sequence or a list of results.
|
# Verify outs is a sequence or a list of results.
|
||||||
if not isinstance(outs, (Sequence, OpResultList)):
|
if not isinstance(outs, (Sequence, OpResultList)):
|
||||||
|
@ -56,11 +66,11 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
|
||||||
raise ValueError(f"Expected {len(out_arg_defs)} outputs but got "
|
raise ValueError(f"Expected {len(out_arg_defs)} outputs but got "
|
||||||
f"{len(outs)} for {op_config}")
|
f"{len(outs)} for {op_config}")
|
||||||
|
|
||||||
# Compute a replacement list for all attribute symbols.
|
# Compute a replacement list for all index attribute symbols.
|
||||||
expressions = [] # type: Sequence[AffineExpr]
|
expressions = [] # type: Sequence[AffineExpr]
|
||||||
replacements = [] # type: Sequence[AffineExpr]
|
replacements = [] # type: Sequence[AffineExpr]
|
||||||
for index_attr in index_attr_arg_defs:
|
for index_attr in index_attr_arg_defs:
|
||||||
index_attr_vals = index_attr.operand_def.default_vals
|
index_attr_vals = index_attr.operand_def.default_indices
|
||||||
if index_attr.name in attrs:
|
if index_attr.name in attrs:
|
||||||
index_attr_vals = attrs.get(index_attr.name)
|
index_attr_vals = attrs.get(index_attr.name)
|
||||||
assert index_attr_vals, "Index attribute has no value"
|
assert index_attr_vals, "Index attribute has no value"
|
||||||
|
@ -125,15 +135,29 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
|
||||||
array = np.array(index_attr_vals, dtype=np.int64)
|
array = np.array(index_attr_vals, dtype=np.int64)
|
||||||
index_attrs[index_attr.name] = DenseElementsAttr.get(array)
|
index_attrs[index_attr.name] = DenseElementsAttr.get(array)
|
||||||
|
|
||||||
|
# Compute the type function attribute mapping.
|
||||||
|
type_fn_attr_mapping = {}
|
||||||
|
for type_fn_attr in type_fn_attr_arg_defs:
|
||||||
|
attr_val = type_fn_attr.operand_def.default_fn
|
||||||
|
if type_fn_attr.name in attrs:
|
||||||
|
type_fn = attrs.get(type_fn_attr.name)
|
||||||
|
if not isinstance(type_fn, TypeFnType):
|
||||||
|
raise ValueError(f"Attribute {type_fn_attr.name} needs to be of type "
|
||||||
|
f"TypeFnType but got {type(attr_val)}")
|
||||||
|
attr_val = type_fn.fn_name
|
||||||
|
assert attr_val, "Type function attribute has no value"
|
||||||
|
type_fn_attr_mapping[type_fn_attr.name] = attr_val
|
||||||
|
|
||||||
return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types,
|
return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types,
|
||||||
type_mapping, indexing_maps_attr, iterator_types_attr,
|
type_mapping, indexing_maps_attr, iterator_types_attr, index_attrs,
|
||||||
index_attrs, block_arg_types)
|
type_fn_attr_mapping, block_arg_types)
|
||||||
|
|
||||||
|
|
||||||
def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
|
def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
|
||||||
outs: ValueList, **attrs: Sequence[int]):
|
outs: ValueList, **attrs: Sequence[int]):
|
||||||
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
|
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
|
||||||
indexing_maps_attr, iterator_types_attr, index_attrs, block_arg_types = \
|
indexing_maps_attr, iterator_types_attr, index_attrs, type_fn_attr_mapping, \
|
||||||
|
block_arg_types = \
|
||||||
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
|
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
|
||||||
|
|
||||||
# An operation that accesses only scalars and scalar/rank zero tensors is
|
# An operation that accesses only scalars and scalar/rank zero tensors is
|
||||||
|
@ -147,10 +171,9 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
|
||||||
tensor_map = AffineMap.get_identity(rank)
|
tensor_map = AffineMap.get_identity(rank)
|
||||||
indexing_maps = []
|
indexing_maps = []
|
||||||
for arg_def in all_arg_defs:
|
for arg_def in all_arg_defs:
|
||||||
if arg_def.operand_def.kind == OperandKind.Scalar:
|
if arg_def.operand_def.kind == OperandKind.SCALAR:
|
||||||
indexing_maps.append(scalar_map)
|
indexing_maps.append(scalar_map)
|
||||||
if (arg_def.operand_def.kind == OperandKind.InputTensor or
|
if arg_def.operand_def.is_tensor():
|
||||||
arg_def.operand_def.kind == OperandKind.OutputTensor):
|
|
||||||
indexing_maps.append(tensor_map)
|
indexing_maps.append(tensor_map)
|
||||||
indexing_maps_attr = ArrayAttr.get(
|
indexing_maps_attr = ArrayAttr.get(
|
||||||
[AffineMapAttr.get(am) for am in indexing_maps])
|
[AffineMapAttr.get(am) for am in indexing_maps])
|
||||||
|
@ -169,7 +192,8 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
|
||||||
block = generic_op.regions[0].blocks.append(*block_arg_types)
|
block = generic_op.regions[0].blocks.append(*block_arg_types)
|
||||||
block_arg_mapping = dict(zip(block_arg_names, block.arguments))
|
block_arg_mapping = dict(zip(block_arg_names, block.arguments))
|
||||||
with InsertionPoint(block):
|
with InsertionPoint(block):
|
||||||
body_builder = _BodyBuilder(type_mapping, block_arg_mapping)
|
body_builder = _BodyBuilder(type_mapping, block_arg_mapping,
|
||||||
|
type_fn_attr_mapping)
|
||||||
for assignment in op_config.assignments:
|
for assignment in op_config.assignments:
|
||||||
body_builder.assign(assignment)
|
body_builder.assign(assignment)
|
||||||
body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs))
|
body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs))
|
||||||
|
@ -184,7 +208,8 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
|
||||||
op_class_name: str, *ins: Value, outs: ValueList,
|
op_class_name: str, *ins: Value, outs: ValueList,
|
||||||
**attrs: Sequence[int]):
|
**attrs: Sequence[int]):
|
||||||
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
|
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
|
||||||
indexing_maps_attr, iterator_types_attr, index_attrs, block_arg_types = \
|
indexing_maps_attr, iterator_types_attr, index_attrs, type_fn_attr_mapping, \
|
||||||
|
block_arg_types = \
|
||||||
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
|
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
|
||||||
|
|
||||||
# If we get here, there must exist a builtin class `op_class_name`.
|
# If we get here, there must exist a builtin class `op_class_name`.
|
||||||
|
@ -200,6 +225,11 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
|
||||||
for name, value in index_attrs.items():
|
for name, value in index_attrs.items():
|
||||||
named_op.operation.attributes[name] = value
|
named_op.operation.attributes[name] = value
|
||||||
|
|
||||||
|
# Set the type function attributes.
|
||||||
|
for name, value in type_fn_attr_mapping.items():
|
||||||
|
named_op.operation.attributes[name] = Attribute.parse(
|
||||||
|
f"#linalg.type_fn<{value}>")
|
||||||
|
|
||||||
linalg.fill_builtin_region(named_op.operation)
|
linalg.fill_builtin_region(named_op.operation)
|
||||||
|
|
||||||
if len(result_types) == 1:
|
if len(result_types) == 1:
|
||||||
|
@ -212,9 +242,11 @@ class _BodyBuilder:
|
||||||
"""Constructs a structured op body by evaluating assignments."""
|
"""Constructs a structured op body by evaluating assignments."""
|
||||||
|
|
||||||
def __init__(self, type_mapping: Dict[str, Type],
|
def __init__(self, type_mapping: Dict[str, Type],
|
||||||
block_arg_mapping: Dict[str, Value]):
|
block_arg_mapping: Dict[str, Value],
|
||||||
|
type_fn_attr_mapping: Dict[str, str]):
|
||||||
self.type_mapping = type_mapping
|
self.type_mapping = type_mapping
|
||||||
self.block_arg_mapping = block_arg_mapping
|
self.block_arg_mapping = block_arg_mapping
|
||||||
|
self.type_fn_attr_mapping = type_fn_attr_mapping
|
||||||
self.yield_mapping = dict() # type: Dict[str, Value]
|
self.yield_mapping = dict() # type: Dict[str, Value]
|
||||||
|
|
||||||
def assign(self, assignment: ScalarAssign):
|
def assign(self, assignment: ScalarAssign):
|
||||||
|
@ -245,7 +277,10 @@ class _BodyBuilder:
|
||||||
]
|
]
|
||||||
return fn(*operand_values)
|
return fn(*operand_values)
|
||||||
elif expr.type_fn:
|
elif expr.type_fn:
|
||||||
fn = self._get_function(f"_typefn_{expr.type_fn.fn_name}")
|
fn_name = expr.type_fn.fn_name
|
||||||
|
if expr.type_fn.attr_name:
|
||||||
|
fn_name = self.type_fn_attr_mapping[expr.type_fn.attr_name]
|
||||||
|
fn = self._get_function(f"_typefn_{fn_name}")
|
||||||
operand = self.expression(expr.type_fn.operand)
|
operand = self.expression(expr.type_fn.operand)
|
||||||
return fn(expr.type_fn.type_var.name, operand)
|
return fn(expr.type_fn.type_var.name, operand)
|
||||||
raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
|
raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
|
||||||
|
|
|
@ -46,9 +46,10 @@ class ScalarArithFn:
|
||||||
class ScalarTypeFn:
|
class ScalarTypeFn:
|
||||||
"""A type of ScalarExpression that applies a type conversion function."""
|
"""A type of ScalarExpression that applies a type conversion function."""
|
||||||
|
|
||||||
def __init__(self, fn_name: str, type_var: TypeVar,
|
def __init__(self, fn_name: Optional[str], attr_name: Optional[str],
|
||||||
operand: "ScalarExpression"):
|
type_var: TypeVar, operand: "ScalarExpression"):
|
||||||
self.fn_name = fn_name
|
self.fn_name = fn_name
|
||||||
|
self.attr_name = attr_name
|
||||||
self.type_var = type_var
|
self.type_var = type_var
|
||||||
self.operand = operand
|
self.operand = operand
|
||||||
|
|
||||||
|
@ -56,7 +57,8 @@ class ScalarTypeFn:
|
||||||
return ScalarExpression(type_fn=self)
|
return ScalarExpression(type_fn=self)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"ScalarTypeFn<{self.fn_name}>({self.type_var}, {self.operand})"
|
return (f"ScalarTypeFn<{self.fn_name}[{self.attr_name}]>"
|
||||||
|
f"({self.type_var}, {self.operand})")
|
||||||
|
|
||||||
|
|
||||||
class ScalarArg:
|
class ScalarArg:
|
||||||
|
@ -138,12 +140,15 @@ class ScalarExpression(YAMLObject):
|
||||||
# Note that even though operands must be arity 1, we write it the
|
# Note that even though operands must be arity 1, we write it the
|
||||||
# same way as for apply because it allows handling code to be more
|
# same way as for apply because it allows handling code to be more
|
||||||
# generic vs having a special form.
|
# generic vs having a special form.
|
||||||
return dict(
|
type_fn_dict = dict(
|
||||||
type_fn=dict(
|
type_var=self.type_fn.type_var.name,
|
||||||
fn_name=self.type_fn.fn_name,
|
operands=[self.type_fn.operand],
|
||||||
type_var=self.type_fn.type_var.name,
|
)
|
||||||
operands=[self.type_fn.operand],
|
if self.type_fn.fn_name:
|
||||||
))
|
type_fn_dict["fn_name"] = self.type_fn.fn_name
|
||||||
|
if self.type_fn.attr_name:
|
||||||
|
type_fn_dict["attr_name"] = self.type_fn.attr_name
|
||||||
|
return dict(type_fn=type_fn_dict)
|
||||||
elif self.scalar_arg:
|
elif self.scalar_arg:
|
||||||
return dict(scalar_arg=self.scalar_arg.arg)
|
return dict(scalar_arg=self.scalar_arg.arg)
|
||||||
elif self.scalar_const:
|
elif self.scalar_const:
|
||||||
|
|
|
@ -10,7 +10,8 @@ Batch = S.Batch
|
||||||
def matmul(
|
def matmul(
|
||||||
A=TensorDef(T1, S.M, S.K),
|
A=TensorDef(T1, S.M, S.K),
|
||||||
B=TensorDef(T2, S.K, S.N),
|
B=TensorDef(T2, S.K, S.N),
|
||||||
C=TensorDef(U, S.M, S.N, output=True)):
|
C=TensorDef(U, S.M, S.N, output=True),
|
||||||
|
cast=TypeFnAttrDef(default=TypeFn.cast)):
|
||||||
"""Performs a matrix multiplication of two 2D inputs.
|
"""Performs a matrix multiplication of two 2D inputs.
|
||||||
|
|
||||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||||
|
@ -18,7 +19,7 @@ def matmul(
|
||||||
"""
|
"""
|
||||||
domain(D.m, D.n, D.k)
|
domain(D.m, D.n, D.k)
|
||||||
implements(ContractionOpInterface)
|
implements(ContractionOpInterface)
|
||||||
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
|
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
|
|
|
@ -36,6 +36,20 @@ func @generalize_matmul_tensor_i16i64i32(%A : tensor<16x8xi16>, %B: tensor<8x32x
|
||||||
// CHECK-NEXT: linalg.yield %[[ADD]] : i32
|
// CHECK-NEXT: linalg.yield %[[ADD]] : i32
|
||||||
// CHECK-NEXT: -> tensor<16x32xi32>
|
// CHECK-NEXT: -> tensor<16x32xi32>
|
||||||
|
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Verifies that cast attributes control the cast operations used.
|
||||||
|
func @generalize_matmul_tensor_i16i64i32_unsigned(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
|
||||||
|
%0 = linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
|
||||||
|
ins(%A, %B: tensor<16x8xi16>, tensor<8x32xi64>)
|
||||||
|
outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
|
||||||
|
return %0: tensor<16x32xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @generalize_matmul_tensor_i16i64i32_unsigned
|
||||||
|
// CHECK: = arith.extui
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @generalize_matmul_tensor_i16i64f32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
|
func @generalize_matmul_tensor_i16i64f32(%A : tensor<16x8xi16>, %B: tensor<8x32xi64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
|
||||||
|
|
|
@ -1258,7 +1258,7 @@ class IndexExpr(abc.ABC):
|
||||||
value = self._emit_expression(expr_to_input_opnd, expr_to_info)
|
value = self._emit_expression(expr_to_input_opnd, expr_to_info)
|
||||||
# Emit the structured op representation for the destination tensor.
|
# Emit the structured op representation for the destination tensor.
|
||||||
dst_opnd = _emit_operand(op_def, op_info.dst_indices, op_info.dst_name,
|
dst_opnd = _emit_operand(op_def, op_info.dst_indices, op_info.dst_name,
|
||||||
lang.OperandKind.OutputTensor)
|
lang.OperandKind.OUTPUT_TENSOR)
|
||||||
dst_dim_syms = _mlir_dimensions_from_index_vars(op_info.dst_indices)
|
dst_dim_syms = _mlir_dimensions_from_index_vars(op_info.dst_indices)
|
||||||
dst_use = lang.TensorUse(dst_opnd, dst_dim_syms)
|
dst_use = lang.TensorUse(dst_opnd, dst_dim_syms)
|
||||||
|
|
||||||
|
@ -1893,6 +1893,6 @@ def _emit_structured_op_input(
|
||||||
name = expr.tensor.name
|
name = expr.tensor.name
|
||||||
|
|
||||||
dim_sym = _mlir_symbols_from_index_vars(indices)
|
dim_sym = _mlir_symbols_from_index_vars(indices)
|
||||||
opnd = lang.OperandDef(lang.OperandKind.InputTensor, lang.T, dim_sym)
|
opnd = lang.OperandDef(lang.OperandKind.INPUT_TENSOR, lang.T, dim_sym)
|
||||||
op_def.add_operand(name, opnd)
|
op_def.add_operand(name, opnd)
|
||||||
return opnd
|
return opnd
|
||||||
|
|
|
@ -2,7 +2,8 @@
|
||||||
# RUN: mlir-linalg-ods-yaml-gen %s --o-impl=- | FileCheck %s --check-prefix=IMPL
|
# RUN: mlir-linalg-ods-yaml-gen %s --o-impl=- | FileCheck %s --check-prefix=IMPL
|
||||||
|
|
||||||
# @linalg_structured_op
|
# @linalg_structured_op
|
||||||
# def test1(O=TensorDef(T, S.M, S.N, output=True)):
|
# def test1(O=TensorDef(T, S.M, S.N, output=True),
|
||||||
|
# cast=TypeFnAttrDef(default=TypeFn.cast)):
|
||||||
# """Title.
|
# """Title.
|
||||||
|
|
||||||
# Detailed description.
|
# Detailed description.
|
||||||
|
@ -21,9 +22,13 @@ structured_op: !LinalgStructuredOpConfig
|
||||||
args:
|
args:
|
||||||
- !LinalgOperandDefConfig
|
- !LinalgOperandDefConfig
|
||||||
name: O
|
name: O
|
||||||
usage: Output
|
kind: output_tensor
|
||||||
type_var: T
|
type_var: T
|
||||||
shape_map: affine_map<()[s0, s1] -> (s0, s1)>
|
shape_map: affine_map<()[s0, s1] -> (s0, s1)>
|
||||||
|
- !LinalgOperandDefConfig
|
||||||
|
name: cast
|
||||||
|
kind: type_fn_attr
|
||||||
|
default_fn: cast
|
||||||
indexing_maps: !LinalgIndexingMapsConfig
|
indexing_maps: !LinalgIndexingMapsConfig
|
||||||
static_indexing_maps:
|
static_indexing_maps:
|
||||||
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
|
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
|
||||||
|
@ -39,18 +44,18 @@ structured_op: !LinalgStructuredOpConfig
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
type_fn:
|
type_fn:
|
||||||
fn_name: cast
|
|
||||||
type_var: T
|
type_var: T
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_const: '42 : i64'
|
scalar_const: '42 : i64'
|
||||||
|
attr_name: cast
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
type_fn:
|
type_fn:
|
||||||
fn_name: cast_unsigned
|
|
||||||
type_var: T
|
type_var: T
|
||||||
operands:
|
operands:
|
||||||
- !ScalarExpression
|
- !ScalarExpression
|
||||||
scalar_index: 1
|
scalar_index: 1
|
||||||
|
attr_name: cast
|
||||||
|
|
||||||
# ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1"
|
# ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1"
|
||||||
|
|
||||||
|
@ -61,16 +66,22 @@ structured_op: !LinalgStructuredOpConfig
|
||||||
|
|
||||||
# ODS: let arguments =
|
# ODS: let arguments =
|
||||||
# ODS-NEXT: Variadic<AnyType>:$inputs,
|
# ODS-NEXT: Variadic<AnyType>:$inputs,
|
||||||
# ODS-NEXT: Variadic<AnyShaped>:$outputs
|
# ODS-NEXT: Variadic<AnyShaped>:$outputs,
|
||||||
|
# ODS-NEXT: DefaultValuedAttr<TypeFnAttr, "TypeFn::cast">:$cast
|
||||||
|
|
||||||
# ODS: let builders =
|
# ODS: let builders =
|
||||||
# ODS: (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
# ODS: (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
||||||
# ODS-NEXT: "ValueRange":$outputs,
|
# ODS-NEXT: "ValueRange":$outputs,
|
||||||
# ODS-NEXT: CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
|
# ODS-NEXT: CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
|
||||||
|
|
||||||
|
# ODS: (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
||||||
|
# ODS-NEXT: "ValueRange":$outputs, "Attribute":$cast,
|
||||||
|
# ODS-NEXT: CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
|
||||||
|
|
||||||
# ODS: $_state.addOperands(inputs);
|
# ODS: $_state.addOperands(inputs);
|
||||||
# ODS-NEXT: $_state.addOperands(outputs);
|
# ODS-NEXT: $_state.addOperands(outputs);
|
||||||
# ODS-NEXT: $_state.addTypes(resultTensorTypes);
|
# ODS-NEXT: $_state.addTypes(resultTensorTypes);
|
||||||
|
# ODS-NEXT: $_state.addAttribute("cast", cast)
|
||||||
# ODS-NEXT: $_state.addAttributes(attributes);
|
# ODS-NEXT: $_state.addAttributes(attributes);
|
||||||
# ODS-NEXT: $_state.addAttribute(
|
# ODS-NEXT: $_state.addAttribute(
|
||||||
# ODS-NEXT: "operand_segment_sizes",
|
# ODS-NEXT: "operand_segment_sizes",
|
||||||
|
@ -85,10 +96,18 @@ structured_op: !LinalgStructuredOpConfig
|
||||||
|
|
||||||
# IMPL-LABEL: void Test1Op::regionBuilder(ImplicitLocOpBuilder &b,
|
# IMPL-LABEL: void Test1Op::regionBuilder(ImplicitLocOpBuilder &b,
|
||||||
# IMPL-NEXT: Block &block, ArrayRef<NamedAttribute> attrs)
|
# IMPL-NEXT: Block &block, ArrayRef<NamedAttribute> attrs)
|
||||||
|
# IMPL: TypeFn castVal = TypeFn::cast;
|
||||||
|
# IMPL-NEXT: auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
|
||||||
|
# IMPL-NEXT: return attr.getName() == "cast"; });
|
||||||
|
# IMPL-NEXT: if (castIter != attrs.end()) {
|
||||||
|
# IMPL-NEXT: if (auto attr = castIter->getValue().dyn_cast<TypeFnAttr>())
|
||||||
|
# IMPL-NEXT: castVal = attr.getValue();
|
||||||
|
# IMPL-NEXT: }
|
||||||
|
|
||||||
# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64");
|
# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64");
|
||||||
# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.typefn__cast(block.getArgument(0).getType(), [[VAL0]]);
|
# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL0]]);
|
||||||
# IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1);
|
# IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1);
|
||||||
# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.typefn__cast_unsigned(block.getArgument(0).getType(), [[VAL2]]);
|
# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL2]]);
|
||||||
# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.arithfn__add([[VAL1]], [[VAL3]]);
|
# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.arithfn__add([[VAL1]], [[VAL3]]);
|
||||||
|
|
||||||
|
|
||||||
|
@ -114,19 +133,19 @@ structured_op: !LinalgStructuredOpConfig
|
||||||
args:
|
args:
|
||||||
- !LinalgOperandDefConfig
|
- !LinalgOperandDefConfig
|
||||||
name: I
|
name: I
|
||||||
usage: Input
|
kind: input_tensor
|
||||||
type_var: T
|
type_var: T
|
||||||
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)>
|
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)>
|
||||||
- !LinalgOperandDefConfig
|
- !LinalgOperandDefConfig
|
||||||
name: O
|
name: O
|
||||||
usage: Output
|
kind: output_tensor
|
||||||
type_var: T
|
type_var: T
|
||||||
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)>
|
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)>
|
||||||
- !LinalgOperandDefConfig
|
- !LinalgOperandDefConfig
|
||||||
name: strides
|
name: strides
|
||||||
usage: IndexAttr
|
kind: index_attr
|
||||||
index_attr_map: affine_map<()[s0, s1, s2, s3] -> (s2, s3)>
|
index_attr_map: affine_map<()[s0, s1, s2, s3] -> (s2, s3)>
|
||||||
default_vals:
|
default_indices:
|
||||||
- 1
|
- 1
|
||||||
- 2
|
- 2
|
||||||
indexing_maps: !LinalgIndexingMapsConfig
|
indexing_maps: !LinalgIndexingMapsConfig
|
||||||
|
@ -201,11 +220,11 @@ structured_op: !LinalgStructuredOpConfig
|
||||||
args:
|
args:
|
||||||
- !LinalgOperandDefConfig
|
- !LinalgOperandDefConfig
|
||||||
name: value
|
name: value
|
||||||
usage: Input
|
kind: scalar
|
||||||
type_var: T1
|
type_var: T1
|
||||||
- !LinalgOperandDefConfig
|
- !LinalgOperandDefConfig
|
||||||
name: O
|
name: O
|
||||||
usage: Output
|
kind: output_tensor
|
||||||
type_var: U
|
type_var: U
|
||||||
shape_map: affine_map<() -> ()>
|
shape_map: affine_map<() -> ()>
|
||||||
indexing_maps: !LinalgIndexingMapsConfig
|
indexing_maps: !LinalgIndexingMapsConfig
|
||||||
|
|
|
@ -7,30 +7,34 @@ from mlir.dialects.linalg.opdsl.lang import *
|
||||||
# CHECK-LABEL: matmul
|
# CHECK-LABEL: matmul
|
||||||
# CHECK: args:
|
# CHECK: args:
|
||||||
# CHECK: name: A
|
# CHECK: name: A
|
||||||
# CHECK: usage: Input
|
# CHECK: kind: input_tensor
|
||||||
# CHECK: type_var: T
|
# CHECK: type_var: T
|
||||||
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
|
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
|
||||||
# CHECK: name: B
|
# CHECK: name: B
|
||||||
# CHECK: usage: Input
|
# CHECK: kind: input_tensor
|
||||||
# CHECK: type_var: T
|
# CHECK: type_var: T
|
||||||
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
|
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
|
||||||
# CHECK: name: C
|
# CHECK: name: C
|
||||||
# CHECK: usage: Output
|
# CHECK: kind: output_tensor
|
||||||
# CHECK: type_var: U
|
# CHECK: type_var: U
|
||||||
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
|
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
|
||||||
|
# CHECK: name: cast
|
||||||
|
# CHECK: kind: type_fn_attr
|
||||||
|
# CHECK: default_fn: cast
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
def matmul(
|
def matmul(
|
||||||
A=TensorDef(T, S.M, S.K),
|
A=TensorDef(T, S.M, S.K),
|
||||||
B=TensorDef(T, S.K, S.N),
|
B=TensorDef(T, S.K, S.N),
|
||||||
C=TensorDef(U, S.M, S.N, output=True)):
|
C=TensorDef(U, S.M, S.N, output=True),
|
||||||
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
|
cast=TypeFnAttrDef(default=TypeFn.cast)):
|
||||||
|
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
|
||||||
|
|
||||||
|
|
||||||
# CHECK: ---
|
# CHECK: ---
|
||||||
# CHECK-LABEL: fill
|
# CHECK-LABEL: fill
|
||||||
# CHECK: args:
|
# CHECK: args:
|
||||||
# CHECK: name: value
|
# CHECK: name: value
|
||||||
# CHECK: usage: Input
|
# CHECK: kind: scalar
|
||||||
# CHECK-NOT: shape_map:
|
# CHECK-NOT: shape_map:
|
||||||
# CHECK: type_var: T
|
# CHECK: type_var: T
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
|
@ -42,17 +46,17 @@ def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
|
||||||
# CHECK-LABEL: strided_copy
|
# CHECK-LABEL: strided_copy
|
||||||
# CHECK: args:
|
# CHECK: args:
|
||||||
# CHECK: name: I
|
# CHECK: name: I
|
||||||
# CHECK: usage: Input
|
# CHECK: kind: input_tensor
|
||||||
# CHECK: type_var: T
|
# CHECK: type_var: T
|
||||||
# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1)>
|
# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1)>
|
||||||
# CHECK: name: O
|
# CHECK: name: O
|
||||||
# CHECK: usage: Output
|
# CHECK: kind: output_tensor
|
||||||
# CHECK: type_var: T
|
# CHECK: type_var: T
|
||||||
# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2, s3)>
|
# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2, s3)>
|
||||||
# CHECK: name: strides
|
# CHECK: name: strides
|
||||||
# CHECK: usage: IndexAttr
|
# CHECK: kind: index_attr
|
||||||
# CHECK: index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s5)>
|
# CHECK: index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s5)>
|
||||||
# CHECK: default_vals:
|
# CHECK: default_indices:
|
||||||
# CHECK: - 1
|
# CHECK: - 1
|
||||||
# CHECK: - 2
|
# CHECK: - 2
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
|
|
|
@ -19,16 +19,19 @@ from mlir.dialects.linalg.opdsl.lang import *
|
||||||
# CHECK: type_var: U
|
# CHECK: type_var: U
|
||||||
# CHECK: operands:
|
# CHECK: operands:
|
||||||
# CHECK: scalar_arg: A
|
# CHECK: scalar_arg: A
|
||||||
|
# CHECK: attr_name: cast
|
||||||
# CHECK: type_fn:
|
# CHECK: type_fn:
|
||||||
# CHECK: type_var: U
|
# CHECK: type_var: U
|
||||||
# CHECK: operands:
|
# CHECK: operands:
|
||||||
# CHECK: scalar_arg: B
|
# CHECK: scalar_arg: B
|
||||||
|
# CHECK: attr_name: cast
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
def matmul(
|
def matmul(
|
||||||
A=TensorDef(T, S.M, S.K),
|
A=TensorDef(T, S.M, S.K),
|
||||||
B=TensorDef(T, S.K, S.N),
|
B=TensorDef(T, S.K, S.N),
|
||||||
C=TensorDef(U, S.M, S.N, output=True)):
|
C=TensorDef(U, S.M, S.N, output=True),
|
||||||
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
|
cast=TypeFnAttrDef(default=TypeFn.cast)):
|
||||||
|
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
|
||||||
|
|
||||||
|
|
||||||
# CHECK: ---
|
# CHECK: ---
|
||||||
|
|
|
@ -24,19 +24,10 @@ def matmul_mono(
|
||||||
def matmul_poly(
|
def matmul_poly(
|
||||||
A=TensorDef(T1, S.M, S.K),
|
A=TensorDef(T1, S.M, S.K),
|
||||||
B=TensorDef(T2, S.K, S.N),
|
B=TensorDef(T2, S.K, S.N),
|
||||||
C=TensorDef(U, S.M, S.N, output=True)):
|
C=TensorDef(U, S.M, S.N, output=True),
|
||||||
|
cast=TypeFnAttrDef(default=TypeFn.cast)):
|
||||||
domain(D.m, D.n, D.k)
|
domain(D.m, D.n, D.k)
|
||||||
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
|
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
|
||||||
|
|
||||||
|
|
||||||
@linalg_structured_op
|
|
||||||
def matmul_unsigned_poly(
|
|
||||||
A=TensorDef(T1, S.M, S.K),
|
|
||||||
B=TensorDef(T2, S.K, S.N),
|
|
||||||
C=TensorDef(U, S.M, S.N, output=True)):
|
|
||||||
domain(D.m, D.n, D.k)
|
|
||||||
C[D.m, D.n] += TypeFn.cast_unsigned(U, A[D.m, D.k]) * TypeFn.cast_unsigned(
|
|
||||||
U, B[D.k, D.n])
|
|
||||||
|
|
||||||
|
|
||||||
with Context() as ctx, Location.unknown():
|
with Context() as ctx, Location.unknown():
|
||||||
|
@ -92,7 +83,8 @@ with Context() as ctx, Location.unknown():
|
||||||
RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
|
RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
|
||||||
RankedTensorType.get((4, 8), i32))
|
RankedTensorType.get((4, 8), i32))
|
||||||
def test_i8i8i32_matmul_unsigned(lhs, rhs, init_result):
|
def test_i8i8i32_matmul_unsigned(lhs, rhs, init_result):
|
||||||
return matmul_unsigned_poly(lhs, rhs, outs=[init_result])
|
return matmul_poly(
|
||||||
|
lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned)
|
||||||
|
|
||||||
# CHECK-LABEL: @test_i8i16i32_matmul
|
# CHECK-LABEL: @test_i8i16i32_matmul
|
||||||
# CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32)
|
# CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32)
|
||||||
|
@ -143,7 +135,8 @@ with Context() as ctx, Location.unknown():
|
||||||
RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
|
RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
|
||||||
RankedTensorType.get((4, 8), f32))
|
RankedTensorType.get((4, 8), f32))
|
||||||
def test_i8i8f32_matmul_unsigned(lhs, rhs, init_result):
|
def test_i8i8f32_matmul_unsigned(lhs, rhs, init_result):
|
||||||
return matmul_unsigned_poly(lhs, rhs, outs=[init_result])
|
return matmul_poly(
|
||||||
|
lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned)
|
||||||
|
|
||||||
# CHECK-LABEL: @test_f16f16f32_matmul
|
# CHECK-LABEL: @test_f16f16f32_matmul
|
||||||
# CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
|
# CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
|
||||||
|
|
|
@ -6,6 +6,8 @@ from mlir.dialects import linalg
|
||||||
from mlir.dialects import std
|
from mlir.dialects import std
|
||||||
from mlir.dialects import arith
|
from mlir.dialects import arith
|
||||||
|
|
||||||
|
from mlir.dialects.linalg.opdsl.lang import *
|
||||||
|
|
||||||
|
|
||||||
def run(f):
|
def run(f):
|
||||||
print("\nTEST:", f.__name__)
|
print("\nTEST:", f.__name__)
|
||||||
|
@ -98,12 +100,14 @@ def testNamedStructuredOpCustomForm():
|
||||||
init_result = linalg.InitTensorOp([4, 8], f32)
|
init_result = linalg.InitTensorOp([4, 8], f32)
|
||||||
# First check the named form with custom format
|
# First check the named form with custom format
|
||||||
# CHECK: linalg.matmul
|
# CHECK: linalg.matmul
|
||||||
|
# CHECK: cast = #linalg.type_fn<cast_unsigned>
|
||||||
# CHECK-NOT: linalg.memoized_indexing_maps
|
# CHECK-NOT: linalg.memoized_indexing_maps
|
||||||
# CHECK-SAME: ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>)
|
# CHECK-SAME: ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>)
|
||||||
# CHECK-SAME: outs(%{{.*}} : tensor<4x8xf32>)
|
# CHECK-SAME: outs(%{{.*}} : tensor<4x8xf32>)
|
||||||
# CHECK-SAME: -> tensor<4x8xf32>
|
# CHECK-SAME: -> tensor<4x8xf32>
|
||||||
# CHECK-NEXT: return
|
# CHECK-NEXT: return
|
||||||
return linalg.matmul(lhs, rhs, outs=[init_result.result])
|
return linalg.matmul(
|
||||||
|
lhs, rhs, outs=[init_result.result], cast=TypeFn.cast_unsigned)
|
||||||
|
|
||||||
print(module)
|
print(module)
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,8 @@ from mlir.dialects import std
|
||||||
from mlir.passmanager import *
|
from mlir.passmanager import *
|
||||||
from mlir.execution_engine import *
|
from mlir.execution_engine import *
|
||||||
|
|
||||||
|
from mlir.dialects.linalg.opdsl.lang import *
|
||||||
|
|
||||||
|
|
||||||
# Log everything to stderr and flush so that we have a unified stream to match
|
# Log everything to stderr and flush so that we have a unified stream to match
|
||||||
# errors/info emitted by MLIR to stderr.
|
# errors/info emitted by MLIR to stderr.
|
||||||
|
@ -20,21 +22,28 @@ def log(*args):
|
||||||
matmul_boiler = """
|
matmul_boiler = """
|
||||||
func @main() -> f32 attributes {llvm.emit_c_interface} {
|
func @main() -> f32 attributes {llvm.emit_c_interface} {
|
||||||
%v0 = arith.constant 0.0 : f32
|
%v0 = arith.constant 0.0 : f32
|
||||||
%v1 = arith.constant 1.0 : f32
|
%v1 = arith.constant -1 : i8
|
||||||
%v2 = arith.constant 2.0 : f32
|
%v2 = arith.constant 2.0 : f32
|
||||||
|
|
||||||
%A = memref.alloc() : memref<4x16xf32>
|
%A = memref.alloc() : memref<4x16xi8>
|
||||||
%B = memref.alloc() : memref<16x8xf32>
|
%B = memref.alloc() : memref<16x8xf32>
|
||||||
%C = memref.alloc() : memref<4x8xf32>
|
%C0 = memref.alloc() : memref<4x8xf32>
|
||||||
linalg.fill(%v1, %A) : f32, memref<4x16xf32>
|
%C1 = memref.alloc() : memref<4x8xf32>
|
||||||
|
linalg.fill(%v1, %A) : i8, memref<4x16xi8>
|
||||||
linalg.fill(%v2, %B) : f32, memref<16x8xf32>
|
linalg.fill(%v2, %B) : f32, memref<16x8xf32>
|
||||||
linalg.fill(%v0, %C) : f32, memref<4x8xf32>
|
linalg.fill(%v0, %C0) : f32, memref<4x8xf32>
|
||||||
|
linalg.fill(%v0, %C1) : f32, memref<4x8xf32>
|
||||||
|
|
||||||
call @matmul_on_buffers(%A, %B, %C) :
|
call @matmul_signed_on_buffers(%A, %B, %C0) :
|
||||||
(memref<4x16xf32>, memref<16x8xf32>, memref<4x8xf32>) -> ()
|
(memref<4x16xi8>, memref<16x8xf32>, memref<4x8xf32>) -> ()
|
||||||
|
call @matmul_unsigned_on_buffers(%A, %B, %C1) :
|
||||||
|
(memref<4x16xi8>, memref<16x8xf32>, memref<4x8xf32>) -> ()
|
||||||
|
|
||||||
%c0 = arith.constant 0 : index
|
%c0 = arith.constant 0 : index
|
||||||
%0 = memref.load %C[%c0, %c0] : memref<4x8xf32>
|
%res0 = memref.load %C0[%c0, %c0] : memref<4x8xf32>
|
||||||
|
%res1 = memref.load %C1[%c0, %c0] : memref<4x8xf32>
|
||||||
|
|
||||||
|
%0 = arith.addf %res0, %res1 : f32
|
||||||
|
|
||||||
// TODO: FFI-based solution to allow testing and printing with python code.
|
// TODO: FFI-based solution to allow testing and printing with python code.
|
||||||
return %0 : f32
|
return %0 : f32
|
||||||
|
@ -157,8 +166,8 @@ def transform(module, boilerplate):
|
||||||
|
|
||||||
pm = PassManager.parse(
|
pm = PassManager.parse(
|
||||||
"builtin.func(convert-linalg-to-loops, lower-affine, " +
|
"builtin.func(convert-linalg-to-loops, lower-affine, " +
|
||||||
"convert-scf-to-cf, arith-expand, memref-expand), convert-vector-to-llvm," +
|
"convert-scf-to-cf, arith-expand, memref-expand), convert-vector-to-llvm,"
|
||||||
"convert-memref-to-llvm, convert-std-to-llvm," +
|
+ "convert-memref-to-llvm, convert-std-to-llvm," +
|
||||||
"reconcile-unrealized-casts")
|
"reconcile-unrealized-casts")
|
||||||
pm.run(mod)
|
pm.run(mod)
|
||||||
return mod
|
return mod
|
||||||
|
@ -168,14 +177,21 @@ def test_matmul_builtin():
|
||||||
with Context() as ctx, Location.unknown():
|
with Context() as ctx, Location.unknown():
|
||||||
module = Module.create()
|
module = Module.create()
|
||||||
f32 = F32Type.get()
|
f32 = F32Type.get()
|
||||||
|
i8 = IntegerType.get_signless(8)
|
||||||
with InsertionPoint(module.body):
|
with InsertionPoint(module.body):
|
||||||
|
|
||||||
@builtin.FuncOp.from_py_func(
|
@builtin.FuncOp.from_py_func(
|
||||||
MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32),
|
MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32),
|
||||||
MemRefType.get((4, 8), f32))
|
MemRefType.get((4, 8), f32))
|
||||||
def matmul_on_buffers(lhs, rhs, out):
|
def matmul_signed_on_buffers(lhs, rhs, out):
|
||||||
linalg.matmul(lhs, rhs, outs=[out])
|
linalg.matmul(lhs, rhs, outs=[out])
|
||||||
|
|
||||||
|
@builtin.FuncOp.from_py_func(
|
||||||
|
MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32),
|
||||||
|
MemRefType.get((4, 8), f32))
|
||||||
|
def matmul_unsigned_on_buffers(lhs, rhs, out):
|
||||||
|
linalg.matmul(lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned)
|
||||||
|
|
||||||
execution_engine = ExecutionEngine(transform(module, matmul_boiler))
|
execution_engine = ExecutionEngine(transform(module, matmul_boiler))
|
||||||
|
|
||||||
# TODO: FFI-based solution to allow testing and printing with python code.
|
# TODO: FFI-based solution to allow testing and printing with python code.
|
||||||
|
@ -186,7 +202,9 @@ def test_matmul_builtin():
|
||||||
execution_engine.invoke("main", res)
|
execution_engine.invoke("main", res)
|
||||||
|
|
||||||
log("RESULT: ", res[0])
|
log("RESULT: ", res[0])
|
||||||
# CHECK: RESULT: 32.0
|
# matmul_signed_on_buffers: -1 * 2.0 * 16 = -32
|
||||||
|
# matmul_unsigned_on_buffers: (2^8-1) * 2.0 * 16 = 8160
|
||||||
|
# CHECK: RESULT: 8128
|
||||||
|
|
||||||
|
|
||||||
test_matmul_builtin()
|
test_matmul_builtin()
|
||||||
|
@ -196,14 +214,22 @@ def test_matmul_generic():
|
||||||
with Context() as ctx, Location.unknown():
|
with Context() as ctx, Location.unknown():
|
||||||
module = Module.create()
|
module = Module.create()
|
||||||
f32 = F32Type.get()
|
f32 = F32Type.get()
|
||||||
|
i8 = IntegerType.get_signless(8)
|
||||||
with InsertionPoint(module.body):
|
with InsertionPoint(module.body):
|
||||||
|
|
||||||
@builtin.FuncOp.from_py_func(
|
@builtin.FuncOp.from_py_func(
|
||||||
MemRefType.get((4, 16), f32), MemRefType.get((16, 8), f32),
|
MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32),
|
||||||
MemRefType.get((4, 8), f32))
|
MemRefType.get((4, 8), f32))
|
||||||
def matmul_on_buffers(lhs, rhs, out):
|
def matmul_signed_on_buffers(lhs, rhs, out):
|
||||||
linalg.matmul(lhs, rhs, outs=[out], emit_generic=True)
|
linalg.matmul(lhs, rhs, outs=[out], emit_generic=True)
|
||||||
|
|
||||||
|
@builtin.FuncOp.from_py_func(
|
||||||
|
MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32),
|
||||||
|
MemRefType.get((4, 8), f32))
|
||||||
|
def matmul_unsigned_on_buffers(lhs, rhs, out):
|
||||||
|
linalg.matmul(
|
||||||
|
lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned, emit_generic=True)
|
||||||
|
|
||||||
execution_engine = ExecutionEngine(transform(module, matmul_boiler))
|
execution_engine = ExecutionEngine(transform(module, matmul_boiler))
|
||||||
|
|
||||||
# TODO: FFI-based solution to allow testing and printing with python code.
|
# TODO: FFI-based solution to allow testing and printing with python code.
|
||||||
|
@ -214,7 +240,9 @@ def test_matmul_generic():
|
||||||
execution_engine.invoke("main", res)
|
execution_engine.invoke("main", res)
|
||||||
|
|
||||||
log("RESULT: ", res[0])
|
log("RESULT: ", res[0])
|
||||||
# CHECK: RESULT: 32.0
|
# matmul_signed_on_buffers = -1 * 2.0 * 16 = -32
|
||||||
|
# matmul_unsigned_on_buffers = (2^8-1) * 2.0 * 16 = 8160
|
||||||
|
# CHECK: RESULT: 8128
|
||||||
|
|
||||||
|
|
||||||
test_matmul_generic()
|
test_matmul_generic()
|
||||||
|
@ -423,11 +451,7 @@ def test_min_pooling_builtin():
|
||||||
MemRefType.get((1, 2, 4, 1), i32))
|
MemRefType.get((1, 2, 4, 1), i32))
|
||||||
# Set the strides and use the default dilations.
|
# Set the strides and use the default dilations.
|
||||||
def pooling_on_buffers(input, shape, output):
|
def pooling_on_buffers(input, shape, output):
|
||||||
linalg.pooling_nhwc_min(
|
linalg.pooling_nhwc_min(input, shape, outs=[output], strides=[2, 4])
|
||||||
input,
|
|
||||||
shape,
|
|
||||||
outs=[output],
|
|
||||||
strides=[2, 4])
|
|
||||||
|
|
||||||
execution_engine = ExecutionEngine(transform(module, pooling_boiler))
|
execution_engine = ExecutionEngine(transform(module, pooling_boiler))
|
||||||
|
|
||||||
|
@ -458,11 +482,7 @@ def test_min_pooling_generic():
|
||||||
# Set the strides and use the default dilations.
|
# Set the strides and use the default dilations.
|
||||||
def pooling_on_buffers(input, shape, output):
|
def pooling_on_buffers(input, shape, output):
|
||||||
linalg.pooling_nhwc_min(
|
linalg.pooling_nhwc_min(
|
||||||
input,
|
input, shape, outs=[output], strides=[2, 4], emit_generic=True)
|
||||||
shape,
|
|
||||||
outs=[output],
|
|
||||||
strides=[2, 4],
|
|
||||||
emit_generic=True)
|
|
||||||
|
|
||||||
execution_engine = ExecutionEngine(transform(module, pooling_boiler))
|
execution_engine = ExecutionEngine(transform(module, pooling_boiler))
|
||||||
|
|
||||||
|
|
|
@ -61,15 +61,22 @@ struct SerializedAffineMap {
|
||||||
AffineMap affineMap() { return affineMapAttr.getValue(); }
|
AffineMap affineMap() { return affineMapAttr.getValue(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
enum class LinalgOperandDefUsage { Input, Output, IndexAttr };
|
enum class LinalgOperandDefKind {
|
||||||
|
InputTensor,
|
||||||
|
Scalar,
|
||||||
|
OutputTensor,
|
||||||
|
IndexAttr,
|
||||||
|
TypeFnAttr
|
||||||
|
};
|
||||||
|
|
||||||
struct LinalgOperandDef {
|
struct LinalgOperandDef {
|
||||||
std::string name;
|
std::string name;
|
||||||
LinalgOperandDefUsage usage;
|
LinalgOperandDefKind kind;
|
||||||
Optional<std::string> typeVar;
|
Optional<std::string> typeVar;
|
||||||
Optional<SerializedAffineMap> shapeMap;
|
Optional<SerializedAffineMap> shapeMap;
|
||||||
Optional<SerializedAffineMap> indexAttrMap;
|
Optional<SerializedAffineMap> indexAttrMap;
|
||||||
Optional<SmallVector<int64_t>> defaultVals;
|
Optional<SmallVector<int64_t>> defaultIndices;
|
||||||
|
Optional<std::string> defaultFn;
|
||||||
};
|
};
|
||||||
|
|
||||||
enum class LinalgIteratorTypeDef {
|
enum class LinalgIteratorTypeDef {
|
||||||
|
@ -91,11 +98,12 @@ struct ScalarArithFn {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ScalarTypeFn {
|
struct ScalarTypeFn {
|
||||||
std::string fnName;
|
|
||||||
std::string typeVar;
|
std::string typeVar;
|
||||||
// NOTE: This must be of arity 1, but to break the self-referential cycle,
|
// NOTE: This must be of arity 1, but to break the self-referential cycle,
|
||||||
// we use a heap allocated vector.
|
// we use a heap allocated vector.
|
||||||
std::vector<ScalarExpression> operands;
|
std::vector<ScalarExpression> operands;
|
||||||
|
Optional<std::string> fnName;
|
||||||
|
Optional<std::string> attrName;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ScalarExpression {
|
struct ScalarExpression {
|
||||||
|
@ -180,27 +188,32 @@ struct MappingTraits<LinalgStructuredOpConfig> {
|
||||||
/// index attribute symbols. During op creation these symbols are replaced
|
/// index attribute symbols. During op creation these symbols are replaced
|
||||||
/// by the corresponding `name` index attribue values. Only index attribute
|
/// by the corresponding `name` index attribue values. Only index attribute
|
||||||
/// arguments have an `index_attr_map`.
|
/// arguments have an `index_attr_map`.
|
||||||
/// - `default_vals`: An optional default initialization for index attribute
|
/// - `default_indices`: An optional default initialization for index
|
||||||
|
/// attribute arguments.
|
||||||
|
/// - `default_fn`: An optional default initialization for function attribute
|
||||||
/// arguments.
|
/// arguments.
|
||||||
template <>
|
template <>
|
||||||
struct MappingTraits<LinalgOperandDef> {
|
struct MappingTraits<LinalgOperandDef> {
|
||||||
static void mapping(IO &io, LinalgOperandDef &info) {
|
static void mapping(IO &io, LinalgOperandDef &info) {
|
||||||
io.mapRequired("name", info.name);
|
io.mapRequired("name", info.name);
|
||||||
io.mapRequired("usage", info.usage);
|
io.mapRequired("kind", info.kind);
|
||||||
io.mapOptional("type_var", info.typeVar);
|
io.mapOptional("type_var", info.typeVar);
|
||||||
io.mapOptional("shape_map", info.shapeMap);
|
io.mapOptional("shape_map", info.shapeMap);
|
||||||
io.mapOptional("index_attr_map", info.indexAttrMap);
|
io.mapOptional("index_attr_map", info.indexAttrMap);
|
||||||
io.mapOptional("default_vals", info.defaultVals);
|
io.mapOptional("default_indices", info.defaultIndices);
|
||||||
|
io.mapOptional("default_fn", info.defaultFn);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Usage enum for a named argument.
|
/// Usage enum for a named argument.
|
||||||
template <>
|
template <>
|
||||||
struct ScalarEnumerationTraits<LinalgOperandDefUsage> {
|
struct ScalarEnumerationTraits<LinalgOperandDefKind> {
|
||||||
static void enumeration(IO &io, LinalgOperandDefUsage &value) {
|
static void enumeration(IO &io, LinalgOperandDefKind &value) {
|
||||||
io.enumCase(value, "Input", LinalgOperandDefUsage::Input);
|
io.enumCase(value, "input_tensor", LinalgOperandDefKind::InputTensor);
|
||||||
io.enumCase(value, "Output", LinalgOperandDefUsage::Output);
|
io.enumCase(value, "scalar", LinalgOperandDefKind::Scalar);
|
||||||
io.enumCase(value, "IndexAttr", LinalgOperandDefUsage::IndexAttr);
|
io.enumCase(value, "output_tensor", LinalgOperandDefKind::OutputTensor);
|
||||||
|
io.enumCase(value, "index_attr", LinalgOperandDefKind::IndexAttr);
|
||||||
|
io.enumCase(value, "type_fn_attr", LinalgOperandDefKind::TypeFnAttr);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -281,9 +294,10 @@ struct MappingTraits<ScalarArithFn> {
|
||||||
template <>
|
template <>
|
||||||
struct MappingTraits<ScalarTypeFn> {
|
struct MappingTraits<ScalarTypeFn> {
|
||||||
static void mapping(IO &io, ScalarTypeFn &info) {
|
static void mapping(IO &io, ScalarTypeFn &info) {
|
||||||
io.mapRequired("fn_name", info.fnName);
|
|
||||||
io.mapRequired("type_var", info.typeVar);
|
io.mapRequired("type_var", info.typeVar);
|
||||||
io.mapRequired("operands", info.operands);
|
io.mapRequired("operands", info.operands);
|
||||||
|
io.mapOptional("fn_name", info.fnName);
|
||||||
|
io.mapOptional("attr_name", info.attrName);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -399,8 +413,9 @@ findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgOperandDef> &args) {
|
||||||
|
|
||||||
// Search all argument types.
|
// Search all argument types.
|
||||||
for (const auto &it : llvm::enumerate(args)) {
|
for (const auto &it : llvm::enumerate(args)) {
|
||||||
if (it.value().usage != LinalgOperandDefUsage::Input &&
|
if (it.value().kind != LinalgOperandDefKind::InputTensor &&
|
||||||
it.value().usage != LinalgOperandDefUsage::Output)
|
it.value().kind != LinalgOperandDefKind::Scalar &&
|
||||||
|
it.value().kind != LinalgOperandDefKind::OutputTensor)
|
||||||
continue;
|
continue;
|
||||||
if (it.value().typeVar.getValue() == typeVar)
|
if (it.value().typeVar.getValue() == typeVar)
|
||||||
return llvm::formatv("block.getArgument({0}).getType()", it.index())
|
return llvm::formatv("block.getArgument({0}).getType()", it.index())
|
||||||
|
@ -552,6 +567,8 @@ static const char structuredOpBuilderFormat[] = R"FMT(
|
||||||
$_state.addOperands(inputs);
|
$_state.addOperands(inputs);
|
||||||
$_state.addOperands(outputs);
|
$_state.addOperands(outputs);
|
||||||
$_state.addTypes(resultTensorTypes);
|
$_state.addTypes(resultTensorTypes);
|
||||||
|
{2}
|
||||||
|
$_state.addAttributes(attributes);
|
||||||
$_state.addAttribute(
|
$_state.addAttribute(
|
||||||
"operand_segment_sizes",
|
"operand_segment_sizes",
|
||||||
$_builder.getI32VectorAttr({{
|
$_builder.getI32VectorAttr({{
|
||||||
|
@ -562,8 +579,6 @@ static const char structuredOpBuilderFormat[] = R"FMT(
|
||||||
$_state,
|
$_state,
|
||||||
TypeRange(inputs),
|
TypeRange(inputs),
|
||||||
TypeRange(outputs));
|
TypeRange(outputs));
|
||||||
{2}
|
|
||||||
$_state.addAttributes(attributes);
|
|
||||||
}]>
|
}]>
|
||||||
)FMT";
|
)FMT";
|
||||||
|
|
||||||
|
@ -681,42 +696,56 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
|
||||||
|
|
||||||
interfaceNameList = interleaveToString(opConfig.metadata->implements, ", ");
|
interfaceNameList = interleaveToString(opConfig.metadata->implements, ", ");
|
||||||
|
|
||||||
// Assemble the attribute specific logic required for the op definition.
|
|
||||||
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
|
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
|
||||||
return arg.usage == LinalgOperandDefUsage::IndexAttr;
|
return arg.kind == LinalgOperandDefKind::IndexAttr ||
|
||||||
|
arg.kind == LinalgOperandDefKind::TypeFnAttr;
|
||||||
})) {
|
})) {
|
||||||
SmallVector<std::string> attrDefs;
|
SmallVector<std::string> attrDefs;
|
||||||
SmallVector<std::string> attrParams;
|
SmallVector<std::string> attrParams;
|
||||||
SmallVector<std::string> attrStmts;
|
SmallVector<std::string> attrStmts;
|
||||||
for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
|
for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
|
||||||
if (arg.usage != LinalgOperandDefUsage::IndexAttr)
|
|
||||||
continue;
|
|
||||||
assert(arg.indexAttrMap.hasValue());
|
|
||||||
assert(arg.defaultVals.hasValue());
|
|
||||||
size_t size = arg.indexAttrMap->affineMap().getNumResults();
|
|
||||||
assert(arg.defaultVals.getValue().size() == size);
|
|
||||||
static const char typeFmt[] = "RankedI64ElementsAttr<[{0}]>";
|
|
||||||
static const char defFmt[] = "DefaultValuedAttr<{0}, \"{1}\">:${2}";
|
|
||||||
static const char paramFmt[] = "\"Attribute\":${0}";
|
static const char paramFmt[] = "\"Attribute\":${0}";
|
||||||
static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});";
|
static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});";
|
||||||
std::string defaultVals;
|
// Add the type conversion attributes to the op definition and builders.
|
||||||
llvm::raw_string_ostream ss(defaultVals);
|
if (arg.kind == LinalgOperandDefKind::TypeFnAttr) {
|
||||||
ss << "{ ";
|
assert(arg.defaultFn.hasValue());
|
||||||
llvm::interleave(
|
static const char typeFmt[] = "TypeFn::{0}";
|
||||||
arg.defaultVals.getValue(), ss,
|
static const char defFmt[] = "DefaultValuedAttr<{0}, \"{1}\">:${2}";
|
||||||
[&](int64_t val) { ss << "static_cast<int64_t>(" << val << ")"; },
|
attrDefs.push_back(llvm::formatv(defFmt, "TypeFnAttr",
|
||||||
", ");
|
llvm::formatv(typeFmt, arg.defaultFn),
|
||||||
ss << " }";
|
arg.name));
|
||||||
attrDefs.push_back(llvm::formatv(defFmt, llvm::formatv(typeFmt, size),
|
attrParams.push_back(llvm::formatv(paramFmt, arg.name));
|
||||||
ss.str(), arg.name));
|
attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
|
||||||
attrParams.push_back(llvm::formatv(paramFmt, arg.name));
|
}
|
||||||
attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
|
// Add the index attributes to the op definition and builders.
|
||||||
|
if (arg.kind == LinalgOperandDefKind::IndexAttr) {
|
||||||
|
assert(arg.indexAttrMap.hasValue());
|
||||||
|
assert(arg.defaultIndices.hasValue());
|
||||||
|
size_t size = arg.indexAttrMap->affineMap().getNumResults();
|
||||||
|
assert(arg.defaultIndices.getValue().size() == size);
|
||||||
|
static const char typeFmt[] = "RankedI64ElementsAttr<[{0}]>";
|
||||||
|
static const char defFmt[] = "DefaultValuedAttr<{0}, \"{ {1} }\">:${2}";
|
||||||
|
std::string defaultVals;
|
||||||
|
llvm::raw_string_ostream ss(defaultVals);
|
||||||
|
llvm::interleave(
|
||||||
|
arg.defaultIndices.getValue(), ss,
|
||||||
|
[&](int64_t val) { ss << "static_cast<int64_t>(" << val << ")"; },
|
||||||
|
", ");
|
||||||
|
attrDefs.push_back(llvm::formatv(defFmt, llvm::formatv(typeFmt, size),
|
||||||
|
ss.str(), arg.name));
|
||||||
|
attrParams.push_back(llvm::formatv(paramFmt, arg.name));
|
||||||
|
attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
|
||||||
|
return arg.kind == LinalgOperandDefKind::IndexAttr;
|
||||||
|
})) {
|
||||||
|
attrMethods = R"(
|
||||||
|
bool hasDynamicIndexingMaps();
|
||||||
|
LogicalResult verifyIndexingMapRequiredAttributes();
|
||||||
|
)";
|
||||||
}
|
}
|
||||||
attrList = ",\n" + llvm::join(attrDefs, ",\n");
|
attrList = ",\n" + llvm::join(attrDefs, ",\n");
|
||||||
attrMethods = R"(
|
|
||||||
bool hasDynamicIndexingMaps();
|
|
||||||
LogicalResult verifyIndexingMapRequiredAttributes();
|
|
||||||
)";
|
|
||||||
attrBuilder = llvm::formatv(
|
attrBuilder = llvm::formatv(
|
||||||
structuredOpBuilderFormat, opConfig.metadata->cppClassName,
|
structuredOpBuilderFormat, opConfig.metadata->cppClassName,
|
||||||
llvm::join(attrParams, ", "), llvm::join(attrStmts, "\n"));
|
llvm::join(attrParams, ", "), llvm::join(attrStmts, "\n"));
|
||||||
|
@ -746,7 +775,9 @@ generateNamedGenericOpDefns(LinalgOpConfig &opConfig,
|
||||||
// Compute the number of scalar and tensor arguments.
|
// Compute the number of scalar and tensor arguments.
|
||||||
int64_t numOfArgs =
|
int64_t numOfArgs =
|
||||||
llvm::count_if(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
|
llvm::count_if(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
|
||||||
return arg.usage != LinalgOperandDefUsage::IndexAttr;
|
return arg.kind == LinalgOperandDefKind::InputTensor ||
|
||||||
|
arg.kind == LinalgOperandDefKind::Scalar ||
|
||||||
|
arg.kind == LinalgOperandDefKind::OutputTensor;
|
||||||
});
|
});
|
||||||
|
|
||||||
// An operation that accesses only scalars and scalar/rank zero tensors is
|
// An operation that accesses only scalars and scalar/rank zero tensors is
|
||||||
|
@ -817,7 +848,7 @@ exprs.push_back(getAffineConstantExpr(cst{1}, context));
|
||||||
)FMT";
|
)FMT";
|
||||||
// Update all symbol bindings mapped to an attribute.
|
// Update all symbol bindings mapped to an attribute.
|
||||||
for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
|
for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
|
||||||
if (arg.usage != LinalgOperandDefUsage::IndexAttr)
|
if (arg.kind != LinalgOperandDefKind::IndexAttr)
|
||||||
continue;
|
continue;
|
||||||
assert(arg.indexAttrMap.hasValue());
|
assert(arg.indexAttrMap.hasValue());
|
||||||
for (auto &en :
|
for (auto &en :
|
||||||
|
@ -910,11 +941,11 @@ std::string {0}::getLibraryCallName() {{
|
||||||
|
|
||||||
// hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes()
|
// hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes()
|
||||||
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
|
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
|
||||||
return arg.usage == LinalgOperandDefUsage::IndexAttr;
|
return arg.kind == LinalgOperandDefKind::IndexAttr;
|
||||||
})) {
|
})) {
|
||||||
std::vector<std::string> attrVerifications;
|
std::vector<std::string> attrVerifications;
|
||||||
for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
|
for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
|
||||||
if (arg.usage != LinalgOperandDefUsage::IndexAttr)
|
if (arg.kind != LinalgOperandDefKind::IndexAttr)
|
||||||
continue;
|
continue;
|
||||||
assert(arg.indexAttrMap.hasValue());
|
assert(arg.indexAttrMap.hasValue());
|
||||||
// Verify index attribute. Paramters:
|
// Verify index attribute. Paramters:
|
||||||
|
@ -952,7 +983,8 @@ LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{
|
||||||
// Generates a regionBuilder method. Parameters.
|
// Generates a regionBuilder method. Parameters.
|
||||||
// {0}: Class name
|
// {0}: Class name
|
||||||
// {1}: Number of args
|
// {1}: Number of args
|
||||||
// {2}: Statements
|
// {2}: Attributes
|
||||||
|
// {3}: Statements
|
||||||
static const char structuredOpRegionBuilderFormat[] = R"FMT(
|
static const char structuredOpRegionBuilderFormat[] = R"FMT(
|
||||||
void {0}::regionBuilder(ImplicitLocOpBuilder &b,
|
void {0}::regionBuilder(ImplicitLocOpBuilder &b,
|
||||||
Block &block, ArrayRef<NamedAttribute> attrs) {{
|
Block &block, ArrayRef<NamedAttribute> attrs) {{
|
||||||
|
@ -961,6 +993,7 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b,
|
||||||
RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
|
RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
|
||||||
SmallVector<Value> yields;
|
SmallVector<Value> yields;
|
||||||
{2}
|
{2}
|
||||||
|
{3}
|
||||||
helper.yieldOutputs(yields);
|
helper.yieldOutputs(yields);
|
||||||
}
|
}
|
||||||
)FMT";
|
)FMT";
|
||||||
|
@ -968,9 +1001,27 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b,
|
||||||
auto &assignments = opConfig.structuredOp->assignments;
|
auto &assignments = opConfig.structuredOp->assignments;
|
||||||
size_t generatedAssignmentCount = 0;
|
size_t generatedAssignmentCount = 0;
|
||||||
int localCounter = 0;
|
int localCounter = 0;
|
||||||
|
SmallVector<std::string> attrs;
|
||||||
SmallVector<std::string> stmts;
|
SmallVector<std::string> stmts;
|
||||||
for (LinalgOperandDef &arg : args) {
|
for (LinalgOperandDef &arg : args) {
|
||||||
if (arg.usage != LinalgOperandDefUsage::Output)
|
if (arg.kind != LinalgOperandDefKind::TypeFnAttr)
|
||||||
|
continue;
|
||||||
|
// Obtain the type function attribute values. Parameters.
|
||||||
|
// {0}: attribute name
|
||||||
|
// {1}: default type function name
|
||||||
|
static const char attrDef[] = R"FMT(
|
||||||
|
TypeFn {0}Val = TypeFn::{1};
|
||||||
|
auto {0}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
|
||||||
|
return attr.getName() == "{0}"; });
|
||||||
|
if ({0}Iter != attrs.end()) {{
|
||||||
|
if (auto attr = {0}Iter->getValue().dyn_cast<TypeFnAttr>())
|
||||||
|
{0}Val = attr.getValue();
|
||||||
|
}
|
||||||
|
)FMT";
|
||||||
|
attrs.push_back(llvm::formatv(attrDef, arg.name, arg.defaultFn));
|
||||||
|
}
|
||||||
|
for (LinalgOperandDef &arg : args) {
|
||||||
|
if (arg.kind != LinalgOperandDefKind::OutputTensor)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
// Find the assignment that correlates with the argument.
|
// Find the assignment that correlates with the argument.
|
||||||
|
@ -1048,11 +1099,25 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b,
|
||||||
<< "an argument type but it does not";
|
<< "an argument type but it does not";
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use the function name or the attribute to build the type function.
|
||||||
|
std::string typeFunc = llvm::formatv(
|
||||||
|
"TypeFn::{0}", expression.typeFn->fnName.getValueOr(""));
|
||||||
|
if (expression.typeFn->attrName) {
|
||||||
|
if (llvm::none_of(args, [&](LinalgOperandDef &arg) {
|
||||||
|
return arg.kind == LinalgOperandDefKind::TypeFnAttr &&
|
||||||
|
arg.name == expression.typeFn->attrName.getValue();
|
||||||
|
})) {
|
||||||
|
emitError(genContext.getLoc())
|
||||||
|
<< "missing type function attribute "
|
||||||
|
<< expression.typeFn->attrName.getValue();
|
||||||
|
}
|
||||||
|
typeFunc = llvm::formatv("{0}Val", *expression.typeFn->attrName);
|
||||||
|
}
|
||||||
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
|
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
|
||||||
stmts.push_back(
|
stmts.push_back(llvm::formatv(
|
||||||
llvm::formatv("Value {0} = helper.typefn__{1}({2}, {3});",
|
"Value {0} = helper.buildTypeFn({1}, {2}, {3});", cppIdent,
|
||||||
cppIdent, expression.typeFn->fnName,
|
typeFunc, typeCppValue.getValue(), *operandCppValue));
|
||||||
typeCppValue.getValue(), *operandCppValue));
|
|
||||||
return cppIdent;
|
return cppIdent;
|
||||||
}
|
}
|
||||||
emitError(genContext.getLoc()) << "unknown ScalarExpression type";
|
emitError(genContext.getLoc()) << "unknown ScalarExpression type";
|
||||||
|
@ -1069,6 +1134,7 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b,
|
||||||
<< "mismatched number of assignments vs output arguments";
|
<< "mismatched number of assignments vs output arguments";
|
||||||
|
|
||||||
os << llvm::formatv(structuredOpRegionBuilderFormat, className, numOfArgs,
|
os << llvm::formatv(structuredOpRegionBuilderFormat, className, numOfArgs,
|
||||||
|
interleaveToString(attrs, "\n "),
|
||||||
interleaveToString(stmts, "\n "));
|
interleaveToString(stmts, "\n "));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6720,6 +6720,22 @@ gentbl_cc_library(
|
||||||
],
|
],
|
||||||
"include/mlir/Dialect/Linalg/IR/LinalgOpsDialect.cpp.inc",
|
"include/mlir/Dialect/Linalg/IR/LinalgOpsDialect.cpp.inc",
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
["-gen-enum-decls"],
|
||||||
|
"include/mlir/Dialect/Linalg/IR/LinalgOpsEnums.h.inc",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
["-gen-enum-defs"],
|
||||||
|
"include/mlir/Dialect/Linalg/IR/LinalgOpsEnums.cpp.inc",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
["-gen-attrdef-decls"],
|
||||||
|
"include/mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.h.inc",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
["-gen-attrdef-defs"],
|
||||||
|
"include/mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.cpp.inc",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
tblgen = ":mlir-tblgen",
|
tblgen = ":mlir-tblgen",
|
||||||
td_file = "include/mlir/Dialect/Linalg/IR/LinalgOps.td",
|
td_file = "include/mlir/Dialect/Linalg/IR/LinalgOps.td",
|
||||||
|
|
Loading…
Reference in New Issue