forked from mindspore-Ecosystem/mindspore
!6651 Modify init interface to internal interface
Merge pull request !6651 from lijiaqi/init
This commit is contained in:
commit
091ee5085e
|
@ -267,7 +267,7 @@ class Model:
|
||||||
|
|
||||||
return dataset_helper, network
|
return dataset_helper, network
|
||||||
|
|
||||||
def init(self, train_dataset=None, valid_dataset=None):
|
def _init(self, train_dataset=None, valid_dataset=None, sink_size=-1):
|
||||||
"""
|
"""
|
||||||
Initialize compute graphs and data graphs with the sink mode.
|
Initialize compute graphs and data graphs with the sink mode.
|
||||||
|
|
||||||
|
@ -279,17 +279,7 @@ class Model:
|
||||||
initialized. Default: None.
|
initialized. Default: None.
|
||||||
valid_dataset (Dataset): A evaluating dataset iterator. If `valid_dataset` is defined, evaluation graphs
|
valid_dataset (Dataset): A evaluating dataset iterator. If `valid_dataset` is defined, evaluation graphs
|
||||||
will be initialized, and `metrics` in `Model` can not be None. Default: None.
|
will be initialized, and `metrics` in `Model` can not be None. Default: None.
|
||||||
|
sink_size (int): Control the amount of data in each sink. Default: -1.
|
||||||
Examples:
|
|
||||||
>>> train_dataset = get_train_dataset()
|
|
||||||
>>> valid_dataset = get_valid_dataset()
|
|
||||||
>>> net = Net()
|
|
||||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
|
||||||
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
||||||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={'acc'})
|
|
||||||
>>> model.init(train_dataset, valid_dataset)
|
|
||||||
>>> model.train(2, train_dataset)
|
|
||||||
>>> model.eval(valid_dataset)
|
|
||||||
"""
|
"""
|
||||||
if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend":
|
if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend":
|
||||||
raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.')
|
raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.')
|
||||||
|
@ -309,7 +299,8 @@ class Model:
|
||||||
is_train=True,
|
is_train=True,
|
||||||
phase='train',
|
phase='train',
|
||||||
dataset=train_dataset,
|
dataset=train_dataset,
|
||||||
dataset_sink_mode=True)
|
dataset_sink_mode=True,
|
||||||
|
sink_size=sink_size)
|
||||||
self._train_network = train_network
|
self._train_network = train_network
|
||||||
for inputs in train_dataset_helper:
|
for inputs in train_dataset_helper:
|
||||||
self._train_network.compile(*inputs)
|
self._train_network.compile(*inputs)
|
||||||
|
|
|
@ -201,11 +201,6 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl):
|
||||||
metrics={'acc': DistAccuracy(batch_size=config.eval_batch_size, device_num=device_num)},
|
metrics={'acc': DistAccuracy(batch_size=config.eval_batch_size, device_num=device_num)},
|
||||||
eval_network=dist_eval_network)
|
eval_network=dist_eval_network)
|
||||||
|
|
||||||
# model init
|
|
||||||
print("init_start", device_id)
|
|
||||||
model.init(dataset, eval_dataset)
|
|
||||||
print("init_stop", device_id)
|
|
||||||
|
|
||||||
# callbacks
|
# callbacks
|
||||||
loss_cb = LossGet(1, step_size)
|
loss_cb = LossGet(1, step_size)
|
||||||
|
|
||||||
|
|
|
@ -221,25 +221,6 @@ def test_model_build_abnormal_string():
|
||||||
assert err
|
assert err
|
||||||
|
|
||||||
|
|
||||||
def test_model_init():
|
|
||||||
""" test_model_init_error """
|
|
||||||
train_dataset = get_dataset()
|
|
||||||
eval_dataset = get_dataset()
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError):
|
|
||||||
context.set_context(mode=context.PYNATIVE_MODE)
|
|
||||||
get_model().init(train_dataset)
|
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
|
||||||
get_model().init(train_dataset)
|
|
||||||
get_model(metrics={'acc'}).init(eval_dataset)
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError):
|
|
||||||
get_model().init(train_dataset, eval_dataset)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
get_model().init()
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_model_error():
|
def test_init_model_error():
|
||||||
""" test_init_model_error """
|
""" test_init_model_error """
|
||||||
net = nn.ReLU()
|
net = nn.ReLU()
|
||||||
|
|
Loading…
Reference in New Issue