forked from mindspore-Ecosystem/mindspore
bug fix
This commit is contained in:
parent
e67b74e8e3
commit
c596e8cab6
|
@ -41,7 +41,7 @@ Validation and eval evaluationdataset used: [Set5](<http://people.rennes.inria.f
|
|||
|
||||
The process of training SRGAN needs a pretrained VGG19 based on Imagenet.
|
||||
|
||||
[Training scripts](<https://gitee.com/mindspore/mindspore/tree/r1.2/model_zoo/official/cv/vgg16>)|
|
||||
[Training scripts](<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/vgg16>)|
|
||||
[VGG19 pretrained model](<https://download.mindspore.cn/model_zoo/>)
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
@ -51,8 +51,8 @@ The process of training SRGAN needs a pretrained VGG19 based on Imagenet.
|
|||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/docs/api/en/master/index.html)
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
|
@ -97,13 +97,16 @@ SRGAN
|
|||
# distributed training
|
||||
Usage: sh run_distribute_train.sh [DEVICE_NUM] [DISTRIBUTE] [RANK_TABLE_FILE] [LRPATH] [GTPATH] [VGGCKPT] [VLRPATH] [VGTPATH]
|
||||
|
||||
eg: sh run_distribute_train.sh 8 1 ./hccl_8p.json ./DIV2K_train_LR_bicubic/X4 ./DIV2K_train_HR ./vgg.ckpt ./Set5/LR ./Set5/HR
|
||||
# standalone training
|
||||
Usage: sh run_standalone_train.sh [DEVICE_ID] [LRPATH] [GTPATH] [VGGCKPT] [VLRPATH] [VGTPATH]
|
||||
|
||||
eg: sh run_distribute_train.sh 0 ./DIV2K_train_LR_bicubic/X4 ./DIV2K_train_HR ./vgg.ckpt ./Set5/LR ./Set5/HR
|
||||
```
|
||||
|
||||
### [Training Result](#content)
|
||||
|
||||
Training result will be stored in scripts/srgan0/ckpt. You can find checkpoint file.
|
||||
Training result will be stored in scripts/train_parallel0/ckpt. You can find checkpoint file.
|
||||
|
||||
### [Evaluation Script Parameters](#content)
|
||||
|
||||
|
@ -111,7 +114,9 @@ Training result will be stored in scripts/srgan0/ckpt. You can find checkpoint f
|
|||
|
||||
```bash
|
||||
# evaling
|
||||
sh run_eval.sh [CKPT] [EVALLRPATH] [EVALGTPATH]
|
||||
sh run_eval.sh [CKPT] [EVALLRPATH] [EVALGTPATH] [DEVICE_ID]
|
||||
|
||||
eg: sh run_eval.sh ./ckpt/best.ckpt ./Set14/LR ./Set14/HR 0
|
||||
```
|
||||
|
||||
### [Evaluation result](#content)
|
||||
|
@ -153,4 +158,4 @@ Evaluation result will be stored in the scripts/result. Under this, you can find
|
|||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||
|
|
|
@ -14,12 +14,10 @@
|
|||
# ============================================================================
|
||||
|
||||
"""file for evaling"""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from skimage.color import rgb2ycbcr
|
||||
from skimage.metrics import peak_signal_noise_ratio
|
||||
from PIL import Image
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
from mindspore import context
|
||||
|
@ -34,7 +32,7 @@ parser.add_argument("--test_LR_path", type=str, default='/data/Set14/LR')
|
|||
parser.add_argument("--test_GT_path", type=str, default='/data/Set14/HR')
|
||||
parser.add_argument("--res_num", type=int, default=16)
|
||||
parser.add_argument("--scale", type=int, default=4)
|
||||
parser.add_argument("--generator_path", type=str, default='./scripts/srgan0/ckpt/G_model_1000.ckpt')
|
||||
parser.add_argument("--generator_path", type=str, default='./ckpt/best.ckpt')
|
||||
parser.add_argument("--mode", type=str, default='train')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id, default: 0.")
|
||||
i = 0
|
||||
|
@ -47,9 +45,6 @@ if __name__ == '__main__':
|
|||
params = load_checkpoint(args.generator_path)
|
||||
load_param_into_net(generator, params)
|
||||
op = ops.ReduceSum(keep_dims=False)
|
||||
if not os.path.exists("./result/Set14"):
|
||||
os.makedirs("./result/Set14")
|
||||
weizhi = './result/Set14/psnr.txt'
|
||||
psnr_list = []
|
||||
|
||||
print("=======starting test=====")
|
||||
|
@ -78,16 +73,4 @@ if __name__ == '__main__':
|
|||
|
||||
psnr = peak_signal_noise_ratio(y_output / 255.0, y_gt / 255.0, data_range=1.0)
|
||||
psnr_list.append(psnr)
|
||||
psnr = str(psnr)
|
||||
with open(weizhi, "w") as f:
|
||||
f.write('psnr : %s \n' % psnr)
|
||||
|
||||
result = Image.fromarray((output * 255.0).astype(np.uint8))
|
||||
result.save('./result/Set14/res_%04d.png'%i)
|
||||
i = i+1
|
||||
mean = np.mean(psnr_list)
|
||||
mean = str(mean)
|
||||
print("avg PSNR:")
|
||||
print(mean)
|
||||
with open(weizhi, "w") as f:
|
||||
f.write('avg psnr : %s \n' % mean)
|
||||
print("avg PSNR:", np.mean(psnr_list))
|
||||
|
|
|
@ -12,18 +12,23 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
if [ $# != 3 ]
|
||||
if [ $# != 4 ]
|
||||
then
|
||||
echo "Usage: sh run_eval.sh [CKPT] [EVALLRPATH] [EVALGTPATH]"
|
||||
echo "Usage: sh run_eval.sh [CKPT] [EVALLRPATH] [EVALGTPATH] [DEVICE_ID]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export CKPT=$1
|
||||
export EVALLRPATH=$2
|
||||
export EVALGTPATH=$3
|
||||
export DEVICE_ID=$4
|
||||
|
||||
rm -rf ./eval
|
||||
mkdir ./eval
|
||||
cp -r ../src ./eval
|
||||
cp -r ../*.py ./eval
|
||||
cd ./eval || exit
|
||||
|
||||
env > env.log
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
python ../eval.py --generator_path=$CKPT --test_LR_path=$EVALLRPATH --test_GT_path=$EVALGTPATH &> log &
|
||||
fi
|
||||
python ./eval.py --generator_path=$CKPT --test_LR_path=$EVALLRPATH --device_id $DEVICE_ID\
|
||||
--test_GT_path=$EVALGTPATH &> log &
|
||||
|
|
|
@ -35,7 +35,7 @@ cp -r ../src ./train_standalone
|
|||
cp -r ../*.py ./train_standalone
|
||||
cd ./train_standalone || exit
|
||||
|
||||
echo "start traning"
|
||||
echo "start training"
|
||||
env > env.log
|
||||
if [ $# == 6 ]
|
||||
then
|
||||
|
|
|
@ -113,7 +113,7 @@ if __name__ == '__main__':
|
|||
os.makedirs("./ckpt")
|
||||
print('start training:')
|
||||
|
||||
print('tart training PSNR:')
|
||||
print('start training PSNR:')
|
||||
# warm up generator
|
||||
for epoch in range(args.start_psnr_epoch, args.psnr_epochs):
|
||||
print("training {:d} epoch:".format(epoch+1))
|
||||
|
@ -124,7 +124,6 @@ if __name__ == '__main__':
|
|||
mse_loss = train_psnr(hr, lr)
|
||||
steps = train_ds.get_dataset_size()
|
||||
time_elapsed = (time.time()-mysince)
|
||||
print('the epoch needs time:{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
|
||||
step_time = time_elapsed / steps
|
||||
print('per step needs time:{:.0f}ms'.format(step_time * 1000))
|
||||
print("mse_loss:")
|
||||
|
@ -192,7 +191,7 @@ if __name__ == '__main__':
|
|||
discriminator_optimizer = nn.Adam(discriminator.trainable_params(), 1e-4)
|
||||
train_discriminator = TrainOneStepD(discriminator_loss, discriminator_optimizer)
|
||||
train_generator = TrainOnestepG(generator_loss, generator_optimizer)
|
||||
|
||||
print("========================================")
|
||||
print('start training GAN :')
|
||||
# trainGAN
|
||||
for epoch in range(args.start_gan_epoch, args.gan_epochs):
|
||||
|
@ -205,7 +204,6 @@ if __name__ == '__main__':
|
|||
G_loss = train_generator(hr, lr)
|
||||
time_elapsed1 = (time.time()-mysince1)
|
||||
steps = train_ds.get_dataset_size()
|
||||
print('the epoch needs time:{:.0f}m {:.0f}s'.format(time_elapsed1 // 60, time_elapsed1 % 60))
|
||||
step_time1 = time_elapsed1 / steps
|
||||
print('per step needs time:{:.0f}ms'.format(step_time1 * 1000))
|
||||
print("D_loss:")
|
||||
|
|
Loading…
Reference in New Issue