forked from mindspore-Ecosystem/mindspore
parent
59935cd7d9
commit
b6d7009695
|
@ -75,6 +75,7 @@ def parse_args():
|
|||
args_opt.max_predict_img_size = cfg.max_predict_img_size
|
||||
args_opt.last_model_name = cfg.last_model_name
|
||||
args_opt.saved_model_file_path = cfg.saved_model_file_path
|
||||
args_opt.is_train = False
|
||||
|
||||
return args_opt
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"]
|
|||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
choices=["Ascend", "GPU"], help="device target(default: Ascend)")
|
||||
args = parser.parse_args()
|
||||
args.is_train = False
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
|
@ -39,7 +40,7 @@ if args.device_target == "Ascend":
|
|||
|
||||
if __name__ == '__main__':
|
||||
|
||||
net = AdvancedEast()
|
||||
net = AdvancedEast(args)
|
||||
|
||||
assert args.ckpt_file is not None, "checkpoint_path is None."
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ from mindspore.ops import composite as C
|
|||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import ResizeNearestNeighbor
|
||||
from mindspore import Tensor, ParameterTuple, Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.initializer import initializer, TruncatedNormal
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
import numpy as np
|
||||
|
||||
|
@ -73,6 +73,7 @@ class AdvancedEast(nn.Cell):
|
|||
if self.device_target == 'GPU':
|
||||
|
||||
self.vgg16 = vgg16()
|
||||
if args.is_train:
|
||||
param_dict = load_checkpoint(cfg.vgg_weights)
|
||||
load_param_into_net(self.vgg16, param_dict)
|
||||
|
||||
|
@ -110,7 +111,23 @@ class AdvancedEast(nn.Cell):
|
|||
self.conv9 = nn.Conv2d(32, 2, 1, weight_init='XavierUniform', has_bias=True)
|
||||
self.conv10 = nn.Conv2d(32, 4, 1, weight_init='XavierUniform', has_bias=True)
|
||||
else:
|
||||
if args.is_train:
|
||||
vgg_dict = np.load(cfg.vgg_npy, encoding='latin1', allow_pickle=True).item()
|
||||
shape_dict = {
|
||||
'conv1_1': [64, 3, 3, 3],
|
||||
'conv1_2': [64, 64, 3, 3],
|
||||
'conv2_1': [128, 64, 3, 3],
|
||||
'conv2_2': [128, 128, 3, 3],
|
||||
'conv3_1': [256, 128, 3, 3],
|
||||
'conv3_2': [256, 256, 3, 3],
|
||||
'conv3_3': [256, 256, 3, 3],
|
||||
'conv4_1': [512, 256, 3, 3],
|
||||
'conv4_2': [512, 512, 3, 3],
|
||||
'conv4_3': [512, 512, 3, 3],
|
||||
'conv5_1': [512, 512, 3, 3],
|
||||
'conv5_2': [512, 512, 3, 3],
|
||||
'conv5_3': [512, 512, 3, 3],
|
||||
}
|
||||
|
||||
def get_var(name, idx):
|
||||
value = vgg_dict[name][idx]
|
||||
|
@ -131,13 +148,18 @@ class AdvancedEast(nn.Cell):
|
|||
|
||||
def __init__(self, name):
|
||||
super(VGG_Conv, self).__init__()
|
||||
if args.is_train:
|
||||
filters, conv_biases = get_conv_var(name)
|
||||
out_channels, in_channels, filter_size, _ = filters.shape
|
||||
else:
|
||||
out_channels, in_channels, filter_size, _ = shape_dict[name]
|
||||
self.conv2d = P.Conv2D(out_channels, filter_size, pad_mode='same', mode=1)
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.weight = Parameter(initializer(filters, [out_channels, in_channels, filter_size, filter_size]),
|
||||
self.weight = Parameter(initializer(filters if args.is_train else TruncatedNormal(),
|
||||
[out_channels, in_channels, filter_size, filter_size]),
|
||||
name='weight')
|
||||
self.bias = Parameter(initializer(conv_biases, [out_channels]), name='bias')
|
||||
self.bias = Parameter(initializer(conv_biases if args.is_train else TruncatedNormal(),
|
||||
[out_channels]), name='bias')
|
||||
self.relu = P.ReLU()
|
||||
self.gn = nn.GroupNorm(32, out_channels)
|
||||
|
||||
|
|
|
@ -67,6 +67,7 @@ def parse_args():
|
|||
args_opt.last_model_name = cfg.last_model_name
|
||||
args_opt.saved_model_file_path = cfg.saved_model_file_path
|
||||
args_opt.ds_sink_mode = cfg.ds_sink_mode
|
||||
args_opt.is_train = True
|
||||
return args_opt
|
||||
|
||||
|
||||
|
|
|
@ -69,6 +69,7 @@ def parse_args(cloud_args=None):
|
|||
args_opt.results_dir = cfg.results_dir
|
||||
args_opt.last_model_name = cfg.last_model_name
|
||||
args_opt.saved_model_file_path = cfg.saved_model_file_path
|
||||
args_opt.is_train = True
|
||||
return args_opt
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue