forked from mindspore-Ecosystem/mindspore
fix bug in reshape strategy search when reshape is first operator
This commit is contained in:
parent
f887618662
commit
f7189adb91
|
@ -39,7 +39,7 @@ Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { return CheckStr
|
||||||
Status ReshapeInfo::InferDevMatrixShape() {
|
Status ReshapeInfo::InferDevMatrixShape() {
|
||||||
Strategys stra = strategy_->GetInputDim();
|
Strategys stra = strategy_->GetInputDim();
|
||||||
input_strategy_ = stra.at(0);
|
input_strategy_ = stra.at(0);
|
||||||
dev_matrix_shape_.push_back(input_strategy_[0]);
|
dev_matrix_shape_ = stra.at(0);
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -162,17 +162,13 @@ Status ReshapeInfo::InferTensorMap() {
|
||||||
}
|
}
|
||||||
|
|
||||||
Shape tensor_map_index_input;
|
Shape tensor_map_index_input;
|
||||||
tensor_map_index_input.push_back(0);
|
for (size_t j = 0; j < inputs_shape_[0].size(); ++j) {
|
||||||
|
tensor_map_index_input.push_back((int64_t)(inputs_shape_[0].size() - j - 1));
|
||||||
for (size_t j = 1; j < inputs_shape_[0].size(); ++j) {
|
|
||||||
tensor_map_index_input.push_back(MAP_NONE);
|
|
||||||
}
|
}
|
||||||
inputs_tensor_map_.push_back(tensor_map_index_input);
|
inputs_tensor_map_.push_back(tensor_map_index_input);
|
||||||
|
|
||||||
Shape tensor_map_index_output;
|
Shape tensor_map_index_output;
|
||||||
tensor_map_index_output.push_back(0);
|
for (size_t j = 0; j < outputs_shape_[0].size(); ++j) {
|
||||||
|
|
||||||
for (size_t j = 1; j < outputs_shape_[0].size(); ++j) {
|
|
||||||
tensor_map_index_output.push_back(MAP_NONE);
|
tensor_map_index_output.push_back(MAP_NONE);
|
||||||
}
|
}
|
||||||
outputs_tensor_map_.push_back(tensor_map_index_output);
|
outputs_tensor_map_.push_back(tensor_map_index_output);
|
||||||
|
@ -186,8 +182,7 @@ Status ReshapeInfo::InferTensorMap() {
|
||||||
Strategys ReshapeInfo::GetOutputsStrategy() {
|
Strategys ReshapeInfo::GetOutputsStrategy() {
|
||||||
Strategys outputs_strategy;
|
Strategys outputs_strategy;
|
||||||
Dimensions strategy;
|
Dimensions strategy;
|
||||||
strategy.push_back(input_strategy_[0]);
|
for (size_t j = 0; j < outputs_shape_[0].size(); ++j) {
|
||||||
for (size_t j = 1; j < outputs_shape_[0].size(); ++j) {
|
|
||||||
strategy.push_back(1);
|
strategy.push_back(1);
|
||||||
}
|
}
|
||||||
outputs_strategy.push_back(strategy);
|
outputs_strategy.push_back(strategy);
|
||||||
|
|
|
@ -74,7 +74,7 @@ TEST_F(TestReshapeInfo, InferDevMatrixShape1) {
|
||||||
reshape->Init(strategy);
|
reshape->Init(strategy);
|
||||||
Shape dev_matrix_shape = reshape->dev_matrix_shape();
|
Shape dev_matrix_shape = reshape->dev_matrix_shape();
|
||||||
|
|
||||||
Shape expect = {4, 8};
|
Shape expect = {4, 1, 1, 1, 8};
|
||||||
ASSERT_EQ(dev_matrix_shape, expect);
|
ASSERT_EQ(dev_matrix_shape, expect);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,7 +85,7 @@ TEST_F(TestReshapeInfo, InferDevMatrixShape2) {
|
||||||
reshape->Init(strategy);
|
reshape->Init(strategy);
|
||||||
Shape dev_matrix_shape = reshape->dev_matrix_shape();
|
Shape dev_matrix_shape = reshape->dev_matrix_shape();
|
||||||
|
|
||||||
Shape expect = {32};
|
Shape expect = {32, 1, 1, 1};
|
||||||
ASSERT_EQ(dev_matrix_shape, expect);
|
ASSERT_EQ(dev_matrix_shape, expect);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,7 +98,7 @@ TEST_F(TestReshapeInfo, InferSliceShape1) {
|
||||||
std::vector<TensorInfo> outputs = reshape->outputs_tensor_info();
|
std::vector<TensorInfo> outputs = reshape->outputs_tensor_info();
|
||||||
|
|
||||||
Shape input_slice_shape_expect = {8, 512, 7, 7};
|
Shape input_slice_shape_expect = {8, 512, 7, 7};
|
||||||
Shape output_slice_shape_expect = {8, 25088};
|
Shape output_slice_shape_expect = {32, 25088};
|
||||||
|
|
||||||
TensorInfo input_tensor_info = inputs.at(0);
|
TensorInfo input_tensor_info = inputs.at(0);
|
||||||
TensorInfo output_tensor_info = outputs.at(0);
|
TensorInfo output_tensor_info = outputs.at(0);
|
||||||
|
@ -119,7 +119,7 @@ TEST_F(TestReshapeInfo, InferSliceShape2) {
|
||||||
std::vector<TensorInfo> outputs = reshape->outputs_tensor_info();
|
std::vector<TensorInfo> outputs = reshape->outputs_tensor_info();
|
||||||
|
|
||||||
Shape input_slice_shape_expect = {1, 512, 7, 7};
|
Shape input_slice_shape_expect = {1, 512, 7, 7};
|
||||||
Shape output_slice_shape_expect = {1, 25088};
|
Shape output_slice_shape_expect = {32, 25088};
|
||||||
|
|
||||||
TensorInfo input_tensor_info = inputs.at(0);
|
TensorInfo input_tensor_info = inputs.at(0);
|
||||||
TensorInfo output_tensor_info = outputs.at(0);
|
TensorInfo output_tensor_info = outputs.at(0);
|
||||||
|
@ -139,8 +139,8 @@ TEST_F(TestReshapeInfo, GetTensorLayout1) {
|
||||||
std::vector<TensorInfo> inputs = reshape->inputs_tensor_info();
|
std::vector<TensorInfo> inputs = reshape->inputs_tensor_info();
|
||||||
std::vector<TensorInfo> outputs = reshape->outputs_tensor_info();
|
std::vector<TensorInfo> outputs = reshape->outputs_tensor_info();
|
||||||
|
|
||||||
TensorMap input_expect = {1, -1, -1, -1};
|
TensorMap input_expect = {4, 3, 2, 1};
|
||||||
TensorMap output_expect = {1, -1};
|
TensorMap output_expect = {-1, -1};
|
||||||
|
|
||||||
TensorInfo input_tensor_info = inputs.at(0);
|
TensorInfo input_tensor_info = inputs.at(0);
|
||||||
TensorInfo output_tensor_info = outputs.at(0);
|
TensorInfo output_tensor_info = outputs.at(0);
|
||||||
|
@ -160,8 +160,8 @@ TEST_F(TestReshapeInfo, GetTensorLayout2) {
|
||||||
std::vector<TensorInfo> inputs = reshape->inputs_tensor_info();
|
std::vector<TensorInfo> inputs = reshape->inputs_tensor_info();
|
||||||
std::vector<TensorInfo> outputs = reshape->outputs_tensor_info();
|
std::vector<TensorInfo> outputs = reshape->outputs_tensor_info();
|
||||||
|
|
||||||
TensorMap input_expect = {0, -1, -1, -1};
|
TensorMap input_expect = {3, 2, 1, 0};
|
||||||
TensorMap output_expect = {0, -1};
|
TensorMap output_expect = {-1, -1};
|
||||||
|
|
||||||
TensorInfo input_tensor_info = inputs.at(0);
|
TensorInfo input_tensor_info = inputs.at(0);
|
||||||
TensorInfo output_tensor_info = outputs.at(0);
|
TensorInfo output_tensor_info = outputs.at(0);
|
||||||
|
|
|
@ -74,12 +74,12 @@ def test_reshape_unexpand_1():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
self.mul = P.Mul().shard(((1, 8), (1, 1, 8)))
|
self.mul = P.Mul().shard(((1, 1, 8), (1, 8)))
|
||||||
self.mul_weight = Parameter(Tensor(np.ones([96, 128]), dtype=ms.float32), name="weight")
|
self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight")
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, data):
|
||||||
weight = self.reshape(self.mul_weight, (1, 128, 96))
|
x = self.reshape(self.mul_weight, (1, 128, 96))
|
||||||
out = self.mul(x, weight)
|
out = self.mul(x, self.mul_weight)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
size = 8
|
size = 8
|
||||||
|
@ -236,3 +236,25 @@ def test_reshape_unexpand_7():
|
||||||
net = GradWrap(NetWithLoss(Net()))
|
net = GradWrap(NetWithLoss(Net()))
|
||||||
net.set_auto_parallel()
|
net.set_auto_parallel()
|
||||||
_executor.compile(net, x)
|
_executor.compile(net, x)
|
||||||
|
|
||||||
|
def test_reshape_unexpand_8():
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.mul = P.Mul().shard(((1, 4, 2), (4, 2)))
|
||||||
|
self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight")
|
||||||
|
|
||||||
|
def construct(self, data):
|
||||||
|
x = self.reshape(self.mul_weight, (1, 128, 96))
|
||||||
|
out = self.mul(x, self.mul_weight)
|
||||||
|
return out
|
||||||
|
|
||||||
|
size = 8
|
||||||
|
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||||
|
x = Tensor(np.ones([128, 96]), dtype=ms.float32)
|
||||||
|
|
||||||
|
net = GradWrap(NetWithLoss(Net()))
|
||||||
|
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||||
|
net.set_auto_parallel()
|
||||||
|
_executor.compile(net, x)
|
||||||
|
|
Loading…
Reference in New Issue