add TNT model
This commit is contained in:
parent
ac5371b38f
commit
50d7062fed
|
@ -0,0 +1,75 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
eval.
|
||||
"""
|
||||
import os
|
||||
import argparse
|
||||
from mindspore import context
|
||||
from mindspore import nn
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import dtype as mstype
|
||||
from src.pet_dataset import create_dataset
|
||||
from src.config import config_ascend, config_gpu
|
||||
from src.tnt import tnt_b
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--platform', type=str, default=None, help='run platform')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config_platform = None
|
||||
if args_opt.platform == "Ascend":
|
||||
config_platform = config_ascend
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
|
||||
device_id=device_id, save_graphs=False)
|
||||
elif args_opt.platform == "GPU":
|
||||
config_platform = config_gpu
|
||||
context.set_context(mode=context.PYNATIVE_MODE,
|
||||
device_target="GPU", save_graphs=False)
|
||||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
|
||||
net = tnt_b(num_class=config_platform.num_classes)
|
||||
|
||||
if args_opt.checkpoint_path:
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
if args_opt.platform == "Ascend":
|
||||
net.to_float(mstype.float16)
|
||||
for _, cell in net.cells_and_names():
|
||||
if isinstance(cell, nn.Dense):
|
||||
cell.to_float(mstype.float32)
|
||||
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path,
|
||||
do_train=False,
|
||||
config=config_platform,
|
||||
platform=args_opt.platform,
|
||||
batch_size=config_platform.batch_size)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
model = Model(net, loss_fn=loss, metrics={'acc'})
|
||||
res = model.eval(dataset)
|
||||
print("result:", res, "ckpt=", args_opt.checkpoint_path)
|
Binary file not shown.
After Width: | Height: | Size: 147 KiB |
|
@ -0,0 +1,22 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config."""
|
||||
from src.tnt import tnt_b
|
||||
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
if name == 'TNT-B':
|
||||
return tnt_b(*args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -0,0 +1,128 @@
|
|||
# Contents
|
||||
|
||||
- [TNT Description](#tnt-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Training Process](#training-process)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Training Performance](#evaluation-performance)
|
||||
- [Inference Performance](#evaluation-performance)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
## [TNT Description](#contents)
|
||||
|
||||
The TNT (Transformer in Transformer) network is a pure transformer model for visual recognition. TNT treats an image as a sequence of patches and treats a patch as a sequence of pixels. TNT block utilizes a outer transformer block to process the sequence of patches and an inner transformer block to process the sequence of pixels.
|
||||
|
||||
[Paper](https://arxiv.org/abs/2103.00112): Kai Han, An Xiao, Enhua Wu, Jianyuan Guo, Chunjing Xu, Yunhe Wang. Transformer in Transformer. preprint 2021.
|
||||
|
||||
## [Model architecture](#contents)
|
||||
|
||||
The overall network architecture of TNT is show below:
|
||||
![](./fig/tnt.PNG)
|
||||
|
||||
## [Dataset](#contents)
|
||||
|
||||
Dataset used: [Oxford-IIIT Pet](https://www.robots.ox.ac.uk/~vgg/data/pets/)
|
||||
|
||||
- Dataset size: 7049 colorful images in 1000 classes
|
||||
- Train: 3680 images
|
||||
- Test: 3369 images
|
||||
- Data format: RGB images.
|
||||
- Note: Data will be processed in src/dataset.py
|
||||
|
||||
## [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend/GPU)
|
||||
- Prepare hardware environment with Ascend or GPU. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below£º
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
## [Script description](#contents)
|
||||
|
||||
### [Script and sample code](#contents)
|
||||
|
||||
```python
|
||||
TNT
|
||||
├── eval.py # inference entry
|
||||
├── fig
|
||||
│ └── tnt.png # the illustration of TNT network
|
||||
├── readme.md # Readme
|
||||
└── src
|
||||
├── config.py # config of model and data
|
||||
├── pet_dataset.py # dataset loader
|
||||
└── tnt.py # TNT network
|
||||
```
|
||||
|
||||
## [Training process](#contents)
|
||||
|
||||
To Be Done
|
||||
|
||||
## [Eval process](#contents)
|
||||
|
||||
### Usage
|
||||
|
||||
After installing MindSpore via the official website, you can start evaluation as follows:
|
||||
|
||||
### Launch
|
||||
|
||||
```bash
|
||||
# infer example
|
||||
GPU: python eval.py --model tnt-b --dataset_path ~/Pets/test.mindrecord --platform GPU --checkpoint_path [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
> checkpoint can be downloaded at https://www.mindspore.cn/resources/hub.
|
||||
|
||||
### Result
|
||||
|
||||
```bash
|
||||
result: {'acc': 0.95} ckpt= ./tnt-b-pets.ckpt
|
||||
```
|
||||
|
||||
## [Model Description](#contents)
|
||||
|
||||
### [Performance](#contents)
|
||||
|
||||
#### Evaluation Performance
|
||||
|
||||
##### TNT on ImageNet2012
|
||||
|
||||
| Parameters | | |
|
||||
| -------------------------- | -------------------------------------- |---------------------------------- |
|
||||
| Model Version | TNT-B |TNT-S|
|
||||
| uploaded Date | 21/03/2021 (month/day/year) | 21/03/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.1 | 1.1 |
|
||||
| Dataset | ImageNet2012 | ImageNet2012|
|
||||
| Input size | 224x224 | 224x224|
|
||||
| Parameters (M) | 86.4 | 23.8 |
|
||||
| FLOPs (M) | 14.1 | 5.2 |
|
||||
| Accuracy (Top1) | 82.8 | 81.3 |
|
||||
|
||||
###### TNT on Oxford-IIIT Pet
|
||||
|
||||
| Parameters | | |
|
||||
| -------------------------- | -------------------------------------- |---------------------------------- |
|
||||
| Model Version | TNT-B |TNT-S|
|
||||
| uploaded Date | 21/03/2021 (month/day/year) | 21/03/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.1 | 1.1 |
|
||||
| Dataset | Oxford-IIIT Pet | Oxford-IIIT Pet|
|
||||
| Input size | 384x384 | 384x384|
|
||||
| Parameters (M) | 86.4 | 23.8 |
|
||||
| Accuracy (Top1) | 95.0 | 94.7 |
|
||||
|
||||
## [Description of Random Situation](#contents)
|
||||
|
||||
In dataset.py, we set the seed inside "create_dataset" function. We also use random seed in train.py.
|
||||
|
||||
## [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config_ascend = ed({
|
||||
"num_classes": 37,
|
||||
"image_height": 384,
|
||||
"image_width": 384,
|
||||
"batch_size": 50,
|
||||
"epoch_size": 300,
|
||||
"warmup_epochs": 5,
|
||||
"lr": 1e-3,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 0.05,
|
||||
"label_smooth": 0.1,
|
||||
"loss_scale": 1024,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 200,
|
||||
"save_checkpoint_path": "./checkpoint",
|
||||
})
|
||||
|
||||
config_gpu = ed({
|
||||
"num_classes": 37,
|
||||
"image_height": 384,
|
||||
"image_width": 384,
|
||||
"batch_size": 50,
|
||||
"epoch_size": 300,
|
||||
"warmup_epochs": 5,
|
||||
"lr": 1e-3,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 0.05,
|
||||
"label_smooth": 0.1,
|
||||
"loss_scale": 1024,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 500,
|
||||
"save_checkpoint_path": "./checkpoint",
|
||||
})
|
|
@ -0,0 +1,97 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
create train or eval dataset.
|
||||
"""
|
||||
import os
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.py_transforms as py_transforms
|
||||
import mindspore.dataset.transforms.c_transforms as c_transforms
|
||||
import mindspore.dataset.vision.py_transforms as py_vision
|
||||
from mindspore.dataset.vision import Inter
|
||||
|
||||
def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch_size=1):
|
||||
"""
|
||||
create a train or eval dataset
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1
|
||||
batch_size(int): the batch size of dataset. Default: 32
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
if platform == "Ascend":
|
||||
rank_size = int(os.getenv("RANK_SIZE"))
|
||||
rank_id = int(os.getenv("RANK_ID"))
|
||||
if rank_size == 1:
|
||||
ds = de.MindDataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
else:
|
||||
ds = de.MindDataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||
num_shards=rank_size, shard_id=rank_id)
|
||||
elif platform == "GPU":
|
||||
if do_train:
|
||||
from mindspore.communication.management import get_rank, get_group_size
|
||||
ds = de.MindDataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||
num_shards=get_group_size(), shard_id=get_rank())
|
||||
else:
|
||||
ds = de.MindDataset(dataset_path, num_parallel_workers=8, shuffle=False)
|
||||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
|
||||
resize_height = config.image_height
|
||||
resize_width = config.image_width
|
||||
buffer_size = 1000
|
||||
|
||||
# define map operations
|
||||
random_resize_crop_bicubic = py_vision.RandomResizedCrop(size=(resize_height, resize_width),
|
||||
scale=(0.08, 1.0), ratio=(3./4., 4./3.),
|
||||
interpolation=Inter.BICUBIC)
|
||||
random_horizontal_flip_op = py_vision.RandomHorizontalFlip(0.5)
|
||||
color_jitter = 0.4
|
||||
adjust_range = (max(0, 1 - color_jitter), 1 + color_jitter)
|
||||
random_color_jitter_op = py_vision.RandomColorAdjust(brightness=adjust_range,
|
||||
contrast=adjust_range,
|
||||
saturation=adjust_range)
|
||||
|
||||
decode_p = py_vision.Decode()
|
||||
resize_p = py_vision.Resize(int(resize_height), interpolation=Inter.BICUBIC)
|
||||
center_crop_p = py_vision.CenterCrop(resize_height)
|
||||
totensor = py_vision.ToTensor()
|
||||
normalize_p = py_vision.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
|
||||
if do_train:
|
||||
trans = py_transforms.Compose([decode_p, random_resize_crop_bicubic, random_horizontal_flip_op,
|
||||
random_color_jitter_op, totensor, normalize_p])
|
||||
else:
|
||||
trans = py_transforms.Compose([decode_p, resize_p, center_crop_p, totensor, normalize_p])
|
||||
|
||||
type_cast_op = c_transforms.TypeCast(mstype.int32)
|
||||
|
||||
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8)
|
||||
ds = ds.map(input_columns="label_list", operations=type_cast_op, num_parallel_workers=8)
|
||||
|
||||
# apply shuffle operations
|
||||
ds = ds.shuffle(buffer_size=buffer_size)
|
||||
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
|
||||
# apply dataset repeat operation
|
||||
ds = ds.repeat(repeat_num)
|
||||
return ds
|
|
@ -0,0 +1,390 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""TNT"""
|
||||
import math
|
||||
import copy
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
|
||||
class MLP(nn.Cell):
|
||||
"""MLP"""
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, dropout=0.):
|
||||
super(MLP, self).__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Dense(in_features, hidden_features)
|
||||
self.dropout = nn.Dropout(1. - dropout)
|
||||
self.fc2 = nn.Dense(hidden_features, out_features)
|
||||
self.act = nn.GELU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.dropout(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Cell):
|
||||
"""Multi-head Attention"""
|
||||
|
||||
def __init__(self, dim, hidden_dim=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
||||
super(Attention, self).__init__()
|
||||
hidden_dim = hidden_dim or dim
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_dim // num_heads
|
||||
self.head_dim = head_dim
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.qk = nn.Dense(dim, hidden_dim * 2, has_bias=qkv_bias)
|
||||
self.v = nn.Dense(dim, hidden_dim, has_bias=qkv_bias)
|
||||
self.softmax = nn.Softmax(axis=-1)
|
||||
self.batmatmul_trans_b = P.BatchMatMul(transpose_b=True)
|
||||
self.attn_drop = nn.Dropout(1. - attn_drop)
|
||||
self.batmatmul = P.BatchMatMul()
|
||||
self.proj = nn.Dense(hidden_dim, dim)
|
||||
self.proj_drop = nn.Dropout(1. - proj_drop)
|
||||
|
||||
self.transpose = P.Transpose()
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, x):
|
||||
"""Multi-head Attention"""
|
||||
B, N, _ = x.shape
|
||||
qk = self.transpose(self.reshape(self.qk(x), (B, N, 2, self.num_heads, self.head_dim)), (2, 0, 3, 1, 4))
|
||||
q, k = qk[0], qk[1]
|
||||
v = self.transpose(self.reshape(self.v(x), (B, N, self.num_heads, self.head_dim)), (0, 2, 1, 3))
|
||||
|
||||
attn = self.softmax(self.batmatmul_trans_b(q, k) * self.scale)
|
||||
attn = self.attn_drop(attn)
|
||||
x = self.reshape(self.transpose(self.batmatmul(attn, v), (0, 2, 1, 3)), (B, N, -1))
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class DropConnect(nn.Cell):
|
||||
"""drop connect implementation"""
|
||||
|
||||
def __init__(self, drop_connect_rate=0., seed0=0, seed1=0):
|
||||
super(DropConnect, self).__init__()
|
||||
self.shape = P.Shape()
|
||||
self.dtype = P.DType()
|
||||
self.keep_prob = 1 - drop_connect_rate
|
||||
self.dropout = P.Dropout(keep_prob=self.keep_prob)
|
||||
self.keep_prob_tensor = Tensor(self.keep_prob, dtype=mstype.float32)
|
||||
|
||||
def construct(self, x):
|
||||
shape = self.shape(x)
|
||||
dtype = self.dtype(x)
|
||||
ones_tensor = P.Fill()(dtype, (shape[0], 1, 1, 1), 1)
|
||||
_, mask = self.dropout(ones_tensor)
|
||||
x = x * mask
|
||||
x = x / self.keep_prob_tensor
|
||||
return x
|
||||
|
||||
|
||||
class Pixel2Patch(nn.Cell):
|
||||
"""Projecting Pixel Embedding to Patch Embedding"""
|
||||
|
||||
def __init__(self, outer_dim):
|
||||
super(Pixel2Patch, self).__init__()
|
||||
self.norm_proj = nn.LayerNorm([outer_dim])
|
||||
self.proj = nn.Dense(outer_dim, outer_dim)
|
||||
self.fake = Parameter(Tensor(np.zeros((1, 1, outer_dim)),
|
||||
mstype.float32), name='fake', requires_grad=False)
|
||||
self.reshape = P.Reshape()
|
||||
self.tile = P.Tile()
|
||||
self.concat = P.Concat(axis=1)
|
||||
|
||||
def construct(self, pixel_embed, patch_embed):
|
||||
B, N, _ = patch_embed.shape
|
||||
proj = self.reshape(pixel_embed, (B, N - 1, -1))
|
||||
proj = self.proj(self.norm_proj(proj))
|
||||
proj = self.concat((self.tile(self.fake, (B, 1, 1)), proj))
|
||||
patch_embed = patch_embed + proj
|
||||
return patch_embed
|
||||
|
||||
|
||||
class TNTBlock(nn.Cell):
|
||||
"""TNT Block"""
|
||||
|
||||
def __init__(self, inner_config, outer_config, dropout=0., attn_dropout=0., drop_connect=0.):
|
||||
super().__init__()
|
||||
# inner transformer
|
||||
inner_dim = inner_config['dim']
|
||||
num_heads = inner_config['num_heads']
|
||||
mlp_ratio = inner_config['mlp_ratio']
|
||||
self.inner_norm1 = nn.LayerNorm([inner_dim])
|
||||
self.inner_attn = Attention(inner_dim, num_heads=num_heads, qkv_bias=True, attn_drop=attn_dropout,
|
||||
proj_drop=dropout)
|
||||
self.inner_norm2 = nn.LayerNorm([inner_dim])
|
||||
self.inner_mlp = MLP(inner_dim, int(inner_dim * mlp_ratio), dropout=dropout)
|
||||
# outer transformer
|
||||
outer_dim = outer_config['dim']
|
||||
num_heads = outer_config['num_heads']
|
||||
mlp_ratio = outer_config['mlp_ratio']
|
||||
self.outer_norm1 = nn.LayerNorm([outer_dim])
|
||||
self.outer_attn = Attention(outer_dim, num_heads=num_heads, qkv_bias=True, attn_drop=attn_dropout,
|
||||
proj_drop=dropout)
|
||||
self.outer_norm2 = nn.LayerNorm([outer_dim])
|
||||
self.outer_mlp = MLP(outer_dim, int(outer_dim * mlp_ratio), dropout=dropout)
|
||||
# pixel2patch
|
||||
self.pixel2patch = Pixel2Patch(outer_dim)
|
||||
# assistant
|
||||
self.drop_connect = DropConnect(drop_connect)
|
||||
self.reshape = P.Reshape()
|
||||
self.tile = P.Tile()
|
||||
self.concat = P.Concat(axis=1)
|
||||
|
||||
def construct(self, pixel_embed, patch_embed):
|
||||
"""TNT Block"""
|
||||
pixel_embed = pixel_embed + self.inner_attn(self.inner_norm1(pixel_embed))
|
||||
pixel_embed = pixel_embed + self.inner_mlp(self.inner_norm2(pixel_embed))
|
||||
|
||||
patch_embed = self.pixel2patch(pixel_embed, patch_embed)
|
||||
|
||||
patch_embed = patch_embed + self.outer_attn(self.outer_norm1(patch_embed))
|
||||
patch_embed = patch_embed + self.outer_mlp(self.outer_norm2(patch_embed))
|
||||
return pixel_embed, patch_embed
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
"""get_clones"""
|
||||
return nn.CellList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
|
||||
class TNTEncoder(nn.Cell):
|
||||
"""TNT"""
|
||||
|
||||
def __init__(self, encoder_layer, num_layers):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(encoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
|
||||
def construct(self, pixel_embed, patch_embed):
|
||||
"""TNT"""
|
||||
for layer in self.layers:
|
||||
pixel_embed, patch_embed = layer(pixel_embed, patch_embed)
|
||||
return pixel_embed, patch_embed
|
||||
|
||||
|
||||
class _stride_unfold_(nn.Cell):
|
||||
"""Unfold with stride"""
|
||||
|
||||
def __init__(
|
||||
self, kernel_size, stride=-1):
|
||||
super(_stride_unfold_, self).__init__()
|
||||
if stride == -1:
|
||||
self.stride = kernel_size
|
||||
else:
|
||||
self.stride = stride
|
||||
self.kernel_size = kernel_size
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
self.unfold = _unfold_(kernel_size)
|
||||
|
||||
def construct(self, x):
|
||||
"""TNT"""
|
||||
N, C, H, W = x.shape
|
||||
leftup_idx_x = []
|
||||
leftup_idx_y = []
|
||||
nh = int((H - self.kernel_size) / self.stride + 1)
|
||||
nw = int((W - self.kernel_size) / self.stride + 1)
|
||||
for i in range(nh):
|
||||
leftup_idx_x.append(i * self.stride)
|
||||
for i in range(nw):
|
||||
leftup_idx_y.append(i * self.stride)
|
||||
NumBlock_x = len(leftup_idx_x)
|
||||
NumBlock_y = len(leftup_idx_y)
|
||||
zeroslike = P.ZerosLike()
|
||||
cc_2 = P.Concat(axis=2)
|
||||
cc_3 = P.Concat(axis=3)
|
||||
unf_x = P.Zeros()((N, C, NumBlock_x * self.kernel_size,
|
||||
NumBlock_y * self.kernel_size), mstype.float32)
|
||||
N, C, H, W = unf_x.shape
|
||||
for i in range(NumBlock_x):
|
||||
for j in range(NumBlock_y):
|
||||
unf_i = i * self.kernel_size
|
||||
unf_j = j * self.kernel_size
|
||||
org_i = leftup_idx_x[i]
|
||||
org_j = leftup_idx_y[j]
|
||||
fill = x[:, :, org_i:org_i + self.kernel_size,
|
||||
org_j:org_j + self.kernel_size]
|
||||
unf_x += cc_3((cc_3((zeroslike(unf_x[:, :, :, :unf_j]),
|
||||
cc_2((cc_2((zeroslike(unf_x[:, :, :unf_i, unf_j:unf_j + self.kernel_size]), fill)),
|
||||
zeroslike(unf_x[:, :, unf_i + self.kernel_size:,
|
||||
unf_j:unf_j + self.kernel_size]))))),
|
||||
zeroslike(unf_x[:, :, :, unf_j + self.kernel_size:])))
|
||||
y = self.unfold(unf_x)
|
||||
return y
|
||||
|
||||
|
||||
class _unfold_(nn.Cell):
|
||||
"""Unfold"""
|
||||
|
||||
def __init__(
|
||||
self, kernel_size, stride=-1):
|
||||
super(_unfold_, self).__init__()
|
||||
if stride == -1:
|
||||
self.stride = kernel_size
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
|
||||
def construct(self, x):
|
||||
"""TNT"""
|
||||
N, C, H, W = x.shape
|
||||
numH = int(H / self.kernel_size)
|
||||
numW = int(W / self.kernel_size)
|
||||
if numH * self.kernel_size != H or numW * self.kernel_size != W:
|
||||
x = x[:, :, :numH * self.kernel_size, :, numW * self.kernel_size]
|
||||
output_img = self.reshape(x, (N, C, numH, self.kernel_size, W))
|
||||
|
||||
output_img = self.transpose(output_img, (0, 1, 2, 4, 3))
|
||||
|
||||
output_img = self.reshape(output_img, (N, C, int(
|
||||
numH * numW), self.kernel_size, self.kernel_size))
|
||||
|
||||
output_img = self.transpose(output_img, (0, 2, 1, 4, 3))
|
||||
|
||||
output_img = self.reshape(output_img, (N, int(numH * numW), -1))
|
||||
return output_img
|
||||
|
||||
|
||||
class PixelEmbed(nn.Cell):
|
||||
"""Image to Pixel Embedding"""
|
||||
|
||||
def __init__(self, img_size, patch_size=16, in_channels=3, embedding_dim=768, stride=4):
|
||||
super(PixelEmbed, self).__init__()
|
||||
self.num_patches = (img_size // patch_size) * (img_size // patch_size)
|
||||
new_patch_size = math.ceil(patch_size / stride)
|
||||
self.new_patch_size = new_patch_size
|
||||
self.inner_dim = embedding_dim // new_patch_size // new_patch_size
|
||||
self.proj = nn.Conv2d(in_channels, self.inner_dim, kernel_size=7, pad_mode='pad',
|
||||
padding=3, stride=stride, has_bias=True)
|
||||
self.unfold = _unfold_(kernel_size=new_patch_size)
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
|
||||
def construct(self, x):
|
||||
B = x.shape[0]
|
||||
x = self.proj(x) # B, C, H, W
|
||||
x = self.unfold(x) # B, N, Ck2
|
||||
x = self.reshape(x, (B * self.num_patches, self.inner_dim, -1)) # B*N, C, M
|
||||
x = self.transpose(x, (0, 2, 1)) # B*N, M, C
|
||||
return x
|
||||
|
||||
|
||||
class TNT(nn.Cell):
|
||||
"""TNT"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size,
|
||||
patch_size,
|
||||
num_channels,
|
||||
embedding_dim,
|
||||
num_heads,
|
||||
num_layers,
|
||||
hidden_dim,
|
||||
num_class,
|
||||
stride=4,
|
||||
dropout=0,
|
||||
attn_dropout=0,
|
||||
drop_connect=0.1
|
||||
):
|
||||
super(TNT, self).__init__()
|
||||
|
||||
assert embedding_dim % num_heads == 0
|
||||
assert img_size % patch_size == 0
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_heads = num_heads
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.img_size = img_size
|
||||
self.num_patches = int((img_size // patch_size) ** 2)
|
||||
new_patch_size = math.ceil(patch_size / stride)
|
||||
inner_dim = embedding_dim // new_patch_size // new_patch_size
|
||||
|
||||
self.patch_pos = Parameter(Tensor(np.random.rand(1, self.num_patches + 1, embedding_dim),
|
||||
mstype.float32), name='patch_pos', requires_grad=True)
|
||||
self.pixel_pos = Parameter(Tensor(np.random.rand(1, inner_dim, new_patch_size * new_patch_size),
|
||||
mstype.float32), name='pixel_pos', requires_grad=True)
|
||||
self.cls_token = Parameter(Tensor(np.random.rand(1, 1, embedding_dim),
|
||||
mstype.float32), requires_grad=True)
|
||||
self.patch_embed = Parameter(Tensor(np.zeros((1, self.num_patches, embedding_dim)),
|
||||
mstype.float32), name='patch_embed', requires_grad=False)
|
||||
self.fake = Parameter(Tensor(np.zeros((1, 1, embedding_dim)),
|
||||
mstype.float32), name='fake', requires_grad=False)
|
||||
self.pos_drop = nn.Dropout(1. - dropout)
|
||||
|
||||
self.pixel_embed = PixelEmbed(img_size, patch_size, num_channels, embedding_dim, stride)
|
||||
self.pixel2patch = Pixel2Patch(embedding_dim)
|
||||
|
||||
inner_config = {'dim': inner_dim, 'num_heads': 4, 'mlp_ratio': 4}
|
||||
outer_config = {'dim': embedding_dim, 'num_heads': num_heads, 'mlp_ratio': hidden_dim / embedding_dim}
|
||||
encoder_layer = TNTBlock(inner_config, outer_config, dropout=dropout, attn_dropout=attn_dropout,
|
||||
drop_connect=drop_connect)
|
||||
self.encoder = TNTEncoder(encoder_layer, num_layers)
|
||||
|
||||
self.head = nn.SequentialCell(
|
||||
nn.LayerNorm([embedding_dim]),
|
||||
nn.Dense(embedding_dim, num_class)
|
||||
)
|
||||
|
||||
self.add = P.TensorAdd()
|
||||
self.reshape = P.Reshape()
|
||||
self.concat = P.Concat(axis=1)
|
||||
self.tile = P.Tile()
|
||||
self.transpose = P.Transpose()
|
||||
|
||||
def construct(self, x):
|
||||
"""TNT"""
|
||||
B, _, _, _ = x.shape
|
||||
pixel_embed = self.pixel_embed(x)
|
||||
pixel_embed = pixel_embed + self.transpose(self.pixel_pos, (0, 2, 1)) # B*N, M, C
|
||||
|
||||
patch_embed = self.concat((self.cls_token, self.patch_embed))
|
||||
patch_embed = self.tile(patch_embed, (B, 1, 1))
|
||||
patch_embed = self.pos_drop(patch_embed + self.patch_pos)
|
||||
|
||||
patch_embed = self.pixel2patch(pixel_embed, patch_embed)
|
||||
|
||||
pixel_embed, patch_embed = self.encoder(pixel_embed, patch_embed)
|
||||
|
||||
y = self.head(patch_embed[:, 0])
|
||||
return y
|
||||
|
||||
|
||||
def tnt_b(num_class):
|
||||
return TNT(img_size=384,
|
||||
patch_size=16,
|
||||
num_channels=3,
|
||||
embedding_dim=640,
|
||||
num_heads=10,
|
||||
num_layers=12,
|
||||
hidden_dim=640*4,
|
||||
stride=4,
|
||||
num_class=num_class)
|
Loading…
Reference in New Issue