mindspore/tests/ut/python/pipeline/infer/infer.py

83 lines
2.6 KiB
Python

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" infer """
from argparse import ArgumentParser
import numpy as np
from mindspore import Tensor
from ....dataset_mock import MindData
__factory = {
"resnet50": resnet50(),
}
def parse_args():
""" parse_args """
parser = ArgumentParser(description="resnet50 example")
parser.add_argument("--model", type=str, default="resnet50",
help="the network architecture for training or testing")
parser.add_argument("--phase", type=str, default="test",
help="the phase of the model, default is test.")
parser.add_argument("--file_path", type=str, default="/data/file/test1.txt",
help="data directory of training or testing")
parser.add_argument("--batch_size", type=int, default=1,
help="batch size for training or testing ")
return parser.parse_args()
def get_model(name):
""" get_model """
if name not in __factory:
raise KeyError("unknown model:", name)
return __factory[name]
def get_dataset(batch_size=32):
""" get_dataset """
dataset_types = np.float32
dataset_shapes = (batch_size, 3, 224, 224)
dataset = MindData(size=2, batch_size=batch_size,
np_types=dataset_types,
output_shapes=dataset_shapes,
input_indexs=(0, 1))
return dataset
# pylint: disable=unused-argument
def test(name, file_path, batch_size):
""" test """
network = get_model(name)
batch = get_dataset(batch_size=batch_size)
data_list = []
for data in batch:
data_list.append(data.asnumpy())
batch_data = np.concatenate(data_list, axis=0).transpose((0, 3, 1, 2))
input_tensor = Tensor(batch_data)
print(input_tensor.shape)
network(input_tensor)
if __name__ == '__main__':
args = parse_args()
if args.phase == "train":
raise NotImplementedError
test(args.model, args.file_path, args.batch_size)