forked from mindspore-Ecosystem/mindspore
update nasnet scripts
This commit is contained in:
parent
dc0cf6f66c
commit
2b0550b77a
|
@ -29,7 +29,7 @@ from mindspore.common import dtype as mstype
|
|||
|
||||
from src.config import nasnet_a_mobile_config_gpu as cfg
|
||||
from src.dataset import create_dataset
|
||||
from src.nasnet_a_mobile import NASNetAMobileWithLoss, NASNetAMobileTrainOneStepWithClipGradient
|
||||
from src.nasnet_a_mobile import NASNetAMobileWithLoss
|
||||
from src.lr_generator import get_lr
|
||||
|
||||
|
||||
|
@ -104,9 +104,13 @@ if __name__ == '__main__':
|
|||
optimizer = RMSProp(group_params, lr, decay=cfg.rmsprop_decay, weight_decay=cfg.weight_decay,
|
||||
momentum=cfg.momentum, epsilon=cfg.opt_eps, loss_scale=cfg.loss_scale)
|
||||
|
||||
net_with_grads = NASNetAMobileTrainOneStepWithClipGradient(net_with_loss, optimizer)
|
||||
net_with_grads.set_train()
|
||||
model = Model(net_with_grads)
|
||||
# net_with_grads = NASNetAMobileTrainOneStepWithClipGradient(net_with_loss, optimizer)
|
||||
# net_with_grads.set_train()
|
||||
# model = Model(net_with_grads)
|
||||
|
||||
# high performance
|
||||
net_with_loss.set_train()
|
||||
model = Model(net_with_loss, optimizer=optimizer)
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
loss_cb = LossMonitor(per_print_times=batches_per_epoch)
|
||||
|
|
Loading…
Reference in New Issue