deeplabv3 hwc & ssd off_optimize & resnet18
modified: model_zoo/official/cv/deeplabv3/eval.py
This commit is contained in:
parent
5d0490909d
commit
b719324e33
|
@ -21,6 +21,7 @@ import cv2
|
|||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.nets import net_factory
|
||||
|
@ -47,6 +48,8 @@ def parse_args():
|
|||
parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model')
|
||||
parser.add_argument('--freeze_bn', action='store_true', default=False, help='freeze bn')
|
||||
parser.add_argument('--ckpt_path', type=str, default='', help='model to evaluate')
|
||||
parser.add_argument("--input_format", type=str, choices=["NCHW", "NHWC"], default="NCHW",
|
||||
help="NCHW or NHWC")
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
return args
|
||||
|
@ -70,12 +73,16 @@ def resize_long(img, long_size=513):
|
|||
|
||||
|
||||
class BuildEvalNetwork(nn.Cell):
|
||||
def __init__(self, network):
|
||||
def __init__(self, network, input_format="NCHW"):
|
||||
super(BuildEvalNetwork, self).__init__()
|
||||
self.network = network
|
||||
self.softmax = nn.Softmax(axis=1)
|
||||
self.transpose = ops.Transpose()
|
||||
self.format = input_format
|
||||
|
||||
def construct(self, input_data):
|
||||
if self.format == "NHWC":
|
||||
input_data = self.transpose(input_data, (0, 3, 1, 2))
|
||||
output = self.network(input_data)
|
||||
output = self.softmax(output)
|
||||
return output
|
||||
|
@ -96,7 +103,6 @@ def pre_process(args, img_, crop_size=513):
|
|||
pad_w = crop_size - img_.shape[1]
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
img_ = cv2.copyMakeBorder(img_, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
|
||||
|
||||
# hwc to chw
|
||||
img_ = img_.transpose((2, 0, 1))
|
||||
return img_, resize_h, resize_w
|
||||
|
@ -162,7 +168,7 @@ def net_eval():
|
|||
else:
|
||||
raise NotImplementedError('model [{:s}] not recognized'.format(args.model))
|
||||
|
||||
eval_net = BuildEvalNetwork(network)
|
||||
eval_net = BuildEvalNetwork(network, args.input_format)
|
||||
|
||||
# load model
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
|
|
|
@ -32,6 +32,8 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"
|
|||
parser.add_argument('--model', type=str.lower, default='deeplab_v3_s8', choices=['deeplab_v3_s16', 'deeplab_v3_s8'],
|
||||
help='Select model structure (Default: deeplab_v3_s8)')
|
||||
parser.add_argument('--num_classes', type=int, default=21, help='the number of classes (Default: 21)')
|
||||
parser.add_argument("--input_format", type=str, choices=["NCHW", "NHWC"], default="NCHW",
|
||||
help="NCHW or NHWC")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
|
@ -43,10 +45,13 @@ if __name__ == '__main__':
|
|||
network = net_factory.nets_map['deeplab_v3_s16']('eval', args.num_classes, 16, True)
|
||||
else:
|
||||
network = net_factory.nets_map['deeplab_v3_s8']('eval', args.num_classes, 8, True)
|
||||
network = BuildEvalNetwork(network)
|
||||
network = BuildEvalNetwork(network, args.input_format)
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
|
||||
# load the parameter into net
|
||||
load_param_into_net(network, param_dict)
|
||||
input_data = Tensor(np.ones([args.batch_size, 3, args.input_size, args.input_size]).astype(np.float32))
|
||||
if args.input_format == "NHWC":
|
||||
input_data = Tensor(np.ones([args.batch_size, args.input_size, args.input_size, 3]).astype(np.float32))
|
||||
else:
|
||||
input_data = Tensor(np.ones([args.batch_size, 3, args.input_size, args.input_size]).astype(np.float32))
|
||||
export(network, input_data, file_name=args.file_name, file_format=args.file_format)
|
||||
|
|
|
@ -39,10 +39,9 @@ using mindspore::dataset::vision::CenterCrop;
|
|||
using mindspore::dataset::vision::Normalize;
|
||||
using mindspore::dataset::vision::HWC2CHW;
|
||||
using mindspore::dataset::TensorTransform;
|
||||
using mindspore::GlobalContext;
|
||||
using mindspore::Context;
|
||||
using mindspore::Serialization;
|
||||
using mindspore::Model;
|
||||
using mindspore::ModelContext;
|
||||
using mindspore::Status;
|
||||
using mindspore::ModelType;
|
||||
using mindspore::GraphCell;
|
||||
|
@ -62,14 +61,15 @@ int main(int argc, char **argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
auto context = std::make_shared<Context>();
|
||||
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
||||
ascend310->SetDeviceID(FLAGS_device_id);
|
||||
context->MutableDeviceInfo().push_back(ascend310);
|
||||
mindspore::Graph graph;
|
||||
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
|
||||
Model model;
|
||||
Status ret = model.Build(GraphCell(graph), context);
|
||||
|
||||
GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310);
|
||||
GlobalContext::SetGlobalDeviceID(FLAGS_device_id);
|
||||
auto graph = Serialization::LoadModel(FLAGS_mindir_path, ModelType::kMindIR);
|
||||
auto model_context = std::make_shared<mindspore::ModelContext>();
|
||||
|
||||
Model model(GraphCell(graph), model_context);
|
||||
Status ret = model.Build();
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "ERROR: Build failed." << std::endl;
|
||||
return 1;
|
||||
|
|
|
@ -66,6 +66,7 @@ int main(int argc, char **argv) {
|
|||
auto context = std::make_shared<Context>();
|
||||
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
||||
ascend310->SetDeviceID(FLAGS_device_id);
|
||||
ascend310->SetBufferOptimizeMode("off_optimize");
|
||||
context->MutableDeviceInfo().push_back(ascend310);
|
||||
mindspore::Graph graph;
|
||||
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
|
||||
|
|
Loading…
Reference in New Issue