diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/searchsorted_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/searchsorted_cpu_kernel.cc index c05bb28cde9..d08b161dcc4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/searchsorted_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/searchsorted_cpu_kernel.cc @@ -103,7 +103,7 @@ void SearchSortedCPUKernel::CheckParam(const std::vector &inpu } } }; - CPUKernelUtils::ParallelFor(task, static_cast(list_count)); + CPUKernelUtils::ParallelFor(task, IntToSize(list_count)); } } // namespace kernel } // namespace mindspore diff --git a/model_zoo/official/cv/maskrcnn_mobilenetv1/README.md b/model_zoo/official/cv/maskrcnn_mobilenetv1/README.md index a82498b9f1e..a616368eb1c 100644 --- a/model_zoo/official/cv/maskrcnn_mobilenetv1/README.md +++ b/model_zoo/official/cv/maskrcnn_mobilenetv1/README.md @@ -58,8 +58,8 @@ Note that you can run the scripts based on the dataset mentioned in original pap # [Environment Requirements](#contents) -- Hardware(Ascend) - - Prepare hardware environment with Ascend processor. +- Hardware(Ascend/CPU) + - Prepare hardware environment with Ascend or CPU processor. - Framework - [MindSpore](https://gitee.com/mindspore/mindspore) - For more information, please check the resources below: @@ -78,7 +78,7 @@ pip install mmcv=0.2.14 1. Download the dataset COCO2017. -2. Change the COCO_ROOT and other settings you need in `config.py`. The directory structure should look like the follows: +2. Change the COCO_ROOT and other settings you need in `default_config.yaml`. The directory structure should look like the follows: ``` . @@ -90,24 +90,31 @@ pip install mmcv=0.2.14 └─train2017 ``` - If you use your own dataset to train the network, **Select dataset to other when run script.** + If you use your own dataset to train the network, **Select dataset to other when run script.** Create a txt file to store dataset information organized in the way as shown as following: ``` train2017/0000001.jpg 0,259,401,459,7 35,28,324,201,2 0,30,59,80,2 ``` - Each row is an image annotation split by spaces. The first column is a relative path of image, followed by columns containing box and class information in the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `IMAGE_DIR`(dataset directory) and the relative path in `ANNO_PATH`(the TXT file path), which can be set in `config.py`. + Each row is an image annotation split by spaces. The first column is a relative path of image, followed by columns containing box and class information in the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `IMAGE_DIR`(dataset directory) and the relative path in `ANNO_PATH`(the TXT file path), which can be set in `default_config.yaml`. 3. Execute train script. After dataset preparation, you can start training as follows: - ``` + ```bash + On Ascend: + # distributed training bash run_distribute_train.sh [RANK_TABLE_FILE] [PRETRAINED_CKPT] # standalone training bash run_standalone_train.sh [PRETRAINED_CKPT] + + On CPU: + + # standalone training + bash run_standalone_train_cpu.sh [PRETRAINED_PATH](optional) ``` Note: @@ -116,27 +123,32 @@ pip install mmcv=0.2.14 3. For large models like maskrcnn_mobilenetv1, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size. 4. Execute eval script. - After training, you can start evaluation as follows: - ```bash - # Evaluation - bash run_eval.sh [VALIDATION_JSON_FILE] [CHECKPOINT_PATH] - ``` + After training, you can start evaluation as follows: - Note: - 1. VALIDATION_JSON_FILE is a label json file for evaluation. + ```bash + # Evaluation on Ascend + bash run_eval.sh [VALIDATION_JSON_FILE] [CHECKPOINT_PATH] + + # Evaluation on CPU + bash run_eval_cpu.sh [ANN_FILE] [CHECKPOINT_PATH] + ``` + + Note: + 1. VALIDATION_JSON_FILE is a label json file for evaluation. 5. Execute inference script. - After training, you can start inference as follows: - ```shell - # inference - bash run_infer_310.sh [MODEL_PATH] [DATA_PATH] [ANN_FILE_PATH] - ``` + After training, you can start inference as follows: - Note: - 1. MODEL_PATH is a model file, exported by export script file. - 2. ANN_FILE_PATH is a annotation file for inference. + ```shell + # inference + bash run_infer_310.sh [MODEL_PATH] [DATA_PATH] [ANN_FILE_PATH] + ``` + + Note: + 1. MODEL_PATH is a model file, exported by export script file. + 2. ANN_FILE_PATH is a annotation file for inference. - Running on [ModelArts](https://support.huaweicloud.com/modelarts/) @@ -284,14 +296,16 @@ pip install mmcv=0.2.14 ```shell . -└─MaskRcnn +└─MaskRcnn_Mobilenetv1 ├─README.md # README - ├─ascend310_infer #application for 310 inference + ├─ascend310_infer # application for 310 inference ├─scripts # shell script - ├─run_standalone_train.sh # training in standalone mode(1pcs) - ├─run_distribute_train.sh # training in parallel mode(8 pcs) - ├─run_infer_310.sh #shell script for 310 inference - └─run_eval.sh # evaluation + ├─run_standalone_train.sh # training in standalone mode on Ascend(1pcs) + ├─run_standalone_train_cpu.sh # training in standalone mode on CPU(1pcs) + ├─run_distribute_train.sh # training in parallel mode on Ascend(8 pcs) + ├─run_infer_310.sh # shell script for 310 inference + ├─run_eval_cpu.sh # evaluation on CPU + └─run_eval.sh # evaluation on Ascend ├─src ├─maskrcnn_mobilenetv1 ├─__init__.py @@ -306,11 +320,18 @@ pip install mmcv=0.2.14 ├─mobilenetv1.py # backbone network ├─roi_align.py # roi align network └─rpn.py # reagion proposal network - ├─config.py # network configuration + ├─util.py # routine operation + ├─model_utils # network configuration + ├─__init__.py + ├─config.py # network configuration + ├─device_adapter.py # Get cloud ID + ├─local_adapter.py # Get local ID + ├─moxing_adapter.py # Parameter processing ├─dataset.py # dataset utils ├─lr_schedule.py # leanring rate geneatore ├─network_define.py # network define for maskrcnn_mobilenetv1 └─util.py # routine operation + ├─default_config.yaml # default configuration settings ├─mindspore_hub_conf.py # mindspore hub interface ├─export.py #script to export AIR,MINDIR model ├─eval.py # evaluation scripts @@ -323,11 +344,18 @@ pip install mmcv=0.2.14 ### [Training Script Parameters](#contents) ```bash +On Ascend: + # distributed training Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [PRETRAINED_MODEL] # standalone training Usage: bash run_standalone_train.sh [PRETRAINED_MODEL] + +On CPU: + +# standalone training +Usage: bash run_standalone_train_cpu.sh [PRETRAINED_MODEL](optional) ``` ### [Parameters Configuration](#contents) @@ -474,20 +502,27 @@ Usage: bash run_standalone_train.sh [PRETRAINED_MODEL] ## [Training Process](#contents) -- Set options in `config.py`, including loss_scale, learning rate and network hyperparameters. Click [here](https://www.mindspore.cn/docs/programming_guide/en/master/dataset_sample.html) for more information about dataset. +- Set options in `default_config.yaml`, including loss_scale, learning rate and network hyperparameters. Click [here](https://www.mindspore.cn/docs/programming_guide/en/master/dataset_sample.html) for more information about dataset. ### [Training](#content) -- Run `run_standalone_train.sh` for non-distributed training of maskrcnn_mobilenetv1 model. +- Run `run_standalone_train.sh` for non-distributed training of maskrcnn_mobilenetv1 model on Ascend. ```bash # standalone training bash run_standalone_train.sh [PRETRAINED_MODEL] ``` +- Run `run_standalone_train_cpu.sh` for non-distributed training of maskrcnn_mobilenetv1 model on CPU. + +```bash +# standalone training +bash run_standalone_train_cpu.sh [PRETRAINED_MODEL](optional) +``` + ### [Distributed Training](#content) -- Run `run_distribute_train.sh` for distributed training of Mask model. +- Run `run_distribute_train.sh` for distributed training of Mask model on Ascend. ```bash bash run_distribute_train.sh [RANK_TABLE_FILE] [PRETRAINED_MODEL] @@ -526,7 +561,7 @@ bash run_eval.sh [VALIDATION_ANN_FILE_JSON] [CHECKPOINT_PATH] ``` > As for the COCO2017 dataset, VALIDATION_ANN_FILE_JSON is refer to the annotations/instances_val2017.json in the dataset directory. -> checkpoint can be produced and saved in training process, whose folder name begins with "train/checkpoint" or "train_parallel*/checkpoint". +> Checkpoint can be produced and saved in training process, whose folder name begins with "train/checkpoint" or "train_parallel*/checkpoint". ### [Evaluation result](#content) diff --git a/model_zoo/official/cv/maskrcnn_mobilenetv1/eval.py b/model_zoo/official/cv/maskrcnn_mobilenetv1/eval.py index 2fe4998b145..056ede03896 100644 --- a/model_zoo/official/cv/maskrcnn_mobilenetv1/eval.py +++ b/model_zoo/official/cv/maskrcnn_mobilenetv1/eval.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-21 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,7 +31,9 @@ from src.util import coco_eval, bbox2result_1image, results2json, get_seg_masks set_seed(1) -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id()) +context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) +if config.device_target == "Ascend": + context.set_context(device_id=config.device_id) def maskrcnn_eval(dataset_path, ckpt_path, ann_file): """MaskRcnn evaluation.""" diff --git a/model_zoo/official/cv/maskrcnn_mobilenetv1/scripts/run_eval_cpu.sh b/model_zoo/official/cv/maskrcnn_mobilenetv1/scripts/run_eval_cpu.sh new file mode 100644 index 00000000000..94166a19af5 --- /dev/null +++ b/model_zoo/official/cv/maskrcnn_mobilenetv1/scripts/run_eval_cpu.sh @@ -0,0 +1,62 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ] +then + echo "Usage: bash run_eval_cpu.sh [ANN_FILE] [CHECKPOINT_PATH]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} +PATH1=$(get_real_path $1) +PATH2=$(get_real_path $2) +echo $PATH1 +echo $PATH2 + +if [ ! -f $PATH1 ] +then + echo "error: ANN_FILE=$PATH1 is not a file" +exit 1 +fi + +if [ ! -f $PATH2 ] +then + echo "error: CHECKPOINT_PATH=$PATH2 is not a file" +exit 1 +fi + +ulimit -u unlimited + +if [ -d "eval" ]; +then + rm -rf ./eval +fi +mkdir ./eval +cp ../*.py ./eval +cp ../*.yaml ./eval +cp *.sh ./eval +cp -r ../src ./eval +cd ./eval || exit +env > env.log +echo "start eval for CPU" +python eval.py --ann_file=$PATH1 --checkpoint_path=$PATH2 --device_target=CPU > cpu_eval_log.txt 2>&1 & +cd .. diff --git a/model_zoo/official/cv/maskrcnn_mobilenetv1/scripts/run_standalone_train_cpu.sh b/model_zoo/official/cv/maskrcnn_mobilenetv1/scripts/run_standalone_train_cpu.sh new file mode 100644 index 00000000000..57e856a9235 --- /dev/null +++ b/model_zoo/official/cv/maskrcnn_mobilenetv1/scripts/run_standalone_train_cpu.sh @@ -0,0 +1,61 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 0 ] && [ $# != 1 ] +then + echo "Usage: bash run_standalone_train_cpu.sh [PRETRAINED_PATH](optional)" + exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +if [ $# == 1 ] +then + PATH1=$(get_real_path $1) + echo $PATH1 +fi + +ulimit -u unlimited + +if [ -d "train" ]; +then + rm -rf ./train +fi +mkdir ./train +cp ../*.py ./train +cp ../*.yaml ./train +cp *.sh ./train +cp -r ../src ./train +cd ./train || exit +echo "start training for CPU" +env > env.log +if [ $# == 1 ] +then + python train.py --do_train=True --pre_trained=$PATH1 --device_target=CPU > cpu_training_log.txt 2>&1 & +fi + +if [ $# == 0 ] +then + python train.py --do_train=True --pre_trained="" --device_target=CPU > cpu_training_log.txt 2>&1 & +fi + +cd .. diff --git a/model_zoo/official/cv/maskrcnn_mobilenetv1/src/dataset.py b/model_zoo/official/cv/maskrcnn_mobilenetv1/src/dataset.py index 99e4a7d85aa..094a0856fc1 100644 --- a/model_zoo/official/cv/maskrcnn_mobilenetv1/src/dataset.py +++ b/model_zoo/official/cv/maskrcnn_mobilenetv1/src/dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-21 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ from numpy import random import mindspore.dataset as de import mindspore.dataset.vision.c_transforms as C from mindspore.mindrecord import FileWriter +from mindspore import context from src.model_utils.config import config @@ -264,7 +265,7 @@ def impad_to_multiple_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mas def imnormalize_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask): """imnormalize operation for image""" - img_data = mmcv.imnormalize(img, [123.675, 116.28, 103.53], [58.395, 57.12, 57.375], True) + img_data = mmcv.imnormalize(img, np.array([123.675, 116.28, 103.53]), np.array([58.395, 57.12, 57.375]), True) img_data = img_data.astype(np.float32) return (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask) @@ -284,10 +285,15 @@ def flip_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask): def transpose_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask): """transpose operation for image""" + if context.get_context("device_target") == "CPU": + platform_dtype = np.float32 + else: + platform_dtype = np.float16 + img_data = img.transpose(2, 0, 1).copy() - img_data = img_data.astype(np.float16) - img_shape = img_shape.astype(np.float16) - gt_bboxes = gt_bboxes.astype(np.float16) + img_data = img_data.astype(platform_dtype) + img_shape = img_shape.astype(platform_dtype) + gt_bboxes = gt_bboxes.astype(platform_dtype) gt_label = gt_label.astype(np.int32) gt_num = gt_num.astype(np.bool) gt_mask_data = gt_mask.astype(np.bool) diff --git a/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/bbox_assign_sample.py b/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/bbox_assign_sample.py index 537792c79da..ae1477f51f7 100644 --- a/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/bbox_assign_sample.py +++ b/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/bbox_assign_sample.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-21 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ import mindspore.nn as nn from mindspore.ops import operations as P from mindspore.common.tensor import Tensor import mindspore.common.dtype as mstype +from mindspore import context class BboxAssignSample(nn.Cell): @@ -79,7 +80,6 @@ class BboxAssignSample(nn.Cell): self.reshape = P.Reshape() self.equal = P.Equal() self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)) - self.scatterNdUpdate = P.ScatterNdUpdate() self.scatterNd = P.ScatterNd() self.logicalnot = P.LogicalNot() self.tile = P.Tile() @@ -93,8 +93,13 @@ class BboxAssignSample(nn.Cell): self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool)) self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16)) - self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16)) - self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16)) + + if context.get_context("device_target") == "CPU": + self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float32)) + self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float32)) + else: + self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16)) + self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16)) def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids): diff --git a/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/bbox_assign_sample_stage2.py b/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/bbox_assign_sample_stage2.py index 8165fffa1d0..dcb31f4473b 100644 --- a/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/bbox_assign_sample_stage2.py +++ b/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/bbox_assign_sample_stage2.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-21 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ import mindspore.nn as nn import mindspore.common.dtype as mstype from mindspore.ops import operations as P from mindspore.common.tensor import Tensor +from mindspore import context class BboxAssignSampleForRcnn(nn.Cell): """ @@ -78,8 +79,12 @@ class BboxAssignSampleForRcnn(nn.Cell): self.tile = P.Tile() # Check - self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16)) - self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16)) + if context.get_context("device_target") == "CPU": + self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float32)) + self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float32)) + else: + self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16)) + self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16)) # Init tensor self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32)) @@ -91,8 +96,13 @@ class BboxAssignSampleForRcnn(nn.Cell): self.gt_ignores = Tensor(np.array(-1 * np.ones(self.num_gts), dtype=np.int32)) self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16)) self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool)) - self.bboxs_neg_mask = Tensor(np.zeros((self.num_expected_neg, 4), dtype=np.float16)) - self.labels_neg_mask = Tensor(np.array(np.zeros(self.num_expected_neg), dtype=np.uint8)) + + if context.get_context("device_target") == "CPU": + self.bboxs_neg_mask = Tensor(np.zeros((self.num_expected_neg, 4), dtype=np.float32)) + self.labels_neg_mask = Tensor(np.array(np.zeros(self.num_expected_neg), dtype=np.int32)) + else: + self.bboxs_neg_mask = Tensor(np.zeros((self.num_expected_neg, 4), dtype=np.float16)) + self.labels_neg_mask = Tensor(np.array(np.zeros(self.num_expected_neg), dtype=np.uint8)) self.reshape_shape_pos = (self.num_expected_pos, 1) self.reshape_shape_neg = (self.num_expected_neg, 1) diff --git a/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/fpn_neck.py b/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/fpn_neck.py index d40413a622e..649c2ae62fa 100644 --- a/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/fpn_neck.py +++ b/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/fpn_neck.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-21 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ from mindspore.ops import operations as P from mindspore.common.tensor import Tensor from mindspore.common import dtype as mstype from mindspore.common.initializer import initializer +from mindspore import context def bias_init_zeros(shape): @@ -66,6 +67,10 @@ class FeatPyramidNeck(nn.Cell): out_channels, num_outs): super(FeatPyramidNeck, self).__init__() + if context.get_context("device_target") == "CPU": + self.platform_mstype = mstype.float32 + else: + self.platform_mstype = mstype.float16 self.num_outs = num_outs self.in_channels = in_channels self.fpn_layer = len(self.in_channels) @@ -96,9 +101,9 @@ class FeatPyramidNeck(nn.Cell): x += (self.lateral_convs_list[i](inputs[i]),) y = (x[3],) - y = y + (x[2] + self.cast(self.interpolate1(y[self.fpn_layer - 4]), mstype.float16),) - y = y + (x[1] + self.cast(self.interpolate2(y[self.fpn_layer - 3]), mstype.float16),) - y = y + (x[0] + self.cast(self.interpolate3(y[self.fpn_layer - 2]), mstype.float16),) + y = y + (x[2] + self.cast(self.interpolate1(y[self.fpn_layer - 4]), self.platform_mstype),) + y = y + (x[1] + self.cast(self.interpolate2(y[self.fpn_layer - 3]), self.platform_mstype),) + y = y + (x[0] + self.cast(self.interpolate3(y[self.fpn_layer - 2]), self.platform_mstype),) z = () for i in range(self.fpn_layer - 1, -1, -1): diff --git a/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/mask_rcnn_mobilenetv1.py b/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/mask_rcnn_mobilenetv1.py index 7bde4e78568..86efb268d68 100644 --- a/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/mask_rcnn_mobilenetv1.py +++ b/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/mask_rcnn_mobilenetv1.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-21 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ from mindspore.ops import operations as P from mindspore.common.tensor import Tensor import mindspore.common.dtype as mstype from mindspore.ops import functional as F +from mindspore import context from .mobilenetv1 import MobileNetV1_FeatureSelector from .bbox_assign_sample_stage2 import BboxAssignSampleForRcnn from .fpn_neck import FeatPyramidNeck @@ -59,16 +60,15 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell): self.anchor_strides = config.anchor_strides self.target_means = tuple(config.rcnn_target_means) self.target_stds = tuple(config.rcnn_target_stds) + self.init_datatype() # Anchor generator anchor_base_sizes = None - self.anchor_base_sizes = list( - self.anchor_strides) if anchor_base_sizes is None else anchor_base_sizes + self.anchor_base_sizes = list(self.anchor_strides) if anchor_base_sizes is None else anchor_base_sizes self.anchor_generators = [] for anchor_base in self.anchor_base_sizes: - self.anchor_generators.append( - AnchorGenerator(anchor_base, self.anchor_scales, self.anchor_ratios)) + self.anchor_generators.append(AnchorGenerator(anchor_base, self.anchor_scales, self.anchor_ratios)) self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales) @@ -78,30 +78,21 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell): self.anchor_list = self.get_anchors(featmap_sizes) # Backbone mobilenetv1 - self.backbone = MobileNetV1_FeatureSelector(1001, features_only=True).to_float(mstype.float16) + self.backbone = MobileNetV1_FeatureSelector(1001, features_only=True).to_float(self.platform_mstype) # Fpn - self.fpn_ncek = FeatPyramidNeck(config.fpn_in_channels, - config.fpn_out_channels, - config.fpn_num_outs) + self.fpn_neck = FeatPyramidNeck(config.fpn_in_channels, config.fpn_out_channels, config.fpn_num_outs) # Rpn and rpn loss - self.gt_labels_stage1 = Tensor(np.ones((self.train_batch_size, config.num_gts)).astype(np.uint8)) - self.rpn_with_loss = RPN(config, - self.train_batch_size, - config.rpn_in_channels, - config.rpn_feat_channels, - config.num_anchors, - config.rpn_cls_out_channels) + self.gt_labels_stage1 = Tensor(np.ones((self.train_batch_size, config.num_gts)).astype(self.int_dtype)) + + self.rpn_with_loss = RPN(config, self.train_batch_size, config.rpn_in_channels, config.rpn_feat_channels, + config.num_anchors, config.rpn_cls_out_channels) # Proposal - self.proposal_generator = Proposal(config, - self.train_batch_size, - config.activate_num_classes, + self.proposal_generator = Proposal(config, self.train_batch_size, config.activate_num_classes, config.use_sigmoid_cls) self.proposal_generator.set_train_local(config, True) - self.proposal_generator_test = Proposal(config, - config.test_batch_size, - config.activate_num_classes, + self.proposal_generator_test = Proposal(config, config.test_batch_size, config.activate_num_classes, config.use_sigmoid_cls) self.proposal_generator_test.set_train_local(config, False) @@ -112,40 +103,24 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell): stds=self.target_stds) # Roi - self.roi_align = SingleRoIExtractor(config, - config.roi_layer, - config.roi_align_out_channels, - config.roi_align_featmap_strides, - self.train_batch_size, - config.roi_align_finest_scale, - mask=False) + self.roi_align = SingleRoIExtractor(config, config.roi_layer, config.roi_align_out_channels, + config.roi_align_featmap_strides, self.train_batch_size, + config.roi_align_finest_scale, mask=False) self.roi_align.set_train_local(config, True) - self.roi_align_mask = SingleRoIExtractor(config, - config.roi_layer, - config.roi_align_out_channels, - config.roi_align_featmap_strides, - self.train_batch_size, - config.roi_align_finest_scale, - mask=True) + self.roi_align_mask = SingleRoIExtractor(config, config.roi_layer, config.roi_align_out_channels, + config.roi_align_featmap_strides, self.train_batch_size, + config.roi_align_finest_scale, mask=True) self.roi_align_mask.set_train_local(config, True) - self.roi_align_test = SingleRoIExtractor(config, - config.roi_layer, - config.roi_align_out_channels, - config.roi_align_featmap_strides, - 1, - config.roi_align_finest_scale, + self.roi_align_test = SingleRoIExtractor(config, config.roi_layer, config.roi_align_out_channels, + config.roi_align_featmap_strides, 1, config.roi_align_finest_scale, mask=False) self.roi_align_test.set_train_local(config, False) - self.roi_align_mask_test = SingleRoIExtractor(config, - config.roi_layer, - config.roi_align_out_channels, - config.roi_align_featmap_strides, - 1, - config.roi_align_finest_scale, - mask=True) + self.roi_align_mask_test = SingleRoIExtractor(config, config.roi_layer, config.roi_align_out_channels, + config.roi_align_featmap_strides, 1, + config.roi_align_finest_scale, mask=True) self.roi_align_mask_test.set_train_local(config, False) # Rcnn @@ -176,7 +151,7 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell): self.rpn_max_num = config.rpn_max_num - self.zeros_for_nms = Tensor(np.zeros((self.rpn_max_num, 3)).astype(np.float16)) + self.zeros_for_nms = Tensor(np.zeros((self.rpn_max_num, 3)).astype(self.platform_dtype)) self.ones_mask = np.ones((self.rpn_max_num, 1)).astype(np.bool) self.zeros_mask = np.zeros((self.rpn_max_num, 1)).astype(np.bool) self.bbox_mask = Tensor(np.concatenate((self.ones_mask, self.zeros_mask, @@ -184,10 +159,11 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell): self.nms_pad_mask = Tensor(np.concatenate((self.ones_mask, self.ones_mask, self.ones_mask, self.ones_mask, self.zeros_mask), axis=1)) - self.test_score_thresh = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * config.test_score_thr) - self.test_score_zeros = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * 0) - self.test_box_zeros = Tensor(np.ones((self.rpn_max_num, 4)).astype(np.float16) * -1) - self.test_iou_thr = Tensor(np.ones((self.rpn_max_num, 1)).astype(np.float16) * config.test_iou_thr) + self.test_score_thresh = Tensor(np.ones((self.rpn_max_num, 1)).astype(self.platform_dtype) + * config.test_score_thr) + self.test_score_zeros = Tensor(np.ones((self.rpn_max_num, 1)).astype(self.platform_dtype) * 0) + self.test_box_zeros = Tensor(np.ones((self.rpn_max_num, 4)).astype(self.platform_dtype) * -1) + self.test_iou_thr = Tensor(np.ones((self.rpn_max_num, 1)).astype(self.platform_dtype) * config.test_iou_thr) self.test_max_per_img = config.test_max_per_img self.nms_test = P.NMSWithMask(config.test_iou_thr) self.softmax = P.Softmax(axis=1) @@ -201,42 +177,14 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell): self.concat_end = (self.num_classes - 1) # Init tensor - roi_align_index = [np.array(np.ones((config.num_expected_pos_stage2 + config.num_expected_neg_stage2, 1)) * i, - dtype=np.float16) for i in range(self.train_batch_size)] + self.init_tensors(config) - roi_align_index_test = [np.array(np.ones((config.rpn_max_num, 1)) * i, dtype=np.float16) \ - for i in range(self.test_batch_size)] - - self.roi_align_index_tensor = Tensor(np.concatenate(roi_align_index)) - self.roi_align_index_test_tensor = Tensor(np.concatenate(roi_align_index_test)) - - roi_align_index_pos = [np.array(np.ones((config.num_expected_pos_stage2, 1)) * i, - dtype=np.float16) for i in range(self.train_batch_size)] - self.roi_align_index_tensor_pos = Tensor(np.concatenate(roi_align_index_pos)) - - self.rcnn_loss_cls_weight = Tensor(np.array(config.rcnn_loss_cls_weight).astype(np.float16)) - self.rcnn_loss_reg_weight = Tensor(np.array(config.rcnn_loss_reg_weight).astype(np.float16)) - self.rcnn_loss_mask_fb_weight = Tensor(np.array(config.rcnn_loss_mask_fb_weight).astype(np.float16)) - - self.argmax_with_value = P.ArgMaxWithValue(axis=1) - self.on_value = Tensor(1.0, mstype.float32) - self.off_value = Tensor(0.0, mstype.float32) - self.onehot = P.OneHot() - self.reducesum = P.ReduceSum() - self.sigmoid = P.Sigmoid() - self.expand_dims = P.ExpandDims() - self.test_mask_fb_zeros = Tensor(np.zeros((self.rpn_max_num, 28, 28)).astype(np.float16)) - self.value = Tensor(1.0, mstype.float16) def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids, gt_masks): x = self.backbone(img_data) - x = self.fpn_ncek(x) + x = self.fpn_neck(x) - rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss, _ = self.rpn_with_loss(x, - img_metas, - self.anchor_list, - gt_bboxes, - self.gt_labels_stage1, - gt_valids) + rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss, _ = \ + self.rpn_with_loss(x, img_metas, self.anchor_list, gt_bboxes, self.gt_labels_stage1, gt_valids) if self.training: proposal, proposal_mask = self.proposal_generator(cls_score, bbox_pred, self.anchor_list) @@ -258,23 +206,13 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell): if self.training: for i in range(self.train_batch_size): gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::]) - - gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::]) - gt_labels_i = self.cast(gt_labels_i, mstype.uint8) - - gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::]) - gt_valids_i = self.cast(gt_valids_i, mstype.bool_) - - gt_masks_i = self.squeeze(gt_masks[i:i + 1:1, ::]) - gt_masks_i = self.cast(gt_masks_i, mstype.bool_) + gt_labels_i = self.cast(self.squeeze(gt_labels[i:i + 1:1, ::]), self.int_mstype) + gt_valids_i = self.cast(self.squeeze(gt_valids[i:i + 1:1, ::]), mstype.bool_) + gt_masks_i = self.cast(self.squeeze(gt_masks[i:i + 1:1, ::]), mstype.bool_) bboxes, deltas, labels, mask, pos_bboxes, pos_mask_fb, pos_labels, pos_mask = \ - self.bbox_assigner_sampler_for_rcnn(gt_bboxes_i, - gt_labels_i, - proposal_mask[i], - proposal[i][::, 0:4:1], - gt_valids_i, - gt_masks_i) + self.bbox_assigner_sampler_for_rcnn(gt_bboxes_i, gt_labels_i, proposal_mask[i], \ + proposal[i][::, 0:4:1], gt_valids_i, gt_masks_i) bboxes_tuple += (bboxes,) deltas_tuple += (deltas,) labels_tuple += (labels,) @@ -288,14 +226,12 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell): bbox_targets = self.concat(deltas_tuple) rcnn_labels = self.concat(labels_tuple) bbox_targets = F.stop_gradient(bbox_targets) - rcnn_labels = F.stop_gradient(rcnn_labels) - rcnn_labels = self.cast(rcnn_labels, mstype.int32) + rcnn_labels = self.cast(F.stop_gradient(rcnn_labels), mstype.int32) rcnn_pos_masks_fb = self.concat(pos_mask_fb_tuple) rcnn_pos_masks_fb = F.stop_gradient(rcnn_pos_masks_fb) rcnn_pos_labels = self.concat(pos_labels_tuple) - rcnn_pos_labels = F.stop_gradient(rcnn_pos_labels) - rcnn_pos_labels = self.cast(rcnn_pos_labels, mstype.int32) + rcnn_pos_labels = self.cast(F.stop_gradient(rcnn_pos_labels), mstype.int32) else: mask_tuple += proposal_mask bbox_targets = proposal_mask @@ -316,8 +252,7 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell): pos_bboxes_all = pos_bboxes_tuple[0] rois = self.concat_1((self.roi_align_index_tensor, bboxes_all)) pos_rois = self.concat_1((self.roi_align_index_tensor_pos, pos_bboxes_all)) - pos_rois = self.cast(pos_rois, mstype.float32) - pos_rois = F.stop_gradient(pos_rois) + pos_rois = F.stop_gradient(self.cast(pos_rois, mstype.float32)) else: if self.test_batch_size > 1: bboxes_all = self.concat(bboxes_tuple) @@ -325,24 +260,17 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell): bboxes_all = bboxes_tuple[0] rois = self.concat_1((self.roi_align_index_test_tensor, bboxes_all)) - rois = self.cast(rois, mstype.float32) - rois = F.stop_gradient(rois) + rois = F.stop_gradient(self.cast(rois, mstype.float32)) if self.training: - roi_feats = self.roi_align(rois, - self.cast(x[0], mstype.float32), - self.cast(x[1], mstype.float32), - self.cast(x[2], mstype.float32), - self.cast(x[3], mstype.float32)) + roi_feats = self.roi_align(rois, self.cast(x[0], mstype.float32), self.cast(x[1], mstype.float32), \ + self.cast(x[2], mstype.float32), self.cast(x[3], mstype.float32)) else: - roi_feats = self.roi_align_test(rois, - self.cast(x[0], mstype.float32), - self.cast(x[1], mstype.float32), - self.cast(x[2], mstype.float32), - self.cast(x[3], mstype.float32)) + roi_feats = self.roi_align_test(rois, self.cast(x[0], mstype.float32), self.cast(x[1], mstype.float32), \ + self.cast(x[2], mstype.float32), self.cast(x[3], mstype.float32)) - roi_feats = self.cast(roi_feats, mstype.float16) + roi_feats = self.cast(roi_feats, self.platform_mstype) rcnn_masks = self.concat(mask_tuple) rcnn_masks = F.stop_gradient(rcnn_masks) rcnn_mask_squeeze = self.squeeze(self.cast(rcnn_masks, mstype.bool_)) @@ -351,22 +279,15 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell): rcnn_pos_masks = F.stop_gradient(rcnn_pos_masks) rcnn_pos_mask_squeeze = self.squeeze(self.cast(rcnn_pos_masks, mstype.bool_)) - rcnn_cls_loss, rcnn_reg_loss = self.rcnn_cls(roi_feats, - bbox_targets, - rcnn_labels, - rcnn_mask_squeeze) + rcnn_cls_loss, rcnn_reg_loss = self.rcnn_cls(roi_feats, bbox_targets, rcnn_labels, rcnn_mask_squeeze) output = () if self.training: - roi_feats_mask = self.roi_align_mask(pos_rois, - self.cast(x[0], mstype.float32), - self.cast(x[1], mstype.float32), - self.cast(x[2], mstype.float32), + roi_feats_mask = self.roi_align_mask(pos_rois, self.cast(x[0], mstype.float32), + self.cast(x[1], mstype.float32), self.cast(x[2], mstype.float32), self.cast(x[3], mstype.float32)) - roi_feats_mask = self.cast(roi_feats_mask, mstype.float16) - rcnn_mask_fb_loss = self.rcnn_mask(roi_feats_mask, - rcnn_pos_labels, - rcnn_pos_mask_squeeze, + roi_feats_mask = self.cast(roi_feats_mask, self.platform_mstype) + rcnn_mask_fb_loss = self.rcnn_mask(roi_feats_mask, rcnn_pos_labels, rcnn_pos_mask_squeeze, \ rcnn_pos_masks_fb) rcnn_loss = self.rcnn_loss_cls_weight * rcnn_cls_loss + self.rcnn_loss_reg_weight * rcnn_reg_loss + \ @@ -374,7 +295,7 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell): output += (rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss, rcnn_cls_loss, rcnn_reg_loss, rcnn_mask_fb_loss) else: mask_fb_pred_all = self.rcnn_mask_test(x, bboxes_all, rcnn_cls_loss, rcnn_reg_loss) - output = self.get_det_bboxes(rcnn_cls_loss, rcnn_reg_loss, rcnn_masks, bboxes_all, + output = self.get_det_bboxes(rcnn_cls_loss, rcnn_reg_loss, rcnn_masks, bboxes_all, \ img_metas, mask_fb_pred_all) return output @@ -526,7 +447,7 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell): for i in range(num_levels): anchors = self.anchor_generators[i].grid_anchors( featmap_sizes[i], self.anchor_strides[i]) - multi_level_anchors += (Tensor(anchors.astype(np.float16)),) + multi_level_anchors += (Tensor(anchors.astype(self.platform_dtype)),) return multi_level_anchors @@ -543,7 +464,7 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell): for i in range(self.test_batch_size): cls_score_max_index, _ = self.argmax_with_value(cls_scores_all[i]) cls_score_max_index = self.cast(self.onehot(cls_score_max_index, self.num_classes, - self.on_value, self.off_value), mstype.float16) + self.on_value, self.off_value), self.platform_mstype) cls_score_max_index = self.expand_dims(cls_score_max_index, -1) cls_score_max_index = self.tile(cls_score_max_index, (1, 1, 4)) reg_pred_max = reg_pred_all[i] * cls_score_max_index @@ -559,6 +480,47 @@ class Mask_Rcnn_Mobilenetv1(nn.Cell): self.cast(x[1], mstype.float32), self.cast(x[2], mstype.float32), self.cast(x[3], mstype.float32)) - roi_feats_mask_test = self.cast(roi_feats_mask_test, mstype.float16) + roi_feats_mask_test = self.cast(roi_feats_mask_test, self.platform_mstype) mask_fb_pred_all = self.rcnn_mask(roi_feats_mask_test) return mask_fb_pred_all + + def init_datatype(self): + self.platform = context.get_context("device_target") + if self.platform == "CPU": + self.platform_dtype = np.float32 + self.platform_mstype = mstype.float32 + self.int_dtype = np.int32 + self.int_mstype = mstype.int32 + else: + self.platform_dtype = np.float16 + self.platform_mstype = mstype.float16 + self.int_dtype = np.uint8 + self.int_mstype = mstype.uint8 + + def init_tensors(self, config): + roi_align_index = [np.array(np.ones((config.num_expected_pos_stage2 + config.num_expected_neg_stage2, 1)) * i, + dtype=self.platform_dtype) for i in range(self.train_batch_size)] + + roi_align_index_test = [np.array(np.ones((config.rpn_max_num, 1)) * i, dtype=self.platform_dtype) \ + for i in range(self.test_batch_size)] + + self.roi_align_index_tensor = Tensor(np.concatenate(roi_align_index)) + self.roi_align_index_test_tensor = Tensor(np.concatenate(roi_align_index_test)) + + roi_align_index_pos = [np.array(np.ones((config.num_expected_pos_stage2, 1)) * i, + dtype=self.platform_dtype) for i in range(self.train_batch_size)] + self.roi_align_index_tensor_pos = Tensor(np.concatenate(roi_align_index_pos)) + + self.rcnn_loss_cls_weight = Tensor(np.array(config.rcnn_loss_cls_weight).astype(self.platform_dtype)) + self.rcnn_loss_reg_weight = Tensor(np.array(config.rcnn_loss_reg_weight).astype(self.platform_dtype)) + self.rcnn_loss_mask_fb_weight = Tensor(np.array(config.rcnn_loss_mask_fb_weight).astype(self.platform_dtype)) + + self.argmax_with_value = P.ArgMaxWithValue(axis=1) + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.onehot = P.OneHot() + self.reducesum = P.ReduceSum() + self.sigmoid = P.Sigmoid() + self.expand_dims = P.ExpandDims() + self.test_mask_fb_zeros = Tensor(np.zeros((self.rpn_max_num, 28, 28)).astype(self.platform_dtype)) + self.value = Tensor(1.0, self.platform_mstype) diff --git a/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/proposal_generator.py b/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/proposal_generator.py index 3c7ae5f7d93..d32223cdb55 100644 --- a/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/proposal_generator.py +++ b/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/proposal_generator.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-21 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ import mindspore.nn as nn import mindspore.common.dtype as mstype from mindspore.ops import operations as P from mindspore import Tensor - +from mindspore import context class Proposal(nn.Cell): """ @@ -104,6 +104,8 @@ class Proposal(nn.Cell): self.multi_10 = Tensor(10.0, mstype.float16) + self.platform = context.get_context("device_target") + def set_train_local(self, config, training=True): """Set training flag.""" self.training_local = training @@ -174,6 +176,10 @@ class Proposal(nn.Cell): proposals_decode = self.decode(anchors_sorted, bboxes_sorted) proposals_decode = self.concat_axis1((proposals_decode, self.reshape(scores_sorted, self.topK_shape[idx]))) + + if self.platform == "CPU": + proposals_decode = self.cast(proposals_decode, mstype.float32) + proposals, _, mask_valid = self.nms(proposals_decode) mlvl_proposals = mlvl_proposals + (proposals,) @@ -184,7 +190,10 @@ class Proposal(nn.Cell): _, _, _, _, scores = self.split(proposals) scores = self.squeeze(scores) - topk_mask = self.cast(self.topK_mask, mstype.float16) + if self.platform == "CPU": + topk_mask = self.cast(self.topK_mask, mstype.float32) + else: + topk_mask = self.cast(self.topK_mask, mstype.float16) scores_using = self.select(masks, scores, topk_mask) _, topk_inds = self.topKv2(scores_using, self.max_num) diff --git a/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/rcnn_cls.py b/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/rcnn_cls.py index d96c2461632..6b35ab3222e 100644 --- a/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/rcnn_cls.py +++ b/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/rcnn_cls.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-21 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ from mindspore.ops import operations as P from mindspore.common.tensor import Tensor from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter +from mindspore import context class DenseNoTranpose(nn.Cell): """Dense method""" @@ -40,20 +41,25 @@ class FpnCls(nn.Cell): """dense layer of classification and box head""" def __init__(self, input_channels, output_channels, num_classes, pool_size): super(FpnCls, self).__init__() + if context.get_context("device_target") == "CPU": + self.platform_mstype = mstype.float32 + else: + self.platform_mstype = mstype.float16 representation_size = input_channels * pool_size * pool_size shape_0 = (output_channels, representation_size) weights_0 = initializer("XavierUniform", shape=shape_0[::-1], dtype=mstype.float32) shape_1 = (output_channels, output_channels) weights_1 = initializer("XavierUniform", shape=shape_1[::-1], dtype=mstype.float32) - self.shared_fc_0 = DenseNoTranpose(representation_size, output_channels, weights_0).to_float(mstype.float16) - self.shared_fc_1 = DenseNoTranpose(output_channels, output_channels, weights_1).to_float(mstype.float16) + self.shared_fc_0 = DenseNoTranpose(representation_size, output_channels, weights_0) \ + .to_float(self.platform_mstype) + self.shared_fc_1 = DenseNoTranpose(output_channels, output_channels, weights_1).to_float(self.platform_mstype) cls_weight = initializer('Normal', shape=[num_classes, output_channels][::-1], dtype=mstype.float32) reg_weight = initializer('Normal', shape=[num_classes * 4, output_channels][::-1], dtype=mstype.float32) - self.cls_scores = DenseNoTranpose(output_channels, num_classes, cls_weight).to_float(mstype.float16) - self.reg_scores = DenseNoTranpose(output_channels, num_classes * 4, reg_weight).to_float(mstype.float16) + self.cls_scores = DenseNoTranpose(output_channels, num_classes, cls_weight).to_float(self.platform_mstype) + self.reg_scores = DenseNoTranpose(output_channels, num_classes * 4, reg_weight).to_float(self.platform_mstype) self.relu = P.ReLU() self.flatten = P.Flatten() @@ -99,8 +105,10 @@ class RcnnCls(nn.Cell): ): super(RcnnCls, self).__init__() cfg = config - self.rcnn_loss_cls_weight = Tensor(np.array(cfg.rcnn_loss_cls_weight).astype(np.float16)) - self.rcnn_loss_reg_weight = Tensor(np.array(cfg.rcnn_loss_reg_weight).astype(np.float16)) + if context.get_context("device_target") == "CPU": + self.platform_mstype = mstype.float32 + else: + self.platform_mstype = mstype.float16 self.rcnn_fc_out_channels = cfg.rcnn_fc_out_channels self.target_means = target_means self.target_stds = target_stds @@ -128,7 +136,6 @@ class RcnnCls(nn.Cell): self.on_value = Tensor(1.0, mstype.float32) self.off_value = Tensor(0.0, mstype.float32) - self.value = Tensor(1.0, mstype.float16) self.num_bboxes = (cfg.num_expected_pos_stage2 + cfg.num_expected_neg_stage2) * batch_size @@ -143,7 +150,8 @@ class RcnnCls(nn.Cell): if self.training: bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels - labels = self.cast(self.onehot(labels, self.num_classes, self.on_value, self.off_value), mstype.float16) + labels = self.onehot(labels, self.num_classes, self.on_value, self.off_value) + labels = self.cast(labels, self.platform_mstype) bbox_targets = self.tile(self.expandims(bbox_targets, 1), (1, self.num_classes, 1)) loss_cls, loss_reg = self.loss(x_cls, x_reg, @@ -160,13 +168,13 @@ class RcnnCls(nn.Cell): """Loss method.""" # loss_cls loss_cls, _ = self.loss_cls(cls_score, labels) - weights = self.cast(weights, mstype.float16) + weights = self.cast(weights, self.platform_mstype) loss_cls = loss_cls * weights loss_cls = self.sum_loss(loss_cls, (0,)) / self.sum_loss(weights, (0,)) # loss_reg bbox_weights = self.cast(self.onehot(bbox_weights, self.num_classes, self.on_value, self.off_value), - mstype.float16) + self.platform_mstype) bbox_weights = bbox_weights * self.rmv_first_tensor # * self.rmv_first_tensor exclude background pos_bbox_pred = self.reshape(bbox_pred, (self.num_bboxes, -1, 4)) loss_reg = self.loss_bbox(pos_bbox_pred, bbox_targets) diff --git a/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/rcnn_mask.py b/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/rcnn_mask.py index 08e4f9c3e6d..93cc2b9ef41 100644 --- a/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/rcnn_mask.py +++ b/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/rcnn_mask.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-21 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import mindspore.nn as nn from mindspore.ops import operations as P from mindspore.common.tensor import Tensor from mindspore.common.initializer import initializer +from mindspore import context def _conv(in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mode='pad'): """Conv2D wrapper.""" @@ -45,27 +46,32 @@ class FpnMask(nn.Cell): """conv layers of mask head""" def __init__(self, input_channels, output_channels, num_classes): super(FpnMask, self).__init__() + self.platform = context.get_context("device_target") + if self.platform == "CPU": + self.platform_mstype = mstype.float32 + else: + self.platform_mstype = mstype.float16 self.mask_conv1 = _conv(input_channels, output_channels, kernel_size=3, - pad_mode="same").to_float(mstype.float16) + pad_mode="same").to_float(self.platform_mstype) self.mask_relu1 = P.ReLU() self.mask_conv2 = _conv(output_channels, output_channels, kernel_size=3, - pad_mode="same").to_float(mstype.float16) + pad_mode="same").to_float(self.platform_mstype) self.mask_relu2 = P.ReLU() self.mask_conv3 = _conv(output_channels, output_channels, kernel_size=3, - pad_mode="same").to_float(mstype.float16) + pad_mode="same").to_float(self.platform_mstype) self.mask_relu3 = P.ReLU() self.mask_conv4 = _conv(output_channels, output_channels, kernel_size=3, - pad_mode="same").to_float(mstype.float16) + pad_mode="same").to_float(self.platform_mstype) self.mask_relu4 = P.ReLU() self.mask_deconv5 = _convTanspose(output_channels, output_channels, kernel_size=2, - stride=2, pad_mode="valid").to_float(mstype.float16) + stride=2, pad_mode="valid").to_float(self.platform_mstype) self.mask_relu5 = P.ReLU() self.mask_conv6 = _conv(output_channels, num_classes, kernel_size=1, stride=1, - pad_mode="valid").to_float(mstype.float16) + pad_mode="valid").to_float(self.platform_mstype) def construct(self, x): x = self.mask_conv1(x) @@ -114,6 +120,11 @@ class RcnnMask(nn.Cell): ): super(RcnnMask, self).__init__() cfg = config + self.platform = context.get_context("device_target") + if self.platform == "CPU": + self.platform_mstype = mstype.float32 + else: + self.platform_mstype = mstype.float16 self.rcnn_loss_mask_fb_weight = Tensor(np.array(cfg.rcnn_loss_mask_fb_weight).astype(np.float16)) self.rcnn_mask_out_channels = cfg.rcnn_mask_out_channels self.target_means = target_means @@ -130,7 +141,7 @@ class RcnnMask(nn.Cell): self.cast = P.Cast() self.sum_loss = P.ReduceSum() self.tile = P.Tile() - self.expandims = P.ExpandDims() + self.expanddims = P.ExpandDims() self.on_value = Tensor(1.0, mstype.float32) self.off_value = Tensor(0.0, mstype.float32) @@ -140,13 +151,14 @@ class RcnnMask(nn.Cell): rmv_first[:, 0] = np.zeros((self.num_bboxes,)) self.rmv_first_tensor = Tensor(rmv_first.astype(np.float16)) self.mean_loss = P.ReduceMean() + self.maximum = P.Maximum() def construct(self, mask_featuremap, labels=None, mask=None, mask_fb_targets=None): x_mask_fb = self.fpn_mask(mask_featuremap) if self.training: bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels - mask_fb_targets = self.tile(self.expandims(mask_fb_targets, 1), (1, self.num_classes, 1, 1)) + mask_fb_targets = self.tile(self.expanddims(mask_fb_targets, 1), (1, self.num_classes, 1, 1)) loss_mask_fb = self.loss(x_mask_fb, bbox_weights, mask, mask_fb_targets) out = loss_mask_fb @@ -158,17 +170,21 @@ class RcnnMask(nn.Cell): def loss(self, masks_fb_pred, bbox_weights, weights, masks_fb_targets): """Loss method.""" - weights = self.cast(weights, mstype.float16) + weights = self.cast(weights, self.platform_mstype) bbox_weights = self.cast(self.onehot(bbox_weights, self.num_classes, self.on_value, self.off_value), - mstype.float16) + self.platform_mstype) bbox_weights = bbox_weights * self.rmv_first_tensor # * self.rmv_first_tensor exclude background # loss_mask_fb - masks_fb_targets = self.cast(masks_fb_targets, mstype.float16) + masks_fb_targets = self.cast(masks_fb_targets, self.platform_mstype) loss_mask_fb = self.loss_mask(masks_fb_pred, masks_fb_targets) loss_mask_fb = self.mean_loss(loss_mask_fb, (2, 3)) loss_mask_fb = loss_mask_fb * bbox_weights - loss_mask_fb = loss_mask_fb / self.sum_loss(weights, (0,)) + if self.platform == "CPU": + sum_weight = self.sum_loss(weights, (0,)) + loss_mask_fb = loss_mask_fb / self.maximum(self.expanddims(sum_weight, 0), 1) + else: + loss_mask_fb = loss_mask_fb / self.sum_loss(weights, (0,)) loss_mask_fb = self.sum_loss(loss_mask_fb, (0, 1)) return loss_mask_fb diff --git a/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/rpn.py b/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/rpn.py index b7effb3d1bb..5ab88584c5c 100644 --- a/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/rpn.py +++ b/model_zoo/official/cv/maskrcnn_mobilenetv1/src/maskrcnn_mobilenetv1/rpn.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-21 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ from mindspore.ops import operations as P from mindspore import Tensor from mindspore.ops import functional as F from mindspore.common.initializer import initializer +from mindspore import context from .bbox_assign_sample import BboxAssignSample @@ -100,6 +101,10 @@ class RPN(nn.Cell): cls_out_channels): super(RPN, self).__init__() cfg_rpn = config + if context.get_context("device_target") == "CPU": + self.platform_mstype = mstype.float32 + else: + self.platform_mstype = mstype.float16 self.num_bboxes = cfg_rpn.num_bboxes self.slice_index = () self.feature_anchor_shape = () @@ -180,7 +185,7 @@ class RPN(nn.Cell): for i in range(num_layers): rpn_layer.append(RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \ weight_conv, bias_conv, weight_cls, \ - bias_cls, weight_reg, bias_reg).to_float(mstype.float16)) + bias_cls, weight_reg, bias_reg).to_float(self.platform_mstype)) for i in range(1, num_layers): rpn_layer[i].rpn_conv.weight = rpn_layer[0].rpn_conv.weight @@ -248,9 +253,9 @@ class RPN(nn.Cell): mstype.bool_), anchor_using_list, gt_valids_i) - bbox_weight = self.cast(bbox_weight, mstype.float16) - label = self.cast(label, mstype.float16) - label_weight = self.cast(label_weight, mstype.float16) + bbox_weight = self.cast(bbox_weight, self.platform_mstype) + label = self.cast(label, self.platform_mstype) + label_weight = self.cast(label_weight, self.platform_mstype) for j in range(self.num_layers): begin = self.slice_index[j] diff --git a/model_zoo/official/cv/maskrcnn_mobilenetv1/src/network_define.py b/model_zoo/official/cv/maskrcnn_mobilenetv1/src/network_define.py index 4c5b4a89b45..45e2773bcc3 100644 --- a/model_zoo/official/cv/maskrcnn_mobilenetv1/src/network_define.py +++ b/model_zoo/official/cv/maskrcnn_mobilenetv1/src/network_define.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-21 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ from mindspore.ops import composite as C from mindspore import ParameterTuple from mindspore.train.callback import Callback from mindspore.nn.wrap.grad_reducer import DistributedGradReducer +from mindspore import context from src.maskrcnn_mobilenetv1.mask_rcnn_mobilenetv1 import Mask_Rcnn_Mobilenetv1 time_stamp_init = False @@ -97,6 +98,8 @@ class LossCallBack(Callback): time_stamp_current = time.time() total_loss = self.loss_sum/self.count + print("%lu epoch: %s step: %s total_loss: %.5f" % + (time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch, total_loss)) loss_file = open("./loss_{}.log".format(self.rank_id), "a+") loss_file.write("%lu epoch: %s step: %s total_loss: %.5f" % (time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch, @@ -164,7 +167,10 @@ class TrainOneStepCell(nn.Cell): self.optimizer = optimizer self.grad = C.GradOperation(get_by_list=True, sens_param=True) - self.sens = Tensor((np.ones((1,)) * sens).astype(np.float16)) + if context.get_context("device_target") == "CPU": + self.sens = Tensor((np.ones((1,)) * sens).astype(np.float32)) + else: + self.sens = Tensor((np.ones((1,)) * sens).astype(np.float16)) self.reduce_flag = reduce_flag self.hyper_map = C.HyperMap() if reduce_flag: diff --git a/model_zoo/official/cv/maskrcnn_mobilenetv1/train.py b/model_zoo/official/cv/maskrcnn_mobilenetv1/train.py index d073cad3b56..22f6615eb5a 100644 --- a/model_zoo/official/cv/maskrcnn_mobilenetv1/train.py +++ b/model_zoo/official/cv/maskrcnn_mobilenetv1/train.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-21 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -96,13 +96,15 @@ def modelarts_pre_process(): config.pre_trained = os.path.join(config.output_path, config.pre_trained) -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id()) +context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) +if config.device_target == "Ascend": + context.set_context(device_id=config.device_id) @moxing_wrapper(pre_process=modelarts_pre_process) def train_maskrcnn_mobilenetv1(): config.mindrecord_dir = os.path.join(config.coco_root, config.mindrecord_dir) print('config:\n', config) - print("Start train for maskrcnn_mobilenetv1!") + print("Start training for maskrcnn_mobilenetv1!") if not config.do_eval and config.run_distribute: rank = get_rank_id() device_num = get_device_num()