commit
069a1cfbd7
|
@ -36,7 +36,7 @@ def set_config(args):
|
|||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 20,
|
||||
"save_checkpoint_path": "./checkpoint",
|
||||
"save_checkpoint_path": "./",
|
||||
"platform": args.platform,
|
||||
"run_distribute": False
|
||||
})
|
||||
|
@ -57,7 +57,7 @@ def set_config(args):
|
|||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 200,
|
||||
"save_checkpoint_path": "./checkpoint",
|
||||
"save_checkpoint_path": "./",
|
||||
"platform": args.platform,
|
||||
"ccl": "nccl",
|
||||
"run_distribute": args.run_distribute
|
||||
|
@ -79,7 +79,7 @@ def set_config(args):
|
|||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 200,
|
||||
"save_checkpoint_path": "./checkpoint",
|
||||
"save_checkpoint_path": "./",
|
||||
"platform": args.platform,
|
||||
"ccl": "hccl",
|
||||
"device_id": int(os.getenv('DEVICE_ID', '0')),
|
||||
|
|
|
@ -24,6 +24,7 @@ from mindspore.nn import WithLossCell, TrainOneStepCell
|
|||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.communication.management import get_rank
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
|
@ -93,10 +94,6 @@ if __name__ == '__main__':
|
|||
features_path = args_opt.dataset_path + '_features'
|
||||
idx_list = list(range(step_size))
|
||||
|
||||
if os.path.isdir(config.save_checkpoint_path):
|
||||
os.rename(config.save_checkpoint_path, "{}_{}".format(config.save_checkpoint_path, time.time()))
|
||||
os.mkdir(config.save_checkpoint_path)
|
||||
|
||||
for epoch in range(epoch_size):
|
||||
random.shuffle(idx_list)
|
||||
epoch_start = time.time()
|
||||
|
@ -110,7 +107,11 @@ if __name__ == '__main__':
|
|||
print("epoch[{}], iter[{}] cost: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}"\
|
||||
.format(epoch + 1, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses))))
|
||||
if (epoch + 1) % config.save_checkpoint_epochs == 0:
|
||||
save_checkpoint(network, os.path.join(config.save_checkpoint_path, \
|
||||
rank = 0
|
||||
if config.run_distribute:
|
||||
rank = get_rank()
|
||||
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
|
||||
save_checkpoint(network, os.path.join(save_ckpt_path, \
|
||||
f"mobilenetv2_head_{epoch+1}.ckpt"))
|
||||
print("total cost {:5.4f} s".format(time.time() - start))
|
||||
|
||||
|
|
|
@ -7,6 +7,6 @@ do_shuffle=true
|
|||
enable_data_sink=true
|
||||
data_sink_steps=100
|
||||
accumulation_steps=1
|
||||
save_checkpoint_path=./checkpoint/
|
||||
save_checkpoint_path=./
|
||||
save_checkpoint_steps=10000
|
||||
save_checkpoint_num=1
|
||||
|
|
|
@ -131,6 +131,7 @@ def run_transformer_train():
|
|||
else:
|
||||
device_num = 1
|
||||
rank_id = 0
|
||||
save_ckpt_path = os.path.join(args.save_checkpoint_path, 'ckpt_0/')
|
||||
dataset = create_transformer_dataset(epoch_count=1, rank_size=device_num,
|
||||
rank_id=rank_id, do_shuffle=args.do_shuffle,
|
||||
dataset_path=args.data_path,
|
||||
|
|
|
@ -36,7 +36,7 @@ do
|
|||
env > env.log
|
||||
python -u train.py \
|
||||
--dataset_path=$DATA_URL \
|
||||
--ckpt_path="checkpoint" \
|
||||
--ckpt_path="./" \
|
||||
--eval_file_name='auc.log' \
|
||||
--loss_file_name='loss.log' \
|
||||
--do_eval=True > output.log 2>&1 &
|
||||
|
|
|
@ -31,7 +31,7 @@ env > env.log
|
|||
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
|
||||
python -u train.py \
|
||||
--dataset_path=$DATA_URL \
|
||||
--ckpt_path="checkpoint" \
|
||||
--ckpt_path="./" \
|
||||
--eval_file_name='auc.log' \
|
||||
--loss_file_name='loss.log' \
|
||||
--device_target='GPU' \
|
||||
|
|
|
@ -38,7 +38,7 @@ def argparse_init():
|
|||
parser.add_argument("--keep_prob", type=float, default=1.0, help="The keep rate in dropout layer.")
|
||||
parser.add_argument("--dropout_flag", type=int, default=0, help="Enable dropout")
|
||||
parser.add_argument("--output_path", type=str, default="./output/")
|
||||
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/", help="The location of the checkpoint file.")
|
||||
parser.add_argument("--ckpt_path", type=str, default="./", help="The location of the checkpoint file.")
|
||||
parser.add_argument("--stra_ckpt", type=str, default="./checkpoints/strategy.ckpt",
|
||||
help="The strategy checkpoint file.")
|
||||
parser.add_argument("--eval_file_name", type=str, default="eval.log", help="Eval output file.")
|
||||
|
@ -77,7 +77,7 @@ class WideDeepConfig():
|
|||
self.output_path = "./output"
|
||||
self.eval_file_name = "eval.log"
|
||||
self.loss_file_name = "loss.log"
|
||||
self.ckpt_path = "./checkpoints/"
|
||||
self.ckpt_path = "./"
|
||||
self.stra_ckpt = './checkpoints/strategy.ckpt'
|
||||
self.host_device_mix = 0
|
||||
self.dataset_type = "tfrecord"
|
||||
|
|
|
@ -35,7 +35,7 @@ def argparse_init():
|
|||
parser.add_argument("--dropout_flag", type=int, default=1) # The dropout rate
|
||||
|
||||
parser.add_argument("--output_path", type=str, default="./output/") # The location of the output file.
|
||||
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/") # The location of the checkpoints file.
|
||||
parser.add_argument("--ckpt_path", type=str, default="./") # The location of the checkpoints file.
|
||||
parser.add_argument("--eval_file_name", type=str, default="eval.log") # Eval output file.
|
||||
parser.add_argument("--loss_file_name", type=str, default="loss.log") # Loss output file.
|
||||
return parser
|
||||
|
@ -67,7 +67,7 @@ class WideDeepConfig():
|
|||
self.output_path = "./output/"
|
||||
self.eval_file_name = "eval.log"
|
||||
self.loss_file_name = "loss.log"
|
||||
self.ckpt_path = "./checkpoints/"
|
||||
self.ckpt_path = "./"
|
||||
|
||||
def argparse_init(self):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue