forked from mindspore-Ecosystem/mindspore
!185 modify comment about normal mode
Merge pull request !185 from jinyaohui/master
This commit is contained in:
commit
0565e4641e
|
@ -67,7 +67,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.")
|
parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.")
|
||||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
||||||
parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or non-sink mode, default is sink")
|
parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink")
|
||||||
parser.add_argument("--epoch_size", type=int, default=10, help="Epoch size, default is 10")
|
parser.add_argument("--epoch_size", type=int, default=10, help="Epoch size, default is 10")
|
||||||
parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.")
|
parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.")
|
||||||
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path")
|
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path")
|
||||||
|
|
|
@ -453,7 +453,7 @@ void ProcessGeArg(const std::map<std::string, ExecutorInfoPtr>& info, const py::
|
||||||
}
|
}
|
||||||
|
|
||||||
// process the first args of tensor
|
// process the first args of tensor
|
||||||
// only in Dataset non-sink Mode, fp_bp graph need input tensors
|
// only in dataset normal(non-sink) mode, fp_bp graph need input tensors
|
||||||
if (ConfigManager::GetInstance().dataset_mode() == DS_NORMAL_MODE) {
|
if (ConfigManager::GetInstance().dataset_mode() == DS_NORMAL_MODE) {
|
||||||
for (std::size_t i = 0; i < size; i++) {
|
for (std::size_t i = 0; i < size; i++) {
|
||||||
ValuePtr converted = nullptr;
|
ValuePtr converted = nullptr;
|
||||||
|
|
|
@ -447,7 +447,7 @@ void DfGraphConvertor::InitLoopVar(std::vector<ge::Operator> *init_input) {
|
||||||
if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
|
if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
|
||||||
value = ConfigManager::GetInstance().iter_num();
|
value = ConfigManager::GetInstance().iter_num();
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(INFO) << "Run with non-sink mode, the iterator number will always be 1";
|
MS_LOG(INFO) << "Run with normal(non-sink) mode, the iterator number will always be 1";
|
||||||
value = 1;
|
value = 1;
|
||||||
ConfigManager::GetInstance().set_iter_num(value);
|
ConfigManager::GetInstance().set_iter_num(value);
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,7 +51,7 @@ class DynamicLossScaleUpdateCell(Cell):
|
||||||
In every training step, the loss scaling value will be updated by loss scaling value/`scale_factor`
|
In every training step, the loss scaling value will be updated by loss scaling value/`scale_factor`
|
||||||
when there is overflow. And it will be increased by loss scaling value * `scale_factor` if there is no
|
when there is overflow. And it will be increased by loss scaling value * `scale_factor` if there is no
|
||||||
overflow for a continuous `scale_window` steps. This cell is used for Graph mode training in which all
|
overflow for a continuous `scale_window` steps. This cell is used for Graph mode training in which all
|
||||||
logic will be executed on device side(Another training mode is non-sink mode in which some logic will be
|
logic will be executed on device side(Another training mode is normal(non-sink) mode in which some logic will be
|
||||||
executed on host).
|
executed on host).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -112,8 +112,8 @@ def test_save_checkpoint():
|
||||||
os.remove('./test_files/test_ckpt-model.pkl')
|
os.remove('./test_files/test_ckpt-model.pkl')
|
||||||
|
|
||||||
|
|
||||||
def test_loss_monitor_sink_model():
|
def test_loss_monitor_sink_mode():
|
||||||
"""Test loss monitor sink model."""
|
"""Test loss monitor sink mode."""
|
||||||
cb_params = _InternalCallbackParam()
|
cb_params = _InternalCallbackParam()
|
||||||
cb_params.cur_epoch_num = 4
|
cb_params.cur_epoch_num = 4
|
||||||
cb_params.cur_step_num = 2
|
cb_params.cur_step_num = 2
|
||||||
|
@ -131,8 +131,8 @@ def test_loss_monitor_sink_model():
|
||||||
callbacklist.end(run_context)
|
callbacklist.end(run_context)
|
||||||
|
|
||||||
|
|
||||||
def test_loss_monitor_feed_model():
|
def test_loss_monitor_normal_mode():
|
||||||
"""Test loss monitor non-sink mode."""
|
"""Test loss monitor normal(non-sink) mode."""
|
||||||
cb_params = _InternalCallbackParam()
|
cb_params = _InternalCallbackParam()
|
||||||
run_context = RunContext(cb_params)
|
run_context = RunContext(cb_params)
|
||||||
loss_cb = LossMonitor(1)
|
loss_cb = LossMonitor(1)
|
||||||
|
|
Loading…
Reference in New Issue