pr to master #8
|
@ -28,22 +28,17 @@ from src.utils import switch_precision, set_context
|
|||
if __name__ == '__main__':
|
||||
args_opt = eval_parse_args()
|
||||
config = set_config(args_opt)
|
||||
set_context(config)
|
||||
backbone_net, head_net, net = define_net(config, args_opt.is_training)
|
||||
|
||||
#load the trained checkpoint file to the net for evaluation
|
||||
if args_opt.head_ckpt:
|
||||
load_ckpt(backbone_net, args_opt.pretrain_ckpt)
|
||||
load_ckpt(head_net, args_opt.head_ckpt)
|
||||
else:
|
||||
load_ckpt(net, args_opt.pretrain_ckpt)
|
||||
load_ckpt(net, args_opt.pretrain_ckpt)
|
||||
|
||||
set_context(config)
|
||||
switch_precision(net, mstype.float16, config)
|
||||
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, config=config)
|
||||
step_size = dataset.get_dataset_size()
|
||||
if step_size == 0:
|
||||
raise ValueError("The step_size of dataset is zero. Check if the images count of train dataset is more \
|
||||
raise ValueError("The step_size of dataset is zero. Check if the images count of eval dataset is more \
|
||||
than batch_size in config.py")
|
||||
|
||||
net.set_train(False)
|
||||
|
@ -53,5 +48,3 @@ if __name__ == '__main__':
|
|||
|
||||
res = model.eval(dataset)
|
||||
print(f"result:{res}\npretrain_ckpt={args_opt.pretrain_ckpt}")
|
||||
if args_opt.head_ckpt:
|
||||
print(f"head_ckpt={args_opt.head_ckpt}")
|
||||
|
|
|
@ -43,7 +43,6 @@ run_ascend()
|
|||
--platform=$1 \
|
||||
--dataset_path=$2 \
|
||||
--pretrain_ckpt=$3 \
|
||||
--head_ckpt=$4 \
|
||||
&> ../eval.log & # dataset val folder path
|
||||
}
|
||||
|
||||
|
@ -69,7 +68,6 @@ run_gpu()
|
|||
--platform=$1 \
|
||||
--dataset_path=$2 \
|
||||
--pretrain_ckpt=$3 \
|
||||
--head_ckpt=$4 \
|
||||
&> ../eval.log & # dataset train folder
|
||||
}
|
||||
|
||||
|
@ -95,7 +93,6 @@ run_cpu()
|
|||
--platform=$1 \
|
||||
--dataset_path=$2 \
|
||||
--pretrain_ckpt=$3 \
|
||||
--head_ckpt=$4 \
|
||||
&> ../eval.log & # dataset train folder
|
||||
}
|
||||
|
||||
|
@ -105,7 +102,7 @@ then
|
|||
echo "Usage:
|
||||
Ascend: sh run_eval.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT]
|
||||
GPU: sh run_eval.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT]
|
||||
CPU: sh run_eval.sh [PLATFORM] [DATASET_PATH] [BACKBONE_CKPT] [HEAD_CKPT]"
|
||||
CPU: sh run_eval.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -123,5 +120,5 @@ elif [ $1 = "GPU" ] ; then
|
|||
elif [ $1 = "Ascend" ] ; then
|
||||
run_ascend "$@"
|
||||
else
|
||||
echo "Unsupported device_target."
|
||||
echo "Unsupported platform."
|
||||
fi;
|
||||
|
|
|
@ -43,8 +43,8 @@ run_ascend()
|
|||
--visible_devices=$3 \
|
||||
--training_script=${BASEPATH}/../train.py \
|
||||
--dataset_path=$5 \
|
||||
--train_method=$6 \
|
||||
--pretrain_ckpt=$7 \
|
||||
--pretrain_ckpt=$6 \
|
||||
--freeze_layer=$7 \
|
||||
&> ../train.log & # dataset train folder
|
||||
}
|
||||
|
||||
|
@ -76,8 +76,8 @@ run_gpu()
|
|||
python ${BASEPATH}/../train.py \
|
||||
--platform=$1 \
|
||||
--dataset_path=$4 \
|
||||
--train_method=$5 \
|
||||
--pretrain_ckpt=$6 \
|
||||
--pretrain_ckpt=$5 \
|
||||
--freeze_layer=$6 \
|
||||
&> ../train.log & # dataset train folder
|
||||
}
|
||||
|
||||
|
@ -102,17 +102,17 @@ run_cpu()
|
|||
python ${BASEPATH}/../train.py \
|
||||
--platform=$1 \
|
||||
--dataset_path=$2 \
|
||||
--train_method=$3 \
|
||||
--pretrain_ckpt=$4 \
|
||||
--pretrain_ckpt=$3 \
|
||||
--freeze_layer=$4 \
|
||||
&> ../train.log & # dataset train folder
|
||||
}
|
||||
|
||||
if [ $# -gt 7 ] || [ $# -lt 4 ]
|
||||
then
|
||||
echo "Usage:
|
||||
Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH]
|
||||
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH]
|
||||
CPU: sh run_train.sh CPU [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH]"
|
||||
Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER]
|
||||
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER]
|
||||
CPU: sh run_train.sh CPU [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -123,5 +123,5 @@ elif [ $1 = "GPU" ] ; then
|
|||
elif [ $1 = "CPU" ] ; then
|
||||
run_cpu "$@"
|
||||
else
|
||||
echo "Unsupported device_target."
|
||||
echo "Unsupported platform."
|
||||
fi;
|
||||
|
|
|
@ -41,11 +41,10 @@ def train_parse_args():
|
|||
train_parser.add_argument('--platform', type=str, default="Ascend", choices=("CPU", "GPU", "Ascend"), \
|
||||
help='run platform, only support CPU, GPU and Ascend')
|
||||
train_parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path')
|
||||
train_parser.add_argument('--train_method', type=str, choices=("train", "fine_tune", "incremental_learn"), \
|
||||
help="\"fine_tune\"or \"incremental_learn\" if to fine tune the net after loading the ckpt, \"train\" to \
|
||||
train from initialization model")
|
||||
train_parser.add_argument('--pretrain_ckpt', type=str, default=None, help='Pretrained checkpoint path \
|
||||
for fine tune or incremental learning')
|
||||
train_parser.add_argument('--freeze_layer', type=str, default=None, choices=["none", "backbone"], \
|
||||
help="freeze the weights of network from start to which layers")
|
||||
train_parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distribute')
|
||||
train_args = train_parser.parse_args()
|
||||
train_args.is_training = True
|
||||
|
@ -58,8 +57,6 @@ def eval_parse_args():
|
|||
eval_parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path')
|
||||
eval_parser.add_argument('--pretrain_ckpt', type=str, required=True, help='Pretrained checkpoint path \
|
||||
for fine tune or incremental learning')
|
||||
eval_parser.add_argument('--head_ckpt', type=str, default=None, help='Pretrained checkpoint path \
|
||||
for incremental learning')
|
||||
eval_parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='If run distribute in GPU.')
|
||||
eval_args = eval_parser.parse_args()
|
||||
eval_args.is_training = False
|
||||
|
|
|
@ -122,5 +122,5 @@ def extract_features(net, dataset_path, config):
|
|||
features = model.predict(Tensor(image))
|
||||
np.save(features_path, features.asnumpy())
|
||||
np.save(label_path, label)
|
||||
print(f"Complete the batch {i}/{step_size}")
|
||||
print(f"Complete the batch {i+1}/{step_size}")
|
||||
return step_size
|
||||
|
|
|
@ -298,8 +298,6 @@ class MobileNetV2(nn.Cell):
|
|||
has_dropout (bool): Is dropout used. Default is false
|
||||
inverted_residual_setting (list): Inverted residual settings. Default is None
|
||||
round_nearest (list): Channel round to . Default is 8
|
||||
backbone(nn.Cell): Backbone of MobileNetV2.
|
||||
head(nn.Cell): Classification head of MobileNetV2.
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
|
|
|
@ -82,6 +82,6 @@ def config_ckpoint(config, lr, step_size):
|
|||
rank = get_rank()
|
||||
|
||||
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(rank) + "/"
|
||||
ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck)
|
||||
ckpt_cb = ModelCheckpoint(prefix="mobilenetv2", directory=ckpt_save_dir, config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
return cb
|
||||
|
|
|
@ -53,30 +53,23 @@ if __name__ == '__main__':
|
|||
# define network
|
||||
backbone_net, head_net, net = define_net(config, args_opt.is_training)
|
||||
|
||||
# load the ckpt file to the network for fine tune or incremental leaning
|
||||
if args_opt.pretrain_ckpt:
|
||||
if args_opt.train_method == "fine_tune":
|
||||
load_ckpt(net, args_opt.pretrain_ckpt)
|
||||
elif args_opt.train_method == "incremental_learn":
|
||||
load_ckpt(backbone_net, args_opt.pretrain_ckpt, trainable=False)
|
||||
elif args_opt.train_method == "train":
|
||||
pass
|
||||
else:
|
||||
raise ValueError("must input the usage of pretrain_ckpt when the pretrain_ckpt isn't None")
|
||||
|
||||
# CPU only support "incremental_learn"
|
||||
if args_opt.train_method == "incremental_learn":
|
||||
if args_opt.pretrain_ckpt and args_opt.freeze_layer == "backbone":
|
||||
load_ckpt(backbone_net, args_opt.pretrain_ckpt, trainable=False)
|
||||
step_size = extract_features(backbone_net, args_opt.dataset_path, config)
|
||||
net = head_net
|
||||
|
||||
elif args_opt.train_method in ("train", "fine_tune"):
|
||||
else:
|
||||
if args_opt.platform == "CPU":
|
||||
raise ValueError("Currently, CPU only support \"incremental_learn\", not \"fine_tune\" or \"train\".")
|
||||
raise ValueError("CPU only support fine tune the head net, doesn't support fine tune the all net")
|
||||
|
||||
if args_opt.pretrain_ckpt:
|
||||
load_ckpt(backbone_net, args_opt.pretrain_ckpt)
|
||||
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, config=config)
|
||||
step_size = dataset.get_dataset_size()
|
||||
if step_size == 0:
|
||||
raise ValueError("The step_size of dataset is zero. Check if the images count of train dataset is more \
|
||||
than batch_size in config.py")
|
||||
|
||||
if step_size == 0:
|
||||
raise ValueError("The step_size of dataset is zero. Check if the images' count of train dataset is more \
|
||||
than batch_size in config.py")
|
||||
|
||||
# Currently, only Ascend support switch precision.
|
||||
switch_precision(net, mstype.float16, config)
|
||||
|
@ -99,15 +92,32 @@ if __name__ == '__main__':
|
|||
total_epochs=epoch_size,
|
||||
steps_per_epoch=step_size))
|
||||
|
||||
if args_opt.train_method == "incremental_learn":
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay)
|
||||
if args_opt.pretrain_ckpt is None or args_opt.freeze_layer == "none":
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, \
|
||||
config.weight_decay, config.loss_scale)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale)
|
||||
|
||||
network = WithLossCell(net, loss)
|
||||
cb = config_ckpoint(config, lr, step_size)
|
||||
print("============== Starting Training ==============")
|
||||
model.train(epoch_size, dataset, callbacks=cb)
|
||||
print("============== End Training ==============")
|
||||
|
||||
else:
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, head_net.get_parameters()), lr, config.momentum, config.weight_decay)
|
||||
|
||||
network = WithLossCell(head_net, loss)
|
||||
network = TrainOneStepCell(network, opt)
|
||||
network.set_train()
|
||||
|
||||
features_path = args_opt.dataset_path + '_features'
|
||||
idx_list = list(range(step_size))
|
||||
rank = 0
|
||||
if config.run_distribute:
|
||||
rank = get_rank()
|
||||
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
|
||||
if not os.path.isdir(save_ckpt_path):
|
||||
os.mkdir(save_checkpoint)
|
||||
|
||||
for epoch in range(epoch_size):
|
||||
random.shuffle(idx_list)
|
||||
|
@ -119,24 +129,8 @@ if __name__ == '__main__':
|
|||
losses.append(network(feature, label).asnumpy())
|
||||
epoch_mseconds = (time.time()-epoch_start) * 1000
|
||||
per_step_mseconds = epoch_mseconds / step_size
|
||||
print("epoch[{}], iter[{}] cost: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}"\
|
||||
.format(epoch + 1, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses))))
|
||||
print("epoch[{}/{}], iter[{}] cost: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}"\
|
||||
.format(epoch + 1, epoch_size, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses))))
|
||||
if (epoch + 1) % config.save_checkpoint_epochs == 0:
|
||||
rank = 0
|
||||
if config.run_distribute:
|
||||
rank = get_rank()
|
||||
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
|
||||
save_checkpoint(network, os.path.join(save_ckpt_path, \
|
||||
f"mobilenetv2_head_{epoch+1}.ckpt"))
|
||||
save_checkpoint(net, os.path.join(save_ckpt_path, f"mobilenetv2_{epoch+1}.ckpt"))
|
||||
print("total cost {:5.4f} s".format(time.time() - start))
|
||||
|
||||
elif args_opt.train_method in ("train", "fine_tune"):
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, \
|
||||
config.weight_decay, config.loss_scale)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale)
|
||||
|
||||
cb = config_ckpoint(config, lr, step_size)
|
||||
print("============== Starting Training ==============")
|
||||
model.train(epoch_size, dataset, callbacks=cb)
|
||||
print("============== End Training ==============")
|
||||
|
|
Loading…
Reference in New Issue