fix error when train twice in pynative and feed model

This commit is contained in:
liyong 2021-03-05 10:04:14 +08:00
parent a018390e40
commit 621b796ed3
3 changed files with 24 additions and 18 deletions

View File

@ -402,6 +402,7 @@ class Model:
# build callback list # build callback list
with _CallbackManager(callbacks) as list_callback: with _CallbackManager(callbacks) as list_callback:
self._check_reuse_dataset(train_dataset)
if not dataset_sink_mode: if not dataset_sink_mode:
self._train_process(epoch, train_dataset, list_callback, cb_params) self._train_process(epoch, train_dataset, list_callback, cb_params)
elif context.get_context("device_target") == "CPU": elif context.get_context("device_target") == "CPU":
@ -409,7 +410,6 @@ class Model:
"So the training process will be performed with dataset not sink.") "So the training process will be performed with dataset not sink.")
self._train_process(epoch, train_dataset, list_callback, cb_params) self._train_process(epoch, train_dataset, list_callback, cb_params)
else: else:
self._check_reuse_dataset(train_dataset)
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params, sink_size) self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params, sink_size)
@staticmethod @staticmethod

View File

@ -310,11 +310,17 @@ class OptimizerSemiAutoAndAutoParallelFactory:
def test_optimizer_parallel_auto_4p_6_parameter_same_strategy_1_1_2_1_momentum(): def test_optimizer_parallel_auto_4p_6_parameter_same_strategy_1_1_2_1_momentum():
inputs_np = np.random.randn(16, 1, 32, 32).astype(np.float32) inputs_np = np.random.randn(16, 1, 32, 32).astype(np.float32)
dataset = FakeData(size=32, ds1 = FakeData(size=32,
batch_size=4, batch_size=4,
image_size=(1, 32, 32), image_size=(1, 32, 32),
use_parallel=True, use_parallel=True,
num_classes=116) num_classes=116)
ds2 = FakeData(size=32,
batch_size=4,
image_size=(1, 32, 32),
use_parallel=True,
num_classes=116)
strategy_dict = {'add1': ((1, 1, 2, 1), (1, 1, 2, 1)), strategy_dict = {'add1': ((1, 1, 2, 1), (1, 1, 2, 1)),
'mul1': ((1, 1, 2, 1), (1, 1, 2, 1)), 'mul1': ((1, 1, 2, 1), (1, 1, 2, 1)),
'fc1_matmul': ((1, 2), (1, 2)), 'fc1_matmul': ((1, 2), (1, 2)),
@ -323,6 +329,6 @@ def test_optimizer_parallel_auto_4p_6_parameter_same_strategy_1_1_2_1_momentum()
'mul3': ((1, 2), (1, 2))} 'mul3': ((1, 2), (1, 2))}
fact = OptimizerSemiAutoAndAutoParallelFactory(net=OptimizerSemiAutoAndAutoParallel6Net, fact = OptimizerSemiAutoAndAutoParallelFactory(net=OptimizerSemiAutoAndAutoParallel6Net,
strategy_dict=strategy_dict) strategy_dict=strategy_dict)
fact.mindspore_auto_parallel_impl(dataset=dataset, epoch=2, device_num=4) fact.mindspore_auto_parallel_impl(dataset=ds1, epoch=2, device_num=4)
fact.mindspore_optimizer_auto_parallel_impl(dataset=dataset, epoch=2, device_num=4) fact.mindspore_optimizer_auto_parallel_impl(dataset=ds2, epoch=2, device_num=4)
fact.checkpoint_cmp(inputs_np=inputs_np) fact.checkpoint_cmp(inputs_np=inputs_np)

View File

