!13726 add IPT Ascend

From: @xiaoan95
Reviewed-by: @c_34,@wuxuejian
Signed-off-by: @c_34
This commit is contained in:
mindspore-ci-bot 2021-03-24 17:02:10 +08:00 committed by Gitee
commit b6bf797ae7
2 changed files with 182 additions and 189 deletions

View File

@ -23,7 +23,7 @@ from mindspore import context
import mindspore.dataset as de
from mindspore.train.serialization import load_checkpoint, load_param_into_net
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", device_id=0)
context.set_context(mode=context.GRAPH_MODE, device_target="ASCEND", device_id=0)
def main():
@ -46,11 +46,12 @@ def main():
net_m.set_train(False)
num_imgs = train_de_dataset.get_dataset_size()
psnrs = np.zeros((num_imgs, 1))
inference = ipt.IPT_post(net_m, args)
for batch_idx, imgs in enumerate(train_loader):
lr = imgs['LR']
hr = imgs['HR']
hr_np = np.float32(hr.asnumpy())
pred = net_m.infrc(lr)
pred = inference.forward(lr)
pred_np = np.float32(pred.asnumpy())
pred_np = quantize(pred_np, 255)
psnr = calc_psnr(pred_np, hr_np, 4, 255.0, y_only=True)

View File

