add pointnet2

This commit is contained in:
WPX 2021-08-31 15:59:46 +08:00
parent 7eac436c56
commit 1d41475e61
15 changed files with 1799 additions and 0 deletions

View File

@ -0,0 +1,230 @@
# Contents
- [Contents](#contents)
- [PointNet2 Description](#pointnet2-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#training-performance)
- [Inference Performance](#inference-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [PointNet2 Description](#contents)
PointNet++ was proposed in 2017, it is a hierarchical neural network that applies PointNet recursively on a nested
partitioning of the input point set. By exploiting metric space distances, this network is able to learn local features
with increasing contextual scales. Experiments show that our network called PointNet++ is able to learn deep point set
features efficiently and robustly.
[Paper](http://arxiv.org/abs/1706.02413): Qi, Charles R., et al. "Pointnet++: Deep hierarchical feature learning on
point sets in a metric space." arXiv preprint arXiv:1706.02413 (2017).
# [Model Architecture](#contents)
The hierarchical structure of PointNet++ is composed by a number of *set abstraction* levels. At each level, a set of
points is processed and abstracted to produce a new set with fewer elements. The set abstraction level is made of three
key layers: *Sampling layer*, *Grouping layer* and *PointNet layer*. The *Sampling* *layer* selects a set of points from
input points, which defines the centroids of local regions. *Grouping* *layer* then constructs local region sets by
finding “neighboring” points around the centroids. *PointNet* *layer* uses a mini-PointNet to encode local region
patterns into feature vectors.
# [Dataset](#contents)
Dataset used: alignment [ModelNet40](<https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip>)
- Dataset size6.48GEach point cloud contains 2048 points uniformly sampled from a shape surface. Each cloud is
zero-mean and normalized into a unit sphere.
- Train5.18G, 9843 point clouds
- Test1.3G, 2468 point clouds
- Data formattxt files
- NoteData will be processed in src/dataset.py
# [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor.
- Framework
- [MindSpore](https://www.mindspore.cn/install)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
# [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
```shell
# Run stand-alone training
bash scripts/run_standalone_train.sh [DATA_PATH] [SAVE_DIR] [PRETRAINDE_CKPT(optional)]
# example:
bash scripts/run_standalone_train.sh modelnet40_normal_resampled save pointnet2.ckpt
# Run distributed training
bash scripts/run_distributed_train.sh [RANK_TABLE_FILE] [DATA_PATH] [SAVE_DIR] [PRETRAINDE_CKPT(optional)]
# example:
bash scripts/run_standalone_train.sh hccl_8p_01234567_127.0.0.1.json modelnet40_normal_resampled save pointnet2.ckpt
# Evaluate
bash scripts/run_eval.sh [DATA_PATH] [CKPT_NAME]
# example:
bash scripts/run_eval.sh modelnet40_normal_resampled pointnet2.ckpt
```
# [Script Description](#contents)
# [Script and Sample Code](#contents)
```bash
├── .
├── PointNet2
├── scripts
│ ├── run_distribute_train.sh # launch distributed training with ascend platform (8p)
│ ├── run_eval.sh # launch evaluating with ascend platform
│ └── run_train_ascend.sh # launch standalone training with ascend platform (1p)
├── src
│ ├── callbacks.py # callbacks definition
│ ├── dataset.py # data preprocessing
│ ├── layers.py # network layers initialization
│ ├── lr_scheduler.py # learning rate scheduler
│ ├── PointNet2.py # network definition
│ ├── PointNet2_utils.py # network definition utils
│ └── provider.py # data preprocessing for training
├── eval.py # eval net
├── README.md
├── requirements.txt
└── train.py # train net
```
# [Script Parameters](#contents)
```bash
Major parameters in train.py are as follows:
--batch_size # Training batch size.
--epoch # Total training epochs.
--learning_rate # Training learning rate.
--optimizer # Optimizer for training. Optional values are "Adam", "SGD".
--data_path # The path to the train and evaluation datasets.
--loss_per_epoch # The times to print loss value per epoch.
--save_dir # The path to save files generated during training.
--use_normals # Whether to use normals data in training.
--pretrained_ckpt # The file path to load checkpoint.
--enable_modelarts # Whether to use modelarts.
```
# [Training Process](#contents)
## Training
- running on Ascend
```shell
# Run stand-alone training
bash scripts/run_standalone_train.sh [DATA_PATH] [SAVE_DIR] [PRETRAINDE_CKPT(optional)]
# example:
bash scripts/run_standalone_train.sh modelnet40_normal_resampled save pointnet2.ckpt
# Run distributed training
bash scripts/run_distributed_train.sh [RANK_TABLE_FILE] [DATA_PATH] [SAVE_DIR] [PRETRAINDE_CKPT(optional)]
# example:
bash scripts/run_standalone_train.sh hccl_8p_01234567_127.0.0.1.json modelnet40_normal_resampled save pointnet2.ckpt
```
Distributed training requires the creation of an HCCL configuration file in JSON format in advance. For specific
operations, see the instructions
in [hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
After training, the loss value will be achieved as follows:
```bash
# train log
epoch: 1 step: 410, loss is 1.4731973
epoch time: 704454.051 ms, per step time: 1718.181 ms
epoch: 2 step: 410, loss is 1.0621885
epoch time: 471478.224 ms, per step time: 1149.947 ms
epoch: 3 step: 410, loss is 1.176581
epoch time: 471530.000 ms, per step time: 1150.073 ms
epoch: 4 step: 410, loss is 1.0118457
epoch time: 471498.514 ms, per step time: 1149.996 ms
epoch: 5 step: 410, loss is 0.47454038
epoch time: 471535.602 ms, per step time: 1150.087 ms
...
```
The model checkpoint will be saved in the 'SAVE_DIR' directory.
# [Evaluation Process](#contents)
## Evaluation
Before running the command below, please check the checkpoint path used for evaluation.
- running on Ascend
```shell
# Evaluate
bash scripts/run_eval.sh [DATA_PATH] [CKPT_NAME]
# example:
bash scripts/run_eval.sh modelnet40_normal_resampled pointnet2.ckpt
```
You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:
```bash
# grep "Accuracy: " eval.log
'Accuracy': 0.9146
```
# [Model Description](#contents)
## [Performance](#contents)
## Training Performance
| Parameters | Ascend |
| -------------------------- | ----------------------------------------------------------- |
| Model Version | PointNet++ |
| Resource | Ascend 910; CPU 24cores; Memory 256G; OS Euler2.8 |
| uploaded Date | 08/31/2021 (month/day/year) |
| MindSpore Version | 1.3.0 |
| Dataset | ModelNet40 |
| Training Parameters | epoch=200, steps=82000, batch_size=24, lr=0.001 |
| Optimizer | Adam |
| Loss Function | NLLLoss |
| outputs | probability |
| Loss | 0.01 |
| Speed | 1.2 s/step (1p) |
| Total time | 27.3 h (1p) |
| Checkpoint for Fine tuning | 17 MB (.ckpt file) |
## Inference Performance
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Model Version | PointNet++ |
| Resource | Ascend 910; CPU 24cores; Memory 256G; OS Euler2.8 |
| Uploaded Date | 08/31/2021 (month/day/year) |
| MindSpore Version | 1.3.0 |
| Dataset | ModelNet40 |
| Batch_size | 24 |
| Outputs | probability |
| Accuracy | 91.5% (1p) |
| Total time | 2.5 min |
# [Description of Random Situation](#contents)
We use random seed in dataset.py, provider.py and pointnet2_utils.py.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,118 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Eval"""
import argparse
import ast
import os
import time
import mindspore.dataset as ds
from mindspore import context
from mindspore.nn.metrics import Accuracy
from mindspore.train import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.dataset import DatasetGenerator
from src.pointnet2 import PointNet2, NLLLoss
def parse_args():
"""PARAMETERS"""
parser = argparse.ArgumentParser('MindSpore PointNet++ Eval Configurations.')
parser.add_argument('--batch_size', type=int, default=24, help='batch size in training')
parser.add_argument('--data_path', type=str, default='../data/modelnet40_normal_resampled/', help='data path')
parser.add_argument('--pretrained_ckpt', type=str, default='')
parser.add_argument('--num_category', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40')
parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
parser.add_argument('--use_normals', type=ast.literal_eval, default=False, help='use normals')
parser.add_argument('--process_data', action='store_true', default=False, help='save data offline')
parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')
parser.add_argument('--platform', type=str, default='Ascend', help='run platform')
parser.add_argument('--enable_modelarts', type=ast.literal_eval, default=False)
parser.add_argument('--data_url', type=str)
parser.add_argument('--train_url', type=str)
return parser.parse_known_args()[0]
def run_eval():
"""Run eval"""
args = parse_args()
# INIT
device_id = int(os.getenv('DEVICE_ID', '0'))
if args.enable_modelarts:
import moxing as mox
local_data_url = "/cache/data"
mox.file.copy_parallel(args.data_url, local_data_url)
pretrained_ckpt_path = "/cache/pretrained_ckpt/pretrained.ckpt"
mox.file.copy_parallel(args.pretrained_ckpt, pretrained_ckpt_path)
local_eval_url = "/cache/eval_out"
mox.file.copy_parallel(args.train_url, local_eval_url)
else:
local_data_url = args.data_path
pretrained_ckpt_path = args.pretrained_ckpt
if args.platform == "Ascend":
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
context.set_context(max_call_depth=2048)
else:
raise ValueError("Unsupported platform.")
print(args)
# DATA LOADING
print('Load dataset ...')
data_path = local_data_url
num_workers = 8
test_ds_generator = DatasetGenerator(root=data_path, args=args, split='test', process_data=args.process_data)
test_ds = ds.GeneratorDataset(test_ds_generator, ["data", "label"], num_parallel_workers=num_workers, shuffle=False)
test_ds = test_ds.batch(batch_size=args.batch_size, drop_remainder=True, num_parallel_workers=num_workers)
# MODEL LOADING
net = PointNet2(args.num_category, args.use_normals)
# load checkpoint
print("Load checkpoint: ", args.pretrained_ckpt)
param_dict = load_checkpoint(pretrained_ckpt_path)
load_param_into_net(net, param_dict)
net_loss = NLLLoss()
model = Model(net, net_loss, metrics={"Accuracy": Accuracy()})
# EVAL
net.set_train(False)
print('Starting eval ...')
time_start = time.time()
print('Time: ', time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime()))
result = model.eval(test_ds, dataset_sink_mode=True)
print("result : {}".format(result))
# END
print('Total time cost: {} min'.format("%.2f" % ((time.time() - time_start) / 60)))
if args.enable_modelarts:
mox.file.copy_parallel(local_eval_url, args.train_url)
if __name__ == '__main__':
run_eval()

View File

@ -0,0 +1,72 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############export checkpoint file into air, onnx, mindir models#################
python export.py
"""
import argparse
import ast
import os
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
from src.pointnet2 import PointNet2
parser = argparse.ArgumentParser(description='PointNet2 export')
parser.add_argument("--enable_modelarts", type=ast.literal_eval, default=False,
help="Run on modelArt, default is false.")
parser.add_argument('--data_url', default=None, help='Directory contains dataset.')
parser.add_argument('--train_url', default=None, help='Directory contains checkpoint file')
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file name.")
parser.add_argument("--batch_size", type=int, default=24, help="batch size")
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='MINDIR', help='file format')
parser.add_argument('--num_category', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40')
parser.add_argument('--use_normals', action='store_true', default=False, help='use normals') # channels = 6 if true
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=int(os.getenv('DEVICE_ID', '0')))
context.set_context(max_call_depth=2048)
if args.enable_modelarts:
import moxing as mox
local_data_url = "/cache/data"
mox.file.copy_parallel(args.data_url, local_data_url)
device_id = int(os.getenv('DEVICE_ID'))
local_output_url = '/cache/ckpt' + str(device_id)
mox.file.copy_parallel(src_url=os.path.join(args.train_url, args.ckpt_file),
dst_url=os.path.join(local_output_url, args.ckpt_file))
else:
local_output_url = '.'
if __name__ == '__main__':
net = PointNet2(args.num_category, args.use_normals)
param_dict = load_checkpoint(os.path.join(local_output_url, args.ckpt_file))
print('load ckpt')
load_param_into_net(net, param_dict)
print('load ckpt to net')
net.set_train(False)
input_arr = Tensor(np.ones([args.batch_size, 1024, 3]), mstype.float32)
print('input')
export(net, input_arr, file_name="PointNet2", file_format=args.file_format)
if args.enable_modelarts:
file_name = "PointNet2." + args.file_format.lower()
mox.file.copy_parallel(src_url=file_name,
dst_url=os.path.join(args.train_url, file_name))

View File

@ -0,0 +1,93 @@
#!/bin/bash
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==========================================================================
PRETRAINED_CKPT=""
if [ $# != 3 ] && [ $# != 4 ]
then
echo "Usage: bash scripts/run_distributed_train.sh [RANK_TABLE_FILE] [DATA_PATH] [SAVE_DIR] [PRETRAINDE_CKPT(optional)]"
echo "============================================================"
echo "[RANK_TABLE_FILE]: The path to the HCCL configuration file in JSON format."
echo "[DATA_PATH]: The path to the train and evaluation datasets."
echo "[SAVE_DIR]: The path to save files generated during training."
echo "[PRETRAINDE_CKPT]: (optional) The path to the checkpoint file."
echo "============================================================"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
if [ $# -ge 3 ]
then
RANK_TABLE_FILE=$(get_real_path $1)
DATA_PATH=$(get_real_path $2)
SAVE_DIR=$(get_real_path $3)
if [ ! -f $RANK_TABLE_FILE ]
then
echo "error: RANK_TABLE_FILE=$RANK_TABLE_FILE is not a file"
exit 1
fi
if [ ! -d $DATA_PATH ]
then
echo "error: DATA_PATH=$DATA_PATH is not a directory"
exit 1
fi
fi
if [ $# -ge 4 ]
then
PRETRAINED_CKPT=$(get_real_path $4)
if [ ! -f $PRETRAINED_CKPT ]
then
echo "error: PRETRAINED_CKPT=$PRETRAINED_CKPT is not a file"
exit 1
fi
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$RANK_TABLE_FILE
export MINDSPORE_HCCL_CONFIG_PATH=$RANK_TABLE_FILE
for((i=0;i<${RANK_SIZE};i++))
do
rm -rf device$i
mkdir device$i
cp *.py ./device$i
cp -r scripts ./device$i
cp -r src ./device$i
cd ./device$i
export DEVICE_ID=$i
export RANK_ID=$i
echo "start training for device $i"
env > env$i.log
python train.py \
--data_path=$DATA_PATH \
--pretrained_ckpt=$PRETRAINED_CKPT \
--save_dir=$SAVE_DIR > train.log 2>&1 &
cd ../
done

View File

@ -0,0 +1,61 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
ulimit -m unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ $# != 2 ]
then
echo "Usage: bash scripts/run_eval.sh [DATA_PATH] [PRETRAINDE_CKPT]"
echo "============================================================"
echo "[DATA_PATH]: The path to the train and evaluation datasets."
echo "[PRETRAINDE_CKPT]: The path to the checkpoint file."
echo "============================================================"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATA_PATH=$(get_real_path $1)
PRETRAINED_CKPT=$(get_real_path $2)
if [ ! -d $DATA_PATH ]
then
echo "error: DATA_PATH=$DATA_PATH is not a directory"
exit 1
fi
if [ ! -f $PRETRAINED_CKPT ]
then
echo "error: PRETRAINED_CKPT=$PRETRAINED_CKPT is not a file"
exit 1
fi
python eval.py \
--data_path=$DATA_PATH \
--pretrained_ckpt=$PRETRAINED_CKPT > eval.log 2>&1 &
echo 'running'

View File

@ -0,0 +1,71 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
PRETRAINED_CKPT=""
ulimit -m unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ $# != 2 ] && [ $# != 3 ]
then
echo "Usage: bash scripts/run_standalone_train.sh [DATA_PATH] [SAVE_DIR] [PRETRAINDE_CKPT(optional)]"
echo "============================================================"
echo "[DATA_PATH]: The path to the train and evaluation datasets."
echo "[SAVE_DIR]: The path to save files generated during training."
echo "[PRETRAINDE_CKPT]: (optional) The path to the checkpoint file."
echo "============================================================"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
if [ $# -ge 2 ]
then
DATA_PATH=$(get_real_path $1)
SAVE_DIR=$(get_real_path $2)
if [ ! -d $DATA_PATH ]
then
echo "error: DATA_PATH=$DATA_PATH is not a directory"
exit 1
fi
fi
if [ $# -ge 3 ]
then
PRETRAINED_CKPT=$(get_real_path $3)
if [ ! -f $PRETRAINED_CKPT ]
then
echo "error: PRETRAINED_CKPT=$PRETRAINED_CKPT is not a file"
exit 1
fi
fi
python train.py \
--data_path=$DATA_PATH \
--pretrained_ckpt=$PRETRAINED_CKPT \
--save_dir=$SAVE_DIR > train.log 2>&1 &
echo 'running'

View File

@ -0,0 +1,34 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""callbacks"""
import moxing as mox
from mindspore.train.callback import Callback
class MoxCallBack(Callback):
"""Mox training files from online"""
def __init__(self, local_train_url, train_url, mox_freq):
super(MoxCallBack, self).__init__()
self.local_train_url = local_train_url
self.train_url = train_url
self.mox_freq = mox_freq
def epoch_end(self, run_context):
"""Mox files at the end of each epoch"""
cb_params = run_context.original_args()
if cb_params.cur_epoch_num % self.mox_freq == 0:
mox.file.copy_parallel(self.local_train_url, self.train_url)

View File

@ -0,0 +1,144 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""data preprocessing"""
import os
import pickle
import numpy as np
def pc_normalize(pc):
"""normalize point cloud"""
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
pc = pc / m
return pc
def farthest_point_sample(point, npoint):
"""
Input:
xyz: pointcloud data, [N, D]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [npoint, D]
"""
N, _ = point.shape
xyz = point[:, :3]
centroids = np.zeros((npoint,))
distance = np.ones((N,)) * 1e10
farthest = np.random.randint(0, N)
for i in range(npoint):
centroids[i] = farthest
centroid = xyz[farthest, :]
dist = np.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = np.argmax(distance, -1)
point = point[centroids.astype(np.int32)]
return point
class DatasetGenerator:
"""DatasetGenerator"""
def __init__(self, root, args, split='train', process_data=False):
self.root = root
self.npoints = args.num_point
self.process_data = process_data
self.uniform = args.use_uniform_sample
self.use_normals = args.use_normals
self.num_category = args.num_category
if self.num_category == 10:
self.catfile = os.path.join(self.root, 'modelnet10_shape_names.txt')
else:
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')
self.cat = [line.rstrip() for line in open(self.catfile)]
self.classes = dict(zip(self.cat, range(len(self.cat))))
shape_ids = {}
if self.num_category == 10:
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_train.txt'))]
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_test.txt'))]
else:
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))]
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))]
assert split in ('train', 'test')
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
in range(len(shape_ids[split]))]
print('The size of %s data is %d' % (split, len(self.datapath)))
if self.uniform:
self.save_path = os.path.join(root,
'modelnet%d_%s_%dpts_fps.dat' % (self.num_category, split, self.npoints))
else:
self.save_path = os.path.join(root, 'modelnet%d_%s_%dpts.dat' % (self.num_category, split, self.npoints))
if self.process_data:
if not os.path.exists(self.save_path):
print('Processing data %s (only running in the first time)...' % self.save_path)
self.list_of_points = [None] * len(self.datapath)
self.list_of_labels = [None] * len(self.datapath)
for index in range(len(self.datapath)):
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
if self.uniform:
point_set = farthest_point_sample(point_set, self.npoints)
else:
point_set = point_set[0:self.npoints, :]
self.list_of_points[index] = point_set
self.list_of_labels[index] = cls
with open(self.save_path, 'wb') as f:
pickle.dump([self.list_of_points, self.list_of_labels], f)
else:
print('Load processed data from %s...' % self.save_path)
with open(self.save_path, 'rb') as f:
self.list_of_points, self.list_of_labels = pickle.load(f)
def __len__(self):
return len(self.datapath)
def __getitem__(self, index):
"""get item"""
if self.process_data:
point_set, label = self.list_of_points[index], self.list_of_labels[index]
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
else:
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
label = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
if self.uniform:
point_set = farthest_point_sample(point_set, self.npoints)
else:
point_set = point_set[0:self.npoints, :]
if not self.use_normals:
point_set = point_set[:, 0:3]
return point_set, label[0]

View File

@ -0,0 +1,82 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""network layers initialization"""
import math
import mindspore.nn as nn
from mindspore.common.initializer import initializer, HeUniform, Uniform
def calculate_fan_in_and_fan_out(shape):
"""
calculate fan_in and fan_out
Args:
shape (tuple): input shape.
Returns:
Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
"""
dimensions = len(shape)
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
if dimensions == 2: # Linear
fan_in = shape[1]
fan_out = shape[0]
else:
num_input_fmaps = shape[1]
num_output_fmaps = shape[0]
receptive_field_size = 1
if dimensions > 2:
receptive_field_size = shape[2] * shape[3]
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
class Dense(nn.Dense):
"""Dense"""
def __init__(self, in_channels, out_channels, has_bias=True, activation=None):
super(Dense, self).__init__(in_channels, out_channels, weight_init='normal', bias_init='zeros',
has_bias=has_bias, activation=activation)
self.reset_parameters()
def reset_parameters(self):
"""reset parameters"""
self.weight.set_data(initializer(HeUniform(math.sqrt(5)), self.weight.shape))
if self.has_bias:
fan_in, _ = calculate_fan_in_and_fan_out(self.weight.shape)
bound = 1 / math.sqrt(fan_in)
self.bias.set_data(initializer(Uniform(bound), [self.out_channels]))
class Conv2d(nn.Conv2d):
"""Conv2d"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad_mode='same', padding=0, dilation=1,
group=1, has_bias=True):
super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, pad_mode, padding, dilation, group,
has_bias, weight_init='normal', bias_init='zeros')
self.reset_parameters()
def reset_parameters(self):
"""reset parameters"""
self.weight.set_data(initializer(HeUniform(math.sqrt(5)), self.weight.shape))
if self.has_bias:
fan_in, _ = calculate_fan_in_and_fan_out(self.weight.shape)
bound = 1 / math.sqrt(fan_in)
self.bias.set_data(initializer(Uniform(bound), [self.out_channels]))

View File

@ -0,0 +1,65 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate scheduler"""
import numpy as np
class MultiStepLR:
"""
Multi-step learning rate scheduler
Decays the learning rate by gamma once the number of epoch reaches one of the milestones.
Args:
lr (float): Initial learning rate which is the lower boundary in the cycle.
milestones (list): List of epoch indices. Must be increasing.
gamma (float): Multiplicative factor of learning rate decay.
steps_per_epoch (int): The number of steps per epoch to train for.
max_epoch (int): The number of epochs to train for.
warmup_epochs (int, optional): The number of epochs to Warmup. Default: 0
Outputs:
numpy.ndarray, shape=(1, steps_per_epoch*max_epoch)
Example:
>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.05 if epoch < 30
>>> # lr = 0.005 if 30 <= epoch < 80
>>> # lr = 0.0005 if epoch >= 80
>>> scheduler = MultiStepLR(lr=0.1, milestones=[30,80], gamma=0.1, steps_per_epoch=5000, max_epoch=90)
>>> lr = scheduler.get_lr()
"""
def __init__(self, lr, milestones, gamma, steps_per_epoch, max_epoch):
self.lr = lr
self.milestones = milestones
self.gamma = gamma
self.steps_per_epoch = steps_per_epoch
self.max_epoch = max_epoch
self.total_steps = int(max_epoch * steps_per_epoch)
def get_lr(self):
"""get learning rate"""
lr_each_step = []
current_lr = self.lr
for i in range(self.total_steps):
cur_ep = i // self.steps_per_epoch
if i % self.steps_per_epoch == 0 and cur_ep in self.milestones:
current_lr = current_lr * self.gamma
lr_each_step.append(current_lr)
return np.array(lr_each_step).astype(np.float32)

View File

@ -0,0 +1,91 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""network definition"""
import mindspore.nn as nn
import mindspore.ops as P
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import functional as F
from src.layers import Dense
from src.pointnet2_utils import PointNetSetAbstraction
class PointNet2(nn.Cell):
"""PointNet2"""
def __init__(self, num_class, normal_channel=False):
super(PointNet2, self).__init__()
in_channel = 6 if normal_channel else 3
self.normal_channel = normal_channel
self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=32,
in_channel=in_channel, mlp=[64, 64, 128],
group_all=False)
self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64,
in_channel=128 + 3, mlp=[128, 128, 256],
group_all=False)
self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None,
in_channel=256 + 3, mlp=[256, 512, 1024],
group_all=True)
self.fc1 = Dense(1024, 512)
self.bn1 = nn.BatchNorm1d(512)
self.drop1 = nn.Dropout(0.6)
self.fc2 = Dense(512, 256)
self.bn2 = nn.BatchNorm1d(256)
self.drop2 = nn.Dropout(0.5)
self.fc3 = Dense(256, num_class)
self.relu = P.ReLU()
self.reshape = P.Reshape()
self.log_softmax = P.LogSoftmax()
self.transpose = P.Transpose()
def construct(self, xyz):
"""
construct method
"""
if self.normal_channel:
norm = self.transpose(xyz[:, :, 3:], (0, 2, 1))
xyz = xyz[:, :, :3]
else:
norm = None
l1_xyz, l1_points = self.sa1(xyz, norm) # [B, 3, 512], [B, 128, 512]
l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) # [B, 3, 128], [B, 256, 128]
_, l3_points = self.sa3(l2_xyz, l2_points) # [B, 3, 1], [B, 1024, 1]
x = self.reshape(l3_points, (-1, 1024))
x = self.drop1(self.relu(self.bn1(self.fc1(x))))
x = self.drop2(self.relu(self.bn2(self.fc2(x))))
x = self.fc3(x)
x = self.log_softmax(x)
return x
class NLLLoss(_Loss):
"""NLL loss"""
def __init__(self, reduction='mean'):
super(NLLLoss, self).__init__(reduction)
self.one_hot = P.OneHot()
self.reduce_sum = P.ReduceSum()
def construct(self, logits, label):
"""
construct method
"""
label_one_hot = self.one_hot(label, F.shape(logits)[-1], F.scalar_to_array(1.0), F.scalar_to_array(0.0))
loss = self.reduce_sum(-1.0 * logits * label_one_hot, (1,))
return loss

