!6256 modify ckpt path

Merge pull request !6256 from TuDouNi/master
This commit is contained in:
mindspore-ci-bot 2020-09-16 14:25:46 +08:00 committed by Gitee
commit 069a1cfbd7
8 changed files with 17 additions and 15 deletions

View File

@ -36,7 +36,7 @@ def set_config(args):
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 20,
"save_checkpoint_path": "./checkpoint",
"save_checkpoint_path": "./",
"platform": args.platform,
"run_distribute": False
})
@ -57,7 +57,7 @@ def set_config(args):
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 200,
"save_checkpoint_path": "./checkpoint",
"save_checkpoint_path": "./",
"platform": args.platform,
"ccl": "nccl",
"run_distribute": args.run_distribute
@ -79,7 +79,7 @@ def set_config(args):
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 200,
"save_checkpoint_path": "./checkpoint",
"save_checkpoint_path": "./",
"platform": args.platform,
"ccl": "hccl",
"device_id": int(os.getenv('DEVICE_ID', '0')),

View File

@ -24,6 +24,7 @@ from mindspore.nn import WithLossCell, TrainOneStepCell
from mindspore.nn.optim.momentum import Momentum
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.common import dtype as mstype
from mindspore.communication.management import get_rank
from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import save_checkpoint
@ -93,10 +94,6 @@ if __name__ == '__main__':
features_path = args_opt.dataset_path + '_features'
idx_list = list(range(step_size))
if os.path.isdir(config.save_checkpoint_path):
os.rename(config.save_checkpoint_path, "{}_{}".format(config.save_checkpoint_path, time.time()))
os.mkdir(config.save_checkpoint_path)
for epoch in range(epoch_size):
random.shuffle(idx_list)
epoch_start = time.time()
@ -110,7 +107,11 @@ if __name__ == '__main__':
print("epoch[{}], iter[{}] cost: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}"\
.format(epoch + 1, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses))))
if (epoch + 1) % config.save_checkpoint_epochs == 0:
save_checkpoint(network, os.path.join(config.save_checkpoint_path, \
rank = 0
if config.run_distribute:
rank = get_rank()
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
save_checkpoint(network, os.path.join(save_ckpt_path, \
f"mobilenetv2_head_{epoch+1}.ckpt"))
print("total cost {:5.4f} s".format(time.time() - start))

View File

@ -7,6 +7,6 @@ do_shuffle=true
enable_data_sink=true
data_sink_steps=100
accumulation_steps=1
save_checkpoint_path=./checkpoint/
save_checkpoint_path=./
save_checkpoint_steps=10000
save_checkpoint_num=1

View File

@ -131,6 +131,7 @@ def run_transformer_train():
else:
device_num = 1
rank_id = 0
save_ckpt_path = os.path.join(args.save_checkpoint_path, 'ckpt_0/')
dataset = create_transformer_dataset(epoch_count=1, rank_size=device_num,
rank_id=rank_id, do_shuffle=args.do_shuffle,
dataset_path=args.data_path,

View File

@ -36,7 +36,7 @@ do
env > env.log
python -u train.py \
--dataset_path=$DATA_URL \
--ckpt_path="checkpoint" \
--ckpt_path="./" \
--eval_file_name='auc.log' \
--loss_file_name='loss.log' \
--do_eval=True > output.log 2>&1 &

View File

@ -31,7 +31,7 @@ env > env.log
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
python -u train.py \
--dataset_path=$DATA_URL \
--ckpt_path="checkpoint" \
--ckpt_path="./" \
--eval_file_name='auc.log' \
--loss_file_name='loss.log' \
--device_target='GPU' \

View File

@ -38,7 +38,7 @@ def argparse_init():
parser.add_argument("--keep_prob", type=float, default=1.0, help="The keep rate in dropout layer.")
parser.add_argument("--dropout_flag", type=int, default=0, help="Enable dropout")
parser.add_argument("--output_path", type=str, default="./output/")
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/", help="The location of the checkpoint file.")
parser.add_argument("--ckpt_path", type=str, default="./", help="The location of the checkpoint file.")
parser.add_argument("--stra_ckpt", type=str, default="./checkpoints/strategy.ckpt",
help="The strategy checkpoint file.")
parser.add_argument("--eval_file_name", type=str, default="eval.log", help="Eval output file.")
@ -77,7 +77,7 @@ class WideDeepConfig():
self.output_path = "./output"
self.eval_file_name = "eval.log"
self.loss_file_name = "loss.log"
self.ckpt_path = "./checkpoints/"
self.ckpt_path = "./"
self.stra_ckpt = './checkpoints/strategy.ckpt'
self.host_device_mix = 0
self.dataset_type = "tfrecord"

View File

@ -35,7 +35,7 @@ def argparse_init():
parser.add_argument("--dropout_flag", type=int, default=1) # The dropout rate
parser.add_argument("--output_path", type=str, default="./output/") # The location of the output file.
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/") # The location of the checkpoints file.
parser.add_argument("--ckpt_path", type=str, default="./") # The location of the checkpoints file.
parser.add_argument("--eval_file_name", type=str, default="eval.log") # Eval output file.
parser.add_argument("--loss_file_name", type=str, default="loss.log") # Loss output file.
return parser
@ -67,7 +67,7 @@ class WideDeepConfig():
self.output_path = "./output/"
self.eval_file_name = "eval.log"
self.loss_file_name = "loss.log"
self.ckpt_path = "./checkpoints/"
self.ckpt_path = "./"
def argparse_init(self):
"""