fix bug when reduce axis have duplicate elements

This commit is contained in:
buxue 2022-06-09 20:59:42 +08:00
parent 90ac7781ba
commit 45c5b4801f
2 changed files with 6 additions and 10 deletions

View File

@ -86,22 +86,18 @@ void ReduceCpuKernelFunc<T>::InitFunc(const CNodePtr &kernel_node) {
auto prim = common::AnfAlgo::GetCNodePrimitive(kernel_node);
MS_EXCEPTION_IF_NULL(prim);
UpdateAxis(prim, kernel_node, kernel_name_, &axis_);
// Delete the duplicate axis.
auto last = std::unique(axis_.begin(), axis_.end());
axis_.erase(last, axis_.end());
int64_t dimension = SizeToLong(input_shape_.size());
if (axis_.size() > LongToSize(dimension)) {
MS_LOG(EXCEPTION) << "For reduce, the axis is " << axis_
<< ", it's elements number is more than the input dimension " << dimension;
}
(void)std::transform(axis_.begin(), axis_.end(), axis_.begin(), [dimension](const auto &a) {
(void)std::for_each(axis_.begin(), axis_.end(), [dimension](auto &a) {
if (a < -dimension || a >= dimension) {
MS_LOG(EXCEPTION) << "For reduce, the each axis element should be in [" << -dimension << ", " << dimension
<< "), but got " << a;
}
return a < 0 ? dimension + a : a;
a = a < 0 ? dimension + a : a;
});
// Delete the duplicate axis.
sort(axis_.begin(), axis_.end());
auto last = std::unique(axis_.begin(), axis_.end());
axis_.erase(last, axis_.end());
if constexpr (std::is_same<T, bool>::value) {
if (kernel_name_ == prim::kPrimReduceAll->name()) {

View File

@ -30,7 +30,7 @@ class NetReduce(nn.Cell):
self.axis0 = 0
self.axis1 = 1
self.axis2 = -1
self.axis3 = (0, 1)
self.axis3 = (0, 1, -2)
self.axis4 = (0, 1, 2)
self.axis5 = (-1,)
self.axis6 = ()