forked from OSchip/llvm-project
[mlir][OpDSL] Add arithmetic function attributes.
The revision extends OpDSL with unary and binary function attributes. A function attribute, makes the operations used in the body of a structured operation configurable. For example, a pooling operation may take an aggregation function attribute that specifies if the op shall implement a min or a max pooling. The goal of this revision is to define less and more flexible operations. We may thus for example define an element wise op: ``` linalg.elem(lhs, rhs, outs=[out], op=BinaryFn.mul) ``` If the op argument is not set the default operation is used. Depends On D120109 Reviewed By: nicolasvasilache, aartbik Differential Revision: https://reviews.llvm.org/D120110
This commit is contained in:
parent
5d91a8a707
commit
24357fec8d
|
@ -107,12 +107,12 @@ copy_and_scale(val, in_tensor, outs=[out_tensor])
|
|||
|
||||
## Index Attributes
|
||||
|
||||
Attributes are compile-time constant parameters only accessible in index
|
||||
Index attributes are compile-time constant parameters only accessible in index
|
||||
expressions. They can be used to parameterize the access pattern of a structured
|
||||
operation, for example, by setting its strides. They cannot take part in the
|
||||
actual computation.
|
||||
|
||||
The following example demonstrates the use of attributes:
|
||||
The following example demonstrates the use of index attributes:
|
||||
|
||||
```python
|
||||
@linalg_structured_op
|
||||
|
@ -136,9 +136,9 @@ The `strides` vector elements substitute the symbols `S.SH` and `S.SW` in the
|
|||
index expressions of the operation instance. If no strides are provided the
|
||||
`default` vector elements are used instead.
|
||||
|
||||
Attributes are currently limited to integer vectors and only accessible in index
|
||||
expressions. An operation may have multiple attributes all of them placed at the
|
||||
end of the parameter list after the output tensors.
|
||||
Index attributes are currently limited to integer vectors and only accessible in
|
||||
index expressions. An operation may have multiple attributes all of them placed
|
||||
at the end of the parameter list after the output tensors.
|
||||
|
||||
## Shape-Only Tensors
|
||||
|
||||
|
@ -220,6 +220,43 @@ There are also special forms:
|
|||
* `const(value)` returns a constant value.
|
||||
* `index(dim)` returns the iteration index in the given dimension `dim`.
|
||||
|
||||
## Function Attributes
|
||||
|
||||
Function attributes are compile-time constant function parameters. They can be
|
||||
used to parameterize the computation performed by a structured operation, for
|
||||
example, to support signed and unsigned computations.
|
||||
|
||||
The following example demonstrates the use of function attributes:
|
||||
|
||||
```python
|
||||
@linalg_structured_op
|
||||
def elemwise_binary(
|
||||
lhs=TensorDef(T1),
|
||||
rhs=TensorDef(T2),
|
||||
O=TensorDef(U, output=True),
|
||||
fun=BinaryFnAttrDef(default=BinaryFn.add),
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast)):
|
||||
O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None]))
|
||||
```
|
||||
|
||||
The `fun` and `cast` function attributes by default are aliases for their
|
||||
default values `BinaryFn.add` and `TypeFn.cast`, respectively. When
|
||||
instantiating the operation, the function attributes may be set to other
|
||||
functions using optional named arguments:
|
||||
|
||||
```python
|
||||
elemwise_binary(lhs, rhs, outs=[out_tensor],
|
||||
fun=BinaryFn.mul, cast=TypeFn.cast_unsigned)
|
||||
```
|
||||
|
||||
In the example, the `fun` and `cast` arguments adapt the body of the operation
|
||||
to implement multiplication and unsigned casts instead of addition and signed
|
||||
casts.
|
||||
|
||||
OpDSL supports unary, binary, and type conversion function attributes. An
|
||||
operation can take multiple attributes of different kinds placed at the end of
|
||||
the parameter list.
|
||||
|
||||
## Types
|
||||
|
||||
All types in assignment expressions are late bound based on actual input and
|
||||
|
|
|
@ -58,7 +58,26 @@ def Linalg_Dialect : Dialect {
|
|||
}];
|
||||
}
|
||||
|
||||
// Define a TypeFn enum matching the OpDSL TypeFn class.
|
||||
// Define the function attribute enums matching the OpDSL functions.
|
||||
def UnaryFn : I32EnumAttr<"UnaryFn", "", [
|
||||
I32EnumAttrCase<"exp", 0>,
|
||||
I32EnumAttrCase<"log", 1>
|
||||
]> {
|
||||
let genSpecializedAttr = 0;
|
||||
let cppNamespace = "::mlir::linalg";
|
||||
}
|
||||
def BinaryFn : I32EnumAttr<"BinaryFn", "", [
|
||||
I32EnumAttrCase<"add", 0>,
|
||||
I32EnumAttrCase<"mul", 1>,
|
||||
I32EnumAttrCase<"max", 2>,
|
||||
I32EnumAttrCase<"min", 3>,
|
||||
I32EnumAttrCase<"sub", 4>,
|
||||
I32EnumAttrCase<"max_unsigned", 5>,
|
||||
I32EnumAttrCase<"min_unsigned", 6>
|
||||
]> {
|
||||
let genSpecializedAttr = 0;
|
||||
let cppNamespace = "::mlir::linalg";
|
||||
}
|
||||
def TypeFn : I32EnumAttr<"TypeFn", "", [
|
||||
I32EnumAttrCase<"cast", 0>,
|
||||
I32EnumAttrCase<"cast_unsigned", 1>
|
||||
|
@ -67,6 +86,12 @@ def TypeFn : I32EnumAttr<"TypeFn", "", [
|
|||
let cppNamespace = "::mlir::linalg";
|
||||
}
|
||||
|
||||
def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
|
||||
let assemblyFormat = "`<` $value `>`";
|
||||
}
|
||||
def BinaryFnAttr : EnumAttr<Linalg_Dialect, BinaryFn, "binary_fn"> {
|
||||
let assemblyFormat = "`<` $value `>`";
|
||||
}
|
||||
def TypeFnAttr : EnumAttr<Linalg_Dialect, TypeFn, "type_fn"> {
|
||||
let assemblyFormat = "`<` $value `>`";
|
||||
}
|
||||
|
|
|
@ -1,6 +1,120 @@
|
|||
### AUTOGENERATED from core_named_ops.py
|
||||
### To regenerate, run: bin/update_core_linalg_named_ops.sh
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: elemwise_unary
|
||||
cpp_class_name: ElemwiseUnaryOp
|
||||
doc: |-
|
||||
Applies the unary function fun elementwise.
|
||||
|
||||
Numeric casting is performed on the input operand, promoting it to the same
|
||||
data type as the accumulator/output.
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !LinalgOperandDefConfig
|
||||
name: I
|
||||
kind: input_tensor
|
||||
type_var: T1
|
||||
shape_map: affine_map<() -> ()>
|
||||
- !LinalgOperandDefConfig
|
||||
name: O
|
||||
kind: output_tensor
|
||||
type_var: U
|
||||
shape_map: affine_map<() -> ()>
|
||||
- !LinalgOperandDefConfig
|
||||
name: fun
|
||||
kind: unary_fn_attr
|
||||
default_fn: exp
|
||||
- !LinalgOperandDefConfig
|
||||
name: cast
|
||||
kind: type_fn_attr
|
||||
default_fn: cast
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<() -> ()>
|
||||
- affine_map<() -> ()>
|
||||
iterator_types: []
|
||||
assignments:
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: unary
|
||||
attr_name: fun
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
attr_name: cast
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: I
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: elemwise_binary
|
||||
cpp_class_name: ElemwiseBinaryOp
|
||||
doc: |-
|
||||
Applies the binary function fun elementwise.
|
||||
|
||||
Numeric casting is performed on the input operand, promoting it to the same
|
||||
data type as the accumulator/output.
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !LinalgOperandDefConfig
|
||||
name: lhs
|
||||
kind: input_tensor
|
||||
type_var: T1
|
||||
shape_map: affine_map<() -> ()>
|
||||
- !LinalgOperandDefConfig
|
||||
name: rhs
|
||||
kind: input_tensor
|
||||
type_var: T2
|
||||
shape_map: affine_map<() -> ()>
|
||||
- !LinalgOperandDefConfig
|
||||
name: O
|
||||
kind: output_tensor
|
||||
type_var: U
|
||||
shape_map: affine_map<() -> ()>
|
||||
- !LinalgOperandDefConfig
|
||||
name: fun
|
||||
kind: binary_fn_attr
|
||||
default_fn: add
|
||||
- !LinalgOperandDefConfig
|
||||
name: cast
|
||||
kind: type_fn_attr
|
||||
default_fn: cast
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<() -> ()>
|
||||
- affine_map<() -> ()>
|
||||
- affine_map<() -> ()>
|
||||
iterator_types: []
|
||||
assignments:
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: binary
|
||||
attr_name: fun
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
attr_name: cast
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: lhs
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
attr_name: cast
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: rhs
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: matmul
|
||||
cpp_class_name: MatmulOp
|
||||
|
|
|
@ -108,17 +108,9 @@ static LogicalResult foldMemRefCast(Operation *op) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
// Region builder helper.
|
||||
// TODO: Move this to a utility library.
|
||||
// The public methods on this class are referenced directly from generated code
|
||||
// and bind by name to math functions in the DSL as:
|
||||
// `unary__{fnName}`
|
||||
// `binary__{fnName}`
|
||||
// Examples:
|
||||
// `binary__add`
|
||||
// `binary__mul`
|
||||
// `unary__exp`
|
||||
// `unary__log`
|
||||
// The naming convention is intentional in order to match snake-cased DSL names.
|
||||
// See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class.
|
||||
// The public methods on this class are referenced directly from generated code.
|
||||
// Helper build the unary, binary, and type conversion functions defined by the
|
||||
// DSL. See mlir-linalg-ods-yaml-gen.cpp for the code that uses this class.
|
||||
//
|
||||
// Implementations of the math functions must be polymorphic over numeric types,
|
||||
// internally performing necessary casts. If the function application makes no
|
||||
|
@ -142,6 +134,98 @@ public:
|
|||
RegionBuilderHelper(MLIRContext *context, Block &block)
|
||||
: context(context), block(block) {}
|
||||
|
||||
// Build the unary functions defined by OpDSL.
|
||||
Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
|
||||
if (!isFloatingPoint(arg))
|
||||
llvm_unreachable("unsupported non numeric type");
|
||||
OpBuilder builder = getBuilder();
|
||||
switch (unaryFn) {
|
||||
case UnaryFn::exp:
|
||||
return builder.create<math::ExpOp>(arg.getLoc(), arg);
|
||||
case UnaryFn::log:
|
||||
return builder.create<math::LogOp>(arg.getLoc(), arg);
|
||||
}
|
||||
llvm_unreachable("unsupported unary function");
|
||||
}
|
||||
|
||||
// Build the binary functions defined by OpDSL.
|
||||
Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
|
||||
bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
|
||||
bool allInteger = isInteger(arg0) && isInteger(arg1);
|
||||
if (!allFloatingPoint && !allInteger)
|
||||
llvm_unreachable("unsupported non numeric type");
|
||||
OpBuilder builder = getBuilder();
|
||||
switch (binaryFn) {
|
||||
case BinaryFn::add:
|
||||
if (allFloatingPoint)
|
||||
return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
|
||||
return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
|
||||
case BinaryFn::mul:
|
||||
if (allFloatingPoint)
|
||||
return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
|
||||
return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
|
||||
case BinaryFn::max:
|
||||
if (allFloatingPoint)
|
||||
return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
|
||||
return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
|
||||
case BinaryFn::min:
|
||||
if (allFloatingPoint)
|
||||
return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
|
||||
return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
|
||||
case BinaryFn::sub:
|
||||
if (allFloatingPoint)
|
||||
return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
|
||||
return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
|
||||
case BinaryFn::max_unsigned:
|
||||
if (allFloatingPoint)
|
||||
return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
|
||||
return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
|
||||
case BinaryFn::min_unsigned:
|
||||
if (allFloatingPoint)
|
||||
return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
|
||||
return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
|
||||
}
|
||||
llvm_unreachable("unsupported binary function");
|
||||
}
|
||||
|
||||
// Build the type functions defined by OpDSL.
|
||||
Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
|
||||
switch (typeFn) {
|
||||
case TypeFn::cast:
|
||||
return cast(toType, operand, false);
|
||||
case TypeFn::cast_unsigned:
|
||||
return cast(toType, operand, true);
|
||||
}
|
||||
llvm_unreachable("unsupported type conversion function");
|
||||
}
|
||||
|
||||
void yieldOutputs(ValueRange values) {
|
||||
OpBuilder builder = getBuilder();
|
||||
Location loc = builder.getUnknownLoc();
|
||||
builder.create<YieldOp>(loc, values);
|
||||
}
|
||||
|
||||
Value constant(const std::string &value) {
|
||||
OpBuilder builder = getBuilder();
|
||||
Location loc = builder.getUnknownLoc();
|
||||
Attribute valueAttr = parseAttribute(value, builder.getContext());
|
||||
return builder.create<arith::ConstantOp>(loc, valueAttr.getType(),
|
||||
valueAttr);
|
||||
}
|
||||
|
||||
Value index(int64_t dim) {
|
||||
OpBuilder builder = getBuilder();
|
||||
return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
|
||||
}
|
||||
|
||||
Type getIntegerType(unsigned width) {
|
||||
return IntegerType::get(context, width);
|
||||
}
|
||||
|
||||
Type getFloat32Type() { return Float32Type::get(context); }
|
||||
Type getFloat64Type() { return Float64Type::get(context); }
|
||||
|
||||
private:
|
||||
// Generates operations to cast the given operand to a specified type.
|
||||
// If the cast cannot be performed, a warning will be issued and the
|
||||
// operand returned as-is (which will presumably yield a verification
|
||||
|
@ -193,136 +277,6 @@ public:
|
|||
return operand;
|
||||
}
|
||||
|
||||
Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
|
||||
switch (typeFn) {
|
||||
case TypeFn::cast:
|
||||
return cast(toType, operand, false);
|
||||
case TypeFn::cast_unsigned:
|
||||
return cast(toType, operand, true);
|
||||
}
|
||||
llvm_unreachable("unsupported type conversion function");
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||
Value binary__add(Value lhs, Value rhs) {
|
||||
OpBuilder builder = getBuilder();
|
||||
if (isFloatingPoint(lhs))
|
||||
return builder.create<arith::AddFOp>(lhs.getLoc(), lhs, rhs);
|
||||
if (isInteger(lhs))
|
||||
return builder.create<arith::AddIOp>(lhs.getLoc(), lhs, rhs);
|
||||
llvm_unreachable("unsupported non numeric type");
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||
Value unary__exp(Value x) {
|
||||
OpBuilder builder = getBuilder();
|
||||
if (isFloatingPoint(x))
|
||||
return builder.create<math::ExpOp>(x.getLoc(), x);
|
||||
llvm_unreachable("unsupported non numeric type");
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||
Value unary__log(Value x) {
|
||||
OpBuilder builder = getBuilder();
|
||||
if (isFloatingPoint(x))
|
||||
return builder.create<math::LogOp>(x.getLoc(), x);
|
||||
llvm_unreachable("unsupported non numeric type");
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||
Value binary__sub(Value lhs, Value rhs) {
|
||||
OpBuilder builder = getBuilder();
|
||||
if (isFloatingPoint(lhs))
|
||||
return builder.create<arith::SubFOp>(lhs.getLoc(), lhs, rhs);
|
||||
if (isInteger(lhs))
|
||||
return builder.create<arith::SubIOp>(lhs.getLoc(), lhs, rhs);
|
||||
llvm_unreachable("unsupported non numeric type");
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||
Value binary__mul(Value lhs, Value rhs) {
|
||||
OpBuilder builder = getBuilder();
|
||||
if (isFloatingPoint(lhs))
|
||||
return builder.create<arith::MulFOp>(lhs.getLoc(), lhs, rhs);
|
||||
if (isInteger(lhs))
|
||||
return builder.create<arith::MulIOp>(lhs.getLoc(), lhs, rhs);
|
||||
llvm_unreachable("unsupported non numeric type");
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||
Value binary__max(Value lhs, Value rhs) {
|
||||
OpBuilder builder = getBuilder();
|
||||
if (isFloatingPoint(lhs))
|
||||
return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
|
||||
if (isInteger(lhs))
|
||||
return builder.create<arith::MaxSIOp>(lhs.getLoc(), lhs, rhs);
|
||||
llvm_unreachable("unsupported non numeric type");
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||
Value binary__max_unsigned(Value lhs, Value rhs) {
|
||||
OpBuilder builder = getBuilder();
|
||||
if (isFloatingPoint(lhs))
|
||||
return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
|
||||
if (isInteger(lhs))
|
||||
return builder.create<arith::MaxUIOp>(lhs.getLoc(), lhs, rhs);
|
||||
llvm_unreachable("unsupported non numeric type");
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||
Value binary__min(Value lhs, Value rhs) {
|
||||
OpBuilder builder = getBuilder();
|
||||
if (isFloatingPoint(lhs))
|
||||
return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
|
||||
if (isInteger(lhs))
|
||||
return builder.create<arith::MinSIOp>(lhs.getLoc(), lhs, rhs);
|
||||
llvm_unreachable("unsupported non numeric type");
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||
Value binary__min_unsigned(Value lhs, Value rhs) {
|
||||
OpBuilder builder = getBuilder();
|
||||
if (isFloatingPoint(lhs))
|
||||
return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
|
||||
if (isInteger(lhs))
|
||||
return builder.create<arith::MinUIOp>(lhs.getLoc(), lhs, rhs);
|
||||
llvm_unreachable("unsupported non numeric type");
|
||||
}
|
||||
|
||||
void yieldOutputs(ValueRange values) {
|
||||
assert(!values.empty() && "linalg ops must yield outputs");
|
||||
if (values.empty())
|
||||
return;
|
||||
Value first = values.front();
|
||||
OpBuilder builder = getBuilder();
|
||||
builder.create<YieldOp>(first.getLoc(), values);
|
||||
}
|
||||
|
||||
Value constant(const std::string &value) {
|
||||
OpBuilder builder = getBuilder();
|
||||
Location loc = builder.getUnknownLoc();
|
||||
Attribute valueAttr = parseAttribute(value, builder.getContext());
|
||||
return builder.create<arith::ConstantOp>(loc, valueAttr.getType(),
|
||||
valueAttr);
|
||||
}
|
||||
|
||||
Value index(int64_t dim) {
|
||||
OpBuilder builder = getBuilder();
|
||||
return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
|
||||
}
|
||||
|
||||
Type getIntegerType(unsigned width) {
|
||||
return IntegerType::get(context, width);
|
||||
}
|
||||
|
||||
Type getFloat32Type() { return Float32Type::get(context); }
|
||||
|
||||
Type getFloat64Type() { return Float64Type::get(context); }
|
||||
|
||||
private:
|
||||
MLIRContext *context;
|
||||
Block █
|
||||
|
||||
bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
|
||||
bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
|
||||
|
||||
|
@ -331,6 +285,9 @@ private:
|
|||
builder.setInsertionPointToEnd(&block);
|
||||
return builder;
|
||||
}
|
||||
|
||||
MLIRContext *context;
|
||||
Block █
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
|
|
@ -126,7 +126,7 @@ class TensorUse(TensorExpression):
|
|||
return rhs_dims - lhs_dims
|
||||
|
||||
def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn":
|
||||
return ReduceFnUse(BinaryFn.add, *self._compute_reduce_dims(rhs))(rhs)
|
||||
return ReduceFnUse(BinaryFn.add, None, *self._compute_reduce_dims(rhs))(rhs)
|
||||
|
||||
def __repr__(self):
|
||||
return (f"{self.operand_def.name}"
|
||||
|
@ -183,8 +183,14 @@ class TensorReduceFn(TensorExpression):
|
|||
f"bound to its lhs: {self}")
|
||||
full_args = [self.lhs.to_scalar_expression()
|
||||
] + [arg.to_scalar_expression() for arg in self.args]
|
||||
return ScalarFn(FunctionKind.BINARY, self.reduce_use.binary_fn.fn_name,
|
||||
None, None, full_args).expr()
|
||||
fn_name = None
|
||||
attr_name = None
|
||||
if self.reduce_use.binary_fn:
|
||||
fn_name = self.reduce_use.binary_fn.fn_name
|
||||
if self.reduce_use.binary_attr:
|
||||
attr_name = self.reduce_use.binary_attr.operand_def.name
|
||||
return ScalarFn(FunctionKind.BINARY, fn_name, attr_name, None,
|
||||
full_args).expr()
|
||||
|
||||
def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
|
||||
for arg in self.args:
|
||||
|
@ -257,8 +263,8 @@ class UnaryFnType:
|
|||
def __init__(self, fn_name: str):
|
||||
self.fn_name = fn_name
|
||||
|
||||
def __call__(self, exp: TensorExpression) -> "TensorFn":
|
||||
return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [exp])
|
||||
def __call__(self, arg: TensorExpression) -> "TensorFn":
|
||||
return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [arg])
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.fn_name}"
|
||||
|
@ -345,16 +351,21 @@ class ReduceFnUse:
|
|||
A reduction use specifies the reduction function and dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, binary_fn: BinaryFnType, *reduce_dims: DimDef):
|
||||
def __init__(self, binary_fn: Optional[BinaryFnType],
|
||||
binary_attr: Optional["BinaryFnAttrDef"], *reduce_dims: DimDef):
|
||||
if bool(binary_fn) + bool(binary_attr) != 1:
|
||||
raise ValueError("One of 'binary_fn', 'binary_attr' must be specified")
|
||||
self.binary_fn = binary_fn
|
||||
self.binary_attr = binary_attr
|
||||
self.reduce_dims = reduce_dims
|
||||
|
||||
def __call__(self, *args: TensorExpression) -> "TensorReduceFn":
|
||||
return TensorReduceFn(self, args)
|
||||
|
||||
def __repr__(self):
|
||||
return (f"reduce_{self.binary_fn.fn_name}"
|
||||
f"({', '.join(repr(d) for d in self.reduce_dims)})")
|
||||
fn = self.binary_fn if self.binary_fn else self.binary_attr
|
||||
return (
|
||||
f"reduce_{repr(fn)}({', '.join(repr(d) for d in self.reduce_dims)})")
|
||||
|
||||
|
||||
class ReduceFnType:
|
||||
|
@ -369,10 +380,10 @@ class ReduceFnType:
|
|||
self.binary_fn = binary_fn
|
||||
|
||||
def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
|
||||
return ReduceFnUse(self.binary_fn, *reduce_dims)
|
||||
return ReduceFnUse(self.binary_fn, None, *reduce_dims)
|
||||
|
||||
def __repr__(self):
|
||||
return (f"reduce_{self.binary_fn.fn_name}")
|
||||
return f"reduce_{repr(self.binary_fn)}"
|
||||
|
||||
|
||||
class ReduceFn:
|
||||
|
@ -394,7 +405,9 @@ class OperandKind(Enum):
|
|||
SCALAR = 1
|
||||
OUTPUT_TENSOR = 2
|
||||
INDEX_ATTR = 3
|
||||
TYPE_FN_ATTR = 4
|
||||
UNARY_FN_ATTR = 4
|
||||
BINARY_FN_ATTR = 5
|
||||
TYPE_FN_ATTR = 6
|
||||
|
||||
|
||||
class OperandDef:
|
||||
|
@ -441,6 +454,8 @@ class OperandDef:
|
|||
|
||||
def is_attribute(self) -> bool:
|
||||
return (self.kind == OperandKind.INDEX_ATTR or
|
||||
self.kind == OperandKind.UNARY_FN_ATTR or
|
||||
self.kind == OperandKind.BINARY_FN_ATTR or
|
||||
self.kind == OperandKind.TYPE_FN_ATTR)
|
||||
|
||||
def __hash__(self):
|
||||
|
@ -557,6 +572,49 @@ class IndexAttrDef:
|
|||
OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default)
|
||||
|
||||
|
||||
class UnaryFnAttrDef:
|
||||
"""Unary function attribute definition.
|
||||
|
||||
Unary function attributes provide a way to make the arithmetic computation
|
||||
parametrizable. Every attribute specifies a default unary function
|
||||
that may be overwritten at operation instantiation time.
|
||||
"""
|
||||
|
||||
def __init__(self, default: "UnaryFnType"):
|
||||
if not isinstance(default, UnaryFnType):
|
||||
raise ValueError(f"UnaryFnAttrDef requires default of type UnaryFnType "
|
||||
f"but got {default}")
|
||||
self.operand_def = OperandDef(
|
||||
OperandKind.UNARY_FN_ATTR, default_fn=default.fn_name)
|
||||
|
||||
def __call__(self, arg: TensorExpression) -> TensorFn:
|
||||
return TensorFn(FunctionKind.UNARY, None, self.operand_def, None, [arg])
|
||||
|
||||
|
||||
class BinaryFnAttrDef:
|
||||
"""Binary function attribute definition.
|
||||
|
||||
Binary function attributes provide a way to make the arithmetic computation
|
||||
parametrizable. Every attribute specifies a default binary function
|
||||
that may be overwritten at operation instantiation time.
|
||||
"""
|
||||
|
||||
def __init__(self, default: "BinaryFnType"):
|
||||
if not isinstance(default, BinaryFnType):
|
||||
raise ValueError(f"BinaryFnAttrDef requires default of type BinaryFnType "
|
||||
f"but got {default}")
|
||||
self.operand_def = OperandDef(
|
||||
OperandKind.BINARY_FN_ATTR, default_fn=default.fn_name)
|
||||
|
||||
def __call__(self, arg0: TensorExpression,
|
||||
arg1: TensorExpression) -> TensorFn:
|
||||
return TensorFn(FunctionKind.BINARY, None, self.operand_def, None,
|
||||
[arg0, arg1])
|
||||
|
||||
def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
|
||||
return ReduceFnUse(None, self, *reduce_dims)
|
||||
|
||||
|
||||
class TypeFnAttrDef:
|
||||
"""Type conversion function attribute definition.
|
||||
|
||||
|
|
|
@ -309,8 +309,8 @@ class LinalgStructuredOpConfig(YAMLObject):
|
|||
def add_operand(self, operand_def: OperandDef):
|
||||
if operand_def in self.operands:
|
||||
return
|
||||
if (operand_def.kind == OperandKind.SCALAR or
|
||||
operand_def.kind == OperandKind.TYPE_FN_ATTR):
|
||||
if not (operand_def.is_tensor() or
|
||||
operand_def.kind == OperandKind.INDEX_ATTR):
|
||||
self.operands[operand_def] = OperandDefConfig(operand_def)
|
||||
return
|
||||
with self.context:
|
||||
|
|
|
@ -130,7 +130,8 @@ def linalg_structured_op(dsl_func=None,
|
|||
for param_name, param in sig.parameters.items():
|
||||
param_default = param.default
|
||||
if isinstance(param_default,
|
||||
(TensorDef, ScalarDef, IndexAttrDef, TypeFnAttrDef)):
|
||||
(TensorDef, ScalarDef, IndexAttrDef, UnaryFnAttrDef,
|
||||
BinaryFnAttrDef, TypeFnAttrDef)):
|
||||
op_def.add_operand(param_name, param_default.operand_def)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
|
|
@ -41,7 +41,7 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
|
|||
all_arg_defs = op_config.ordered_operands
|
||||
in_arg_defs = [
|
||||
d for d in all_arg_defs
|
||||
if d.kind == OperandKind.SCALAR or d.kind == OperandKind.INPUT_TENSOR
|
||||
if d.kind in [OperandKind.SCALAR, OperandKind.INPUT_TENSOR]
|
||||
]
|
||||
out_arg_defs = [
|
||||
d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR
|
||||
|
@ -49,8 +49,11 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
|
|||
index_attr_arg_defs = [
|
||||
d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR
|
||||
]
|
||||
type_fn_attr_arg_defs = [
|
||||
d for d in all_arg_defs if d.kind == OperandKind.TYPE_FN_ATTR
|
||||
fn_attr_arg_defs = [
|
||||
d for d in all_arg_defs if d.kind in [
|
||||
OperandKind.UNARY_FN_ATTR, OperandKind.BINARY_FN_ATTR,
|
||||
OperandKind.TYPE_FN_ATTR
|
||||
]
|
||||
]
|
||||
|
||||
# Verify outs is a sequence or a list of results.
|
||||
|
@ -135,28 +138,38 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
|
|||
array = np.array(index_attr_vals, dtype=np.int64)
|
||||
index_attrs[index_attr.name] = DenseElementsAttr.get(array)
|
||||
|
||||
# Compute the type function attribute mapping.
|
||||
type_fn_attr_mapping = {}
|
||||
for type_fn_attr in type_fn_attr_arg_defs:
|
||||
attr_val = type_fn_attr.operand_def.default_fn
|
||||
if type_fn_attr.name in attrs:
|
||||
type_fn = attrs.get(type_fn_attr.name)
|
||||
if not isinstance(type_fn, TypeFnType):
|
||||
raise ValueError(f"Attribute {type_fn_attr.name} needs to be of type "
|
||||
f"TypeFnType but got {type(attr_val)}")
|
||||
attr_val = type_fn.fn_name
|
||||
assert attr_val, "Type function attribute has no value"
|
||||
type_fn_attr_mapping[type_fn_attr.name] = attr_val
|
||||
# Compute the function attribute mapping.
|
||||
fn_attr_mapping = {}
|
||||
for fn_attr in fn_attr_arg_defs:
|
||||
attr_val = fn_attr.operand_def.default_fn
|
||||
attr_kind = fn_attr.kind
|
||||
if fn_attr.name in attrs:
|
||||
fn = attrs.get(fn_attr.name)
|
||||
if attr_kind == OperandKind.UNARY_FN_ATTR:
|
||||
if not isinstance(fn, UnaryFnType):
|
||||
raise ValueError(f"Attribute {fn_attr.name} needs to be of type "
|
||||
f"UnaryFnType but got {type(attr_val)}")
|
||||
elif attr_kind == OperandKind.BINARY_FN_ATTR:
|
||||
if not isinstance(fn, BinaryFnType):
|
||||
raise ValueError(f"Attribute {fn_attr.name} needs to be of type "
|
||||
f"BinaryFnType but got {type(attr_val)}")
|
||||
else:
|
||||
if not isinstance(fn, TypeFnType):
|
||||
raise ValueError(f"Attribute {fn_attr.name} needs to be of type "
|
||||
f"TypeFnType but got {type(attr_val)}")
|
||||
attr_val = fn.fn_name
|
||||
assert attr_val, "Function attribute has no value"
|
||||
fn_attr_mapping[fn_attr.name] = (attr_val, attr_kind)
|
||||
|
||||
return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types,
|
||||
type_mapping, indexing_maps_attr, iterator_types_attr, index_attrs,
|
||||
type_fn_attr_mapping, block_arg_types)
|
||||
fn_attr_mapping, block_arg_types)
|
||||
|
||||
|
||||
def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
|
||||
outs: ValueList, **attrs: Sequence[int]):
|
||||
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
|
||||
indexing_maps_attr, iterator_types_attr, index_attrs, type_fn_attr_mapping, \
|
||||
indexing_maps_attr, iterator_types_attr, index_attrs, fn_attr_mapping, \
|
||||
block_arg_types = \
|
||||
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
|
||||
|
||||
|
@ -193,7 +206,7 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
|
|||
block_arg_mapping = dict(zip(block_arg_names, block.arguments))
|
||||
with InsertionPoint(block):
|
||||
body_builder = _BodyBuilder(type_mapping, block_arg_mapping,
|
||||
type_fn_attr_mapping)
|
||||
fn_attr_mapping)
|
||||
for assignment in op_config.assignments:
|
||||
body_builder.assign(assignment)
|
||||
body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs))
|
||||
|
@ -208,7 +221,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
|
|||
op_class_name: str, *ins: Value, outs: ValueList,
|
||||
**attrs: Sequence[int]):
|
||||
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
|
||||
indexing_maps_attr, iterator_types_attr, index_attrs, type_fn_attr_mapping, \
|
||||
indexing_maps_attr, iterator_types_attr, index_attrs, fn_attr_mapping, \
|
||||
block_arg_types = \
|
||||
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
|
||||
|
||||
|
@ -225,10 +238,12 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
|
|||
for name, value in index_attrs.items():
|
||||
named_op.operation.attributes[name] = value
|
||||
|
||||
# Set the type function attributes.
|
||||
for name, value in type_fn_attr_mapping.items():
|
||||
# Compute the function attributes by combining operand kind and function name.
|
||||
for name, (fn_name, kind) in fn_attr_mapping.items():
|
||||
assert kind.name.lower().endswith("_attr")
|
||||
enum_name = kind.name.lower()[:-5]
|
||||
named_op.operation.attributes[name] = Attribute.parse(
|
||||
f"#linalg.type_fn<{value}>")
|
||||
f"#linalg.{enum_name}<{fn_name}>")
|
||||
|
||||
linalg.fill_builtin_region(named_op.operation)
|
||||
|
||||
|
@ -242,11 +257,11 @@ class _BodyBuilder:
|
|||
"""Constructs a structured op body by evaluating assignments."""
|
||||
|
||||
def __init__(self, type_mapping: Dict[str, Type],
|
||||
block_arg_mapping: Dict[str, Value],
|
||||
type_fn_attr_mapping: Dict[str, str]):
|
||||
block_arg_mapping: Dict[str, Value], fn_attr_mapping: Dict[str,
|
||||
str]):
|
||||
self.type_mapping = type_mapping
|
||||
self.block_arg_mapping = block_arg_mapping
|
||||
self.type_fn_attr_mapping = type_fn_attr_mapping
|
||||
self.fn_attr_mapping = fn_attr_mapping
|
||||
self.yield_mapping = dict() # type: Dict[str, Value]
|
||||
|
||||
def assign(self, assignment: ScalarAssign):
|
||||
|
@ -270,21 +285,18 @@ class _BodyBuilder:
|
|||
dim_attr = IntegerAttr.get(
|
||||
IntegerType.get_signless(64), expr.scalar_index.dim)
|
||||
return linalg.IndexOp(dim_attr).result
|
||||
elif expr.scalar_fn and expr.scalar_fn.kind is not FunctionKind.TYPE:
|
||||
kind = expr.scalar_fn.kind.name.lower()
|
||||
fn = self._get_function(f"_{kind}_{expr.scalar_fn.fn_name}")
|
||||
operand_values = [
|
||||
self.expression(operand) for operand in expr.scalar_fn.operands
|
||||
]
|
||||
return fn(*operand_values)
|
||||
elif expr.scalar_fn and expr.scalar_fn.kind is FunctionKind.TYPE:
|
||||
elif expr.scalar_fn:
|
||||
kind = expr.scalar_fn.kind.name.lower()
|
||||
fn_name = expr.scalar_fn.fn_name
|
||||
if expr.scalar_fn.attr_name:
|
||||
fn_name = self.type_fn_attr_mapping[expr.scalar_fn.attr_name]
|
||||
fn_name, _ = self.fn_attr_mapping[expr.scalar_fn.attr_name]
|
||||
fn = self._get_function(f"_{kind}_{fn_name}")
|
||||
operand_value = self.expression(expr.scalar_fn.operands[0])
|
||||
return fn(expr.scalar_fn.type_var.name, operand_value)
|
||||
operand_values = [
|
||||
self.expression(operand) for operand in expr.scalar_fn.operands
|
||||
]
|
||||
if expr.scalar_fn.kind == FunctionKind.TYPE:
|
||||
operand_values = [expr.scalar_fn.type_var.name] + operand_values
|
||||
return fn(*operand_values)
|
||||
raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
|
||||
|
||||
def yield_outputs(self, *output_names: str):
|
||||
|
|
|
@ -6,6 +6,35 @@ T2 = TV.T2
|
|||
Batch = S.Batch
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def elemwise_unary(
|
||||
I=TensorDef(T1),
|
||||
O=TensorDef(U, output=True),
|
||||
fun=UnaryFnAttrDef(default=UnaryFn.exp),
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast)):
|
||||
"""Applies the unary function fun elementwise.
|
||||
|
||||
Numeric casting is performed on the input operand, promoting it to the same
|
||||
data type as the accumulator/output.
|
||||
"""
|
||||
O[None] = fun(cast(U, I[None]))
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def elemwise_binary(
|
||||
lhs=TensorDef(T1),
|
||||
rhs=TensorDef(T2),
|
||||
O=TensorDef(U, output=True),
|
||||
fun=BinaryFnAttrDef(default=BinaryFn.add),
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast)):
|
||||
"""Applies the binary function fun elementwise.
|
||||
|
||||
Numeric casting is performed on the input operand, promoting it to the same
|
||||
data type as the accumulator/output.
|
||||
"""
|
||||
O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None]))
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def matmul(
|
||||
A=TensorDef(T1, S.M, S.K),
|
||||
|
|
|
@ -292,16 +292,48 @@ func @generalize_fill_rng_2d_i32(%min: f64, %max: f64, %seed: i32, %O: tensor<16
|
|||
|
||||
// -----
|
||||
|
||||
func @generalize_soft_plus_2d_f32(%input: tensor<16x32xf32>, %output: tensor<16x32xf32>) -> tensor<16x32xf32> {
|
||||
%0 = linalg.soft_plus_2d ins(%input: tensor<16x32xf32>) outs(%output: tensor<16x32xf32>) -> tensor<16x32xf32>
|
||||
return %0: tensor<16x32xf32>
|
||||
// Verifies the default value of the fun attribute is an exp op.
|
||||
func @generalize_elemwise_exp(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
|
||||
%0 = linalg.elemwise_unary ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
|
||||
return %0: tensor<4x8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @generalize_soft_plus_2d_f32
|
||||
// CHECK: %[[C1:.+]] = arith.constant 1.000000e+00 : f32
|
||||
// CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32
|
||||
// CHECK-NEXT: %[[EXP:.+]] = math.exp %[[IN]] : f32
|
||||
// CHECK-NEXT: %[[SUM:.+]] = arith.addf %[[EXP]], %[[C1]] : f32
|
||||
// CHECK-NEXT: %[[LOG:.+]] = math.log %[[SUM]] : f32
|
||||
// CHECK-NEXT: linalg.yield %[[LOG]] : f32
|
||||
// CHECK-NEXT: -> tensor<16x32xf32>
|
||||
// CHECK-LABEL: @generalize_elemwise_exp
|
||||
// CHECK: = math.exp
|
||||
|
||||
// -----
|
||||
|
||||
// Verifies the fun attribute controls the unary function used.
|
||||
func @generalize_elemwise_log(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
|
||||
%0 = linalg.elemwise_unary {fun = #linalg.unary_fn<log>}
|
||||
ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
|
||||
return %0: tensor<4x8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @generalize_elemwise_log
|
||||
// CHECK: = math.log
|
||||
|
||||
// -----
|
||||
|
||||
// Verifies the default value of the fun attribute is an add op.
|
||||
func @generalize_elemwise_add(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
|
||||
%0 = linalg.elemwise_binary ins(%lhs, %rhs: tensor<4x8xf32>, tensor<4x8xf32>)
|
||||
outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
|
||||
return %0: tensor<4x8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @generalize_elemwise_add
|
||||
// CHECK: = arith.addf
|
||||
|
||||
// -----
|
||||
|
||||
// Verifies the fun attribute controls the binary function used.
|
||||
func @generalize_elemwise_mul(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
|
||||
%0 = linalg.elemwise_binary {fun = #linalg.binary_fn<mul>}
|
||||
ins(%lhs, %rhs: tensor<4x8xf32>, tensor<4x8xf32>)
|
||||
outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
|
||||
return %0: tensor<4x8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @generalize_elemwise_mul
|
||||
// CHECK: = arith.mulf
|
||||
|
|
|
@ -111,7 +111,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL0]]);
|
||||
# IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1);
|
||||
# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL2]]);
|
||||
# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.binary__add([[VAL1]], [[VAL3]]);
|
||||
# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.buildBinaryFn(BinaryFn::add, [[VAL1]], [[VAL3]]);
|
||||
|
||||
|
||||
# @linalg_structured_op
|
||||
|
@ -255,14 +255,15 @@ structured_op: !LinalgStructuredOpConfig
|
|||
# IMPL-NEXT: AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context);
|
||||
# IMPL-NEXT: AffineMap tensorMap = AffineMap::getMultiDimIdentityMap(
|
||||
|
||||
|
||||
# @linalg_structured_op
|
||||
# def test4(O=TensorDef(T, S.M, S.N, output=True)):
|
||||
# def test4(O=TensorDef(T, S.M, S.N, output=True),
|
||||
# unary_fun=UnaryFnAttrDef(default=UnaryFn.exp),
|
||||
# binary_fun=BinaryFnAttrDef(default=BinaryFn.add)):
|
||||
# """Title.
|
||||
|
||||
# Detailed description.
|
||||
# """
|
||||
# O[D.m, D.n] = BinaryFn.add(UnaryFn.exp(O[D.m, D.n]), O[D.m, D.n])
|
||||
# O[D.m, D.n] = binary_fun(unary_fun(O[D.m, D.n]), O[D.m, D.n])
|
||||
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
|
@ -279,6 +280,14 @@ structured_op: !LinalgStructuredOpConfig
|
|||
kind: output_tensor
|
||||
type_var: T
|
||||
shape_map: affine_map<()[s0, s1] -> (s0, s1)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: unary_fun
|
||||
kind: unary_fn_attr
|
||||
default_fn: exp
|
||||
- !LinalgOperandDefConfig
|
||||
name: binary_fun
|
||||
kind: binary_fn_attr
|
||||
default_fn: add
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
|
||||
|
@ -291,21 +300,36 @@ structured_op: !LinalgStructuredOpConfig
|
|||
value: !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: binary
|
||||
fn_name: add
|
||||
attr_name: binary_fun
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: unary
|
||||
fn_name: exp
|
||||
attr_name: unary_fun
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
|
||||
# ODS-LABEL: def Test4Op : LinalgStructuredBase_Op<"test4"
|
||||
|
||||
# ODS: let arguments =
|
||||
# ODS-NEXT: Variadic<AnyType>:$inputs,
|
||||
# ODS-NEXT: Variadic<AnyShaped>:$outputs,
|
||||
# ODS-NEXT: DefaultValuedAttr<UnaryFnAttr, "UnaryFn::exp">:$unary_fun,
|
||||
# ODS-NEXT: DefaultValuedAttr<BinaryFnAttr, "BinaryFn::add">:$binary_fun
|
||||
|
||||
# ODS: "Attribute":$unary_fun, "Attribute":$binary_fun,
|
||||
|
||||
# ODS: $_state.addAttribute("unary_fun", unary_fun)
|
||||
# ODS-NEXT: $_state.addAttribute("binary_fun", binary_fun)
|
||||
|
||||
# IMPL-LABEL: void Test4Op::regionBuilder(ImplicitLocOpBuilder &b,
|
||||
# IMPL-NEXT: Block &block, ArrayRef<NamedAttribute> attrs)
|
||||
# IMPL: UnaryFn unary_funVal = UnaryFn::exp
|
||||
# IMPL: BinaryFn binary_funVal = BinaryFn::add
|
||||
|
||||
# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.unary__exp(block.getArgument(0))
|
||||
# IMPL-NEXT: Value [[VAL1:[a-z0-9]+]] = helper.binary__add([[VAL0]], block.getArgument(0))
|
||||
# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.buildUnaryFn(unary_funVal, block.getArgument(0))
|
||||
# IMPL-NEXT: Value [[VAL1:[a-z0-9]+]] = helper.buildBinaryFn(binary_funVal, [[VAL0]], block.getArgument(0))
|
||||
# IMPL-NEXT: yields.push_back([[VAL1]])
|
||||
|
|
|
@ -18,6 +18,12 @@ from mlir.dialects.linalg.opdsl.lang import *
|
|||
# CHECK: kind: output_tensor
|
||||
# CHECK: type_var: U
|
||||
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
|
||||
# CHECK: name: bfn
|
||||
# CHECK: kind: binary_fn_attr
|
||||
# CHECK: default_fn: mul
|
||||
# CHECK: name: ufn
|
||||
# CHECK: kind: unary_fn_attr
|
||||
# CHECK: default_fn: exp
|
||||
# CHECK: name: cast
|
||||
# CHECK: kind: type_fn_attr
|
||||
# CHECK: default_fn: cast
|
||||
|
@ -26,8 +32,10 @@ def matmul(
|
|||
A=TensorDef(T, S.M, S.K),
|
||||
B=TensorDef(T, S.K, S.N),
|
||||
C=TensorDef(U, S.M, S.N, output=True),
|
||||
bfn=BinaryFnAttrDef(default=BinaryFn.mul),
|
||||
ufn=UnaryFnAttrDef(default=UnaryFn.exp),
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast)):
|
||||
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
|
||||
C[D.m, D.n] += bfn(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
|
||||
|
||||
|
||||
# CHECK: ---
|
||||
|
|
|
@ -10,10 +10,12 @@ from mlir.dialects.linalg.opdsl.lang import *
|
|||
# CHECK: arg: C
|
||||
# CHECK: value:
|
||||
# CHECK: scalar_fn:
|
||||
# CHECK: kind: binary
|
||||
# CHECK: fn_name: add
|
||||
# CHECK: operands:
|
||||
# CHECK: scalar_fn:
|
||||
# CHECK: fn_name: mul
|
||||
# CHECK: kind: binary
|
||||
# CHECK: attr_name: mul
|
||||
# CHECK: operands:
|
||||
# CHECK: scalar_fn:
|
||||
# CHECK: kind: type
|
||||
|
@ -32,8 +34,9 @@ def matmul(
|
|||
A=TensorDef(T, S.M, S.K),
|
||||
B=TensorDef(T, S.K, S.N),
|
||||
C=TensorDef(U, S.M, S.N, output=True),
|
||||
mul=BinaryFnAttrDef(default=BinaryFn.mul),
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast)):
|
||||
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
|
||||
C[D.m, D.n] += mul(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
|
||||
|
||||
|
||||
# CHECK: ---
|
||||
|
@ -69,14 +72,21 @@ def matmul(
|
|||
# CHECK: fn_name: cast
|
||||
# CHECK: type_var: T
|
||||
# CHECK: operands:
|
||||
# CHECK: scalar_const: '1.{{[0]*}}e+03 : f64'
|
||||
# CHECK: scalar_fn:
|
||||
# CHECK: kind: unary
|
||||
# CHECK: attr_name: exp
|
||||
# CHECK: operands:
|
||||
# CHECK: scalar_const: '1.{{[0]*}}e+03 : f64'
|
||||
@linalg_structured_op
|
||||
def constants(O=TensorDef(T, S.M, S.K, output=True)):
|
||||
def constants(
|
||||
O=TensorDef(T, S.M, S.K, output=True),
|
||||
exp=UnaryFnAttrDef(default=UnaryFn.exp)):
|
||||
pi = TypeFn.cast(T, const(3.1415926535897931))
|
||||
cst42 = TypeFn.cast(T, const(42))
|
||||
cst1000 = TypeFn.cast(T, const(1e+3))
|
||||
cst1000 = TypeFn.cast(T, exp(const(1e+3)))
|
||||
O[D.m, D.n] = UnaryFn.exp(pi) + cst42 - cst1000
|
||||
|
||||
|
||||
# CHECK: ---
|
||||
# CHECK-LABEL: indices
|
||||
# CHECK: assignments:
|
||||
|
|
|
@ -12,55 +12,18 @@ T2 = TV.T2
|
|||
|
||||
|
||||
@linalg_structured_op
|
||||
def pooling_max_poly(
|
||||
def pooling_poly(
|
||||
I=TensorDef(T1, S.N, S.H, S.W, S.C),
|
||||
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
|
||||
reduce=BinaryFnAttrDef(default=BinaryFn.max),
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast),
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
|
||||
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw](
|
||||
TypeFn.cast(
|
||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def pooling_max_unsigned_poly(
|
||||
I=TensorDef(T1, S.N, S.H, S.W, S.C),
|
||||
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
|
||||
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw](
|
||||
TypeFn.cast_unsigned(
|
||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def pooling_min_poly(
|
||||
I=TensorDef(T1, S.N, S.H, S.W, S.C),
|
||||
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
|
||||
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw](
|
||||
TypeFn.cast(
|
||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def pooling_min_unsigned_poly(
|
||||
I=TensorDef(T1, S.N, S.H, S.W, S.C),
|
||||
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
|
||||
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw](
|
||||
TypeFn.cast_unsigned(
|
||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
|
||||
O[D.n, D.oh, D.ow, D.c] = reduce[D.kh, D.kw](
|
||||
cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
||||
D.c]))
|
||||
|
||||
|
||||
with Context() as ctx, Location.unknown():
|
||||
|
@ -88,7 +51,7 @@ with Context() as ctx, Location.unknown():
|
|||
RankedTensorType.get((2, 2), f32),
|
||||
RankedTensorType.get((1, 2, 4, 1), i32))
|
||||
def test_f32i32_max_pooling(input, shape, init_result):
|
||||
return pooling_max_poly(
|
||||
return pooling_poly(
|
||||
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
|
||||
|
||||
# CHECK-LABEL: @test_f32i32_max_unsigned_pooling
|
||||
|
@ -99,8 +62,14 @@ with Context() as ctx, Location.unknown():
|
|||
RankedTensorType.get((2, 2), f32),
|
||||
RankedTensorType.get((1, 2, 4, 1), i32))
|
||||
def test_f32i32_max_unsigned_pooling(input, shape, init_result):
|
||||
return pooling_max_unsigned_poly(
|
||||
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
|
||||
return pooling_poly(
|
||||
input,
|
||||
shape,
|
||||
outs=[init_result],
|
||||
reduce=BinaryFn.max_unsigned,
|
||||
cast=TypeFn.cast_unsigned,
|
||||
strides=[2, 4],
|
||||
dilations=[1, 2])
|
||||
|
||||
# CHECK-LABEL: @test_f32f32_max_pooling
|
||||
# CHECK: linalg.generic
|
||||
|
@ -115,7 +84,7 @@ with Context() as ctx, Location.unknown():
|
|||
RankedTensorType.get((2, 2), f32),
|
||||
RankedTensorType.get((1, 2, 4, 1), f32))
|
||||
def test_f32f32_max_pooling(input, shape, init_result):
|
||||
return pooling_max_poly(
|
||||
return pooling_poly(
|
||||
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
|
||||
|
||||
# CHECK-LABEL: @test_f32i32_min_pooling
|
||||
|
@ -126,8 +95,13 @@ with Context() as ctx, Location.unknown():
|
|||
RankedTensorType.get((2, 2), f32),
|
||||
RankedTensorType.get((1, 2, 4, 1), i32))
|
||||
def test_f32i32_min_pooling(input, shape, init_result):
|
||||
return pooling_min_poly(
|
||||
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
|
||||
return pooling_poly(
|
||||
input,
|
||||
shape,
|
||||
outs=[init_result],
|
||||
reduce=BinaryFn.min,
|
||||
strides=[2, 4],
|
||||
dilations=[1, 2])
|
||||
|
||||
# CHECK-LABEL: @test_f32i32_min_unsigned_pooling
|
||||
# CHECK: = arith.fptoui
|
||||
|
@ -137,8 +111,14 @@ with Context() as ctx, Location.unknown():
|
|||
RankedTensorType.get((2, 2), f32),
|
||||
RankedTensorType.get((1, 2, 4, 1), i32))
|
||||
def test_f32i32_min_unsigned_pooling(input, shape, init_result):
|
||||
return pooling_min_unsigned_poly(
|
||||
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
|
||||
return pooling_poly(
|
||||
input,
|
||||
shape,
|
||||
outs=[init_result],
|
||||
reduce=BinaryFn.min_unsigned,
|
||||
cast=TypeFn.cast_unsigned,
|
||||
strides=[2, 4],
|
||||
dilations=[1, 2])
|
||||
|
||||
# CHECK-LABEL: @test_f32f32_min_pooling
|
||||
# CHECK: = arith.minf
|
||||
|
@ -147,8 +127,13 @@ with Context() as ctx, Location.unknown():
|
|||
RankedTensorType.get((2, 2), f32),
|
||||
RankedTensorType.get((1, 2, 4, 1), f32))
|
||||
def test_f32f32_min_pooling(input, shape, init_result):
|
||||
return pooling_min_poly(
|
||||
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
|
||||
return pooling_poly(
|
||||
input,
|
||||
shape,
|
||||
outs=[init_result],
|
||||
reduce=BinaryFn.min,
|
||||
strides=[2, 4],
|
||||
dilations=[1, 2])
|
||||
|
||||
|
||||
print(module)
|
||||
|
|
|
@ -94,20 +94,27 @@ def testNamedStructuredOpCustomForm():
|
|||
with InsertionPoint(module.body):
|
||||
|
||||
@builtin.FuncOp.from_py_func(
|
||||
RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
|
||||
f32))
|
||||
RankedTensorType.get((4, 8), f32), RankedTensorType.get((4, 8), f32))
|
||||
def named_form(lhs, rhs):
|
||||
init_result = linalg.InitTensorOp([4, 8], f32)
|
||||
# First check the named form with custom format
|
||||
# CHECK: linalg.matmul
|
||||
# CHECK: cast = #linalg.type_fn<cast_unsigned>
|
||||
# CHECK-NOT: linalg.memoized_indexing_maps
|
||||
# CHECK-SAME: ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>)
|
||||
# CHECK-SAME: outs(%{{.*}} : tensor<4x8xf32>)
|
||||
# CHECK-SAME: -> tensor<4x8xf32>
|
||||
# CHECK-NEXT: return
|
||||
return linalg.matmul(
|
||||
lhs, rhs, outs=[init_result.result], cast=TypeFn.cast_unsigned)
|
||||
# Check for the named form with custom format
|
||||
# CHECK: linalg.elemwise_unary
|
||||
# CHECK-SAME: cast = #linalg.type_fn<cast>
|
||||
# CHECK-SAME: fun = #linalg.unary_fn<exp>
|
||||
# CHECK-SAME: ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
|
||||
unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result])
|
||||
# CHECK: linalg.elemwise_binary
|
||||
# CHECK-SAME: cast = #linalg.type_fn<cast_unsigned>
|
||||
# CHECK-SAME: fun = #linalg.binary_fn<mul>
|
||||
# CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<4x8xf32>, tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
|
||||
# CHECK: return
|
||||
binary_result = linalg.elemwise_binary(
|
||||
lhs,
|
||||
rhs,
|
||||
outs=[init_result.result],
|
||||
fun=BinaryFn.mul,
|
||||
cast=TypeFn.cast_unsigned)
|
||||
return unary_result, binary_result
|
||||
|
||||
print(module)
|
||||
|
||||
|
@ -130,7 +137,8 @@ def testNamedStructuredOpGenericForm():
|
|||
# CHECK-NEXT: arith.mulf{{.*}} (f32, f32) -> f32
|
||||
# CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32
|
||||
# CHECK-NEXT: linalg.yield{{.*}} (f32) -> ()
|
||||
# CHECK-NEXT: operand_segment_sizes = dense<[2, 1]> : vector<2xi32>
|
||||
# CHECK-NEXT: cast = #linalg.type_fn<cast>
|
||||
# CHECK-SAME: operand_segment_sizes = dense<[2, 1]> : vector<2xi32>
|
||||
# CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
|
||||
return linalg.matmul(lhs, rhs, outs=[init_result.result])
|
||||
|
||||
|
|
|
@ -19,6 +19,37 @@ def log(*args):
|
|||
sys.stderr.flush()
|
||||
|
||||
|
||||
elemwise_boiler = """
|
||||
func @main() -> f32 attributes {llvm.emit_c_interface} {
|
||||
%v0 = arith.constant 0.0 : f32
|
||||
%v1 = arith.constant 1.0 : f32
|
||||
%v2 = arith.constant 2.0 : f32
|
||||
|
||||
%lhs = memref.alloc() : memref<4x8xf32>
|
||||
%rhs = memref.alloc() : memref<4x8xf32>
|
||||
%O0 = memref.alloc() : memref<4x8xf32>
|
||||
%O1 = memref.alloc() : memref<4x8xf32>
|
||||
linalg.fill(%v1, %lhs) : f32, memref<4x8xf32>
|
||||
linalg.fill(%v2, %rhs) : f32, memref<4x8xf32>
|
||||
linalg.fill(%v0, %O0) : f32, memref<4x8xf32>
|
||||
linalg.fill(%v0, %O1) : f32, memref<4x8xf32>
|
||||
|
||||
call @elemwise_exp_add_on_buffers(%lhs, %rhs, %O0) :
|
||||
(memref<4x8xf32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
|
||||
call @elemwise_log_mul_on_buffers(%lhs, %rhs, %O1) :
|
||||
(memref<4x8xf32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
|
||||
|
||||
%c0 = arith.constant 0 : index
|
||||
%res0 = memref.load %O0[%c0, %c0] : memref<4x8xf32>
|
||||
%res1 = memref.load %O1[%c0, %c0] : memref<4x8xf32>
|
||||
|
||||
%0 = arith.addf %res0, %res1 : f32
|
||||
|
||||
// TODO: FFI-based solution to allow testing and printing with python code.
|
||||
return %0 : f32
|
||||
}
|
||||
"""
|
||||
|
||||
matmul_boiler = """
|
||||
func @main() -> f32 attributes {llvm.emit_c_interface} {
|
||||
%v0 = arith.constant 0.0 : f32
|
||||
|
@ -166,13 +197,93 @@ def transform(module, boilerplate):
|
|||
|
||||
pm = PassManager.parse(
|
||||
"builtin.func(convert-linalg-to-loops, lower-affine, " +
|
||||
"convert-scf-to-cf, arith-expand, memref-expand), convert-vector-to-llvm,"
|
||||
+ "convert-memref-to-llvm, convert-std-to-llvm," +
|
||||
"convert-math-to-llvm, convert-scf-to-cf, arith-expand, memref-expand), "
|
||||
+ "convert-vector-to-llvm, convert-memref-to-llvm, convert-std-to-llvm," +
|
||||
"reconcile-unrealized-casts")
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def test_elemwise_builtin():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
f32 = F32Type.get()
|
||||
i8 = IntegerType.get_signless(8)
|
||||
with InsertionPoint(module.body):
|
||||
|
||||
@builtin.FuncOp.from_py_func(
|
||||
MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
|
||||
MemRefType.get((4, 8), f32))
|
||||
def elemwise_exp_add_on_buffers(lhs, rhs, out):
|
||||
linalg.elemwise_unary(lhs, outs=[out])
|
||||
linalg.elemwise_binary(out, rhs, outs=[out])
|
||||
|
||||
@builtin.FuncOp.from_py_func(
|
||||
MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
|
||||
MemRefType.get((4, 8), f32))
|
||||
def elemwise_log_mul_on_buffers(lhs, rhs, out):
|
||||
linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log)
|
||||
linalg.elemwise_binary(out, rhs, outs=[out], fun=BinaryFn.mul)
|
||||
|
||||
execution_engine = ExecutionEngine(transform(module, elemwise_boiler))
|
||||
|
||||
# TODO: FFI-based solution to allow testing and printing with python code.
|
||||
# Prepare arguments: one result f32.
|
||||
# Arguments must be passed as pointers.
|
||||
c_float_p = ctypes.c_float * 1
|
||||
res = c_float_p(-1.)
|
||||
execution_engine.invoke("main", res)
|
||||
|
||||
log("RESULT: ", res[0])
|
||||
# elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
|
||||
# elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
|
||||
# CHECK: RESULT: 4.71828
|
||||
|
||||
|
||||
test_elemwise_builtin()
|
||||
|
||||
|
||||
def test_elemwise_generic():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
f32 = F32Type.get()
|
||||
i8 = IntegerType.get_signless(8)
|
||||
with InsertionPoint(module.body):
|
||||
|
||||
@builtin.FuncOp.from_py_func(
|
||||
MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
|
||||
MemRefType.get((4, 8), f32))
|
||||
def elemwise_exp_add_on_buffers(lhs, rhs, out):
|
||||
linalg.elemwise_unary(lhs, outs=[out], emit_generic=True)
|
||||
linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True)
|
||||
|
||||
@builtin.FuncOp.from_py_func(
|
||||
MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
|
||||
MemRefType.get((4, 8), f32))
|
||||
def elemwise_log_mul_on_buffers(lhs, rhs, out):
|
||||
linalg.elemwise_unary(
|
||||
lhs, outs=[out], fun=UnaryFn.log, emit_generic=True)
|
||||
linalg.elemwise_binary(
|
||||
out, rhs, outs=[out], fun=BinaryFn.mul, emit_generic=True)
|
||||
|
||||
execution_engine = ExecutionEngine(transform(module, elemwise_boiler))
|
||||
|
||||
# TODO: FFI-based solution to allow testing and printing with python code.
|
||||
# Prepare arguments: one result f32.
|
||||
# Arguments must be passed as pointers.
|
||||
c_float_p = ctypes.c_float * 1
|
||||
res = c_float_p(-1.)
|
||||
execution_engine.invoke("main", res)
|
||||
|
||||
log("RESULT: ", res[0])
|
||||
# elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
|
||||
# elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
|
||||
# CHECK: RESULT: 4.71828
|
||||
|
||||
|
||||
test_elemwise_generic()
|
||||
|
||||
|
||||
def test_matmul_builtin():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
|
|
|
@ -66,6 +66,8 @@ enum class LinalgOperandDefKind {
|
|||
Scalar,
|
||||
OutputTensor,
|
||||
IndexAttr,
|
||||
UnaryFnAttr,
|
||||
BinaryFnAttr,
|
||||
TypeFnAttr
|
||||
};
|
||||
|
||||
|
@ -208,6 +210,8 @@ struct ScalarEnumerationTraits<LinalgOperandDefKind> {
|
|||
io.enumCase(value, "scalar", LinalgOperandDefKind::Scalar);
|
||||
io.enumCase(value, "output_tensor", LinalgOperandDefKind::OutputTensor);
|
||||
io.enumCase(value, "index_attr", LinalgOperandDefKind::IndexAttr);
|
||||
io.enumCase(value, "unary_fn_attr", LinalgOperandDefKind::UnaryFnAttr);
|
||||
io.enumCase(value, "binary_fn_attr", LinalgOperandDefKind::BinaryFnAttr);
|
||||
io.enumCase(value, "type_fn_attr", LinalgOperandDefKind::TypeFnAttr);
|
||||
}
|
||||
};
|
||||
|
@ -430,6 +434,45 @@ static ScalarAssign *findAssignment(StringRef name,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// Return true if the operand is a function attribute.
|
||||
static bool isFunctionAttribute(LinalgOperandDefKind kind) {
|
||||
return kind == LinalgOperandDefKind::UnaryFnAttr ||
|
||||
kind == LinalgOperandDefKind::BinaryFnAttr ||
|
||||
kind == LinalgOperandDefKind::TypeFnAttr;
|
||||
}
|
||||
|
||||
// Return true if the operand is an attribute.
|
||||
static bool isAttribute(LinalgOperandDefKind kind) {
|
||||
return kind == LinalgOperandDefKind::IndexAttr || isFunctionAttribute(kind);
|
||||
}
|
||||
|
||||
// Get the enum name for the given operand kind.
|
||||
std::string convertOperandKindToEnumName(LinalgOperandDefKind kind) {
|
||||
switch (kind) {
|
||||
case LinalgOperandDefKind::UnaryFnAttr:
|
||||
return std::string("UnaryFn");
|
||||
case LinalgOperandDefKind::BinaryFnAttr:
|
||||
return std::string("BinaryFn");
|
||||
case LinalgOperandDefKind::TypeFnAttr:
|
||||
return std::string("TypeFn");
|
||||
default:
|
||||
break;
|
||||
}
|
||||
llvm_unreachable("unsupported function attribute kind");
|
||||
}
|
||||
|
||||
// Get the enum name for the given function kind.
|
||||
std::string convertFunctionKindToEnumName(ScalarFnKind kind) {
|
||||
switch (kind) {
|
||||
case ScalarFnKind::Unary:
|
||||
return std::string("UnaryFn");
|
||||
case ScalarFnKind::Binary:
|
||||
return std::string("BinaryFn");
|
||||
case ScalarFnKind::Type:
|
||||
return std::string("TypeFn");
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Templates
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -693,8 +736,7 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
|
|||
interfaceNameList = interleaveToString(opConfig.metadata->implements, ", ");
|
||||
|
||||
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
|
||||
return arg.kind == LinalgOperandDefKind::IndexAttr ||
|
||||
arg.kind == LinalgOperandDefKind::TypeFnAttr;
|
||||
return isAttribute(arg.kind);
|
||||
})) {
|
||||
SmallVector<std::string> attrDefs;
|
||||
SmallVector<std::string> attrParams;
|
||||
|
@ -703,13 +745,14 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
|
|||
static const char paramFmt[] = "\"Attribute\":${0}";
|
||||
static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});";
|
||||
// Add the type conversion attributes to the op definition and builders.
|
||||
if (arg.kind == LinalgOperandDefKind::TypeFnAttr) {
|
||||
if (isFunctionAttribute(arg.kind)) {
|
||||
assert(arg.defaultFn.hasValue());
|
||||
static const char typeFmt[] = "TypeFn::{0}";
|
||||
std::string enumName = convertOperandKindToEnumName(arg.kind);
|
||||
static const char typeFmt[] = "{0}::{1}";
|
||||
static const char defFmt[] = "DefaultValuedAttr<{0}, \"{1}\">:${2}";
|
||||
attrDefs.push_back(llvm::formatv(defFmt, "TypeFnAttr",
|
||||
llvm::formatv(typeFmt, arg.defaultFn),
|
||||
arg.name));
|
||||
attrDefs.push_back(llvm::formatv(
|
||||
defFmt, llvm::formatv("{0}Attr", enumName),
|
||||
llvm::formatv(typeFmt, enumName, arg.defaultFn), arg.name));
|
||||
attrParams.push_back(llvm::formatv(paramFmt, arg.name));
|
||||
attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
|
||||
}
|
||||
|
@ -1000,21 +1043,24 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b,
|
|||
SmallVector<std::string> attrs;
|
||||
SmallVector<std::string> stmts;
|
||||
for (LinalgOperandDef &arg : args) {
|
||||
if (arg.kind != LinalgOperandDefKind::TypeFnAttr)
|
||||
if (!isFunctionAttribute(arg.kind))
|
||||
continue;
|
||||
// Obtain the type function attribute values. Parameters.
|
||||
// {0}: attribute name
|
||||
// {1}: default type function name
|
||||
// {0}: enum name
|
||||
// {1}: attribute name
|
||||
// {2}: default type function name
|
||||
static const char attrDef[] = R"FMT(
|
||||
TypeFn {0}Val = TypeFn::{1};
|
||||
auto {0}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
|
||||
return attr.getName() == "{0}"; });
|
||||
if ({0}Iter != attrs.end()) {{
|
||||
if (auto attr = {0}Iter->getValue().dyn_cast<TypeFnAttr>())
|
||||
{0}Val = attr.getValue();
|
||||
{0} {1}Val = {0}::{2};
|
||||
auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
|
||||
return attr.getName() == "{1}"; });
|
||||
if ({1}Iter != attrs.end()) {{
|
||||
if (auto attr = {1}Iter->getValue().dyn_cast<{0}Attr>())
|
||||
{1}Val = attr.getValue();
|
||||
}
|
||||
)FMT";
|
||||
attrs.push_back(llvm::formatv(attrDef, arg.name, arg.defaultFn));
|
||||
std::string enumName = convertOperandKindToEnumName(arg.kind);
|
||||
attrs.push_back(
|
||||
llvm::formatv(attrDef, enumName, arg.name, arg.defaultFn));
|
||||
}
|
||||
for (LinalgOperandDef &arg : args) {
|
||||
if (arg.kind != LinalgOperandDefKind::OutputTensor)
|
||||
|
@ -1056,11 +1102,47 @@ if ({0}Iter != attrs.end()) {{
|
|||
cppIdent, *expression.index));
|
||||
return cppIdent;
|
||||
}
|
||||
if (expression.scalarFn &&
|
||||
expression.scalarFn->kind != ScalarFnKind::Type) {
|
||||
// Apply function.
|
||||
// Recursively generate operands.
|
||||
if (expression.scalarFn) {
|
||||
std::string enumName =
|
||||
convertFunctionKindToEnumName(expression.scalarFn->kind);
|
||||
|
||||
// Get the function or attribute name.
|
||||
assert(expression.scalarFn->fnName || expression.scalarFn->attrName);
|
||||
std::string funcType;
|
||||
if (expression.scalarFn->fnName) {
|
||||
funcType = llvm::formatv("{0}::{1}", enumName,
|
||||
*expression.scalarFn->fnName);
|
||||
}
|
||||
if (expression.scalarFn->attrName) {
|
||||
if (llvm::none_of(args, [&](LinalgOperandDef &arg) {
|
||||
return isFunctionAttribute(arg.kind) &&
|
||||
arg.name == expression.scalarFn->attrName.getValue();
|
||||
})) {
|
||||
emitError(genContext.getLoc())
|
||||
<< "missing function attribute "
|
||||
<< expression.scalarFn->attrName.getValue();
|
||||
}
|
||||
funcType = llvm::formatv("{0}Val", *expression.scalarFn->attrName);
|
||||
}
|
||||
assert(!funcType.empty());
|
||||
|
||||
// Add the optional type parameter to the operands.
|
||||
SmallVector<std::string> operandCppValues;
|
||||
if (expression.scalarFn->kind == ScalarFnKind::Type) {
|
||||
assert(expression.scalarFn->typeVar.hasValue());
|
||||
Optional<std::string> typeCppValue =
|
||||
findTypeValue(expression.scalarFn->typeVar.getValue(), args);
|
||||
if (!typeCppValue) {
|
||||
emitError(genContext.getLoc())
|
||||
<< "type variable " << expression.scalarFn->typeVar.getValue()
|
||||
<< ", used in a type conversion, must map to a predefined or "
|
||||
<< "an argument type but it does not";
|
||||
return None;
|
||||
}
|
||||
operandCppValues.push_back(typeCppValue.getValue());
|
||||
}
|
||||
|
||||
// Collect the scalar operands.
|
||||
for (ScalarExpression &operand : expression.scalarFn->operands) {
|
||||
auto operandCppValue = generateExpression(operand);
|
||||
if (!operandCppValue)
|
||||
|
@ -1068,59 +1150,11 @@ if ({0}Iter != attrs.end()) {{
|
|||
operandCppValues.push_back(*operandCppValue);
|
||||
}
|
||||
|
||||
std::string prefix = expression.scalarFn->kind == ScalarFnKind::Unary
|
||||
? "unary"
|
||||
: "binary";
|
||||
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
|
||||
stmts.push_back(
|
||||
llvm::formatv("Value {0} = helper.{1}__{2}({3});", cppIdent,
|
||||
prefix, expression.scalarFn->fnName,
|
||||
interleaveToString(operandCppValues, ", ")));
|
||||
return cppIdent;
|
||||
}
|
||||
if (expression.scalarFn &&
|
||||
expression.scalarFn->kind == ScalarFnKind::Type) {
|
||||
// Symbolic cast.
|
||||
// Operands must be arity 1.
|
||||
if (expression.scalarFn->operands.size() != 1) {
|
||||
emitError(genContext.getLoc())
|
||||
<< "type conversion operand arity must be 1";
|
||||
return None;
|
||||
}
|
||||
Optional<std::string> operandCppValue =
|
||||
generateExpression(expression.scalarFn->operands[0]);
|
||||
if (!operandCppValue)
|
||||
return None;
|
||||
|
||||
assert(expression.scalarFn->typeVar.hasValue());
|
||||
Optional<std::string> typeCppValue =
|
||||
findTypeValue(expression.scalarFn->typeVar.getValue(), args);
|
||||
if (!typeCppValue) {
|
||||
emitError(genContext.getLoc())
|
||||
<< "type variable " << expression.scalarFn->typeVar.getValue()
|
||||
<< ", used in a type conversion, must map to a predefined or "
|
||||
<< "an argument type but it does not";
|
||||
return None;
|
||||
}
|
||||
|
||||
// Use the function name or the attribute to build the type function.
|
||||
std::string typeFunc = llvm::formatv(
|
||||
"TypeFn::{0}", expression.scalarFn->fnName.getValueOr(""));
|
||||
if (expression.scalarFn->attrName) {
|
||||
if (llvm::none_of(args, [&](LinalgOperandDef &arg) {
|
||||
return arg.kind == LinalgOperandDefKind::TypeFnAttr &&
|
||||
arg.name == expression.scalarFn->attrName.getValue();
|
||||
})) {
|
||||
emitError(genContext.getLoc())
|
||||
<< "missing type function attribute "
|
||||
<< expression.scalarFn->attrName.getValue();
|
||||
}
|
||||
typeFunc = llvm::formatv("{0}Val", *expression.scalarFn->attrName);
|
||||
}
|
||||
// Call the function builder.
|
||||
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
|
||||
stmts.push_back(llvm::formatv(
|
||||
"Value {0} = helper.buildTypeFn({1}, {2}, {3});", cppIdent,
|
||||
typeFunc, typeCppValue.getValue(), *operandCppValue));
|
||||
"Value {0} = helper.build{1}({2}, {3});", cppIdent, enumName,
|
||||
funcType, interleaveToString(operandCppValues, ", ")));
|
||||
return cppIdent;
|
||||
}
|
||||
emitError(genContext.getLoc()) << "unknown ScalarExpression type";
|
||||
|
|
Loading…
Reference in New Issue