bugfix for svd
This commit is contained in:
parent
8f1b9ce5c6
commit
b90f0820b3
|
@ -124,7 +124,7 @@ bool KLDivLossCpuKernelMod::LaunchNoneReduction(const std::vector<AddressPtr> &i
|
|||
for (size_t i = start; i < end; ++i) {
|
||||
T diff = static_cast<T>(array_log[iter.GetInputPosB()] - array_x[iter.GetInputPosA()]);
|
||||
array_y[i] = static_cast<T>(diff * array_target[iter.GetInputPosB()]);
|
||||
if (std::isnan(static_cast<float>(array_log[iter.GetInputPosB()]))) {
|
||||
if (std::isnan(static_cast<float>(array_y[i]))) {
|
||||
array_y[i] = static_cast<T>(0);
|
||||
}
|
||||
iter.GenNextPos();
|
||||
|
@ -156,7 +156,7 @@ bool KLDivLossCpuKernelMod::LaunchOther(const std::vector<AddressPtr> &inputs, c
|
|||
for (size_t i = start; i < end; ++i) {
|
||||
T diff = static_cast<T>(array_log[iter.GetInputPosB()] - array_x[iter.GetInputPosA()]);
|
||||
array_tmp[i] = static_cast<T>(diff * array_target[iter.GetInputPosB()]);
|
||||
if (std::isnan(static_cast<float>(array_log[iter.GetInputPosB()]))) {
|
||||
if (std::isnan(static_cast<float>(array_tmp[i]))) {
|
||||
array_tmp[i] = static_cast<T>(0);
|
||||
}
|
||||
iter.GenNextPos();
|
||||
|
|
|
@ -82,6 +82,7 @@ int SvdCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vec
|
|||
|
||||
num_of_rows_ = input_shape[dim - kDim2];
|
||||
num_of_cols_ = input_shape[dim - kDim1];
|
||||
batch_size_ = 1;
|
||||
for (size_t i = 0; i < dim - kDim2; i++) {
|
||||
batch_size_ = batch_size_ * input_shape[i];
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue