forked from mindspore-Ecosystem/mindspore
fix bug when reduce axis have duplicate elements
This commit is contained in:
parent
90ac7781ba
commit
45c5b4801f
|
@ -86,22 +86,18 @@ void ReduceCpuKernelFunc<T>::InitFunc(const CNodePtr &kernel_node) {
|
|||
auto prim = common::AnfAlgo::GetCNodePrimitive(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
UpdateAxis(prim, kernel_node, kernel_name_, &axis_);
|
||||
// Delete the duplicate axis.
|
||||
auto last = std::unique(axis_.begin(), axis_.end());
|
||||
axis_.erase(last, axis_.end());
|
||||
int64_t dimension = SizeToLong(input_shape_.size());
|
||||
if (axis_.size() > LongToSize(dimension)) {
|
||||
MS_LOG(EXCEPTION) << "For reduce, the axis is " << axis_
|
||||
<< ", it's elements number is more than the input dimension " << dimension;
|
||||
}
|
||||
(void)std::transform(axis_.begin(), axis_.end(), axis_.begin(), [dimension](const auto &a) {
|
||||
(void)std::for_each(axis_.begin(), axis_.end(), [dimension](auto &a) {
|
||||
if (a < -dimension || a >= dimension) {
|
||||
MS_LOG(EXCEPTION) << "For reduce, the each axis element should be in [" << -dimension << ", " << dimension
|
||||
<< "), but got " << a;
|
||||
}
|
||||
return a < 0 ? dimension + a : a;
|
||||
a = a < 0 ? dimension + a : a;
|
||||
});
|
||||
// Delete the duplicate axis.
|
||||
sort(axis_.begin(), axis_.end());
|
||||
auto last = std::unique(axis_.begin(), axis_.end());
|
||||
axis_.erase(last, axis_.end());
|
||||
|
||||
if constexpr (std::is_same<T, bool>::value) {
|
||||
if (kernel_name_ == prim::kPrimReduceAll->name()) {
|
||||
|
|
|
@ -30,7 +30,7 @@ class NetReduce(nn.Cell):
|
|||
self.axis0 = 0
|
||||
self.axis1 = 1
|
||||
self.axis2 = -1
|
||||
self.axis3 = (0, 1)
|
||||
self.axis3 = (0, 1, -2)
|
||||
self.axis4 = (0, 1, 2)
|
||||
self.axis5 = (-1,)
|
||||
self.axis6 = ()
|
||||
|
|
Loading…
Reference in New Issue