forked from mindspore-Ecosystem/mindspore
!15807 add resnet infer
From: @jiangzg001 Reviewed-by: @oacjiewen,@wuxuejian,@c_34 Signed-off-by: @c_34
This commit is contained in:
commit
c0690309aa
|
@ -0,0 +1,111 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""train resnet."""
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Image classification')
|
||||||
|
parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet18, '
|
||||||
|
'resnet50 or resnet101')
|
||||||
|
parser.add_argument('--dataset', type=str, default=None, help='Dataset, imagenet2012')
|
||||||
|
|
||||||
|
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||||
|
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||||
|
parser.add_argument('--device_target', type=str, default='Ascend', choices=("Ascend", "GPU", "CPU"),
|
||||||
|
help="Device target, support Ascend, GPU and CPU.")
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
|
if args_opt.dataset != "imagenet2012":
|
||||||
|
raise ValueError("Currently only support imagenet2012 dataset format")
|
||||||
|
if args_opt.net in ("resnet18", "resnet50"):
|
||||||
|
if args_opt.net == "resnet18":
|
||||||
|
from src.resnet import resnet18 as resnet
|
||||||
|
if args_opt.net == "resnet50":
|
||||||
|
from src.resnet import resnet50 as resnet
|
||||||
|
from src.config import config2 as config
|
||||||
|
from src.dataset_infer import create_dataset
|
||||||
|
|
||||||
|
elif args_opt.net == "resnet101":
|
||||||
|
from src.resnet import resnet101 as resnet
|
||||||
|
from src.config import config3 as config
|
||||||
|
from src.dataset_infer import create_dataset2 as create_dataset
|
||||||
|
else:
|
||||||
|
from src.resnet import se_resnet50 as resnet
|
||||||
|
from src.config import config4 as config
|
||||||
|
from src.dataset_infer import create_dataset3 as create_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def show_predict_info(label_list, prediction_list, filename_list, predict_ng):
|
||||||
|
label_index = 0
|
||||||
|
for label_index, predict_index, filename in zip(label_list, prediction_list, filename_list):
|
||||||
|
filename = np.array(filename).tostring().decode('utf8')
|
||||||
|
if label_index == -1:
|
||||||
|
print("file: '{}' predict class id is: {}".format(filename, predict_index))
|
||||||
|
continue
|
||||||
|
if predict_index != label_index:
|
||||||
|
predict_ng.append((filename, predict_index, label_index))
|
||||||
|
print("file: '{}' predict wrong, predict class id is: {}, "
|
||||||
|
"label is {}".format(filename, predict_index, label_index))
|
||||||
|
return predict_ng, label_index
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
target = args_opt.device_target
|
||||||
|
|
||||||
|
# init context
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
||||||
|
if target == "Ascend":
|
||||||
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
context.set_context(device_id=device_id)
|
||||||
|
|
||||||
|
# create dataset
|
||||||
|
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size,
|
||||||
|
target=target)
|
||||||
|
step_size = dataset.get_dataset_size()
|
||||||
|
|
||||||
|
# define net
|
||||||
|
net = resnet(class_num=config.class_num)
|
||||||
|
|
||||||
|
# load checkpoint
|
||||||
|
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
net.set_train(False)
|
||||||
|
|
||||||
|
print("start infer")
|
||||||
|
predict_negative = []
|
||||||
|
total_sample = step_size * config.batch_size
|
||||||
|
only_file = 0
|
||||||
|
data_loader = dataset.create_dict_iterator(output_numpy=True, num_epochs=1)
|
||||||
|
for i, data in enumerate(data_loader):
|
||||||
|
images = data["image"]
|
||||||
|
label = data["label"]
|
||||||
|
file_name = data["filename"]
|
||||||
|
res = net(Tensor(images))
|
||||||
|
res = res.asnumpy()
|
||||||
|
predict_id = np.argmax(res, axis=1)
|
||||||
|
predict_negative, only_file = show_predict_info(label.tolist(), predict_id.tolist(),
|
||||||
|
file_name.tolist(), predict_negative)
|
||||||
|
|
||||||
|
if only_file != -1:
|
||||||
|
print(f"total {total_sample} data, top1 acc is {(total_sample - len(predict_negative)) * 1.0 / total_sample}")
|
||||||
|
else:
|
||||||
|
print("infer completed")
|
|
@ -0,0 +1,78 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 4 ]
|
||||||
|
then
|
||||||
|
echo "Usage: bash run_eval.sh [resnet18|resnet50|resnet101|se-resnet50] [imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $1 != "resnet18" ] && [ $1 != "resnet50" ] && [ $1 != "resnet101" ] && [ $1 != "se-resnet50" ]
|
||||||
|
then
|
||||||
|
echo "error: the selected net is neither resnet50 nor resnet101 nor se-resnet50"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $2 != "imagenet2012" ]
|
||||||
|
then
|
||||||
|
echo "error: only support imagenet2012"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
PATH1=$(get_real_path $3)
|
||||||
|
PATH2=$(get_real_path $4)
|
||||||
|
|
||||||
|
|
||||||
|
if [ ! -d $PATH1 ]
|
||||||
|
then
|
||||||
|
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f $PATH2 ]
|
||||||
|
then
|
||||||
|
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
ulimit -u unlimited
|
||||||
|
export DEVICE_NUM=1
|
||||||
|
export DEVICE_ID=0
|
||||||
|
export RANK_SIZE=$DEVICE_NUM
|
||||||
|
export RANK_ID=0
|
||||||
|
|
||||||
|
if [ -d "infer" ];
|
||||||
|
then
|
||||||
|
rm -rf ./infer
|
||||||
|
fi
|
||||||
|
mkdir ./infer
|
||||||
|
cp ../*.py ./infer
|
||||||
|
cp *.sh ./infer
|
||||||
|
cp -r ../src ./infer
|
||||||
|
cd ./infer || exit
|
||||||
|
env > env.log
|
||||||
|
echo "start evaluation for device $DEVICE_ID"
|
||||||
|
python infer.py --net=$1 --dataset=$2 --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log &
|
||||||
|
cd ..
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -17,8 +17,10 @@ network config setting, will be used in train.py and eval.py
|
||||||
"""
|
"""
|
||||||
from easydict import EasyDict as ed
|
from easydict import EasyDict as ed
|
||||||
# config optimizer for resnet50, imagenet2012. Momentum is default, Thor is optional.
|
# config optimizer for resnet50, imagenet2012. Momentum is default, Thor is optional.
|
||||||
|
# infer_label is a directory and label mapping table. such as 'infer_label': {"directory0": 0, "directory1": 1, ...}
|
||||||
cfg = ed({
|
cfg = ed({
|
||||||
'optimizer': 'Momentum',
|
'optimizer': 'Momentum',
|
||||||
|
'infer_label': {}
|
||||||
})
|
})
|
||||||
|
|
||||||
# config for resent50, cifar10
|
# config for resent50, cifar10
|
||||||
|
|
|
@ -0,0 +1,319 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
create train or eval dataset.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
import mindspore.dataset.vision.c_transforms as C
|
||||||
|
import mindspore.dataset.transforms.c_transforms as C2
|
||||||
|
from mindspore.communication.management import init, get_rank, get_group_size
|
||||||
|
from src.config import cfg
|
||||||
|
|
||||||
|
|
||||||
|
class ImgDataset:
|
||||||
|
"""
|
||||||
|
create img dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
Returns:
|
||||||
|
de_dataset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dataset_path):
|
||||||
|
super(ImgDataset, self).__init__()
|
||||||
|
self.data = []
|
||||||
|
self.dir_label_dict = {}
|
||||||
|
self.img_format = (".bmp", ".png", ".jpg", ".jpeg")
|
||||||
|
self.dir_label = cfg.infer_label
|
||||||
|
dataset_list = sorted(os.listdir(dataset_path))
|
||||||
|
file_exist = dir_exist = False
|
||||||
|
for index, data_name in enumerate(dataset_list):
|
||||||
|
data_path = os.path.join(dataset_path, data_name)
|
||||||
|
if os.path.isdir(data_path):
|
||||||
|
dir_exist = True
|
||||||
|
self.dir_label_dict = self.get_file_label(data_name, data_path, index)
|
||||||
|
if os.path.isfile(data_path):
|
||||||
|
file_exist = True
|
||||||
|
self.dir_label_dict = self.get_file_label(data_name, data_path, index=-1)
|
||||||
|
if dir_exist and file_exist:
|
||||||
|
raise ValueError(f"{dataset_path} can not concurrently have image file and directory")
|
||||||
|
|
||||||
|
for data_name, img_label in self.dir_label_dict.items():
|
||||||
|
if os.path.isfile(data_name):
|
||||||
|
if not data_name.lower().endswith(self.img_format):
|
||||||
|
continue
|
||||||
|
img_data, file_name = self.read_image_data(data_name)
|
||||||
|
self.data.append((img_label, img_data, file_name))
|
||||||
|
else:
|
||||||
|
for file in os.listdir(data_name):
|
||||||
|
if not file.lower().endswith(self.img_format):
|
||||||
|
continue
|
||||||
|
file_path = os.path.join(data_name, file)
|
||||||
|
img_data, file_name = self.read_image_data(file_path)
|
||||||
|
self.data.append((img_label, img_data, file_name))
|
||||||
|
|
||||||
|
def get_file_label(self, data_name, data_path, index):
|
||||||
|
if self.dir_label and data_name not in self.dir_label:
|
||||||
|
return self.dir_label_dict
|
||||||
|
if self.dir_label and os.path.isdir(data_name):
|
||||||
|
data_path_name = os.path.split(data_path)[-1]
|
||||||
|
self.dir_label_dict[data_path] = self.dir_label[data_path_name]
|
||||||
|
else:
|
||||||
|
self.dir_label_dict[data_path] = index
|
||||||
|
return self.dir_label_dict
|
||||||
|
|
||||||
|
def read_image_data(self, file_path):
|
||||||
|
file_name = os.path.split(file_path)[-1]
|
||||||
|
img_data = np.fromfile(file_path, np.uint8)
|
||||||
|
file_name = np.fromstring(file_name, np.uint8)
|
||||||
|
file_name = np.pad(file_name, (0, 300 - file_name.shape[0]))
|
||||||
|
return img_data, file_name
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.data[index]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend", distribute=False):
|
||||||
|
"""
|
||||||
|
create a train or eval imagenet2012 dataset for resnet50
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_path(string): the path of dataset.
|
||||||
|
do_train(bool): whether dataset is used for train or eval.
|
||||||
|
repeat_num(int): the repeat times of dataset. Default: 1
|
||||||
|
batch_size(int): the batch size of dataset. Default: 32
|
||||||
|
target(str): the device target. Default: Ascend
|
||||||
|
distribute(bool): data for distribute or not. Default: False
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dataset
|
||||||
|
"""
|
||||||
|
if target == "Ascend":
|
||||||
|
device_num, rank_id = _get_rank_info()
|
||||||
|
else:
|
||||||
|
if distribute:
|
||||||
|
init()
|
||||||
|
rank_id = get_rank()
|
||||||
|
device_num = get_group_size()
|
||||||
|
else:
|
||||||
|
device_num = 1
|
||||||
|
|
||||||
|
dataset_generator = ImgDataset(dataset_path)
|
||||||
|
if device_num == 1:
|
||||||
|
data_set = ds.GeneratorDataset(source=dataset_generator, column_names=["label", "image", "filename"],
|
||||||
|
num_parallel_workers=8, shuffle=True)
|
||||||
|
else:
|
||||||
|
data_set = ds.GeneratorDataset(source=dataset_generator, column_names=["label", "image", "filename"],
|
||||||
|
num_parallel_workers=8, shuffle=True,
|
||||||
|
num_shards=device_num, shard_id=rank_id)
|
||||||
|
|
||||||
|
image_size = 224
|
||||||
|
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
|
||||||
|
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
|
||||||
|
|
||||||
|
# define map operations
|
||||||
|
if do_train:
|
||||||
|
trans = [
|
||||||
|
C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
||||||
|
C.RandomHorizontalFlip(prob=0.5),
|
||||||
|
C.Normalize(mean=mean, std=std),
|
||||||
|
C.HWC2CHW()
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
trans = [
|
||||||
|
C.Decode(),
|
||||||
|
C.Resize(256),
|
||||||
|
C.CenterCrop(image_size),
|
||||||
|
C.Normalize(mean=mean, std=std),
|
||||||
|
C.HWC2CHW()
|
||||||
|
]
|
||||||
|
|
||||||
|
type_cast_op = C2.TypeCast(mstype.int32)
|
||||||
|
|
||||||
|
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
|
||||||
|
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
|
||||||
|
|
||||||
|
if do_train:
|
||||||
|
data_set = data_set.project(["image", "label"])
|
||||||
|
else:
|
||||||
|
data_set = data_set.project(["image", "label", "filename"])
|
||||||
|
|
||||||
|
# apply batch operations
|
||||||
|
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||||
|
|
||||||
|
# apply dataset repeat operation
|
||||||
|
data_set = data_set.repeat(repeat_num)
|
||||||
|
|
||||||
|
return data_set
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend", distribute=False):
|
||||||
|
"""
|
||||||
|
create a train or eval imagenet2012 dataset for resnet101
|
||||||
|
Args:
|
||||||
|
dataset_path(string): the path of dataset.
|
||||||
|
do_train(bool): whether dataset is used for train or eval.
|
||||||
|
repeat_num(int): the repeat times of dataset. Default: 1
|
||||||
|
batch_size(int): the batch size of dataset. Default: 32
|
||||||
|
target(str): the device target. Default: Ascend
|
||||||
|
distribute(bool): data for distribute or not. Default: False
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dataset
|
||||||
|
"""
|
||||||
|
if target == "Ascend":
|
||||||
|
device_num, rank_id = _get_rank_info()
|
||||||
|
else:
|
||||||
|
if distribute:
|
||||||
|
init()
|
||||||
|
rank_id = get_rank()
|
||||||
|
device_num = get_group_size()
|
||||||
|
else:
|
||||||
|
device_num = 1
|
||||||
|
rank_id = 1
|
||||||
|
dataset_generator = ImgDataset(dataset_path)
|
||||||
|
if device_num == 1:
|
||||||
|
data_set = ds.GeneratorDataset(source=dataset_generator, column_names=["label", "image", "filename"],
|
||||||
|
num_parallel_workers=8, shuffle=True)
|
||||||
|
else:
|
||||||
|
data_set = ds.GeneratorDataset(source=dataset_generator, column_names=["label", "image", "filename"],
|
||||||
|
num_parallel_workers=8, shuffle=True,
|
||||||
|
num_shards=device_num, shard_id=rank_id)
|
||||||
|
image_size = 224
|
||||||
|
mean = [0.475 * 255, 0.451 * 255, 0.392 * 255]
|
||||||
|
std = [0.275 * 255, 0.267 * 255, 0.278 * 255]
|
||||||
|
|
||||||
|
# define map operations
|
||||||
|
if do_train:
|
||||||
|
trans = [
|
||||||
|
C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
||||||
|
C.RandomHorizontalFlip(rank_id / (rank_id + 1)),
|
||||||
|
C.Normalize(mean=mean, std=std),
|
||||||
|
C.HWC2CHW()
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
trans = [
|
||||||
|
C.Decode(),
|
||||||
|
C.Resize(256),
|
||||||
|
C.CenterCrop(image_size),
|
||||||
|
C.Normalize(mean=mean, std=std),
|
||||||
|
C.HWC2CHW()
|
||||||
|
]
|
||||||
|
|
||||||
|
type_cast_op = C2.TypeCast(mstype.int32)
|
||||||
|
|
||||||
|
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
|
||||||
|
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
|
||||||
|
if do_train:
|
||||||
|
data_set = data_set.project(["image", "label"])
|
||||||
|
else:
|
||||||
|
data_set = data_set.project(["image", "label", "filename"])
|
||||||
|
# apply batch operations
|
||||||
|
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||||
|
# apply dataset repeat operation
|
||||||
|
data_set = data_set.repeat(repeat_num)
|
||||||
|
|
||||||
|
return data_set
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend", distribute=False):
|
||||||
|
"""
|
||||||
|
create a train or eval imagenet2012 dataset for se-resnet50
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_path(string): the path of dataset.
|
||||||
|
do_train(bool): whether dataset is used for train or eval.
|
||||||
|
repeat_num(int): the repeat times of dataset. Default: 1
|
||||||
|
batch_size(int): the batch size of dataset. Default: 32
|
||||||
|
target(str): the device target. Default: Ascend
|
||||||
|
distribute(bool): data for distribute or not. Default: False
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dataset
|
||||||
|
"""
|
||||||
|
if target == "Ascend":
|
||||||
|
device_num, rank_id = _get_rank_info()
|
||||||
|
else:
|
||||||
|
if distribute:
|
||||||
|
init()
|
||||||
|
rank_id = get_rank()
|
||||||
|
device_num = get_group_size()
|
||||||
|
else:
|
||||||
|
device_num = 1
|
||||||
|
dataset_generator = ImgDataset(dataset_path)
|
||||||
|
if device_num == 1:
|
||||||
|
data_set = ds.GeneratorDataset(source=dataset_generator, column_names=["label", "image", "filename"],
|
||||||
|
num_parallel_workers=8, shuffle=True)
|
||||||
|
else:
|
||||||
|
data_set = ds.GeneratorDataset(source=dataset_generator, column_names=["label", "image", "filename"],
|
||||||
|
num_parallel_workers=8, shuffle=True,
|
||||||
|
num_shards=device_num, shard_id=rank_id)
|
||||||
|
image_size = 224
|
||||||
|
mean = [123.68, 116.78, 103.94]
|
||||||
|
std = [1.0, 1.0, 1.0]
|
||||||
|
|
||||||
|
# define map operations
|
||||||
|
if do_train:
|
||||||
|
trans = [
|
||||||
|
C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
||||||
|
C.RandomHorizontalFlip(prob=0.5),
|
||||||
|
C.Normalize(mean=mean, std=std),
|
||||||
|
C.HWC2CHW()
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
trans = [
|
||||||
|
C.Decode(),
|
||||||
|
C.Resize(292),
|
||||||
|
C.CenterCrop(256),
|
||||||
|
C.Normalize(mean=mean, std=std),
|
||||||
|
C.HWC2CHW()
|
||||||
|
]
|
||||||
|
|
||||||
|
type_cast_op = C2.TypeCast(mstype.int32)
|
||||||
|
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=12)
|
||||||
|
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=12)
|
||||||
|
if do_train:
|
||||||
|
data_set = data_set.project(["image", "label"])
|
||||||
|
else:
|
||||||
|
data_set = data_set.project(["image", "label", "filename"])
|
||||||
|
# apply batch operations
|
||||||
|
data_set = data_set.batch(batch_size, drop_remainder=True)
|
||||||
|
|
||||||
|
# apply dataset repeat operation
|
||||||
|
data_set = data_set.repeat(repeat_num)
|
||||||
|
|
||||||
|
return data_set
|
||||||
|
|
||||||
|
|
||||||
|
def _get_rank_info():
|
||||||
|
"""
|
||||||
|
get rank size and rank id
|
||||||
|
"""
|
||||||
|
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||||
|
|
||||||
|
if rank_size > 1:
|
||||||
|
rank_size = get_group_size()
|
||||||
|
rank_id = get_rank()
|
||||||
|
else:
|
||||||
|
rank_size = 1
|
||||||
|
rank_id = 0
|
||||||
|
|
||||||
|
return rank_size, rank_id
|
Loading…
Reference in New Issue