run on Ascend with 8p

deepsort description

deepsort

run on Ascend with 8p

deepsort description

de

deepsort

de

deepsort

deepsort description

deepsort

remove

deepsort

run on Ascend with 8p

deepsort description

de

deepsort

de

deepsort

deepsort description

deepsort

dele

Deepsort

stgcn

rm

stgcn
This commit is contained in:
changshun 2021-04-26 11:08:43 +08:00 committed by fancyshun
parent 196d65c0bd
commit 129234a155
22 changed files with 2267 additions and 0 deletions

View File

@ -0,0 +1,255 @@
# Contents
- [STGCN 介绍](#STGCN-介绍)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [快速开始](#快速开始)
- [脚本介绍](#脚本介绍)
- [脚本以及简单代码](#脚本以及简单代码)
- [脚本参数](#脚本参数)
- [训练步骤](#训练步骤)
- [训练](#训练)
- [评估步骤](#评估步骤)
- [评估](#评估)
- [导出mindir模型](#导出mindir模型)
- [推理过程](#推理过程)
- [用法](#用法)
- [结果](#结果)
- [模型介绍](#模型介绍)
- [性能](#性能)
- [评估性能](#评估性能)
- [随机事件介绍](#随机事件介绍)
- [ModelZoo 主页](#ModelZoo-主页)
# [STGCN 介绍](#contents)
STGCN主要用于交通预测领域是一种时空卷积网络。在STGCN文章中提出一种新颖的深度学习框架——时空图卷积网络STGCN解决在通领域的时间序列预测问题。在定义图上的问题并用纯卷积结构建立模型这使得使用更少的参数能带来更快的训练速度。STGCN通过建模多尺度交通网络有效捕获全面的时空相关性且在各种真实世界交通数据集始终保持SOTA。
[Paper](https://arxiv.org/abs/1709.04875): Bing yu, Haoteng Yin, and Zhanxing Zhu. "Spatio-Temporal Graph Convolutional Networks:
A Deep Learning Framework for Traffic Forecasting." Proceedings of the 27th International Joint Conference on Artificial Intelligence. 2017.
# [模型架构](#contents)
STGCN模型结构是由两个时空卷积快和一个输出层构成。时空卷积块分为时域卷积块和空域卷积块。空域卷积块有两种不同卷积方式分别为Cheb和GCN。
# [数据集](#contents)
Dataset used:
PeMED7(PeMSD7-m、PeMSD7-L)
BJER4
由于数据集下载原因,只找到了[PeMSD7-M](https://github.com/hazdzz/STGCN/tree/main/data/train/road_traffic/pemsd7-m)数据集。
# [环境要求](#contents)
- 硬件Ascend/GPU
- 需要准备具有Ascend或GPU处理能力的硬件环境.
- 框架
- [MindSpore](https://www.mindspore.cn/install/en)
- 如需获取更多信息,请查看如下链接:
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [快速开始](#contents)
在通过官方网站安装MindSpore之后你可以通过如下步骤开始训练以及评估
- running on Ascend with default parameters
```python
# 单卡训练
python train.py --train_url="" --data_url="" --run_distribute=False --run_modelarts=False --graph_conv_type="chebgcn" --n_pred=9
# 多卡训练
bash scripts/run_distribute_train.sh train_code_path data_path n_pred graph_conv_type
```
# [脚本介绍](#contents)
## [脚本以及简单代码](#contents)
```python
├── STGCN
├── scripts
├── run_distribute_train.sh //traing on Ascend with 8P
├── run_eval_ascend.sh //testing on Ascend
├── src
├── model
├──layers.py // model layer
├──metric.py // network with losscell
├──models.py // network model
├──config.py // parameter
├──dataloder.py // creating dataset
├──utility.py // calculate laplacian matrix and evaluate metric
├──weight_init.py // layernorm weight init
├── train.py // traing network
├── test.py // tesing network performance
├── postprocess.py // compute accuracy for ascend310
├── preprocess.py // process dataset for ascend310
├── README.md // descriptions
```
## [脚本参数](#contents)
训练以及评估的参数可以在config.py中设置
- config for STGCN
```python
stgcn_chebconv_45min_cfg = edict({
'learning_rate': 0.003,
'n_his': 12,
'n_pred': 9,
'n_vertex': 228,
'epochs': 500,
'batch_size': 8,
'decay_epoch': 10,
'gamma': 0.7,
'stblock_num': 2,
'Ks': 2,
'Kt': 3,
'time_intvl': 5,
'drop_rate': 0.5,
'weight_decay_rate': 0.0005,
'gated_act_func':"glu",
'graph_conv_type': "chebconv",
'mat_type': "wid_sym_normd_lap_mat",
})
```
如需查看更多信息,请查看`config.py`.
## [训练步骤](#contents)
### 训练
- running on Ascend
```python
#1P训练
python train.py --train_url="" --data_url="" --run_distribute=False --run_modelarts=True --graph_conv_type="chebgcn" --n_pred=9
#8P训练
bash scripts/run_distribute_train.sh train_code_path data_path n_pred graph_conv_type
```
8P训练时需要将RANK_TABLE_FILE放在scripts文件夹中RANK_TABLE_FILE[生成方法](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)
训练时训练过程中的epch和step以及此时的loss和精确度会呈现在终端上
```python
epoch: 1 step: 139, loss is 0.429
epoch time: 203885.163 ms, per step time: 1466.800 ms
epoch: 2 step: 139, loss is 0.2097
epoch time: 6330.939 ms, per step time: 45.546 ms
epoch: 3 step: 139, loss is 0.4192
epoch time: 6364.882 ms, per step time: 45.791 ms
epoch: 4 step: 139, loss is 0.2917
epoch time: 6378.299 ms, per step time: 45.887 ms
epoch: 5 step: 139, loss is 0.2365
epoch time: 6369.215 ms, per step time: 45.822 ms
epoch: 6 step: 139, loss is 0.2269
epoch time: 6389.238 ms, per step time: 45.966 ms
epoch: 7 step: 139, loss is 0.3071
epoch time: 6365.901 ms, per step time: 45.798 ms
epoch: 8 step: 139, loss is 0.2336
epoch time: 6358.127 ms, per step time: 45.742 ms
epoch: 9 step: 139, loss is 0.2812
epoch time: 6333.794 ms, per step time: 45.567 ms
epoch: 10 step: 139, loss is 0.2622
epoch time: 6334.013 ms, per step time: 45.568 ms
...
```
此模型的checkpoint存储在train_url路径中
## [评估步骤](#contents)
### 评估
- 在Ascend上使用PeMSD7-m 测试集进行评估
在使用命令运行时,需要传入模型参数地址、模型参数名称、空域卷积方式、预测时段。
```python
python test.py --run_modelarts=False --run_distribute=False --device_id=0 --ckpt_url="" --ckpt_name="" --graph_conv_type="" --n_pred=9
#使用脚本评估
bash scripts/run_eval_ascend.sh data_path ckpt_url ckpt_name device_id graph_conv_type n_pred
```
以上的python命令会在终端上运行你可以在终端上查看此次评估的结果。测试集的精确度会以如下方式呈现
```python
MAE 3.23 | MAPE 8.32 | RMSE 6.06
```
## [导出mindir模型](#contents)
```shell
python export.py --data_url [DATA_URL] --ckpt_file [CKPT_PATH] --n_pred [N_PRED] --graph_conv_type [GRAPH_CONV_TYPE] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
```
## [推理过程](#contents)
### 用法
执行推断之前minirir文件必须由export.py导出。输入文件必须为bin格式
```shell
# Ascend310 inference
bash run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [NEED_PREPROCESS] [DEVICE_TARGET] [DEVICE_ID]
```
### 结果
推理结果保存在当前路径中您可以在acc.log文件中找到结果
# [模型介绍](#contents)
## [性能](#contents)
### 评估性能
#### STGCN on PeMSD7-m (Cheb,n_pred=9)
| Parameters | ModelArts
| -------------------------- | -----------------------------------------------------------
| Model Version | STGCN
| Resource | Ascend 910 CPU 2.60GHz192coresMemory755G
| uploaded Date | 05/07/2021 (month/day/year)
| MindSpore Version | 1.2.0
| Dataset | PeMSD7-m
| Training Parameters | epoch=500, steps=139, batch_size = 8, lr=0.003
| Optimizer | AdamWeightDecay
| Loss Function | MES Loss
| outputs | probability
| Loss | 0.183
| Speed | 8pc: 45.601 ms/step;
| Scripts | [STGCN script]
### Inference Performance
#### STGCN on PeMSD7-m (Cheb,n_pred=9)
| Parameters | Ascend
| ------------------- | ---------------------------
| Model Version | STGCN
| Resource | Ascend 910
| Uploaded Date | 05/07/2021 (month/day/year)
| MindSpore Version | 1.2.0
| Dataset | PeMSD7-m
| batch_size | 8
| outputs | probability
| MAE | 3.23
| MAPE | 8.32
| RMSE | 6.06
| Model for inference | about 6M(.ckpt fil)
# [随机事件介绍](#contents)
我们在train.py中设置了随机种子
# [ModelZoo 主页](#contents)
请查看官方网站 [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,14 @@
cmake_minimum_required(VERSION 3.14.1)
project(Ascend310Infer)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
option(MINDSPORE_PATH "mindspore install path" "")
include_directories(${MINDSPORE_PATH})
include_directories(${MINDSPORE_PATH}/include)
include_directories(${PROJECT_SRC_ROOT})
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
add_executable(main src/main.cc src/utils.cc)
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)

View File

@ -0,0 +1,29 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ -d out ]; then
rm -rf out
fi
mkdir out
cd out || exit
if [ -f "Makefile" ]; then
make clean
fi
cmake .. \
-DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
make

View File

@ -0,0 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INFERENCE_UTILS_H_
#define MINDSPORE_INFERENCE_UTILS_H_
#include <sys/stat.h>
#include <dirent.h>
#include <vector>
#include <string>
#include <memory>
#include "include/api/types.h"
std::vector<std::string> GetAllFiles(std::string_view dirName);
DIR *OpenDir(std::string_view dirName);
std::string RealPath(std::string_view path);
mindspore::MSTensor ReadFileToTensor(const std::string &file);
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
#endif

View File

@ -0,0 +1,134 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/time.h>
#include <gflags/gflags.h>
#include <dirent.h>
#include <iostream>
#include <string>
#include <algorithm>
#include <iosfwd>
#include <vector>
#include <fstream>
#include <sstream>
#include "include/api/model.h"
#include "include/api/context.h"
#include "include/api/types.h"
#include "include/api/serialization.h"
#include "include/dataset/execute.h"
#include "include/dataset/vision.h"
#include "inc/utils.h"
using mindspore::Context;
using mindspore::Serialization;
using mindspore::Model;
using mindspore::Status;
using mindspore::MSTensor;
using mindspore::dataset::Execute;
using mindspore::ModelType;
using mindspore::GraphCell;
using mindspore::kSuccess;
DEFINE_string(mindir_path, "", "mindir path");
DEFINE_string(input0_path, ".", "input0 path");
DEFINE_int32(device_id, 0, "device id");
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (RealPath(FLAGS_mindir_path).empty()) {
std::cout << "Invalid mindir" << std::endl;
return 1;
}
auto context = std::make_shared<Context>();
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310->SetDeviceID(FLAGS_device_id);
context->MutableDeviceInfo().push_back(ascend310);
mindspore::Graph graph;
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
Model model;
Status ret = model.Build(GraphCell(graph), context);
if (ret != kSuccess) {
std::cout << "ERROR: Build failed." << std::endl;
return 1;
}
std::vector<MSTensor> model_inputs = model.GetInputs();
if (model_inputs.empty()) {
std::cout << "Invalid model, inputs is empty." << std::endl;
return 1;
}
auto input0_files = GetAllFiles(FLAGS_input0_path);
if (input0_files.empty()) {
std::cout << "ERROR: input data empty." << std::endl;
return 1;
}
std::map<double, double> costTime_map;
size_t size = input0_files.size();
for (size_t i = 0; i < size; ++i) {
struct timeval start = {0};
struct timeval end = {0};
double startTimeMs;
double endTimeMs;
std::vector<MSTensor> inputs;
std::vector<MSTensor> outputs;
std::cout << "Start predict input files:" << input0_files[i] << std::endl;
auto input0 = ReadFileToTensor(input0_files[i]);
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
input0.Data().get(), input0.DataSize());
for (auto shape : model_inputs[0].Shape()) {
std::cout << "model input shape" << shape << std::endl;
}
gettimeofday(&start, nullptr);
ret = model.Predict(inputs, &outputs);
gettimeofday(&end, nullptr);
if (ret != kSuccess) {
std::cout << "Predict " << input0_files[i] << " failed." << std::endl;
return 1;
}
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
WriteResult(input0_files[i], outputs);
}
double average = 0.0;
int inferCount = 0;
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
double diff = 0.0;
diff = iter->second - iter->first;
average += diff;
inferCount++;
}
average = average / inferCount;
std::stringstream timeCost;
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl;
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl;
std::string fileName = "./time_Result" + std::string("/test_perform_static.txt");
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
fileStream << timeCost.str();
fileStream.close();
costTime_map.clear();
return 0;
}

View File

@ -0,0 +1,130 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <fstream>
#include <algorithm>
#include <iostream>
#include "inc/utils.h"
using mindspore::MSTensor;
using mindspore::DataType;
std::vector<std::string> GetAllFiles(std::string_view dirName) {
struct dirent *filename;
DIR *dir = OpenDir(dirName);
if (dir == nullptr) {
return {};
}
std::vector<std::string> res;
while ((filename = readdir(dir)) != nullptr) {
std::string dName = std::string(filename->d_name);
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
continue;
}
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
}
std::sort(res.begin(), res.end());
for (auto &f : res) {
std::cout << "image file: " << f << std::endl;
}
return res;
}
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
std::string homePath = "./result_Files";
for (size_t i = 0; i < outputs.size(); ++i) {
size_t outputSize;
std::shared_ptr<const void> netOutput;
netOutput = outputs[i].Data();
outputSize = outputs[i].DataSize();
std::cout << "output size:" << outputSize << std::endl;
int pos = imageFile.rfind('/');
std::string fileName(imageFile, pos + 1);
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
std::string outFileName = homePath + "/" + fileName;
FILE * outputFile = fopen(outFileName.c_str(), "wb");
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
fclose(outputFile);
outputFile = nullptr;
}
return 0;
}
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
if (file.empty()) {
std::cout << "Pointer file is nullptr" << std::endl;
return mindspore::MSTensor();
}
std::ifstream ifs(file);
if (!ifs.good()) {
std::cout << "File: " << file << " is not exist" << std::endl;
return mindspore::MSTensor();
}
if (!ifs.is_open()) {
std::cout << "File: " << file << "open failed" << std::endl;
return mindspore::MSTensor();
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
ifs.seekg(0, std::ios::beg);
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
ifs.close();
return buffer;
}
DIR *OpenDir(std::string_view dirName) {
if (dirName.empty()) {
std::cout << " dirName is null ! " << std::endl;
return nullptr;
}
std::string realPath = RealPath(dirName);
struct stat s;
lstat(realPath.c_str(), &s);
if (!S_ISDIR(s.st_mode)) {
std::cout << "dirName is not a valid directory !" << std::endl;
return nullptr;
}
DIR *dir;
dir = opendir(realPath.c_str());
if (dir == nullptr) {
std::cout << "Can not open dir " << dirName << std::endl;
return nullptr;
}
std::cout << "Successfully opened the dir " << dirName << std::endl;
return dir;
}
std::string RealPath(std::string_view path) {
char realPathMem[PATH_MAX] = {0};
char *realPathRet = nullptr;
realPathRet = realpath(path.data(), realPathMem);
if (realPathRet == nullptr) {
std::cout << "File: " << path << " is not exist.";
return "";
}
std::string realPath(realPathMem);
std::cout << path << " realpath is: " << realPath << std::endl;
return realPath;
}

View File

@ -0,0 +1,131 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############export checkpoint file into air, onnx, mindir models#################
python export.py
"""
import argparse
import numpy as np
import pandas as pd
import mindspore as ms
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from src import dataloader, utility
from src.config import stgcn_chebconv_45min_cfg, stgcn_chebconv_30min_cfg, stgcn_chebconv_15min_cfg, stgcn_gcnconv_45min_cfg, stgcn_gcnconv_30min_cfg, stgcn_gcnconv_15min_cfg
from src.model import models
parser = argparse.ArgumentParser(description='Tracking')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_url', type=str, help='Train dataset directory.')
parser.add_argument('--data_path', type=str, default="vel.csv", help='Dataset file of vel.')
parser.add_argument('--wam_path', type=str, default="adj_mat.csv", help='Dataset file of warm.')
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--n_pred", type=int, default=3, help="The number of time interval for predcition.")
parser.add_argument("--graph_conv_type", type=str, default="chebconv", help="Grapg convolution type.")
parser.add_argument("--file_name", type=str, default="stgcn", help="output file name.")
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
if args.graph_conv_type == "chebconv":
if args.n_pred == 9:
cfg = stgcn_chebconv_45min_cfg
elif args.n_pred == 6:
cfg = stgcn_chebconv_30min_cfg
elif args.n_pred == 3:
cfg = stgcn_chebconv_15min_cfg
else:
raise ValueError("Unsupported n_pred.")
elif args.graph_conv_type == "gcnconv":
if args.n_pred == 9:
cfg = stgcn_gcnconv_45min_cfg
elif args.n_pred == 6:
cfg = stgcn_gcnconv_30min_cfg
elif args.n_pred == 3:
cfg = stgcn_gcnconv_15min_cfg
else:
raise ValueError("Unsupported pred.")
else:
raise ValueError("Unsupported graph_conv_type.")
if ((cfg.Kt - 1) * 2 * cfg.stblock_num > cfg.n_his) or ((cfg.Kt - 1) * 2 * cfg.stblock_num <= 0):
raise ValueError(f'ERROR: {cfg.Kt} and {cfg.stblock_num} are unacceptable.')
Ko = cfg.n_his - (cfg.Kt - 1) * 2 * cfg.stblock_num
if (cfg.graph_conv_type != "chebconv") and (cfg.graph_conv_type != "gcnconv"):
raise NotImplementedError(f'ERROR: {cfg.graph_conv_type} is not implemented.')
if (cfg.graph_conv_type == 'gcnconv') and (cfg.Ks != 2):
cfg.Ks = 2
# blocks: settings of channel size in st_conv_blocks and output layer,
# using the bottleneck design in st_conv_blocks
blocks = []
blocks.append([1])
for l in range(cfg.stblock_num):
blocks.append([64, 16, 64])
if Ko == 0:
blocks.append([128])
elif Ko > 0:
blocks.append([128, 128])
blocks.append([1])
day_slot = int(24 * 60 / cfg.time_intvl)
cfg.n_pred = cfg.n_pred
time_pred = cfg.n_pred * cfg.time_intvl
time_pred_str = str(time_pred) + '_mins'
context.set_context(device_id=args.device_id)
device_num = 1
cfg.batch_size = cfg.batch_size*int(8/device_num)
device_id = args.device_id
data_dir = args.data_url + '/'
adj_mat = dataloader.load_weighted_adjacency_matrix(data_dir+args.wam_path)
n_vertex_vel = pd.read_csv(data_dir+args.data_path, header=None).shape[1]
n_vertex_adj = pd.read_csv(data_dir+args.wam_path, header=None).shape[1]
if n_vertex_vel == n_vertex_adj:
n_vertex = n_vertex_vel
else:
raise ValueError(f'ERROR: number of vertices in dataset is not equal to number of \
vertices in weighted adjacency matrix.')
mat = utility.calculate_laplacian_matrix(adj_mat, cfg.mat_type)
conv_matrix = Tensor(Tensor.from_numpy(mat), ms.float32)
if cfg.graph_conv_type == "chebconv":
if (cfg.mat_type != "wid_sym_normd_lap_mat") and (cfg.mat_type != "wid_rw_normd_lap_mat"):
raise ValueError(f'ERROR: {cfg.mat_type} is wrong.')
elif cfg.graph_conv_type == "gcnconv":
if (cfg.mat_type != "hat_sym_normd_lap_mat") and (cfg.mat_type != "hat_rw_normd_lap_mat"):
raise ValueError(f'ERROR: {cfg.mat_type} is wrong.')
stgcn_conv = models.STGCN_Conv(cfg.Kt, cfg.Ks, blocks, cfg.n_his, n_vertex, \
cfg.gated_act_func, cfg.graph_conv_type, conv_matrix, cfg.drop_rate)
net = stgcn_conv
if __name__ == '__main__':
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(net, param_dict)
input_arr = Tensor(np.zeros([args.batch_size, 1, cfg.n_his, n_vertex]), ms.float32)
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)

View File

@ -0,0 +1,66 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""compute acc for ascend 310"""
import os
import argparse
import numpy as np
from src.config import stgcn_chebconv_45min_cfg
from src import dataloader
from sklearn import preprocessing
parser = argparse.ArgumentParser('mindspore stgcn testing')
# Path for data
parser.add_argument('--data_url', type=str, default='./data/', help='Test dataset directory.')
parser.add_argument('--label_dir', type=str, default='', help='label data directory.')
parser.add_argument('--result_dir', type=str, default="./result_Files", help='infer result dir.')
parser.add_argument('--data_path', type=str, default="vel.csv", help='Dataset file of vel.')
# Super parameters for testing
parser.add_argument('--n_pred', type=int, default=9, help='The number of time interval for predcition')
args, _ = parser.parse_known_args()
cfg = stgcn_chebconv_45min_cfg
cfg.batch_size = 1
if __name__ == "__main__":
zscore = preprocessing.StandardScaler()
rst_path = args.result_dir
labels = np.load(args.label_dir)
dataset = dataloader.create_dataset(args.data_url+args.data_path, \
cfg.batch_size, cfg.n_his, cfg.n_pred, zscore, True, mode=2)
mae, sum_y, mape, mse = [], [], [], []
for i in range(len(os.listdir(rst_path))):
file_name = os.path.join(rst_path, "STGCN_data_bs" + str(cfg.batch_size) + '_' + str(i) + '_0.bin')
output = np.fromfile(file_name, np.float16)
output = zscore.inverse_transform(output)
label = zscore.inverse_transform(labels[i])
d = np.abs(label - output)
mae += d.tolist()
sum_y += label.tolist()
mape += (d / label).tolist()
mse += (d ** 2).tolist()
MAE = np.array(mae).mean()
MAPE = np.array(mape).mean()
RMSE = np.sqrt(np.array(mse).mean())
print(f'MAE {MAE:.2f} | MAPE {MAPE*100:.2f} | RMSE {RMSE:.2f}')

View File

@ -0,0 +1,59 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""generate dataset for ascend 310"""
import os
import argparse
import numpy as np
from src.config import stgcn_chebconv_45min_cfg
from src import dataloader
from sklearn import preprocessing
parser = argparse.ArgumentParser('mindspore stgcn testing')
parser.add_argument('--device_target', type=str, default='Ascend', \
help='device where the code will be implemented. (Default: Ascend)')
# Path for data and checkpoint
parser.add_argument('--data_url', type=str, default='', help='Test dataset directory.')
parser.add_argument('--data_path', type=str, default="vel.csv", help='Dataset file of vel.')
parser.add_argument('--result_path', type=str, default='./preprocess_Result/', help='result path')
# Super parameters for testing
parser.add_argument('--n_pred', type=int, default=9, help='The number of time interval for predcition')
args, _ = parser.parse_known_args()
cfg = stgcn_chebconv_45min_cfg
cfg.batch_size = 1
if __name__ == "__main__":
zscore = preprocessing.StandardScaler()
dataset = dataloader.create_dataset(args.data_url+args.data_path, \
cfg.batch_size, cfg.n_his, cfg.n_pred, zscore, True, mode=2)
img_path = os.path.join(args.result_path, "00_data")
os.mkdir(img_path)
label_list = []
# dataset is an instance of Dataset object
iterator = dataset.create_dict_iterator(output_numpy=True)
for i, data in enumerate(iterator):
file_name = "STGCN_data_bs" + str(cfg.batch_size) + "_" + str(i) + ".bin"
file_path = img_path + "/" + file_name
data['inputs'].tofile(file_path)
label_list.append(data['labels'])
np.save(args.result_path + "label_ids.npy", label_list)
print("="*20, "export bin files finished", "="*20)

View File

@ -0,0 +1,83 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 4 ]; then
echo "Usage: sh run_distribute_train.sh [train_code_path][data_path][n_pred][graph_conv_type]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
train_code_path=$(get_real_path $1)
echo $train_code_path
if [ ! -d $train_code_path ]
then
echo "error: train_code_path=$train_code_path is not a dictionary."
exit 1
fi
data_path=$(get_real_path $2)
echo $train_code_path
if [ ! -d $data_path ]
then
echo "error: train_code_path=$train_code_path is not a dictionary."
exit 1
fi
ulimit -c unlimited
export SLOG_PRINT_TO_STDOUT=0
export RANK_TABLE_FILE=${train_code_path}/scripts/hccl_8p_01234567_127.0.0.1.json
export RANK_SIZE=8
export RANK_START_ID=0
export n_pred=$3
export graph_conv_type=$4
for((i=0;i<=$RANK_SIZE-1;i++));
do
export RANK_ID=${i}
export DEVICE_ID=$((i + RANK_START_ID))
echo 'start rank='${i}', device id='${DEVICE_ID}'...'
if [ -d ${train_code_path}/device${DEVICE_ID} ]; then
rm -rf ${train_code_path}/device${DEVICE_ID}
fi
mkdir ${train_code_path}/device${DEVICE_ID}
cd ${train_code_path}/device${DEVICE_ID} || exit
python ${train_code_path}/train.py --data_url=${data_path} \
--train_url=./checkpoint \
--run_distribute=True \
--run_modelarts=False \
--n_pred=$n_pred \
--graph_conv_type=$graph_conv_type > out.log 2>&1 &
done

View File

@ -0,0 +1,39 @@
#!/bin/bash
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 6 ]
then
echo "Usage: sh run_standalone_eval_ascend.sh [data_path][ckpt_url][ckpt_name][device_id][graph_conv_type][n_pred]"
exit 1
fi
export Data_path=$1
export Ckpt_path=$2
export Ckpt_name=$3
export Device_id=$4
export Graph_conv_type=$5
export N_pred=$6
python test.py --data_url=$Data_path \
--train_url=./checkpoint \
--run_distribute=False \
--run_modelarts=False \
--device_id=$Device_id \
--ckpt_url=$Ckpt_path \
--ckpt_name=$Ckpt_name \
--n_pred=$N_pred \
--graph_conv_type=$Graph_conv_type > test.log 2>&1 &

View File

@ -0,0 +1,130 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -lt 4 || $# -gt 5 ]]; then
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [NEED_PREPROCESS] [DEVICE_TARGET] [DEVICE_ID]
DEVICE_TARGET must choose from ['GPU', 'CPU', 'Ascend']
NEED_PREPROCESS means weather need preprocess or not, it's value is 'y' or 'n'.
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
model=$(get_real_path $1)
dataset_path=$(get_real_path $2)
if [ "$3" == "y" ] || [ "$3" == "n" ];then
need_preprocess=$3
else
echo "weather need preprocess or not, it's value must be in [y, n]"
exit 1
fi
if [ "$4" == "GPU" ] || [ "$4" == "CPU" ] || [ "$4" == "Ascend" ];then
device_target=$4
else
echo "device_target must be in ['GPU', 'CPU', 'Ascend']"
exit 1
fi
device_id=0
if [ $# == 5 ]; then
device_id=$5
fi
echo "mindir name: "$model
echo "dataset path: "$dataset_path
echo "need preprocess: "$need_preprocess
echo "device_target: "$device_target
echo "device id: "$device_id
export ASCEND_HOME=/usr/local/Ascend/
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
fi
function preprocess_data()
{
if [ -d preprocess_Result ]; then
rm -rf ./preprocess_Result
fi
mkdir preprocess_Result
python3.7 ../preprocess.py --data_url=$dataset_path --result_path=./preprocess_Result/ --device_target=$device_target &>preprocess.log
}
function compile_app()
{
cd ../ascend310_infer || exit
bash build.sh &> build.log
}
function infer()
{
cd - || exit
if [ -d result_Files ]; then
rm -rf ./result_Files
fi
if [ -d time_Result ]; then
rm -rf ./time_Result
fi
mkdir result_Files
mkdir time_Result
../ascend310_infer/out/main --mindir_path=$model --input0_path=./preprocess_Result/00_data --device_id=$device_id &> infer.log
}
function cal_acc()
{
python3.7 ../postprocess.py --data_url=$dataset_path --result_dir=./result_Files --label_dir=./preprocess_Result/label_ids.npy --device_target=$device_target &> acc.log
}
if [ $need_preprocess == "y" ]; then
preprocess_data
if [ $? -ne 0 ]; then
echo "preprocess dataset failed"
exit 1
fi
fi
compile_app
if [ $? -ne 0 ]; then
echo "compile app code failed"
exit 1
fi
infer
if [ $? -ne 0 ]; then
echo " execute inference failed"
exit 1
fi
cal_acc
if [ $? -ne 0 ]; then
echo "calculate accuracy failed"
exit 1
fi

View File

@ -0,0 +1,133 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py
"""
from easydict import EasyDict as edict
stgcn_chebconv_45min_cfg = edict({
'learning_rate': 0.003,
'n_his': 12,
'n_pred': 9,
'epochs': 500,
'batch_size': 8,
'decay_epoch': 10,
'gamma': 0.7,
'stblock_num': 2,
'Ks': 2,
'Kt': 3,
'time_intvl': 5,
'drop_rate': 0.5,
'weight_decay_rate': 0.0005,
'gated_act_func': "glu",
'graph_conv_type': "chebconv",
'mat_type': "wid_sym_normd_lap_mat",
})
stgcn_chebconv_30min_cfg = edict({
'learning_rate': 0.003,
'n_his': 12,
'n_pred': 6,
'epochs': 500,
'batch_size': 8,
'decay_epoch': 10,
'gamma': 0.7,
'stblock_num': 2,
'Ks': 2,
'Kt': 3,
'time_intvl': 5,
'drop_rate': 0.5,
'weight_decay_rate': 0.0005,
'gated_act_func': "glu",
'graph_conv_type': "chebconv",
'mat_type': "wid_sym_normd_lap_mat",
})
stgcn_chebconv_15min_cfg = edict({
'learning_rate': 0.002,
'n_his': 12,
'n_pred': 3,
'epochs': 100,
'batch_size': 8,
'decay_epoch': 10,
'gamma': 0.999,
'stblock_num': 2,
'Ks': 2,
'Kt': 3,
'time_intvl': 5,
'drop_rate': 0.5,
'weight_decay_rate': 0.0005,
'gated_act_func': "glu",
'graph_conv_type': "chebconv",
'mat_type': "wid_rw_normd_lap_mat",
})
stgcn_gcnconv_45min_cfg = edict({
'learning_rate': 0.003,
'n_his': 12,
'n_pred': 9,
'epochs': 500,
'batch_size': 8,
'decay_epoch': 10,
'gamma': 0.7,
'stblock_num': 2,
'Ks': 2,
'Kt': 3,
'time_intvl': 5,
'drop_rate': 0.5,
'weight_decay_rate': 0.0005,
'gated_act_func': "glu",
'graph_conv_type': "gcnconv",
'mat_type': "hat_sym_normd_lap_mat",
})
stgcn_gcnconv_30min_cfg = edict({
'learning_rate': 0.003,
'n_his': 12,
'n_pred': 6,
'epochs': 500,
'batch_size': 8,
'decay_epoch': 10,
'gamma': 0.7,
'stblock_num': 2,
'Ks': 2,
'Kt': 3,
'time_intvl': 5,
'drop_rate': 0.5,
'weight_decay_rate': 0.0005,
'gated_act_func': "glu",
'graph_conv_type': "gcnconv",
'mat_type': "hat_sym_normd_lap_mat",
})
stgcn_gcnconv_15min_cfg = edict({
'learning_rate': 0.002,
'n_his': 12,
'n_pred': 3,
'epochs': 100,
'batch_size': 8,
'decay_epoch': 10,
'gamma': 0.9999,
'stblock_num': 2,
'Ks': 2,
'Kt': 3,
'time_intvl': 5,
'drop_rate': 0.5,
'weight_decay_rate': 0.0005,
'gated_act_func': "glu",
'graph_conv_type': "gcnconv",
'mat_type': "hat_rw_normd_lap_mat",
})

View File

@ -0,0 +1,106 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
process dataset.
"""
import math
import numpy as np
import pandas as pd
import mindspore.dataset as ds
class STGCNDataset:
""" BRDNetDataset.
Args:
mode: 0 means train;1 means val;2 means test
"""
def __init__(self, data_path, n_his, n_pred, zscore, mode=0):
self.df = pd.read_csv(data_path, header=None)
self.data_col = self.df.shape[0]
# recommended dataset split rate as train: val: test = 60: 20: 20, 70: 15: 15 or 80: 10: 10
# using dataset split rate as train: val: test = 70: 15: 15
self.val_and_test_rate = 0.15
self.len_val = int(math.floor(self.data_col * self.val_and_test_rate))
self.len_test = int(math.floor(self.data_col * self.val_and_test_rate))
self.len_train = int(self.data_col - self.len_val - self.len_test)
self.dataset_train = self.df[: self.len_train]
self.dataset_val = self.df[self.len_train: self.len_train + self.len_val]
self.dataset_test = self.df[self.len_train + self.len_val:]
self.dataset_train = zscore.fit_transform(self.dataset_train)
self.dataset_val = zscore.transform(self.dataset_val)
self.dataset_test = zscore.transform(self.dataset_test)
if mode == 0:
self.dataset = self.dataset_train
elif mode == 1:
self.dataset = self.dataset_val
else:
self.dataset = self.dataset_test
self.n_his = n_his
self.n_pred = n_pred
self.n_vertex = self.dataset.shape[1]
self.len_record = len(self.dataset)
self.num = self.len_record - self.n_his - self.n_pred
self.x = np.zeros([self.num, 1, self.n_his, self.n_vertex], np.float32)
self.y = np.zeros([self.num, self.n_vertex], np.float32)
for i in range(self.num):
head = i
tail = i + self.n_his
self.x[i, :, :, :] = self.dataset[head: tail].reshape(1, self.n_his, self.n_vertex)
self.y[i] = self.dataset[tail + self.n_pred - 1]
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
x[index]: input of network
y[index]: label of network
"""
return self.x[index], self.y[index]
def __len__(self):
return self.num
def load_weighted_adjacency_matrix(file_path):
df = pd.read_csv(file_path, header=None)
return df.to_numpy()
def create_dataset(data_path, batch_size, n_his, n_pred, zscore, is_sigle, device_num=1, device_id=0, mode=0):
"""
generate dataset for train or test.
"""
data = STGCNDataset(data_path, n_his, n_pred, zscore, mode=mode)
shuffle = True
if mode != 0:
shuffle = False
if not is_sigle:
dataset = ds.GeneratorDataset(data, column_names=["inputs", "labels"], num_parallel_workers=32, \
shuffle=shuffle, num_shards=device_num, shard_id=device_id)
else:
dataset = ds.GeneratorDataset(data, column_names=["inputs", "labels"], num_parallel_workers=32, shuffle=shuffle)
dataset = dataset.batch(batch_size)
return dataset

View File

@ -0,0 +1,306 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""network layer"""
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.common.dtype as mstype
from mindspore.common.initializer import initializer
class Align(nn.Cell):
"""align"""
def __init__(self, c_in, c_out):
super(Align, self).__init__()
self.c_in = c_in
self.c_out = c_out
self.align_conv = nn.Conv2d(in_channels=self.c_in, out_channels=self.c_out, kernel_size=1, \
pad_mode='valid', weight_init='he_uniform')
self.concat = ops.Concat(axis=1)
self.zeros = ops.Zeros()
def construct(self, x):
x_align = x
if self.c_in > self.c_out:
x_align = self.align_conv(x)
elif self.c_in < self.c_out:
batch_size, _, timestep, n_vertex = x.shape
y = self.zeros((batch_size, self.c_out - self.c_in, timestep, n_vertex), x.dtype)
x_align = self.concat((x, y))
return x_align
class CausalConv2d(nn.Cell):
"""causal conv2d"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, \
enable_padding=False, dilation=1, groups=1, bias=True):
super(CausalConv2d, self).__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation)
if enable_padding:
self.__padding = [int((kernel_size[i] - 1) * dilation[i]) for i in range(len(kernel_size))]
else:
self.__padding = 0
if isinstance(self.__padding, int):
self.left_padding = (self.__padding, self.__padding)
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, \
padding=0, pad_mode='valid', dilation=dilation, group=groups, has_bias=bias, weight_init='he_uniform')
self.pad = ops.Pad(((0, 0), (0, 0), (self.left_padding[0], 0), (self.left_padding[1], 0)))
def construct(self, x):
if self.__padding != 0:
x = self.pad(x)
result = self.conv2d(x)
return result
class TemporalConvLayer(nn.Cell):
"""
# Temporal Convolution Layer (GLU)
#
# |-------------------------------| * residual connection *
# | |
# | |--->--- casual conv ----- + -------|
# -------|----| ⊙ ------>
# |--->--- casual conv --- sigmoid ---|
#
#param x: tensor, [batch_size, c_in, timestep, n_vertex]
"""
def __init__(self, Kt, c_in, c_out, n_vertex, act_func):
super(TemporalConvLayer, self).__init__()
self.Kt = Kt
self.c_in = c_in
self.c_out = c_out
self.n_vertex = n_vertex
self.act_func = act_func
self.align = Align(self.c_in, self.c_out)
self.causal_conv = CausalConv2d(in_channels=self.c_in, out_channels=2 * self.c_out, \
kernel_size=(self.Kt, 1), enable_padding=False, dilation=1)
self.linear = nn.Dense(self.n_vertex, self.n_vertex).to_float(mstype.float16)
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
self.add = ops.Add()
self.mul = ops.Mul()
self.split = ops.Split(axis=1, output_num=2)
def construct(self, x):
"""TemporalConvLayer compute"""
x_in = self.align(x)
x_in = x_in[:, :, self.Kt - 1:, :]
x_causal_conv = self.causal_conv(x)
x_tc_out = x_causal_conv
x_pq = self.split(x_tc_out)
x_p = x_pq[0]
x_q = x_pq[1]
x_glu = x_causal_conv
x_gtu = x_causal_conv
if self.act_func == 'glu':
# (x_p + x_in) ⊙ Sigmoid(x_q)
x_glu = self.mul(self.add(x_p, x_in), self.sigmoid(x_q))
x_tc_out = x_glu
# Temporal Convolution Layer (GTU)
elif self.act_func == 'gtu':
# Tanh(x_p + x_in) ⊙ Sigmoid(x_q)
x_gtu = self.mul(self.tanh(self.add(x_p, x_in)), self.sigmoid(x_q))
x_tc_out = x_gtu
return x_tc_out
class ChebConv(nn.Cell):
"""cheb conv"""
def __init__(self, c_in, c_out, Ks, chebconv_matrix):
super(ChebConv, self).__init__()
self.c_in = c_in
self.c_out = c_out
self.Ks = Ks
self.chebconv_matrix = chebconv_matrix
self.matmul = ops.MatMul()
self.stack = ops.Stack(axis=0)
self.reshape = ops.Reshape()
self.bias_add = ops.BiasAdd()
self.weight = ms.Parameter(initializer('normal', (self.Ks, self.c_in, self.c_out)), name='weight')
self.bias = ms.Parameter(initializer('Uniform', [self.c_out]), name='bias')
def construct(self, x):
"""chebconv compute"""
_, c_in, _, n_vertex = x.shape
# Using recurrence relation to reduce time complexity from O(n^2) to O(K|E|),
# where K = Ks - 1
x = self.reshape(x, (n_vertex, -1))
x_0 = x
x_1 = self.matmul(self.chebconv_matrix, x)
x_list = []
if self.Ks - 1 == 0:
x_list = [x_0]
elif self.Ks - 1 == 1:
x_list = [x_0, x_1]
elif self.Ks - 1 >= 2:
x_list = [x_0, x_1]
for k in range(2, self.Ks):
x_list.append(self.matmul(2 * self.chebconv_matrix, x_list[k - 1]) - x_list[k - 2])
x_tensor = self.stack(x_list)
x_mul = self.matmul(self.reshape(x_tensor, (-1, self.Ks * c_in)), self.reshape(self.weight, \
(self.Ks * c_in, -1)))
x_mul = self.reshape(x_mul, (-1, self.c_out))
x_chebconv = self.bias_add(x_mul, self.bias)
return x_chebconv
class GCNConv(nn.Cell):
"""gcn conv"""
def __init__(self, c_in, c_out, gcnconv_matrix):
super(GCNConv, self).__init__()
self.c_in = c_in
self.c_out = c_out
self.gcnconv_matrix = gcnconv_matrix
self.matmul = ops.MatMul()
self.reshape = ops.Reshape()
self.bias_add = ops.BiasAdd()
self.weight = ms.Parameter(initializer('he_uniform', (self.c_in, self.c_out)), name='weight')
self.bias = ms.Parameter(initializer('Uniform', [self.c_out]), name='bias')
def construct(self, x):
"""gcnconv compute"""
_, c_in, _, n_vertex = x.shape
x_first_mul = self.matmul(self.reshape(x, (-1, c_in)), self.weight)
x_first_mul = self.reshape(x_first_mul, (n_vertex, -1))
x_second_mul = self.matmul(self.gcnconv_matrix, x_first_mul)
x_second_mul = self.reshape(x_second_mul, (-1, self.c_out))
if self.bias is not None:
x_gcnconv_out = self.bias_add(x_second_mul, self.bias)
else:
x_gcnconv_out = x_second_mul
return x_gcnconv_out
class GraphConvLayer(nn.Cell):
"""grarh conv layer"""
def __init__(self, Ks, c_in, c_out, graph_conv_type, graph_conv_matrix):
super(GraphConvLayer, self).__init__()
self.Ks = Ks
self.c_in = c_in
self.c_out = c_out
self.align = Align(self.c_in, self.c_out)
self.graph_conv_type = graph_conv_type
self.graph_conv_matrix = graph_conv_matrix
if self.graph_conv_type == "chebconv":
self.chebconv = ChebConv(self.c_out, self.c_out, self.Ks, self.graph_conv_matrix)
elif self.graph_conv_type == "gcnconv":
self.gcnconv = GCNConv(self.c_out, self.c_out, self.graph_conv_matrix)
self.reshape = ops.Reshape()
self.add = ops.Add()
def construct(self, x):
"""GraphConvLayer compute"""
x_gc_in = self.align(x)
batch_size, _, T, n_vertex = x_gc_in.shape
x_gc = x_gc_in
if self.graph_conv_type == "chebconv":
x_gc = self.chebconv(x_gc_in)
elif self.graph_conv_type == "gcnconv":
x_gc = self.gcnconv(x_gc_in)
x_gc_with_rc = self.add(self.reshape(x_gc, (batch_size, self.c_out, T, n_vertex)), x_gc_in)
x_gc_out = x_gc_with_rc
return x_gc_out
class STConvBlock(nn.Cell):
"""
# STConv Block contains 'TNSATND' structure
# T: Gated Temporal Convolution Layer (GLU or GTU)
# G: Graph Convolution Layer (ChebConv or GCNConv)
# T: Gated Temporal Convolution Layer (GLU or GTU)
# N: Layer Normolization
# D: Dropout
#Kt Ks n_vertex
"""
def __init__(self, Kt, Ks, n_vertex, last_block_channel, channels, gated_act_func, graph_conv_type, \
graph_conv_matrix, drop_rate):
super(STConvBlock, self).__init__()
self.Kt = Kt
self.Ks = Ks
self.n_vertex = n_vertex
self.last_block_channel = last_block_channel
self.channels = channels
self.gated_act_func = gated_act_func
self.enable_gated_act_func = True
self.graph_conv_type = graph_conv_type
self.graph_conv_matrix = graph_conv_matrix
self.drop_rate = drop_rate
self.tmp_conv1 = TemporalConvLayer(self.Kt, self.last_block_channel, self.channels[0], \
self.n_vertex, self.gated_act_func)
self.graph_conv = GraphConvLayer(self.Ks, self.channels[0], self.channels[1], \
self.graph_conv_type, self.graph_conv_matrix)
self.tmp_conv2 = TemporalConvLayer(self.Kt, self.channels[1], self.channels[2], \
self.n_vertex, self.gated_act_func)
self.tc2_ln = nn.LayerNorm([self.n_vertex, self.channels[2]], begin_norm_axis=2, \
begin_params_axis=2, epsilon=1e-05)
self.relu = nn.ReLU()
self.do = nn.Dropout(keep_prob=self.drop_rate)
self.transpose = ops.Transpose()
def construct(self, x):
"""STConvBlock compute"""
x_tmp_conv1 = self.tmp_conv1(x)
x_graph_conv = self.graph_conv(x_tmp_conv1)
x_act_func = self.relu(x_graph_conv)
x_tmp_conv2 = self.tmp_conv2(x_act_func)
x_tc2_ln = self.transpose(x_tmp_conv2, (0, 2, 3, 1))
x_tc2_ln = self.tc2_ln(x_tc2_ln)
x_tc2_ln = self.transpose(x_tc2_ln, (0, 3, 1, 2))
x_do = self.do(x_tc2_ln)
x_st_conv_out = x_do
return x_st_conv_out
class OutputBlock(nn.Cell):
"""
# Output block contains 'TNFF' structure
# T: Gated Temporal Convolution Layer (GLU or GTU)
# N: Layer Normolization
# F: Fully-Connected Layer
# F: Fully-Connected Layer
"""
def __init__(self, Ko, last_block_channel, channels, end_channel, n_vertex, gated_act_func, drop_rate):
super(OutputBlock, self).__init__()
self.Ko = Ko
self.last_block_channel = last_block_channel
self.channels = channels
self.end_channel = end_channel
self.n_vertex = n_vertex
self.gated_act_func = gated_act_func
self.drop_rate = drop_rate
self.tmp_conv1 = TemporalConvLayer(self.Ko, self.last_block_channel, \
self.channels[0], self.n_vertex, self.gated_act_func)
self.fc1 = nn.Dense(self.channels[0], self.channels[1]).to_float(mstype.float16)
self.fc2 = nn.Dense(self.channels[1], self.end_channel).to_float(mstype.float16)
self.tc1_ln = nn.LayerNorm([self.n_vertex, self.channels[0]], begin_norm_axis=2, \
begin_params_axis=2, epsilon=1e-05)
self.sigmoid = nn.Sigmoid()
self.transpose = ops.Transpose()
def construct(self, x):
"""OutputBlock compute"""
x_tc1 = self.tmp_conv1(x)
x_tc1_ln = self.tc1_ln(self.transpose(x_tc1, (0, 2, 3, 1)))
x_fc1 = self.fc1(x_tc1_ln)
x_act_func = self.sigmoid(x_fc1)
x_fc2 = self.transpose(self.fc2(x_act_func), (0, 3, 1, 2))
x_out = x_fc2
return x_out

View File

@ -0,0 +1,35 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
stgcn network with loss.
"""
import mindspore.nn as nn
import mindspore.ops as P
class LossCellWithNetwork(nn.Cell):
"""STGCN loss."""
def __init__(self, network):
super(LossCellWithNetwork, self).__init__()
self.loss = nn.MSELoss()
self.network = network
self.reshape = P.Reshape()
def construct(self, x, label):
x = self.network(x)
x = self.reshape(x, (len(x), -1))
label = self.reshape(label, (len(label), -1))
STGCN_loss = self.loss(x, label)
return STGCN_loss

View File

@ -0,0 +1,64 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
stgcn network.
"""
import mindspore.nn as nn
from src.model import layers
class STGCN_Conv(nn.Cell):
"""
# STGCN(ChebConv) contains 'TGTND TGTND TNFF' structure
# ChebConv is the graph convolution from ChebyNet.
# Using the Chebyshev polynomials of the first kind to
# approximate graph convolution kernel from Spectral CNN.
# GCNConv is the graph convolution from GCN.
# T: Gated Temporal Convolution Layer (GLU or GTU)
# G: Graph Convolution Layer (ChebConv)
# T: Gated Temporal Convolution Layer (GLU or GTU)
# N: Layer Normolization
# D: Dropout
# T: Gated Temporal Convolution Layer (GLU or GTU)
# G: Graph Convolution Layer (ChebConv)
# T: Gated Temporal Convolution Layer (GLU or GTU)
# N: Layer Normolization
# D: Dropout
# T: Gated Temporal Convolution Layer (GLU or GTU)
# N: Layer Normalization
# F: Fully-Connected Layer
# F: Fully-Connected Layer
"""
def __init__(self, Kt, Ks, blocks, T, n_vertex, gated_act_func, graph_conv_type, chebconv_matrix, drop_rate):
super(STGCN_Conv, self).__init__()
modules = []
for l in range(len(blocks) - 3):
modules.append(layers.STConvBlock(Kt, Ks, n_vertex, blocks[l][-1], blocks[l+1], \
gated_act_func, graph_conv_type, chebconv_matrix, drop_rate))
self.st_blocks = nn.SequentialCell(modules)
Ko = T - (len(blocks) - 3) * 2 * (Kt - 1)
self.Ko = Ko
self.output = layers.OutputBlock(self.Ko, blocks[-3][-1], blocks[-2], \
blocks[-1][0], n_vertex, gated_act_func, drop_rate)
def construct(self, x):
x_stbs = self.st_blocks(x)
x_out = self.output(x_stbs)
return x_out

View File

@ -0,0 +1,109 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Calculate laplacian matrix, used to network weight.
Evaluate the performance of net work.
"""
import numpy as np
import mindspore.ops as ops
from scipy.linalg import fractional_matrix_power
from scipy.sparse.linalg import eigs
def calculate_laplacian_matrix(adj_mat, mat_type):
"""
calculate laplacian matrix used for graph convolution layer.
"""
n_vertex = adj_mat.shape[0]
# row sum
deg_mat_row = np.asmatrix(np.diag(np.sum(adj_mat, axis=1)))
# column sum
#deg_mat_col = np.asmatrix(np.diag(np.sum(adj_mat, axis=0)))
deg_mat = deg_mat_row
adj_mat = np.asmatrix(adj_mat)
id_mat = np.asmatrix(np.identity(n_vertex))
# Combinatorial
com_lap_mat = deg_mat - adj_mat
# For SpectraConv
# To [0, 1]
sym_normd_lap_mat = np.matmul(np.matmul(fractional_matrix_power(deg_mat, -0.5), \
com_lap_mat), fractional_matrix_power(deg_mat, -0.5))
# For ChebConv
# From [0, 1] to [-1, 1]
lambda_max_sym = eigs(sym_normd_lap_mat, k=1, which='LR')[0][0].real
wid_sym_normd_lap_mat = 2 * sym_normd_lap_mat / lambda_max_sym - id_mat
# For GCNConv
wid_deg_mat = deg_mat + id_mat
wid_adj_mat = adj_mat + id_mat
hat_sym_normd_lap_mat = np.matmul(np.matmul(fractional_matrix_power(wid_deg_mat, -0.5), \
wid_adj_mat), fractional_matrix_power(wid_deg_mat, -0.5))
# Random Walk
rw_lap_mat = np.matmul(np.linalg.matrix_power(deg_mat, -1), adj_mat)
# For SpectraConv
# To [0, 1]
rw_normd_lap_mat = id_mat - rw_lap_mat
# For ChebConv
# From [0, 1] to [-1, 1]
lambda_max_rw = eigs(rw_lap_mat, k=1, which='LR')[0][0].real
wid_rw_normd_lap_mat = 2 * rw_normd_lap_mat / lambda_max_rw - id_mat
# For GCNConv
wid_deg_mat = deg_mat + id_mat
wid_adj_mat = adj_mat + id_mat
hat_rw_normd_lap_mat = np.matmul(np.linalg.matrix_power(wid_deg_mat, -1), wid_adj_mat)
if mat_type == 'wid_sym_normd_lap_mat':
return wid_sym_normd_lap_mat
if mat_type == 'hat_sym_normd_lap_mat':
return hat_sym_normd_lap_mat
if mat_type == 'wid_rw_normd_lap_mat':
return wid_rw_normd_lap_mat
if mat_type == 'hat_rw_normd_lap_mat':
return hat_rw_normd_lap_mat
raise ValueError(f'ERROR: "{mat_type}" is unknown.')
def evaluate_metric(model, dataset, scaler):
"""
evaluate the performance of network.
"""
mae, sum_y, mape, mse = [], [], [], []
for data in dataset.create_dict_iterator():
x = data['inputs']
y = data['labels']
y_pred = model(x)
y_pred = ops.Reshape()(y_pred, (len(y_pred), -1))
y_pred = scaler.inverse_transform(y_pred.asnumpy()).reshape(-1)
y = scaler.inverse_transform(y.asnumpy()).reshape(-1)
d = np.abs(y - y_pred)
mae += d.tolist()
sum_y += y.tolist()
mape += (d / y).tolist()
mse += (d ** 2).tolist()
MAE = np.array(mae).mean()
MAPE = np.array(mape).mean()
RMSE = np.sqrt(np.array(mse).mean())
#WMAPE = np.sum(np.array(mae)) / np.sum(np.array(sum_y))
return MAE, RMSE, MAPE

View File

@ -0,0 +1,190 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
testing network performance.
"""
import os
import ast
import argparse
import pandas as pd
from sklearn import preprocessing
from mindspore.common import dtype as mstype
from mindspore import context
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init
from mindspore.train.model import ParallelMode
from src.model import models
from src.config import stgcn_chebconv_45min_cfg, stgcn_chebconv_30min_cfg, stgcn_chebconv_15min_cfg, stgcn_gcnconv_45min_cfg, stgcn_gcnconv_30min_cfg, stgcn_gcnconv_15min_cfg
from src import dataloader, utility
os.system("export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python")
parser = argparse.ArgumentParser('mindspore stgcn testing')
parser.add_argument('--device_target', type=str, default='Ascend', \
help='device where the code will be implemented. (Default: Ascend)')
# The way of testing
parser.add_argument('--run_modelarts', type=ast.literal_eval, default=False, help='Run on modelarts.')
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
parser.add_argument('--device_id', type=int, default=0, help='Device id.')
# Path for data and checkpoint
parser.add_argument('--data_url', type=str, default='', help='Test dataset directory.')
parser.add_argument('--train_url', type=str, default='', help='Output directory.')
parser.add_argument('--data_path', type=str, default="vel.csv", help='Dataset file of vel.')
parser.add_argument('--wam_path', type=str, default="adj_mat.csv", help='Dataset file of warm.')
parser.add_argument('--ckpt_url', type=str, default='', help='The path of checkpoint.')
parser.add_argument('--ckpt_name', type=str, default="", help='the name of checkpoint.')
# Super parameters for testing
parser.add_argument('--n_pred', type=int, default=3, help='The number of time interval for predcition')
#network
parser.add_argument('--graph_conv_type', type=str, default="gcnconv", help='Grapg convolution type')
#dataset
args, _ = parser.parse_known_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
if args.graph_conv_type == "chebconv":
if args.n_pred == 9:
cfg = stgcn_chebconv_45min_cfg
elif args.n_pred == 6:
cfg = stgcn_chebconv_30min_cfg
elif args.n_pred == 3:
cfg = stgcn_chebconv_15min_cfg
else:
raise ValueError("Unsupported n_pred.")
elif args.graph_conv_type == "gcnconv":
if args.n_pred == 9:
cfg = stgcn_gcnconv_45min_cfg
elif args.n_pred == 6:
cfg = stgcn_gcnconv_30min_cfg
elif args.n_pred == 3:
cfg = stgcn_gcnconv_15min_cfg
else:
raise ValueError("Unsupported pred.")
else:
raise ValueError("Unsupported graph_conv_type.")
if ((cfg.Kt - 1) * 2 * cfg.stblock_num > cfg.n_his) or ((cfg.Kt - 1) * 2 * cfg.stblock_num <= 0):
raise ValueError(f'ERROR: {cfg.Kt} and {cfg.stblock_num} are unacceptable.')
Ko = cfg.n_his - (cfg.Kt - 1) * 2 * cfg.stblock_num
if (cfg.graph_conv_type != "chebconv") and (cfg.graph_conv_type != "gcnconv"):
raise NotImplementedError(f'ERROR: {cfg.graph_conv_type} is not implemented.')
if (cfg.graph_conv_type == 'gcnconv') and (cfg.Ks != 2):
cfg.Ks = 2
# blocks: settings of channel size in st_conv_blocks and output layer,
# using the bottleneck design in st_conv_blocks
blocks = []
blocks.append([1])
for l in range(cfg.stblock_num):
blocks.append([64, 16, 64])
if Ko == 0:
blocks.append([128])
elif Ko > 0:
blocks.append([128, 128])
blocks.append([1])
day_slot = int(24 * 60 / cfg.time_intvl)
cfg.n_pred = cfg.n_pred
time_pred = cfg.n_pred * cfg.time_intvl
time_pred_str = str(time_pred) + '_mins'
if args.run_modelarts:
import moxing as mox
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
cfg.batch_size = cfg.batch_size*int(8/device_num)
local_data_url = '/cache/data'
local_ckpt_url = '/cache/ckpt'
mox.file.copy_parallel(args.data_url, local_data_url)
mox.file.copy_parallel(args.ckpt_url, local_ckpt_url)
if device_num > 1:
init()
context.set_auto_parallel_context(device_num=device_num, \
parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
data_dir = local_data_url + '/'
local_ckpt_url = local_ckpt_url + '/'
else:
if args.run_distribute:
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
cfg.batch_size = cfg.batch_size*int(8/device_num)
context.set_context(device_id=device_id)
init()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, \
parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
else:
device_num = 1
device_id = args.device_id
context.set_context(device_id=args.device_id)
data_dir = args.data_url + '/'
local_ckpt_url = args.ckpt_url + '/'
adj_mat = dataloader.load_weighted_adjacency_matrix(data_dir+args.wam_path)
n_vertex_vel = pd.read_csv(data_dir+args.data_path, header=None).shape[1]
n_vertex_adj = pd.read_csv(data_dir+args.wam_path, header=None).shape[1]
if n_vertex_vel == n_vertex_adj:
n_vertex = n_vertex_vel
else:
raise ValueError(f'ERROR: number of vertices in dataset is not equal to \
number of vertices in weighted adjacency matrix.')
mat = utility.calculate_laplacian_matrix(adj_mat, cfg.mat_type)
conv_matrix = Tensor(Tensor.from_numpy(mat), mstype.float32)
if cfg.graph_conv_type == "chebconv":
if (cfg.mat_type != "wid_sym_normd_lap_mat") and (cfg.mat_type != "wid_rw_normd_lap_mat"):
raise ValueError(f'ERROR: {cfg.mat_type} is wrong.')
elif cfg.graph_conv_type == "gcnconv":
if (cfg.mat_type != "hat_sym_normd_lap_mat") and (cfg.mat_type != "hat_rw_normd_lap_mat"):
raise ValueError(f'ERROR: {cfg.mat_type} is wrong.')
stgcn_conv = models.STGCN_Conv(cfg.Kt, cfg.Ks, blocks, cfg.n_his, n_vertex, \
cfg.gated_act_func, cfg.graph_conv_type, conv_matrix, cfg.drop_rate)
net = stgcn_conv
if __name__ == "__main__":
zscore = preprocessing.StandardScaler()
if args.run_modelarts or args.run_distribute:
dataset = dataloader.create_dataset(data_dir+args.data_path, \
cfg.batch_size, cfg.n_his, cfg.n_pred, zscore, False, device_num, device_id, mode=2)
else:
dataset = dataloader.create_dataset(data_dir+args.data_path, \
cfg.batch_size, cfg.n_his, cfg.n_pred, zscore, True, device_num, device_id, mode=2)
data_len = dataset.get_dataset_size()
param_dict = load_checkpoint(local_ckpt_url+args.ckpt_name)
load_param_into_net(net, param_dict)
test_MAE, test_RMSE, test_MAPE = utility.evaluate_metric(net, dataset, zscore)
print(f'MAE {test_MAE:.2f} | MAPE {test_MAPE*100:.2f} | RMSE {test_RMSE:.2f}')

View File

@ -0,0 +1,222 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
train network.
"""
import os
import argparse
import ast
import pandas as pd
from sklearn import preprocessing
from mindspore.common import dtype as mstype
import mindspore.nn as nn
from mindspore import context
from mindspore import Tensor
from mindspore.communication.management import init
from mindspore.train.model import Model, ParallelMode
from mindspore.train.callback import CheckpointConfig, LossMonitor, ModelCheckpoint, TimeMonitor
from mindspore.common import set_seed
from src.config import stgcn_chebconv_45min_cfg, stgcn_chebconv_30min_cfg, stgcn_chebconv_15min_cfg, stgcn_gcnconv_45min_cfg, stgcn_gcnconv_30min_cfg, stgcn_gcnconv_15min_cfg
from src import dataloader, utility
from src.model import models, metric
set_seed(1)
parser = argparse.ArgumentParser('mindspore stgcn training')
# The way of training
parser.add_argument('--device_target', type=str, default='Ascend', \
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
parser.add_argument('--device_id', type=int, default=0, help='Device id')
parser.add_argument('--run_modelarts', type=ast.literal_eval, default=False, help='Run on modelarts')
parser.add_argument('--save_check_point', type=bool, default=True, help='Whether save checkpoint')
# Path for data and checkpoint
parser.add_argument('--data_url', type=str, required=True, help='Train dataset directory.')
parser.add_argument('--train_url', type=str, required=True, help='Save checkpoint directory.')
parser.add_argument('--data_path', type=str, default="vel.csv", help='Dataset file of vel.')
parser.add_argument('--wam_path', type=str, default="adj_mat.csv", help='Dataset file of warm.')
# Super parameters for training
parser.add_argument('--n_pred', type=int, default=3, help='The number of time interval for predcition, default as 3')
parser.add_argument('--opt', type=str, default='AdamW', help='optimizer, default as AdamW')
#network
parser.add_argument('--graph_conv_type', type=str, default="gcnconv", help='Grapg convolution type')
args, _ = parser.parse_known_args()
if args.graph_conv_type == "chebconv":
if args.n_pred == 9:
cfg = stgcn_chebconv_45min_cfg
elif args.n_pred == 6:
cfg = stgcn_chebconv_30min_cfg
elif args.n_pred == 3:
cfg = stgcn_chebconv_15min_cfg
else:
raise ValueError("Unsupported n_pred.")
elif args.graph_conv_type == "gcnconv":
if args.n_pred == 9:
cfg = stgcn_gcnconv_45min_cfg
elif args.n_pred == 6:
cfg = stgcn_gcnconv_30min_cfg
elif args.n_pred == 3:
cfg = stgcn_gcnconv_15min_cfg
else:
raise ValueError("Unsupported pred.")
else:
raise ValueError("Unsupported graph_conv_type.")
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False)
if ((cfg.Kt - 1) * 2 * cfg.stblock_num > cfg.n_his) or ((cfg.Kt - 1) * 2 * cfg.stblock_num <= 0):
raise ValueError(f'ERROR: {cfg.Kt} and {cfg.stblock_num} are unacceptable.')
Ko = cfg.n_his - (cfg.Kt - 1) * 2 * cfg.stblock_num
if (cfg.graph_conv_type != "chebconv") and (cfg.graph_conv_type != "gcnconv"):
raise NotImplementedError(f'ERROR: {cfg.graph_conv_type} is not implemented.')
if (cfg.graph_conv_type == 'gcnconv') and (cfg.Ks != 2):
cfg.Ks = 2
# blocks: settings of channel size in st_conv_blocks and output layer,
# using the bottleneck design in st_conv_blocks
blocks = []
blocks.append([1])
for l in range(cfg.stblock_num):
blocks.append([64, 16, 64])
if Ko == 0:
blocks.append([128])
elif Ko > 0:
blocks.append([128, 128])
blocks.append([1])
day_slot = int(24 * 60 / cfg.time_intvl)
cfg.n_pred = cfg.n_pred
time_pred = cfg.n_pred * cfg.time_intvl
time_pred_str = str(time_pred) + '_mins'
if args.run_modelarts:
import moxing as mox
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
cfg.batch_size = cfg.batch_size*int(8/device_num)
context.set_context(device_id=device_id)
local_data_url = '/cache/data'
local_train_url = '/cache/train'
#mox.file.make_dirs(local_train_url)
mox.file.copy_parallel(args.data_url, local_data_url)
if device_num > 1:
init()
#context.set_auto_parallel_context(parameter_broadcast=True)
context.set_auto_parallel_context(device_num=device_num, \
parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
data_dir = local_data_url + '/'
else:
if args.run_distribute:
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
cfg.batch_size = cfg.batch_size*int(8/device_num)
context.set_context(device_id=device_id)
init()
context.reset_auto_parallel_context()
#context.set_auto_parallel_context(parameter_broadcast=True)
context.set_auto_parallel_context(device_num=device_num, \
parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
else:
context.set_context(device_id=args.device_id)
device_num = 1
cfg.batch_size = cfg.batch_size*int(8/device_num)
device_id = args.device_id
data_dir = args.data_url + '/'
model_save_path = args.train_url + cfg.graph_conv_type + '_' + time_pred_str
adj_mat = dataloader.load_weighted_adjacency_matrix(data_dir+args.wam_path)
n_vertex_vel = pd.read_csv(data_dir+args.data_path, header=None).shape[1]
n_vertex_adj = pd.read_csv(data_dir+args.wam_path, header=None).shape[1]
if n_vertex_vel == n_vertex_adj:
n_vertex = n_vertex_vel
else:
raise ValueError(f"ERROR: number of vertices in dataset is not equal to number \
of vertices in weighted adjacency matrix.")
mat = utility.calculate_laplacian_matrix(adj_mat, cfg.mat_type)
conv_matrix = Tensor(Tensor.from_numpy(mat), mstype.float32)
if cfg.graph_conv_type == "chebconv":
if (cfg.mat_type != "wid_sym_normd_lap_mat") and (cfg.mat_type != "wid_rw_normd_lap_mat"):
raise ValueError(f'ERROR: {cfg.mat_type} is wrong.')
elif cfg.graph_conv_type == "gcnconv":
if (cfg.mat_type != "hat_sym_normd_lap_mat") and (cfg.mat_type != "hat_rw_normd_lap_mat"):
raise ValueError(f'ERROR: {cfg.mat_type} is wrong.')
stgcn_conv = models.STGCN_Conv(cfg.Kt, cfg.Ks, blocks, cfg.n_his, n_vertex, \
cfg.gated_act_func, cfg.graph_conv_type, conv_matrix, cfg.drop_rate)
net = stgcn_conv
if __name__ == "__main__":
#start training
zscore = preprocessing.StandardScaler()
if args.run_modelarts or args.run_distribute:
dataset = dataloader.create_dataset(data_dir+args.data_path, cfg.batch_size, cfg.n_his, \
cfg.n_pred, zscore, False, device_num, device_id, mode=0)
else:
dataset = dataloader.create_dataset(data_dir+args.data_path, cfg.batch_size, cfg.n_his, \
cfg.n_pred, zscore, True, device_num, device_id, mode=0)
data_len = dataset.get_dataset_size()
learning_rate = nn.exponential_decay_lr(learning_rate=cfg.learning_rate, decay_rate=cfg.gamma, \
total_step=data_len*cfg.epochs, step_per_epoch=data_len, decay_epoch=cfg.decay_epoch)
if args.opt == "RMSProp":
optimizer = nn.RMSProp(net.trainable_params(), learning_rate=learning_rate)
elif args.opt == "Adam":
optimizer = nn.Adam(net.trainable_params(), learning_rate=learning_rate, \
weight_decay=cfg.weight_decay_rate)
elif args.opt == "AdamW":
optimizer = nn.AdamWeightDecay(net.trainable_params(), learning_rate=learning_rate, \
weight_decay=cfg.weight_decay_rate)
else:
raise ValueError(f'ERROR: optimizer {args.opt} is undefined.')
loss_cb = LossMonitor()
time_cb = TimeMonitor(data_size=data_len)
callbacks = [time_cb, loss_cb]
#save training results
if args.save_check_point and (device_num == 1 or device_id == 0):
config_ck = CheckpointConfig(
save_checkpoint_steps=data_len*cfg.epochs, keep_checkpoint_max=cfg.epochs)
if args.run_modelarts:
ckpoint_cb = ModelCheckpoint(prefix='STGCN'+cfg.graph_conv_type+str(cfg.n_pred)+'-', \
directory=local_train_url, config=config_ck)
else:
ckpoint_cb = ModelCheckpoint(prefix='STGCN', directory=model_save_path, config=config_ck)
callbacks += [ckpoint_cb]
net = metric.LossCellWithNetwork(net)
model = Model(net, optimizer=optimizer, amp_level='O3')
model.train(cfg.epochs, dataset, callbacks=callbacks)
if args.run_modelarts:
mox.file.copy_parallel(src_url=local_train_url, dst_url=args.train_url)