diff --git a/docs/api/api_python/mindspore.ops.functional.rst b/docs/api/api_python/mindspore.ops.functional.rst index d5e2a09e0c5..50c8867b955 100644 --- a/docs/api/api_python/mindspore.ops.functional.rst +++ b/docs/api/api_python/mindspore.ops.functional.rst @@ -239,6 +239,8 @@ Reduction函数 mindspore.ops.amax mindspore.ops.amin mindspore.ops.argmin + mindspore.ops.cummax + mindspore.ops.cummin mindspore.ops.logsumexp mindspore.ops.max mindspore.ops.mean @@ -265,8 +267,6 @@ Reduction函数 :template: classtemplate.rst mindspore.ops.approximate_equal - mindspore.ops.cummax - mindspore.ops.cummin mindspore.ops.equal mindspore.ops.ge mindspore.ops.gt diff --git a/docs/api/api_python_en/mindspore.ops.functional.rst b/docs/api/api_python_en/mindspore.ops.functional.rst index 79540d14a81..bd6faeb4c2c 100644 --- a/docs/api/api_python_en/mindspore.ops.functional.rst +++ b/docs/api/api_python_en/mindspore.ops.functional.rst @@ -241,6 +241,8 @@ Reduction Functions mindspore.ops.amax mindspore.ops.amin mindspore.ops.argmin + mindspore.ops.cummax + mindspore.ops.cummin mindspore.ops.logsumexp mindspore.ops.max mindspore.ops.mean @@ -267,8 +269,6 @@ Comparison Functions :template: classtemplate.rst mindspore.ops.approximate_equal - mindspore.ops.cummax - mindspore.ops.cummin mindspore.ops.equal mindspore.ops.ge mindspore.ops.gt diff --git a/mindspore/ccsrc/backend/graph_compiler/backend.cc b/mindspore/ccsrc/backend/graph_compiler/backend.cc index 538f56b53cf..0e4ed5449df 100644 --- a/mindspore/ccsrc/backend/graph_compiler/backend.cc +++ b/mindspore/ccsrc/backend/graph_compiler/backend.cc @@ -793,7 +793,7 @@ void GetControlOpInput(const std::shared_ptr &graph_compiler, con } } if (args_tuple_num == 0) { - (void)args->emplace_back(value); + args->emplace_back(value); front_index++; } } diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/cumsum_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/cumsum_cpu_kernel.cc index 0987ac42324..88dc3474b54 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/cumsum_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/cumsum_cpu_kernel.cc @@ -64,7 +64,7 @@ inline void RightMove(const T *input, T *output, size_t dim0, size_t dim1, size_ } template -inline void Copy(T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2, +inline void Copy(T *input, const T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2, size_t start, size_t end) { for (size_t i = start; i < end; i++) { size_t k1 = i / dim2 % dim0; diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/cumsum_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/cumsum_cpu_kernel.h index 48b4ddfa9ad..0afdfa0f3e9 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/cumsum_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/cumsum_cpu_kernel.h @@ -58,7 +58,7 @@ class CumSumCpuKernelMod : public NativeCpuKernelMod { static std::vector> func_list_; CumSumLaunchFunc kernel_func_; int axis_{0}; - ShapeVector shape_{}; + std::vector shape_{}; size_t stride_{0}; size_t stride2_{0}; size_t dims_[3]{0}; diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/is_close_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/is_close_cpu_kernel.cc index 95cfcd4cab2..5f7b3917fa5 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/is_close_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/is_close_cpu_kernel.cc @@ -15,7 +15,7 @@ */ #include "plugin/device/cpu/kernel/is_close_cpu_kernel.h" -#include +#include #include #include #include "mindspore/core/ops/is_close.h" @@ -26,15 +26,16 @@ namespace kernel { namespace { constexpr size_t kIsCloseInputsNum = 2; constexpr size_t kIsCloseOutputsNum = 1; +constexpr size_t kIsCloseMaxDim = 10; template inline bool IsClose(T a, T b, float rtol, float atol, bool equal_nan) { - if (a == b) { + if (std::equal_to()(a, b)) { return true; } if (equal_nan && std::isnan(a) && std::isnan(b)) { return true; } - if (atol == 0 && rtol == 0) { + if (std::equal_to()(atol, 0) && std::equal_to()(rtol, 0)) { return false; } auto left_side = std::abs(a - b); @@ -42,27 +43,26 @@ inline bool IsClose(T a, T b, float rtol, float atol, bool equal_nan) { return std::isfinite(left_side) && left_side <= right_side; } -void GetBroadCastIndex(const ShapeVector &unaligned_input_shape, const ShapeVector &output_shape, - std::vector *index_list) { +void GetBroadCastIndex(const std::vector &unaligned_input_shape, const std::vector &output_shape, + std::vector *index_list) { // Given unaligned input shape and output shape, this function returns the mapping // from indices of output (logical) to corespondingly real input indices (physical). // The return will write to index_list, whose size is equal to total elements of output. - constexpr int MaxDim = 10; - int64_t logical_shape[MaxDim]; - int64_t physical_shape[MaxDim]; - int64_t size = 0, output_size = 1; + size_t logical_shape[kIsCloseMaxDim]; + size_t physical_shape[kIsCloseMaxDim]; + size_t size = 0, output_size = 1; // Align input shape to output shape by filling one into the outermost dimension. - ShapeVector input_shape(output_shape.size()); + std::vector input_shape(output_shape.size()); for (size_t i = 0, j = output_shape.size() - unaligned_input_shape.size(); i < output_shape.size(); i++) { input_shape[i] = i < j ? 1 : unaligned_input_shape[i - j]; } // Get logical shape and physical shape of input. Moreover, we will merge the dimensions with same // (logical or physical) property. - for (int i = SizeToInt(output_shape.size()) - 1; i >= 0;) { - int64_t stride = 1; + for (size_t i = output_shape.size(); i > 0;) { + size_t stride = 1; bool change = false, is_valid = false; - while (i >= 0 && input_shape[i] == output_shape[i]) { - stride *= output_shape[i]; + while (i > 0 && input_shape[i - 1] == output_shape[i - 1]) { + stride *= output_shape[i - 1]; change = is_valid = true; --i; } @@ -73,8 +73,8 @@ void GetBroadCastIndex(const ShapeVector &unaligned_input_shape, const ShapeVect } change = false; stride = 1; - while (i >= 0 && input_shape[i] == 1) { - stride *= output_shape[i]; + while (i > 0 && input_shape[i - 1] == 1) { + stride *= output_shape[i - 1]; change = is_valid = true; --i; } @@ -90,13 +90,13 @@ void GetBroadCastIndex(const ShapeVector &unaligned_input_shape, const ShapeVect } } // Get the flatten input indices according to "logical_shape" and "physical_shape". - int64_t offset = 1; - int64_t stride = 1; + size_t offset = 1; + size_t stride = 1; index_list->resize(output_size); (*index_list)[0] = 0; // First element is set to 0. - for (int64_t i = 0; i < size; ++i) { - int64_t increment = (logical_shape[i] == physical_shape[i] ? stride : 0); - for (int64_t j = 0; j < (physical_shape[i] - 1) * offset; ++j) { + for (size_t i = 0; i < size; ++i) { + size_t increment = (logical_shape[i] == physical_shape[i] ? stride : 0); + for (size_t j = 0; j < (physical_shape[i] - 1) * offset; ++j) { (*index_list)[offset + j] = (*index_list)[j] + increment; } offset *= physical_shape[i]; @@ -125,9 +125,9 @@ int IsCloseCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std: if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) { return ret; } - auto input_shape = inputs.at(kIndex0)->GetShapeVector(); - auto other_shape = inputs.at(kIndex1)->GetShapeVector(); - auto output_shape = outputs.at(kIndex0)->GetShapeVector(); + auto input_shape = LongVecToSizeVec(inputs.at(kIndex0)->GetShapeVector()); + auto other_shape = LongVecToSizeVec(inputs.at(kIndex1)->GetShapeVector()); + auto output_shape = LongVecToSizeVec(outputs.at(kIndex0)->GetShapeVector()); is_need_broadcast_ = input_shape != other_shape; if (is_need_broadcast_) { GetBroadCastIndex(input_shape, output_shape, &index_list1_); @@ -137,7 +137,7 @@ int IsCloseCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std: } template -bool IsCloseCpuKernelMod::LaunchKernel(const std::vector &inputs, const std::vector &workspace, +bool IsCloseCpuKernelMod::LaunchKernel(const std::vector &inputs, const std::vector &, const std::vector &outputs) { CHECK_KERNEL_INPUTS_NUM(inputs.size(), kIsCloseInputsNum, kernel_name_); CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kIsCloseOutputsNum, kernel_name_); diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/is_close_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/is_close_cpu_kernel.h index c805355aeb4..ddaf73b3065 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/is_close_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/is_close_cpu_kernel.h @@ -55,8 +55,8 @@ class IsCloseCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper< bool equal_nan_{true}; // Broadcast related. - std::vector index_list1_{}; - std::vector index_list2_{}; + std::vector index_list1_{}; + std::vector index_list2_{}; bool is_need_broadcast_{false}; }; } // namespace kernel diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/scatter_nd_arithmetic_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/scatter_nd_arithmetic_cpu_kernel.cc index 50f61c2ae06..6b7f062fd7d 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/scatter_nd_arithmetic_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/scatter_nd_arithmetic_cpu_kernel.cc @@ -173,22 +173,21 @@ bool ScatterNdArithmeticCpuKernelMod::LaunchKernel(const std::vectorGetKernelThreadNum()); - ParallelLaunch(task, element_size, block_size, this); + ParallelLaunch(task, element_size, 0, this, pool_); if (invalid_index_pos != -1) { std::stringstream indices_ss; std::stringstream input_shape_ss; + auto pos = LongToSize(invalid_index_pos); for (size_t i = 0; i < slice_size_; i++) { if (i > 0) { indices_ss << ", "; input_shape_ss << ", "; } - indices_ss << std::to_string(indices[invalid_index_pos + i]); + indices_ss << std::to_string(indices[pos + i]); input_shape_ss << std::to_string(input_shape_[i]); } - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the " << invalid_index_pos << "-th value of 'indices'[" - << indices_ss.str() << "] is out of range[" + input_shape_ss.str() + "]."; + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the " << pos << "-th value of 'indices'[" << indices_ss.str() + << "] is out of range[" + input_shape_ss.str() + "]."; return false; } return true; diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_segment_mean_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_segment_mean_cpu_kernel.cc index bddff0c7280..638a98c01b3 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_segment_mean_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_segment_mean_cpu_kernel.cc @@ -42,21 +42,21 @@ bool SparseSegmentMeanCpuKernelMod::Init(const BaseOperatorPtr &base_operator, int SparseSegmentMeanCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, const std::vector &outputs, - const std::map &inputsOnHost) { + const std::map &) { int ret = KernelMod::Resize(base_operator, inputs, outputs); if (ret != KRET_OK) { return ret; } - auto x_shape = inputs.at(kIndex0)->GetShapeVector(); - auto indices_shape = inputs.at(kIndex1)->GetShapeVector(); - auto y_shape = outputs.at(kIndex0)->GetShapeVector(); - batch_rank_ = LongToSize(base_operator->get_batch_rank()); - batch_size_ = std::accumulate(x_shape.begin(), x_shape.begin() + batch_rank_, size_t(1), std::multiplies{}); - outer_size_ = LongToSize(x_shape.at(batch_rank_)); - inner_size_ = std::accumulate(x_shape.begin() + batch_rank_ + 1, x_shape.end(), size_t(1), std::multiplies{}); + auto x_shape = LongVecToSizeVec(inputs.at(kIndex0)->GetShapeVector()); + auto indices_shape = LongVecToSizeVec(inputs.at(kIndex1)->GetShapeVector()); + auto y_shape = LongVecToSizeVec(outputs.at(kIndex0)->GetShapeVector()); + auto batch_rank = base_operator->get_batch_rank(); + batch_size_ = std::accumulate(x_shape.begin(), x_shape.begin() + batch_rank, size_t(1), std::multiplies{}); + outer_size_ = x_shape.at(LongToSize(batch_rank)); + inner_size_ = std::accumulate(x_shape.begin() + batch_rank + 1, x_shape.end(), size_t(1), std::multiplies{}); x_size_ = inner_size_ * outer_size_; - indices_size_ = LongToSize(indices_shape.at(batch_rank_)); - y_size_ = std::accumulate(y_shape.begin() + batch_rank_, y_shape.end(), size_t(1), std::multiplies{}); + indices_size_ = indices_shape.at(LongToSize(batch_rank)); + y_size_ = std::accumulate(y_shape.begin() + batch_rank, y_shape.end(), size_t(1), std::multiplies{}); return ret; } @@ -107,11 +107,11 @@ bool SparseSegmentMeanCpuKernelMod::LaunchKernel(const std::vector &outputs, const std::map &inputsOnHost) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, + bool Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs) override { return kernel_func_(this, inputs, outputs); } @@ -59,7 +59,6 @@ class SparseSegmentMeanCpuKernelMod : public NativeCpuKernelMod { size_t indices_size_{1}; size_t x_size_{1}; size_t y_size_{1}; - size_t batch_rank_{0}; size_t batch_size_{1}; }; } // namespace kernel diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/lamb_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/lamb_gpu_kernel.h index 5bdcbb3b8e4..5e270e994a8 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/lamb_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/lamb_gpu_kernel.h @@ -105,6 +105,7 @@ class LambGpuKernelMod : public NativeGpuKernelMod { const std::vector &outputs) override { auto kernel_ptr = std::dynamic_pointer_cast(base_operator); kernel_name_ = kernel_ptr->name(); + InitResource(); return true; } @@ -114,7 +115,6 @@ class LambGpuKernelMod : public NativeGpuKernelMod { MS_LOG(EXCEPTION) << "For 'Lamb', the number of inputs should be " << INPUT_NUM << ", but got " << inputs.size(); } - InitResource(); InitParamSizeByType(); auto variable_int64_shape = inputs[kVarIndex]->GetShapeVector(); diff --git a/mindspore/core/ops/index_fill.cc b/mindspore/core/ops/index_fill.cc index d656ffc69a6..6457af4c07b 100644 --- a/mindspore/core/ops/index_fill.cc +++ b/mindspore/core/ops/index_fill.cc @@ -56,8 +56,8 @@ TypePtr IndexFillInferType(const PrimitivePtr &primitive, const std::vector args; - (void)args.insert({"x", x_type}); - (void)args.insert({"value", value_type}); + (void)args.emplace("x", x_type); + (void)args.emplace("value", x_type); (void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_data_types, prim_name); return x_type; } @@ -78,7 +78,7 @@ abstract::ShapePtr IndexFillInferShape(const PrimitivePtr &primitive, const std: // Input 'index' must be a scalar/vector. auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; auto index_rank = SizeToLong(index_shape.size()); - (void)CheckAndConvertUtils::CheckInRange("rank of 'index'", index_rank, kIncludeBoth, {0, 1}, prim_name); + CheckAndConvertUtils::CheckInRange("rank of 'index'", index_rank, kIncludeBoth, {0, 1}, prim_name); // Input 'value' must be a tensor with a value or a scalar. if (input_args[kInputIndex3]->isa()) { diff --git a/mindspore/core/ops/is_close.cc b/mindspore/core/ops/is_close.cc index c2ab6e20337..ddca34f6942 100644 --- a/mindspore/core/ops/is_close.cc +++ b/mindspore/core/ops/is_close.cc @@ -53,14 +53,16 @@ abstract::ShapePtr IsCloseInferShape(const PrimitivePtr &primitive, const std::v << ", 'other' size: " << other_shape[i] << "."; } } - if (input_size > MAX) + if (input_size > MAX) { MS_EXCEPTION(ValueError) << "For '" << op_name << "', the size of tensor 'input' must be less than [2147483648], but got: " << input_size << "."; - if (other_size > MAX) + } + if (other_size > MAX) { MS_EXCEPTION(ValueError) << "For '" << op_name << "', the size of tensor 'other' must be less than [2147483648], but got: " << other_size << "."; + } } return BroadCastInferShape(op_name, input_args); } diff --git a/mindspore/core/ops/matrix_diag_v3.cc b/mindspore/core/ops/matrix_diag_v3.cc index c64317cb1df..a61e83e3bcd 100644 --- a/mindspore/core/ops/matrix_diag_v3.cc +++ b/mindspore/core/ops/matrix_diag_v3.cc @@ -30,7 +30,7 @@ namespace mindspore { namespace ops { namespace { -int64_t GetTensorValue(AbstractBasePtr arg, const std::string &prim_name, const std::string &arg_name) { +int64_t GetTensorValue(const AbstractBasePtr &arg, const std::string &prim_name, const std::string &arg_name) { if (!arg->isa() || !arg->BuildValue()->isa()) { MS_EXCEPTION(TypeError) << "For " << prim_name << ", the input '" << arg_name << "' must be const Tensor."; } @@ -60,9 +60,9 @@ ShapeVector GetOutputShape(const std::vector &x_shape, int64_t lower_di MS_EXCEPTION(ValueError) << "For " << prim_name << ", k[0] must not be greater than k[1], but got k[0] is " << lower_diag_index << ", k[1] is " << upper_diag_index << "."; } - CheckAndConvertUtils::CheckInteger("rank of 'x'", x_rank, kGreaterEqual, number_two, prim_name); + (void)CheckAndConvertUtils::CheckInteger("rank of 'x'", x_rank, kGreaterEqual, number_two, prim_name); auto num_diags = upper_diag_index - lower_diag_index + 1; - if (x_shape[x_rank - number_two] != num_diags) { + if (x_shape[LongToSize(x_rank - number_two)] != num_diags) { MS_EXCEPTION(ValueError) << "For " << prim_name << ", the input x_shape[-2] doesn't match with k value."; } (void)out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end() - number_two); @@ -73,10 +73,12 @@ ShapeVector GetOutputShape(const std::vector &x_shape, int64_t lower_di int64_t max_diag_len = x_shape.back(); int64_t min_num_rows = max_diag_len - std::min(upper_diag_index, int64_t(0)); int64_t min_num_cols = max_diag_len + std::max(lower_diag_index, int64_t(0)); - if (row_val != -1 && row_val < min_num_rows) + if (row_val != -1 && row_val < min_num_rows) { MS_EXCEPTION(ValueError) << "For " << prim_name << ", the number of rows is too small."; - if (col_val != -1 && col_val < min_num_cols) + } + if (col_val != -1 && col_val < min_num_cols) { MS_EXCEPTION(ValueError) << "For " << prim_name << ", the number of columns is too small."; + } if (row_val == -1 && col_val == -1) { row_val = std::max(min_num_rows, min_num_cols); col_val = row_val; @@ -212,8 +214,8 @@ AbstractBasePtr MatrixDiagV3Infer(const abstract::AnalysisEnginePtr &, const Pri auto align_ptr = primitive->GetAttr(kAlign); MS_EXCEPTION_IF_NULL(align_ptr); auto align = GetValue(align_ptr); - CheckAndConvertUtils::CheckString(kAlign, align, {"LEFT_RIGHT", "RIGHT_LEFT", "LEFT_LEFT", "RIGHT_RIGHT"}, - primitive->name()); + (void)CheckAndConvertUtils::CheckString(kAlign, align, {"LEFT_RIGHT", "RIGHT_LEFT", "LEFT_LEFT", "RIGHT_RIGHT"}, + primitive->name()); auto infer_type = MatrixDiagV3InferType(primitive, input_args); auto infer_shape = MatrixDiagV3InferShape(primitive, input_args); return abstract::MakeAbstract(infer_shape, infer_type); diff --git a/mindspore/core/ops/sparse_segment_mean.cc b/mindspore/core/ops/sparse_segment_mean.cc index a8c9a4cabe6..b92f1616f88 100644 --- a/mindspore/core/ops/sparse_segment_mean.cc +++ b/mindspore/core/ops/sparse_segment_mean.cc @@ -60,7 +60,7 @@ abstract::ShapePtr SparseSegmentMeanInferShape(const PrimitivePtr &prim, if (!input_args[kInputIndex2]->BuildValue()->isa()) { // The real output shape relies on the last value of 'segment_ids', we have already added dependency map. // The framework ensures the `else` branch will be executed, so min/max shape are not necessary to set. - out_shape[batch_rank] = abstract::Shape::SHP_ANY; + out_shape[LongToSize(batch_rank)] = abstract::Shape::SHP_ANY; return std::make_shared(out_shape); } else { auto segment_ids = input_args[kInputIndex2]->cast(); @@ -85,7 +85,7 @@ abstract::ShapePtr SparseSegmentMeanInferShape(const PrimitivePtr &prim, if (segment_num <= 0) { MS_LOG(EXCEPTION) << "For '" << prim_name << "', the input 'segment_ids' must be non-negative."; } - out_shape[batch_rank] = segment_num; + out_shape[LongToSize(batch_rank)] = segment_num; return std::make_shared(out_shape); } } diff --git a/mindspore/python/mindspore/scipy/linalg.py b/mindspore/python/mindspore/scipy/linalg.py index 4969779196f..fd54466cd78 100755 --- a/mindspore/python/mindspore/scipy/linalg.py +++ b/mindspore/python/mindspore/scipy/linalg.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """Linear algebra submodule""" +from __future__ import absolute_import from .ops import Cholesky from .ops import Eigh from .ops import LU diff --git a/mindspore/python/mindspore/scipy/optimize/_bfgs.py b/mindspore/python/mindspore/scipy/optimize/_bfgs.py index 15d9d7ad671..66a42fbfcf8 100644 --- a/mindspore/python/mindspore/scipy/optimize/_bfgs.py +++ b/mindspore/python/mindspore/scipy/optimize/_bfgs.py @@ -72,9 +72,9 @@ class MinimizeBfgs(nn.Cell): def construct(self, x0, maxiter=None, norm=mnp.inf, gtol=1e-5, line_search_maxiter=10): # Constant tensors which avoid loop unrolling - _BOOL_FALSE = _to_tensor(False) - _INT_ZERO = _to_tensor(0) - _INT_ONE = _to_tensor(1) + const_bool_false = _to_tensor(False) + const_int_zero = _to_tensor(0) + const_int_one = _to_tensor(1) if maxiter is None: maxiter = mnp.size(x0) * 200 @@ -86,18 +86,18 @@ class MinimizeBfgs(nn.Cell): state = { "converged": _norm(g_0, ord_=mnp.inf) < gtol, - "failed": _BOOL_FALSE, - "k": _INT_ZERO, - "nfev": _INT_ONE, - "ngev": _INT_ONE, - "nhev": _INT_ZERO, + "failed": const_bool_false, + "k": const_int_zero, + "nfev": const_int_one, + "ngev": const_int_one, + "nhev": const_int_zero, "x_k": x0, "f_k": f_0, "g_k": g_0, "H_k": identity, "old_old_fval": f_0 + _norm(g_0) / 2, - "status": _INT_ZERO, - "line_search_status": _INT_ZERO + "status": const_int_zero, + "line_search_status": const_int_zero } while state["k"] < maxiter: diff --git a/mindspore/python/mindspore/scipy/optimize/line_search.py b/mindspore/python/mindspore/scipy/optimize/line_search.py index 261ade35db8..c294d98c92b 100644 --- a/mindspore/python/mindspore/scipy/optimize/line_search.py +++ b/mindspore/python/mindspore/scipy/optimize/line_search.py @@ -86,13 +86,13 @@ def _zoom(fn, a_low, phi_low, dphi_low, a_high, phi_high, dphi_high, phi_0, g_0, Tries cubic, quadratic, and bisection methods of zooming. """ # Constant tensors which avoid loop unrolling - _FLOAT_ONE = _to_tensor(1., dtype=a_low.dtype) - _BOOL_FALSE = _to_tensor(False) - _INT_ZERO = _to_tensor(0) + const_float_one = _to_tensor(1., dtype=a_low.dtype) + const_bool_false = _to_tensor(False) + const_int_zero = _to_tensor(0) state = { - "done": _BOOL_FALSE, - "failed": _BOOL_FALSE, - "j": _INT_ZERO, + "done": const_bool_false, + "failed": const_bool_false, + "j": const_int_zero, "a_low": a_low, "phi_low": phi_low, "dphi_low": dphi_low, @@ -101,12 +101,12 @@ def _zoom(fn, a_low, phi_low, dphi_low, a_high, phi_high, dphi_high, phi_0, g_0, "dphi_high": dphi_high, "a_rec": (a_low + a_high) / 2., "phi_rec": (phi_low + phi_high) / 2., - "a_star": _FLOAT_ONE, + "a_star": const_float_one, "phi_star": phi_low, "dphi_star": dphi_low, "g_star": g_0, - "nfev": _INT_ZERO, - "ngev": _INT_ZERO, + "nfev": const_int_zero, + "ngev": const_int_zero, } if mnp.logical_not(is_run): @@ -195,41 +195,41 @@ class LineSearch(nn.Cell): return fkk, gkk, mnp.dot(gkk, pk) # Constant tensors which avoid loop unrolling - _FLOAT_ZERO = _to_tensor(0., dtype=xk.dtype) - _FLOAT_ONE = _to_tensor(1., dtype=xk.dtype) - _BOOL_FALSE = _to_tensor(False) - _INT_ZERO = _to_tensor(0) - _INT_ONE = _to_tensor(1) + const_float_zero = _to_tensor(0., dtype=xk.dtype) + const_float_one = _to_tensor(1., dtype=xk.dtype) + const_bool_false = _to_tensor(False) + const_int_zero = _to_tensor(0) + const_int_one = _to_tensor(1) if old_fval is None or gfk is None: - nfev, ngev = _INT_ONE, _INT_ONE - phi_0, g_0, dphi_0 = fval_and_grad(_FLOAT_ZERO) + nfev, ngev = const_int_one, const_int_one + phi_0, g_0, dphi_0 = fval_and_grad(const_float_zero) else: - nfev, ngev = _INT_ZERO, _INT_ZERO + nfev, ngev = const_int_zero, const_int_zero phi_0, g_0 = old_fval, gfk dphi_0 = mnp.dot(g_0, pk) if old_old_fval is None: - start_value = _FLOAT_ONE + start_value = const_float_one else: old_phi0 = old_old_fval candidate_start_value = 1.01 * 2 * (phi_0 - old_phi0) / dphi_0 start_value = mnp.where( mnp.isfinite(candidate_start_value), - mnp.minimum(candidate_start_value, _FLOAT_ONE), - _FLOAT_ONE + mnp.minimum(candidate_start_value, const_float_one), + const_float_one ) state = { - "done": _BOOL_FALSE, - "failed": _BOOL_FALSE, - "i": _INT_ONE, - "a_i": _FLOAT_ZERO, + "done": const_bool_false, + "failed": const_bool_false, + "i": const_int_one, + "a_i": const_float_zero, "phi_i": phi_0, "dphi_i": dphi_0, "nfev": nfev, "ngev": ngev, - "a_star": _FLOAT_ZERO, + "a_star": const_float_zero, "phi_star": phi_0, "dphi_star": dphi_0, "g_star": g_0, diff --git a/mindspore/python/mindspore/scipy/sparse/linalg.py b/mindspore/python/mindspore/scipy/sparse/linalg.py index 1c40006c234..26dbd55f8ac 100644 --- a/mindspore/python/mindspore/scipy/sparse/linalg.py +++ b/mindspore/python/mindspore/scipy/sparse/linalg.py @@ -78,19 +78,19 @@ def _batch_gmres(A, b, x0, tol, restart, maxiter, M, atol): It does not allow for early termination, but has much less overhead on GPUs. """ # Constant tensor which avoids loop unrolling - _INT_ZERO = _to_tensor(0) + const_int_zero = _to_tensor(0) dtype = b.dtype _, b_norm = _safe_normalize(b) atol = mnp.maximum(tol * b_norm, _to_tensor(atol), dtype=dtype) residual = _matvec(M, b - _matvec(A, x0)) unit_residual, residual_norm = _safe_normalize(residual) - k = _INT_ZERO + k = const_int_zero x = x0 while k < maxiter and residual_norm > atol: pad_width = ((0, 0),) * unit_residual.ndim + ((0, restart),) V = mnp.pad(unit_residual[..., None], pad_width=pad_width) H = mnp.eye(restart, restart + 1, dtype=dtype) - k_iter = _INT_ZERO + k_iter = const_int_zero breakdown = _to_tensor(False) while k_iter < restart and mnp.logical_not(breakdown): V, H, breakdown = arnoldi_iteration(k_iter, A, M, V, H) @@ -103,7 +103,7 @@ def _batch_gmres(A, b, x0, tol, restart, maxiter, M, atol): residual = _matvec(M, b - _matvec(A, x)) unit_residual, residual_norm = _safe_normalize(residual) k += 1 - return x, F.select(residual_norm > atol, k, _INT_ZERO) + return x, F.select(residual_norm > atol, k, const_int_zero) def _incremental_gmres(A, b, x0, tol, restart, maxiter, M, atol): @@ -112,7 +112,7 @@ def _incremental_gmres(A, b, x0, tol, restart, maxiter, M, atol): the GMRES process using Givens rotations. This improves numerical stability and gives a free estimate of the residual norm that allows for early termination within a single "restart". """ - _INT_ZERO = _to_tensor(0) + const_int_zero = _to_tensor(0) _, b_norm = _safe_normalize(b) atol = mnp.maximum(tol * b_norm, atol) @@ -123,7 +123,7 @@ def _incremental_gmres(A, b, x0, tol, restart, maxiter, M, atol): r = _matvec(M, b - _matvec(A, x0)) r, r_norm = _safe_normalize(r) - iters = _INT_ZERO + iters = const_int_zero while iters < maxiter and r_norm > atol: V = mnp.pad(r[..., None], ((0, 0),) * r.ndim + ((0, restart),)) dtype = mnp.result_type(b) @@ -134,13 +134,13 @@ def _incremental_gmres(A, b, x0, tol, restart, maxiter, M, atol): beta_vec = mnp.zeros((restart + 1), dtype=dtype) beta_vec[0] = r_norm - k = _INT_ZERO + k = const_int_zero err = r_norm while mnp.logical_and(mnp.less(k, restart), mnp.less(ptol, err)): V, R, _ = arnoldi_iteration(k, A, M, V, R) # Givens rotation row_k = R[k, :] - i = _INT_ZERO + i = const_int_zero while i < k: row_k = rotate_vectors(row_k, i, givens[i, 0], givens[i, 1]) i += 1 @@ -168,7 +168,7 @@ def _incremental_gmres(A, b, x0, tol, restart, maxiter, M, atol): r, r_norm = _safe_normalize(r) x0 = x iters += 1 - return x0, F.select(r_norm > atol, iters, _INT_ZERO) + return x0, F.select(r_norm > atol, iters, const_int_zero) class GMRES(nn.Cell): @@ -360,13 +360,13 @@ def _cg(A, b, x0, tol, atol, maxiter, M): building blocks for iterative methods', 1994, pg. 12-14 """ # Constant tensor which avoids loop unrolling - _INT_ZERO = _to_tensor(0) + const_int_zero = _to_tensor(0) atol_ = mnp.maximum(atol, tol * _norm(b)) r = b - _matvec(A, x0) z = p = _matvec(M, r) rho = mnp.dot(r, z) - k = _INT_ZERO + k = const_int_zero x = x0 while k < maxiter and _norm(r) > atol_: q = _matvec(A, p) @@ -382,7 +382,7 @@ def _cg(A, b, x0, tol, atol, maxiter, M): k += 1 - return x, F.select(_norm(r) > atol_, k, _INT_ZERO) + return x, F.select(_norm(r) > atol_, k, const_int_zero) class CG(nn.Cell): @@ -535,20 +535,20 @@ class BiCGStab(nn.Cell): def construct(self, b, x0, tol, atol, maxiter): # Constant tensors which avoid loop unrolling - _INT_ZERO = _to_tensor(0) - _INT_NEG_ONE = _to_tensor(-1) + const_int_zero = _to_tensor(0) + const_int_neg_one = _to_tensor(-1) - _float_one = _to_tensor(1., dtype=b.dtype) + const_float_one = _to_tensor(1., dtype=b.dtype) atol_ = mnp.maximum(atol, tol * _norm(b)) r = r_tilde = v = p = b - _matvec(self.A, x0) - rho = alpha = omega = _float_one - k = _INT_ZERO + rho = alpha = omega = const_float_one + k = const_int_zero x = x0 while k < maxiter: rho_ = mnp.dot(r_tilde, r) if rho_ == 0. or omega == 0.: - k = _INT_NEG_ONE + k = const_int_neg_one break beta = rho_ / rho * (alpha / omega) @@ -572,7 +572,7 @@ class BiCGStab(nn.Cell): rho = rho_ k += 1 - return x, F.select(k == _INT_NEG_ONE or k >= maxiter, k, _INT_ZERO) + return x, F.select(k == const_int_neg_one or k >= maxiter, k, const_int_zero) def bicgstab(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None): diff --git a/mindspore/python/mindspore/scipy/utils_const.py b/mindspore/python/mindspore/scipy/utils_const.py index c7cf80c7e3d..c974f5b1366 100644 --- a/mindspore/python/mindspore/scipy/utils_const.py +++ b/mindspore/python/mindspore/scipy/utils_const.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """internal graph-compatible utility functions""" +from __future__ import absolute_import from types import FunctionType from collections.abc import Iterable from .. import context