View File

@ -0,0 +1,237 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""network definition utils"""
import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.numpy as mnp
import mindspore.ops as P
from mindspore.common.tensor import Tensor
from mindspore.ops.primitive import constexpr
from src.layers import Conv2d
@constexpr
def generate_tensor_fps(B, N):
"""generate tensor"""
farthest = Tensor(np.random.randint(N, size=(B,)), ms.int32)
return farthest
@constexpr
def generate_tensor_batch_indices(B):
"""generate tensor"""
return Tensor(np.arange(B), ms.int32)
def square_distance(src, dst):
"""
Calculate Euclid distance between each two points.
src^T * dst = xn * xm + yn * ym + zn * zm;
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
Input:
src: source points, [B, N, C]
dst: target points, [B, M, C]
Output:
dist: per-point square distance, [B, N, M]
"""
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * P.BatchMatMul()(src, P.Transpose()(dst, (0, 2, 1)))
dist += P.Reshape()(P.ReduceSum()(src ** 2, -1), (B, N, 1))
dist += P.Reshape()(P.ReduceSum()(dst ** 2, -1), (B, 1, M))
return dist
def index_points(points, idx):
"""
Input:
points: input points data, [B, N, C]
idx: sample index data, [B, S] or [B, S, nsample]
Return:
new_points:, indexed points data, [B, S, C] or [B, S, nsample, C]
"""
shape = idx.shape
batch_indices = generate_tensor_batch_indices(shape[0])
if len(shape) == 2:
batch_indices = batch_indices.view(shape[0], 1)
else:
batch_indices = batch_indices.view(shape[0], 1, 1)
batch_indices = batch_indices.expand_as(idx)
index = P.Concat(-1)((batch_indices.reshape(idx.shape + (1,)), idx.reshape(idx.shape + (1,))))
new_points = P.GatherNd()(points, index)
return new_points
def farthest_point_sample(xyz, npoint):
"""
Input:
xyz: pointcloud data, [B, N, 3]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [B, npoint]
"""
B, N, _ = xyz.shape
centroids = mnp.zeros((npoint, B), ms.int32)
distance = mnp.ones((B, N), ms.int32) * 1e9
farthest = generate_tensor_fps(B, N)
batch_indices = generate_tensor_batch_indices(B)
for i in range(npoint):
centroids = P.Cast()(centroids, ms.float32)
farthest = P.Cast()(farthest, ms.float32)
centroids[i] = farthest
centroids = P.Cast()(centroids, ms.int32)
farthest = P.Cast()(farthest, ms.int32)
index = P.Concat(-1)((batch_indices.reshape(batch_indices.shape + (1,)),
farthest.reshape(farthest.shape + (1,))))
centroid = P.GatherNd()(xyz, index).reshape((B, 1, 3))
dist = P.ReduceSum()((xyz - centroid) ** 2, -1)
distance = P.Minimum()(distance, dist)
farthest = P.Argmax()(distance)
return P.Transpose()(centroids, (1, 0))
def query_ball_point(radius, nsample, xyz, new_xyz):
"""
Input:
radius: local region radius
nsample: max sample number in local region
xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3]
Return:
group_idx: grouped points index, [B, S, nsample]
"""
B, N, _ = xyz.shape
_, S, _ = new_xyz.shape
group_idx = mnp.arange(0, N, 1, ms.int32).view(1, 1, N)
group_idx = P.Tile()(group_idx, (B, S, 1))
sqrdists = square_distance(new_xyz, xyz)
idx = sqrdists > radius ** 2
group_idx = P.Select()(idx, -1 * P.OnesLike()(group_idx), group_idx)
group_idx = P.Cast()(group_idx, ms.float32)
group_idx, _ = P.TopK()(group_idx, nsample)
group_idx = P.Cast()(group_idx, ms.int32)
group_first = group_idx[:, :, 0].view(B, S, 1)
group_first = P.Tile()(group_first, (1, 1, nsample)) # [B, S, nsample]
index = group_idx != -1
group_first = P.Select()(index, -1 * P.OnesLike()(group_first), group_first)
group_idx = P.Maximum()(group_idx, group_first)
return group_idx
def sample_and_group_all(xyz, points):
"""
Input:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, 1, 3]
new_points: sampled points data, [B, 1, N, 3+D]
"""
B, N, C = xyz.shape
grouped_xyz = P.Reshape()(xyz, (B, 1, N, C))
new_points = P.Concat(-1)((grouped_xyz, P.Reshape()(points, (B, 1, N, -1))))
return new_points
def sample_and_group(npoint, radius, nsample, xyz, points):
"""
Input:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, npoint, nsample, 3]
new_points: sampled points data, [B, npoint, nsample, 3+D]
"""
B, _, C = xyz.shape
S = npoint
fps_idx = farthest_point_sample(xyz, S) # [B, S]
new_xyz = index_points(xyz, fps_idx) # [B, S, C]
idx = query_ball_point(radius, nsample, xyz, new_xyz) # [B, S, nsample]
grouped_xyz = index_points(xyz, idx) # [B, S, nsample, C]
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
if points is not None:
grouped_points = index_points(points, idx)
new_points = P.Concat(-1)((grouped_xyz_norm, grouped_points)) # [B, S, nsample, C+D]
else:
new_points = grouped_xyz_norm
return new_xyz, new_points
class PointNetSetAbstraction(nn.Cell):
"""PointNetSetAbstraction"""
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
super(PointNetSetAbstraction, self).__init__()
self.npoint = npoint
self.radius = radius
self.nsample = nsample
self.group_all = group_all
self.conv1 = Conv2d(in_channel, mlp[0], 1)
self.bn1 = nn.BatchNorm2d(mlp[0])
self.conv2 = Conv2d(mlp[0], mlp[1], 1)
self.bn2 = nn.BatchNorm2d(mlp[1])
self.conv3 = Conv2d(mlp[1], mlp[2], 1)
self.bn3 = nn.BatchNorm2d(mlp[2])
self.relu = P.ReLU()
self.transpose = P.Transpose()
self.reduce_max = P.ReduceMax()
def construct(self, xyz, points):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_xyz: sampled points position data, [B, C, S]
new_points_concat: sample points feature data, [B, D', S]
"""
if points is not None:
points = self.transpose(points, (0, 2, 1))
if self.group_all:
new_points = sample_and_group_all(xyz, points)
new_xyz = None
else:
new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
# new_xyz: sampled points position data, [B, npoint, C]
# new_points: sampled points data, [B, npoint, nsample, C+D]
d1, d2, d3, d4 = new_points.shape
new_points = self.transpose(new_points.reshape((d1, d2 * d3, d4)), (0, 2, 1))
new_points = self.transpose(new_points.reshape((d1 * d4, d2, d3)), (0, 2, 1)).reshape((d1, d4, d3, d2))
new_points = self.relu(self.bn1(self.conv1(new_points)))
new_points = self.relu(self.bn2(self.conv2(new_points)))
new_points = self.relu(self.bn3(self.conv3(new_points)))
new_points = self.reduce_max(new_points, 2)
return new_xyz, new_points

