pr to master #8
|
@ -39,6 +39,7 @@ void ReduceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
}
|
||||
shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
auto axis_addr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(AXIS);
|
||||
|
||||
if (axis_addr->isa<ValueTuple>()) {
|
||||
auto attr_axis = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, AXIS);
|
||||
if (attr_axis.size() > shape_.size()) {
|
||||
|
@ -47,18 +48,24 @@ void ReduceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
axis_.push_back(shape_.size() - 1);
|
||||
} else {
|
||||
for (auto axis : attr_axis) {
|
||||
while (axis < 0) {
|
||||
axis += SizeToInt(shape_.size());
|
||||
}
|
||||
if (IntToSize(axis) >= (shape_.size())) {
|
||||
MS_LOG(EXCEPTION) << "axis value is oversize.";
|
||||
}
|
||||
axis < 0 ? axis_.push_back(axis + shape_.size()) : axis_.push_back(axis);
|
||||
axis_.push_back(IntToSize(axis));
|
||||
}
|
||||
}
|
||||
} else if (axis_addr->isa<Int32Imm>()) {
|
||||
int axis = AnfAlgo::GetNodeAttr<int>(kernel_node, AXIS);
|
||||
if (axis >= 0 && IntToSize(axis) >= shape_.size()) {
|
||||
while (axis < 0) {
|
||||
axis += SizeToInt(shape_.size());
|
||||
}
|
||||
if (IntToSize(axis) >= shape_.size()) {
|
||||
MS_LOG(EXCEPTION) << "axis value is oversize.";
|
||||
}
|
||||
axis < 0 ? axis_.push_back(axis + shape_.size()) : axis_.push_back(axis);
|
||||
axis_.push_back(IntToSize(axis));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Attribute axis type is invalid.";
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue