diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 75f1638a8fc..545b0330dcf 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -378,8 +378,8 @@ class Model: with _CallbackManager(callbacks) as list_callback: if not dataset_sink_mode: self._train_process(epoch, train_dataset, list_callback, cb_params) - elif context.get_context("mode") == context.PYNATIVE_MODE: - logger.warning("The pynative mode cannot support dataset sink mode currently." + elif context.get_context("mode") == context.PYNATIVE_MODE or context.get_context("device_target") == "CPU": + logger.warning("The pynative mode and CPU cannot support dataset sink mode currently." "So the training process will be performed with dataset not sink.") self._train_process(epoch, train_dataset, list_callback, cb_params) else: diff --git a/model_zoo/official/cv/lenet/train.py b/model_zoo/official/cv/lenet/train.py index 7cd379134aa..1561b3daef1 100644 --- a/model_zoo/official/cv/lenet/train.py +++ b/model_zoo/official/cv/lenet/train.py @@ -44,8 +44,6 @@ if __name__ == "__main__": args = parser.parse_args() - if args.device_target == "CPU": - args.dataset_sink_mode = False context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) ds_train = create_dataset(os.path.join(args.data_path, "train"),