forked from mindspore-Ecosystem/mindspore
commit
2298c98932
|
@ -38,37 +38,7 @@ void ReduceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
MS_LOG(EXCEPTION) << "Array reduce kernel type " << kernel_name << " is not supported.";
|
||||
}
|
||||
shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
auto axis_addr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(AXIS);
|
||||
|
||||
if (axis_addr->isa<ValueTuple>()) {
|
||||
auto attr_axis = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, AXIS);
|
||||
if (attr_axis.size() > shape_.size()) {
|
||||
MS_LOG(EXCEPTION) << "invalid axis size: " << axis_.size();
|
||||
} else if (attr_axis.empty()) {
|
||||
axis_.push_back(shape_.size() - 1);
|
||||
} else {
|
||||
for (auto axis : attr_axis) {
|
||||
while (axis < 0) {
|
||||
axis += SizeToInt(shape_.size());
|
||||
}
|
||||
if (IntToSize(axis) >= (shape_.size())) {
|
||||
MS_LOG(EXCEPTION) << "axis value is oversize.";
|
||||
}
|
||||
axis_.push_back(IntToSize(axis));
|
||||
}
|
||||
}
|
||||
} else if (axis_addr->isa<Int32Imm>()) {
|
||||
int axis = AnfAlgo::GetNodeAttr<int>(kernel_node, AXIS);
|
||||
while (axis < 0) {
|
||||
axis += SizeToInt(shape_.size());
|
||||
}
|
||||
if (IntToSize(axis) >= shape_.size()) {
|
||||
MS_LOG(EXCEPTION) << "axis value is oversize.";
|
||||
}
|
||||
axis_.push_back(IntToSize(axis));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Attribute axis type is invalid.";
|
||||
}
|
||||
CheckAxis(kernel_node);
|
||||
for (size_t i = 0; i < shape_.size(); ++i) {
|
||||
if (shape_[i] <= 0) {
|
||||
MS_LOG(EXCEPTION) << "shape value is invalid.";
|
||||
|
@ -114,6 +84,41 @@ bool ReduceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
return true;
|
||||
}
|
||||
|
||||
void ReduceCPUKernel::CheckAxis(const CNodePtr &kernel_node) {
|
||||
auto axis_addr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(AXIS);
|
||||
if (axis_addr->isa<ValueTuple>()) {
|
||||
auto attr_axis = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, AXIS);
|
||||
if (attr_axis.size() > shape_.size()) {
|
||||
MS_LOG(EXCEPTION) << "invalid axis size: " << axis_.size();
|
||||
} else if (attr_axis.empty()) {
|
||||
for (size_t i = 0; i < shape_.size(); ++i) {
|
||||
axis_.push_back(i);
|
||||
}
|
||||
} else {
|
||||
for (auto axis : attr_axis) {
|
||||
while (axis < 0) {
|
||||
axis += SizeToInt(shape_.size());
|
||||
}
|
||||
if (IntToSize(axis) >= (shape_.size())) {
|
||||
MS_LOG(EXCEPTION) << "axis value is oversize.";
|
||||
}
|
||||
axis_.push_back(IntToSize(axis));
|
||||
}
|
||||
}
|
||||
} else if (axis_addr->isa<Int32Imm>()) {
|
||||
int axis = AnfAlgo::GetNodeAttr<int>(kernel_node, AXIS);
|
||||
while (axis < 0) {
|
||||
axis += SizeToInt(shape_.size());
|
||||
}
|
||||
if (IntToSize(axis) >= shape_.size()) {
|
||||
MS_LOG(EXCEPTION) << "axis value is oversize.";
|
||||
}
|
||||
axis_.push_back(IntToSize(axis));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Attribute axis type is invalid.";
|
||||
}
|
||||
}
|
||||
|
||||
void ReduceCPUKernel::ConvertDataToOutput(const float *new_input, float *output) {
|
||||
if (reduce_type_ == kReduceTypeMax) {
|
||||
for (size_t i = 0; i < left_dims_; ++i) {
|
||||
|
|
|
@ -35,6 +35,7 @@ class ReduceCPUKernel : public CPUKernel {
|
|||
void Transpose(const int size, const float *input, const std::vector<size_t> &input_shape,
|
||||
const std::vector<size_t> &input_axis, const int shape_size, float *output);
|
||||
void ConvertDataToOutput(const float *input, float *output);
|
||||
void CheckAxis(const CNodePtr &kernel_node);
|
||||
size_t reduce_type_ = 0;
|
||||
std::vector<size_t> axis_;
|
||||
std::vector<size_t> shape_;
|
||||
|
|
|
@ -32,6 +32,8 @@ class NetReduce(nn.Cell):
|
|||
self.axis2 = -1
|
||||
self.axis3 = (0, 1)
|
||||
self.axis4 = (0, 1, 2)
|
||||
self.axis5 = (-1,)
|
||||
self.axis6 = ()
|
||||
self.reduce_mean = P.ReduceMean(False)
|
||||
self.reduce_sum = P.ReduceSum(False)
|
||||
self.reduce_max = P.ReduceMax(False)
|
||||
|
@ -46,7 +48,10 @@ class NetReduce(nn.Cell):
|
|||
self.reduce_sum(indice, self.axis0),
|
||||
self.reduce_sum(indice, self.axis2),
|
||||
self.reduce_max(indice, self.axis0),
|
||||
self.reduce_max(indice, self.axis2))
|
||||
self.reduce_max(indice, self.axis2),
|
||||
self.reduce_max(indice, self.axis5),
|
||||
self.reduce_max(indice, self.axis6))
|
||||
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -69,6 +74,8 @@ def test_reduce():
|
|||
print(output[6])
|
||||
print(output[7])
|
||||
print(output[8])
|
||||
print(output[9])
|
||||
print(output[10])
|
||||
expect_0 = np.array([[2., 1., 2., 3., 0., 1], [2., 2., 1., 2., 3., 2.]]).astype(np.float32)
|
||||
expect_1 = np.array([[1.5, 1.5, 1.5, 3., 2., 1.], [1.5, 0., 0.5, 4.5, 2., 2.], [3., 3., 2.5, 0., 0.5, 1.5]]).astype(
|
||||
np.float32)
|
||||
|
@ -88,6 +95,7 @@ def test_reduce():
|
|||
assert (output[6].asnumpy() == expect_6).all()
|
||||
assert (output[7].asnumpy() == expect_7).all()
|
||||
assert (output[8].asnumpy() == expect_8).all()
|
||||
|
||||
assert (output[9].asnumpy() == expect_8).all()
|
||||
assert (output[10].asnumpy() == 5.0).all()
|
||||
|
||||
test_reduce()
|
||||
|
|
Loading…
Reference in New Issue