forked from mindspore-Ecosystem/mindspore
!17738 modify BGCF network for clould
From: @zhanghuiyao Reviewed-by: @c_34,@oacjiewen Signed-off-by: @c_34
This commit is contained in:
commit
2395a0b198
|
@ -21,7 +21,7 @@ import numpy as np
|
|||
import mindspore.nn as nn
|
||||
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.communication.management import get_rank, get_group_size
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
@ -95,6 +95,7 @@ def modelarts_pre_process():
|
|||
|
||||
# Each server contains 8 devices as most.
|
||||
if config.device_target == "GPU":
|
||||
init()
|
||||
device_id = get_rank()
|
||||
device_num = get_group_size()
|
||||
elif config.device_target == "Ascend":
|
||||
|
|
|
@ -82,6 +82,7 @@ def modelarts_pre_process():
|
|||
|
||||
# Each server contains 8 devices as most.
|
||||
if config.device_target == "GPU":
|
||||
init()
|
||||
device_id = get_rank()
|
||||
device_num = get_group_size()
|
||||
elif config.device_target == "Ascend":
|
||||
|
@ -129,7 +130,11 @@ def run_train():
|
|||
init()
|
||||
context.set_context(device_id=config.device_id)
|
||||
elif config.device_target == "GPU":
|
||||
init()
|
||||
if not config.enable_modelarts:
|
||||
init()
|
||||
else:
|
||||
if not config.need_modelarts_dataset_unzip:
|
||||
init()
|
||||
|
||||
device_num = config.group_size
|
||||
context.reset_auto_parallel_context()
|
||||
|
|
|
@ -90,17 +90,17 @@ To ultilize the strong computation power of Ascend chip, and accelerate the trai
|
|||
|
||||
After installing MindSpore via the official website and Dataset is correctly generated, you can start training and evaluation as follows.
|
||||
|
||||
- running on Ascend
|
||||
- Running on Ascend
|
||||
|
||||
```python
|
||||
# run training example with Amazon-Beauty dataset
|
||||
sh run_train_ascend.sh
|
||||
sh run_train_ascend.sh dataset_path
|
||||
|
||||
# run evaluation example with Amazon-Beauty dataset
|
||||
sh run_eval_ascend.sh
|
||||
sh run_eval_ascend.sh dataset_path
|
||||
```
|
||||
|
||||
- running on GPU
|
||||
- Running on GPU
|
||||
|
||||
```python
|
||||
# run training example with Amazon-Beauty dataset
|
||||
|
@ -110,6 +110,62 @@ After installing MindSpore via the official website and Dataset is correctly gen
|
|||
sh run_eval_gpu.sh 0 dataset_path
|
||||
```
|
||||
|
||||
- Running on ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows)
|
||||
|
||||
- Train 1p on ModelArts Ascend/GPU
|
||||
|
||||
```python
|
||||
# (1) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "datapath='/cache/data/amazon_beauty/data_mr'" on default_config.yaml file.
|
||||
# Set "ckptpath='./ckpts'" on default_config.yaml file.
|
||||
# (options)Set "device_target='GPU'" on default_config.yaml file if run on GPU.
|
||||
# (options)Set "num_epoch=680" on default_config.yaml file if run on GPU.
|
||||
# (options)Set "dist_reg=0" on default_config.yaml file if run on GPU.
|
||||
# Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "datapath=/cache/data/amazon_beauty/data_mr" on the website UI interface.
|
||||
# Add "ckptpath=./ckpts" on the website UI interface.
|
||||
# (options)Add "device_target=GPU" on the website UI interface if run on GPU.
|
||||
# (options)Add "num_epoch=680" on the website UI interface if run on GPU.
|
||||
# (options)Add "dist_reg=0" on the website UI interface if run on GPU.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (2) Prepare the converted dataset and zip it to one file like "amazon_beauty.zip" locally. (The conversion process can refer to the above data set processing code.)
|
||||
# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
|
||||
# (4) Set the code directory to "/path/bgcf" on the website UI interface.
|
||||
# (5) Set the startup file to "train.py" on the website UI interface.
|
||||
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (7) Create your job.
|
||||
```
|
||||
|
||||
- Eval 1p on ModelArts Ascend/GPU
|
||||
|
||||
```python
|
||||
# (1) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "datapath='/cache/data/amazon_beauty/data_mr'" on default_config.yaml file.
|
||||
# Set "ckptpath='/cache/checkpoint_path'" on default_config.yaml file.
|
||||
# Set "checkpoint_url='s3://dir_to_your_trained_ckpt/'" on default_config.yaml file.
|
||||
# (options)Set "device_target='GPU'" on default_config.yaml file if run on GPU.
|
||||
# (options)Set "num_epoch=680" on default_config.yaml file if run on GPU.
|
||||
# (options)Set "dist_reg=0" on default_config.yaml file if run on GPU.
|
||||
# Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "datapath=/cache/data/amazon_beauty/data_mr" on the website UI interface.
|
||||
# Add "ckptpath='/cache/checkpoint_path'" on the website UI interface.
|
||||
# Add "checkpoint_url='s3://dir_to_your_trained_ckpt/'" on the website UI interface.
|
||||
# (options)Add "device_target=GPU" on the website UI interface if run on GPU.
|
||||
# (options)Add "num_epoch=680" on the website UI interface if run on GPU.
|
||||
# (options)Add "dist_reg=0" on the website UI interface if run on GPU.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (2) Prepare the converted dataset and zip it to one file like "amazon_beauty.zip" locally. (The conversion process can refer to the above data set processing code.)
|
||||
# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
|
||||
# (4) Set the code directory to "/path/bgcf" on the website UI interface.
|
||||
# (5) Set the startup file to "eval.py" on the website UI interface.
|
||||
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (7) Create your job.
|
||||
```
|
||||
|
||||
## [Script Description](#contents)
|
||||
|
||||
### [Script and Sample Code](#contents)
|
||||
|
@ -118,28 +174,35 @@ After installing MindSpore via the official website and Dataset is correctly gen
|
|||
.
|
||||
└─bgcf
|
||||
├─README.md
|
||||
├─README_CN.md
|
||||
├─model_utils
|
||||
| ├─__init__.py # Module init file
|
||||
| ├─config.py # Parse arguments
|
||||
| ├─device_adapter.py # Device adapter for ModelArts
|
||||
| ├─local_adapter.py # Local adapter
|
||||
| └─moxing_adapter.py # Moxing adapter for ModelArts
|
||||
├─scripts
|
||||
| ├─run_eval_ascend.sh # Launch evaluation in ascend
|
||||
| ├─run_eval_gpu.sh # Launch evaluation in gpu
|
||||
| ├─run_process_data_ascend.sh # Generate dataset in mindrecord format
|
||||
| └─run_train_ascend.sh # Launch training in ascend
|
||||
| └─run_train_gpu.sh # Launch training in gpu
|
||||
|
|
||||
├─src
|
||||
| ├─bgcf.py # BGCF model
|
||||
| ├─callback.py # Callback function
|
||||
| ├─config.py # Training configurations
|
||||
| ├─dataset.py # Data preprocessing
|
||||
| ├─metrics.py # Recommendation metrics
|
||||
| └─utils.py # Utils for training bgcf
|
||||
|
|
||||
├─default_config.yaml # Configurations file
|
||||
├─mindspore_hub_conf.py # Mindspore hub file
|
||||
├─export.py # Export net
|
||||
├─eval.py # Evaluation net
|
||||
└─train.py # Train net
|
||||
```
|
||||
|
||||
### [Script Parameters](#contents)
|
||||
|
||||
Parameters for both training and evaluation can be set in config.py.
|
||||
Parameters for both training and evaluation can be set in default_config.yaml.
|
||||
|
||||
- config for BGCF dataset
|
||||
|
||||
|
@ -154,7 +217,7 @@ Parameters for both training and evaluation can be set in config.py.
|
|||
"neighbor_dropout": [0.0, 0.2, 0.3] # Dropout ratio for different aggregation layer
|
||||
```
|
||||
|
||||
config.py for more configuration.
|
||||
default_config.yaml for more configuration.
|
||||
|
||||
### [Training Process](#contents)
|
||||
|
||||
|
@ -163,7 +226,7 @@ Parameters for both training and evaluation can be set in config.py.
|
|||
- running on Ascend
|
||||
|
||||
```python
|
||||
sh run_train_ascend.sh
|
||||
sh run_train_ascend.sh dataset_path
|
||||
```
|
||||
|
||||
Training result will be stored in the scripts path, whose folder name begins with "train". You can find the result like the
|
||||
|
@ -206,7 +269,7 @@ Parameters for both training and evaluation can be set in config.py.
|
|||
- Evaluation on Ascend
|
||||
|
||||
```python
|
||||
sh run_eval_ascend.sh
|
||||
sh run_eval_ascend.sh dataset_path
|
||||
```
|
||||
|
||||
Evaluation result will be stored in the scripts path, whose folder name begins with "eval". You can find the result like the
|
||||
|
@ -279,7 +342,7 @@ Parameters for both training and evaluation can be set in config.py.
|
|||
|
||||
## [Description of random situation](#contents)
|
||||
|
||||
BGCF model contains lots of dropout operations, if you want to disable dropout, set the neighbor_dropout to [0.0, 0.0, 0.0] in src/config.py.
|
||||
BGCF model contains lots of dropout operations, if you want to disable dropout, set the neighbor_dropout to [0.0, 0.0, 0.0] in default_config.yaml.
|
||||
|
||||
## [ModelZoo Homepage](#contents)
|
||||
|
||||
|
|
|
@ -104,10 +104,10 @@ BGCF包含两个主要模块。首先是抽样,它生成基于节点复制的
|
|||
```text
|
||||
|
||||
# 使用Amazon-Beauty数据集运行训练示例
|
||||
sh run_train_ascend.sh
|
||||
sh run_train_ascend.sh dataset_path
|
||||
|
||||
# 使用Amazon-Beauty数据集运行评估示例
|
||||
sh run_eval_ascend.sh
|
||||
sh run_eval_ascend.sh dataset_path
|
||||
|
||||
```
|
||||
|
||||
|
@ -123,37 +123,96 @@ BGCF包含两个主要模块。首先是抽样,它生成基于节点复制的
|
|||
|
||||
```
|
||||
|
||||
- 在 ModelArts 进行训练 (如果你想在modelarts上运行,可以参考以下文档 [modelarts](https://support.huaweicloud.com/modelarts/))
|
||||
|
||||
- 在 ModelArts 上使用单卡训练(GPU or Ascend)
|
||||
|
||||
```python
|
||||
# (1) 执行a或者b
|
||||
# a. 在 default_config.yaml 文件中设置 "enable_modelarts=True"
|
||||
# 在 default_config.yaml 文件中设置 "datapath='/cache/data/amazon_beauty/data_mr'"
|
||||
# 在 default_config.yaml 文件中设置 "ckptpath='./ckpts'"
|
||||
# (可选)如果选择GPU运行,在 default_config.yaml 文件中设置 "device_target='GPU'"
|
||||
# (可选)如果选择GPU运行,在 default_config.yaml 文件中设置 "num_epoch=680"
|
||||
# (可选)如果选择GPU运行,在 default_config.yaml 文件中设置 "dist_reg=0"
|
||||
# 在 default_config.yaml 文件中设置 其他参数
|
||||
# b. 在网页上设置 "enable_modelarts=True"
|
||||
# 在网页上设置 "datapath=/cache/data/amazon_beauty/data_mr"
|
||||
# 在网页上设置 "ckptpath=./ckpts"
|
||||
# (可选)如果选择GPU运行,在网页上设置 "device_target=GPU"
|
||||
# (可选)如果选择GPU运行,在网页上设置 "num_epoch=680"
|
||||
# (可选)如果选择GPU运行,在网页上设置 "dist_reg=0"
|
||||
# 在网页上设置 其他参数
|
||||
# (2) 在本地准备转换好的数据集并将其压缩为一个文件,如:"amazon_beauty.zip" (数据集转换代码可以参考上面的Dataset章节)
|
||||
# (3) 上传你的压缩数据集到 S3 桶上 (你也可以上传原始的数据集,但那可能会很慢。)
|
||||
# (4) 在网页上设置你的代码路径为 "/path/googlenet"
|
||||
# (5) 在网页上设置启动文件为 "train.py"
|
||||
# (6) 在网页上设置"训练数据集"、"训练输出文件路径"、"作业日志路径"等
|
||||
# (7) 创建训练作业
|
||||
```
|
||||
|
||||
- 在 ModelArts 上使用单卡验证(GPU or Ascend)
|
||||
|
||||
```python
|
||||
# (1) 执行a或者b
|
||||
# a. 在 default_config.yaml 文件中设置 "enable_modelarts=True"
|
||||
# 在 default_config.yaml 文件中设置 "datapath='/cache/data/amazon_beauty/data_mr'"
|
||||
# 在 default_config.yaml 文件中设置 "ckptpath='./ckpts'"
|
||||
# (可选)如果选择GPU运行,在 default_config.yaml 文件中设置 "device_target='GPU'"
|
||||
# (可选)如果选择GPU运行,在 default_config.yaml 文件中设置 "num_epoch=680"
|
||||
# (可选)如果选择GPU运行,在 default_config.yaml 文件中设置 "dist_reg=0"
|
||||
# 在 default_config.yaml 文件中设置 其他参数
|
||||
# b. 在网页上设置 "enable_modelarts=True"
|
||||
# 在网页上设置 "datapath=/cache/data/amazon_beauty/data_mr"
|
||||
# 在网页上设置 "ckptpath=./ckpts"
|
||||
# (可选)如果选择GPU运行,在网页上设置 "device_target=GPU"
|
||||
# (可选)如果选择GPU运行,在网页上设置 "num_epoch=680"
|
||||
# (可选)如果选择GPU运行,在网页上设置 "dist_reg=0"
|
||||
# 在网页上设置 其他参数
|
||||
# (2) 在本地准备转换好的数据集并将其压缩为一个文件,如:"amazon_beauty.zip" (数据集转换代码可以参考上面的Dataset章节)
|
||||
# (3) 上传你的压缩数据集到 S3 桶上 (你也可以上传原始的数据集,但那可能会很慢。)
|
||||
# (4) 在网页上设置你的代码路径为 "/path/googlenet"
|
||||
# (5) 在网页上设置启动文件为 "eval.py"
|
||||
# (6) 在网页上设置"训练数据集"、"训练输出文件路径"、"作业日志路径"等
|
||||
# (7) 创建训练作业
|
||||
```
|
||||
|
||||
## 脚本说明
|
||||
|
||||
### 脚本及样例代码
|
||||
|
||||
```shell
|
||||
|
||||
└─bgcf
|
||||
├─README.md
|
||||
├─README_CN.md
|
||||
├─model_utils
|
||||
| ├─__init__.py # 初始化文件
|
||||
| ├─config.py # 参数获取文件
|
||||
| ├─device_adapter.py # modelarts 设备适配文件
|
||||
| ├─local_adapter.py # 本地适配文件
|
||||
| └─moxing_adapter.py # modelarts 模型适配文件
|
||||
├─scripts
|
||||
| ├─run_eval_ascend.sh # Ascend启动评估
|
||||
| ├─run_eval_gpu.sh # GPU启动评估
|
||||
| ├─run_process_data_ascend.sh # 生成MindRecord格式的数据集
|
||||
| └─run_train_ascend.sh # Ascend启动训练
|
||||
| └─run_train_gpu.sh # GPU启动训练
|
||||
|
|
||||
├─src
|
||||
| ├─bgcf.py # BGCF模型
|
||||
| ├─callback.py # 回调函数
|
||||
| ├─config.py # 训练配置
|
||||
| ├─dataset.py # 数据预处理
|
||||
| ├─metrics.py # 推荐指标
|
||||
| └─utils.py # 训练BGCF的工具
|
||||
|
|
||||
├─default_config.yaml # 参数配置文件
|
||||
├─mindspore_hub_conf.py # Mindspore hub文件
|
||||
├─export.py # 导出网络
|
||||
├─eval.py # 评估网络
|
||||
└─train.py # 训练网络
|
||||
|
||||
```
|
||||
|
||||
### 脚本参数
|
||||
|
||||
在config.py中可以同时配置训练参数和评估参数。
|
||||
在 default_config.yaml 中可以同时配置训练参数和评估参数。
|
||||
|
||||
- BGCF数据集配置
|
||||
|
||||
|
@ -171,7 +230,7 @@ BGCF包含两个主要模块。首先是抽样,它生成基于节点复制的
|
|||
|
||||
```
|
||||
|
||||
在config.py中以获取更多配置。
|
||||
在 default_config.yaml 中以获取更多配置。
|
||||
|
||||
### 训练过程
|
||||
|
||||
|
@ -181,7 +240,7 @@ BGCF包含两个主要模块。首先是抽样,它生成基于节点复制的
|
|||
|
||||
```python
|
||||
|
||||
sh run_train_ascend.sh
|
||||
sh run_train_ascend.sh dataset_path
|
||||
|
||||
```
|
||||
|
||||
|
@ -229,7 +288,7 @@ BGCF包含两个主要模块。首先是抽样,它生成基于节点复制的
|
|||
|
||||
```python
|
||||
|
||||
sh run_eval_ascend.sh
|
||||
sh run_eval_ascend.sh dataset_path
|
||||
|
||||
```
|
||||
|
||||
|
@ -305,7 +364,7 @@ BGCF包含两个主要模块。首先是抽样,它生成基于节点复制的
|
|||
|
||||
## 随机情况说明
|
||||
|
||||
BGCF模型中有很多的dropout操作,如果想关闭dropout,可以在src/config.py中将neighbor_dropout设置为[0.0, 0.0, 0.0] 。
|
||||
BGCF模型中有很多的dropout操作,如果想关闭dropout,可以在 ./default_config.yaml 中将neighbor_dropout设置为[0.0, 0.0, 0.0] 。
|
||||
|
||||
## ModelZoo主页
|
||||
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
need_modelarts_dataset_unzip: True
|
||||
modelarts_dataset_unzip_name: "amazon_beauty"
|
||||
|
||||
# ==============================================================================
|
||||
# options
|
||||
dataset: "Beauty"
|
||||
datapath: "./scripts/data_mr"
|
||||
Ks: [5, 10, 20, 100]
|
||||
workers: 8
|
||||
ckptpath: "../ckpts"
|
||||
epsilon: 0.00000001 # 1e-8
|
||||
learning_rate: 0.001 # 1e-3
|
||||
l2: 0.03
|
||||
activation: "tanh"
|
||||
neighbor_dropout: [0.0, 0.2, 0.3]
|
||||
log_name: "test"
|
||||
num_epoch: 600
|
||||
input_dim: 64
|
||||
batch_pairs: 5000
|
||||
eval_interval: 20
|
||||
num_neg: 10
|
||||
raw_neighs: 40
|
||||
gnew_neighs: 20
|
||||
embedded_dimension: 64
|
||||
dist_reg: 0.003
|
||||
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
dataset: "choose which dataset"
|
||||
datapath: "minddata path"
|
||||
Ks: "top K"
|
||||
workers: "number of process to generate data"
|
||||
ckptpath: "checkpoint path"
|
||||
epsilon: "optimizer parameter"
|
||||
learning_rate: "learning rate"
|
||||
l2: "l2 coefficient"
|
||||
activation: "activation function, choices in ['relu', 'tanh']."
|
||||
neighbor_dropout: "dropout ratio for different aggregation layer"
|
||||
log_name: "log name"
|
||||
num_epoch: "epoch sizes for training"
|
||||
input_dim: "user and item embedding dimension, choices in [64, 128]"
|
||||
batch_pairs: "batch size"
|
||||
eval_interval: "evaluation interval"
|
||||
num_neg: "negative sampling rate "
|
||||
raw_neighs: "num of sampling neighbors in raw graph"
|
||||
gnew_neighs: "num of sampling neighbors in sample graph"
|
||||
embedded_dimension: "output embedding dim"
|
||||
dist_reg: "distance loss coefficient"
|
||||
device_target: "device target, choices in ['Ascend', GPU]"
|
|
@ -15,6 +15,8 @@
|
|||
"""
|
||||
BGCF evaluation script.
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import datetime
|
||||
|
||||
import mindspore.context as context
|
||||
|
@ -23,38 +25,117 @@ from mindspore.common import set_seed
|
|||
|
||||
from src.bgcf import BGCF
|
||||
from src.utils import BGCFLogger
|
||||
from src.config import parser_args
|
||||
from src.metrics import BGCFEvaluate
|
||||
from src.callback import ForwardBGCF, TestBGCF
|
||||
from src.dataset import TestGraphDataset, load_graph
|
||||
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
from model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
set_seed(1)
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if config.need_modelarts_dataset_unzip:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def evaluation():
|
||||
"""evaluation"""
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target=config.device_target,
|
||||
save_graphs=False)
|
||||
if config.device_target == "Ascend":
|
||||
context.set_context(device_id=get_device_id())
|
||||
|
||||
train_graph, test_graph, sampled_graph_list = load_graph(config.datapath)
|
||||
test_graph_dataset = TestGraphDataset(train_graph, sampled_graph_list, num_samples=config.raw_neighs,
|
||||
num_bgcn_neigh=config.gnew_neighs,
|
||||
num_neg=config.num_neg)
|
||||
|
||||
if config.log_name:
|
||||
now = datetime.datetime.now().strftime("%b_%d_%H_%M_%S")
|
||||
name = "bgcf" + '-' + config.log_name + '-' + config.dataset
|
||||
log_save_path = './log-files/' + name + '/' + now
|
||||
log = BGCFLogger(logname=name, now=now, foldername='log-files', copy=False)
|
||||
log.open(log_save_path + '/log.train.txt', mode='a')
|
||||
for arg in vars(config):
|
||||
log.write(arg + '=' + str(getattr(config, arg)) + '\n')
|
||||
else:
|
||||
for arg in vars(config):
|
||||
print(arg + '=' + str(getattr(config, arg)))
|
||||
|
||||
num_user = train_graph.graph_info()["node_num"][0]
|
||||
num_item = train_graph.graph_info()["node_num"][1]
|
||||
|
||||
eval_class = BGCFEvaluate(parser, train_graph, test_graph, parser.Ks)
|
||||
for _epoch in range(parser.eval_interval, parser.num_epoch+1, parser.eval_interval) \
|
||||
if parser.device_target == "Ascend" else range(parser.num_epoch, parser.num_epoch+1):
|
||||
bgcfnet_test = BGCF([parser.input_dim, num_user, num_item],
|
||||
parser.embedded_dimension,
|
||||
parser.activation,
|
||||
eval_class = BGCFEvaluate(config, train_graph, test_graph, config.Ks)
|
||||
for _epoch in range(config.eval_interval, config.num_epoch+1, config.eval_interval) \
|
||||
if config.device_target == "Ascend" else range(config.num_epoch, config.num_epoch+1):
|
||||
bgcfnet_test = BGCF([config.input_dim, num_user, num_item],
|
||||
config.embedded_dimension,
|
||||
config.activation,
|
||||
[0.0, 0.0, 0.0],
|
||||
num_user,
|
||||
num_item,
|
||||
parser.input_dim)
|
||||
config.input_dim)
|
||||
|
||||
load_checkpoint(parser.ckptpath + "/bgcf_epoch{}.ckpt".format(_epoch), net=bgcfnet_test)
|
||||
load_checkpoint(config.ckptpath + "/bgcf_epoch{}.ckpt".format(_epoch), net=bgcfnet_test)
|
||||
|
||||
forward_net = ForwardBGCF(bgcfnet_test)
|
||||
user_reps, item_reps = TestBGCF(forward_net, num_user, num_item, parser.input_dim, test_graph_dataset)
|
||||
user_reps, item_reps = TestBGCF(forward_net, num_user, num_item, config.input_dim, test_graph_dataset)
|
||||
|
||||
test_recall_bgcf, test_ndcg_bgcf, \
|
||||
test_sedp, test_nov = eval_class.eval_with_rep(user_reps, item_reps, parser)
|
||||
test_sedp, test_nov = eval_class.eval_with_rep(user_reps, item_reps, config)
|
||||
|
||||
if parser.log_name:
|
||||
if config.log_name:
|
||||
log.write(
|
||||
'epoch:%03d, recall_@10:%.5f, recall_@20:%.5f, ndcg_@10:%.5f, ndcg_@20:%.5f, '
|
||||
'sedp_@10:%.5f, sedp_@20:%.5f, nov_@10:%.5f, nov_@20:%.5f\n' % (_epoch,
|
||||
|
@ -80,28 +161,4 @@ def evaluation():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = parser_args()
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target=parser.device_target,
|
||||
save_graphs=False)
|
||||
if parser.device_target == "Ascend":
|
||||
context.set_context(device_id=int(parser.device))
|
||||
|
||||
train_graph, test_graph, sampled_graph_list = load_graph(parser.datapath)
|
||||
test_graph_dataset = TestGraphDataset(train_graph, sampled_graph_list, num_samples=parser.raw_neighs,
|
||||
num_bgcn_neigh=parser.gnew_neighs,
|
||||
num_neg=parser.num_neg)
|
||||
|
||||
if parser.log_name:
|
||||
now = datetime.datetime.now().strftime("%b_%d_%H_%M_%S")
|
||||
name = "bgcf" + '-' + parser.log_name + '-' + parser.dataset
|
||||
log_save_path = './log-files/' + name + '/' + now
|
||||
log = BGCFLogger(logname=name, now=now, foldername='log-files', copy=False)
|
||||
log.open(log_save_path + '/log.train.txt', mode='a')
|
||||
for arg in vars(parser):
|
||||
log.write(arg + '=' + str(getattr(parser, arg)) + '\n')
|
||||
else:
|
||||
for arg in vars(parser):
|
||||
print(arg + '=' + str(getattr(parser, arg)))
|
||||
|
||||
evaluation()
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pformat
|
||||
import yaml
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from .config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
# Run the main function
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -14,6 +14,13 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ]
|
||||
then
|
||||
echo "Usage: sh run_train_ascend.sh [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
DATASET_PATH=$1
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
|
@ -26,12 +33,14 @@ fi
|
|||
mkdir ./eval
|
||||
|
||||
cp ../*.py ./eval
|
||||
cp ../*.yaml ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cp -r ../model_utils ./eval
|
||||
cd ./eval || exit
|
||||
env > env.log
|
||||
echo "start evaluation"
|
||||
|
||||
python eval.py --datapath=../data_mr --ckptpath=../ckpts &> log &
|
||||
python eval.py --datapath=$DATASET_PATH --ckptpath=../ckpts &> log &
|
||||
|
||||
cd ..
|
||||
|
|
|
@ -32,8 +32,10 @@ fi
|
|||
mkdir ./eval
|
||||
|
||||
cp ../*.py ./eval
|
||||
cp ../*.yaml ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cp -r ../model_utils ./eval
|
||||
cd ./eval || exit
|
||||
env > env.log
|
||||
echo "start evaluation"
|
||||
|
|
|
@ -14,6 +14,13 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ]
|
||||
then
|
||||
echo "Usage: sh run_train_ascend.sh [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
DATASET_PATH=$1
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
|
@ -32,12 +39,14 @@ fi
|
|||
mkdir ./ckpts
|
||||
|
||||
cp ../*.py ./train
|
||||
cp ../*.yaml ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cp -r ../model_utils ./train
|
||||
cd ./train || exit
|
||||
env > env.log
|
||||
echo "start training"
|
||||
|
||||
python train.py --datapath=../data_mr --ckptpath=../ckpts &> log &
|
||||
python train.py --datapath=$DATASET_PATH --ckptpath=../ckpts &> log &
|
||||
|
||||
cd ..
|
||||
|
|
|
@ -36,8 +36,10 @@ fi
|
|||
mkdir ./ckpts
|
||||
|
||||
cp ../*.py ./train
|
||||
cp ../*.yaml ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cp -r ../model_utils ./train
|
||||
cd ./train || exit
|
||||
env > env.log
|
||||
echo "start training"
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""
|
||||
BGCF training script.
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
|
||||
from mindspore import Tensor
|
||||
|
@ -24,35 +25,107 @@ from mindspore.train.serialization import save_checkpoint
|
|||
from mindspore.common import set_seed
|
||||
|
||||
from src.bgcf import BGCF
|
||||
from src.config import parser_args
|
||||
from src.utils import convert_item_id
|
||||
from src.callback import TrainBGCF
|
||||
from src.dataset import load_graph, create_dataset
|
||||
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
from model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
set_seed(1)
|
||||
|
||||
def train():
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if config.need_modelarts_dataset_unzip:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
|
||||
config.ckptpath = os.path.join(config.output_path, config.ckptpath)
|
||||
if not os.path.isdir(config.ckptpath):
|
||||
os.makedirs(config.ckptpath)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_train():
|
||||
"""Train"""
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target=config.device_target,
|
||||
save_graphs=False)
|
||||
|
||||
if config.device_target == "Ascend":
|
||||
context.set_context(device_id=get_device_id())
|
||||
|
||||
train_graph, _, sampled_graph_list = load_graph(config.datapath)
|
||||
train_ds = create_dataset(train_graph, sampled_graph_list, config.workers, batch_size=config.batch_pairs,
|
||||
num_samples=config.raw_neighs, num_bgcn_neigh=config.gnew_neighs, num_neg=config.num_neg)
|
||||
|
||||
num_user = train_graph.graph_info()["node_num"][0]
|
||||
num_item = train_graph.graph_info()["node_num"][1]
|
||||
num_pairs = train_graph.graph_info()['edge_num'][0]
|
||||
|
||||
bgcfnet = BGCF([parser.input_dim, num_user, num_item],
|
||||
parser.embedded_dimension,
|
||||
parser.activation,
|
||||
parser.neighbor_dropout,
|
||||
bgcfnet = BGCF([config.input_dim, num_user, num_item],
|
||||
config.embedded_dimension,
|
||||
config.activation,
|
||||
config.neighbor_dropout,
|
||||
num_user,
|
||||
num_item,
|
||||
parser.input_dim)
|
||||
config.input_dim)
|
||||
|
||||
train_net = TrainBGCF(bgcfnet, parser.num_neg, parser.l2, parser.learning_rate,
|
||||
parser.epsilon, parser.dist_reg)
|
||||
train_net = TrainBGCF(bgcfnet, config.num_neg, config.l2, config.learning_rate,
|
||||
config.epsilon, config.dist_reg)
|
||||
train_net.set_train(True)
|
||||
|
||||
itr = train_ds.create_dict_iterator(parser.num_epoch, output_numpy=True)
|
||||
num_iter = int(num_pairs / parser.batch_pairs)
|
||||
itr = train_ds.create_dict_iterator(config.num_epoch, output_numpy=True)
|
||||
num_iter = int(num_pairs / config.batch_pairs)
|
||||
|
||||
for _epoch in range(1, parser.num_epoch + 1):
|
||||
for _epoch in range(1, config.num_epoch + 1):
|
||||
|
||||
epoch_start = time.time()
|
||||
iter_num = 1
|
||||
|
@ -98,22 +171,9 @@ def train():
|
|||
'{}, cost:{:.4f}'.format(train_loss, time.time() - epoch_start))
|
||||
iter_num += 1
|
||||
|
||||
if _epoch % parser.eval_interval == 0:
|
||||
save_checkpoint(bgcfnet, parser.ckptpath + "/bgcf_epoch{}.ckpt".format(_epoch))
|
||||
if _epoch % config.eval_interval == 0:
|
||||
save_checkpoint(bgcfnet, config.ckptpath + "/bgcf_epoch{}.ckpt".format(_epoch))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = parser_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target=parser.device_target,
|
||||
save_graphs=False)
|
||||
|
||||
if parser.device_target == "Ascend":
|
||||
context.set_context(device_id=int(parser.device))
|
||||
|
||||
train_graph, _, sampled_graph_list = load_graph(parser.datapath)
|
||||
train_ds = create_dataset(train_graph, sampled_graph_list, parser.workers, batch_size=parser.batch_pairs,
|
||||
num_samples=parser.raw_neighs, num_bgcn_neigh=parser.gnew_neighs, num_neg=parser.num_neg)
|
||||
|
||||
train()
|
||||
run_train()
|
||||
|
|
|
@ -29,10 +29,10 @@ def test_BGCF_amazon_beauty():
|
|||
utils.copy_files(model_path, cur_path, model_name)
|
||||
cur_model_path = os.path.join(cur_path, model_name)
|
||||
|
||||
old_list = ["default=600,"]
|
||||
new_list = ["default=50,"]
|
||||
utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "src/config.py"))
|
||||
old_list = ["context.set_context(device_id=int(parser.device))",
|
||||
old_list = ["num_epoch: 600"]
|
||||
new_list = ["num_epoch: 50"]
|
||||
utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "default_config.yaml"))
|
||||
old_list = ["context.set_context(device_id=get_device_id())",
|
||||
"save_checkpoint("]
|
||||
new_list = ["context.set_context()",
|
||||
"pass \\# save_checkpoint("]
|
||||
|
|
Loading…
Reference in New Issue