!44848 [assistant] [ops] add LessEqual
Merge pull request !44848 from lxw/master
This commit is contained in:
commit
01ef7c3d1e
|
@ -521,7 +521,21 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithLogicCpuFunc
|
|||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<double>}}},
|
||||
SpecializeArithLogFunc<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<float16>}}},
|
||||
{kLogicalAnd,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<bool>}}},
|
||||
|
|
|
@ -47,8 +47,8 @@ TypePtr LessEqualInferType(const PrimitivePtr &prim, const std::vector<AbstractB
|
|||
MIND_API_OPERATOR_IMPL(LessEqual, BaseOperator);
|
||||
AbstractBasePtr LessEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto infer_shape = LessEqualInferShape(primitive, input_args);
|
||||
auto infer_type = LessEqualInferType(primitive, input_args);
|
||||
auto infer_shape = LessEqualInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(LessEqual, prim::kPrimLessEqual, LessEqualInfer, nullptr, true);
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""LessEqual op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
less_equal_op_info = AiCPURegOp("LessEqual") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x1", "required") \
|
||||
.input(1, "x2", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(less_equal_op_info)
|
||||
def _less_equal_aicpu():
|
||||
"""LessEqual aicpu register"""
|
||||
return
|
Loading…
Reference in New Issue