forked from mindspore-Ecosystem/mindspore
!13726 add IPT Ascend
From: @xiaoan95 Reviewed-by: @c_34,@wuxuejian Signed-off-by: @c_34
This commit is contained in:
commit
b6bf797ae7
|
@ -23,7 +23,7 @@ from mindspore import context
|
|||
import mindspore.dataset as de
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", device_id=0)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="ASCEND", device_id=0)
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -46,11 +46,12 @@ def main():
|
|||
net_m.set_train(False)
|
||||
num_imgs = train_de_dataset.get_dataset_size()
|
||||
psnrs = np.zeros((num_imgs, 1))
|
||||
inference = ipt.IPT_post(net_m, args)
|
||||
for batch_idx, imgs in enumerate(train_loader):
|
||||
lr = imgs['LR']
|
||||
hr = imgs['HR']
|
||||
hr_np = np.float32(hr.asnumpy())
|
||||
pred = net_m.infrc(lr)
|
||||
pred = inference.forward(lr)
|
||||
pred_np = np.float32(pred.asnumpy())
|
||||
pred_np = quantize(pred_np, 255)
|
||||
psnr = calc_psnr(pred_np, hr_np, 4, 255.0, y_only=True)
|
||||
|
|
|
@ -23,9 +23,6 @@ from mindspore.common.tensor import Tensor
|
|||
from mindspore.common.parameter import Parameter
|
||||
|
||||
|
||||
# from mindspore.ops.primitive import constexpr
|
||||
# import IPython
|
||||
|
||||
class MultiheadAttention(nn.Cell):
|
||||
"""
|
||||
Apply multi-headed attention from "from_tensor" to "to_tensor".
|
||||
|
@ -85,7 +82,7 @@ class MultiheadAttention(nn.Cell):
|
|||
self.shape_q_2d = (-1, q_tensor_width)
|
||||
self.shape_k_2d = (-1, k_tensor_width)
|
||||
self.shape_v_2d = (-1, v_tensor_width)
|
||||
self.hidden_width = hidden_width
|
||||
self.hidden_width = int(hidden_width)
|
||||
# units = num_attention_heads * self.size_per_head
|
||||
if self.same_dim:
|
||||
self.in_proj_layer = \
|
||||
|
@ -132,46 +129,49 @@ class MultiheadAttention(nn.Cell):
|
|||
self.softmax_cast = P.Cast()
|
||||
self.matmul_dense = P.MatMul(transpose_b=True)
|
||||
self.split = P.Split(0, 3)
|
||||
self.equal = P.Equal()
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, tensor_q, tensor_k, tensor_v, batch_size, seq_length, attention_mask=None):
|
||||
def construct(self, tensor_q, tensor_k, tensor_v, attention_mask=None):
|
||||
"""Apply multihead attention."""
|
||||
self.batch_size = batch_size
|
||||
shape_qkv = (self.batch_size, -1,
|
||||
batch_size, seq_length, _ = self.shape(tensor_q)
|
||||
shape_qkv = (batch_size, -1,
|
||||
self.num_attention_heads, self.size_per_head)
|
||||
shape_linear = (self.batch_size * seq_length,
|
||||
shape_linear = (batch_size * seq_length,
|
||||
self.num_attention_heads * self.size_per_head)
|
||||
if self.do_return_2d_tensor:
|
||||
shape_return = (self.batch_size * seq_length,
|
||||
if self.do_return_2d_tensor is True:
|
||||
shape_return = (batch_size * seq_length,
|
||||
self.num_attention_heads * self.size_per_head)
|
||||
if seq_length == -1:
|
||||
shape_return = (-1, self.num_attention_heads *
|
||||
self.size_per_head)
|
||||
else:
|
||||
shape_return = (self.batch_size, seq_length,
|
||||
shape_return = (batch_size, seq_length,
|
||||
self.num_attention_heads * self.size_per_head)
|
||||
|
||||
tensor_q_2d = self.reshape(tensor_q, self.shape_q_2d)
|
||||
tensor_k_2d = self.reshape(tensor_k, self.shape_k_2d)
|
||||
tensor_v_2d = self.reshape(tensor_v, self.shape_v_2d)
|
||||
|
||||
if P.Equal()(tensor_q_2d, tensor_v_2d)[0][0]:
|
||||
if self.equal(tensor_q_2d, tensor_v_2d) is True:
|
||||
x = self.matmul_dense(self.in_proj_layer, tensor_q_2d)
|
||||
query_out, key_out, value_out = self.split(x)
|
||||
|
||||
elif self.same_dim:
|
||||
_start = int(0)
|
||||
_end = int(self.hidden_width)
|
||||
elif self.same_dim is True:
|
||||
_start = 0
|
||||
_end = self.hidden_width
|
||||
_w = self.in_proj_layer[_start:_end, :]
|
||||
# _b = None
|
||||
query_out = self.matmul_dense(_w, tensor_q_2d)
|
||||
|
||||
_start = int(self.hidden_width)
|
||||
_end = int(self.hidden_width * 2)
|
||||
_start = self.hidden_width
|
||||
_end = self.hidden_width * 2
|
||||
_w = self.in_proj_layer[_start:_end, :]
|
||||
# _b = None
|
||||
key_out = self.matmul_dense(_w, tensor_k_2d)
|
||||
|
||||
_start = int(self.hidden_width * 2)
|
||||
_start = self.hidden_width * 2
|
||||
|
||||
_end = None
|
||||
_w = self.in_proj_layer[_start:]
|
||||
# _b = None
|
||||
|
@ -247,7 +247,7 @@ class TransformerEncoderLayer(nn.Cell):
|
|||
permute_recover = (b, n, d)
|
||||
src2 = self.norm1(src)
|
||||
q = k = self.with_pos_embed(src2, pos)
|
||||
src2 = self.self_attn(q, k, src2, batch_size=b, seq_length=n)
|
||||
src2 = self.self_attn(q, k, src2)
|
||||
src = src + self.dropout1(src2)
|
||||
src2 = self.norm2(src)
|
||||
src2 = self.reshape(src2, permute_linear)
|
||||
|
@ -301,13 +301,12 @@ class TransformerDecoderLayer(nn.Cell):
|
|||
permute_recover = (b, n, d)
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
tgt2 = self.self_attn(q, k, tensor_v=tgt2, batch_size=b, seq_length=n)
|
||||
tgt2 = self.self_attn(q, k, tensor_v=tgt2)
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.multihead_attn(tensor_q=self.with_pos_embed(tgt2, query_pos),
|
||||
tensor_k=self.with_pos_embed(memory, pos),
|
||||
tensor_v=memory,
|
||||
batch_size=b, seq_length=n)
|
||||
tensor_v=memory,)
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt2 = self.norm3(tgt)
|
||||
tgt2 = self.reshape(tgt2, permute_linear)
|
||||
|
@ -393,6 +392,7 @@ class VisionTransformer(nn.Cell):
|
|||
num_layers,
|
||||
hidden_dim,
|
||||
num_queries,
|
||||
idx,
|
||||
positional_encoding_type="learned",
|
||||
dropout_rate=0,
|
||||
norm=False,
|
||||
|
@ -422,7 +422,7 @@ class VisionTransformer(nn.Cell):
|
|||
self.no_pos = no_pos
|
||||
|
||||
self.unf = _unfold_(patch_dim)
|
||||
self.fold = _fold_(patch_dim)
|
||||
self.fold = _fold_(patch_dim, output_shape=(img_dim, img_dim))
|
||||
|
||||
if self.mlp is not True:
|
||||
self.linear_encoding = nn.Dense(
|
||||
|
@ -437,7 +437,6 @@ class VisionTransformer(nn.Cell):
|
|||
|
||||
self.query_embed = nn.Embedding(
|
||||
num_queries, embedding_dim * self.seq_length)
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
embedding_dim, num_heads, hidden_dim, dropout_rate)
|
||||
self.encoder = TransformerEncoder(encoder_layer, num_layers)
|
||||
|
@ -455,30 +454,31 @@ class VisionTransformer(nn.Cell):
|
|||
)
|
||||
|
||||
self.dropout_layer1 = nn.Dropout(1. - dropout_rate)
|
||||
|
||||
def construct(self, x, query_idx):
|
||||
self.query_idx = idx
|
||||
self.query_idx_tensor = Tensor(idx, mstype.int32)
|
||||
def construct(self, x):
|
||||
"""ipt"""
|
||||
B, _, _, _ = x.shape
|
||||
x = self.unf(x)
|
||||
B, N, _ = x.shape
|
||||
|
||||
if self.mlp is not True:
|
||||
x = self.reshape(x, (int(B * N), -1))
|
||||
x = self.reshape(x, (B * N, -1))
|
||||
x = self.dropout_layer1(self.linear_encoding(x)) + x
|
||||
x = self.reshape(x, (B, N, -1))
|
||||
query_embed = self.tile(
|
||||
self.reshape(self.query_embed.embedding_table[int(
|
||||
query_idx)], (1, self.seq_length, self.embedding_dim)),
|
||||
self.reshape(self.query_embed(self.query_idx_tensor), (1, self.seq_length, self.embedding_dim)),
|
||||
(B, 1, 1))
|
||||
|
||||
if not self.no_pos:
|
||||
pos = self.position_encoding(x)
|
||||
|
||||
x = self.encoder(x + pos)
|
||||
x = self.encoder(x + pos)
|
||||
else:
|
||||
x = self.encoder(x)
|
||||
x = self.decoder(x, x, query_pos=query_embed)
|
||||
|
||||
if self.mlp is not True:
|
||||
x = self.reshape(x, (int(B * N), -1))
|
||||
x = self.reshape(x, (B * N, -1))
|
||||
x = self.mlp_head(x) + x
|
||||
x = self.reshape(x, (B, N, -1))
|
||||
x = self.fold(x)
|
||||
|
@ -542,9 +542,9 @@ class ResBlock(nn.Cell):
|
|||
def _pixelsf_(x, scale):
|
||||
"""ipt"""
|
||||
N, C, iH, iW = x.shape
|
||||
oH = int(iH * scale)
|
||||
oW = int(iW * scale)
|
||||
oC = int(C // (scale ** 2))
|
||||
oH = iH * scale
|
||||
oW = iW * scale
|
||||
oC = C // (scale ** 2)
|
||||
|
||||
output = P.Reshape()(x, (N, oC, scale, scale, iH, iW))
|
||||
|
||||
|
@ -565,11 +565,12 @@ class SmallUpSampler(nn.Cell):
|
|||
self.conv = conv(n_feats, upsize * upsize * n_feats, 3, bias)
|
||||
self.reshape = P.Reshape()
|
||||
self.upsize = upsize
|
||||
self.pixelsf = _pixelsf_
|
||||
|
||||
def construct(self, x):
|
||||
"""ipt"""
|
||||
x = self.conv(x)
|
||||
output = _pixelsf_(x, self.upsize)
|
||||
output = self.pixelsf(x, self.upsize)
|
||||
return output
|
||||
|
||||
|
||||
|
@ -628,7 +629,8 @@ class IPT(nn.Cell):
|
|||
dropout_rate=args.dropout_rate,
|
||||
mlp=args.no_mlp,
|
||||
pos_every=args.pos_every,
|
||||
no_pos=args.no_pos)
|
||||
no_pos=args.no_pos,
|
||||
idx=self.scale_idx)
|
||||
|
||||
self.tail = nn.CellList([
|
||||
nn.SequentialCell(
|
||||
|
@ -645,7 +647,7 @@ class IPT(nn.Cell):
|
|||
"""ipt"""
|
||||
x = self.sub_mean(x)
|
||||
x = self.head[self.scale_idx](x)
|
||||
res = self.body(x, self.scale_idx)
|
||||
res = self.body(x)
|
||||
res += x
|
||||
x = self.tail[self.scale_idx](res)
|
||||
x = self.add_mean(x)
|
||||
|
@ -654,30 +656,43 @@ class IPT(nn.Cell):
|
|||
|
||||
def set_scale(self, scale_idx):
|
||||
"""ipt"""
|
||||
self.body.query_idx = scale_idx
|
||||
self.scale_idx = scale_idx
|
||||
|
||||
def infrc(self, x):
|
||||
|
||||
class IPT_post():
|
||||
"""ipt"""
|
||||
def __init__(self, model, args):
|
||||
super(IPT_post, self).__init__()
|
||||
self.model = model
|
||||
self.args = args
|
||||
self.scale_idx = 0
|
||||
self.reshape = P.Reshape()
|
||||
self.tile = P.Tile()
|
||||
self.transpose = P.Transpose()
|
||||
self.cc_0 = P.Concat(axis=0)
|
||||
self.cc_2 = P.Concat(axis=2)
|
||||
self.cc_3 = P.Concat(axis=3)
|
||||
|
||||
def set_scale(self, scale_idx):
|
||||
"""ipt"""
|
||||
forward_function = self.forward_chop_new
|
||||
self.body.query_idx = scale_idx
|
||||
self.scale_idx = scale_idx
|
||||
|
||||
return forward_function(x)
|
||||
|
||||
def forward_chop_new(self, x, shave=12, batchsize=64):
|
||||
def forward(self, x, shave=12, batchsize=64):
|
||||
"""ipt"""
|
||||
h, w = x.shape[-2:]
|
||||
padsize = int(self.args.patch_size)
|
||||
shave = int(self.args.patch_size / 4)
|
||||
scale = self.args.scale[self.scale_idx]
|
||||
|
||||
h_cut = (h - padsize) % (padsize - shave)
|
||||
w_cut = (w - padsize) % (padsize - shave)
|
||||
|
||||
unf_1 = _stride_unfold_(padsize, stride=padsize - shave)
|
||||
x_unfold = unf_1(x)
|
||||
x_unfold = unf_1.compute(x)
|
||||
x_unfold = self.transpose(x_unfold, (1, 0, 2)) # transpose(0,2)
|
||||
|
||||
x_hw_cut = x[:, :, (h - padsize):, (w - padsize):]
|
||||
y_hw_cut = self.construct(x_hw_cut)
|
||||
y_hw_cut = self.model(x_hw_cut)
|
||||
|
||||
x_h_cut = x[:, :, (h - padsize):, :]
|
||||
x_w_cut = x[:, :, :, (w - padsize):]
|
||||
|
@ -696,66 +711,71 @@ class IPT(nn.Cell):
|
|||
x_unfold, (x_unfold.shape[0], -1, padsize, padsize))
|
||||
x_range = x_unfold.shape[0] // batchsize + \
|
||||
(x_unfold.shape[0] % batchsize != 0)
|
||||
|
||||
cc_0 = P.Concat(axis=0)
|
||||
for i in range(x_range):
|
||||
if i == 0:
|
||||
y_unfold = self.construct(
|
||||
y_unfold = self.model(
|
||||
x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])
|
||||
else:
|
||||
y_unfold = cc_0((y_unfold, self.construct(
|
||||
y_unfold = self.cc_0((y_unfold, self.model(
|
||||
x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])))
|
||||
y_unf_shape_0 = y_unfold.shape[0]
|
||||
fold_1 = \
|
||||
_stride_fold_(padsize * scale, output_shape=((h - h_cut) * scale, (w - w_cut) * scale),
|
||||
stride=padsize * scale - shave * scale)
|
||||
y = fold_1(self.transpose(self.reshape(
|
||||
y = fold_1.compute(self.transpose(self.reshape(
|
||||
y_unfold, (y_unf_shape_0, -1, 1)), (2, 0, 1)))
|
||||
cc_2 = P.Concat(axis=2)
|
||||
cc_3 = P.Concat(axis=3)
|
||||
y = cc_2((y_h_top, y[:, :, padsize * scale:, :]))
|
||||
y = cc_3((y_w_top, y[:, :, :, padsize * scale:]))
|
||||
if y[:, :, padsize * scale:, :].shape[2] == 0:
|
||||
y = y_h_top
|
||||
else:
|
||||
y = self.cc_2((y_h_top, y[:, :, padsize * scale:, :]))
|
||||
if y[:, :, :, padsize * scale:].shape[3] == 0:
|
||||
y = y_w_top
|
||||
else:
|
||||
y = self.cc_3((y_w_top, y[:, :, :, padsize * scale:]))
|
||||
y_unfold = y_unfold[:, :, int(shave / 2 * scale):padsize * scale - int(shave / 2 * scale),
|
||||
int(shave / 2 * scale):padsize * scale - int(shave / 2 * scale)]
|
||||
fold_2 = _stride_fold_(padsize * scale - shave * scale,
|
||||
output_shape=((h - h_cut - shave) *
|
||||
scale, (w - w_cut - shave) * scale),
|
||||
stride=padsize * scale - shave * scale)
|
||||
y_inter = fold_2(self.transpose(self.reshape(
|
||||
y_inter = fold_2.compute(self.transpose(self.reshape(
|
||||
y_unfold, (y_unf_shape_0, -1, 1)), (2, 0, 1)))
|
||||
y = cc_3((cc_3((y[:, :, :, :int(shave / 2 * scale)], cc_2((cc_2((y[:, :, :int(shave / 2 * scale), int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)], y_inter)), y[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)])))), y[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):])) #pylint: disable=line-too-long
|
||||
y = cc_2((y[:, :, :y.shape[2] - int((padsize - h_cut) / 2 * scale), :],
|
||||
y_h_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :]))
|
||||
y_w_cat = cc_2((y_w_cut[:, :, :y_w_cut.shape[2] - int((padsize - h_cut) / 2 * scale), :],
|
||||
y_hw_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :]))
|
||||
y = cc_3((y[:, :, :, :y.shape[3] - int((padsize - w_cut) / 2 * scale)],
|
||||
y_w_cat[:, :, :, int((padsize - w_cut) / 2 * scale + 0.5):]))
|
||||
concat1 = self.cc_2((y[:, :, :int(shave / 2 * scale), int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)], y_inter)) #pylint: disable=line-too-long
|
||||
concat2 = self.cc_2((concat1, y[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)])) #pylint: disable=line-too-long
|
||||
concat3 = self.cc_3((y[:, :, :, :int(shave / 2 * scale)], concat2))
|
||||
y = self.cc_3((concat3, y[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):])) #pylint: disable=line-too-long
|
||||
y = self.cc_2((y[:, :, :y.shape[2] - int((padsize - h_cut) / 2 * scale), :], y_h_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :])) #pylint: disable=line-too-long
|
||||
|
||||
y_w_cat = self.cc_2((y_w_cut[:, :, :y_w_cut.shape[2] - int((padsize - h_cut) / 2 * scale), :],
|
||||
y_hw_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :]))
|
||||
y = self.cc_3((y[:, :, :, :y.shape[3] - int((padsize - w_cut) / 2 * scale)],
|
||||
y_w_cat[:, :, :, int((padsize - w_cut) / 2 * scale + 0.5):]))
|
||||
|
||||
return y
|
||||
|
||||
def cut_h_new(self, x_h_cut, h, w, h_cut, w_cut, padsize, shave, scale, batchsize):
|
||||
"""ipt"""
|
||||
unf_1 = _stride_unfold_(padsize, stride=padsize - shave)
|
||||
x_h_cut_unfold = unf_1(x_h_cut)
|
||||
x_h_cut_unfold = unf_1.compute(x_h_cut)
|
||||
x_h_cut_unfold = self.transpose(x_h_cut_unfold, (1, 0, 2))
|
||||
|
||||
x_h_cut_unfold = self.reshape(
|
||||
x_h_cut_unfold, (x_h_cut_unfold.shape[0], -1, padsize, padsize))
|
||||
x_range = x_h_cut_unfold.shape[0] // batchsize + \
|
||||
(x_h_cut_unfold.shape[0] % batchsize != 0)
|
||||
cc_0 = P.Concat(axis=0)
|
||||
for i in range(x_range):
|
||||
if i == 0:
|
||||
y_h_cut_unfold = self.construct(
|
||||
y_h_cut_unfold = self.model(
|
||||
x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])
|
||||
else:
|
||||
y_h_cut_unfold = \
|
||||
cc_0((y_h_cut_unfold, self.construct(
|
||||
self.cc_0((y_h_cut_unfold, self.model(
|
||||
x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])))
|
||||
y_h_cut_unfold_shape_0 = y_h_cut_unfold.shape[0]
|
||||
fold_1 = \
|
||||
_stride_fold_(padsize * scale, output_shape=(padsize * scale, (w - w_cut) * scale),
|
||||
stride=padsize * scale - shave * scale)
|
||||
y_h_cut = fold_1(self.transpose(self.reshape(
|
||||
y_h_cut = fold_1.compute(self.transpose(self.reshape(
|
||||
y_h_cut_unfold, (y_h_cut_unfold_shape_0, -1, 1)), (2, 0, 1)))
|
||||
y_h_cut_unfold = y_h_cut_unfold[:, :, :, int(
|
||||
shave / 2 * scale):padsize * scale - int(shave / 2 * scale)]
|
||||
|
@ -763,37 +783,35 @@ class IPT(nn.Cell):
|
|||
output_shape=(padsize * scale,
|
||||
(w - w_cut - shave) * scale),
|
||||
stride=padsize * scale - shave * scale)
|
||||
y_h_cut_inter = fold_2(self.transpose(self.reshape(
|
||||
y_h_cut_inter = fold_2.compute(self.transpose(self.reshape(
|
||||
y_h_cut_unfold, (y_h_cut_unfold_shape_0, -1, 1)), (2, 0, 1)))
|
||||
cc_3 = P.Concat(axis=3)
|
||||
y_h_cut = cc_3((cc_3((y_h_cut[:, :, :, :int(shave / 2 * scale)],
|
||||
y_h_cut_inter)), y_h_cut[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):]))
|
||||
concat1 = self.cc_3((y_h_cut[:, :, :, :int(shave / 2 * scale)], y_h_cut_inter))
|
||||
y_h_cut = self.cc_3((concat1, y_h_cut[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):]))
|
||||
return y_h_cut
|
||||
|
||||
def cut_w_new(self, x_w_cut, h, w, h_cut, w_cut, padsize, shave, scale, batchsize):
|
||||
"""ipt"""
|
||||
unf_1 = _stride_unfold_(padsize, stride=padsize - shave)
|
||||
x_w_cut_unfold = unf_1(x_w_cut)
|
||||
x_w_cut_unfold = unf_1.compute(x_w_cut)
|
||||
x_w_cut_unfold = self.transpose(x_w_cut_unfold, (1, 0, 2))
|
||||
|
||||
x_w_cut_unfold = self.reshape(
|
||||
x_w_cut_unfold, (x_w_cut_unfold.shape[0], -1, padsize, padsize))
|
||||
x_range = x_w_cut_unfold.shape[0] // batchsize + \
|
||||
(x_w_cut_unfold.shape[0] % batchsize != 0)
|
||||
cc_0 = P.Concat(axis=0)
|
||||
for i in range(x_range):
|
||||
if i == 0:
|
||||
y_w_cut_unfold = self.construct(
|
||||
y_w_cut_unfold = self.model(
|
||||
x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])
|
||||
else:
|
||||
y_w_cut_unfold = cc_0((y_w_cut_unfold,
|
||||
self.construct(x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])))
|
||||
y_w_cut_unfold = self.cc_0((y_w_cut_unfold,
|
||||
self.model(x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])))
|
||||
y_w_cut_unfold_shape_0 = y_w_cut_unfold.shape[0]
|
||||
fold_1 = _stride_fold_(padsize * scale,
|
||||
output_shape=((h - h_cut) * scale,
|
||||
padsize * scale),
|
||||
stride=padsize * scale - shave * scale)
|
||||
y_w_cut = fold_1(self.transpose(self.reshape(
|
||||
y_w_cut = fold_1.compute(self.transpose(self.reshape(
|
||||
y_w_cut_unfold, (y_w_cut_unfold_shape_0, -1, 1)), (2, 0, 1)))
|
||||
y_w_cut_unfold = y_w_cut_unfold[:, :, int(
|
||||
shave / 2 * scale):padsize * scale - int(shave / 2 * scale), :]
|
||||
|
@ -801,19 +819,18 @@ class IPT(nn.Cell):
|
|||
output_shape=((h - h_cut - shave)
|
||||
* scale, padsize * scale),
|
||||
stride=padsize * scale - shave * scale)
|
||||
y_w_cut_inter = fold_2(self.transpose(self.reshape(
|
||||
y_w_cut_inter = fold_2.compute(self.transpose(self.reshape(
|
||||
y_w_cut_unfold, (y_w_cut_unfold_shape_0, -1, 1)), (2, 0, 1)))
|
||||
cc_2 = P.Concat(axis=2)
|
||||
y_w_cut = cc_2((cc_2((y_w_cut[:, :, :int(shave / 2 * scale), :],
|
||||
y_w_cut_inter)), y_w_cut[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, :]))
|
||||
concat1 = self.cc_2((y_w_cut[:, :, :int(shave / 2 * scale), :], y_w_cut_inter))
|
||||
y_w_cut = self.cc_2((concat1, y_w_cut[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, :]))
|
||||
return y_w_cut
|
||||
|
||||
class _stride_unfold_():
|
||||
'''stride'''
|
||||
|
||||
class _stride_unfold_(nn.Cell):
|
||||
"""ipt"""
|
||||
|
||||
def __init__(
|
||||
self, kernel_size, stride=-1):
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
stride=-1):
|
||||
|
||||
super(_stride_unfold_, self).__init__()
|
||||
if stride == -1:
|
||||
|
@ -821,28 +838,24 @@ class _stride_unfold_(nn.Cell):
|
|||
else:
|
||||
self.stride = stride
|
||||
self.kernel_size = kernel_size
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
|
||||
self.unfold = _unfold_(kernel_size)
|
||||
|
||||
def construct(self, x):
|
||||
"""ipt"""
|
||||
def compute(self, x):
|
||||
"""stride"""
|
||||
x = x.asnumpy()
|
||||
N, C, H, W = x.shape
|
||||
leftup_idx_x = []
|
||||
leftup_idx_y = []
|
||||
nh = int((H - self.kernel_size) / self.stride + 1)
|
||||
nw = int((W - self.kernel_size) / self.stride + 1)
|
||||
nh = (H - self.kernel_size) // self.stride + 1
|
||||
nw = (W - self.kernel_size) // self.stride + 1
|
||||
for i in range(nh):
|
||||
leftup_idx_x.append(i * self.stride)
|
||||
for i in range(nw):
|
||||
leftup_idx_y.append(i * self.stride)
|
||||
NumBlock_x = len(leftup_idx_x)
|
||||
NumBlock_y = len(leftup_idx_y)
|
||||
zeroslike = P.ZerosLike()
|
||||
cc_2 = P.Concat(axis=2)
|
||||
cc_3 = P.Concat(axis=3)
|
||||
unf_x = P.Zeros()((N, C, NumBlock_x * self.kernel_size,
|
||||
NumBlock_y * self.kernel_size), mstype.float32)
|
||||
unf_x = np.zeros((N, C, NumBlock_x * self.kernel_size, NumBlock_y * self.kernel_size), dtype=np.float32)
|
||||
N, C, H, W = unf_x.shape
|
||||
for i in range(NumBlock_x):
|
||||
for j in range(NumBlock_y):
|
||||
|
@ -852,23 +865,28 @@ class _stride_unfold_(nn.Cell):
|
|||
org_j = leftup_idx_y[j]
|
||||
fills = x[:, :, org_i:org_i + self.kernel_size,
|
||||
org_j:org_j + self.kernel_size]
|
||||
unf_x += cc_3((cc_3((zeroslike(unf_x[:, :, :, :unf_j]),
|
||||
cc_2(
|
||||
(cc_2((zeroslike(unf_x[:, :, :unf_i, unf_j:unf_j + self.kernel_size]), fills)),
|
||||
zeroslike(unf_x[:, :, unf_i + self.kernel_size:,
|
||||
unf_j:unf_j + self.kernel_size]))))),
|
||||
zeroslike(unf_x[:, :, :, unf_j + self.kernel_size:])))
|
||||
zeros2 = np.zeros(unf_x[:, :, :unf_i, unf_j:unf_j + self.kernel_size].shape)
|
||||
concat1 = np.concatenate((zeros2, fills), axis=2)
|
||||
zeros3 = np.zeros(unf_x[:, :, unf_i + self.kernel_size:, unf_j:unf_j + self.kernel_size].shape)
|
||||
concat2 = np.concatenate((concat1, zeros3), axis=2)
|
||||
zeros1 = np.zeros(unf_x[:, :, :, :unf_j].shape)
|
||||
concat3 = np.concatenate((zeros1, concat2), axis=3)
|
||||
zeros4 = np.zeros(unf_x[:, :, :, unf_j + self.kernel_size:].shape)
|
||||
concat4 = np.concatenate((concat3, zeros4), axis=3)
|
||||
unf_x += concat4
|
||||
unf_x = Tensor(unf_x, mstype.float32)
|
||||
y = self.unfold(unf_x)
|
||||
return y
|
||||
|
||||
class _stride_fold_(nn.Cell):
|
||||
"""ipt"""
|
||||
class _stride_fold_():
|
||||
'''stride'''
|
||||
|
||||
def __init__(
|
||||
self, kernel_size, output_shape=(-1, -1), stride=-1):
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
output_shape=(-1, -1),
|
||||
stride=-1):
|
||||
|
||||
super(_stride_fold_, self).__init__()
|
||||
|
||||
if isinstance(kernel_size, (list, tuple)):
|
||||
self.kernel_size = kernel_size
|
||||
else:
|
||||
|
@ -880,66 +898,49 @@ class _stride_fold_(nn.Cell):
|
|||
self.stride = stride
|
||||
|
||||
self.output_shape = output_shape
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
self.fold = _fold_(kernel_size)
|
||||
|
||||
def construct(self, x):
|
||||
"""ipt"""
|
||||
cc_2 = P.Concat(axis=2)
|
||||
cc_3 = P.Concat(axis=3)
|
||||
zeroslike = P.ZerosLike()
|
||||
if self.output_shape[0] == -1:
|
||||
large_x = self.fold(x)
|
||||
N, C, H, _ = large_x.shape
|
||||
leftup_idx = []
|
||||
for i in range(0, H, self.kernel_size[0]):
|
||||
leftup_idx.append(i)
|
||||
NumBlock = len(leftup_idx)
|
||||
fold_x = P.Zeros()((N, C, (NumBlock - 1) * self.stride + self.kernel_size[0],
|
||||
(NumBlock - 1) * self.stride + self.kernel_size[0]), mstype.float32)
|
||||
self.NumBlock_x = (self.output_shape[0] - self.kernel_size[0]) // self.stride + 1
|
||||
self.NumBlock_y = (self.output_shape[1] - self.kernel_size[1]) // self.stride + 1
|
||||
self.large_shape = [self.NumBlock_x * self.kernel_size[0], self.NumBlock_y * self.kernel_size[1]]
|
||||
self.fold = _fold_(self.kernel_size, self.large_shape)
|
||||
|
||||
for i in range(NumBlock):
|
||||
for j in range(NumBlock):
|
||||
fold_i = i * self.stride
|
||||
fold_j = j * self.stride
|
||||
org_i = leftup_idx[i]
|
||||
org_j = leftup_idx[j]
|
||||
fills = large_x[:, :, org_i:org_i + self.kernel_size[0],
|
||||
org_j:org_j + self.kernel_size[1]]
|
||||
fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2((zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike(fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))), zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:]))) #pylint: disable=line-too-long
|
||||
y = fold_x
|
||||
else:
|
||||
NumBlock_x = int(
|
||||
(self.output_shape[0] - self.kernel_size[0]) / self.stride + 1)
|
||||
NumBlock_y = int(
|
||||
(self.output_shape[1] - self.kernel_size[1]) / self.stride + 1)
|
||||
large_shape = [NumBlock_x * self.kernel_size[0],
|
||||
NumBlock_y * self.kernel_size[1]]
|
||||
self.fold = _fold_(self.kernel_size, large_shape)
|
||||
large_x = self.fold(x)
|
||||
N, C, H, _ = large_x.shape
|
||||
leftup_idx_x = []
|
||||
leftup_idx_y = []
|
||||
for i in range(NumBlock_x):
|
||||
leftup_idx_x.append(i * self.kernel_size[0])
|
||||
for i in range(NumBlock_y):
|
||||
leftup_idx_y.append(i * self.kernel_size[1])
|
||||
fold_x = P.Zeros()((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0],
|
||||
(NumBlock_y - 1) * self.stride + self.kernel_size[1]), mstype.float32)
|
||||
for i in range(NumBlock_x):
|
||||
for j in range(NumBlock_y):
|
||||
fold_i = i * self.stride
|
||||
fold_j = j * self.stride
|
||||
org_i = leftup_idx_x[i]
|
||||
org_j = leftup_idx_y[j]
|
||||
fills = large_x[:, :, org_i:org_i + self.kernel_size[0],
|
||||
org_j:org_j + self.kernel_size[1]]
|
||||
fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2((zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike(fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))), zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:]))) #pylint: disable=line-too-long
|
||||
y = fold_x
|
||||
def compute(self, x):
|
||||
'''stride'''
|
||||
NumBlock_x = self.NumBlock_x
|
||||
NumBlock_y = self.NumBlock_y
|
||||
large_x = self.fold(x)
|
||||
large_x = large_x.asnumpy()
|
||||
N, C, _, _ = large_x.shape
|
||||
leftup_idx_x = []
|
||||
leftup_idx_y = []
|
||||
for i in range(NumBlock_x):
|
||||
leftup_idx_x.append(i * self.kernel_size[0])
|
||||
for i in range(NumBlock_y):
|
||||
leftup_idx_y.append(i * self.kernel_size[1])
|
||||
fold_x = np.zeros((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0], (NumBlock_y - 1) * self.stride + self.kernel_size[1]), dtype=np.float32) #pylint: disable=line-too-long
|
||||
for i in range(NumBlock_x):
|
||||
for j in range(NumBlock_y):
|
||||
fold_i = i * self.stride
|
||||
fold_j = j * self.stride
|
||||
org_i = leftup_idx_x[i]
|
||||
org_j = leftup_idx_y[j]
|
||||
fills = large_x[:, :, org_i:org_i + self.kernel_size[0], org_j:org_j + self.kernel_size[1]]
|
||||
t2 = fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]
|
||||
zeros2 = np.zeros(t2.shape)
|
||||
concat1 = np.concatenate((zeros2, fills), axis=2)
|
||||
t3 = fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]
|
||||
zeros3 = np.zeros(t3.shape)
|
||||
concat2 = np.concatenate((concat1, zeros3), axis=2)
|
||||
t1 = fold_x[:, :, :, :fold_j]
|
||||
zeros1 = np.zeros(t1.shape)
|
||||
concat3 = np.concatenate((zeros1, concat2), axis=3)
|
||||
t4 = fold_x[:, :, :, fold_j + self.kernel_size[1]:]
|
||||
zeros4 = np.zeros(t4.shape)
|
||||
concat4 = np.concatenate((concat3, zeros4), axis=3)
|
||||
fold_x += concat4
|
||||
y = Tensor(fold_x, mstype.float32)
|
||||
return y
|
||||
|
||||
|
||||
class _unfold_(nn.Cell):
|
||||
"""ipt"""
|
||||
|
||||
|
@ -957,20 +958,16 @@ class _unfold_(nn.Cell):
|
|||
def construct(self, x):
|
||||
"""ipt"""
|
||||
N, C, H, W = x.shape
|
||||
numH = int(H / self.kernel_size)
|
||||
numW = int(W / self.kernel_size)
|
||||
numH = H // self.kernel_size
|
||||
numW = W // self.kernel_size
|
||||
if numH * self.kernel_size != H or numW * self.kernel_size != W:
|
||||
x = x[:, :, :numH * self.kernel_size, :, numW * self.kernel_size]
|
||||
output_img = self.reshape(x, (N, C, numH, self.kernel_size, W))
|
||||
|
||||
output_img = self.transpose(output_img, (0, 1, 2, 4, 3))
|
||||
|
||||
output_img = self.reshape(output_img, (N, C, int(
|
||||
numH * numW), self.kernel_size, self.kernel_size))
|
||||
|
||||
output_img = self.transpose(output_img, (0, 2, 1, 4, 3))
|
||||
|
||||
output_img = self.reshape(output_img, (N, int(numH * numW), -1))
|
||||
output_img = self.reshape(output_img, (N, C, numH, -1, self.kernel_size, self.kernel_size))
|
||||
output_img = self.transpose(output_img, (0, 2, 3, 1, 5, 4))
|
||||
output_img = self.reshape(output_img, (N, numH * numW, -1))
|
||||
return output_img
|
||||
|
||||
|
||||
|
@ -994,22 +991,17 @@ class _fold_(nn.Cell):
|
|||
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, x):
|
||||
"""ipt"""
|
||||
N, C, L = x.shape
|
||||
org_C = int(L / self.kernel_size[0] / self.kernel_size[1])
|
||||
if self.output_shape[0] == -1:
|
||||
numH = int(np.sqrt(C))
|
||||
numW = int(np.sqrt(C))
|
||||
org_H = int(numH * self.kernel_size[0])
|
||||
org_W = org_H
|
||||
else:
|
||||
org_H = int(self.output_shape[0])
|
||||
org_W = int(self.output_shape[1])
|
||||
numH = int(org_H / self.kernel_size[0])
|
||||
numW = int(org_W / self.kernel_size[1])
|
||||
|
||||
org_C = L // (self.kernel_size[0] * self.kernel_size[1])
|
||||
org_H = self.output_shape[0]
|
||||
org_W = self.output_shape[1]
|
||||
numH = org_H // self.kernel_size[0]
|
||||
numW = org_W // self.kernel_size[1]
|
||||
output_img = self.reshape(
|
||||
x, (N, C, org_C, self.kernel_size[0], self.kernel_size[1]))
|
||||
|
||||
|
|
Loading…
Reference in New Issue