From 1bf956787d15d7f0c3accb46af82a168c2223955 Mon Sep 17 00:00:00 2001 From: zheng_pengfei <18865382565@163.com> Date: Mon, 20 Dec 2021 10:56:57 +0800 Subject: [PATCH] [feat] [assistant] [I48OB0] maximum with dynamic infer shape --- mindspore/core/ops/maximum.cc | 61 +++++++++++++++---- mindspore/core/ops/maximum.h | 3 +- .../mindspore/ops/_op_impl/tbe/__init__.py | 1 + .../mindspore/ops/_op_impl/tbe/maximum_ds.py | 40 ++++++++++++ .../mindspore/ops/operations/math_ops.py | 9 --- 5 files changed, 91 insertions(+), 23 deletions(-) create mode 100644 mindspore/python/mindspore/ops/_op_impl/tbe/maximum_ds.py diff --git a/mindspore/core/ops/maximum.cc b/mindspore/core/ops/maximum.cc index e8450bed5c7..01ef0dc922d 100644 --- a/mindspore/core/ops/maximum.cc +++ b/mindspore/core/ops/maximum.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,33 +18,68 @@ #include #include "ops/maximum.h" #include "ops/op_utils.h" -#include "utils/check_convert_utils.h" namespace mindspore { namespace ops { namespace { -abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { +abstract::ShapePtr MaximumInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto op_name = primitive->name(); - return BroadCastInferShape(op_name, input_args); -} - -TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) { + auto prim_name = primitive->name(); + const int64_t kInputNum = 2; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, kInputNum, + prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } + return BroadCastInferShape(prim_name, input_args); +} + +TypePtr MaximumInferType(const PrimitivePtr &prim, const std::vector &input_args) { + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + auto op_name = prim->name(); + const int64_t kInputNum = 2; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, kInputNum, + op_name); std::map types; (void)types.emplace("x", input_args[0]->BuildType()); (void)types.emplace("y", input_args[1]->BuildType()); - return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); + + auto type_x = input_args[0]->BuildType(); + auto type_y = input_args[1]->BuildType(); + MS_EXCEPTION_IF_NULL(type_x); + MS_EXCEPTION_IF_NULL(type_y); + if (type_x->isa() || type_y->isa()) { + if (type_x->type_id() == kNumberTypeComplex64 && type_y->type_id() == kNumberTypeComplex64) { + return type_x; + } else if (type_x->type_id() == kNumberTypeComplex64 && type_y->type_id() == kNumberTypeFloat32) { + return type_x; + } else if (type_x->type_id() == kNumberTypeComplex128 && type_y->type_id() == kNumberTypeComplex128) { + return type_x; + } else if (type_x->type_id() == kNumberTypeComplex128 && type_y->type_id() == kNumberTypeFloat64) { + return type_x; + } else if (type_x->type_id() == kNumberTypeFloat32 && type_y->type_id() == kNumberTypeComplex64) { + return type_y; + } else if (type_x->type_id() == kNumberTypeFloat64 && type_y->type_id() == kNumberTypeComplex128) { + return type_y; + } else { + MS_EXCEPTION(TypeError) + << "Complex math binary op expecting Tensor [complex64, complex64],[complex64, float32], [float32, " + "complex64],[complex128, complex128],[complex128, float64], [float64, complex128], but got[" + << type_x->ToString() << ", " << type_y->ToString() << "]."; + } + } + (void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types_with_complex, prim->name()); + return type_x; } } // namespace - AbstractBasePtr MaximumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - return std::make_shared(InferType(primitive, input_args), - InferShape(primitive, input_args)->shape()); + auto infer_type = MaximumInferType(primitive, input_args); + auto infer_shape = MaximumInferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); } -REGISTER_PRIMITIVE_C(kNameMaximum, Maximum); +REGISTER_PRIMITIVE_EVAL_IMPL(Maximum, prim::kPrimMaximum, MaximumInfer, nullptr, true); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/maximum.h b/mindspore/core/ops/maximum.h index bee84ac83c4..27bd9908719 100644 --- a/mindspore/core/ops/maximum.h +++ b/mindspore/core/ops/maximum.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,6 +39,7 @@ class MS_CORE_API Maximum : public PrimitiveC { }; AbstractBasePtr MaximumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args); +using kPrimMaximumPtr = std::shared_ptr; } // namespace ops } // namespace mindspore #endif // MINDSPORE_CORE_OPS_MAXIMUM_H_ diff --git a/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py index a6415c0f964..aa61a721ff9 100644 --- a/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py @@ -234,6 +234,7 @@ from .logsoftmax_ds import _logsoftmax_ds_tbe from .select import _select_tbe from .pow import _pow_tbe from .maximum import _maximum_tbe +from .maximum_ds import _maximum_ds_tbe from .minimum import _minimum_tbe from .minimum_ds import _minimum_ds_tbe from .minimum_grad import _minimum_grad_tbe diff --git a/mindspore/python/mindspore/ops/_op_impl/tbe/maximum_ds.py b/mindspore/python/mindspore/ops/_op_impl/tbe/maximum_ds.py new file mode 100644 index 00000000000..987a890cb39 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/tbe/maximum_ds.py @@ -0,0 +1,40 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Maximum op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +maximum_op_info = TBERegOp("Maximum") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("maximum.so") \ + .compute_cost(10) \ + .kernel_name("maximum") \ + .partial_flag(True) \ + .dynamic_shape(True) \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .op_pattern("broadcast") \ + .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \ + .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ + .get_op_info() + + +@op_info_register(maximum_op_info) +def _maximum_ds_tbe(): + """Maximum TBE register""" + return diff --git a/mindspore/python/mindspore/ops/operations/math_ops.py b/mindspore/python/mindspore/ops/operations/math_ops.py index 0a98ece9b19..beba13a302c 100644 --- a/mindspore/python/mindspore/ops/operations/math_ops.py +++ b/mindspore/python/mindspore/ops/operations/math_ops.py @@ -2535,15 +2535,6 @@ class Maximum(_MathBinaryOp): Float32 """ - def infer_value(self, x, y): - if x is not None and y is not None: - x = x.asnumpy() - y = y.asnumpy() - out = np.maximum(x, y) - out = np.array(out, x.dtype) - return Tensor(out) - return None - class RealDiv(_MathBinaryOp): """