bugfix for svd

This commit is contained in:
zhuyuxiao 2022-11-03 10:11:43 +08:00
parent 8f1b9ce5c6
commit b90f0820b3
2 changed files with 3 additions and 2 deletions

View File

@ -124,7 +124,7 @@ bool KLDivLossCpuKernelMod::LaunchNoneReduction(const std::vector<AddressPtr> &i
for (size_t i = start; i < end; ++i) {
T diff = static_cast<T>(array_log[iter.GetInputPosB()] - array_x[iter.GetInputPosA()]);
array_y[i] = static_cast<T>(diff * array_target[iter.GetInputPosB()]);
if (std::isnan(static_cast<float>(array_log[iter.GetInputPosB()]))) {
if (std::isnan(static_cast<float>(array_y[i]))) {
array_y[i] = static_cast<T>(0);
}
iter.GenNextPos();
@ -156,7 +156,7 @@ bool KLDivLossCpuKernelMod::LaunchOther(const std::vector<AddressPtr> &inputs, c
for (size_t i = start; i < end; ++i) {
T diff = static_cast<T>(array_log[iter.GetInputPosB()] - array_x[iter.GetInputPosA()]);
array_tmp[i] = static_cast<T>(diff * array_target[iter.GetInputPosB()]);
if (std::isnan(static_cast<float>(array_log[iter.GetInputPosB()]))) {
if (std::isnan(static_cast<float>(array_tmp[i]))) {
array_tmp[i] = static_cast<T>(0);
}
iter.GenNextPos();

View File

@ -82,6 +82,7 @@ int SvdCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vec
num_of_rows_ = input_shape[dim - kDim2];
num_of_cols_ = input_shape[dim - kDim1];
batch_size_ = 1;
for (size_t i = 0; i < dim - kDim2; i++) {
batch_size_ = batch_size_ * input_shape[i];
}