diff --git a/model_zoo/official/cv/resnet/eval.py b/model_zoo/official/cv/resnet/eval.py index e0925346c7..0d793aaa91 100755 --- a/model_zoo/official/cv/resnet/eval.py +++ b/model_zoo/official/cv/resnet/eval.py @@ -58,11 +58,8 @@ if __name__ == '__main__': context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False, device_id=device_id) # create dataset - if args_opt.net == "resnet50": - dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size, - target=target) - else: - dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size) + dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size, + target=target) step_size = dataset.get_dataset_size() # define net diff --git a/model_zoo/official/cv/resnet/scripts/run_distribute_train_gpu.sh b/model_zoo/official/cv/resnet/scripts/run_distribute_train_gpu.sh new file mode 100755 index 0000000000..95e0d7df06 --- /dev/null +++ b/model_zoo/official/cv/resnet/scripts/run_distribute_train_gpu.sh @@ -0,0 +1,93 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 3 ] && [ $# != 4 ] +then + echo "Usage: sh run_distribute_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" +exit 1 +fi + +if [ $1 != "resnet50" ] && [ $1 != "resnet101" ] +then + echo "error: the selected net is neither resnet50 nor resnet101" +exit 1 +fi + +if [ $2 != "cifar10" ] && [ $2 != "imagenet2012" ] +then + echo "error: the selected dataset is neither cifar10 nor imagenet2012" +exit 1 +fi + +if [ $1 == "resnet101" ] && [ $2 == "cifar10" ] +then + echo "error: training resnet101 with cifar10 dataset is unsupported now!" +exit 1 +fi + + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $3) + +if [ $# == 4 ] +then + PATH2=$(get_real_path $4) +fi + + +if [ ! -d $PATH2 ] +then + echo "error: DATASET_PATH=$PATH1 is not a directory" +exit 1 +fi + +if [ $# == 5 ] && [ ! -f $PATH2 ] +then + echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=8 +export RANK_SIZE=8 + +rm -rf ./train_parallel +mkdir ./train_parallel +cp ../*.py ./train_parallel +cp *.sh ./train_parallel +cp -r ../src ./train_parallel +cd ./train_parallel || exit + +if [ $# == 3 ] +then + mpirun --allow-run-as-root -n $RANK_SIZE \ + python train.py --net=$1 --dataset=$2 --run_distribute=True \ + --device_num=$DEVICE_NUM --device_target="GPU" --dataset_path=$PATH1 &> log & +fi + +if [ $# == 4 ] +then + mpirun --allow-run-as-root -n $RANK_SIZE \ + python train.py --net=$1 --dataset=$2 --run_distribute=True \ + --device_num=$DEVICE_NUM --device_target="GPU" --dataset_path=$PATH1 --pre_trained=$PATH2 &> log & +fi diff --git a/model_zoo/official/cv/resnet/scripts/run_standalone_train_gpu.sh b/model_zoo/official/cv/resnet/scripts/run_standalone_train_gpu.sh new file mode 100755 index 0000000000..0be444d738 --- /dev/null +++ b/model_zoo/official/cv/resnet/scripts/run_standalone_train_gpu.sh @@ -0,0 +1,95 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 3 ] && [ $# != 4 ] +then + echo "Usage: sh run_standalone_train.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" +exit 1 +fi + +if [ $1 != "resnet50" ] && [ $1 != "resnet101" ] +then + echo "error: the selected net is neither resnet50 nor resnet101" +exit 1 +fi + +if [ $2 != "cifar10" ] && [ $2 != "imagenet2012" ] +then + echo "error: the selected dataset is neither cifar10 nor imagenet2012" +exit 1 +fi + +if [ $1 == "resnet101" ] && [ $2 == "cifar10" ] +then + echo "error: training resnet101 with cifar10 dataset is unsupported now!" +exit 1 +fi + + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $3) + +if [ $# == 4 ] +then + PATH2=$(get_real_path $4) +fi + +if [ ! -d $PATH1 ] +then + echo "error: DATASET_PATH=$PATH1 is not a directory" +exit 1 +fi + +if [ $# == 4 ] && [ ! -f $PATH2 ] +then + echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=1 +export DEVICE_ID=0 +export RANK_ID=0 +export RANK_SIZE=1 + +if [ -d "train" ]; +then + rm -rf ./train +fi +mkdir ./train +cp ../*.py ./train +cp *.sh ./train +cp -r ../src ./train +cd ./train || exit +echo "start training for device $DEVICE_ID" +env > env.log +if [ $# == 3 ] +then + python train.py --net=$1 --dataset=$2 --device_target="GPU" --dataset_path=$PATH1 &> log & +fi + +if [ $# == 4 ] +then + python train.py --net=$1 --dataset=$2 --device_target="GPU" --dataset_path=$PATH1 --pre_trained=$PATH2 &> log & +fi +cd .. diff --git a/model_zoo/official/cv/resnet/src/dataset.py b/model_zoo/official/cv/resnet/src/dataset.py index 79730fc460..d4a8969ed1 100755 --- a/model_zoo/official/cv/resnet/src/dataset.py +++ b/model_zoo/official/cv/resnet/src/dataset.py @@ -139,7 +139,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target= return ds -def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32): +def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"): """ create a train or eval imagenet2012 dataset for resnet101 Args: @@ -158,36 +158,26 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32): else: ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, num_shards=device_num, shard_id=rank_id) - resize_height = 224 - rescale = 1.0 / 255.0 - shift = 0.0 + image_size = 224 + mean = [0.475 * 255, 0.451 * 255, 0.392 * 255] + std = [0.275 * 255, 0.267 * 255, 0.278 * 255] # define map operations - decode_op = C.Decode() - - random_resize_crop_op = C.RandomResizedCrop(resize_height, (0.08, 1.0), (0.75, 1.33), max_attempts=100) - horizontal_flip_op = C.RandomHorizontalFlip(rank_id / (rank_id + 1)) - resize_op_256 = C.Resize((256, 256)) - center_crop = C.CenterCrop(224) - rescale_op = C.Rescale(rescale, shift) - normalize_op = C.Normalize((0.475, 0.451, 0.392), (0.275, 0.267, 0.278)) - changeswap_op = C.HWC2CHW() - if do_train: - trans = [decode_op, - random_resize_crop_op, - horizontal_flip_op, - rescale_op, - normalize_op, - changeswap_op] - + trans = [ + C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), + C.RandomHorizontalFlip(rank_id/ (rank_id +1)), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] else: - trans = [decode_op, - resize_op_256, - center_crop, - rescale_op, - normalize_op, - changeswap_op] + trans = [ + C.Decode(), + C.Resize(256), + C.CenterCrop(image_size), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] type_cast_op = C2.TypeCast(mstype.int32) diff --git a/model_zoo/official/cv/resnet/train.py b/model_zoo/official/cv/resnet/train.py index be1c6290b1..b721fa0f92 100755 --- a/model_zoo/official/cv/resnet/train.py +++ b/model_zoo/official/cv/resnet/train.py @@ -86,12 +86,8 @@ if __name__ == '__main__': ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" # create dataset - if args_opt.net == "resnet50": - dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1, - batch_size=config.batch_size, target=target) - else: - dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1, - batch_size=config.batch_size) + dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1, + batch_size=config.batch_size, target=target) step_size = dataset.get_dataset_size() # define net @@ -122,7 +118,7 @@ if __name__ == '__main__': lr = get_lr(lr_init=config.lr_init, lr_end=0.0, lr_max=config.lr_max, warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size, steps_per_epoch=step_size, lr_decay_mode='cosine') else: - lr = warmup_cosine_annealing_lr(config.lr, step_size, config.warmup_epochs, 120, + lr = warmup_cosine_annealing_lr(config.lr, step_size, config.warmup_epochs, config.epoch_size, config.pretrain_epoch_size * step_size) lr = Tensor(lr) @@ -147,9 +143,13 @@ if __name__ == '__main__': amp_level="O2", keep_batchnorm_fp32=False) else: # GPU target - loss = SoftmaxCrossEntropyWithLogits(sparse=True, is_grad=False, reduction='mean') - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum) + loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", is_grad=False, + smooth_factor=config.label_smooth_factor, num_classes=config.class_num) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay) model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) + ##Mixed precision + #model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, + # amp_level="O2", keep_batchnorm_fp32=True) # define callbacks time_cb = TimeMonitor(data_size=step_size)