clean code warnings for svd, scatter_nd_update

This commit is contained in:
huanghui 2022-07-28 15:19:07 +08:00
parent deabda1856
commit 7b0aac92dd
4 changed files with 18 additions and 18 deletions

View File

@ -33,17 +33,17 @@ bool Compute(const ComputeParams<T> *params, const size_t start, const size_t en
T *x = params->x_; T *x = params->x_;
int *indices = params->indices_; int *indices = params->indices_;
T *updates = params->updates_; T *updates = params->updates_;
std::vector<int> *out_strides = params->out_strides_; std::vector<size_t> *out_strides = params->out_strides_;
MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(indices); MS_EXCEPTION_IF_NULL(indices);
MS_EXCEPTION_IF_NULL(updates); MS_EXCEPTION_IF_NULL(updates);
MS_EXCEPTION_IF_NULL(out_strides); MS_EXCEPTION_IF_NULL(out_strides);
for (int i = SizeToInt(start); i < SizeToInt(end); ++i) { for (size_t i = start; i < end; ++i) {
int offset = 0; size_t offset = 0;
std::vector<int> local_indices; std::vector<size_t> local_indices;
for (int j = 0; j < params->indices_unit_rank_; ++j) { for (size_t j = 0; j < params->indices_unit_rank_; ++j) {
auto index = indices[i * params->indices_unit_rank_ + j]; auto index = IntToSize(indices[i * params->indices_unit_rank_ + j]);
(void)local_indices.emplace_back(index); (void)local_indices.emplace_back(index);
if (index < 0) { if (index < 0) {
MS_LOG(ERROR) << "For '" << kKernelName MS_LOG(ERROR) << "For '" << kKernelName
@ -104,20 +104,20 @@ void ScatterUpdateCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
<< indices_shape[i]; << indices_shape[i];
} }
} }
indices_unit_rank_ = SizeToInt(indices_unit_rank); indices_unit_rank_ = indices_unit_rank;
unit_size_ = 1; unit_size_ = 1;
for (size_t i = indices_shape.size() - 1; i < updates_shape.size(); ++i) { for (size_t i = indices_shape.size() - 1; i < updates_shape.size(); ++i) {
unit_size_ *= SizeToInt(updates_shape[i]); unit_size_ *= updates_shape[i];
} }
num_units_ = 1; num_units_ = 1;
num_units_ *= updates_shape[indices_shape.size() - 2]; num_units_ *= updates_shape[indices_shape.size() - 2];
for (int i = SizeToInt(indices_shape.size()) - 3; i >= 0; i--) { for (int i = SizeToInt(indices_shape.size()) - 3; i >= 0; i--) {
num_units_ *= updates_shape[i]; num_units_ *= updates_shape[i];
} }
int out_stride = 1; size_t out_stride = 1;
out_strides_.push_back(out_stride); out_strides_.push_back(out_stride);
for (int i = indices_unit_rank_ - 2; i >= 0; i--) { for (int i = indices_unit_rank_ - 2; i >= 0; i--) {
out_stride *= shape[i + 1]; out_stride *= LongToSize(shape[i + 1]);
out_strides_.push_back(out_stride); out_strides_.push_back(out_stride);
} }
reverse(out_strides_.begin(), out_strides_.end()); reverse(out_strides_.begin(), out_strides_.end());

View File

@ -30,9 +30,9 @@ struct ComputeParams {
T *x_{nullptr}; T *x_{nullptr};
int *indices_{nullptr}; int *indices_{nullptr};
T *updates_{nullptr}; T *updates_{nullptr};
int unit_size_{0}; size_t unit_size_{0};
int indices_unit_rank_{0}; size_t indices_unit_rank_{0};
std::vector<int> *out_strides_{nullptr}; std::vector<size_t> *out_strides_{nullptr};
size_t x_mem_size_{0}; size_t x_mem_size_{0};
}; };
@ -55,10 +55,10 @@ class ScatterUpdateCpuKernelMod : public DeprecatedNativeCpuKernelMod {
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs); void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
TypeId dtype_{kTypeUnknown}; TypeId dtype_{kTypeUnknown};
int unit_size_{0}; size_t unit_size_{0};
size_t num_units_{0}; size_t num_units_{0};
int indices_unit_rank_{0}; size_t indices_unit_rank_{0};
std::vector<int> out_strides_; std::vector<size_t> out_strides_;
}; };
class ScatterNdUpdateCpuKernelMod : public ScatterUpdateCpuKernelMod { class ScatterNdUpdateCpuKernelMod : public ScatterUpdateCpuKernelMod {

View File

@ -18,6 +18,7 @@
#include <algorithm> #include <algorithm>
#include <set> #include <set>
#include <vector>
#include "mindapi/ir/type.h" #include "mindapi/ir/type.h"
#include "utils/check_convert_utils.h" #include "utils/check_convert_utils.h"
@ -39,7 +40,7 @@ abstract::BaseShapePtr SvdInferShape(const PrimitivePtr &prim, const std::vector
auto n = a_shape[ndim - kIndexOne]; auto n = a_shape[ndim - kIndexOne];
auto p = std::min(m, n); auto p = std::min(m, n);
auto s_shape = ShapeVector(a_shape.begin(), a_shape.end() - kIndexOne); auto s_shape = ShapeVector(a_shape.begin(), a_shape.end() - SizeToLong(kIndexOne));
s_shape[s_shape.size() - kIndexOne] = p; s_shape[s_shape.size() - kIndexOne] = p;
auto u_shape = ShapeVector(a_shape.begin(), a_shape.end()); auto u_shape = ShapeVector(a_shape.begin(), a_shape.end());
auto v_shape = ShapeVector(a_shape.begin(), a_shape.end()); auto v_shape = ShapeVector(a_shape.begin(), a_shape.end());

View File

@ -17,7 +17,6 @@
#ifndef MINDSPORE_CORE_OPS_SVD_H_ #ifndef MINDSPORE_CORE_OPS_SVD_H_
#define MINDSPORE_CORE_OPS_SVD_H_ #define MINDSPORE_CORE_OPS_SVD_H_
#include <string> #include <string>
#include <vector>
#include <memory> #include <memory>
#include "ops/base_operator.h" #include "ops/base_operator.h"