forked from mindspore-Ecosystem/mindspore
!7843 fix nasnet & efficientnet scripts
Merge pull request !7843 from panfengfeng/fix_nasnet_hung
This commit is contained in:
commit
7633727fc8
|
@ -41,6 +41,7 @@ efficientnet_b0_config_gpu = edict({
|
|||
'smoothing': 0.1,
|
||||
#Use Tensorflow BatchNorm defaults for models that support it
|
||||
'bn_tf': False,
|
||||
'save_checkpoint': True,
|
||||
'keep_checkpoint_max': 10,
|
||||
'loss_scale': 1024,
|
||||
'resume_start_epoch': 0,
|
||||
|
|
|
@ -146,10 +146,14 @@ def main():
|
|||
loss_scale_manager = FixedLossScaleManager(
|
||||
cfg.loss_scale, drop_overflow_update=False)
|
||||
|
||||
config_ck = CheckpointConfig(
|
||||
save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(
|
||||
prefix=cfg.model, directory=output_dir, config=config_ck)
|
||||
callbacks = [time_cb, loss_cb]
|
||||
|
||||
if cfg.save_checkpoint:
|
||||
config_ck = CheckpointConfig(
|
||||
save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(
|
||||
prefix=cfg.model, directory=output_dir, config=config_ck)
|
||||
callbacks += [ckpoint_cb]
|
||||
|
||||
lr = Tensor(get_lr(base_lr=cfg.lr, total_epochs=cfg.epochs, steps_per_epoch=batches_per_epoch,
|
||||
decay_steps=cfg.decay_epochs, decay_rate=cfg.decay_rate,
|
||||
|
@ -176,7 +180,7 @@ def main():
|
|||
amp_level=cfg.amp_level
|
||||
)
|
||||
|
||||
callbacks = [loss_cb, ckpoint_cb, time_cb] if is_master else []
|
||||
callbacks = callbacks if is_master else []
|
||||
|
||||
if args.resume:
|
||||
real_epoch = cfg.epochs - cfg.resume_start_epoch
|
||||
|
|
|
@ -14,5 +14,5 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
DATA_DIR=$1
|
||||
mpirun --allow-run-as-root -n 8 --output-filename log_output --merge-stderr-to-stdout \
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 mpirun --allow-run-as-root -n 8 --output-filename log_output --merge-stderr-to-stdout \
|
||||
python ./train.py --is_distributed --platform 'GPU' --dataset_path $DATA_DIR > train.log 2>&1 &
|
||||
|
|
Loading…
Reference in New Issue