openpose network optimize 8p acc.

This commit is contained in:
zhanghuiyao 2020-12-17 14:25:56 +08:00
parent da74482812
commit 3c5a8538d5
11 changed files with 415 additions and 349 deletions

View File

@ -142,10 +142,10 @@ Parameters for both training and evaluation can be set in config.py
'batch_size': 10 # training batch size
'lr_gamma': 0.1 # lr scale when reach lr_steps
'lr_steps': '100000,200000,250000' # the steps when lr * lr_gamma
'loss scale': 16386 # the loss scale of mixed precision
'loss scale': 16384 # the loss scale of mixed precision
'max_epoch_train': 60 # total training epochs
'insize': 368 # image size used as input to the model
'keep_checkpoint_max': 5 # only keep the last keep_checkpoint_max checkpoint
'keep_checkpoint_max': 1 # only keep the last keep_checkpoint_max checkpoint
'log_interval': 100 # the interval of print a log
'ckpt_interval': 5000 # the interval of saving a output model
```
@ -195,7 +195,7 @@ For more configuration details, please refer the script `config.py`.
```python
# grep "AP" eval.log
{'AP': 0.40030956300341397, 'Ap .5': 0.6658941566481336, 'AP .75': 0.396047897339743, 'AP (M)': 0.3075356543635785, 'AP (L)': 0.533772768618845, 'AR': 0.4519836272040302, 'AR .5': 0.693639798488665, 'AR .75': 0.4570214105793451, 'AR (M)': 0.32155148866429945, 'AR (L)': 0.6330360460795242}
{'AP': 0.39830956300341397, 'Ap .5': 0.6658941566481336, 'AP .75': 0.396047897339743, 'AP (M)': 0.3075356543635785, 'AP (L)': 0.533772768618845, 'AR': 0.4519836272040302, 'AR .5': 0.693639798488665, 'AR .75': 0.4570214105793451, 'AR (M)': 0.32155148866429945, 'AR (L)': 0.6330360460795242}
```
@ -209,14 +209,14 @@ For more configuration details, please refer the script `config.py`.
| -------------------------- | -----------------------------------------------------------
| Model Version | openpose
| Resource | Ascend 910 CPU 2.60GHz192coresMemory755G
| uploaded Date | 10/20/2020 (month/day/year)
| uploaded Date | 12/14/2020 (month/day/year)
| MindSpore Version | 1.0.1-alpha
| Training Parameters | epoch = 60, steps = 30k, batch_size = 10, lr = 0.0001
| Optimizer | Adam
| Training Parameters | epoch=60(1pcs)/80(8pcs), steps=30k(1pcs)/5k(8pcs), batch_size=10, init_lr=0.0001
| Optimizer | Adam(1pcs)/Momentum(8pcs)
| Loss Function | MSE
| outputs | pose
| Speed | 1pc: 29imgs/s
| Total time | 1pc: 30h
| Speed | 1pcs: 35fps, 8pcs: 230fps
| Total time | 1pcs: 22.5h, 8pcs: 5.1h
| Checkpoint for Fine tuning | 602.33M (.ckpt file)

View File

@ -17,10 +17,9 @@ import os
import argparse
import warnings
import sys
import cv2
from tqdm import tqdm
import numpy as np
from tqdm import tqdm
import cv2
from scipy.ndimage.filters import gaussian_filter
from pycocotools.coco import COCO as LoadAnn
from pycocotools.cocoeval import COCOeval as MapEval
@ -30,9 +29,10 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.common import dtype as mstype
from src.dataset import valdata
from src.openposenet import OpenPoseNet
from src.config import params, JointType
from src.openposenet import OpenPoseNet
from src.dataset import valdata
warnings.filterwarnings("ignore")
devid = int(os.getenv('DEVICE_ID'))
@ -40,6 +40,18 @@ context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend", save_graphs=False, device_id=devid)
show_gt = 0
parser = argparse.ArgumentParser('mindspore openpose_net test')
parser.add_argument('--model_path', type=str, default='./0-33_170000.ckpt', help='path of testing model')
parser.add_argument('--imgpath_val', type=str, default='./dataset/coco/val2017', help='path of testing imgs')
parser.add_argument('--ann', type=str, default='./dataset/coco/annotations/person_keypoints_val2017.json',
help='path of annotations')
parser.add_argument('--output_path', type=str, default='./output_img', help='path of testing imgs')
# distributed related
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
args, _ = parser.parse_known_args()
def evaluate_mAP(res_file, ann_file, ann_type='keypoints', silence=True):
class NullWriter():
def write(self, arg):
@ -68,23 +80,6 @@ def evaluate_mAP(res_file, ann_file, ann_type='keypoints', silence=True):
return info_str
def parse_args():
"""Parse arguments."""
parser = argparse.ArgumentParser('mindspore openpose_net test')
parser.add_argument('--model_path', type=str, default='./scripts/train_parallel0/checkpoints/ckpt_0/0-60_663.ckpt',
help='path of testing model')
parser.add_argument('--imgpath_val', type=str, default='/data0/zhy/dataset/coco/val2017',
help='path of testing imgs')
parser.add_argument('--ann', type=str, default='/data0/zhy/dataset/coco/annotations/person_keypoints_val2017.json',
help='path of annotations')
parser.add_argument('--output_path', type=str, default='./output_img', help='path of testing imgs')
# distributed related
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
args, _ = parser.parse_known_args()
return args
def load_model(test_net, model_path):
assert os.path.exists(model_path)
@ -178,7 +173,7 @@ def compute_peaks_from_heatmaps(heatmaps):
return all_peaks
def compute_candidate_connections(paf, cand_a, cand_b, img_len, cfg):
def compute_candidate_connections(paf, cand_a, cand_b, img_len, params_):
candidate_connections = []
for joint_a in cand_a:
for joint_b in cand_b:
@ -186,33 +181,33 @@ def compute_candidate_connections(paf, cand_a, cand_b, img_len, cfg):
norm = np.linalg.norm(vector)
if norm == 0:
continue
ys = np.linspace(joint_a[1], joint_b[1], num=cfg['n_integ_points'])
xs = np.linspace(joint_a[0], joint_b[0], num=cfg['n_integ_points'])
ys = np.linspace(joint_a[1], joint_b[1], num=params_['n_integ_points'])
xs = np.linspace(joint_a[0], joint_b[0], num=params_['n_integ_points'])
integ_points = np.stack([ys, xs]).T.round().astype('i')
paf_in_edge = np.hstack([paf[0][np.hsplit(integ_points, 2)], paf[1][np.hsplit(integ_points, 2)]])
unit_vector = vector / norm
inner_products = np.dot(paf_in_edge, unit_vector)
integ_value = inner_products.sum() / len(inner_products)
integ_value_with_dist_prior = integ_value + min(cfg['limb_length_ratio'] * img_len / norm -
cfg['length_penalty_value'], 0)
n_valid_points = sum(inner_products > cfg['inner_product_thresh'])
if n_valid_points > cfg['n_integ_points_thresh'] and integ_value_with_dist_prior > 0:
integ_value_with_dist_prior = integ_value + min(params_['limb_length_ratio'] * img_len / norm -
params_['length_penalty_value'], 0)
n_valid_points = sum(inner_products > params_['inner_product_thresh'])
if n_valid_points > params_['n_integ_points_thresh'] and integ_value_with_dist_prior > 0:
candidate_connections.append([int(joint_a[3]), int(joint_b[3]), integ_value_with_dist_prior])
candidate_connections = sorted(candidate_connections, key=lambda x: x[2], reverse=True)
return candidate_connections
def compute_connections(pafs, all_peaks, img_len, cfg):
def compute_connections(pafs, all_peaks, img_len, params_):
all_connections = []
for i in range(len(cfg['limbs_point'])):
for i in range(len(params_['limbs_point'])):
paf_index = [i * 2, i * 2 + 1]
paf = pafs[paf_index] # shape: (2, 320, 320)
limb_point = cfg['limbs_point'][i] # example: [<JointType.Neck: 1>, <JointType.RightWaist: 8>]
limb_point = params_['limbs_point'][i] # example: [<JointType.Neck: 1>, <JointType.RightWaist: 8>]
cand_a = all_peaks[all_peaks[:, 0] == limb_point[0]][:, 1:]
cand_b = all_peaks[all_peaks[:, 0] == limb_point[1]][:, 1:]
if cand_a.shape[0] > 0 and cand_b.shape[0] > 0:
candidate_connections = compute_candidate_connections(paf, cand_a, cand_b, img_len, cfg)
candidate_connections = compute_candidate_connections(paf, cand_a, cand_b, img_len, params_)
connections = np.zeros((0, 3))
@ -226,11 +221,11 @@ def compute_connections(pafs, all_peaks, img_len, cfg):
all_connections.append(np.zeros((0, 3)))
return all_connections
def grouping_key_points(all_connections, candidate_peaks, cfg):
def grouping_key_points(all_connections, candidate_peaks, params_):
subsets = -1 * np.ones((0, 20))
for l, connections in enumerate(all_connections):
joint_a, joint_b = cfg['limbs_point'][l]
joint_a, joint_b = params_['limbs_point'][l]
for ind_a, ind_b, score in connections[:, :3]:
ind_a, ind_b = int(ind_a), int(ind_b)
joint_found_cnt = 0
@ -249,6 +244,7 @@ def grouping_key_points(all_connections, candidate_peaks, cfg):
found_subset[-1] += 1 # increment joint count
found_subset[-2] += candidate_peaks[ind_b, 3] + score
elif joint_found_cnt == 2:
found_subset_1 = subsets[joint_found_subset_index[0]]
@ -289,10 +285,8 @@ def grouping_key_points(all_connections, candidate_peaks, cfg):
pass
# delete low score subsets
keep = np.logical_and(subsets[:, -1] >= cfg['n_subset_limbs_thresh'],
subsets[:, -2] / subsets[:, -1] >= cfg['subset_score_thresh'])
# cfg['n_subset_limbs_thresh'] = 3
# cfg['subset_score_thresh'] = 0.2
keep = np.logical_and(subsets[:, -1] >= params_['n_subset_limbs_thresh'],
subsets[:, -2] / subsets[:, -1] >= params_['subset_score_thresh'])
subsets = subsets[keep]
return subsets
@ -319,7 +313,7 @@ def detect(img, network):
# map_w, map_h = compute_optimal_size(orig_img, params['heatmap_size']) # 320
map_w, map_h = compute_optimal_size(orig_img, params['inference_img_size'])
print("image size is: ", input_w, input_h)
# print("image size is: ", input_w, input_h)
resized_image = cv2.resize(orig_img, (input_w, input_h))
x_data = preprocess(resized_image)
@ -394,7 +388,7 @@ def draw_person_pose(orig_img, poses):
return canvas
def depreprocess(img):
# x_data = img.astype('f')
#x_data = img.astype('f')
x_data = img[0]
x_data += 0.5
x_data *= 255
@ -402,15 +396,14 @@ def depreprocess(img):
x_data = x_data.transpose(1, 2, 0)
return x_data
def _eval():
args = parse_args()
def val():
if args.is_distributed:
init()
args.rank = get_rank()
args.group_size = get_group_size()
if not os.path.exists(args.output_path):
os.mkdir(args.output_path)
network = OpenPoseNet()
network = OpenPoseNet(vgg_with_bn=params['vgg_with_bn'])
network.set_train(False)
load_model(network, args.model_path)
@ -455,4 +448,4 @@ def _eval():
print('result: ', res)
if __name__ == "__main__":
_eval()
val()

View File

@ -15,9 +15,8 @@
# ============================================================================
export DEVICE_ID=0
export RANK_ID=0
python eval.py \
--model_path ./scripts/train_parallel0/checkpoints/ckpt_0/0-60_663.ckpt \
--imgpath_val /data0/zhy/dataset/coco/val2017 \
--ann /data0/zhy/dataset/coco/annotations/person_keypoints_val2017.json \
--model_path ./scripts/train_parallel0/checkpoints/ckpt_0/0-80_663.ckpt \
--imgpath_val ./dataset/val2017 \
--ann ./dataset/annotations/person_keypoints_val2017.json \
> eval.log 2>&1 &

View File

@ -14,5 +14,6 @@
# limitations under the License.
# ============================================================================
export DEVICE_ID=0
cd ..
python train.py --train_dir train2017 --train_ann person_keypoints_train2017.json > scripts/train.log 2>&1 &

View File

@ -53,21 +53,41 @@ class JointType(IntEnum):
params = {
# paths
'data_dir': '/data0/zhy/dataset/coco',
'vgg_path': '/data0/zhy/dataset/coco/vgg19-0-97_5004.ckpt',
'data_dir': './dataset',
'save_model_path': './checkpoints/',
'load_pretrain': False,
'pretrained_model_path': "",
# training params
'batch_size': 10,
# train type
'train_type': 'fix_loss_scale', # chose in ['clip_grad', 'fix_loss_scale']
'train_type_NP': 'clip_grad',
# vgg bn
'vgg_with_bn': False,
'vgg_path': './vgg_model/vgg19-0-97_5004.ckpt',
# if clip_grad
'GRADIENT_CLIP_TYPE': 1,
'GRADIENT_CLIP_VALUE': 10.0,
# optimizer and lr
'optimizer': "Adam", # chose in ['Momentum', 'Adam']
'optimizer_NP': "Momentum",
'group_params': True,
'group_params_NP': False,
'lr': 1e-4,
'lr_gamma': 0.1,
'lr_steps': '100000,200000,250000',
'lr_steps_NP': '250000',
'loss_scale': 16386,
'lr_type': 'default', # chose in ["default", "cosine"]
'lr_gamma': 0.1, # if default
'lr_steps': '100000,200000,250000', # if default
'lr_steps_NP': '250000,300000', # if default
'warmup_epoch': 5, # if cosine
'max_epoch_train': 60,
'max_epoch_train_NP': 80,
'loss_scale': 16384,
# default param
'batch_size': 10,
'min_keypoints': 5,
'min_area': 32 * 32,
'insize': 368,
@ -75,9 +95,9 @@ params = {
'paf_sigma': 8,
'heatmap_sigma': 7,
'eva_num': 100,
'keep_checkpoint_max': 5,
'keep_checkpoint_max': 1,
'log_interval': 100,
'ckpt_interval': 663, # 5000,
'ckpt_interval': 5304,
'min_box_size': 64,
'max_box_size': 512,

View File

@ -15,10 +15,10 @@
import os
import math
import random
import cv2
import numpy as np
import cv2
from pycocotools.coco import COCO as ReadJson
import mindspore.dataset as de
from src.config import JointType, params
@ -41,6 +41,7 @@ class txtdataset():
self.imgIds = random.sample(self.imgIds, n_samples)
print('{} images: {}'.format(mode, len(self)))
def __len__(self):
return len(self.imgIds)
@ -217,9 +218,9 @@ class txtdataset():
flipped_mask = cv2.flip(mask.astype(np.uint8), 1).astype('bool')
poses[:, :, 0] = img.shape[1] - 1 - poses[:, :, 0]
def swap_joints(poses, joint_type_1, joint_type_2):
tmp = poses[:, joint_type_1].copy()
poses[:, joint_type_1] = poses[:, joint_type_2]
def swap_joints(poses, joint_type_, joint_type_2):
tmp = poses[:, joint_type_].copy()
poses[:, joint_type_] = poses[:, joint_type_2]
poses[:, joint_type_2] = tmp
swap_joints(poses, JointType.LeftEye, JointType.RightEye)
@ -243,8 +244,10 @@ class txtdataset():
aug_img, ignore_mask, poses = self.flip_img(aug_img, ignore_mask, poses)
return aug_img, ignore_mask, poses
# ------------------------------------------------------------------
# ------------------------------- end -----------------------------------
# ------------------------------ Heatmap ------------------------------------
# return shape: (height, width)
def generate_gaussian_heatmap(self, shape, joint, sigma):
x, y = joint
@ -269,6 +272,38 @@ class txtdataset():
heatmaps = np.vstack((heatmaps, bg_heatmap[None]))
return heatmaps.astype('f')
def generate_gaussian_heatmap_fast(self, shape, joint, sigma):
x, y = joint
grid_x = np.tile(np.arange(shape[1]), (shape[0], 1))
grid_y = np.tile(np.arange(shape[0]), (shape[1], 1)).transpose()
grid_x = grid_x + 0.4375
grid_y = grid_y + 0.4375
grid_distance = (grid_x - x) ** 2 + (grid_y - y) ** 2
gaussian_heatmap = np.exp(-0.5 * grid_distance / sigma**2)
return gaussian_heatmap
def generate_heatmaps_fast(self, img, poses, heatmap_sigma):
resize_shape = (img.shape[0] // 8, img.shape[1] // 8)
heatmaps = np.zeros((0,) + resize_shape)
sum_heatmap = np.zeros(resize_shape)
for joint_index in range(len(JointType)):
heatmap = np.zeros(resize_shape)
for pose in poses:
if pose[joint_index, 2] > 0:
jointmap = self.generate_gaussian_heatmap_fast(resize_shape, pose[joint_index][:2]/8,
heatmap_sigma/8)
index_1 = jointmap > heatmap
heatmap[index_1] = jointmap[index_1]
index_2 = jointmap > sum_heatmap
sum_heatmap[index_2] = jointmap[index_2]
heatmaps = np.vstack((heatmaps, heatmap.reshape((1,) + heatmap.shape)))
bg_heatmap = 1 - sum_heatmap # background channel
heatmaps = np.vstack((heatmaps, bg_heatmap[None]))
return heatmaps.astype('f')
# ------------------------------ end ------------------------------------
# ------------------------------ PAF ------------------------------------
# return shape: (2, height, width)
def generate_constant_paf(self, shape, joint_from, joint_to, paf_width):
if np.array_equal(joint_from, joint_to): # same joint
@ -285,7 +320,7 @@ class txtdataset():
grid_y = np.tile(np.arange(shape[0]), (shape[1], 1)).transpose()
horizontal_inner_product = unit_vector[0] * (grid_x - joint_from[0]) + unit_vector[1] * (grid_y - joint_from[1])
horizontal_paf_flag = (horizontal_inner_product >= 0) & (horizontal_inner_product <= joint_distance)
vertical_inner_product = vertical_unit_vector[0] * (grid_x - joint_from[0]) + vertical_unit_vector[1] * \
vertical_inner_product = vertical_unit_vector[0] * (grid_x - joint_from[0]) + vertical_unit_vector[1] *\
(grid_y - joint_from[1])
vertical_paf_flag = np.abs(vertical_inner_product) <= paf_width # paf_width : 8
paf_flag = horizontal_paf_flag & vertical_paf_flag
@ -314,6 +349,55 @@ class txtdataset():
pafs = np.vstack((pafs, paf))
return pafs.astype('f')
def generate_constant_paf_fast(self, shape, joint_from, joint_to, paf_width):
if np.array_equal(joint_from, joint_to): # same joint
return np.zeros((2,) + shape[:-1])
joint_distance = np.linalg.norm(joint_to - joint_from)
unit_vector = (joint_to - joint_from) / joint_distance
rad = np.pi / 2
# [[0, 1], [-1, 0]]
rot_matrix = np.array([[np.cos(rad), np.sin(rad)], [-np.sin(rad), np.cos(rad)]])
# [[u_y], [-u_x]]
vertical_unit_vector = np.dot(rot_matrix, unit_vector)
grid_x = np.tile(np.arange(shape[1]), (shape[0], 1))
grid_y = np.tile(np.arange(shape[0]), (shape[1], 1)).transpose()
grid_x = grid_x + 0.4375
grid_y = grid_y + 0.4375
horizontal_inner_product = unit_vector[0] * (grid_x - joint_from[0]) + unit_vector[1] * (grid_y - joint_from[1])
horizontal_paf_flag = (horizontal_inner_product >= 0) & (horizontal_inner_product <= joint_distance)
vertical_inner_product = vertical_unit_vector[0] * (grid_x - joint_from[0]) + vertical_unit_vector[1] *\
(grid_y - joint_from[1])
vertical_paf_flag = np.abs(vertical_inner_product) <= paf_width # paf_width : 8/8 = 1
paf_flag = horizontal_paf_flag & vertical_paf_flag
constant_paf = np.stack((paf_flag, paf_flag)) *\
np.broadcast_to(unit_vector, shape[:-1] + (2,)).transpose(2, 0, 1)
return constant_paf
def generate_pafs_fast(self, img, poses, paf_sigma):
resize_shape = (img.shape[0]//8, img.shape[1]//8, 3)
pafs = np.zeros((0,) + resize_shape[:-1])
for limb in params['limbs_point']:
paf = np.zeros((2,) + resize_shape[:-1])
paf_flags = np.zeros(paf.shape) # for constant paf
for pose in poses:
joint_from, joint_to = pose[limb]
if joint_from[2] > 0 and joint_to[2] > 0:
limb_paf = self.generate_constant_paf_fast(resize_shape, joint_from[:2]/8, joint_to[:2]/8, paf_sigma/8) # [2, 368, 368]
limb_paf_flags = limb_paf != 0
paf_flags += np.broadcast_to(limb_paf_flags[0] | limb_paf_flags[1], limb_paf.shape)
paf += limb_paf
index_1 = paf_flags > 0
paf[index_1] /= paf_flags[index_1]
pafs = np.vstack((pafs, paf))
return pafs.astype('f')
# ------------------------------ end ------------------------------------
def get_img_annotation(self, ind=None, img_id=None):
annotations = None
@ -389,14 +473,18 @@ class txtdataset():
resized_img, ignore_mask, resized_poses = self.resize_data(img, ignore_mask, poses,
shape=(self.insize, self.insize))
heatmaps = self.generate_heatmaps(resized_img, resized_poses, params['heatmap_sigma'])
pafs = self.generate_pafs(resized_img, resized_poses, params['paf_sigma']) # params['paf_sigma']: 8
# heatmaps = self.generate_heatmaps(resized_img, resized_poses, params['heatmap_sigma'])
# resized_heatmaps = self.resize_output(heatmaps)
resized_heatmaps = self.generate_heatmaps_fast(resized_img, resized_poses, params['heatmap_sigma'])
# pafs = self.generate_pafs(resized_img, resized_poses, params['paf_sigma'])
# resized_pafs = self.resize_output(pafs)
resized_pafs = self.generate_pafs_fast(resized_img, resized_poses, params['paf_sigma'])
ignore_mask = cv2.morphologyEx(ignore_mask.astype('uint8'), cv2.MORPH_DILATE, np.ones((16, 16))).astype('bool')
resized_pafs = self.resize_output(pafs)
resized_heatmaps = self.resize_output(heatmaps)
resized_ignore_mask = self.resize_output(ignore_mask)
return resized_img, resized_pafs, resized_heatmaps, resized_ignore_mask
def preprocess(self, img):
@ -459,7 +547,6 @@ class DistributedSampler():
def __len__(self):
return self.num_samplers
def valdata(jsonpath, imgpath, rank, group_size, mode='val', maskpath=''):
#cv2.setNumThreads(0)
val = ReadJson(jsonpath)
@ -470,23 +557,6 @@ def valdata(jsonpath, imgpath, rank, group_size, mode='val', maskpath=''):
return ds
def openpose(jsonpath, imgpath, maskpath, per_batch_size, max_epoch, rank, group_size, mode='train'):
train = ReadJson(jsonpath)
num_parallel = 48
if group_size > 1:
num_parallel = 20
dataset = txtdataset(train, imgpath, maskpath, params['insize'], mode=mode)
sampler = DistributedSampler(dataset, rank, group_size)
de_dataset = de.GeneratorDataset(dataset, ["image", "pafs", "heatmaps", "ignore_mask"],
num_parallel_workers=num_parallel, sampler=sampler, shuffle=True)
de_dataset = de_dataset.project(columns=["image", "pafs", "heatmaps", "ignore_mask"])
de_dataset = de_dataset.batch(batch_size=per_batch_size, drop_remainder=True, num_parallel_workers=num_parallel)
steap_pre_epoch = de_dataset.get_dataset_size()
de_dataset = de_dataset.repeat(max_epoch)
return de_dataset, steap_pre_epoch
def create_dataset(jsonpath, imgpath, maskpath, batch_size, rank, group_size, mode='train', repeat_num=1, shuffle=True,
multiprocessing=True, num_worker=20):

View File

@ -15,7 +15,6 @@
import os
import argparse
import cv2
import numpy as np
from tqdm import tqdm
from pycocotools.coco import COCO as ReadJson
@ -23,44 +22,44 @@ from pycocotools.coco import COCO as ReadJson
from config import params
class DataLoader():
def __init__(self, coco, dir_name, data_mode='train'):
self.train = coco
def __init__(self, train_, dir_name, mode_='train'):
self.train = train_
self.dir_name = dir_name
assert data_mode in ['train', 'val'], 'Data loading mode is invalid.'
self.mode = data_mode
self.catIds = coco.getCatIds() # catNms=['person']
self.imgIds = sorted(coco.getImgIds(catIds=self.catIds))
assert mode_ in ['train', 'val'], 'Data loading mode is invalid.'
self.mode = mode_
self.catIds = train_.getCatIds() # catNms=['person']
self.imgIds = sorted(train_.getImgIds(catIds=self.catIds))
def __len__(self):
return len(self.imgIds)
def gen_masks(self, image, anns):
_mask_all = np.zeros(image.shape[:2], 'bool')
_mask_miss = np.zeros(image.shape[:2], 'bool')
for ann in anns:
def gen_masks(self, image_, annotations_):
mask_all_1 = np.zeros(image_.shape[:2], 'bool')
mask_miss_1 = np.zeros(image_.shape[:2], 'bool')
for ann in annotations_:
mask = self.train.annToMask(ann).astype('bool')
if ann['iscrowd'] == 1:
intxn = _mask_all & mask
_mask_miss = np.bitwise_or(_mask_miss.astype(int), np.subtract(mask, intxn, dtype=np.int32))
_mask_all = np.bitwise_or(_mask_all.astype(int), mask.astype(int))
intxn = mask_all_1 & mask
mask_miss_1 = np.bitwise_or(mask_miss_1.astype(int), np.subtract(mask, intxn, dtype=np.int32))
mask_all_1 = np.bitwise_or(mask_all_1.astype(int), mask.astype(int))
elif ann['num_keypoints'] < params['min_keypoints'] or ann['area'] <= params['min_area']:
_mask_all = np.bitwise_or(_mask_all.astype(int), mask.astype(int))
_mask_miss = np.bitwise_or(_mask_miss.astype(int), mask.astype(int))
mask_all_1 = np.bitwise_or(mask_all_1.astype(int), mask.astype(int))
mask_miss_1 = np.bitwise_or(mask_miss_1.astype(int), mask.astype(int))
else:
_mask_all = np.bitwise_or(_mask_all.astype(int), mask.astype(int))
return _mask_all, _mask_miss
mask_all_1 = np.bitwise_or(mask_all_1.astype(int), mask.astype(int))
return mask_all_1, mask_miss_1
def dwaw_gen_masks(self, image, mask, color=(0, 0, 1)):
def dwaw_gen_masks(self, image_, mask, color=(0, 0, 1)):
bimsk = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
mskd = image * bimsk.astype(np.int32)
mskd = image_ * bimsk.astype(np.int32)
clmsk = np.ones(bimsk.shape) * bimsk
for ind in range(3):
clmsk[:, :, ind] = clmsk[:, :, ind] * color[ind] * 255
image = image + 0.7 * clmsk - 0.7 * mskd
return image.astype(np.uint8)
for index_1 in range(3):
clmsk[:, :, index_1] = clmsk[:, :, index_1] * color[index_1] * 255
image_ = image_ + 0.7 * clmsk - 0.7 * mskd
return image_.astype(np.uint8)
def draw_masks_and_keypoints(self, image, anns):
for ann in anns:
def draw_masks_and_keypoints(self, image_, annotations_):
for ann in annotations_:
# masks
mask = self.train.annToMask(ann).astype(np.uint8)
if ann['iscrowd'] == 1:
@ -70,30 +69,30 @@ class DataLoader():
else:
color = (1, 0, 0)
bimsk = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
mskd = image * bimsk.astype(np.int32)
mskd = image_ * bimsk.astype(np.int32)
clmsk = np.ones(bimsk.shape) * bimsk
for ind in range(3):
clmsk[:, :, ind] = clmsk[:, :, ind] * color[ind] * 255
image = image + 0.7 * clmsk - 0.7 * mskd
for index_1 in range(3):
clmsk[:, :, index_1] = clmsk[:, :, index_1] * color[index_1] * 255
image_ = image_ + 0.7 * clmsk - 0.7 * mskd
# keypoints
for x, y, v in np.array(ann['keypoints']).reshape(-1, 3):
if v == 1:
cv2.circle(image, (x, y), 3, (255, 255, 0), -1)
cv2.circle(image_, (x, y), 3, (255, 255, 0), -1)
elif v == 2:
cv2.circle(image, (x, y), 3, (255, 0, 255), -1)
return image.astype(np.uint8)
cv2.circle(image_, (x, y), 3, (255, 0, 255), -1)
return image_.astype(np.uint8)
def get_img_annotation(self, ind=None, image_id=None):
def get_img_annotation(self, ind=None, img_id_=None):
if ind is not None:
image_id = self.imgIds[ind]
img_id_ = self.imgIds[ind]
anno_ids = self.train.getAnnIds(imgIds=[image_id])
_annotations = self.train.loadAnns(anno_ids)
anno_ids = self.train.getAnnIds(imgIds=[img_id_])
annotations_ = self.train.loadAnns(anno_ids)
img_file = os.path.join(params['data_dir'], self.dir_name, self.train.loadImgs([image_id])[0]['file_name'])
_image = cv2.imread(img_file)
return _image, _annotations, image_id
img_file = os.path.join(params['data_dir'], self.dir_name, self.train.loadImgs([img_id_])[0]['file_name'])
image_ = cv2.imread(img_file)
return image_, annotations_, img_id_
if __name__ == '__main__':
@ -107,7 +106,7 @@ if __name__ == '__main__':
path_list = [args.train_ann, args.val_ann, args.train_dir, args.val_dir]
for index, mode in enumerate(['train', 'val']):
train = ReadJson(path_list[index])
data_loader = DataLoader(train, path_list[index+2], mode=mode)
data_loader = DataLoader(train, path_list[index+2], mode_=mode)
save_dir = os.path.join(params['data_dir'], 'ignore_mask_{}'.format(mode))
if not os.path.exists(save_dir):

View File

@ -12,37 +12,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import time
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.nn.loss.loss import _Loss
from mindspore.train.callback import Callback
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.communication.management import get_group_size
from mindspore.context import ParallelMode
from mindspore import context
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from src.config import params
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
time_stamp_init = False
time_stamp_first = 0
grad_scale = C.MultitypeFuncGraph("grad_scale")
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
reciprocal = P.Reciprocal()
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * F.cast(reciprocal(scale), F.dtype(grad))
GRADIENT_CLIP_TYPE = params['GRADIENT_CLIP_TYPE']
GRADIENT_CLIP_VALUE = params['GRADIENT_CLIP_VALUE']
@grad_scale.register("Tensor", "RowTensor")
def tensor_grad_scale_row_tensor(scale, grad):
return RowTensor(grad.indices,
grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)),
grad.dense_shape)
clip_grad = C.MultitypeFuncGraph("clip_grad")
@clip_grad.register("Number", "Number", "Tensor")
def _clip_grad(clip_type, clip_value, grad):
"""
Clip gradients.
Inputs:
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
clip_value (float): Specifies how much to clip.
grad (tuple[Tensor]): Gradients.
Outputs:
tuple[Tensor]: clipped gradients.
"""
if clip_type not in (0, 1):
return grad
dt = F.dtype(grad)
if clip_type == 0:
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
else:
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
return new_grad
class openpose_loss(_Loss):
def __init__(self):
super(openpose_loss, self).__init__()
@ -99,54 +115,6 @@ class openpose_loss(_Loss):
return total_loss, heatmaps_loss, pafs_loss
class Depend_network(nn.Cell):
def __init__(self, network):
super(Depend_network, self).__init__()
self.network = network
def construct(self, *args):
loss, _, _ = self.network(*args) # loss, heatmaps_loss, pafs_loss
return loss
class TrainingWrapper(nn.Cell):
def __init__(self, network, optimizer, sens=1):
super(TrainingWrapper, self).__init__(auto_prefix=False)
self.network = network
self.depend_network = Depend_network(network)
# self.weights = ms.ParameterTuple(network.trainable_params())
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = None
self.print = P.Print()
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
if self.reducer_flag:
mean = context.get_auto_parallel_context("gradients_mean")
#if mean.get_device_num_is_set():
# if mean:
#degree = context.get_auto_parallel_context("device_num")
# else:
degree = get_group_size()
self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, *args):
weights = self.weights
loss, heatmaps_loss, pafs_loss = self.network(*args)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
#grads = self.grad(self.network, weights)(*args, sens)
grads = self.grad(self.depend_network, weights)(*args, sens)
if self.reducer_flag:
grads = self.grad_reducer(grads)
#return F.depend(loss, self.optimizer(grads))
# for grad in grads:
# self.print(grad)
loss = F.depend(loss, self.optimizer(grads))
return loss, heatmaps_loss, pafs_loss
class BuildTrainNetwork(nn.Cell):
def __init__(self, network, criterion):
super(BuildTrainNetwork, self).__init__()
@ -157,51 +125,39 @@ class BuildTrainNetwork(nn.Cell):
logit_pafs, logit_heatmap = self.network(input_data)
loss, _, _ = self.criterion(logit_pafs, logit_heatmap, gt_paf, gt_heatmap, mask)
return loss
#loss = self.criterion(logit_pafs, logit_heatmap, gt_paf, gt_heatmap, mask)
# return loss, heatmaps_loss, pafs_loss
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note:
If per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
class TrainOneStepWithClipGradientCell(nn.Cell):
'''TrainOneStepWithClipGradientCell'''
def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepWithClipGradientCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.hyper_map = C.HyperMap()
self.sens = sens
self.reducer_flag = False
self.grad_reducer = None
parallel_mode = _get_parallel_mode()
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
def __init__(self, per_print_times=1):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self.count = 0
self.loss_sum = 0
def construct(self, *inputs):
global time_stamp_init, time_stamp_first
if not time_stamp_init:
time_stamp_first = time.time()
time_stamp_init = True
def step_end(self, run_context):
cb_params = run_context.original_args()
loss = cb_params.net_outputs.asnumpy()
self.count += 1
self.loss_sum += float(loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
if self.count >= 1:
global time_stamp_first
time_stamp_current = time.time()
loss = self.loss_sum/self.count
loss_file = open("./loss.log", "a+")
loss_file.write("%lu epoch: %s step: %s ,loss: %.5f" %
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
loss))
loss_file.write("\n")
loss_file.close()
self.count = 0
self.loss_sum = 0
weights = self.weights
loss = self.network(*inputs)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*inputs, sens)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))

View File

@ -17,19 +17,18 @@ from mindspore.nn import Conv2d, ReLU
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.ops import operations as P
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
#selfCat = P.Concat(axis=1)
time_stamp_init = False
time_stamp_first = 0
loadvgg = 1
class OpenPoseNet(nn.Cell):
insize = 368
def __init__(self, vggpath=''):
def __init__(self, vggpath='', vgg_with_bn=False):
super(OpenPoseNet, self).__init__()
self.base = Base_model()
self.base = Base_model(vgg_with_bn=vgg_with_bn)
self.stage_1 = Stage_1()
self.stage_2 = Stage_x()
self.stage_3 = Stage_x()
@ -39,23 +38,15 @@ class OpenPoseNet(nn.Cell):
self.shape = P.Shape()
self.cat = P.Concat(axis=1)
self.print = P.Print()
# for m in self.modules():
# if isinstance(m, Conv2d):
# init.constant_(m.bias, 0)
if loadvgg and vggpath:
param_dict = load_checkpoint(vggpath)
param_dict_new = {}
trans_name = 'base.vgg_base.'
for key, values in param_dict.items():
#print('key:',key,self.shape(values))
if key.startswith('moments.'):
continue
elif key.startswith('network.'):
param_dict_new[trans_name+key[17:]] = values
# else:
# param_dict_new[key] = values
#print(param_dict_new)
load_param_into_net(self.base.vgg_base, param_dict_new)
def construct(self, x):
@ -205,20 +196,17 @@ class VGG_Base_MS(nn.Cell):
return x
class Base_model(nn.Cell):
def __init__(self):
def __init__(self, vgg_with_bn=False):
super(Base_model, self).__init__()
#cfgs_zh = {'16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512]}
cfgs_zh = {'19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512]}
#cfgs_zh = {'16': [64, 64,128, 128, 256, 256, 256, 512, 512, 512]}
self.vgg_base = Vgg(cfgs_zh['19'], batch_norm=False)
#self.vgg_base = VGG_Base()
#self.vgg_base = VGG_Base_MS()
self.vgg_base = Vgg(cfgs_zh['19'], batch_norm=vgg_with_bn)
self.conv4_3_CPM = Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, pad_mode='same',
has_bias=True)
self.conv4_4_CPM = Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, pad_mode='same',
has_bias=True)
self.relu = ReLU()
def construct(self, x):
x = self.vgg_base(x)
x = self.relu(self.conv4_3_CPM(x))

View File

@ -1,27 +1,11 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import os
import math
import time
import numpy as np
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import LossMonitor
from mindspore.train.callback import LossMonitor, Callback
from mindspore.common.tensor import Tensor
from src.config import params
from mindspore.common import dtype as mstype
class MyLossMonitor(LossMonitor):
def __init__(self, per_print_times=1):
@ -32,6 +16,7 @@ class MyLossMonitor(LossMonitor):
def step_end(self, run_context):
cb_params = run_context.original_args()
loss = cb_params.net_outputs
if isinstance(loss, (tuple, list)):
@ -47,63 +32,76 @@ class MyLossMonitor(LossMonitor):
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
cb_params.cur_epoch_num, cur_step_in_epoch))
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
# print("epoch: %s step: %s, loss is %s, step time: %.3f s." % (cb_params.cur_epoch_num, cur_step_in_epoch,
# loss,
# (time.time() - self._start_time)), flush=True)
self._loss_list.append(loss)
if cb_params.cur_step_num % 100 == 0:
print("epoch: %s, steps: [%s] mean loss is: %s"%(cb_params.cur_epoch_num, cur_step_in_epoch,
print("epoch: %s, steps: [%s], mean loss is: %s"%(cb_params.cur_epoch_num, cur_step_in_epoch,
np.array(self._loss_list).mean()), flush=True)
self._loss_list = []
self._start_time = time.time()
class MyScaleSensCallback(Callback):
'''MyLossScaleCallback'''
def __init__(self, loss_scale_list, epoch_list):
super(MyScaleSensCallback, self).__init__()
self.loss_scale_list = loss_scale_list
self.epoch_list = epoch_list
self.scaling_sens = loss_scale_list[0]
def parse_args():
"""Parse train arguments."""
parser = argparse.ArgumentParser('mindspore openpose training')
def epoch_end(self, run_context):
cb_params = run_context.original_args()
epoch = cb_params.cur_epoch_num
# dataset related
parser.add_argument('--train_dir', type=str, default='train2017', help='train data dir')
parser.add_argument('--train_ann', type=str, default='person_keypoints_train2017.json',
help='train annotations json')
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
for i, _ in enumerate(self.epoch_list):
if epoch >= self.epoch_list[i]:
self.scaling_sens = self.loss_scale_list[i+1]
else:
break
args, _ = parser.parse_known_args()
args.jsonpath_train = os.path.join(params['data_dir'], 'annotations/' + args.train_ann)
args.imgpath_train = os.path.join(params['data_dir'], args.train_dir)
args.maskpath_train = os.path.join(params['data_dir'], 'ignore_mask_train')
return args
scaling_sens_tensor = Tensor(self.scaling_sens, dtype=mstype.float32)
cb_params.train_network.set_sense_scale(scaling_sens_tensor)
print("Epoch: set train network scale sens to {}".format(self.scaling_sens))
def get_lr(lr, lr_gamma, steps_per_epoch, max_epoch_train, lr_steps, group_size):
def _linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
learning_rate = float(init_lr) + lr_inc * current_step
return learning_rate
def _a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
base = float(current_step - warmup_steps) / float(decay_steps)
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
return learning_rate
def _dynamic_lr(base_lr, total_steps, warmup_steps, warmup_ratio=1 / 3):
lr = []
for i in range(total_steps):
if i < warmup_steps:
lr.append(_linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * warmup_ratio))
else:
lr.append(_a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
return lr
def get_lr(lr, lr_gamma, steps_per_epoch, max_epoch_train, lr_steps, group_size, lr_type='default', warmup_epoch=5):
if lr_type == 'default':
lr_stage = np.array([lr] * steps_per_epoch * max_epoch_train).astype('f')
for step in lr_steps:
step //= group_size
lr_stage[step:] *= lr_gamma
elif lr_type == 'cosine':
lr_stage = _dynamic_lr(lr, steps_per_epoch * max_epoch_train, warmup_epoch * steps_per_epoch,
warmup_ratio=1 / 3)
lr_stage = np.array(lr_stage).astype('f')
else:
raise ValueError("lr type {} is not support.".format(lr_type))
lr_base = lr_stage.copy()
lr_base = lr_base / 4
lr_vgg = lr_base.copy()
vgg_freeze_step = 2000
vgg_freeze_step = 2000 // group_size
lr_vgg[:vgg_freeze_step] = 0
return lr_stage, lr_base, lr_vgg
# zhang add
def adjust_learning_rate(init_lr, lr_gamma, steps_per_epoch, max_epoch_train, stepvalues):
lr_stage = np.array([init_lr] * steps_per_epoch * max_epoch_train).astype('f')
for epoch in stepvalues:
lr_stage[epoch * steps_per_epoch:] *= lr_gamma
lr_base = lr_stage.copy()
lr_base = lr_base / 4
lr_vgg = lr_base.copy()
vgg_freeze_step = 2000
lr_vgg[:vgg_freeze_step] = 0
return lr_stage, lr_base, lr_vgg

View File

@ -13,26 +13,38 @@
# limitations under the License.
# ============================================================================
import os
import argparse
import mindspore
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.nn.optim import Adam, Momentum
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.nn.optim import Adam
from src.dataset import create_dataset
from src.openposenet import OpenPoseNet
from src.loss import openpose_loss, BuildTrainNetwork
from src.loss import openpose_loss, BuildTrainNetwork, TrainOneStepWithClipGradientCell
from src.config import params
from src.utils import parse_args, get_lr, load_model, MyLossMonitor
from src.utils import get_lr, load_model, MyLossMonitor
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
parser = argparse.ArgumentParser('mindspore openpose training')
parser.add_argument('--train_dir', type=str, default='train2017', help='train data dir')
parser.add_argument('--train_ann', type=str, default='person_keypoints_train2017.json',
help='train annotations json')
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
args, _ = parser.parse_known_args()
args.jsonpath_train = os.path.join(params['data_dir'], 'annotations/' + args.train_ann)
args.imgpath_train = os.path.join(params['data_dir'], args.train_dir)
args.maskpath_train = os.path.join(params['data_dir'], 'ignore_mask_train')
def train():
"""Train function."""
args = parse_args()
args.outputs_dir = params['save_model_path']
@ -46,11 +58,15 @@ def train():
args.outputs_dir = os.path.join(args.outputs_dir, "ckpt_0/")
args.rank = 0
# with out loss_scale
if args.group_size > 1:
args.max_epoch = params["max_epoch_train_NP"]
args.loss_scale = params['loss_scale'] / 2
args.lr_steps = list(map(int, params["lr_steps_NP"].split(',')))
params['train_type'] = params['train_type_NP']
params['optimizer'] = params['optimizer_NP']
params['group_params'] = params['group_params_NP']
else:
args.max_epoch = params["max_epoch_train"]
args.loss_scale = params['loss_scale']
args.lr_steps = list(map(int, params["lr_steps"].split(',')))
@ -58,9 +74,7 @@ def train():
print('start create network')
criterion = openpose_loss()
criterion.add_flags_recursive(fp32=True)
network = OpenPoseNet(vggpath=params['vgg_path'])
# network.add_flags_recursive(fp32=True)
network = OpenPoseNet(vggpath=params['vgg_path'], vgg_with_bn=params['vgg_with_bn'])
if params["load_pretrain"]:
print("load pretrain model:", params["pretrained_model_path"])
load_model(network, params["pretrained_model_path"])
@ -72,7 +86,7 @@ def train():
print('start create dataset')
else:
print('Error: wrong data path')
return 0
num_worker = 20 if args.group_size > 1 else 48
de_dataset_train = create_dataset(args.jsonpath_train, args.imgpath_train, args.maskpath_train,
@ -90,9 +104,14 @@ def train():
lr_stage, lr_base, lr_vgg = get_lr(params['lr'] * args.group_size,
params['lr_gamma'],
steps_per_epoch,
params["max_epoch_train"],
args.max_epoch,
args.lr_steps,
args.group_size)
args.group_size,
lr_type=params['lr_type'],
warmup_epoch=params['warmup_epoch'])
# optimizer
if params['group_params']:
vgg19_base_params = list(filter(lambda x: 'base.vgg_base' in x.name, train_net.trainable_params()))
base_params = list(filter(lambda x: 'base.conv' in x.name, train_net.trainable_params()))
stages_params = list(filter(lambda x: 'base' not in x.name, train_net.trainable_params()))
@ -101,24 +120,47 @@ def train():
{'params': base_params, 'lr': lr_base},
{'params': stages_params, 'lr': lr_stage}]
opt = Adam(group_params, loss_scale=args.loss_scale)
if params['optimizer'] == "Momentum":
opt = Momentum(group_params, learning_rate=lr_stage, momentum=0.9)
elif params['optimizer'] == "Adam":
opt = Adam(group_params)
else:
raise ValueError("optimizer not support.")
else:
if params['optimizer'] == "Momentum":
opt = Momentum(train_net.trainable_params(), learning_rate=lr_stage, momentum=0.9)
elif params['optimizer'] == "Adam":
opt = Adam(train_net.trainable_params(), learning_rate=lr_stage)
else:
raise ValueError("optimizer not support.")
train_net.set_train(True)
loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
model = Model(train_net, optimizer=opt, loss_scale_manager=loss_scale_manager)
params['ckpt_interval'] = max(steps_per_epoch, params['ckpt_interval'])
# callback
config_ck = CheckpointConfig(save_checkpoint_steps=params['ckpt_interval'],
keep_checkpoint_max=params["keep_checkpoint_max"])
ckpoint_cb = ModelCheckpoint(prefix='{}'.format(args.rank), directory=args.outputs_dir, config=config_ck)
time_cb = TimeMonitor(data_size=de_dataset_train.get_dataset_size())
if args.rank == 0:
callback_list = [MyLossMonitor(), time_cb, ckpoint_cb]
print("============== Starting Training ==============")
model.train(params["max_epoch_train"], de_dataset_train, callbacks=callback_list,
dataset_sink_mode=False)
else:
callback_list = [MyLossMonitor(), time_cb]
# train
if params['train_type'] == 'clip_grad':
train_net = TrainOneStepWithClipGradientCell(train_net, opt, sens=args.loss_scale)
train_net.set_train()
model = Model(train_net)
elif params['train_type'] == 'fix_loss_scale':
loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
train_net.set_train()
model = Model(train_net, optimizer=opt, loss_scale_manager=loss_scale_manager)
else:
raise ValueError("Type {} is not support.".format(params['train_type']))
print("============== Starting Training ==============")
model.train(args.max_epoch, de_dataset_train, callbacks=callback_list,
dataset_sink_mode=False)
return 0
if __name__ == "__main__":
# mindspore.common.seed.set_seed(1)
mindspore.common.seed.set_seed(1)
train()