add cpu mode for maskrcnn-mobilenetv1

This commit is contained in:
huangbo77 2021-08-10 15:08:13 +08:00
parent b59e9c3e20
commit 22f2b0e38a
16 changed files with 420 additions and 226 deletions

View File

@ -103,7 +103,7 @@ void SearchSortedCPUKernel<S, T>::CheckParam(const std::vector<AddressPtr> &inpu
}
}
};
CPUKernelUtils::ParallelFor(task, static_cast<size_t>(list_count));
CPUKernelUtils::ParallelFor(task, IntToSize(list_count));
}
} // namespace kernel
} // namespace mindspore

View File

@ -58,8 +58,8 @@ Note that you can run the scripts based on the dataset mentioned in original pap
# [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor.
- HardwareAscend/CPU
- Prepare hardware environment with Ascend or CPU processor.
- Framework
- [MindSpore](https://gitee.com/mindspore/mindspore)
- For more information, please check the resources below:
@ -78,7 +78,7 @@ pip install mmcv=0.2.14
1. Download the dataset COCO2017.
2. Change the COCO_ROOT and other settings you need in `config.py`. The directory structure should look like the follows:
2. Change the COCO_ROOT and other settings you need in `default_config.yaml`. The directory structure should look like the follows:
```
.
@ -90,24 +90,31 @@ pip install mmcv=0.2.14
└─train2017
```
If you use your own dataset to train the network, **Select dataset to other when run script.**
If you use your own dataset to train the network, **Select dataset to other when run script.**
Create a txt file to store dataset information organized in the way as shown as following:
```
train2017/0000001.jpg 0,259,401,459,7 35,28,324,201,2 0,30,59,80,2
```
Each row is an image annotation split by spaces. The first column is a relative path of image, followed by columns containing box and class information in the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `IMAGE_DIR`(dataset directory) and the relative path in `ANNO_PATH`(the TXT file path), which can be set in `config.py`.
Each row is an image annotation split by spaces. The first column is a relative path of image, followed by columns containing box and class information in the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `IMAGE_DIR`(dataset directory) and the relative path in `ANNO_PATH`(the TXT file path), which can be set in `default_config.yaml`.
3. Execute train script.
After dataset preparation, you can start training as follows:
```
```bash
On Ascend:
# distributed training
bash run_distribute_train.sh [RANK_TABLE_FILE] [PRETRAINED_CKPT]
# standalone training
bash run_standalone_train.sh [PRETRAINED_CKPT]
On CPU:
# standalone training
bash run_standalone_train_cpu.sh [PRETRAINED_PATH](optional)
```
Note:
@ -116,27 +123,32 @@ pip install mmcv=0.2.14
3. For large models like maskrcnn_mobilenetv1, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size.
4. Execute eval script.
After training, you can start evaluation as follows:
```bash
# Evaluation
bash run_eval.sh [VALIDATION_JSON_FILE] [CHECKPOINT_PATH]
```
After training, you can start evaluation as follows:
Note:
1. VALIDATION_JSON_FILE is a label json file for evaluation.
```bash
# Evaluation on Ascend
bash run_eval.sh [VALIDATION_JSON_FILE] [CHECKPOINT_PATH]
# Evaluation on CPU
bash run_eval_cpu.sh [ANN_FILE] [CHECKPOINT_PATH]
```
Note:
1. VALIDATION_JSON_FILE is a label json file for evaluation.
5. Execute inference script.
After training, you can start inference as follows:
```shell
# inference
bash run_infer_310.sh [MODEL_PATH] [DATA_PATH] [ANN_FILE_PATH]
```
After training, you can start inference as follows:
Note:
1. MODEL_PATH is a model file, exported by export script file.
2. ANN_FILE_PATH is a annotation file for inference.
```shell
# inference
bash run_infer_310.sh [MODEL_PATH] [DATA_PATH] [ANN_FILE_PATH]
```
Note:
1. MODEL_PATH is a model file, exported by export script file.
2. ANN_FILE_PATH is a annotation file for inference.
- Running on [ModelArts](https://support.huaweicloud.com/modelarts/)
@ -284,14 +296,16 @@ pip install mmcv=0.2.14
```shell
.
└─MaskRcnn
└─MaskRcnn_Mobilenetv1
├─README.md # README
├─ascend310_infer #application for 310 inference
├─ascend310_infer # application for 310 inference
├─scripts # shell script
├─run_standalone_train.sh # training in standalone mode(1pcs)
├─run_distribute_train.sh # training in parallel mode(8 pcs)
├─run_infer_310.sh #shell script for 310 inference
└─run_eval.sh # evaluation
├─run_standalone_train.sh # training in standalone mode on Ascend(1pcs)
├─run_standalone_train_cpu.sh # training in standalone mode on CPU(1pcs)
├─run_distribute_train.sh # training in parallel mode on Ascend(8 pcs)
├─run_infer_310.sh # shell script for 310 inference
├─run_eval_cpu.sh # evaluation on CPU
└─run_eval.sh # evaluation on Ascend
├─src
├─maskrcnn_mobilenetv1
├─__init__.py
@ -306,11 +320,18 @@ pip install mmcv=0.2.14
├─mobilenetv1.py # backbone network
├─roi_align.py # roi align network
└─rpn.py # reagion proposal network
├─config.py # network configuration
├─util.py # routine operation
├─model_utils # network configuration
├─__init__.py
├─config.py # network configuration
├─device_adapter.py # Get cloud ID
├─local_adapter.py # Get local ID
├─moxing_adapter.py # Parameter processing
├─dataset.py # dataset utils
├─lr_schedule.py # leanring rate geneatore
├─network_define.py # network define for maskrcnn_mobilenetv1
└─util.py # routine operation
├─default_config.yaml # default configuration settings
├─mindspore_hub_conf.py # mindspore hub interface
├─export.py #script to export AIR,MINDIR model
├─eval.py # evaluation scripts
@ -323,11 +344,18 @@ pip install mmcv=0.2.14
### [Training Script Parameters](#contents)
```bash
On Ascend:
# distributed training
Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [PRETRAINED_MODEL]
# standalone training
Usage: bash run_standalone_train.sh [PRETRAINED_MODEL]
On CPU:
# standalone training
Usage: bash run_standalone_train_cpu.sh [PRETRAINED_MODEL](optional)
```
### [Parameters Configuration](#contents)
@ -474,20 +502,27 @@ Usage: bash run_standalone_train.sh [PRETRAINED_MODEL]
## [Training Process](#contents)
- Set options in `config.py`, including loss_scale, learning rate and network hyperparameters. Click [here](https://www.mindspore.cn/docs/programming_guide/en/master/dataset_sample.html) for more information about dataset.
- Set options in `default_config.yaml`, including loss_scale, learning rate and network hyperparameters. Click [here](https://www.mindspore.cn/docs/programming_guide/en/master/dataset_sample.html) for more information about dataset.
### [Training](#content)
- Run `run_standalone_train.sh` for non-distributed training of maskrcnn_mobilenetv1 model.
- Run `run_standalone_train.sh` for non-distributed training of maskrcnn_mobilenetv1 model on Ascend.
```bash
# standalone training
bash run_standalone_train.sh [PRETRAINED_MODEL]
```
- Run `run_standalone_train_cpu.sh` for non-distributed training of maskrcnn_mobilenetv1 model on CPU.
```bash
# standalone training
bash run_standalone_train_cpu.sh [PRETRAINED_MODEL](optional)
```
### [Distributed Training](#content)
- Run `run_distribute_train.sh` for distributed training of Mask model.
- Run `run_distribute_train.sh` for distributed training of Mask model on Ascend.
```bash
bash run_distribute_train.sh [RANK_TABLE_FILE] [PRETRAINED_MODEL]
@ -526,7 +561,7 @@ bash run_eval.sh [VALIDATION_ANN_FILE_JSON] [CHECKPOINT_PATH]
```
> As for the COCO2017 dataset, VALIDATION_ANN_FILE_JSON is refer to the annotations/instances_val2017.json in the dataset directory.
> checkpoint can be produced and saved in training process, whose folder name begins with "train/checkpoint" or "train_parallel*/checkpoint".
> Checkpoint can be produced and saved in training process, whose folder name begins with "train/checkpoint" or "train_parallel*/checkpoint".
### [Evaluation result](#content)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-21 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -31,7 +31,9 @@ from src.util import coco_eval, bbox2result_1image, results2json, get_seg_masks
set_seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id())
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
if config.device_target == "Ascend":
context.set_context(device_id=config.device_id)
def maskrcnn_eval(dataset_path, ckpt_path, ann_file):
"""MaskRcnn evaluation."""

View File

@ -0,0 +1,62 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: bash run_eval_cpu.sh [ANN_FILE] [CHECKPOINT_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
echo $PATH1
echo $PATH2
if [ ! -f $PATH1 ]
then
echo "error: ANN_FILE=$PATH1 is not a file"
exit 1
fi
if [ ! -f $PATH2 ]
then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp ../*.yaml ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start eval for CPU"
python eval.py --ann_file=$PATH1 --checkpoint_path=$PATH2 --device_target=CPU > cpu_eval_log.txt 2>&1 &
cd ..

View File

@ -0,0 +1,61 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 0 ] && [ $# != 1 ]
then
echo "Usage: bash run_standalone_train_cpu.sh [PRETRAINED_PATH](optional)"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
if [ $# == 1 ]
then
PATH1=$(get_real_path $1)
echo $PATH1
fi
ulimit -u unlimited
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp ../*.yaml ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for CPU"
env > env.log
if [ $# == 1 ]
then
python train.py --do_train=True --pre_trained=$PATH1 --device_target=CPU > cpu_training_log.txt 2>&1 &
fi
if [ $# == 0 ]
then
python train.py --do_train=True --pre_trained="" --device_target=CPU > cpu_training_log.txt 2>&1 &
fi
cd ..

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-21 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -26,6 +26,7 @@ from numpy import random
import mindspore.dataset as de
import mindspore.dataset.vision.c_transforms as C
from mindspore.mindrecord import FileWriter
from mindspore import context
from src.model_utils.config import config
@ -264,7 +265,7 @@ def impad_to_multiple_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mas
def imnormalize_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask):
"""imnormalize operation for image"""
img_data = mmcv.imnormalize(img, [123.675, 116.28, 103.53], [58.395, 57.12, 57.375], True)
img_data = mmcv.imnormalize(img, np.array([123.675, 116.28, 103.53]), np.array([58.395, 57.12, 57.375]), True)
img_data = img_data.astype(np.float32)
return (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask)
@ -284,10 +285,15 @@ def flip_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask):
def transpose_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask):
"""transpose operation for image"""
if context.get_context("device_target") == "CPU":
platform_dtype = np.float32
else:
platform_dtype = np.float16
img_data = img.transpose(2, 0, 1).copy()
img_data = img_data.astype(np.float16)
img_shape = img_shape.astype(np.float16)
gt_bboxes = gt_bboxes.astype(np.float16)
img_data = img_data.astype(platform_dtype)
img_shape = img_shape.astype(platform_dtype)
gt_bboxes = gt_bboxes.astype(platform_dtype)
gt_label = gt_label.astype(np.int32)
gt_num = gt_num.astype(np.bool)
gt_mask_data = gt_mask.astype(np.bool)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-21 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -19,6 +19,7 @@ import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore import context
class BboxAssignSample(nn.Cell):
@ -79,7 +80,6 @@ class BboxAssignSample(nn.Cell):
self.reshape = P.Reshape()
self.equal = P.Equal()
self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0))
self.scatterNdUpdate = P.ScatterNdUpdate()
self.scatterNd = P.ScatterNd()
self.logicalnot = P.LogicalNot()
self.tile = P.Tile()
@ -93,8 +93,13 @@ class BboxAssignSample(nn.Cell):
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16))
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16))
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16))
if context.get_context("device_target") == "CPU":
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float32))
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float32))
else:
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16))
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16))
def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids):

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-21 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -19,6 +19,7 @@ import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore import context
class BboxAssignSampleForRcnn(nn.Cell):
"""
@ -78,8 +79,12 @@ class BboxAssignSampleForRcnn(nn.Cell):
self.tile = P.Tile()
# Check
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16))
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16))
if context.get_context("device_target") == "CPU":
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float32))
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float32))
else:
self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16))
self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16))
# Init tensor
self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32))
@ -91,8 +96,13 @@ class BboxAssignSampleForRcnn(nn.Cell):
self.gt_ignores = Tensor(np.array(-1 * np.ones(self.num_gts), dtype=np.int32))
self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16))
self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool))
self.bboxs_neg_mask = Tensor(np.zeros((self.num_expected_neg, 4), dtype=np.float16))
self.labels_neg_mask = Tensor(np.array(np.zeros(self.num_expected_neg), dtype=np.uint8))
if context.get_context("device_target") == "CPU":
self.bboxs_neg_mask = Tensor(np.zeros((self.num_expected_neg, 4), dtype=np.float32))
self.labels_neg_mask = Tensor(np.array(np.zeros(self.num_expected_neg), dtype=np.int32))
else:
self.bboxs_neg_mask = Tensor(np.zeros((self.num_expected_neg, 4), dtype=np.float16))
self.labels_neg_mask = Tensor(np.array(np.zeros(self.num_expected_neg), dtype=np.uint8))
self.reshape_shape_pos = (self.num_expected_pos, 1)
self.reshape_shape_neg = (self.num_expected_neg, 1)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-21 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -20,6 +20,7 @@ from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
from mindspore import context
def bias_init_zeros(shape):
@ -66,6 +67,10 @@ class FeatPyramidNeck(nn.Cell):
out_channels,
num_outs):
super(FeatPyramidNeck, self).__init__()
if context.get_context("device_target") == "CPU":
self.platform_mstype = mstype.float32
else:
self.platform_mstype = mstype.float16
self.num_outs = num_outs
self.in_channels = in_channels
self.fpn_layer = len(self.in_channels)
@ -96,9 +101,9 @@ class FeatPyramidNeck(nn.Cell):
x += (self.lateral_convs_list[i](inputs[i]),)
y = (x[3],)
y = y + (x[2] + self.cast(self.interpolate1(y[self.fpn_layer - 4]), mstype.float16),)
y = y + (x[1] + self.cast(self.interpolate2(y[self.fpn_layer - 3]), mstype.float16),)
y = y + (x[0] + self.cast(self.interpolate3(y[self.fpn_layer - 2]), mstype.float16),)
y = y + (x[2] + self.cast(self.interpolate1(y[self.fpn_layer - 4]), self.platform_mstype),)
y = y + (x[1] + self.cast(self.interpolate2(y[self.fpn_layer - 3]), self.platform_mstype),)
y = y + (x[0] + self.cast(self.interpolate3(y[self.fpn_layer - 2]), self.platform_mstype),)
z = ()
for i in range(self.fpn_layer - 1, -1, -1):

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-21 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -20,6 +20,7 @@ from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
from mindspore import context
from .mobilenetv1 import MobileNetV1_FeatureSelector
from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn
from .fpn_neck import FeatPyramidNeck
@ -59,16 +60,15 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell):
self.anchor_strides = config.anchor_strides
self.target_means = tuple(config.rcnn_target_means)
self.target_stds = tuple(config.rcnn_target_stds)
self.init_datatype()
# Anchor generator
anchor_base_sizes = None
self.anchor_base_sizes = list(
self.anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
self.anchor_base_sizes = list(self.anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
self.anchor_generators = []
for anchor_base in self.anchor_base_sizes:
self.anchor_generators.append(
AnchorGenerator(anchor_base, self.anchor_scales, self.anchor_ratios))
self.anchor_generators.append(AnchorGenerator(anchor_base, self.anchor_scales, self.anchor_ratios))
self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales)
@ -78,30 +78,21 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell):
self.anchor_list = self.get_anchors(featmap_sizes)
# Backbone mobilenetv1
self.backbone = MobileNetV1_FeatureSelector(1001, features_only=True).to_float(mstype.float16)
self.backbone = MobileNetV1_FeatureSelector(1001, features_only=True).to_float(self.platform_mstype)
# Fpn
self.fpn_ncek = FeatPyramidNeck(config.fpn_in_channels,
config.fpn_out_channels,
config.fpn_num_outs)
self.fpn_neck = FeatPyramidNeck(config.fpn_in_channels, config.fpn_out_channels, config.fpn_num_outs)
# Rpn and rpn loss
self.gt_labels_stage1 = Tensor(np.ones((self.train_batch_size, config.num_gts)).astype(np.uint8))
self.rpn_with_loss = RPN(config,
self.train_batch_size,
config.rpn_in_channels,
config.rpn_feat_channels,
config.num_anchors,
config.rpn_cls_out_channels)
self.gt_labels_stage1 = Tensor(np.ones((self.train_batch_size, config.num_gts)).astype(self.int_dtype))
self.rpn_with_loss = RPN(config, self.train_batch_size, config.rpn_in_channels, config.rpn_feat_channels,
config.num_anchors, config.rpn_cls_out_channels)
# Proposal
self.proposal_generator = Proposal(config,
self.train_batch_size,
config.activate_num_classes,
self.proposal_generator = Proposal(config, self.train_batch_size, config.activate_num_classes,
config.use_sigmoid_cls)
self.proposal_generator.set_train_local(config, True)
self.proposal_generator_test = Proposal(config,
config.test_batch_size,
config.activate_num_classes,
self.proposal_generator_test = Proposal(config, config.test_batch_size, config.activate_num_classes,
config.use_sigmoid_cls)
self.proposal_generator_test.set_train_local(config, False)
@ -112,40 +103,24 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell):
stds=self.target_stds)
# Roi
self.roi_align = SingleRoIExtractor(config,
config.roi_layer,
config.roi_align_out_channels,
config.roi_align_featmap_strides,
self.train_batch_size,
config.roi_align_finest_scale,
mask=False)
self.roi_align = SingleRoIExtractor(config, config.roi_layer, config.roi_align_out_channels,
config.roi_align_featmap_strides, self.train_batch_size,
config.roi_align_finest_scale, mask=False)
self.roi_align.set_train_local(config, True)
self.roi_align_mask = SingleRoIExtractor(config,
config.roi_layer,
config.roi_align_out_channels,
config.roi_align_featmap_strides,
self.train_batch_size,
config.roi_align_finest_scale,
mask=True)
self.roi_align_mask = SingleRoIExtractor(config, config.roi_layer, config.roi_align_out_channels,
config.roi_align_featmap_strides, self.train_batch_size,
config.roi_align_finest_scale, mask=True)
self.roi_align_mask.set_train_local(config, True)
self.roi_align_test = SingleRoIExtractor(config,
config.roi_layer,
config.roi_align_out_channels,
config.roi_align_featmap_strides,
1,
config.roi_align_finest_scale,
self.roi_align_test = SingleRoIExtractor(config, config.roi_layer, config.roi_align_out_channels,
config.roi_align_featmap_strides, 1, config.roi_align_finest_scale,
mask=False)
self.roi_align_test.set_train_local(config, False)
self.roi_align_mask_test = SingleRoIExtractor(config,
config.roi_layer,
config.roi_align_out_channels,
config.roi_align_featmap_strides,
1,
config.roi_align_finest_scale,
mask=True)
self.roi_align_mask_test = SingleRoIExtractor(config, config.roi_layer, config.roi_align_out_channels,
config.roi_align_featmap_strides, 1,
config.roi_align_finest_scale, mask=True)
self.roi_align_mask_test.set_train_local(config, False)
# Rcnn
@ -176,7 +151,7 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell):
self.rpn_max_num = config.rpn_max_num
self.zeros_for_nms = Tensor(np.zeros((self.rpn_max_num, 3)).astype(np.float16))
self.zeros_for_nms = Tensor(np.zeros((self.rpn_max_num, 3)).astype(self.platform_dtype))
self.ones_mask = np.ones((self.rpn_max_num, 1)).astype(np.bool)
self.zeros_mask = np.zeros((self.rpn_max_num, 1)).astype(np.bool)
self.bbox_mask = Tensor(np.concatenate((self.ones_mask, self.zeros_mask,
@ -184,10 +159,11 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell):
self.nms_pad_mask = Tensor(np.concatenate((self.ones_mask, self.ones_mask,
self.ones_mask, self.ones_mask, self.zeros_mask), axis=1))
self.test_score_thresh = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * config.test_score_thr)
self.test_score_zeros = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * 0)
self.test_box_zeros = Tensor(np.ones((self.rpn_max_num, 4)).astype(np.float16) * -1)
self.test_iou_thr = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * config.test_iou_thr)
self.test_score_thresh = Tensor(np.ones((self.rpn_max_num, 1)).astype(self.platform_dtype)
* config.test_score_thr)
self.test_score_zeros = Tensor(np.ones((self.rpn_max_num, 1)).astype(self.platform_dtype) * 0)
self.test_box_zeros = Tensor(np.ones((self.rpn_max_num, 4)).astype(self.platform_dtype) * -1)
self.test_iou_thr = Tensor(np.ones((self.rpn_max_num, 1)).astype(self.platform_dtype) * config.test_iou_thr)
self.test_max_per_img = config.test_max_per_img
self.nms_test = P.NMSWithMask(config.test_iou_thr)
self.softmax = P.Softmax(axis=1)
@ -201,42 +177,14 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell):
self.concat_end = (self.num_classes - 1)
# Init tensor
roi_align_index = [np.array(np.ones((config.num_expected_pos_stage2 + config.num_expected_neg_stage2, 1)) * i,
dtype=np.float16) for i in range(self.train_batch_size)]
self.init_tensors(config)
roi_align_index_test = [np.array(np.ones((config.rpn_max_num, 1)) * i, dtype=np.float16) \
for i in range(self.test_batch_size)]
self.roi_align_index_tensor = Tensor(np.concatenate(roi_align_index))
self.roi_align_index_test_tensor = Tensor(np.concatenate(roi_align_index_test))
roi_align_index_pos = [np.array(np.ones((config.num_expected_pos_stage2, 1)) * i,
dtype=np.float16) for i in range(self.train_batch_size)]
self.roi_align_index_tensor_pos = Tensor(np.concatenate(roi_align_index_pos))
self.rcnn_loss_cls_weight = Tensor(np.array(config.rcnn_loss_cls_weight).astype(np.float16))
self.rcnn_loss_reg_weight = Tensor(np.array(config.rcnn_loss_reg_weight).astype(np.float16))
self.rcnn_loss_mask_fb_weight = Tensor(np.array(config.rcnn_loss_mask_fb_weight).astype(np.float16))
self.argmax_with_value = P.ArgMaxWithValue(axis=1)
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.onehot = P.OneHot()
self.reducesum = P.ReduceSum()
self.sigmoid = P.Sigmoid()
self.expand_dims = P.ExpandDims()
self.test_mask_fb_zeros = Tensor(np.zeros((self.rpn_max_num, 28, 28)).astype(np.float16))
self.value = Tensor(1.0, mstype.float16)
def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids, gt_masks):
x = self.backbone(img_data)
x = self.fpn_ncek(x)
x = self.fpn_neck(x)
rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss, _ = self.rpn_with_loss(x,
img_metas,
self.anchor_list,
gt_bboxes,
self.gt_labels_stage1,
gt_valids)
rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss, _ = \
self.rpn_with_loss(x, img_metas, self.anchor_list, gt_bboxes, self.gt_labels_stage1, gt_valids)
if self.training:
proposal, proposal_mask = self.proposal_generator(cls_score, bbox_pred, self.anchor_list)
@ -258,23 +206,13 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell):
if self.training:
for i in range(self.train_batch_size):
gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::])
gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::])
gt_labels_i = self.cast(gt_labels_i, mstype.uint8)
gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::])
gt_valids_i = self.cast(gt_valids_i, mstype.bool_)
gt_masks_i = self.squeeze(gt_masks[i:i + 1:1, ::])
gt_masks_i = self.cast(gt_masks_i, mstype.bool_)
gt_labels_i = self.cast(self.squeeze(gt_labels[i:i + 1:1, ::]), self.int_mstype)
gt_valids_i = self.cast(self.squeeze(gt_valids[i:i + 1:1, ::]), mstype.bool_)
gt_masks_i = self.cast(self.squeeze(gt_masks[i:i + 1:1, ::]), mstype.bool_)
bboxes, deltas, labels, mask, pos_bboxes, pos_mask_fb, pos_labels, pos_mask = \
self.bbox_assigner_sampler_for_rcnn(gt_bboxes_i,
gt_labels_i,
proposal_mask[i],
proposal[i][::, 0:4:1],
gt_valids_i,
gt_masks_i)
self.bbox_assigner_sampler_for_rcnn(gt_bboxes_i, gt_labels_i, proposal_mask[i], \
proposal[i][::, 0:4:1], gt_valids_i, gt_masks_i)
bboxes_tuple += (bboxes,)
deltas_tuple += (deltas,)
labels_tuple += (labels,)
@ -288,14 +226,12 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell):
bbox_targets = self.concat(deltas_tuple)
rcnn_labels = self.concat(labels_tuple)
bbox_targets = F.stop_gradient(bbox_targets)
rcnn_labels = F.stop_gradient(rcnn_labels)
rcnn_labels = self.cast(rcnn_labels, mstype.int32)
rcnn_labels = self.cast(F.stop_gradient(rcnn_labels), mstype.int32)
rcnn_pos_masks_fb = self.concat(pos_mask_fb_tuple)
rcnn_pos_masks_fb = F.stop_gradient(rcnn_pos_masks_fb)
rcnn_pos_labels = self.concat(pos_labels_tuple)
rcnn_pos_labels = F.stop_gradient(rcnn_pos_labels)
rcnn_pos_labels = self.cast(rcnn_pos_labels, mstype.int32)
rcnn_pos_labels = self.cast(F.stop_gradient(rcnn_pos_labels), mstype.int32)
else:
mask_tuple += proposal_mask
bbox_targets = proposal_mask
@ -316,8 +252,7 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell):
pos_bboxes_all = pos_bboxes_tuple[0]
rois = self.concat_1((self.roi_align_index_tensor, bboxes_all))
pos_rois = self.concat_1((self.roi_align_index_tensor_pos, pos_bboxes_all))
pos_rois = self.cast(pos_rois, mstype.float32)
pos_rois = F.stop_gradient(pos_rois)
pos_rois = F.stop_gradient(self.cast(pos_rois, mstype.float32))
else:
if self.test_batch_size > 1:
bboxes_all = self.concat(bboxes_tuple)
@ -325,24 +260,17 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell):
bboxes_all = bboxes_tuple[0]
rois = self.concat_1((self.roi_align_index_test_tensor, bboxes_all))
rois = self.cast(rois, mstype.float32)
rois = F.stop_gradient(rois)
rois = F.stop_gradient(self.cast(rois, mstype.float32))
if self.training:
roi_feats = self.roi_align(rois,
self.cast(x[0], mstype.float32),
self.cast(x[1], mstype.float32),
self.cast(x[2], mstype.float32),
self.cast(x[3], mstype.float32))
roi_feats = self.roi_align(rois, self.cast(x[0], mstype.float32), self.cast(x[1], mstype.float32), \
self.cast(x[2], mstype.float32), self.cast(x[3], mstype.float32))
else:
roi_feats = self.roi_align_test(rois,
self.cast(x[0], mstype.float32),
self.cast(x[1], mstype.float32),
self.cast(x[2], mstype.float32),
self.cast(x[3], mstype.float32))
roi_feats = self.roi_align_test(rois, self.cast(x[0], mstype.float32), self.cast(x[1], mstype.float32), \
self.cast(x[2], mstype.float32), self.cast(x[3], mstype.float32))
roi_feats = self.cast(roi_feats, mstype.float16)
roi_feats = self.cast(roi_feats, self.platform_mstype)
rcnn_masks = self.concat(mask_tuple)
rcnn_masks = F.stop_gradient(rcnn_masks)
rcnn_mask_squeeze = self.squeeze(self.cast(rcnn_masks, mstype.bool_))
@ -351,22 +279,15 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell):
rcnn_pos_masks = F.stop_gradient(rcnn_pos_masks)
rcnn_pos_mask_squeeze = self.squeeze(self.cast(rcnn_pos_masks, mstype.bool_))
rcnn_cls_loss, rcnn_reg_loss = self.rcnn_cls(roi_feats,
bbox_targets,
rcnn_labels,
rcnn_mask_squeeze)
rcnn_cls_loss, rcnn_reg_loss = self.rcnn_cls(roi_feats, bbox_targets, rcnn_labels, rcnn_mask_squeeze)
output = ()
if self.training:
roi_feats_mask = self.roi_align_mask(pos_rois,
self.cast(x[0], mstype.float32),
self.cast(x[1], mstype.float32),
self.cast(x[2], mstype.float32),
roi_feats_mask = self.roi_align_mask(pos_rois, self.cast(x[0], mstype.float32),
self.cast(x[1], mstype.float32), self.cast(x[2], mstype.float32),
self.cast(x[3], mstype.float32))
roi_feats_mask = self.cast(roi_feats_mask, mstype.float16)
rcnn_mask_fb_loss = self.rcnn_mask(roi_feats_mask,
rcnn_pos_labels,
rcnn_pos_mask_squeeze,
roi_feats_mask = self.cast(roi_feats_mask, self.platform_mstype)
rcnn_mask_fb_loss = self.rcnn_mask(roi_feats_mask, rcnn_pos_labels, rcnn_pos_mask_squeeze, \
rcnn_pos_masks_fb)
rcnn_loss = self.rcnn_loss_cls_weight * rcnn_cls_loss + self.rcnn_loss_reg_weight * rcnn_reg_loss + \
@ -374,7 +295,7 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell):
output += (rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss, rcnn_cls_loss, rcnn_reg_loss, rcnn_mask_fb_loss)
else:
mask_fb_pred_all = self.rcnn_mask_test(x, bboxes_all, rcnn_cls_loss, rcnn_reg_loss)
output = self.get_det_bboxes(rcnn_cls_loss, rcnn_reg_loss, rcnn_masks, bboxes_all,
output = self.get_det_bboxes(rcnn_cls_loss, rcnn_reg_loss, rcnn_masks, bboxes_all, \
img_metas, mask_fb_pred_all)
return output
@ -526,7 +447,7 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell):
for i in range(num_levels):
anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
multi_level_anchors += (Tensor(anchors.astype(np.float16)),)
multi_level_anchors += (Tensor(anchors.astype(self.platform_dtype)),)
return multi_level_anchors
@ -543,7 +464,7 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell):
for i in range(self.test_batch_size):
cls_score_max_index, _ = self.argmax_with_value(cls_scores_all[i])
cls_score_max_index = self.cast(self.onehot(cls_score_max_index, self.num_classes,
self.on_value, self.off_value), mstype.float16)
self.on_value, self.off_value), self.platform_mstype)
cls_score_max_index = self.expand_dims(cls_score_max_index, -1)
cls_score_max_index = self.tile(cls_score_max_index, (1, 1, 4))
reg_pred_max = reg_pred_all[i] * cls_score_max_index
@ -559,6 +480,47 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell):
self.cast(x[1], mstype.float32),
self.cast(x[2], mstype.float32),
self.cast(x[3], mstype.float32))
roi_feats_mask_test = self.cast(roi_feats_mask_test, mstype.float16)
roi_feats_mask_test = self.cast(roi_feats_mask_test, self.platform_mstype)
mask_fb_pred_all = self.rcnn_mask(roi_feats_mask_test)
return mask_fb_pred_all
def init_datatype(self):
self.platform = context.get_context("device_target")
if self.platform == "CPU":
self.platform_dtype = np.float32
self.platform_mstype = mstype.float32
self.int_dtype = np.int32
self.int_mstype = mstype.int32
else:
self.platform_dtype = np.float16
self.platform_mstype = mstype.float16
self.int_dtype = np.uint8
self.int_mstype = mstype.uint8
def init_tensors(self, config):
roi_align_index = [np.array(np.ones((config.num_expected_pos_stage2 + config.num_expected_neg_stage2, 1)) * i,
dtype=self.platform_dtype) for i in range(self.train_batch_size)]
roi_align_index_test = [np.array(np.ones((config.rpn_max_num, 1)) * i, dtype=self.platform_dtype) \
for i in range(self.test_batch_size)]
self.roi_align_index_tensor = Tensor(np.concatenate(roi_align_index))
self.roi_align_index_test_tensor = Tensor(np.concatenate(roi_align_index_test))
roi_align_index_pos = [np.array(np.ones((config.num_expected_pos_stage2, 1)) * i,
dtype=self.platform_dtype) for i in range(self.train_batch_size)]
self.roi_align_index_tensor_pos = Tensor(np.concatenate(roi_align_index_pos))
self.rcnn_loss_cls_weight = Tensor(np.array(config.rcnn_loss_cls_weight).astype(self.platform_dtype))
self.rcnn_loss_reg_weight = Tensor(np.array(config.rcnn_loss_reg_weight).astype(self.platform_dtype))
self.rcnn_loss_mask_fb_weight = Tensor(np.array(config.rcnn_loss_mask_fb_weight).astype(self.platform_dtype))
self.argmax_with_value = P.ArgMaxWithValue(axis=1)
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.onehot = P.OneHot()
self.reducesum = P.ReduceSum()
self.sigmoid = P.Sigmoid()
self.expand_dims = P.ExpandDims()
self.test_mask_fb_zeros = Tensor(np.zeros((self.rpn_max_num, 28, 28)).astype(self.platform_dtype))
self.value = Tensor(1.0, self.platform_mstype)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-21 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -19,7 +19,7 @@ import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore import context
class Proposal(nn.Cell):
"""
@ -104,6 +104,8 @@ class Proposal(nn.Cell):
self.multi_10 = Tensor(10.0, mstype.float16)
self.platform = context.get_context("device_target")
def set_train_local(self, config, training=True):
"""Set training flag."""
self.training_local = training
@ -174,6 +176,10 @@ class Proposal(nn.Cell):
proposals_decode = self.decode(anchors_sorted, bboxes_sorted)
proposals_decode = self.concat_axis1((proposals_decode, self.reshape(scores_sorted, self.topK_shape[idx])))
if self.platform == "CPU":
proposals_decode = self.cast(proposals_decode, mstype.float32)
proposals, _, mask_valid = self.nms(proposals_decode)
mlvl_proposals = mlvl_proposals + (proposals,)
@ -184,7 +190,10 @@ class Proposal(nn.Cell):
_, _, _, _, scores = self.split(proposals)
scores = self.squeeze(scores)
topk_mask = self.cast(self.topK_mask, mstype.float16)
if self.platform == "CPU":
topk_mask = self.cast(self.topK_mask, mstype.float32)
else:
topk_mask = self.cast(self.topK_mask, mstype.float16)
scores_using = self.select(masks, scores, topk_mask)
_, topk_inds = self.topKv2(scores_using, self.max_num)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-21 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -21,6 +21,7 @@ from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore import context
class DenseNoTranpose(nn.Cell):
"""Dense method"""
@ -40,20 +41,25 @@ class FpnCls(nn.Cell):
"""dense layer of classification and box head"""
def __init__(self, input_channels, output_channels, num_classes, pool_size):
super(FpnCls, self).__init__()
if context.get_context("device_target") == "CPU":
self.platform_mstype = mstype.float32
else:
self.platform_mstype = mstype.float16
representation_size = input_channels * pool_size * pool_size
shape_0 = (output_channels, representation_size)
weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=mstype.float32)
shape_1 = (output_channels, output_channels)
weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=mstype.float32)
self.shared_fc_0 = DenseNoTranpose(representation_size, output_channels, weights_0).to_float(mstype.float16)
self.shared_fc_1 = DenseNoTranpose(output_channels, output_channels, weights_1).to_float(mstype.float16)
self.shared_fc_0 = DenseNoTranpose(representation_size, output_channels, weights_0) \
.to_float(self.platform_mstype)
self.shared_fc_1 = DenseNoTranpose(output_channels, output_channels, weights_1).to_float(self.platform_mstype)
cls_weight = initializer('Normal', shape=[num_classes, output_channels][::-1],
dtype=mstype.float32)
reg_weight = initializer('Normal', shape=[num_classes * 4, output_channels][::-1],
dtype=mstype.float32)
self.cls_scores = DenseNoTranpose(output_channels, num_classes, cls_weight).to_float(mstype.float16)
self.reg_scores = DenseNoTranpose(output_channels, num_classes * 4, reg_weight).to_float(mstype.float16)
self.cls_scores = DenseNoTranpose(output_channels, num_classes, cls_weight).to_float(self.platform_mstype)
self.reg_scores = DenseNoTranpose(output_channels, num_classes * 4, reg_weight).to_float(self.platform_mstype)
self.relu = P.ReLU()
self.flatten = P.Flatten()
@ -99,8 +105,10 @@ class RcnnCls(nn.Cell):
):
super(RcnnCls, self).__init__()
cfg = config
self.rcnn_loss_cls_weight = Tensor(np.array(cfg.rcnn_loss_cls_weight).astype(np.float16))
self.rcnn_loss_reg_weight = Tensor(np.array(cfg.rcnn_loss_reg_weight).astype(np.float16))
if context.get_context("device_target") == "CPU":
self.platform_mstype = mstype.float32
else:
self.platform_mstype = mstype.float16
self.rcnn_fc_out_channels = cfg.rcnn_fc_out_channels
self.target_means = target_means
self.target_stds = target_stds
@ -128,7 +136,6 @@ class RcnnCls(nn.Cell):
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.value = Tensor(1.0, mstype.float16)
self.num_bboxes = (cfg.num_expected_pos_stage2 + cfg.num_expected_neg_stage2) * batch_size
@ -143,7 +150,8 @@ class RcnnCls(nn.Cell):
if self.training:
bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels
labels = self.cast(self.onehot(labels, self.num_classes, self.on_value, self.off_value), mstype.float16)
labels = self.onehot(labels, self.num_classes, self.on_value, self.off_value)
labels = self.cast(labels, self.platform_mstype)
bbox_targets = self.tile(self.expandims(bbox_targets, 1), (1, self.num_classes, 1))
loss_cls, loss_reg = self.loss(x_cls, x_reg,
@ -160,13 +168,13 @@ class RcnnCls(nn.Cell):
"""Loss method."""
# loss_cls
loss_cls, _ = self.loss_cls(cls_score, labels)
weights = self.cast(weights, mstype.float16)
weights = self.cast(weights, self.platform_mstype)
loss_cls = loss_cls * weights
loss_cls = self.sum_loss(loss_cls, (0,)) / self.sum_loss(weights, (0,))
# loss_reg
bbox_weights = self.cast(self.onehot(bbox_weights, self.num_classes, self.on_value, self.off_value),
mstype.float16)
self.platform_mstype)
bbox_weights = bbox_weights * self.rmv_first_tensor # * self.rmv_first_tensor exclude background
pos_bbox_pred = self.reshape(bbox_pred, (self.num_bboxes, -1, 4))
loss_reg = self.loss_bbox(pos_bbox_pred, bbox_targets)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-21 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -20,6 +20,7 @@ import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import initializer
from mindspore import context
def _conv(in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='pad'):
"""Conv2D wrapper."""
@ -45,27 +46,32 @@ class FpnMask(nn.Cell):
"""conv layers of mask head"""
def __init__(self, input_channels, output_channels, num_classes):
super(FpnMask, self).__init__()
self.platform = context.get_context("device_target")
if self.platform == "CPU":
self.platform_mstype = mstype.float32
else:
self.platform_mstype = mstype.float16
self.mask_conv1 = _conv(input_channels, output_channels, kernel_size=3,
pad_mode="same").to_float(mstype.float16)
pad_mode="same").to_float(self.platform_mstype)
self.mask_relu1 = P.ReLU()
self.mask_conv2 = _conv(output_channels, output_channels, kernel_size=3,
pad_mode="same").to_float(mstype.float16)
pad_mode="same").to_float(self.platform_mstype)
self.mask_relu2 = P.ReLU()
self.mask_conv3 = _conv(output_channels, output_channels, kernel_size=3,
pad_mode="same").to_float(mstype.float16)
pad_mode="same").to_float(self.platform_mstype)
self.mask_relu3 = P.ReLU()
self.mask_conv4 = _conv(output_channels, output_channels, kernel_size=3,
pad_mode="same").to_float(mstype.float16)
pad_mode="same").to_float(self.platform_mstype)
self.mask_relu4 = P.ReLU()
self.mask_deconv5 = _convTanspose(output_channels, output_channels, kernel_size=2,
stride=2, pad_mode="valid").to_float(mstype.float16)
stride=2, pad_mode="valid").to_float(self.platform_mstype)
self.mask_relu5 = P.ReLU()
self.mask_conv6 = _conv(output_channels, num_classes, kernel_size=1, stride=1,
pad_mode="valid").to_float(mstype.float16)
pad_mode="valid").to_float(self.platform_mstype)
def construct(self, x):
x = self.mask_conv1(x)
@ -114,6 +120,11 @@ class RcnnMask(nn.Cell):
):
super(RcnnMask, self).__init__()
cfg = config
self.platform = context.get_context("device_target")
if self.platform == "CPU":
self.platform_mstype = mstype.float32
else:
self.platform_mstype = mstype.float16
self.rcnn_loss_mask_fb_weight = Tensor(np.array(cfg.rcnn_loss_mask_fb_weight).astype(np.float16))
self.rcnn_mask_out_channels = cfg.rcnn_mask_out_channels
self.target_means = target_means
@ -130,7 +141,7 @@ class RcnnMask(nn.Cell):
self.cast = P.Cast()
self.sum_loss = P.ReduceSum()
self.tile = P.Tile()
self.expandims = P.ExpandDims()
self.expanddims = P.ExpandDims()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
@ -140,13 +151,14 @@ class RcnnMask(nn.Cell):
rmv_first[:, 0] = np.zeros((self.num_bboxes,))
self.rmv_first_tensor = Tensor(rmv_first.astype(np.float16))
self.mean_loss = P.ReduceMean()
self.maximum = P.Maximum()
def construct(self, mask_featuremap, labels=None, mask=None, mask_fb_targets=None):
x_mask_fb = self.fpn_mask(mask_featuremap)
if self.training:
bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels
mask_fb_targets = self.tile(self.expandims(mask_fb_targets, 1), (1, self.num_classes, 1, 1))
mask_fb_targets = self.tile(self.expanddims(mask_fb_targets, 1), (1, self.num_classes, 1, 1))
loss_mask_fb = self.loss(x_mask_fb, bbox_weights, mask, mask_fb_targets)
out = loss_mask_fb
@ -158,17 +170,21 @@ class RcnnMask(nn.Cell):
def loss(self, masks_fb_pred, bbox_weights, weights, masks_fb_targets):
"""Loss method."""
weights = self.cast(weights, mstype.float16)
weights = self.cast(weights, self.platform_mstype)
bbox_weights = self.cast(self.onehot(bbox_weights, self.num_classes, self.on_value, self.off_value),
mstype.float16)
self.platform_mstype)
bbox_weights = bbox_weights * self.rmv_first_tensor # * self.rmv_first_tensor exclude background
# loss_mask_fb
masks_fb_targets = self.cast(masks_fb_targets, mstype.float16)
masks_fb_targets = self.cast(masks_fb_targets, self.platform_mstype)
loss_mask_fb = self.loss_mask(masks_fb_pred, masks_fb_targets)
loss_mask_fb = self.mean_loss(loss_mask_fb, (2, 3))
loss_mask_fb = loss_mask_fb * bbox_weights
loss_mask_fb = loss_mask_fb / self.sum_loss(weights, (0,))
if self.platform == "CPU":
sum_weight = self.sum_loss(weights, (0,))
loss_mask_fb = loss_mask_fb / self.maximum(self.expanddims(sum_weight, 0), 1)
else:
loss_mask_fb = loss_mask_fb / self.sum_loss(weights, (0,))
loss_mask_fb = self.sum_loss(loss_mask_fb, (0, 1))
return loss_mask_fb

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-21 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -20,6 +20,7 @@ from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore.ops import functional as F
from mindspore.common.initializer import initializer
from mindspore import context
from .bbox_assign_sample import BboxAssignSample
@ -100,6 +101,10 @@ class RPN(nn.Cell):
cls_out_channels):
super(RPN, self).__init__()
cfg_rpn = config
if context.get_context("device_target") == "CPU":
self.platform_mstype = mstype.float32
else:
self.platform_mstype = mstype.float16
self.num_bboxes = cfg_rpn.num_bboxes
self.slice_index = ()
self.feature_anchor_shape = ()
@ -180,7 +185,7 @@ class RPN(nn.Cell):
for i in range(num_layers):
rpn_layer.append(RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \
weight_conv, bias_conv, weight_cls, \
bias_cls, weight_reg, bias_reg).to_float(mstype.float16))
bias_cls, weight_reg, bias_reg).to_float(self.platform_mstype))
for i in range(1, num_layers):
rpn_layer[i].rpn_conv.weight = rpn_layer[0].rpn_conv.weight
@ -248,9 +253,9 @@ class RPN(nn.Cell):
mstype.bool_),
anchor_using_list, gt_valids_i)
bbox_weight = self.cast(bbox_weight, mstype.float16)
label = self.cast(label, mstype.float16)
label_weight = self.cast(label_weight, mstype.float16)
bbox_weight = self.cast(bbox_weight, self.platform_mstype)
label = self.cast(label, self.platform_mstype)
label_weight = self.cast(label_weight, self.platform_mstype)
for j in range(self.num_layers):
begin = self.slice_index[j]

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-21 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -23,6 +23,7 @@ from mindspore.ops import composite as C
from mindspore import ParameterTuple
from mindspore.train.callback import Callback
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore import context
from src.maskrcnn_mobilenetv1.mask_rcnn_mobilenetv1 import Mask_Rcnn_Mobilenetv1
time_stamp_init = False
@ -97,6 +98,8 @@ class LossCallBack(Callback):
time_stamp_current = time.time()
total_loss = self.loss_sum/self.count
print("%lu epoch: %s step: %s total_loss: %.5f" %
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch, total_loss))
loss_file = open("./loss_{}.log".format(self.rank_id), "a+")
loss_file.write("%lu epoch: %s step: %s total_loss: %.5f" %
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
@ -164,7 +167,10 @@ class TrainOneStepCell(nn.Cell):
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True,
sens_param=True)
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float16))
if context.get_context("device_target") == "CPU":
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float32))
else:
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float16))
self.reduce_flag = reduce_flag
self.hyper_map = C.HyperMap()
if reduce_flag:

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-21 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -96,13 +96,15 @@ def modelarts_pre_process():
config.pre_trained = os.path.join(config.output_path, config.pre_trained)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id())
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
if config.device_target == "Ascend":
context.set_context(device_id=config.device_id)
@moxing_wrapper(pre_process=modelarts_pre_process)
def train_maskrcnn_mobilenetv1():
config.mindrecord_dir = os.path.join(config.coco_root, config.mindrecord_dir)
print('config:\n', config)
print("Start train for maskrcnn_mobilenetv1!")
print("Start training for maskrcnn_mobilenetv1!")
if not config.do_eval and config.run_distribute:
rank = get_rank_id()
device_num = get_device_num()