forked from mindspore-Ecosystem/mindspore
!20634 optimize FaceRecognitionForTracking training speed
Merge pull request !20634 from zhouneng/code_docs_fix_issue_I3XLT6
This commit is contained in:
commit
243116d480
|
@ -46,7 +46,7 @@ world_size: 8
|
||||||
# logging related
|
# logging related
|
||||||
log_interval: 10
|
log_interval: 10
|
||||||
ckpt_path: '../../output'
|
ckpt_path: '../../output'
|
||||||
ckpt_interval: 200
|
ckpt_interval: 400
|
||||||
|
|
||||||
# train/eval option
|
# train/eval option
|
||||||
data_dir: ''
|
data_dir: ''
|
||||||
|
|
|
@ -70,6 +70,10 @@ echo $PRETRAINED_BACKBONE
|
||||||
export RANK_TABLE_FILE=$RANK_TABLE
|
export RANK_TABLE_FILE=$RANK_TABLE
|
||||||
export RANK_SIZE=8
|
export RANK_SIZE=8
|
||||||
|
|
||||||
|
cpus=`cat /proc/cpuinfo| grep "processor"| wc -l`
|
||||||
|
avg=`expr $cpus \/ $RANK_SIZE`
|
||||||
|
gap=`expr $avg \- 1`
|
||||||
|
|
||||||
config_path="${dirname_path}/reid_8p_ascend_config.yaml"
|
config_path="${dirname_path}/reid_8p_ascend_config.yaml"
|
||||||
echo "config path is : ${config_path}"
|
echo "config path is : ${config_path}"
|
||||||
|
|
||||||
|
@ -77,12 +81,15 @@ echo 'start training'
|
||||||
for((i=0;i<=$RANK_SIZE-1;i++));
|
for((i=0;i<=$RANK_SIZE-1;i++));
|
||||||
do
|
do
|
||||||
echo 'start rank '$i
|
echo 'start rank '$i
|
||||||
|
start=`expr $i \* $avg`
|
||||||
|
end=`expr $start \+ $gap`
|
||||||
|
cmdopt=$start"-"$end
|
||||||
mkdir ${current_exec_path}/device$i
|
mkdir ${current_exec_path}/device$i
|
||||||
cd ${current_exec_path}/device$i || exit
|
cd ${current_exec_path}/device$i || exit
|
||||||
export RANK_ID=$i
|
export RANK_ID=$i
|
||||||
dev=`expr $i + 0`
|
dev=`expr $i + 0`
|
||||||
export DEVICE_ID=$dev
|
export DEVICE_ID=$dev
|
||||||
python ${dirname_path}/${SCRIPT_NAME} \
|
taskset -c $cmdopt python ${dirname_path}/${SCRIPT_NAME} \
|
||||||
--config_path=$config_path \
|
--config_path=$config_path \
|
||||||
--is_distributed=1 \
|
--is_distributed=1 \
|
||||||
--data_dir=$DATA_DIR \
|
--data_dir=$DATA_DIR \
|
||||||
|
|
|
@ -38,9 +38,9 @@ def get_de_dataset(args):
|
||||||
VC.HWC2CHW()]
|
VC.HWC2CHW()]
|
||||||
|
|
||||||
de_dataset = de.ImageFolderDataset(dataset_dir=args.data_dir, num_shards=args.world_size,
|
de_dataset = de.ImageFolderDataset(dataset_dir=args.data_dir, num_shards=args.world_size,
|
||||||
shard_id=args.local_rank, shuffle=True)
|
shard_id=args.local_rank, shuffle=True, num_parallel_workers=4)
|
||||||
de_dataset = de_dataset.map(input_columns="image", operations=transform_img)
|
de_dataset = de_dataset.map(input_columns="image", operations=transform_img, num_parallel_workers=4)
|
||||||
de_dataset = de_dataset.map(input_columns="label", operations=transform_label)
|
de_dataset = de_dataset.map(input_columns="label", operations=transform_label, num_parallel_workers=4)
|
||||||
de_dataset = de_dataset.project(columns=["image", "label"])
|
de_dataset = de_dataset.project(columns=["image", "label"])
|
||||||
de_dataset = de_dataset.batch(args.per_batch_size, drop_remainder=True)
|
de_dataset = de_dataset.batch(args.per_batch_size, drop_remainder=True)
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,6 @@ import numpy as np
|
||||||
|
|
||||||
import mindspore
|
import mindspore
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import Tensor
|
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore.train.callback import ModelCheckpoint, RunContext, _InternalCallbackParam, CheckpointConfig
|
from mindspore.train.callback import ModelCheckpoint, RunContext, _InternalCallbackParam, CheckpointConfig
|
||||||
|
@ -67,6 +66,9 @@ def init_argument():
|
||||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=config.world_size,
|
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=config.world_size,
|
||||||
gradients_mean=True)
|
gradients_mean=True)
|
||||||
|
|
||||||
|
if config.device_target == 'Ascend' and config.is_distributed:
|
||||||
|
context.set_auto_parallel_context(all_reduce_fusion_config=[1, 10])
|
||||||
|
|
||||||
mindspore.common.set_seed(1)
|
mindspore.common.set_seed(1)
|
||||||
|
|
||||||
# logger
|
# logger
|
||||||
|
@ -141,7 +143,13 @@ def run_train():
|
||||||
de_dataset, steps_per_epoch, class_num = get_de_dataset(cfg)
|
de_dataset, steps_per_epoch, class_num = get_de_dataset(cfg)
|
||||||
cfg.steps_per_epoch = steps_per_epoch
|
cfg.steps_per_epoch = steps_per_epoch
|
||||||
cfg.logger.info('step per epoch: %s', cfg.steps_per_epoch)
|
cfg.logger.info('step per epoch: %s', cfg.steps_per_epoch)
|
||||||
|
|
||||||
|
# increase training speed for Ascend and distribute mode
|
||||||
|
if config.device_target == 'Ascend' and config.is_distributed:
|
||||||
|
de_dataloader = de_dataset.create_tuple_iterator(do_copy=False)
|
||||||
|
else:
|
||||||
de_dataloader = de_dataset.create_tuple_iterator()
|
de_dataloader = de_dataset.create_tuple_iterator()
|
||||||
|
|
||||||
cfg.logger.info('class num original: %s', class_num)
|
cfg.logger.info('class num original: %s', class_num)
|
||||||
if class_num % 16 != 0:
|
if class_num % 16 != 0:
|
||||||
class_num = (class_num // 16 + 1) * 16
|
class_num = (class_num // 16 + 1) * 16
|
||||||
|
@ -214,8 +222,6 @@ def run_train():
|
||||||
cfg.logger.important_info('====start train====')
|
cfg.logger.important_info('====start train====')
|
||||||
for i, total_data in enumerate(de_dataloader):
|
for i, total_data in enumerate(de_dataloader):
|
||||||
data, gt = total_data
|
data, gt = total_data
|
||||||
data = Tensor(data)
|
|
||||||
gt = Tensor(gt)
|
|
||||||
|
|
||||||
loss = train_net(data, gt)
|
loss = train_net(data, gt)
|
||||||
loss_meter.update(loss.asnumpy())
|
loss_meter.update(loss.asnumpy())
|
||||||
|
|
Loading…
Reference in New Issue