This commit is contained in:
l_emon 2021-07-15 16:56:31 +08:00
parent e67b74e8e3
commit c596e8cab6
5 changed files with 27 additions and 36 deletions

View File

@ -41,7 +41,7 @@ Validation and eval evaluationdataset used: [Set5](<http://people.rennes.inria.f
The process of training SRGAN needs a pretrained VGG19 based on Imagenet.
[Training scripts](<https://gitee.com/mindspore/mindspore/tree/r1.2/model_zoo/official/cv/vgg16>)|
[Training scripts](<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/vgg16>)|
[VGG19 pretrained model](<https://download.mindspore.cn/model_zoo/>)
# [Environment Requirements](#contents)
@ -51,8 +51,8 @@ The process of training SRGAN needs a pretrained VGG19 based on Imagenet.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/en/master/index.html)
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Script Description](#contents)
@ -97,13 +97,16 @@ SRGAN
# distributed training
Usage: sh run_distribute_train.sh [DEVICE_NUM] [DISTRIBUTE] [RANK_TABLE_FILE] [LRPATH] [GTPATH] [VGGCKPT] [VLRPATH] [VGTPATH]
eg: sh run_distribute_train.sh 8 1 ./hccl_8p.json ./DIV2K_train_LR_bicubic/X4 ./DIV2K_train_HR ./vgg.ckpt ./Set5/LR ./Set5/HR
# standalone training
Usage: sh run_standalone_train.sh [DEVICE_ID] [LRPATH] [GTPATH] [VGGCKPT] [VLRPATH] [VGTPATH]
eg: sh run_distribute_train.sh 0 ./DIV2K_train_LR_bicubic/X4 ./DIV2K_train_HR ./vgg.ckpt ./Set5/LR ./Set5/HR
```
### [Training Result](#content)
Training result will be stored in scripts/srgan0/ckpt. You can find checkpoint file.
Training result will be stored in scripts/train_parallel0/ckpt. You can find checkpoint file.
### [Evaluation Script Parameters](#content)
@ -111,7 +114,9 @@ Training result will be stored in scripts/srgan0/ckpt. You can find checkpoint f
```bash
# evaling
sh run_eval.sh [CKPT] [EVALLRPATH] [EVALGTPATH]
sh run_eval.sh [CKPT] [EVALLRPATH] [EVALGTPATH] [DEVICE_ID]
eg: sh run_eval.sh ./ckpt/best.ckpt ./Set14/LR ./Set14/HR 0
```
### [Evaluation result](#content)
@ -153,4 +158,4 @@ Evaluation result will be stored in the scripts/result. Under this, you can find
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -14,12 +14,10 @@
# ============================================================================
"""file for evaling"""
import os
import argparse
import numpy as np
from skimage.color import rgb2ycbcr
from skimage.metrics import peak_signal_noise_ratio
from PIL import Image
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed
from mindspore import context
@ -34,7 +32,7 @@ parser.add_argument("--test_LR_path", type=str, default='/data/Set14/LR')
parser.add_argument("--test_GT_path", type=str, default='/data/Set14/HR')
parser.add_argument("--res_num", type=int, default=16)
parser.add_argument("--scale", type=int, default=4)
parser.add_argument("--generator_path", type=str, default='./scripts/srgan0/ckpt/G_model_1000.ckpt')
parser.add_argument("--generator_path", type=str, default='./ckpt/best.ckpt')
parser.add_argument("--mode", type=str, default='train')
parser.add_argument("--device_id", type=int, default=0, help="device id, default: 0.")
i = 0
@ -47,9 +45,6 @@ if __name__ == '__main__':
params = load_checkpoint(args.generator_path)
load_param_into_net(generator, params)
op = ops.ReduceSum(keep_dims=False)
if not os.path.exists("./result/Set14"):
os.makedirs("./result/Set14")
weizhi = './result/Set14/psnr.txt'
psnr_list = []
print("=======starting test=====")
@ -78,16 +73,4 @@ if __name__ == '__main__':
psnr = peak_signal_noise_ratio(y_output / 255.0, y_gt / 255.0, data_range=1.0)
psnr_list.append(psnr)
psnr = str(psnr)
with open(weizhi, "w") as f:
f.write('psnr : %s \n' % psnr)
result = Image.fromarray((output * 255.0).astype(np.uint8))
result.save('./result/Set14/res_%04d.png'%i)
i = i+1
mean = np.mean(psnr_list)
mean = str(mean)
print("avg PSNR:")
print(mean)
with open(weizhi, "w") as f:
f.write('avg psnr : %s \n' % mean)
print("avg PSNR:", np.mean(psnr_list))

View File

@ -12,18 +12,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if [ $# != 3 ]
if [ $# != 4 ]
then
echo "Usage: sh run_eval.sh [CKPT] [EVALLRPATH] [EVALGTPATH]"
echo "Usage: sh run_eval.sh [CKPT] [EVALLRPATH] [EVALGTPATH] [DEVICE_ID]"
exit 1
fi
export CKPT=$1
export EVALLRPATH=$2
export EVALGTPATH=$3
export DEVICE_ID=$4
rm -rf ./eval
mkdir ./eval
cp -r ../src ./eval
cp -r ../*.py ./eval
cd ./eval || exit
env > env.log
if [ $# == 3 ]
then
python ../eval.py --generator_path=$CKPT --test_LR_path=$EVALLRPATH --test_GT_path=$EVALGTPATH &> log &
fi
python ./eval.py --generator_path=$CKPT --test_LR_path=$EVALLRPATH --device_id $DEVICE_ID\
--test_GT_path=$EVALGTPATH &> log &

View File

@ -35,7 +35,7 @@ cp -r ../src ./train_standalone
cp -r ../*.py ./train_standalone
cd ./train_standalone || exit
echo "start traning"
echo "start training"
env > env.log
if [ $# == 6 ]
then

View File

@ -113,7 +113,7 @@ if __name__ == '__main__':
os.makedirs("./ckpt")
print('start training:')
print('tart training PSNR:')
print('start training PSNR:')
# warm up generator
for epoch in range(args.start_psnr_epoch, args.psnr_epochs):
print("training {:d} epoch:".format(epoch+1))
@ -124,7 +124,6 @@ if __name__ == '__main__':
mse_loss = train_psnr(hr, lr)
steps = train_ds.get_dataset_size()
time_elapsed = (time.time()-mysince)
print('the epoch needs time:{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
step_time = time_elapsed / steps
print('per step needs time:{:.0f}ms'.format(step_time * 1000))
print("mse_loss:")
@ -192,7 +191,7 @@ if __name__ == '__main__':
discriminator_optimizer = nn.Adam(discriminator.trainable_params(), 1e-4)
train_discriminator = TrainOneStepD(discriminator_loss, discriminator_optimizer)
train_generator = TrainOnestepG(generator_loss, generator_optimizer)
print("========================================")
print('start training GAN :')
# trainGAN
for epoch in range(args.start_gan_epoch, args.gan_epochs):
@ -205,7 +204,6 @@ if __name__ == '__main__':
G_loss = train_generator(hr, lr)
time_elapsed1 = (time.time()-mysince1)
steps = train_ds.get_dataset_size()
print('the epoch needs time:{:.0f}m {:.0f}s'.format(time_elapsed1 // 60, time_elapsed1 % 60))
step_time1 = time_elapsed1 / steps
print('per step needs time:{:.0f}ms'.format(step_time1 * 1000))
print("D_loss:")