!47601 solve the bug of uniform core dump

Merge pull request !47601 from zong_shuai/uniform_debug_1
This commit is contained in:
i-robot 2023-01-07 07:58:56 +00:00 committed by Gitee
commit f4493f2c49
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 1 additions and 5 deletions

View File

@ -104,8 +104,7 @@ int UniformCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std:
if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs)) != 0) { if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs)) != 0) {
return ret; return ret;
} }
std::vector<int64_t> input_shape = inputs.at(kIndex0)->GetShapeVector(); input_elements_ = SizeOf(inputs.at(kIndex0)->GetShapeVector());
std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(input_shape_), LongToSize);
return ret; return ret;
} }
@ -118,7 +117,6 @@ bool UniformCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &in
InitMSPhiloxRandom(seed_, offset_); InitMSPhiloxRandom(seed_, offset_);
auto y = reinterpret_cast<T *>(outputs[0]->addr); auto y = reinterpret_cast<T *>(outputs[0]->addr);
input_elements_ = std::accumulate(input_shape_.begin(), input_shape_.end(), int64_t(1), std::multiplies<int64_t>());
for (int64_t i = 0; i < input_elements_; i++) { for (int64_t i = 0; i < input_elements_; i++) {
y[i] = static_cast<T>(RandFloat() * (to_ - from_) + from_); y[i] = static_cast<T>(RandFloat() * (to_ - from_) + from_);
} }

View File

@ -68,8 +68,6 @@ class UniformCpuKernelMod : public NativeCpuKernelMod {
static std::vector<std::pair<KernelAttr, UniformFunc>> func_list_; static std::vector<std::pair<KernelAttr, UniformFunc>> func_list_;
UniformFunc kernel_func_; UniformFunc kernel_func_;
std::vector<size_t> input_shape_;
std::vector<size_t> output_shape_;
int64_t input_elements_; int64_t input_elements_;
float from_{0.0}; float from_{0.0};
float to_{1.0}; float to_{1.0};