From 55b991177bb50ceaa1ddb433ae63d883bf1e7f77 Mon Sep 17 00:00:00 2001 From: z00512249 Date: Tue, 13 Sep 2022 17:33:03 +0800 Subject: [PATCH] fix Image, log_matrix_determinant, lp_norm kernel bugs --- .../device/cpu/kernel/unary_op_cpu_kernel.cc | 73 ++++++++----------- mindspore/core/ops/log_matrix_determinant.cc | 4 + mindspore/core/ops/lp_norm.cc | 23 +++--- mindspore/core/ops/matrix_determinant.cc | 3 + mindspore/python/mindspore/common/tensor.py | 2 +- .../mindspore/ops/function/math_func.py | 2 +- mindspore/python/mindspore/ops/functional.py | 1 - .../test_fused_cast_adam_weight_decay_cpu.py | 3 +- .../parallel/test_auto_parallel_for_loop.py | 7 +- .../test_auto_parallel_for_loop_reshape.py | 5 +- 10 files changed, 63 insertions(+), 60 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/unary_op_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/unary_op_cpu_kernel.cc index 7f87cb6ce6f..f3c82b9746c 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/unary_op_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/unary_op_cpu_kernel.cc @@ -55,61 +55,49 @@ void Ceil(const float *input, float *output, size_t start, size_t end) { template void Real(const T *input, S *output, size_t start, size_t end) { - if (!std::is_same::value && !std::is_same::value) { - for (size_t i = start; i < end; ++i) { - output[i] = static_cast(std::real(input[i])); - } - } else { - MS_LOG(EXCEPTION) << "For Real, it's output data type only support these types: float or double"; + for (size_t i = start; i < end; ++i) { + output[i] = static_cast(std::real(input[i])); } } template void Imag(const T *input, S *output, size_t start, size_t end) { - if constexpr (!std::is_same>::value && !std::is_same>::value) { - for (size_t i = start; i < end; ++i) { - output[i] = static_cast(std::imag(input[i])); - } - } else { - MS_LOG(EXCEPTION) << "For Imag, it's output data type only support these types: float or double"; + for (size_t i = start; i < end; ++i) { + output[i] = static_cast(std::imag(input[i])); } } template void Conj(const T *input, S *output, size_t start, size_t end) { - if constexpr (std::is_same::value && - (std::is_same::value || std::is_same::value)) { - for (size_t i = start; i < end; ++i) { - output[i] = static_cast(std::conj(input[i])); + if constexpr (std::is_same::value) { + if constexpr ((std::is_same::value || std::is_same::value)) { + for (size_t i = start; i < end; ++i) { + output[i] = static_cast(std::conj(input[i])); + } + } else { + for (size_t i = start; i < end; ++i) { + output[i] = static_cast(input[i]); + } } } else { - MS_LOG(EXCEPTION) << "For Conj, it's output data type only support these types: complex or complex"; + MS_LOG(EXCEPTION) << "For Conj, it's output data type not equal to input data type."; } } template void UnaryOpCpuKernelFunc::GetUnaryOpFunc() { - // only support float + const std::map kCommonSupportedMap = {{prim::kPrimReal->name(), &Real}, + {prim::kPrimImag->name(), &Imag}, + {prim::kPrimConj->name(), &Conj}}; if constexpr (std::is_same::value) { - static std::map kFloatSupportedMap = {{prim::kPrimCeil->name(), &Ceil}}; - auto iter = kFloatSupportedMap.find(kernel_name_); - if (iter != kFloatSupportedMap.end()) { - unary_op_func_ = iter->second; + if (kernel_name_ == prim::kPrimCeil->name()) { + unary_op_func_ = &Ceil; return; } } - - if constexpr (std::is_same::value || std::is_same::value) { - static std::map kComplexSupportedTypeMap = {{prim::kPrimReal->name(), &Real}, - {prim::kPrimImag->name(), &Imag}, - {prim::kPrimConj->name(), &Conj}}; - auto iter = kComplexSupportedTypeMap.find(kernel_name_); - if (iter != kComplexSupportedTypeMap.end()) { - unary_op_func_ = iter->second; - return; - } - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << ", only support these types: Real, Imag, Conj currently, but got " - << kernel_name_; + auto iter = kCommonSupportedMap.find(kernel_name_); + if (iter != kCommonSupportedMap.end()) { + unary_op_func_ = iter->second; } } @@ -126,15 +114,12 @@ bool UnaryOpCpuKernelFunc::RunFunc(const std::vector &inputs, auto output = outputs.front(); const auto input_addr = reinterpret_cast(input->addr); auto output_addr = reinterpret_cast(output->addr); - if (unary_op_func_ != nullptr) { - ParallelLaunchAutoSearch( - std::bind(unary_op_func_, input_addr, output_addr, std::placeholders::_1, std::placeholders::_2), - output->size / sizeof(S), this, ¶llel_search_info_); - } else { - if (memcpy_s(output_addr, output->size, input_addr, input->size) != EOK) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does memory copy fail."; - } + if (unary_op_func_ == nullptr) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it has no cpu backend implements."; } + ParallelLaunchAutoSearch( + std::bind(unary_op_func_, input_addr, output_addr, std::placeholders::_1, std::placeholders::_2), + output->size / sizeof(S), this, ¶llel_search_info_); return true; } @@ -192,7 +177,9 @@ std::map>> SpecializeUnaryFunc}, {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), SpecializeUnaryFunc}, - {KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), SpecializeUnaryFunc}}}, + {KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), SpecializeUnaryFunc}, + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + SpecializeUnaryFunc}}}, {kConj, {{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), SpecializeUnaryFunc}, diff --git a/mindspore/core/ops/log_matrix_determinant.cc b/mindspore/core/ops/log_matrix_determinant.cc index 30ea982dab8..7a455a10c9a 100644 --- a/mindspore/core/ops/log_matrix_determinant.cc +++ b/mindspore/core/ops/log_matrix_determinant.cc @@ -30,6 +30,10 @@ abstract::TupleShapePtr LogMatrixDeterminantInferShape(const PrimitivePtr &primi const std::vector &input_args) { auto prim_name = primitive->name(); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + if (IsDynamicRank(x_shape)) { + abstract::ShapePtr out_shape = std::make_shared(std::vector{UNKNOWN_RANK}); + return std::make_shared(std::vector{out_shape, out_shape}); + } auto x_rank = SizeToLong(x_shape.size()); constexpr int64_t number1 = 1; constexpr int64_t number2 = 2; diff --git a/mindspore/core/ops/lp_norm.cc b/mindspore/core/ops/lp_norm.cc index 0edd04b7264..ccd39e28301 100644 --- a/mindspore/core/ops/lp_norm.cc +++ b/mindspore/core/ops/lp_norm.cc @@ -34,6 +34,9 @@ abstract::ShapePtr LpNormInferShape(const PrimitivePtr &primitive, const std::ve MS_EXCEPTION_IF_NULL(item); } auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + if (IsDynamicRank(input_shape)) { + return std::make_shared(std::vector{UNKNOWN_RANK}); + } auto output_shape = input_shape; auto input_rank = SizeToLong(input_shape.size()); auto axis = GetValue>(primitive->GetAttr("axis")); @@ -44,13 +47,18 @@ abstract::ShapePtr LpNormInferShape(const PrimitivePtr &primitive, const std::ve } else { CheckAndConvertUtils::CheckInRange("axis size", axis.size(), kIncludeNeither, {0, input_rank + 1}, prim_name); } - if (axis.size() > 1) { - for (size_t i = 0; i < axis.size(); ++i) { - CheckAndConvertUtils::CheckInRange("axis value", axis[i], kIncludeLeft, {-input_rank, input_rank}, prim_name); - if (axis[i] < 0) { - axis[i] += input_rank; - } + for (int64_t &axi : axis) { + CheckAndConvertUtils::CheckInRange("axis value", axi, kIncludeLeft, {-input_rank, input_rank}, prim_name); + if (axi < 0) { + axi += input_rank; } + } + bool invalid_axis = std::any_of(axis.begin(), axis.end(), [&input_rank](int64_t axis) { return axis >= input_rank; }); + if (invalid_axis) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ", the value of axis is out of range (-" << input_rank << ", " + << input_rank << ")."; + } + if (axis.size() > 1) { constexpr int64_t place_holder = INT64_MAX; for (size_t i = 0; i < axis.size(); ++i) { auto temp = axis; @@ -77,9 +85,6 @@ abstract::ShapePtr LpNormInferShape(const PrimitivePtr &primitive, const std::ve } } } else { - if (axis[0] < 0) { - axis[0] += input_rank; - } if (!keep_dims) { (void)output_shape.erase(output_shape.begin() + axis[0]); } else { diff --git a/mindspore/core/ops/matrix_determinant.cc b/mindspore/core/ops/matrix_determinant.cc index 77c6dfd78be..87253c3120e 100644 --- a/mindspore/core/ops/matrix_determinant.cc +++ b/mindspore/core/ops/matrix_determinant.cc @@ -30,6 +30,9 @@ abstract::ShapePtr MatrixDeterminantInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { auto prim_name = primitive->name(); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + if (IsDynamicRank(x_shape)) { + return std::make_shared(std::vector{UNKNOWN_RANK}); + } auto x_rank = SizeToLong(x_shape.size()); constexpr int64_t number1 = 1; constexpr int64_t number2 = 2; diff --git a/mindspore/python/mindspore/common/tensor.py b/mindspore/python/mindspore/common/tensor.py index b71b716b82e..628881ae569 100644 --- a/mindspore/python/mindspore/common/tensor.py +++ b/mindspore/python/mindspore/common/tensor.py @@ -1798,7 +1798,7 @@ class Tensor(Tensor_): ValueError: If either of `atol` and `rtol` is less than zero. Supported Platforms: - ``CPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> input = Tensor(np.array([1.3, 2.1, 3.2, 4.1, 5.1]), mindspore.float16) diff --git a/mindspore/python/mindspore/ops/function/math_func.py b/mindspore/python/mindspore/ops/function/math_func.py index 9204ee07c33..4bddebd912e 100644 --- a/mindspore/python/mindspore/ops/function/math_func.py +++ b/mindspore/python/mindspore/ops/function/math_func.py @@ -2936,7 +2936,7 @@ def isclose(x1, x2, rtol=1e-05, atol=1e-08, equal_nan=False): ValueError: If either of `atol` and `rtol` is less than zero. Supported Platforms: - ``CPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> input = Tensor(np.array([1.3, 2.1, 3.2, 4.1, 5.1]), mindspore.float16) diff --git a/mindspore/python/mindspore/ops/functional.py b/mindspore/python/mindspore/ops/functional.py index e2366d2b4fe..e72225a3fa9 100644 --- a/mindspore/python/mindspore/ops/functional.py +++ b/mindspore/python/mindspore/ops/functional.py @@ -54,7 +54,6 @@ merge = P.Merge() geswitch = P.GeSwitch() strided_slice = P.StridedSlice() check_bprop = P.CheckBprop() -sqrt = P.Sqrt() reduce_sum = P.ReduceSum() reduce_max = P.ReduceMax() reduce_min = P.ReduceMin() diff --git a/tests/st/heterogeneous/test_fused_cast_adam_weight_decay_cpu.py b/tests/st/heterogeneous/test_fused_cast_adam_weight_decay_cpu.py index 0270af6c369..7041d41b812 100644 --- a/tests/st/heterogeneous/test_fused_cast_adam_weight_decay_cpu.py +++ b/tests/st/heterogeneous/test_fused_cast_adam_weight_decay_cpu.py @@ -115,11 +115,12 @@ class GlobalNorm(nn.Cell): super(GlobalNorm, self).__init__() self.norm = nn.Norm() self.hyper_map = C.HyperMap() + self.sqrt = P.Sqrt() def construct(self, grads): """Calculate global norm construct""" square_sum = self.hyper_map(get_square_sum, grads) - global_norms = F.sqrt(F.addn(square_sum)) + global_norms = self.sqrt(F.addn(square_sum)) return global_norms diff --git a/tests/ut/python/parallel/test_auto_parallel_for_loop.py b/tests/ut/python/parallel/test_auto_parallel_for_loop.py index 252107d2415..d67d8ee7917 100644 --- a/tests/ut/python/parallel/test_auto_parallel_for_loop.py +++ b/tests/ut/python/parallel/test_auto_parallel_for_loop.py @@ -18,7 +18,7 @@ import mindspore as ms from mindspore import context, Tensor, Parameter from mindspore.nn import Cell import mindspore.nn as nn -from mindspore.ops import operations as P, functional as F +from mindspore.ops import operations as P from mindspore.common.initializer import initializer import mindspore.common.dtype as mstype from mindspore.common.api import _cell_graph_executor @@ -63,11 +63,14 @@ class LayerNorm(nn.Cell): self.mul = P.Mul() self.div = P.RealDiv() self.square = P.Square() + self.sqrt = P.Sqrt() def construct(self, x): mean = self.mean(x, -1) variance = self.mean(self.square(self.sub(x, mean))) - output = self.div(self.sub(x, mean), F.sqrt(self.add(variance, self.eps))) + add_variance = self.add(variance, self.eps) + sqrt_variance = self.sqrt(add_variance) + output = self.div(self.sub(x, mean), sqrt_variance) rescaled_output = self.add(self.mul(output, self.gamma), self.beta) return rescaled_output diff --git a/tests/ut/python/parallel/test_auto_parallel_for_loop_reshape.py b/tests/ut/python/parallel/test_auto_parallel_for_loop_reshape.py index 1d4464e27a0..44f6b049868 100644 --- a/tests/ut/python/parallel/test_auto_parallel_for_loop_reshape.py +++ b/tests/ut/python/parallel/test_auto_parallel_for_loop_reshape.py @@ -18,7 +18,7 @@ import mindspore as ms from mindspore import context, Tensor, Parameter from mindspore.nn import Cell import mindspore.nn as nn -from mindspore.ops import operations as P, functional as F +from mindspore.ops import operations as P from mindspore.common.initializer import initializer import mindspore.common.dtype as mstype from mindspore.common.api import _cell_graph_executor @@ -65,6 +65,7 @@ class LayerNorm(nn.Cell): self.square = P.Square() self.reshape = P.Reshape() self.shape = P.Shape() + self.sqrt = P.Sqrt() def construct(self, x): x_origin_shape = self.shape(x) @@ -74,7 +75,7 @@ class LayerNorm(nn.Cell): x = self.reshape(x, x_target_shape) mean = self.mean(x, -1) variance = self.mean(self.square(self.sub(x, mean))) - output = self.div(self.sub(x, mean), F.sqrt(self.add(variance, self.eps))) + output = self.div(self.sub(x, mean), self.sqrt(self.add(variance, self.eps))) rescaled_output = self.add(self.mul(output, self.gamma), self.beta) output_shape = self.shape(rescaled_output) + (1,) rescaled_output = self.reshape(rescaled_output, output_shape)