@ -23,9 +23,6 @@ from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
# from mindspore.ops.primitive import constexpr
# import IPython
class MultiheadAttention(nn.Cell):
"""
Apply multi-headed attention from "from_tensor" to "to_tensor".
@ -85,7 +82,7 @@ class MultiheadAttention(nn.Cell):
self.shape_q_2d = (-1, q_tensor_width)
self.shape_k_2d = (-1, k_tensor_width)
self.shape_v_2d = (-1, v_tensor_width)
self.hidden_width = hidden_width
self.hidden_width = int(hidden_width)
# units = num_attention_heads * self.size_per_head
if self.same_dim:
self.in_proj_layer = \
@ -132,46 +129,49 @@ class MultiheadAttention(nn.Cell):
self.softmax_cast = P.Cast()
self.matmul_dense = P.MatMul(transpose_b=True)
self.split = P.Split(0, 3)
self.equal = P.Equal()
self.shape = P.Shape()
def construct(self, tensor_q, tensor_k, tensor_v, batch_size, seq_length, attention_mask=None):
def construct(self, tensor_q, tensor_k, tensor_v, attention_mask=None):
"""Apply multihead attention."""
self.batch_size = batch_size
shape_qkv = (self.batch_size, -1,
batch_size, seq_length, _ = self.shape(tensor_q)
shape_qkv = (batch_size, -1,
self.num_attention_heads, self.size_per_head)
shape_linear = (self.batch_size * seq_length,
shape_linear = (batch_size * seq_length,
self.num_attention_heads * self.size_per_head)
if self.do_return_2d_tensor:
shape_return = (self.batch_size * seq_length,
if self.do_return_2d_tensor is True:
shape_return = (batch_size * seq_length,
self.num_attention_heads * self.size_per_head)
if seq_length == -1:
shape_return = (-1, self.num_attention_heads *
self.size_per_head)
else:
shape_return = (self.batch_size, seq_length,
shape_return = (batch_size, seq_length,
self.num_attention_heads * self.size_per_head)
tensor_q_2d = self.reshape(tensor_q, self.shape_q_2d)
tensor_k_2d = self.reshape(tensor_k, self.shape_k_2d)
tensor_v_2d = self.reshape(tensor_v, self.shape_v_2d)
if P.Equal()(tensor_q_2d, tensor_v_2d)[0][0]:
if self.equal(tensor_q_2d, tensor_v_2d) is True:
x = self.matmul_dense(self.in_proj_layer, tensor_q_2d)
query_out, key_out, value_out = self.split(x)
elif self.same_dim:
_start = int(0)
_end = int(self.hidden_width)
elif self.same_dim is True:
_start = 0
_end = self.hidden_width
_w = self.in_proj_layer[_start:_end, :]
# _b = None
query_out = self.matmul_dense(_w, tensor_q_2d)
_start = int(self.hidden_width)
_end = int(self.hidden_width * 2)
_start = self.hidden_width
_end = self.hidden_width * 2
_w = self.in_proj_layer[_start:_end, :]
# _b = None
key_out = self.matmul_dense(_w, tensor_k_2d)
_start = int(self.hidden_width * 2)
_start = self.hidden_width * 2
_end = None
_w = self.in_proj_layer[_start:]
# _b = None
@ -247,7 +247,7 @@ class TransformerEncoderLayer(nn.Cell):
permute_recover = (b, n, d)
src2 = self.norm1(src)
q = k = self.with_pos_embed(src2, pos)
src2 = self.self_attn(q, k, src2, batch_size=b, seq_length=n)
src2 = self.self_attn(q, k, src2)
src = src + self.dropout1(src2)
src2 = self.norm2(src)
src2 = self.reshape(src2, permute_linear)
@ -301,13 +301,12 @@ class TransformerDecoderLayer(nn.Cell):
permute_recover = (b, n, d)
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, tensor_v=tgt2, batch_size=b, seq_length=n)
tgt2 = self.self_attn(q, k, tensor_v=tgt2)
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(tensor_q=self.with_pos_embed(tgt2, query_pos),
tensor_k=self.with_pos_embed(memory, pos),
tensor_v=memory,
batch_size=b, seq_length=n)
tensor_v=memory,)
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.reshape(tgt2, permute_linear)
@ -393,6 +392,7 @@ class VisionTransformer(nn.Cell):
num_layers,
hidden_dim,
num_queries,
idx,
positional_encoding_type="learned",
dropout_rate=0,
norm=False,
@ -422,7 +422,7 @@ class VisionTransformer(nn.Cell):
self.no_pos = no_pos
self.unf = _unfold_(patch_dim)
self.fold = _fold_(patch_dim)
self.fold = _fold_(patch_dim, output_shape=(img_dim, img_dim))
if self.mlp is not True:
self.linear_encoding = nn.Dense(
@ -437,7 +437,6 @@ class VisionTransformer(nn.Cell):
self.query_embed = nn.Embedding(
num_queries, embedding_dim * self.seq_length)
encoder_layer = TransformerEncoderLayer(
embedding_dim, num_heads, hidden_dim, dropout_rate)
self.encoder = TransformerEncoder(encoder_layer, num_layers)
@ -455,30 +454,31 @@ class VisionTransformer(nn.Cell):
)
self.dropout_layer1 = nn.Dropout(1. - dropout_rate)
def construct(self, x, query_idx):
self.query_idx = idx
self.query_idx_tensor = Tensor(idx, mstype.int32)
def construct(self, x):
"""ipt"""
B, _, _, _ = x.shape
x = self.unf(x)
B, N, _ = x.shape
if self.mlp is not True:
x = self.reshape(x, (int(B * N), -1))
x = self.reshape(x, (B * N, -1))
x = self.dropout_layer1(self.linear_encoding(x)) + x
x = self.reshape(x, (B, N, -1))
query_embed = self.tile(
self.reshape(self.query_embed.embedding_table[int(
query_idx)], (1, self.seq_length, self.embedding_dim)),
self.reshape(self.query_embed(self.query_idx_tensor), (1, self.seq_length, self.embedding_dim)),
(B, 1, 1))
if not self.no_pos:
pos = self.position_encoding(x)
x = self.encoder(x + pos)
x = self.encoder(x + pos)
else:
x = self.encoder(x)
x = self.decoder(x, x, query_pos=query_embed)
if self.mlp is not True:
x = self.reshape(x, (int(B * N), -1))
x = self.reshape(x, (B * N, -1))
x = self.mlp_head(x) + x
x = self.reshape(x, (B, N, -1))
x = self.fold(x)
@ -542,9 +542,9 @@ class ResBlock(nn.Cell):
def _pixelsf_(x, scale):
"""ipt"""
N, C, iH, iW = x.shape
oH = int(iH * scale)
oW = int(iW * scale)
oC = int(C // (scale ** 2))
oH = iH * scale
oW = iW * scale
oC = C // (scale ** 2)
output = P.Reshape()(x, (N, oC, scale, scale, iH, iW))
@ -565,11 +565,12 @@ class SmallUpSampler(nn.Cell):
self.conv = conv(n_feats, upsize * upsize * n_feats, 3, bias)
self.reshape = P.Reshape()
self.upsize = upsize
self.pixelsf = _pixelsf_
def construct(self, x):
"""ipt"""
x = self.conv(x)
output = _pixelsf_(x, self.upsize)
output = self.pixelsf(x, self.upsize)
return output
@ -628,7 +629,8 @@ class IPT(nn.Cell):
dropout_rate=args.dropout_rate,
mlp=args.no_mlp,
pos_every=args.pos_every,
no_pos=args.no_pos)
no_pos=args.no_pos,
idx=self.scale_idx)
self.tail = nn.CellList([
nn.SequentialCell(
@ -645,7 +647,7 @@ class IPT(nn.Cell):
"""ipt"""
x = self.sub_mean(x)
x = self.head[self.scale_idx](x)
res = self.body(x, self.scale_idx)
res = self.body(x)
res += x
x = self.tail[self.scale_idx](res)
x = self.add_mean(x)
@ -654,30 +656,43 @@ class IPT(nn.Cell):
def set_scale(self, scale_idx):
"""ipt"""
self.body.query_idx = scale_idx
self.scale_idx = scale_idx
def infrc(self, x):
class IPT_post():
"""ipt"""
def __init__(self, model, args):
super(IPT_post, self).__init__()
self.model = model
self.args = args
self.scale_idx = 0
self.reshape = P.Reshape()
self.tile = P.Tile()
self.transpose = P.Transpose()
self.cc_0 = P.Concat(axis=0)
self.cc_2 = P.Concat(axis=2)
self.cc_3 = P.Concat(axis=3)
def set_scale(self, scale_idx):
"""ipt"""
forward_function = self.forward_chop_new
self.body.query_idx = scale_idx
self.scale_idx = scale_idx
return forward_function(x)
def forward_chop_new(self, x, shave=12, batchsize=64):
def forward(self, x, shave=12, batchsize=64):
"""ipt"""
h, w = x.shape[-2:]
padsize = int(self.args.patch_size)
shave = int(self.args.patch_size / 4)
scale = self.args.scale[self.scale_idx]
h_cut = (h - padsize) % (padsize - shave)
w_cut = (w - padsize) % (padsize - shave)
unf_1 = _stride_unfold_(padsize, stride=padsize - shave)
x_unfold = unf_1(x)
x_unfold = unf_1.compute(x)
x_unfold = self.transpose(x_unfold, (1, 0, 2)) # transpose(0,2)
x_hw_cut = x[:, :, (h - padsize):, (w - padsize):]
y_hw_cut = self.construct(x_hw_cut)
y_hw_cut = self.model(x_hw_cut)
x_h_cut = x[:, :, (h - padsize):, :]
x_w_cut = x[:, :, :, (w - padsize):]
@ -696,66 +711,71 @@ class IPT(nn.Cell):
x_unfold, (x_unfold.shape[0], -1, padsize, padsize))
x_range = x_unfold.shape[0] // batchsize + \
(x_unfold.shape[0] % batchsize != 0)
cc_0 = P.Concat(axis=0)
for i in range(x_range):
if i == 0:
y_unfold = self.construct(
y_unfold = self.model(
x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])
else:
y_unfold = cc_0((y_unfold, self.construct(
y_unfold = self.cc_0((y_unfold, self.model(
x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])))
y_unf_shape_0 = y_unfold.shape[0]
fold_1 = \
_stride_fold_(padsize * scale, output_shape=((h - h_cut) * scale, (w - w_cut) * scale),
stride=padsize * scale - shave * scale)
y = fold_1(self.transpose(self.reshape(
y = fold_1.compute(self.transpose(self.reshape(
y_unfold, (y_unf_shape_0, -1, 1)), (2, 0, 1)))
cc_2 = P.Concat(axis=2)
cc_3 = P.Concat(axis=3)
y = cc_2((y_h_top, y[:, :, padsize * scale:, :]))
y = cc_3((y_w_top, y[:, :, :, padsize * scale:]))
if y[:, :, padsize * scale:, :].shape[2] == 0:
y = y_h_top
else:
y = self.cc_2((y_h_top, y[:, :, padsize * scale:, :]))
if y[:, :, :, padsize * scale:].shape[3] == 0:
y = y_w_top
else:
y = self.cc_3((y_w_top, y[:, :, :, padsize * scale:]))
y_unfold = y_unfold[:, :, int(shave / 2 * scale):padsize * scale - int(shave / 2 * scale),
int(shave / 2 * scale):padsize * scale - int(shave / 2 * scale)]
fold_2 = _stride_fold_(padsize * scale - shave * scale,
output_shape=((h - h_cut - shave) *
scale, (w - w_cut - shave) * scale),
stride=padsize * scale - shave * scale)
y_inter = fold_2(self.transpose(self.reshape(
y_inter = fold_2.compute(self.transpose(self.reshape(
y_unfold, (y_unf_shape_0, -1, 1)), (2, 0, 1)))
y = cc_3((cc_3((y[:, :, :, :int(shave / 2 * scale)], cc_2((cc_2((y[:, :, :int(shave / 2 * scale), int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)], y_inter)), y[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)])))), y[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):])) #pylint: disable=line-too-long
y = cc_2((y[:, :, :y.shape[2] - int((padsize - h_cut) / 2 * scale), :],
y_h_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :]))
y_w_cat = cc_2((y_w_cut[:, :, :y_w_cut.shape[2] - int((padsize - h_cut) / 2 * scale), :],
y_hw_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :]))
y = cc_3((y[:, :, :, :y.shape[3] - int((padsize - w_cut) / 2 * scale)],
y_w_cat[:, :, :, int((padsize - w_cut) / 2 * scale + 0.5):]))
concat1 = self.cc_2((y[:, :, :int(shave / 2 * scale), int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)], y_inter)) #pylint: disable=line-too-long
concat2 = self.cc_2((concat1, y[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)])) #pylint: disable=line-too-long
concat3 = self.cc_3((y[:, :, :, :int(shave / 2 * scale)], concat2))
y = self.cc_3((concat3, y[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):])) #pylint: disable=line-too-long
y = self.cc_2((y[:, :, :y.shape[2] - int((padsize - h_cut) / 2 * scale), :], y_h_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :])) #pylint: disable=line-too-long
y_w_cat = self.cc_2((y_w_cut[:, :, :y_w_cut.shape[2] - int((padsize - h_cut) / 2 * scale), :],
y_hw_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :]))
y = self.cc_3((y[:, :, :, :y.shape[3] - int((padsize - w_cut) / 2 * scale)],
y_w_cat[:, :, :, int((padsize - w_cut) / 2 * scale + 0.5):]))
return y
def cut_h_new(self, x_h_cut, h, w, h_cut, w_cut, padsize, shave, scale, batchsize):
"""ipt"""
unf_1 = _stride_unfold_(padsize, stride=padsize - shave)
x_h_cut_unfold = unf_1(x_h_cut)
x_h_cut_unfold = unf_1.compute(x_h_cut)
x_h_cut_unfold = self.transpose(x_h_cut_unfold, (1, 0, 2))
x_h_cut_unfold = self.reshape(
x_h_cut_unfold, (x_h_cut_unfold.shape[0], -1, padsize, padsize))
x_range = x_h_cut_unfold.shape[0] // batchsize + \
(x_h_cut_unfold.shape[0] % batchsize != 0)
cc_0 = P.Concat(axis=0)
for i in range(x_range):
if i == 0:
y_h_cut_unfold = self.construct(
y_h_cut_unfold = self.model(
x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])
else:
y_h_cut_unfold = \
cc_0((y_h_cut_unfold, self.construct(
self.cc_0((y_h_cut_unfold, self.model(
x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])))
y_h_cut_unfold_shape_0 = y_h_cut_unfold.shape[0]
fold_1 = \
_stride_fold_(padsize * scale, output_shape=(padsize * scale, (w - w_cut) * scale),
stride=padsize * scale - shave * scale)
y_h_cut = fold_1(self.transpose(self.reshape(
y_h_cut = fold_1.compute(self.transpose(self.reshape(
y_h_cut_unfold, (y_h_cut_unfold_shape_0, -1, 1)), (2, 0, 1)))
y_h_cut_unfold = y_h_cut_unfold[:, :, :, int(
shave / 2 * scale):padsize * scale - int(shave / 2 * scale)]
@ -763,37 +783,35 @@ class IPT(nn.Cell):
output_shape=(padsize * scale,
(w - w_cut - shave) * scale),
stride=padsize * scale - shave * scale)
y_h_cut_inter = fold_2(self.transpose(self.reshape(
y_h_cut_inter = fold_2.compute(self.transpose(self.reshape(
y_h_cut_unfold, (y_h_cut_unfold_shape_0, -1, 1)), (2, 0, 1)))
cc_3 = P.Concat(axis=3)
y_h_cut = cc_3((cc_3((y_h_cut[:, :, :, :int(shave / 2 * scale)],
y_h_cut_inter)), y_h_cut[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):]))
concat1 = self.cc_3((y_h_cut[:, :, :, :int(shave / 2 * scale)], y_h_cut_inter))
y_h_cut = self.cc_3((concat1, y_h_cut[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):]))
return y_h_cut
def cut_w_new(self, x_w_cut, h, w, h_cut, w_cut, padsize, shave, scale, batchsize):
"""ipt"""
unf_1 = _stride_unfold_(padsize, stride=padsize - shave)
x_w_cut_unfold = unf_1(x_w_cut)
x_w_cut_unfold = unf_1.compute(x_w_cut)
x_w_cut_unfold = self.transpose(x_w_cut_unfold, (1, 0, 2))
x_w_cut_unfold = self.reshape(
x_w_cut_unfold, (x_w_cut_unfold.shape[0], -1, padsize, padsize))
x_range = x_w_cut_unfold.shape[0] // batchsize + \
(x_w_cut_unfold.shape[0] % batchsize != 0)
cc_0 = P.Concat(axis=0)
for i in range(x_range):
if i == 0:
y_w_cut_unfold = self.construct(
y_w_cut_unfold = self.model(
x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])
else:
y_w_cut_unfold = cc_0((y_w_cut_unfold,
self.construct(x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])))
y_w_cut_unfold = self.cc_0((y_w_cut_unfold,
self.model(x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])))
y_w_cut_unfold_shape_0 = y_w_cut_unfold.shape[0]
fold_1 = _stride_fold_(padsize * scale,
output_shape=((h - h_cut) * scale,
padsize * scale),
stride=padsize * scale - shave * scale)
y_w_cut = fold_1(self.transpose(self.reshape(
y_w_cut = fold_1.compute(self.transpose(self.reshape(
y_w_cut_unfold, (y_w_cut_unfold_shape_0, -1, 1)), (2, 0, 1)))
y_w_cut_unfold = y_w_cut_unfold[:, :, int(
shave / 2 * scale):padsize * scale - int(shave / 2 * scale), :]
@ -801,19 +819,18 @@ class IPT(nn.Cell):
output_shape=((h - h_cut - shave)
* scale, padsize * scale),
stride=padsize * scale - shave * scale)
y_w_cut_inter = fold_2(self.transpose(self.reshape(
y_w_cut_inter = fold_2.compute(self.transpose(self.reshape(
y_w_cut_unfold, (y_w_cut_unfold_shape_0, -1, 1)), (2, 0, 1)))
cc_2 = P.Concat(axis=2)
y_w_cut = cc_2((cc_2((y_w_cut[:, :, :int(shave / 2 * scale), :],
y_w_cut_inter)), y_w_cut[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, :]))
concat1 = self.cc_2((y_w_cut[:, :, :int(shave / 2 * scale), :], y_w_cut_inter))
y_w_cut = self.cc_2((concat1, y_w_cut[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, :]))
return y_w_cut
class _stride_unfold_():
'''stride'''
class _stride_unfold_(nn.Cell):
"""ipt"""
def __init__(
self, kernel_size, stride=-1):
def __init__(self,
kernel_size,
stride=-1):
super(_stride_unfold_, self).__init__()
if stride == -1:
@ -821,28 +838,24 @@ class _stride_unfold_(nn.Cell):
else:
self.stride = stride
self.kernel_size = kernel_size
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.unfold = _unfold_(kernel_size)
def construct(self, x):
"""ipt"""
def compute(self, x):
"""stride"""
x = x.asnumpy()
N, C, H, W = x.shape
leftup_idx_x = []
leftup_idx_y = []
nh = int((H - self.kernel_size) / self.stride + 1)
nw = int((W - self.kernel_size) / self.stride + 1)
nh = (H - self.kernel_size) // self.stride + 1
nw = (W - self.kernel_size) // self.stride + 1
for i in range(nh):
leftup_idx_x.append(i * self.stride)
for i in range(nw):
leftup_idx_y.append(i * self.stride)
NumBlock_x = len(leftup_idx_x)
NumBlock_y = len(leftup_idx_y)
zeroslike = P.ZerosLike()
cc_2 = P.Concat(axis=2)
cc_3 = P.Concat(axis=3)
unf_x = P.Zeros()((N, C, NumBlock_x * self.kernel_size,
NumBlock_y * self.kernel_size), mstype.float32)
unf_x = np.zeros((N, C, NumBlock_x * self.kernel_size, NumBlock_y * self.kernel_size), dtype=np.float32)
N, C, H, W = unf_x.shape
for i in range(NumBlock_x):
for j in range(NumBlock_y):
@ -852,23 +865,28 @@ class _stride_unfold_(nn.Cell):
org_j = leftup_idx_y[j]
fills = x[:, :, org_i:org_i + self.kernel_size,
org_j:org_j + self.kernel_size]
unf_x += cc_3((cc_3((zeroslike(unf_x[:, :, :, :unf_j]),
cc_2(
(cc_2((zeroslike(unf_x[:, :, :unf_i, unf_j:unf_j + self.kernel_size]), fills)),
zeroslike(unf_x[:, :, unf_i + self.kernel_size:,
unf_j:unf_j + self.kernel_size]))))),
zeroslike(unf_x[:, :, :, unf_j + self.kernel_size:])))
zeros2 = np.zeros(unf_x[:, :, :unf_i, unf_j:unf_j + self.kernel_size].shape)
concat1 = np.concatenate((zeros2, fills), axis=2)
zeros3 = np.zeros(unf_x[:, :, unf_i + self.kernel_size:, unf_j:unf_j + self.kernel_size].shape)
concat2 = np.concatenate((concat1, zeros3), axis=2)
zeros1 = np.zeros(unf_x[:, :, :, :unf_j].shape)
concat3 = np.concatenate((zeros1, concat2), axis=3)
zeros4 = np.zeros(unf_x[:, :, :, unf_j + self.kernel_size:].shape)
concat4 = np.concatenate((concat3, zeros4), axis=3)
unf_x += concat4
unf_x = Tensor(unf_x, mstype.float32)
y = self.unfold(unf_x)
return y
class _stride_fold_(nn.Cell):
"""ipt"""
class _stride_fold_():
'''stride'''
def __init__(
self, kernel_size, output_shape=(-1, -1), stride=-1):
def __init__(self,
kernel_size,
output_shape=(-1, -1),
stride=-1):
super(_stride_fold_, self).__init__()
if isinstance(kernel_size, (list, tuple)):
self.kernel_size = kernel_size
else:
@ -880,66 +898,49 @@ class _stride_fold_(nn.Cell):
self.stride = stride
self.output_shape = output_shape
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.fold = _fold_(kernel_size)
def construct(self, x):
"""ipt"""
cc_2 = P.Concat(axis=2)
cc_3 = P.Concat(axis=3)
zeroslike = P.ZerosLike()
if self.output_shape[0] == -1:
large_x = self.fold(x)
N, C, H, _ = large_x.shape
leftup_idx = []
for i in range(0, H, self.kernel_size[0]):
leftup_idx.append(i)
NumBlock = len(leftup_idx)
fold_x = P.Zeros()((N, C, (NumBlock - 1) * self.stride + self.kernel_size[0],
(NumBlock - 1) * self.stride + self.kernel_size[0]), mstype.float32)
self.NumBlock_x = (self.output_shape[0] - self.kernel_size[0]) // self.stride + 1
self.NumBlock_y = (self.output_shape[1] - self.kernel_size[1]) // self.stride + 1
self.large_shape = [self.NumBlock_x * self.kernel_size[0], self.NumBlock_y * self.kernel_size[1]]
self.fold = _fold_(self.kernel_size, self.large_shape)
for i in range(NumBlock):
for j in range(NumBlock):
fold_i = i * self.stride
fold_j = j * self.stride
org_i = leftup_idx[i]
org_j = leftup_idx[j]
fills = large_x[:, :, org_i:org_i + self.kernel_size[0],
org_j:org_j + self.kernel_size[1]]
fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2((zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike(fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))), zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:]))) #pylint: disable=line-too-long
y = fold_x
else:
NumBlock_x = int(
(self.output_shape[0] - self.kernel_size[0]) / self.stride + 1)
NumBlock_y = int(
(self.output_shape[1] - self.kernel_size[1]) / self.stride + 1)
large_shape = [NumBlock_x * self.kernel_size[0],
NumBlock_y * self.kernel_size[1]]
self.fold = _fold_(self.kernel_size, large_shape)
large_x = self.fold(x)
N, C, H, _ = large_x.shape
leftup_idx_x = []
leftup_idx_y = []
for i in range(NumBlock_x):
leftup_idx_x.append(i * self.kernel_size[0])
for i in range(NumBlock_y):
leftup_idx_y.append(i * self.kernel_size[1])
fold_x = P.Zeros()((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0],
(NumBlock_y - 1) * self.stride + self.kernel_size[1]), mstype.float32)
for i in range(NumBlock_x):
for j in range(NumBlock_y):
fold_i = i * self.stride
fold_j = j * self.stride
org_i = leftup_idx_x[i]
org_j = leftup_idx_y[j]
fills = large_x[:, :, org_i:org_i + self.kernel_size[0],
org_j:org_j + self.kernel_size[1]]
fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2((zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike(fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))), zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:]))) #pylint: disable=line-too-long
y = fold_x
def compute(self, x):
'''stride'''
NumBlock_x = self.NumBlock_x
NumBlock_y = self.NumBlock_y
large_x = self.fold(x)
large_x = large_x.asnumpy()
N, C, _, _ = large_x.shape
leftup_idx_x = []
leftup_idx_y = []
for i in range(NumBlock_x):
leftup_idx_x.append(i * self.kernel_size[0])
for i in range(NumBlock_y):
leftup_idx_y.append(i * self.kernel_size[1])
fold_x = np.zeros((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0], (NumBlock_y - 1) * self.stride + self.kernel_size[1]), dtype=np.float32) #pylint: disable=line-too-long
for i in range(NumBlock_x):
for j in range(NumBlock_y):
fold_i = i * self.stride
fold_j = j * self.stride
org_i = leftup_idx_x[i]
org_j = leftup_idx_y[j]
fills = large_x[:, :, org_i:org_i + self.kernel_size[0], org_j:org_j + self.kernel_size[1]]
t2 = fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]
zeros2 = np.zeros(t2.shape)
concat1 = np.concatenate((zeros2, fills), axis=2)
t3 = fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]
zeros3 = np.zeros(t3.shape)
concat2 = np.concatenate((concat1, zeros3), axis=2)
t1 = fold_x[:, :, :, :fold_j]
zeros1 = np.zeros(t1.shape)
concat3 = np.concatenate((zeros1, concat2), axis=3)
t4 = fold_x[:, :, :, fold_j + self.kernel_size[1]:]
zeros4 = np.zeros(t4.shape)
concat4 = np.concatenate((concat3, zeros4), axis=3)
fold_x += concat4
y = Tensor(fold_x, mstype.float32)
return y
class _unfold_(nn.Cell):
"""ipt"""
@ -957,20 +958,16 @@ class _unfold_(nn.Cell):
def construct(self, x):
"""ipt"""
N, C, H, W = x.shape
numH = int(H / self.kernel_size)
numW = int(W / self.kernel_size)
numH = H // self.kernel_size
numW = W // self.kernel_size
if numH * self.kernel_size != H or numW * self.kernel_size != W:
x = x[:, :, :numH * self.kernel_size, :, numW * self.kernel_size]
output_img = self.reshape(x, (N, C, numH, self.kernel_size, W))
output_img = self.transpose(output_img, (0, 1, 2, 4, 3))
output_img = self.reshape(output_img, (N, C, int(
numH * numW), self.kernel_size, self.kernel_size))
output_img = self.transpose(output_img, (0, 2, 1, 4, 3))
output_img = self.reshape(output_img, (N, int(numH * numW), -1))
output_img = self.reshape(output_img, (N, C, numH, -1, self.kernel_size, self.kernel_size))
output_img = self.transpose(output_img, (0, 2, 3, 1, 5, 4))
output_img = self.reshape(output_img, (N, numH * numW, -1))
return output_img
@ -994,22 +991,17 @@ class _fold_(nn.Cell):
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.sqrt = P.Sqrt()
self.cast = P.Cast()
def construct(self, x):
"""ipt"""
N, C, L = x.shape
org_C = int(L / self.kernel_size[0] / self.kernel_size[1])
if self.output_shape[0] == -1:
numH = int(np.sqrt(C))
numW = int(np.sqrt(C))
org_H = int(numH * self.kernel_size[0])
org_W = org_H
else:
org_H = int(self.output_shape[0])
org_W = int(self.output_shape[1])
numH = int(org_H / self.kernel_size[0])
numW = int(org_W / self.kernel_size[1])
org_C = L // (self.kernel_size[0] * self.kernel_size[1])
org_H = self.output_shape[0]
org_W = self.output_shape[1]
numH = org_H // self.kernel_size[0]
numW = org_W // self.kernel_size[1]
output_img = self.reshape(
x, (N, C, org_C, self.kernel_size[0], self.kernel_size[1]))