bugfix:Attr of axis of ReduceSum is invalid

This commit is contained in:
lizhenyu 2020-06-29 11:38:08 +08:00
parent f6c80b22f9
commit 7bbe183e6f
1 changed files with 2 additions and 1 deletions

View File

@ -94,7 +94,8 @@ class ArrayReduceGpuKernel : public GpuKernel {
}
int input_dim_length = SizeToInt(AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0).size());
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("axis")->isa<ValueTuple>()) {
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("axis")->isa<ValueTuple>() ||
AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("axis")->isa<ValueList>()) {
auto attr_axis = GetAttr<std::vector<int>>(kernel_node, "axis");
if (attr_axis.empty()) {
axis_.push_back(-1);