forked from mindspore-Ecosystem/mindspore
make shape to {1} for Svd's u,v if compute_uv is False
This commit is contained in:
parent
d6d132e3ae
commit
2b391b2285
|
@ -87,8 +87,8 @@ void SvdGpuKernelMod::InitSizeLists() {
|
||||||
output_size_list_.push_back(batch_size_ * n_ * p_ * unit_size_);
|
output_size_list_.push_back(batch_size_ * n_ * p_ * unit_size_);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
output_size_list_.push_back(0);
|
output_size_list_.push_back(1);
|
||||||
output_size_list_.push_back(0);
|
output_size_list_.push_back(1);
|
||||||
}
|
}
|
||||||
// for dev_info
|
// for dev_info
|
||||||
workspace_size_list_.push_back(batch_size_ * sizeof(int));
|
workspace_size_list_.push_back(batch_size_ * sizeof(int));
|
||||||
|
|
|
@ -65,8 +65,8 @@ abstract::BaseShapePtr SvdInferShape(const PrimitivePtr &prim, const std::vector
|
||||||
v_shape[v_shape.size() - kIndexOne] = p;
|
v_shape[v_shape.size() - kIndexOne] = p;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
u_shape = {0};
|
u_shape = {1};
|
||||||
v_shape = {0};
|
v_shape = {1};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<abstract::BaseShapePtr> shape_tuple;
|
std::vector<abstract::BaseShapePtr> shape_tuple;
|
||||||
|
|
Loading…
Reference in New Issue