!1205 Gpu UnsortedSegmentSum fix

Merge pull request !1205 from chenweifeng/unsorted_segment_sum
This commit is contained in:
mindspore-ci-bot 2020-05-20 16:00:42 +08:00 committed by Gitee
commit df1eb2f65d
1 changed files with 10 additions and 5 deletions

View File

@ -50,16 +50,21 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel {
bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto ids_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0);
input_dim0_ = input_shapes[0]; auto axis = ids_shapes.size();
for (size_t i = 1; i < input_shapes.size(); i++) { for (size_t i = 0; i < input_shapes.size(); i++) {
if (i < axis) {
input_dim0_ *= input_shapes[i];
} else {
input_dim1_ *= input_shapes[i]; input_dim1_ *= input_shapes[i];
} }
}
output_dim0_ = output_shapes[0]; output_dim0_ = output_shapes[0];
for (size_t i = 1; i < output_shapes.size(); i++) { for (size_t j = 1; j < output_shapes.size(); j++) {
output_dim1_ *= output_shapes[i]; output_dim1_ *= output_shapes[j];
} }
InitSizeLists(); InitSizeLists();