@ -122,7 +122,7 @@ def create_dataset(data_path, batch_size=32, repeat_size=1,
class NetFactory: class NetFactory:
def __init__(self, dataset, input_shape=(2, 1, 32, 32), in_channels=1, out_channels=3, def __init__(self, input_shape=(2, 1, 32, 32), in_channels=1, out_channels=3,
kernel_size=5, vocab_size=5, embedding_size=1, output_channels=3072, kernel_size=5, vocab_size=5, embedding_size=1, output_channels=3072,
epoch_size=1, target='CPU', sparse=True): epoch_size=1, target='CPU', sparse=True):
self.in_channels = in_channels self.in_channels = in_channels
@ -131,13 +131,12 @@ class NetFactory:
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.embedding_size = embedding_size self.embedding_size = embedding_size
self.output_channels = output_channels self.output_channels = output_channels
self.dataset = dataset
self.epoch_size = epoch_size self.epoch_size = epoch_size
self.target = target self.target = target
self.sparse = sparse self.sparse = sparse
self.input_np = np.random.randn(*input_shape).astype(np.float32) self.input_np = np.random.randn(*input_shape).astype(np.float32)
def no_ps_impl(self): def no_ps_impl(self, dataset):
context.set_ps_context(enable_ps=False) context.set_ps_context(enable_ps=False)
net = Menet(self.in_channels, self.out_channels, self.kernel_size, self.vocab_size, net = Menet(self.in_channels, self.out_channels, self.kernel_size, self.vocab_size,
self.embedding_size, self.output_channels, self.target, self.sparse) self.embedding_size, self.output_channels, self.target, self.sparse)
@ -148,13 +147,13 @@ class NetFactory:
opt = Adam(params=filter(lambda x: x.requires_grad, net.get_parameters())) opt = Adam(params=filter(lambda x: x.requires_grad, net.get_parameters()))
opt.target = 'CPU' opt.target = 'CPU'
model = Model(net, loss, opt) model = Model(net, loss, opt)
model.train(self.epoch_size, self.dataset, dataset_sink_mode=False) model.train(self.epoch_size, dataset, dataset_sink_mode=False)
input_me = Tensor(self.input_np) input_me = Tensor(self.input_np)
out_me = model.predict(input_me) out_me = model.predict(input_me)
context.set_ps_context(enable_ps=True) context.set_ps_context(enable_ps=True)
return out_me.asnumpy() return out_me.asnumpy()
def part_ps_impl(self): def part_ps_impl(self, dataset):
net = Menet(self.in_channels, self.out_channels, self.kernel_size, self.vocab_size, net = Menet(self.in_channels, self.out_channels, self.kernel_size, self.vocab_size,
self.embedding_size, self.output_channels, self.target, self.sparse) self.embedding_size, self.output_channels, self.target, self.sparse)
net.embedding_lookup.set_param_ps() net.embedding_lookup.set_param_ps()
@ -165,20 +164,21 @@ class NetFactory:
opt = Adam(params=filter(lambda x: x.requires_grad, net.get_parameters())) opt = Adam(params=filter(lambda x: x.requires_grad, net.get_parameters()))
opt.target = 'CPU' opt.target = 'CPU'
model = Model(net, loss, opt) model = Model(net, loss, opt)
model.train(self.epoch_size, self.dataset, dataset_sink_mode=False) model.train(self.epoch_size, dataset, dataset_sink_mode=False)
input_me = Tensor(self.input_np) input_me = Tensor(self.input_np)
out_me = model.predict(input_me) out_me = model.predict(input_me)
return out_me.asnumpy() return out_me.asnumpy()
def part_cmp(self): def part_cmp(self):
part_ps = self.part_ps_impl() ds1 = create_dataset(os.path.join(dataset_path, "train"), 32, 1)
no_ps = self.no_ps_impl() ds2 = create_dataset(os.path.join(dataset_path, "train"), 32, 1)
part_ps = self.part_ps_impl(ds1)
no_ps = self.no_ps_impl(ds2)
print(part_ps) print(part_ps)
print(no_ps) print(no_ps)
assert np.allclose(no_ps, part_ps, rtol=1.0e-4, atol=1.0e-4) assert np.allclose(no_ps, part_ps, rtol=1.0e-4, atol=1.0e-4)
if __name__ == "__main__": if __name__ == "__main__":
datasets = create_dataset(os.path.join(dataset_path, "train"), 32, 1) fact = NetFactory()
fact = NetFactory(dataset=datasets)
fact.part_cmp() fact.part_cmp()