!41590 Add repeat computation support for pynative shard
Merge pull request !41590 from liuluobin/master_repeat
This commit is contained in:
commit
e4884bff6e
|
@ -418,7 +418,6 @@ static bool SetStrategyForShard(const FuncGraphPtr &root, const std::vector<AnfN
|
|||
constexpr size_t kShardParameterPlanIndex = 4;
|
||||
for (auto &node : all_nodes) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimShard)) {
|
||||
root->set_flag(kPynativeShard, true);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto vnode = cnode->input(kShardFnIndex)->cast<ValueNodePtr>();
|
||||
auto in_strategy = cnode->input(kShardInStrategyIndex);
|
||||
|
|
|
@ -128,7 +128,7 @@ static CNodePtr CreateVirtualDataset(const FuncGraphPtr &func_graph) {
|
|||
|
||||
static std::set<FuncGraphPtr> FindForwardGraph(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
|
||||
std::set<FuncGraphPtr> graph_sets;
|
||||
if (!parallel::IsAutoParallelCareGraph(root) && !parallel::IsPynativeParallel()) {
|
||||
if (!parallel::IsAutoParallelCareGraph(root)) {
|
||||
return graph_sets;
|
||||
}
|
||||
std::set<AnfNodePtr> input_parameters;
|
||||
|
@ -217,6 +217,16 @@ static void InsertVirtualDataset(const FuncGraphPtr &root, const std::vector<Anf
|
|||
}
|
||||
}
|
||||
|
||||
// If graph has shard node, set flag 'kPynativeShard' for root graph
|
||||
void SetPynativeShardFlagIfHasShardNode(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
|
||||
for (auto &node : all_nodes) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimShard)) {
|
||||
root->set_flag(parallel::kPynativeShard, true);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only auto_parallel and semi_auto_parallel support PipelineSplit
|
||||
bool PipelineSplit(const ResourcePtr &res) {
|
||||
#ifdef WITH_BACKEND
|
||||
|
@ -237,6 +247,7 @@ bool PipelineSplit(const ResourcePtr &res) {
|
|||
MS_EXCEPTION_IF_NULL(ret);
|
||||
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
|
||||
|
||||
SetPynativeShardFlagIfHasShardNode(root, all_nodes);
|
||||
if (!HasVirtualDataset(all_nodes)) {
|
||||
InsertVirtualDataset(root, all_nodes);
|
||||
}
|
||||
|
|
|
@ -40,7 +40,8 @@ from mindspore._c_expression import GraphExecutor_, Tensor, MetaTensor, CSRTenso
|
|||
_ms_memory_recycle
|
||||
from mindspore.parallel._tensor import _load_tensor_by_layout
|
||||
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched, _enable_distributed_mindrt
|
||||
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _get_pipeline_stages
|
||||
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _get_pipeline_stages, \
|
||||
_is_pynative_parallel
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.common._utils import is_shape_unknown
|
||||
from mindspore.common.mutable import mutable
|
||||
|
@ -291,6 +292,9 @@ class _MindsporeFunctionExecutor:
|
|||
def _parallel_process_for_ms_function(self, phase):
|
||||
"""Set parameter and optimizer states data according to sliced shape for shard"""
|
||||
obj = self.shard_parent_obj if self.obj is None else self.obj
|
||||
if not isinstance(obj, ms.nn.Cell):
|
||||
return
|
||||
|
||||
obj.parameter_layout_dict = self._graph_executor.get_parameter_layout(phase)
|
||||
obj.parallel_parameter_name_list = self._graph_executor.get_parallel_parameter_name_list(phase)
|
||||
replace = obj.init_parameters_data(auto_parallel_mode=True)
|
||||
|
@ -342,7 +346,7 @@ class _MindsporeFunctionExecutor:
|
|||
str(self.fn.__code__.co_firstlineno) + '.' + str(id(self.fn))
|
||||
if _pynative_executor.grad_flag():
|
||||
generate_name = generate_name + ".grad"
|
||||
if is_pynative_parallel():
|
||||
if _is_pynative_parallel():
|
||||
generate_name = generate_name[:generate_name.rfind(str(id(self.fn)))] + str(id(self.shard_parent_obj))
|
||||
|
||||
# Add key with obj
|
||||
|
@ -383,11 +387,11 @@ class _MindsporeFunctionExecutor:
|
|||
is_compile = self._graph_executor.compile(self.obj, compile_args, phase, True)
|
||||
|
||||
# init sliced parameter and optimizer state
|
||||
if is_pynative_parallel() and self.fn.__name__ == _PYNATIVE_PARRALLEL_FUNC_NAME:
|
||||
if _is_pynative_parallel() and self.fn.__name__ == _PYNATIVE_PARRALLEL_FUNC_NAME:
|
||||
self._parallel_process_for_ms_function(phase)
|
||||
|
||||
# init the rest optimizer states
|
||||
if is_pynative_parallel() and _pynative_executor.get_optimizer():
|
||||
if _is_pynative_parallel() and _pynative_executor.get_optimizer():
|
||||
opt_states = _pynative_executor.get_optimizer().trainable_params()
|
||||
self._optimizer_state_init(opt_states)
|
||||
|
||||
|
@ -591,7 +595,7 @@ def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
|
|||
if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
|
||||
process_obj = args[0]
|
||||
# only the function or cell instance wrapped by shard will fall into this branch
|
||||
if is_pynative_parallel() and func.__name__ == _PYNATIVE_PARRALLEL_FUNC_NAME:
|
||||
if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARRALLEL_FUNC_NAME:
|
||||
process_obj = args[0]
|
||||
args = args[1:]
|
||||
out = _MindsporeFunctionExecutor(func, hash_obj, input_signature, process_obj, jit_config)(*args)
|
||||
|
@ -668,13 +672,6 @@ def _function_forbid_reuse(func):
|
|||
return func
|
||||
|
||||
|
||||
def is_pynative_parallel():
|
||||
run_mode = context.get_context('mode')
|
||||
parallel_mode = context.get_auto_parallel_context('parallel_mode')
|
||||
return run_mode == context.PYNATIVE_MODE and parallel_mode in (
|
||||
context.ParallelMode.SEMI_AUTO_PARALLEL, context.ParallelMode.AUTO_PARALLEL)
|
||||
|
||||
|
||||
def _get_auto_split_param_names(parameter_layout_dict):
|
||||
auto_split_param_names = []
|
||||
for key, value in parameter_layout_dict.items():
|
||||
|
|
|
@ -20,12 +20,11 @@ from types import FunctionType, MethodType
|
|||
|
||||
from mindspore import log as logger
|
||||
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean,\
|
||||
_get_parallel_mode, _get_enable_parallel_optimizer
|
||||
_get_parallel_mode, _get_enable_parallel_optimizer, _is_pynative_parallel
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore import ops, nn
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.api import is_pynative_parallel
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops import composite as C
|
||||
|
@ -359,7 +358,7 @@ class TrainOneStepCell(Cell):
|
|||
self.grad_reducer = F.identity
|
||||
self.parallel_mode = _get_parallel_mode()
|
||||
self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL) or \
|
||||
is_pynative_parallel()
|
||||
_is_pynative_parallel()
|
||||
if self.reducer_flag:
|
||||
self.mean = _get_gradients_mean()
|
||||
self.degree = _get_device_num()
|
||||
|
|
|
@ -22,6 +22,7 @@ import mindspore as ms
|
|||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.parallel._utils import _sens_divided_by_device_num_if_recomputation
|
||||
from mindspore import log as logger
|
||||
from ..._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \
|
||||
TupleAdd_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_, \
|
||||
|
@ -381,6 +382,7 @@ class GradOperation(GradOperation_):
|
|||
elif self.pynative_:
|
||||
@_wrap_func
|
||||
def after_grad(*args, **kwargs):
|
||||
args, kwargs = _sens_divided_by_device_num_if_recomputation(grad_.sens_param, args, kwargs)
|
||||
self._pynative_forward_run(fn, grad_, args, kwargs)
|
||||
_pynative_executor.grad(fn, grad_, weights, self.grad_position, *args, **kwargs)
|
||||
out = _pynative_executor(fn, grad_.sens_param, *args, **kwargs)
|
||||
|
@ -878,10 +880,8 @@ class Shard(Shard_):
|
|||
if context.get_context("mode") != context.PYNATIVE_MODE or \
|
||||
context.get_auto_parallel_context("parallel_mode") not in ["auto_parallel"]:
|
||||
raise AssertionError(f"'Shard' only supports auto parallel under PyNative mode")
|
||||
if context.get_context("device_target") not in ["Ascend"]:
|
||||
raise AssertionError(f"'Shard' now only supports 'Ascend'")
|
||||
if context.get_auto_parallel_context("full_batch"):
|
||||
raise AssertionError(f"'Shard' doesn't support 'full_batch'. Please set 'full_batch' as False")
|
||||
if context.get_context("device_target") not in ["Ascend", "GPU"]:
|
||||
raise AssertionError(f"'Shard' now only supports 'Ascend' and 'GPU'")
|
||||
if context.get_auto_parallel_context("search_mode") != "sharding_propagation":
|
||||
raise AssertionError(f"'search_mode' must be 'sharding_propagation' for 'Shard'")
|
||||
if not isinstance(in_strategy, tuple):
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import log as logger
|
||||
from mindspore.common._register_for_tensor import tensor_operator_registry
|
||||
from mindspore.common import Tensor
|
||||
from mindspore.common._decorator import deprecated
|
||||
|
@ -80,8 +81,9 @@ shard_fn = Shard()
|
|||
|
||||
def shard(fn, in_strategy, out_strategy, parameter_plan=None, device="Ascend", level=0):
|
||||
"""Apply distributed process for fn"""
|
||||
if not isinstance(fn, ms.nn.Cell):
|
||||
raise TypeError(f"Type of fn must be 'Cell', but got type {type(fn)}")
|
||||
if not isinstance(fn, (ms.nn.Cell)):
|
||||
logger.warning("'fn' is not a mindspore.nn.Cell, and when it derivable contains parameters, "
|
||||
"the gradient calculation may be incorrect.")
|
||||
return shard_fn(fn, in_strategy, out_strategy, parameter_plan, device, level)
|
||||
|
||||
|
||||
|
|
|
@ -43,6 +43,13 @@ def _is_in_auto_parallel_mode():
|
|||
return _get_parallel_mode() in [ms.ParallelMode.SEMI_AUTO_PARALLEL, ms.ParallelMode.AUTO_PARALLEL]
|
||||
|
||||
|
||||
def _is_pynative_parallel():
|
||||
run_mode = context.get_context('mode')
|
||||
parallel_mode = context.get_auto_parallel_context('parallel_mode')
|
||||
return run_mode == context.PYNATIVE_MODE and parallel_mode in (
|
||||
context.ParallelMode.SEMI_AUTO_PARALLEL, context.ParallelMode.AUTO_PARALLEL)
|
||||
|
||||
|
||||
def _get_full_batch():
|
||||
"""Get whether to use full_batch."""
|
||||
return auto_parallel_context().get_full_batch()
|
||||
|
@ -380,3 +387,34 @@ def _infer_rank_list(train_map, predict_map=None):
|
|||
else:
|
||||
ret[param_name] = (rank_list, False)
|
||||
return ret
|
||||
|
||||
|
||||
def _sens_divided_by_device_num_if_recomputation(sens_param, args, kwargs):
|
||||
"""
|
||||
If in pynative parallel and full_batch is True, divide sens by device num to ensure that the gradients is right.
|
||||
"""
|
||||
if not _is_pynative_parallel() or not _get_full_batch():
|
||||
return args, kwargs
|
||||
if not sens_param:
|
||||
logger.warning(
|
||||
"When pynative parallel and full_batch=True, the 'sens_param' should be set to True, "
|
||||
"otherwise the gradients may be wrong.")
|
||||
return args, kwargs
|
||||
|
||||
device_num = _get_device_num()
|
||||
logger.info(f"When pynative_parallel and full_batch=True, "
|
||||
f"the 'sens' will be divided by device num({device_num})")
|
||||
sens = kwargs['sens'] if 'sens' in kwargs.keys() else args[-1]
|
||||
|
||||
if isinstance(sens, tuple):
|
||||
new_sens = ()
|
||||
for item in sens:
|
||||
new_sens += (item / device_num,)
|
||||
else:
|
||||
new_sens = sens / device_num
|
||||
|
||||
if not 'sens' in kwargs.keys():
|
||||
args = args[:-1] + (new_sens,)
|
||||
else:
|
||||
kwargs['sens'] = new_sens
|
||||
return args, kwargs
|
||||
|
|
|
@ -31,13 +31,6 @@ from mindspore.train.model import Model
|
|||
from mindspore.context import ParallelMode
|
||||
import mindspore.dataset as ds
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend", max_device_memory="25GB")
|
||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||
init()
|
||||
context.set_auto_parallel_context(gradients_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL,
|
||||
search_mode="sharding_propagation", device_num=8)
|
||||
np.random.seed(42)
|
||||
|
||||
|
||||
def weight_variable():
|
||||
return TruncatedNormal(0.01)
|
||||
|
@ -366,12 +359,7 @@ class ModelCallback(Callback):
|
|||
self.loss_list.append(result.asnumpy().mean())
|
||||
|
||||
|
||||
def test_train_feed(num_classes=65536):
|
||||
'''
|
||||
Feature: shard function for cell to enable parallel execution under PyNative mode
|
||||
Description: Test a shrunk version of ResNet50 with a alternative execution of shard and pynative
|
||||
Expectation: Run success
|
||||
'''
|
||||
def train_feed(num_classes, expect_out):
|
||||
parallel_callback = ModelCallback()
|
||||
data_gen = DataGenerator()
|
||||
_, input_part = data_gen.input_data((32 * 8, 3, 224, 224))
|
||||
|
@ -385,6 +373,35 @@ def test_train_feed(num_classes=65536):
|
|||
model = Model(net, loss_fn=loss, optimizer=opt)
|
||||
model.train(3, dataset, dataset_sink_mode=False, callbacks=parallel_callback)
|
||||
loss_value = np.array(parallel_callback.loss_list)
|
||||
expect_out = [11.259036, 11.015917, 10.599615]
|
||||
print(loss_value)
|
||||
assert np.allclose(loss_value, expect_out, 0.0001, 0.0001)
|
||||
|
||||
|
||||
def test_train_feed_ascend():
|
||||
'''
|
||||
Feature: shard function for cell to enable parallel execution under PyNative mode in Ascend
|
||||
Description: Test a shrunk version of ResNet50 with a alternative execution of shard and pynative
|
||||
Expectation: Run success
|
||||
'''
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend", max_device_memory="25GB")
|
||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||
init()
|
||||
context.set_auto_parallel_context(gradients_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL,
|
||||
search_mode="sharding_propagation", device_num=8)
|
||||
np.random.seed(42)
|
||||
train_feed(num_classes=65536, expect_out=[11.259036, 11.015917, 10.599615])
|
||||
|
||||
|
||||
def test_train_feed_gpu():
|
||||
'''
|
||||
Feature: shard function for cell to enable parallel execution under PyNative mode in GPU
|
||||
Description: Test a shrunk version of ResNet50 with a alternative execution of shard and pynative
|
||||
Expectation: Run success
|
||||
'''
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||
init()
|
||||
context.set_auto_parallel_context(gradients_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL,
|
||||
search_mode="sharding_propagation", device_num=8)
|
||||
np.random.seed(42)
|
||||
train_feed(num_classes=65536, expect_out=[54.420227, 54.950275, 54.788376])
|
||||
|
|
|
@ -16,15 +16,29 @@
|
|||
import os
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_single
|
||||
def test_cell_shard():
|
||||
def test_cell_shard_ascend():
|
||||
'''
|
||||
Feature: shard function for cell to enable parallel execution under PyNative mode
|
||||
Description: Test a shrunk version of ResNet50 with a alternative execution of shard and pynative
|
||||
Expectation: Run success
|
||||
'''
|
||||
ret = os.system("mpirun -n 8 --allow-run-as-root pytest -s -v cell_shard.py")
|
||||
ret = os.system("mpirun -n 8 --allow-run-as-root pytest -s -v cell_shard.py::test_train_feed_ascend")
|
||||
assert ret == 0
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_single
|
||||
def test_cell_shard_gpu():
|
||||
'''
|
||||
Feature: shard function for cell to enable parallel execution under PyNative mode
|
||||
Description: Test a shrunk version of ResNet50 with a alternative execution of shard and pynative
|
||||
Expectation: Run success
|
||||
'''
|
||||
ret = os.system("mpirun -n 8 --allow-run-as-root pytest -s -v cell_shard.py::test_train_feed_gpu")
|
||||
assert ret == 0
|
||||
|
|
Loading…
Reference in New Issue