forked from mindspore-Ecosystem/mindspore
83 lines
2.6 KiB
Python
83 lines
2.6 KiB
Python
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
""" infer """
|
|
from argparse import ArgumentParser
|
|
import numpy as np
|
|
|
|
from mindspore import Tensor
|
|
from ....dataset_mock import MindData
|
|
|
|
__factory = {
|
|
"resnet50": resnet50(),
|
|
}
|
|
|
|
|
|
def parse_args():
|
|
""" parse_args """
|
|
parser = ArgumentParser(description="resnet50 example")
|
|
|
|
parser.add_argument("--model", type=str, default="resnet50",
|
|
help="the network architecture for training or testing")
|
|
parser.add_argument("--phase", type=str, default="test",
|
|
help="the phase of the model, default is test.")
|
|
parser.add_argument("--file_path", type=str, default="/data/file/test1.txt",
|
|
help="data directory of training or testing")
|
|
parser.add_argument("--batch_size", type=int, default=1,
|
|
help="batch size for training or testing ")
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def get_model(name):
|
|
""" get_model """
|
|
if name not in __factory:
|
|
raise KeyError("unknown model:", name)
|
|
return __factory[name]
|
|
|
|
|
|
def get_dataset(batch_size=32):
|
|
""" get_dataset """
|
|
dataset_types = np.float32
|
|
dataset_shapes = (batch_size, 3, 224, 224)
|
|
|
|
dataset = MindData(size=2, batch_size=batch_size,
|
|
np_types=dataset_types,
|
|
output_shapes=dataset_shapes,
|
|
input_indexs=(0, 1))
|
|
return dataset
|
|
|
|
|
|
# pylint: disable=unused-argument
|
|
def test(name, file_path, batch_size):
|
|
""" test """
|
|
network = get_model(name)
|
|
|
|
batch = get_dataset(batch_size=batch_size)
|
|
|
|
data_list = []
|
|
for data in batch:
|
|
data_list.append(data.asnumpy())
|
|
batch_data = np.concatenate(data_list, axis=0).transpose((0, 3, 1, 2))
|
|
input_tensor = Tensor(batch_data)
|
|
print(input_tensor.shape)
|
|
network(input_tensor)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
if args.phase == "train":
|
|
raise NotImplementedError
|
|
test(args.model, args.file_path, args.batch_size)
|