modify gcn network for clould

This commit is contained in:
zhanghuiyao 2021-06-04 11:35:07 +08:00
parent b52534c9c9
commit cccd62cde4
10 changed files with 533 additions and 23 deletions

View File

@ -87,6 +87,49 @@ sh run_process_data.sh ./data cora
sh run_process_data.sh ./data citeseer
```
- Running on local with Ascend
```bash
# run train with cora or citeseer dataset, DATASET_NAME is cora or citeseer
sh run_train.sh [DATASET_NAME]
```
- Running on [ModelArts](https://support.huaweicloud.com/modelarts/)
```bash
# Train cora 1p on ModelArts
# (1) Perform a or b.
# a. Set "enable_modelarts=True" on default_config.yaml file.
# Set "data_dir='/cache/data/cora'" on default_config.yaml file.
# Set "train_nodes_num=140" on default_config.yaml file.
# Set other parameters on default_config.yaml file you need.
# b. Add "enable_modelarts=True" on the website UI interface.
# Add "data_dir='/cache/data/cora'" on the website UI interface.
# Add "train_nodes_num=140" on the website UI interface.
# Add other parameters on the website UI interface.
# (2) Upload dataset to S3 bucket.
# (3) Set the code directory to "/path/gcn" on the website UI interface.
# (4) Set the startup file to "train.py" on the website UI interface.
# (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
# (6) Create your job.
#
# Train citeseer 1p on ModelArts
# (1) Perform a or b.
# a. Set "enable_modelarts=True" on default_config.yaml file.
# Set "data_dir='/cache/data/citeseer'" on default_config.yaml file.
# Set "train_nodes_num=120" on default_config.yaml file.
# Set other parameters on default_config.yaml file you need.
# b. Add "enable_modelarts=True" on the website UI interface.
# Add "data_dir='/cache/data/citeseer'" on the website UI interface.
# Add "train_nodes_num=120" on the website UI interface.
# Add other parameters on the website UI interface.
# (2) Upload dataset to S3 bucket.
# (3) Set the code directory to "/path/gcn" on the website UI interface.
# (4) Set the startup file to "train.py" on the website UI interface.
# (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
# (6) Create your job.
```
### [Script Description](#contents)
### [Script and Sample Code](#contents)
@ -95,7 +138,16 @@ sh run_process_data.sh ./data citeseer
.
└─gcn
├─README.md
├─README_CN.md
├─model_utils
| ├─__init__.py # init file
| ├─config.py # Parse arguments
| ├─device_adapter.py # Device adapter for ModelArts
| ├─local_adapter.py # Local adapter
| └─moxing_adapter.py # Moxing adapter for ModelArts
|
├─scripts
| ├─run_infer_310.sh # shell script for infer on Ascend 310
| ├─run_process_data.sh # Generate dataset in mindrecord format
| └─run_train.sh # Launch training, now only Ascend backend is supported.
|
@ -105,6 +157,11 @@ sh run_process_data.sh ./data citeseer
| ├─gcn.py # GCN backbone
| └─metrics.py # Loss and accuracy
|
├─default_config.py # Configurations
├─export.py # export scripts
├─mindspore_hub_conf.py # mindspore_hub_conf scripts
├─postprocess.py # postprocess script
├─preprocess.py # preprocess scripts
└─train.py # Train net, evaluation is performed after every training epoch. After the verification result converges, the training stops, then testing is performed.
```

View File

