From 4e3e6006b6a378e04c079b7b070e3118878866e1 Mon Sep 17 00:00:00 2001 From: Jiaqi Date: Mon, 21 Sep 2020 17:05:50 +0800 Subject: [PATCH] modify init --- mindspore/train/model.py | 17 ++++------------- .../models/resnet50/test_resnet50_imagenet.py | 5 ----- tests/ut/python/train/test_training.py | 19 ------------------- 3 files changed, 4 insertions(+), 37 deletions(-) diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 1bfda62b5ea..32da5a9e012 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -267,7 +267,7 @@ class Model: return dataset_helper, network - def init(self, train_dataset=None, valid_dataset=None): + def _init(self, train_dataset=None, valid_dataset=None, sink_size=-1): """ Initialize compute graphs and data graphs with the sink mode. @@ -279,17 +279,7 @@ class Model: initialized. Default: None. valid_dataset (Dataset): A evaluating dataset iterator. If `valid_dataset` is defined, evaluation graphs will be initialized, and `metrics` in `Model` can not be None. Default: None. - - Examples: - >>> train_dataset = get_train_dataset() - >>> valid_dataset = get_valid_dataset() - >>> net = Net() - >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) - >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) - >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={'acc'}) - >>> model.init(train_dataset, valid_dataset) - >>> model.train(2, train_dataset) - >>> model.eval(valid_dataset) + sink_size (int): Control the amount of data in each sink. Default: -1. """ if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend": raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.') @@ -309,7 +299,8 @@ class Model: is_train=True, phase='train', dataset=train_dataset, - dataset_sink_mode=True) + dataset_sink_mode=True, + sink_size=sink_size) self._train_network = train_network for inputs in train_dataset_helper: self._train_network.compile(*inputs) diff --git a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py index c1b682a31b4..0482aff053b 100644 --- a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py +++ b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py @@ -201,11 +201,6 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl): metrics={'acc': DistAccuracy(batch_size=config.eval_batch_size, device_num=device_num)}, eval_network=dist_eval_network) - # model init - print("init_start", device_id) - model.init(dataset, eval_dataset) - print("init_stop", device_id) - # callbacks loss_cb = LossGet(1, step_size) diff --git a/tests/ut/python/train/test_training.py b/tests/ut/python/train/test_training.py index ad269511040..184077195e0 100644 --- a/tests/ut/python/train/test_training.py +++ b/tests/ut/python/train/test_training.py @@ -221,25 +221,6 @@ def test_model_build_abnormal_string(): assert err -def test_model_init(): - """ test_model_init_error """ - train_dataset = get_dataset() - eval_dataset = get_dataset() - - with pytest.raises(RuntimeError): - context.set_context(mode=context.PYNATIVE_MODE) - get_model().init(train_dataset) - - context.set_context(mode=context.GRAPH_MODE) - get_model().init(train_dataset) - get_model(metrics={'acc'}).init(eval_dataset) - - with pytest.raises(RuntimeError): - get_model().init(train_dataset, eval_dataset) - with pytest.raises(ValueError): - get_model().init() - - def test_init_model_error(): """ test_init_model_error """ net = nn.ReLU()