[mlir][OpDSL] Add arithmetic function attributes.

The revision extends OpDSL with unary and binary function attributes. A function attribute, makes the operations used in the body of a structured operation configurable. For example, a pooling operation may take an aggregation function attribute that specifies if the op shall implement a min or a max pooling. The goal of this revision is to define less and more flexible operations.

We may thus for example define an element wise op:
```
linalg.elem(lhs, rhs, outs=[out], op=BinaryFn.mul)
```
If the op argument is not set the default operation is used.

Depends On D120109

Reviewed By: nicolasvasilache, aartbik

Differential Revision: https://reviews.llvm.org/D120110
This commit is contained in:
gysit 2022-03-01 07:40:06 +00:00
parent 5d91a8a707
commit 24357fec8d
17 changed files with 807 additions and 362 deletions

View File

@ -107,12 +107,12 @@ copy_and_scale(val, in_tensor, outs=[out_tensor])
## Index Attributes
Attributes are compile-time constant parameters only accessible in index
Index attributes are compile-time constant parameters only accessible in index
expressions. They can be used to parameterize the access pattern of a structured
operation, for example, by setting its strides. They cannot take part in the
actual computation.
The following example demonstrates the use of attributes:
The following example demonstrates the use of index attributes:
```python
@linalg_structured_op
@ -136,9 +136,9 @@ The `strides` vector elements substitute the symbols `S.SH` and `S.SW` in the
index expressions of the operation instance. If no strides are provided the
`default` vector elements are used instead.
Attributes are currently limited to integer vectors and only accessible in index
expressions. An operation may have multiple attributes all of them placed at the
end of the parameter list after the output tensors.
Index attributes are currently limited to integer vectors and only accessible in
index expressions. An operation may have multiple attributes all of them placed
at the end of the parameter list after the output tensors.
## Shape-Only Tensors
@ -220,6 +220,43 @@ There are also special forms:
* `const(value)` returns a constant value.
* `index(dim)` returns the iteration index in the given dimension `dim`.
## Function Attributes
Function attributes are compile-time constant function parameters. They can be
used to parameterize the computation performed by a structured operation, for
example, to support signed and unsigned computations.
The following example demonstrates the use of function attributes:
```python
@linalg_structured_op
def elemwise_binary(
lhs=TensorDef(T1),
rhs=TensorDef(T2),
O=TensorDef(U, output=True),
fun=BinaryFnAttrDef(default=BinaryFn.add),
cast=TypeFnAttrDef(default=TypeFn.cast)):
O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None]))
```
The `fun` and `cast` function attributes by default are aliases for their
default values `BinaryFn.add` and `TypeFn.cast`, respectively. When
instantiating the operation, the function attributes may be set to other
functions using optional named arguments:
```python
elemwise_binary(lhs, rhs, outs=[out_tensor],
fun=BinaryFn.mul, cast=TypeFn.cast_unsigned)
```
In the example, the `fun` and `cast` arguments adapt the body of the operation
to implement multiplication and unsigned casts instead of addition and signed
casts.
OpDSL supports unary, binary, and type conversion function attributes. An
operation can take multiple attributes of different kinds placed at the end of
the parameter list.
## Types
All types in assignment expressions are late bound based on actual input and

View File

@ -58,7 +58,26 @@ def Linalg_Dialect : Dialect {
}];
}
// Define a TypeFn enum matching the OpDSL TypeFn class.
// Define the function attribute enums matching the OpDSL functions.
def UnaryFn : I32EnumAttr<"UnaryFn", "", [
I32EnumAttrCase<"exp", 0>,
I32EnumAttrCase<"log", 1>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}
def BinaryFn : I32EnumAttr<"BinaryFn", "", [
I32EnumAttrCase<"add", 0>,
I32EnumAttrCase<"mul", 1>,
I32EnumAttrCase<"max", 2>,
I32EnumAttrCase<"min", 3>,
I32EnumAttrCase<"sub", 4>,
I32EnumAttrCase<"max_unsigned", 5>,
I32EnumAttrCase<"min_unsigned", 6>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}
def TypeFn : I32EnumAttr<"TypeFn", "", [
I32EnumAttrCase<"cast", 0>,
I32EnumAttrCase<"cast_unsigned", 1>
@ -67,6 +86,12 @@ def TypeFn : I32EnumAttr<"TypeFn", "", [
let cppNamespace = "::mlir::linalg";
}
def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
let assemblyFormat = "`<` $value `>`";
}
def BinaryFnAttr : EnumAttr<Linalg_Dialect, BinaryFn, "binary_fn"> {
let assemblyFormat = "`<` $value `>`";
}
def TypeFnAttr : EnumAttr<Linalg_Dialect, TypeFn, "type_fn"> {
let assemblyFormat = "`<` $value `>`";
}

View File

@ -1,6 +1,120 @@
### AUTOGENERATED from core_named_ops.py
### To regenerate, run: bin/update_core_linalg_named_ops.sh
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: elemwise_unary
cpp_class_name: ElemwiseUnaryOp
doc: |-
Applies the unary function fun elementwise.
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
kind: input_tensor
type_var: T1
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: O
kind: output_tensor
type_var: U
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: fun
kind: unary_fn_attr
default_fn: exp
- !LinalgOperandDefConfig
name: cast
kind: type_fn_attr
default_fn: cast
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<() -> ()>
- affine_map<() -> ()>
iterator_types: []
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_fn:
kind: unary
attr_name: fun
operands:
- !ScalarExpression
scalar_fn:
kind: type
attr_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: elemwise_binary
cpp_class_name: ElemwiseBinaryOp
doc: |-
Applies the binary function fun elementwise.
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: lhs
kind: input_tensor
type_var: T1
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: rhs
kind: input_tensor
type_var: T2
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: O
kind: output_tensor
type_var: U
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: fun
kind: binary_fn_attr
default_fn: add
- !LinalgOperandDefConfig
name: cast
kind: type_fn_attr
default_fn: cast
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<() -> ()>
- affine_map<() -> ()>
- affine_map<() -> ()>
iterator_types: []
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_fn:
kind: binary
attr_name: fun
operands:
- !ScalarExpression
scalar_fn:
kind: type
attr_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: lhs
- !ScalarExpression
scalar_fn:
kind: type
attr_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matmul
cpp_class_name: MatmulOp

View File

@ -108,17 +108,9 @@ static LogicalResult foldMemRefCast(Operation *op) {
//===----------------------------------------------------------------------===//
// Region builder helper.
// TODO: Move this to a utility library.
// The public methods on this class are referenced directly from generated code
// and bind by name to math functions in the DSL as:
// `unary__{fnName}`
// `binary__{fnName}`
// Examples:
// `binary__add`
// `binary__mul`
// `unary__exp`
// `unary__log`
// The naming convention is intentional in order to match snake-cased DSL names.
// See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class.
// The public methods on this class are referenced directly from generated code.
// Helper build the unary, binary, and type conversion functions defined by the
// DSL. See mlir-linalg-ods-yaml-gen.cpp for the code that uses this class.
//
// Implementations of the math functions must be polymorphic over numeric types,
// internally performing necessary casts. If the function application makes no
@ -142,6 +134,98 @@ public:
RegionBuilderHelper(MLIRContext *context, Block &block)
: context(context), block(block) {}
// Build the unary functions defined by OpDSL.
Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
if (!isFloatingPoint(arg))
llvm_unreachable("unsupported non numeric type");
OpBuilder builder = getBuilder();
switch (unaryFn) {
case UnaryFn::exp:
return builder.create<math::ExpOp>(arg.getLoc(), arg);
case UnaryFn::log:
return builder.create<math::LogOp>(arg.getLoc(), arg);
}
llvm_unreachable("unsupported unary function");
}
// Build the binary functions defined by OpDSL.
Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
bool allInteger = isInteger(arg0) && isInteger(arg1);
if (!allFloatingPoint && !allInteger)
llvm_unreachable("unsupported non numeric type");
OpBuilder builder = getBuilder();
switch (binaryFn) {
case BinaryFn::add:
if (allFloatingPoint)
return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::mul:
if (allFloatingPoint)
return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::max:
if (allFloatingPoint)
return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::min:
if (allFloatingPoint)
return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::sub:
if (allFloatingPoint)
return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::max_unsigned:
if (allFloatingPoint)
return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::min_unsigned:
if (allFloatingPoint)
return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
}
llvm_unreachable("unsupported binary function");
}
// Build the type functions defined by OpDSL.
Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
switch (typeFn) {
case TypeFn::cast:
return cast(toType, operand, false);
case TypeFn::cast_unsigned:
return cast(toType, operand, true);
}
llvm_unreachable("unsupported type conversion function");
}
void yieldOutputs(ValueRange values) {
OpBuilder builder = getBuilder();
Location loc = builder.getUnknownLoc();
builder.create<YieldOp>(loc, values);
}
Value constant(const std::string &value) {
OpBuilder builder = getBuilder();
Location loc = builder.getUnknownLoc();
Attribute valueAttr = parseAttribute(value, builder.getContext());
return builder.create<arith::ConstantOp>(loc, valueAttr.getType(),
valueAttr);
}
Value index(int64_t dim) {
OpBuilder builder = getBuilder();
return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
}
Type getIntegerType(unsigned width) {
return IntegerType::get(context, width);
}
Type getFloat32Type() { return Float32Type::get(context); }
Type getFloat64Type() { return Float64Type::get(context); }
private:
// Generates operations to cast the given operand to a specified type.
// If the cast cannot be performed, a warning will be issued and the
// operand returned as-is (which will presumably yield a verification
@ -193,136 +277,6 @@ public:
return operand;
}
Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
switch (typeFn) {
case TypeFn::cast:
return cast(toType, operand, false);
case TypeFn::cast_unsigned:
return cast(toType, operand, true);
}
llvm_unreachable("unsupported type conversion function");
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value binary__add(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::AddFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
return builder.create<arith::AddIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value unary__exp(Value x) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(x))
return builder.create<math::ExpOp>(x.getLoc(), x);
llvm_unreachable("unsupported non numeric type");
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value unary__log(Value x) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(x))
return builder.create<math::LogOp>(x.getLoc(), x);
llvm_unreachable("unsupported non numeric type");
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value binary__sub(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::SubFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
return builder.create<arith::SubIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value binary__mul(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MulFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
return builder.create<arith::MulIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value binary__max(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
return builder.create<arith::MaxSIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value binary__max_unsigned(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
return builder.create<arith::MaxUIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value binary__min(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
return builder.create<arith::MinSIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value binary__min_unsigned(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
return builder.create<arith::MinUIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
void yieldOutputs(ValueRange values) {
assert(!values.empty() && "linalg ops must yield outputs");
if (values.empty())
return;
Value first = values.front();
OpBuilder builder = getBuilder();
builder.create<YieldOp>(first.getLoc(), values);
}
Value constant(const std::string &value) {
OpBuilder builder = getBuilder();
Location loc = builder.getUnknownLoc();
Attribute valueAttr = parseAttribute(value, builder.getContext());
return builder.create<arith::ConstantOp>(loc, valueAttr.getType(),
valueAttr);
}
Value index(int64_t dim) {
OpBuilder builder = getBuilder();
return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
}
Type getIntegerType(unsigned width) {
return IntegerType::get(context, width);
}
Type getFloat32Type() { return Float32Type::get(context); }
Type getFloat64Type() { return Float64Type::get(context); }
private:
MLIRContext *context;
Block &block;
bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
@ -331,6 +285,9 @@ private:
builder.setInsertionPointToEnd(&block);
return builder;
}
MLIRContext *context;
Block &block;
};
} // namespace

View File

@ -126,7 +126,7 @@ class TensorUse(TensorExpression):
return rhs_dims - lhs_dims
def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn":
return ReduceFnUse(BinaryFn.add, *self._compute_reduce_dims(rhs))(rhs)
return ReduceFnUse(BinaryFn.add, None, *self._compute_reduce_dims(rhs))(rhs)
def __repr__(self):
return (f"{self.operand_def.name}"
@ -183,8 +183,14 @@ class TensorReduceFn(TensorExpression):
f"bound to its lhs: {self}")
full_args = [self.lhs.to_scalar_expression()
] + [arg.to_scalar_expression() for arg in self.args]
return ScalarFn(FunctionKind.BINARY, self.reduce_use.binary_fn.fn_name,
None, None, full_args).expr()
fn_name = None
attr_name = None
if self.reduce_use.binary_fn:
fn_name = self.reduce_use.binary_fn.fn_name
if self.reduce_use.binary_attr:
attr_name = self.reduce_use.binary_attr.operand_def.name
return ScalarFn(FunctionKind.BINARY, fn_name, attr_name, None,
full_args).expr()
def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
for arg in self.args:
@ -257,8 +263,8 @@ class UnaryFnType:
def __init__(self, fn_name: str):
self.fn_name = fn_name
def __call__(self, exp: TensorExpression) -> "TensorFn":
return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [exp])
def __call__(self, arg: TensorExpression) -> "TensorFn":
return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [arg])
def __repr__(self):
return f"{self.fn_name}"
@ -345,16 +351,21 @@ class ReduceFnUse:
A reduction use specifies the reduction function and dimensions.
"""
def __init__(self, binary_fn: BinaryFnType, *reduce_dims: DimDef):
def __init__(self, binary_fn: Optional[BinaryFnType],
binary_attr: Optional["BinaryFnAttrDef"], *reduce_dims: DimDef):
if bool(binary_fn) + bool(binary_attr) != 1:
raise ValueError("One of 'binary_fn', 'binary_attr' must be specified")
self.binary_fn = binary_fn
self.binary_attr = binary_attr
self.reduce_dims = reduce_dims
def __call__(self, *args: TensorExpression) -> "TensorReduceFn":
return TensorReduceFn(self, args)
def __repr__(self):
return (f"reduce_{self.binary_fn.fn_name}"
f"({', '.join(repr(d) for d in self.reduce_dims)})")
fn = self.binary_fn if self.binary_fn else self.binary_attr
return (
f"reduce_{repr(fn)}({', '.join(repr(d) for d in self.reduce_dims)})")
class ReduceFnType:
@ -369,10 +380,10 @@ class ReduceFnType:
self.binary_fn = binary_fn
def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
return ReduceFnUse(self.binary_fn, *reduce_dims)
return ReduceFnUse(self.binary_fn, None, *reduce_dims)
def __repr__(self):
return (f"reduce_{self.binary_fn.fn_name}")
return f"reduce_{repr(self.binary_fn)}"
class ReduceFn:
@ -394,7 +405,9 @@ class OperandKind(Enum):
SCALAR = 1
OUTPUT_TENSOR = 2
INDEX_ATTR = 3
TYPE_FN_ATTR = 4
UNARY_FN_ATTR = 4
BINARY_FN_ATTR = 5
TYPE_FN_ATTR = 6
class OperandDef:
@ -441,6 +454,8 @@ class OperandDef:
def is_attribute(self) -> bool:
return (self.kind == OperandKind.INDEX_ATTR or
self.kind == OperandKind.UNARY_FN_ATTR or
self.kind == OperandKind.BINARY_FN_ATTR or
self.kind == OperandKind.TYPE_FN_ATTR)
def __hash__(self):
@ -557,6 +572,49 @@ class IndexAttrDef:
OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default)
class UnaryFnAttrDef:
"""Unary function attribute definition.
Unary function attributes provide a way to make the arithmetic computation
parametrizable. Every attribute specifies a default unary function
that may be overwritten at operation instantiation time.
"""
def __init__(self, default: "UnaryFnType"):
if not isinstance(default, UnaryFnType):
raise ValueError(f"UnaryFnAttrDef requires default of type UnaryFnType "
f"but got {default}")
self.operand_def = OperandDef(
OperandKind.UNARY_FN_ATTR, default_fn=default.fn_name)
def __call__(self, arg: TensorExpression) -> TensorFn:
return TensorFn(FunctionKind.UNARY, None, self.operand_def, None, [arg])
class BinaryFnAttrDef:
"""Binary function attribute definition.
Binary function attributes provide a way to make the arithmetic computation
parametrizable. Every attribute specifies a default binary function
that may be overwritten at operation instantiation time.
"""
def __init__(self, default: "BinaryFnType"):
if not isinstance(default, BinaryFnType):
raise ValueError(f"BinaryFnAttrDef requires default of type BinaryFnType "
f"but got {default}")
self.operand_def = OperandDef(
OperandKind.BINARY_FN_ATTR, default_fn=default.fn_name)
def __call__(self, arg0: TensorExpression,
arg1: TensorExpression) -> TensorFn:
return TensorFn(FunctionKind.BINARY, None, self.operand_def, None,
[arg0, arg1])
def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
return ReduceFnUse(None, self, *reduce_dims)
class TypeFnAttrDef:
"""Type conversion function attribute definition.

View File

@ -309,8 +309,8 @@ class LinalgStructuredOpConfig(YAMLObject):
def add_operand(self, operand_def: OperandDef):
if operand_def in self.operands:
return
if (operand_def.kind == OperandKind.SCALAR or
operand_def.kind == OperandKind.TYPE_FN_ATTR):
if not (operand_def.is_tensor() or
operand_def.kind == OperandKind.INDEX_ATTR):
self.operands[operand_def] = OperandDefConfig(operand_def)
return
with self.context:

View File

@ -130,7 +130,8 @@ def linalg_structured_op(dsl_func=None,
for param_name, param in sig.parameters.items():
param_default = param.default
if isinstance(param_default,
(TensorDef, ScalarDef, IndexAttrDef, TypeFnAttrDef)):
(TensorDef, ScalarDef, IndexAttrDef, UnaryFnAttrDef,
BinaryFnAttrDef, TypeFnAttrDef)):
op_def.add_operand(param_name, param_default.operand_def)
else:
raise ValueError(

View File

@ -41,7 +41,7 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
all_arg_defs = op_config.ordered_operands
in_arg_defs = [
d for d in all_arg_defs
if d.kind == OperandKind.SCALAR or d.kind == OperandKind.INPUT_TENSOR
if d.kind in [OperandKind.SCALAR, OperandKind.INPUT_TENSOR]
]
out_arg_defs = [
d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR
@ -49,8 +49,11 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
index_attr_arg_defs = [
d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR
]
type_fn_attr_arg_defs = [
d for d in all_arg_defs if d.kind == OperandKind.TYPE_FN_ATTR
fn_attr_arg_defs = [
d for d in all_arg_defs if d.kind in [
OperandKind.UNARY_FN_ATTR, OperandKind.BINARY_FN_ATTR,
OperandKind.TYPE_FN_ATTR
]
]
# Verify outs is a sequence or a list of results.
@ -135,28 +138,38 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
array = np.array(index_attr_vals, dtype=np.int64)
index_attrs[index_attr.name] = DenseElementsAttr.get(array)
# Compute the type function attribute mapping.
type_fn_attr_mapping = {}
for type_fn_attr in type_fn_attr_arg_defs:
attr_val = type_fn_attr.operand_def.default_fn
if type_fn_attr.name in attrs:
type_fn = attrs.get(type_fn_attr.name)
if not isinstance(type_fn, TypeFnType):
raise ValueError(f"Attribute {type_fn_attr.name} needs to be of type "
f"TypeFnType but got {type(attr_val)}")
attr_val = type_fn.fn_name
assert attr_val, "Type function attribute has no value"
type_fn_attr_mapping[type_fn_attr.name] = attr_val
# Compute the function attribute mapping.
fn_attr_mapping = {}
for fn_attr in fn_attr_arg_defs:
attr_val = fn_attr.operand_def.default_fn
attr_kind = fn_attr.kind
if fn_attr.name in attrs:
fn = attrs.get(fn_attr.name)
if attr_kind == OperandKind.UNARY_FN_ATTR:
if not isinstance(fn, UnaryFnType):
raise ValueError(f"Attribute {fn_attr.name} needs to be of type "
f"UnaryFnType but got {type(attr_val)}")
elif attr_kind == OperandKind.BINARY_FN_ATTR:
if not isinstance(fn, BinaryFnType):
raise ValueError(f"Attribute {fn_attr.name} needs to be of type "
f"BinaryFnType but got {type(attr_val)}")
else:
if not isinstance(fn, TypeFnType):
raise ValueError(f"Attribute {fn_attr.name} needs to be of type "
f"TypeFnType but got {type(attr_val)}")
attr_val = fn.fn_name
assert attr_val, "Function attribute has no value"
fn_attr_mapping[fn_attr.name] = (attr_val, attr_kind)
return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types,
type_mapping, indexing_maps_attr, iterator_types_attr, index_attrs,
type_fn_attr_mapping, block_arg_types)
fn_attr_mapping, block_arg_types)
def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
outs: ValueList, **attrs: Sequence[int]):
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
indexing_maps_attr, iterator_types_attr, index_attrs, type_fn_attr_mapping, \
indexing_maps_attr, iterator_types_attr, index_attrs, fn_attr_mapping, \
block_arg_types = \
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
@ -193,7 +206,7 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
block_arg_mapping = dict(zip(block_arg_names, block.arguments))
with InsertionPoint(block):
body_builder = _BodyBuilder(type_mapping, block_arg_mapping,
type_fn_attr_mapping)
fn_attr_mapping)
for assignment in op_config.assignments:
body_builder.assign(assignment)
body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs))
@ -208,7 +221,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
op_class_name: str, *ins: Value, outs: ValueList,
**attrs: Sequence[int]):
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
indexing_maps_attr, iterator_types_attr, index_attrs, type_fn_attr_mapping, \
indexing_maps_attr, iterator_types_attr, index_attrs, fn_attr_mapping, \
block_arg_types = \
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
@ -225,10 +238,12 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
for name, value in index_attrs.items():
named_op.operation.attributes[name] = value
# Set the type function attributes.
for name, value in type_fn_attr_mapping.items():
# Compute the function attributes by combining operand kind and function name.
for name, (fn_name, kind) in fn_attr_mapping.items():
assert kind.name.lower().endswith("_attr")
enum_name = kind.name.lower()[:-5]
named_op.operation.attributes[name] = Attribute.parse(
f"#linalg.type_fn<{value}>")
f"#linalg.{enum_name}<{fn_name}>")
linalg.fill_builtin_region(named_op.operation)
@ -242,11 +257,11 @@ class _BodyBuilder:
"""Constructs a structured op body by evaluating assignments."""
def __init__(self, type_mapping: Dict[str, Type],
block_arg_mapping: Dict[str, Value],
type_fn_attr_mapping: Dict[str, str]):
block_arg_mapping: Dict[str, Value], fn_attr_mapping: Dict[str,
str]):
self.type_mapping = type_mapping
self.block_arg_mapping = block_arg_mapping
self.type_fn_attr_mapping = type_fn_attr_mapping
self.fn_attr_mapping = fn_attr_mapping
self.yield_mapping = dict() # type: Dict[str, Value]
def assign(self, assignment: ScalarAssign):
@ -270,21 +285,18 @@ class _BodyBuilder:
dim_attr = IntegerAttr.get(
IntegerType.get_signless(64), expr.scalar_index.dim)
return linalg.IndexOp(dim_attr).result
elif expr.scalar_fn and expr.scalar_fn.kind is not FunctionKind.TYPE:
kind = expr.scalar_fn.kind.name.lower()
fn = self._get_function(f"_{kind}_{expr.scalar_fn.fn_name}")
operand_values = [
self.expression(operand) for operand in expr.scalar_fn.operands
]
return fn(*operand_values)
elif expr.scalar_fn and expr.scalar_fn.kind is FunctionKind.TYPE:
elif expr.scalar_fn:
kind = expr.scalar_fn.kind.name.lower()
fn_name = expr.scalar_fn.fn_name
if expr.scalar_fn.attr_name:
fn_name = self.type_fn_attr_mapping[expr.scalar_fn.attr_name]
fn_name, _ = self.fn_attr_mapping[expr.scalar_fn.attr_name]
fn = self._get_function(f"_{kind}_{fn_name}")
operand_value = self.expression(expr.scalar_fn.operands[0])
return fn(expr.scalar_fn.type_var.name, operand_value)
operand_values = [
self.expression(operand) for operand in expr.scalar_fn.operands
]
if expr.scalar_fn.kind == FunctionKind.TYPE:
operand_values = [expr.scalar_fn.type_var.name] + operand_values
return fn(*operand_values)
raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
def yield_outputs(self, *output_names: str):

View File

@ -6,6 +6,35 @@ T2 = TV.T2
Batch = S.Batch
@linalg_structured_op
def elemwise_unary(
I=TensorDef(T1),
O=TensorDef(U, output=True),
fun=UnaryFnAttrDef(default=UnaryFn.exp),
cast=TypeFnAttrDef(default=TypeFn.cast)):
"""Applies the unary function fun elementwise.
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
"""
O[None] = fun(cast(U, I[None]))
@linalg_structured_op
def elemwise_binary(
lhs=TensorDef(T1),
rhs=TensorDef(T2),
O=TensorDef(U, output=True),
fun=BinaryFnAttrDef(default=BinaryFn.add),
cast=TypeFnAttrDef(default=TypeFn.cast)):
"""Applies the binary function fun elementwise.
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
"""
O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None]))
@linalg_structured_op
def matmul(
A=TensorDef(T1, S.M, S.K),

View File

@ -292,16 +292,48 @@ func @generalize_fill_rng_2d_i32(%min: f64, %max: f64, %seed: i32, %O: tensor<16
// -----
func @generalize_soft_plus_2d_f32(%input: tensor<16x32xf32>, %output: tensor<16x32xf32>) -> tensor<16x32xf32> {
%0 = linalg.soft_plus_2d ins(%input: tensor<16x32xf32>) outs(%output: tensor<16x32xf32>) -> tensor<16x32xf32>
return %0: tensor<16x32xf32>
// Verifies the default value of the fun attribute is an exp op.
func @generalize_elemwise_exp(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
%0 = linalg.elemwise_unary ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
return %0: tensor<4x8xf32>
}
// CHECK-LABEL: @generalize_soft_plus_2d_f32
// CHECK: %[[C1:.+]] = arith.constant 1.000000e+00 : f32
// CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32
// CHECK-NEXT: %[[EXP:.+]] = math.exp %[[IN]] : f32
// CHECK-NEXT: %[[SUM:.+]] = arith.addf %[[EXP]], %[[C1]] : f32
// CHECK-NEXT: %[[LOG:.+]] = math.log %[[SUM]] : f32
// CHECK-NEXT: linalg.yield %[[LOG]] : f32
// CHECK-NEXT: -> tensor<16x32xf32>
// CHECK-LABEL: @generalize_elemwise_exp
// CHECK: = math.exp
// -----
// Verifies the fun attribute controls the unary function used.
func @generalize_elemwise_log(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
%0 = linalg.elemwise_unary {fun = #linalg.unary_fn<log>}
ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
return %0: tensor<4x8xf32>
}
// CHECK-LABEL: @generalize_elemwise_log
// CHECK: = math.log
// -----
// Verifies the default value of the fun attribute is an add op.
func @generalize_elemwise_add(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
%0 = linalg.elemwise_binary ins(%lhs, %rhs: tensor<4x8xf32>, tensor<4x8xf32>)
outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
return %0: tensor<4x8xf32>
}
// CHECK-LABEL: @generalize_elemwise_add
// CHECK: = arith.addf
// -----
// Verifies the fun attribute controls the binary function used.
func @generalize_elemwise_mul(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
%0 = linalg.elemwise_binary {fun = #linalg.binary_fn<mul>}
ins(%lhs, %rhs: tensor<4x8xf32>, tensor<4x8xf32>)
outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
return %0: tensor<4x8xf32>
}
// CHECK-LABEL: @generalize_elemwise_mul
// CHECK: = arith.mulf

View File

@ -111,7 +111,7 @@ structured_op: !LinalgStructuredOpConfig
# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL0]]);
# IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1);
# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL2]]);
# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.binary__add([[VAL1]], [[VAL3]]);
# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.buildBinaryFn(BinaryFn::add, [[VAL1]], [[VAL3]]);
# @linalg_structured_op
@ -255,14 +255,15 @@ structured_op: !LinalgStructuredOpConfig
# IMPL-NEXT: AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context);
# IMPL-NEXT: AffineMap tensorMap = AffineMap::getMultiDimIdentityMap(
# @linalg_structured_op
# def test4(O=TensorDef(T, S.M, S.N, output=True)):
# def test4(O=TensorDef(T, S.M, S.N, output=True),
# unary_fun=UnaryFnAttrDef(default=UnaryFn.exp),
# binary_fun=BinaryFnAttrDef(default=BinaryFn.add)):
# """Title.
# Detailed description.
# """
# O[D.m, D.n] = BinaryFn.add(UnaryFn.exp(O[D.m, D.n]), O[D.m, D.n])
# O[D.m, D.n] = binary_fun(unary_fun(O[D.m, D.n]), O[D.m, D.n])
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
@ -279,6 +280,14 @@ structured_op: !LinalgStructuredOpConfig
kind: output_tensor
type_var: T
shape_map: affine_map<()[s0, s1] -> (s0, s1)>
- !LinalgOperandDefConfig
name: unary_fun
kind: unary_fn_attr
default_fn: exp
- !LinalgOperandDefConfig
name: binary_fun
kind: binary_fn_attr
default_fn: add
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
@ -291,21 +300,36 @@ structured_op: !LinalgStructuredOpConfig
value: !ScalarExpression
scalar_fn:
kind: binary
fn_name: add
attr_name: binary_fun
operands:
- !ScalarExpression
scalar_fn:
kind: unary
fn_name: exp
attr_name: unary_fun
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_arg: O
# ODS-LABEL: def Test4Op : LinalgStructuredBase_Op<"test4"
# ODS: let arguments =
# ODS-NEXT: Variadic<AnyType>:$inputs,
# ODS-NEXT: Variadic<AnyShaped>:$outputs,
# ODS-NEXT: DefaultValuedAttr<UnaryFnAttr, "UnaryFn::exp">:$unary_fun,
# ODS-NEXT: DefaultValuedAttr<BinaryFnAttr, "BinaryFn::add">:$binary_fun
# ODS: "Attribute":$unary_fun, "Attribute":$binary_fun,
# ODS: $_state.addAttribute("unary_fun", unary_fun)
# ODS-NEXT: $_state.addAttribute("binary_fun", binary_fun)
# IMPL-LABEL: void Test4Op::regionBuilder(ImplicitLocOpBuilder &b,
# IMPL-NEXT: Block &block, ArrayRef<NamedAttribute> attrs)
# IMPL: UnaryFn unary_funVal = UnaryFn::exp
# IMPL: BinaryFn binary_funVal = BinaryFn::add
# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.unary__exp(block.getArgument(0))
# IMPL-NEXT: Value [[VAL1:[a-z0-9]+]] = helper.binary__add([[VAL0]], block.getArgument(0))
# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.buildUnaryFn(unary_funVal, block.getArgument(0))
# IMPL-NEXT: Value [[VAL1:[a-z0-9]+]] = helper.buildBinaryFn(binary_funVal, [[VAL0]], block.getArgument(0))
# IMPL-NEXT: yields.push_back([[VAL1]])

View File

@ -18,6 +18,12 @@ from mlir.dialects.linalg.opdsl.lang import *
# CHECK: kind: output_tensor
# CHECK: type_var: U
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
# CHECK: name: bfn
# CHECK: kind: binary_fn_attr
# CHECK: default_fn: mul
# CHECK: name: ufn
# CHECK: kind: unary_fn_attr
# CHECK: default_fn: exp
# CHECK: name: cast
# CHECK: kind: type_fn_attr
# CHECK: default_fn: cast
@ -26,8 +32,10 @@ def matmul(
A=TensorDef(T, S.M, S.K),
B=TensorDef(T, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True),
bfn=BinaryFnAttrDef(default=BinaryFn.mul),
ufn=UnaryFnAttrDef(default=UnaryFn.exp),
cast=TypeFnAttrDef(default=TypeFn.cast)):
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
C[D.m, D.n] += bfn(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
# CHECK: ---

View File

@ -10,10 +10,12 @@ from mlir.dialects.linalg.opdsl.lang import *
# CHECK: arg: C
# CHECK: value:
# CHECK: scalar_fn:
# CHECK: kind: binary
# CHECK: fn_name: add
# CHECK: operands:
# CHECK: scalar_fn:
# CHECK: fn_name: mul
# CHECK: kind: binary
# CHECK: attr_name: mul
# CHECK: operands:
# CHECK: scalar_fn:
# CHECK: kind: type
@ -32,8 +34,9 @@ def matmul(
A=TensorDef(T, S.M, S.K),
B=TensorDef(T, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True),
mul=BinaryFnAttrDef(default=BinaryFn.mul),
cast=TypeFnAttrDef(default=TypeFn.cast)):
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
C[D.m, D.n] += mul(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
# CHECK: ---
@ -69,14 +72,21 @@ def matmul(
# CHECK: fn_name: cast
# CHECK: type_var: T
# CHECK: operands:
# CHECK: scalar_const: '1.{{[0]*}}e+03 : f64'
# CHECK: scalar_fn:
# CHECK: kind: unary
# CHECK: attr_name: exp
# CHECK: operands:
# CHECK: scalar_const: '1.{{[0]*}}e+03 : f64'
@linalg_structured_op
def constants(O=TensorDef(T, S.M, S.K, output=True)):
def constants(
O=TensorDef(T, S.M, S.K, output=True),
exp=UnaryFnAttrDef(default=UnaryFn.exp)):
pi = TypeFn.cast(T, const(3.1415926535897931))
cst42 = TypeFn.cast(T, const(42))
cst1000 = TypeFn.cast(T, const(1e+3))
cst1000 = TypeFn.cast(T, exp(const(1e+3)))
O[D.m, D.n] = UnaryFn.exp(pi) + cst42 - cst1000
# CHECK: ---
# CHECK-LABEL: indices
# CHECK: assignments:

View File

@ -12,55 +12,18 @@ T2 = TV.T2
@linalg_structured_op
def pooling_max_poly(
def pooling_poly(
I=TensorDef(T1, S.N, S.H, S.W, S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
reduce=BinaryFnAttrDef(default=BinaryFn.max),
cast=TypeFnAttrDef(default=TypeFn.cast),
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw](
TypeFn.cast(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
@linalg_structured_op
def pooling_max_unsigned_poly(
I=TensorDef(T1, S.N, S.H, S.W, S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw](
TypeFn.cast_unsigned(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
@linalg_structured_op
def pooling_min_poly(
I=TensorDef(T1, S.N, S.H, S.W, S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw](
TypeFn.cast(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
@linalg_structured_op
def pooling_min_unsigned_poly(
I=TensorDef(T1, S.N, S.H, S.W, S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw](
TypeFn.cast_unsigned(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
O[D.n, D.oh, D.ow, D.c] = reduce[D.kh, D.kw](
cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
D.c]))
with Context() as ctx, Location.unknown():
@ -88,7 +51,7 @@ with Context() as ctx, Location.unknown():
RankedTensorType.get((2, 2), f32),
RankedTensorType.get((1, 2, 4, 1), i32))
def test_f32i32_max_pooling(input, shape, init_result):
return pooling_max_poly(
return pooling_poly(
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
# CHECK-LABEL: @test_f32i32_max_unsigned_pooling
@ -99,8 +62,14 @@ with Context() as ctx, Location.unknown():
RankedTensorType.get((2, 2), f32),
RankedTensorType.get((1, 2, 4, 1), i32))
def test_f32i32_max_unsigned_pooling(input, shape, init_result):
return pooling_max_unsigned_poly(
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
return pooling_poly(
input,
shape,
outs=[init_result],
reduce=BinaryFn.max_unsigned,
cast=TypeFn.cast_unsigned,
strides=[2, 4],
dilations=[1, 2])
# CHECK-LABEL: @test_f32f32_max_pooling
# CHECK: linalg.generic
@ -115,7 +84,7 @@ with Context() as ctx, Location.unknown():
RankedTensorType.get((2, 2), f32),
RankedTensorType.get((1, 2, 4, 1), f32))
def test_f32f32_max_pooling(input, shape, init_result):
return pooling_max_poly(
return pooling_poly(
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
# CHECK-LABEL: @test_f32i32_min_pooling
@ -126,8 +95,13 @@ with Context() as ctx, Location.unknown():
RankedTensorType.get((2, 2), f32),
RankedTensorType.get((1, 2, 4, 1), i32))
def test_f32i32_min_pooling(input, shape, init_result):
return pooling_min_poly(
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
return pooling_poly(
input,
shape,
outs=[init_result],
reduce=BinaryFn.min,
strides=[2, 4],
dilations=[1, 2])
# CHECK-LABEL: @test_f32i32_min_unsigned_pooling
# CHECK: = arith.fptoui
@ -137,8 +111,14 @@ with Context() as ctx, Location.unknown():
RankedTensorType.get((2, 2), f32),
RankedTensorType.get((1, 2, 4, 1), i32))
def test_f32i32_min_unsigned_pooling(input, shape, init_result):
return pooling_min_unsigned_poly(
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
return pooling_poly(
input,
shape,
outs=[init_result],
reduce=BinaryFn.min_unsigned,
cast=TypeFn.cast_unsigned,
strides=[2, 4],
dilations=[1, 2])
# CHECK-LABEL: @test_f32f32_min_pooling
# CHECK: = arith.minf
@ -147,8 +127,13 @@ with Context() as ctx, Location.unknown():
RankedTensorType.get((2, 2), f32),
RankedTensorType.get((1, 2, 4, 1), f32))
def test_f32f32_min_pooling(input, shape, init_result):
return pooling_min_poly(
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
return pooling_poly(
input,
shape,
outs=[init_result],
reduce=BinaryFn.min,
strides=[2, 4],
dilations=[1, 2])
print(module)

View File

@ -94,20 +94,27 @@ def testNamedStructuredOpCustomForm():
with InsertionPoint(module.body):
@builtin.FuncOp.from_py_func(
RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
f32))
RankedTensorType.get((4, 8), f32), RankedTensorType.get((4, 8), f32))
def named_form(lhs, rhs):
init_result = linalg.InitTensorOp([4, 8], f32)
# First check the named form with custom format
# CHECK: linalg.matmul
# CHECK: cast = #linalg.type_fn<cast_unsigned>
# CHECK-NOT: linalg.memoized_indexing_maps
# CHECK-SAME: ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>)
# CHECK-SAME: outs(%{{.*}} : tensor<4x8xf32>)
# CHECK-SAME: -> tensor<4x8xf32>
# CHECK-NEXT: return
return linalg.matmul(
lhs, rhs, outs=[init_result.result], cast=TypeFn.cast_unsigned)
# Check for the named form with custom format
# CHECK: linalg.elemwise_unary
# CHECK-SAME: cast = #linalg.type_fn<cast>
# CHECK-SAME: fun = #linalg.unary_fn<exp>
# CHECK-SAME: ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result])
# CHECK: linalg.elemwise_binary
# CHECK-SAME: cast = #linalg.type_fn<cast_unsigned>
# CHECK-SAME: fun = #linalg.binary_fn<mul>
# CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<4x8xf32>, tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
# CHECK: return
binary_result = linalg.elemwise_binary(
lhs,
rhs,
outs=[init_result.result],
fun=BinaryFn.mul,
cast=TypeFn.cast_unsigned)
return unary_result, binary_result
print(module)
@ -130,7 +137,8 @@ def testNamedStructuredOpGenericForm():
# CHECK-NEXT: arith.mulf{{.*}} (f32, f32) -> f32
# CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32
# CHECK-NEXT: linalg.yield{{.*}} (f32) -> ()
# CHECK-NEXT: operand_segment_sizes = dense<[2, 1]> : vector<2xi32>
# CHECK-NEXT: cast = #linalg.type_fn<cast>
# CHECK-SAME: operand_segment_sizes = dense<[2, 1]> : vector<2xi32>
# CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
return linalg.matmul(lhs, rhs, outs=[init_result.result])

View File

@ -19,6 +19,37 @@ def log(*args):
sys.stderr.flush()
elemwise_boiler = """
func @main() -> f32 attributes {llvm.emit_c_interface} {
%v0 = arith.constant 0.0 : f32
%v1 = arith.constant 1.0 : f32
%v2 = arith.constant 2.0 : f32
%lhs = memref.alloc() : memref<4x8xf32>
%rhs = memref.alloc() : memref<4x8xf32>
%O0 = memref.alloc() : memref<4x8xf32>
%O1 = memref.alloc() : memref<4x8xf32>
linalg.fill(%v1, %lhs) : f32, memref<4x8xf32>
linalg.fill(%v2, %rhs) : f32, memref<4x8xf32>
linalg.fill(%v0, %O0) : f32, memref<4x8xf32>
linalg.fill(%v0, %O1) : f32, memref<4x8xf32>
call @elemwise_exp_add_on_buffers(%lhs, %rhs, %O0) :
(memref<4x8xf32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
call @elemwise_log_mul_on_buffers(%lhs, %rhs, %O1) :
(memref<4x8xf32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
%c0 = arith.constant 0 : index
%res0 = memref.load %O0[%c0, %c0] : memref<4x8xf32>
%res1 = memref.load %O1[%c0, %c0] : memref<4x8xf32>
%0 = arith.addf %res0, %res1 : f32
// TODO: FFI-based solution to allow testing and printing with python code.
return %0 : f32
}
"""
matmul_boiler = """
func @main() -> f32 attributes {llvm.emit_c_interface} {
%v0 = arith.constant 0.0 : f32
@ -166,13 +197,93 @@ def transform(module, boilerplate):
pm = PassManager.parse(
"builtin.func(convert-linalg-to-loops, lower-affine, " +
"convert-scf-to-cf, arith-expand, memref-expand), convert-vector-to-llvm,"
+ "convert-memref-to-llvm, convert-std-to-llvm," +
"convert-math-to-llvm, convert-scf-to-cf, arith-expand, memref-expand), "
+ "convert-vector-to-llvm, convert-memref-to-llvm, convert-std-to-llvm," +
"reconcile-unrealized-casts")
pm.run(mod)
return mod
def test_elemwise_builtin():
with Context() as ctx, Location.unknown():
module = Module.create()
f32 = F32Type.get()
i8 = IntegerType.get_signless(8)
with InsertionPoint(module.body):
@builtin.FuncOp.from_py_func(
MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
MemRefType.get((4, 8), f32))
def elemwise_exp_add_on_buffers(lhs, rhs, out):
linalg.elemwise_unary(lhs, outs=[out])
linalg.elemwise_binary(out, rhs, outs=[out])
@builtin.FuncOp.from_py_func(
MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
MemRefType.get((4, 8), f32))
def elemwise_log_mul_on_buffers(lhs, rhs, out):
linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log)
linalg.elemwise_binary(out, rhs, outs=[out], fun=BinaryFn.mul)
execution_engine = ExecutionEngine(transform(module, elemwise_boiler))
# TODO: FFI-based solution to allow testing and printing with python code.
# Prepare arguments: one result f32.
# Arguments must be passed as pointers.
c_float_p = ctypes.c_float * 1
res = c_float_p(-1.)
execution_engine.invoke("main", res)
log("RESULT: ", res[0])
# elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
# elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
# CHECK: RESULT: 4.71828
test_elemwise_builtin()
def test_elemwise_generic():
with Context() as ctx, Location.unknown():
module = Module.create()
f32 = F32Type.get()
i8 = IntegerType.get_signless(8)
with InsertionPoint(module.body):
@builtin.FuncOp.from_py_func(
MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
MemRefType.get((4, 8), f32))
def elemwise_exp_add_on_buffers(lhs, rhs, out):
linalg.elemwise_unary(lhs, outs=[out], emit_generic=True)
linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True)
@builtin.FuncOp.from_py_func(
MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
MemRefType.get((4, 8), f32))
def elemwise_log_mul_on_buffers(lhs, rhs, out):
linalg.elemwise_unary(
lhs, outs=[out], fun=UnaryFn.log, emit_generic=True)
linalg.elemwise_binary(
out, rhs, outs=[out], fun=BinaryFn.mul, emit_generic=True)
execution_engine = ExecutionEngine(transform(module, elemwise_boiler))
# TODO: FFI-based solution to allow testing and printing with python code.
# Prepare arguments: one result f32.
# Arguments must be passed as pointers.
c_float_p = ctypes.c_float * 1
res = c_float_p(-1.)
execution_engine.invoke("main", res)
log("RESULT: ", res[0])
# elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
# elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
# CHECK: RESULT: 4.71828
test_elemwise_generic()
def test_matmul_builtin():
with Context() as ctx, Location.unknown():
module = Module.create()

View File

@ -66,6 +66,8 @@ enum class LinalgOperandDefKind {
Scalar,
OutputTensor,
IndexAttr,
UnaryFnAttr,
BinaryFnAttr,
TypeFnAttr
};
@ -208,6 +210,8 @@ struct ScalarEnumerationTraits<LinalgOperandDefKind> {
io.enumCase(value, "scalar", LinalgOperandDefKind::Scalar);
io.enumCase(value, "output_tensor", LinalgOperandDefKind::OutputTensor);
io.enumCase(value, "index_attr", LinalgOperandDefKind::IndexAttr);
io.enumCase(value, "unary_fn_attr", LinalgOperandDefKind::UnaryFnAttr);
io.enumCase(value, "binary_fn_attr", LinalgOperandDefKind::BinaryFnAttr);
io.enumCase(value, "type_fn_attr", LinalgOperandDefKind::TypeFnAttr);
}
};
@ -430,6 +434,45 @@ static ScalarAssign *findAssignment(StringRef name,
return nullptr;
}
// Return true if the operand is a function attribute.
static bool isFunctionAttribute(LinalgOperandDefKind kind) {
return kind == LinalgOperandDefKind::UnaryFnAttr ||
kind == LinalgOperandDefKind::BinaryFnAttr ||
kind == LinalgOperandDefKind::TypeFnAttr;
}
// Return true if the operand is an attribute.
static bool isAttribute(LinalgOperandDefKind kind) {
return kind == LinalgOperandDefKind::IndexAttr || isFunctionAttribute(kind);
}
// Get the enum name for the given operand kind.
std::string convertOperandKindToEnumName(LinalgOperandDefKind kind) {
switch (kind) {
case LinalgOperandDefKind::UnaryFnAttr:
return std::string("UnaryFn");
case LinalgOperandDefKind::BinaryFnAttr:
return std::string("BinaryFn");
case LinalgOperandDefKind::TypeFnAttr:
return std::string("TypeFn");
default:
break;
}
llvm_unreachable("unsupported function attribute kind");
}
// Get the enum name for the given function kind.
std::string convertFunctionKindToEnumName(ScalarFnKind kind) {
switch (kind) {
case ScalarFnKind::Unary:
return std::string("UnaryFn");
case ScalarFnKind::Binary:
return std::string("BinaryFn");
case ScalarFnKind::Type:
return std::string("TypeFn");
}
}
//===----------------------------------------------------------------------===//
// Templates
//===----------------------------------------------------------------------===//
@ -693,8 +736,7 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
interfaceNameList = interleaveToString(opConfig.metadata->implements, ", ");
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
return arg.kind == LinalgOperandDefKind::IndexAttr ||
arg.kind == LinalgOperandDefKind::TypeFnAttr;
return isAttribute(arg.kind);
})) {
SmallVector<std::string> attrDefs;
SmallVector<std::string> attrParams;
@ -703,13 +745,14 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
static const char paramFmt[] = "\"Attribute\":${0}";
static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});";
// Add the type conversion attributes to the op definition and builders.
if (arg.kind == LinalgOperandDefKind::TypeFnAttr) {
if (isFunctionAttribute(arg.kind)) {
assert(arg.defaultFn.hasValue());
static const char typeFmt[] = "TypeFn::{0}";
std::string enumName = convertOperandKindToEnumName(arg.kind);
static const char typeFmt[] = "{0}::{1}";
static const char defFmt[] = "DefaultValuedAttr<{0}, \"{1}\">:${2}";
attrDefs.push_back(llvm::formatv(defFmt, "TypeFnAttr",
llvm::formatv(typeFmt, arg.defaultFn),
arg.name));
attrDefs.push_back(llvm::formatv(
defFmt, llvm::formatv("{0}Attr", enumName),
llvm::formatv(typeFmt, enumName, arg.defaultFn), arg.name));
attrParams.push_back(llvm::formatv(paramFmt, arg.name));
attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
}
@ -1000,21 +1043,24 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b,
SmallVector<std::string> attrs;
SmallVector<std::string> stmts;
for (LinalgOperandDef &arg : args) {
if (arg.kind != LinalgOperandDefKind::TypeFnAttr)
if (!isFunctionAttribute(arg.kind))
continue;
// Obtain the type function attribute values. Parameters.
// {0}: attribute name
// {1}: default type function name
// {0}: enum name
// {1}: attribute name
// {2}: default type function name
static const char attrDef[] = R"FMT(
TypeFn {0}Val = TypeFn::{1};
auto {0}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
return attr.getName() == "{0}"; });
if ({0}Iter != attrs.end()) {{
if (auto attr = {0}Iter->getValue().dyn_cast<TypeFnAttr>())
{0}Val = attr.getValue();
{0} {1}Val = {0}::{2};
auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
return attr.getName() == "{1}"; });
if ({1}Iter != attrs.end()) {{
if (auto attr = {1}Iter->getValue().dyn_cast<{0}Attr>())
{1}Val = attr.getValue();
}
)FMT";
attrs.push_back(llvm::formatv(attrDef, arg.name, arg.defaultFn));
std::string enumName = convertOperandKindToEnumName(arg.kind);
attrs.push_back(
llvm::formatv(attrDef, enumName, arg.name, arg.defaultFn));
}
for (LinalgOperandDef &arg : args) {
if (arg.kind != LinalgOperandDefKind::OutputTensor)
@ -1056,11 +1102,47 @@ if ({0}Iter != attrs.end()) {{
cppIdent, *expression.index));
return cppIdent;
}
if (expression.scalarFn &&
expression.scalarFn->kind != ScalarFnKind::Type) {
// Apply function.
// Recursively generate operands.
if (expression.scalarFn) {
std::string enumName =
convertFunctionKindToEnumName(expression.scalarFn->kind);
// Get the function or attribute name.
assert(expression.scalarFn->fnName || expression.scalarFn->attrName);
std::string funcType;
if (expression.scalarFn->fnName) {
funcType = llvm::formatv("{0}::{1}", enumName,
*expression.scalarFn->fnName);
}
if (expression.scalarFn->attrName) {
if (llvm::none_of(args, [&](LinalgOperandDef &arg) {
return isFunctionAttribute(arg.kind) &&
arg.name == expression.scalarFn->attrName.getValue();
})) {
emitError(genContext.getLoc())
<< "missing function attribute "
<< expression.scalarFn->attrName.getValue();
}
funcType = llvm::formatv("{0}Val", *expression.scalarFn->attrName);
}
assert(!funcType.empty());
// Add the optional type parameter to the operands.
SmallVector<std::string> operandCppValues;
if (expression.scalarFn->kind == ScalarFnKind::Type) {
assert(expression.scalarFn->typeVar.hasValue());
Optional<std::string> typeCppValue =
findTypeValue(expression.scalarFn->typeVar.getValue(), args);
if (!typeCppValue) {
emitError(genContext.getLoc())
<< "type variable " << expression.scalarFn->typeVar.getValue()
<< ", used in a type conversion, must map to a predefined or "
<< "an argument type but it does not";
return None;
}
operandCppValues.push_back(typeCppValue.getValue());
}
// Collect the scalar operands.
for (ScalarExpression &operand : expression.scalarFn->operands) {
auto operandCppValue = generateExpression(operand);
if (!operandCppValue)
@ -1068,59 +1150,11 @@ if ({0}Iter != attrs.end()) {{
operandCppValues.push_back(*operandCppValue);
}
std::string prefix = expression.scalarFn->kind == ScalarFnKind::Unary
? "unary"
: "binary";
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
stmts.push_back(
llvm::formatv("Value {0} = helper.{1}__{2}({3});", cppIdent,
prefix, expression.scalarFn->fnName,
interleaveToString(operandCppValues, ", ")));
return cppIdent;
}
if (expression.scalarFn &&
expression.scalarFn->kind == ScalarFnKind::Type) {
// Symbolic cast.
// Operands must be arity 1.
if (expression.scalarFn->operands.size() != 1) {
emitError(genContext.getLoc())
<< "type conversion operand arity must be 1";
return None;
}
Optional<std::string> operandCppValue =
generateExpression(expression.scalarFn->operands[0]);
if (!operandCppValue)
return None;
assert(expression.scalarFn->typeVar.hasValue());
Optional<std::string> typeCppValue =
findTypeValue(expression.scalarFn->typeVar.getValue(), args);
if (!typeCppValue) {
emitError(genContext.getLoc())
<< "type variable " << expression.scalarFn->typeVar.getValue()
<< ", used in a type conversion, must map to a predefined or "
<< "an argument type but it does not";
return None;
}
// Use the function name or the attribute to build the type function.
std::string typeFunc = llvm::formatv(
"TypeFn::{0}", expression.scalarFn->fnName.getValueOr(""));
if (expression.scalarFn->attrName) {
if (llvm::none_of(args, [&](LinalgOperandDef &arg) {
return arg.kind == LinalgOperandDefKind::TypeFnAttr &&
arg.name == expression.scalarFn->attrName.getValue();
})) {
emitError(genContext.getLoc())
<< "missing type function attribute "
<< expression.scalarFn->attrName.getValue();
}
typeFunc = llvm::formatv("{0}Val", *expression.scalarFn->attrName);
}
// Call the function builder.
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
stmts.push_back(llvm::formatv(
"Value {0} = helper.buildTypeFn({1}, {2}, {3});", cppIdent,
typeFunc, typeCppValue.getValue(), *operandCppValue));
"Value {0} = helper.build{1}({2}, {3});", cppIdent, enumName,
funcType, interleaveToString(operandCppValues, ", ")));
return cppIdent;
}
emitError(genContext.getLoc()) << "unknown ScalarExpression type";