!44848 [assistant] [ops] add LessEqual

Merge pull request !44848 from lxw/master
This commit is contained in:
i-robot 2022-11-18 02:02:50 +00:00 committed by Gitee
commit 01ef7c3d1e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 57 additions and 2 deletions

View File

@ -521,7 +521,21 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithLogicCpuFunc
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<double>}}},
SpecializeArithLogFunc<double>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<float16>}}},
{kLogicalAnd,
{{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<bool>}}},

View File

@ -47,8 +47,8 @@ TypePtr LessEqualInferType(const PrimitivePtr &prim, const std::vector<AbstractB
MIND_API_OPERATOR_IMPL(LessEqual, BaseOperator);
AbstractBasePtr LessEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto infer_shape = LessEqualInferShape(primitive, input_args);
auto infer_type = LessEqualInferType(primitive, input_args);
auto infer_shape = LessEqualInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(LessEqual, prim::kPrimLessEqual, LessEqualInfer, nullptr, true);

View File

@ -0,0 +1,41 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LessEqual op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
less_equal_op_info = AiCPURegOp("LessEqual") \
.fusion_type("OPAQUE") \
.input(0, "x1", "required") \
.input(1, "x2", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(less_equal_op_info)
def _less_equal_aicpu():
"""LessEqual aicpu register"""
return