modify gcn network for clould
This commit is contained in:
parent
b52534c9c9
commit
cccd62cde4
|
@ -87,6 +87,49 @@ sh run_process_data.sh ./data cora
|
|||
sh run_process_data.sh ./data citeseer
|
||||
```
|
||||
|
||||
- Running on local with Ascend
|
||||
|
||||
```bash
|
||||
# run train with cora or citeseer dataset, DATASET_NAME is cora or citeseer
|
||||
sh run_train.sh [DATASET_NAME]
|
||||
```
|
||||
|
||||
- Running on [ModelArts](https://support.huaweicloud.com/modelarts/)
|
||||
|
||||
```bash
|
||||
# Train cora 1p on ModelArts
|
||||
# (1) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "data_dir='/cache/data/cora'" on default_config.yaml file.
|
||||
# Set "train_nodes_num=140" on default_config.yaml file.
|
||||
# Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "data_dir='/cache/data/cora'" on the website UI interface.
|
||||
# Add "train_nodes_num=140" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (2) Upload dataset to S3 bucket.
|
||||
# (3) Set the code directory to "/path/gcn" on the website UI interface.
|
||||
# (4) Set the startup file to "train.py" on the website UI interface.
|
||||
# (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (6) Create your job.
|
||||
#
|
||||
# Train citeseer 1p on ModelArts
|
||||
# (1) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "data_dir='/cache/data/citeseer'" on default_config.yaml file.
|
||||
# Set "train_nodes_num=120" on default_config.yaml file.
|
||||
# Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "data_dir='/cache/data/citeseer'" on the website UI interface.
|
||||
# Add "train_nodes_num=120" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (2) Upload dataset to S3 bucket.
|
||||
# (3) Set the code directory to "/path/gcn" on the website UI interface.
|
||||
# (4) Set the startup file to "train.py" on the website UI interface.
|
||||
# (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (6) Create your job.
|
||||
```
|
||||
|
||||
### [Script Description](#contents)
|
||||
|
||||
### [Script and Sample Code](#contents)
|
||||
|
@ -95,7 +138,16 @@ sh run_process_data.sh ./data citeseer
|
|||
.
|
||||
└─gcn
|
||||
├─README.md
|
||||
├─README_CN.md
|
||||
├─model_utils
|
||||
| ├─__init__.py # init file
|
||||
| ├─config.py # Parse arguments
|
||||
| ├─device_adapter.py # Device adapter for ModelArts
|
||||
| ├─local_adapter.py # Local adapter
|
||||
| └─moxing_adapter.py # Moxing adapter for ModelArts
|
||||
|
|
||||
├─scripts
|
||||
| ├─run_infer_310.sh # shell script for infer on Ascend 310
|
||||
| ├─run_process_data.sh # Generate dataset in mindrecord format
|
||||
| └─run_train.sh # Launch training, now only Ascend backend is supported.
|
||||
|
|
||||
|
@ -105,6 +157,11 @@ sh run_process_data.sh ./data citeseer
|
|||
| ├─gcn.py # GCN backbone
|
||||
| └─metrics.py # Loss and accuracy
|
||||
|
|
||||
├─default_config.py # Configurations
|
||||
├─export.py # export scripts
|
||||
├─mindspore_hub_conf.py # mindspore_hub_conf scripts
|
||||
├─postprocess.py # postprocess script
|
||||
├─preprocess.py # preprocess scripts
|
||||
└─train.py # Train net, evaluation is performed after every training epoch. After the verification result converges, the training stops, then testing is performed.
|
||||
```
|
||||
|
||||
|
|
|
@ -95,6 +95,49 @@ sh run_process_data.sh ./data cora
|
|||
sh run_process_data.sh ./data citeseer
|
||||
```
|
||||
|
||||
- Running on local with Ascend
|
||||
|
||||
```bash
|
||||
# 在 cora 或 citeseer 数据集上训练, DATASET_NAME 设置为 cora 或 citeseer
|
||||
sh run_train.sh [DATASET_NAME]
|
||||
```
|
||||
|
||||
- Running on [ModelArts](https://support.huaweicloud.com/modelarts/)
|
||||
|
||||
```bash
|
||||
# 在 ModelArts 上使用 单卡训练 cora 数据集
|
||||
# (1) 执行a或者b
|
||||
# a. 在 default_config.yaml 文件中设置 "enable_modelarts=True"
|
||||
# 在 default_config.yaml 文件中设置 "data_dir='/cache/data/cora'"
|
||||
# 在 default_config.yaml 文件中设置 "train_nodes_num=140"
|
||||
# 在 default_config.yaml 文件中设置 其他参数
|
||||
# b. 在网页上设置 "enable_modelarts=True"
|
||||
# 在网页上设置 "data_dir='/cache/data/cora'"
|
||||
# 在网页上设置 "train_nodes_num=140"
|
||||
# 在网页上设置 其他参数
|
||||
# (2) 上传你的数据集到 S3 桶上
|
||||
# (3) 在网页上设置你的代码路径为 "/path/gcn"
|
||||
# (4) 在网页上设置启动文件为 "train.py"
|
||||
# (5) 在网页上设置"训练数据集"、"训练输出文件路径"、"作业日志路径"等
|
||||
# (6) 创建训练作业
|
||||
#
|
||||
# 在 ModelArts 上使用 单卡训练 citeseer 数据集
|
||||
# (1) 执行a或者b
|
||||
# a. 在 default_config.yaml 文件中设置 "enable_modelarts=True"
|
||||
# 在 default_config.yaml 文件中设置 "data_dir='/cache/data/citeseer'"
|
||||
# 在 default_config.yaml 文件中设置 "train_nodes_num=120"
|
||||
# 在 default_config.yaml 文件中设置 其他参数
|
||||
# b. 在网页上设置 "enable_modelarts=True"
|
||||
# 在网页上设置 "data_dir='/cache/data/citeseer'"
|
||||
# 在网页上设置 "train_nodes_num=120"
|
||||
# 在网页上设置 其他参数
|
||||
# (2) 上传你的数据集到 S3 桶上
|
||||
# (3) 在网页上设置你的代码路径为 "/path/gcn"
|
||||
# (4) 在网页上设置启动文件为 "train.py"
|
||||
# (5) 在网页上设置"训练数据集"、"训练输出文件路径"、"作业日志路径"等
|
||||
# (6) 创建训练作业
|
||||
```
|
||||
|
||||
## 脚本说明
|
||||
|
||||
### 脚本及样例代码
|
||||
|
@ -103,7 +146,16 @@ sh run_process_data.sh ./data citeseer
|
|||
.
|
||||
└─gcn
|
||||
├─README.md
|
||||
├─README_CN.md
|
||||
├─model_utils
|
||||
| ├─__init__.py # 初始化文件
|
||||
| ├─config.py # 参数配置
|
||||
| ├─device_adapter.py # ModelArts的设备适配器
|
||||
| ├─local_adapter.py # 本地适配器
|
||||
| └─moxing_adapter.py # ModelArts的模型适配器
|
||||
|
|
||||
├─scripts
|
||||
| ├─run_infer_310.sh # Ascend310 推理shell脚本
|
||||
| ├─run_process_data.sh # 生成MindRecord格式的数据集
|
||||
| └─run_train.sh # 启动训练,目前只支持Ascend后端
|
||||
|
|
||||
|
@ -113,6 +165,11 @@ sh run_process_data.sh ./data citeseer
|
|||
| ├─gcn.py # GCN骨干
|
||||
| └─metrics.py # 损失和准确率
|
||||
|
|
||||
├─default_config.py # 配置文件
|
||||
├─export.py # 导出脚本
|
||||
├─mindspore_hub_conf.py # mindspore hub 脚本
|
||||
├─postprocess.py # 后处理脚本
|
||||
├─preprocess.py # 预处理脚本
|
||||
└─train.py # 训练网络,每个训练轮次后评估验证结果收敛后,训练停止,然后进行测试。
|
||||
```
|
||||
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
need_modelarts_dataset_unzip: False
|
||||
modelarts_dataset_unzip_name: ""
|
||||
|
||||
# ==============================================================================
|
||||
# train options
|
||||
data_dir: "./data/cora/cora_mr"
|
||||
train_nodes_num: 140
|
||||
eval_nodes_num: 500
|
||||
test_nodes_num: 1000
|
||||
save_TSNE: False
|
||||
save_ckptpath: "ckpts/"
|
||||
|
||||
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
data_dir: "Dataset directory"
|
||||
train_nodes_num: "Nodes numbers for training"
|
||||
eval_nodes_num: "Nodes numbers for evaluation"
|
||||
test_nodes_num: "Nodes numbers for test"
|
||||
save_TSNE: "Whether to save t-SNE graph"
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pformat
|
||||
import yaml
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from .config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
# Run the main function
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -35,8 +35,10 @@ then
|
|||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp ../*.yaml ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cp -r ../model_utils ./train
|
||||
cd ./train || exit
|
||||
env > env.log
|
||||
echo "start training for device $DEVICE_ID"
|
||||
|
|
|
@ -18,8 +18,6 @@ GCN training script.
|
|||
"""
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import ast
|
||||
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
|
@ -34,6 +32,10 @@ from src.metrics import LossAccuracyWrapper, TrainNetWrapper
|
|||
from src.config import ConfigGCN
|
||||
from src.dataset import get_adj_features_labels, get_mask
|
||||
|
||||
from model_utils.config import config as default_args
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
from model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
|
||||
def t_SNE(out_feature, dim):
|
||||
t_sne = manifold.TSNE(n_components=dim, init='pca', random_state=0)
|
||||
|
@ -46,27 +48,79 @@ def update_graph(i, data, scat, plot):
|
|||
return scat, plot
|
||||
|
||||
|
||||
def train():
|
||||
"""Train model."""
|
||||
parser = argparse.ArgumentParser(description='GCN')
|
||||
parser.add_argument('--data_dir', type=str, default='./data/cora/cora_mr', help='Dataset directory')
|
||||
parser.add_argument('--train_nodes_num', type=int, default=140, help='Nodes numbers for training')
|
||||
parser.add_argument('--eval_nodes_num', type=int, default=500, help='Nodes numbers for evaluation')
|
||||
parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test')
|
||||
parser.add_argument('--save_TSNE', type=ast.literal_eval, default=False, help='Whether to save t-SNE graph')
|
||||
args_opt = parser.parse_args()
|
||||
if not os.path.exists("ckpts"):
|
||||
os.mkdir("ckpts")
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, default_args.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if default_args.need_modelarts_dataset_unzip:
|
||||
zip_file_1 = os.path.join(default_args.data_path, default_args.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(default_args.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if default_args.device_target == "Ascend":
|
||||
device_id = get_device_id()
|
||||
device_num = get_device_num()
|
||||
else:
|
||||
raise ValueError("Not support device_target.")
|
||||
|
||||
if device_id % min(device_num, 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(device_id, zip_file_1, save_dir_1))
|
||||
|
||||
default_args.save_ckptpath = os.path.join(default_args.output_path, default_args.save_ckptpath)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_train():
|
||||
"""Train model."""
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend", save_graphs=False)
|
||||
config = ConfigGCN()
|
||||
adj, feature, label_onehot, label = get_adj_features_labels(args_opt.data_dir)
|
||||
adj, feature, label_onehot, label = get_adj_features_labels(default_args.data_dir)
|
||||
|
||||
nodes_num = label_onehot.shape[0]
|
||||
train_mask = get_mask(nodes_num, 0, args_opt.train_nodes_num)
|
||||
eval_mask = get_mask(nodes_num, args_opt.train_nodes_num, args_opt.train_nodes_num + args_opt.eval_nodes_num)
|
||||
test_mask = get_mask(nodes_num, nodes_num - args_opt.test_nodes_num, nodes_num)
|
||||
train_mask = get_mask(nodes_num, 0, default_args.train_nodes_num)
|
||||
eval_mask = get_mask(nodes_num, default_args.train_nodes_num,
|
||||
default_args.train_nodes_num + default_args.eval_nodes_num)
|
||||
test_mask = get_mask(nodes_num, nodes_num - default_args.test_nodes_num, nodes_num)
|
||||
|
||||
class_num = label_onehot.shape[1]
|
||||
input_dim = feature.shape[1]
|
||||
|
@ -81,7 +135,7 @@ def train():
|
|||
|
||||
loss_list = []
|
||||
|
||||
if args_opt.save_TSNE:
|
||||
if default_args.save_TSNE:
|
||||
out_feature = gcn_net()
|
||||
tsne_result = t_SNE(out_feature.asnumpy(), 2)
|
||||
graph_data = []
|
||||
|
@ -108,7 +162,7 @@ def train():
|
|||
"train_acc=", "{:.5f}".format(train_accuracy), "val_loss=", "{:.5f}".format(eval_loss),
|
||||
"val_acc=", "{:.5f}".format(eval_accuracy), "time=", "{:.5f}".format(time.time() - t))
|
||||
|
||||
if args_opt.save_TSNE:
|
||||
if default_args.save_TSNE:
|
||||
out_feature = gcn_net()
|
||||
tsne_result = t_SNE(out_feature.asnumpy(), 2)
|
||||
graph_data.append(tsne_result)
|
||||
|
@ -116,9 +170,12 @@ def train():
|
|||
if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]):
|
||||
print("Early stopping...")
|
||||
break
|
||||
save_checkpoint(gcn_net, "ckpts/gcn.ckpt")
|
||||
if not os.path.isdir(default_args.save_ckptpath):
|
||||
os.makedirs(default_args.save_ckptpath)
|
||||
ckpt_path = os.path.join(default_args.save_ckptpath, "gcn.ckpt")
|
||||
save_checkpoint(gcn_net, ckpt_path)
|
||||
gcn_net_test = GCN(config, input_dim, class_num)
|
||||
load_checkpoint("ckpts/gcn.ckpt", net=gcn_net_test)
|
||||
load_checkpoint(ckpt_path, net=gcn_net_test)
|
||||
gcn_net_test.add_flags_recursive(fp16=True)
|
||||
|
||||
test_net = LossAccuracyWrapper(gcn_net_test, label_onehot, test_mask, config.weight_decay)
|
||||
|
@ -130,10 +187,10 @@ def train():
|
|||
print("Test set results:", "loss=", "{:.5f}".format(test_loss),
|
||||
"accuracy=", "{:.5f}".format(test_accuracy), "time=", "{:.5f}".format(time.time() - t_test))
|
||||
|
||||
if args_opt.save_TSNE:
|
||||
if default_args.save_TSNE:
|
||||
ani = animation.FuncAnimation(fig, update_graph, frames=range(config.epochs + 1), fargs=(graph_data, scat, plt))
|
||||
ani.save('t-SNE_visualization.gif', writer='imagemagick')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
||||
run_train()
|
||||
|
|
Loading…
Reference in New Issue