model_train

This commit is contained in:
liuyang_655 2020-08-31 21:01:18 +08:00
parent 521e351dac
commit 65a9c80aae
2 changed files with 2 additions and 4 deletions

View File

@ -378,8 +378,8 @@ class Model:
with _CallbackManager(callbacks) as list_callback: with _CallbackManager(callbacks) as list_callback:
if not dataset_sink_mode: if not dataset_sink_mode:
self._train_process(epoch, train_dataset, list_callback, cb_params) self._train_process(epoch, train_dataset, list_callback, cb_params)
elif context.get_context("mode") == context.PYNATIVE_MODE: elif context.get_context("mode") == context.PYNATIVE_MODE or context.get_context("device_target") == "CPU":
logger.warning("The pynative mode cannot support dataset sink mode currently." logger.warning("The pynative mode and CPU cannot support dataset sink mode currently."
"So the training process will be performed with dataset not sink.") "So the training process will be performed with dataset not sink.")
self._train_process(epoch, train_dataset, list_callback, cb_params) self._train_process(epoch, train_dataset, list_callback, cb_params)
else: else:

View File

@ -44,8 +44,6 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
if args.device_target == "CPU":
args.dataset_sink_mode = False
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
ds_train = create_dataset(os.path.join(args.data_path, "train"), ds_train = create_dataset(os.path.join(args.data_path, "train"),