forked from mindspore-Ecosystem/mindspore
!1205 Gpu UnsortedSegmentSum fix
Merge pull request !1205 from chenweifeng/unsorted_segment_sum
This commit is contained in:
commit
df1eb2f65d
|
@ -50,16 +50,21 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel {
|
|||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto ids_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
|
||||
input_dim0_ = input_shapes[0];
|
||||
for (size_t i = 1; i < input_shapes.size(); i++) {
|
||||
input_dim1_ *= input_shapes[i];
|
||||
auto axis = ids_shapes.size();
|
||||
for (size_t i = 0; i < input_shapes.size(); i++) {
|
||||
if (i < axis) {
|
||||
input_dim0_ *= input_shapes[i];
|
||||
} else {
|
||||
input_dim1_ *= input_shapes[i];
|
||||
}
|
||||
}
|
||||
|
||||
output_dim0_ = output_shapes[0];
|
||||
for (size_t i = 1; i < output_shapes.size(); i++) {
|
||||
output_dim1_ *= output_shapes[i];
|
||||
for (size_t j = 1; j < output_shapes.size(); j++) {
|
||||
output_dim1_ *= output_shapes[j];
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
|
|
Loading…
Reference in New Issue