From aa82a2dd2cdb3e569c886d4a7faa58197c04395a Mon Sep 17 00:00:00 2001 From: zlq2020 Date: Wed, 1 Mar 2023 14:24:16 +0800 Subject: [PATCH] fix eye bug --- .jenkins/check/config/whitelizard.txt | 1 + mindspore/core/ops/eye.cc | 32 +++++++++++-------- .../mindspore/ops/function/math_func.py | 1 + 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt index 1ef6537bc0e..1bd68bd67de 100644 --- a/.jenkins/check/config/whitelizard.txt +++ b/.jenkins/check/config/whitelizard.txt @@ -36,6 +36,7 @@ mindspore/mindspore/python/mindspore/ops/function/nn_func.py:max_unpool1d mindspore/mindspore/python/mindspore/ops/function/nn_func.py:pad mindspore/mindspore/python/mindspore/ops/function/math_func.py:cov mindspore/mindspore/python/mindspore/ops/function/math_func.py:norm +mindspore/mindspore/python/mindspore/ops/function/math_func.py:einsum mindspore/mindspore/python/mindspore/context.py:set_auto_parallel_context mindspore/mindspore/python/mindspore/common/tensor.py:__init__ mindspore/mindspore/python/mindspore/common/parameter.py:set_data diff --git a/mindspore/core/ops/eye.cc b/mindspore/core/ops/eye.cc index 5400d3d2a8e..c883ff27cef 100644 --- a/mindspore/core/ops/eye.cc +++ b/mindspore/core/ops/eye.cc @@ -63,20 +63,6 @@ abstract::ShapePtr EyeInferShape(const PrimitivePtr &primitive, const std::vecto return std::make_shared(state_shape); } -TypePtr EyeInferType(const PrimitivePtr &prim, const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(prim); - auto prim_name = prim->name(); - auto dtype_value = input_args[2]->BuildValue(); - if (!dtype_value->isa()) { - MS_EXCEPTION(TypeError) << "For Eye, the dtype of Eye is invalid!"; - } - auto output_type = dtype_value->cast(); - const std::set valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32, - kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128, kBool}; - auto dtype_ret = CheckAndConvertUtils::CheckSubClass("dtype", output_type, valid_types, prim_name); - return dtype_ret; -} - void EyeCheck(const std::vector &input_args) { if (!input_args[0]->isa()) { MS_EXCEPTION(TypeError) << "For Eye, 'n' must be int, but got AnyValue!"; @@ -86,6 +72,8 @@ void EyeCheck(const std::vector &input_args) { } auto n_ptr_ = input_args[0]->BuildValue(); auto m_ptr_ = input_args[1]->BuildValue(); + MS_EXCEPTION_IF_NULL(n_ptr_); + MS_EXCEPTION_IF_NULL(m_ptr_); if (!n_ptr_->isa() && !n_ptr_->isa()) { MS_EXCEPTION(TypeError) << "For Eye, the dtype of n is invalid!"; } @@ -93,11 +81,27 @@ void EyeCheck(const std::vector &input_args) { MS_EXCEPTION(TypeError) << "For Eye, the dtype of m is invalid!"; } auto dtype_value_c = input_args[2]->BuildValue(); + MS_EXCEPTION_IF_NULL(dtype_value_c); if (!dtype_value_c->isa()) { MS_EXCEPTION(TypeError) << "For Eye, the dtype of Eye is invalid!"; } } +TypePtr EyeInferType(const PrimitivePtr &prim, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + EyeCheck(input_args); + auto dtype_value = input_args[2]->BuildValue(); + if (!dtype_value->isa()) { + MS_EXCEPTION(TypeError) << "For Eye, the dtype of Eye is invalid!"; + } + auto output_type = dtype_value->cast(); + const std::set valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32, + kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128, kBool}; + auto dtype_ret = CheckAndConvertUtils::CheckSubClass("dtype", output_type, valid_types, prim_name); + return dtype_ret; +} + ValuePtr EyeInferValue(const PrimitivePtr &prim, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(prim); auto prim_name = prim->name(); diff --git a/mindspore/python/mindspore/ops/function/math_func.py b/mindspore/python/mindspore/ops/function/math_func.py index 56f6495a895..7b2875dff40 100644 --- a/mindspore/python/mindspore/ops/function/math_func.py +++ b/mindspore/python/mindspore/ops/function/math_func.py @@ -9445,6 +9445,7 @@ def einsum(equation, *operands): >>> x = mindspore.Tensor([1, 2, 3, 4], mindspore.float32) >>> y = mindspore.Tensor([1, 2], mindspore.float32) >>> output = ops.einsum(x, [..., 1], y, [..., 2], [..., 1, 2]) + >>> print(output) [[1. 2.] [2. 4.] [3. 6.]