deeplabv3 hwc & ssd off_optimize & resnet18
modified: model_zoo/official/cv/deeplabv3/eval.py
This commit is contained in:
parent
5d0490909d
commit
b719324e33
|
@ -21,6 +21,7 @@ import cv2
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
|
import mindspore.ops as ops
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from src.nets import net_factory
|
from src.nets import net_factory
|
||||||
|
@ -47,6 +48,8 @@ def parse_args():
|
||||||
parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model')
|
parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model')
|
||||||
parser.add_argument('--freeze_bn', action='store_true', default=False, help='freeze bn')
|
parser.add_argument('--freeze_bn', action='store_true', default=False, help='freeze bn')
|
||||||
parser.add_argument('--ckpt_path', type=str, default='', help='model to evaluate')
|
parser.add_argument('--ckpt_path', type=str, default='', help='model to evaluate')
|
||||||
|
parser.add_argument("--input_format", type=str, choices=["NCHW", "NHWC"], default="NCHW",
|
||||||
|
help="NCHW or NHWC")
|
||||||
|
|
||||||
args, _ = parser.parse_known_args()
|
args, _ = parser.parse_known_args()
|
||||||
return args
|
return args
|
||||||
|
@ -70,12 +73,16 @@ def resize_long(img, long_size=513):
|
||||||
|
|
||||||
|
|
||||||
class BuildEvalNetwork(nn.Cell):
|
class BuildEvalNetwork(nn.Cell):
|
||||||
def __init__(self, network):
|
def __init__(self, network, input_format="NCHW"):
|
||||||
super(BuildEvalNetwork, self).__init__()
|
super(BuildEvalNetwork, self).__init__()
|
||||||
self.network = network
|
self.network = network
|
||||||
self.softmax = nn.Softmax(axis=1)
|
self.softmax = nn.Softmax(axis=1)
|
||||||
|
self.transpose = ops.Transpose()
|
||||||
|
self.format = input_format
|
||||||
|
|
||||||
def construct(self, input_data):
|
def construct(self, input_data):
|
||||||
|
if self.format == "NHWC":
|
||||||
|
input_data = self.transpose(input_data, (0, 3, 1, 2))
|
||||||
output = self.network(input_data)
|
output = self.network(input_data)
|
||||||
output = self.softmax(output)
|
output = self.softmax(output)
|
||||||
return output
|
return output
|
||||||
|
@ -96,7 +103,6 @@ def pre_process(args, img_, crop_size=513):
|
||||||
pad_w = crop_size - img_.shape[1]
|
pad_w = crop_size - img_.shape[1]
|
||||||
if pad_h > 0 or pad_w > 0:
|
if pad_h > 0 or pad_w > 0:
|
||||||
img_ = cv2.copyMakeBorder(img_, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
|
img_ = cv2.copyMakeBorder(img_, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
|
||||||
|
|
||||||
# hwc to chw
|
# hwc to chw
|
||||||
img_ = img_.transpose((2, 0, 1))
|
img_ = img_.transpose((2, 0, 1))
|
||||||
return img_, resize_h, resize_w
|
return img_, resize_h, resize_w
|
||||||
|
@ -162,7 +168,7 @@ def net_eval():
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('model [{:s}] not recognized'.format(args.model))
|
raise NotImplementedError('model [{:s}] not recognized'.format(args.model))
|
||||||
|
|
||||||
eval_net = BuildEvalNetwork(network)
|
eval_net = BuildEvalNetwork(network, args.input_format)
|
||||||
|
|
||||||
# load model
|
# load model
|
||||||
param_dict = load_checkpoint(args.ckpt_path)
|
param_dict = load_checkpoint(args.ckpt_path)
|
||||||
|
|
|
@ -32,6 +32,8 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
||||||
parser.add_argument('--model', type=str.lower, default='deeplab_v3_s8', choices=['deeplab_v3_s16', 'deeplab_v3_s8'],
|
parser.add_argument('--model', type=str.lower, default='deeplab_v3_s8', choices=['deeplab_v3_s16', 'deeplab_v3_s8'],
|
||||||
help='Select model structure (Default: deeplab_v3_s8)')
|
help='Select model structure (Default: deeplab_v3_s8)')
|
||||||
parser.add_argument('--num_classes', type=int, default=21, help='the number of classes (Default: 21)')
|
parser.add_argument('--num_classes', type=int, default=21, help='the number of classes (Default: 21)')
|
||||||
|
parser.add_argument("--input_format", type=str, choices=["NCHW", "NHWC"], default="NCHW",
|
||||||
|
help="NCHW or NHWC")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
@ -43,10 +45,13 @@ if __name__ == '__main__':
|
||||||
network = net_factory.nets_map['deeplab_v3_s16']('eval', args.num_classes, 16, True)
|
network = net_factory.nets_map['deeplab_v3_s16']('eval', args.num_classes, 16, True)
|
||||||
else:
|
else:
|
||||||
network = net_factory.nets_map['deeplab_v3_s8']('eval', args.num_classes, 8, True)
|
network = net_factory.nets_map['deeplab_v3_s8']('eval', args.num_classes, 8, True)
|
||||||
network = BuildEvalNetwork(network)
|
network = BuildEvalNetwork(network, args.input_format)
|
||||||
param_dict = load_checkpoint(args.ckpt_file)
|
param_dict = load_checkpoint(args.ckpt_file)
|
||||||
|
|
||||||
# load the parameter into net
|
# load the parameter into net
|
||||||
load_param_into_net(network, param_dict)
|
load_param_into_net(network, param_dict)
|
||||||
input_data = Tensor(np.ones([args.batch_size, 3, args.input_size, args.input_size]).astype(np.float32))
|
if args.input_format == "NHWC":
|
||||||
|
input_data = Tensor(np.ones([args.batch_size, args.input_size, args.input_size, 3]).astype(np.float32))
|
||||||
|
else:
|
||||||
|
input_data = Tensor(np.ones([args.batch_size, 3, args.input_size, args.input_size]).astype(np.float32))
|
||||||
export(network, input_data, file_name=args.file_name, file_format=args.file_format)
|
export(network, input_data, file_name=args.file_name, file_format=args.file_format)
|
||||||
|
|
|
@ -39,10 +39,9 @@ using mindspore::dataset::vision::CenterCrop;
|
||||||
using mindspore::dataset::vision::Normalize;
|
using mindspore::dataset::vision::Normalize;
|
||||||
using mindspore::dataset::vision::HWC2CHW;
|
using mindspore::dataset::vision::HWC2CHW;
|
||||||
using mindspore::dataset::TensorTransform;
|
using mindspore::dataset::TensorTransform;
|
||||||
using mindspore::GlobalContext;
|
using mindspore::Context;
|
||||||
using mindspore::Serialization;
|
using mindspore::Serialization;
|
||||||
using mindspore::Model;
|
using mindspore::Model;
|
||||||
using mindspore::ModelContext;
|
|
||||||
using mindspore::Status;
|
using mindspore::Status;
|
||||||
using mindspore::ModelType;
|
using mindspore::ModelType;
|
||||||
using mindspore::GraphCell;
|
using mindspore::GraphCell;
|
||||||
|
@ -62,14 +61,15 @@ int main(int argc, char **argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto context = std::make_shared<Context>();
|
||||||
|
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
||||||
|
ascend310->SetDeviceID(FLAGS_device_id);
|
||||||
|
context->MutableDeviceInfo().push_back(ascend310);
|
||||||
|
mindspore::Graph graph;
|
||||||
|
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
|
||||||
|
Model model;
|
||||||
|
Status ret = model.Build(GraphCell(graph), context);
|
||||||
|
|
||||||
GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310);
|
|
||||||
GlobalContext::SetGlobalDeviceID(FLAGS_device_id);
|
|
||||||
auto graph = Serialization::LoadModel(FLAGS_mindir_path, ModelType::kMindIR);
|
|
||||||
auto model_context = std::make_shared<mindspore::ModelContext>();
|
|
||||||
|
|
||||||
Model model(GraphCell(graph), model_context);
|
|
||||||
Status ret = model.Build();
|
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
std::cout << "ERROR: Build failed." << std::endl;
|
std::cout << "ERROR: Build failed." << std::endl;
|
||||||
return 1;
|
return 1;
|
||||||
|
|
|
@ -66,6 +66,7 @@ int main(int argc, char **argv) {
|
||||||
auto context = std::make_shared<Context>();
|
auto context = std::make_shared<Context>();
|
||||||
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
||||||
ascend310->SetDeviceID(FLAGS_device_id);
|
ascend310->SetDeviceID(FLAGS_device_id);
|
||||||
|
ascend310->SetBufferOptimizeMode("off_optimize");
|
||||||
context->MutableDeviceInfo().push_back(ascend310);
|
context->MutableDeviceInfo().push_back(ascend310);
|
||||||
mindspore::Graph graph;
|
mindspore::Graph graph;
|
||||||
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
|
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
|
||||||
|
|
Loading…
Reference in New Issue