deeplabv3 hwc & ssd off_optimize & resnet18

modified:   model_zoo/official/cv/deeplabv3/eval.py
This commit is contained in:
unknown 2021-03-26 10:31:24 +08:00
parent 5d0490909d
commit b719324e33
4 changed files with 26 additions and 14 deletions

View File

@ -21,6 +21,7 @@ import cv2
from mindspore import Tensor from mindspore import Tensor
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import context from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.nets import net_factory from src.nets import net_factory
@ -47,6 +48,8 @@ def parse_args():
parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model') parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model')
parser.add_argument('--freeze_bn', action='store_true', default=False, help='freeze bn') parser.add_argument('--freeze_bn', action='store_true', default=False, help='freeze bn')
parser.add_argument('--ckpt_path', type=str, default='', help='model to evaluate') parser.add_argument('--ckpt_path', type=str, default='', help='model to evaluate')
parser.add_argument("--input_format", type=str, choices=["NCHW", "NHWC"], default="NCHW",
help="NCHW or NHWC")
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
return args return args
@ -70,12 +73,16 @@ def resize_long(img, long_size=513):
class BuildEvalNetwork(nn.Cell): class BuildEvalNetwork(nn.Cell):
def __init__(self, network): def __init__(self, network, input_format="NCHW"):
super(BuildEvalNetwork, self).__init__() super(BuildEvalNetwork, self).__init__()
self.network = network self.network = network
self.softmax = nn.Softmax(axis=1) self.softmax = nn.Softmax(axis=1)
self.transpose = ops.Transpose()
self.format = input_format
def construct(self, input_data): def construct(self, input_data):
if self.format == "NHWC":
input_data = self.transpose(input_data, (0, 3, 1, 2))
output = self.network(input_data) output = self.network(input_data)
output = self.softmax(output) output = self.softmax(output)
return output return output
@ -96,7 +103,6 @@ def pre_process(args, img_, crop_size=513):
pad_w = crop_size - img_.shape[1] pad_w = crop_size - img_.shape[1]
if pad_h > 0 or pad_w > 0: if pad_h > 0 or pad_w > 0:
img_ = cv2.copyMakeBorder(img_, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0) img_ = cv2.copyMakeBorder(img_, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
# hwc to chw # hwc to chw
img_ = img_.transpose((2, 0, 1)) img_ = img_.transpose((2, 0, 1))
return img_, resize_h, resize_w return img_, resize_h, resize_w
@ -162,7 +168,7 @@ def net_eval():
else: else:
raise NotImplementedError('model [{:s}] not recognized'.format(args.model)) raise NotImplementedError('model [{:s}] not recognized'.format(args.model))
eval_net = BuildEvalNetwork(network) eval_net = BuildEvalNetwork(network, args.input_format)
# load model # load model
param_dict = load_checkpoint(args.ckpt_path) param_dict = load_checkpoint(args.ckpt_path)

View File

@ -32,6 +32,8 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
parser.add_argument('--model', type=str.lower, default='deeplab_v3_s8', choices=['deeplab_v3_s16', 'deeplab_v3_s8'], parser.add_argument('--model', type=str.lower, default='deeplab_v3_s8', choices=['deeplab_v3_s16', 'deeplab_v3_s8'],
help='Select model structure (Default: deeplab_v3_s8)') help='Select model structure (Default: deeplab_v3_s8)')
parser.add_argument('--num_classes', type=int, default=21, help='the number of classes (Default: 21)') parser.add_argument('--num_classes', type=int, default=21, help='the number of classes (Default: 21)')
parser.add_argument("--input_format", type=str, choices=["NCHW", "NHWC"], default="NCHW",
help="NCHW or NHWC")
args = parser.parse_args() args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
@ -43,10 +45,13 @@ if __name__ == '__main__':
network = net_factory.nets_map['deeplab_v3_s16']('eval', args.num_classes, 16, True) network = net_factory.nets_map['deeplab_v3_s16']('eval', args.num_classes, 16, True)
else: else:
network = net_factory.nets_map['deeplab_v3_s8']('eval', args.num_classes, 8, True) network = net_factory.nets_map['deeplab_v3_s8']('eval', args.num_classes, 8, True)
network = BuildEvalNetwork(network) network = BuildEvalNetwork(network, args.input_format)
param_dict = load_checkpoint(args.ckpt_file) param_dict = load_checkpoint(args.ckpt_file)
# load the parameter into net # load the parameter into net
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
input_data = Tensor(np.ones([args.batch_size, 3, args.input_size, args.input_size]).astype(np.float32)) if args.input_format == "NHWC":
input_data = Tensor(np.ones([args.batch_size, args.input_size, args.input_size, 3]).astype(np.float32))
else:
input_data = Tensor(np.ones([args.batch_size, 3, args.input_size, args.input_size]).astype(np.float32))
export(network, input_data, file_name=args.file_name, file_format=args.file_format) export(network, input_data, file_name=args.file_name, file_format=args.file_format)

View File

@ -39,10 +39,9 @@ using mindspore::dataset::vision::CenterCrop;
using mindspore::dataset::vision::Normalize; using mindspore::dataset::vision::Normalize;
using mindspore::dataset::vision::HWC2CHW; using mindspore::dataset::vision::HWC2CHW;
using mindspore::dataset::TensorTransform; using mindspore::dataset::TensorTransform;
using mindspore::GlobalContext; using mindspore::Context;
using mindspore::Serialization; using mindspore::Serialization;
using mindspore::Model; using mindspore::Model;
using mindspore::ModelContext;
using mindspore::Status; using mindspore::Status;
using mindspore::ModelType; using mindspore::ModelType;
using mindspore::GraphCell; using mindspore::GraphCell;
@ -62,14 +61,15 @@ int main(int argc, char **argv) {
return 1; return 1;
} }
auto context = std::make_shared<Context>();
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310->SetDeviceID(FLAGS_device_id);
context->MutableDeviceInfo().push_back(ascend310);
mindspore::Graph graph;
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
Model model;
Status ret = model.Build(GraphCell(graph), context);
GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310);
GlobalContext::SetGlobalDeviceID(FLAGS_device_id);
auto graph = Serialization::LoadModel(FLAGS_mindir_path, ModelType::kMindIR);
auto model_context = std::make_shared<mindspore::ModelContext>();
Model model(GraphCell(graph), model_context);
Status ret = model.Build();
if (ret != kSuccess) { if (ret != kSuccess) {
std::cout << "ERROR: Build failed." << std::endl; std::cout << "ERROR: Build failed." << std::endl;
return 1; return 1;

View File

@ -66,6 +66,7 @@ int main(int argc, char **argv) {
auto context = std::make_shared<Context>(); auto context = std::make_shared<Context>();
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>(); auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310->SetDeviceID(FLAGS_device_id); ascend310->SetDeviceID(FLAGS_device_id);
ascend310->SetBufferOptimizeMode("off_optimize");
context->MutableDeviceInfo().push_back(ascend310); context->MutableDeviceInfo().push_back(ascend310);
mindspore::Graph graph; mindspore::Graph graph;
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph); Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);