test_micro_batch_Interleaved

This commit is contained in:
lilei 2021-12-05 10:04:39 +08:00
parent c88da99f77
commit e933aa268b
5 changed files with 168 additions and 4 deletions

View File

@ -545,7 +545,8 @@ AnfNodePtr GetPreNode(const AnfNodePtr &node) {
continue; continue;
} }
(void)node_queue.erase(node_queue.begin()); (void)node_queue.erase(node_queue.begin());
if (!IsInEndNodeBlackList(cur_node) && cur_node->HasPrimalAttr(NEED_GRAD)) { auto prim = GetCNodePrimitive(cur_node);
if (!IsInEndNodeBlackList(cur_node) && cur_node->HasPrimalAttr(NEED_GRAD) && !prim->HasAttr("realdiv_flag")) {
MS_LOG(INFO) << "Pipeline End node: " << cur_node->DebugString(); MS_LOG(INFO) << "Pipeline End node: " << cur_node->DebugString();
return cur_node; return cur_node;
} }

View File

@ -18,7 +18,7 @@ Wrap cells for networks.
Use the Wrapper to combine the loss or build the training steps. Use the Wrapper to combine the loss or build the training steps.
""" """
from .cell_wrapper import ForwardValueAndGrad, TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, \ from .cell_wrapper import ForwardValueAndGrad, TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, \
ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple, PipelineCell ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple, MicroBatchInterleaved, PipelineCell
from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell
from .grad_reducer import DistributedGradReducer from .grad_reducer import DistributedGradReducer
from ..layer.timedistributed import TimeDistributed from ..layer.timedistributed import TimeDistributed
@ -30,6 +30,7 @@ __all__ = [
"TrainOneStepCell", "TrainOneStepCell",
"WithLossCell", "WithLossCell",
"WithGradCell", "WithGradCell",
"MicroBatchInterleaved",
"PipelineCell", "PipelineCell",
"WithEvalCell", "WithEvalCell",
"GetNextSingleOp", "GetNextSingleOp",

View File

@ -480,6 +480,7 @@ class MicroBatchInterleaved(Cell):
self.network = network self.network = network
self.interleave_num = interleave_num self.interleave_num = interleave_num
self.interleave_inputs = nn.CellList() self.interleave_inputs = nn.CellList()
self.realdiv = P.RealDiv().add_prim_attr("realdiv_flag", True)
for _ in range(interleave_num): for _ in range(interleave_num):
interleave_data = _MicroBatch(interleave_num) interleave_data = _MicroBatch(interleave_num)
interleave_data.strided_slice.add_prim_attr("strided_slice_flag", True) interleave_data.strided_slice.add_prim_attr("strided_slice_flag", True)
@ -490,7 +491,7 @@ class MicroBatchInterleaved(Cell):
for i in range(self.interleave_num): for i in range(self.interleave_num):
interleave_input = self.interleave_inputs[i](i, *inputs) interleave_input = self.interleave_inputs[i](i, *inputs)
output += self.network(*interleave_input) output += self.network(*interleave_input)
return output / self.interleave_num return self.realdiv(output, self.interleave_num)
class PipelineCell(Cell): class PipelineCell(Cell):

View File

@ -0,0 +1,71 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore.ops import operations as P
from mindspore.common.api import _cell_graph_executor
from mindspore.nn.wrap.cell_wrapper import MicroBatchInterleaved
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul1 = P.MatMul().shard(strategy1)
self.matmul2 = P.MatMul().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul1(x, y)
out = self.matmul2(out, b)
return out
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = P.ReLU()
self.network = network
def construct(self, x, y, b):
predict = self.network(x, y, b)
return self.loss(predict)
def compile_net(net, x, y, b):
net.set_auto_parallel()
net.set_train()
_cell_graph_executor.compile(net, x, y, b)
def test_micro_batch_interleaved():
"""
Feature: test MicroBatchInterleaved in auto parallel.
Description: net with MicroBatchInterleaved in semi auto parallel.
Expectation: compile done without error.
"""
context.set_context(mode=context.GRAPH_MODE)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
context.set_auto_parallel_context(device_num=8, global_rank=0, gradients_mean=True)
strategy1 = ((4, 2), (2, 1))
strategy2 = ((2, 4), (4, 1))
micro_batch_interleaved = 2
net = MicroBatchInterleaved(NetWithLoss(Net(strategy1, strategy2)), micro_batch_interleaved)
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([32 * micro_batch_interleaved, 64]), dtype=ms.float32)
b = Tensor(np.ones([64 * micro_batch_interleaved, 64]), dtype=ms.float32)
compile_net(net, x, y, b)

View File

@ -21,7 +21,7 @@ from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.nn.wrap.cell_wrapper import PipelineCell from mindspore.nn.wrap.cell_wrapper import PipelineCell, MicroBatchInterleaved
class DatasetLenet(): class DatasetLenet():
@ -263,3 +263,93 @@ def test_pipeline_split_shared_parameter_stage1_opt_shard():
optimizer = nn.Lamb(params, learning_rate=0.01) optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer) model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False) model.train(2, dataset, dataset_sink_mode=False)
def test_pipeline_split_with_micro_batch_interleaved_stage0():
"""
Feature: test PipelineSplit with MicroBatchInterleaved in auto parallel.
Description: net with MicroBatchInterleaved in semi auto parallel.
Expectation: success.
"""
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
micro_batch_interleaved = 2
net = PipelineCell(MicroBatchInterleaved(PipelineSplit(strategy1, strategy2), micro_batch_interleaved), 4)
params = net.network.network.cell.block[0].trainable_params()
dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
for _, param in model._train_network.parameters_and_names():
assert param.name != "cell.block.1.param"
assert param.name != "cell.block.1.param1"
def test_pipeline_split_with_micro_batch_interleaved_stage1():
"""
Feature: test PipelineSplit with MicroBatchInterleaved in auto parallel.
Description: net with MicroBatchInterleaved in semi auto parallel.
Expectation: success.
"""
context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
micro_batch_interleaved = 2
net = PipelineCell(MicroBatchInterleaved(PipelineSplit(strategy1, strategy2), micro_batch_interleaved), 4)
params = net.network.network.cell.block[1].trainable_params()
dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
for _, param in model._train_network.parameters_and_names():
assert param.name != "cell.block.0.param"
assert param.name != "cell.block.0.param1"
def test_pipeline_split_shared_parameter_with_micro_batch_interleaved_stage0_opt_shard():
"""
Feature: test PipelineSplitSharedParameter with MicroBatchInterleaved in auto parallel.
Description: net with MicroBatchInterleaved in semi auto parallel.
Expectation: success.
"""
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
micro_batch_interleaved = 2
net = PipelineCell(MicroBatchInterleaved(PipelineSplit2(strategy1, strategy2), micro_batch_interleaved), 4)
params = net.network.network.cell.block[0].trainable_params()
dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
def test_pipeline_split_shared_parameter_with_micro_batch_interleaved_stage1_opt_shard():
"""
Feature: test PipelineSplitSharedParameter with MicroBatchInterleaved in auto parallel.
Description: net with MicroBatchInterleaved in semi auto parallel.
Expectation: success.
"""
context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2, enable_parallel_optimizer=True)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
micro_batch_interleaved = 2
net = PipelineCell(MicroBatchInterleaved(PipelineSplit2(strategy1, strategy2), micro_batch_interleaved), 4)
params = net.network.network.cell.block[1].trainable_params()
dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)