!20049 [LITE][TRAIN]fix reduce infer

Merge pull request !20049 from yefeng/131-train-fix_infer
This commit is contained in:
i-robot 2021-07-13 06:10:13 +00:00 committed by Gitee
commit be52337747
1 changed files with 1 additions and 1 deletions

View File

@ -70,7 +70,7 @@ int ReduceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
bool keep_dims = param->keep_dims_;
int out_shape[MAX_SHAPE_SIZE] = {0};
const size_t out_shape_size = 0;
if (inputs_size == 1) {
if (inputs_size == 1 || (inputs_size == 2 && inputs[1]->shape_size_ == 1 && inputs[1]->shape_[0] == 0)) {
return ReduceOnAllAxes(input, output, out_shape, out_shape_size, keep_dims);
}
// get axes from input tensor