!15129 fix ctpn vgg16 backbone training

From: @qujianwei
Reviewed-by: @oacjiewen,@c_34
Signed-off-by: @c_34
This commit is contained in:
mindspore-ci-bot 2021-04-14 14:50:21 +08:00 committed by Gitee
commit 210a4a2490
4 changed files with 26 additions and 25 deletions

View File

@ -169,11 +169,23 @@ The `pretrained_path` should be a checkpoint of vgg16 trained on Imagenet2012. T
...
from src.vgg16 import VGG16
...
network = VGG16()
network = VGG16(num_classes=cfg.num_classes)
...
```
To train a better model, you can modify some parameter in modelzoo/official/cv/vgg16/src/config.py, here we suggested you modify the "warmup_epochs" just like below, you can also try to adjust other parameter.
```python
imagenet_cfg = edict({
...
"warmup_epochs": 5
...
})
```
Then you can train it with ImageNet2012.
> Notes:
> RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html) , and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). For large models like InceptionV4, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size.

View File

@ -56,6 +56,7 @@ def ctpn_infer_test(dataset_path='', ckpt_path='', img_dir=''):
os.mkdir(output_dir)
for file in os.listdir(img_dir):
img_basenames.append(os.path.basename(file))
img_basenames = sorted(img_basenames)
for data in ds.create_dict_iterator():
img_data = data['image']
img_metas = data['image_shape']

View File

@ -23,36 +23,22 @@ def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)
def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=False):
"""Batchnorm2D wrapper."""
gamma_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32))
beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32))
moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32))
moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32))
return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init,
beta_init=beta_init, moving_mean_init=moving_mean_init,
moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics)
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad', weights_update=True):
"""Conv2D wrapper."""
weights = 'ones'
layers = []
conv = nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=False)
pad_mode=pad_mode, has_bias=False)
if not weights_update:
conv.weight.requires_grad = False
layers += [conv]
layers += [_BatchNorm2dInit(out_channels)]
layers += [nn.BatchNorm2d(out_channels)]
return nn.SequentialCell(layers)
def _fc(in_channels, out_channels):
'''full connection layer'''
weight = _weight_variable((out_channels, in_channels))
bias = _weight_variable((out_channels,))
return nn.Dense(in_channels, out_channels, weight, bias)
return nn.Dense(in_channels, out_channels)
class VGG16FeatureExtraction(nn.Cell):
@ -141,36 +127,38 @@ class VGG16Classfier(nn.Cell):
self.relu = nn.ReLU()
self.fc1 = _fc(in_channels=7*7*512, out_channels=4096)
self.fc2 = _fc(in_channels=4096, out_channels=4096)
self.batch_size = 32
self.reshape = P.Reshape()
self.dropout = nn.Dropout(0.5)
def construct(self, x):
"""
:param x: shape=(B, 512, 7, 7)
:return:
"""
x = self.reshape(x, (self.batch_size, 7*7*512))
x = self.reshape(x, (-1, 7*7*512))
x = self.fc1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.relu(x)
x = self.dropout(x)
return x
class VGG16(nn.Cell):
def __init__(self):
def __init__(self, num_classes):
"""VGG16 construct for training backbone"""
super(VGG16, self).__init__()
self.feature_extraction = VGG16FeatureExtraction(weights_update=True)
self.vgg16_feature_extractor = VGG16FeatureExtraction(weights_update=True)
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.classifier = VGG16Classfier()
self.fc3 = _fc(in_channels=4096, out_channels=1000)
self.fc3 = _fc(in_channels=4096, out_channels=num_classes)
def construct(self, x):
"""
:param x: shape=(B, 3, 224, 224)
:return: logits, shape=(B, 1000)
"""
feature_maps = self.feature_extraction(x)
feature_maps = self.vgg16_feature_extractor(x)
x = self.max_pool(feature_maps)
x = self.classifier(x)
x = self.fc3(x)

View File

@ -145,7 +145,7 @@ def create_train_dataset(dataset_type):
# test: icdar2013 test
icdar_test_image_files, icdar_test_anno_dict = create_icdar_svt_label(config.icdar13_test_path[0],\
config.icdar13_test_path[1], "")
image_files = icdar_test_image_files
image_files = sorted(icdar_test_image_files)
image_anno_dict = icdar_test_anno_dict
data_to_mindrecord_byte_image(image_files, image_anno_dict, config.test_dataset_path, \
prefix="ctpn_test.mindrecord", file_num=1)