!15766 modify tile op

From: @yangzhenzhang
Reviewed-by: @kisnwang,@stsuteng
Signed-off-by: @stsuteng
This commit is contained in:
mindspore-ci-bot 2021-04-28 09:17:17 +08:00 committed by Gitee
commit cbca36d723
4 changed files with 79 additions and 11 deletions

View File

@ -45,7 +45,6 @@ class SelectInfo : public OperatorInfo {
protected:
Status GetAttrs() override { return SUCCESS; }
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override { return SUCCESS; }
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;

View File

@ -65,8 +65,19 @@ Status TileInfo::GetAttrs() {
return SUCCESS;
}
// if some dimension of multiples > 1, split the multiple; otherwise, split the dimension of input
Status TileInfo::CheckStrategy(const StrategyPtr &strategy) {
Shapes multiples = {full_multiples_};
Shape tmp;
for (size_t i = 0; i < full_multiples_.size(); ++i) {
if (full_multiples_[i] != 1) {
tmp.push_back(full_multiples_[i]);
} else {
tmp.push_back(inputs_shape_[0][i]);
}
}
Shapes multiples = {tmp};
MS_LOG(INFO) << name_ << ": The input shape is " << ShapeToString(inputs_shape_[0]) << ", the multiples is "
<< ShapeToString(full_multiples_) << ", so the 'shape' can be split is " << ShapeToString(tmp);
return CheckStrategyValue(strategy, multiples);
}
@ -74,7 +85,7 @@ Status TileInfo::InferDevMatrixShape() {
MS_EXCEPTION_IF_NULL(strategy_);
std::vector<Dimensions> stra = strategy_->GetInputDim();
if (stra.empty()) {
MS_LOG(ERROR) << name_ << "The strategy is empty";
MS_LOG(ERROR) << name_ << ": The strategy is empty";
return FAILED;
}
if (full_multiples_.size() != stra[0].size()) {
@ -86,6 +97,9 @@ Status TileInfo::InferDevMatrixShape() {
slice_multiples_ = full_multiples_;
for (size_t i = 0; i < full_multiples_.size(); ++i) {
if (full_multiples_[i] == 1) {
continue;
}
slice_multiples_[i] = slice_multiples_[i] / dev_matrix_shape_[i];
}
return SUCCESS;
@ -95,13 +109,18 @@ Status TileInfo::InferTensorMap() {
TensorMap input_tensor_map;
TensorMap output_tensor_map;
if (inputs_shape_.empty() || outputs_shape_.empty()) {
MS_LOG(ERROR) << name_ << "The inputs or outputs' shape is empty";
MS_LOG(ERROR) << name_ << ": The inputs or outputs' shape is empty";
return FAILED;
}
// the input tensor cannot be split
// if some dimension of multiples > 1, split the multiple; otherwise, split the dimension of input
for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
input_tensor_map.push_back(MAP_NONE);
input_tensor_map.push_back(inputs_shape_[0].size() - i - 1);
}
for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
if (full_multiples_[i] != 1) {
input_tensor_map[i] = MAP_NONE;
}
}
// cannot use dev_matrix_shape_ replace outputs_shape_[0], because it may not be fully split in all devices.
@ -163,11 +182,11 @@ Status TileInfo::InferTensorInfo() {
void TileInfo::UpdateMultiples(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() != 3) {
MS_LOG(EXCEPTION) << "The size of tile cnode's inputs must be 3";
MS_LOG(EXCEPTION) << name_ << ": The size of tile cnode's inputs must be 3";
}
if (!IsValueNode<ValueTuple>(cnode->input(2))) {
MS_LOG(EXCEPTION) << "The input[2] of tile cnode is not ValueTuple.";
MS_LOG(EXCEPTION) << name_ << ": The input[2] of tile cnode is not ValueTuple.";
}
auto func_graph = cnode->func_graph();
@ -199,7 +218,14 @@ Status TileInfo::GenerateStrategies(int64_t stage_id) {
Shape multiples_split(full_multiples_.size(), 1);
Shapes splittable_inputs = {multiples_split};
// if some dimension of multiples > 1, split the multiple; otherwise, split the dimension of input
std::vector<StrategyPtr> sp_vector;
Shape tmp_input_shape = full_multiples_;
for (size_t i = 0; i < full_multiples_.size(); ++i) {
if (full_multiples_[i] == 0) {
tmp_input_shape[i] = inputs_shape_[0][i];
}
}
Shapes tmp_inputs_shape = {full_multiples_};
if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_inputs, &sp_vector) != SUCCESS) {
return FAILED;

View File

@ -94,6 +94,15 @@ def test_select_model_parallel():
compile_net(net)
def test_select_mirror():
context.set_auto_parallel_context(
parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 2, 2), (1, 2, 2))
strategy2 = ((1, 2, 2), (1, 2, 2), (1, 2, 2))
net = Net(_w1, _w2, strategy1, strategy2)
compile_net(net)
def test_select_auto_parallel():
context.set_auto_parallel_context(
parallel_mode="auto_parallel", device_num=8, global_rank=0)

View File

@ -52,20 +52,38 @@ class Net2(Cell):
out = self.tile(out, (8, 8, 4, 2))
return out
class Net3(Cell):
def __init__(self, weight, strategy1=None, strategy2=None, is_parameter=True):
super().__init__()
self.mul = P.Mul().shard(strategy1)
self.tile = P.Tile().shard(strategy2)
if is_parameter:
self.weight = Parameter(weight, "w1")
else:
self.weight = weight
self.mul2 = P.Mul()
def construct(self, x, b):
out = self.tile(self.weight, (8, 1, 1))
out = self.mul(x, out)
return out
_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
_x1 = Tensor(np.ones([128, 16, 16]), dtype=ms.float32)
_w1 = Tensor(np.ones([16, 16, 16]), dtype=ms.float32)
_w2 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
_w3 = Tensor(np.ones([128, 16, 16]), dtype=ms.float32)
_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
def compile_net(net):
context.set_context(save_graphs=False)
def compile_net(net, x=_b, b=_b):
context.set_context(save_graphs=True)
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()
_executor.compile(train_net, _x, _b)
_executor.compile(train_net, x, b)
context.reset_auto_parallel_context()
@ -101,6 +119,14 @@ def test_tile_tensor_no_full_split():
compile_net(net)
def test_tile_tensor_no_full_split2():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 1), (2, 2, 1))
strategy2 = ((2, 2, 1),)
net = Net3(_w1, strategy1, strategy2)
compile_net(net, _x1, _b)
def test_tile_output():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 2), (2, 2, 2))
@ -108,6 +134,7 @@ def test_tile_output():
net = Net2(_w2, strategy1, strategy2)
compile_net(net)
def test_tile_output_no_full_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 2), (2, 2, 2))
@ -123,7 +150,14 @@ def test_tile_no_strategy():
net = Net2(_w2, strategy1, strategy2)
compile_net(net)
def test_tile_auto_parallel():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net2(_w2)
compile_net(net)
def test_tile_auto_parallel_2():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
net = Net3(_w1)
compile_net(net, _x1, _b)