!41590 Add repeat computation support for pynative shard

Merge pull request !41590 from liuluobin/master_repeat
This commit is contained in:
i-robot 2022-09-21 01:24:28 +00:00 committed by Gitee
commit e4884bff6e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 116 additions and 39 deletions

View File

@ -418,7 +418,6 @@ static bool SetStrategyForShard(const FuncGraphPtr &root, const std::vector<AnfN
constexpr size_t kShardParameterPlanIndex = 4;
for (auto &node : all_nodes) {
if (IsPrimitiveCNode(node, prim::kPrimShard)) {
root->set_flag(kPynativeShard, true);
auto cnode = node->cast<CNodePtr>();
auto vnode = cnode->input(kShardFnIndex)->cast<ValueNodePtr>();
auto in_strategy = cnode->input(kShardInStrategyIndex);

View File

@ -128,7 +128,7 @@ static CNodePtr CreateVirtualDataset(const FuncGraphPtr &func_graph) {
static std::set<FuncGraphPtr> FindForwardGraph(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
std::set<FuncGraphPtr> graph_sets;
if (!parallel::IsAutoParallelCareGraph(root) && !parallel::IsPynativeParallel()) {
if (!parallel::IsAutoParallelCareGraph(root)) {
return graph_sets;
}
std::set<AnfNodePtr> input_parameters;
@ -217,6 +217,16 @@ static void InsertVirtualDataset(const FuncGraphPtr &root, const std::vector<Anf
}
}
// If graph has shard node, set flag 'kPynativeShard' for root graph
void SetPynativeShardFlagIfHasShardNode(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
for (auto &node : all_nodes) {
if (IsPrimitiveCNode(node, prim::kPrimShard)) {
root->set_flag(parallel::kPynativeShard, true);
break;
}
}
}
// Only auto_parallel and semi_auto_parallel support PipelineSplit
bool PipelineSplit(const ResourcePtr &res) {
#ifdef WITH_BACKEND
@ -237,6 +247,7 @@ bool PipelineSplit(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(ret);
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
SetPynativeShardFlagIfHasShardNode(root, all_nodes);
if (!HasVirtualDataset(all_nodes)) {
InsertVirtualDataset(root, all_nodes);
}

View File

@ -40,7 +40,8 @@ from mindspore._c_expression import GraphExecutor_, Tensor, MetaTensor, CSRTenso
_ms_memory_recycle
from mindspore.parallel._tensor import _load_tensor_by_layout
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched, _enable_distributed_mindrt
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _get_pipeline_stages
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _get_pipeline_stages, \
_is_pynative_parallel
from mindspore._checkparam import Validator
from mindspore.common._utils import is_shape_unknown
from mindspore.common.mutable import mutable
@ -291,6 +292,9 @@ class _MindsporeFunctionExecutor:
def _parallel_process_for_ms_function(self, phase):
"""Set parameter and optimizer states data according to sliced shape for shard"""
obj = self.shard_parent_obj if self.obj is None else self.obj
if not isinstance(obj, ms.nn.Cell):
return
obj.parameter_layout_dict = self._graph_executor.get_parameter_layout(phase)
obj.parallel_parameter_name_list = self._graph_executor.get_parallel_parameter_name_list(phase)
replace = obj.init_parameters_data(auto_parallel_mode=True)
@ -342,7 +346,7 @@ class _MindsporeFunctionExecutor:
str(self.fn.__code__.co_firstlineno) + '.' + str(id(self.fn))
if _pynative_executor.grad_flag():
generate_name = generate_name + ".grad"
if is_pynative_parallel():
if _is_pynative_parallel():
generate_name = generate_name[:generate_name.rfind(str(id(self.fn)))] + str(id(self.shard_parent_obj))
# Add key with obj
@ -383,11 +387,11 @@ class _MindsporeFunctionExecutor:
is_compile = self._graph_executor.compile(self.obj, compile_args, phase, True)
# init sliced parameter and optimizer state
if is_pynative_parallel() and self.fn.__name__ == _PYNATIVE_PARRALLEL_FUNC_NAME:
if _is_pynative_parallel() and self.fn.__name__ == _PYNATIVE_PARRALLEL_FUNC_NAME:
self._parallel_process_for_ms_function(phase)
# init the rest optimizer states
if is_pynative_parallel() and _pynative_executor.get_optimizer():
if _is_pynative_parallel() and _pynative_executor.get_optimizer():
opt_states = _pynative_executor.get_optimizer().trainable_params()
self._optimizer_state_init(opt_states)
@ -591,7 +595,7 @@ def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
process_obj = args[0]
# only the function or cell instance wrapped by shard will fall into this branch
if is_pynative_parallel() and func.__name__ == _PYNATIVE_PARRALLEL_FUNC_NAME:
if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARRALLEL_FUNC_NAME:
process_obj = args[0]
args = args[1:]
out = _MindsporeFunctionExecutor(func, hash_obj, input_signature, process_obj, jit_config)(*args)
@ -668,13 +672,6 @@ def _function_forbid_reuse(func):
return func
def is_pynative_parallel():
run_mode = context.get_context('mode')
parallel_mode = context.get_auto_parallel_context('parallel_mode')
return run_mode == context.PYNATIVE_MODE and parallel_mode in (
context.ParallelMode.SEMI_AUTO_PARALLEL, context.ParallelMode.AUTO_PARALLEL)
def _get_auto_split_param_names(parameter_layout_dict):
auto_split_param_names = []
for key, value in parameter_layout_dict.items():

View File

@ -20,12 +20,11 @@ from types import FunctionType, MethodType
from mindspore import log as logger
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean,\
_get_parallel_mode, _get_enable_parallel_optimizer
_get_parallel_mode, _get_enable_parallel_optimizer, _is_pynative_parallel
from mindspore.context import ParallelMode
from mindspore._checkparam import Validator as validator
from mindspore import ops, nn
from mindspore.common import dtype as mstype
from mindspore.common.api import is_pynative_parallel
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.ops.primitive import constexpr
from mindspore.ops import composite as C
@ -359,7 +358,7 @@ class TrainOneStepCell(Cell):
self.grad_reducer = F.identity
self.parallel_mode = _get_parallel_mode()
self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL) or \
is_pynative_parallel()
_is_pynative_parallel()
if self.reducer_flag:
self.mean = _get_gradients_mean()
self.degree = _get_device_num()

View File

@ -22,6 +22,7 @@ import mindspore as ms
import mindspore.nn as nn
from mindspore import context
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.parallel._utils import _sens_divided_by_device_num_if_recomputation
from mindspore import log as logger
from ..._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \
TupleAdd_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_, \
@ -381,6 +382,7 @@ class GradOperation(GradOperation_):
elif self.pynative_:
@_wrap_func
def after_grad(*args, **kwargs):
args, kwargs = _sens_divided_by_device_num_if_recomputation(grad_.sens_param, args, kwargs)
self._pynative_forward_run(fn, grad_, args, kwargs)
_pynative_executor.grad(fn, grad_, weights, self.grad_position, *args, **kwargs)
out = _pynative_executor(fn, grad_.sens_param, *args, **kwargs)
@ -878,10 +880,8 @@ class Shard(Shard_):
if context.get_context("mode") != context.PYNATIVE_MODE or \
context.get_auto_parallel_context("parallel_mode") not in ["auto_parallel"]:
raise AssertionError(f"'Shard' only supports auto parallel under PyNative mode")
if context.get_context("device_target") not in ["Ascend"]:
raise AssertionError(f"'Shard' now only supports 'Ascend'")
if context.get_auto_parallel_context("full_batch"):
raise AssertionError(f"'Shard' doesn't support 'full_batch'. Please set 'full_batch' as False")
if context.get_context("device_target") not in ["Ascend", "GPU"]:
raise AssertionError(f"'Shard' now only supports 'Ascend' and 'GPU'")
if context.get_auto_parallel_context("search_mode") != "sharding_propagation":
raise AssertionError(f"'search_mode' must be 'sharding_propagation' for 'Shard'")
if not isinstance(in_strategy, tuple):

View File

@ -20,6 +20,7 @@
import numpy as np
import mindspore as ms
from mindspore import log as logger
from mindspore.common._register_for_tensor import tensor_operator_registry
from mindspore.common import Tensor
from mindspore.common._decorator import deprecated
@ -80,8 +81,9 @@ shard_fn = Shard()
def shard(fn, in_strategy, out_strategy, parameter_plan=None, device="Ascend", level=0):
"""Apply distributed process for fn"""
if not isinstance(fn, ms.nn.Cell):
raise TypeError(f"Type of fn must be 'Cell', but got type {type(fn)}")
if not isinstance(fn, (ms.nn.Cell)):
logger.warning("'fn' is not a mindspore.nn.Cell, and when it derivable contains parameters, "
"the gradient calculation may be incorrect.")
return shard_fn(fn, in_strategy, out_strategy, parameter_plan, device, level)

View File

@ -43,6 +43,13 @@ def _is_in_auto_parallel_mode():
return _get_parallel_mode() in [ms.ParallelMode.SEMI_AUTO_PARALLEL, ms.ParallelMode.AUTO_PARALLEL]
def _is_pynative_parallel():
run_mode = context.get_context('mode')
parallel_mode = context.get_auto_parallel_context('parallel_mode')
return run_mode == context.PYNATIVE_MODE and parallel_mode in (
context.ParallelMode.SEMI_AUTO_PARALLEL, context.ParallelMode.AUTO_PARALLEL)
def _get_full_batch():
"""Get whether to use full_batch."""
return auto_parallel_context().get_full_batch()
@ -380,3 +387,34 @@ def _infer_rank_list(train_map, predict_map=None):
else:
ret[param_name] = (rank_list, False)
return ret
def _sens_divided_by_device_num_if_recomputation(sens_param, args, kwargs):
"""
If in pynative parallel and full_batch is True, divide sens by device num to ensure that the gradients is right.
"""
if not _is_pynative_parallel() or not _get_full_batch():
return args, kwargs
if not sens_param:
logger.warning(
"When pynative parallel and full_batch=True, the 'sens_param' should be set to True, "
"otherwise the gradients may be wrong.")
return args, kwargs
device_num = _get_device_num()
logger.info(f"When pynative_parallel and full_batch=True, "
f"the 'sens' will be divided by device num({device_num})")
sens = kwargs['sens'] if 'sens' in kwargs.keys() else args[-1]
if isinstance(sens, tuple):
new_sens = ()
for item in sens:
new_sens += (item / device_num,)
else:
new_sens = sens / device_num
if not 'sens' in kwargs.keys():
args = args[:-1] + (new_sens,)
else:
kwargs['sens'] = new_sens
return args, kwargs

View File

@ -31,13 +31,6 @@ from mindspore.train.model import Model
from mindspore.context import ParallelMode
import mindspore.dataset as ds
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend", max_device_memory="25GB")
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
init()
context.set_auto_parallel_context(gradients_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL,
search_mode="sharding_propagation", device_num=8)
np.random.seed(42)
def weight_variable():
return TruncatedNormal(0.01)
@ -366,12 +359,7 @@ class ModelCallback(Callback):
self.loss_list.append(result.asnumpy().mean())
def test_train_feed(num_classes=65536):
'''
Feature: shard function for cell to enable parallel execution under PyNative mode
Description: Test a shrunk version of ResNet50 with a alternative execution of shard and pynative
Expectation: Run success
'''
def train_feed(num_classes, expect_out):
parallel_callback = ModelCallback()
data_gen = DataGenerator()
_, input_part = data_gen.input_data((32 * 8, 3, 224, 224))
@ -385,6 +373,35 @@ def test_train_feed(num_classes=65536):
model = Model(net, loss_fn=loss, optimizer=opt)
model.train(3, dataset, dataset_sink_mode=False, callbacks=parallel_callback)
loss_value = np.array(parallel_callback.loss_list)
expect_out = [11.259036, 11.015917, 10.599615]
print(loss_value)
assert np.allclose(loss_value, expect_out, 0.0001, 0.0001)
def test_train_feed_ascend():
'''
Feature: shard function for cell to enable parallel execution under PyNative mode in Ascend
Description: Test a shrunk version of ResNet50 with a alternative execution of shard and pynative
Expectation: Run success
'''
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend", max_device_memory="25GB")
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
init()
context.set_auto_parallel_context(gradients_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL,
search_mode="sharding_propagation", device_num=8)
np.random.seed(42)
train_feed(num_classes=65536, expect_out=[11.259036, 11.015917, 10.599615])
def test_train_feed_gpu():
'''
Feature: shard function for cell to enable parallel execution under PyNative mode in GPU
Description: Test a shrunk version of ResNet50 with a alternative execution of shard and pynative
Expectation: Run success
'''
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
init()
context.set_auto_parallel_context(gradients_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL,
search_mode="sharding_propagation", device_num=8)
np.random.seed(42)
train_feed(num_classes=65536, expect_out=[54.420227, 54.950275, 54.788376])

View File

@ -16,15 +16,29 @@
import os
import pytest
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_single
def test_cell_shard():
def test_cell_shard_ascend():
'''
Feature: shard function for cell to enable parallel execution under PyNative mode
Description: Test a shrunk version of ResNet50 with a alternative execution of shard and pynative
Expectation: Run success
'''
ret = os.system("mpirun -n 8 --allow-run-as-root pytest -s -v cell_shard.py")
ret = os.system("mpirun -n 8 --allow-run-as-root pytest -s -v cell_shard.py::test_train_feed_ascend")
assert ret == 0
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_single
def test_cell_shard_gpu():
'''
Feature: shard function for cell to enable parallel execution under PyNative mode
Description: Test a shrunk version of ResNet50 with a alternative execution of shard and pynative
Expectation: Run success
'''
ret = os.system("mpirun -n 8 --allow-run-as-root pytest -s -v cell_shard.py::test_train_feed_gpu")
assert ret == 0