fix core dump issue - SquareSumAll op

This commit is contained in:
TinaMengtingZhang 2022-08-09 14:50:23 +00:00 committed by Unknown
parent ace944e14b
commit 3f1006e446
2 changed files with 6 additions and 4 deletions

View File

@ -55,7 +55,7 @@ void SquareSumAllCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
int64_t batch_rank = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kBatchRank);
batch_rank_ = LongToSize(batch_rank);
}
batch_size_ =
num_batch_ =
std::accumulate(input_shape.begin(), input_shape.begin() + batch_rank_, size_t(1), std::multiplies<size_t>());
x_size_ = std::accumulate(input_shape.begin() + batch_rank_, input_shape.end(), size_t(1), std::multiplies<size_t>());
}
@ -81,6 +81,9 @@ void SquareSumAllCpuKernelMod::InitInputOutputSize(const CNodePtr &kernel_node)
bool SquareSumAllCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
bool ret = true;
if (input_size_ == 0) {
return ret;
}
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSquareSumAllInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSquareSumAllOutputsNum, kernel_name_);
if (dtype_ == kNumberTypeFloat16) {
@ -107,8 +110,7 @@ bool SquareSumAllCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr
auto task = std::bind(SquareSum<T>, input_0_addr, input_1_addr, workspace_0_addr, workspace_1_addr, x_size_,
std::placeholders::_1, std::placeholders::_2);
ParallelLaunchAutoSearch(task, input_size_ * kSquareSumAllInputsNum, this, &parallel_search_info_);
size_t num_batches = input_size_ / x_size_;
for (size_t i = 0; i < num_batches; i++) {
for (size_t i = 0; i < num_batch_; i++) {
output_0_addr[i] = static_cast<T>(workspace_0_addr[i]);
output_1_addr[i] = static_cast<T>(workspace_1_addr[i]);
}

View File

@ -38,7 +38,7 @@ class SquareSumAllCpuKernelMod : public DeprecatedNativeCpuKernelMod {
size_t input_size_;
TypeId dtype_;
size_t batch_rank_{0};
size_t batch_size_{1};
size_t num_batch_{1};
size_t x_size_{1};
void InitInputOutputSize(const CNodePtr &kernel_node) override;