forked from OSchip/llvm-project
[mlir][linalg] Add a few unary operations.
Add operations abs, ceil, floor, and neg to the C++ API and Python API. Add test cases. Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D121339
This commit is contained in:
parent
e0f549a43a
commit
13d3307176
|
@ -61,7 +61,11 @@ def Linalg_Dialect : Dialect {
|
||||||
// Define the function attribute enums matching the OpDSL functions.
|
// Define the function attribute enums matching the OpDSL functions.
|
||||||
def UnaryFn : I32EnumAttr<"UnaryFn", "", [
|
def UnaryFn : I32EnumAttr<"UnaryFn", "", [
|
||||||
I32EnumAttrCase<"exp", 0>,
|
I32EnumAttrCase<"exp", 0>,
|
||||||
I32EnumAttrCase<"log", 1>
|
I32EnumAttrCase<"log", 1>,
|
||||||
|
I32EnumAttrCase<"abs", 2>,
|
||||||
|
I32EnumAttrCase<"ceil", 3>,
|
||||||
|
I32EnumAttrCase<"floor", 4>,
|
||||||
|
I32EnumAttrCase<"negf", 5>
|
||||||
]> {
|
]> {
|
||||||
let genSpecializedAttr = 0;
|
let genSpecializedAttr = 0;
|
||||||
let cppNamespace = "::mlir::linalg";
|
let cppNamespace = "::mlir::linalg";
|
||||||
|
|
|
@ -144,6 +144,14 @@ public:
|
||||||
return builder.create<math::ExpOp>(arg.getLoc(), arg);
|
return builder.create<math::ExpOp>(arg.getLoc(), arg);
|
||||||
case UnaryFn::log:
|
case UnaryFn::log:
|
||||||
return builder.create<math::LogOp>(arg.getLoc(), arg);
|
return builder.create<math::LogOp>(arg.getLoc(), arg);
|
||||||
|
case UnaryFn::abs:
|
||||||
|
return builder.create<math::AbsOp>(arg.getLoc(), arg);
|
||||||
|
case UnaryFn::ceil:
|
||||||
|
return builder.create<math::CeilOp>(arg.getLoc(), arg);
|
||||||
|
case UnaryFn::floor:
|
||||||
|
return builder.create<math::FloorOp>(arg.getLoc(), arg);
|
||||||
|
case UnaryFn::negf:
|
||||||
|
return builder.create<arith::NegFOp>(arg.getLoc(), arg);
|
||||||
}
|
}
|
||||||
llvm_unreachable("unsupported unary function");
|
llvm_unreachable("unsupported unary function");
|
||||||
}
|
}
|
||||||
|
|
|
@ -274,6 +274,10 @@ class UnaryFn:
|
||||||
"""Unary function namespace."""
|
"""Unary function namespace."""
|
||||||
exp = UnaryFnType("exp")
|
exp = UnaryFnType("exp")
|
||||||
log = UnaryFnType("log")
|
log = UnaryFnType("log")
|
||||||
|
abs = UnaryFnType("abs")
|
||||||
|
ceil = UnaryFnType("ceil")
|
||||||
|
floor = UnaryFnType("floor")
|
||||||
|
negf = UnaryFnType("negf")
|
||||||
|
|
||||||
|
|
||||||
class BinaryFnType:
|
class BinaryFnType:
|
||||||
|
|
|
@ -390,6 +390,26 @@ class _BodyBuilder:
|
||||||
return math.LogOp(x).result
|
return math.LogOp(x).result
|
||||||
raise NotImplementedError("Unsupported 'log' operand: {x}")
|
raise NotImplementedError("Unsupported 'log' operand: {x}")
|
||||||
|
|
||||||
|
def _unary_abs(self, x: Value) -> Value:
|
||||||
|
if _is_floating_point_type(x.type):
|
||||||
|
return math.AbsOp(x).result
|
||||||
|
raise NotImplementedError("Unsupported 'abs' operand: {x}")
|
||||||
|
|
||||||
|
def _unary_ceil(self, x: Value) -> Value:
|
||||||
|
if _is_floating_point_type(x.type):
|
||||||
|
return math.CeilOp(x).result
|
||||||
|
raise NotImplementedError("Unsupported 'ceil' operand: {x}")
|
||||||
|
|
||||||
|
def _unary_floor(self, x: Value) -> Value:
|
||||||
|
if _is_floating_point_type(x.type):
|
||||||
|
return math.FloorOp(x).result
|
||||||
|
raise NotImplementedError("Unsupported 'floor' operand: {x}")
|
||||||
|
|
||||||
|
def _unary_negf(self, x: Value) -> Value:
|
||||||
|
if _is_floating_point_type(x.type):
|
||||||
|
return arith.NegFOp(x).result
|
||||||
|
raise NotImplementedError("Unsupported 'negf' operand: {x}")
|
||||||
|
|
||||||
def _binary_add(self, lhs: Value, rhs: Value) -> Value:
|
def _binary_add(self, lhs: Value, rhs: Value) -> Value:
|
||||||
if _is_floating_point_type(lhs.type):
|
if _is_floating_point_type(lhs.type):
|
||||||
return arith.AddFOp(lhs, rhs).result
|
return arith.AddFOp(lhs, rhs).result
|
||||||
|
|
|
@ -298,6 +298,54 @@ func @generalize_elemwise_log(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>)
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// Verifies the fun attribute controls the unary function used.
|
||||||
|
func @generalize_elemwise_abs(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
|
||||||
|
%0 = linalg.elemwise_unary {fun = #linalg.unary_fn<abs>}
|
||||||
|
ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
|
||||||
|
return %0: tensor<4x8xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @generalize_elemwise_abs
|
||||||
|
// CHECK: = math.abs
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Verifies the fun attribute controls the unary function used.
|
||||||
|
func @generalize_elemwise_ceil(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
|
||||||
|
%0 = linalg.elemwise_unary {fun = #linalg.unary_fn<ceil>}
|
||||||
|
ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
|
||||||
|
return %0: tensor<4x8xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @generalize_elemwise_ceil
|
||||||
|
// CHECK: = math.ceil
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Verifies the fun attribute controls the unary function used.
|
||||||
|
func @generalize_elemwise_floor(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
|
||||||
|
%0 = linalg.elemwise_unary {fun = #linalg.unary_fn<floor>}
|
||||||
|
ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
|
||||||
|
return %0: tensor<4x8xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @generalize_elemwise_floor
|
||||||
|
// CHECK: = math.floor
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Verifies the fun attribute controls the unary function used.
|
||||||
|
func @generalize_elemwise_negf(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
|
||||||
|
%0 = linalg.elemwise_unary {fun = #linalg.unary_fn<negf>}
|
||||||
|
ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
|
||||||
|
return %0: tensor<4x8xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @generalize_elemwise_negf
|
||||||
|
// CHECK: = arith.negf
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// Verifies the default value of the fun attribute is an add op.
|
// Verifies the default value of the fun attribute is an add op.
|
||||||
func @generalize_elemwise_add(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
|
func @generalize_elemwise_add(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
|
||||||
%0 = linalg.elemwise_binary ins(%lhs, %rhs: tensor<4x8xf32>, tensor<4x8xf32>)
|
%0 = linalg.elemwise_binary ins(%lhs, %rhs: tensor<4x8xf32>, tensor<4x8xf32>)
|
||||||
|
|
|
@ -11,7 +11,7 @@ from mlir.dialects.linalg.opdsl.lang import *
|
||||||
# fill, matmul, convolution, or pooling tests. The features include:
|
# fill, matmul, convolution, or pooling tests. The features include:
|
||||||
# - constant defined in the body
|
# - constant defined in the body
|
||||||
# - fix/predefined types
|
# - fix/predefined types
|
||||||
# - exponential functions
|
# - some math/arith functions, including abs, ceil, exp, floor, log, and negf
|
||||||
# - custom op names.
|
# - custom op names.
|
||||||
|
|
||||||
|
|
||||||
|
@ -89,6 +89,46 @@ with Context() as ctx, Location.unknown():
|
||||||
def test_f32_elemwise_log(input, init_result):
|
def test_f32_elemwise_log(input, init_result):
|
||||||
return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log)
|
return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log)
|
||||||
|
|
||||||
|
# CHECK-LABEL: @test_f32_elemwise_abs
|
||||||
|
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
|
||||||
|
# CHECK-NEXT: %[[EXP:.+]] = math.abs %[[IN]] : f32
|
||||||
|
# CHECK-NEXT: linalg.yield %[[EXP]] : f32
|
||||||
|
# CHECK-NEXT: -> tensor<4x16xf32>
|
||||||
|
@builtin.FuncOp.from_py_func(
|
||||||
|
RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
|
||||||
|
def test_f32_elemwise_abs(input, init_result):
|
||||||
|
return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs)
|
||||||
|
|
||||||
|
# CHECK-LABEL: @test_f32_elemwise_ceil
|
||||||
|
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
|
||||||
|
# CHECK-NEXT: %[[EXP:.+]] = math.ceil %[[IN]] : f32
|
||||||
|
# CHECK-NEXT: linalg.yield %[[EXP]] : f32
|
||||||
|
# CHECK-NEXT: -> tensor<4x16xf32>
|
||||||
|
@builtin.FuncOp.from_py_func(
|
||||||
|
RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
|
||||||
|
def test_f32_elemwise_ceil(input, init_result):
|
||||||
|
return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.ceil)
|
||||||
|
|
||||||
|
# CHECK-LABEL: @test_f32_elemwise_floor
|
||||||
|
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
|
||||||
|
# CHECK-NEXT: %[[EXP:.+]] = math.floor %[[IN]] : f32
|
||||||
|
# CHECK-NEXT: linalg.yield %[[EXP]] : f32
|
||||||
|
# CHECK-NEXT: -> tensor<4x16xf32>
|
||||||
|
@builtin.FuncOp.from_py_func(
|
||||||
|
RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
|
||||||
|
def test_f32_elemwise_floor(input, init_result):
|
||||||
|
return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.floor)
|
||||||
|
|
||||||
|
# CHECK-LABEL: @test_f32_elemwise_neg
|
||||||
|
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
|
||||||
|
# CHECK-NEXT: %[[EXP:.+]] = arith.negf %[[IN]] : f32
|
||||||
|
# CHECK-NEXT: linalg.yield %[[EXP]] : f32
|
||||||
|
# CHECK-NEXT: -> tensor<4x16xf32>
|
||||||
|
@builtin.FuncOp.from_py_func(
|
||||||
|
RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
|
||||||
|
def test_f32_elemwise_neg(input, init_result):
|
||||||
|
return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
|
||||||
|
|
||||||
# Just check that we don't assert out on name mismatch.
|
# Just check that we don't assert out on name mismatch.
|
||||||
# CHECK-LABEL: @test_non_default_op_name
|
# CHECK-LABEL: @test_non_default_op_name
|
||||||
@builtin.FuncOp.from_py_func(
|
@builtin.FuncOp.from_py_func(
|
||||||
|
|
Loading…
Reference in New Issue