forked from mindspore-Ecosystem/mindspore
fix core dump issue - SquareSumAll op
This commit is contained in:
parent
ace944e14b
commit
3f1006e446
|
@ -55,7 +55,7 @@ void SquareSumAllCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
int64_t batch_rank = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kBatchRank);
|
||||
batch_rank_ = LongToSize(batch_rank);
|
||||
}
|
||||
batch_size_ =
|
||||
num_batch_ =
|
||||
std::accumulate(input_shape.begin(), input_shape.begin() + batch_rank_, size_t(1), std::multiplies<size_t>());
|
||||
x_size_ = std::accumulate(input_shape.begin() + batch_rank_, input_shape.end(), size_t(1), std::multiplies<size_t>());
|
||||
}
|
||||
|
@ -81,6 +81,9 @@ void SquareSumAllCpuKernelMod::InitInputOutputSize(const CNodePtr &kernel_node)
|
|||
bool SquareSumAllCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
bool ret = true;
|
||||
if (input_size_ == 0) {
|
||||
return ret;
|
||||
}
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSquareSumAllInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSquareSumAllOutputsNum, kernel_name_);
|
||||
if (dtype_ == kNumberTypeFloat16) {
|
||||
|
@ -107,8 +110,7 @@ bool SquareSumAllCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr
|
|||
auto task = std::bind(SquareSum<T>, input_0_addr, input_1_addr, workspace_0_addr, workspace_1_addr, x_size_,
|
||||
std::placeholders::_1, std::placeholders::_2);
|
||||
ParallelLaunchAutoSearch(task, input_size_ * kSquareSumAllInputsNum, this, ¶llel_search_info_);
|
||||
size_t num_batches = input_size_ / x_size_;
|
||||
for (size_t i = 0; i < num_batches; i++) {
|
||||
for (size_t i = 0; i < num_batch_; i++) {
|
||||
output_0_addr[i] = static_cast<T>(workspace_0_addr[i]);
|
||||
output_1_addr[i] = static_cast<T>(workspace_1_addr[i]);
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ class SquareSumAllCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
size_t input_size_;
|
||||
TypeId dtype_;
|
||||
size_t batch_rank_{0};
|
||||
size_t batch_size_{1};
|
||||
size_t num_batch_{1};
|
||||
size_t x_size_{1};
|
||||
|
||||
void InitInputOutputSize(const CNodePtr &kernel_node) override;
|
||||
|
|
Loading…
Reference in New Issue