clear warnings

This commit is contained in:
baihuawei 2020-09-22 19:43:08 +08:00
parent e7296ffd69
commit 3583e63901
3 changed files with 47 additions and 33 deletions

View File

@ -38,37 +38,7 @@ void ReduceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_LOG(EXCEPTION) << "Array reduce kernel type " << kernel_name << " is not supported.";
}
shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
auto axis_addr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(AXIS);
if (axis_addr->isa<ValueTuple>()) {
auto attr_axis = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, AXIS);
if (attr_axis.size() > shape_.size()) {
MS_LOG(EXCEPTION) << "invalid axis size: " << axis_.size();
} else if (attr_axis.empty()) {
axis_.push_back(shape_.size() - 1);
} else {
for (auto axis : attr_axis) {
while (axis < 0) {
axis += SizeToInt(shape_.size());
}
if (IntToSize(axis) >= (shape_.size())) {
MS_LOG(EXCEPTION) << "axis value is oversize.";
}
axis_.push_back(IntToSize(axis));
}
}
} else if (axis_addr->isa<Int32Imm>()) {
int axis = AnfAlgo::GetNodeAttr<int>(kernel_node, AXIS);
while (axis < 0) {
axis += SizeToInt(shape_.size());
}
if (IntToSize(axis) >= shape_.size()) {
MS_LOG(EXCEPTION) << "axis value is oversize.";
}
axis_.push_back(IntToSize(axis));
} else {
MS_LOG(EXCEPTION) << "Attribute axis type is invalid.";
}
CheckAxis(kernel_node);
for (size_t i = 0; i < shape_.size(); ++i) {
if (shape_[i] <= 0) {
MS_LOG(EXCEPTION) << "shape value is invalid.";
@ -114,6 +84,41 @@ bool ReduceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
return true;
}
void ReduceCPUKernel::CheckAxis(const CNodePtr &kernel_node) {
auto axis_addr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(AXIS);
if (axis_addr->isa<ValueTuple>()) {
auto attr_axis = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, AXIS);
if (attr_axis.size() > shape_.size()) {
MS_LOG(EXCEPTION) << "invalid axis size: " << axis_.size();
} else if (attr_axis.empty()) {
for (size_t i = 0; i < shape_.size(); ++i) {
axis_.push_back(i);
}
} else {
for (auto axis : attr_axis) {
while (axis < 0) {
axis += SizeToInt(shape_.size());
}
if (IntToSize(axis) >= (shape_.size())) {
MS_LOG(EXCEPTION) << "axis value is oversize.";
}
axis_.push_back(IntToSize(axis));
}
}
} else if (axis_addr->isa<Int32Imm>()) {
int axis = AnfAlgo::GetNodeAttr<int>(kernel_node, AXIS);
while (axis < 0) {
axis += SizeToInt(shape_.size());
}
if (IntToSize(axis) >= shape_.size()) {
MS_LOG(EXCEPTION) << "axis value is oversize.";
}
axis_.push_back(IntToSize(axis));
} else {
MS_LOG(EXCEPTION) << "Attribute axis type is invalid.";
}
}
void ReduceCPUKernel::ConvertDataToOutput(const float *new_input, float *output) {
if (reduce_type_ == kReduceTypeMax) {
for (size_t i = 0; i < left_dims_; ++i) {

View File

@ -35,6 +35,7 @@ class ReduceCPUKernel : public CPUKernel {
void Transpose(const int size, const float *input, const std::vector<size_t> &input_shape,
const std::vector<size_t> &input_axis, const int shape_size, float *output);
void ConvertDataToOutput(const float *input, float *output);
void CheckAxis(const CNodePtr &kernel_node);
size_t reduce_type_ = 0;
std::vector<size_t> axis_;
std::vector<size_t> shape_;

View File

@ -32,6 +32,8 @@ class NetReduce(nn.Cell):
self.axis2 = -1
self.axis3 = (0, 1)
self.axis4 = (0, 1, 2)
self.axis5 = (-1,)
self.axis6 = ()
self.reduce_mean = P.ReduceMean(False)
self.reduce_sum = P.ReduceSum(False)
self.reduce_max = P.ReduceMax(False)
@ -46,7 +48,10 @@ class NetReduce(nn.Cell):
self.reduce_sum(indice, self.axis0),
self.reduce_sum(indice, self.axis2),
self.reduce_max(indice, self.axis0),
self.reduce_max(indice, self.axis2))
self.reduce_max(indice, self.axis2),
self.reduce_max(indice, self.axis5),
self.reduce_max(indice, self.axis6))
@pytest.mark.level0
@ -69,6 +74,8 @@ def test_reduce():
print(output[6])
print(output[7])
print(output[8])
print(output[9])
print(output[10])
expect_0 = np.array([[2., 1., 2., 3., 0., 1], [2., 2., 1., 2., 3., 2.]]).astype(np.float32)
expect_1 = np.array([[1.5, 1.5, 1.5, 3., 2., 1.], [1.5, 0., 0.5, 4.5, 2., 2.], [3., 3., 2.5, 0., 0.5, 1.5]]).astype(
np.float32)
@ -88,6 +95,7 @@ def test_reduce():
assert (output[6].asnumpy() == expect_6).all()
assert (output[7].asnumpy() == expect_7).all()
assert (output[8].asnumpy() == expect_8).all()
assert (output[9].asnumpy() == expect_8).all()
assert (output[10].asnumpy() == 5.0).all()
test_reduce()