From c596e8cab65db8b8b895194f5780da55b4d7723a Mon Sep 17 00:00:00 2001 From: l_emon <715485601@qq.com> Date: Thu, 15 Jul 2021 16:56:31 +0800 Subject: [PATCH] bug fix --- model_zoo/research/cv/SRGAN/README.md | 17 +++++++++------ model_zoo/research/cv/SRGAN/eval.py | 21 ++----------------- .../research/cv/SRGAN/scripts/run_eval.sh | 17 +++++++++------ .../cv/SRGAN/scripts/run_standalone_train.sh | 2 +- model_zoo/research/cv/SRGAN/train.py | 6 ++---- 5 files changed, 27 insertions(+), 36 deletions(-) diff --git a/model_zoo/research/cv/SRGAN/README.md b/model_zoo/research/cv/SRGAN/README.md index 20eabfd9b1b..0904b8eb250 100644 --- a/model_zoo/research/cv/SRGAN/README.md +++ b/model_zoo/research/cv/SRGAN/README.md @@ -41,7 +41,7 @@ Validation and eval evaluationdataset used: [Set5]()| +[Training scripts]()| [VGG19 pretrained model]() # [Environment Requirements](#contents) @@ -51,8 +51,8 @@ The process of training SRGAN needs a pretrained VGG19 based on Imagenet. - Framework - [MindSpore](https://www.mindspore.cn/install/en) - For more information, please check the resources below: - - [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/master/index.html) - - [MindSpore Python API](https://www.mindspore.cn/docs/api/en/master/index.html) + - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) # [Script Description](#contents) @@ -97,13 +97,16 @@ SRGAN # distributed training Usage: sh run_distribute_train.sh [DEVICE_NUM] [DISTRIBUTE] [RANK_TABLE_FILE] [LRPATH] [GTPATH] [VGGCKPT] [VLRPATH] [VGTPATH] +eg: sh run_distribute_train.sh 8 1 ./hccl_8p.json ./DIV2K_train_LR_bicubic/X4 ./DIV2K_train_HR ./vgg.ckpt ./Set5/LR ./Set5/HR # standalone training Usage: sh run_standalone_train.sh [DEVICE_ID] [LRPATH] [GTPATH] [VGGCKPT] [VLRPATH] [VGTPATH] + +eg: sh run_distribute_train.sh 0 ./DIV2K_train_LR_bicubic/X4 ./DIV2K_train_HR ./vgg.ckpt ./Set5/LR ./Set5/HR ``` ### [Training Result](#content) -Training result will be stored in scripts/srgan0/ckpt. You can find checkpoint file. +Training result will be stored in scripts/train_parallel0/ckpt. You can find checkpoint file. ### [Evaluation Script Parameters](#content) @@ -111,7 +114,9 @@ Training result will be stored in scripts/srgan0/ckpt. You can find checkpoint f ```bash # evaling -sh run_eval.sh [CKPT] [EVALLRPATH] [EVALGTPATH] +sh run_eval.sh [CKPT] [EVALLRPATH] [EVALGTPATH] [DEVICE_ID] + +eg: sh run_eval.sh ./ckpt/best.ckpt ./Set14/LR ./Set14/HR 0 ``` ### [Evaluation result](#content) @@ -153,4 +158,4 @@ Evaluation result will be stored in the scripts/result. Under this, you can find # [ModelZoo Homepage](#contents) -Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). \ No newline at end of file +Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). diff --git a/model_zoo/research/cv/SRGAN/eval.py b/model_zoo/research/cv/SRGAN/eval.py index 1aef8f116c4..8ad35559472 100644 --- a/model_zoo/research/cv/SRGAN/eval.py +++ b/model_zoo/research/cv/SRGAN/eval.py @@ -14,12 +14,10 @@ # ============================================================================ """file for evaling""" -import os import argparse import numpy as np from skimage.color import rgb2ycbcr from skimage.metrics import peak_signal_noise_ratio -from PIL import Image from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.common import set_seed from mindspore import context @@ -34,7 +32,7 @@ parser.add_argument("--test_LR_path", type=str, default='/data/Set14/LR') parser.add_argument("--test_GT_path", type=str, default='/data/Set14/HR') parser.add_argument("--res_num", type=int, default=16) parser.add_argument("--scale", type=int, default=4) -parser.add_argument("--generator_path", type=str, default='./scripts/srgan0/ckpt/G_model_1000.ckpt') +parser.add_argument("--generator_path", type=str, default='./ckpt/best.ckpt') parser.add_argument("--mode", type=str, default='train') parser.add_argument("--device_id", type=int, default=0, help="device id, default: 0.") i = 0 @@ -47,9 +45,6 @@ if __name__ == '__main__': params = load_checkpoint(args.generator_path) load_param_into_net(generator, params) op = ops.ReduceSum(keep_dims=False) - if not os.path.exists("./result/Set14"): - os.makedirs("./result/Set14") - weizhi = './result/Set14/psnr.txt' psnr_list = [] print("=======starting test=====") @@ -78,16 +73,4 @@ if __name__ == '__main__': psnr = peak_signal_noise_ratio(y_output / 255.0, y_gt / 255.0, data_range=1.0) psnr_list.append(psnr) - psnr = str(psnr) - with open(weizhi, "w") as f: - f.write('psnr : %s \n' % psnr) - - result = Image.fromarray((output * 255.0).astype(np.uint8)) - result.save('./result/Set14/res_%04d.png'%i) - i = i+1 - mean = np.mean(psnr_list) - mean = str(mean) - print("avg PSNR:") - print(mean) - with open(weizhi, "w") as f: - f.write('avg psnr : %s \n' % mean) + print("avg PSNR:", np.mean(psnr_list)) diff --git a/model_zoo/research/cv/SRGAN/scripts/run_eval.sh b/model_zoo/research/cv/SRGAN/scripts/run_eval.sh index e5733fd9719..dce198f9acf 100644 --- a/model_zoo/research/cv/SRGAN/scripts/run_eval.sh +++ b/model_zoo/research/cv/SRGAN/scripts/run_eval.sh @@ -12,18 +12,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -if [ $# != 3 ] +if [ $# != 4 ] then - echo "Usage: sh run_eval.sh [CKPT] [EVALLRPATH] [EVALGTPATH]" + echo "Usage: sh run_eval.sh [CKPT] [EVALLRPATH] [EVALGTPATH] [DEVICE_ID]" exit 1 fi export CKPT=$1 export EVALLRPATH=$2 export EVALGTPATH=$3 +export DEVICE_ID=$4 + +rm -rf ./eval +mkdir ./eval +cp -r ../src ./eval +cp -r ../*.py ./eval +cd ./eval || exit env > env.log -if [ $# == 3 ] -then -python ../eval.py --generator_path=$CKPT --test_LR_path=$EVALLRPATH --test_GT_path=$EVALGTPATH &> log & -fi +python ./eval.py --generator_path=$CKPT --test_LR_path=$EVALLRPATH --device_id $DEVICE_ID\ + --test_GT_path=$EVALGTPATH &> log & diff --git a/model_zoo/research/cv/SRGAN/scripts/run_standalone_train.sh b/model_zoo/research/cv/SRGAN/scripts/run_standalone_train.sh index a1a8cb5c30b..c910bc61db7 100644 --- a/model_zoo/research/cv/SRGAN/scripts/run_standalone_train.sh +++ b/model_zoo/research/cv/SRGAN/scripts/run_standalone_train.sh @@ -35,7 +35,7 @@ cp -r ../src ./train_standalone cp -r ../*.py ./train_standalone cd ./train_standalone || exit -echo "start traning" +echo "start training" env > env.log if [ $# == 6 ] then diff --git a/model_zoo/research/cv/SRGAN/train.py b/model_zoo/research/cv/SRGAN/train.py index d66d3c23fb9..896e9ed2691 100644 --- a/model_zoo/research/cv/SRGAN/train.py +++ b/model_zoo/research/cv/SRGAN/train.py @@ -113,7 +113,7 @@ if __name__ == '__main__': os.makedirs("./ckpt") print('start training:') - print('tart training PSNR:') + print('start training PSNR:') # warm up generator for epoch in range(args.start_psnr_epoch, args.psnr_epochs): print("training {:d} epoch:".format(epoch+1)) @@ -124,7 +124,6 @@ if __name__ == '__main__': mse_loss = train_psnr(hr, lr) steps = train_ds.get_dataset_size() time_elapsed = (time.time()-mysince) - print('the epoch needs time:{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) step_time = time_elapsed / steps print('per step needs time:{:.0f}ms'.format(step_time * 1000)) print("mse_loss:") @@ -192,7 +191,7 @@ if __name__ == '__main__': discriminator_optimizer = nn.Adam(discriminator.trainable_params(), 1e-4) train_discriminator = TrainOneStepD(discriminator_loss, discriminator_optimizer) train_generator = TrainOnestepG(generator_loss, generator_optimizer) - + print("========================================") print('start training GAN :') # trainGAN for epoch in range(args.start_gan_epoch, args.gan_epochs): @@ -205,7 +204,6 @@ if __name__ == '__main__': G_loss = train_generator(hr, lr) time_elapsed1 = (time.time()-mysince1) steps = train_ds.get_dataset_size() - print('the epoch needs time:{:.0f}m {:.0f}s'.format(time_elapsed1 // 60, time_elapsed1 % 60)) step_time1 = time_elapsed1 / steps print('per step needs time:{:.0f}ms'.format(step_time1 * 1000)) print("D_loss:")