diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_logic_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_logic_cpu_kernel.cc index 8d8039fbb89..bba56c0e15e 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_logic_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_logic_cpu_kernel.cc @@ -521,7 +521,21 @@ static std::map}, {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), - SpecializeArithLogFunc}}}, + SpecializeArithLogFunc}, + {KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool), + SpecializeArithLogFunc}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), + SpecializeArithLogFunc}, + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), + SpecializeArithLogFunc}, + {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool), + SpecializeArithLogFunc}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool), + SpecializeArithLogFunc}, + {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool), + SpecializeArithLogFunc}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + SpecializeArithLogFunc}}}, {kLogicalAnd, {{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), SpecializeArithLogFunc}}}, diff --git a/mindspore/core/ops/less_equal.cc b/mindspore/core/ops/less_equal.cc index 3f92568f5fa..73baee6127b 100644 --- a/mindspore/core/ops/less_equal.cc +++ b/mindspore/core/ops/less_equal.cc @@ -47,8 +47,8 @@ TypePtr LessEqualInferType(const PrimitivePtr &prim, const std::vector &input_args) { - auto infer_shape = LessEqualInferShape(primitive, input_args); auto infer_type = LessEqualInferType(primitive, input_args); + auto infer_shape = LessEqualInferShape(primitive, input_args); return abstract::MakeAbstract(infer_shape, infer_type); } REGISTER_PRIMITIVE_EVAL_IMPL(LessEqual, prim::kPrimLessEqual, LessEqualInfer, nullptr, true); diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/less_equal.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/less_equal.py new file mode 100644 index 00000000000..74d377a936c --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/less_equal.py @@ -0,0 +1,41 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""LessEqual op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +less_equal_op_info = AiCPURegOp("LessEqual") \ + .fusion_type("OPAQUE") \ + .input(0, "x1", "required") \ + .input(1, "x2", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \ + .get_op_info() + + +@op_info_register(less_equal_op_info) +def _less_equal_aicpu(): + """LessEqual aicpu register""" + return