forked from OSchip/llvm-project
[mlir][OpDSL] Rename function to make signedness explicit (NFC).
The revision renames the following OpDSL functions: ``` TypeFn.cast -> TypeFn.cast_signed BinaryFn.min -> BinaryFn.min_signed BinaryFn.max -> BinaryFn.max_signed ``` The corresponding enum values on the C++ side are renamed accordingly: ``` #linalg.type_fn<cast> -> #linalg.type_fn<cast_signed> #linalg.binary_fn<min> -> #linalg.binary_fn<min_signed> #linalg.binary_fn<max> -> #linalg.binary_fn<max_signed> ``` Depends On D120110 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D120562
This commit is contained in:
parent
db85cd729a
commit
e9085d0d25
|
@ -56,7 +56,8 @@ def matmul(A=TensorDef(T1, S.M, S.K),
|
|||
"""
|
||||
domain(D.m, D.n, D.k)
|
||||
implements(ContractionOpInterface)
|
||||
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
|
||||
C[D.m, D.n] += TypeFn.cast_signed(
|
||||
U, A[D.m, D.k]) * TypeFn.cast_signed(U, B[D.k, D.n])
|
||||
```
|
||||
|
||||
Here we have a simple type polymorphic contraction that takes arguments `A` and
|
||||
|
@ -160,7 +161,7 @@ def pooling_poly(
|
|||
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
|
||||
O[D.n, D.oh, D.ow, D.c] += TypeFn.cast(U,
|
||||
O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed(U,
|
||||
I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
|
||||
```
|
||||
|
||||
|
@ -182,8 +183,8 @@ A number of unary and binary arithmetic functions are supported:
|
|||
|
||||
* `BinaryFn.add(a, b)` (also via overloading the binary `+` operator)
|
||||
* `BinaryFn.mul(a, b)` (also via overloading the binary `*` operator)
|
||||
* `BinaryFn.max(a, b)`
|
||||
* `BinaryFn.min(a, b)`
|
||||
* `BinaryFn.max_signed(a, b)`
|
||||
* `BinaryFn.min_signed(a, b)`
|
||||
* `BinaryFn.sub(a, b)` (also via overloading the binary `-` operator)
|
||||
* `BinaryFn.max_unsigned(a, b)`
|
||||
* `BinaryFn.min_unsigned(a, b)`
|
||||
|
@ -198,8 +199,8 @@ reduction functions can appear as the outermost function on the RHS:
|
|||
|
||||
* `ReduceFn.add` (also overloading the inplace `+=` on a LHS)
|
||||
* `ReduceFn.mul`
|
||||
* `ReduceFn.max`
|
||||
* `ReduceFn.min`
|
||||
* `ReduceFn.max_signed`
|
||||
* `ReduceFn.min_signed`
|
||||
* `ReduceFn.max_unsigned`
|
||||
* `ReduceFn.min_unsigned`
|
||||
|
||||
|
@ -208,11 +209,11 @@ functions that treat integers as signed or unsigned values.
|
|||
|
||||
Additionally, type conversion functions cast an operand to a target type:
|
||||
|
||||
* `TypeFn.cast(TypeVar, operand)`
|
||||
* `TypeFn.cast_signed(TypeVar, operand)`
|
||||
* `TypeFn.cast_unsigned(TypeVar, operand)`
|
||||
|
||||
As the integer types are signless, signedness is implement by different
|
||||
functions that treat integers as signed (`TypeFn.cast`) or unsigned
|
||||
functions that treat integers as signed (`TypeFn.cast_signed`) or unsigned
|
||||
(`TypeFn.cast_unsigned`) values.
|
||||
|
||||
There are also special forms:
|
||||
|
@ -235,12 +236,12 @@ def elemwise_binary(
|
|||
rhs=TensorDef(T2),
|
||||
O=TensorDef(U, output=True),
|
||||
fun=BinaryFnAttrDef(default=BinaryFn.add),
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast)):
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
|
||||
O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None]))
|
||||
```
|
||||
|
||||
The `fun` and `cast` function attributes by default are aliases for their
|
||||
default values `BinaryFn.add` and `TypeFn.cast`, respectively. When
|
||||
default values `BinaryFn.add` and `TypeFn.cast_signed`, respectively. When
|
||||
instantiating the operation, the function attributes may be set to other
|
||||
functions using optional named arguments:
|
||||
|
||||
|
@ -265,26 +266,27 @@ output types of constructed ops. An exception are predefined types such as
|
|||
computations with a type that is independent of the input and output types. For
|
||||
example, parts of floating point computation may require double precision
|
||||
arithmetic despite all inputs and outputs being single precision values.
|
||||
Assignment expressions with no `TypeFn.cast` calls will generally require
|
||||
Assignment expressions with no `TypeFn.cast_signed` calls will generally require
|
||||
uniform types throughout and will fail to verify if violated. The presence of a
|
||||
`TypeFn.cast` or `TypeFn.cast_unsigned` allows for a limited form of numeric
|
||||
type conversion between element types that can be derived from inputs and
|
||||
outputs (and in the future, attributes). `TypeFn.cast` calls with a `TypeVar`
|
||||
first argument are emitted as `type_fn` primitives in the YAML definition.
|
||||
`TypeFn.cast_signed` or `TypeFn.cast_unsigned` allows for a limited form of
|
||||
numeric type conversion between element types that can be derived from inputs
|
||||
and outputs (and in the future, attributes). `TypeFn.cast_signed` calls with a
|
||||
`TypeVar` first argument are emitted as `type_fn` primitives in the YAML
|
||||
definition.
|
||||
|
||||
Casting will perform `int<->float` and `index->int` type conversions and will
|
||||
perform any necessary extension or truncation within the type family. The
|
||||
integer types themselves are signless and signedness is implemented by
|
||||
functions/operations. The `TypeFn.cast` function treats all integers as signed,
|
||||
while `TypeFn.cast_unsigned` treats them as unsigned.
|
||||
functions/operations. The `TypeFn.cast_signed` function treats all integers as
|
||||
signed, while `TypeFn.cast_unsigned` treats them as unsigned.
|
||||
|
||||
The following examples illustrate the lowering of signed and unsigned functions:
|
||||
|
||||
* cast(I32 -> I64) -> `arith.ExtSIOp`
|
||||
* cast(F32 -> I32) -> `arith.FPToSIOp`
|
||||
* cast_signed(I32 -> I64) -> `arith.ExtSIOp`
|
||||
* cast_signed(F32 -> I32) -> `arith.FPToSIOp`
|
||||
* cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
|
||||
* cast_unsigned(F32 -> I32) -> `arith.FPToUIOp`
|
||||
* max -> `arith.MaxSIOp`
|
||||
* max_signed -> `arith.MaxSIOp`
|
||||
* max_unsinged -> `arith.MaxUIOp`
|
||||
|
||||
Not all functions are applicable for all numeric types, and on mismatch, op
|
||||
|
@ -302,7 +304,7 @@ An example for a rank polymorphic operation is `fill`:
|
|||
@linalg_structured_op
|
||||
def fill(value=ScalarDef(T1),
|
||||
O=TensorDef(U, output=True)):
|
||||
O[None] = TypeFn.cast(U, value)
|
||||
O[None] = TypeFn.cast_signed(U, value)
|
||||
```
|
||||
|
||||
The operation sets the elements of the output tensor `O` to `value`. All
|
||||
|
|
|
@ -68,10 +68,10 @@ def UnaryFn : I32EnumAttr<"UnaryFn", "", [
|
|||
}
|
||||
def BinaryFn : I32EnumAttr<"BinaryFn", "", [
|
||||
I32EnumAttrCase<"add", 0>,
|
||||
I32EnumAttrCase<"mul", 1>,
|
||||
I32EnumAttrCase<"max", 2>,
|
||||
I32EnumAttrCase<"min", 3>,
|
||||
I32EnumAttrCase<"sub", 4>,
|
||||
I32EnumAttrCase<"sub", 1>,
|
||||
I32EnumAttrCase<"mul", 2>,
|
||||
I32EnumAttrCase<"max_signed", 3>,
|
||||
I32EnumAttrCase<"min_signed", 4>,
|
||||
I32EnumAttrCase<"max_unsigned", 5>,
|
||||
I32EnumAttrCase<"min_unsigned", 6>
|
||||
]> {
|
||||
|
@ -79,7 +79,7 @@ def BinaryFn : I32EnumAttr<"BinaryFn", "", [
|
|||
let cppNamespace = "::mlir::linalg";
|
||||
}
|
||||
def TypeFn : I32EnumAttr<"TypeFn", "", [
|
||||
I32EnumAttrCase<"cast", 0>,
|
||||
I32EnumAttrCase<"cast_signed", 0>,
|
||||
I32EnumAttrCase<"cast_unsigned", 1>
|
||||
]> {
|
||||
let genSpecializedAttr = 0;
|
||||
|
|
|
@ -28,7 +28,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !LinalgOperandDefConfig
|
||||
name: cast
|
||||
kind: type_fn_attr
|
||||
default_fn: cast
|
||||
default_fn: cast_signed
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<() -> ()>
|
||||
|
@ -83,7 +83,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !LinalgOperandDefConfig
|
||||
name: cast
|
||||
kind: type_fn_attr
|
||||
default_fn: cast
|
||||
default_fn: cast_signed
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<() -> ()>
|
||||
|
@ -145,7 +145,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !LinalgOperandDefConfig
|
||||
name: cast
|
||||
kind: type_fn_attr
|
||||
default_fn: cast
|
||||
default_fn: cast_signed
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
|
||||
|
@ -324,7 +324,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -332,7 +332,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -345,7 +345,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -353,7 +353,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -424,7 +424,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: AccumType
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -432,7 +432,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: AccumType
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -493,7 +493,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -501,7 +501,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -577,7 +577,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -585,7 +585,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -598,7 +598,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -606,7 +606,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -665,7 +665,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -673,7 +673,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -732,7 +732,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -740,7 +740,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -800,7 +800,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -808,7 +808,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -866,7 +866,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -874,7 +874,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -933,7 +933,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -941,7 +941,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1002,7 +1002,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1010,7 +1010,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1074,7 +1074,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1082,7 +1082,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1158,7 +1158,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1166,7 +1166,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1256,7 +1256,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1264,7 +1264,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1372,7 +1372,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1380,7 +1380,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1393,7 +1393,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1401,7 +1401,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1491,7 +1491,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1499,7 +1499,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1591,7 +1591,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1599,7 +1599,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1674,7 +1674,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1682,7 +1682,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1767,7 +1767,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1775,7 +1775,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1876,7 +1876,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1884,7 +1884,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1897,7 +1897,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1905,7 +1905,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1991,7 +1991,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -1999,7 +1999,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -2102,7 +2102,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -2110,7 +2110,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -2123,7 +2123,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -2131,7 +2131,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -2210,7 +2210,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -2282,14 +2282,14 @@ structured_op: !LinalgStructuredOpConfig
|
|||
value: !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: binary
|
||||
fn_name: max
|
||||
fn_name: max_signed
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -2440,14 +2440,14 @@ structured_op: !LinalgStructuredOpConfig
|
|||
value: !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: binary
|
||||
fn_name: max
|
||||
fn_name: max_signed
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -2519,14 +2519,14 @@ structured_op: !LinalgStructuredOpConfig
|
|||
value: !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: binary
|
||||
fn_name: min
|
||||
fn_name: min_signed
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -2690,7 +2690,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -2768,14 +2768,14 @@ structured_op: !LinalgStructuredOpConfig
|
|||
value: !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: binary
|
||||
fn_name: max
|
||||
fn_name: max_signed
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -2853,14 +2853,14 @@ structured_op: !LinalgStructuredOpConfig
|
|||
value: !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: binary
|
||||
fn_name: min
|
||||
fn_name: min_signed
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -2897,7 +2897,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
value: !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -2950,7 +2950,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
value: !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: T
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -2971,7 +2971,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: F64
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -2979,7 +2979,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: F64
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -3000,7 +3000,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: I32
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -3023,7 +3023,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: I32
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -3033,7 +3033,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: I32
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -3041,7 +3041,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: I32
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -3049,7 +3049,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: I32
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -3057,7 +3057,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: I32
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -3079,7 +3079,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: F64
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -3130,7 +3130,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
@ -3143,7 +3143,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
|
|
@ -160,22 +160,22 @@ public:
|
|||
if (allFloatingPoint)
|
||||
return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
|
||||
return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
|
||||
case BinaryFn::mul:
|
||||
if (allFloatingPoint)
|
||||
return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
|
||||
return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
|
||||
case BinaryFn::max:
|
||||
if (allFloatingPoint)
|
||||
return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
|
||||
return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
|
||||
case BinaryFn::min:
|
||||
if (allFloatingPoint)
|
||||
return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
|
||||
return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
|
||||
case BinaryFn::sub:
|
||||
if (allFloatingPoint)
|
||||
return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
|
||||
return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
|
||||
case BinaryFn::mul:
|
||||
if (allFloatingPoint)
|
||||
return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
|
||||
return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
|
||||
case BinaryFn::max_signed:
|
||||
if (allFloatingPoint)
|
||||
return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
|
||||
return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
|
||||
case BinaryFn::min_signed:
|
||||
if (allFloatingPoint)
|
||||
return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
|
||||
return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
|
||||
case BinaryFn::max_unsigned:
|
||||
if (allFloatingPoint)
|
||||
return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
|
||||
|
@ -191,7 +191,7 @@ public:
|
|||
// Build the type functions defined by OpDSL.
|
||||
Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
|
||||
switch (typeFn) {
|
||||
case TypeFn::cast:
|
||||
case TypeFn::cast_signed:
|
||||
return cast(toType, operand, false);
|
||||
case TypeFn::cast_unsigned:
|
||||
return cast(toType, operand, true);
|
||||
|
|
|
@ -305,10 +305,10 @@ class BinaryFn:
|
|||
- max_unsinged -> `arith.MaxUIOp`
|
||||
"""
|
||||
add = BinaryFnType("add")
|
||||
mul = BinaryFnType("mul")
|
||||
max = BinaryFnType("max")
|
||||
min = BinaryFnType("min")
|
||||
sub = BinaryFnType("sub")
|
||||
mul = BinaryFnType("mul")
|
||||
max_signed = BinaryFnType("max_signed")
|
||||
min_signed = BinaryFnType("min_signed")
|
||||
max_unsigned = BinaryFnType("max_unsigned")
|
||||
min_unsigned = BinaryFnType("min_unsigned")
|
||||
|
||||
|
@ -334,14 +334,14 @@ class TypeFn:
|
|||
"""Type conversion function namespace.
|
||||
|
||||
As the integer types are signless, signedness is implement by different cast
|
||||
functions that treat integers as signed (`cast`) or unsigned
|
||||
functions that treat integers as signed (`cast_signed`) or unsigned
|
||||
(`cast_unsigned`) values.
|
||||
|
||||
Examples:
|
||||
- cast(I32 -> I64) -> `arith.ExtSIOp`
|
||||
- cast_signed(I32 -> I64) -> `arith.ExtSIOp`
|
||||
- cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
|
||||
"""
|
||||
cast = TypeFnType("cast")
|
||||
cast_signed = TypeFnType("cast_signed")
|
||||
cast_unsigned = TypeFnType("cast_unsigned")
|
||||
|
||||
|
||||
|
@ -389,8 +389,8 @@ class ReduceFnType:
|
|||
class ReduceFn:
|
||||
add = ReduceFnType(BinaryFn.add)
|
||||
mul = ReduceFnType(BinaryFn.mul)
|
||||
max = ReduceFnType(BinaryFn.max)
|
||||
min = ReduceFnType(BinaryFn.min)
|
||||
max_signed = ReduceFnType(BinaryFn.max_signed)
|
||||
min_signed = ReduceFnType(BinaryFn.min_signed)
|
||||
max_unsigned = ReduceFnType(BinaryFn.max_unsigned)
|
||||
min_unsigned = ReduceFnType(BinaryFn.min_unsigned)
|
||||
|
||||
|
|
|
@ -370,7 +370,7 @@ class _BodyBuilder:
|
|||
raise ValueError(f"Unable to cast body expression from {operand_type} to "
|
||||
f"{to_type}")
|
||||
|
||||
def _type_cast(self, type_var_name: str, operand: Value) -> Value:
|
||||
def _type_cast_signed(self, type_var_name: str, operand: Value) -> Value:
|
||||
return self._cast(type_var_name, operand, False)
|
||||
|
||||
def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
|
||||
|
@ -407,7 +407,7 @@ class _BodyBuilder:
|
|||
return arith.MulIOp(lhs, rhs).result
|
||||
raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")
|
||||
|
||||
def _binary_max(self, lhs: Value, rhs: Value) -> Value:
|
||||
def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
|
||||
if _is_floating_point_type(lhs.type):
|
||||
return arith.MaxFOp(lhs, rhs).result
|
||||
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
|
||||
|
@ -422,7 +422,7 @@ class _BodyBuilder:
|
|||
raise NotImplementedError(
|
||||
"Unsupported 'max_unsigned' operands: {lhs}, {rhs}")
|
||||
|
||||
def _binary_min(self, lhs: Value, rhs: Value) -> Value:
|
||||
def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value:
|
||||
if _is_floating_point_type(lhs.type):
|
||||
return arith.MinFOp(lhs, rhs).result
|
||||
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
|
||||
|
|
|
@ -11,7 +11,7 @@ def elemwise_unary(
|
|||
I=TensorDef(T1),
|
||||
O=TensorDef(U, output=True),
|
||||
fun=UnaryFnAttrDef(default=UnaryFn.exp),
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast)):
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
|
||||
"""Applies the unary function fun elementwise.
|
||||
|
||||
Numeric casting is performed on the input operand, promoting it to the same
|
||||
|
@ -26,7 +26,7 @@ def elemwise_binary(
|
|||
rhs=TensorDef(T2),
|
||||
O=TensorDef(U, output=True),
|
||||
fun=BinaryFnAttrDef(default=BinaryFn.add),
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast)):
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
|
||||
"""Applies the binary function fun elementwise.
|
||||
|
||||
Numeric casting is performed on the input operand, promoting it to the same
|
||||
|
@ -40,7 +40,7 @@ def matmul(
|
|||
A=TensorDef(T1, S.M, S.K),
|
||||
B=TensorDef(T2, S.K, S.N),
|
||||
C=TensorDef(U, S.M, S.N, output=True),
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast)):
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
|
||||
"""Performs a matrix multiplication of two 2D inputs.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
|
@ -82,8 +82,9 @@ def quantized_matmul(
|
|||
matmul.
|
||||
"""
|
||||
domain(D.m, D.n, D.k)
|
||||
C[D.m, D.n] += (TypeFn.cast(U, A[D.m, D.k]) - TypeFn.cast(U, AZp)) * (
|
||||
TypeFn.cast(U, B[D.k, D.n]) - TypeFn.cast(U, BZp))
|
||||
C[D.m, D.n] += (
|
||||
TypeFn.cast_signed(U, A[D.m, D.k]) - TypeFn.cast_signed(U, AZp)) * (
|
||||
TypeFn.cast_signed(U, B[D.k, D.n]) - TypeFn.cast_signed(U, BZp))
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
|
@ -103,8 +104,8 @@ def mmt4d(
|
|||
"""
|
||||
domain(D.m, D.n, D.k, D.m0, D.n0, D.k0)
|
||||
implements(ContractionOpInterface)
|
||||
accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast(
|
||||
TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * TypeFn.cast(
|
||||
accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed(
|
||||
TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * TypeFn.cast_signed(
|
||||
TV.AccumType, rhs[D.n, D.k, D.n0, D.k0])
|
||||
|
||||
|
||||
|
@ -121,7 +122,8 @@ def batch_matmul(
|
|||
domain(D.b, D.m, D.n, D.k)
|
||||
implements(ContractionOpInterface)
|
||||
C[D.b, D.m,
|
||||
D.n] += TypeFn.cast(U, A[D.b, D.m, D.k]) * TypeFn.cast(U, B[D.b, D.k, D.n])
|
||||
D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
|
||||
U, B[D.b, D.k, D.n])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
|
@ -139,9 +141,9 @@ def quantized_batch_matmul(
|
|||
matmul.
|
||||
"""
|
||||
domain(D.b, D.m, D.n, D.k)
|
||||
C[D.b, D.m,
|
||||
D.n] += (TypeFn.cast(U, A[D.b, D.m, D.k]) - TypeFn.cast(U, AZp)) * (
|
||||
TypeFn.cast(U, B[D.b, D.k, D.n]) - TypeFn.cast(U, BZp))
|
||||
C[D.b, D.m, D.n] += (
|
||||
TypeFn.cast_signed(U, A[D.b, D.m, D.k]) - TypeFn.cast_signed(U, AZp)) * (
|
||||
TypeFn.cast_signed(U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp))
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
|
@ -156,7 +158,7 @@ def matvec(
|
|||
"""
|
||||
domain(D.m, D.n)
|
||||
implements(ContractionOpInterface)
|
||||
x[D.m] += TypeFn.cast(U, A[D.m, D.n]) * TypeFn.cast(U, y[D.n])
|
||||
x[D.m] += TypeFn.cast_signed(U, A[D.m, D.n]) * TypeFn.cast_signed(U, y[D.n])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
|
@ -171,7 +173,7 @@ def vecmat(
|
|||
"""
|
||||
domain(D.n, D.m)
|
||||
implements(ContractionOpInterface)
|
||||
x[D.n] += TypeFn.cast(U, y[D.m]) * TypeFn.cast(U, A[D.m, D.n])
|
||||
x[D.n] += TypeFn.cast_signed(U, y[D.m]) * TypeFn.cast_signed(U, A[D.m, D.n])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
|
@ -186,7 +188,8 @@ def batch_matvec(
|
|||
"""
|
||||
domain(D.b, D.m, D.k)
|
||||
implements(ContractionOpInterface)
|
||||
C[D.b, D.m] += TypeFn.cast(U, A[D.b, D.m, D.k]) * TypeFn.cast(U, B[D.b, D.k])
|
||||
C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
|
||||
U, B[D.b, D.k])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
|
@ -198,7 +201,7 @@ def dot(
|
|||
them to the same data type as the accumulator/output.
|
||||
"""
|
||||
implements(ContractionOpInterface)
|
||||
C[None] += TypeFn.cast(U, A[D.m]) * TypeFn.cast(U, B[D.m])
|
||||
C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
|
@ -213,7 +216,8 @@ def conv_1d(
|
|||
"""
|
||||
implements(ConvolutionOpInterface)
|
||||
domain(D.ow, D.kw)
|
||||
O[D.ow] += TypeFn.cast(U, I[D.ow + D.kw]) * TypeFn.cast(U, K[D.kw])
|
||||
O[D.ow] += TypeFn.cast_signed(U, I[D.ow + D.kw]) * TypeFn.cast_signed(
|
||||
U, K[D.kw])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
|
@ -228,8 +232,8 @@ def conv_2d(
|
|||
"""
|
||||
implements(ConvolutionOpInterface)
|
||||
domain(D.oh, D.ow, D.kh, D.kw)
|
||||
O[D.oh, D.ow] += TypeFn.cast(U, I[D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast(
|
||||
U, K[D.kh, D.kw])
|
||||
O[D.oh, D.ow] += TypeFn.cast_signed(
|
||||
U, I[D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast_signed(U, K[D.kh, D.kw])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
|
@ -244,9 +248,9 @@ def conv_3d(
|
|||
"""
|
||||
implements(ConvolutionOpInterface)
|
||||
domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw)
|
||||
O[D.od, D.oh,
|
||||
D.ow] += TypeFn.cast(U, I[D.od + D.kd, D.oh + D.kh, D.ow +
|
||||
D.kw]) * TypeFn.cast(U, K[D.kd, D.kh, D.kw])
|
||||
O[D.od, D.oh, D.ow] += TypeFn.cast_signed(
|
||||
U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast_signed(
|
||||
U, K[D.kd, D.kh, D.kw])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
|
@ -264,8 +268,8 @@ def conv_1d_nwc_wcf(
|
|||
implements(ConvolutionOpInterface)
|
||||
domain(D.n, D.ow, D.f, D.kw, D.c)
|
||||
O[D.n, D.ow,
|
||||
D.f] += TypeFn.cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW,
|
||||
D.c]) * TypeFn.cast(U, K[D.kw, D.c, D.f])
|
||||
D.f] += TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW,
|
||||
D.c]) * TypeFn.cast_signed(U, K[D.kw, D.c, D.f])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
|
@ -287,9 +291,9 @@ def conv_2d_nhwc_hwcf(
|
|||
"""
|
||||
implements(ConvolutionOpInterface)
|
||||
domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
|
||||
O[D.n, D.oh, D.ow, D.f] += TypeFn.cast(
|
||||
O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed(
|
||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
||||
D.c]) * TypeFn.cast(U, K[D.kh, D.kw, D.c, D.f])
|
||||
D.c]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
|
@ -315,10 +319,11 @@ def conv_2d_nhwc_hwcf_q(
|
|||
implements(ConvolutionOpInterface)
|
||||
domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
|
||||
O[D.n, D.oh, D.ow,
|
||||
D.f] += (TypeFn.cast(
|
||||
D.f] += (TypeFn.cast_signed(
|
||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) -
|
||||
TypeFn.cast(U, IZp)) * (
|
||||
TypeFn.cast(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast(U, KZp))
|
||||
TypeFn.cast_signed(U, IZp)) * (
|
||||
TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) -
|
||||
TypeFn.cast_signed(U, KZp))
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
|
@ -340,9 +345,9 @@ def conv_2d_nchw_fchw(
|
|||
"""
|
||||
implements(ConvolutionOpInterface)
|
||||
domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
|
||||
O[D.n, D.f, D.oh, D.ow] += TypeFn.cast(
|
||||
U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH,
|
||||
D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast(U, K[D.f, D.c, D.kh, D.kw])
|
||||
O[D.n, D.f, D.oh, D.ow] += TypeFn.cast_signed(
|
||||
U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW +
|
||||
D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
|
@ -360,9 +365,9 @@ def conv_3d_ndhwc_dhwcf(
|
|||
"""
|
||||
implements(ConvolutionOpInterface)
|
||||
domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
|
||||
O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast(
|
||||
O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast_signed(
|
||||
U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
|
||||
D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast(
|
||||
D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed(
|
||||
U, K[D.kd, D.kh, D.kw, D.c, D.f])
|
||||
|
||||
|
||||
|
@ -382,8 +387,8 @@ def depthwise_conv_1d_nwc_wc(
|
|||
implements(ConvolutionOpInterface)
|
||||
domain(D.n, D.ow, D.ic, D.kw)
|
||||
O[D.n, D.ow, D.ic] += \
|
||||
TypeFn.cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \
|
||||
TypeFn.cast(U, K[D.kw, D.ic])
|
||||
TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \
|
||||
TypeFn.cast_signed(U, K[D.kw, D.ic])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
|
@ -402,9 +407,9 @@ def depthwise_conv_2d_nhwc_hwc(
|
|||
"""
|
||||
implements(ConvolutionOpInterface)
|
||||
domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
|
||||
O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast(
|
||||
O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast_signed(
|
||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
||||
D.ic]) * TypeFn.cast(U, K[D.kh, D.kw, D.ic])
|
||||
D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
|
@ -424,11 +429,11 @@ def depthwise_conv_2d_nhwc_hwc_q(
|
|||
"""
|
||||
implements(ConvolutionOpInterface)
|
||||
domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
|
||||
O[D.n, D.oh, D.ow,
|
||||
D.ic] += ((TypeFn.cast(
|
||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) -
|
||||
TypeFn.cast(U, IZp)) *
|
||||
(TypeFn.cast(U, K[D.kh, D.kw, D.ic]) - TypeFn.cast(U, KZp)))
|
||||
O[D.n, D.oh, D.ow, D.ic] += ((TypeFn.cast_signed(
|
||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) -
|
||||
TypeFn.cast_signed(U, IZp)) *
|
||||
(TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) -
|
||||
TypeFn.cast_signed(U, KZp)))
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
|
@ -446,9 +451,9 @@ def depthwise_conv_2d_nhwc_hwcm(
|
|||
"""
|
||||
implements(ConvolutionOpInterface)
|
||||
domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
|
||||
O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast(
|
||||
O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
|
||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
||||
D.ic]) * TypeFn.cast(U, K[D.kh, D.kw, D.ic, D.cm])
|
||||
D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
|
@ -469,10 +474,11 @@ def depthwise_conv_2d_nhwc_hwcm_q(
|
|||
implements(ConvolutionOpInterface)
|
||||
domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
|
||||
O[D.n, D.oh, D.ow, D.ic,
|
||||
D.cm] += ((TypeFn.cast(
|
||||
D.cm] += ((TypeFn.cast_signed(
|
||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) -
|
||||
TypeFn.cast(U, IZp)) *
|
||||
(TypeFn.cast(U, K[D.kh, D.kw, D.ic, D.cm]) - TypeFn.cast(U, KZp)))
|
||||
TypeFn.cast_signed(U, IZp)) *
|
||||
(TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) -
|
||||
TypeFn.cast_signed(U, KZp)))
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
|
@ -490,7 +496,7 @@ def pooling_nhwc_sum(
|
|||
"""
|
||||
implements(ConvolutionOpInterface)
|
||||
domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
|
||||
O[D.n, D.oh, D.ow, D.c] += TypeFn.cast(
|
||||
O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed(
|
||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
|
||||
|
||||
|
||||
|
@ -509,8 +515,8 @@ def pooling_nhwc_max(
|
|||
"""
|
||||
implements(ConvolutionOpInterface)
|
||||
domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
|
||||
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw](
|
||||
TypeFn.cast(
|
||||
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw](
|
||||
TypeFn.cast_signed(
|
||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
|
||||
|
||||
|
||||
|
@ -549,8 +555,8 @@ def pooling_nchw_max(
|
|||
"""
|
||||
implements(ConvolutionOpInterface)
|
||||
domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
|
||||
O[D.n, D.c, D.oh, D.ow] = ReduceFn.max[D.kh, D.kw](
|
||||
TypeFn.cast(
|
||||
O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw](
|
||||
TypeFn.cast_signed(
|
||||
U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH,
|
||||
D.ow * S.SW + D.kw * S.DW,]))
|
||||
|
||||
|
@ -570,8 +576,8 @@ def pooling_nhwc_min(
|
|||
"""
|
||||
implements(ConvolutionOpInterface)
|
||||
domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
|
||||
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw](
|
||||
TypeFn.cast(
|
||||
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw](
|
||||
TypeFn.cast_signed(
|
||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
|
||||
|
||||
|
||||
|
@ -610,7 +616,7 @@ def pooling_ndhwc_sum(
|
|||
"""
|
||||
implements(ConvolutionOpInterface)
|
||||
domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
|
||||
O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast(
|
||||
O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast_signed(
|
||||
U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
|
||||
D.ow * S.SW + D.kw * S.DW, D.c])
|
||||
|
||||
|
@ -630,8 +636,8 @@ def pooling_ndhwc_max(
|
|||
"""
|
||||
implements(ConvolutionOpInterface)
|
||||
domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
|
||||
O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max[D.kd, D.kh, D.kw](
|
||||
TypeFn.cast(
|
||||
O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw](
|
||||
TypeFn.cast_signed(
|
||||
U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
|
||||
D.ow * S.SW + D.kw * S.DW, D.c]))
|
||||
|
||||
|
@ -651,8 +657,8 @@ def pooling_ndhwc_min(
|
|||
"""
|
||||
implements(ConvolutionOpInterface)
|
||||
domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
|
||||
O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min[D.kd, D.kh, D.kw](
|
||||
TypeFn.cast(
|
||||
O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw](
|
||||
TypeFn.cast_signed(
|
||||
U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
|
||||
D.ow * S.SW + D.kw * S.DW, D.c]))
|
||||
|
||||
|
@ -665,7 +671,7 @@ def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)):
|
|||
accesses only and is thus rank polymorphic. Numeric casting is performed on
|
||||
the value operand, promoting it to the same data type as the output.
|
||||
"""
|
||||
O[None] = TypeFn.cast(U, value)
|
||||
O[None] = TypeFn.cast_signed(U, value)
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
|
@ -685,15 +691,15 @@ def fill_rng_2d(
|
|||
the range of the generated random numbers.
|
||||
"""
|
||||
domain(D.m, D.n)
|
||||
multiplier = TypeFn.cast(I32, const(1103515245))
|
||||
increment = TypeFn.cast(I32, const(12345))
|
||||
rand1 = (TypeFn.cast(I32, index(D.m)) + seed) * multiplier + increment
|
||||
rand2 = (TypeFn.cast(I32, index(D.n)) + rand1) * multiplier + increment
|
||||
inv_range = TypeFn.cast(F64, const(2.3283064e-10))
|
||||
offset = TypeFn.cast(F64, const(2147483647))
|
||||
multiplier = TypeFn.cast_signed(I32, const(1103515245))
|
||||
increment = TypeFn.cast_signed(I32, const(12345))
|
||||
rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment
|
||||
rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment
|
||||
inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10))
|
||||
offset = TypeFn.cast_signed(F64, const(2147483647))
|
||||
scaling = (max - min) * inv_range
|
||||
O[D.m, D.n] = TypeFn.cast(T,
|
||||
(offset + TypeFn.cast(F64, rand2)) * scaling + min)
|
||||
O[D.m, D.n] = TypeFn.cast_signed(
|
||||
T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min)
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
|
@ -706,4 +712,4 @@ def soft_plus_2d(
|
|||
"""
|
||||
domain(D.m, D.n)
|
||||
O[D.m, D.n] = \
|
||||
UnaryFn.log(TypeFn.cast(U, const(1.0)) + UnaryFn.exp(TypeFn.cast(U, I[D.m, D.n])))
|
||||
UnaryFn.log(TypeFn.cast_signed(U, const(1.0)) + UnaryFn.exp(TypeFn.cast_signed(U, I[D.m, D.n])))
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
# @linalg_structured_op
|
||||
# def test1(O=TensorDef(T, S.M, S.N, output=True),
|
||||
# cast=TypeFnAttrDef(default=TypeFn.cast)):
|
||||
# cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
|
||||
# """Title.
|
||||
|
||||
# Detailed description.
|
||||
|
@ -28,7 +28,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !LinalgOperandDefConfig
|
||||
name: cast
|
||||
kind: type_fn_attr
|
||||
default_fn: cast
|
||||
default_fn: cast_signed
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
|
||||
|
@ -70,7 +70,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
# ODS: let arguments =
|
||||
# ODS-NEXT: Variadic<AnyType>:$inputs,
|
||||
# ODS-NEXT: Variadic<AnyShaped>:$outputs,
|
||||
# ODS-NEXT: DefaultValuedAttr<TypeFnAttr, "TypeFn::cast">:$cast
|
||||
# ODS-NEXT: DefaultValuedAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
|
||||
|
||||
# ODS: let builders =
|
||||
# ODS: (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
||||
|
@ -99,7 +99,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
|
||||
# IMPL-LABEL: void Test1Op::regionBuilder(ImplicitLocOpBuilder &b,
|
||||
# IMPL-NEXT: Block &block, ArrayRef<NamedAttribute> attrs)
|
||||
# IMPL: TypeFn castVal = TypeFn::cast;
|
||||
# IMPL: TypeFn castVal = TypeFn::cast_signed;
|
||||
# IMPL-NEXT: auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
|
||||
# IMPL-NEXT: return attr.getName() == "cast"; });
|
||||
# IMPL-NEXT: if (castIter != attrs.end()) {
|
||||
|
@ -209,7 +209,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
|
||||
# Detailed description.
|
||||
# """
|
||||
# O[None] = TypeFn.cast(U, value)
|
||||
# O[None] = TypeFn.cast_signed(U, value)
|
||||
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
|
@ -241,7 +241,7 @@ structured_op: !LinalgStructuredOpConfig
|
|||
value: !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
|
|
|
@ -26,7 +26,7 @@ from mlir.dialects.linalg.opdsl.lang import *
|
|||
# CHECK: default_fn: exp
|
||||
# CHECK: name: cast
|
||||
# CHECK: kind: type_fn_attr
|
||||
# CHECK: default_fn: cast
|
||||
# CHECK: default_fn: cast_signed
|
||||
@linalg_structured_op
|
||||
def matmul(
|
||||
A=TensorDef(T, S.M, S.K),
|
||||
|
@ -34,7 +34,7 @@ def matmul(
|
|||
C=TensorDef(U, S.M, S.N, output=True),
|
||||
bfn=BinaryFnAttrDef(default=BinaryFn.mul),
|
||||
ufn=UnaryFnAttrDef(default=UnaryFn.exp),
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast)):
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
|
||||
C[D.m, D.n] += bfn(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
|
||||
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ def matmul(
|
|||
B=TensorDef(T, S.K, S.N),
|
||||
C=TensorDef(U, S.M, S.N, output=True),
|
||||
mul=BinaryFnAttrDef(default=BinaryFn.mul),
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast)):
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
|
||||
C[D.m, D.n] += mul(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
|
||||
|
||||
|
||||
|
@ -63,13 +63,13 @@ def matmul(
|
|||
# CHECK: scalar_const: '3.1415926535897931 : f64'
|
||||
# CHECK: scalar_fn:
|
||||
# CHECK: kind: type
|
||||
# CHECK: fn_name: cast
|
||||
# CHECK: fn_name: cast_signed
|
||||
# CHECK: type_var: T
|
||||
# CHECK: operands:
|
||||
# CHECK: scalar_const: '42 : i64'
|
||||
# CHECK: scalar_fn:
|
||||
# CHECK: kind: type
|
||||
# CHECK: fn_name: cast
|
||||
# CHECK: fn_name: cast_signed
|
||||
# CHECK: type_var: T
|
||||
# CHECK: operands:
|
||||
# CHECK: scalar_fn:
|
||||
|
@ -81,9 +81,9 @@ def matmul(
|
|||
def constants(
|
||||
O=TensorDef(T, S.M, S.K, output=True),
|
||||
exp=UnaryFnAttrDef(default=UnaryFn.exp)):
|
||||
pi = TypeFn.cast(T, const(3.1415926535897931))
|
||||
cst42 = TypeFn.cast(T, const(42))
|
||||
cst1000 = TypeFn.cast(T, exp(const(1e+3)))
|
||||
pi = TypeFn.cast_signed(T, const(3.1415926535897931))
|
||||
cst42 = TypeFn.cast_signed(T, const(42))
|
||||
cst1000 = TypeFn.cast_signed(T, exp(const(1e+3)))
|
||||
O[D.m, D.n] = UnaryFn.exp(pi) + cst42 - cst1000
|
||||
|
||||
|
||||
|
|
|
@ -19,9 +19,9 @@ def conv_poly(
|
|||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 2])):
|
||||
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||
O[D.n, D.oh, D.ow, D.c] += TypeFn.cast(
|
||||
O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed(
|
||||
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
|
||||
D.c]) * TypeFn.cast(U, K[D.kh, D.kw, D.c])
|
||||
D.c]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c])
|
||||
|
||||
|
||||
with Context() as ctx, Location.unknown():
|
||||
|
|
|
@ -13,7 +13,7 @@ T2 = TV.T2
|
|||
|
||||
@linalg_structured_op
|
||||
def fill_poly(value=ScalarDef(T1), O=TensorDef(U, output=True)):
|
||||
O[None] = TypeFn.cast(U, value)
|
||||
O[None] = TypeFn.cast_signed(U, value)
|
||||
|
||||
|
||||
with Context() as ctx, Location.unknown():
|
||||
|
|
|
@ -25,7 +25,7 @@ def matmul_poly(
|
|||
A=TensorDef(T1, S.M, S.K),
|
||||
B=TensorDef(T2, S.K, S.N),
|
||||
C=TensorDef(U, S.M, S.N, output=True),
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast)):
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
|
||||
domain(D.m, D.n, D.k)
|
||||
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
|
||||
|
||||
|
|
|
@ -21,22 +21,23 @@ def fill_rng_poly(
|
|||
max=ScalarDef(F64),
|
||||
seed=ScalarDef(I32),
|
||||
O=TensorDef(T, S.M, S.N, output=True)):
|
||||
multiplier = TypeFn.cast(I32, const(1103515245))
|
||||
increment = TypeFn.cast(I32, const(12345))
|
||||
rand1 = (TypeFn.cast(I32, index(D.m)) + seed) * multiplier + increment
|
||||
rand2 = (TypeFn.cast(I32, index(D.n)) + rand1) * multiplier + increment
|
||||
inv_range = TypeFn.cast(F64, const(2.3283064e-10))
|
||||
offset = TypeFn.cast(F64, const(2147483647))
|
||||
multiplier = TypeFn.cast_signed(I32, const(1103515245))
|
||||
increment = TypeFn.cast_signed(I32, const(12345))
|
||||
rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment
|
||||
rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment
|
||||
inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10))
|
||||
offset = TypeFn.cast_signed(F64, const(2147483647))
|
||||
scaling = (max - min) * inv_range
|
||||
O[D.m, D.n] = TypeFn.cast(T,
|
||||
(offset + TypeFn.cast(F64, rand2)) * scaling + min)
|
||||
O[D.m, D.n] = TypeFn.cast_signed(
|
||||
T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min)
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def soft_plus_poly(
|
||||
I=TensorDef(T, S.M, S.N), O=TensorDef(U, S.M, S.N, output=True)):
|
||||
O[D.m, D.n] = UnaryFn.log(
|
||||
TypeFn.cast(U, const(1.0)) + TypeFn.cast(U, UnaryFn.exp(I[D.m, D.n])))
|
||||
TypeFn.cast_signed(U, const(1.0)) +
|
||||
TypeFn.cast_signed(U, UnaryFn.exp(I[D.m, D.n])))
|
||||
|
||||
|
||||
@linalg_structured_op(op_name="custom_op_name")
|
||||
|
|
|
@ -16,8 +16,8 @@ def pooling_poly(
|
|||
I=TensorDef(T1, S.N, S.H, S.W, S.C),
|
||||
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
|
||||
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
|
||||
reduce=BinaryFnAttrDef(default=BinaryFn.max),
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast),
|
||||
reduce=BinaryFnAttrDef(default=BinaryFn.max_signed),
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast_signed),
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
|
||||
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
|
||||
|
@ -99,7 +99,7 @@ with Context() as ctx, Location.unknown():
|
|||
input,
|
||||
shape,
|
||||
outs=[init_result],
|
||||
reduce=BinaryFn.min,
|
||||
reduce=BinaryFn.min_signed,
|
||||
strides=[2, 4],
|
||||
dilations=[1, 2])
|
||||
|
||||
|
@ -131,7 +131,7 @@ with Context() as ctx, Location.unknown():
|
|||
input,
|
||||
shape,
|
||||
outs=[init_result],
|
||||
reduce=BinaryFn.min,
|
||||
reduce=BinaryFn.min_signed,
|
||||
strides=[2, 4],
|
||||
dilations=[1, 2])
|
||||
|
||||
|
|
|
@ -13,4 +13,5 @@ def matmul(
|
|||
B=TensorDef(T, S.K, S.N),
|
||||
C=TensorDef(U, S.M, S.N, output=True)):
|
||||
implements(ContractionOpInterface)
|
||||
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
|
||||
C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
|
||||
U, B[D.k, D.n])
|
||||
|
|
|
@ -24,7 +24,8 @@ def matmul(
|
|||
B=TensorDef(T, S.K, S.N),
|
||||
C=TensorDef(U, S.M, S.N, output=True)):
|
||||
domain(D.m, D.n, D.k)
|
||||
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
|
||||
C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
|
||||
U, B[D.k, D.n])
|
||||
|
||||
|
||||
# Verifies that assignment to a scalar (represented as [None]) is represented
|
||||
|
@ -42,7 +43,7 @@ def matmul(
|
|||
# CHECK-NEXT: - reduction
|
||||
@linalg_structured_op
|
||||
def dot(A=TensorDef(T, S.M), B=TensorDef(T, S.M), C=TensorDef(U, output=True)):
|
||||
C[None] += TypeFn.cast(U, A[D.m]) * TypeFn.cast(U, B[D.m])
|
||||
C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m])
|
||||
|
||||
|
||||
# Verifies that the index_dims of shape-only operands translate to correct
|
||||
|
@ -65,4 +66,4 @@ def pool(
|
|||
K=TensorDef(T, S.K, index_dims=[D.k]),
|
||||
O=TensorDef(U, S.O, output=True)):
|
||||
domain(D.o, D.k)
|
||||
O[D.o] += TypeFn.cast(U, I[D.o * 2 + D.k])
|
||||
O[D.o] += TypeFn.cast_signed(U, I[D.o * 2 + D.k])
|
||||
|
|
|
@ -99,7 +99,7 @@ def testNamedStructuredOpCustomForm():
|
|||
init_result = linalg.InitTensorOp([4, 8], f32)
|
||||
# Check for the named form with custom format
|
||||
# CHECK: linalg.elemwise_unary
|
||||
# CHECK-SAME: cast = #linalg.type_fn<cast>
|
||||
# CHECK-SAME: cast = #linalg.type_fn<cast_signed>
|
||||
# CHECK-SAME: fun = #linalg.unary_fn<exp>
|
||||
# CHECK-SAME: ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
|
||||
unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result])
|
||||
|
@ -137,7 +137,7 @@ def testNamedStructuredOpGenericForm():
|
|||
# CHECK-NEXT: arith.mulf{{.*}} (f32, f32) -> f32
|
||||
# CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32
|
||||
# CHECK-NEXT: linalg.yield{{.*}} (f32) -> ()
|
||||
# CHECK-NEXT: cast = #linalg.type_fn<cast>
|
||||
# CHECK-NEXT: cast = #linalg.type_fn<cast_signed>
|
||||
# CHECK-SAME: operand_segment_sizes = dense<[2, 1]> : vector<2xi32>
|
||||
# CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
|
||||
return linalg.matmul(lhs, rhs, outs=[init_result.result])
|
||||
|
|
Loading…
Reference in New Issue