deeplabv3 hwc & ssd off_optimize & resnet18

modified:   model_zoo/official/cv/deeplabv3/eval.py
This commit is contained in:
unknown 2021-03-26 10:31:24 +08:00
parent 5d0490909d
commit b719324e33
4 changed files with 26 additions and 14 deletions

View File

@ -21,6 +21,7 @@ import cv2
from mindspore import Tensor
import mindspore.common.dtype as mstype
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.nets import net_factory
@ -47,6 +48,8 @@ def parse_args():
parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model')
parser.add_argument('--freeze_bn', action='store_true', default=False, help='freeze bn')
parser.add_argument('--ckpt_path', type=str, default='', help='model to evaluate')
parser.add_argument("--input_format", type=str, choices=["NCHW", "NHWC"], default="NCHW",
help="NCHW or NHWC")
args, _ = parser.parse_known_args()
return args
@ -70,12 +73,16 @@ def resize_long(img, long_size=513):
class BuildEvalNetwork(nn.Cell):
def __init__(self, network):
def __init__(self, network, input_format="NCHW"):
super(BuildEvalNetwork, self).__init__()
self.network = network
self.softmax = nn.Softmax(axis=1)
self.transpose = ops.Transpose()
self.format = input_format
def construct(self, input_data):
if self.format == "NHWC":
input_data = self.transpose(input_data, (0, 3, 1, 2))
output = self.network(input_data)
output = self.softmax(output)
return output
@ -96,7 +103,6 @@ def pre_process(args, img_, crop_size=513):
pad_w = crop_size - img_.shape[1]
if pad_h > 0 or pad_w > 0:
img_ = cv2.copyMakeBorder(img_, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
# hwc to chw
img_ = img_.transpose((2, 0, 1))
return img_, resize_h, resize_w
@ -162,7 +168,7 @@ def net_eval():
else:
raise NotImplementedError('model [{:s}] not recognized'.format(args.model))
eval_net = BuildEvalNetwork(network)
eval_net = BuildEvalNetwork(network, args.input_format)
# load model
param_dict = load_checkpoint(args.ckpt_path)

View File

@ -32,6 +32,8 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
parser.add_argument('--model', type=str.lower, default='deeplab_v3_s8', choices=['deeplab_v3_s16', 'deeplab_v3_s8'],
help='Select model structure (Default: deeplab_v3_s8)')
parser.add_argument('--num_classes', type=int, default=21, help='the number of classes (Default: 21)')
parser.add_argument("--input_format", type=str, choices=["NCHW", "NHWC"], default="NCHW",
help="NCHW or NHWC")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
@ -43,10 +45,13 @@ if __name__ == '__main__':
network = net_factory.nets_map['deeplab_v3_s16']('eval', args.num_classes, 16, True)
else:
network = net_factory.nets_map['deeplab_v3_s8']('eval', args.num_classes, 8, True)
network = BuildEvalNetwork(network)
network = BuildEvalNetwork(network, args.input_format)
param_dict = load_checkpoint(args.ckpt_file)
# load the parameter into net
load_param_into_net(network, param_dict)
input_data = Tensor(np.ones([args.batch_size, 3, args.input_size, args.input_size]).astype(np.float32))
if args.input_format == "NHWC":
input_data = Tensor(np.ones([args.batch_size, args.input_size, args.input_size, 3]).astype(np.float32))
else:
input_data = Tensor(np.ones([args.batch_size, 3, args.input_size, args.input_size]).astype(np.float32))
export(network, input_data, file_name=args.file_name, file_format=args.file_format)

View File

@ -39,10 +39,9 @@ using mindspore::dataset::vision::CenterCrop;
using mindspore::dataset::vision::Normalize;
using mindspore::dataset::vision::HWC2CHW;
using mindspore::dataset::TensorTransform;
using mindspore::GlobalContext;
using mindspore::Context;
using mindspore::Serialization;
using mindspore::Model;
using mindspore::ModelContext;
using mindspore::Status;
using mindspore::ModelType;
using mindspore::GraphCell;
@ -62,14 +61,15 @@ int main(int argc, char **argv) {
return 1;
}
auto context = std::make_shared<Context>();
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310->SetDeviceID(FLAGS_device_id);
context->MutableDeviceInfo().push_back(ascend310);
mindspore::Graph graph;
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
Model model;
Status ret = model.Build(GraphCell(graph), context);
GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310);
GlobalContext::SetGlobalDeviceID(FLAGS_device_id);
auto graph = Serialization::LoadModel(FLAGS_mindir_path, ModelType::kMindIR);
auto model_context = std::make_shared<mindspore::ModelContext>();
Model model(GraphCell(graph), model_context);
Status ret = model.Build();
if (ret != kSuccess) {
std::cout << "ERROR: Build failed." << std::endl;
return 1;

View File

@ -66,6 +66,7 @@ int main(int argc, char **argv) {
auto context = std::make_shared<Context>();
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310->SetDeviceID(FLAGS_device_id);
ascend310->SetBufferOptimizeMode("off_optimize");
context->MutableDeviceInfo().push_back(ascend310);
mindspore::Graph graph;
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);