kldivlossGrad bugfix for precision problem on cpu at dynamic-shape

This commit is contained in:
zhangyanhui 2023-02-18 10:58:54 +08:00
parent 978cfb6392
commit 14d3f8ecb4
1 changed files with 3 additions and 0 deletions

View File

@ -73,6 +73,7 @@ int KLDivLossGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, cons
return ret; return ret;
} }
input_grad_shape_size_ = 1;
std::vector<int64_t> input_grad_shape = inputs[kIndex0]->GetShapeVector(); std::vector<int64_t> input_grad_shape = inputs[kIndex0]->GetShapeVector();
if (input_grad_shape.size() >= 1) { if (input_grad_shape.size() >= 1) {
for (size_t i = 0; i < input_grad_shape.size(); ++i) { for (size_t i = 0; i < input_grad_shape.size(); ++i) {
@ -80,6 +81,7 @@ int KLDivLossGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, cons
} }
} }
input_x_shape_size_ = 1;
std::vector<int64_t> input_x_shape = inputs[kIndex1]->GetShapeVector(); std::vector<int64_t> input_x_shape = inputs[kIndex1]->GetShapeVector();
if (input_x_shape.size() >= 1) { if (input_x_shape.size() >= 1) {
for (size_t i = 0; i < input_x_shape.size(); ++i) { for (size_t i = 0; i < input_x_shape.size(); ++i) {
@ -87,6 +89,7 @@ int KLDivLossGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, cons
} }
} }
input_target_shape_size_ = 1;
std::vector<int64_t> input_target_shape = inputs[kIndex2]->GetShapeVector(); std::vector<int64_t> input_target_shape = inputs[kIndex2]->GetShapeVector();
if (input_target_shape.size() >= 1) { if (input_target_shape.size() >= 1) {
for (size_t i = 0; i < input_target_shape.size(); ++i) { for (size_t i = 0; i < input_target_shape.size(); ++i) {