Fix diag op floating point exception, core dump.

This commit is contained in:
shaw_zhang 2023-03-07 11:04:17 +08:00
parent 8900c69557
commit 02239dc20c
1 changed files with 4 additions and 0 deletions

View File

@ -110,6 +110,10 @@ int DiagGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::ve
for (size_t i = LongToSize(batch_rank_); i < input_shape.size(); ++i) {
input_size_ *= LongToSize(input_shape[i]);
}
if (input_size_ == 0) {
MS_LOG(ERROR) << kernel_name_ << "input size should should be larger than 0, but got: " << input_size_;
return KRET_RESIZE_FAILED;
}
// Get the output size of each batch.
auto output = outputs.at(kIndex0);