!4284 remove to_full_tensor and load_inputs in exexute stage

Merge pull request !4284 from yao_yf/remove_to_full_tensor_and_load_inputs_in_exexute_stage
This commit is contained in:
mindspore-ci-bot 2020-08-14 18:05:44 +08:00 committed by Gitee
commit f41c21c5fa
13 changed files with 77 additions and 75 deletions

View File

@ -23,7 +23,7 @@ from mindspore import log as logger
from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, PynativeExecutor_
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend
from .tensor import Tensor as MsTensor
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_tensor
# store ms_function class compiled pipeline cache
ms_compile_cache = {}
@ -402,6 +402,11 @@ class _Executor:
logger.debug("%r graph has existed.", phase)
return phase, False
is_sink_mode = args and isinstance(args[0], Tensor) and args[0].virtual_flag
if auto_parallel_mode and _need_to_full() and not is_sink_mode and obj.auto_parallel_compile_and_run():
args_full = _to_full_tensor(args, _get_device_num(), _get_global_rank())
_, args_list = _generate_pip_args(obj, *args_full)
result = self._executor.compile(obj, args_list, phase, use_vm)
self.compile_cache[phase] = phase
if not result:
@ -423,7 +428,7 @@ class _Executor:
self._updata_param_node_default_input(phase, replace)
# set parallel inputs in sink mode
if auto_parallel_mode and (args and isinstance(args[0], Tensor) and args[0].virtual_flag):
if auto_parallel_mode and is_sink_mode:
obj.set_parallel_input_with_inputs(*args)
# the following GE init process is not needed when use vm or ms backend

View File

@ -31,7 +31,6 @@ from ..ops.functional import cast
from ..parallel._tensor import _load_tensor_by_layout
from ..common.tensor import Tensor
class Cell:
"""
Base class for all neural networks.
@ -87,6 +86,7 @@ class Cell:
self._bprop_debug = False
self._already_run = False
self.cell_type = None
self._auto_parallel_compile_and_run = False
@property
def already_run(self):
@ -445,6 +445,7 @@ class Cell:
Returns:
Object, the result of executing.
"""
self._auto_parallel_compile_and_run = True
_executor.compile(self, *inputs, phase=self.phase, auto_parallel_mode=self._auto_parallel_mode)
if self._auto_parallel_mode:
@ -452,12 +453,13 @@ class Cell:
# get parallel inputs in sink mode, parallel inputs set in _executor.compile
parallel_inputs_run = self._parallel_inputs_run
else:
# set parallel inputs in normal mode
self._parallel_inputs_run = self._load_inputs(*inputs)
parallel_inputs_run = self._parallel_inputs_run
parallel_inputs_run = inputs
return _executor(self, *parallel_inputs_run, phase=self.phase)
return _executor(self, *inputs, phase=self.phase)
def auto_parallel_compile_and_run(self):
return self._auto_parallel_compile_and_run
def exec_checkpoint_graph(self):
"""Executes saving checkpoint graph operation."""
_executor(self, phase='save')

View File

@ -121,9 +121,8 @@ class EmbeddingLookup(Cell):
When 'target' is set to 'DEVICE', this module will use P.GatherV2() which
specified 'axis = 0' to lookup table.
In field slice mode, the manual_shapes should be given. It is a tuple ,where
the element is (vocab[i], offset[i]), vocab[i] is the row numbers for i-th
part and offset[i] is the feature id offset for i-th part. The feature id in
i-th part will be subtracted by offset[i] to ensure the id start from 0.
the element is vocab[i], vocab[i] is the row numbers for i-th
part.
Args:
vocab_size (int): Size of the dictionary of embeddings.

View File

@ -14,7 +14,11 @@
# ============================================================================
"""Utils of auto parallel"""
import numpy as np
from mindspore._c_expression import reset_op_id
from mindspore.common.tensor import Tensor
from mindspore.common.dtype import dtype_to_nptype
from mindspore.common import dtype as mstype
from mindspore.communication.management import get_group_size, get_rank
from mindspore.parallel._auto_parallel_context import auto_parallel_context
@ -37,6 +41,52 @@ def _need_to_full():
and (not full_batch))
return need
def _to_full_shapes(shapes, device_num):
"""Expanding batch dimension according to device_num, adapt to mindspore minddata graph solution."""
new_shapes = []
for shape in shapes:
new_shape = ()
for i, item in enumerate(shape):
if i == 0:
new_shape += (item * device_num,)
else:
new_shape += (item,)
new_shapes.append(new_shape)
return new_shapes
def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None):
"""Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data
from host solution."""
lst = []
if not isinstance(elem, (tuple, list)):
elem = [elem]
if global_rank >= device_num:
raise ValueError("The global rank must be smaller than device number, the global rank is {}, "
"the device num is {}".format(global_rank, device_num))
for data in elem:
if isinstance(data, np.ndarray):
data = Tensor(data)
if not isinstance(data, Tensor):
raise ValueError("elements in tensors must be Tensor")
shape_ = data.shape
type_ = data.dtype
new_shape = ()
batchsize_per_device = 1
for i, item in enumerate(shape_):
if i == 0:
new_shape += (item * device_num,)
batchsize_per_device = item
else:
new_shape += (item,)
new_tensor_numpy = np.zeros(new_shape, dtype_to_nptype(type_))
start = global_rank * batchsize_per_device
new_tensor_numpy[start: start + batchsize_per_device] = data.asnumpy()
new_tensor = Tensor(new_tensor_numpy)
lst.append(new_tensor)
if scaling_sens:
lst.append(Tensor(scaling_sens, mstype.float32))
return tuple(lst)
def _get_mirror_mean():
"""Get if using mirror_mean."""

