forked from mindspore-Ecosystem/mindspore
!5368 fix sink_size bug for transformer
Merge pull request !5368 from yuchaojie/transformer_st
This commit is contained in:
commit
f3fc24c82a
|
@ -170,8 +170,14 @@ def run_transformer_train():
|
|||
|
||||
netwithgrads.set_train(True)
|
||||
model = Model(netwithgrads)
|
||||
model.train(args.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true"),
|
||||
sink_size=args.save_checkpoint_steps)
|
||||
|
||||
enable_sink = (args.enable_data_sink == "true")
|
||||
if enable_sink:
|
||||
sink_size = args.save_checkpoint_steps
|
||||
model.train(args.epoch_size*dataset.get_dataset_size()//sink_size, dataset, callbacks=callbacks,
|
||||
dataset_sink_mode=enable_sink, sink_size=sink_size)
|
||||
else:
|
||||
model.train(args.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=enable_sink)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_transformer_train()
|
||||
|
|
Loading…
Reference in New Issue