fix ctpn backbone vgg16 init

This commit is contained in:
qujianwei 2021-04-08 19:23:34 +08:00
parent 68093fa5d1
commit ee023b2e02
4 changed files with 26 additions and 25 deletions

View File

@ -169,11 +169,23 @@ The `pretrained_path` should be a checkpoint of vgg16 trained on Imagenet2012. T
...
from src.vgg16 import VGG16
...
network = VGG16()
network = VGG16(num_classes=cfg.num_classes)
...
```
To train a better model, you can modify some parameter in modelzoo/official/cv/vgg16/src/config.py, here we suggested you modify the "warmup_epochs" just like below, you can also try to adjust other parameter.
```python
imagenet_cfg = edict({
...
"warmup_epochs": 5
...
})
```
Then you can train it with ImageNet2012.
> Notes:
> RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html) , and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). For large models like InceptionV4, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size.

View File

@ -56,6 +56,7 @@ def ctpn_infer_test(dataset_path='', ckpt_path='', img_dir=''):
os.mkdir(output_dir)
for file in os.listdir(img_dir):
img_basenames.append(os.path.basename(file))
img_basenames = sorted(img_basenames)
for data in ds.create_dict_iterator():
img_data = data['image']
img_metas = data['image_shape']

View File

@ -23,36 +23,22 @@ def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)
def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=False):
"""Batchnorm2D wrapper."""
gamma_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32))
beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32))
moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32))
moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32))
return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init,
beta_init=beta_init, moving_mean_init=moving_mean_init,
moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics)
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad', weights_update=True):
"""Conv2D wrapper."""
weights = 'ones'
layers = []
conv = nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=False)
pad_mode=pad_mode, has_bias=False)
if not weights_update:
conv.weight.requires_grad = False
layers += [conv]
layers += [_BatchNorm2dInit(out_channels)]
layers += [nn.BatchNorm2d(out_channels)]
return nn.SequentialCell(layers)
def _fc(in_channels, out_channels):
'''full connection layer'''
weight = _weight_variable((out_channels, in_channels))
bias = _weight_variable((out_channels,))
return nn.Dense(in_channels, out_channels, weight, bias)
return nn.Dense(in_channels, out_channels)
class VGG16FeatureExtraction(nn.Cell):
@ -141,36 +127,38 @@ class VGG16Classfier(nn.Cell):
self.relu = nn.ReLU()
self.fc1 = _fc(in_channels=7*7*512, out_channels=4096)
self.fc2 = _fc(in_channels=4096, out_channels=4096)
self.batch_size = 32
self.reshape = P.Reshape()
self.dropout = nn.Dropout(0.5)
def construct(self, x):
"""
:param x: shape=(B, 512, 7, 7)
:return:
"""
x = self.reshape(x, (self.batch_size, 7*7*512))
x = self.reshape(x, (-1, 7*7*512))
x = self.fc1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.relu(x)
x = self.dropout(x)
return x
class VGG16(nn.Cell):
def __init__(self):
def __init__(self, num_classes):
"""VGG16 construct for training backbone"""
super(VGG16, self).__init__()
self.feature_extraction = VGG16FeatureExtraction(weights_update=True)
self.vgg16_feature_extractor = VGG16FeatureExtraction(weights_update=True)
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.classifier = VGG16Classfier()
self.fc3 = _fc(in_channels=4096, out_channels=1000)
self.fc3 = _fc(in_channels=4096, out_channels=num_classes)
def construct(self, x):
"""
:param x: shape=(B, 3, 224, 224)
:return: logits, shape=(B, 1000)
"""
feature_maps = self.feature_extraction(x)
feature_maps = self.vgg16_feature_extractor(x)
x = self.max_pool(feature_maps)
x = self.classifier(x)
x = self.fc3(x)

View File

@ -145,7 +145,7 @@ def create_train_dataset(dataset_type):
# test: icdar2013 test
icdar_test_image_files, icdar_test_anno_dict = create_icdar_svt_label(config.icdar13_test_path[0],\
config.icdar13_test_path[1], "")
image_files = icdar_test_image_files
image_files = sorted(icdar_test_image_files)
image_anno_dict = icdar_test_anno_dict
data_to_mindrecord_byte_image(image_files, image_anno_dict, config.test_dataset_path, \
prefix="ctpn_test.mindrecord", file_num=1)