@ -95,6 +95,49 @@ sh run_process_data.sh ./data cora
sh run_process_data.sh ./data citeseer
```
- Running on local with Ascend
```bash
# 在 cora 或 citeseer 数据集上训练, DATASET_NAME 设置为 cora 或 citeseer
sh run_train.sh [DATASET_NAME]
```
- Running on [ModelArts](https://support.huaweicloud.com/modelarts/)
```bash
# 在 ModelArts 上使用 单卡训练 cora 数据集
# (1) 执行a或者b
# a. 在 default_config.yaml 文件中设置 "enable_modelarts=True"
# 在 default_config.yaml 文件中设置 "data_dir='/cache/data/cora'"
# 在 default_config.yaml 文件中设置 "train_nodes_num=140"
# 在 default_config.yaml 文件中设置 其他参数
# b. 在网页上设置 "enable_modelarts=True"
# 在网页上设置 "data_dir='/cache/data/cora'"
# 在网页上设置 "train_nodes_num=140"
# 在网页上设置 其他参数
# (2) 上传你的数据集到 S3 桶上
# (3) 在网页上设置你的代码路径为 "/path/gcn"
# (4) 在网页上设置启动文件为 "train.py"
# (5) 在网页上设置"训练数据集"、"训练输出文件路径"、"作业日志路径"等
# (6) 创建训练作业
#
# 在 ModelArts 上使用 单卡训练 citeseer 数据集
# (1) 执行a或者b
# a. 在 default_config.yaml 文件中设置 "enable_modelarts=True"
# 在 default_config.yaml 文件中设置 "data_dir='/cache/data/citeseer'"
# 在 default_config.yaml 文件中设置 "train_nodes_num=120"
# 在 default_config.yaml 文件中设置 其他参数
# b. 在网页上设置 "enable_modelarts=True"
# 在网页上设置 "data_dir='/cache/data/citeseer'"
# 在网页上设置 "train_nodes_num=120"
# 在网页上设置 其他参数
# (2) 上传你的数据集到 S3 桶上
# (3) 在网页上设置你的代码路径为 "/path/gcn"
# (4) 在网页上设置启动文件为 "train.py"
# (5) 在网页上设置"训练数据集"、"训练输出文件路径"、"作业日志路径"等
# (6) 创建训练作业
```
## 脚本说明
### 脚本及样例代码
@ -103,7 +146,16 @@ sh run_process_data.sh ./data citeseer
.
└─gcn
├─README.md
├─README_CN.md
├─model_utils
| ├─__init__.py # 初始化文件
| ├─config.py # 参数配置
| ├─device_adapter.py # ModelArts的设备适配器
| ├─local_adapter.py # 本地适配器
| └─moxing_adapter.py # ModelArts的模型适配器
|
├─scripts
| ├─run_infer_310.sh # Ascend310 推理shell脚本
| ├─run_process_data.sh # 生成MindRecord格式的数据集
| └─run_train.sh # 启动训练目前只支持Ascend后端
|
@ -113,6 +165,11 @@ sh run_process_data.sh ./data citeseer
| ├─gcn.py # GCN骨干
| └─metrics.py # 损失和准确率
|
├─default_config.py # 配置文件
├─export.py # 导出脚本
├─mindspore_hub_conf.py # mindspore hub 脚本
├─postprocess.py # 后处理脚本
├─preprocess.py # 预处理脚本
└─train.py # 训练网络,每个训练轮次后评估验证结果收敛后,训练停止,然后进行测试。
```

View File

@ -0,0 +1,32 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: "Ascend"
need_modelarts_dataset_unzip: False
modelarts_dataset_unzip_name: ""
# ==============================================================================
# train options
data_dir: "./data/cora/cora_mr"
train_nodes_num: 140
eval_nodes_num: 500
test_nodes_num: 1000
save_TSNE: False
save_ckptpath: "ckpts/"
---
# Help description for each configuration
data_dir: "Dataset directory"
train_nodes_num: "Nodes numbers for training"
eval_nodes_num: "Nodes numbers for evaluation"
test_nodes_num: "Nodes numbers for test"
save_TSNE: "Whether to save t-SNE graph"

View File

@ -0,0 +1,126 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pformat
import yaml
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
return Config(final_config)
config = get_config()

View File

