forked from mindspore-Ecosystem/mindspore
add YOLOv3 infer scipt and change dataset to MindRecord
This commit is contained in:
parent
cc0ba93d17
commit
0c81759ae6
|
@ -26,6 +26,7 @@ class ConfigYOLOV3ResNet18:
|
|||
img_shape = [352, 640]
|
||||
feature_shape = [32, 3, 352, 640]
|
||||
num_classes = 80
|
||||
nms_max_num = 50
|
||||
|
||||
backbone_input_shape = [64, 64, 128, 256]
|
||||
backbone_shape = [64, 128, 256, 512]
|
||||
|
@ -33,6 +34,8 @@ class ConfigYOLOV3ResNet18:
|
|||
backbone_stride = [1, 2, 2, 2]
|
||||
|
||||
ignore_threshold = 0.5
|
||||
obj_threshold = 0.3
|
||||
nms_threshold = 0.4
|
||||
|
||||
anchor_scales = [(10, 13),
|
||||
(16, 30),
|
||||
|
|
|
@ -16,16 +16,14 @@
|
|||
"""YOLOv3 dataset"""
|
||||
from __future__ import division
|
||||
|
||||
import abc
|
||||
import io
|
||||
import os
|
||||
import math
|
||||
import json
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from matplotlib.colors import rgb_to_hsv, hsv_to_rgb
|
||||
import mindspore.dataset as de
|
||||
from mindspore.mindrecord import FileWriter
|
||||
import mindspore.dataset.transforms.vision.py_transforms as P
|
||||
import mindspore.dataset.transforms.vision.c_transforms as C
|
||||
from config import ConfigYOLOV3ResNet18
|
||||
|
||||
iter_cnt = 0
|
||||
|
@ -114,6 +112,29 @@ def preprocess_fn(image, box, is_training):
|
|||
|
||||
return y_true[0], y_true[1], y_true[2], pad_gt_box0, pad_gt_box1, pad_gt_box2
|
||||
|
||||
def _infer_data(img_data, input_shape, box):
|
||||
w, h = img_data.size
|
||||
input_h, input_w = input_shape
|
||||
scale = min(float(input_w) / float(w), float(input_h) / float(h))
|
||||
nw = int(w * scale)
|
||||
nh = int(h * scale)
|
||||
img_data = img_data.resize((nw, nh), Image.BICUBIC)
|
||||
|
||||
new_image = np.zeros((input_h, input_w, 3), np.float32)
|
||||
new_image.fill(128)
|
||||
img_data = np.array(img_data)
|
||||
if len(img_data.shape) == 2:
|
||||
img_data = np.expand_dims(img_data, axis=-1)
|
||||
img_data = np.concatenate([img_data, img_data, img_data], axis=-1)
|
||||
|
||||
dh = int((input_h - nh) / 2)
|
||||
dw = int((input_w - nw) / 2)
|
||||
new_image[dh:(nh + dh), dw:(nw + dw), :] = img_data
|
||||
new_image /= 255.
|
||||
new_image = np.transpose(new_image, (2, 0, 1))
|
||||
new_image = np.expand_dims(new_image, 0)
|
||||
return new_image, np.array([h, w], np.float32), box
|
||||
|
||||
def _data_aug(image, box, is_training, jitter=0.3, hue=0.1, sat=1.5, val=1.5, image_size=(352, 640)):
|
||||
"""Data augmentation function."""
|
||||
if not isinstance(image, Image.Image):
|
||||
|
@ -124,32 +145,7 @@ def preprocess_fn(image, box, is_training):
|
|||
h, w = image_size
|
||||
|
||||
if not is_training:
|
||||
image = image.resize((w, h), Image.BICUBIC)
|
||||
image_data = np.array(image) / 255.
|
||||
if len(image_data.shape) == 2:
|
||||
image_data = np.expand_dims(image_data, axis=-1)
|
||||
image_data = np.concatenate([image_data, image_data, image_data], axis=-1)
|
||||
image_data = image_data.astype(np.float32)
|
||||
|
||||
# correct boxes
|
||||
box_data = np.zeros((max_boxes, 5))
|
||||
if len(box) >= 1:
|
||||
np.random.shuffle(box)
|
||||
if len(box) > max_boxes:
|
||||
box = box[:max_boxes]
|
||||
# xmin ymin xmax ymax
|
||||
box[:, [0, 2]] = box[:, [0, 2]] * float(w) / float(iw)
|
||||
box[:, [1, 3]] = box[:, [1, 3]] * float(h) / float(ih)
|
||||
box_data[:len(box)] = box
|
||||
else:
|
||||
image_data, box_data = None, None
|
||||
|
||||
# preprocess bounding boxes
|
||||
bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \
|
||||
_preprocess_true_boxes(box_data, anchors, image_size)
|
||||
|
||||
return image_data, bbox_true_1, bbox_true_2, bbox_true_3, \
|
||||
ori_image_shape, gt_box1, gt_box2, gt_box3
|
||||
return _infer_data(image, image_size, box)
|
||||
|
||||
flip = _rand() < .5
|
||||
# correct boxes
|
||||
|
@ -235,12 +231,16 @@ def preprocess_fn(image, box, is_training):
|
|||
return image_data, bbox_true_1, bbox_true_2, bbox_true_3, \
|
||||
ori_image_shape, gt_box1, gt_box2, gt_box3
|
||||
|
||||
if is_training:
|
||||
images, bbox_1, bbox_2, bbox_3, _, gt_box1, gt_box2, gt_box3 = _data_aug(image, box, is_training)
|
||||
return images, bbox_1, bbox_2, bbox_3, gt_box1, gt_box2, gt_box3
|
||||
|
||||
images, shape, anno = _data_aug(image, box, is_training)
|
||||
return images, shape, anno
|
||||
|
||||
|
||||
def anno_parser(annos_str):
|
||||
"""Annotation parser."""
|
||||
"""Parse annotation from string to list."""
|
||||
annos = []
|
||||
for anno_str in annos_str:
|
||||
anno = list(map(int, anno_str.strip().split(',')))
|
||||
|
@ -248,135 +248,59 @@ def anno_parser(annos_str):
|
|||
return annos
|
||||
|
||||
|
||||
def expand_path(path):
|
||||
"""Get file list from path."""
|
||||
files = []
|
||||
if os.path.isdir(path):
|
||||
for file in os.listdir(path):
|
||||
if os.path.isfile(os.path.join(path, file)):
|
||||
files.append(file)
|
||||
else:
|
||||
raise RuntimeError("Path given is not valid.")
|
||||
return files
|
||||
|
||||
|
||||
def read_image(img_path):
|
||||
"""Read image with PIL."""
|
||||
with open(img_path, "rb") as f:
|
||||
img = f.read()
|
||||
data = io.BytesIO(img)
|
||||
img = Image.open(data)
|
||||
return np.array(img)
|
||||
|
||||
|
||||
class BaseDataset():
|
||||
"""BaseDataset for GeneratorDataset iterator."""
|
||||
def __init__(self, image_dir, anno_path):
|
||||
self.image_dir = image_dir
|
||||
self.anno_path = anno_path
|
||||
self.cur_index = 0
|
||||
self.samples = []
|
||||
self.image_anno_dict = {}
|
||||
self._load_samples()
|
||||
|
||||
def __getitem__(self, item):
|
||||
sample = self.samples[item]
|
||||
return self._next_data(sample, self.image_dir, self.image_anno_dict)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
@staticmethod
|
||||
def _next_data(sample, image_dir, image_anno_dict):
|
||||
"""Get next data."""
|
||||
image = read_image(os.path.join(image_dir, sample))
|
||||
annos = image_anno_dict[sample]
|
||||
return [np.array(image), np.array(annos)]
|
||||
|
||||
@abc.abstractmethod
|
||||
def _load_samples(self):
|
||||
"""Base load samples."""
|
||||
|
||||
|
||||
class YoloDataset(BaseDataset):
|
||||
"""YoloDataset for GeneratorDataset iterator."""
|
||||
def _load_samples(self):
|
||||
"""Load samples."""
|
||||
image_files_raw = expand_path(self.image_dir)
|
||||
self.samples = self._filter_valid_data(self.anno_path, image_files_raw)
|
||||
self.dataset_size = len(self.samples)
|
||||
if self.dataset_size == 0:
|
||||
raise RuntimeError("Valid dataset is none!")
|
||||
|
||||
def _filter_valid_data(self, anno_path, image_files_raw):
|
||||
"""Filter valid data."""
|
||||
def filter_valid_data(image_dir, anno_path):
|
||||
"""Filter valid image file, which both in image_dir and anno_path."""
|
||||
image_files = []
|
||||
anno_dict = {}
|
||||
print("Start filter valid data.")
|
||||
image_anno_dict = {}
|
||||
if not os.path.isdir(image_dir):
|
||||
raise RuntimeError("Path given is not valid.")
|
||||
if not os.path.isfile(anno_path):
|
||||
raise RuntimeError("Annotation file is not valid.")
|
||||
|
||||
with open(anno_path, "rb") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line_str = line.decode("utf-8")
|
||||
line_str = line.decode("utf-8").strip()
|
||||
line_split = str(line_str).split(' ')
|
||||
anno_dict[line_split[0].split("/")[-1]] = line_split[1:]
|
||||
anno_set = set(anno_dict.keys())
|
||||
image_set = set(image_files_raw)
|
||||
for image_file in (anno_set & image_set):
|
||||
image_files.append(image_file)
|
||||
self.image_anno_dict[image_file] = anno_parser(anno_dict[image_file])
|
||||
image_files.sort()
|
||||
print("Filter valid data done!")
|
||||
return image_files
|
||||
file_name = line_split[0]
|
||||
if os.path.isfile(os.path.join(image_dir, file_name)):
|
||||
image_anno_dict[file_name] = anno_parser(line_split[1:])
|
||||
image_files.append(file_name)
|
||||
return image_files, image_anno_dict
|
||||
|
||||
|
||||
class DistributedSampler():
|
||||
"""DistributedSampler for YOLOv3"""
|
||||
def __init__(self, dataset_size, batch_size, num_replicas=None, rank=None, shuffle=True):
|
||||
if num_replicas is None:
|
||||
num_replicas = 1
|
||||
if rank is None:
|
||||
rank = 0
|
||||
self.dataset_size = dataset_size
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank % num_replicas
|
||||
self.epoch = 0
|
||||
self.num_samples = max(batch_size, int(math.ceil(dataset_size * 1.0 / self.num_replicas)))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
def data_to_mindrecord_byte_image(image_dir, anno_path, mindrecord_dir, prefix="yolo.mindrecord", file_num=8):
|
||||
"""Create MindRecord file by image_dir and anno_path."""
|
||||
mindrecord_path = os.path.join(mindrecord_dir, prefix)
|
||||
writer = FileWriter(mindrecord_path, file_num)
|
||||
image_files, image_anno_dict = filter_valid_data(image_dir, anno_path)
|
||||
|
||||
def __iter__(self):
|
||||
# deterministically shuffle based on epoch
|
||||
if self.shuffle:
|
||||
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
|
||||
indices = indices.tolist()
|
||||
else:
|
||||
indices = list(range(self.dataset_size))
|
||||
yolo_json = {
|
||||
"image": {"type": "bytes"},
|
||||
"annotation": {"type": "int64", "shape": [-1, 5]},
|
||||
}
|
||||
writer.add_schema(yolo_json, "yolo_json")
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
for image_name in image_files:
|
||||
image_path = os.path.join(image_dir, image_name)
|
||||
with open(image_path, 'rb') as f:
|
||||
img = f.read()
|
||||
annos = np.array(image_anno_dict[image_name])
|
||||
row = {"image": img, "annotation": annos}
|
||||
writer.write_raw_data([row])
|
||||
writer.commit()
|
||||
|
||||
|
||||
def create_yolo_dataset(image_dir, anno_path, batch_size=32, repeat_num=10, device_num=1, rank=0,
|
||||
def create_yolo_dataset(mindrecord_dir, batch_size=32, repeat_num=10, device_num=1, rank=0,
|
||||
is_training=True, num_parallel_workers=8):
|
||||
"""Creatr YOLOv3 dataset with GeneratorDataset."""
|
||||
yolo_dataset = YoloDataset(image_dir=image_dir, anno_path=anno_path)
|
||||
distributed_sampler = DistributedSampler(yolo_dataset.dataset_size, batch_size, device_num, rank)
|
||||
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"], sampler=distributed_sampler)
|
||||
ds.set_dataset_size(len(distributed_sampler))
|
||||
"""Creatr YOLOv3 dataset with MindDataset."""
|
||||
ds = de.MindDataset(mindrecord_dir, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank,
|
||||
num_parallel_workers=num_parallel_workers, shuffle=is_training)
|
||||
decode = C.Decode()
|
||||
ds = ds.map(input_columns=["image"], operations=decode)
|
||||
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
|
||||
|
||||
if is_training:
|
||||
hwc_to_chw = P.HWC2CHW()
|
||||
ds = ds.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
|
||||
|
@ -386,4 +310,9 @@ def create_yolo_dataset(image_dir, anno_path, batch_size=32, repeat_num=10, devi
|
|||
ds = ds.shuffle(buffer_size=256)
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_num)
|
||||
else:
|
||||
ds = ds.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "image_shape", "annotation"],
|
||||
columns_order=["image", "image_shape", "annotation"],
|
||||
operations=compose_map_func, num_parallel_workers=num_parallel_workers)
|
||||
return ds
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Evaluation for yolo_v3"""
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.model_zoo.yolov3 import yolov3_resnet18, YoloWithEval
|
||||
from dataset import create_yolo_dataset, data_to_mindrecord_byte_image
|
||||
from config import ConfigYOLOV3ResNet18
|
||||
from util import metrics
|
||||
|
||||
def yolo_eval(dataset_path, ckpt_path):
|
||||
"""Yolov3 evaluation."""
|
||||
|
||||
ds = create_yolo_dataset(dataset_path, is_training=False)
|
||||
config = ConfigYOLOV3ResNet18()
|
||||
net = yolov3_resnet18(config)
|
||||
eval_net = YoloWithEval(net, config)
|
||||
print("Load Checkpoint!")
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
|
||||
eval_net.set_train(False)
|
||||
i = 1.
|
||||
total = ds.get_dataset_size()
|
||||
start = time.time()
|
||||
pred_data = []
|
||||
print("\n========================================\n")
|
||||
print("total images num: ", total)
|
||||
print("Processing, please wait a moment.")
|
||||
for data in ds.create_dict_iterator():
|
||||
img_np = data['image']
|
||||
image_shape = data['image_shape']
|
||||
annotation = data['annotation']
|
||||
|
||||
eval_net.set_train(False)
|
||||
output = eval_net(Tensor(img_np), Tensor(image_shape))
|
||||
for batch_idx in range(img_np.shape[0]):
|
||||
pred_data.append({"boxes": output[0].asnumpy()[batch_idx],
|
||||
"box_scores": output[1].asnumpy()[batch_idx],
|
||||
"annotation": annotation})
|
||||
percent = round(i / total * 100, 2)
|
||||
|
||||
print(' %s [%d/%d]' % (str(percent) + '%', i, total), end='\r')
|
||||
i += 1
|
||||
print(' %s [%d/%d] cost %d ms' % (str(100.0) + '%', total, total, int((time.time() - start) * 1000)), end='\n')
|
||||
|
||||
precisions, recalls = metrics(pred_data)
|
||||
print("\n========================================\n")
|
||||
for i in range(config.num_classes):
|
||||
print("class {} precision is {:.2f}%, recall is {:.2f}%".format(i, precisions[i] * 100, recalls[i] * 100))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Yolov3 evaluation')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--mindrecord_dir", type=str, default="./Mindrecord_eval",
|
||||
help="Mindrecord directory. If the mindrecord_dir is empty, it wil generate mindrecord file by"
|
||||
"image_dir and anno_path. Note if mindrecord_dir isn't empty, it will use mindrecord_dir "
|
||||
"rather than image_dir and anno_path. Default is ./Mindrecord_eval")
|
||||
parser.add_argument("--image_dir", type=str, default="", help="Dataset directory, "
|
||||
"the absolute image path is joined by the image_dir "
|
||||
"and the relative path in anno_path.")
|
||||
parser.add_argument("--anno_path", type=str, default="", help="Annotation path.")
|
||||
parser.add_argument("--ckpt_path", type=str, required=True, help="Checkpoint path.")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True,
|
||||
enable_auto_mixed_precision=False)
|
||||
|
||||
# It will generate mindrecord file in args_opt.mindrecord_dir,
|
||||
# and the file name is yolo.mindrecord0, 1, ... file_num.
|
||||
if not os.path.isdir(args_opt.mindrecord_dir):
|
||||
os.makedirs(args_opt.mindrecord_dir)
|
||||
|
||||
prefix = "yolo.mindrecord"
|
||||
mindrecord_file = os.path.join(args_opt.mindrecord_dir, prefix + "0")
|
||||
if not os.path.exists(mindrecord_file):
|
||||
if os.path.isdir(args_opt.image_dir) and os.path.exists(args_opt.anno_path):
|
||||
print("Create Mindrecord")
|
||||
data_to_mindrecord_byte_image(args_opt.image_dir,
|
||||
args_opt.anno_path,
|
||||
args_opt.mindrecord_dir,
|
||||
prefix=prefix,
|
||||
file_num=8)
|
||||
print("Create Mindrecord Done, at {}".format(args_opt.mindrecord_dir))
|
||||
else:
|
||||
print("image_dir or anno_path not exits")
|
||||
print("Start Eval!")
|
||||
yolo_eval(mindrecord_file, args_opt.ckpt_path)
|
|
@ -14,17 +14,26 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE IMAGE_DIR ANNO_PATH MINDSPORE_HCCL_CONFIG_PATH"
|
||||
echo "for example: sh run_distribute_train.sh 8 100 ./dataset/coco/train2017 ./dataset/train.txt ./hccl.json"
|
||||
echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE MINDRECORD_DIR IMAGE_DIR ANNO_PATH MINDSPORE_HCCL_CONFIG_PATH"
|
||||
echo "for example: sh run_distribute_train.sh 8 100 /data/Mindrecord_train /data /data/train.txt /data/hccl.json"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
EPOCH_SIZE=$2
|
||||
MINDRECORD_DIR=$3
|
||||
IMAGE_DIR=$4
|
||||
ANNO_PATH=$5
|
||||
|
||||
# Before start distribute train, first create mindrecord files.
|
||||
python train.py --only_create_dataset=1 --mindrecord_dir=$MINDRECORD_DIR --image_dir=$IMAGE_DIR \
|
||||
--anno_path=$ANNO_PATH
|
||||
|
||||
echo "After running the scipt, the network runs in the background. The log will be generated in LOGx/log.txt"
|
||||
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$6
|
||||
export RANK_SIZE=$1
|
||||
EPOCH_SIZE=$2
|
||||
IMAGE_DIR=$3
|
||||
ANNO_PATH=$4
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$5
|
||||
|
||||
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
|
@ -40,6 +49,7 @@ do
|
|||
--distribute=1 \
|
||||
--device_num=$RANK_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--mindrecord_dir=$MINDRECORD_DIR \
|
||||
--image_dir=$IMAGE_DIR \
|
||||
--epoch_size=$EPOCH_SIZE \
|
||||
--anno_path=$ANNO_PATH > log.txt 2>&1 &
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_eval.sh DEVICE_ID CKPT_PATH MINDRECORD_DIR IMAGE_DIR ANNO_PATH"
|
||||
echo "for example: sh run_eval.sh 0 yolo.ckpt ./Mindrecord_eval ./dataset ./dataset/eval.txt"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
python eval.py --device_id=$1 --ckpt_path=$2 --mindrecord_dir=$3 --image_dir=$4 --anno_path=$5
|
|
@ -14,8 +14,10 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_standalone_train.sh DEVICE_ID EPOCH_SIZE IMAGE_DIR ANNO_PATH"
|
||||
echo "for example: sh run_standalone_train.sh 0 50 ./dataset/coco/train2017 ./dataset/train.txt"
|
||||
echo "sh run_standalone_train.sh DEVICE_ID EPOCH_SIZE MINDRECORD_DIR IMAGE_DIR ANNO_PATH"
|
||||
echo "for example: sh run_standalone_train.sh 0 50 ./Mindrecord_train ./dataset ./dataset/train.txt"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
python train.py --device_id=$1 --epoch_size=$2 --image_dir=$3 --anno_path=$4
|
||||
python train.py --device_id=$1 --epoch_size=$2 --mindrecord_dir=$3 --image_dir=$4 --anno_path=$5
|
||||
|
|
|
@ -16,26 +16,30 @@
|
|||
"""
|
||||
######################## train YOLOv3 example ########################
|
||||
train YOLOv3 and get network model files(.ckpt) :
|
||||
python train.py --image_dir dataset/coco/coco/train2017 --anno_path dataset/coco/train_coco.txt
|
||||
python train.py --image_dir /data --anno_path /data/coco/train_coco.txt --mindrecord_dir=/data/Mindrecord_train
|
||||
|
||||
If the mindrecord_dir is empty, it wil generate mindrecord file by image_dir and anno_path.
|
||||
Note if mindrecord_dir isn't empty, it will use mindrecord_dir rather than image_dir and anno_path.
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
|
||||
from mindspore.train import Model, ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common.initializer import initializer
|
||||
|
||||
from mindspore.model_zoo.yolov3 import yolov3_resnet18, YoloWithLossCell, TrainingWrapper
|
||||
from dataset import create_yolo_dataset
|
||||
from dataset import create_yolo_dataset, data_to_mindrecord_byte_image
|
||||
from config import ConfigYOLOV3ResNet18
|
||||
|
||||
|
||||
def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps=False):
|
||||
"""Set learning rate"""
|
||||
"""Set learning rate."""
|
||||
lr_each_step = []
|
||||
lr = learning_rate
|
||||
for i in range(global_step):
|
||||
|
@ -57,7 +61,9 @@ def init_net_param(net, init='ones'):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="YOLOv3")
|
||||
parser = argparse.ArgumentParser(description="YOLOv3 train")
|
||||
parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create "
|
||||
"Mindrecord, default is false.")
|
||||
parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
||||
|
@ -67,12 +73,19 @@ if __name__ == '__main__':
|
|||
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path")
|
||||
parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.")
|
||||
parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.")
|
||||
parser.add_argument("--image_dir", type=str, required=True, help="Dataset image dir.")
|
||||
parser.add_argument("--anno_path", type=str, required=True, help="Dataset anno path.")
|
||||
parser.add_argument("--mindrecord_dir", type=str, default="./Mindrecord_train",
|
||||
help="Mindrecord directory. If the mindrecord_dir is empty, it wil generate mindrecord file by"
|
||||
"image_dir and anno_path. Note if mindrecord_dir isn't empty, it will use mindrecord_dir "
|
||||
"rather than image_dir and anno_path. Default is ./Mindrecord_train")
|
||||
parser.add_argument("--image_dir", type=str, default="", help="Dataset directory, "
|
||||
"the absolute image path is joined by the image_dir "
|
||||
"and the relative path in anno_path")
|
||||
parser.add_argument("--anno_path", type=str, default="", help="Annotation path.")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True)
|
||||
context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True,
|
||||
enable_auto_mixed_precision=False)
|
||||
if args_opt.distribute:
|
||||
device_num = args_opt.device_num
|
||||
context.reset_auto_parallel_context()
|
||||
|
@ -80,16 +93,42 @@ if __name__ == '__main__':
|
|||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
|
||||
device_num=device_num)
|
||||
init()
|
||||
rank = args_opt.device_id
|
||||
rank = args_opt.device_id % device_num
|
||||
else:
|
||||
context.set_context(enable_hccl=False)
|
||||
rank = 0
|
||||
device_num = 1
|
||||
|
||||
print("Start create dataset!")
|
||||
|
||||
# It will generate mindrecord file in args_opt.mindrecord_dir,
|
||||
# and the file name is yolo.mindrecord0, 1, ... file_num.
|
||||
if not os.path.isdir(args_opt.mindrecord_dir):
|
||||
os.makedirs(args_opt.mindrecord_dir)
|
||||
|
||||
prefix = "yolo.mindrecord"
|
||||
mindrecord_file = os.path.join(args_opt.mindrecord_dir, prefix + "0")
|
||||
if not os.path.exists(mindrecord_file):
|
||||
if os.path.isdir(args_opt.image_dir) and os.path.exists(args_opt.anno_path):
|
||||
print("Create Mindrecord.")
|
||||
data_to_mindrecord_byte_image(args_opt.image_dir,
|
||||
args_opt.anno_path,
|
||||
args_opt.mindrecord_dir,
|
||||
prefix=prefix,
|
||||
file_num=8)
|
||||
print("Create Mindrecord Done, at {}".format(args_opt.mindrecord_dir))
|
||||
else:
|
||||
print("image_dir or anno_path not exits.")
|
||||
|
||||
if not args_opt.only_create_dataset:
|
||||
loss_scale = float(args_opt.loss_scale)
|
||||
dataset = create_yolo_dataset(args_opt.image_dir, args_opt.anno_path, repeat_num=args_opt.epoch_size,
|
||||
|
||||
# When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0.
|
||||
dataset = create_yolo_dataset(mindrecord_file, repeat_num=args_opt.epoch_size,
|
||||
batch_size=args_opt.batch_size, device_num=device_num, rank=rank)
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
print("Create dataset done!")
|
||||
|
||||
net = yolov3_resnet18(ConfigYOLOV3ResNet18())
|
||||
net = YoloWithLossCell(net, ConfigYOLOV3ResNet18())
|
||||
init_net_param(net, "XavierUniform")
|
||||
|
@ -97,19 +136,22 @@ if __name__ == '__main__':
|
|||
# checkpoint
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="yolov3", directory=None, config=ckpt_config)
|
||||
if args_opt.checkpoint_path != "":
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
lr = Tensor(get_lr(learning_rate=0.001, start_step=0, global_step=args_opt.epoch_size * dataset_size,
|
||||
decay_step=1000, decay_rate=0.95))
|
||||
opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale)
|
||||
net = TrainingWrapper(net, opt, loss_scale)
|
||||
|
||||
if args_opt.checkpoint_path != "":
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb]
|
||||
|
||||
model = Model(net)
|
||||
dataset_sink_mode = False
|
||||
if args_opt.mode == "graph":
|
||||
print("In graph mode, one epoch return a loss.")
|
||||
dataset_sink_mode = True
|
||||
print("Start train YOLOv3.")
|
||||
print("Start train YOLOv3, the first epoch will be slower because of the graph compilation.")
|
||||
model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode)
|
||||
|
|
|
@ -0,0 +1,146 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""metrics utils"""
|
||||
|
||||
import numpy as np
|
||||
from config import ConfigYOLOV3ResNet18
|
||||
|
||||
|
||||
def calc_iou(bbox_pred, bbox_ground):
|
||||
"""Calculate iou of predicted bbox and ground truth."""
|
||||
x1 = bbox_pred[0]
|
||||
y1 = bbox_pred[1]
|
||||
width1 = bbox_pred[2] - bbox_pred[0]
|
||||
height1 = bbox_pred[3] - bbox_pred[1]
|
||||
|
||||
x2 = bbox_ground[0]
|
||||
y2 = bbox_ground[1]
|
||||
width2 = bbox_ground[2] - bbox_ground[0]
|
||||
height2 = bbox_ground[3] - bbox_ground[1]
|
||||
|
||||
endx = max(x1 + width1, x2 + width2)
|
||||
startx = min(x1, x2)
|
||||
width = width1 + width2 - (endx - startx)
|
||||
|
||||
endy = max(y1 + height1, y2 + height2)
|
||||
starty = min(y1, y2)
|
||||
height = height1 + height2 - (endy - starty)
|
||||
|
||||
if width <= 0 or height <= 0:
|
||||
iou = 0
|
||||
else:
|
||||
area = width * height
|
||||
area1 = width1 * height1
|
||||
area2 = width2 * height2
|
||||
iou = area * 1. / (area1 + area2 - area)
|
||||
|
||||
return iou
|
||||
|
||||
|
||||
def apply_nms(all_boxes, all_scores, thres, max_boxes):
|
||||
"""Apply NMS to bboxes."""
|
||||
x1 = all_boxes[:, 0]
|
||||
y1 = all_boxes[:, 1]
|
||||
x2 = all_boxes[:, 2]
|
||||
y2 = all_boxes[:, 3]
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
|
||||
order = all_scores.argsort()[::-1]
|
||||
keep = []
|
||||
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
|
||||
if len(keep) >= max_boxes:
|
||||
break
|
||||
|
||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||
inter = w * h
|
||||
|
||||
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
||||
|
||||
inds = np.where(ovr <= thres)[0]
|
||||
|
||||
order = order[inds + 1]
|
||||
return keep
|
||||
|
||||
|
||||
def metrics(pred_data):
|
||||
"""Calculate precision and recall of predicted bboxes."""
|
||||
config = ConfigYOLOV3ResNet18()
|
||||
num_classes = config.num_classes
|
||||
count_corrects = [1e-6 for _ in range(num_classes)]
|
||||
count_grounds = [1e-6 for _ in range(num_classes)]
|
||||
count_preds = [1e-6 for _ in range(num_classes)]
|
||||
|
||||
for i, sample in enumerate(pred_data):
|
||||
gt_anno = sample["annotation"]
|
||||
box_scores = sample['box_scores']
|
||||
boxes = sample['boxes']
|
||||
mask = box_scores >= config.obj_threshold
|
||||
boxes_ = []
|
||||
scores_ = []
|
||||
classes_ = []
|
||||
max_boxes = config.nms_max_num
|
||||
for c in range(num_classes):
|
||||
class_boxes = np.reshape(boxes, [-1, 4])[np.reshape(mask[:, c], [-1])]
|
||||
class_box_scores = np.reshape(box_scores[:, c], [-1])[np.reshape(mask[:, c], [-1])]
|
||||
nms_index = apply_nms(class_boxes, class_box_scores, config.nms_threshold, max_boxes)
|
||||
class_boxes = class_boxes[nms_index]
|
||||
class_box_scores = class_box_scores[nms_index]
|
||||
classes = np.ones_like(class_box_scores, 'int32') * c
|
||||
boxes_.append(class_boxes)
|
||||
scores_.append(class_box_scores)
|
||||
classes_.append(classes)
|
||||
|
||||
boxes = np.concatenate(boxes_, axis=0)
|
||||
classes = np.concatenate(classes_, axis=0)
|
||||
|
||||
|
||||
# metric
|
||||
count_correct = [1e-6 for _ in range(num_classes)]
|
||||
count_ground = [1e-6 for _ in range(num_classes)]
|
||||
count_pred = [1e-6 for _ in range(num_classes)]
|
||||
|
||||
for anno in gt_anno:
|
||||
count_ground[anno[4]] += 1
|
||||
|
||||
for box_index, box in enumerate(boxes):
|
||||
bbox_pred = [box[1], box[0], box[3], box[2]]
|
||||
count_pred[classes[box_index]] += 1
|
||||
|
||||
for anno in gt_anno:
|
||||
class_ground = anno[4]
|
||||
|
||||
if classes[box_index] == class_ground:
|
||||
iou = calc_iou(bbox_pred, anno)
|
||||
if iou >= 0.5:
|
||||
count_correct[class_ground] += 1
|
||||
break
|
||||
|
||||
count_corrects = [count_corrects[i] + count_correct[i] for i in range(num_classes)]
|
||||
count_preds = [count_preds[i] + count_pred[i] for i in range(num_classes)]
|
||||
count_grounds = [count_grounds[i] + count_ground[i] for i in range(num_classes)]
|
||||
|
||||
precision = np.array([count_corrects[ix] / count_preds[ix] for ix in range(num_classes)])
|
||||
recall = np.array([count_corrects[ix] / count_grounds[ix] for ix in range(num_classes)])
|
||||
return precision, recall
|
|
@ -34,6 +34,7 @@ static std::map<string, string> tbe_func_adapter_map = {
|
|||
{"tensor_add", "add"},
|
||||
{"reduce_mean", "reduce_mean_d"},
|
||||
{"reduce_max", "reduce_max_d"},
|
||||
{"reduce_min", "reduce_min_d"},
|
||||
{"conv2d_backprop_filter", "conv2d_backprop_filter_d"},
|
||||
{"conv2d_backprop_input", "conv2d_backprop_input_d"},
|
||||
{"top_kv2", "top_k"},
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
"""YOLOv3 based on ResNet18."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context, Tensor
|
||||
|
@ -31,19 +32,14 @@ def weight_variable():
|
|||
return TruncatedNormal(0.02)
|
||||
|
||||
|
||||
class _conv_with_pad(nn.Cell):
|
||||
class _conv2d(nn.Cell):
|
||||
"""Create Conv2D with padding."""
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
|
||||
super(_conv_with_pad, self).__init__()
|
||||
total_pad = kernel_size - 1
|
||||
pad_begin = total_pad // 2
|
||||
pad_end = total_pad - pad_begin
|
||||
self.pad = P.Pad(((0, 0), (0, 0), (pad_begin, pad_end), (pad_begin, pad_end)))
|
||||
super(_conv2d, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=kernel_size, stride=stride, padding=0, pad_mode='valid',
|
||||
kernel_size=kernel_size, stride=stride, padding=0, pad_mode='same',
|
||||
weight_init=weight_variable())
|
||||
def construct(self, x):
|
||||
x = self.pad(x)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
@ -101,15 +97,15 @@ class BasicBlock(nn.Cell):
|
|||
momentum=0.99):
|
||||
super(BasicBlock, self).__init__()
|
||||
|
||||
self.conv1 = _conv_with_pad(in_channels, out_channels, 3, stride=stride)
|
||||
self.conv1 = _conv2d(in_channels, out_channels, 3, stride=stride)
|
||||
self.bn1 = _fused_bn(out_channels, momentum=momentum)
|
||||
self.conv2 = _conv_with_pad(out_channels, out_channels, 3)
|
||||
self.conv2 = _conv2d(out_channels, out_channels, 3)
|
||||
self.bn2 = _fused_bn(out_channels, momentum=momentum)
|
||||
self.relu = P.ReLU()
|
||||
self.down_sample_layer = None
|
||||
self.downsample = (in_channels != out_channels)
|
||||
if self.downsample:
|
||||
self.down_sample_layer = _conv_with_pad(in_channels, out_channels, 1, stride=stride)
|
||||
self.down_sample_layer = _conv2d(in_channels, out_channels, 1, stride=stride)
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
def construct(self, x):
|
||||
|
@ -166,7 +162,7 @@ class ResNet(nn.Cell):
|
|||
raise ValueError("the length of "
|
||||
"layer_num, inchannel, outchannel list must be 4!")
|
||||
|
||||
self.conv1 = _conv_with_pad(3, 64, 7, stride=2)
|
||||
self.conv1 = _conv2d(3, 64, 7, stride=2)
|
||||
self.bn1 = _fused_bn(64)
|
||||
self.relu = P.ReLU()
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
|
||||
|
@ -452,7 +448,7 @@ class DetectionBlock(nn.Cell):
|
|||
|
||||
if self.training:
|
||||
return grid, prediction, box_xy, box_wh
|
||||
return self.concat((box_xy, box_wh, box_confidence, box_probs))
|
||||
return box_xy, box_wh, box_confidence, box_probs
|
||||
|
||||
|
||||
class Iou(nn.Cell):
|
||||
|
@ -675,3 +671,78 @@ class TrainingWrapper(nn.Cell):
|
|||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
return F.depend(loss, self.optimizer(grads))
|
||||
|
||||
|
||||
class YoloBoxScores(nn.Cell):
|
||||
"""
|
||||
Calculate the boxes of the original picture size and the score of each box.
|
||||
|
||||
Args:
|
||||
config (Class): YOLOv3 config.
|
||||
|
||||
Returns:
|
||||
Tensor, the boxes of the original picture size.
|
||||
Tensor, the score of each box.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(YoloBoxScores, self).__init__()
|
||||
self.input_shape = Tensor(np.array(config.img_shape), ms.float32)
|
||||
self.num_classes = config.num_classes
|
||||
|
||||
def construct(self, box_xy, box_wh, box_confidence, box_probs, image_shape):
|
||||
batch_size = F.shape(box_xy)[0]
|
||||
x = box_xy[:, :, :, :, 0:1]
|
||||
y = box_xy[:, :, :, :, 1:2]
|
||||
box_yx = P.Concat(-1)((y, x))
|
||||
w = box_wh[:, :, :, :, 0:1]
|
||||
h = box_wh[:, :, :, :, 1:2]
|
||||
box_hw = P.Concat(-1)((h, w))
|
||||
|
||||
new_shape = P.Round()(image_shape * P.ReduceMin()(self.input_shape / image_shape))
|
||||
offset = (self.input_shape - new_shape) / 2.0 / self.input_shape
|
||||
scale = self.input_shape / new_shape
|
||||
box_yx = (box_yx - offset) * scale
|
||||
box_hw = box_hw * scale
|
||||
|
||||
box_min = box_yx - box_hw / 2.0
|
||||
box_max = box_yx + box_hw / 2.0
|
||||
boxes = P.Concat(-1)((box_min[:, :, :, :, 0:1],
|
||||
box_min[:, :, :, :, 1:2],
|
||||
box_max[:, :, :, :, 0:1],
|
||||
box_max[:, :, :, :, 1:2]))
|
||||
image_scale = P.Tile()(image_shape, (1, 2))
|
||||
boxes = boxes * image_scale
|
||||
boxes = F.reshape(boxes, (batch_size, -1, 4))
|
||||
boxes_scores = box_confidence * box_probs
|
||||
boxes_scores = F.reshape(boxes_scores, (batch_size, -1, self.num_classes))
|
||||
return boxes, boxes_scores
|
||||
|
||||
|
||||
class YoloWithEval(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of YOLOv3 evaluation.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function and optimizer must not be added.
|
||||
config (Class): YOLOv3 config.
|
||||
|
||||
Returns:
|
||||
Tensor, the boxes of the original picture size.
|
||||
Tensor, the score of each box.
|
||||
Tensor, the original picture size.
|
||||
"""
|
||||
def __init__(self, network, config):
|
||||
super(YoloWithEval, self).__init__()
|
||||
self.yolo_network = network
|
||||
self.box_score_0 = YoloBoxScores(config)
|
||||
self.box_score_1 = YoloBoxScores(config)
|
||||
self.box_score_2 = YoloBoxScores(config)
|
||||
|
||||
def construct(self, x, image_shape):
|
||||
yolo_output = self.yolo_network(x)
|
||||
boxes_0, boxes_scores_0 = self.box_score_0(*yolo_output[0], image_shape)
|
||||
boxes_1, boxes_scores_1 = self.box_score_1(*yolo_output[1], image_shape)
|
||||
boxes_2, boxes_scores_2 = self.box_score_2(*yolo_output[2], image_shape)
|
||||
boxes = P.Concat(1)((boxes_0, boxes_1, boxes_2))
|
||||
boxes_scores = P.Concat(1)((boxes_scores_0, boxes_scores_1, boxes_scores_2))
|
||||
return boxes, boxes_scores, image_shape
|
||||
|
|
|
@ -85,7 +85,9 @@ from .logical_and import _logical_and_tbe
|
|||
from .logical_not import _logical_not_tbe
|
||||
from .logical_or import _logical_or_tbe
|
||||
from .reduce_max import _reduce_max_tbe
|
||||
from .reduce_min import _reduce_min_tbe
|
||||
from .reduce_sum import _reduce_sum_tbe
|
||||
from .round import _round_tbe
|
||||
from .tanh import _tanh_tbe
|
||||
from .tanh_grad import _tanh_grad_tbe
|
||||
from .softmax import _softmax_tbe
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ReduceMin op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "ReduceMin",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "OPAQUE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "reduce_min_d.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "reduce_min_d",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
{
|
||||
"name": "axis",
|
||||
"param_type": "required",
|
||||
"type": "listInt",
|
||||
"value": "all"
|
||||
},
|
||||
{
|
||||
"name": "keep_dims",
|
||||
"param_type": "required",
|
||||
"type": "bool",
|
||||
"value": "all"
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float", "int8", "int8", "uint8", "uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float", "float", "int8", "int8", "uint8", "uint8"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
def _reduce_min_tbe():
|
||||
"""ReduceMin TBE register"""
|
||||
return
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Round op"""
|
||||
from mindspore.ops.op_info_register import op_info_register
|
||||
|
||||
|
||||
@op_info_register("""{
|
||||
"op_name": "Round",
|
||||
"imply_type": "TBE",
|
||||
"fusion_type": "ELEMWISE",
|
||||
"async_flag": false,
|
||||
"binfile_name": "round.so",
|
||||
"compute_cost": 10,
|
||||
"kernel_name": "round",
|
||||
"partial_flag": true,
|
||||
"attr": [
|
||||
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "FracZ", "DefaultFormat", "NC1HWC0", "FracZ"
|
||||
],
|
||||
"name": "x",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"index": 0,
|
||||
"dtype": [
|
||||
"float16", "float16", "float16", "float", "float", "float"
|
||||
],
|
||||
"format": [
|
||||
"DefaultFormat", "NC1HWC0", "FracZ", "DefaultFormat", "NC1HWC0", "FracZ"
|
||||
],
|
||||
"name": "y",
|
||||
"need_compile": false,
|
||||
"param_type": "required",
|
||||
"shape": "all"
|
||||
}
|
||||
]
|
||||
}""")
|
||||
def _round_tbe():
|
||||
"""Round TBE register"""
|
||||
return
|
Loading…
Reference in New Issue