forked from mindspore-Ecosystem/mindspore
add pointnet2
This commit is contained in:
parent
7eac436c56
commit
1d41475e61
|
@ -0,0 +1,230 @@
|
|||
# Contents
|
||||
|
||||
- [Contents](#contents)
|
||||
- [PointNet2 Description](#pointnet2-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Training Performance](#training-performance)
|
||||
- [Inference Performance](#inference-performance)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [PointNet2 Description](#contents)
|
||||
|
||||
PointNet++ was proposed in 2017, it is a hierarchical neural network that applies PointNet recursively on a nested
|
||||
partitioning of the input point set. By exploiting metric space distances, this network is able to learn local features
|
||||
with increasing contextual scales. Experiments show that our network called PointNet++ is able to learn deep point set
|
||||
features efficiently and robustly.
|
||||
|
||||
[Paper](http://arxiv.org/abs/1706.02413): Qi, Charles R., et al. "Pointnet++: Deep hierarchical feature learning on
|
||||
point sets in a metric space." arXiv preprint arXiv:1706.02413 (2017).
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
The hierarchical structure of PointNet++ is composed by a number of *set abstraction* levels. At each level, a set of
|
||||
points is processed and abstracted to produce a new set with fewer elements. The set abstraction level is made of three
|
||||
key layers: *Sampling layer*, *Grouping layer* and *PointNet layer*. The *Sampling* *layer* selects a set of points from
|
||||
input points, which defines the centroids of local regions. *Grouping* *layer* then constructs local region sets by
|
||||
finding “neighboring” points around the centroids. *PointNet* *layer* uses a mini-PointNet to encode local region
|
||||
patterns into feature vectors.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Dataset used: alignment [ModelNet40](<https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip>)
|
||||
|
||||
- Dataset size:6.48G,Each point cloud contains 2048 points uniformly sampled from a shape surface. Each cloud is
|
||||
zero-mean and normalized into a unit sphere.
|
||||
- Train:5.18G, 9843 point clouds
|
||||
- Test:1.3G, 2468 point clouds
|
||||
- Data format:txt files
|
||||
- Note:Data will be processed in src/dataset.py
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend)
|
||||
- Prepare hardware environment with Ascend processor.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
|
||||
|
||||
# [Quick Start](#contents)
|
||||
|
||||
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||
|
||||
```shell
|
||||
# Run stand-alone training
|
||||
bash scripts/run_standalone_train.sh [DATA_PATH] [SAVE_DIR] [PRETRAINDE_CKPT(optional)]
|
||||
# example:
|
||||
bash scripts/run_standalone_train.sh modelnet40_normal_resampled save pointnet2.ckpt
|
||||
|
||||
# Run distributed training
|
||||
bash scripts/run_distributed_train.sh [RANK_TABLE_FILE] [DATA_PATH] [SAVE_DIR] [PRETRAINDE_CKPT(optional)]
|
||||
# example:
|
||||
bash scripts/run_standalone_train.sh hccl_8p_01234567_127.0.0.1.json modelnet40_normal_resampled save pointnet2.ckpt
|
||||
|
||||
# Evaluate
|
||||
bash scripts/run_eval.sh [DATA_PATH] [CKPT_NAME]
|
||||
# example:
|
||||
bash scripts/run_eval.sh modelnet40_normal_resampled pointnet2.ckpt
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
# [Script and Sample Code](#contents)
|
||||
|
||||
```bash
|
||||
├── .
|
||||
├── PointNet2
|
||||
├── scripts
|
||||
│ ├── run_distribute_train.sh # launch distributed training with ascend platform (8p)
|
||||
│ ├── run_eval.sh # launch evaluating with ascend platform
|
||||
│ └── run_train_ascend.sh # launch standalone training with ascend platform (1p)
|
||||
├── src
|
||||
│ ├── callbacks.py # callbacks definition
|
||||
│ ├── dataset.py # data preprocessing
|
||||
│ ├── layers.py # network layers initialization
|
||||
│ ├── lr_scheduler.py # learning rate scheduler
|
||||
│ ├── PointNet2.py # network definition
|
||||
│ ├── PointNet2_utils.py # network definition utils
|
||||
│ └── provider.py # data preprocessing for training
|
||||
├── eval.py # eval net
|
||||
├── README.md
|
||||
├── requirements.txt
|
||||
└── train.py # train net
|
||||
```
|
||||
|
||||
# [Script Parameters](#contents)
|
||||
|
||||
```bash
|
||||
Major parameters in train.py are as follows:
|
||||
--batch_size # Training batch size.
|
||||
--epoch # Total training epochs.
|
||||
--learning_rate # Training learning rate.
|
||||
--optimizer # Optimizer for training. Optional values are "Adam", "SGD".
|
||||
--data_path # The path to the train and evaluation datasets.
|
||||
--loss_per_epoch # The times to print loss value per epoch.
|
||||
--save_dir # The path to save files generated during training.
|
||||
--use_normals # Whether to use normals data in training.
|
||||
--pretrained_ckpt # The file path to load checkpoint.
|
||||
--enable_modelarts # Whether to use modelarts.
|
||||
```
|
||||
|
||||
# [Training Process](#contents)
|
||||
|
||||
## Training
|
||||
|
||||
- running on Ascend
|
||||
|
||||
```shell
|
||||
# Run stand-alone training
|
||||
bash scripts/run_standalone_train.sh [DATA_PATH] [SAVE_DIR] [PRETRAINDE_CKPT(optional)]
|
||||
# example:
|
||||
bash scripts/run_standalone_train.sh modelnet40_normal_resampled save pointnet2.ckpt
|
||||
|
||||
# Run distributed training
|
||||
bash scripts/run_distributed_train.sh [RANK_TABLE_FILE] [DATA_PATH] [SAVE_DIR] [PRETRAINDE_CKPT(optional)]
|
||||
# example:
|
||||
bash scripts/run_standalone_train.sh hccl_8p_01234567_127.0.0.1.json modelnet40_normal_resampled save pointnet2.ckpt
|
||||
```
|
||||
|
||||
Distributed training requires the creation of an HCCL configuration file in JSON format in advance. For specific
|
||||
operations, see the instructions
|
||||
in [hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
|
||||
|
||||
After training, the loss value will be achieved as follows:
|
||||
|
||||
```bash
|
||||
# train log
|
||||
epoch: 1 step: 410, loss is 1.4731973
|
||||
epoch time: 704454.051 ms, per step time: 1718.181 ms
|
||||
epoch: 2 step: 410, loss is 1.0621885
|
||||
epoch time: 471478.224 ms, per step time: 1149.947 ms
|
||||
epoch: 3 step: 410, loss is 1.176581
|
||||
epoch time: 471530.000 ms, per step time: 1150.073 ms
|
||||
epoch: 4 step: 410, loss is 1.0118457
|
||||
epoch time: 471498.514 ms, per step time: 1149.996 ms
|
||||
epoch: 5 step: 410, loss is 0.47454038
|
||||
epoch time: 471535.602 ms, per step time: 1150.087 ms
|
||||
...
|
||||
```
|
||||
|
||||
The model checkpoint will be saved in the 'SAVE_DIR' directory.
|
||||
|
||||
# [Evaluation Process](#contents)
|
||||
|
||||
## Evaluation
|
||||
|
||||
Before running the command below, please check the checkpoint path used for evaluation.
|
||||
|
||||
- running on Ascend
|
||||
|
||||
```shell
|
||||
# Evaluate
|
||||
bash scripts/run_eval.sh [DATA_PATH] [CKPT_NAME]
|
||||
# example:
|
||||
bash scripts/run_eval.sh modelnet40_normal_resampled pointnet2.ckpt
|
||||
```
|
||||
|
||||
You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:
|
||||
|
||||
```bash
|
||||
# grep "Accuracy: " eval.log
|
||||
'Accuracy': 0.9146
|
||||
```
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
## Training Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| Model Version | PointNet++ |
|
||||
| Resource | Ascend 910; CPU 24cores; Memory 256G; OS Euler2.8 |
|
||||
| uploaded Date | 08/31/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.3.0 |
|
||||
| Dataset | ModelNet40 |
|
||||
| Training Parameters | epoch=200, steps=82000, batch_size=24, lr=0.001 |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | NLLLoss |
|
||||
| outputs | probability |
|
||||
| Loss | 0.01 |
|
||||
| Speed | 1.2 s/step (1p) |
|
||||
| Total time | 27.3 h (1p) |
|
||||
| Checkpoint for Fine tuning | 17 MB (.ckpt file) |
|
||||
|
||||
## Inference Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | PointNet++ |
|
||||
| Resource | Ascend 910; CPU 24cores; Memory 256G; OS Euler2.8 |
|
||||
| Uploaded Date | 08/31/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.3.0 |
|
||||
| Dataset | ModelNet40 |
|
||||
| Batch_size | 24 |
|
||||
| Outputs | probability |
|
||||
| Accuracy | 91.5% (1p) |
|
||||
| Total time | 2.5 min |
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
We use random seed in dataset.py, provider.py and pointnet2_utils.py.
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Eval"""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import os
|
||||
import time
|
||||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import context
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.train import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.dataset import DatasetGenerator
|
||||
from src.pointnet2 import PointNet2, NLLLoss
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""PARAMETERS"""
|
||||
parser = argparse.ArgumentParser('MindSpore PointNet++ Eval Configurations.')
|
||||
parser.add_argument('--batch_size', type=int, default=24, help='batch size in training')
|
||||
parser.add_argument('--data_path', type=str, default='../data/modelnet40_normal_resampled/', help='data path')
|
||||
parser.add_argument('--pretrained_ckpt', type=str, default='')
|
||||
parser.add_argument('--num_category', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40')
|
||||
parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
|
||||
parser.add_argument('--use_normals', type=ast.literal_eval, default=False, help='use normals')
|
||||
parser.add_argument('--process_data', action='store_true', default=False, help='save data offline')
|
||||
parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')
|
||||
|
||||
parser.add_argument('--platform', type=str, default='Ascend', help='run platform')
|
||||
parser.add_argument('--enable_modelarts', type=ast.literal_eval, default=False)
|
||||
parser.add_argument('--data_url', type=str)
|
||||
parser.add_argument('--train_url', type=str)
|
||||
|
||||
return parser.parse_known_args()[0]
|
||||
|
||||
|
||||
def run_eval():
|
||||
"""Run eval"""
|
||||
args = parse_args()
|
||||
|
||||
# INIT
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
|
||||
if args.enable_modelarts:
|
||||
import moxing as mox
|
||||
|
||||
local_data_url = "/cache/data"
|
||||
mox.file.copy_parallel(args.data_url, local_data_url)
|
||||
pretrained_ckpt_path = "/cache/pretrained_ckpt/pretrained.ckpt"
|
||||
mox.file.copy_parallel(args.pretrained_ckpt, pretrained_ckpt_path)
|
||||
local_eval_url = "/cache/eval_out"
|
||||
mox.file.copy_parallel(args.train_url, local_eval_url)
|
||||
else:
|
||||
local_data_url = args.data_path
|
||||
pretrained_ckpt_path = args.pretrained_ckpt
|
||||
|
||||
if args.platform == "Ascend":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
|
||||
context.set_context(max_call_depth=2048)
|
||||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
|
||||
print(args)
|
||||
|
||||
# DATA LOADING
|
||||
print('Load dataset ...')
|
||||
data_path = local_data_url
|
||||
|
||||
num_workers = 8
|
||||
test_ds_generator = DatasetGenerator(root=data_path, args=args, split='test', process_data=args.process_data)
|
||||
test_ds = ds.GeneratorDataset(test_ds_generator, ["data", "label"], num_parallel_workers=num_workers, shuffle=False)
|
||||
test_ds = test_ds.batch(batch_size=args.batch_size, drop_remainder=True, num_parallel_workers=num_workers)
|
||||
|
||||
# MODEL LOADING
|
||||
net = PointNet2(args.num_category, args.use_normals)
|
||||
|
||||
# load checkpoint
|
||||
print("Load checkpoint: ", args.pretrained_ckpt)
|
||||
param_dict = load_checkpoint(pretrained_ckpt_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
net_loss = NLLLoss()
|
||||
|
||||
model = Model(net, net_loss, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
# EVAL
|
||||
net.set_train(False)
|
||||
print('Starting eval ...')
|
||||
time_start = time.time()
|
||||
print('Time: ', time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime()))
|
||||
|
||||
result = model.eval(test_ds, dataset_sink_mode=True)
|
||||
print("result : {}".format(result))
|
||||
|
||||
# END
|
||||
print('Total time cost: {} min'.format("%.2f" % ((time.time() - time_start) / 60)))
|
||||
|
||||
if args.enable_modelarts:
|
||||
mox.file.copy_parallel(local_eval_url, args.train_url)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_eval()
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
##############export checkpoint file into air, onnx, mindir models#################
|
||||
python export.py
|
||||
"""
|
||||
import argparse
|
||||
import ast
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
|
||||
|
||||
from src.pointnet2 import PointNet2
|
||||
|
||||
parser = argparse.ArgumentParser(description='PointNet2 export')
|
||||
parser.add_argument("--enable_modelarts", type=ast.literal_eval, default=False,
|
||||
help="Run on modelArt, default is false.")
|
||||
parser.add_argument('--data_url', default=None, help='Directory contains dataset.')
|
||||
parser.add_argument('--train_url', default=None, help='Directory contains checkpoint file')
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file name.")
|
||||
parser.add_argument("--batch_size", type=int, default=24, help="batch size")
|
||||
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='MINDIR', help='file format')
|
||||
parser.add_argument('--num_category', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40')
|
||||
parser.add_argument('--use_normals', action='store_true', default=False, help='use normals') # channels = 6 if true
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_context(device_id=int(os.getenv('DEVICE_ID', '0')))
|
||||
context.set_context(max_call_depth=2048)
|
||||
|
||||
if args.enable_modelarts:
|
||||
import moxing as mox
|
||||
|
||||
local_data_url = "/cache/data"
|
||||
mox.file.copy_parallel(args.data_url, local_data_url)
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
local_output_url = '/cache/ckpt' + str(device_id)
|
||||
mox.file.copy_parallel(src_url=os.path.join(args.train_url, args.ckpt_file),
|
||||
dst_url=os.path.join(local_output_url, args.ckpt_file))
|
||||
else:
|
||||
local_output_url = '.'
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
net = PointNet2(args.num_category, args.use_normals)
|
||||
|
||||
param_dict = load_checkpoint(os.path.join(local_output_url, args.ckpt_file))
|
||||
print('load ckpt')
|
||||
load_param_into_net(net, param_dict)
|
||||
print('load ckpt to net')
|
||||
net.set_train(False)
|
||||
input_arr = Tensor(np.ones([args.batch_size, 1024, 3]), mstype.float32)
|
||||
print('input')
|
||||
export(net, input_arr, file_name="PointNet2", file_format=args.file_format)
|
||||
if args.enable_modelarts:
|
||||
file_name = "PointNet2." + args.file_format.lower()
|
||||
mox.file.copy_parallel(src_url=file_name,
|
||||
dst_url=os.path.join(args.train_url, file_name))
|
|
@ -0,0 +1,93 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==========================================================================
|
||||
|
||||
PRETRAINED_CKPT=""
|
||||
|
||||
if [ $# != 3 ] && [ $# != 4 ]
|
||||
then
|
||||
echo "Usage: bash scripts/run_distributed_train.sh [RANK_TABLE_FILE] [DATA_PATH] [SAVE_DIR] [PRETRAINDE_CKPT(optional)]"
|
||||
echo "============================================================"
|
||||
echo "[RANK_TABLE_FILE]: The path to the HCCL configuration file in JSON format."
|
||||
echo "[DATA_PATH]: The path to the train and evaluation datasets."
|
||||
echo "[SAVE_DIR]: The path to save files generated during training."
|
||||
echo "[PRETRAINDE_CKPT]: (optional) The path to the checkpoint file."
|
||||
echo "============================================================"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
if [ $# -ge 3 ]
|
||||
then
|
||||
RANK_TABLE_FILE=$(get_real_path $1)
|
||||
DATA_PATH=$(get_real_path $2)
|
||||
SAVE_DIR=$(get_real_path $3)
|
||||
|
||||
if [ ! -f $RANK_TABLE_FILE ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$RANK_TABLE_FILE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $DATA_PATH ]
|
||||
then
|
||||
echo "error: DATA_PATH=$DATA_PATH is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $# -ge 4 ]
|
||||
then
|
||||
PRETRAINED_CKPT=$(get_real_path $4)
|
||||
if [ ! -f $PRETRAINED_CKPT ]
|
||||
then
|
||||
echo "error: PRETRAINED_CKPT=$PRETRAINED_CKPT is not a file"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$RANK_TABLE_FILE
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$RANK_TABLE_FILE
|
||||
|
||||
for((i=0;i<${RANK_SIZE};i++))
|
||||
do
|
||||
rm -rf device$i
|
||||
mkdir device$i
|
||||
cp *.py ./device$i
|
||||
cp -r scripts ./device$i
|
||||
cp -r src ./device$i
|
||||
cd ./device$i
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
echo "start training for device $i"
|
||||
env > env$i.log
|
||||
|
||||
python train.py \
|
||||
--data_path=$DATA_PATH \
|
||||
--pretrained_ckpt=$PRETRAINED_CKPT \
|
||||
--save_dir=$SAVE_DIR > train.log 2>&1 &
|
||||
|
||||
cd ../
|
||||
done
|
|
@ -0,0 +1,61 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
ulimit -m unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: bash scripts/run_eval.sh [DATA_PATH] [PRETRAINDE_CKPT]"
|
||||
echo "============================================================"
|
||||
echo "[DATA_PATH]: The path to the train and evaluation datasets."
|
||||
echo "[PRETRAINDE_CKPT]: The path to the checkpoint file."
|
||||
echo "============================================================"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATA_PATH=$(get_real_path $1)
|
||||
PRETRAINED_CKPT=$(get_real_path $2)
|
||||
|
||||
if [ ! -d $DATA_PATH ]
|
||||
then
|
||||
echo "error: DATA_PATH=$DATA_PATH is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $PRETRAINED_CKPT ]
|
||||
then
|
||||
echo "error: PRETRAINED_CKPT=$PRETRAINED_CKPT is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
python eval.py \
|
||||
--data_path=$DATA_PATH \
|
||||
--pretrained_ckpt=$PRETRAINED_CKPT > eval.log 2>&1 &
|
||||
|
||||
echo 'running'
|
|
@ -0,0 +1,71 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
PRETRAINED_CKPT=""
|
||||
|
||||
ulimit -m unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
if [ $# != 2 ] && [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: bash scripts/run_standalone_train.sh [DATA_PATH] [SAVE_DIR] [PRETRAINDE_CKPT(optional)]"
|
||||
echo "============================================================"
|
||||
echo "[DATA_PATH]: The path to the train and evaluation datasets."
|
||||
echo "[SAVE_DIR]: The path to save files generated during training."
|
||||
echo "[PRETRAINDE_CKPT]: (optional) The path to the checkpoint file."
|
||||
echo "============================================================"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
if [ $# -ge 2 ]
|
||||
then
|
||||
DATA_PATH=$(get_real_path $1)
|
||||
SAVE_DIR=$(get_real_path $2)
|
||||
|
||||
if [ ! -d $DATA_PATH ]
|
||||
then
|
||||
echo "error: DATA_PATH=$DATA_PATH is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $# -ge 3 ]
|
||||
then
|
||||
PRETRAINED_CKPT=$(get_real_path $3)
|
||||
if [ ! -f $PRETRAINED_CKPT ]
|
||||
then
|
||||
echo "error: PRETRAINED_CKPT=$PRETRAINED_CKPT is not a file"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
python train.py \
|
||||
--data_path=$DATA_PATH \
|
||||
--pretrained_ckpt=$PRETRAINED_CKPT \
|
||||
--save_dir=$SAVE_DIR > train.log 2>&1 &
|
||||
|
||||
echo 'running'
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""callbacks"""
|
||||
|
||||
import moxing as mox
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
|
||||
class MoxCallBack(Callback):
|
||||
"""Mox training files from online"""
|
||||
|
||||
def __init__(self, local_train_url, train_url, mox_freq):
|
||||
super(MoxCallBack, self).__init__()
|
||||
self.local_train_url = local_train_url
|
||||
self.train_url = train_url
|
||||
self.mox_freq = mox_freq
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""Mox files at the end of each epoch"""
|
||||
cb_params = run_context.original_args()
|
||||
if cb_params.cur_epoch_num % self.mox_freq == 0:
|
||||
mox.file.copy_parallel(self.local_train_url, self.train_url)
|
|
@ -0,0 +1,144 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""data preprocessing"""
|
||||
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def pc_normalize(pc):
|
||||
"""normalize point cloud"""
|
||||
centroid = np.mean(pc, axis=0)
|
||||
pc = pc - centroid
|
||||
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
|
||||
pc = pc / m
|
||||
return pc
|
||||
|
||||
|
||||
def farthest_point_sample(point, npoint):
|
||||
"""
|
||||
Input:
|
||||
xyz: pointcloud data, [N, D]
|
||||
npoint: number of samples
|
||||
Return:
|
||||
centroids: sampled pointcloud index, [npoint, D]
|
||||
"""
|
||||
N, _ = point.shape
|
||||
xyz = point[:, :3]
|
||||
centroids = np.zeros((npoint,))
|
||||
distance = np.ones((N,)) * 1e10
|
||||
farthest = np.random.randint(0, N)
|
||||
for i in range(npoint):
|
||||
centroids[i] = farthest
|
||||
centroid = xyz[farthest, :]
|
||||
dist = np.sum((xyz - centroid) ** 2, -1)
|
||||
mask = dist < distance
|
||||
distance[mask] = dist[mask]
|
||||
farthest = np.argmax(distance, -1)
|
||||
point = point[centroids.astype(np.int32)]
|
||||
return point
|
||||
|
||||
|
||||
class DatasetGenerator:
|
||||
"""DatasetGenerator"""
|
||||
|
||||
def __init__(self, root, args, split='train', process_data=False):
|
||||
self.root = root
|
||||
self.npoints = args.num_point
|
||||
self.process_data = process_data
|
||||
self.uniform = args.use_uniform_sample
|
||||
self.use_normals = args.use_normals
|
||||
self.num_category = args.num_category
|
||||
|
||||
if self.num_category == 10:
|
||||
self.catfile = os.path.join(self.root, 'modelnet10_shape_names.txt')
|
||||
else:
|
||||
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')
|
||||
|
||||
self.cat = [line.rstrip() for line in open(self.catfile)]
|
||||
self.classes = dict(zip(self.cat, range(len(self.cat))))
|
||||
|
||||
shape_ids = {}
|
||||
if self.num_category == 10:
|
||||
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_train.txt'))]
|
||||
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_test.txt'))]
|
||||
else:
|
||||
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))]
|
||||
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))]
|
||||
|
||||
assert split in ('train', 'test')
|
||||
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
|
||||
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
|
||||
in range(len(shape_ids[split]))]
|
||||
print('The size of %s data is %d' % (split, len(self.datapath)))
|
||||
|
||||
if self.uniform:
|
||||
self.save_path = os.path.join(root,
|
||||
'modelnet%d_%s_%dpts_fps.dat' % (self.num_category, split, self.npoints))
|
||||
else:
|
||||
self.save_path = os.path.join(root, 'modelnet%d_%s_%dpts.dat' % (self.num_category, split, self.npoints))
|
||||
|
||||
if self.process_data:
|
||||
if not os.path.exists(self.save_path):
|
||||
print('Processing data %s (only running in the first time)...' % self.save_path)
|
||||
self.list_of_points = [None] * len(self.datapath)
|
||||
self.list_of_labels = [None] * len(self.datapath)
|
||||
|
||||
for index in range(len(self.datapath)):
|
||||
fn = self.datapath[index]
|
||||
cls = self.classes[self.datapath[index][0]]
|
||||
cls = np.array([cls]).astype(np.int32)
|
||||
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
|
||||
|
||||
if self.uniform:
|
||||
point_set = farthest_point_sample(point_set, self.npoints)
|
||||
else:
|
||||
point_set = point_set[0:self.npoints, :]
|
||||
|
||||
self.list_of_points[index] = point_set
|
||||
self.list_of_labels[index] = cls
|
||||
|
||||
with open(self.save_path, 'wb') as f:
|
||||
pickle.dump([self.list_of_points, self.list_of_labels], f)
|
||||
else:
|
||||
print('Load processed data from %s...' % self.save_path)
|
||||
with open(self.save_path, 'rb') as f:
|
||||
self.list_of_points, self.list_of_labels = pickle.load(f)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.datapath)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""get item"""
|
||||
if self.process_data:
|
||||
point_set, label = self.list_of_points[index], self.list_of_labels[index]
|
||||
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
|
||||
else:
|
||||
fn = self.datapath[index]
|
||||
cls = self.classes[self.datapath[index][0]]
|
||||
label = np.array([cls]).astype(np.int32)
|
||||
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
|
||||
|
||||
if self.uniform:
|
||||
point_set = farthest_point_sample(point_set, self.npoints)
|
||||
else:
|
||||
point_set = point_set[0:self.npoints, :]
|
||||
|
||||
if not self.use_normals:
|
||||
point_set = point_set[:, 0:3]
|
||||
|
||||
return point_set, label[0]
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""network layers initialization"""
|
||||
|
||||
import math
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.initializer import initializer, HeUniform, Uniform
|
||||
|
||||
|
||||
def calculate_fan_in_and_fan_out(shape):
|
||||
"""
|
||||
calculate fan_in and fan_out
|
||||
|
||||
Args:
|
||||
shape (tuple): input shape.
|
||||
|
||||
Returns:
|
||||
Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
|
||||
"""
|
||||
dimensions = len(shape)
|
||||
if dimensions < 2:
|
||||
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
|
||||
if dimensions == 2: # Linear
|
||||
fan_in = shape[1]
|
||||
fan_out = shape[0]
|
||||
else:
|
||||
num_input_fmaps = shape[1]
|
||||
num_output_fmaps = shape[0]
|
||||
receptive_field_size = 1
|
||||
if dimensions > 2:
|
||||
receptive_field_size = shape[2] * shape[3]
|
||||
fan_in = num_input_fmaps * receptive_field_size
|
||||
fan_out = num_output_fmaps * receptive_field_size
|
||||
return fan_in, fan_out
|
||||
|
||||
|
||||
class Dense(nn.Dense):
|
||||
"""Dense"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, has_bias=True, activation=None):
|
||||
super(Dense, self).__init__(in_channels, out_channels, weight_init='normal', bias_init='zeros',
|
||||
has_bias=has_bias, activation=activation)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
"""reset parameters"""
|
||||
self.weight.set_data(initializer(HeUniform(math.sqrt(5)), self.weight.shape))
|
||||
if self.has_bias:
|
||||
fan_in, _ = calculate_fan_in_and_fan_out(self.weight.shape)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
self.bias.set_data(initializer(Uniform(bound), [self.out_channels]))
|
||||
|
||||
|
||||
class Conv2d(nn.Conv2d):
|
||||
"""Conv2d"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad_mode='same', padding=0, dilation=1,
|
||||
group=1, has_bias=True):
|
||||
super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, pad_mode, padding, dilation, group,
|
||||
has_bias, weight_init='normal', bias_init='zeros')
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
"""reset parameters"""
|
||||
self.weight.set_data(initializer(HeUniform(math.sqrt(5)), self.weight.shape))
|
||||
if self.has_bias:
|
||||
fan_in, _ = calculate_fan_in_and_fan_out(self.weight.shape)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
self.bias.set_data(initializer(Uniform(bound), [self.out_channels]))
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""learning rate scheduler"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MultiStepLR:
|
||||
"""
|
||||
Multi-step learning rate scheduler
|
||||
|
||||
Decays the learning rate by gamma once the number of epoch reaches one of the milestones.
|
||||
|
||||
Args:
|
||||
lr (float): Initial learning rate which is the lower boundary in the cycle.
|
||||
milestones (list): List of epoch indices. Must be increasing.
|
||||
gamma (float): Multiplicative factor of learning rate decay.
|
||||
steps_per_epoch (int): The number of steps per epoch to train for.
|
||||
max_epoch (int): The number of epochs to train for.
|
||||
warmup_epochs (int, optional): The number of epochs to Warmup. Default: 0
|
||||
|
||||
Outputs:
|
||||
numpy.ndarray, shape=(1, steps_per_epoch*max_epoch)
|
||||
|
||||
Example:
|
||||
>>> # Assuming optimizer uses lr = 0.05 for all groups
|
||||
>>> # lr = 0.05 if epoch < 30
|
||||
>>> # lr = 0.005 if 30 <= epoch < 80
|
||||
>>> # lr = 0.0005 if epoch >= 80
|
||||
>>> scheduler = MultiStepLR(lr=0.1, milestones=[30,80], gamma=0.1, steps_per_epoch=5000, max_epoch=90)
|
||||
>>> lr = scheduler.get_lr()
|
||||
"""
|
||||
|
||||
def __init__(self, lr, milestones, gamma, steps_per_epoch, max_epoch):
|
||||
self.lr = lr
|
||||
self.milestones = milestones
|
||||
self.gamma = gamma
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
self.max_epoch = max_epoch
|
||||
self.total_steps = int(max_epoch * steps_per_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
"""get learning rate"""
|
||||
lr_each_step = []
|
||||
current_lr = self.lr
|
||||
for i in range(self.total_steps):
|
||||
cur_ep = i // self.steps_per_epoch
|
||||
if i % self.steps_per_epoch == 0 and cur_ep in self.milestones:
|
||||
current_lr = current_lr * self.gamma
|
||||
lr_each_step.append(current_lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""network definition"""
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as P
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
from src.layers import Dense
|
||||
from src.pointnet2_utils import PointNetSetAbstraction
|
||||
|
||||
|
||||
class PointNet2(nn.Cell):
|
||||
"""PointNet2"""
|
||||
|
||||
def __init__(self, num_class, normal_channel=False):
|
||||
super(PointNet2, self).__init__()
|
||||
in_channel = 6 if normal_channel else 3
|
||||
self.normal_channel = normal_channel
|
||||
|
||||
self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=32,
|
||||
in_channel=in_channel, mlp=[64, 64, 128],
|
||||
group_all=False)
|
||||
self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64,
|
||||
in_channel=128 + 3, mlp=[128, 128, 256],
|
||||
group_all=False)
|
||||
self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None,
|
||||
in_channel=256 + 3, mlp=[256, 512, 1024],
|
||||
group_all=True)
|
||||
|
||||
self.fc1 = Dense(1024, 512)
|
||||
self.bn1 = nn.BatchNorm1d(512)
|
||||
self.drop1 = nn.Dropout(0.6)
|
||||
self.fc2 = Dense(512, 256)
|
||||
self.bn2 = nn.BatchNorm1d(256)
|
||||
self.drop2 = nn.Dropout(0.5)
|
||||
self.fc3 = Dense(256, num_class)
|
||||
|
||||
self.relu = P.ReLU()
|
||||
self.reshape = P.Reshape()
|
||||
self.log_softmax = P.LogSoftmax()
|
||||
self.transpose = P.Transpose()
|
||||
|
||||
def construct(self, xyz):
|
||||
"""
|
||||
construct method
|
||||
"""
|
||||
if self.normal_channel:
|
||||
norm = self.transpose(xyz[:, :, 3:], (0, 2, 1))
|
||||
xyz = xyz[:, :, :3]
|
||||
else:
|
||||
norm = None
|
||||
l1_xyz, l1_points = self.sa1(xyz, norm) # [B, 3, 512], [B, 128, 512]
|
||||
l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) # [B, 3, 128], [B, 256, 128]
|
||||
_, l3_points = self.sa3(l2_xyz, l2_points) # [B, 3, 1], [B, 1024, 1]
|
||||
x = self.reshape(l3_points, (-1, 1024))
|
||||
x = self.drop1(self.relu(self.bn1(self.fc1(x))))
|
||||
x = self.drop2(self.relu(self.bn2(self.fc2(x))))
|
||||
x = self.fc3(x)
|
||||
x = self.log_softmax(x)
|
||||
return x
|
||||
|
||||
|
||||
class NLLLoss(_Loss):
|
||||
"""NLL loss"""
|
||||
|
||||
def __init__(self, reduction='mean'):
|
||||
super(NLLLoss, self).__init__(reduction)
|
||||
self.one_hot = P.OneHot()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
|
||||
def construct(self, logits, label):
|
||||
"""
|
||||
construct method
|
||||
"""
|
||||
label_one_hot = self.one_hot(label, F.shape(logits)[-1], F.scalar_to_array(1.0), F.scalar_to_array(0.0))
|
||||
loss = self.reduce_sum(-1.0 * logits * label_one_hot, (1,))
|
||||
return loss
|
|
@ -0,0 +1,237 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""network definition utils"""
|
||||
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.numpy as mnp
|
||||
import mindspore.ops as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops.primitive import constexpr
|
||||
|
||||
from src.layers import Conv2d
|
||||
|
||||
|
||||
@constexpr
|
||||
def generate_tensor_fps(B, N):
|
||||
"""generate tensor"""
|
||||
farthest = Tensor(np.random.randint(N, size=(B,)), ms.int32)
|
||||
return farthest
|
||||
|
||||
|
||||
@constexpr
|
||||
def generate_tensor_batch_indices(B):
|
||||
"""generate tensor"""
|
||||
return Tensor(np.arange(B), ms.int32)
|
||||
|
||||
|
||||
def square_distance(src, dst):
|
||||
"""
|
||||
Calculate Euclid distance between each two points.
|
||||
|
||||
src^T * dst = xn * xm + yn * ym + zn * zm;
|
||||
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
|
||||
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
|
||||
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
|
||||
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
|
||||
|
||||
Input:
|
||||
src: source points, [B, N, C]
|
||||
dst: target points, [B, M, C]
|
||||
Output:
|
||||
dist: per-point square distance, [B, N, M]
|
||||
"""
|
||||
B, N, _ = src.shape
|
||||
_, M, _ = dst.shape
|
||||
dist = -2 * P.BatchMatMul()(src, P.Transpose()(dst, (0, 2, 1)))
|
||||
dist += P.Reshape()(P.ReduceSum()(src ** 2, -1), (B, N, 1))
|
||||
dist += P.Reshape()(P.ReduceSum()(dst ** 2, -1), (B, 1, M))
|
||||
return dist
|
||||
|
||||
|
||||
def index_points(points, idx):
|
||||
"""
|
||||
Input:
|
||||
points: input points data, [B, N, C]
|
||||
idx: sample index data, [B, S] or [B, S, nsample]
|
||||
Return:
|
||||
new_points:, indexed points data, [B, S, C] or [B, S, nsample, C]
|
||||
"""
|
||||
shape = idx.shape
|
||||
batch_indices = generate_tensor_batch_indices(shape[0])
|
||||
if len(shape) == 2:
|
||||
batch_indices = batch_indices.view(shape[0], 1)
|
||||
else:
|
||||
batch_indices = batch_indices.view(shape[0], 1, 1)
|
||||
batch_indices = batch_indices.expand_as(idx)
|
||||
index = P.Concat(-1)((batch_indices.reshape(idx.shape + (1,)), idx.reshape(idx.shape + (1,))))
|
||||
new_points = P.GatherNd()(points, index)
|
||||
return new_points
|
||||
|
||||
|
||||
def farthest_point_sample(xyz, npoint):
|
||||
"""
|
||||
Input:
|
||||
xyz: pointcloud data, [B, N, 3]
|
||||
npoint: number of samples
|
||||
Return:
|
||||
centroids: sampled pointcloud index, [B, npoint]
|
||||
"""
|
||||
B, N, _ = xyz.shape
|
||||
centroids = mnp.zeros((npoint, B), ms.int32)
|
||||
distance = mnp.ones((B, N), ms.int32) * 1e9
|
||||
farthest = generate_tensor_fps(B, N)
|
||||
batch_indices = generate_tensor_batch_indices(B)
|
||||
for i in range(npoint):
|
||||
centroids = P.Cast()(centroids, ms.float32)
|
||||
farthest = P.Cast()(farthest, ms.float32)
|
||||
centroids[i] = farthest
|
||||
centroids = P.Cast()(centroids, ms.int32)
|
||||
farthest = P.Cast()(farthest, ms.int32)
|
||||
index = P.Concat(-1)((batch_indices.reshape(batch_indices.shape + (1,)),
|
||||
farthest.reshape(farthest.shape + (1,))))
|
||||
centroid = P.GatherNd()(xyz, index).reshape((B, 1, 3))
|
||||
dist = P.ReduceSum()((xyz - centroid) ** 2, -1)
|
||||
distance = P.Minimum()(distance, dist)
|
||||
farthest = P.Argmax()(distance)
|
||||
return P.Transpose()(centroids, (1, 0))
|
||||
|
||||
|
||||
def query_ball_point(radius, nsample, xyz, new_xyz):
|
||||
"""
|
||||
Input:
|
||||
radius: local region radius
|
||||
nsample: max sample number in local region
|
||||
xyz: all points, [B, N, 3]
|
||||
new_xyz: query points, [B, S, 3]
|
||||
Return:
|
||||
group_idx: grouped points index, [B, S, nsample]
|
||||
"""
|
||||
B, N, _ = xyz.shape
|
||||
_, S, _ = new_xyz.shape
|
||||
group_idx = mnp.arange(0, N, 1, ms.int32).view(1, 1, N)
|
||||
group_idx = P.Tile()(group_idx, (B, S, 1))
|
||||
sqrdists = square_distance(new_xyz, xyz)
|
||||
|
||||
idx = sqrdists > radius ** 2
|
||||
group_idx = P.Select()(idx, -1 * P.OnesLike()(group_idx), group_idx)
|
||||
group_idx = P.Cast()(group_idx, ms.float32)
|
||||
group_idx, _ = P.TopK()(group_idx, nsample)
|
||||
group_idx = P.Cast()(group_idx, ms.int32)
|
||||
|
||||
group_first = group_idx[:, :, 0].view(B, S, 1)
|
||||
group_first = P.Tile()(group_first, (1, 1, nsample)) # [B, S, nsample]
|
||||
|
||||
index = group_idx != -1
|
||||
group_first = P.Select()(index, -1 * P.OnesLike()(group_first), group_first)
|
||||
group_idx = P.Maximum()(group_idx, group_first)
|
||||
|
||||
return group_idx
|
||||
|
||||
|
||||
def sample_and_group_all(xyz, points):
|
||||
"""
|
||||
Input:
|
||||
xyz: input points position data, [B, N, 3]
|
||||
points: input points data, [B, N, D]
|
||||
Return:
|
||||
new_xyz: sampled points position data, [B, 1, 3]
|
||||
new_points: sampled points data, [B, 1, N, 3+D]
|
||||
"""
|
||||
B, N, C = xyz.shape
|
||||
grouped_xyz = P.Reshape()(xyz, (B, 1, N, C))
|
||||
new_points = P.Concat(-1)((grouped_xyz, P.Reshape()(points, (B, 1, N, -1))))
|
||||
return new_points
|
||||
|
||||
|
||||
def sample_and_group(npoint, radius, nsample, xyz, points):
|
||||
"""
|
||||
Input:
|
||||
xyz: input points position data, [B, N, 3]
|
||||
points: input points data, [B, N, D]
|
||||
Return:
|
||||
new_xyz: sampled points position data, [B, npoint, nsample, 3]
|
||||
new_points: sampled points data, [B, npoint, nsample, 3+D]
|
||||
"""
|
||||
B, _, C = xyz.shape
|
||||
S = npoint
|
||||
fps_idx = farthest_point_sample(xyz, S) # [B, S]
|
||||
new_xyz = index_points(xyz, fps_idx) # [B, S, C]
|
||||
idx = query_ball_point(radius, nsample, xyz, new_xyz) # [B, S, nsample]
|
||||
grouped_xyz = index_points(xyz, idx) # [B, S, nsample, C]
|
||||
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
|
||||
|
||||
if points is not None:
|
||||
grouped_points = index_points(points, idx)
|
||||
new_points = P.Concat(-1)((grouped_xyz_norm, grouped_points)) # [B, S, nsample, C+D]
|
||||
else:
|
||||
new_points = grouped_xyz_norm
|
||||
|
||||
return new_xyz, new_points
|
||||
|
||||
|
||||
class PointNetSetAbstraction(nn.Cell):
|
||||
"""PointNetSetAbstraction"""
|
||||
|
||||
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
|
||||
super(PointNetSetAbstraction, self).__init__()
|
||||
self.npoint = npoint
|
||||
self.radius = radius
|
||||
self.nsample = nsample
|
||||
self.group_all = group_all
|
||||
|
||||
self.conv1 = Conv2d(in_channel, mlp[0], 1)
|
||||
self.bn1 = nn.BatchNorm2d(mlp[0])
|
||||
self.conv2 = Conv2d(mlp[0], mlp[1], 1)
|
||||
self.bn2 = nn.BatchNorm2d(mlp[1])
|
||||
self.conv3 = Conv2d(mlp[1], mlp[2], 1)
|
||||
self.bn3 = nn.BatchNorm2d(mlp[2])
|
||||
|
||||
self.relu = P.ReLU()
|
||||
self.transpose = P.Transpose()
|
||||
self.reduce_max = P.ReduceMax()
|
||||
|
||||
def construct(self, xyz, points):
|
||||
"""
|
||||
Input:
|
||||
xyz: input points position data, [B, C, N]
|
||||
points: input points data, [B, D, N]
|
||||
Return:
|
||||
new_xyz: sampled points position data, [B, C, S]
|
||||
new_points_concat: sample points feature data, [B, D', S]
|
||||
"""
|
||||
if points is not None:
|
||||
points = self.transpose(points, (0, 2, 1))
|
||||
|
||||
if self.group_all:
|
||||
new_points = sample_and_group_all(xyz, points)
|
||||
new_xyz = None
|
||||
else:
|
||||
new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
|
||||
# new_xyz: sampled points position data, [B, npoint, C]
|
||||
# new_points: sampled points data, [B, npoint, nsample, C+D]
|
||||
|
||||
d1, d2, d3, d4 = new_points.shape
|
||||
new_points = self.transpose(new_points.reshape((d1, d2 * d3, d4)), (0, 2, 1))
|
||||
new_points = self.transpose(new_points.reshape((d1 * d4, d2, d3)), (0, 2, 1)).reshape((d1, d4, d3, d2))
|
||||
|
||||
new_points = self.relu(self.bn1(self.conv1(new_points)))
|
||||
new_points = self.relu(self.bn2(self.conv2(new_points)))
|
||||
new_points = self.relu(self.bn3(self.conv3(new_points)))
|
||||
|
||||
new_points = self.reduce_max(new_points, 2)
|
||||
|
||||
return new_xyz, new_points
|
|
@ -0,0 +1,298 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""data preprocessing for training"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def normalize_data(batch_data):
|
||||
"""
|
||||
Normalize the batch data, use coordinates of the block centered at origin,
|
||||
Input:
|
||||
BxNxC array
|
||||
Output:
|
||||
BxNxC array
|
||||
"""
|
||||
B, N, C = batch_data.shape
|
||||
normal_data = np.zeros((B, N, C))
|
||||
for b in range(B):
|
||||
pc = batch_data[b]
|
||||
centroid = np.mean(pc, axis=0)
|
||||
pc = pc - centroid
|
||||
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
|
||||
pc = pc / m
|
||||
normal_data[b] = pc
|
||||
return normal_data
|
||||
|
||||
|
||||
def shuffle_data(data, labels):
|
||||
"""
|
||||
Shuffle data and labels.
|
||||
Input:
|
||||
data: B,N,... numpy array
|
||||
label: B,... numpy array
|
||||
Return:
|
||||
shuffled data, label and shuffle indices
|
||||
"""
|
||||
idx = np.arange(len(labels))
|
||||
np.random.shuffle(idx)
|
||||
return data[idx, ...], labels[idx], idx
|
||||
|
||||
|
||||
def shuffle_points(batch_data):
|
||||
"""
|
||||
Shuffle orders of points in each point cloud -- changes FPS behavior.
|
||||
Use the same shuffling idx for the entire batch.
|
||||
Input:
|
||||
BxNxC array
|
||||
Output:
|
||||
BxNxC array
|
||||
"""
|
||||
idx = np.arange(batch_data.shape[1])
|
||||
np.random.shuffle(idx)
|
||||
return batch_data[:, idx, :]
|
||||
|
||||
|
||||
def rotate_point_cloud(batch_data):
|
||||
"""
|
||||
Randomly rotate the point clouds to augment the dataset
|
||||
rotation is per shape based along up direction
|
||||
Input:
|
||||
BxNx3 array, original batch of point clouds
|
||||
Return:
|
||||
BxNx3 array, rotated batch of point clouds
|
||||
"""
|
||||
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
|
||||
for k in range(batch_data.shape[0]):
|
||||
rotation_angle = np.random.uniform() * 2 * np.pi
|
||||
cosval = np.cos(rotation_angle)
|
||||
sinval = np.sin(rotation_angle)
|
||||
rotation_matrix = np.array([[cosval, 0, sinval],
|
||||
[0, 1, 0],
|
||||
[-sinval, 0, cosval]])
|
||||
shape_pc = batch_data[k, ...]
|
||||
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
|
||||
return rotated_data
|
||||
|
||||
|
||||
def rotate_point_cloud_z(batch_data):
|
||||
"""
|
||||
Randomly rotate the point clouds to augment the dataset
|
||||
rotation is per shape based along up direction
|
||||
Input:
|
||||
BxNx3 array, original batch of point clouds
|
||||
Return:
|
||||
BxNx3 array, rotated batch of point clouds
|
||||
"""
|
||||
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
|
||||
for k in range(batch_data.shape[0]):
|
||||
rotation_angle = np.random.uniform() * 2 * np.pi
|
||||
cosval = np.cos(rotation_angle)
|
||||
sinval = np.sin(rotation_angle)
|
||||
rotation_matrix = np.array([[cosval, sinval, 0],
|
||||
[-sinval, cosval, 0],
|
||||
[0, 0, 1]])
|
||||
shape_pc = batch_data[k, ...]
|
||||
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
|
||||
return rotated_data
|
||||
|
||||
|
||||
def rotate_point_cloud_with_normal(batch_xyz_normal):
|
||||
"""
|
||||
Randomly rotate XYZ, normal point cloud.
|
||||
Input:
|
||||
batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal
|
||||
Output:
|
||||
B,N,6, rotated XYZ, normal point cloud
|
||||
"""
|
||||
for k in range(batch_xyz_normal.shape[0]):
|
||||
rotation_angle = np.random.uniform() * 2 * np.pi
|
||||
cosval = np.cos(rotation_angle)
|
||||
sinval = np.sin(rotation_angle)
|
||||
rotation_matrix = np.array([[cosval, 0, sinval],
|
||||
[0, 1, 0],
|
||||
[-sinval, 0, cosval]])
|
||||
shape_pc = batch_xyz_normal[k, :, 0:3]
|
||||
shape_normal = batch_xyz_normal[k, :, 3:6]
|
||||
batch_xyz_normal[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
|
||||
batch_xyz_normal[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix)
|
||||
return batch_xyz_normal
|
||||
|
||||
|
||||
def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18):
|
||||
"""
|
||||
Randomly perturb the point clouds by small rotations
|
||||
Input:
|
||||
BxNx6 array, original batch of point clouds and point normals
|
||||
Return:
|
||||
BxNx3 array, rotated batch of point clouds
|
||||
"""
|
||||
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
|
||||
for k in range(batch_data.shape[0]):
|
||||
angles = np.clip(angle_sigma * np.random.randn(3), -angle_clip, angle_clip)
|
||||
Rx = np.array([[1, 0, 0],
|
||||
[0, np.cos(angles[0]), -np.sin(angles[0])],
|
||||
[0, np.sin(angles[0]), np.cos(angles[0])]])
|
||||
Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])],
|
||||
[0, 1, 0],
|
||||
[-np.sin(angles[1]), 0, np.cos(angles[1])]])
|
||||
Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0],
|
||||
[np.sin(angles[2]), np.cos(angles[2]), 0],
|
||||
[0, 0, 1]])
|
||||
R = np.dot(Rz, np.dot(Ry, Rx))
|
||||
shape_pc = batch_data[k, :, 0:3]
|
||||
shape_normal = batch_data[k, :, 3:6]
|
||||
rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), R)
|
||||
rotated_data[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), R)
|
||||
return rotated_data
|
||||
|
||||
|
||||
def rotate_point_cloud_by_angle(batch_data, rotation_angle):
|
||||
"""
|
||||
Rotate the point cloud along up direction with certain angle.
|
||||
Input:
|
||||
BxNx3 array, original batch of point clouds
|
||||
Return:
|
||||
BxNx3 array, rotated batch of point clouds
|
||||
"""
|
||||
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
|
||||
for k in range(batch_data.shape[0]):
|
||||
cosval = np.cos(rotation_angle)
|
||||
sinval = np.sin(rotation_angle)
|
||||
rotation_matrix = np.array([[cosval, 0, sinval],
|
||||
[0, 1, 0],
|
||||
[-sinval, 0, cosval]])
|
||||
shape_pc = batch_data[k, :, 0:3]
|
||||
rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
|
||||
return rotated_data
|
||||
|
||||
|
||||
def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):
|
||||
"""
|
||||
Rotate the point cloud along up direction with certain angle.
|
||||
Input:
|
||||
BxNx6 array, original batch of point clouds with normal
|
||||
scalar, angle of rotation
|
||||
Return:
|
||||
BxNx6 array, rotated batch of point clouds with normal
|
||||
"""
|
||||
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
|
||||
for k in range(batch_data.shape[0]):
|
||||
cosval = np.cos(rotation_angle)
|
||||
sinval = np.sin(rotation_angle)
|
||||
rotation_matrix = np.array([[cosval, 0, sinval],
|
||||
[0, 1, 0],
|
||||
[-sinval, 0, cosval]])
|
||||
shape_pc = batch_data[k, :, 0:3]
|
||||
shape_normal = batch_data[k, :, 3:6]
|
||||
rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
|
||||
rotated_data[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix)
|
||||
return rotated_data
|
||||
|
||||
|
||||
def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18):
|
||||
"""
|
||||
Randomly perturb the point clouds by small rotations
|
||||
Input:
|
||||
BxNx3 array, original batch of point clouds
|
||||
Return:
|
||||
BxNx3 array, rotated batch of point clouds
|
||||
"""
|
||||
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
|
||||
for k in range(batch_data.shape[0]):
|
||||
angles = np.clip(angle_sigma * np.random.randn(3), -angle_clip, angle_clip)
|
||||
Rx = np.array([[1, 0, 0],
|
||||
[0, np.cos(angles[0]), -np.sin(angles[0])],
|
||||
[0, np.sin(angles[0]), np.cos(angles[0])]])
|
||||
Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])],
|
||||
[0, 1, 0],
|
||||
[-np.sin(angles[1]), 0, np.cos(angles[1])]])
|
||||
Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0],
|
||||
[np.sin(angles[2]), np.cos(angles[2]), 0],
|
||||
[0, 0, 1]])
|
||||
R = np.dot(Rz, np.dot(Ry, Rx))
|
||||
shape_pc = batch_data[k, ...]
|
||||
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
|
||||
return rotated_data
|
||||
|
||||
|
||||
def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
|
||||
"""
|
||||
Randomly jitter points.
|
||||
Input:
|
||||
BxNx3 array, original batch of point clouds
|
||||
Return:
|
||||
BxNx3 array, processed batch of point clouds
|
||||
"""
|
||||
B, N, C = batch_data.shape
|
||||
assert clip > 0
|
||||
jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1 * clip, clip)
|
||||
jittered_data += batch_data
|
||||
return jittered_data
|
||||
|
||||
|
||||
def shift_point_cloud(batch_data, shift_range=0.1):
|
||||
"""
|
||||
Randomly shift point cloud. Shift is per point cloud.
|
||||
Input:
|
||||
BxNx3 array, original batch of point clouds
|
||||
Return:
|
||||
BxNx3 array, shifted batch of point clouds
|
||||
"""
|
||||
B, _, _ = batch_data.shape
|
||||
shifts = np.random.uniform(-shift_range, shift_range, (B, 3))
|
||||
for batch_index in range(B):
|
||||
batch_data[batch_index, :, :] += shifts[batch_index, :]
|
||||
return batch_data
|
||||
|
||||
|
||||
def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
|
||||
"""
|
||||
Randomly scale the point cloud. Scale is per point cloud.
|
||||
Input:
|
||||
BxNx3 array, original batch of point clouds
|
||||
Return:
|
||||
BxNx3 array, scaled batch of point clouds
|
||||
"""
|
||||
B, _, _ = batch_data.shape
|
||||
scales = np.random.uniform(scale_low, scale_high, B)
|
||||
for batch_index in range(B):
|
||||
batch_data[batch_index, :, :] *= scales[batch_index]
|
||||
return batch_data
|
||||
|
||||
|
||||
def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
|
||||
""" batch_pc: BxNx3 """
|
||||
for b in range(batch_pc.shape[0]):
|
||||
dropout_ratio = np.random.random() * max_dropout_ratio # 0~0.875
|
||||
drop_idx = np.where(np.random.random((batch_pc.shape[1])) <= dropout_ratio)[0]
|
||||
if drop_idx.any():
|
||||
batch_pc[b, drop_idx, :] = batch_pc[b, 0, :] # set to the first point
|
||||
return batch_pc
|
||||
|
||||
|
||||
class RandomInputDropout:
|
||||
"""random input dropout during training"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, data, label, batchInfo):
|
||||
data = np.array(data)
|
||||
label = np.array(label)
|
||||
data = random_point_dropout(data)
|
||||
data[:, :, 0:3] = random_scale_point_cloud(data[:, :, 0:3])
|
||||
data[:, :, 0:3] = shift_point_cloud(data[:, :, 0:3])
|
||||
return data, label
|
|
@ -0,0 +1,203 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Training"""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Model, Tensor, context, load_checkpoint, load_param_into_net
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.profiler import Profiler
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
|
||||
from src.dataset import DatasetGenerator
|
||||
from src.lr_scheduler import MultiStepLR
|
||||
from src.pointnet2 import PointNet2, NLLLoss
|
||||
from src.provider import RandomInputDropout
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""PARAMETERS"""
|
||||
parser = argparse.ArgumentParser('MindSpore PointNet++ Training Configurations.')
|
||||
parser.add_argument('--batch_size', type=int, default=24, help='batch size in training') # 24
|
||||
parser.add_argument('--epoch', default=200, type=int, help='number of epoch in training') # 200
|
||||
parser.add_argument('--learning_rate', default=0.001, type=float, help='learning rate in training') # 0.001
|
||||
parser.add_argument('--optimizer', type=str, default='Adam', choices=('Adam', 'SGD'),
|
||||
help='optimizer for training') # Adam
|
||||
parser.add_argument('--data_path', type=str, default='../modelnet40_normal_resampled/', help='data path')
|
||||
parser.add_argument('--pretrained_ckpt', type=str, default='')
|
||||
parser.add_argument('--loss_per_epoch', type=int, default=5, help='times to print loss value per epoch')
|
||||
parser.add_argument('--save_dir', type=str, default='./save', help='save root')
|
||||
|
||||
parser.add_argument('--use_normals', type=ast.literal_eval, default=False, help='use normals')
|
||||
parser.add_argument('--decay_rate', type=float, default=1e-4, help='decay rate')
|
||||
parser.add_argument('--num_category', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40')
|
||||
parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
|
||||
parser.add_argument('--process_data', action='store_true', default=False, help='save data offline')
|
||||
parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampling')
|
||||
|
||||
parser.add_argument('--platform', type=str, default='Ascend', help='run platform')
|
||||
parser.add_argument('--enable_profiling', type=ast.literal_eval, default=False)
|
||||
|
||||
parser.add_argument('--enable_modelarts', type=ast.literal_eval, default=False)
|
||||
parser.add_argument('--data_url', type=str)
|
||||
parser.add_argument('--train_url', type=str)
|
||||
parser.add_argument('--mox_freq', type=int, default=10, help='mox frequency')
|
||||
|
||||
return parser.parse_known_args()[0]
|
||||
|
||||
|
||||
def run_train():
|
||||
"""run train"""
|
||||
args = parse_args()
|
||||
|
||||
# INIT
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
device_num = int(os.getenv('RANK_SIZE', '1'))
|
||||
rank_id = int(os.getenv('RANK_ID', '0'))
|
||||
|
||||
if args.platform == "Ascend":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
|
||||
context.set_context(max_call_depth=2048)
|
||||
|
||||
if device_num > 1:
|
||||
init()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
||||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
|
||||
if args.enable_modelarts:
|
||||
import moxing as mox
|
||||
|
||||
local_data_url = "/cache/data"
|
||||
mox.file.copy_parallel(args.data_url, local_data_url)
|
||||
if args.pretrained_ckpt.endswith('.ckpt'):
|
||||
pretrained_ckpt_path = "/cache/pretrained_ckpt/pretrained.ckpt"
|
||||
mox.file.copy_parallel(args.pretrained_ckpt, pretrained_ckpt_path)
|
||||
local_train_url = "/cache/train_output"
|
||||
save_dir = local_train_url
|
||||
if rank_id == 0:
|
||||
mox.file.copy_parallel(os.path.join(args.train_url, 'log_train.txt'),
|
||||
os.path.join(save_dir, 'log_train.txt'))
|
||||
log_file = open(os.path.join(save_dir, 'log_train.txt'), 'w')
|
||||
sys.stdout = log_file
|
||||
else:
|
||||
local_data_url = args.data_path
|
||||
if args.pretrained_ckpt.endswith('.ckpt'):
|
||||
pretrained_ckpt_path = args.pretrained_ckpt
|
||||
local_train_url = args.save_dir
|
||||
save_dir = args.save_dir
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
|
||||
if args.enable_profiling:
|
||||
profiler = Profiler()
|
||||
|
||||
print(args)
|
||||
|
||||
# DATA LOADING
|
||||
print('Load dataset ...')
|
||||
data_path = local_data_url
|
||||
|
||||
num_workers = 4
|
||||
train_ds_generator = DatasetGenerator(root=data_path, args=args, split='train', process_data=args.process_data)
|
||||
if device_num > 1:
|
||||
train_ds = ds.GeneratorDataset(train_ds_generator, ["data", "label"], num_parallel_workers=num_workers,
|
||||
shuffle=True, shard_id=rank_id, num_shards=device_num)
|
||||
else:
|
||||
train_ds = ds.GeneratorDataset(train_ds_generator, ["data", "label"], num_parallel_workers=num_workers,
|
||||
shuffle=True)
|
||||
random_input_dropout = RandomInputDropout()
|
||||
train_ds = train_ds.batch(batch_size=args.batch_size, per_batch_map=random_input_dropout,
|
||||
input_columns=["data", "label"], drop_remainder=True, num_parallel_workers=num_workers)
|
||||
|
||||
steps_per_epoch = train_ds.get_dataset_size()
|
||||
|
||||
# MODEL
|
||||
net = PointNet2(args.num_category, args.use_normals)
|
||||
|
||||
# load checkpoint
|
||||
if args.pretrained_ckpt.endswith('.ckpt'):
|
||||
print("Load checkpoint: %s" % args.pretrained_ckpt)
|
||||
param_dict = load_checkpoint(pretrained_ckpt_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
net_loss = NLLLoss()
|
||||
|
||||
lr_epochs = list(range(20, 201, 20))
|
||||
lr_fun = MultiStepLR(args.learning_rate, lr_epochs, 0.7, steps_per_epoch, args.epoch)
|
||||
lr = lr_fun.get_lr()
|
||||
|
||||
if args.optimizer == 'Adam':
|
||||
net_opt = nn.Adam(
|
||||
net.trainable_params(),
|
||||
learning_rate=Tensor(lr),
|
||||
beta1=0.9,
|
||||
beta2=0.999,
|
||||
eps=1e-08,
|
||||
weight_decay=args.decay_rate
|
||||
)
|
||||
else:
|
||||
net_opt = nn.SGD(net.trainable_params(), learning_rate=args.learning_rate, momentum=0.9)
|
||||
|
||||
model = Model(net, net_loss, net_opt)
|
||||
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=args.epoch)
|
||||
ckpt_cb = ModelCheckpoint(prefix="ckpt_pointnet2", directory=local_train_url, config=config_ck)
|
||||
|
||||
loss_freq = max(steps_per_epoch // args.loss_per_epoch, 1)
|
||||
|
||||
cb = []
|
||||
cb += [TimeMonitor()]
|
||||
cb += [LossMonitor(loss_freq)]
|
||||
if (not args.enable_modelarts) or (rank_id == 0):
|
||||
cb += [ckpt_cb]
|
||||
|
||||
if args.enable_modelarts:
|
||||
from src.callbacks import MoxCallBack
|
||||
cb += [MoxCallBack(local_train_url, args.train_url, args.mox_freq)]
|
||||
|
||||
# TRAINING
|
||||
net.set_train()
|
||||
print('Starting training ...')
|
||||
print('Time: ', time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime()))
|
||||
|
||||
if args.enable_modelarts:
|
||||
mox.file.copy_parallel(local_train_url, args.train_url)
|
||||
|
||||
time_start = time.time()
|
||||
|
||||
model.train(epoch=args.epoch, train_dataset=train_ds, callbacks=cb, dataset_sink_mode=True)
|
||||
|
||||
# END
|
||||
print('End of training.')
|
||||
print('Total time cost: {} min'.format("%.2f" % ((time.time() - time_start) / 60)))
|
||||
|
||||
if args.enable_profiling:
|
||||
profiler.analyse()
|
||||
|
||||
if args.enable_modelarts and rank_id == 0:
|
||||
log_file.close()
|
||||
mox.file.copy_parallel(local_train_url, args.train_url)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_train()
|
Loading…
Reference in New Issue