forked from mindspore-Ecosystem/mindspore
add mcnn
This commit is contained in:
parent
c3c9dbd0a4
commit
16fe01f2b6
|
@ -0,0 +1,195 @@
|
|||
# Contents
|
||||
|
||||
- [MCNN Description](#mcnn-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [MCNN Description](#contents)
|
||||
|
||||
MCNN was a Multi-column Convolution Neural Network which can estimate crowd number accurately in a single image from almost any perspective.
|
||||
|
||||
[Paper](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Zhang_Single-Image_Crowd_Counting_CVPR_2016_paper.pdf): Yingying Zhang, Desen Zhou, Siqin Chen, Shenghua Gao, Yi Ma. Single-Image Crowd Counting via Multi-Column Convolutional Neural Network.
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
MCNN contains three parallel CNNs whose filters are with local receptive fields of different sizes. For simplification, we use the same network structures for all columns (i.e.,conv–pooling–conv–pooling) except for the sizes and numbers of filters. Max pooling is applied for each 2×2 region, and Rectified linear unit (ReLU) is adopted as the activation function because of its good performance for CNNs.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
|
||||
|
||||
Dataset used: [ShanghaitechA](<https://www.dropbox.com/s/fipgjqxl7uj8hd5/ShanghaiTech.zip?dl=0>)
|
||||
|
||||
```text
|
||||
├─data
|
||||
├─formatted_trainval
|
||||
├─shanghaitech_part_A_patches_9
|
||||
├─train
|
||||
├─train-den
|
||||
├─val
|
||||
├─val-den
|
||||
├─original
|
||||
├─shanghaitech
|
||||
├─part_A_final
|
||||
├─train_data
|
||||
├─images
|
||||
├─ground_truth
|
||||
├─test_data
|
||||
├─images
|
||||
├─ground_truth
|
||||
├─ground_truth_csv
|
||||
```
|
||||
|
||||
- note: formatted_trainval dir is generated by file [create_training_set_shtech](https://github.com/svishwa/crowdcount-mcnn/blob/master/data_preparation/create_training_set_shtech.m)
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware (Ascend)
|
||||
- Prepare hardware environment with Ascend processor.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Quick Start](#contents)
|
||||
|
||||
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||
|
||||
```bash
|
||||
# enter script dir, train MCNN example
|
||||
sh run_standalone_train_ascend.sh 0 ./formatted_trainval/shanghaitech_part_A_patches_9/train ./formatted_trainval/shanghaitech_part_A_patches_9/train_den ./formatted_trainval/shanghaitech_part_A_patches_9/val ./formatted_trainval/shanghaitech_part_A_patches_9/val_den ./ckpt
|
||||
# enter script dir, evaluate MCNN example
|
||||
sh run_standalone_eval_ascend.sh 0 ./original/shanghaitech/part_A_final/test_data/images ./original/shanghaitech/part_A_final/test_data/ground_truth_csv ./train/ckpt/best.ckpt
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
```text
|
||||
├── cv
|
||||
├── MCNN
|
||||
├── README.md // descriptions about MCNN
|
||||
├── scripts
|
||||
│ ├──run_distribute_train.sh // train in distribute
|
||||
│ ├──run_eval.sh // eval in ascend
|
||||
│ ├──run_standalone_train.sh // train in standalone
|
||||
├── src
|
||||
│ ├──dataset.py // creating dataset
|
||||
│ ├──mcnn.py // mcnn architecture
|
||||
│ ├──config.py // parameter configuration
|
||||
│ ├──data_loader.py // prepare dataset loader(GREY)
|
||||
│ ├──data_loader_3channel.py // prepare dataset loader(RGB)
|
||||
│ ├──evaluate_model.py // evaluate model
|
||||
│ ├──generator_lr.py // generator learning rate
|
||||
│ ├──Mcnn_Callback.py // Mcnn Callback
|
||||
├── train.py // training script
|
||||
├── eval.py // evaluation script
|
||||
├── export.py // export script
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
```python # parameters
|
||||
Major parameters in train.py and config.py as follows:
|
||||
|
||||
--data_path: The absolute full path to the train and evaluation datasets.
|
||||
--epoch_size: Total training epochs.
|
||||
--batch_size: Training batch size.
|
||||
--device_target: Device where the code will be implemented. Optional values are "Ascend", "GPU".
|
||||
--ckpt_path: The absolute full path to the checkpoint file saved after training.
|
||||
--train_path: Training dataset's data
|
||||
--train_gt_path: Training dataset's label
|
||||
--val_path: Testing dataset's data
|
||||
--val_gt_path: Testing dataset's label
|
||||
```
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
### Training
|
||||
|
||||
- running on Ascend
|
||||
|
||||
```bash
|
||||
# enter script dir, and run the distribute script
|
||||
sh run_distribute_train.sh ./hccl_table.json ./formatted_trainval/shanghaitech_part_A_patches_9/train ./formatted_trainval/shanghaitech_part_A_patches_9/train_den ./formatted_trainval/shanghaitech_part_A_patches_9/val ./formatted_trainval/shanghaitech_part_A_patches_9/val_den ./ckpt
|
||||
# enter script dir, and run the standalone script
|
||||
sh run_standalone_train_ascend.sh 0 ./formatted_trainval/shanghaitech_part_A_patches_9/train ./formatted_trainval/shanghaitech_part_A_patches_9/train_den ./formatted_trainval/shanghaitech_part_A_patches_9/val ./formatted_trainval/shanghaitech_part_A_patches_9/val_den ./ckpt
|
||||
```
|
||||
|
||||
After training, the loss value will be achieved as follows:
|
||||
|
||||
```text
|
||||
# grep "loss is " log
|
||||
epoch: 1 step: 305, loss is 0.00041025918
|
||||
epoch: 2 step: 305, loss is 3.7117527e-05
|
||||
...
|
||||
epoch: 798 step: 305, loss is 0.000332611
|
||||
epoch: 799 step: 305, loss is 2.6959011e-05
|
||||
epoch: 800 step: 305, loss is 5.6599742e-06
|
||||
...
|
||||
```
|
||||
|
||||
The model checkpoint will be saved in the current directory.
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### Evaluation
|
||||
|
||||
Before running the command below, please check the checkpoint path used for evaluation.
|
||||
|
||||
- running on Ascend
|
||||
|
||||
```bash
|
||||
# enter script dir, and run the script
|
||||
sh run_standalone_eval_ascend.sh 0 ./original/shanghaitech/part_A_final/test_data/images ./original/shanghaitech/part_A_final/test_data/ground_truth_csv ./train/ckpt/best.ckpt
|
||||
```
|
||||
|
||||
You can view the results through the file "eval_log". The accuracy of the test dataset will be as follows:
|
||||
|
||||
```text
|
||||
# grep "MAE: " eval_log
|
||||
MAE: 105.87984801910736 MSE: 161.6687899899305
|
||||
```
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### Evaluation Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | ------------------------------------------------------------|
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G |
|
||||
| uploaded Date | 06/29/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 |
|
||||
| Dataset | ShanghaitechA |
|
||||
| Training Parameters | steps=2439, batch_size = 1 |
|
||||
| Optimizer | Momentum |
|
||||
| outputs | probability |
|
||||
| Speed | 5.79 ms/step |
|
||||
| Total time | 23 mins |
|
||||
| Checkpoint for Fine tuning | 500.94KB (.ckpt file) |
|
||||
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/MCNN | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/alexnet |
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
In dataset.py, we set the seed inside ```create_dataset``` function.
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
######################## train mcnn example ########################
|
||||
train mcnn and get network model files(.ckpt) :
|
||||
python eval.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import ast
|
||||
from src.dataset import create_dataset
|
||||
from src.mcnn import MCNN
|
||||
from src.data_loader_3channel import ImageDataLoader_3channel
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
import numpy as np
|
||||
|
||||
local_path = '/cache/val_path'
|
||||
local_gt_path = '/cache/val_gt_path'
|
||||
local_ckpt_url = '/cache/ckpt'
|
||||
ckptpath = "obs://lhb1234/MCNN/ckpt"
|
||||
|
||||
parser = argparse.ArgumentParser(description='MindSpore MCNN Example')
|
||||
parser.add_argument('--run_offline', type=ast.literal_eval,
|
||||
default=True, help='run in offline is False or True')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of Ascend. (Default: 0)')
|
||||
parser.add_argument('--ckpt_path', type=str, default="/cache/train_output", help='Location of ckpt.')
|
||||
parser.add_argument('--data_url', default=None, help='Location of data.')
|
||||
parser.add_argument('--train_url', default=None, help='Location of training outputs.')
|
||||
parser.add_argument('--val_path', required=True,
|
||||
default='/data/mcnn/original/shanghaitech/part_A_final/test_data/images',
|
||||
help='Location of data.')
|
||||
parser.add_argument('--val_gt_path', required=True,
|
||||
default='/data/mcnn/original/shanghaitech/part_A_final/test_data/ground_truth_csv',
|
||||
help='Location of data.')
|
||||
args = parser.parse_args()
|
||||
set_seed(64678)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
device_num = int(os.getenv("RANK_SIZE"))
|
||||
|
||||
device_target = args.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
context.set_context(save_graphs=False)
|
||||
|
||||
if device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
|
||||
if args.run_offline:
|
||||
local_path = args.val_path
|
||||
local_gt_path = args.val_gt_path
|
||||
local_ckpt_url = args.ckpt_path
|
||||
else:
|
||||
import moxing as mox
|
||||
mox.file.copy_parallel(src_url=args.val_path, dst_url=local_path)
|
||||
mox.file.copy_parallel(src_url=args.val_gt_path, dst_url=local_gt_path)
|
||||
mox.file.copy_parallel(src_url=ckptpath, dst_url=local_ckpt_url)
|
||||
|
||||
data_loader_val = ImageDataLoader_3channel(local_path, local_gt_path, shuffle=False, gt_downsample=True,
|
||||
pre_load=True)
|
||||
ds_val = create_dataset(data_loader_val, target=args.device_target, train=False)
|
||||
ds_val = ds_val.batch(1)
|
||||
network = MCNN()
|
||||
|
||||
model_name = local_ckpt_url
|
||||
print(model_name)
|
||||
mae = 0.0
|
||||
mse = 0.0
|
||||
load_checkpoint(model_name, net=network)
|
||||
network.set_train(False)
|
||||
for sample in ds_val.create_dict_iterator():
|
||||
im_data = sample['data']
|
||||
gt_data = sample['gt_density']
|
||||
density_map = network(im_data)
|
||||
gt_count = np.sum(gt_data.asnumpy())
|
||||
et_count = np.sum(density_map.asnumpy())
|
||||
mae += abs(gt_count-et_count)
|
||||
mse += ((gt_count-et_count) * (gt_count-et_count))
|
||||
mae = mae / ds_val.get_dataset_size()
|
||||
mse = np.sqrt(mse / ds_val.get_dataset_size())
|
||||
print('MAE:', mae, ' MSE:', mse)
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export checkpoint file into air, onnx, mindir models"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import mindspore
|
||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||
from src.config import crowd_cfg as cfg
|
||||
from src.mcnn import MCNN
|
||||
|
||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||
parser.add_argument("--device_id", type=int, default=4, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="mcnn", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# define fusion network
|
||||
network = MCNN()
|
||||
# load network checkpoint
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
# export network
|
||||
inputs = Tensor(np.ones([args.batch_size, 1, cfg.image_height, cfg.image_width]), mindspore.float32)
|
||||
export(network, inputs, file_name=args.file_name, file_format=args.file_format)
|
|
@ -0,0 +1,45 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
ulimit -u unlimited
|
||||
export RANK_SIZE=8
|
||||
export DEVICE_NUM=8
|
||||
export RANK_TABLE_FILE=$1
|
||||
export TRAIN_PATH=$2
|
||||
export TRAIN_GT_PATH=$3
|
||||
export VAL_PATH=$4
|
||||
export VAL_GT_PATH=$5
|
||||
export CKPT_PATH=$6
|
||||
|
||||
export SERVER_ID=0
|
||||
rank_start=$((DEVICE_NUM * SERVER_ID))
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=${i}
|
||||
export RANK_ID=$((rank_start + i))
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python -u train.py --device_id=$DEVICE_ID --train_path=$TRAIN_PATH --train_gt_path=$TRAIN_GT_PATH \
|
||||
--val_path=$VAL_PATH --val_gt_path=$VAL_GT_PATH --ckpt_path=$CKPT_PATH &> log &
|
||||
cd ..
|
||||
done
|
|
@ -0,0 +1,44 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 4 ]
|
||||
then
|
||||
echo "Usage: sh run_eval.sh [DEVICE_ID] [VAL_PATH] [VAL_GT_PATH] [CKPT_PATH] "
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export RANK_SIZE=1
|
||||
export DEVICE_ID=$1
|
||||
export VAL_PATH=$2
|
||||
export VAL_GT_PATH=$3
|
||||
export CKPT_PATH=$4
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env > env.log
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
python -u eval.py --device_id=$DEVICE_ID --val_path=$VAL_PATH \
|
||||
--val_gt_path=$VAL_GT_PATH --ckpt_path=$CKPT_PATH &> log &
|
||||
cd ..
|
|
@ -0,0 +1,43 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
ulimit -u unlimited
|
||||
export RANK_SIZE=1
|
||||
export DEVICE_ID=$1
|
||||
export TRAIN_PATH=$2
|
||||
export TRAIN_GT_PATH=$3
|
||||
export VAL_PATH=$4
|
||||
export VAL_GT_PATH=$5
|
||||
export CKPT_PATH=$6
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.
|
||||
|
||||
if [ $# == 6 ]
|
||||
then
|
||||
python -u train.py --device_id=$DEVICE_ID --train_path=$TRAIN_PATH --train_gt_path=$TRAIN_GT_PATH \
|
||||
--val_path=$VAL_PATH --val_gt_path=$VAL_GT_PATH --ckpt_path=$CKPT_PATH &> log &
|
||||
fi
|
||||
cd ..
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""This is callback program"""
|
||||
import os
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
from src.evaluate_model import evaluate_model
|
||||
|
||||
|
||||
class mcnn_callback(Callback):
|
||||
def __init__(self, net, eval_data, run_offline, ckpt_path):
|
||||
self.net = net
|
||||
self.eval_data = eval_data
|
||||
self.best_mae = 999999
|
||||
self.best_mse = 999999
|
||||
self.best_epoch = 0
|
||||
self.path_url = "/cache/train_output"
|
||||
self.run_offline = run_offline
|
||||
self.ckpt_path = ckpt_path
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
# print(self.net.trainable_params()[0].data.asnumpy()[0][0])
|
||||
mae, mse = evaluate_model(self.net, self.eval_data)
|
||||
cb_param = run_context.original_args()
|
||||
cur_epoch = cb_param.cur_epoch_num
|
||||
if cur_epoch % 2 == 0:
|
||||
if mae < self.best_mae:
|
||||
self.best_mae = mae
|
||||
self.best_mse = mse
|
||||
self.best_epoch = cur_epoch
|
||||
device_id = int(os.getenv("DEVICE_ID"))
|
||||
device_num = int(os.getenv("RANK_SIZE"))
|
||||
if (device_num == 1) or (device_num == 8 and device_id == 0):
|
||||
# save_checkpoint(self.net, path_url+'/best.ckpt')
|
||||
if self.run_offline:
|
||||
self.path_url = self.ckpt_path
|
||||
if not os.path.exists(self.path_url):
|
||||
os.makedirs(self.path_url, exist_ok=True)
|
||||
save_checkpoint(self.net, os.path.join(self.path_url, 'best.ckpt'))
|
||||
|
||||
log_text = 'EPOCH: %d, MAE: %.1f, MSE: %0.1f' % (cur_epoch, mae, mse)
|
||||
print(log_text)
|
||||
log_text = 'BEST MAE: %0.1f, BEST MSE: %0.1f, BEST EPOCH: %s' \
|
||||
% (self.best_mae, self.best_mse, self.best_epoch)
|
||||
print(log_text)
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py
|
||||
"""
|
||||
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
crowd_cfg = edict({
|
||||
'lr': 0.000028,# 0.00001 if device_num == 1; 0.00003 device_num=8
|
||||
'momentum': 0.0,
|
||||
'epoch_size': 800,
|
||||
'batch_size': 1,
|
||||
'buffer_size': 1000,
|
||||
'save_checkpoint_steps': 1,
|
||||
'keep_checkpoint_max': 10,
|
||||
'air_name': "mcnn",
|
||||
})
|
|
@ -0,0 +1,136 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""image dataloader"""
|
||||
|
||||
import os
|
||||
import random
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class ImageDataLoader():
|
||||
def __init__(self, data_path, gt_path, shuffle=False, gt_downsample=False, pre_load=False):
|
||||
# pre_load: if true, all training and validation images are loaded into CPU RAM for faster processing.
|
||||
# This avoids frequent file reads. Use this only for small datasets.
|
||||
self.data_path = data_path
|
||||
self.gt_path = gt_path
|
||||
self.gt_downsample = gt_downsample
|
||||
self.pre_load = pre_load
|
||||
self.data_files = [filename for filename in os.listdir(data_path) \
|
||||
if os.path.isfile(os.path.join(data_path, filename))]
|
||||
self.data_files.sort()
|
||||
self.shuffle = shuffle
|
||||
if shuffle:
|
||||
random.seed(2468)
|
||||
self.num_samples = len(self.data_files)
|
||||
self.blob_list = {}
|
||||
self.id_list = range(0, self.num_samples)
|
||||
if self.pre_load:
|
||||
print('Pre-loading the data. This may take a while...')
|
||||
idx = 0
|
||||
for fname in self.data_files:
|
||||
|
||||
img = cv2.imread(os.path.join(self.data_path, fname), 0)
|
||||
img = img.astype(np.float32, copy=False)
|
||||
ht = img.shape[0]
|
||||
wd = img.shape[1]
|
||||
ht_1 = (ht // 4) * 4
|
||||
wd_1 = (wd // 4) * 4
|
||||
img = cv2.resize(img, (wd_1, ht_1))
|
||||
|
||||
hang = (256 - ht_1) // 2
|
||||
lie = (256 - wd_1) // 2
|
||||
img = np.pad(img, ((hang, hang), (lie, lie)), 'constant')
|
||||
|
||||
img = img.reshape((1, img.shape[0], img.shape[1]))
|
||||
den = pd.read_csv(os.path.join(self.gt_path, os.path.splitext(fname)[0] + '.csv'), sep=',',
|
||||
header=None).values
|
||||
den = den.astype(np.float32, copy=False)
|
||||
if self.gt_downsample:
|
||||
den = np.pad(den, ((hang, hang), (lie, lie)), 'constant')
|
||||
# print(den.shape)
|
||||
wd_1 = wd_1 // 4
|
||||
ht_1 = ht_1 // 4
|
||||
den = cv2.resize(den, (64, 64))
|
||||
den = den * ((wd * ht) / (wd_1 * ht_1))
|
||||
else:
|
||||
den = cv2.resize(den, (wd_1, ht_1))
|
||||
den = den * ((wd * ht) / (wd_1 * ht_1))
|
||||
|
||||
den = den.reshape((1, den.shape[0], den.shape[1]))
|
||||
blob = {}
|
||||
blob['data'] = img
|
||||
blob['gt_density'] = den
|
||||
blob['fname'] = fname
|
||||
self.blob_list[idx] = blob
|
||||
idx = idx + 1
|
||||
if idx % 100 == 0:
|
||||
print('Loaded ', idx, '/', self.num_samples, 'files')
|
||||
|
||||
print('Completed Loading ', idx, 'files')
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
if self.pre_load:
|
||||
random.shuffle(list(self.id_list))
|
||||
else:
|
||||
random.shuffle(list(self.data_files))
|
||||
files = self.data_files
|
||||
id_list = self.id_list
|
||||
|
||||
for idx in id_list:
|
||||
if self.pre_load:
|
||||
blob = self.blob_list[idx]
|
||||
blob['idx'] = idx
|
||||
else:
|
||||
fname = files[idx]
|
||||
img = cv2.imread(os.path.join(self.data_path, fname), 0)
|
||||
img = img.astype(np.float32, copy=False)
|
||||
ht = img.shape[0]
|
||||
wd = img.shape[1]
|
||||
ht_1 = (ht / 4) * 4
|
||||
wd_1 = (wd / 4) * 4
|
||||
img = cv2.resize(img, (wd_1, ht_1))
|
||||
|
||||
hang = (256 - ht_1) // 2
|
||||
lie = (256 - wd_1) // 2
|
||||
img = np.pad(img, ((hang, hang), (lie, lie)), 'constant')
|
||||
|
||||
img = img.reshape((1, img.shape[0], img.shape[1]))
|
||||
den = pd.read_csv(os.path.join(self.gt_path, os.path.splitext(fname)[0] + '.csv'), sep=',',
|
||||
header=None).as_matrix()
|
||||
den = den.astype(np.float32, copy=False)
|
||||
|
||||
if self.gt_downsample:
|
||||
den = np.pad(den, ((hang, hang), (lie, lie)), 'constant')
|
||||
wd_1 = wd_1 / 4
|
||||
ht_1 = ht_1 / 4
|
||||
den = cv2.resize(den, (64, 64))
|
||||
den = den * ((wd * ht) / (wd_1 * ht_1))
|
||||
else:
|
||||
den = cv2.resize(den, (wd_1, ht_1))
|
||||
den = den * ((wd * ht) / (wd_1 * ht_1))
|
||||
|
||||
den = den.reshape((1, den.shape[0], den.shape[1]))
|
||||
blob = {}
|
||||
blob['data'] = img
|
||||
blob['gt_density'] = den
|
||||
blob['fname'] = fname
|
||||
|
||||
yield blob
|
||||
|
||||
def get_num_samples(self):
|
||||
return self.num_samples
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""ImageDataLoader_3channel"""
|
||||
import os
|
||||
import random
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class ImageDataLoader_3channel():
|
||||
def __init__(self, data_path, gt_path, shuffle=False, gt_downsample=False, pre_load=False):
|
||||
# pre_load: if true, all training and validation images are loaded into CPU RAM for faster processing.
|
||||
# This avoids frequent file reads. Use this only for small datasets.
|
||||
self.data_path = data_path
|
||||
self.gt_path = gt_path
|
||||
self.gt_downsample = gt_downsample
|
||||
self.pre_load = pre_load
|
||||
self.data_files = [filename for filename in os.listdir(data_path) \
|
||||
if os.path.isfile(os.path.join(data_path, filename))]
|
||||
self.data_files.sort()
|
||||
self.shuffle = shuffle
|
||||
if shuffle:
|
||||
random.seed(2468)
|
||||
self.num_samples = len(self.data_files)
|
||||
self.blob_list = {}
|
||||
self.id_list = range(0, self.num_samples)
|
||||
if self.pre_load:
|
||||
print('Pre-loading the data. This may take a while...')
|
||||
idx = 0
|
||||
for fname in self.data_files:
|
||||
|
||||
img = cv2.imread(os.path.join(self.data_path, fname), 0)
|
||||
img = img.astype(np.float32, copy=False)
|
||||
ht = img.shape[0]
|
||||
wd = img.shape[1]
|
||||
ht_1 = (ht // 4) * 4
|
||||
wd_1 = (wd // 4) * 4
|
||||
img = cv2.resize(img, (wd_1, ht_1))
|
||||
|
||||
img = img.reshape((1, img.shape[0], img.shape[1]))
|
||||
den = pd.read_csv(os.path.join(self.gt_path, os.path.splitext(fname)[0] + '.csv'), sep=',',
|
||||
header=None).values
|
||||
den = den.astype(np.float32, copy=False)
|
||||
if self.gt_downsample:
|
||||
# print(den.shape)
|
||||
wd_1 = wd_1 // 4
|
||||
ht_1 = ht_1 // 4
|
||||
den = cv2.resize(den, (wd_1, ht_1))
|
||||
den = den * ((wd * ht) / (wd_1 * ht_1))
|
||||
else:
|
||||
den = cv2.resize(den, (wd_1, ht_1))
|
||||
den = den * ((wd * ht) / (wd_1 * ht_1))
|
||||
|
||||
den = den.reshape((1, den.shape[0], den.shape[1]))
|
||||
blob = {}
|
||||
blob['data'] = img
|
||||
blob['gt_density'] = den
|
||||
blob['fname'] = fname
|
||||
self.blob_list[idx] = blob
|
||||
idx = idx + 1
|
||||
if idx % 100 == 0:
|
||||
print('Loaded ', idx, '/', self.num_samples, 'files')
|
||||
|
||||
print('Completed Loading ', idx, 'files')
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
if self.pre_load:
|
||||
random.shuffle(list(self.id_list))
|
||||
else:
|
||||
random.shuffle(list(self.data_files))
|
||||
files = self.data_files
|
||||
id_list = self.id_list
|
||||
|
||||
for idx in id_list:
|
||||
if self.pre_load:
|
||||
blob = self.blob_list[idx]
|
||||
blob['idx'] = idx
|
||||
else:
|
||||
fname = files[idx]
|
||||
img = cv2.imread(os.path.join(self.data_path, fname), 0)
|
||||
img = img.astype(np.float32, copy=False)
|
||||
ht = img.shape[0]
|
||||
wd = img.shape[1]
|
||||
ht_1 = (ht / 4) * 4
|
||||
wd_1 = (wd / 4) * 4
|
||||
img = cv2.resize(img, (wd_1, ht_1))
|
||||
|
||||
img = img.reshape((1, img.shape[0], img.shape[1]))
|
||||
den = pd.read_csv(os.path.join(self.gt_path, os.path.splitext(fname)[0] + '.csv'), sep=',',
|
||||
header=None).as_matrix()
|
||||
den = den.astype(np.float32, copy=False)
|
||||
|
||||
if self.gt_downsample:
|
||||
wd_1 = wd_1 / 4
|
||||
ht_1 = ht_1 / 4
|
||||
den = cv2.resize(den, (wd_1, ht_1))
|
||||
den = den * ((wd * ht) / (wd_1 * ht_1))
|
||||
else:
|
||||
den = cv2.resize(den, (wd_1, ht_1))
|
||||
den = den * ((wd * ht) / (wd_1 * ht_1))
|
||||
|
||||
den = den.reshape((1, den.shape[0], den.shape[1]))
|
||||
blob = {}
|
||||
blob['data'] = img
|
||||
blob['gt_density'] = den
|
||||
blob['fname'] = fname
|
||||
|
||||
yield blob
|
||||
|
||||
def get_num_samples(self):
|
||||
return self.num_samples
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Data operations, will be used in train.py and eval.py
|
||||
"""
|
||||
import os
|
||||
import math
|
||||
import mindspore.dataset as ds
|
||||
|
||||
|
||||
def create_dataset(data_loader, target="Ascend", train=True):
|
||||
datalist = []
|
||||
labellist = []
|
||||
for blob in data_loader:
|
||||
datalist.append(blob['data'])
|
||||
labellist.append(blob['gt_density'])
|
||||
|
||||
class GetDatasetGenerator:
|
||||
def __init__(self):
|
||||
|
||||
self.__data = datalist
|
||||
self.__label = labellist
|
||||
|
||||
def __getitem__(self, index):
|
||||
return (self.__data[index], self.__label[index])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.__data)
|
||||
|
||||
class MySampler():
|
||||
def __init__(self, dataset, local_rank, world_size):
|
||||
self.__num_data = len(dataset)
|
||||
self.__local_rank = local_rank
|
||||
self.__world_size = world_size
|
||||
self.samples_per_rank = int(math.ceil(self.__num_data / float(self.__world_size)))
|
||||
self.total_num_samples = self.samples_per_rank * self.__world_size
|
||||
|
||||
def __iter__(self):
|
||||
indices = list(range(self.__num_data))
|
||||
indices.extend(indices[:self.total_num_samples-len(indices)])
|
||||
indices = indices[self.__local_rank:self.total_num_samples:self.__world_size]
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.samples_per_rank
|
||||
|
||||
dataset_generator = GetDatasetGenerator()
|
||||
sampler = MySampler(dataset_generator, local_rank=0, world_size=8)
|
||||
|
||||
if target == "Ascend":
|
||||
# device_num, rank_id = _get_rank_info()
|
||||
device_num = int(os.getenv("RANK_SIZE"))
|
||||
rank_id = int(os.getenv("DEVICE_ID"))
|
||||
sampler = MySampler(dataset_generator, local_rank=rank_id, world_size=8)
|
||||
if target != "Ascend" or device_num == 1 or (not train):
|
||||
data_set = ds.GeneratorDataset(dataset_generator, ["data", "gt_density"])
|
||||
else:
|
||||
data_set = ds.GeneratorDataset(dataset_generator, ["data", "gt_density"], num_parallel_workers=8,
|
||||
num_shards=device_num, shard_id=rank_id, sampler=sampler)
|
||||
|
||||
return data_set
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""evaluate model"""
|
||||
import numpy as np
|
||||
|
||||
|
||||
def evaluate_model(model, dataset):
|
||||
net = model
|
||||
print("*******************************************************************************************************")
|
||||
mae = 0.0
|
||||
mse = 0.0
|
||||
net.set_train(False)
|
||||
for sample in dataset.create_dict_iterator():
|
||||
im_data = sample['data']
|
||||
gt_data = sample['gt_density']
|
||||
density_map = net(im_data)
|
||||
gt_count = np.sum(gt_data.asnumpy())
|
||||
et_count = np.sum(density_map.asnumpy())
|
||||
mae += abs(gt_count - et_count)
|
||||
mse += ((gt_count - et_count) * (gt_count - et_count))
|
||||
mae = mae / (dataset.get_dataset_size())
|
||||
mse = np.sqrt(mse / dataset.get_dataset_size())
|
||||
return mae, mse
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""learning rate generator"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_lr_sha(current_step, lr_max, total_epochs, steps_per_epoch):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
current_step(int): current steps of the training
|
||||
lr_max(float): max learning rate
|
||||
total_epochs(int): total epoch of training
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
decay_epoch_index = [0.5 * total_steps, 0.75 * total_steps]
|
||||
for i in range(total_steps):
|
||||
if i < decay_epoch_index[0]:
|
||||
lr = lr_max
|
||||
elif i < decay_epoch_index[1]:
|
||||
lr = lr_max * 0.1
|
||||
else:
|
||||
lr = lr_max * 0.01
|
||||
lr_each_step.append(lr)
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
learning_rate = lr_each_step[current_step:]
|
||||
|
||||
return learning_rate
|
|
@ -0,0 +1,119 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""This is mcnn model"""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype as mstype
|
||||
|
||||
|
||||
class Conv2d(nn.Cell):
|
||||
"""This is Conv2d model"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, relu=True, same_padding=False, bn=False):
|
||||
super(Conv2d, self).__init__()
|
||||
padding = int((kernel_size - 1) / 2) if same_padding else 0
|
||||
# padding = 'same' if same_padding else 'valid'
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
|
||||
pad_mode='pad', padding=padding, has_bias=True)
|
||||
self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0, affine=True) if bn else None
|
||||
self.relu = nn.ReLU() if relu else None
|
||||
# # TODO init weights
|
||||
self._initialize_weights()
|
||||
|
||||
def construct(self, x):
|
||||
"""define Conv2d network"""
|
||||
x = self.conv(x)
|
||||
# if self.bn is not None:
|
||||
# x = self.bn(x)
|
||||
if self.relu is not None:
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
def _initialize_weights(self):
|
||||
"""initialize weights"""
|
||||
for _, m in self.cells_and_names():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
m.weight.set_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32")))
|
||||
if m.bias is not None:
|
||||
m.bias.set_data(
|
||||
Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
|
||||
if isinstance(m, nn.Dense):
|
||||
m.weight.set_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32")))
|
||||
|
||||
|
||||
def np_to_tensor(x, is_cuda=True, is_training=False):
|
||||
if is_training:
|
||||
v = Tensor(x, mstype.float32)
|
||||
else:
|
||||
v = Tensor(x, mstype.float32) # with torch.no_grad():
|
||||
return v
|
||||
|
||||
|
||||
class MCNN(nn.Cell):
|
||||
'''
|
||||
Multi-column CNN
|
||||
-Implementation of Single Image Crowd Counting via Multi-column CNN (Zhang et al.)
|
||||
'''
|
||||
def __init__(self, bn=False):
|
||||
super(MCNN, self).__init__()
|
||||
|
||||
self.branch1 = nn.SequentialCell(Conv2d(1, 16, 9, same_padding=True, bn=bn),
|
||||
nn.MaxPool2d(2, 2),
|
||||
Conv2d(16, 32, 7, same_padding=True, bn=bn),
|
||||
nn.MaxPool2d(2, 2),
|
||||
Conv2d(32, 16, 7, same_padding=True, bn=bn),
|
||||
Conv2d(16, 8, 7, same_padding=True, bn=bn))
|
||||
|
||||
self.branch2 = nn.SequentialCell(Conv2d(1, 20, 7, same_padding=True, bn=bn),
|
||||
nn.MaxPool2d(2, 2),
|
||||
Conv2d(20, 40, 5, same_padding=True, bn=bn),
|
||||
nn.MaxPool2d(2, 2),
|
||||
Conv2d(40, 20, 5, same_padding=True, bn=bn),
|
||||
Conv2d(20, 10, 5, same_padding=True, bn=bn))
|
||||
|
||||
self.branch3 = nn.SequentialCell(Conv2d(1, 24, 5, same_padding=True, bn=bn),
|
||||
nn.MaxPool2d(2, 2),
|
||||
Conv2d(24, 48, 3, same_padding=True, bn=bn),
|
||||
nn.MaxPool2d(2, 2),
|
||||
Conv2d(48, 24, 3, same_padding=True, bn=bn),
|
||||
Conv2d(24, 12, 3, same_padding=True, bn=bn))
|
||||
|
||||
self.fuse = nn.SequentialCell([Conv2d(30, 1, 1, same_padding=True, bn=bn)])
|
||||
|
||||
##TODO init weights
|
||||
self._initialize_weights()
|
||||
|
||||
def construct(self, im_data):
|
||||
"""define network"""
|
||||
x1 = self.branch1(im_data)
|
||||
x2 = self.branch2(im_data)
|
||||
x3 = self.branch3(im_data)
|
||||
op = ops.Concat(1)
|
||||
x = op((x1, x2, x3))
|
||||
x = self.fuse(x)
|
||||
return x
|
||||
|
||||
def _initialize_weights(self):
|
||||
"""initialize weights"""
|
||||
for _, m in self.cells_and_names():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
m.weight.set_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32")))
|
||||
if m.bias is not None:
|
||||
m.bias.set_data(
|
||||
Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
|
||||
if isinstance(m, nn.Dense):
|
||||
m.weight.set_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32")))
|
|
@ -0,0 +1,119 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
######################## train mcnn example ########################
|
||||
train mcnn and get network model files(.ckpt) :
|
||||
python train.py
|
||||
"""
|
||||
import os
|
||||
import argparse
|
||||
import ast
|
||||
import numpy as np
|
||||
from mindspore.communication.management import init
|
||||
import mindspore.nn as nn
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.callback import LossMonitor, TimeMonitor
|
||||
from mindspore.train import Model
|
||||
from src.data_loader import ImageDataLoader
|
||||
from src.config import crowd_cfg as cfg
|
||||
from src.dataset import create_dataset
|
||||
from src.mcnn import MCNN
|
||||
from src.generator_lr import get_lr_sha
|
||||
from src.Mcnn_Callback import mcnn_callback
|
||||
|
||||
parser = argparse.ArgumentParser(description='MindSpore MCNN Example')
|
||||
parser.add_argument('--run_offline', type=ast.literal_eval,
|
||||
default=True, help='run in offline is False or True')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of Ascend. (Default: 0)')
|
||||
parser.add_argument('--ckpt_path', type=str, default="/cache/train_output", help='Location of ckpt.')
|
||||
|
||||
parser.add_argument('--data_url', default=None, help='Location of data.')
|
||||
parser.add_argument('--train_url', default=None, help='Location of training outputs.')
|
||||
|
||||
parser.add_argument('--train_path', required=True, default=None, help='Location of data.')
|
||||
parser.add_argument('--train_gt_path', required=True, default=None, help='Location of data.')
|
||||
parser.add_argument('--val_path', required=True,
|
||||
default='/data/formatted_trainval/shanghaitech_part_A_patches_9/val',
|
||||
help='Location of data.')
|
||||
parser.add_argument('--val_gt_path', required=True,
|
||||
default='/data/formatted_trainval/shanghaitech_part_A_patches_9/val_den',
|
||||
help='Location of data.')
|
||||
args = parser.parse_args()
|
||||
rand_seed = 64678
|
||||
np.random.seed(rand_seed)
|
||||
|
||||
if __name__ == "__main__":
|
||||
device_num = int(os.getenv("RANK_SIZE"))
|
||||
|
||||
print("device_num:", device_num)
|
||||
device_target = args.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
context.set_context(save_graphs=False)
|
||||
|
||||
if device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if device_num > 1:
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
if args.run_offline:
|
||||
local_data1_url = args.train_path
|
||||
local_data2_url = args.train_gt_path
|
||||
local_data3_url = args.val_path
|
||||
local_data4_url = args.val_gt_path
|
||||
else:
|
||||
import moxing as mox
|
||||
local_data1_url = '/cache/train_path'
|
||||
local_data2_url = '/cache/train_gt_path'
|
||||
local_data3_url = '/cache/val_path'
|
||||
local_data4_url = '/cache/val_gt_path'
|
||||
|
||||
mox.file.copy_parallel(src_url=args.train_path, dst_url=local_data1_url) # pcl
|
||||
mox.file.copy_parallel(src_url=args.train_gt_path, dst_url=local_data2_url) # pcl
|
||||
mox.file.copy_parallel(src_url=args.val_path, dst_url=local_data3_url) # pcl
|
||||
mox.file.copy_parallel(src_url=args.val_gt_path, dst_url=local_data4_url) # pcl
|
||||
|
||||
data_loader = ImageDataLoader(local_data1_url, local_data2_url, shuffle=True, gt_downsample=True, pre_load=True)
|
||||
data_loader_val = ImageDataLoader(local_data3_url, local_data4_url,
|
||||
shuffle=False, gt_downsample=True, pre_load=True)
|
||||
ds_train = create_dataset(data_loader, target=args.device_target)
|
||||
ds_val = create_dataset(data_loader_val, target=args.device_target, train=False)
|
||||
|
||||
ds_train = ds_train.batch(cfg['batch_size'])
|
||||
ds_val = ds_val.batch(1)
|
||||
|
||||
network = MCNN()
|
||||
net_loss = nn.MSELoss(reduction='mean')
|
||||
lr = Tensor(get_lr_sha(0, cfg['lr'], cfg['epoch_size'], ds_train.get_dataset_size()))
|
||||
net_opt = nn.Adam(list(filter(lambda p: p.requires_grad, network.get_parameters())), learning_rate=lr)
|
||||
|
||||
if args.device_target != "Ascend":
|
||||
model = Model(network, net_loss, net_opt)
|
||||
else:
|
||||
model = Model(network, net_loss, net_opt, amp_level="O2")
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||
eval_callback = mcnn_callback(network, ds_val, args.run_offline, args.ckpt_path)
|
||||
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, eval_callback, LossMonitor(1)])
|
||||
if not args.run_offline:
|
||||
mox.file.copy_parallel(src_url='/cache/train_output', dst_url="obs://lhb1234/MCNN/ckpt")
|
Loading…
Reference in New Issue