pr to master #8

Open
m7grui4p8 wants to merge 201 commits from p69201753/mindspore:cpu-kernel-reuse-1 into master
1 changed files with 10 additions and 3 deletions
Showing only changes of commit d1d28fb032 - Show all commits

View File

@ -39,6 +39,7 @@ void ReduceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
}
shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
auto axis_addr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(AXIS);
if (axis_addr->isa<ValueTuple>()) {
auto attr_axis = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, AXIS);
if (attr_axis.size() > shape_.size()) {
@ -47,18 +48,24 @@ void ReduceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
axis_.push_back(shape_.size() - 1);
} else {
for (auto axis : attr_axis) {
while (axis < 0) {
axis += SizeToInt(shape_.size());
}
if (IntToSize(axis) >= (shape_.size())) {
MS_LOG(EXCEPTION) << "axis value is oversize.";
}
axis < 0 ? axis_.push_back(axis + shape_.size()) : axis_.push_back(axis);
axis_.push_back(IntToSize(axis));
}
}
} else if (axis_addr->isa<Int32Imm>()) {
int axis = AnfAlgo::GetNodeAttr<int>(kernel_node, AXIS);
if (axis >= 0 && IntToSize(axis) >= shape_.size()) {
while (axis < 0) {
axis += SizeToInt(shape_.size());
}
if (IntToSize(axis) >= shape_.size()) {
MS_LOG(EXCEPTION) << "axis value is oversize.";
}
axis < 0 ? axis_.push_back(axis + shape_.size()) : axis_.push_back(axis);
axis_.push_back(IntToSize(axis));
} else {
MS_LOG(EXCEPTION) << "Attribute axis type is invalid.";
}