View File

@ -145,41 +145,6 @@ def _to_tensor(elem, scaling_sens=None):
return lst[0] if len(lst) == 1 else tuple(lst)
def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None):
"""Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data
from host solution."""
lst = []
if not isinstance(elem, (tuple, list)):
elem = [elem]
if global_rank >= device_num:
raise ValueError("The global rank must be smaller than device number, the global rank is {}, "
"the device num is {}".format(global_rank, device_num))
for data in elem:
if isinstance(data, np.ndarray):
data = Tensor(data)
if not isinstance(data, Tensor):
raise ValueError("elements in tensors must be Tensor")
shape_ = data.shape
type_ = data.dtype
new_shape = ()
batchsize_per_device = 1
for i, item in enumerate(shape_):
if i == 0:
new_shape += (item * device_num,)
batchsize_per_device = item
else:
new_shape += (item,)
new_tensor_numpy = np.zeros(new_shape, dtype_to_nptype(type_))
start = global_rank * batchsize_per_device
new_tensor_numpy[start: start + batchsize_per_device] = data.asnumpy()
new_tensor = Tensor(new_tensor_numpy)
lst.append(new_tensor)
if scaling_sens:
lst.append(Tensor(scaling_sens, mstype.float32))
return tuple(lst)
def _construct_input_tensors(dataset_types, dataset_shapes, device_number=1):
"""Construct tensor list to initialize the network which implemented in dataset sink."""
tensor_list_run = _construct_tensor_list(dataset_types, dataset_shapes, batch_expand_num=1)
@ -187,20 +152,6 @@ def _construct_input_tensors(dataset_types, dataset_shapes, device_number=1):
return tensor_list_run, tensor_list_compile
def _to_full_shapes(shapes, device_num):
"""Expanding batch dimension according to device_num, adapt to mindspore minddata graph solution."""
new_shapes = []
for shape in shapes:
new_shape = ()
for i, item in enumerate(shape):
if i == 0:
new_shape += (item * device_num,)
else:
new_shape += (item,)
new_shapes.append(new_shape)
return new_shapes
def _check_to_numpy(plugin, tensor):
"""Check the tensor and return a numpy.ndarray."""
np_value = tensor.asnumpy()

View File

@ -19,9 +19,9 @@ import os
from mindspore._checkparam import check_bool, check_int
from .. import context
from ._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \
_construct_tensor_list, _to_full_shapes, _to_full_tensor
_construct_tensor_list
from ..nn.wrap import GetNextSingleOp
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_shapes
def _send_data(dataset, epoch_num):
@ -236,6 +236,4 @@ class _DatasetIterNormal:
def __next__(self):
data = self.iter.__next__()
if _need_to_full():
return _to_full_tensor(data, self.device_num, self.global_rank)
return _to_tensor(data)

View File

@ -31,8 +31,7 @@ from ..nn.metrics import Loss
from .. import nn
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from .parallel_utils import ParallelMode
from ._utils import _to_full_tensor
from ..parallel._utils import _need_to_full
from ..parallel._utils import _need_to_full, _to_full_tensor
from ..common import dtype as mstype
from .dataset_helper import DatasetHelper
from . import amp

View File

@ -15,12 +15,11 @@
"""Dataset help for minddata dataset"""
import math
import os
from mindspore._checkparam import check_bool, check_int
from mindspore import context
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, _to_full_shapes
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
from mindspore.nn.wrap import GetNextSingleOp
from mindspore.parallel._utils import _get_device_num, _need_to_full
from mindspore.parallel._utils import _get_device_num, _need_to_full, _to_full_shapes
def _send_data(dataset, epoch_num):

View File

@ -17,8 +17,8 @@ import os
from mindspore import context
from mindspore._checkparam import check_bool, check_int
from mindspore.parallel._utils import _get_device_num, _need_to_full
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, _to_full_shapes
from mindspore.parallel._utils import _get_device_num, _need_to_full, _to_full_shapes
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
def _send_data(dataset, epoch_num):

View File

@ -34,7 +34,7 @@ from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
from mindspore.parallel._utils import _need_to_full
from mindspore.train import amp
from mindspore.train._utils import _to_full_tensor
from mindspore.parallel._utils import _to_full_tensor
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager
from mindspore.train.parallel_utils import ParallelMode
from .dataset_helper import DatasetHelper

View File

@ -117,7 +117,7 @@ def train_and_eval(config):
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config, host_device_mix=host_device_mix)
callback = LossCallBack(config=config)
callback = LossCallBack(config=config, per_print_times=20)
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
directory=config.ckpt_path, config=ckptconfig)

View File

@ -14,9 +14,8 @@
# ============================================================================
"""Dataset help for minddata dataset"""
from mindspore._checkparam import check_bool
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
_to_full_shapes
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _to_full_shapes
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
from mindspore.train.parallel_utils import ParallelMode
def _send_data(dataset):

View File

@ -16,7 +16,7 @@ import numpy as np
import mindspore as ms
from mindspore import Tensor
from mindspore.train._utils import _to_full_shapes, _to_full_tensor
from mindspore.parallel._utils import _to_full_shapes, _to_full_tensor
def test_to_full_shapes():