forked from mindspore-Ecosystem/mindspore
!20049 [LITE][TRAIN]fix reduce infer
Merge pull request !20049 from yefeng/131-train-fix_infer
This commit is contained in:
commit
be52337747
|
@ -70,7 +70,7 @@ int ReduceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
|
|||
bool keep_dims = param->keep_dims_;
|
||||
int out_shape[MAX_SHAPE_SIZE] = {0};
|
||||
const size_t out_shape_size = 0;
|
||||
if (inputs_size == 1) {
|
||||
if (inputs_size == 1 || (inputs_size == 2 && inputs[1]->shape_size_ == 1 && inputs[1]->shape_[0] == 0)) {
|
||||
return ReduceOnAllAxes(input, output, out_shape, out_shape_size, keep_dims);
|
||||
}
|
||||
// get axes from input tensor
|
||||
|
|
Loading…
Reference in New Issue