@ -0,0 +1,27 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Device adapter for ModelArts"""
from .config import config
if config.enable_modelarts:
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
]

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return "Local Job"

View File

@ -0,0 +1,116 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from .config import config
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
# Run the main function
run_func(*args, **kwargs)
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper

View File

@ -35,8 +35,10 @@ then
fi
mkdir ./train
cp ../*.py ./train
cp ../*.yaml ./train
cp *.sh ./train
cp -r ../src ./train
cp -r ../model_utils ./train
cd ./train || exit
env > env.log
echo "start training for device $DEVICE_ID"

View File

@ -18,8 +18,6 @@ GCN training script.
"""
import os
import time
import argparse
import ast
import numpy as np
from matplotlib import pyplot as plt
@ -34,6 +32,10 @@ from src.metrics import LossAccuracyWrapper, TrainNetWrapper
from src.config import ConfigGCN
from src.dataset import get_adj_features_labels, get_mask
from model_utils.config import config as default_args
from model_utils.moxing_adapter import moxing_wrapper
from model_utils.device_adapter import get_device_id, get_device_num
def t_SNE(out_feature, dim):
t_sne = manifold.TSNE(n_components=dim, init='pca', random_state=0)
@ -46,27 +48,79 @@ def update_graph(i, data, scat, plot):
return scat, plot
def train():
"""Train model."""
parser = argparse.ArgumentParser(description='GCN')
parser.add_argument('--data_dir', type=str, default='./data/cora/cora_mr', help='Dataset directory')
parser.add_argument('--train_nodes_num', type=int, default=140, help='Nodes numbers for training')
parser.add_argument('--eval_nodes_num', type=int, default=500, help='Nodes numbers for evaluation')
parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test')
parser.add_argument('--save_TSNE', type=ast.literal_eval, default=False, help='Whether to save t-SNE graph')
args_opt = parser.parse_args()
if not os.path.exists("ckpts"):
os.mkdir("ckpts")
def modelarts_pre_process():
'''modelarts pre process function.'''
def unzip(zip_file, save_dir):
import zipfile
s_time = time.time()
if not os.path.exists(os.path.join(save_dir, default_args.modelarts_dataset_unzip_name)):
zip_isexist = zipfile.is_zipfile(zip_file)
if zip_isexist:
fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist())
print("Extract Start...")
print("unzip file num: {}".format(data_num))
data_print = int(data_num / 100) if data_num > 100 else 1
i = 0
for file in fz.namelist():
if i % data_print == 0:
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
i += 1
fz.extract(file, save_dir)
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
int(int(time.time() - s_time) % 60)))
print("Extract Done.")
else:
print("This is not zip.")
else:
print("Zip has been extracted.")
if default_args.need_modelarts_dataset_unzip:
zip_file_1 = os.path.join(default_args.data_path, default_args.modelarts_dataset_unzip_name + ".zip")
save_dir_1 = os.path.join(default_args.data_path)
sync_lock = "/tmp/unzip_sync.lock"
# Each server contains 8 devices as most.
if default_args.device_target == "Ascend":
device_id = get_device_id()
device_num = get_device_num()
else:
raise ValueError("Not support device_target.")
if device_id % min(device_num, 8) == 0 and not os.path.exists(sync_lock):
print("Zip file path: ", zip_file_1)
print("Unzip file save dir: ", save_dir_1)
unzip(zip_file_1, save_dir_1)
print("===Finish extract data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Device: {}, Finish sync unzip data from {} to {}.".format(device_id, zip_file_1, save_dir_1))
default_args.save_ckptpath = os.path.join(default_args.output_path, default_args.save_ckptpath)
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_train():
"""Train model."""
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend", save_graphs=False)
config = ConfigGCN()
adj, feature, label_onehot, label = get_adj_features_labels(args_opt.data_dir)
adj, feature, label_onehot, label = get_adj_features_labels(default_args.data_dir)
nodes_num = label_onehot.shape[0]
train_mask = get_mask(nodes_num, 0, args_opt.train_nodes_num)
eval_mask = get_mask(nodes_num, args_opt.train_nodes_num, args_opt.train_nodes_num + args_opt.eval_nodes_num)
test_mask = get_mask(nodes_num, nodes_num - args_opt.test_nodes_num, nodes_num)
train_mask = get_mask(nodes_num, 0, default_args.train_nodes_num)
eval_mask = get_mask(nodes_num, default_args.train_nodes_num,
default_args.train_nodes_num + default_args.eval_nodes_num)
test_mask = get_mask(nodes_num, nodes_num - default_args.test_nodes_num, nodes_num)
class_num = label_onehot.shape[1]
input_dim = feature.shape[1]
@ -81,7 +135,7 @@ def train():
loss_list = []
if args_opt.save_TSNE:
if default_args.save_TSNE:
out_feature = gcn_net()
tsne_result = t_SNE(out_feature.asnumpy(), 2)
graph_data = []
@ -108,7 +162,7 @@ def train():
"train_acc=", "{:.5f}".format(train_accuracy), "val_loss=", "{:.5f}".format(eval_loss),
"val_acc=", "{:.5f}".format(eval_accuracy), "time=", "{:.5f}".format(time.time() - t))
if args_opt.save_TSNE:
if default_args.save_TSNE:
out_feature = gcn_net()
tsne_result = t_SNE(out_feature.asnumpy(), 2)
graph_data.append(tsne_result)
@ -116,9 +170,12 @@ def train():
if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]):
print("Early stopping...")
break
save_checkpoint(gcn_net, "ckpts/gcn.ckpt")
if not os.path.isdir(default_args.save_ckptpath):
os.makedirs(default_args.save_ckptpath)
ckpt_path = os.path.join(default_args.save_ckptpath, "gcn.ckpt")
save_checkpoint(gcn_net, ckpt_path)
gcn_net_test = GCN(config, input_dim, class_num)
load_checkpoint("ckpts/gcn.ckpt", net=gcn_net_test)
load_checkpoint(ckpt_path, net=gcn_net_test)
gcn_net_test.add_flags_recursive(fp16=True)
test_net = LossAccuracyWrapper(gcn_net_test, label_onehot, test_mask, config.weight_decay)
@ -130,10 +187,10 @@ def train():
print("Test set results:", "loss=", "{:.5f}".format(test_loss),
"accuracy=", "{:.5f}".format(test_accuracy), "time=", "{:.5f}".format(time.time() - t_test))
if args_opt.save_TSNE:
if default_args.save_TSNE:
ani = animation.FuncAnimation(fig, update_graph, frames=range(config.epochs + 1), fargs=(graph_data, scat, plt))
ani.save('t-SNE_visualization.gif', writer='imagemagick')
if __name__ == '__main__':
train()
run_train()