deepsort
run on Ascend with 8p deepsort description deepsort run on Ascend with 8p deepsort description de deepsort de deepsort deepsort description deepsort remove deepsort run on Ascend with 8p deepsort description de deepsort de deepsort deepsort description deepsort dele Deepsort stgcn rm stgcn
This commit is contained in:
parent
196d65c0bd
commit
129234a155
|
@ -0,0 +1,255 @@
|
||||||
|
# Contents
|
||||||
|
|
||||||
|
- [STGCN 介绍](#STGCN-介绍)
|
||||||
|
- [模型架构](#模型架构)
|
||||||
|
- [数据集](#数据集)
|
||||||
|
- [环境要求](#环境要求)
|
||||||
|
- [快速开始](#快速开始)
|
||||||
|
- [脚本介绍](#脚本介绍)
|
||||||
|
- [脚本以及简单代码](#脚本以及简单代码)
|
||||||
|
- [脚本参数](#脚本参数)
|
||||||
|
- [训练步骤](#训练步骤)
|
||||||
|
- [训练](#训练)
|
||||||
|
- [评估步骤](#评估步骤)
|
||||||
|
- [评估](#评估)
|
||||||
|
- [导出mindir模型](#导出mindir模型)
|
||||||
|
- [推理过程](#推理过程)
|
||||||
|
- [用法](#用法)
|
||||||
|
- [结果](#结果)
|
||||||
|
- [模型介绍](#模型介绍)
|
||||||
|
- [性能](#性能)
|
||||||
|
- [评估性能](#评估性能)
|
||||||
|
- [随机事件介绍](#随机事件介绍)
|
||||||
|
- [ModelZoo 主页](#ModelZoo-主页)
|
||||||
|
|
||||||
|
# [STGCN 介绍](#contents)
|
||||||
|
|
||||||
|
STGCN主要用于交通预测领域,是一种时空卷积网络。在STGCN文章中提出一种新颖的深度学习框架——时空图卷积网络(STGCN),解决在通领域的时间序列预测问题。在定义图上的问题,并用纯卷积结构建立模型,这使得使用更少的参数能带来更快的训练速度。STGCN通过建模多尺度交通网络有效捕获全面的时空相关性,且在各种真实世界交通数据集始终保持SOTA。
|
||||||
|
|
||||||
|
[Paper](https://arxiv.org/abs/1709.04875): Bing yu, Haoteng Yin, and Zhanxing Zhu. "Spatio-Temporal Graph Convolutional Networks:
|
||||||
|
A Deep Learning Framework for Traffic Forecasting." Proceedings of the 27th International Joint Conference on Artificial Intelligence. 2017.
|
||||||
|
|
||||||
|
# [模型架构](#contents)
|
||||||
|
|
||||||
|
STGCN模型结构是由两个时空卷积快和一个输出层构成。时空卷积块分为时域卷积块和空域卷积块。空域卷积块有两种不同卷积方式,分别为:Cheb和GCN。
|
||||||
|
|
||||||
|
# [数据集](#contents)
|
||||||
|
|
||||||
|
Dataset used:
|
||||||
|
|
||||||
|
PeMED7(PeMSD7-m、PeMSD7-L)
|
||||||
|
BJER4
|
||||||
|
|
||||||
|
由于数据集下载原因,只找到了[PeMSD7-M](https://github.com/hazdzz/STGCN/tree/main/data/train/road_traffic/pemsd7-m)数据集。
|
||||||
|
|
||||||
|
# [环境要求](#contents)
|
||||||
|
|
||||||
|
- 硬件(Ascend/GPU)
|
||||||
|
- 需要准备具有Ascend或GPU处理能力的硬件环境.
|
||||||
|
- 框架
|
||||||
|
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||||
|
- 如需获取更多信息,请查看如下链接:
|
||||||
|
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||||
|
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||||
|
|
||||||
|
# [快速开始](#contents)
|
||||||
|
|
||||||
|
在通过官方网站安装MindSpore之后,你可以通过如下步骤开始训练以及评估:
|
||||||
|
|
||||||
|
- running on Ascend with default parameters
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 单卡训练
|
||||||
|
python train.py --train_url="" --data_url="" --run_distribute=False --run_modelarts=False --graph_conv_type="chebgcn" --n_pred=9
|
||||||
|
|
||||||
|
# 多卡训练
|
||||||
|
bash scripts/run_distribute_train.sh train_code_path data_path n_pred graph_conv_type
|
||||||
|
```
|
||||||
|
|
||||||
|
# [脚本介绍](#contents)
|
||||||
|
|
||||||
|
## [脚本以及简单代码](#contents)
|
||||||
|
|
||||||
|
```python
|
||||||
|
├── STGCN
|
||||||
|
├── scripts
|
||||||
|
├── run_distribute_train.sh //traing on Ascend with 8P
|
||||||
|
├── run_eval_ascend.sh //testing on Ascend
|
||||||
|
├── src
|
||||||
|
├── model
|
||||||
|
├──layers.py // model layer
|
||||||
|
├──metric.py // network with losscell
|
||||||
|
├──models.py // network model
|
||||||
|
├──config.py // parameter
|
||||||
|
├──dataloder.py // creating dataset
|
||||||
|
├──utility.py // calculate laplacian matrix and evaluate metric
|
||||||
|
├──weight_init.py // layernorm weight init
|
||||||
|
├── train.py // traing network
|
||||||
|
├── test.py // tesing network performance
|
||||||
|
├── postprocess.py // compute accuracy for ascend310
|
||||||
|
├── preprocess.py // process dataset for ascend310
|
||||||
|
├── README.md // descriptions
|
||||||
|
```
|
||||||
|
|
||||||
|
## [脚本参数](#contents)
|
||||||
|
|
||||||
|
训练以及评估的参数可以在config.py中设置
|
||||||
|
|
||||||
|
- config for STGCN
|
||||||
|
|
||||||
|
```python
|
||||||
|
stgcn_chebconv_45min_cfg = edict({
|
||||||
|
'learning_rate': 0.003,
|
||||||
|
'n_his': 12,
|
||||||
|
'n_pred': 9,
|
||||||
|
'n_vertex': 228,
|
||||||
|
'epochs': 500,
|
||||||
|
'batch_size': 8,
|
||||||
|
'decay_epoch': 10,
|
||||||
|
'gamma': 0.7,
|
||||||
|
'stblock_num': 2,
|
||||||
|
'Ks': 2,
|
||||||
|
'Kt': 3,
|
||||||
|
'time_intvl': 5,
|
||||||
|
'drop_rate': 0.5,
|
||||||
|
'weight_decay_rate': 0.0005,
|
||||||
|
'gated_act_func':"glu",
|
||||||
|
'graph_conv_type': "chebconv",
|
||||||
|
'mat_type': "wid_sym_normd_lap_mat",
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
如需查看更多信息,请查看`config.py`.
|
||||||
|
|
||||||
|
## [训练步骤](#contents)
|
||||||
|
|
||||||
|
### 训练
|
||||||
|
|
||||||
|
- running on Ascend
|
||||||
|
|
||||||
|
```python
|
||||||
|
#1P训练
|
||||||
|
python train.py --train_url="" --data_url="" --run_distribute=False --run_modelarts=True --graph_conv_type="chebgcn" --n_pred=9
|
||||||
|
#8P训练
|
||||||
|
bash scripts/run_distribute_train.sh train_code_path data_path n_pred graph_conv_type
|
||||||
|
```
|
||||||
|
|
||||||
|
8P训练时需要将RANK_TABLE_FILE放在scripts文件夹中,RANK_TABLE_FILE[生成方法](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)
|
||||||
|
|
||||||
|
训练时,训练过程中的epch和step以及此时的loss和精确度会呈现在终端上:
|
||||||
|
|
||||||
|
```python
|
||||||
|
epoch: 1 step: 139, loss is 0.429
|
||||||
|
epoch time: 203885.163 ms, per step time: 1466.800 ms
|
||||||
|
epoch: 2 step: 139, loss is 0.2097
|
||||||
|
epoch time: 6330.939 ms, per step time: 45.546 ms
|
||||||
|
epoch: 3 step: 139, loss is 0.4192
|
||||||
|
epoch time: 6364.882 ms, per step time: 45.791 ms
|
||||||
|
epoch: 4 step: 139, loss is 0.2917
|
||||||
|
epoch time: 6378.299 ms, per step time: 45.887 ms
|
||||||
|
epoch: 5 step: 139, loss is 0.2365
|
||||||
|
epoch time: 6369.215 ms, per step time: 45.822 ms
|
||||||
|
epoch: 6 step: 139, loss is 0.2269
|
||||||
|
epoch time: 6389.238 ms, per step time: 45.966 ms
|
||||||
|
epoch: 7 step: 139, loss is 0.3071
|
||||||
|
epoch time: 6365.901 ms, per step time: 45.798 ms
|
||||||
|
epoch: 8 step: 139, loss is 0.2336
|
||||||
|
epoch time: 6358.127 ms, per step time: 45.742 ms
|
||||||
|
epoch: 9 step: 139, loss is 0.2812
|
||||||
|
epoch time: 6333.794 ms, per step time: 45.567 ms
|
||||||
|
epoch: 10 step: 139, loss is 0.2622
|
||||||
|
epoch time: 6334.013 ms, per step time: 45.568 ms
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
此模型的checkpoint存储在train_url路径中
|
||||||
|
|
||||||
|
## [评估步骤](#contents)
|
||||||
|
|
||||||
|
### 评估
|
||||||
|
|
||||||
|
- 在Ascend上使用PeMSD7-m 测试集进行评估
|
||||||
|
|
||||||
|
在使用命令运行时,需要传入模型参数地址、模型参数名称、空域卷积方式、预测时段。
|
||||||
|
|
||||||
|
```python
|
||||||
|
python test.py --run_modelarts=False --run_distribute=False --device_id=0 --ckpt_url="" --ckpt_name="" --graph_conv_type="" --n_pred=9
|
||||||
|
#使用脚本评估
|
||||||
|
bash scripts/run_eval_ascend.sh data_path ckpt_url ckpt_name device_id graph_conv_type n_pred
|
||||||
|
```
|
||||||
|
|
||||||
|
以上的python命令会在终端上运行,你可以在终端上查看此次评估的结果。测试集的精确度会以如下方式呈现:
|
||||||
|
|
||||||
|
```python
|
||||||
|
MAE 3.23 | MAPE 8.32 | RMSE 6.06
|
||||||
|
```
|
||||||
|
|
||||||
|
## [导出mindir模型](#contents)
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python export.py --data_url [DATA_URL] --ckpt_file [CKPT_PATH] --n_pred [N_PRED] --graph_conv_type [GRAPH_CONV_TYPE] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
|
||||||
|
```
|
||||||
|
|
||||||
|
## [推理过程](#contents)
|
||||||
|
|
||||||
|
### 用法
|
||||||
|
|
||||||
|
执行推断之前,minirir文件必须由export.py导出。输入文件必须为bin格式
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# Ascend310 inference
|
||||||
|
bash run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [NEED_PREPROCESS] [DEVICE_TARGET] [DEVICE_ID]
|
||||||
|
```
|
||||||
|
|
||||||
|
### 结果
|
||||||
|
|
||||||
|
推理结果保存在当前路径中,您可以在acc.log文件中找到结果
|
||||||
|
|
||||||
|
# [模型介绍](#contents)
|
||||||
|
|
||||||
|
## [性能](#contents)
|
||||||
|
|
||||||
|
### 评估性能
|
||||||
|
|
||||||
|
#### STGCN on PeMSD7-m (Cheb,n_pred=9)
|
||||||
|
|
||||||
|
| Parameters | ModelArts
|
||||||
|
| -------------------------- | -----------------------------------------------------------
|
||||||
|
| Model Version | STGCN
|
||||||
|
| Resource | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G
|
||||||
|
| uploaded Date | 05/07/2021 (month/day/year)
|
||||||
|
| MindSpore Version | 1.2.0
|
||||||
|
| Dataset | PeMSD7-m
|
||||||
|
| Training Parameters | epoch=500, steps=139, batch_size = 8, lr=0.003
|
||||||
|
| Optimizer | AdamWeightDecay
|
||||||
|
| Loss Function | MES Loss
|
||||||
|
| outputs | probability
|
||||||
|
| Loss | 0.183
|
||||||
|
| Speed | 8pc: 45.601 ms/step;
|
||||||
|
| Scripts | [STGCN script]
|
||||||
|
|
||||||
|
### Inference Performance
|
||||||
|
|
||||||
|
#### STGCN on PeMSD7-m (Cheb,n_pred=9)
|
||||||
|
|
||||||
|
| Parameters | Ascend
|
||||||
|
| ------------------- | ---------------------------
|
||||||
|
| Model Version | STGCN
|
||||||
|
| Resource | Ascend 910
|
||||||
|
| Uploaded Date | 05/07/2021 (month/day/year)
|
||||||
|
| MindSpore Version | 1.2.0
|
||||||
|
| Dataset | PeMSD7-m
|
||||||
|
| batch_size | 8
|
||||||
|
| outputs | probability
|
||||||
|
| MAE | 3.23
|
||||||
|
| MAPE | 8.32
|
||||||
|
| RMSE | 6.06
|
||||||
|
| Model for inference | about 6M(.ckpt fil)
|
||||||
|
|
||||||
|
# [随机事件介绍](#contents)
|
||||||
|
|
||||||
|
我们在train.py中设置了随机种子
|
||||||
|
|
||||||
|
# [ModelZoo 主页](#contents)
|
||||||
|
|
||||||
|
请查看官方网站 [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,14 @@
|
||||||
|
cmake_minimum_required(VERSION 3.14.1)
|
||||||
|
project(Ascend310Infer)
|
||||||
|
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
|
||||||
|
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
|
||||||
|
option(MINDSPORE_PATH "mindspore install path" "")
|
||||||
|
include_directories(${MINDSPORE_PATH})
|
||||||
|
include_directories(${MINDSPORE_PATH}/include)
|
||||||
|
include_directories(${PROJECT_SRC_ROOT})
|
||||||
|
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
|
||||||
|
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
|
||||||
|
|
||||||
|
add_executable(main src/main.cc src/utils.cc)
|
||||||
|
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)
|
|
@ -0,0 +1,29 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
if [ -d out ]; then
|
||||||
|
rm -rf out
|
||||||
|
fi
|
||||||
|
|
||||||
|
mkdir out
|
||||||
|
cd out || exit
|
||||||
|
|
||||||
|
if [ -f "Makefile" ]; then
|
||||||
|
make clean
|
||||||
|
fi
|
||||||
|
|
||||||
|
cmake .. \
|
||||||
|
-DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
|
||||||
|
make
|
|
@ -0,0 +1,32 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_INFERENCE_UTILS_H_
|
||||||
|
#define MINDSPORE_INFERENCE_UTILS_H_
|
||||||
|
|
||||||
|
#include <sys/stat.h>
|
||||||
|
#include <dirent.h>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include "include/api/types.h"
|
||||||
|
|
||||||
|
std::vector<std::string> GetAllFiles(std::string_view dirName);
|
||||||
|
DIR *OpenDir(std::string_view dirName);
|
||||||
|
std::string RealPath(std::string_view path);
|
||||||
|
mindspore::MSTensor ReadFileToTensor(const std::string &file);
|
||||||
|
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
|
||||||
|
#endif
|
|
@ -0,0 +1,134 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include <sys/time.h>
|
||||||
|
#include <gflags/gflags.h>
|
||||||
|
#include <dirent.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <iosfwd>
|
||||||
|
#include <vector>
|
||||||
|
#include <fstream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "include/api/model.h"
|
||||||
|
#include "include/api/context.h"
|
||||||
|
#include "include/api/types.h"
|
||||||
|
#include "include/api/serialization.h"
|
||||||
|
#include "include/dataset/execute.h"
|
||||||
|
#include "include/dataset/vision.h"
|
||||||
|
#include "inc/utils.h"
|
||||||
|
|
||||||
|
using mindspore::Context;
|
||||||
|
using mindspore::Serialization;
|
||||||
|
using mindspore::Model;
|
||||||
|
using mindspore::Status;
|
||||||
|
using mindspore::MSTensor;
|
||||||
|
using mindspore::dataset::Execute;
|
||||||
|
using mindspore::ModelType;
|
||||||
|
using mindspore::GraphCell;
|
||||||
|
using mindspore::kSuccess;
|
||||||
|
|
||||||
|
DEFINE_string(mindir_path, "", "mindir path");
|
||||||
|
DEFINE_string(input0_path, ".", "input0 path");
|
||||||
|
DEFINE_int32(device_id, 0, "device id");
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||||
|
if (RealPath(FLAGS_mindir_path).empty()) {
|
||||||
|
std::cout << "Invalid mindir" << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto context = std::make_shared<Context>();
|
||||||
|
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
||||||
|
ascend310->SetDeviceID(FLAGS_device_id);
|
||||||
|
context->MutableDeviceInfo().push_back(ascend310);
|
||||||
|
mindspore::Graph graph;
|
||||||
|
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
|
||||||
|
|
||||||
|
Model model;
|
||||||
|
Status ret = model.Build(GraphCell(graph), context);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
std::cout << "ERROR: Build failed." << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<MSTensor> model_inputs = model.GetInputs();
|
||||||
|
if (model_inputs.empty()) {
|
||||||
|
std::cout << "Invalid model, inputs is empty." << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto input0_files = GetAllFiles(FLAGS_input0_path);
|
||||||
|
|
||||||
|
if (input0_files.empty()) {
|
||||||
|
std::cout << "ERROR: input data empty." << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::map<double, double> costTime_map;
|
||||||
|
size_t size = input0_files.size();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < size; ++i) {
|
||||||
|
struct timeval start = {0};
|
||||||
|
struct timeval end = {0};
|
||||||
|
double startTimeMs;
|
||||||
|
double endTimeMs;
|
||||||
|
std::vector<MSTensor> inputs;
|
||||||
|
std::vector<MSTensor> outputs;
|
||||||
|
std::cout << "Start predict input files:" << input0_files[i] << std::endl;
|
||||||
|
|
||||||
|
auto input0 = ReadFileToTensor(input0_files[i]);
|
||||||
|
|
||||||
|
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
|
||||||
|
input0.Data().get(), input0.DataSize());
|
||||||
|
|
||||||
|
for (auto shape : model_inputs[0].Shape()) {
|
||||||
|
std::cout << "model input shape" << shape << std::endl;
|
||||||
|
}
|
||||||
|
gettimeofday(&start, nullptr);
|
||||||
|
ret = model.Predict(inputs, &outputs);
|
||||||
|
gettimeofday(&end, nullptr);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
std::cout << "Predict " << input0_files[i] << " failed." << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
|
||||||
|
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
|
||||||
|
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
|
||||||
|
WriteResult(input0_files[i], outputs);
|
||||||
|
}
|
||||||
|
double average = 0.0;
|
||||||
|
int inferCount = 0;
|
||||||
|
|
||||||
|
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
|
||||||
|
double diff = 0.0;
|
||||||
|
diff = iter->second - iter->first;
|
||||||
|
average += diff;
|
||||||
|
inferCount++;
|
||||||
|
}
|
||||||
|
average = average / inferCount;
|
||||||
|
std::stringstream timeCost;
|
||||||
|
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl;
|
||||||
|
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl;
|
||||||
|
std::string fileName = "./time_Result" + std::string("/test_perform_static.txt");
|
||||||
|
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
|
||||||
|
fileStream << timeCost.str();
|
||||||
|
fileStream.close();
|
||||||
|
costTime_map.clear();
|
||||||
|
return 0;
|
||||||
|
}
|
|
@ -0,0 +1,130 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <iostream>
|
||||||
|
#include "inc/utils.h"
|
||||||
|
|
||||||
|
using mindspore::MSTensor;
|
||||||
|
using mindspore::DataType;
|
||||||
|
|
||||||
|
std::vector<std::string> GetAllFiles(std::string_view dirName) {
|
||||||
|
struct dirent *filename;
|
||||||
|
DIR *dir = OpenDir(dirName);
|
||||||
|
if (dir == nullptr) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
std::vector<std::string> res;
|
||||||
|
while ((filename = readdir(dir)) != nullptr) {
|
||||||
|
std::string dName = std::string(filename->d_name);
|
||||||
|
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
|
||||||
|
}
|
||||||
|
std::sort(res.begin(), res.end());
|
||||||
|
for (auto &f : res) {
|
||||||
|
std::cout << "image file: " << f << std::endl;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
|
||||||
|
std::string homePath = "./result_Files";
|
||||||
|
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||||
|
size_t outputSize;
|
||||||
|
std::shared_ptr<const void> netOutput;
|
||||||
|
netOutput = outputs[i].Data();
|
||||||
|
outputSize = outputs[i].DataSize();
|
||||||
|
std::cout << "output size:" << outputSize << std::endl;
|
||||||
|
int pos = imageFile.rfind('/');
|
||||||
|
std::string fileName(imageFile, pos + 1);
|
||||||
|
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
|
||||||
|
std::string outFileName = homePath + "/" + fileName;
|
||||||
|
FILE * outputFile = fopen(outFileName.c_str(), "wb");
|
||||||
|
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
|
||||||
|
fclose(outputFile);
|
||||||
|
outputFile = nullptr;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
|
||||||
|
if (file.empty()) {
|
||||||
|
std::cout << "Pointer file is nullptr" << std::endl;
|
||||||
|
return mindspore::MSTensor();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ifstream ifs(file);
|
||||||
|
if (!ifs.good()) {
|
||||||
|
std::cout << "File: " << file << " is not exist" << std::endl;
|
||||||
|
return mindspore::MSTensor();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ifs.is_open()) {
|
||||||
|
std::cout << "File: " << file << "open failed" << std::endl;
|
||||||
|
return mindspore::MSTensor();
|
||||||
|
}
|
||||||
|
|
||||||
|
ifs.seekg(0, std::ios::end);
|
||||||
|
size_t size = ifs.tellg();
|
||||||
|
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
|
||||||
|
|
||||||
|
ifs.seekg(0, std::ios::beg);
|
||||||
|
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
|
||||||
|
ifs.close();
|
||||||
|
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
DIR *OpenDir(std::string_view dirName) {
|
||||||
|
if (dirName.empty()) {
|
||||||
|
std::cout << " dirName is null ! " << std::endl;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
std::string realPath = RealPath(dirName);
|
||||||
|
struct stat s;
|
||||||
|
lstat(realPath.c_str(), &s);
|
||||||
|
if (!S_ISDIR(s.st_mode)) {
|
||||||
|
std::cout << "dirName is not a valid directory !" << std::endl;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
DIR *dir;
|
||||||
|
dir = opendir(realPath.c_str());
|
||||||
|
if (dir == nullptr) {
|
||||||
|
std::cout << "Can not open dir " << dirName << std::endl;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
std::cout << "Successfully opened the dir " << dirName << std::endl;
|
||||||
|
return dir;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string RealPath(std::string_view path) {
|
||||||
|
char realPathMem[PATH_MAX] = {0};
|
||||||
|
char *realPathRet = nullptr;
|
||||||
|
realPathRet = realpath(path.data(), realPathMem);
|
||||||
|
|
||||||
|
if (realPathRet == nullptr) {
|
||||||
|
std::cout << "File: " << path << " is not exist.";
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string realPath(realPathMem);
|
||||||
|
std::cout << path << " realpath is: " << realPath << std::endl;
|
||||||
|
return realPath;
|
||||||
|
}
|
|
@ -0,0 +1,131 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
##############export checkpoint file into air, onnx, mindir models#################
|
||||||
|
python export.py
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
import mindspore as ms
|
||||||
|
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
|
||||||
|
from src import dataloader, utility
|
||||||
|
from src.config import stgcn_chebconv_45min_cfg, stgcn_chebconv_30min_cfg, stgcn_chebconv_15min_cfg, stgcn_gcnconv_45min_cfg, stgcn_gcnconv_30min_cfg, stgcn_gcnconv_15min_cfg
|
||||||
|
from src.model import models
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Tracking')
|
||||||
|
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||||
|
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||||
|
choices=['Ascend', 'GPU', 'CPU'],
|
||||||
|
help='device where the code will be implemented (default: Ascend)')
|
||||||
|
parser.add_argument('--data_url', type=str, help='Train dataset directory.')
|
||||||
|
parser.add_argument('--data_path', type=str, default="vel.csv", help='Dataset file of vel.')
|
||||||
|
parser.add_argument('--wam_path', type=str, default="adj_mat.csv", help='Dataset file of warm.')
|
||||||
|
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||||
|
parser.add_argument("--n_pred", type=int, default=3, help="The number of time interval for predcition.")
|
||||||
|
parser.add_argument("--graph_conv_type", type=str, default="chebconv", help="Grapg convolution type.")
|
||||||
|
parser.add_argument("--file_name", type=str, default="stgcn", help="output file name.")
|
||||||
|
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||||
|
|
||||||
|
if args.graph_conv_type == "chebconv":
|
||||||
|
if args.n_pred == 9:
|
||||||
|
cfg = stgcn_chebconv_45min_cfg
|
||||||
|
elif args.n_pred == 6:
|
||||||
|
cfg = stgcn_chebconv_30min_cfg
|
||||||
|
elif args.n_pred == 3:
|
||||||
|
cfg = stgcn_chebconv_15min_cfg
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported n_pred.")
|
||||||
|
elif args.graph_conv_type == "gcnconv":
|
||||||
|
if args.n_pred == 9:
|
||||||
|
cfg = stgcn_gcnconv_45min_cfg
|
||||||
|
elif args.n_pred == 6:
|
||||||
|
cfg = stgcn_gcnconv_30min_cfg
|
||||||
|
elif args.n_pred == 3:
|
||||||
|
cfg = stgcn_gcnconv_15min_cfg
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported pred.")
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported graph_conv_type.")
|
||||||
|
|
||||||
|
if ((cfg.Kt - 1) * 2 * cfg.stblock_num > cfg.n_his) or ((cfg.Kt - 1) * 2 * cfg.stblock_num <= 0):
|
||||||
|
raise ValueError(f'ERROR: {cfg.Kt} and {cfg.stblock_num} are unacceptable.')
|
||||||
|
Ko = cfg.n_his - (cfg.Kt - 1) * 2 * cfg.stblock_num
|
||||||
|
if (cfg.graph_conv_type != "chebconv") and (cfg.graph_conv_type != "gcnconv"):
|
||||||
|
raise NotImplementedError(f'ERROR: {cfg.graph_conv_type} is not implemented.')
|
||||||
|
|
||||||
|
if (cfg.graph_conv_type == 'gcnconv') and (cfg.Ks != 2):
|
||||||
|
cfg.Ks = 2
|
||||||
|
|
||||||
|
# blocks: settings of channel size in st_conv_blocks and output layer,
|
||||||
|
# using the bottleneck design in st_conv_blocks
|
||||||
|
blocks = []
|
||||||
|
blocks.append([1])
|
||||||
|
for l in range(cfg.stblock_num):
|
||||||
|
blocks.append([64, 16, 64])
|
||||||
|
if Ko == 0:
|
||||||
|
blocks.append([128])
|
||||||
|
elif Ko > 0:
|
||||||
|
blocks.append([128, 128])
|
||||||
|
blocks.append([1])
|
||||||
|
|
||||||
|
|
||||||
|
day_slot = int(24 * 60 / cfg.time_intvl)
|
||||||
|
cfg.n_pred = cfg.n_pred
|
||||||
|
|
||||||
|
time_pred = cfg.n_pred * cfg.time_intvl
|
||||||
|
time_pred_str = str(time_pred) + '_mins'
|
||||||
|
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
device_num = 1
|
||||||
|
cfg.batch_size = cfg.batch_size*int(8/device_num)
|
||||||
|
device_id = args.device_id
|
||||||
|
data_dir = args.data_url + '/'
|
||||||
|
|
||||||
|
adj_mat = dataloader.load_weighted_adjacency_matrix(data_dir+args.wam_path)
|
||||||
|
|
||||||
|
n_vertex_vel = pd.read_csv(data_dir+args.data_path, header=None).shape[1]
|
||||||
|
n_vertex_adj = pd.read_csv(data_dir+args.wam_path, header=None).shape[1]
|
||||||
|
if n_vertex_vel == n_vertex_adj:
|
||||||
|
n_vertex = n_vertex_vel
|
||||||
|
else:
|
||||||
|
raise ValueError(f'ERROR: number of vertices in dataset is not equal to number of \
|
||||||
|
vertices in weighted adjacency matrix.')
|
||||||
|
|
||||||
|
mat = utility.calculate_laplacian_matrix(adj_mat, cfg.mat_type)
|
||||||
|
conv_matrix = Tensor(Tensor.from_numpy(mat), ms.float32)
|
||||||
|
if cfg.graph_conv_type == "chebconv":
|
||||||
|
if (cfg.mat_type != "wid_sym_normd_lap_mat") and (cfg.mat_type != "wid_rw_normd_lap_mat"):
|
||||||
|
raise ValueError(f'ERROR: {cfg.mat_type} is wrong.')
|
||||||
|
elif cfg.graph_conv_type == "gcnconv":
|
||||||
|
if (cfg.mat_type != "hat_sym_normd_lap_mat") and (cfg.mat_type != "hat_rw_normd_lap_mat"):
|
||||||
|
raise ValueError(f'ERROR: {cfg.mat_type} is wrong.')
|
||||||
|
|
||||||
|
stgcn_conv = models.STGCN_Conv(cfg.Kt, cfg.Ks, blocks, cfg.n_his, n_vertex, \
|
||||||
|
cfg.gated_act_func, cfg.graph_conv_type, conv_matrix, cfg.drop_rate)
|
||||||
|
net = stgcn_conv
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
param_dict = load_checkpoint(args.ckpt_file)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
|
||||||
|
input_arr = Tensor(np.zeros([args.batch_size, 1, cfg.n_his, n_vertex]), ms.float32)
|
||||||
|
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)
|
|
@ -0,0 +1,66 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""compute acc for ascend 310"""
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from src.config import stgcn_chebconv_45min_cfg
|
||||||
|
from src import dataloader
|
||||||
|
from sklearn import preprocessing
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser('mindspore stgcn testing')
|
||||||
|
# Path for data
|
||||||
|
parser.add_argument('--data_url', type=str, default='./data/', help='Test dataset directory.')
|
||||||
|
parser.add_argument('--label_dir', type=str, default='', help='label data directory.')
|
||||||
|
parser.add_argument('--result_dir', type=str, default="./result_Files", help='infer result dir.')
|
||||||
|
parser.add_argument('--data_path', type=str, default="vel.csv", help='Dataset file of vel.')
|
||||||
|
# Super parameters for testing
|
||||||
|
parser.add_argument('--n_pred', type=int, default=9, help='The number of time interval for predcition')
|
||||||
|
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
cfg = stgcn_chebconv_45min_cfg
|
||||||
|
cfg.batch_size = 1
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
zscore = preprocessing.StandardScaler()
|
||||||
|
|
||||||
|
rst_path = args.result_dir
|
||||||
|
labels = np.load(args.label_dir)
|
||||||
|
|
||||||
|
dataset = dataloader.create_dataset(args.data_url+args.data_path, \
|
||||||
|
cfg.batch_size, cfg.n_his, cfg.n_pred, zscore, True, mode=2)
|
||||||
|
|
||||||
|
mae, sum_y, mape, mse = [], [], [], []
|
||||||
|
|
||||||
|
for i in range(len(os.listdir(rst_path))):
|
||||||
|
file_name = os.path.join(rst_path, "STGCN_data_bs" + str(cfg.batch_size) + '_' + str(i) + '_0.bin')
|
||||||
|
output = np.fromfile(file_name, np.float16)
|
||||||
|
output = zscore.inverse_transform(output)
|
||||||
|
label = zscore.inverse_transform(labels[i])
|
||||||
|
|
||||||
|
d = np.abs(label - output)
|
||||||
|
mae += d.tolist()
|
||||||
|
sum_y += label.tolist()
|
||||||
|
mape += (d / label).tolist()
|
||||||
|
mse += (d ** 2).tolist()
|
||||||
|
|
||||||
|
MAE = np.array(mae).mean()
|
||||||
|
MAPE = np.array(mape).mean()
|
||||||
|
RMSE = np.sqrt(np.array(mse).mean())
|
||||||
|
|
||||||
|
print(f'MAE {MAE:.2f} | MAPE {MAPE*100:.2f} | RMSE {RMSE:.2f}')
|
|
@ -0,0 +1,59 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""generate dataset for ascend 310"""
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from src.config import stgcn_chebconv_45min_cfg
|
||||||
|
from src import dataloader
|
||||||
|
from sklearn import preprocessing
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser('mindspore stgcn testing')
|
||||||
|
parser.add_argument('--device_target', type=str, default='Ascend', \
|
||||||
|
help='device where the code will be implemented. (Default: Ascend)')
|
||||||
|
# Path for data and checkpoint
|
||||||
|
parser.add_argument('--data_url', type=str, default='', help='Test dataset directory.')
|
||||||
|
parser.add_argument('--data_path', type=str, default="vel.csv", help='Dataset file of vel.')
|
||||||
|
parser.add_argument('--result_path', type=str, default='./preprocess_Result/', help='result path')
|
||||||
|
# Super parameters for testing
|
||||||
|
parser.add_argument('--n_pred', type=int, default=9, help='The number of time interval for predcition')
|
||||||
|
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
cfg = stgcn_chebconv_45min_cfg
|
||||||
|
cfg.batch_size = 1
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
zscore = preprocessing.StandardScaler()
|
||||||
|
|
||||||
|
dataset = dataloader.create_dataset(args.data_url+args.data_path, \
|
||||||
|
cfg.batch_size, cfg.n_his, cfg.n_pred, zscore, True, mode=2)
|
||||||
|
|
||||||
|
img_path = os.path.join(args.result_path, "00_data")
|
||||||
|
os.mkdir(img_path)
|
||||||
|
|
||||||
|
label_list = []
|
||||||
|
# dataset is an instance of Dataset object
|
||||||
|
iterator = dataset.create_dict_iterator(output_numpy=True)
|
||||||
|
for i, data in enumerate(iterator):
|
||||||
|
file_name = "STGCN_data_bs" + str(cfg.batch_size) + "_" + str(i) + ".bin"
|
||||||
|
file_path = img_path + "/" + file_name
|
||||||
|
data['inputs'].tofile(file_path)
|
||||||
|
label_list.append(data['labels'])
|
||||||
|
|
||||||
|
np.save(args.result_path + "label_ids.npy", label_list)
|
||||||
|
print("="*20, "export bin files finished", "="*20)
|
|
@ -0,0 +1,83 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 4 ]; then
|
||||||
|
echo "Usage: sh run_distribute_train.sh [train_code_path][data_path][n_pred][graph_conv_type]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path() {
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
train_code_path=$(get_real_path $1)
|
||||||
|
echo $train_code_path
|
||||||
|
|
||||||
|
if [ ! -d $train_code_path ]
|
||||||
|
then
|
||||||
|
echo "error: train_code_path=$train_code_path is not a dictionary."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
data_path=$(get_real_path $2)
|
||||||
|
echo $train_code_path
|
||||||
|
|
||||||
|
if [ ! -d $data_path ]
|
||||||
|
then
|
||||||
|
echo "error: train_code_path=$train_code_path is not a dictionary."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ulimit -c unlimited
|
||||||
|
export SLOG_PRINT_TO_STDOUT=0
|
||||||
|
export RANK_TABLE_FILE=${train_code_path}/scripts/hccl_8p_01234567_127.0.0.1.json
|
||||||
|
export RANK_SIZE=8
|
||||||
|
export RANK_START_ID=0
|
||||||
|
export n_pred=$3
|
||||||
|
export graph_conv_type=$4
|
||||||
|
|
||||||
|
|
||||||
|
for((i=0;i<=$RANK_SIZE-1;i++));
|
||||||
|
do
|
||||||
|
export RANK_ID=${i}
|
||||||
|
export DEVICE_ID=$((i + RANK_START_ID))
|
||||||
|
echo 'start rank='${i}', device id='${DEVICE_ID}'...'
|
||||||
|
if [ -d ${train_code_path}/device${DEVICE_ID} ]; then
|
||||||
|
rm -rf ${train_code_path}/device${DEVICE_ID}
|
||||||
|
fi
|
||||||
|
mkdir ${train_code_path}/device${DEVICE_ID}
|
||||||
|
cd ${train_code_path}/device${DEVICE_ID} || exit
|
||||||
|
python ${train_code_path}/train.py --data_url=${data_path} \
|
||||||
|
--train_url=./checkpoint \
|
||||||
|
--run_distribute=True \
|
||||||
|
--run_modelarts=False \
|
||||||
|
--n_pred=$n_pred \
|
||||||
|
--graph_conv_type=$graph_conv_type > out.log 2>&1 &
|
||||||
|
done
|
|
@ -0,0 +1,39 @@
|
||||||
|
#!/bin/bash
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 6 ]
|
||||||
|
then
|
||||||
|
echo "Usage: sh run_standalone_eval_ascend.sh [data_path][ckpt_url][ckpt_name][device_id][graph_conv_type][n_pred]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
export Data_path=$1
|
||||||
|
export Ckpt_path=$2
|
||||||
|
export Ckpt_name=$3
|
||||||
|
export Device_id=$4
|
||||||
|
export Graph_conv_type=$5
|
||||||
|
export N_pred=$6
|
||||||
|
|
||||||
|
python test.py --data_url=$Data_path \
|
||||||
|
--train_url=./checkpoint \
|
||||||
|
--run_distribute=False \
|
||||||
|
--run_modelarts=False \
|
||||||
|
--device_id=$Device_id \
|
||||||
|
--ckpt_url=$Ckpt_path \
|
||||||
|
--ckpt_name=$Ckpt_name \
|
||||||
|
--n_pred=$N_pred \
|
||||||
|
--graph_conv_type=$Graph_conv_type > test.log 2>&1 &
|
|
@ -0,0 +1,130 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [[ $# -lt 4 || $# -gt 5 ]]; then
|
||||||
|
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [NEED_PREPROCESS] [DEVICE_TARGET] [DEVICE_ID]
|
||||||
|
DEVICE_TARGET must choose from ['GPU', 'CPU', 'Ascend']
|
||||||
|
NEED_PREPROCESS means weather need preprocess or not, it's value is 'y' or 'n'.
|
||||||
|
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
model=$(get_real_path $1)
|
||||||
|
dataset_path=$(get_real_path $2)
|
||||||
|
|
||||||
|
if [ "$3" == "y" ] || [ "$3" == "n" ];then
|
||||||
|
need_preprocess=$3
|
||||||
|
else
|
||||||
|
echo "weather need preprocess or not, it's value must be in [y, n]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [ "$4" == "GPU" ] || [ "$4" == "CPU" ] || [ "$4" == "Ascend" ];then
|
||||||
|
device_target=$4
|
||||||
|
else
|
||||||
|
echo "device_target must be in ['GPU', 'CPU', 'Ascend']"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
device_id=0
|
||||||
|
if [ $# == 5 ]; then
|
||||||
|
device_id=$5
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "mindir name: "$model
|
||||||
|
echo "dataset path: "$dataset_path
|
||||||
|
echo "need preprocess: "$need_preprocess
|
||||||
|
echo "device_target: "$device_target
|
||||||
|
echo "device id: "$device_id
|
||||||
|
|
||||||
|
export ASCEND_HOME=/usr/local/Ascend/
|
||||||
|
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
|
||||||
|
export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
|
||||||
|
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||||
|
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
|
||||||
|
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
|
||||||
|
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
|
||||||
|
else
|
||||||
|
export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
|
||||||
|
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||||
|
export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
|
||||||
|
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
|
||||||
|
fi
|
||||||
|
|
||||||
|
function preprocess_data()
|
||||||
|
{
|
||||||
|
if [ -d preprocess_Result ]; then
|
||||||
|
rm -rf ./preprocess_Result
|
||||||
|
fi
|
||||||
|
mkdir preprocess_Result
|
||||||
|
python3.7 ../preprocess.py --data_url=$dataset_path --result_path=./preprocess_Result/ --device_target=$device_target &>preprocess.log
|
||||||
|
}
|
||||||
|
|
||||||
|
function compile_app()
|
||||||
|
{
|
||||||
|
cd ../ascend310_infer || exit
|
||||||
|
bash build.sh &> build.log
|
||||||
|
}
|
||||||
|
|
||||||
|
function infer()
|
||||||
|
{
|
||||||
|
cd - || exit
|
||||||
|
if [ -d result_Files ]; then
|
||||||
|
rm -rf ./result_Files
|
||||||
|
fi
|
||||||
|
if [ -d time_Result ]; then
|
||||||
|
rm -rf ./time_Result
|
||||||
|
fi
|
||||||
|
mkdir result_Files
|
||||||
|
mkdir time_Result
|
||||||
|
|
||||||
|
../ascend310_infer/out/main --mindir_path=$model --input0_path=./preprocess_Result/00_data --device_id=$device_id &> infer.log
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
function cal_acc()
|
||||||
|
{
|
||||||
|
python3.7 ../postprocess.py --data_url=$dataset_path --result_dir=./result_Files --label_dir=./preprocess_Result/label_ids.npy --device_target=$device_target &> acc.log
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ $need_preprocess == "y" ]; then
|
||||||
|
preprocess_data
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "preprocess dataset failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
compile_app
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "compile app code failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
infer
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo " execute inference failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cal_acc
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "calculate accuracy failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
|
@ -0,0 +1,133 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
network config setting, will be used in train.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
|
||||||
|
stgcn_chebconv_45min_cfg = edict({
|
||||||
|
'learning_rate': 0.003,
|
||||||
|
'n_his': 12,
|
||||||
|
'n_pred': 9,
|
||||||
|
'epochs': 500,
|
||||||
|
'batch_size': 8,
|
||||||
|
'decay_epoch': 10,
|
||||||
|
'gamma': 0.7,
|
||||||
|
'stblock_num': 2,
|
||||||
|
'Ks': 2,
|
||||||
|
'Kt': 3,
|
||||||
|
'time_intvl': 5,
|
||||||
|
'drop_rate': 0.5,
|
||||||
|
'weight_decay_rate': 0.0005,
|
||||||
|
'gated_act_func': "glu",
|
||||||
|
'graph_conv_type': "chebconv",
|
||||||
|
'mat_type': "wid_sym_normd_lap_mat",
|
||||||
|
})
|
||||||
|
|
||||||
|
stgcn_chebconv_30min_cfg = edict({
|
||||||
|
'learning_rate': 0.003,
|
||||||
|
'n_his': 12,
|
||||||
|
'n_pred': 6,
|
||||||
|
'epochs': 500,
|
||||||
|
'batch_size': 8,
|
||||||
|
'decay_epoch': 10,
|
||||||
|
'gamma': 0.7,
|
||||||
|
'stblock_num': 2,
|
||||||
|
'Ks': 2,
|
||||||
|
'Kt': 3,
|
||||||
|
'time_intvl': 5,
|
||||||
|
'drop_rate': 0.5,
|
||||||
|
'weight_decay_rate': 0.0005,
|
||||||
|
'gated_act_func': "glu",
|
||||||
|
'graph_conv_type': "chebconv",
|
||||||
|
'mat_type': "wid_sym_normd_lap_mat",
|
||||||
|
})
|
||||||
|
|
||||||
|
stgcn_chebconv_15min_cfg = edict({
|
||||||
|
'learning_rate': 0.002,
|
||||||
|
'n_his': 12,
|
||||||
|
'n_pred': 3,
|
||||||
|
'epochs': 100,
|
||||||
|
'batch_size': 8,
|
||||||
|
'decay_epoch': 10,
|
||||||
|
'gamma': 0.999,
|
||||||
|
'stblock_num': 2,
|
||||||
|
'Ks': 2,
|
||||||
|
'Kt': 3,
|
||||||
|
'time_intvl': 5,
|
||||||
|
'drop_rate': 0.5,
|
||||||
|
'weight_decay_rate': 0.0005,
|
||||||
|
'gated_act_func': "glu",
|
||||||
|
'graph_conv_type': "chebconv",
|
||||||
|
'mat_type': "wid_rw_normd_lap_mat",
|
||||||
|
})
|
||||||
|
|
||||||
|
stgcn_gcnconv_45min_cfg = edict({
|
||||||
|
'learning_rate': 0.003,
|
||||||
|
'n_his': 12,
|
||||||
|
'n_pred': 9,
|
||||||
|
'epochs': 500,
|
||||||
|
'batch_size': 8,
|
||||||
|
'decay_epoch': 10,
|
||||||
|
'gamma': 0.7,
|
||||||
|
'stblock_num': 2,
|
||||||
|
'Ks': 2,
|
||||||
|
'Kt': 3,
|
||||||
|
'time_intvl': 5,
|
||||||
|
'drop_rate': 0.5,
|
||||||
|
'weight_decay_rate': 0.0005,
|
||||||
|
'gated_act_func': "glu",
|
||||||
|
'graph_conv_type': "gcnconv",
|
||||||
|
'mat_type': "hat_sym_normd_lap_mat",
|
||||||
|
})
|
||||||
|
|
||||||
|
stgcn_gcnconv_30min_cfg = edict({
|
||||||
|
'learning_rate': 0.003,
|
||||||
|
'n_his': 12,
|
||||||
|
'n_pred': 6,
|
||||||
|
'epochs': 500,
|
||||||
|
'batch_size': 8,
|
||||||
|
'decay_epoch': 10,
|
||||||
|
'gamma': 0.7,
|
||||||
|
'stblock_num': 2,
|
||||||
|
'Ks': 2,
|
||||||
|
'Kt': 3,
|
||||||
|
'time_intvl': 5,
|
||||||
|
'drop_rate': 0.5,
|
||||||
|
'weight_decay_rate': 0.0005,
|
||||||
|
'gated_act_func': "glu",
|
||||||
|
'graph_conv_type': "gcnconv",
|
||||||
|
'mat_type': "hat_sym_normd_lap_mat",
|
||||||
|
})
|
||||||
|
|
||||||
|
stgcn_gcnconv_15min_cfg = edict({
|
||||||
|
'learning_rate': 0.002,
|
||||||
|
'n_his': 12,
|
||||||
|
'n_pred': 3,
|
||||||
|
'epochs': 100,
|
||||||
|
'batch_size': 8,
|
||||||
|
'decay_epoch': 10,
|
||||||
|
'gamma': 0.9999,
|
||||||
|
'stblock_num': 2,
|
||||||
|
'Ks': 2,
|
||||||
|
'Kt': 3,
|
||||||
|
'time_intvl': 5,
|
||||||
|
'drop_rate': 0.5,
|
||||||
|
'weight_decay_rate': 0.0005,
|
||||||
|
'gated_act_func': "glu",
|
||||||
|
'graph_conv_type': "gcnconv",
|
||||||
|
'mat_type': "hat_rw_normd_lap_mat",
|
||||||
|
})
|
|
@ -0,0 +1,106 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
process dataset.
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
|
||||||
|
class STGCNDataset:
|
||||||
|
""" BRDNetDataset.
|
||||||
|
Args:
|
||||||
|
mode: 0 means train;1 means val;2 means test
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, data_path, n_his, n_pred, zscore, mode=0):
|
||||||
|
|
||||||
|
self.df = pd.read_csv(data_path, header=None)
|
||||||
|
self.data_col = self.df.shape[0]
|
||||||
|
# recommended dataset split rate as train: val: test = 60: 20: 20, 70: 15: 15 or 80: 10: 10
|
||||||
|
# using dataset split rate as train: val: test = 70: 15: 15
|
||||||
|
self.val_and_test_rate = 0.15
|
||||||
|
|
||||||
|
self.len_val = int(math.floor(self.data_col * self.val_and_test_rate))
|
||||||
|
self.len_test = int(math.floor(self.data_col * self.val_and_test_rate))
|
||||||
|
self.len_train = int(self.data_col - self.len_val - self.len_test)
|
||||||
|
|
||||||
|
self.dataset_train = self.df[: self.len_train]
|
||||||
|
self.dataset_val = self.df[self.len_train: self.len_train + self.len_val]
|
||||||
|
self.dataset_test = self.df[self.len_train + self.len_val:]
|
||||||
|
|
||||||
|
self.dataset_train = zscore.fit_transform(self.dataset_train)
|
||||||
|
self.dataset_val = zscore.transform(self.dataset_val)
|
||||||
|
self.dataset_test = zscore.transform(self.dataset_test)
|
||||||
|
|
||||||
|
if mode == 0:
|
||||||
|
self.dataset = self.dataset_train
|
||||||
|
elif mode == 1:
|
||||||
|
self.dataset = self.dataset_val
|
||||||
|
else:
|
||||||
|
self.dataset = self.dataset_test
|
||||||
|
|
||||||
|
self.n_his = n_his
|
||||||
|
self.n_pred = n_pred
|
||||||
|
self.n_vertex = self.dataset.shape[1]
|
||||||
|
self.len_record = len(self.dataset)
|
||||||
|
self.num = self.len_record - self.n_his - self.n_pred
|
||||||
|
|
||||||
|
self.x = np.zeros([self.num, 1, self.n_his, self.n_vertex], np.float32)
|
||||||
|
self.y = np.zeros([self.num, self.n_vertex], np.float32)
|
||||||
|
|
||||||
|
for i in range(self.num):
|
||||||
|
head = i
|
||||||
|
tail = i + self.n_his
|
||||||
|
self.x[i, :, :, :] = self.dataset[head: tail].reshape(1, self.n_his, self.n_vertex)
|
||||||
|
self.y[i] = self.dataset[tail + self.n_pred - 1]
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
index (int): Index
|
||||||
|
Returns:
|
||||||
|
x[index]: input of network
|
||||||
|
y[index]: label of network
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self.x[index], self.y[index]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num
|
||||||
|
|
||||||
|
|
||||||
|
def load_weighted_adjacency_matrix(file_path):
|
||||||
|
df = pd.read_csv(file_path, header=None)
|
||||||
|
return df.to_numpy()
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset(data_path, batch_size, n_his, n_pred, zscore, is_sigle, device_num=1, device_id=0, mode=0):
|
||||||
|
"""
|
||||||
|
generate dataset for train or test.
|
||||||
|
"""
|
||||||
|
data = STGCNDataset(data_path, n_his, n_pred, zscore, mode=mode)
|
||||||
|
shuffle = True
|
||||||
|
if mode != 0:
|
||||||
|
shuffle = False
|
||||||
|
if not is_sigle:
|
||||||
|
dataset = ds.GeneratorDataset(data, column_names=["inputs", "labels"], num_parallel_workers=32, \
|
||||||
|
shuffle=shuffle, num_shards=device_num, shard_id=device_id)
|
||||||
|
else:
|
||||||
|
dataset = ds.GeneratorDataset(data, column_names=["inputs", "labels"], num_parallel_workers=32, shuffle=shuffle)
|
||||||
|
dataset = dataset.batch(batch_size)
|
||||||
|
return dataset
|
|
@ -0,0 +1,306 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""network layer"""
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.ops as ops
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
|
||||||
|
from mindspore.common.initializer import initializer
|
||||||
|
|
||||||
|
class Align(nn.Cell):
|
||||||
|
"""align"""
|
||||||
|
def __init__(self, c_in, c_out):
|
||||||
|
super(Align, self).__init__()
|
||||||
|
self.c_in = c_in
|
||||||
|
self.c_out = c_out
|
||||||
|
self.align_conv = nn.Conv2d(in_channels=self.c_in, out_channels=self.c_out, kernel_size=1, \
|
||||||
|
pad_mode='valid', weight_init='he_uniform')
|
||||||
|
self.concat = ops.Concat(axis=1)
|
||||||
|
self.zeros = ops.Zeros()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x_align = x
|
||||||
|
if self.c_in > self.c_out:
|
||||||
|
x_align = self.align_conv(x)
|
||||||
|
elif self.c_in < self.c_out:
|
||||||
|
batch_size, _, timestep, n_vertex = x.shape
|
||||||
|
y = self.zeros((batch_size, self.c_out - self.c_in, timestep, n_vertex), x.dtype)
|
||||||
|
x_align = self.concat((x, y))
|
||||||
|
return x_align
|
||||||
|
|
||||||
|
class CausalConv2d(nn.Cell):
|
||||||
|
"""causal conv2d"""
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, \
|
||||||
|
enable_padding=False, dilation=1, groups=1, bias=True):
|
||||||
|
super(CausalConv2d, self).__init__()
|
||||||
|
if isinstance(kernel_size, int):
|
||||||
|
kernel_size = (kernel_size, kernel_size)
|
||||||
|
if isinstance(stride, int):
|
||||||
|
stride = (stride, stride)
|
||||||
|
if isinstance(dilation, int):
|
||||||
|
dilation = (dilation, dilation)
|
||||||
|
|
||||||
|
if enable_padding:
|
||||||
|
self.__padding = [int((kernel_size[i] - 1) * dilation[i]) for i in range(len(kernel_size))]
|
||||||
|
else:
|
||||||
|
self.__padding = 0
|
||||||
|
if isinstance(self.__padding, int):
|
||||||
|
self.left_padding = (self.__padding, self.__padding)
|
||||||
|
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, \
|
||||||
|
padding=0, pad_mode='valid', dilation=dilation, group=groups, has_bias=bias, weight_init='he_uniform')
|
||||||
|
self.pad = ops.Pad(((0, 0), (0, 0), (self.left_padding[0], 0), (self.left_padding[1], 0)))
|
||||||
|
def construct(self, x):
|
||||||
|
if self.__padding != 0:
|
||||||
|
x = self.pad(x)
|
||||||
|
result = self.conv2d(x)
|
||||||
|
return result
|
||||||
|
|
||||||
|
class TemporalConvLayer(nn.Cell):
|
||||||
|
"""
|
||||||
|
# Temporal Convolution Layer (GLU)
|
||||||
|
#
|
||||||
|
# |-------------------------------| * residual connection *
|
||||||
|
# | |
|
||||||
|
# | |--->--- casual conv ----- + -------|
|
||||||
|
# -------|----| ⊙ ------>
|
||||||
|
# |--->--- casual conv --- sigmoid ---|
|
||||||
|
#
|
||||||
|
|
||||||
|
#param x: tensor, [batch_size, c_in, timestep, n_vertex]
|
||||||
|
"""
|
||||||
|
def __init__(self, Kt, c_in, c_out, n_vertex, act_func):
|
||||||
|
super(TemporalConvLayer, self).__init__()
|
||||||
|
self.Kt = Kt
|
||||||
|
self.c_in = c_in
|
||||||
|
self.c_out = c_out
|
||||||
|
self.n_vertex = n_vertex
|
||||||
|
self.act_func = act_func
|
||||||
|
self.align = Align(self.c_in, self.c_out)
|
||||||
|
self.causal_conv = CausalConv2d(in_channels=self.c_in, out_channels=2 * self.c_out, \
|
||||||
|
kernel_size=(self.Kt, 1), enable_padding=False, dilation=1)
|
||||||
|
self.linear = nn.Dense(self.n_vertex, self.n_vertex).to_float(mstype.float16)
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
self.tanh = nn.Tanh()
|
||||||
|
self.add = ops.Add()
|
||||||
|
self.mul = ops.Mul()
|
||||||
|
self.split = ops.Split(axis=1, output_num=2)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""TemporalConvLayer compute"""
|
||||||
|
x_in = self.align(x)
|
||||||
|
x_in = x_in[:, :, self.Kt - 1:, :]
|
||||||
|
x_causal_conv = self.causal_conv(x)
|
||||||
|
x_tc_out = x_causal_conv
|
||||||
|
x_pq = self.split(x_tc_out)
|
||||||
|
x_p = x_pq[0]
|
||||||
|
x_q = x_pq[1]
|
||||||
|
x_glu = x_causal_conv
|
||||||
|
x_gtu = x_causal_conv
|
||||||
|
if self.act_func == 'glu':
|
||||||
|
# (x_p + x_in) ⊙ Sigmoid(x_q)
|
||||||
|
x_glu = self.mul(self.add(x_p, x_in), self.sigmoid(x_q))
|
||||||
|
x_tc_out = x_glu
|
||||||
|
# Temporal Convolution Layer (GTU)
|
||||||
|
elif self.act_func == 'gtu':
|
||||||
|
# Tanh(x_p + x_in) ⊙ Sigmoid(x_q)
|
||||||
|
x_gtu = self.mul(self.tanh(self.add(x_p, x_in)), self.sigmoid(x_q))
|
||||||
|
x_tc_out = x_gtu
|
||||||
|
return x_tc_out
|
||||||
|
|
||||||
|
class ChebConv(nn.Cell):
|
||||||
|
"""cheb conv"""
|
||||||
|
def __init__(self, c_in, c_out, Ks, chebconv_matrix):
|
||||||
|
super(ChebConv, self).__init__()
|
||||||
|
self.c_in = c_in
|
||||||
|
self.c_out = c_out
|
||||||
|
self.Ks = Ks
|
||||||
|
self.chebconv_matrix = chebconv_matrix
|
||||||
|
self.matmul = ops.MatMul()
|
||||||
|
self.stack = ops.Stack(axis=0)
|
||||||
|
self.reshape = ops.Reshape()
|
||||||
|
self.bias_add = ops.BiasAdd()
|
||||||
|
self.weight = ms.Parameter(initializer('normal', (self.Ks, self.c_in, self.c_out)), name='weight')
|
||||||
|
self.bias = ms.Parameter(initializer('Uniform', [self.c_out]), name='bias')
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""chebconv compute"""
|
||||||
|
_, c_in, _, n_vertex = x.shape
|
||||||
|
|
||||||
|
# Using recurrence relation to reduce time complexity from O(n^2) to O(K|E|),
|
||||||
|
# where K = Ks - 1
|
||||||
|
x = self.reshape(x, (n_vertex, -1))
|
||||||
|
x_0 = x
|
||||||
|
x_1 = self.matmul(self.chebconv_matrix, x)
|
||||||
|
x_list = []
|
||||||
|
if self.Ks - 1 == 0:
|
||||||
|
x_list = [x_0]
|
||||||
|
elif self.Ks - 1 == 1:
|
||||||
|
x_list = [x_0, x_1]
|
||||||
|
elif self.Ks - 1 >= 2:
|
||||||
|
x_list = [x_0, x_1]
|
||||||
|
for k in range(2, self.Ks):
|
||||||
|
x_list.append(self.matmul(2 * self.chebconv_matrix, x_list[k - 1]) - x_list[k - 2])
|
||||||
|
x_tensor = self.stack(x_list)
|
||||||
|
|
||||||
|
x_mul = self.matmul(self.reshape(x_tensor, (-1, self.Ks * c_in)), self.reshape(self.weight, \
|
||||||
|
(self.Ks * c_in, -1)))
|
||||||
|
x_mul = self.reshape(x_mul, (-1, self.c_out))
|
||||||
|
x_chebconv = self.bias_add(x_mul, self.bias)
|
||||||
|
return x_chebconv
|
||||||
|
|
||||||
|
class GCNConv(nn.Cell):
|
||||||
|
"""gcn conv"""
|
||||||
|
def __init__(self, c_in, c_out, gcnconv_matrix):
|
||||||
|
super(GCNConv, self).__init__()
|
||||||
|
self.c_in = c_in
|
||||||
|
self.c_out = c_out
|
||||||
|
self.gcnconv_matrix = gcnconv_matrix
|
||||||
|
self.matmul = ops.MatMul()
|
||||||
|
self.reshape = ops.Reshape()
|
||||||
|
self.bias_add = ops.BiasAdd()
|
||||||
|
|
||||||
|
self.weight = ms.Parameter(initializer('he_uniform', (self.c_in, self.c_out)), name='weight')
|
||||||
|
self.bias = ms.Parameter(initializer('Uniform', [self.c_out]), name='bias')
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""gcnconv compute"""
|
||||||
|
_, c_in, _, n_vertex = x.shape
|
||||||
|
x_first_mul = self.matmul(self.reshape(x, (-1, c_in)), self.weight)
|
||||||
|
x_first_mul = self.reshape(x_first_mul, (n_vertex, -1))
|
||||||
|
x_second_mul = self.matmul(self.gcnconv_matrix, x_first_mul)
|
||||||
|
x_second_mul = self.reshape(x_second_mul, (-1, self.c_out))
|
||||||
|
|
||||||
|
if self.bias is not None:
|
||||||
|
x_gcnconv_out = self.bias_add(x_second_mul, self.bias)
|
||||||
|
else:
|
||||||
|
x_gcnconv_out = x_second_mul
|
||||||
|
|
||||||
|
return x_gcnconv_out
|
||||||
|
|
||||||
|
class GraphConvLayer(nn.Cell):
|
||||||
|
"""grarh conv layer"""
|
||||||
|
def __init__(self, Ks, c_in, c_out, graph_conv_type, graph_conv_matrix):
|
||||||
|
super(GraphConvLayer, self).__init__()
|
||||||
|
self.Ks = Ks
|
||||||
|
self.c_in = c_in
|
||||||
|
self.c_out = c_out
|
||||||
|
self.align = Align(self.c_in, self.c_out)
|
||||||
|
self.graph_conv_type = graph_conv_type
|
||||||
|
self.graph_conv_matrix = graph_conv_matrix
|
||||||
|
if self.graph_conv_type == "chebconv":
|
||||||
|
self.chebconv = ChebConv(self.c_out, self.c_out, self.Ks, self.graph_conv_matrix)
|
||||||
|
elif self.graph_conv_type == "gcnconv":
|
||||||
|
self.gcnconv = GCNConv(self.c_out, self.c_out, self.graph_conv_matrix)
|
||||||
|
self.reshape = ops.Reshape()
|
||||||
|
self.add = ops.Add()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""GraphConvLayer compute"""
|
||||||
|
x_gc_in = self.align(x)
|
||||||
|
batch_size, _, T, n_vertex = x_gc_in.shape
|
||||||
|
x_gc = x_gc_in
|
||||||
|
if self.graph_conv_type == "chebconv":
|
||||||
|
x_gc = self.chebconv(x_gc_in)
|
||||||
|
elif self.graph_conv_type == "gcnconv":
|
||||||
|
x_gc = self.gcnconv(x_gc_in)
|
||||||
|
x_gc_with_rc = self.add(self.reshape(x_gc, (batch_size, self.c_out, T, n_vertex)), x_gc_in)
|
||||||
|
x_gc_out = x_gc_with_rc
|
||||||
|
return x_gc_out
|
||||||
|
|
||||||
|
class STConvBlock(nn.Cell):
|
||||||
|
"""
|
||||||
|
# STConv Block contains 'TNSATND' structure
|
||||||
|
# T: Gated Temporal Convolution Layer (GLU or GTU)
|
||||||
|
# G: Graph Convolution Layer (ChebConv or GCNConv)
|
||||||
|
# T: Gated Temporal Convolution Layer (GLU or GTU)
|
||||||
|
# N: Layer Normolization
|
||||||
|
# D: Dropout
|
||||||
|
#Kt Ks n_vertex
|
||||||
|
"""
|
||||||
|
def __init__(self, Kt, Ks, n_vertex, last_block_channel, channels, gated_act_func, graph_conv_type, \
|
||||||
|
graph_conv_matrix, drop_rate):
|
||||||
|
super(STConvBlock, self).__init__()
|
||||||
|
self.Kt = Kt
|
||||||
|
self.Ks = Ks
|
||||||
|
self.n_vertex = n_vertex
|
||||||
|
self.last_block_channel = last_block_channel
|
||||||
|
self.channels = channels
|
||||||
|
self.gated_act_func = gated_act_func
|
||||||
|
self.enable_gated_act_func = True
|
||||||
|
self.graph_conv_type = graph_conv_type
|
||||||
|
self.graph_conv_matrix = graph_conv_matrix
|
||||||
|
self.drop_rate = drop_rate
|
||||||
|
self.tmp_conv1 = TemporalConvLayer(self.Kt, self.last_block_channel, self.channels[0], \
|
||||||
|
self.n_vertex, self.gated_act_func)
|
||||||
|
self.graph_conv = GraphConvLayer(self.Ks, self.channels[0], self.channels[1], \
|
||||||
|
self.graph_conv_type, self.graph_conv_matrix)
|
||||||
|
self.tmp_conv2 = TemporalConvLayer(self.Kt, self.channels[1], self.channels[2], \
|
||||||
|
self.n_vertex, self.gated_act_func)
|
||||||
|
self.tc2_ln = nn.LayerNorm([self.n_vertex, self.channels[2]], begin_norm_axis=2, \
|
||||||
|
begin_params_axis=2, epsilon=1e-05)
|
||||||
|
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.do = nn.Dropout(keep_prob=self.drop_rate)
|
||||||
|
self.transpose = ops.Transpose()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""STConvBlock compute"""
|
||||||
|
x_tmp_conv1 = self.tmp_conv1(x)
|
||||||
|
x_graph_conv = self.graph_conv(x_tmp_conv1)
|
||||||
|
x_act_func = self.relu(x_graph_conv)
|
||||||
|
x_tmp_conv2 = self.tmp_conv2(x_act_func)
|
||||||
|
x_tc2_ln = self.transpose(x_tmp_conv2, (0, 2, 3, 1))
|
||||||
|
x_tc2_ln = self.tc2_ln(x_tc2_ln)
|
||||||
|
x_tc2_ln = self.transpose(x_tc2_ln, (0, 3, 1, 2))
|
||||||
|
x_do = self.do(x_tc2_ln)
|
||||||
|
x_st_conv_out = x_do
|
||||||
|
return x_st_conv_out
|
||||||
|
|
||||||
|
class OutputBlock(nn.Cell):
|
||||||
|
"""
|
||||||
|
# Output block contains 'TNFF' structure
|
||||||
|
# T: Gated Temporal Convolution Layer (GLU or GTU)
|
||||||
|
# N: Layer Normolization
|
||||||
|
# F: Fully-Connected Layer
|
||||||
|
# F: Fully-Connected Layer
|
||||||
|
"""
|
||||||
|
def __init__(self, Ko, last_block_channel, channels, end_channel, n_vertex, gated_act_func, drop_rate):
|
||||||
|
super(OutputBlock, self).__init__()
|
||||||
|
self.Ko = Ko
|
||||||
|
self.last_block_channel = last_block_channel
|
||||||
|
self.channels = channels
|
||||||
|
self.end_channel = end_channel
|
||||||
|
self.n_vertex = n_vertex
|
||||||
|
self.gated_act_func = gated_act_func
|
||||||
|
self.drop_rate = drop_rate
|
||||||
|
self.tmp_conv1 = TemporalConvLayer(self.Ko, self.last_block_channel, \
|
||||||
|
self.channels[0], self.n_vertex, self.gated_act_func)
|
||||||
|
self.fc1 = nn.Dense(self.channels[0], self.channels[1]).to_float(mstype.float16)
|
||||||
|
self.fc2 = nn.Dense(self.channels[1], self.end_channel).to_float(mstype.float16)
|
||||||
|
self.tc1_ln = nn.LayerNorm([self.n_vertex, self.channels[0]], begin_norm_axis=2, \
|
||||||
|
begin_params_axis=2, epsilon=1e-05)
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
self.transpose = ops.Transpose()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""OutputBlock compute"""
|
||||||
|
x_tc1 = self.tmp_conv1(x)
|
||||||
|
x_tc1_ln = self.tc1_ln(self.transpose(x_tc1, (0, 2, 3, 1)))
|
||||||
|
x_fc1 = self.fc1(x_tc1_ln)
|
||||||
|
x_act_func = self.sigmoid(x_fc1)
|
||||||
|
x_fc2 = self.transpose(self.fc2(x_act_func), (0, 3, 1, 2))
|
||||||
|
x_out = x_fc2
|
||||||
|
return x_out
|
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
stgcn network with loss.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.ops as P
|
||||||
|
|
||||||
|
class LossCellWithNetwork(nn.Cell):
|
||||||
|
"""STGCN loss."""
|
||||||
|
def __init__(self, network):
|
||||||
|
super(LossCellWithNetwork, self).__init__()
|
||||||
|
self.loss = nn.MSELoss()
|
||||||
|
self.network = network
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
|
||||||
|
def construct(self, x, label):
|
||||||
|
x = self.network(x)
|
||||||
|
x = self.reshape(x, (len(x), -1))
|
||||||
|
label = self.reshape(label, (len(label), -1))
|
||||||
|
STGCN_loss = self.loss(x, label)
|
||||||
|
return STGCN_loss
|
|
@ -0,0 +1,64 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
stgcn network.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
|
||||||
|
from src.model import layers
|
||||||
|
|
||||||
|
class STGCN_Conv(nn.Cell):
|
||||||
|
"""
|
||||||
|
# STGCN(ChebConv) contains 'TGTND TGTND TNFF' structure
|
||||||
|
# ChebConv is the graph convolution from ChebyNet.
|
||||||
|
# Using the Chebyshev polynomials of the first kind to
|
||||||
|
# approximate graph convolution kernel from Spectral CNN.
|
||||||
|
|
||||||
|
# GCNConv is the graph convolution from GCN.
|
||||||
|
|
||||||
|
# T: Gated Temporal Convolution Layer (GLU or GTU)
|
||||||
|
# G: Graph Convolution Layer (ChebConv)
|
||||||
|
# T: Gated Temporal Convolution Layer (GLU or GTU)
|
||||||
|
# N: Layer Normolization
|
||||||
|
# D: Dropout
|
||||||
|
|
||||||
|
# T: Gated Temporal Convolution Layer (GLU or GTU)
|
||||||
|
# G: Graph Convolution Layer (ChebConv)
|
||||||
|
# T: Gated Temporal Convolution Layer (GLU or GTU)
|
||||||
|
# N: Layer Normolization
|
||||||
|
# D: Dropout
|
||||||
|
|
||||||
|
# T: Gated Temporal Convolution Layer (GLU or GTU)
|
||||||
|
# N: Layer Normalization
|
||||||
|
# F: Fully-Connected Layer
|
||||||
|
# F: Fully-Connected Layer
|
||||||
|
"""
|
||||||
|
def __init__(self, Kt, Ks, blocks, T, n_vertex, gated_act_func, graph_conv_type, chebconv_matrix, drop_rate):
|
||||||
|
super(STGCN_Conv, self).__init__()
|
||||||
|
modules = []
|
||||||
|
for l in range(len(blocks) - 3):
|
||||||
|
modules.append(layers.STConvBlock(Kt, Ks, n_vertex, blocks[l][-1], blocks[l+1], \
|
||||||
|
gated_act_func, graph_conv_type, chebconv_matrix, drop_rate))
|
||||||
|
self.st_blocks = nn.SequentialCell(modules)
|
||||||
|
Ko = T - (len(blocks) - 3) * 2 * (Kt - 1)
|
||||||
|
self.Ko = Ko
|
||||||
|
self.output = layers.OutputBlock(self.Ko, blocks[-3][-1], blocks[-2], \
|
||||||
|
blocks[-1][0], n_vertex, gated_act_func, drop_rate)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x_stbs = self.st_blocks(x)
|
||||||
|
x_out = self.output(x_stbs)
|
||||||
|
return x_out
|
|
@ -0,0 +1,109 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
Calculate laplacian matrix, used to network weight.
|
||||||
|
Evaluate the performance of net work.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.ops as ops
|
||||||
|
|
||||||
|
from scipy.linalg import fractional_matrix_power
|
||||||
|
from scipy.sparse.linalg import eigs
|
||||||
|
|
||||||
|
def calculate_laplacian_matrix(adj_mat, mat_type):
|
||||||
|
"""
|
||||||
|
calculate laplacian matrix used for graph convolution layer.
|
||||||
|
"""
|
||||||
|
n_vertex = adj_mat.shape[0]
|
||||||
|
|
||||||
|
# row sum
|
||||||
|
deg_mat_row = np.asmatrix(np.diag(np.sum(adj_mat, axis=1)))
|
||||||
|
# column sum
|
||||||
|
#deg_mat_col = np.asmatrix(np.diag(np.sum(adj_mat, axis=0)))
|
||||||
|
deg_mat = deg_mat_row
|
||||||
|
|
||||||
|
adj_mat = np.asmatrix(adj_mat)
|
||||||
|
id_mat = np.asmatrix(np.identity(n_vertex))
|
||||||
|
|
||||||
|
# Combinatorial
|
||||||
|
com_lap_mat = deg_mat - adj_mat
|
||||||
|
|
||||||
|
# For SpectraConv
|
||||||
|
# To [0, 1]
|
||||||
|
sym_normd_lap_mat = np.matmul(np.matmul(fractional_matrix_power(deg_mat, -0.5), \
|
||||||
|
com_lap_mat), fractional_matrix_power(deg_mat, -0.5))
|
||||||
|
|
||||||
|
# For ChebConv
|
||||||
|
# From [0, 1] to [-1, 1]
|
||||||
|
lambda_max_sym = eigs(sym_normd_lap_mat, k=1, which='LR')[0][0].real
|
||||||
|
wid_sym_normd_lap_mat = 2 * sym_normd_lap_mat / lambda_max_sym - id_mat
|
||||||
|
|
||||||
|
# For GCNConv
|
||||||
|
wid_deg_mat = deg_mat + id_mat
|
||||||
|
wid_adj_mat = adj_mat + id_mat
|
||||||
|
hat_sym_normd_lap_mat = np.matmul(np.matmul(fractional_matrix_power(wid_deg_mat, -0.5), \
|
||||||
|
wid_adj_mat), fractional_matrix_power(wid_deg_mat, -0.5))
|
||||||
|
|
||||||
|
# Random Walk
|
||||||
|
rw_lap_mat = np.matmul(np.linalg.matrix_power(deg_mat, -1), adj_mat)
|
||||||
|
|
||||||
|
# For SpectraConv
|
||||||
|
# To [0, 1]
|
||||||
|
rw_normd_lap_mat = id_mat - rw_lap_mat
|
||||||
|
|
||||||
|
# For ChebConv
|
||||||
|
# From [0, 1] to [-1, 1]
|
||||||
|
lambda_max_rw = eigs(rw_lap_mat, k=1, which='LR')[0][0].real
|
||||||
|
wid_rw_normd_lap_mat = 2 * rw_normd_lap_mat / lambda_max_rw - id_mat
|
||||||
|
|
||||||
|
# For GCNConv
|
||||||
|
wid_deg_mat = deg_mat + id_mat
|
||||||
|
wid_adj_mat = adj_mat + id_mat
|
||||||
|
hat_rw_normd_lap_mat = np.matmul(np.linalg.matrix_power(wid_deg_mat, -1), wid_adj_mat)
|
||||||
|
|
||||||
|
if mat_type == 'wid_sym_normd_lap_mat':
|
||||||
|
return wid_sym_normd_lap_mat
|
||||||
|
if mat_type == 'hat_sym_normd_lap_mat':
|
||||||
|
return hat_sym_normd_lap_mat
|
||||||
|
if mat_type == 'wid_rw_normd_lap_mat':
|
||||||
|
return wid_rw_normd_lap_mat
|
||||||
|
if mat_type == 'hat_rw_normd_lap_mat':
|
||||||
|
return hat_rw_normd_lap_mat
|
||||||
|
raise ValueError(f'ERROR: "{mat_type}" is unknown.')
|
||||||
|
|
||||||
|
def evaluate_metric(model, dataset, scaler):
|
||||||
|
"""
|
||||||
|
evaluate the performance of network.
|
||||||
|
"""
|
||||||
|
mae, sum_y, mape, mse = [], [], [], []
|
||||||
|
for data in dataset.create_dict_iterator():
|
||||||
|
x = data['inputs']
|
||||||
|
y = data['labels']
|
||||||
|
y_pred = model(x)
|
||||||
|
y_pred = ops.Reshape()(y_pred, (len(y_pred), -1))
|
||||||
|
y_pred = scaler.inverse_transform(y_pred.asnumpy()).reshape(-1)
|
||||||
|
y = scaler.inverse_transform(y.asnumpy()).reshape(-1)
|
||||||
|
d = np.abs(y - y_pred)
|
||||||
|
mae += d.tolist()
|
||||||
|
sum_y += y.tolist()
|
||||||
|
mape += (d / y).tolist()
|
||||||
|
mse += (d ** 2).tolist()
|
||||||
|
MAE = np.array(mae).mean()
|
||||||
|
MAPE = np.array(mape).mean()
|
||||||
|
RMSE = np.sqrt(np.array(mse).mean())
|
||||||
|
#WMAPE = np.sum(np.array(mae)) / np.sum(np.array(sum_y))
|
||||||
|
|
||||||
|
return MAE, RMSE, MAPE
|
|
@ -0,0 +1,190 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
testing network performance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import ast
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from sklearn import preprocessing
|
||||||
|
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore.communication.management import init
|
||||||
|
from mindspore.train.model import ParallelMode
|
||||||
|
|
||||||
|
from src.model import models
|
||||||
|
from src.config import stgcn_chebconv_45min_cfg, stgcn_chebconv_30min_cfg, stgcn_chebconv_15min_cfg, stgcn_gcnconv_45min_cfg, stgcn_gcnconv_30min_cfg, stgcn_gcnconv_15min_cfg
|
||||||
|
from src import dataloader, utility
|
||||||
|
|
||||||
|
os.system("export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python")
|
||||||
|
parser = argparse.ArgumentParser('mindspore stgcn testing')
|
||||||
|
parser.add_argument('--device_target', type=str, default='Ascend', \
|
||||||
|
help='device where the code will be implemented. (Default: Ascend)')
|
||||||
|
|
||||||
|
# The way of testing
|
||||||
|
parser.add_argument('--run_modelarts', type=ast.literal_eval, default=False, help='Run on modelarts.')
|
||||||
|
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
|
||||||
|
parser.add_argument('--device_id', type=int, default=0, help='Device id.')
|
||||||
|
|
||||||
|
# Path for data and checkpoint
|
||||||
|
parser.add_argument('--data_url', type=str, default='', help='Test dataset directory.')
|
||||||
|
parser.add_argument('--train_url', type=str, default='', help='Output directory.')
|
||||||
|
parser.add_argument('--data_path', type=str, default="vel.csv", help='Dataset file of vel.')
|
||||||
|
parser.add_argument('--wam_path', type=str, default="adj_mat.csv", help='Dataset file of warm.')
|
||||||
|
parser.add_argument('--ckpt_url', type=str, default='', help='The path of checkpoint.')
|
||||||
|
parser.add_argument('--ckpt_name', type=str, default="", help='the name of checkpoint.')
|
||||||
|
|
||||||
|
# Super parameters for testing
|
||||||
|
parser.add_argument('--n_pred', type=int, default=3, help='The number of time interval for predcition')
|
||||||
|
|
||||||
|
#network
|
||||||
|
parser.add_argument('--graph_conv_type', type=str, default="gcnconv", help='Grapg convolution type')
|
||||||
|
#dataset
|
||||||
|
|
||||||
|
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||||
|
|
||||||
|
if args.graph_conv_type == "chebconv":
|
||||||
|
if args.n_pred == 9:
|
||||||
|
cfg = stgcn_chebconv_45min_cfg
|
||||||
|
elif args.n_pred == 6:
|
||||||
|
cfg = stgcn_chebconv_30min_cfg
|
||||||
|
elif args.n_pred == 3:
|
||||||
|
cfg = stgcn_chebconv_15min_cfg
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported n_pred.")
|
||||||
|
elif args.graph_conv_type == "gcnconv":
|
||||||
|
if args.n_pred == 9:
|
||||||
|
cfg = stgcn_gcnconv_45min_cfg
|
||||||
|
elif args.n_pred == 6:
|
||||||
|
cfg = stgcn_gcnconv_30min_cfg
|
||||||
|
elif args.n_pred == 3:
|
||||||
|
cfg = stgcn_gcnconv_15min_cfg
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported pred.")
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported graph_conv_type.")
|
||||||
|
|
||||||
|
|
||||||
|
if ((cfg.Kt - 1) * 2 * cfg.stblock_num > cfg.n_his) or ((cfg.Kt - 1) * 2 * cfg.stblock_num <= 0):
|
||||||
|
raise ValueError(f'ERROR: {cfg.Kt} and {cfg.stblock_num} are unacceptable.')
|
||||||
|
Ko = cfg.n_his - (cfg.Kt - 1) * 2 * cfg.stblock_num
|
||||||
|
if (cfg.graph_conv_type != "chebconv") and (cfg.graph_conv_type != "gcnconv"):
|
||||||
|
raise NotImplementedError(f'ERROR: {cfg.graph_conv_type} is not implemented.')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if (cfg.graph_conv_type == 'gcnconv') and (cfg.Ks != 2):
|
||||||
|
cfg.Ks = 2
|
||||||
|
|
||||||
|
# blocks: settings of channel size in st_conv_blocks and output layer,
|
||||||
|
# using the bottleneck design in st_conv_blocks
|
||||||
|
blocks = []
|
||||||
|
blocks.append([1])
|
||||||
|
for l in range(cfg.stblock_num):
|
||||||
|
blocks.append([64, 16, 64])
|
||||||
|
if Ko == 0:
|
||||||
|
blocks.append([128])
|
||||||
|
elif Ko > 0:
|
||||||
|
blocks.append([128, 128])
|
||||||
|
blocks.append([1])
|
||||||
|
|
||||||
|
|
||||||
|
day_slot = int(24 * 60 / cfg.time_intvl)
|
||||||
|
cfg.n_pred = cfg.n_pred
|
||||||
|
|
||||||
|
time_pred = cfg.n_pred * cfg.time_intvl
|
||||||
|
time_pred_str = str(time_pred) + '_mins'
|
||||||
|
|
||||||
|
if args.run_modelarts:
|
||||||
|
import moxing as mox
|
||||||
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
device_num = int(os.getenv('RANK_SIZE'))
|
||||||
|
cfg.batch_size = cfg.batch_size*int(8/device_num)
|
||||||
|
local_data_url = '/cache/data'
|
||||||
|
local_ckpt_url = '/cache/ckpt'
|
||||||
|
mox.file.copy_parallel(args.data_url, local_data_url)
|
||||||
|
mox.file.copy_parallel(args.ckpt_url, local_ckpt_url)
|
||||||
|
if device_num > 1:
|
||||||
|
init()
|
||||||
|
context.set_auto_parallel_context(device_num=device_num, \
|
||||||
|
parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
||||||
|
data_dir = local_data_url + '/'
|
||||||
|
local_ckpt_url = local_ckpt_url + '/'
|
||||||
|
else:
|
||||||
|
if args.run_distribute:
|
||||||
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
device_num = int(os.getenv('RANK_SIZE'))
|
||||||
|
cfg.batch_size = cfg.batch_size*int(8/device_num)
|
||||||
|
context.set_context(device_id=device_id)
|
||||||
|
init()
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
context.set_auto_parallel_context(device_num=device_num, \
|
||||||
|
parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
||||||
|
else:
|
||||||
|
device_num = 1
|
||||||
|
device_id = args.device_id
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
data_dir = args.data_url + '/'
|
||||||
|
local_ckpt_url = args.ckpt_url + '/'
|
||||||
|
|
||||||
|
adj_mat = dataloader.load_weighted_adjacency_matrix(data_dir+args.wam_path)
|
||||||
|
|
||||||
|
n_vertex_vel = pd.read_csv(data_dir+args.data_path, header=None).shape[1]
|
||||||
|
n_vertex_adj = pd.read_csv(data_dir+args.wam_path, header=None).shape[1]
|
||||||
|
if n_vertex_vel == n_vertex_adj:
|
||||||
|
n_vertex = n_vertex_vel
|
||||||
|
else:
|
||||||
|
raise ValueError(f'ERROR: number of vertices in dataset is not equal to \
|
||||||
|
number of vertices in weighted adjacency matrix.')
|
||||||
|
|
||||||
|
mat = utility.calculate_laplacian_matrix(adj_mat, cfg.mat_type)
|
||||||
|
conv_matrix = Tensor(Tensor.from_numpy(mat), mstype.float32)
|
||||||
|
if cfg.graph_conv_type == "chebconv":
|
||||||
|
if (cfg.mat_type != "wid_sym_normd_lap_mat") and (cfg.mat_type != "wid_rw_normd_lap_mat"):
|
||||||
|
raise ValueError(f'ERROR: {cfg.mat_type} is wrong.')
|
||||||
|
elif cfg.graph_conv_type == "gcnconv":
|
||||||
|
if (cfg.mat_type != "hat_sym_normd_lap_mat") and (cfg.mat_type != "hat_rw_normd_lap_mat"):
|
||||||
|
raise ValueError(f'ERROR: {cfg.mat_type} is wrong.')
|
||||||
|
|
||||||
|
stgcn_conv = models.STGCN_Conv(cfg.Kt, cfg.Ks, blocks, cfg.n_his, n_vertex, \
|
||||||
|
cfg.gated_act_func, cfg.graph_conv_type, conv_matrix, cfg.drop_rate)
|
||||||
|
net = stgcn_conv
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
zscore = preprocessing.StandardScaler()
|
||||||
|
if args.run_modelarts or args.run_distribute:
|
||||||
|
dataset = dataloader.create_dataset(data_dir+args.data_path, \
|
||||||
|
cfg.batch_size, cfg.n_his, cfg.n_pred, zscore, False, device_num, device_id, mode=2)
|
||||||
|
else:
|
||||||
|
dataset = dataloader.create_dataset(data_dir+args.data_path, \
|
||||||
|
cfg.batch_size, cfg.n_his, cfg.n_pred, zscore, True, device_num, device_id, mode=2)
|
||||||
|
data_len = dataset.get_dataset_size()
|
||||||
|
|
||||||
|
param_dict = load_checkpoint(local_ckpt_url+args.ckpt_name)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
|
||||||
|
test_MAE, test_RMSE, test_MAPE = utility.evaluate_metric(net, dataset, zscore)
|
||||||
|
print(f'MAE {test_MAE:.2f} | MAPE {test_MAPE*100:.2f} | RMSE {test_RMSE:.2f}')
|
|
@ -0,0 +1,222 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
train network.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import ast
|
||||||
|
import pandas as pd
|
||||||
|
from sklearn import preprocessing
|
||||||
|
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
import mindspore.nn as nn
|
||||||
|
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.communication.management import init
|
||||||
|
from mindspore.train.model import Model, ParallelMode
|
||||||
|
from mindspore.train.callback import CheckpointConfig, LossMonitor, ModelCheckpoint, TimeMonitor
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
|
from src.config import stgcn_chebconv_45min_cfg, stgcn_chebconv_30min_cfg, stgcn_chebconv_15min_cfg, stgcn_gcnconv_45min_cfg, stgcn_gcnconv_30min_cfg, stgcn_gcnconv_15min_cfg
|
||||||
|
from src import dataloader, utility
|
||||||
|
from src.model import models, metric
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser('mindspore stgcn training')
|
||||||
|
|
||||||
|
# The way of training
|
||||||
|
parser.add_argument('--device_target', type=str, default='Ascend', \
|
||||||
|
help='device where the code will be implemented. (Default: Ascend)')
|
||||||
|
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
|
||||||
|
parser.add_argument('--device_id', type=int, default=0, help='Device id')
|
||||||
|
parser.add_argument('--run_modelarts', type=ast.literal_eval, default=False, help='Run on modelarts')
|
||||||
|
parser.add_argument('--save_check_point', type=bool, default=True, help='Whether save checkpoint')
|
||||||
|
|
||||||
|
# Path for data and checkpoint
|
||||||
|
parser.add_argument('--data_url', type=str, required=True, help='Train dataset directory.')
|
||||||
|
parser.add_argument('--train_url', type=str, required=True, help='Save checkpoint directory.')
|
||||||
|
parser.add_argument('--data_path', type=str, default="vel.csv", help='Dataset file of vel.')
|
||||||
|
parser.add_argument('--wam_path', type=str, default="adj_mat.csv", help='Dataset file of warm.')
|
||||||
|
|
||||||
|
# Super parameters for training
|
||||||
|
parser.add_argument('--n_pred', type=int, default=3, help='The number of time interval for predcition, default as 3')
|
||||||
|
parser.add_argument('--opt', type=str, default='AdamW', help='optimizer, default as AdamW')
|
||||||
|
|
||||||
|
#network
|
||||||
|
parser.add_argument('--graph_conv_type', type=str, default="gcnconv", help='Grapg convolution type')
|
||||||
|
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
if args.graph_conv_type == "chebconv":
|
||||||
|
if args.n_pred == 9:
|
||||||
|
cfg = stgcn_chebconv_45min_cfg
|
||||||
|
elif args.n_pred == 6:
|
||||||
|
cfg = stgcn_chebconv_30min_cfg
|
||||||
|
elif args.n_pred == 3:
|
||||||
|
cfg = stgcn_chebconv_15min_cfg
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported n_pred.")
|
||||||
|
elif args.graph_conv_type == "gcnconv":
|
||||||
|
if args.n_pred == 9:
|
||||||
|
cfg = stgcn_gcnconv_45min_cfg
|
||||||
|
elif args.n_pred == 6:
|
||||||
|
cfg = stgcn_gcnconv_30min_cfg
|
||||||
|
elif args.n_pred == 3:
|
||||||
|
cfg = stgcn_gcnconv_15min_cfg
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported pred.")
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported graph_conv_type.")
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False)
|
||||||
|
|
||||||
|
if ((cfg.Kt - 1) * 2 * cfg.stblock_num > cfg.n_his) or ((cfg.Kt - 1) * 2 * cfg.stblock_num <= 0):
|
||||||
|
raise ValueError(f'ERROR: {cfg.Kt} and {cfg.stblock_num} are unacceptable.')
|
||||||
|
|
||||||
|
Ko = cfg.n_his - (cfg.Kt - 1) * 2 * cfg.stblock_num
|
||||||
|
|
||||||
|
if (cfg.graph_conv_type != "chebconv") and (cfg.graph_conv_type != "gcnconv"):
|
||||||
|
raise NotImplementedError(f'ERROR: {cfg.graph_conv_type} is not implemented.')
|
||||||
|
|
||||||
|
if (cfg.graph_conv_type == 'gcnconv') and (cfg.Ks != 2):
|
||||||
|
cfg.Ks = 2
|
||||||
|
|
||||||
|
# blocks: settings of channel size in st_conv_blocks and output layer,
|
||||||
|
# using the bottleneck design in st_conv_blocks
|
||||||
|
blocks = []
|
||||||
|
blocks.append([1])
|
||||||
|
for l in range(cfg.stblock_num):
|
||||||
|
blocks.append([64, 16, 64])
|
||||||
|
if Ko == 0:
|
||||||
|
blocks.append([128])
|
||||||
|
elif Ko > 0:
|
||||||
|
blocks.append([128, 128])
|
||||||
|
blocks.append([1])
|
||||||
|
|
||||||
|
|
||||||
|
day_slot = int(24 * 60 / cfg.time_intvl)
|
||||||
|
cfg.n_pred = cfg.n_pred
|
||||||
|
|
||||||
|
time_pred = cfg.n_pred * cfg.time_intvl
|
||||||
|
time_pred_str = str(time_pred) + '_mins'
|
||||||
|
|
||||||
|
if args.run_modelarts:
|
||||||
|
import moxing as mox
|
||||||
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
device_num = int(os.getenv('RANK_SIZE'))
|
||||||
|
cfg.batch_size = cfg.batch_size*int(8/device_num)
|
||||||
|
context.set_context(device_id=device_id)
|
||||||
|
local_data_url = '/cache/data'
|
||||||
|
local_train_url = '/cache/train'
|
||||||
|
#mox.file.make_dirs(local_train_url)
|
||||||
|
mox.file.copy_parallel(args.data_url, local_data_url)
|
||||||
|
if device_num > 1:
|
||||||
|
init()
|
||||||
|
#context.set_auto_parallel_context(parameter_broadcast=True)
|
||||||
|
context.set_auto_parallel_context(device_num=device_num, \
|
||||||
|
parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
||||||
|
data_dir = local_data_url + '/'
|
||||||
|
else:
|
||||||
|
if args.run_distribute:
|
||||||
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
device_num = int(os.getenv('RANK_SIZE'))
|
||||||
|
cfg.batch_size = cfg.batch_size*int(8/device_num)
|
||||||
|
context.set_context(device_id=device_id)
|
||||||
|
init()
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
#context.set_auto_parallel_context(parameter_broadcast=True)
|
||||||
|
context.set_auto_parallel_context(device_num=device_num, \
|
||||||
|
parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
||||||
|
else:
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
device_num = 1
|
||||||
|
cfg.batch_size = cfg.batch_size*int(8/device_num)
|
||||||
|
device_id = args.device_id
|
||||||
|
data_dir = args.data_url + '/'
|
||||||
|
model_save_path = args.train_url + cfg.graph_conv_type + '_' + time_pred_str
|
||||||
|
|
||||||
|
adj_mat = dataloader.load_weighted_adjacency_matrix(data_dir+args.wam_path)
|
||||||
|
|
||||||
|
n_vertex_vel = pd.read_csv(data_dir+args.data_path, header=None).shape[1]
|
||||||
|
n_vertex_adj = pd.read_csv(data_dir+args.wam_path, header=None).shape[1]
|
||||||
|
if n_vertex_vel == n_vertex_adj:
|
||||||
|
n_vertex = n_vertex_vel
|
||||||
|
else:
|
||||||
|
raise ValueError(f"ERROR: number of vertices in dataset is not equal to number \
|
||||||
|
of vertices in weighted adjacency matrix.")
|
||||||
|
|
||||||
|
mat = utility.calculate_laplacian_matrix(adj_mat, cfg.mat_type)
|
||||||
|
conv_matrix = Tensor(Tensor.from_numpy(mat), mstype.float32)
|
||||||
|
if cfg.graph_conv_type == "chebconv":
|
||||||
|
if (cfg.mat_type != "wid_sym_normd_lap_mat") and (cfg.mat_type != "wid_rw_normd_lap_mat"):
|
||||||
|
raise ValueError(f'ERROR: {cfg.mat_type} is wrong.')
|
||||||
|
elif cfg.graph_conv_type == "gcnconv":
|
||||||
|
if (cfg.mat_type != "hat_sym_normd_lap_mat") and (cfg.mat_type != "hat_rw_normd_lap_mat"):
|
||||||
|
raise ValueError(f'ERROR: {cfg.mat_type} is wrong.')
|
||||||
|
|
||||||
|
stgcn_conv = models.STGCN_Conv(cfg.Kt, cfg.Ks, blocks, cfg.n_his, n_vertex, \
|
||||||
|
cfg.gated_act_func, cfg.graph_conv_type, conv_matrix, cfg.drop_rate)
|
||||||
|
net = stgcn_conv
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
#start training
|
||||||
|
|
||||||
|
zscore = preprocessing.StandardScaler()
|
||||||
|
if args.run_modelarts or args.run_distribute:
|
||||||
|
dataset = dataloader.create_dataset(data_dir+args.data_path, cfg.batch_size, cfg.n_his, \
|
||||||
|
cfg.n_pred, zscore, False, device_num, device_id, mode=0)
|
||||||
|
else:
|
||||||
|
dataset = dataloader.create_dataset(data_dir+args.data_path, cfg.batch_size, cfg.n_his, \
|
||||||
|
cfg.n_pred, zscore, True, device_num, device_id, mode=0)
|
||||||
|
data_len = dataset.get_dataset_size()
|
||||||
|
|
||||||
|
learning_rate = nn.exponential_decay_lr(learning_rate=cfg.learning_rate, decay_rate=cfg.gamma, \
|
||||||
|
total_step=data_len*cfg.epochs, step_per_epoch=data_len, decay_epoch=cfg.decay_epoch)
|
||||||
|
if args.opt == "RMSProp":
|
||||||
|
optimizer = nn.RMSProp(net.trainable_params(), learning_rate=learning_rate)
|
||||||
|
elif args.opt == "Adam":
|
||||||
|
optimizer = nn.Adam(net.trainable_params(), learning_rate=learning_rate, \
|
||||||
|
weight_decay=cfg.weight_decay_rate)
|
||||||
|
elif args.opt == "AdamW":
|
||||||
|
optimizer = nn.AdamWeightDecay(net.trainable_params(), learning_rate=learning_rate, \
|
||||||
|
weight_decay=cfg.weight_decay_rate)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'ERROR: optimizer {args.opt} is undefined.')
|
||||||
|
|
||||||
|
loss_cb = LossMonitor()
|
||||||
|
time_cb = TimeMonitor(data_size=data_len)
|
||||||
|
callbacks = [time_cb, loss_cb]
|
||||||
|
|
||||||
|
#save training results
|
||||||
|
if args.save_check_point and (device_num == 1 or device_id == 0):
|
||||||
|
config_ck = CheckpointConfig(
|
||||||
|
save_checkpoint_steps=data_len*cfg.epochs, keep_checkpoint_max=cfg.epochs)
|
||||||
|
if args.run_modelarts:
|
||||||
|
ckpoint_cb = ModelCheckpoint(prefix='STGCN'+cfg.graph_conv_type+str(cfg.n_pred)+'-', \
|
||||||
|
directory=local_train_url, config=config_ck)
|
||||||
|
else:
|
||||||
|
ckpoint_cb = ModelCheckpoint(prefix='STGCN', directory=model_save_path, config=config_ck)
|
||||||
|
callbacks += [ckpoint_cb]
|
||||||
|
|
||||||
|
net = metric.LossCellWithNetwork(net)
|
||||||
|
model = Model(net, optimizer=optimizer, amp_level='O3')
|
||||||
|
|
||||||
|
model.train(cfg.epochs, dataset, callbacks=callbacks)
|
||||||
|
if args.run_modelarts:
|
||||||
|
mox.file.copy_parallel(src_url=local_train_url, dst_url=args.train_url)
|
Loading…
Reference in New Issue