UnsortedSegmentSum kernel support Nd

This commit is contained in:
wilfChen 2020-05-16 17:06:47 +08:00
parent e42631c127
commit 83151509dc
1 changed files with 10 additions and 5 deletions

View File

@ -50,16 +50,21 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel {
bool Init(const CNodePtr &kernel_node) override {
auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto ids_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0);
input_dim0_ = input_shapes[0];
for (size_t i = 1; i < input_shapes.size(); i++) {
input_dim1_ *= input_shapes[i];
auto axis = ids_shapes.size();
for (size_t i = 0; i < input_shapes.size(); i++) {
if (i < axis) {
input_dim0_ *= input_shapes[i];
} else {
input_dim1_ *= input_shapes[i];
}
}
output_dim0_ = output_shapes[0];
for (size_t i = 1; i < output_shapes.size(); i++) {
output_dim1_ *= output_shapes[i];
for (size_t j = 1; j < output_shapes.size(); j++) {
output_dim1_ *= output_shapes[j];
}
InitSizeLists();