forked from mindspore-Ecosystem/mindspore
add model init api to compile df graph before exec
This commit is contained in:
parent
1ba8e052f8
commit
1ebf98b795
|
@ -383,6 +383,10 @@ class _Executor:
|
||||||
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
|
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
|
||||||
obj.load_parameter_slice(params)
|
obj.load_parameter_slice(params)
|
||||||
|
|
||||||
|
# set parallel inputs in sink mode
|
||||||
|
if auto_parallel_mode and (args and isinstance(args[0], Tensor) and args[0].virtual_flag):
|
||||||
|
obj.set_parallel_input_with_inputs(*args)
|
||||||
|
|
||||||
# the following GE init process is not needed when use vm or ms backend
|
# the following GE init process is not needed when use vm or ms backend
|
||||||
if enable_ge:
|
if enable_ge:
|
||||||
# decide whether to sink based on whether the inputs is virtual or not
|
# decide whether to sink based on whether the inputs is virtual or not
|
||||||
|
|
|
@ -288,6 +288,15 @@ class Cell:
|
||||||
parallel_inputs_run.append(new_tensor)
|
parallel_inputs_run.append(new_tensor)
|
||||||
return tuple(parallel_inputs_run)
|
return tuple(parallel_inputs_run)
|
||||||
|
|
||||||
|
def set_parallel_input_with_inputs(self, *inputs):
|
||||||
|
"""
|
||||||
|
Slice inputs tensors by parallel strategies, and set the sliced inputs to `_parallel_input_run`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (tuple): inputs of construct method.
|
||||||
|
"""
|
||||||
|
self._parallel_inputs_run = self._load_inputs(*inputs)
|
||||||
|
|
||||||
def _get_construct_inputs_number_and_name(self):
|
def _get_construct_inputs_number_and_name(self):
|
||||||
"""Compute self._construct_inputs_names and self._construct_inputs_num"""
|
"""Compute self._construct_inputs_names and self._construct_inputs_num"""
|
||||||
import inspect
|
import inspect
|
||||||
|
@ -304,6 +313,15 @@ class Cell:
|
||||||
self._construct_inputs_names = self._construct_inputs_names[1:self._construct_inputs_num]
|
self._construct_inputs_names = self._construct_inputs_names[1:self._construct_inputs_num]
|
||||||
self._construct_inputs_num = self._construct_inputs_num - 1
|
self._construct_inputs_num = self._construct_inputs_num - 1
|
||||||
|
|
||||||
|
def compile(self, *inputs):
|
||||||
|
"""
|
||||||
|
Compiles cell.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (tuple): Input parameters.
|
||||||
|
"""
|
||||||
|
_executor.compile(self, *inputs, phase=self.phase, auto_parallel_mode=self._auto_parallel_mode)
|
||||||
|
|
||||||
def compile_and_run(self, *inputs):
|
def compile_and_run(self, *inputs):
|
||||||
"""
|
"""
|
||||||
Compiles and runs cell.
|
Compiles and runs cell.
|
||||||
|
@ -314,13 +332,14 @@ class Cell:
|
||||||
Returns:
|
Returns:
|
||||||
Object, the result of executing.
|
Object, the result of executing.
|
||||||
"""
|
"""
|
||||||
_, compile_flag = _executor.compile(self, *inputs, phase=self.phase,
|
_executor.compile(self, *inputs, phase=self.phase, auto_parallel_mode=self._auto_parallel_mode)
|
||||||
auto_parallel_mode=self._auto_parallel_mode)
|
|
||||||
|
|
||||||
if self._auto_parallel_mode:
|
if self._auto_parallel_mode:
|
||||||
if inputs and isinstance(inputs[0], Tensor) and inputs[0].virtual_flag and (not compile_flag):
|
if inputs and isinstance(inputs[0], Tensor) and inputs[0].virtual_flag:
|
||||||
|
# get parallel inputs in sink mode, parallel inputs set in _executor.compile
|
||||||
parallel_inputs_run = self._parallel_inputs_run
|
parallel_inputs_run = self._parallel_inputs_run
|
||||||
else:
|
else:
|
||||||
|
# set parallel inputs in normal mode
|
||||||
self._parallel_inputs_run = self._load_inputs(*inputs)
|
self._parallel_inputs_run = self._load_inputs(*inputs)
|
||||||
parallel_inputs_run = self._parallel_inputs_run
|
parallel_inputs_run = self._parallel_inputs_run
|
||||||
return _executor(self, *parallel_inputs_run, phase=self.phase)
|
return _executor(self, *parallel_inputs_run, phase=self.phase)
|
||||||
|
|
|
@ -217,6 +217,94 @@ class Model:
|
||||||
scaling_sens /= self._device_number
|
scaling_sens /= self._device_number
|
||||||
return scaling_sens
|
return scaling_sens
|
||||||
|
|
||||||
|
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode):
|
||||||
|
"""Initializes dataset."""
|
||||||
|
need_wrap = False
|
||||||
|
if dataset_sink_mode:
|
||||||
|
# remove later to deal with loop sink
|
||||||
|
if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
|
||||||
|
and not context.get_context("enable_ge"):
|
||||||
|
need_wrap = True
|
||||||
|
|
||||||
|
if not is_train:
|
||||||
|
dataset.__loop_size__ = 1
|
||||||
|
|
||||||
|
dataset_helper = DatasetHelper(dataset, dataset_sink_mode)
|
||||||
|
|
||||||
|
# remove later to deal with loop sink
|
||||||
|
if need_wrap:
|
||||||
|
network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__)
|
||||||
|
network.set_train(is_train)
|
||||||
|
network.phase = phase
|
||||||
|
|
||||||
|
return dataset_helper, network
|
||||||
|
|
||||||
|
def init(self, train_dataset=None, valid_dataset=None):
|
||||||
|
"""
|
||||||
|
Initializes compute graphs and data graphs with sink mode.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Pre-init process only supports `GRAPH_MODE` and `Ascend` target currently.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_dataset (Dataset): A training dataset iterator. If define `train_dataset`, training graphs will be
|
||||||
|
initialized. Default: None.
|
||||||
|
valid_dataset (Dataset): A evaluating dataset iterator. If define `valid_dataset`, evaluation graphs will
|
||||||
|
be initialized, and `metrics` in `Model` can not be None. Default: None.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> train_dataset = get_train_dataset()
|
||||||
|
>>> valid_dataset = get_valid_dataset()
|
||||||
|
>>> net = Net()
|
||||||
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||||
|
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||||
|
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={'acc'})
|
||||||
|
>>> model.init(train_dataset, valid_dataset)
|
||||||
|
>>> model.train(2, train_dataset)
|
||||||
|
>>> model.eval(valid_dataset)
|
||||||
|
"""
|
||||||
|
if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend":
|
||||||
|
raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.')
|
||||||
|
|
||||||
|
if not train_dataset and not valid_dataset:
|
||||||
|
raise ValueError('Both train_dataset and valid_dataset can not be None or empty.')
|
||||||
|
|
||||||
|
_device_number_check(self._parallel_mode, self._device_number)
|
||||||
|
|
||||||
|
if train_dataset:
|
||||||
|
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
|
||||||
|
self._train_network.set_train()
|
||||||
|
self._train_network.phase = 'train'
|
||||||
|
|
||||||
|
if self._parameter_broadcast:
|
||||||
|
self._train_network.set_broadcast_flag()
|
||||||
|
|
||||||
|
train_dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
||||||
|
is_train=True,
|
||||||
|
phase='train',
|
||||||
|
dataset=train_dataset,
|
||||||
|
dataset_sink_mode=True)
|
||||||
|
self._train_network = train_network
|
||||||
|
for inputs in train_dataset_helper:
|
||||||
|
self._train_network.compile(*inputs)
|
||||||
|
break
|
||||||
|
|
||||||
|
if valid_dataset:
|
||||||
|
if not self._metric_fns:
|
||||||
|
raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.')
|
||||||
|
|
||||||
|
self._eval_network.set_train(False)
|
||||||
|
self._eval_network.phase = 'eval'
|
||||||
|
valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
|
||||||
|
is_train=False,
|
||||||
|
phase='eval',
|
||||||
|
dataset=valid_dataset,
|
||||||
|
dataset_sink_mode=True)
|
||||||
|
self._eval_network = eval_network
|
||||||
|
for inputs in valid_dataset_helper:
|
||||||
|
self._eval_network.compile(*inputs)
|
||||||
|
break
|
||||||
|
|
||||||
def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True):
|
def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True):
|
||||||
"""
|
"""
|
||||||
Training.
|
Training.
|
||||||
|
@ -277,21 +365,15 @@ class Model:
|
||||||
list_callback (_ListCallback): Executor of callback list. Default: None.
|
list_callback (_ListCallback): Executor of callback list. Default: None.
|
||||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||||
"""
|
"""
|
||||||
# remove later to deal with loop sink
|
dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
||||||
need_wrap = False
|
is_train=True,
|
||||||
if not hasattr(train_dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
|
phase='train',
|
||||||
and not context.get_context("enable_ge"):
|
dataset=train_dataset,
|
||||||
need_wrap = True
|
dataset_sink_mode=True)
|
||||||
|
self._train_network = train_network
|
||||||
dataset_helper = DatasetHelper(train_dataset)
|
cb_params.train_network = self._train_network
|
||||||
# remove later to deal with loop sink
|
|
||||||
if need_wrap:
|
|
||||||
self._train_network = nn.DataWrapper(self._train_network, *(dataset_helper.types_shapes()),
|
|
||||||
train_dataset.__ME_INITED__)
|
|
||||||
cb_params.train_network = self._train_network
|
|
||||||
self._train_network.set_train()
|
|
||||||
|
|
||||||
cb_params.cur_step_num = 0
|
cb_params.cur_step_num = 0
|
||||||
|
|
||||||
loop_size = dataset_helper.loop_size()
|
loop_size = dataset_helper.loop_size()
|
||||||
run_context = RunContext(cb_params)
|
run_context = RunContext(cb_params)
|
||||||
list_callback.begin(run_context)
|
list_callback.begin(run_context)
|
||||||
|
@ -331,7 +413,11 @@ class Model:
|
||||||
list_callback (_ListCallback): Executor of callback list. Default: None.
|
list_callback (_ListCallback): Executor of callback list. Default: None.
|
||||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||||
"""
|
"""
|
||||||
dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False)
|
dataset_helper, _ = self._exec_preprocess(self._train_network,
|
||||||
|
is_train=True,
|
||||||
|
phase='train',
|
||||||
|
dataset=train_dataset,
|
||||||
|
dataset_sink_mode=False)
|
||||||
cb_params.cur_step_num = 0
|
cb_params.cur_step_num = 0
|
||||||
run_context = RunContext(cb_params)
|
run_context = RunContext(cb_params)
|
||||||
list_callback.begin(run_context)
|
list_callback.begin(run_context)
|
||||||
|
@ -437,26 +523,15 @@ class Model:
|
||||||
Returns:
|
Returns:
|
||||||
Dict, returns the loss value & metrics values for the model in test mode.
|
Dict, returns the loss value & metrics values for the model in test mode.
|
||||||
"""
|
"""
|
||||||
_device_number_check(self._parallel_mode, self._device_number)
|
|
||||||
|
|
||||||
run_context = RunContext(cb_params)
|
run_context = RunContext(cb_params)
|
||||||
|
|
||||||
# remove later to deal with loop sink
|
dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
|
||||||
need_wrap = False
|
is_train=False,
|
||||||
if not hasattr(valid_dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
|
phase='eval',
|
||||||
and not context.get_context("enable_ge"):
|
dataset=valid_dataset,
|
||||||
need_wrap = True
|
dataset_sink_mode=True)
|
||||||
|
self._eval_network = eval_network
|
||||||
valid_dataset.__loop_size__ = 1
|
cb_params.eval_network = self._eval_network
|
||||||
dataset_helper = DatasetHelper(valid_dataset)
|
|
||||||
|
|
||||||
# remove later to deal with loop sink
|
|
||||||
if need_wrap:
|
|
||||||
self._eval_network = nn.DataWrapper(self._eval_network, *(dataset_helper.types_shapes()),
|
|
||||||
valid_dataset.__ME_INITED__)
|
|
||||||
self._eval_network.set_train(mode=False)
|
|
||||||
self._eval_network.phase = 'eval'
|
|
||||||
|
|
||||||
list_callback.begin(run_context)
|
list_callback.begin(run_context)
|
||||||
|
|
||||||
for inputs in dataset_helper:
|
for inputs in dataset_helper:
|
||||||
|
@ -490,7 +565,11 @@ class Model:
|
||||||
run_context = RunContext(cb_params)
|
run_context = RunContext(cb_params)
|
||||||
list_callback.begin(run_context)
|
list_callback.begin(run_context)
|
||||||
|
|
||||||
dataset_helper = DatasetHelper(valid_dataset, dataset_sink_mode=False)
|
dataset_helper, _ = self._exec_preprocess(self._eval_network,
|
||||||
|
is_train=False,
|
||||||
|
phase='eval',
|
||||||
|
dataset=valid_dataset,
|
||||||
|
dataset_sink_mode=False)
|
||||||
for next_element in dataset_helper:
|
for next_element in dataset_helper:
|
||||||
cb_params.cur_step_num += 1
|
cb_params.cur_step_num += 1
|
||||||
list_callback.step_begin(run_context)
|
list_callback.step_begin(run_context)
|
||||||
|
@ -532,6 +611,7 @@ class Model:
|
||||||
>>> model.eval(dataset)
|
>>> model.eval(dataset)
|
||||||
"""
|
"""
|
||||||
check_bool(dataset_sink_mode)
|
check_bool(dataset_sink_mode)
|
||||||
|
_device_number_check(self._parallel_mode, self._device_number)
|
||||||
if not self._metric_fns:
|
if not self._metric_fns:
|
||||||
raise ValueError("metric fn can not be None or empty.")
|
raise ValueError("metric fn can not be None or empty.")
|
||||||
|
|
||||||
|
|
|
@ -68,12 +68,12 @@ class LossNet(nn.Cell):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def get_model():
|
def get_model(metrics=None):
|
||||||
""" get_model """
|
""" get_model """
|
||||||
net = Net()
|
net = Net()
|
||||||
loss = nn.SoftmaxCrossEntropyWithLogits()
|
loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||||
optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||||
model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
|
model = Model(net, loss_fn=loss, optimizer=optim, metrics=metrics)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -215,8 +215,27 @@ def test_model_build_abnormal_string():
|
||||||
assert err
|
assert err
|
||||||
|
|
||||||
|
|
||||||
def test_model_init_error():
|
def test_model_init():
|
||||||
""" test_model_init_error """
|
""" test_model_init_error """
|
||||||
|
train_dataset = get_dataset()
|
||||||
|
eval_dataset = get_dataset()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
get_model().init(train_dataset)
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
get_model().init(train_dataset)
|
||||||
|
get_model(metrics={'acc'}).init(eval_dataset)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
get_model().init(train_dataset, eval_dataset)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
get_model().init()
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_model_error():
|
||||||
|
""" test_init_model_error """
|
||||||
net = nn.ReLU()
|
net = nn.ReLU()
|
||||||
loss = nn.SoftmaxCrossEntropyWithLogits()
|
loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
|
|
Loading…
Reference in New Issue