fix googlenet & deepfm

This commit is contained in:
panfengfeng 2020-08-22 10:44:56 +08:00
parent 77198f3182
commit 30b69d3488
2 changed files with 14 additions and 6 deletions

View File

@ -82,10 +82,10 @@ After installing MindSpore via the official website, you can start training and
python train.py > train.log 2>&1 &
# run distributed training example
sh scripts/run_train.sh rank_table.json
Ascend: sh scripts/run_train.sh rank_table.json OR GPU: sh scripts/run_train_gpu.sh 8 0,1,2,3,4,5,6,7
# run evaluation example
python eval.py > eval.log 2>&1 & OR sh run_eval.sh
python eval.py > eval.log 2>&1 & OR Ascend: sh run_eval.sh OR GPU: sh run_eval_gpu.sh
```
@ -161,7 +161,7 @@ The model checkpoint will be saved in the current directory.
### Distributed Training
```
sh scripts/run_train.sh rank_table.json
Ascend: sh scripts/run_train.sh rank_table.json OR GPU: sh scripts/run_train_gpu.sh 8 0,1,2,3,4,5,6,7
```
The above shell script will run distribute training in the background. You can view the results through the file `train_parallel[X]/log`. The loss value will be achieved as follows:
@ -187,7 +187,9 @@ Before running the command below, please check the checkpoint path used for eval
```
python eval.py > eval.log 2>&1 &
OR
sh scripts/run_eval.sh
Ascned: sh scripts/run_eval.sh
OR
GPU: sh scripts/run_eval_gpu.sh
```
The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:

View File

@ -83,6 +83,8 @@ if __name__ == '__main__':
rank_size=rank_size,
rank_id=rank_id)
steps_size = ds_train.get_dataset_size()
model_builder = ModelBuilder(ModelConfig, TrainConfig)
train_net, eval_net = model_builder.get_train_eval_net()
auc_metric = AUCMetric()
@ -95,8 +97,12 @@ if __name__ == '__main__':
if train_config.save_checkpoint:
if rank_size:
train_config.ckpt_file_name_prefix = train_config.ckpt_file_name_prefix + str(get_rank())
config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps,
keep_checkpoint_max=train_config.keep_checkpoint_max)
if args_opt.device_target == "GPU":
config_ck = CheckpointConfig(save_checkpoint_steps=steps_size,
keep_checkpoint_max=train_config.keep_checkpoint_max)
else:
config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps,
keep_checkpoint_max=train_config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix=train_config.ckpt_file_name_prefix,
directory=args_opt.ckpt_path,
config=config_ck)