fix googlenet & deepfm

This commit is contained in:
panfengfeng 2020-08-22 10:44:56 +08:00
parent 77198f3182
commit 30b69d3488
2 changed files with 14 additions and 6 deletions

View File

@ -82,10 +82,10 @@ After installing MindSpore via the official website, you can start training and
python train.py > train.log 2>&1 & python train.py > train.log 2>&1 &
# run distributed training example # run distributed training example
sh scripts/run_train.sh rank_table.json Ascend: sh scripts/run_train.sh rank_table.json OR GPU: sh scripts/run_train_gpu.sh 8 0,1,2,3,4,5,6,7
# run evaluation example # run evaluation example
python eval.py > eval.log 2>&1 & OR sh run_eval.sh python eval.py > eval.log 2>&1 & OR Ascend: sh run_eval.sh OR GPU: sh run_eval_gpu.sh
``` ```
@ -161,7 +161,7 @@ The model checkpoint will be saved in the current directory.
### Distributed Training ### Distributed Training
``` ```
sh scripts/run_train.sh rank_table.json Ascend: sh scripts/run_train.sh rank_table.json OR GPU: sh scripts/run_train_gpu.sh 8 0,1,2,3,4,5,6,7
``` ```
The above shell script will run distribute training in the background. You can view the results through the file `train_parallel[X]/log`. The loss value will be achieved as follows: The above shell script will run distribute training in the background. You can view the results through the file `train_parallel[X]/log`. The loss value will be achieved as follows:
@ -187,7 +187,9 @@ Before running the command below, please check the checkpoint path used for eval
``` ```
python eval.py > eval.log 2>&1 & python eval.py > eval.log 2>&1 &
OR OR
sh scripts/run_eval.sh Ascned: sh scripts/run_eval.sh
OR
GPU: sh scripts/run_eval_gpu.sh
``` ```
The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows: The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:

View File

@ -83,6 +83,8 @@ if __name__ == '__main__':
rank_size=rank_size, rank_size=rank_size,
rank_id=rank_id) rank_id=rank_id)
steps_size = ds_train.get_dataset_size()
model_builder = ModelBuilder(ModelConfig, TrainConfig) model_builder = ModelBuilder(ModelConfig, TrainConfig)
train_net, eval_net = model_builder.get_train_eval_net() train_net, eval_net = model_builder.get_train_eval_net()
auc_metric = AUCMetric() auc_metric = AUCMetric()
@ -95,8 +97,12 @@ if __name__ == '__main__':
if train_config.save_checkpoint: if train_config.save_checkpoint:
if rank_size: if rank_size:
train_config.ckpt_file_name_prefix = train_config.ckpt_file_name_prefix + str(get_rank()) train_config.ckpt_file_name_prefix = train_config.ckpt_file_name_prefix + str(get_rank())
config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps, if args_opt.device_target == "GPU":
keep_checkpoint_max=train_config.keep_checkpoint_max) config_ck = CheckpointConfig(save_checkpoint_steps=steps_size,
keep_checkpoint_max=train_config.keep_checkpoint_max)
else:
config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps,
keep_checkpoint_max=train_config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix=train_config.ckpt_file_name_prefix, ckpt_cb = ModelCheckpoint(prefix=train_config.ckpt_file_name_prefix,
directory=args_opt.ckpt_path, directory=args_opt.ckpt_path,
config=config_ck) config=config_ck)