yolov3-darknet add weight transform scripts.

This commit is contained in:
linqingke 2020-09-23 17:08:54 +08:00
parent ae356c1df9
commit af440bb3d8
7 changed files with 148 additions and 20 deletions

View File

@ -37,7 +37,7 @@ do
cp -r ./src ./eval_$i
cd ./eval_$i || exit
export RANK_ID=$i
echo "start training for rank $i, device $DEVICE_ID"
echo "start infering for rank $i, device $DEVICE_ID"
env > env.log
python eval.py \
--data_dir=$DATASET \

View File

@ -141,7 +141,6 @@ def classification_dataset(data_dir, image_size, per_batch_size, max_epoch, rank
dataset = TxtDataset(root, data_dir)
sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle)
de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler)
de_dataset.set_dataset_size(len(sampler))
de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=8, operations=transform_img)
de_dataset = de_dataset.map(input_columns="label", num_parallel_workers=8, operations=transform_label)

View File

@ -0,0 +1,36 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import numpy as np
from mindspore import Tensor
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
from src.unet.unet_model import UNet
parser = argparse.ArgumentParser(description='Export ckpt to air')
parser.add_argument('--ckpt_file', type=str, default="ckpt_unet_medical_adam-1_600.ckpt",
help='The path of input ckpt file')
parser.add_argument('--air_file', type=str, default="unet_medical_adam-1_600.air", help='The path of output air file')
args = parser.parse_args()
net = UNet(n_channels=1, n_classes=2)
# return a parameter dict for model
param_dict = load_checkpoint(args.ckpt_file)
# load the parameter into net
load_param_into_net(net, param_dict)
input_data = np.random.uniform(0.0, 1.0, size=[1, 1, 572, 572]).astype(np.float32)
export(net, Tensor(input_data), file_name=args.air_file, file_format='AIR')

View File

@ -69,7 +69,7 @@ After installing MindSpore via the official website, you can start training and
```
# The darknet53_backbone.ckpt in the follow script is got from darknet53 training like paper.
# The parameter of pretrained_backbone is not necessary.
# pretrained_backbone can use src/convert_weight.py, convert darknet53.conv.74 to mindspore ckpt, darknet53.conv.74 can get from `https://pjreddie.com/media/files/darknet53.conv.74` .
# The parameter of training_shape define image shape for network, default is "".
# It means use 10 kinds of shape as input shape, or it can be set some kind of shape.
# run training example(1p) by python command.

View File

@ -0,0 +1,14 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

View File

@ -0,0 +1,80 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Convert weight to mindspore ckpt."""
import os
import argparse
import numpy as np
from mindspore.train.serialization import save_checkpoint
from mindspore import Tensor
from src.yolo import YOLOV3DarkNet53
def load_weight(weights_file):
"""Loads pre-trained weights."""
if not os.path.isfile(weights_file):
raise ValueError(f'"{weights_file}" is not a valid weight file.')
with open(weights_file, 'rb') as fp:
np.fromfile(fp, dtype=np.int32, count=5)
return np.fromfile(fp, dtype=np.float32)
def build_network():
"""Build YOLOv3 network."""
network = YOLOV3DarkNet53(is_training=True)
params = network.get_parameters()
params = [p for p in params if 'backbone' in p.name]
return params
def convert(weights_file, output_file):
"""Conver weight to mindspore ckpt."""
params = build_network()
weights = load_weight(weights_file)
index = 0
param_list = []
for i in range(0, len(params), 5):
weight = params[i]
mean = params[i+1]
var = params[i+2]
gamma = params[i+3]
beta = params[i+4]
beta_data = weights[index: index+beta.size()].reshape(beta.shape)
index += beta.size()
gamma_data = weights[index: index+gamma.size()].reshape(gamma.shape)
index += gamma.size()
mean_data = weights[index: index+mean.size()].reshape(mean.shape)
index += mean.size()
var_data = weights[index: index + var.size()].reshape(var.shape)
index += var.size()
weight_data = weights[index: index+weight.size()].reshape(weight.shape)
index += weight.size()
param_list.append({'name': weight.name, 'type': weight.dtype, 'shape': weight.shape,
'data': Tensor(weight_data)})
param_list.append({'name': mean.name, 'type': mean.dtype, 'shape': mean.shape, 'data': Tensor(mean_data)})
param_list.append({'name': var.name, 'type': var.dtype, 'shape': var.shape, 'data': Tensor(var_data)})
param_list.append({'name': gamma.name, 'type': gamma.dtype, 'shape': gamma.shape, 'data': Tensor(gamma_data)})
param_list.append({'name': beta.name, 'type': beta.dtype, 'shape': beta.shape, 'data': Tensor(beta_data)})
save_checkpoint(param_list, output_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="yolov3 weight convert.")
parser.add_argument("--input_file", type=str, default="./darknet53.conv.74", help="input file path.")
parser.add_argument("--output_file", type=str, default="./ackbone_darknet53.ckpt", help="output file path.")
args_opt = parser.parse_args()
convert(args_opt.input_file, args_opt.output_file)

View File

@ -115,39 +115,38 @@ class DarkNet(nn.Cell):
out_channels[0],
kernel_size=3,
stride=2)
self.conv2 = conv_block(in_channels[1],
out_channels[1],
kernel_size=3,
stride=2)
self.conv3 = conv_block(in_channels[2],
out_channels[2],
kernel_size=3,
stride=2)
self.conv4 = conv_block(in_channels[3],
out_channels[3],
kernel_size=3,
stride=2)
self.conv5 = conv_block(in_channels[4],
out_channels[4],
kernel_size=3,
stride=2)
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=out_channels[0],
out_channel=out_channels[0])
self.conv2 = conv_block(in_channels[1],
out_channels[1],
kernel_size=3,
stride=2)
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=out_channels[1],
out_channel=out_channels[1])
self.conv3 = conv_block(in_channels[2],
out_channels[2],
kernel_size=3,
stride=2)
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=out_channels[2],
out_channel=out_channels[2])
self.conv4 = conv_block(in_channels[3],
out_channels[3],
kernel_size=3,
stride=2)
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=out_channels[3],
out_channel=out_channels[3])
self.conv5 = conv_block(in_channels[4],
out_channels[4],
kernel_size=3,
stride=2)
self.layer5 = self._make_layer(block,
layer_nums[4],
in_channel=out_channels[4],