[mlir][OpDSL] Rename function to make signedness explicit (NFC).

The revision renames the following OpDSL functions:
```
TypeFn.cast -> TypeFn.cast_signed
BinaryFn.min -> BinaryFn.min_signed
BinaryFn.max -> BinaryFn.max_signed
```
The corresponding enum values on the C++ side are renamed accordingly:
```
#linalg.type_fn<cast> -> #linalg.type_fn<cast_signed>
#linalg.binary_fn<min> -> #linalg.binary_fn<min_signed>
#linalg.binary_fn<max> -> #linalg.binary_fn<max_signed>
```

Depends On D120110

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D120562
This commit is contained in:
gysit 2022-03-01 08:10:51 +00:00
parent db85cd729a
commit e9085d0d25
18 changed files with 247 additions and 236 deletions

View File

@ -56,7 +56,8 @@ def matmul(A=TensorDef(T1, S.M, S.K),
"""
domain(D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
C[D.m, D.n] += TypeFn.cast_signed(
U, A[D.m, D.k]) * TypeFn.cast_signed(U, B[D.k, D.n])
```
Here we have a simple type polymorphic contraction that takes arguments `A` and
@ -160,7 +161,7 @@ def pooling_poly(
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
O[D.n, D.oh, D.ow, D.c] += TypeFn.cast(U,
O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed(U,
I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
```
@ -182,8 +183,8 @@ A number of unary and binary arithmetic functions are supported:
* `BinaryFn.add(a, b)` (also via overloading the binary `+` operator)
* `BinaryFn.mul(a, b)` (also via overloading the binary `*` operator)
* `BinaryFn.max(a, b)`
* `BinaryFn.min(a, b)`
* `BinaryFn.max_signed(a, b)`
* `BinaryFn.min_signed(a, b)`
* `BinaryFn.sub(a, b)` (also via overloading the binary `-` operator)
* `BinaryFn.max_unsigned(a, b)`
* `BinaryFn.min_unsigned(a, b)`
@ -198,8 +199,8 @@ reduction functions can appear as the outermost function on the RHS:
* `ReduceFn.add` (also overloading the inplace `+=` on a LHS)
* `ReduceFn.mul`
* `ReduceFn.max`
* `ReduceFn.min`
* `ReduceFn.max_signed`
* `ReduceFn.min_signed`
* `ReduceFn.max_unsigned`
* `ReduceFn.min_unsigned`
@ -208,11 +209,11 @@ functions that treat integers as signed or unsigned values.
Additionally, type conversion functions cast an operand to a target type:
* `TypeFn.cast(TypeVar, operand)`
* `TypeFn.cast_signed(TypeVar, operand)`
* `TypeFn.cast_unsigned(TypeVar, operand)`
As the integer types are signless, signedness is implement by different
functions that treat integers as signed (`TypeFn.cast`) or unsigned
functions that treat integers as signed (`TypeFn.cast_signed`) or unsigned
(`TypeFn.cast_unsigned`) values.
There are also special forms:
@ -235,12 +236,12 @@ def elemwise_binary(
rhs=TensorDef(T2),
O=TensorDef(U, output=True),
fun=BinaryFnAttrDef(default=BinaryFn.add),
cast=TypeFnAttrDef(default=TypeFn.cast)):
cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None]))
```
The `fun` and `cast` function attributes by default are aliases for their
default values `BinaryFn.add` and `TypeFn.cast`, respectively. When
default values `BinaryFn.add` and `TypeFn.cast_signed`, respectively. When
instantiating the operation, the function attributes may be set to other
functions using optional named arguments:
@ -265,26 +266,27 @@ output types of constructed ops. An exception are predefined types such as
computations with a type that is independent of the input and output types. For
example, parts of floating point computation may require double precision
arithmetic despite all inputs and outputs being single precision values.
Assignment expressions with no `TypeFn.cast` calls will generally require
Assignment expressions with no `TypeFn.cast_signed` calls will generally require
uniform types throughout and will fail to verify if violated. The presence of a
`TypeFn.cast` or `TypeFn.cast_unsigned` allows for a limited form of numeric
type conversion between element types that can be derived from inputs and
outputs (and in the future, attributes). `TypeFn.cast` calls with a `TypeVar`
first argument are emitted as `type_fn` primitives in the YAML definition.
`TypeFn.cast_signed` or `TypeFn.cast_unsigned` allows for a limited form of
numeric type conversion between element types that can be derived from inputs
and outputs (and in the future, attributes). `TypeFn.cast_signed` calls with a
`TypeVar` first argument are emitted as `type_fn` primitives in the YAML
definition.
Casting will perform `int<->float` and `index->int` type conversions and will
perform any necessary extension or truncation within the type family. The
integer types themselves are signless and signedness is implemented by
functions/operations. The `TypeFn.cast` function treats all integers as signed,
while `TypeFn.cast_unsigned` treats them as unsigned.
functions/operations. The `TypeFn.cast_signed` function treats all integers as
signed, while `TypeFn.cast_unsigned` treats them as unsigned.
The following examples illustrate the lowering of signed and unsigned functions:
* cast(I32 -> I64) -> `arith.ExtSIOp`
* cast(F32 -> I32) -> `arith.FPToSIOp`
* cast_signed(I32 -> I64) -> `arith.ExtSIOp`
* cast_signed(F32 -> I32) -> `arith.FPToSIOp`
* cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
* cast_unsigned(F32 -> I32) -> `arith.FPToUIOp`
* max -> `arith.MaxSIOp`
* max_signed -> `arith.MaxSIOp`
* max_unsinged -> `arith.MaxUIOp`
Not all functions are applicable for all numeric types, and on mismatch, op
@ -302,7 +304,7 @@ An example for a rank polymorphic operation is `fill`:
@linalg_structured_op
def fill(value=ScalarDef(T1),
O=TensorDef(U, output=True)):
O[None] = TypeFn.cast(U, value)
O[None] = TypeFn.cast_signed(U, value)
```
The operation sets the elements of the output tensor `O` to `value`. All

View File

@ -68,10 +68,10 @@ def UnaryFn : I32EnumAttr<"UnaryFn", "", [
}
def BinaryFn : I32EnumAttr<"BinaryFn", "", [
I32EnumAttrCase<"add", 0>,
I32EnumAttrCase<"mul", 1>,
I32EnumAttrCase<"max", 2>,
I32EnumAttrCase<"min", 3>,
I32EnumAttrCase<"sub", 4>,
I32EnumAttrCase<"sub", 1>,
I32EnumAttrCase<"mul", 2>,
I32EnumAttrCase<"max_signed", 3>,
I32EnumAttrCase<"min_signed", 4>,
I32EnumAttrCase<"max_unsigned", 5>,
I32EnumAttrCase<"min_unsigned", 6>
]> {
@ -79,7 +79,7 @@ def BinaryFn : I32EnumAttr<"BinaryFn", "", [
let cppNamespace = "::mlir::linalg";
}
def TypeFn : I32EnumAttr<"TypeFn", "", [
I32EnumAttrCase<"cast", 0>,
I32EnumAttrCase<"cast_signed", 0>,
I32EnumAttrCase<"cast_unsigned", 1>
]> {
let genSpecializedAttr = 0;

View File

@ -28,7 +28,7 @@ structured_op: !LinalgStructuredOpConfig
- !LinalgOperandDefConfig
name: cast
kind: type_fn_attr
default_fn: cast
default_fn: cast_signed
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<() -> ()>
@ -83,7 +83,7 @@ structured_op: !LinalgStructuredOpConfig
- !LinalgOperandDefConfig
name: cast
kind: type_fn_attr
default_fn: cast
default_fn: cast_signed
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<() -> ()>
@ -145,7 +145,7 @@ structured_op: !LinalgStructuredOpConfig
- !LinalgOperandDefConfig
name: cast
kind: type_fn_attr
default_fn: cast
default_fn: cast_signed
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
@ -324,7 +324,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -332,7 +332,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -345,7 +345,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -353,7 +353,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -424,7 +424,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: AccumType
operands:
- !ScalarExpression
@ -432,7 +432,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: AccumType
operands:
- !ScalarExpression
@ -493,7 +493,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -501,7 +501,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -577,7 +577,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -585,7 +585,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -598,7 +598,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -606,7 +606,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -665,7 +665,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -673,7 +673,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -732,7 +732,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -740,7 +740,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -800,7 +800,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -808,7 +808,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -866,7 +866,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -874,7 +874,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -933,7 +933,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -941,7 +941,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1002,7 +1002,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1010,7 +1010,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1074,7 +1074,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1082,7 +1082,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1158,7 +1158,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1166,7 +1166,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1256,7 +1256,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1264,7 +1264,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1372,7 +1372,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1380,7 +1380,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1393,7 +1393,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1401,7 +1401,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1491,7 +1491,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1499,7 +1499,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1591,7 +1591,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1599,7 +1599,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1674,7 +1674,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1682,7 +1682,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1767,7 +1767,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1775,7 +1775,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1876,7 +1876,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1884,7 +1884,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1897,7 +1897,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1905,7 +1905,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1991,7 +1991,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -1999,7 +1999,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -2102,7 +2102,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -2110,7 +2110,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -2123,7 +2123,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -2131,7 +2131,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -2210,7 +2210,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -2282,14 +2282,14 @@ structured_op: !LinalgStructuredOpConfig
value: !ScalarExpression
scalar_fn:
kind: binary
fn_name: max
fn_name: max_signed
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -2440,14 +2440,14 @@ structured_op: !LinalgStructuredOpConfig
value: !ScalarExpression
scalar_fn:
kind: binary
fn_name: max
fn_name: max_signed
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -2519,14 +2519,14 @@ structured_op: !LinalgStructuredOpConfig
value: !ScalarExpression
scalar_fn:
kind: binary
fn_name: min
fn_name: min_signed
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -2690,7 +2690,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -2768,14 +2768,14 @@ structured_op: !LinalgStructuredOpConfig
value: !ScalarExpression
scalar_fn:
kind: binary
fn_name: max
fn_name: max_signed
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -2853,14 +2853,14 @@ structured_op: !LinalgStructuredOpConfig
value: !ScalarExpression
scalar_fn:
kind: binary
fn_name: min
fn_name: min_signed
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -2897,7 +2897,7 @@ structured_op: !LinalgStructuredOpConfig
value: !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -2950,7 +2950,7 @@ structured_op: !LinalgStructuredOpConfig
value: !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: T
operands:
- !ScalarExpression
@ -2971,7 +2971,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: F64
operands:
- !ScalarExpression
@ -2979,7 +2979,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: F64
operands:
- !ScalarExpression
@ -3000,7 +3000,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: I32
operands:
- !ScalarExpression
@ -3023,7 +3023,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: I32
operands:
- !ScalarExpression
@ -3033,7 +3033,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: I32
operands:
- !ScalarExpression
@ -3041,7 +3041,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: I32
operands:
- !ScalarExpression
@ -3049,7 +3049,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: I32
operands:
- !ScalarExpression
@ -3057,7 +3057,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: I32
operands:
- !ScalarExpression
@ -3079,7 +3079,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: F64
operands:
- !ScalarExpression
@ -3130,7 +3130,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
@ -3143,7 +3143,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression

View File

@ -160,22 +160,22 @@ public:
if (allFloatingPoint)
return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::mul:
if (allFloatingPoint)
return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::max:
if (allFloatingPoint)
return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::min:
if (allFloatingPoint)
return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::sub:
if (allFloatingPoint)
return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::mul:
if (allFloatingPoint)
return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::max_signed:
if (allFloatingPoint)
return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::min_signed:
if (allFloatingPoint)
return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::max_unsigned:
if (allFloatingPoint)
return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
@ -191,7 +191,7 @@ public:
// Build the type functions defined by OpDSL.
Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
switch (typeFn) {
case TypeFn::cast:
case TypeFn::cast_signed:
return cast(toType, operand, false);
case TypeFn::cast_unsigned:
return cast(toType, operand, true);

View File

@ -305,10 +305,10 @@ class BinaryFn:
- max_unsinged -> `arith.MaxUIOp`
"""
add = BinaryFnType("add")
mul = BinaryFnType("mul")
max = BinaryFnType("max")
min = BinaryFnType("min")
sub = BinaryFnType("sub")
mul = BinaryFnType("mul")
max_signed = BinaryFnType("max_signed")
min_signed = BinaryFnType("min_signed")
max_unsigned = BinaryFnType("max_unsigned")
min_unsigned = BinaryFnType("min_unsigned")
@ -334,14 +334,14 @@ class TypeFn:
"""Type conversion function namespace.
As the integer types are signless, signedness is implement by different cast
functions that treat integers as signed (`cast`) or unsigned
functions that treat integers as signed (`cast_signed`) or unsigned
(`cast_unsigned`) values.
Examples:
- cast(I32 -> I64) -> `arith.ExtSIOp`
- cast_signed(I32 -> I64) -> `arith.ExtSIOp`
- cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
"""
cast = TypeFnType("cast")
cast_signed = TypeFnType("cast_signed")
cast_unsigned = TypeFnType("cast_unsigned")
@ -389,8 +389,8 @@ class ReduceFnType:
class ReduceFn:
add = ReduceFnType(BinaryFn.add)
mul = ReduceFnType(BinaryFn.mul)
max = ReduceFnType(BinaryFn.max)
min = ReduceFnType(BinaryFn.min)
max_signed = ReduceFnType(BinaryFn.max_signed)
min_signed = ReduceFnType(BinaryFn.min_signed)
max_unsigned = ReduceFnType(BinaryFn.max_unsigned)
min_unsigned = ReduceFnType(BinaryFn.min_unsigned)

View File

@ -370,7 +370,7 @@ class _BodyBuilder:
raise ValueError(f"Unable to cast body expression from {operand_type} to "
f"{to_type}")
def _type_cast(self, type_var_name: str, operand: Value) -> Value:
def _type_cast_signed(self, type_var_name: str, operand: Value) -> Value:
return self._cast(type_var_name, operand, False)
def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
@ -407,7 +407,7 @@ class _BodyBuilder:
return arith.MulIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")
def _binary_max(self, lhs: Value, rhs: Value) -> Value:
def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.MaxFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
@ -422,7 +422,7 @@ class _BodyBuilder:
raise NotImplementedError(
"Unsupported 'max_unsigned' operands: {lhs}, {rhs}")
def _binary_min(self, lhs: Value, rhs: Value) -> Value:
def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.MinFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):