View File

@ -0,0 +1,298 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""data preprocessing for training"""
import numpy as np
def normalize_data(batch_data):
"""
Normalize the batch data, use coordinates of the block centered at origin,
Input:
BxNxC array
Output:
BxNxC array
"""
B, N, C = batch_data.shape
normal_data = np.zeros((B, N, C))
for b in range(B):
pc = batch_data[b]
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
pc = pc / m
normal_data[b] = pc
return normal_data
def shuffle_data(data, labels):
"""
Shuffle data and labels.
Input:
data: B,N,... numpy array
label: B,... numpy array
Return:
shuffled data, label and shuffle indices
"""
idx = np.arange(len(labels))
np.random.shuffle(idx)
return data[idx, ...], labels[idx], idx
def shuffle_points(batch_data):
"""
Shuffle orders of points in each point cloud -- changes FPS behavior.
Use the same shuffling idx for the entire batch.
Input:
BxNxC array
Output:
BxNxC array
"""
idx = np.arange(batch_data.shape[1])
np.random.shuffle(idx)
return batch_data[:, idx, :]
def rotate_point_cloud(batch_data):
"""
Randomly rotate the point clouds to augment the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data
def rotate_point_cloud_z(batch_data):
"""
Randomly rotate the point clouds to augment the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, sinval, 0],
[-sinval, cosval, 0],
[0, 0, 1]])
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data
def rotate_point_cloud_with_normal(batch_xyz_normal):
"""
Randomly rotate XYZ, normal point cloud.
Input:
batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal
Output:
B,N,6, rotated XYZ, normal point cloud
"""
for k in range(batch_xyz_normal.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_xyz_normal[k, :, 0:3]
shape_normal = batch_xyz_normal[k, :, 3:6]
batch_xyz_normal[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
batch_xyz_normal[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix)
return batch_xyz_normal
def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18):
"""
Randomly perturb the point clouds by small rotations
Input:
BxNx6 array, original batch of point clouds and point normals
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma * np.random.randn(3), -angle_clip, angle_clip)
Rx = np.array([[1, 0, 0],
[0, np.cos(angles[0]), -np.sin(angles[0])],
[0, np.sin(angles[0]), np.cos(angles[0])]])
Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])],
[0, 1, 0],
[-np.sin(angles[1]), 0, np.cos(angles[1])]])
Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0],
[np.sin(angles[2]), np.cos(angles[2]), 0],
[0, 0, 1]])
R = np.dot(Rz, np.dot(Ry, Rx))
shape_pc = batch_data[k, :, 0:3]
shape_normal = batch_data[k, :, 3:6]
rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), R)
rotated_data[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), R)
return rotated_data
def rotate_point_cloud_by_angle(batch_data, rotation_angle):
"""
Rotate the point cloud along up direction with certain angle.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k, :, 0:3]
rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data
def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):
"""
Rotate the point cloud along up direction with certain angle.
Input:
BxNx6 array, original batch of point clouds with normal
scalar, angle of rotation
Return:
BxNx6 array, rotated batch of point clouds with normal
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k, :, 0:3]
shape_normal = batch_data[k, :, 3:6]
rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
rotated_data[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix)
return rotated_data
def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18):
"""
Randomly perturb the point clouds by small rotations
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma * np.random.randn(3), -angle_clip, angle_clip)
Rx = np.array([[1, 0, 0],
[0, np.cos(angles[0]), -np.sin(angles[0])],
[0, np.sin(angles[0]), np.cos(angles[0])]])
Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])],
[0, 1, 0],
[-np.sin(angles[1]), 0, np.cos(angles[1])]])
Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0],
[np.sin(angles[2]), np.cos(angles[2]), 0],
[0, 0, 1]])
R = np.dot(Rz, np.dot(Ry, Rx))
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
return rotated_data
def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
"""
Randomly jitter points.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, processed batch of point clouds
"""
B, N, C = batch_data.shape
assert clip > 0
jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1 * clip, clip)
jittered_data += batch_data
return jittered_data
def shift_point_cloud(batch_data, shift_range=0.1):
"""
Randomly shift point cloud. Shift is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, shifted batch of point clouds
"""
B, _, _ = batch_data.shape
shifts = np.random.uniform(-shift_range, shift_range, (B, 3))
for batch_index in range(B):
batch_data[batch_index, :, :] += shifts[batch_index, :]
return batch_data
def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
"""
Randomly scale the point cloud. Scale is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, scaled batch of point clouds
"""
B, _, _ = batch_data.shape
scales = np.random.uniform(scale_low, scale_high, B)
for batch_index in range(B):
batch_data[batch_index, :, :] *= scales[batch_index]
return batch_data
def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
""" batch_pc: BxNx3 """
for b in range(batch_pc.shape[0]):
dropout_ratio = np.random.random() * max_dropout_ratio # 0~0.875
drop_idx = np.where(np.random.random((batch_pc.shape[1])) <= dropout_ratio)[0]
if drop_idx.any():
batch_pc[b, drop_idx, :] = batch_pc[b, 0, :] # set to the first point
return batch_pc
class RandomInputDropout:
"""random input dropout during training"""
def __init__(self):
pass
def __call__(self, data, label, batchInfo):
data = np.array(data)
label = np.array(label)
data = random_point_dropout(data)
data[:, :, 0:3] = random_scale_point_cloud(data[:, :, 0:3])
data[:, :, 0:3] = shift_point_cloud(data[:, :, 0:3])
return data, label

View File

@ -0,0 +1,203 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Training"""
import argparse
import ast
import os
import sys
import time
import mindspore.dataset as ds
import mindspore.nn as nn
from mindspore import Model, Tensor, context, load_checkpoint, load_param_into_net
from mindspore.communication.management import init
from mindspore.context import ParallelMode
from mindspore.profiler import Profiler
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from src.dataset import DatasetGenerator
from src.lr_scheduler import MultiStepLR
from src.pointnet2 import PointNet2, NLLLoss
from src.provider import RandomInputDropout
def parse_args():
"""PARAMETERS"""
parser = argparse.ArgumentParser('MindSpore PointNet++ Training Configurations.')
parser.add_argument('--batch_size', type=int, default=24, help='batch size in training') # 24
parser.add_argument('--epoch', default=200, type=int, help='number of epoch in training') # 200
parser.add_argument('--learning_rate', default=0.001, type=float, help='learning rate in training') # 0.001
parser.add_argument('--optimizer', type=str, default='Adam', choices=('Adam', 'SGD'),
help='optimizer for training') # Adam
parser.add_argument('--data_path', type=str, default='../modelnet40_normal_resampled/', help='data path')
parser.add_argument('--pretrained_ckpt', type=str, default='')
parser.add_argument('--loss_per_epoch', type=int, default=5, help='times to print loss value per epoch')
parser.add_argument('--save_dir', type=str, default='./save', help='save root')
parser.add_argument('--use_normals', type=ast.literal_eval, default=False, help='use normals')
parser.add_argument('--decay_rate', type=float, default=1e-4, help='decay rate')
parser.add_argument('--num_category', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40')
parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
parser.add_argument('--process_data', action='store_true', default=False, help='save data offline')
parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampling')
parser.add_argument('--platform', type=str, default='Ascend', help='run platform')
parser.add_argument('--enable_profiling', type=ast.literal_eval, default=False)
parser.add_argument('--enable_modelarts', type=ast.literal_eval, default=False)
parser.add_argument('--data_url', type=str)
parser.add_argument('--train_url', type=str)
parser.add_argument('--mox_freq', type=int, default=10, help='mox frequency')
return parser.parse_known_args()[0]
def run_train():
"""run train"""
args = parse_args()
# INIT
device_id = int(os.getenv('DEVICE_ID', '0'))
device_num = int(os.getenv('RANK_SIZE', '1'))
rank_id = int(os.getenv('RANK_ID', '0'))
if args.platform == "Ascend":
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
context.set_context(max_call_depth=2048)
if device_num > 1:
init()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
else:
raise ValueError("Unsupported platform.")
if args.enable_modelarts:
import moxing as mox
local_data_url = "/cache/data"
mox.file.copy_parallel(args.data_url, local_data_url)
if args.pretrained_ckpt.endswith('.ckpt'):
pretrained_ckpt_path = "/cache/pretrained_ckpt/pretrained.ckpt"
mox.file.copy_parallel(args.pretrained_ckpt, pretrained_ckpt_path)
local_train_url = "/cache/train_output"
save_dir = local_train_url
if rank_id == 0:
mox.file.copy_parallel(os.path.join(args.train_url, 'log_train.txt'),
os.path.join(save_dir, 'log_train.txt'))
log_file = open(os.path.join(save_dir, 'log_train.txt'), 'w')
sys.stdout = log_file
else:
local_data_url = args.data_path
if args.pretrained_ckpt.endswith('.ckpt'):
pretrained_ckpt_path = args.pretrained_ckpt
local_train_url = args.save_dir
save_dir = args.save_dir
if not os.path.exists(save_dir):
os.makedirs(save_dir)
if args.enable_profiling:
profiler = Profiler()
print(args)
# DATA LOADING
print('Load dataset ...')
data_path = local_data_url
num_workers = 4
train_ds_generator = DatasetGenerator(root=data_path, args=args, split='train', process_data=args.process_data)
if device_num > 1:
train_ds = ds.GeneratorDataset(train_ds_generator, ["data", "label"], num_parallel_workers=num_workers,
shuffle=True, shard_id=rank_id, num_shards=device_num)
else:
train_ds = ds.GeneratorDataset(train_ds_generator, ["data", "label"], num_parallel_workers=num_workers,
shuffle=True)
random_input_dropout = RandomInputDropout()
train_ds = train_ds.batch(batch_size=args.batch_size, per_batch_map=random_input_dropout,
input_columns=["data", "label"], drop_remainder=True, num_parallel_workers=num_workers)
steps_per_epoch = train_ds.get_dataset_size()
# MODEL
net = PointNet2(args.num_category, args.use_normals)
# load checkpoint
if args.pretrained_ckpt.endswith('.ckpt'):
print("Load checkpoint: %s" % args.pretrained_ckpt)
param_dict = load_checkpoint(pretrained_ckpt_path)
load_param_into_net(net, param_dict)
net_loss = NLLLoss()
lr_epochs = list(range(20, 201, 20))
lr_fun = MultiStepLR(args.learning_rate, lr_epochs, 0.7, steps_per_epoch, args.epoch)
lr = lr_fun.get_lr()
if args.optimizer == 'Adam':
net_opt = nn.Adam(
net.trainable_params(),
learning_rate=Tensor(lr),
beta1=0.9,
beta2=0.999,
eps=1e-08,
weight_decay=args.decay_rate
)
else:
net_opt = nn.SGD(net.trainable_params(), learning_rate=args.learning_rate, momentum=0.9)
model = Model(net, net_loss, net_opt)
config_ck = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=args.epoch)
ckpt_cb = ModelCheckpoint(prefix="ckpt_pointnet2", directory=local_train_url, config=config_ck)
loss_freq = max(steps_per_epoch // args.loss_per_epoch, 1)
cb = []
cb += [TimeMonitor()]
cb += [LossMonitor(loss_freq)]
if (not args.enable_modelarts) or (rank_id == 0):
cb += [ckpt_cb]
if args.enable_modelarts:
from src.callbacks import MoxCallBack
cb += [MoxCallBack(local_train_url, args.train_url, args.mox_freq)]
# TRAINING
net.set_train()
print('Starting training ...')
print('Time: ', time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime()))
if args.enable_modelarts:
mox.file.copy_parallel(local_train_url, args.train_url)
time_start = time.time()
model.train(epoch=args.epoch, train_dataset=train_ds, callbacks=cb, dataset_sink_mode=True)
# END
print('End of training.')
print('Total time cost: {} min'.format("%.2f" % ((time.time() - time_start) / 60)))
if args.enable_profiling:
profiler.analyse()
if args.enable_modelarts and rank_id == 0:
log_file.close()
mox.file.copy_parallel(local_train_url, args.train_url)
if __name__ == '__main__':
run_train()