!12538 add naml to modelzoo
From: @zhao_ting_v Reviewed-by: @guoqi1024,@wuxuejian Signed-off-by: @wuxuejian
This commit is contained in:
commit
5fadc61222
|
@ -52,6 +52,7 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
|
|||
- [Transformer](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/transformer/README.md)
|
||||
- [Recommender Systems](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend)
|
||||
- [DeepFM](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/deepfm/README.md)
|
||||
- [NAML](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/naml/README.md)
|
||||
- [Wide&Deep[benchmark]](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/wide_and_deep/README.md)
|
||||
- [Graph Neural Networks](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn)
|
||||
- [BGCF](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf/README.md)
|
||||
|
|
|
@ -52,6 +52,7 @@
|
|||
- [Transformer](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/transformer/README.md)
|
||||
- [推荐系统](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend)
|
||||
- [DeepFM](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/deepfm/README.md)
|
||||
- [NAML](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/naml/README.md)
|
||||
- [Wide&Deep[基准]](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/wide_and_deep/README.md)
|
||||
- [图神经网络](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn)
|
||||
- [BGCF](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf/README.md)
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
# Contents
|
||||
|
||||
- [NAML Description](#NAML-description)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Training Process](#training-process)
|
||||
- [Model Export](#model-export)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [Inference Performance](#evaluation-performance)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [NAML Description](#contents)
|
||||
|
||||
NAML is a multi-view news recommendation approach. The core of NAML is a news encoder and a user encoder. The newsencoder is composed of a title encoder, a abstract encoder, a category encoder and a subcategory encoder. In the user encoder, we learn representations of users from their browsed news. Besides, we apply additive attention to learn more informative news and user representations by selecting important words and news.
|
||||
|
||||
[Paper](https://arxiv.org/abs/1907.05576) Chuhan Wu, Fangzhao Wu, Mingxiao An, Jianqiang Huang, Yongfeng Huang and Xing Xie: Neural News Recommendation with Attentive Multi-View Learning, IJCAI 2019
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Dataset used: [MIND](https://msnews.github.io/)
|
||||
|
||||
MIND contains about 160k English news articles and more than 15 million impression logs generated by 1 million users.
|
||||
|
||||
You can download the dataset and put the directory in structure as follows:
|
||||
|
||||
```path
|
||||
└─MINDlarge
|
||||
├─MINDlarge_train
|
||||
├─MINDlarge_dev
|
||||
└─MINDlarge_utils
|
||||
```
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend/GPU)
|
||||
- Prepare hardware environment with Ascend, GPU processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Script description](#contents)
|
||||
|
||||
## [Script and sample code](#contents)
|
||||
|
||||
```path
|
||||
├── naml
|
||||
├── README.md # descriptions about NAML
|
||||
├── scripts
|
||||
│ ├──run_train.sh # shell script for training
|
||||
│ ├──run_eval.sh # shell script for evaluation
|
||||
├── src
|
||||
│ ├──option.py # parse args
|
||||
│ ├──callback.py # callback file
|
||||
│ ├──dataset.py # creating dataset
|
||||
│ ├──naml.py # NAML architecture
|
||||
│ ├──config.py # config file
|
||||
│ ├──utils.py # utils to load ckpt_file for fine tune or incremental learn
|
||||
├── train.py # training script
|
||||
├── eval.py # evaluation script
|
||||
├── export.py # export mindir script
|
||||
```
|
||||
|
||||
## [Training process](#contents)
|
||||
|
||||
### Usage
|
||||
|
||||
You can start training using python or shell scripts. The usage of shell scripts as follows:
|
||||
|
||||
```shell
|
||||
bash run_train.sh [PLATFORM] [DEVICE_ID] [DATASET] [DATASET_PATH]
|
||||
bash run_eval.sh [PLATFORM] [DEVICE_ID] [DATASET] [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
- `PLATFORM` should be Ascend.
|
||||
- `DEVICE_ID` is the device id you want to run the network.
|
||||
- `DATASET` MIND dataset, support large, small and demo.
|
||||
- `DATASET_PATH` is the dataset path, the structure as [Dataset](#dataset).
|
||||
- `CHECKPOINT_PATH` is a pre-trained checkpoint path.
|
||||
|
||||
## [Model Export](#contents)
|
||||
|
||||
```shell
|
||||
python export.py --platform [PLATFORM] --checkpoint_path [CHECKPOINT_PATH] --file_format [EXPORT_FORMAT] --batch_size [BATCH_SIZE]
|
||||
```
|
||||
|
||||
- `EXPORT_FORMAT` should be in ["AIR", "ONNX", "MINDIR"]
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### Evaluation Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | ------------------------------------------------------------ |
|
||||
| Model Version | NAML |
|
||||
| Resource | Ascend 910 ;CPU 2.60GHz,56cores;Memory,314G |
|
||||
| uploaded Date | 02/23/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 |
|
||||
| Dataset | MINDlarge |
|
||||
| Training Parameters | epoch=1, steps=52869, batch_size=64, lr=0.001 |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | Softmax Cross Entropy |
|
||||
| outputs | probability |
|
||||
| Speed | 1pc: 62 ms/step |
|
||||
| Total time | 1pc: 54 mins |
|
||||
|
||||
### Inference Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | NAML |
|
||||
| Resource | Ascend 910 |
|
||||
| Uploaded Date | 02/23/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 |
|
||||
| Dataset | MINDlarge |
|
||||
| batch_size | 64 |
|
||||
| outputs | probability |
|
||||
| Accuracy | AUC: 0.66 |
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
<!-- In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py. -->
|
||||
In train.py, we set the seed which is used by numpy.random, mindspore.common.Initializer, mindspore.ops.composite.random_ops and mindspore.nn.probability.distribution.
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Evaluation NAML."""
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
from src.naml import NAML, NAMLWithLossCell
|
||||
from src.option import get_args
|
||||
from src.dataset import MINDPreprocess
|
||||
from src.utils import NAMLMetric, get_metric
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args("eval")
|
||||
set_seed(args.seed)
|
||||
net = NAML(args)
|
||||
net.set_train(False)
|
||||
net_with_loss = NAMLWithLossCell(net)
|
||||
load_checkpoint(args.checkpoint_path, net_with_loss)
|
||||
news_encoder = net.news_encoder
|
||||
user_encoder = net.user_encoder
|
||||
metric = NAMLMetric()
|
||||
mindpreprocess = MINDPreprocess(vars(args), dataset_path=args.eval_dataset_path)
|
||||
get_metric(args, mindpreprocess, news_encoder, user_encoder, metric)
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""NAML export."""
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, export
|
||||
from src.naml import NAML, NAMLWithLossCell
|
||||
from src.option import get_args
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
args = get_args("export")
|
||||
net = NAML(args)
|
||||
net.set_train(False)
|
||||
net_with_loss = NAMLWithLossCell(net)
|
||||
load_checkpoint(args.checkpoint_path, net_with_loss)
|
||||
news_encoder = net.news_encoder
|
||||
user_encoder = net.user_encoder
|
||||
bs = args.batch_size
|
||||
category = Tensor(np.zeros([bs, 1], np.int32))
|
||||
subcategory = Tensor(np.zeros([bs, 1], np.int32))
|
||||
title = Tensor(np.zeros([bs, args.n_words_title], np.int32))
|
||||
abstract = Tensor(np.zeros([bs, args.n_words_abstract], np.int32))
|
||||
|
||||
news_input_data = [category, subcategory, title, abstract]
|
||||
export(news_encoder, *news_input_data, file_name=f"naml_news_encoder_bs_{bs}", file_format=args.file_format)
|
||||
|
||||
browsed_news = Tensor(np.zeros([bs, args.n_browsed_news, args.n_filters], np.float32))
|
||||
export(user_encoder, browsed_news, file_name=f"naml_user_encoder_bs_{bs}", file_format=args.file_format)
|
|
@ -0,0 +1,35 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_eval.sh [PLATFORM] [DEVICE_ID] [DATASET] [DATASET_PATH] [CHECKPOINT_PATH]"
|
||||
echo "for example: bash run_eval.sh Ascend 0 large /path/MINDlarge ./checkpoint/naml_last.ckpt"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
PLATFORM=$1
|
||||
DEVICE_ID=$2
|
||||
DATASET=$3
|
||||
DATASET_PATH=$4
|
||||
CHECKPOINT_PATH=$5
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
python ${PROJECT_DIR}/../eval.py \
|
||||
--platform=${PLATFORM} \
|
||||
--device_id=${DEVICE_ID} \
|
||||
--dataset=${DATASET} \
|
||||
--dataset_path=${DATASET_PATH} \
|
||||
--checkpoint_path=${CHECKPOINT_PATH}
|
|
@ -0,0 +1,44 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_train.sh [PLATFORM] [DEVICE_ID] [DATASET] [DATASET_PATH]"
|
||||
echo "for example: bash run_train.sh Ascend 0 large /path/MINDlarge"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
PLATFORM=$1
|
||||
DEVICE_ID=$2
|
||||
DATASET=$3
|
||||
DATASET_PATH=$4
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
CHECKPOINT_PATH="./checkpoint"
|
||||
python ${PROJECT_DIR}/../train.py \
|
||||
--platform=${PLATFORM} \
|
||||
--device_id=${DEVICE_ID} \
|
||||
--dataset=${DATASET} \
|
||||
--dataset_path=${DATASET_PATH} \
|
||||
--save_checkpoint_path=${CHECKPOINT_PATH} \
|
||||
--weight_decay=False \
|
||||
--sink_mode=True
|
||||
|
||||
python ${PROJECT_DIR}/../eval.py \
|
||||
--platform=${PLATFORM} \
|
||||
--device_id=${DEVICE_ID} \
|
||||
--dataset=${DATASET} \
|
||||
--dataset_path=${DATASET_PATH} \
|
||||
--checkpoint_path=${CHECKPOINT_PATH}/naml_last.ckpt
|
|
@ -0,0 +1,111 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""NAML Callback"""
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
from mindspore import Tensor, save_checkpoint
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
class Monitor(Callback):
|
||||
"""
|
||||
Monitor loss and time.
|
||||
|
||||
Args:
|
||||
lr_init (numpy array): train lr
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Examples:
|
||||
>>> Monitor(args)
|
||||
"""
|
||||
|
||||
def __init__(self, args):
|
||||
super(Monitor, self).__init__()
|
||||
self.cur_step = 1
|
||||
self.cur_epoch = 1
|
||||
self.epochs = args.epochs
|
||||
self.sink_size = args.print_times
|
||||
self.sink_mode = args.sink_mode
|
||||
self.dataset_size = args.dataset_size
|
||||
self.save_checkpoint_path = args.save_checkpoint_path
|
||||
self.save_checkpoint = args.save_checkpoint
|
||||
self.losses = []
|
||||
if args.sink_mode:
|
||||
self.epoch_steps = self.sink_size
|
||||
else:
|
||||
self.epoch_steps = args.dataset_size
|
||||
if self.save_checkpoint and not os.path.isdir(self.save_checkpoint_path):
|
||||
os.makedirs(self.save_checkpoint_path)
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
self.epoch_time = time.time()
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""Callback when epoch end."""
|
||||
epoch_end_f = True
|
||||
if self.sink_mode:
|
||||
self.cur_step += self.epoch_steps
|
||||
epoch_end_f = False
|
||||
if self.cur_step >= self.dataset_size:
|
||||
epoch_end_f = True
|
||||
self.cur_step = self.cur_step % self.dataset_size
|
||||
cb_params = run_context.original_args()
|
||||
epoch_mseconds = (time.time() - self.epoch_time) * 1000
|
||||
per_step_mseconds = epoch_mseconds / cb_params.batch_num
|
||||
step_loss = cb_params.net_outputs
|
||||
if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
|
||||
step_loss = step_loss[0]
|
||||
if isinstance(step_loss, Tensor):
|
||||
step_loss = np.mean(step_loss.asnumpy())
|
||||
self.losses.append(step_loss)
|
||||
if epoch_end_f:
|
||||
print("epoch: {:3d}/{:3d}, avg loss:{:5.3f}".format(
|
||||
self.cur_epoch, self.epochs, np.mean(self.losses)), flush=True)
|
||||
self.losses = []
|
||||
self.cur_epoch += 1
|
||||
if self.sink_mode:
|
||||
print("epoch: {:3d}/{:3d}, step:{:5d}/{:5d}, loss:{:5.3f}, per step time:{:5.3f} ms".format(
|
||||
self.cur_epoch, self.epochs, self.cur_step, self.dataset_size, step_loss, per_step_mseconds),
|
||||
flush=True)
|
||||
if epoch_end_f and self.save_checkpoint:
|
||||
save_checkpoint(cb_params.train_network,
|
||||
os.path.join(self.save_checkpoint_path, f"naml_{self.cur_epoch-1}.ckpt"))
|
||||
|
||||
def step_begin(self, run_context):
|
||||
self.step_time = time.time()
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""Callback when step end."""
|
||||
if not self.sink_mode:
|
||||
cb_params = run_context.original_args()
|
||||
self.cur_step += 1
|
||||
self.cur_step = self.cur_step % self.dataset_size
|
||||
step_loss = cb_params.net_outputs
|
||||
if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
|
||||
step_loss = step_loss[0]
|
||||
if isinstance(step_loss, Tensor):
|
||||
step_loss = np.mean(step_loss.asnumpy())
|
||||
step_mseconds = (time.time() - self.step_time) * 1000
|
||||
print("epoch: {:3d}/{:3d}, step:{:5d}/{:5d}, loss:{:5.3f}, per step time:{:5.3f} ms".format(
|
||||
self.cur_epoch, self.epochs, self.cur_step, self.dataset_size, step_loss, step_mseconds), flush=True)
|
||||
|
||||
def end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
print("epoch: {:3d}/{:3d}, avg loss:{:5.3f}".format(
|
||||
self.epochs, self.epochs, np.mean(self.losses)), flush=True)
|
||||
if self.save_checkpoint:
|
||||
save_checkpoint(cb_params.train_network, os.path.join(self.save_checkpoint_path, f"naml_last.ckpt"))
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#===================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
The parameter is usually a multiple of 16 in order to adapt to Ascend.
|
||||
"""
|
||||
|
||||
class MINDlarge:
|
||||
"""MIND large config."""
|
||||
n_categories = 19
|
||||
n_sub_categories = 286
|
||||
n_words = 74308
|
||||
epochs = 1
|
||||
lr = 0.001
|
||||
print_times = 1000
|
||||
embedding_file = "{}/MINDlarge_utils/embedding_all.npy"
|
||||
word_dict_path = "{}/MINDlarge_utils/word_dict_all.pkl"
|
||||
category_dict_path = "{}/MINDlarge_utils/vert_dict.pkl"
|
||||
subcategory_dict_path = "{}/MINDlarge_utils/subvert_dict.pkl"
|
||||
uid2index_path = "{}/MINDlarge_utils/uid2index.pkl"
|
||||
train_dataset_path = "{}/MINDlarge_train"
|
||||
eval_dataset_path = "{}/MINDlarge_dev"
|
||||
|
||||
class MINDsmall:
|
||||
"""MIND small config."""
|
||||
n_categories = 19
|
||||
n_sub_categories = 271
|
||||
n_words = 60993
|
||||
epochs = 3
|
||||
lr = 0.0005
|
||||
print_times = 500
|
||||
embedding_file = "{}/MINDsmall_utils/embedding_all.npy"
|
||||
word_dict_path = "{}/MINDsmall_utils/word_dict_all.pkl"
|
||||
category_dict_path = "{}/MINDsmall_utils/vert_dict.pkl"
|
||||
subcategory_dict_path = "{}/MINDsmall_utils/subvert_dict.pkl"
|
||||
uid2index_path = "{}/MINDsmall_utils/uid2index.pkl"
|
||||
train_dataset_path = "{}/MINDsmall_train"
|
||||
eval_dataset_path = "{}/MINDsmall_dev"
|
||||
|
||||
class MINDdemo:
|
||||
"""MIND small config."""
|
||||
n_categories = 18
|
||||
n_sub_categories = 237
|
||||
n_words = 41059
|
||||
epochs = 10
|
||||
lr = 0.0005
|
||||
print_times = 100
|
||||
embedding_file = "{}/MINDdemo_utils/embedding_all.npy"
|
||||
word_dict_path = "{}/MINDdemo_utils/word_dict_all.pkl"
|
||||
category_dict_path = "{}/MINDdemo_utils/vert_dict.pkl"
|
||||
subcategory_dict_path = "{}/MINDdemo_utils/subvert_dict.pkl"
|
||||
uid2index_path = "{}/MINDdemo_utils/uid2index.pkl"
|
||||
train_dataset_path = "{}/MINDdemo_train"
|
||||
eval_dataset_path = "{}/MINDdemo_dev"
|
||||
|
||||
def get_dataset_config(dataset):
|
||||
if dataset == "large":
|
||||
return MINDlarge
|
||||
if dataset == "small":
|
||||
return MINDsmall
|
||||
if dataset == "demo":
|
||||
return MINDdemo
|
||||
raise ValueError(f"Only support MINDlarge, MINDsmall and MINDdemo")
|
|
@ -0,0 +1,296 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dataset loading, creation and processing"""
|
||||
import re
|
||||
import os
|
||||
import random
|
||||
from collections import namedtuple
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import mindspore.dataset as ds
|
||||
|
||||
ds.config.set_prefetch_size(8)
|
||||
|
||||
NEWS_ITEMS = ['category', 'subcategory', 'title', 'abstract']
|
||||
News = namedtuple('News', NEWS_ITEMS)
|
||||
|
||||
class MINDPreprocess:
|
||||
"""
|
||||
MIND dataset Preprocess class.
|
||||
when training, neg_sample=4, when test, neg_sample=-1
|
||||
"""
|
||||
def __init__(self, config, dataset_path=""):
|
||||
self.config = config
|
||||
self.dataset_dir = dataset_path
|
||||
self.word_dict_path = config['word_dict_path']
|
||||
self.category_dict_path = config['category_dict_path']
|
||||
self.subcategory_dict_path = config['subcategory_dict_path']
|
||||
self.uid2index_path = config['uid2index_path']
|
||||
|
||||
# behaviros config
|
||||
self.neg_sample = config['neg_sample']
|
||||
self.n_words_title = config['n_words_title']
|
||||
self.n_words_abstract = config['n_words_abstract']
|
||||
self.n_browsed_news = config['n_browsed_news']
|
||||
|
||||
# news config
|
||||
self.n_words_title = config['n_words_title']
|
||||
self.n_words_abstract = config['n_words_abstract']
|
||||
self._tokenize = 're'
|
||||
|
||||
self._is_init_data = False
|
||||
self._init_data()
|
||||
|
||||
self._index = 0
|
||||
self._sample_store = []
|
||||
|
||||
self._diff = 0
|
||||
|
||||
def _load_pickle(self, file_path):
|
||||
with open(file_path, 'rb') as fp:
|
||||
return pickle.load(fp)
|
||||
|
||||
def _init_news(self, news_path):
|
||||
"""News info initialization."""
|
||||
print(f"Start to init news, news path: {news_path}")
|
||||
|
||||
category_dict = self._load_pickle(file_path=self.category_dict_path)
|
||||
word_dict = self._load_pickle(file_path=self.word_dict_path)
|
||||
subcategory_dict = self._load_pickle(
|
||||
file_path=self.subcategory_dict_path)
|
||||
|
||||
self.nid_map_index = {}
|
||||
title_list = []
|
||||
category_list = []
|
||||
subcategory_list = []
|
||||
abstract_list = []
|
||||
|
||||
with open(news_path) as file_handler:
|
||||
for line in file_handler:
|
||||
nid, category, subcategory, title, abstract, _ = line.strip("\n").split('\t')[:6]
|
||||
|
||||
if nid in self.nid_map_index:
|
||||
continue
|
||||
|
||||
self.nid_map_index[nid] = len(self.nid_map_index)
|
||||
title_list.append(self._word_tokenize(title))
|
||||
category_list.append(category)
|
||||
subcategory_list.append(subcategory)
|
||||
abstract_list.append(self._word_tokenize(abstract))
|
||||
|
||||
news_len = len(title_list)
|
||||
self.news_title_index = np.zeros((news_len, self.n_words_title), dtype=np.int32)
|
||||
self.news_abstract_index = np.zeros((news_len, self.n_words_abstract), dtype=np.int32)
|
||||
self.news_category_index = np.zeros((news_len, 1), dtype=np.int32)
|
||||
self.news_subcategory_index = np.zeros((news_len, 1), dtype=np.int32)
|
||||
self.news_ids = np.zeros((news_len, 1), dtype=np.int32)
|
||||
|
||||
for news_index in range(news_len):
|
||||
title = title_list[news_index]
|
||||
title_index_list = [word_dict.get(word.lower(), 0) for word in title[:self.n_words_title]]
|
||||
self.news_title_index[news_index, list(range(len(title_index_list)))] = title_index_list
|
||||
|
||||
abstract = abstract_list[news_index]
|
||||
abstract_index_list = [word_dict.get(word.lower(), 0) for word in abstract[:self.n_words_abstract]]
|
||||
self.news_abstract_index[news_index, list(range(len(abstract_index_list)))] = abstract_index_list
|
||||
|
||||
category = category_list[news_index]
|
||||
self.news_category_index[news_index, 0] = category_dict.get(category, 0)
|
||||
|
||||
subcategory = subcategory_list[news_index]
|
||||
self.news_subcategory_index[news_index, 0] = subcategory_dict.get(subcategory, 0)
|
||||
|
||||
self.news_ids[news_index, 0] = news_index
|
||||
|
||||
def _init_behaviors(self, behaviors_path):
|
||||
"""Behaviors info initialization."""
|
||||
print(f"Start to init behaviors, path: {behaviors_path}")
|
||||
|
||||
self.history_list = []
|
||||
self.impression_list = []
|
||||
self.label_list = []
|
||||
self.impression_index_list = []
|
||||
self.uid_list = []
|
||||
self.poss = []
|
||||
self.negs = []
|
||||
self.index_map = {}
|
||||
|
||||
self.total_count = 0
|
||||
uid2index = self._load_pickle(self.uid2index_path)
|
||||
|
||||
with open(behaviors_path) as file_handler:
|
||||
for index, line in enumerate(file_handler):
|
||||
uid, _, history, impressions = line.strip("\n").split('\t')[-4:]
|
||||
negs = []
|
||||
history = [self.nid_map_index[i] for i in history.split()]
|
||||
random.shuffle(history)
|
||||
history = [0] * (self.n_browsed_news - len(history)) + history[:self.n_browsed_news]
|
||||
user_id = uid2index.get(uid, 0)
|
||||
|
||||
if self.neg_sample > 0:
|
||||
for item in impressions.split():
|
||||
nid, label = item.split('-')
|
||||
nid = self.nid_map_index[nid]
|
||||
if label == '1':
|
||||
self.poss.append(nid)
|
||||
self.index_map[self.total_count] = index
|
||||
self.total_count += 1
|
||||
else:
|
||||
negs.append(nid)
|
||||
else:
|
||||
nids = []
|
||||
labels = []
|
||||
for item in impressions.split():
|
||||
nid, label = item.split('-')
|
||||
nids.append(self.nid_map_index[nid])
|
||||
labels.append(int(label))
|
||||
self.impression_list.append((np.array(nids, dtype=np.int32), np.array(labels, dtype=np.int32)))
|
||||
self.total_count += 1
|
||||
|
||||
self.history_list.append(history)
|
||||
self.negs.append(negs)
|
||||
self.uid_list.append(user_id)
|
||||
|
||||
def _init_data(self):
|
||||
news_path = os.path.join(self.dataset_dir, 'news.tsv')
|
||||
behavior_path = os.path.join(self.dataset_dir, 'behaviors.tsv')
|
||||
if not self._is_init_data:
|
||||
self._init_news(news_path)
|
||||
self._init_behaviors(behavior_path)
|
||||
self._is_init_data = True
|
||||
print(f'init data end, count: {self.total_count}')
|
||||
|
||||
def _word_tokenize(self, sent):
|
||||
"""
|
||||
Split sentence into word list using regex.
|
||||
Args:
|
||||
sent (str): Input sentence
|
||||
|
||||
Return:
|
||||
list: word list
|
||||
"""
|
||||
pat = re.compile(r"[\w]+|[.,!?;|]")
|
||||
if isinstance(sent, str):
|
||||
return pat.findall(sent.lower())
|
||||
return []
|
||||
|
||||
def __getitem__(self, index):
|
||||
uid_index = self.index_map[index]
|
||||
if self.neg_sample >= 0:
|
||||
negs = self.negs[uid_index]
|
||||
nid = self.poss[index]
|
||||
random.shuffle(negs)
|
||||
neg_samples = (negs + [0] * (self.neg_sample - len(negs))) if self.neg_sample > len(negs) \
|
||||
else random.sample(negs, self.neg_sample)
|
||||
candidate_samples = [nid] + neg_samples
|
||||
labels = [1] + [0] * self.neg_sample
|
||||
|
||||
else:
|
||||
candidate_samples, labels = self.preprocess.impression_list[index]
|
||||
browsed_samples = self.history_list[uid_index]
|
||||
browsed_category = np.array(self.news_category_index[browsed_samples], dtype=np.int32)
|
||||
browsed_subcategory = np.array(self.news_subcategory_index[browsed_samples], dtype=np.int32)
|
||||
browsed_title = np.array(self.news_title_index[browsed_samples], dtype=np.int32)
|
||||
browsed_abstract = np.array(self.news_abstract_index[browsed_samples], dtype=np.int32)
|
||||
candidate_category = np.array(self.news_category_index[candidate_samples], dtype=np.int32)
|
||||
candidate_subcategory = np.array(self.news_subcategory_index[candidate_samples], dtype=np.int32)
|
||||
candidate_title = np.array(self.news_title_index[candidate_samples], dtype=np.int32)
|
||||
candidate_abstract = np.array(self.news_abstract_index[candidate_samples], dtype=np.int32)
|
||||
labels = np.array(labels, dtype=np.int32)
|
||||
return browsed_category, browsed_subcategory, browsed_title, browsed_abstract,\
|
||||
candidate_category, candidate_subcategory, candidate_title, candidate_abstract, labels
|
||||
|
||||
@property
|
||||
def column_names(self):
|
||||
news_column_names = ['category', 'subcategory', 'title', 'abstract']
|
||||
column_names = ['browsed_' + item for item in news_column_names]
|
||||
column_names += ['candidate_' + item for item in news_column_names]
|
||||
column_names += ['labels']
|
||||
return column_names
|
||||
|
||||
def __len__(self):
|
||||
return self.total_count
|
||||
|
||||
|
||||
class EvalDatasetBase:
|
||||
"""Base evaluation Datase class."""
|
||||
def __init__(self, preprocess: MINDPreprocess):
|
||||
self.preprocess = preprocess
|
||||
|
||||
|
||||
class EvalNews(EvalDatasetBase):
|
||||
"""Generator dataset for all news."""
|
||||
def __len__(self):
|
||||
return len(self.preprocess.news_title_index)
|
||||
|
||||
def __getitem__(self, index):
|
||||
news_id = self.preprocess.news_ids[index]
|
||||
title = self.preprocess.news_title_index[index]
|
||||
category = self.preprocess.news_category_index[index]
|
||||
subcategory = self.preprocess.news_subcategory_index[index]
|
||||
abstract = self.preprocess.news_abstract_index[index]
|
||||
return news_id.reshape(-1), category.reshape(-1), subcategory.reshape(-1), title.reshape(-1),\
|
||||
abstract.reshape(-1)
|
||||
|
||||
@property
|
||||
def column_names(self):
|
||||
return ['news_id', 'category', 'subcategory', 'title', 'abstract']
|
||||
|
||||
|
||||
class EvalUsers(EvalDatasetBase):
|
||||
"""Generator dataset for all user."""
|
||||
def __len__(self):
|
||||
return len(self.preprocess.uid_list)
|
||||
|
||||
def __getitem__(self, index):
|
||||
uid = np.array(self.preprocess.uid_list[index], dtype=np.int32)
|
||||
history = np.array(self.preprocess.history_list[index], dtype=np.int32)
|
||||
return uid, history.reshape(50, 1)
|
||||
|
||||
@property
|
||||
def column_names(self):
|
||||
return ['uid', 'history']
|
||||
|
||||
|
||||
class EvalCandidateNews(EvalDatasetBase):
|
||||
"""Generator dataset for all candidate news."""
|
||||
@property
|
||||
def column_names(self):
|
||||
return ['uid', 'candidate_nid', 'labels']
|
||||
|
||||
def __len__(self):
|
||||
return self.preprocess.total_count
|
||||
|
||||
def __getitem__(self, index):
|
||||
uid = np.array(self.preprocess.uid_list[index], dtype=np.int32)
|
||||
nid, label = self.preprocess.impression_list[index]
|
||||
return uid, nid, label
|
||||
|
||||
|
||||
def create_dataset(mindpreprocess, batch_size=64):
|
||||
"""Get generator dataset when training."""
|
||||
dataset = ds.GeneratorDataset(mindpreprocess, mindpreprocess.column_names, shuffle=True)
|
||||
dataset = dataset.batch(batch_size, drop_remainder=True)
|
||||
return dataset
|
||||
|
||||
|
||||
def create_eval_dataset(mindpreprocess, eval_cls, batch_size=64):
|
||||
"""Get generator dataset when evaluation."""
|
||||
eval_instance = eval_cls(mindpreprocess)
|
||||
dataset = ds.GeneratorDataset(eval_instance, eval_instance.column_names, shuffle=False)
|
||||
if not isinstance(eval_instance, EvalCandidateNews):
|
||||
dataset = dataset.batch(batch_size)
|
||||
return dataset
|
|
@ -0,0 +1,246 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""NAML network."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common import initializer as init
|
||||
import mindspore.ops as ops
|
||||
|
||||
class Attention(nn.Cell):
|
||||
"""
|
||||
Softmax attention implement.
|
||||
|
||||
Args:
|
||||
query_vector_dim (int): dimension of the query vector in attention.
|
||||
input_vector_dim (int): dimension of the input vector in attention.
|
||||
|
||||
Input:
|
||||
input (Tensor): input tensor, shape is (batch_size, n_input_vector, input_vector_dim)
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor, shape is (batch_size, n_input_vector).
|
||||
|
||||
Examples:
|
||||
>>> Attention(query_vector_dim, input_vector_dim)
|
||||
"""
|
||||
def __init__(self, query_vector_dim, input_vector_dim):
|
||||
super(Attention, self).__init__()
|
||||
self.dense1 = nn.Dense(input_vector_dim, query_vector_dim, has_bias=True, activation='tanh')
|
||||
self.dense2 = nn.Dense(query_vector_dim, 1, has_bias=False)
|
||||
self.softmax = nn.Softmax()
|
||||
self.sum_keep_dims = ops.ReduceSum(keep_dims=True)
|
||||
self.sum = ops.ReduceSum(keep_dims=False)
|
||||
|
||||
def construct(self, x):
|
||||
dtype = ops.dtype(x)
|
||||
batch_size, n_input_vector, input_vector_dim = ops.shape(x)
|
||||
feature = ops.reshape(x, (-1, input_vector_dim))
|
||||
attention = ops.reshape(self.dense2(self.dense1(feature)), (batch_size, n_input_vector))
|
||||
attention_weight = ops.cast(self.softmax(attention), dtype)
|
||||
weighted_input = x * ops.expand_dims(attention_weight, 2)
|
||||
return self.sum(weighted_input, 1)
|
||||
|
||||
class NewsEncoder(nn.Cell):
|
||||
"""
|
||||
The main function to create news encoder of NAML.
|
||||
|
||||
Args:
|
||||
args (class): global hyper-parameters.
|
||||
word_embedding (Tensor): parameter of word embedding.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> NewsEncoder(args, embedding_table)
|
||||
"""
|
||||
def __init__(self, args, embedding_table=None):
|
||||
super(NewsEncoder, self).__init__()
|
||||
# categories
|
||||
self.category_embedding = nn.Embedding(args.n_categories, args.category_embedding_dim)
|
||||
self.category_dense = nn.Dense(args.category_embedding_dim, args.n_filters, has_bias=True, activation="relu")
|
||||
|
||||
self.sub_category_embedding = nn.Embedding(args.n_sub_categories, args.category_embedding_dim)
|
||||
self.subcategory_dense = nn.Dense(args.category_embedding_dim, args.n_filters, has_bias=True, activation="relu")
|
||||
|
||||
# title and abstract
|
||||
if embedding_table is None:
|
||||
word_embedding = [nn.Embedding(args.n_words, args.word_embedding_dim)]
|
||||
else:
|
||||
word_embedding = [nn.Embedding(args.n_words, args.word_embedding_dim, embedding_table=embedding_table)]
|
||||
title_CNN = [
|
||||
nn.Conv1d(args.word_embedding_dim, args.n_filters, kernel_size=args.window_size, pad_mode='same',
|
||||
has_bias=True),
|
||||
nn.ReLU()
|
||||
]
|
||||
abstract_CNN = [
|
||||
nn.Conv1d(args.word_embedding_dim, args.n_filters, kernel_size=args.window_size, pad_mode='same',
|
||||
has_bias=True),
|
||||
nn.ReLU()
|
||||
]
|
||||
if args.phase == "train":
|
||||
word_embedding.append(nn.Dropout(keep_prob=(1-args.dropout_ratio)))
|
||||
title_CNN.append(nn.Dropout(keep_prob=(1-args.dropout_ratio)))
|
||||
abstract_CNN.append(nn.Dropout(keep_prob=(1-args.dropout_ratio)))
|
||||
self.word_embedding = nn.SequentialCell(word_embedding)
|
||||
self.title_CNN = nn.SequentialCell(title_CNN)
|
||||
self.abstract_CNN = nn.SequentialCell(abstract_CNN)
|
||||
self.title_attention = Attention(args.query_vector_dim, args.n_filters)
|
||||
self.abstract_attention = Attention(args.query_vector_dim, args.n_filters)
|
||||
self.total_attention = Attention(args.query_vector_dim, args.n_filters)
|
||||
self.pack = ops.Stack(axis=1)
|
||||
self.title_shape = (-1, args.n_words_title)
|
||||
self.abstract_shape = (-1, args.n_words_abstract)
|
||||
|
||||
def construct(self, category, subcategory, title, abstract):
|
||||
"""
|
||||
The news encoder is composed of title encoder, abstract encoder, category encoder and subcategory encoder.
|
||||
"""
|
||||
# Categories
|
||||
category_embedded = self.category_embedding(ops.reshape(category, (-1,)))
|
||||
category_vector = self.category_dense(category_embedded)
|
||||
subcategory_embedded = self.sub_category_embedding(ops.reshape(subcategory, (-1,)))
|
||||
subcategory_vector = self.subcategory_dense(subcategory_embedded)
|
||||
# title
|
||||
title_embedded = self.word_embedding(ops.reshape(title, self.title_shape))
|
||||
title_feature = self.title_CNN(ops.Transpose()(title_embedded, (0, 2, 1)))
|
||||
title_vector = self.title_attention(ops.Transpose()(title_feature, (0, 2, 1)))
|
||||
# abstract
|
||||
abstract_embedded = self.word_embedding(ops.reshape(abstract, self.abstract_shape))
|
||||
abstract_feature = self.abstract_CNN(ops.Transpose()(abstract_embedded, (0, 2, 1)))
|
||||
abstract_vector = self.abstract_attention(ops.Transpose()(abstract_feature, (0, 2, 1)))
|
||||
# total
|
||||
news_vector = self.total_attention(
|
||||
self.pack((category_vector, subcategory_vector, title_vector, abstract_vector)))
|
||||
return news_vector
|
||||
|
||||
class UserEncoder(nn.Cell):
|
||||
"""
|
||||
The main function to create user encoder of NAML.
|
||||
|
||||
Args:
|
||||
args (class): global hyper-parameters.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> UserEncoder(args)
|
||||
"""
|
||||
def __init__(self, args):
|
||||
super(UserEncoder, self).__init__()
|
||||
self.news_attention = Attention(args.query_vector_dim, args.n_filters)
|
||||
|
||||
def construct(self, news_vectors):
|
||||
user_vector = self.news_attention(news_vectors)
|
||||
return user_vector
|
||||
|
||||
class ClickPredictor(nn.Cell):
|
||||
"""
|
||||
Click predictor by user encoder and news encoder.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ClickPredictor()
|
||||
"""
|
||||
def __init__(self):
|
||||
super(ClickPredictor, self).__init__()
|
||||
self.matmul = ops.BatchMatMul()
|
||||
|
||||
def construct(self, news_vector, user_vector):
|
||||
predict = ops.Flatten()(self.matmul(news_vector, ops.expand_dims(user_vector, 2)))
|
||||
return predict
|
||||
|
||||
class NAML(nn.Cell):
|
||||
"""
|
||||
NAML model(Neural News Recommendation with Attentive Multi-View Learning).
|
||||
|
||||
Args:
|
||||
args (class): global hyper-parameters.
|
||||
word_embedding (Tensor): parameter of word embedding.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> NAML(rgs, embedding_table)
|
||||
"""
|
||||
def __init__(self, args, embedding_table=None):
|
||||
super(NAML, self).__init__()
|
||||
self.args = args
|
||||
self.news_encoder = NewsEncoder(args, embedding_table)
|
||||
self.user_encoder = UserEncoder(args)
|
||||
self.click_predictor = ClickPredictor()
|
||||
self.browsed_vector_shape = (args.batch_size, args.n_browsed_news, args.n_filters)
|
||||
self.candidate_vector_shape = (args.batch_size, args.neg_sample + 1, args.n_filters)
|
||||
if not embedding_table is None:
|
||||
self.word_embedding_shape = embedding_table.shape
|
||||
else:
|
||||
self.word_embedding_shape = ()
|
||||
self._initialize_weights()
|
||||
|
||||
def construct(self, category_b, subcategory_b, title_b, abstract_b, category_c, subcategory_c, title_c, abstract_c):
|
||||
browsed_news_vectors = ops.reshape(self.news_encoder(category_b, subcategory_b, title_b, abstract_b),
|
||||
self.browsed_vector_shape)
|
||||
user_vector = self.user_encoder(browsed_news_vectors)
|
||||
candidate_news_vector = ops.reshape(self.news_encoder(category_c, subcategory_c, title_c, abstract_c),
|
||||
self.candidate_vector_shape)
|
||||
predict = self.click_predictor(candidate_news_vector, user_vector)
|
||||
return predict
|
||||
|
||||
def _initialize_weights(self):
|
||||
"""Weights initialize."""
|
||||
self.init_parameters_data()
|
||||
for _, cell in self.cells_and_names():
|
||||
if isinstance(cell, nn.Conv1d):
|
||||
cell.weight.set_data(init.initializer("XavierUniform",
|
||||
cell.weight.shape,
|
||||
cell.weight.dtype))
|
||||
elif isinstance(cell, nn.Dense):
|
||||
cell.weight.set_data(init.initializer("XavierUniform",
|
||||
cell.weight.shape,
|
||||
cell.weight.dtype))
|
||||
elif isinstance(cell, nn.Embedding) and cell.embedding_table.shape != self.word_embedding_shape:
|
||||
cell.embedding_table.set_data(init.initializer("uniform",
|
||||
cell.embedding_table.shape,
|
||||
cell.embedding_table.dtype))
|
||||
|
||||
class NAMLWithLossCell(nn.Cell):
|
||||
"""
|
||||
NAML add loss Cell.
|
||||
|
||||
Args:
|
||||
network (Cell): naml network.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> NAMLWithLossCell(NAML(rgs, word_embedding))
|
||||
"""
|
||||
def __init__(self, network):
|
||||
super(NAMLWithLossCell, self).__init__()
|
||||
self.network = network
|
||||
self.loss = nn.SoftmaxCrossEntropyWithLogits(sparse=False, reduction='mean')
|
||||
|
||||
def construct(self, category_b, subcategory_b, title_b, abstract_b, category_c, subcategory_c, title_c, abstract_c,
|
||||
label):
|
||||
predict = self.network(category_b, subcategory_b, title_b, abstract_b, category_c, subcategory_c, title_c,
|
||||
abstract_c)
|
||||
dtype = ops.dtype(predict)
|
||||
shp = ops.shape(predict)
|
||||
loss = self.loss(predict, ops.reshape(ops.cast(label, dtype), shp))
|
||||
return loss
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""parse args"""
|
||||
import argparse
|
||||
import ast
|
||||
import os
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import init, get_rank
|
||||
from .config import get_dataset_config
|
||||
|
||||
def get_args(phase):
|
||||
"""Define the common options that are used in both training and test."""
|
||||
parser = argparse.ArgumentParser(description='Configuration')
|
||||
|
||||
# Hardware specifications
|
||||
parser.add_argument('--seed', type=int, default=1, help='random seed')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id, default is 0.")
|
||||
parser.add_argument('--device_num', type=int, default=1, help='device num, default is 1.')
|
||||
parser.add_argument('--platform', type=str, default="Ascend", \
|
||||
help='run platform, only support Ascend')
|
||||
parser.add_argument('--save_graphs', type=ast.literal_eval, default=False, \
|
||||
help='whether save graphs, default is False.')
|
||||
parser.add_argument('--dataset', type=str, default="large", choices=("large", "small", "demo"), \
|
||||
help='MIND dataset, support large, small and demo.')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='MIND dataset path.')
|
||||
|
||||
# Model specifications
|
||||
parser.add_argument('--n_browsed_news', type=int, default=50, help='number of browsed news per user')
|
||||
parser.add_argument('--n_words_title', type=int, default=16, help='number of words per title')
|
||||
parser.add_argument('--n_words_abstract', type=int, default=48, help='number of words per abstract')
|
||||
parser.add_argument('--word_embedding_dim', type=int, default=304, help='dimension of word embedding vector')
|
||||
parser.add_argument('--category_embedding_dim', type=int, default=112, \
|
||||
help='dimension of category embedding vector')
|
||||
parser.add_argument('--query_vector_dim', type=int, default=208, help='dimension of the query vector in attention')
|
||||
parser.add_argument('--n_filters', type=int, default=400, help='number of filters in CNN')
|
||||
parser.add_argument('--window_size', type=int, default=3, help='size of filter in CNN')
|
||||
parser.add_argument("--checkpoint_path", type=str, default=None, \
|
||||
help="Pre trained checkpoint path, default is None.")
|
||||
parser.add_argument('--batch_size', type=int, default=64, help='size of each batch')
|
||||
# Training specifications
|
||||
if phase == "train":
|
||||
parser.add_argument('--train_dataset_path', type=str, default=None, help='training set directory')
|
||||
parser.add_argument('--epochs', type=int, default=None, help='number of epochs for training')
|
||||
parser.add_argument('--lr', type=float, default=None, help='learning rate')
|
||||
parser.add_argument('--beta1', type=float, default=0.9, help='ADAM beta1')
|
||||
parser.add_argument('--beta2', type=float, default=0.999, help='ADAM beta2')
|
||||
parser.add_argument('--epsilon', type=float, default=1e-8, help='ADAM epsilon for numerical stability')
|
||||
parser.add_argument('--neg_sample', type=int, default=4, help='number of negative samples in negative sampling')
|
||||
parser.add_argument("--mixed", type=ast.literal_eval, default=True, \
|
||||
help="whether use mixed precision, default is True.")
|
||||
parser.add_argument("--sink_mode", type=ast.literal_eval, default=True, \
|
||||
help="whether use dataset sink, default is True.")
|
||||
parser.add_argument('--print_times', type=int, default=None, help='number of print times, default is None')
|
||||
parser.add_argument("--weight_decay", type=ast.literal_eval, default=True, \
|
||||
help="whether use weight decay, default is True.")
|
||||
parser.add_argument('--save_checkpoint', type=ast.literal_eval, default=True, \
|
||||
help='whether save checkpoint, default is True.')
|
||||
parser.add_argument("--save_checkpoint_path", type=str, default="./checkpoint", \
|
||||
help="Save checkpoint path, default is checkpoint.")
|
||||
parser.add_argument('--dropout_ratio', type=float, default=0.2, help='ratio of dropout')
|
||||
if phase == "eval":
|
||||
parser.add_argument('--eval_dataset_path', type=str, default=None)
|
||||
parser.add_argument('--neg_sample', type=int, default=-1, \
|
||||
help='number of negative samples in negative sampling')
|
||||
if phase == "export":
|
||||
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', \
|
||||
help='file format')
|
||||
parser.add_argument('--neg_sample', type=int, default=-1, \
|
||||
help='number of negative samples in negative sampling')
|
||||
args = parser.parse_args()
|
||||
if args.device_num > 1:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, save_graphs=args.save_graphs)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||
device_num=args.device_num)
|
||||
init()
|
||||
args.rank = get_rank()
|
||||
args.save_checkpoint_path = os.path.join(args.save_checkpoint_path, "ckpt_" + str(args.rank))
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, device_id=args.device_id,
|
||||
save_graphs=args.save_graphs, save_graphs_path="naml_ir")
|
||||
args.rank = 0
|
||||
args.device_num = 1
|
||||
args.phase = phase
|
||||
cfg = get_dataset_config(args.dataset)
|
||||
args.n_categories = cfg.n_categories
|
||||
args.n_sub_categories = cfg.n_sub_categories
|
||||
args.n_words = cfg.n_words
|
||||
if phase == "train":
|
||||
args.epochs = cfg.epochs if args.epochs is None else args.epochs
|
||||
args.lr = cfg.lr if args.lr is None else args.lr
|
||||
args.print_times = cfg.print_times if args.print_times is None else args.print_times
|
||||
args.embedding_file = cfg.embedding_file.format(args.dataset_path)
|
||||
args.word_dict_path = cfg.word_dict_path.format(args.dataset_path)
|
||||
args.category_dict_path = cfg.category_dict_path.format(args.dataset_path)
|
||||
args.subcategory_dict_path = cfg.subcategory_dict_path.format(args.dataset_path)
|
||||
args.uid2index_path = cfg.uid2index_path.format(args.dataset_path)
|
||||
args.train_dataset_path = cfg.train_dataset_path.format(args.dataset_path)
|
||||
args.eval_dataset_path = cfg.eval_dataset_path.format(args.dataset_path)
|
||||
args_dict = vars(args)
|
||||
for key in args_dict.keys():
|
||||
print('--> {}:{}'.format(key, args_dict[key]), flush=True)
|
||||
return args
|
|
@ -0,0 +1,132 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Utils for NAML."""
|
||||
import time
|
||||
import numpy as np
|
||||
from sklearn.metrics import roc_auc_score
|
||||
from mindspore import Tensor
|
||||
|
||||
from .dataset import create_eval_dataset, EvalNews, EvalUsers, EvalCandidateNews
|
||||
|
||||
def get_metric(args, mindpreprocess, news_encoder, user_encoder, metric):
|
||||
"""Calculate metrics."""
|
||||
start = time.time()
|
||||
news_dict = {}
|
||||
user_dict = {}
|
||||
dataset = create_eval_dataset(mindpreprocess, EvalNews, batch_size=args.batch_size)
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
iterator = dataset.create_dict_iterator(output_numpy=True)
|
||||
for count, data in enumerate(iterator):
|
||||
news_vector = news_encoder(Tensor(data["category"]), Tensor(data["subcategory"]),
|
||||
Tensor(data["title"]), Tensor(data["abstract"])).asnumpy()
|
||||
for i, nid in enumerate(data["news_id"]):
|
||||
news_dict[str(nid[0])] = news_vector[i]
|
||||
print(f"===Generate News vector==== [ {count} / {dataset_size} ]", end='\r')
|
||||
print(f"===Generate News vector==== [ {dataset_size} / {dataset_size} ]")
|
||||
dataset = create_eval_dataset(mindpreprocess, EvalUsers, batch_size=args.batch_size)
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
iterator = dataset.create_dict_iterator(output_numpy=True)
|
||||
for count, data in enumerate(iterator):
|
||||
browsed_news = []
|
||||
for newses in data["history"]:
|
||||
news_list = []
|
||||
for nid in newses:
|
||||
news_list.append(news_dict[str(nid[0])])
|
||||
browsed_news.append(np.array(news_list))
|
||||
browsed_news = np.array(browsed_news)
|
||||
user_vector = user_encoder(Tensor(browsed_news)).asnumpy()
|
||||
for i, uid in enumerate(data["uid"]):
|
||||
user_dict[str(uid)] = user_vector[i]
|
||||
print(f"===Generate Users vector==== [ {count} / {dataset_size} ]", end='\r')
|
||||
print(f"===Generate Users vector==== [ {dataset_size} / {dataset_size} ]")
|
||||
dataset = create_eval_dataset(mindpreprocess, EvalCandidateNews, batch_size=args.batch_size)
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
iterator = dataset.create_dict_iterator(output_numpy=True)
|
||||
for count, data in enumerate(iterator):
|
||||
pred = np.dot(
|
||||
np.stack([news_dict[str(nid)] for nid in data["candidate_nid"]], axis=0),
|
||||
user_dict[str(data["uid"])]
|
||||
)
|
||||
metric.update(pred, data["labels"])
|
||||
print(f"===Click Prediction==== [ {count} / {dataset_size} ]", end='\r')
|
||||
print(f"===Click Prediction==== [ {dataset_size} / {dataset_size} ]")
|
||||
auc = metric.eval()
|
||||
total_cost = time.time() - start
|
||||
print(f"Eval total cost: {total_cost} s")
|
||||
return auc
|
||||
|
||||
def process_data(args):
|
||||
word_embedding = np.load(args.embedding_file)
|
||||
_, h = word_embedding.shape
|
||||
if h < args.word_embedding_dim:
|
||||
word_embedding = np.pad(word_embedding, ((0, 0), (0, args.word_embedding_dim - 300)), 'constant',
|
||||
constant_values=0)
|
||||
elif h > args.word_embedding_dim:
|
||||
word_embedding = word_embedding[:, :args.word_embedding_dim]
|
||||
print("Load word_embedding", word_embedding.shape)
|
||||
return Tensor(word_embedding.astype(np.float32))
|
||||
|
||||
def AUC(y_true, y_pred):
|
||||
return roc_auc_score(y_true, y_pred)
|
||||
|
||||
def MRR(y_true, y_pred):
|
||||
index = np.argsort(y_pred)[::-1]
|
||||
y_true = np.take(y_true, index)
|
||||
score = y_true / (np.arange(len(y_true)) + 1)
|
||||
return np.sum(score) / np.sum(y_true)
|
||||
|
||||
def DCG(y_true, y_pred, n):
|
||||
index = np.argsort(y_pred)[::-1]
|
||||
y_true = np.take(y_true, index[:n])
|
||||
score = (2 ** y_true - 1) / np.log2(np.arange(len(y_true)) + 2)
|
||||
return np.sum(score)
|
||||
|
||||
def nDCG(y_true, y_pred, n):
|
||||
return DCG(y_true, y_pred, n) / DCG(y_true, y_true, n)
|
||||
|
||||
class NAMLMetric:
|
||||
"""
|
||||
Metric method
|
||||
"""
|
||||
def __init__(self):
|
||||
super(NAMLMetric, self).__init__()
|
||||
self.AUC_list = []
|
||||
self.MRR_list = []
|
||||
self.nDCG5_list = []
|
||||
self.nDCG10_list = []
|
||||
|
||||
def clear(self):
|
||||
"""Clear the internal evaluation result."""
|
||||
self.AUC_list = []
|
||||
self.MRR_list = []
|
||||
self.nDCG5_list = []
|
||||
self.nDCG10_list = []
|
||||
|
||||
def update(self, predict, y_true):
|
||||
predict = predict.flatten()
|
||||
y_true = y_true.flatten()
|
||||
# predict = np.interp(predict, (predict.min(), predict.max()), (0, 1))
|
||||
self.AUC_list.append(AUC(y_true, predict))
|
||||
self.MRR_list.append(MRR(y_true, predict))
|
||||
self.nDCG5_list.append(nDCG(y_true, predict, 5))
|
||||
self.nDCG10_list.append(nDCG(y_true, predict, 10))
|
||||
|
||||
def eval(self):
|
||||
auc = np.mean(self.AUC_list)
|
||||
print('AUC:', auc)
|
||||
print('MRR:', np.mean(self.MRR_list))
|
||||
print('nDCG@5:', np.mean(self.nDCG5_list))
|
||||
print('nDCG@10:', np.mean(self.nDCG10_list))
|
||||
return auc
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Train NAML."""
|
||||
import time
|
||||
from mindspore import nn, load_checkpoint
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
||||
from src.naml import NAML, NAMLWithLossCell
|
||||
from src.option import get_args
|
||||
from src.dataset import create_dataset, MINDPreprocess
|
||||
from src.utils import process_data
|
||||
from src.callback import Monitor
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args("train")
|
||||
set_seed(args.seed)
|
||||
word_embedding = process_data(args)
|
||||
net = NAML(args, word_embedding)
|
||||
net_with_loss = NAMLWithLossCell(net)
|
||||
if args.checkpoint_path is not None:
|
||||
load_checkpoint(args.pretrain_checkpoint, net_with_loss)
|
||||
mindpreprocess_train = MINDPreprocess(vars(args), dataset_path=args.train_dataset_path)
|
||||
dataset = create_dataset(mindpreprocess_train, batch_size=args.batch_size)
|
||||
args.dataset_size = dataset.get_dataset_size()
|
||||
args.print_times = min(args.dataset_size, args.print_times)
|
||||
if args.weight_decay:
|
||||
weight_params = list(filter(lambda x: 'weight' in x.name, net.trainable_params()))
|
||||
other_params = list(filter(lambda x: 'weight' not in x.name, net.trainable_params()))
|
||||
group_params = [{'params': weight_params, 'weight_decay': 1e-3},
|
||||
{'params': other_params, 'weight_decay': 0.0},
|
||||
{'order_params': net.trainable_params()}]
|
||||
opt = nn.AdamWeightDecay(group_params, args.lr, beta1=args.beta1, beta2=args.beta2, eps=args.epsilon)
|
||||
else:
|
||||
opt = nn.Adam(net.trainable_params(), args.lr, beta1=args.beta1, beta2=args.beta2, eps=args.epsilon)
|
||||
if args.mixed:
|
||||
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=128.0, scale_factor=2, scale_window=10000)
|
||||
net_with_loss.to_float(mstype.float16)
|
||||
for _, cell in net_with_loss.cells_and_names():
|
||||
if isinstance(cell, (nn.Embedding, nn.Softmax, nn.SoftmaxCrossEntropyWithLogits)):
|
||||
cell.to_float(mstype.float32)
|
||||
model = Model(net_with_loss, optimizer=opt, loss_scale_manager=loss_scale_manager)
|
||||
else:
|
||||
model = Model(net_with_loss, optimizer=opt)
|
||||
cb = [Monitor(args)]
|
||||
epochs = args.epochs
|
||||
if args.sink_mode:
|
||||
epochs = int(args.epochs * args.dataset_size / args.print_times)
|
||||
start_time = time.time()
|
||||
print("======================= Start Train ==========================", flush=True)
|
||||
model.train(epochs, dataset, callbacks=cb, dataset_sink_mode=args.sink_mode, sink_size=args.print_times)
|
||||
end_time = time.time()
|
||||
print("==============================================================")
|
||||
print("processor_name: {}".format(args.platform))
|
||||
print("test_name: NAML")
|
||||
print(f"model_name: NAML MIND{args.dataset}")
|
||||
print("batch_size: {}".format(args.batch_size))
|
||||
print("latency: {} s".format(end_time - start_time))
|
Loading…
Reference in New Issue