View File

@ -11,7 +11,7 @@ def elemwise_unary(
I=TensorDef(T1),
O=TensorDef(U, output=True),
fun=UnaryFnAttrDef(default=UnaryFn.exp),
cast=TypeFnAttrDef(default=TypeFn.cast)):
cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
"""Applies the unary function fun elementwise.
Numeric casting is performed on the input operand, promoting it to the same
@ -26,7 +26,7 @@ def elemwise_binary(
rhs=TensorDef(T2),
O=TensorDef(U, output=True),
fun=BinaryFnAttrDef(default=BinaryFn.add),
cast=TypeFnAttrDef(default=TypeFn.cast)):
cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
"""Applies the binary function fun elementwise.
Numeric casting is performed on the input operand, promoting it to the same
@ -40,7 +40,7 @@ def matmul(
A=TensorDef(T1, S.M, S.K),
B=TensorDef(T2, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True),
cast=TypeFnAttrDef(default=TypeFn.cast)):
cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
"""Performs a matrix multiplication of two 2D inputs.
Numeric casting is performed on the operands to the inner multiply, promoting
@ -82,8 +82,9 @@ def quantized_matmul(
matmul.
"""
domain(D.m, D.n, D.k)
C[D.m, D.n] += (TypeFn.cast(U, A[D.m, D.k]) - TypeFn.cast(U, AZp)) * (
TypeFn.cast(U, B[D.k, D.n]) - TypeFn.cast(U, BZp))
C[D.m, D.n] += (
TypeFn.cast_signed(U, A[D.m, D.k]) - TypeFn.cast_signed(U, AZp)) * (
TypeFn.cast_signed(U, B[D.k, D.n]) - TypeFn.cast_signed(U, BZp))
@linalg_structured_op
@ -103,8 +104,8 @@ def mmt4d(
"""
domain(D.m, D.n, D.k, D.m0, D.n0, D.k0)
implements(ContractionOpInterface)
accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast(
TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * TypeFn.cast(
accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed(
TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * TypeFn.cast_signed(
TV.AccumType, rhs[D.n, D.k, D.n0, D.k0])
@ -121,7 +122,8 @@ def batch_matmul(
domain(D.b, D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.b, D.m,
D.n] += TypeFn.cast(U, A[D.b, D.m, D.k]) * TypeFn.cast(U, B[D.b, D.k, D.n])
D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
U, B[D.b, D.k, D.n])
@linalg_structured_op
@ -139,9 +141,9 @@ def quantized_batch_matmul(
matmul.
"""
domain(D.b, D.m, D.n, D.k)
C[D.b, D.m,
D.n] += (TypeFn.cast(U, A[D.b, D.m, D.k]) - TypeFn.cast(U, AZp)) * (
TypeFn.cast(U, B[D.b, D.k, D.n]) - TypeFn.cast(U, BZp))
C[D.b, D.m, D.n] += (
TypeFn.cast_signed(U, A[D.b, D.m, D.k]) - TypeFn.cast_signed(U, AZp)) * (
TypeFn.cast_signed(U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp))
@linalg_structured_op
@ -156,7 +158,7 @@ def matvec(
"""
domain(D.m, D.n)
implements(ContractionOpInterface)
x[D.m] += TypeFn.cast(U, A[D.m, D.n]) * TypeFn.cast(U, y[D.n])
x[D.m] += TypeFn.cast_signed(U, A[D.m, D.n]) * TypeFn.cast_signed(U, y[D.n])
@linalg_structured_op
@ -171,7 +173,7 @@ def vecmat(
"""
domain(D.n, D.m)
implements(ContractionOpInterface)
x[D.n] += TypeFn.cast(U, y[D.m]) * TypeFn.cast(U, A[D.m, D.n])
x[D.n] += TypeFn.cast_signed(U, y[D.m]) * TypeFn.cast_signed(U, A[D.m, D.n])
@linalg_structured_op
@ -186,7 +188,8 @@ def batch_matvec(
"""
domain(D.b, D.m, D.k)
implements(ContractionOpInterface)
C[D.b, D.m] += TypeFn.cast(U, A[D.b, D.m, D.k]) * TypeFn.cast(U, B[D.b, D.k])
C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
U, B[D.b, D.k])
@linalg_structured_op
@ -198,7 +201,7 @@ def dot(
them to the same data type as the accumulator/output.
"""
implements(ContractionOpInterface)
C[None] += TypeFn.cast(U, A[D.m]) * TypeFn.cast(U, B[D.m])
C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m])
@linalg_structured_op
@ -213,7 +216,8 @@ def conv_1d(
"""
implements(ConvolutionOpInterface)
domain(D.ow, D.kw)
O[D.ow] += TypeFn.cast(U, I[D.ow + D.kw]) * TypeFn.cast(U, K[D.kw])
O[D.ow] += TypeFn.cast_signed(U, I[D.ow + D.kw]) * TypeFn.cast_signed(
U, K[D.kw])
@linalg_structured_op
@ -228,8 +232,8 @@ def conv_2d(
"""
implements(ConvolutionOpInterface)
domain(D.oh, D.ow, D.kh, D.kw)
O[D.oh, D.ow] += TypeFn.cast(U, I[D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast(
U, K[D.kh, D.kw])
O[D.oh, D.ow] += TypeFn.cast_signed(
U, I[D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast_signed(U, K[D.kh, D.kw])
@linalg_structured_op
@ -244,9 +248,9 @@ def conv_3d(
"""
implements(ConvolutionOpInterface)
domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw)
O[D.od, D.oh,
D.ow] += TypeFn.cast(U, I[D.od + D.kd, D.oh + D.kh, D.ow +
D.kw]) * TypeFn.cast(U, K[D.kd, D.kh, D.kw])
O[D.od, D.oh, D.ow] += TypeFn.cast_signed(
U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast_signed(
U, K[D.kd, D.kh, D.kw])
@linalg_structured_op
@ -264,8 +268,8 @@ def conv_1d_nwc_wcf(
implements(ConvolutionOpInterface)
domain(D.n, D.ow, D.f, D.kw, D.c)
O[D.n, D.ow,
D.f] += TypeFn.cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW,
D.c]) * TypeFn.cast(U, K[D.kw, D.c, D.f])
D.f] += TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW,
D.c]) * TypeFn.cast_signed(U, K[D.kw, D.c, D.f])
@linalg_structured_op
@ -287,9 +291,9 @@ def conv_2d_nhwc_hwcf(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.f] += TypeFn.cast(
O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
D.c]) * TypeFn.cast(U, K[D.kh, D.kw, D.c, D.f])
D.c]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f])
@linalg_structured_op
@ -315,10 +319,11 @@ def conv_2d_nhwc_hwcf_q(
implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow,
D.f] += (TypeFn.cast(
D.f] += (TypeFn.cast_signed(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) -
TypeFn.cast(U, IZp)) * (
TypeFn.cast(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast(U, KZp))
TypeFn.cast_signed(U, IZp)) * (
TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) -
TypeFn.cast_signed(U, KZp))
@linalg_structured_op
@ -340,9 +345,9 @@ def conv_2d_nchw_fchw(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
O[D.n, D.f, D.oh, D.ow] += TypeFn.cast(
U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH,
D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast(U, K[D.f, D.c, D.kh, D.kw])
O[D.n, D.f, D.oh, D.ow] += TypeFn.cast_signed(
U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW +
D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw])
@linalg_structured_op
@ -360,9 +365,9 @@ def conv_3d_ndhwc_dhwcf(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast(
O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast_signed(
U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast(
D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed(
U, K[D.kd, D.kh, D.kw, D.c, D.f])
@ -382,8 +387,8 @@ def depthwise_conv_1d_nwc_wc(
implements(ConvolutionOpInterface)
domain(D.n, D.ow, D.ic, D.kw)
O[D.n, D.ow, D.ic] += \
TypeFn.cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \
TypeFn.cast(U, K[D.kw, D.ic])
TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \
TypeFn.cast_signed(U, K[D.kw, D.ic])
@linalg_structured_op
@ -402,9 +407,9 @@ def depthwise_conv_2d_nhwc_hwc(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast(
O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast_signed(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
D.ic]) * TypeFn.cast(U, K[D.kh, D.kw, D.ic])
D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic])
@linalg_structured_op
@ -424,11 +429,11 @@ def depthwise_conv_2d_nhwc_hwc_q(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
O[D.n, D.oh, D.ow,
D.ic] += ((TypeFn.cast(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) -
TypeFn.cast(U, IZp)) *
(TypeFn.cast(U, K[D.kh, D.kw, D.ic]) - TypeFn.cast(U, KZp)))
O[D.n, D.oh, D.ow, D.ic] += ((TypeFn.cast_signed(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) -
TypeFn.cast_signed(U, IZp)) *
(TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) -
TypeFn.cast_signed(U, KZp)))
@linalg_structured_op
@ -446,9 +451,9 @@ def depthwise_conv_2d_nhwc_hwcm(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast(
O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
D.ic]) * TypeFn.cast(U, K[D.kh, D.kw, D.ic, D.cm])
D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm])
@linalg_structured_op
@ -469,10 +474,11 @@ def depthwise_conv_2d_nhwc_hwcm_q(
implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
O[D.n, D.oh, D.ow, D.ic,
D.cm] += ((TypeFn.cast(
D.cm] += ((TypeFn.cast_signed(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) -
TypeFn.cast(U, IZp)) *
(TypeFn.cast(U, K[D.kh, D.kw, D.ic, D.cm]) - TypeFn.cast(U, KZp)))
TypeFn.cast_signed(U, IZp)) *
(TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) -
TypeFn.cast_signed(U, KZp)))
@linalg_structured_op
@ -490,7 +496,7 @@ def pooling_nhwc_sum(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
O[D.n, D.oh, D.ow, D.c] += TypeFn.cast(
O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
@ -509,8 +515,8 @@ def pooling_nhwc_max(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw](
TypeFn.cast(
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw](
TypeFn.cast_signed(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
@ -549,8 +555,8 @@ def pooling_nchw_max(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
O[D.n, D.c, D.oh, D.ow] = ReduceFn.max[D.kh, D.kw](
TypeFn.cast(
O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw](
TypeFn.cast_signed(
U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH,
D.ow * S.SW + D.kw * S.DW,]))
@ -570,8 +576,8 @@ def pooling_nhwc_min(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw](
TypeFn.cast(
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw](
TypeFn.cast_signed(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
@ -610,7 +616,7 @@ def pooling_ndhwc_sum(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast(
O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast_signed(
U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
D.ow * S.SW + D.kw * S.DW, D.c])
@ -630,8 +636,8 @@ def pooling_ndhwc_max(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max[D.kd, D.kh, D.kw](
TypeFn.cast(
O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw](
TypeFn.cast_signed(
U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
D.ow * S.SW + D.kw * S.DW, D.c]))
@ -651,8 +657,8 @@ def pooling_ndhwc_min(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min[D.kd, D.kh, D.kw](
TypeFn.cast(
O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw](
TypeFn.cast_signed(
U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
D.ow * S.SW + D.kw * S.DW, D.c]))
@ -665,7 +671,7 @@ def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)):
accesses only and is thus rank polymorphic. Numeric casting is performed on
the value operand, promoting it to the same data type as the output.
"""
O[None] = TypeFn.cast(U, value)
O[None] = TypeFn.cast_signed(U, value)
@linalg_structured_op
@ -685,15 +691,15 @@ def fill_rng_2d(
the range of the generated random numbers.
"""
domain(D.m, D.n)
multiplier = TypeFn.cast(I32, const(1103515245))
increment = TypeFn.cast(I32, const(12345))
rand1 = (TypeFn.cast(I32, index(D.m)) + seed) * multiplier + increment
rand2 = (TypeFn.cast(I32, index(D.n)) + rand1) * multiplier + increment
inv_range = TypeFn.cast(F64, const(2.3283064e-10))
offset = TypeFn.cast(F64, const(2147483647))
multiplier = TypeFn.cast_signed(I32, const(1103515245))
increment = TypeFn.cast_signed(I32, const(12345))
rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment
rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment
inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10))
offset = TypeFn.cast_signed(F64, const(2147483647))
scaling = (max - min) * inv_range
O[D.m, D.n] = TypeFn.cast(T,
(offset + TypeFn.cast(F64, rand2)) * scaling + min)
O[D.m, D.n] = TypeFn.cast_signed(
T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min)
@linalg_structured_op
@ -706,4 +712,4 @@ def soft_plus_2d(
"""
domain(D.m, D.n)
O[D.m, D.n] = \
UnaryFn.log(TypeFn.cast(U, const(1.0)) + UnaryFn.exp(TypeFn.cast(U, I[D.m, D.n])))
UnaryFn.log(TypeFn.cast_signed(U, const(1.0)) + UnaryFn.exp(TypeFn.cast_signed(U, I[D.m, D.n])))

View File

@ -3,7 +3,7 @@
# @linalg_structured_op
# def test1(O=TensorDef(T, S.M, S.N, output=True),
# cast=TypeFnAttrDef(default=TypeFn.cast)):
# cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
# """Title.
# Detailed description.
@ -28,7 +28,7 @@ structured_op: !LinalgStructuredOpConfig
- !LinalgOperandDefConfig
name: cast
kind: type_fn_attr
default_fn: cast
default_fn: cast_signed
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
@ -70,7 +70,7 @@ structured_op: !LinalgStructuredOpConfig
# ODS: let arguments =
# ODS-NEXT: Variadic<AnyType>:$inputs,
# ODS-NEXT: Variadic<AnyShaped>:$outputs,
# ODS-NEXT: DefaultValuedAttr<TypeFnAttr, "TypeFn::cast">:$cast
# ODS-NEXT: DefaultValuedAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
# ODS: let builders =
# ODS: (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
@ -99,7 +99,7 @@ structured_op: !LinalgStructuredOpConfig
# IMPL-LABEL: void Test1Op::regionBuilder(ImplicitLocOpBuilder &b,
# IMPL-NEXT: Block &block, ArrayRef<NamedAttribute> attrs)
# IMPL: TypeFn castVal = TypeFn::cast;
# IMPL: TypeFn castVal = TypeFn::cast_signed;
# IMPL-NEXT: auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
# IMPL-NEXT: return attr.getName() == "cast"; });
# IMPL-NEXT: if (castIter != attrs.end()) {
@ -209,7 +209,7 @@ structured_op: !LinalgStructuredOpConfig
# Detailed description.
# """
# O[None] = TypeFn.cast(U, value)
# O[None] = TypeFn.cast_signed(U, value)
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
@ -241,7 +241,7 @@ structured_op: !LinalgStructuredOpConfig
value: !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression

View File

@ -26,7 +26,7 @@ from mlir.dialects.linalg.opdsl.lang import *
# CHECK: default_fn: exp
# CHECK: name: cast
# CHECK: kind: type_fn_attr
# CHECK: default_fn: cast
# CHECK: default_fn: cast_signed
@linalg_structured_op
def matmul(
A=TensorDef(T, S.M, S.K),
@ -34,7 +34,7 @@ def matmul(
C=TensorDef(U, S.M, S.N, output=True),
bfn=BinaryFnAttrDef(default=BinaryFn.mul),
ufn=UnaryFnAttrDef(default=UnaryFn.exp),
cast=TypeFnAttrDef(default=TypeFn.cast)):
cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
C[D.m, D.n] += bfn(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))

View File

@ -35,7 +35,7 @@ def matmul(
B=TensorDef(T, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True),
mul=BinaryFnAttrDef(default=BinaryFn.mul),
cast=TypeFnAttrDef(default=TypeFn.cast)):
cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
C[D.m, D.n] += mul(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
@ -63,13 +63,13 @@ def matmul(
# CHECK: scalar_const: '3.1415926535897931 : f64'
# CHECK: scalar_fn:
# CHECK: kind: type
# CHECK: fn_name: cast
# CHECK: fn_name: cast_signed
# CHECK: type_var: T
# CHECK: operands:
# CHECK: scalar_const: '42 : i64'
# CHECK: scalar_fn:
# CHECK: kind: type
# CHECK: fn_name: cast
# CHECK: fn_name: cast_signed
# CHECK: type_var: T
# CHECK: operands:
# CHECK: scalar_fn:
@ -81,9 +81,9 @@ def matmul(
def constants(
O=TensorDef(T, S.M, S.K, output=True),
exp=UnaryFnAttrDef(default=UnaryFn.exp)):
pi = TypeFn.cast(T, const(3.1415926535897931))
cst42 = TypeFn.cast(T, const(42))
cst1000 = TypeFn.cast(T, exp(const(1e+3)))
pi = TypeFn.cast_signed(T, const(3.1415926535897931))
cst42 = TypeFn.cast_signed(T, const(42))
cst1000 = TypeFn.cast_signed(T, exp(const(1e+3)))
O[D.m, D.n] = UnaryFn.exp(pi) + cst42 - cst1000

View File

@ -19,9 +19,9 @@ def conv_poly(
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 2])):
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] += TypeFn.cast(
O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
D.c]) * TypeFn.cast(U, K[D.kh, D.kw, D.c])
D.c]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c])
with Context() as ctx, Location.unknown():

View File

@ -13,7 +13,7 @@ T2 = TV.T2
@linalg_structured_op
def fill_poly(value=ScalarDef(T1), O=TensorDef(U, output=True)):
O[None] = TypeFn.cast(U, value)
O[None] = TypeFn.cast_signed(U, value)
with Context() as ctx, Location.unknown():

View File

@ -25,7 +25,7 @@ def matmul_poly(
A=TensorDef(T1, S.M, S.K),
B=TensorDef(T2, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True),
cast=TypeFnAttrDef(default=TypeFn.cast)):
cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
domain(D.m, D.n, D.k)
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])

View File

@ -21,22 +21,23 @@ def fill_rng_poly(
max=ScalarDef(F64),
seed=ScalarDef(I32),
O=TensorDef(T, S.M, S.N, output=True)):
multiplier = TypeFn.cast(I32, const(1103515245))
increment = TypeFn.cast(I32, const(12345))
rand1 = (TypeFn.cast(I32, index(D.m)) + seed) * multiplier + increment
rand2 = (TypeFn.cast(I32, index(D.n)) + rand1) * multiplier + increment
inv_range = TypeFn.cast(F64, const(2.3283064e-10))
offset = TypeFn.cast(F64, const(2147483647))
multiplier = TypeFn.cast_signed(I32, const(1103515245))
increment = TypeFn.cast_signed(I32, const(12345))
rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment
rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment
inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10))
offset = TypeFn.cast_signed(F64, const(2147483647))
scaling = (max - min) * inv_range
O[D.m, D.n] = TypeFn.cast(T,
(offset + TypeFn.cast(F64, rand2)) * scaling + min)
O[D.m, D.n] = TypeFn.cast_signed(
T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min)
@linalg_structured_op
def soft_plus_poly(
I=TensorDef(T, S.M, S.N), O=TensorDef(U, S.M, S.N, output=True)):
O[D.m, D.n] = UnaryFn.log(
TypeFn.cast(U, const(1.0)) + TypeFn.cast(U, UnaryFn.exp(I[D.m, D.n])))
TypeFn.cast_signed(U, const(1.0)) +
TypeFn.cast_signed(U, UnaryFn.exp(I[D.m, D.n])))
@linalg_structured_op(op_name="custom_op_name")

