fixup

fixup
This commit is contained in:
累到崴脚 2021-06-02 16:30:54 +08:00
parent 59935cd7d9
commit b6d7009695
5 changed files with 35 additions and 9 deletions

View File

@ -75,6 +75,7 @@ def parse_args():
args_opt.max_predict_img_size = cfg.max_predict_img_size
args_opt.last_model_name = cfg.last_model_name
args_opt.saved_model_file_path = cfg.saved_model_file_path
args_opt.is_train = False
return args_opt

View File

@ -32,6 +32,7 @@ parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"]
parser.add_argument("--device_target", type=str, default="Ascend",
choices=["Ascend", "GPU"], help="device target(default: Ascend)")
args = parser.parse_args()
args.is_train = False
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
@ -39,7 +40,7 @@ if args.device_target == "Ascend":
if __name__ == '__main__':
net = AdvancedEast()
net = AdvancedEast(args)
assert args.ckpt_file is not None, "checkpoint_path is None."

View File

@ -22,7 +22,7 @@ from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import ResizeNearestNeighbor
from mindspore import Tensor, ParameterTuple, Parameter
from mindspore.common.initializer import initializer
from mindspore.common.initializer import initializer, TruncatedNormal
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import numpy as np
@ -73,8 +73,9 @@ class AdvancedEast(nn.Cell):
if self.device_target == 'GPU':
self.vgg16 = vgg16()
param_dict = load_checkpoint(cfg.vgg_weights)
load_param_into_net(self.vgg16, param_dict)
if args.is_train:
param_dict = load_checkpoint(cfg.vgg_weights)
load_param_into_net(self.vgg16, param_dict)
self.bn1 = nn.BatchNorm2d(1024, momentum=0.99, eps=1e-3)
self.conv1 = nn.Conv2d(1024, 128, 1, weight_init='XavierUniform', has_bias=True)
@ -110,7 +111,23 @@ class AdvancedEast(nn.Cell):
self.conv9 = nn.Conv2d(32, 2, 1, weight_init='XavierUniform', has_bias=True)
self.conv10 = nn.Conv2d(32, 4, 1, weight_init='XavierUniform', has_bias=True)
else:
vgg_dict = np.load(cfg.vgg_npy, encoding='latin1', allow_pickle=True).item()
if args.is_train:
vgg_dict = np.load(cfg.vgg_npy, encoding='latin1', allow_pickle=True).item()
shape_dict = {
'conv1_1': [64, 3, 3, 3],
'conv1_2': [64, 64, 3, 3],
'conv2_1': [128, 64, 3, 3],
'conv2_2': [128, 128, 3, 3],
'conv3_1': [256, 128, 3, 3],
'conv3_2': [256, 256, 3, 3],
'conv3_3': [256, 256, 3, 3],
'conv4_1': [512, 256, 3, 3],
'conv4_2': [512, 512, 3, 3],
'conv4_3': [512, 512, 3, 3],
'conv5_1': [512, 512, 3, 3],
'conv5_2': [512, 512, 3, 3],
'conv5_3': [512, 512, 3, 3],
}
def get_var(name, idx):
value = vgg_dict[name][idx]
@ -131,13 +148,18 @@ class AdvancedEast(nn.Cell):
def __init__(self, name):
super(VGG_Conv, self).__init__()
filters, conv_biases = get_conv_var(name)
out_channels, in_channels, filter_size, _ = filters.shape
if args.is_train:
filters, conv_biases = get_conv_var(name)
out_channels, in_channels, filter_size, _ = filters.shape
else:
out_channels, in_channels, filter_size, _ = shape_dict[name]
self.conv2d = P.Conv2D(out_channels, filter_size, pad_mode='same', mode=1)
self.bias_add = P.BiasAdd()
self.weight = Parameter(initializer(filters, [out_channels, in_channels, filter_size, filter_size]),
self.weight = Parameter(initializer(filters if args.is_train else TruncatedNormal(),
[out_channels, in_channels, filter_size, filter_size]),
name='weight')
self.bias = Parameter(initializer(conv_biases, [out_channels]), name='bias')
self.bias = Parameter(initializer(conv_biases if args.is_train else TruncatedNormal(),
[out_channels]), name='bias')
self.relu = P.ReLU()
self.gn = nn.GroupNorm(32, out_channels)

View File

@ -67,6 +67,7 @@ def parse_args():
args_opt.last_model_name = cfg.last_model_name
args_opt.saved_model_file_path = cfg.saved_model_file_path
args_opt.ds_sink_mode = cfg.ds_sink_mode
args_opt.is_train = True
return args_opt

View File

@ -69,6 +69,7 @@ def parse_args(cloud_args=None):
args_opt.results_dir = cfg.results_dir
args_opt.last_model_name = cfg.last_model_name
args_opt.saved_model_file_path = cfg.saved_model_file_path
args_opt.is_train = True
return args_opt