forked from mindspore-Ecosystem/mindspore
Reset datasink to false when training with ps using ascend
This commit is contained in:
parent
5c5ca060ba
commit
b81da9d274
|
@ -1009,6 +1009,11 @@ class Model:
|
|||
... loss_scale_manager=loss_scale_manager)
|
||||
>>> model.train(2, dataset)
|
||||
"""
|
||||
device_target = context.get_context("device_target")
|
||||
if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
|
||||
logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
|
||||
dataset_sink_mode = False
|
||||
|
||||
Validator.check_bool(dataset_sink_mode)
|
||||
if isinstance(self._train_network, nn.GraphCell) and dataset_sink_mode:
|
||||
raise ValueError("Dataset sink mode is currently not supported when training with a GraphCell.")
|
||||
|
@ -1144,6 +1149,10 @@ class Model:
|
|||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={"accuracy"})
|
||||
>>> model.fit(2, train_dataset, valid_dataset)
|
||||
"""
|
||||
device_target = context.get_context("device_target")
|
||||
if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
|
||||
logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
|
||||
dataset_sink_mode = False
|
||||
|
||||
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
||||
valid_dataset_sink_mode = Validator.check_bool(valid_dataset_sink_mode)
|
||||
|
|
Loading…
Reference in New Issue