forked from mindspore-Ecosystem/mindspore
!26985 fix the negative axis problem of reduce_eliminate
Merge pull request !26985 from huangbingjian/reduce_eliminate
This commit is contained in:
commit
58fc9a6d81
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -67,14 +67,20 @@ class ReduceOneEliminater : public AnfVisitor {
|
|||
}
|
||||
|
||||
// {_Reduce, X, axis} -> {Reshape, X, new_shape}
|
||||
size_t x_shape_size = x_shape_.size();
|
||||
std::vector<int64_t> positive_axis;
|
||||
std::transform(axis_.begin(), axis_.end(), std::back_inserter(positive_axis),
|
||||
[x_shape_size](int64_t idx) { return idx < 0 ? idx + x_shape_size : idx; });
|
||||
|
||||
std::vector<ValuePtr> elements;
|
||||
for (size_t i = 0; i < x_shape_.size(); i++) {
|
||||
auto iter = find(axis_.begin(), axis_.end(), i);
|
||||
if (iter == axis_.end()) {
|
||||
for (size_t i = 0; i < x_shape_size; i++) {
|
||||
auto iter = find(positive_axis.begin(), positive_axis.end(), i);
|
||||
if (iter == positive_axis.end()) {
|
||||
ValuePtr s = MakeValue(x_shape_[i]);
|
||||
elements.push_back(s);
|
||||
}
|
||||
}
|
||||
|
||||
auto new_shape = std::make_shared<ValueTuple>(elements);
|
||||
auto reshape_op = prim::GetPythonOps("reshape", "mindspore.ops.functional")->cast<PrimitivePtr>();
|
||||
auto node_abstract = node->abstract();
|
||||
|
|
|
@ -19,6 +19,7 @@ import pytest
|
|||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
|
||||
|
@ -100,3 +101,30 @@ def test_dynamic_reduce_mean(dtype, shape, axis, keep_dims):
|
|||
error = np.ones(shape=expect.shape) * 1.0e-5
|
||||
assert np.all(diff < error)
|
||||
assert output.shape == expect.shape
|
||||
|
||||
|
||||
class ReduceMeanNegativeNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mean0 = P.ReduceMean(True)
|
||||
self.mean1 = P.ReduceMean(False)
|
||||
|
||||
def construct(self, x):
|
||||
t = self.mean0(x, ())
|
||||
return self.mean1(t, (-1,))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_reduce_mean_negative():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for ReduceMean with negative axis.
|
||||
Expectation: the result match expectation
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = Tensor([[[1, 2, 3,], [3, 2, 1]]], mstype.float32)
|
||||
net = ReduceMeanNegativeNet()
|
||||
out = net(x)
|
||||
assert out.shape == (1, 1)
|
||||
|
|
Loading…
Reference in New Issue