forked from mindspore-Ecosystem/mindspore
!17928 mobilenetv2 test
Merge pull request !17928 from huchunmei/mobilenetv2
This commit is contained in:
commit
1318156356
|
@ -1,4 +1,4 @@
|
|||
# Mobilenet_V1
|
||||
# Mobilenet_V1
|
||||
|
||||
- [Mobilenet_V1](#mobilenet_v1)
|
||||
- [MobileNetV1 Description](#mobilenetv1-description)
|
||||
|
@ -79,17 +79,23 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
|
|||
├── MobileNetV1
|
||||
├── README.md # descriptions about MobileNetV1
|
||||
├── scripts
|
||||
│ ├──run_distribute_train.sh # shell script for distribute train
|
||||
│ ├──run_standalone_train.sh # shell script for standalone train
|
||||
│ ├──run_eval.sh # shell script for evaluation
|
||||
│ ├──run_distribute_train.sh # shell script for distribute train
|
||||
│ ├──run_standalone_train.sh # shell script for standalone train
|
||||
│ ├──run_eval.sh # shell script for evaluation
|
||||
├── src
|
||||
│ ├──config.py # parameter configuration
|
||||
│ ├──dataset.py # creating dataset
|
||||
│ ├──lr_generator.py # learning rate config
|
||||
│ ├──mobilenet_v1_fpn.py # MobileNetV1 architecture
|
||||
│ ├──CrossEntropySmooth.py # loss function
|
||||
├── train.py # training script
|
||||
├── eval.py # evaluation script
|
||||
│ ├──dataset.py # creating dataset
|
||||
│ ├──lr_generator.py # learning rate config
|
||||
│ ├──mobilenet_v1_fpn.py # MobileNetV1 architecture
|
||||
│ ├──CrossEntropySmooth.py # loss function
|
||||
│ └──model_utils
|
||||
│ ├──config.py # Processing configuration parameters
|
||||
│ ├──device_adapter.py # Get cloud ID
|
||||
│ ├──local_adapter.py # Get local ID
|
||||
│ └──moxing_adapter.py # Parameter processing
|
||||
├── default_config.yaml # Training parameter profile(cifar10)
|
||||
├── default_config_imagenet.yaml # Training parameter profile(imagenet)
|
||||
├── train.py # training script
|
||||
├── eval.py # evaluation script
|
||||
```
|
||||
|
||||
## [Training process](#contents)
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
checkpoint_file: './checkpoint/mobilenetv1-90_625.ckpt'
|
||||
device_target: Ascend
|
||||
enable_profiling: False
|
||||
|
||||
# ==============================================================================
|
||||
modelarts_dataset_unzip_name: 'ImageNet_Original'
|
||||
need_modelarts_dataset_unzip: True
|
||||
|
||||
# config for mobilenet, cifar10
|
||||
class_num: 10
|
||||
batch_size: 32
|
||||
loss_scale: 1024
|
||||
momentum: 0.9
|
||||
weight_decay: 0.0001 # 1e-4
|
||||
epoch_size: 90
|
||||
pretrain_epoch_size: 0
|
||||
save_checkpoint: True
|
||||
save_checkpoint_epochs: 5
|
||||
keep_checkpoint_max: 10
|
||||
save_checkpoint_path: "/cache/train"
|
||||
warmup_epochs: 5
|
||||
lr_decay_mode: "poly"
|
||||
lr_init: 0.01
|
||||
lr_end: 0.00001
|
||||
lr_max: 0.1
|
||||
|
||||
# Image classification - train
|
||||
dataset: 'cifar10'
|
||||
run_distribute: True
|
||||
device_num: 1
|
||||
dataset_path: "/cache/data"
|
||||
device_target: 'Ascend'
|
||||
pre_trained: "./mobilenetv2-200_625.ckpt" # "./mobilenetv1-90_195.ckpt"
|
||||
parameter_server: False
|
||||
|
||||
# Image classification - eval
|
||||
checkpoint_path: "./mobilenetv1-90_625.ckpt"
|
||||
|
||||
# mobilenetv1 export
|
||||
device_id: 0
|
||||
ckpt_file: "/cache/data/mobilenetv1-90_625.ckpt"
|
||||
width: 224
|
||||
height: 224
|
||||
file_name: "mobilenetv1"
|
||||
file_format: "AIR"
|
||||
|
||||
# postprocess
|
||||
result_dir: ''
|
||||
label_dir: ''
|
||||
dataset_name: 'cifar10'
|
||||
|
||||
# preprocess
|
||||
# data_path: '' help='eval data dir'
|
||||
result_path: './preprocess_Result/'
|
||||
|
||||
---
|
||||
# Config description for each option
|
||||
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||
data_url: 'Dataset url for obs'
|
||||
train_url: 'Training output url for obs'
|
||||
data_path: 'Dataset path for local'
|
||||
output_path: 'Training output path for local'
|
||||
result_dir: "result files path."
|
||||
label_dir: "image file path."
|
||||
|
||||
file_name: "output file name."
|
||||
dataset: "Dataset, either cifar10 or imagenet2012"
|
||||
parameter_server: 'Run parameter server train'
|
||||
width: 'input width'
|
||||
height: 'input height'
|
||||
device_target: 'Target device type'
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
only_create_dataset: 'If set it true, only create Mindrecord, default is false.'
|
||||
run_distribute: 'Run distribute, default is false.'
|
||||
do_train: 'Do train or not, default is true.'
|
||||
do_eval: 'Do eval or not, default is false.'
|
||||
pre_trained: 'Pretrained checkpoint path'
|
||||
device_id: 'Device id, default is 0.'
|
||||
device_num: 'Use device nums, default is 1.'
|
||||
rank_id: 'Rank id, default is 0.'
|
||||
file_format: 'file format'
|
||||
|
||||
---
|
||||
device_target: ['Ascend', 'GPU', 'CPU']
|
||||
file_format: ["AIR", "ONNX", "MINDIR"]
|
||||
dataset_name: ["cifar10", "imagenet2012"]
|
|
@ -0,0 +1,96 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
checkpoint_file: './checkpoint/mobilenetv1-90_625.ckpt'
|
||||
device_target: Ascend
|
||||
enable_profiling: False
|
||||
|
||||
# ==============================================================================
|
||||
modelarts_dataset_unzip_name: 'ImageNet_Original'
|
||||
need_modelarts_dataset_unzip: True
|
||||
|
||||
# config for mobilenet, imagenet2012
|
||||
class_num: 1001
|
||||
batch_size: 256
|
||||
loss_scale: 1024
|
||||
momentum: 0.9
|
||||
weight_decay: 0.0001 # 1e-4
|
||||
epoch_size: 90
|
||||
pretrain_epoch_size: 0
|
||||
save_checkpoint: True
|
||||
save_checkpoint_epochs: 5
|
||||
keep_checkpoint_max: 10
|
||||
save_checkpoint_path: "./"
|
||||
warmup_epochs: 0
|
||||
lr_decay_mode: "linear"
|
||||
use_label_smooth: True
|
||||
label_smooth_factor: 0.1
|
||||
lr_init: 0
|
||||
lr_max: 0.8
|
||||
lr_end: 0.0
|
||||
|
||||
# Image classification - train
|
||||
dataset: 'imagenet2012'
|
||||
run_distribute: True
|
||||
device_num: 1
|
||||
dataset_path: "/cache/data"
|
||||
device_target: 'Ascend'
|
||||
pre_trained: "./mobilenetv2-200_625.ckpt" # "./mobilenetv1-90_625.ckpt"
|
||||
parameter_server: False
|
||||
|
||||
# Image classification - eval
|
||||
checkpoint_path: "./mobilenetv1-90_625.ckpt"
|
||||
|
||||
# mobilenetv1 export
|
||||
device_id: 0
|
||||
ckpt_file: "/cache/data/mobilenetv1-90_625.ckpt"
|
||||
width: 224
|
||||
height: 224
|
||||
file_name: "mobilenetv1"
|
||||
file_format: "AIR"
|
||||
|
||||
# postprocess
|
||||
result_dir: ''
|
||||
label_dir: ''
|
||||
dataset_name: 'imagenet2012'
|
||||
|
||||
# preprocess
|
||||
# data_path: '' help='eval data dir'
|
||||
result_path: './preprocess_Result/'
|
||||
|
||||
---
|
||||
# Config description for each option
|
||||
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||
data_url: 'Dataset url for obs'
|
||||
train_url: 'Training output url for obs'
|
||||
data_path: 'Dataset path for local'
|
||||
output_path: 'Training output path for local'
|
||||
result_dir: "result files path."
|
||||
label_dir: "image file path."
|
||||
|
||||
file_name: "output file name."
|
||||
dataset: "Dataset, either cifar10 or imagenet2012"
|
||||
parameter_server: 'Run parameter server train'
|
||||
width: 'input width'
|
||||
height: 'input height'
|
||||
device_target: 'Target device type'
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
only_create_dataset: 'If set it true, only create Mindrecord, default is false.'
|
||||
run_distribute: 'Run distribute, default is false.'
|
||||
do_train: 'Do train or not, default is true.'
|
||||
do_eval: 'Do eval or not, default is false.'
|
||||
pre_trained: 'Pretrained checkpoint path'
|
||||
device_id: 'Device id, default is 0.'
|
||||
device_num: 'Use device nums, default is 1.'
|
||||
rank_id: 'Rank id, default is 0.'
|
||||
file_format: 'file format'
|
||||
|
||||
---
|
||||
device_target: ['Ascend', 'GPU', 'CPU']
|
||||
file_format: ["AIR", "ONNX", "MINDIR"]
|
||||
dataset_name: ["cifar10", "imagenet2012"]
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
"""eval mobilenet_v1."""
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
|
@ -22,26 +22,81 @@ from mindspore.train.model import Model
|
|||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.CrossEntropySmooth import CrossEntropySmooth
|
||||
from src.mobilenet_v1 import mobilenet_v1 as mobilenet
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--dataset', type=str, default=None, help='Dataset, either cifar10 or imagenet2012')
|
||||
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
set_seed(1)
|
||||
|
||||
if args_opt.dataset == 'cifar10':
|
||||
from src.config import config1 as config
|
||||
if config.dataset == 'cifar10':
|
||||
from src.dataset import create_dataset1 as create_dataset
|
||||
else:
|
||||
from src.config import config2 as config
|
||||
from src.dataset import create_dataset2 as create_dataset
|
||||
|
||||
if __name__ == '__main__':
|
||||
target = args_opt.device_target
|
||||
|
||||
def modelarts_process():
|
||||
""" modelarts process """
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),\
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if config.need_modelarts_dataset_unzip:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
print("#" * 200, os.listdir(save_dir_1))
|
||||
print("#" * 200, os.listdir(os.path.join(config.data_path, config.modelarts_dataset_unzip_name)))
|
||||
|
||||
config.dataset_path = os.path.join(config.data_path, config.modelarts_dataset_unzip_name)
|
||||
config.checkpoint_path = os.path.join(config.output_path, config.checkpoint_path)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_process)
|
||||
def eval_mobilenetv1():
|
||||
config.dataset_path = os.path.join(config.dataset_path, 'validation_preprocess')
|
||||
print('\nconfig:\n', config)
|
||||
target = config.device_target
|
||||
|
||||
# init context
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
||||
|
@ -50,20 +105,20 @@ if __name__ == '__main__':
|
|||
context.set_context(device_id=device_id)
|
||||
|
||||
# create dataset
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size,
|
||||
dataset = create_dataset(dataset_path=config.dataset_path, do_train=False, batch_size=config.batch_size,
|
||||
target=target)
|
||||
step_size = dataset.get_dataset_size()
|
||||
# step_size = dataset.get_dataset_size()
|
||||
|
||||
# define net
|
||||
net = mobilenet(class_num=config.class_num)
|
||||
|
||||
# load checkpoint
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
param_dict = load_checkpoint(config.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
# define loss, model
|
||||
if args_opt.dataset == "imagenet2012":
|
||||
if config.dataset == "imagenet2012":
|
||||
if not config.use_label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
loss = CrossEntropySmooth(sparse=True, reduction='mean',
|
||||
|
@ -76,4 +131,7 @@ if __name__ == '__main__':
|
|||
|
||||
# eval model
|
||||
res = model.eval(dataset)
|
||||
print("result:", res, "ckpt=", args_opt.checkpoint_path)
|
||||
print("result:", res, "ckpt=", config.checkpoint_path)
|
||||
|
||||
if __name__ == '__main__':
|
||||
eval_mobilenetv1()
|
||||
|
|
|
@ -13,44 +13,25 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import export, load_checkpoint
|
||||
|
||||
from src.mobilenet_v1 import mobilenet_v1 as mobilenet
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.device_adapter import get_device_id
|
||||
|
||||
parser = argparse.ArgumentParser(description="mobilenetv1 export")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--dataset", type=str, default="imagenet2012", help="Dataset, either cifar10 or imagenet2012")
|
||||
parser.add_argument('--width', type=int, default=224, help='input width')
|
||||
parser.add_argument('--height', type=int, default=224, help='input height')
|
||||
parser.add_argument("--file_name", type=str, default="mobilenetv1", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
|
||||
if args.dataset == "cifar10":
|
||||
from src.config import config1 as config
|
||||
else:
|
||||
from src.config import config2 as config
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
|
||||
if __name__ == "__main__":
|
||||
target = args.device_target
|
||||
target = config.device_target
|
||||
if target != "GPU":
|
||||
context.set_context(device_id=args.device_id)
|
||||
context.set_context(device_id=get_device_id())
|
||||
|
||||
network = mobilenet(class_num=config.class_num)
|
||||
|
||||
param_dict = load_checkpoint(args.ckpt_file, net=network)
|
||||
|
||||
param_dict = load_checkpoint(config.ckpt_file, net=network)
|
||||
network.set_train(False)
|
||||
|
||||
input_data = Tensor(np.zeros([config.batch_size, 3, args.height, args.width]).astype(np.float32))
|
||||
|
||||
export(network, input_data, file_name=args.file_name, file_format=args.file_format)
|
||||
input_data = Tensor(np.zeros([config.batch_size, 3, config.height, config.width]).astype(np.float32))
|
||||
export(network, input_data, file_name=config.file_name, file_format=config.file_format)
|
||||
|
|
|
@ -15,41 +15,33 @@
|
|||
"""postprocess for 310 inference"""
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore.nn import Top1CategoricalAccuracy, Top5CategoricalAccuracy
|
||||
from src.model_utils.config import config
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="postprocess")
|
||||
parser.add_argument("--result_dir", type=str, required=True, help="result files path.")
|
||||
parser.add_argument("--label_dir", type=str, required=True, help="image file path.")
|
||||
parser.add_argument('--dataset_name', type=str, choices=["cifar10", "imagenet2012"], default="imagenet2012")
|
||||
args = parser.parse_args()
|
||||
|
||||
def calcul_acc(lab, preds):
|
||||
return sum(1 for x, y in zip(lab, preds) if x == y) / len(lab)
|
||||
|
||||
if __name__ == '__main__':
|
||||
batch_size = 1
|
||||
top1_acc = Top1CategoricalAccuracy()
|
||||
rst_path = args.result_dir
|
||||
rst_path = config.result_dir
|
||||
label_list = []
|
||||
pred_list = []
|
||||
|
||||
if args.dataset_name == "cifar10":
|
||||
from src.config import config1 as cfg
|
||||
labels = np.load(args.label_dir, allow_pickle=True)
|
||||
if config.dataset_name == "cifar10":
|
||||
labels = np.load(config.label_dir, allow_pickle=True)
|
||||
for idx, label in enumerate(labels):
|
||||
f_name = os.path.join(rst_path, "mobilenetv1_data_bs" + str(cfg.batch_size) + "_" + str(idx) + "_0.bin")
|
||||
f_name = os.path.join(rst_path, "mobilenetv1_data_bs" + str(config.batch_size) + "_" + str(idx) + "_0.bin")
|
||||
pred = np.fromfile(f_name, np.float32)
|
||||
pred = pred.reshape(cfg.batch_size, int(pred.shape[0] / cfg.batch_size))
|
||||
pred = pred.reshape(config.batch_size, int(pred.shape[0] / config.batch_size))
|
||||
top1_acc.update(pred, labels[idx])
|
||||
print("acc: ", top1_acc.eval())
|
||||
else:
|
||||
from src.config import config2 as cfg
|
||||
top5_acc = Top5CategoricalAccuracy()
|
||||
file_list = os.listdir(rst_path)
|
||||
with open(args.label_dir, "r") as label:
|
||||
with open(config.label_dir, "r") as label:
|
||||
labels = json.load(label)
|
||||
for f in file_list:
|
||||
label = f.split("_0.bin")[0] + ".JPEG"
|
||||
|
|
|
@ -14,11 +14,13 @@
|
|||
# ============================================================================
|
||||
"""preprocess"""
|
||||
import os
|
||||
import argparse
|
||||
import json
|
||||
import numpy as np
|
||||
from src.dataset import create_dataset1
|
||||
|
||||
from src.model_utils.config import config
|
||||
|
||||
|
||||
def create_label(result_path, dir_path):
|
||||
print("[WARNING] Create imagenet label. Currently only use for Imagenet2012!")
|
||||
dirs = os.listdir(dir_path)
|
||||
|
@ -41,33 +43,23 @@ def create_label(result_path, dir_path):
|
|||
|
||||
print("[INFO] Completed! Total {} data.".format(total))
|
||||
|
||||
parser = argparse.ArgumentParser('preprocess')
|
||||
parser.add_argument('--dataset', type=str, choices=["cifar10", "imagenet2012"], default="imagenet2012")
|
||||
parser.add_argument('--data_path', type=str, default='', help='eval data dir')
|
||||
parser.add_argument('--result_path', type=str, default='./preprocess_Result/', help='result path')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.dataset == "cifar10":
|
||||
from src.config import config1 as cfg
|
||||
else:
|
||||
from src.config import config2 as cfg
|
||||
|
||||
args.per_batch_size = cfg.batch_size
|
||||
#args.image_size = list(map(int, cfg.image_size.split(',')))
|
||||
config.per_batch_size = config.batch_size
|
||||
#config.image_size = list(map(int, config.image_size.split(',')))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.dataset == "cifar10":
|
||||
dataset = create_dataset1(args.data_path, False, args.per_batch_size)
|
||||
img_path = os.path.join(args.result_path, "00_data")
|
||||
if config.dataset == "cifar10":
|
||||
dataset = create_dataset1(config.data_path, False, config.per_batch_size)
|
||||
img_path = os.path.join(config.result_path, "00_data")
|
||||
os.makedirs(img_path)
|
||||
label_list = []
|
||||
for idx, data in enumerate(dataset.create_dict_iterator(output_numpy=True)):
|
||||
file_name = "mobilenetv1_data_bs" + str(args.per_batch_size) + "_" + str(idx) + ".bin"
|
||||
file_name = "mobilenetv1_data_bs" + str(config.per_batch_size) + "_" + str(idx) + ".bin"
|
||||
file_path = os.path.join(img_path, file_name)
|
||||
data["image"].tofile(file_path)
|
||||
label_list.append(data["label"])
|
||||
np.save(os.path.join(args.result_path, "cifar10_label_ids.npy"), label_list)
|
||||
np.save(os.path.join(config.result_path, "cifar10_label_ids.npy"), label_list)
|
||||
print("=" * 20, "export bin files finished", "=" * 20)
|
||||
else:
|
||||
create_label(args.result_path, args.data_path)
|
||||
create_label(config.result_path, config.data_path)
|
||||
|
|
|
@ -68,6 +68,20 @@ export RANK_TABLE_FILE=$PATH1
|
|||
export SERVER_ID=0
|
||||
rank_start=$((DEVICE_NUM * SERVER_ID))
|
||||
|
||||
BASE_PATH=$(cd ./"`dirname $0`" || exit; pwd)
|
||||
if [ $# -ge 1 ]; then
|
||||
if [ $1 == 'cifar10' ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/../default_config.yaml"
|
||||
elif [ $1 == 'imagenet2012' ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/../default_config_imagenet.yaml"
|
||||
else
|
||||
echo "Unrecognized parameter"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
CONFIG_FILE="${BASE_PATH}/../default_config.yaml"
|
||||
fi
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=${i}
|
||||
|
@ -75,6 +89,7 @@ do
|
|||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp ../*.yaml ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
|
@ -82,12 +97,12 @@ do
|
|||
env > env.log
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
python train.py --dataset=$1 --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &> log &
|
||||
python train.py --config_path=$CONFIG_FILE --dataset=$1 --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &> log.txt &
|
||||
fi
|
||||
|
||||
if [ $# == 4 ]
|
||||
then
|
||||
python train.py --dataset=$1 --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 --pre_trained=$PATH3 &> log &
|
||||
python train.py --config_path=$CONFIG_FILE --dataset=$1 --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 --pre_trained=$PATH3 &> log.txt &
|
||||
fi
|
||||
|
||||
cd ..
|
||||
|
|
|
@ -56,16 +56,31 @@ export DEVICE_ID=0
|
|||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
BASE_PATH=$(cd ./"`dirname $0`" || exit; pwd)
|
||||
if [ $# -ge 1 ]; then
|
||||
if [ $1 == 'cifar10' ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/../default_config.yaml"
|
||||
elif [ $1 == 'imagenet2012' ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/../default_config_imagenet.yaml"
|
||||
else
|
||||
echo "Unrecognized parameter"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
CONFIG_FILE="${BASE_PATH}/../default_config.yaml"
|
||||
fi
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp ../*.yaml ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env > env.log
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
python eval.py --dataset=$1 --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log &
|
||||
python eval.py --config_path=$CONFIG_FILE --dataset=$1 --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log_eval.txt &
|
||||
cd ..
|
||||
|
|
|
@ -50,15 +50,30 @@ then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
BASE_PATH=$(cd ./"`dirname $0`" || exit; pwd)
|
||||
if [ $# -ge 1 ]; then
|
||||
if [ $1 == 'cifar10' ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/../default_config.yaml"
|
||||
elif [ $1 == 'imagenet2012' ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/../default_config_imagenet.yaml"
|
||||
else
|
||||
echo "Unrecognized parameter"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
CONFIG_FILE="${BASE_PATH}/../default_config.yaml"
|
||||
fi
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp ../*.yaml ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env > env.log
|
||||
python eval.py --dataset=$1 --dataset_path=$PATH1 --checkpoint_path=$PATH2 --device_target=CPU &> log &
|
||||
python eval.py --config_path=$CONFIG_FILE --dataset=$1 --dataset_path=$PATH1 --checkpoint_path=$PATH2 --device_target=CPU &> log_eval_cpu.txt &
|
||||
cd ..
|
||||
|
|
|
@ -33,6 +33,21 @@ dataset_path=$(get_real_path $2)
|
|||
dataset_name="imagenet2012"
|
||||
DVPP="CPU"
|
||||
|
||||
BASE_PATH=$(cd ./"`dirname $0`" || exit; pwd)
|
||||
if [ $# -ge 1 ]; then
|
||||
if [ $dataset_name == 'cifar10' ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/../default_config.yaml"
|
||||
elif [ $dataset_name == 'imagenet2012' ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/../default_config_imagenet.yaml"
|
||||
else
|
||||
echo "Unrecognized parameter"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
CONFIG_FILE="${BASE_PATH}/../default_config.yaml"
|
||||
fi
|
||||
|
||||
|
||||
device_id=0
|
||||
if [ $# == 3 ]; then
|
||||
device_id=$3
|
||||
|
@ -60,7 +75,6 @@ export SLOG_PRINT_to_STDOUT=0
|
|||
export GLOG_v=2
|
||||
export DUMP_GE_GRAPH=2
|
||||
|
||||
|
||||
export ASCEND_HOME=/usr/local/Ascend
|
||||
|
||||
export PATH=$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/toolkit/bin:$PATH
|
||||
|
@ -81,7 +95,7 @@ function preprocess_data()
|
|||
rm -rf ./preprocess_Result
|
||||
fi
|
||||
mkdir preprocess_Result
|
||||
python3.7 ../preprocess.py --dataset=$dataset_name --data_path=$dataset_path --result_path=./preprocess_Result/
|
||||
python3.7 ../preprocess.py --config_path=$CONFIG_FILE --dataset=$dataset_name --data_path=$dataset_path --result_path=./preprocess_Result/
|
||||
}
|
||||
|
||||
function compile_app()
|
||||
|
@ -112,7 +126,7 @@ function infer()
|
|||
function cal_acc()
|
||||
{
|
||||
|
||||
python3.7 ../postprocess.py --result_dir=./result_Files --label_dir=./preprocess_Result/imagenet_label.json &> acc.log
|
||||
python3.7 ../postprocess.py --config_path=$CONFIG_FILE --result_dir=./result_Files --label_dir=./preprocess_Result/imagenet_label.json &> acc.log
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -59,12 +59,27 @@ export DEVICE_ID=0
|
|||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
BASE_PATH=$(cd ./"`dirname $0`" || exit; pwd)
|
||||
if [ $# -ge 1 ]; then
|
||||
if [ $1 == 'cifar10' ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/../default_config.yaml"
|
||||
elif [ $1 == 'imagenet2012' ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/../default_config_imagenet.yaml"
|
||||
else
|
||||
echo "Unrecognized parameter"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
CONFIG_FILE="${BASE_PATH}/../default_config.yaml"
|
||||
fi
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp ../*.yaml ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
|
@ -72,11 +87,11 @@ echo "start training for device $DEVICE_ID"
|
|||
env > env.log
|
||||
if [ $# == 2 ]
|
||||
then
|
||||
python train.py --dataset=$1 --dataset_path=$PATH1 &> log &
|
||||
python train.py --config_path=$CONFIG_FILE --dataset=$1 --dataset_path=$PATH1 &> log.txt &
|
||||
fi
|
||||
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
python train.py --dataset=$1 --dataset_path=$PATH1 --pre_trained=$PATH2 &> log &
|
||||
python train.py --config_path=$CONFIG_FILE --dataset=$1 --dataset_path=$PATH1 --pre_trained=$PATH2 &> log.txt &
|
||||
fi
|
||||
cd ..
|
||||
|
|
|
@ -53,23 +53,38 @@ then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
BASE_PATH=$(cd ./"`dirname $0`" || exit; pwd)
|
||||
if [ $# -ge 1 ]; then
|
||||
if [ $1 == 'cifar10' ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/../default_config.yaml"
|
||||
elif [ $1 == 'imagenet2012' ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/../default_config_imagenet.yaml"
|
||||
else
|
||||
echo "Unrecognized parameter"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
CONFIG_FILE="${BASE_PATH}/../default_config.yaml"
|
||||
fi
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp ../*.yaml ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
env > env.log
|
||||
if [ $# == 2 ]
|
||||
then
|
||||
python train.py --dataset=$1 --dataset_path=$PATH1 --device_target=CPU &> log &
|
||||
python train.py --config_path=$CONFIG_FILE --dataset=$1 --dataset_path=$PATH1 --device_target=CPU &> log.txt &
|
||||
fi
|
||||
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
python train.py --dataset=$1 --dataset_path=$PATH1 --pre_trained=$PATH2 --device_target=CPU &> log &
|
||||
python train.py --config_path=$CONFIG_FILE --dataset=$1 --dataset_path=$PATH1 --pre_trained=$PATH2 --device_target=CPU &> log.txt &
|
||||
fi
|
||||
cd ..
|
||||
|
|
|
@ -1,60 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
# config for mobilenet, cifar10
|
||||
config1 = ed({
|
||||
"class_num": 10,
|
||||
"batch_size": 32,
|
||||
"loss_scale": 1024,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 1e-4,
|
||||
"epoch_size": 90,
|
||||
"pretrain_epoch_size": 0,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 5,
|
||||
"keep_checkpoint_max": 10,
|
||||
"save_checkpoint_path": "./",
|
||||
"warmup_epochs": 5,
|
||||
"lr_decay_mode": "poly",
|
||||
"lr_init": 0.01,
|
||||
"lr_end": 0.00001,
|
||||
"lr_max": 0.1
|
||||
})
|
||||
|
||||
# config for mobilenet, imagenet2012
|
||||
config2 = ed({
|
||||
"class_num": 1001,
|
||||
"batch_size": 256,
|
||||
"loss_scale": 1024,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 1e-4,
|
||||
"epoch_size": 90,
|
||||
"pretrain_epoch_size": 0,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 5,
|
||||
"keep_checkpoint_max": 10,
|
||||
"save_checkpoint_path": "./",
|
||||
"warmup_epochs": 0,
|
||||
"lr_decay_mode": "linear",
|
||||
"use_label_smooth": True,
|
||||
"label_smooth_factor": 0.1,
|
||||
"lr_init": 0,
|
||||
"lr_max": 0.8,
|
||||
"lr_end": 0.0
|
||||
})
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from mindspore.profiler import Profiler
|
||||
from .config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler = Profiler()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler.analyse()
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -14,8 +14,7 @@
|
|||
# ============================================================================
|
||||
"""train mobilenet_v1."""
|
||||
import os
|
||||
import argparse
|
||||
import ast
|
||||
import time
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
|
@ -32,40 +31,91 @@ import mindspore.common.initializer as weight_init
|
|||
from src.lr_generator import get_lr
|
||||
from src.CrossEntropySmooth import CrossEntropySmooth
|
||||
from src.mobilenet_v1 import mobilenet_v1 as mobilenet
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--dataset', type=str, default=None, help='Dataset, either cifar10 or imagenet2012')
|
||||
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
|
||||
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
||||
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
|
||||
parser.add_argument('--parameter_server', type=ast.literal_eval, default=False, help='Run parameter server train')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
set_seed(1)
|
||||
|
||||
if args_opt.dataset == 'cifar10':
|
||||
from src.config import config1 as config
|
||||
if config.dataset == 'cifar10':
|
||||
from src.dataset import create_dataset1 as create_dataset
|
||||
else:
|
||||
from src.config import config2 as config
|
||||
from src.dataset import create_dataset2 as create_dataset
|
||||
|
||||
if __name__ == '__main__':
|
||||
target = args_opt.device_target
|
||||
|
||||
def modelarts_pre_process():
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),\
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if config.need_modelarts_dataset_unzip:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
print("#" * 200, os.listdir(save_dir_1))
|
||||
print("#" * 200, os.listdir(os.path.join(config.data_path, config.modelarts_dataset_unzip_name)))
|
||||
|
||||
config.dataset_path = os.path.join(config.data_path, config.modelarts_dataset_unzip_name)
|
||||
config.ckpt_path = config.output_path
|
||||
config.pre_trained = os.path.join(config.dataset_path, config.pre_trained)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def train_mobilenetv1():
|
||||
config.dataset_path = os.path.join(config.dataset_path, 'train')
|
||||
target = config.device_target
|
||||
ckpt_save_dir = config.save_checkpoint_path
|
||||
|
||||
# init context
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
||||
if args_opt.parameter_server:
|
||||
if config.parameter_server:
|
||||
context.set_ps_context(enable_ps=True)
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
if args_opt.run_distribute:
|
||||
if config.run_distribute:
|
||||
if target == "Ascend":
|
||||
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
|
||||
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
context.set_auto_parallel_context(device_num=get_device_num(), parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
context.set_auto_parallel_context(all_reduce_fusion_config=[75])
|
||||
|
@ -77,18 +127,18 @@ if __name__ == '__main__':
|
|||
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
|
||||
|
||||
# create dataset
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1,
|
||||
dataset = create_dataset(dataset_path=config.dataset_path, do_train=True, repeat_num=1,
|
||||
batch_size=config.batch_size, target=target)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
# define net
|
||||
net = mobilenet(class_num=config.class_num)
|
||||
if args_opt.parameter_server:
|
||||
if config.parameter_server:
|
||||
net.set_param_ps()
|
||||
|
||||
# init weight
|
||||
if args_opt.pre_trained:
|
||||
param_dict = load_checkpoint(args_opt.pre_trained)
|
||||
if config.pre_trained:
|
||||
param_dict = load_checkpoint(config.pre_trained)
|
||||
load_param_into_net(net, param_dict)
|
||||
else:
|
||||
for _, cell in net.cells_and_names():
|
||||
|
@ -125,7 +175,7 @@ if __name__ == '__main__':
|
|||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay,
|
||||
config.loss_scale)
|
||||
# define loss, model
|
||||
if args_opt.dataset == "imagenet2012":
|
||||
if config.dataset == "imagenet2012":
|
||||
if not config.use_label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
loss = CrossEntropySmooth(sparse=True, reduction="mean",
|
||||
|
@ -143,7 +193,7 @@ if __name__ == '__main__':
|
|||
time_cb = TimeMonitor(data_size=step_size)
|
||||
loss_cb = LossMonitor()
|
||||
cb = [time_cb, loss_cb]
|
||||
if config.save_checkpoint and device_id % min(8, args_opt.device_num) == 0:
|
||||
if config.save_checkpoint and device_id % min(8, get_device_num()) == 0:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="mobilenetv1", directory=ckpt_save_dir, config=config_ck)
|
||||
|
@ -151,4 +201,7 @@ if __name__ == '__main__':
|
|||
|
||||
# train model
|
||||
model.train(config.epoch_size - config.pretrain_epoch_size, dataset, callbacks=cb,
|
||||
sink_size=dataset.get_dataset_size(), dataset_sink_mode=(not args_opt.parameter_server))
|
||||
sink_size=dataset.get_dataset_size(), dataset_sink_mode=(not config.parameter_server))
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_mobilenetv1()
|
||||
|
|
|
@ -77,13 +77,19 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
|
|||
│ ├──run_train_nfs_cache.sh # shell script for train with NFS dataset and leverage caching service for better performance
|
||||
├── src
|
||||
│ ├──aipp.cfg # aipp config
|
||||
│ ├──args.py # parse args
|
||||
│ ├──config.py # parameter configuration
|
||||
│ ├──dataset.py # creating dataset
|
||||
│ ├──lr_generator.py # learning rate config
|
||||
│ ├──mobilenetV2.py # MobileNetV2 architecture
|
||||
│ ├──models.py # contain define_net and Loss, Monitor
|
||||
│ ├──utils.py # utils to load ckpt_file for fine tune or incremental learn
|
||||
│ └──model_utils
|
||||
│ ├──config.py # Processing configuration parameters
|
||||
│ ├──device_adapter.py # Get cloud ID
|
||||
│ ├──local_adapter.py # Get local ID
|
||||
│ └──moxing_adapter.py # Parameter processing
|
||||
├── default_config.yaml # Training parameter profile(ascend)
|
||||
├── default_config_cpu.yaml # Training parameter profile(cpu)
|
||||
├── default_config_gpu.yaml # Training parameter profile(gpu)
|
||||
├── train.py # training script
|
||||
├── eval.py # evaluation script
|
||||
├── export.py # export mindir script
|
||||
|
|
|
@ -73,24 +73,30 @@ MobileNetV2总体网络架构如下:
|
|||
|
||||
```python
|
||||
├── MobileNetV2
|
||||
├── README.md # MobileNetV2相关描述
|
||||
├── ascend310_infer # 用于310推理
|
||||
├── README.md # MobileNetV2相关描述
|
||||
├── ascend310_infer # 用于310推理
|
||||
├── scripts
|
||||
│ ├──run_train.sh # 使用CPU、GPU或Ascend进行训练、微调或增量学习的shell脚本
|
||||
│ ├──run_eval.sh # 使用CPU、GPU或Ascend进行评估的shell脚本
|
||||
│ ├──cache_util.sh # 包含一些使用cache的帮助函数
|
||||
│ ├──run_train_nfs_cache.sh # 使用NFS的数据集进行训练并利用缓存服务进行加速的shell脚本
|
||||
│ ├──run_infer_310.sh # 使用Dvpp 或CPU算子进行推理的shell脚本
|
||||
│ ├──run_infer_310.sh # 使用Dvpp 或CPU算子进行推理的shell脚本
|
||||
├── src
|
||||
│ ├──aipp.cfg # aipp配置
|
||||
│ ├──args.py # 参数解析
|
||||
│ ├──config.py # 参数配置
|
||||
│ ├──dataset.py # 创建数据集
|
||||
│ ├──launch.py # 启动python脚本
|
||||
│ ├──lr_generator.py # 配置学习率
|
||||
│ ├──mobilenetV2.py # MobileNetV2架构
|
||||
│ ├──models.py # 加载define_net、Loss、及Monitor
|
||||
│ ├──utils.py # 加载ckpt_file进行微调或增量学习
|
||||
│ └──model_utils
|
||||
│ ├──config.py # 获取.yaml配置参数
|
||||
│ ├──device_adapter.py # 获取云上id
|
||||
│ ├──local_adapter.py # 获取本地id
|
||||
│ └──moxing_adapter.py # 云上数据准备
|
||||
├── default_config.yaml # 训练配置参数(ascend)
|
||||
├── default_config_cpu.yaml # 训练配置参数(cpu)
|
||||
├── default_config_gpu.yaml # 训练配置参数(gpu)
|
||||
├── train.py # 训练脚本
|
||||
├── eval.py # 评估脚本
|
||||
├── export.py # 模型导出脚本
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
checkpoint_path: './checkpoint/'
|
||||
device_target: Ascend
|
||||
enable_profiling: False
|
||||
|
||||
# ==============================================================================
|
||||
modelarts_dataset_unzip_name: 'ImageNet_Original'
|
||||
need_modelarts_dataset_unzip: True
|
||||
|
||||
num_classes: 1000
|
||||
image_height: 224
|
||||
image_width: 224
|
||||
batch_size: 256
|
||||
epoch_size: 200
|
||||
warmup_epochs: 4
|
||||
lr_init: 0.00
|
||||
lr_end: 0.00
|
||||
lr_max: 0.4
|
||||
momentum: 0.9
|
||||
weight_decay: 0.00001 # 4e-5
|
||||
label_smooth: 0.1
|
||||
loss_scale: 1024
|
||||
save_checkpoint: True
|
||||
save_checkpoint_epochs: 1
|
||||
keep_checkpoint_max: 200
|
||||
save_checkpoint_path: "./"
|
||||
platform: 'Ascend'
|
||||
device_id: int(os.getenv('DEVICE_ID', '0'))
|
||||
rank_id: int(os.getenv('RANK_ID', '0'))
|
||||
rank_size: int(os.getenv('RANK_SIZE', '1'))
|
||||
run_distribute: int(os.getenv('RANK_SIZE', '1')) > 1.
|
||||
activation: "Softmax"
|
||||
|
||||
# Image classification trian. train_parse_args():return train_args
|
||||
dataset_path: "/cache/data"
|
||||
pretrain_ckpt: "./mobilenetv2-200_625.ckpt"
|
||||
freeze_layer: ""
|
||||
filter_head: False
|
||||
enable_cache: False
|
||||
cache_session_id: ""
|
||||
is_training: True
|
||||
|
||||
# mobilenetv2 eval
|
||||
is_training_eval: False
|
||||
run_distribute_eval: False
|
||||
|
||||
# mobilenetv2 export
|
||||
device_id_export: 0
|
||||
batch_size_export: 1
|
||||
ckpt_file: "/cache/train/mobilenetv2-200_625.ckpt"
|
||||
file_name: "mobilenetv2"
|
||||
file_format: "MINDIR"
|
||||
is_training_export: False
|
||||
run_distribute_export: False
|
||||
|
||||
# postprocess.py / mobilenetv2 acc calculation
|
||||
batch_size_postprocess: 1
|
||||
result_path: '' # "result files path."
|
||||
label_path: '' # "label path."
|
||||
|
||||
---
|
||||
# Config description for each option
|
||||
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||
data_url: 'Dataset url for obs'
|
||||
train_url: 'Training output url for obs'
|
||||
data_path: 'Dataset path for local'
|
||||
output_path: 'Training output path for local'
|
||||
ann_file: 'Ann file, default is val.json.'
|
||||
|
||||
pretrain_ckpt: 'Pretrained checkpoint path for fine tune or incremental learning'
|
||||
platform: 'Target device type'
|
||||
freeze_layer: 'freeze the weights of network from start to which layers'
|
||||
filter_head: 'Filter head weight parameters when load checkpoint, default is False.'
|
||||
enable_cache: 'Caching the dataset in memory to speedup dataset processing, default is False.'
|
||||
cache_session_id: 'The session id for cache service.'
|
||||
file_name: "output file name."
|
||||
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
only_create_dataset: 'If set it true, only create Mindrecord, default is false.'
|
||||
run_distribute: 'Run distribute, default is false.'
|
||||
do_train: 'Do train or not, default is true.'
|
||||
do_eval: 'Do eval or not, default is false.'
|
||||
dataset: 'Dataset, default is coco.'
|
||||
pre_trained: 'Pretrain file path.'
|
||||
device_id: 'Device id, default is 0.'
|
||||
device_num: 'Use device nums, default is 1.'
|
||||
rank_id: 'Rank id, default is 0.'
|
||||
file_format: 'file format'
|
||||
img_path: "image file path."
|
||||
result_path: "result file path."
|
||||
|
||||
---
|
||||
platform: ['Ascend', 'GPU', 'CPU']
|
||||
file_format: ["AIR", "ONNX", "MINDIR"]
|
||||
freeze_layer: ["", "none", "backbone"]
|
|
@ -0,0 +1,93 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
checkpoint_path: './checkpoint/'
|
||||
device_target: Ascend
|
||||
enable_profiling: False
|
||||
|
||||
# ==============================================================================
|
||||
modelarts_dataset_unzip_name: 'ImageNet_Original'
|
||||
need_modelarts_dataset_unzip: True
|
||||
|
||||
num_classes: 26
|
||||
image_height: 224
|
||||
image_width: 224
|
||||
batch_size: 150
|
||||
epoch_size: 15
|
||||
warmup_epochs: 0
|
||||
lr_init: .0
|
||||
lr_end: 0.03
|
||||
lr_max: 0.03
|
||||
momentum: 0.9
|
||||
weight_decay: 0.00001 # 4e-5
|
||||
label_smooth: 0.1
|
||||
loss_scale: 1024
|
||||
save_checkpoint: True
|
||||
save_checkpoint_epochs: 1
|
||||
keep_checkpoint_max: 20
|
||||
save_checkpoint_path: "./"
|
||||
platform: 'CPU'
|
||||
run_distribute: False
|
||||
activation: "Softmax"
|
||||
run_distribute: False
|
||||
|
||||
# Image classification trian. train_parse_args():return train_args
|
||||
dataset_path: "/cache/data"
|
||||
pretrain_ckpt: "./mobilenetv2-200_625.ckpt"
|
||||
freeze_layer: ""
|
||||
filter_head: False
|
||||
enable_cache: False
|
||||
cache_session_id: ""
|
||||
is_training: True
|
||||
|
||||
# mobilenetv2 eval
|
||||
is_training_eval: False
|
||||
run_distribute_eval: False
|
||||
|
||||
# mobilenetv2 export
|
||||
device_id_export: 0
|
||||
batch_size_export: 1
|
||||
ckpt_file: "/cache/train/mobilenetv2-200_625.ckpt"
|
||||
file_name: "mobilenetv2"
|
||||
file_format: "MINDIR"
|
||||
is_training_export: False
|
||||
run_distribute_export: False
|
||||
|
||||
# postprocess.py / mobilenetv2 acc calculation
|
||||
batch_size_postprocess: 1
|
||||
result_path: '' # "result files path."
|
||||
label_path: '' # "label path."
|
||||
|
||||
|
||||
---
|
||||
# Config description for each option
|
||||
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||
data_url: 'Dataset url for obs'
|
||||
train_url: 'Training output url for obs'
|
||||
data_path: 'Dataset path for local'
|
||||
output_path: 'Training output path for local'
|
||||
ann_file: 'Ann file, default is val.json.'
|
||||
|
||||
device_target: 'Target device type'
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
only_create_dataset: 'If set it true, only create Mindrecord, default is false.'
|
||||
run_distribute: 'Run distribute, default is false.'
|
||||
do_train: 'Do train or not, default is true.'
|
||||
do_eval: 'Do eval or not, default is false.'
|
||||
dataset: 'Dataset, default is coco.'
|
||||
pre_trained: 'Pretrain file path.'
|
||||
device_id: 'Device id, default is 0.'
|
||||
device_num: 'Use device nums, default is 1.'
|
||||
rank_id: 'Rank id, default is 0.'
|
||||
file_format: 'file format'
|
||||
img_path: "image file path."
|
||||
result_path: "result file path."
|
||||
|
||||
---
|
||||
device_target: ['Ascend', 'GPU', 'CPU']
|
||||
file_format: ["AIR", "ONNX", "MINDIR"]
|
|
@ -0,0 +1,91 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
checkpoint_path: './checkpoint/'
|
||||
device_target: Ascend
|
||||
enable_profiling: False
|
||||
|
||||
# ==============================================================================
|
||||
modelarts_dataset_unzip_name: 'ImageNet_Original'
|
||||
need_modelarts_dataset_unzip: True
|
||||
|
||||
num_classes: 1000
|
||||
image_height: 224
|
||||
image_width: 224
|
||||
batch_size: 150
|
||||
epoch_size: 200
|
||||
warmup_epochs: 0
|
||||
lr_init: .0
|
||||
lr_end: .0
|
||||
lr_max: 0.8
|
||||
momentum: 0.9
|
||||
weight_decay: 0.00001 # 4e-5
|
||||
label_smooth: 0.1
|
||||
loss_scale: 1024
|
||||
save_checkpoint: True
|
||||
save_checkpoint_epochs: 1
|
||||
keep_checkpoint_max: 200
|
||||
save_checkpoint_path: "./"
|
||||
platform: 'GPU'
|
||||
run_distribute: False
|
||||
activation: "Softmax"
|
||||
|
||||
# Image classification trian. train_parse_args():return train_args
|
||||
dataset_path: "/cache/data"
|
||||
pretrain_ckpt: "./mobilenetv2-200_625.ckpt"
|
||||
freeze_layer: ""
|
||||
filter_head: False
|
||||
enable_cache: False
|
||||
cache_session_id: ""
|
||||
is_training: True
|
||||
|
||||
# mobilenetv2 eval
|
||||
is_training_eval: False
|
||||
run_distribute_eval: False
|
||||
|
||||
# mobilenetv2 export
|
||||
device_id_export: 0
|
||||
batch_size_export: 1
|
||||
ckpt_file: "/cache/train/mobilenetv2-200_625.ckpt"
|
||||
file_name: "mobilenetv2"
|
||||
file_format: "MINDIR"
|
||||
is_training_export: False
|
||||
run_distribute_export: False
|
||||
|
||||
# postprocess.py / mobilenetv2 acc calculation
|
||||
batch_size_postprocess: 1
|
||||
result_path: '' # "result files path."
|
||||
label_path: '' # "label path."
|
||||
|
||||
---
|
||||
# Config description for each option
|
||||
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||
data_url: 'Dataset url for obs'
|
||||
train_url: 'Training output url for obs'
|
||||
data_path: 'Dataset path for local'
|
||||
output_path: 'Training output path for local'
|
||||
ann_file: 'Ann file, default is val.json.'
|
||||
|
||||
device_target: 'Target device type'
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
only_create_dataset: 'If set it true, only create Mindrecord, default is false.'
|
||||
run_distribute: 'Run distribute, default is false.'
|
||||
do_train: 'Do train or not, default is true.'
|
||||
do_eval: 'Do eval or not, default is false.'
|
||||
dataset: 'Dataset, default is coco.'
|
||||
pre_trained: 'Pretrain file path.'
|
||||
device_id: 'Device id, default is 0.'
|
||||
device_num: 'Use device nums, default is 1.'
|
||||
rank_id: 'Rank id, default is 0.'
|
||||
file_format: 'file format'
|
||||
img_path: "image file path."
|
||||
result_path: "result file path."
|
||||
|
||||
---
|
||||
device_target: ['Ascend', 'GPU', 'CPU']
|
||||
file_format: ["AIR", "ONNX", "MINDIR"]
|
|
@ -15,27 +15,95 @@
|
|||
"""
|
||||
eval.
|
||||
"""
|
||||
import time
|
||||
import os
|
||||
from mindspore import nn
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from src.dataset import create_dataset
|
||||
from src.config import set_config
|
||||
from src.args import eval_parse_args
|
||||
from src.models import define_net, load_ckpt
|
||||
from src.utils import switch_precision, set_context
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
|
||||
|
||||
if __name__ == '__main__':
|
||||
args_opt = eval_parse_args()
|
||||
config = set_config(args_opt)
|
||||
|
||||
config.is_training = config.is_training_eval
|
||||
config.device_id = get_device_id()
|
||||
config.rank_id = get_rank_id()
|
||||
config.rank_size = get_device_num()
|
||||
config.run_distribute = config.rank_size > 1.
|
||||
|
||||
def modelarts_process():
|
||||
""" modelarts process """
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),\
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if config.need_modelarts_dataset_unzip:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
print("#" * 200, os.listdir(save_dir_1))
|
||||
print("#" * 200, os.listdir(os.path.join(config.data_path, config.modelarts_dataset_unzip_name)))
|
||||
|
||||
config.dataset_path = os.path.join(config.data_path, config.modelarts_dataset_unzip_name)
|
||||
config.pretrain_ckpt = os.path.join(config.output_path, config.pretrain_ckpt)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_process)
|
||||
def eval_mobilenetv2():
|
||||
config.dataset_path = os.path.join(config.dataset_path, 'validation_preprocess')
|
||||
print('\nconfig: \n', config)
|
||||
set_context(config)
|
||||
backbone_net, head_net, net = define_net(config, args_opt.is_training)
|
||||
_, _, net = define_net(config, config.is_training)
|
||||
|
||||
load_ckpt(net, args_opt.pretrain_ckpt)
|
||||
load_ckpt(net, config.pretrain_ckpt)
|
||||
|
||||
switch_precision(net, mstype.float16, config)
|
||||
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, config=config)
|
||||
dataset = create_dataset(dataset_path=config.dataset_path, do_train=False, config=config)
|
||||
step_size = dataset.get_dataset_size()
|
||||
if step_size == 0:
|
||||
raise ValueError("The step_size of dataset is zero. Check if the images count of eval dataset is more \
|
||||
|
@ -47,4 +115,7 @@ if __name__ == '__main__':
|
|||
model = Model(net, loss_fn=loss, metrics={'acc'})
|
||||
|
||||
res = model.eval(dataset)
|
||||
print(f"result:{res}\npretrain_ckpt={args_opt.pretrain_ckpt}")
|
||||
print(f"result:{res}\npretrain_ckpt={config.pretrain_ckpt}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
eval_mobilenetv2()
|
||||
|
|
|
@ -15,35 +15,32 @@
|
|||
"""
|
||||
mobilenetv2 export file.
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import Tensor, export, context
|
||||
from src.config import set_config
|
||||
from src.models import define_net, load_ckpt
|
||||
from src.utils import set_context
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
|
||||
|
||||
parser = argparse.ArgumentParser(description="mobilenetv2 export")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="mobilenetv2", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="AIR", help="file format")
|
||||
parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"),
|
||||
help='run platform, only support GPU, CPU and Ascend')
|
||||
args = parser.parse_args()
|
||||
args.is_training = False
|
||||
args.run_distribute = False
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform)
|
||||
if args.platform == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
config.device_id = get_device_id()
|
||||
config.rank_id = get_rank_id()
|
||||
config.rank_size = get_device_num()
|
||||
config.run_distribute = config.rank_size > 1.
|
||||
|
||||
config.batch_size = config.batch_size_export
|
||||
config.is_training = config.is_training_export
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform)
|
||||
if config.platform == "Ascend":
|
||||
context.set_context(device_id=get_device_id())
|
||||
|
||||
if __name__ == '__main__':
|
||||
cfg = set_config(args)
|
||||
set_context(cfg)
|
||||
_, _, net = define_net(cfg, args.is_training)
|
||||
print('\nconfig: \n', config)
|
||||
set_context(config)
|
||||
_, _, net = define_net(config, config.is_training)
|
||||
|
||||
load_ckpt(net, args.ckpt_file)
|
||||
input_shp = [args.batch_size, 3, cfg.image_height, cfg.image_width]
|
||||
load_ckpt(net, config.ckpt_file)
|
||||
input_shp = [config.batch_size, 3, config.image_height, config.image_width]
|
||||
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32))
|
||||
export(net, input_array, file_name=args.file_name, file_format=args.file_format)
|
||||
export(net, input_array, file_name=config.file_name, file_format=config.file_format)
|
||||
|
|
|
@ -14,15 +14,11 @@
|
|||
# ============================================================================
|
||||
"""post process for 310 inference"""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from src.model_utils.config import config
|
||||
|
||||
batch_size = 1
|
||||
parser = argparse.ArgumentParser(description="mobilenetv2 acc calculation")
|
||||
parser.add_argument("--result_path", type=str, required=True, help="result files path.")
|
||||
parser.add_argument("--label_path", type=str, required=True, help="label path.")
|
||||
args = parser.parse_args()
|
||||
|
||||
config.batch_size = config.batch_size_postprocess
|
||||
|
||||
def calcul_acc(labels, preds):
|
||||
return sum(1 for x, y in zip(labels, preds) if x == y) / len(labels)
|
||||
|
@ -55,4 +51,4 @@ def get_result(result_path, label_path):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
get_result(args.result_path, args.label_path)
|
||||
get_result(config.result_path, config.label_path)
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
# ============================================================================
|
||||
|
||||
|
||||
|
||||
run_ascend()
|
||||
{
|
||||
# check pretrain_ckpt file
|
||||
|
@ -27,6 +26,7 @@ run_ascend()
|
|||
|
||||
# set environment
|
||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
CONFIG_FILE="${BASEPATH}/../default_config.yaml"
|
||||
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
|
@ -40,6 +40,7 @@ run_ascend()
|
|||
|
||||
# launch
|
||||
python ${BASEPATH}/../eval.py \
|
||||
--config_path=$CONFIG_FILE \
|
||||
--platform=$1 \
|
||||
--dataset_path=$2 \
|
||||
--pretrain_ckpt=$3 \
|
||||
|
@ -56,6 +57,7 @@ run_gpu()
|
|||
fi
|
||||
|
||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
CONFIG_FILE="${BASEPATH}/../default_config_gpu.yaml"
|
||||
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||
if [ -d "../eval" ];
|
||||
then
|
||||
|
@ -65,6 +67,7 @@ run_gpu()
|
|||
cd ../eval || exit
|
||||
|
||||
python ${BASEPATH}/../eval.py \
|
||||
--config_path=$CONFIG_FILE \
|
||||
--platform=$1 \
|
||||
--dataset_path=$2 \
|
||||
--pretrain_ckpt=$3 \
|
||||
|
@ -81,6 +84,7 @@ run_cpu()
|
|||
fi
|
||||
|
||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
CONFIG_FILE="${BASEPATH}/../default_config_cpu.yaml"
|
||||
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||
if [ -d "../eval" ];
|
||||
then
|
||||
|
@ -90,6 +94,7 @@ run_cpu()
|
|||
cd ../eval || exit
|
||||
|
||||
python ${BASEPATH}/../eval.py \
|
||||
--config_path=$CONFIG_FILE \
|
||||
--platform=$1 \
|
||||
--dataset_path=$2 \
|
||||
--pretrain_ckpt=$3 \
|
||||
|
|
|
@ -87,7 +87,9 @@ function infer()
|
|||
|
||||
function cal_acc()
|
||||
{
|
||||
python3.7 ../postprocess.py --result_path=./result_Files --label_path=$label_path &> acc.log &
|
||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
CONFIG_FILE="${BASEPATH}/../default_config.yaml"
|
||||
python3.7 ../postprocess.py --config_path=$CONFIG_FILE --result_path=./result_Files --label_path=$label_path &> acc.log &
|
||||
}
|
||||
|
||||
compile_app
|
||||
|
|
|
@ -48,6 +48,8 @@ run_ascend()
|
|||
fi
|
||||
|
||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
CONFIG_FILE="${BASEPATH}/../default_config.yaml"
|
||||
|
||||
VISIABLE_DEVICES=$3
|
||||
IFS="," read -r -a CANDIDATE_DEVICE <<< "$VISIABLE_DEVICES"
|
||||
if [ ${#CANDIDATE_DEVICE[@]} -ne $2 ]
|
||||
|
@ -71,11 +73,13 @@ run_ascend()
|
|||
rm -rf ./rank$i
|
||||
mkdir ./rank$i
|
||||
cp ../*.py ./rank$i
|
||||
cp ../*.yaml ./rank$i
|
||||
cp -r ../src ./rank$i
|
||||
cd ./rank$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py \
|
||||
--config_path=$CONFIG_FILE \
|
||||
--platform=$1 \
|
||||
--dataset_path=$5 \
|
||||
--pretrain_ckpt=$PRETRAINED_CKPT \
|
||||
|
@ -119,6 +123,8 @@ run_gpu()
|
|||
fi
|
||||
|
||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
CONFIG_FILE="${BASEPATH}/../default_config_gpu.yaml"
|
||||
|
||||
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||
if [ -d "../train" ];
|
||||
then
|
||||
|
@ -130,6 +136,7 @@ run_gpu()
|
|||
export CUDA_VISIBLE_DEVICES="$3"
|
||||
mpirun -n $2 --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \
|
||||
python ${BASEPATH}/../train.py \
|
||||
--config_path=$CONFIG_FILE \
|
||||
--platform=$1 \
|
||||
--dataset_path=$4 \
|
||||
--pretrain_ckpt=$PRETRAINED_CKPT \
|
||||
|
@ -165,6 +172,8 @@ run_cpu()
|
|||
fi
|
||||
|
||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
CONFIG_FILE="${BASEPATH}/../default_config_cpu.yaml"
|
||||
|
||||
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||
if [ -d "../train" ];
|
||||
then
|
||||
|
@ -174,6 +183,7 @@ run_cpu()
|
|||
cd ../train || exit
|
||||
|
||||
python ${BASEPATH}/../train.py \
|
||||
--config_path=$CONFIG_FILE \
|
||||
--platform=$1 \
|
||||
--dataset_path=$2 \
|
||||
--pretrain_ckpt=$PRETRAINED_CKPT \
|
||||
|
|
|
@ -54,6 +54,8 @@ run_ascend()
|
|||
CACHE_SESSION_ID=$(generate_cache_session)
|
||||
|
||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
CONFIG_FILE="${BASEPATH}/../default_config.yaml"
|
||||
|
||||
VISIABLE_DEVICES=$3
|
||||
IFS="," read -r -a CANDIDATE_DEVICE <<< "$VISIABLE_DEVICES"
|
||||
if [ ${#CANDIDATE_DEVICE[@]} -ne $2 ]
|
||||
|
@ -77,11 +79,13 @@ run_ascend()
|
|||
rm -rf ./rank$i
|
||||
mkdir ./rank$i
|
||||
cp ../*.py ./rank$i
|
||||
cp ../*.yaml ./rank$i
|
||||
cp -r ../src ./rank$i
|
||||
cd ./rank$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py \
|
||||
--config_path=$CONFIG_FILE \
|
||||
--platform=$1 \
|
||||
--dataset_path=$5 \
|
||||
--pretrain_ckpt=$PRETRAINED_CKPT \
|
||||
|
@ -131,6 +135,7 @@ run_gpu()
|
|||
CACHE_SESSION_ID=$(generate_cache_session)
|
||||
|
||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
CONFIG_FILE="${BASEPATH}/../default_config_gpu.yaml"
|
||||
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||
if [ -d "../train" ];
|
||||
then
|
||||
|
@ -142,6 +147,7 @@ run_gpu()
|
|||
export CUDA_VISIBLE_DEVICES="$3"
|
||||
mpirun -n $2 --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \
|
||||
python ${BASEPATH}/../train.py \
|
||||
--config_path=$CONFIG_FILE \
|
||||
--platform=$1 \
|
||||
--dataset_path=$4 \
|
||||
--pretrain_ckpt=$PRETRAINED_CKPT \
|
||||
|
@ -183,6 +189,7 @@ run_cpu()
|
|||
CACHE_SESSION_ID=$(generate_cache_session)
|
||||
|
||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
CONFIG_FILE="${BASEPATH}/../default_config_cpu.yaml"
|
||||
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||
if [ -d "../train" ];
|
||||
then
|
||||
|
@ -192,6 +199,7 @@ run_cpu()
|
|||
cd ../train || exit
|
||||
|
||||
python ${BASEPATH}/../train.py \
|
||||
--config_path=$CONFIG_FILE \
|
||||
--platform=$1 \
|
||||
--dataset_path=$2 \
|
||||
--pretrain_ckpt=$PRETRAINED_CKPT \
|
||||
|
@ -211,4 +219,4 @@ elif [ $1 = "CPU" ] ; then
|
|||
run_cpu "$@"
|
||||
else
|
||||
echo "Unsupported platform."
|
||||
fi;
|
||||
fi;
|
||||
|
|
|
@ -1,61 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
|
||||
def train_parse_args():
|
||||
train_parser = argparse.ArgumentParser(description='Image classification trian')
|
||||
train_parser.add_argument('--platform', type=str, default="Ascend", choices=("CPU", "GPU", "Ascend"), \
|
||||
help='run platform, only support CPU, GPU and Ascend')
|
||||
train_parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path')
|
||||
train_parser.add_argument('--pretrain_ckpt', type=str, default="", help='Pretrained checkpoint path \
|
||||
for fine tune or incremental learning')
|
||||
train_parser.add_argument('--freeze_layer', type=str, default="", choices=["", "none", "backbone"], \
|
||||
help="freeze the weights of network from start to which layers")
|
||||
train_parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distribute')
|
||||
train_parser.add_argument('--filter_head', type=ast.literal_eval, default=False,\
|
||||
help='Filter head weight parameters when load checkpoint, default is False.')
|
||||
train_parser.add_argument('--enable_cache', type=ast.literal_eval, default=False, \
|
||||
help='Caching the dataset in memory to speedup dataset processing, default is False.')
|
||||
train_parser.add_argument('--cache_session_id', type=str, default="", help='The session id for cache service.')
|
||||
train_args = train_parser.parse_args()
|
||||
train_args.is_training = True
|
||||
if train_args.platform == "CPU":
|
||||
train_args.run_distribute = False
|
||||
return train_args
|
||||
|
||||
def eval_parse_args():
|
||||
eval_parser = argparse.ArgumentParser(description='Image classification eval')
|
||||
eval_parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), \
|
||||
help='run platform, only support GPU, CPU and Ascend')
|
||||
eval_parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path')
|
||||
eval_parser.add_argument('--pretrain_ckpt', type=str, required=True, help='Pretrained checkpoint path \
|
||||
for fine tune or incremental learning')
|
||||
eval_parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='If run distribute in GPU.')
|
||||
eval_args = eval_parser.parse_args()
|
||||
eval_args.is_training = False
|
||||
return eval_args
|
||||
|
||||
def export_parse_args():
|
||||
export_parser = argparse.ArgumentParser(description='Image classification export')
|
||||
export_parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), \
|
||||
help='run platform, only support GPU, CPU and Ascend')
|
||||
export_parser.add_argument('--pretrain_ckpt', type=str, required=True, help='Pretrained checkpoint path \
|
||||
for fine tune or incremental learning')
|
||||
export_args = export_parser.parse_args()
|
||||
export_args.is_training = False
|
||||
export_args.run_distribute = False
|
||||
return export_args
|
|
@ -1,100 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
import os
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
def set_config(args):
|
||||
if not args.run_distribute:
|
||||
args.run_distribute = False
|
||||
config_cpu = ed({
|
||||
"num_classes": 26,
|
||||
"image_height": 224,
|
||||
"image_width": 224,
|
||||
"batch_size": 150,
|
||||
"epoch_size": 15,
|
||||
"warmup_epochs": 0,
|
||||
"lr_init": .0,
|
||||
"lr_end": 0.03,
|
||||
"lr_max": 0.03,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 4e-5,
|
||||
"label_smooth": 0.1,
|
||||
"loss_scale": 1024,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 20,
|
||||
"save_checkpoint_path": "./",
|
||||
"platform": args.platform,
|
||||
"run_distribute": args.run_distribute,
|
||||
"activation": "Softmax"
|
||||
})
|
||||
config_gpu = ed({
|
||||
"num_classes": 1000,
|
||||
"image_height": 224,
|
||||
"image_width": 224,
|
||||
"batch_size": 150,
|
||||
"epoch_size": 200,
|
||||
"warmup_epochs": 0,
|
||||
"lr_init": .0,
|
||||
"lr_end": .0,
|
||||
"lr_max": 0.8,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 4e-5,
|
||||
"label_smooth": 0.1,
|
||||
"loss_scale": 1024,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 200,
|
||||
"save_checkpoint_path": "./",
|
||||
"platform": args.platform,
|
||||
"run_distribute": args.run_distribute,
|
||||
"activation": "Softmax"
|
||||
})
|
||||
config_ascend = ed({
|
||||
"num_classes": 1000,
|
||||
"image_height": 224,
|
||||
"image_width": 224,
|
||||
"batch_size": 256,
|
||||
"epoch_size": 200,
|
||||
"warmup_epochs": 4,
|
||||
"lr_init": 0.00,
|
||||
"lr_end": 0.00,
|
||||
"lr_max": 0.4,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 4e-5,
|
||||
"label_smooth": 0.1,
|
||||
"loss_scale": 1024,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 200,
|
||||
"save_checkpoint_path": "./",
|
||||
"platform": args.platform,
|
||||
"device_id": int(os.getenv('DEVICE_ID', '0')),
|
||||
"rank_id": int(os.getenv('RANK_ID', '0')),
|
||||
"rank_size": int(os.getenv('RANK_SIZE', '1')),
|
||||
"run_distribute": int(os.getenv('RANK_SIZE', '1')) > 1.,
|
||||
"activation": "Softmax"
|
||||
})
|
||||
config = ed({"CPU": config_cpu,
|
||||
"GPU": config_gpu,
|
||||
"Ascend": config_ascend})
|
||||
|
||||
if args.platform not in config.keys():
|
||||
raise ValueError("Unsupported platform.")
|
||||
|
||||
return config[args.platform]
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from mindspore.profiler import Profiler
|
||||
from .config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler = Profiler()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler.analyse()
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -33,40 +33,99 @@ from mindspore.common import set_seed
|
|||
|
||||
from src.dataset import create_dataset, extract_features
|
||||
from src.lr_generator import get_lr
|
||||
from src.config import set_config
|
||||
|
||||
from src.args import train_parse_args
|
||||
from src.utils import context_device_init, switch_precision, config_ckpoint
|
||||
from src.models import CrossEntropyWithLabelSmooth, define_net, load_ckpt
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
|
||||
|
||||
|
||||
set_seed(1)
|
||||
config.device_id = get_device_id()
|
||||
config.rank_id = get_rank_id()
|
||||
config.rank_size = get_device_num()
|
||||
config.run_distribute = config.rank_size > 1.
|
||||
|
||||
if __name__ == '__main__':
|
||||
args_opt = train_parse_args()
|
||||
args_opt.dataset_path = os.path.abspath(args_opt.dataset_path)
|
||||
config = set_config(args_opt)
|
||||
|
||||
def modelarts_pre_process():
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),\
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if config.need_modelarts_dataset_unzip:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
print("#" * 200, os.listdir(save_dir_1))
|
||||
print("#" * 200, os.listdir(os.path.join(config.data_path, config.modelarts_dataset_unzip_name)))
|
||||
|
||||
config.dataset_path = os.path.join(config.data_path, config.modelarts_dataset_unzip_name)
|
||||
config.pretrain_ckpt = os.path.join(config.output_path, config.pretrain_ckpt)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def train_mobilenetv2():
|
||||
config.dataset_path = os.path.join(config.dataset_path, 'train')
|
||||
print('\nconfig: \n', config)
|
||||
start = time.time()
|
||||
|
||||
print(f"train args: {args_opt}\ncfg: {config}")
|
||||
|
||||
# set context and device init
|
||||
context_device_init(config)
|
||||
|
||||
# define network
|
||||
backbone_net, head_net, net = define_net(config, args_opt.is_training)
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, config=config,
|
||||
enable_cache=args_opt.enable_cache, cache_session_id=args_opt.cache_session_id)
|
||||
backbone_net, head_net, net = define_net(config, config.is_training)
|
||||
dataset = create_dataset(dataset_path=config.dataset_path, do_train=True, config=config,
|
||||
enable_cache=config.enable_cache, cache_session_id=config.cache_session_id)
|
||||
step_size = dataset.get_dataset_size()
|
||||
if config.platform == "GPU":
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
if args_opt.pretrain_ckpt:
|
||||
if args_opt.freeze_layer == "backbone":
|
||||
load_ckpt(backbone_net, args_opt.pretrain_ckpt, trainable=False)
|
||||
step_size = extract_features(backbone_net, args_opt.dataset_path, config)
|
||||
elif args_opt.filter_head:
|
||||
load_ckpt(backbone_net, args_opt.pretrain_ckpt)
|
||||
if config.pretrain_ckpt:
|
||||
if config.freeze_layer == "backbone":
|
||||
load_ckpt(backbone_net, config.pretrain_ckpt, trainable=False)
|
||||
step_size = extract_features(backbone_net, config.dataset_path, config)
|
||||
elif config.filter_head:
|
||||
load_ckpt(backbone_net, config.pretrain_ckpt)
|
||||
else:
|
||||
load_ckpt(net, args_opt.pretrain_ckpt)
|
||||
load_ckpt(net, config.pretrain_ckpt)
|
||||
if step_size == 0:
|
||||
raise ValueError("The step_size of dataset is zero. Check if the images' count of train dataset is more \
|
||||
than batch_size in config.py")
|
||||
|
@ -92,7 +151,7 @@ if __name__ == '__main__':
|
|||
total_epochs=epoch_size,
|
||||
steps_per_epoch=step_size))
|
||||
|
||||
if args_opt.pretrain_ckpt == "" or args_opt.freeze_layer != "backbone":
|
||||
if config.pretrain_ckpt == "" or config.freeze_layer != "backbone":
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
|
||||
config.weight_decay, config.loss_scale)
|
||||
|
@ -111,7 +170,7 @@ if __name__ == '__main__':
|
|||
network = TrainOneStepCell(network, opt)
|
||||
network.set_train()
|
||||
|
||||
features_path = args_opt.dataset_path + '_features'
|
||||
features_path = config.dataset_path + '_features'
|
||||
idx_list = list(range(step_size))
|
||||
rank = 0
|
||||
if config.run_distribute:
|
||||
|
@ -136,5 +195,9 @@ if __name__ == '__main__':
|
|||
save_checkpoint(net, os.path.join(save_ckpt_path, f"mobilenetv2_{epoch+1}.ckpt"))
|
||||
print("total cost {:5.4f} s".format(time.time() - start))
|
||||
|
||||
if args_opt.enable_cache:
|
||||
if config.enable_cache:
|
||||
print("Remember to shut down the cache server via \"cache_admin --stop\"")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_mobilenetv2()
|
||||
|
|
Loading…
Reference in New Issue