View File

@ -16,8 +16,8 @@ def pooling_poly(
I=TensorDef(T1, S.N, S.H, S.W, S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
reduce=BinaryFnAttrDef(default=BinaryFn.max),
cast=TypeFnAttrDef(default=TypeFn.cast),
reduce=BinaryFnAttrDef(default=BinaryFn.max_signed),
cast=TypeFnAttrDef(default=TypeFn.cast_signed),
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
@ -99,7 +99,7 @@ with Context() as ctx, Location.unknown():
input,
shape,
outs=[init_result],
reduce=BinaryFn.min,
reduce=BinaryFn.min_signed,
strides=[2, 4],
dilations=[1, 2])
@ -131,7 +131,7 @@ with Context() as ctx, Location.unknown():
input,
shape,
outs=[init_result],
reduce=BinaryFn.min,
reduce=BinaryFn.min_signed,
strides=[2, 4],
dilations=[1, 2])

View File

@ -13,4 +13,5 @@ def matmul(
B=TensorDef(T, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True)):
implements(ContractionOpInterface)
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
U, B[D.k, D.n])

View File

@ -24,7 +24,8 @@ def matmul(
B=TensorDef(T, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True)):
domain(D.m, D.n, D.k)
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
U, B[D.k, D.n])
# Verifies that assignment to a scalar (represented as [None]) is represented
@ -42,7 +43,7 @@ def matmul(
# CHECK-NEXT: - reduction
@linalg_structured_op
def dot(A=TensorDef(T, S.M), B=TensorDef(T, S.M), C=TensorDef(U, output=True)):
C[None] += TypeFn.cast(U, A[D.m]) * TypeFn.cast(U, B[D.m])
C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m])
# Verifies that the index_dims of shape-only operands translate to correct
@ -65,4 +66,4 @@ def pool(
K=TensorDef(T, S.K, index_dims=[D.k]),
O=TensorDef(U, S.O, output=True)):
domain(D.o, D.k)
O[D.o] += TypeFn.cast(U, I[D.o * 2 + D.k])
O[D.o] += TypeFn.cast_signed(U, I[D.o * 2 + D.k])

View File

@ -99,7 +99,7 @@ def testNamedStructuredOpCustomForm():
init_result = linalg.InitTensorOp([4, 8], f32)
# Check for the named form with custom format
# CHECK: linalg.elemwise_unary
# CHECK-SAME: cast = #linalg.type_fn<cast>
# CHECK-SAME: cast = #linalg.type_fn<cast_signed>
# CHECK-SAME: fun = #linalg.unary_fn<exp>
# CHECK-SAME: ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result])
@ -137,7 +137,7 @@ def testNamedStructuredOpGenericForm():
# CHECK-NEXT: arith.mulf{{.*}} (f32, f32) -> f32
# CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32
# CHECK-NEXT: linalg.yield{{.*}} (f32) -> ()
# CHECK-NEXT: cast = #linalg.type_fn<cast>
# CHECK-NEXT: cast = #linalg.type_fn<cast_signed>
# CHECK-SAME: operand_segment_sizes = dense<[2, 1]> : vector<2xi32>
# CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
return linalg.matmul(lhs, rhs, outs=[init_result.result])