From f7189adb91e6c47c6dfd0ec08281dc47696b7e4d Mon Sep 17 00:00:00 2001 From: yao_yf Date: Wed, 21 Oct 2020 17:18:36 +0800 Subject: [PATCH] fix bug in reshape strategy search when reshape is first operator --- .../parallel/ops_info/reshape_info.cc | 15 +++------ .../ut/cpp/parallel/ops_info/reshape_test.cc | 16 +++++----- .../python/parallel/test_reshape_unexpand.py | 32 ++++++++++++++++--- 3 files changed, 40 insertions(+), 23 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc index 0c50b847a4b..e74665d4a8f 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc @@ -39,7 +39,7 @@ Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { return CheckStr Status ReshapeInfo::InferDevMatrixShape() { Strategys stra = strategy_->GetInputDim(); input_strategy_ = stra.at(0); - dev_matrix_shape_.push_back(input_strategy_[0]); + dev_matrix_shape_ = stra.at(0); return SUCCESS; } @@ -162,17 +162,13 @@ Status ReshapeInfo::InferTensorMap() { } Shape tensor_map_index_input; - tensor_map_index_input.push_back(0); - - for (size_t j = 1; j < inputs_shape_[0].size(); ++j) { - tensor_map_index_input.push_back(MAP_NONE); + for (size_t j = 0; j < inputs_shape_[0].size(); ++j) { + tensor_map_index_input.push_back((int64_t)(inputs_shape_[0].size() - j - 1)); } inputs_tensor_map_.push_back(tensor_map_index_input); Shape tensor_map_index_output; - tensor_map_index_output.push_back(0); - - for (size_t j = 1; j < outputs_shape_[0].size(); ++j) { + for (size_t j = 0; j < outputs_shape_[0].size(); ++j) { tensor_map_index_output.push_back(MAP_NONE); } outputs_tensor_map_.push_back(tensor_map_index_output); @@ -186,8 +182,7 @@ Status ReshapeInfo::InferTensorMap() { Strategys ReshapeInfo::GetOutputsStrategy() { Strategys outputs_strategy; Dimensions strategy; - strategy.push_back(input_strategy_[0]); - for (size_t j = 1; j < outputs_shape_[0].size(); ++j) { + for (size_t j = 0; j < outputs_shape_[0].size(); ++j) { strategy.push_back(1); } outputs_strategy.push_back(strategy); diff --git a/tests/ut/cpp/parallel/ops_info/reshape_test.cc b/tests/ut/cpp/parallel/ops_info/reshape_test.cc index 2818e6c4030..a15cbf2adc5 100644 --- a/tests/ut/cpp/parallel/ops_info/reshape_test.cc +++ b/tests/ut/cpp/parallel/ops_info/reshape_test.cc @@ -74,7 +74,7 @@ TEST_F(TestReshapeInfo, InferDevMatrixShape1) { reshape->Init(strategy); Shape dev_matrix_shape = reshape->dev_matrix_shape(); - Shape expect = {4, 8}; + Shape expect = {4, 1, 1, 1, 8}; ASSERT_EQ(dev_matrix_shape, expect); } @@ -85,7 +85,7 @@ TEST_F(TestReshapeInfo, InferDevMatrixShape2) { reshape->Init(strategy); Shape dev_matrix_shape = reshape->dev_matrix_shape(); - Shape expect = {32}; + Shape expect = {32, 1, 1, 1}; ASSERT_EQ(dev_matrix_shape, expect); } @@ -98,7 +98,7 @@ TEST_F(TestReshapeInfo, InferSliceShape1) { std::vector outputs = reshape->outputs_tensor_info(); Shape input_slice_shape_expect = {8, 512, 7, 7}; - Shape output_slice_shape_expect = {8, 25088}; + Shape output_slice_shape_expect = {32, 25088}; TensorInfo input_tensor_info = inputs.at(0); TensorInfo output_tensor_info = outputs.at(0); @@ -119,7 +119,7 @@ TEST_F(TestReshapeInfo, InferSliceShape2) { std::vector outputs = reshape->outputs_tensor_info(); Shape input_slice_shape_expect = {1, 512, 7, 7}; - Shape output_slice_shape_expect = {1, 25088}; + Shape output_slice_shape_expect = {32, 25088}; TensorInfo input_tensor_info = inputs.at(0); TensorInfo output_tensor_info = outputs.at(0); @@ -139,8 +139,8 @@ TEST_F(TestReshapeInfo, GetTensorLayout1) { std::vector inputs = reshape->inputs_tensor_info(); std::vector outputs = reshape->outputs_tensor_info(); - TensorMap input_expect = {1, -1, -1, -1}; - TensorMap output_expect = {1, -1}; + TensorMap input_expect = {4, 3, 2, 1}; + TensorMap output_expect = {-1, -1}; TensorInfo input_tensor_info = inputs.at(0); TensorInfo output_tensor_info = outputs.at(0); @@ -160,8 +160,8 @@ TEST_F(TestReshapeInfo, GetTensorLayout2) { std::vector inputs = reshape->inputs_tensor_info(); std::vector outputs = reshape->outputs_tensor_info(); - TensorMap input_expect = {0, -1, -1, -1}; - TensorMap output_expect = {0, -1}; + TensorMap input_expect = {3, 2, 1, 0}; + TensorMap output_expect = {-1, -1}; TensorInfo input_tensor_info = inputs.at(0); TensorInfo output_tensor_info = outputs.at(0); diff --git a/tests/ut/python/parallel/test_reshape_unexpand.py b/tests/ut/python/parallel/test_reshape_unexpand.py index 792b982b655..d0144bebdf0 100644 --- a/tests/ut/python/parallel/test_reshape_unexpand.py +++ b/tests/ut/python/parallel/test_reshape_unexpand.py @@ -74,12 +74,12 @@ def test_reshape_unexpand_1(): def __init__(self): super().__init__() self.reshape = P.Reshape() - self.mul = P.Mul().shard(((1, 8), (1, 1, 8))) - self.mul_weight = Parameter(Tensor(np.ones([96, 128]), dtype=ms.float32), name="weight") + self.mul = P.Mul().shard(((1, 1, 8), (1, 8))) + self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight") - def construct(self, x): - weight = self.reshape(self.mul_weight, (1, 128, 96)) - out = self.mul(x, weight) + def construct(self, data): + x = self.reshape(self.mul_weight, (1, 128, 96)) + out = self.mul(x, self.mul_weight) return out size = 8 @@ -236,3 +236,25 @@ def test_reshape_unexpand_7(): net = GradWrap(NetWithLoss(Net())) net.set_auto_parallel() _executor.compile(net, x) + +def test_reshape_unexpand_8(): + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.reshape = P.Reshape() + self.mul = P.Mul().shard(((1, 4, 2), (4, 2))) + self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight") + + def construct(self, data): + x = self.reshape(self.mul_weight, (1, 128, 96)) + out = self.mul(x, self.mul_weight) + return out + + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + x = Tensor(np.ones([128, 96]), dtype=ms.float32) + + net = GradWrap(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + net.set_auto_parallel() + _executor.compile(net, x)