forked from mindspore-Ecosystem/mindspore
recitify wdsr scripts
This commit is contained in:
parent
6012dab4b2
commit
ec4c7387b8
|
@ -116,8 +116,8 @@ WDSR网络主要由几个基本模块(包括卷积层和池化层)组成。
|
|||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/docs/api/en/master/index.html)
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# 快速入门
|
||||
|
||||
|
|
|
@ -1,19 +1,3 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
#!/bin/bash
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
|
|
|
@ -52,8 +52,6 @@ parser.add_argument('--epochs', type=int, default=300,
|
|||
help='number of epochs to train')
|
||||
parser.add_argument('--batch_size', type=int, default=16,
|
||||
help='input batch size for training')
|
||||
#parser.add_argument('--self_ensemble', action='store_true',
|
||||
# help='use self-ensemble method for test')
|
||||
parser.add_argument('--test_only', action='store_true',
|
||||
help='set this option to test the model')
|
||||
# Optimization specifications
|
||||
|
@ -61,6 +59,8 @@ parser.add_argument('--lr', type=float, default=1e-4,
|
|||
help='learning rate')
|
||||
parser.add_argument('--init_loss_scale', type=float, default=65536.,
|
||||
help='scaling factor')
|
||||
parser.add_argument('--loss_scale', type=float, default=1024.0,
|
||||
help='loss_scale')
|
||||
parser.add_argument('--decay', type=str, default='200',
|
||||
help='learning rate decay type')
|
||||
parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
|
||||
|
|
|
@ -20,11 +20,8 @@ def get_patch(*args, patch_size=96, scale=2, input_large=False):
|
|||
ih, iw = args[0].shape[:2]
|
||||
tp = patch_size
|
||||
ip = tp // scale
|
||||
# print("tp=%g,scale=%g,"%(tp,scale),end='',flush=True)
|
||||
# print("ih=%g,iw=%g,ip=%g"%(ih,iw,ip),flush=True)
|
||||
ix = random.randrange(0, iw - ip + 1)
|
||||
iy = random.randrange(0, ih - ip + 1)
|
||||
# ix = iy = 0
|
||||
if not input_large:
|
||||
tx, ty = scale * ix, scale * iy
|
||||
else:
|
||||
|
|
|
@ -106,13 +106,11 @@ class SRData:
|
|||
with open(f, 'wb') as _f:
|
||||
pickle.dump(imageio.imread(img), _f)
|
||||
|
||||
# pylint: disable=unused-variable
|
||||
def __getitem__(self, idx):
|
||||
lr, hr, filename = self._load_file(idx)
|
||||
lr, hr, _ = self._load_file(idx)
|
||||
pair = self.get_patch(lr, hr)
|
||||
pair = common.set_channel(*pair, n_channels=self.args.n_colors)
|
||||
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
|
||||
#return pair_t[0], pair_t[1], [self.idx_scale], [filename]
|
||||
return pair_t[0], pair_t[1]
|
||||
|
||||
def __len__(self):
|
||||
|
@ -163,7 +161,6 @@ class SRData:
|
|||
def _load_file(self, idx):
|
||||
"""srdata"""
|
||||
idx = self._get_index(idx)
|
||||
# print(idx,flush=True)
|
||||
f_hr = self.images_hr[idx]
|
||||
f_lr = self.images_lr[self.idx_scale][idx]
|
||||
filename, _ = os.path.splitext(os.path.basename(f_hr))
|
||||
|
|
|
@ -36,7 +36,6 @@ class Block(nn.Cell):
|
|||
"""residual block"""
|
||||
def __init__(self):
|
||||
super(Block, self).__init__()
|
||||
# wn = lambda x: mindspore.nn.GroupNorm(x)
|
||||
act = nn.ReLU()
|
||||
self.res_scale = 1
|
||||
body = []
|
||||
|
|
|
@ -30,7 +30,7 @@ class Trainer():
|
|||
self.criterion = nn.L1Loss()
|
||||
self.loss_history = []
|
||||
self.begin_time = time.time()
|
||||
self.optimizer = nn.Adam(self.model.trainable_params(), learning_rate=args.lr, loss_scale=1024.0)
|
||||
self.optimizer = nn.Adam(self.model.trainable_params(), learning_rate=args.lr, loss_scale=args.loss_scale)
|
||||
self.loss_net = nn.WithLossCell(self.model, self.criterion)
|
||||
self.net = nn.TrainOneStepCell(self.loss_net, self.optimizer)
|
||||
def train(self, epoch):
|
||||
|
|
|
@ -56,7 +56,7 @@ def train_net():
|
|||
for i in range(0, args.epochs):
|
||||
cur_lr = args.lr / (2 ** ((i + 1)//200))
|
||||
lr.extend([cur_lr] * step_size)
|
||||
opt = nn.Adam(net_m.trainable_params(), learning_rate=lr, loss_scale=1024.0)
|
||||
opt = nn.Adam(net_m.trainable_params(), learning_rate=lr, loss_scale=args.loss_scale)
|
||||
loss = nn.L1Loss()
|
||||
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=args.init_loss_scale, \
|
||||
scale_factor=2, scale_window=1000)
|
||||
|
|
Loading…
Reference in New Issue