forked from mindspore-Ecosystem/mindspore
!4965 fix googlenet deepfm
Merge pull request !4965 from panfengfeng/fix_googlenet_deepfm
This commit is contained in:
commit
0c60f7e6ac
|
@ -82,10 +82,10 @@ After installing MindSpore via the official website, you can start training and
|
|||
python train.py > train.log 2>&1 &
|
||||
|
||||
# run distributed training example
|
||||
sh scripts/run_train.sh rank_table.json
|
||||
Ascend: sh scripts/run_train.sh rank_table.json OR GPU: sh scripts/run_train_gpu.sh 8 0,1,2,3,4,5,6,7
|
||||
|
||||
# run evaluation example
|
||||
python eval.py > eval.log 2>&1 & OR sh run_eval.sh
|
||||
python eval.py > eval.log 2>&1 & OR Ascend: sh run_eval.sh OR GPU: sh run_eval_gpu.sh
|
||||
```
|
||||
|
||||
|
||||
|
@ -161,7 +161,7 @@ The model checkpoint will be saved in the current directory.
|
|||
### Distributed Training
|
||||
|
||||
```
|
||||
sh scripts/run_train.sh rank_table.json
|
||||
Ascend: sh scripts/run_train.sh rank_table.json OR GPU: sh scripts/run_train_gpu.sh 8 0,1,2,3,4,5,6,7
|
||||
```
|
||||
|
||||
The above shell script will run distribute training in the background. You can view the results through the file `train_parallel[X]/log`. The loss value will be achieved as follows:
|
||||
|
@ -187,7 +187,9 @@ Before running the command below, please check the checkpoint path used for eval
|
|||
```
|
||||
python eval.py > eval.log 2>&1 &
|
||||
OR
|
||||
sh scripts/run_eval.sh
|
||||
Ascned: sh scripts/run_eval.sh
|
||||
OR
|
||||
GPU: sh scripts/run_eval_gpu.sh
|
||||
```
|
||||
|
||||
The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:
|
||||
|
|
|
@ -83,6 +83,8 @@ if __name__ == '__main__':
|
|||
rank_size=rank_size,
|
||||
rank_id=rank_id)
|
||||
|
||||
steps_size = ds_train.get_dataset_size()
|
||||
|
||||
model_builder = ModelBuilder(ModelConfig, TrainConfig)
|
||||
train_net, eval_net = model_builder.get_train_eval_net()
|
||||
auc_metric = AUCMetric()
|
||||
|
@ -95,8 +97,12 @@ if __name__ == '__main__':
|
|||
if train_config.save_checkpoint:
|
||||
if rank_size:
|
||||
train_config.ckpt_file_name_prefix = train_config.ckpt_file_name_prefix + str(get_rank())
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps,
|
||||
keep_checkpoint_max=train_config.keep_checkpoint_max)
|
||||
if args_opt.device_target == "GPU":
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=steps_size,
|
||||
keep_checkpoint_max=train_config.keep_checkpoint_max)
|
||||
else:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps,
|
||||
keep_checkpoint_max=train_config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix=train_config.ckpt_file_name_prefix,
|
||||
directory=args_opt.ckpt_path,
|
||||
config=config_ck)
|
||||
|
|
Loading…
Reference in New Issue