This commit is contained in:
zhaoting 2021-02-22 17:15:45 +08:00
parent 3805f0dfeb
commit 267466ca01
15 changed files with 1338 additions and 0 deletions

View File

@ -52,6 +52,7 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
- [Transformer](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/transformer/README.md)
- [Recommender Systems](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend)
- [DeepFM](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/deepfm/README.md)
- [NAML](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/naml/README.md)
- [Wide&Deep[benchmark]](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/wide_and_deep/README.md)
- [Graph Neural Networks](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn)
- [BGCF](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf/README.md)

View File

@ -52,6 +52,7 @@
- [Transformer](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/transformer/README.md)
- [推荐系统](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend)
- [DeepFM](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/deepfm/README.md)
- [NAML](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/naml/README.md)
- [Wide&Deep[基准]](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/wide_and_deep/README.md)
- [图神经网络](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn)
- [BGCF](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf/README.md)

View File

@ -0,0 +1,135 @@
# Contents
- [NAML Description](#NAML-description)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Training Process](#training-process)
- [Model Export](#model-export)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [NAML Description](#contents)
NAML is a multi-view news recommendation approach. The core of NAML is a news encoder and a user encoder. The newsencoder is composed of a title encoder, a abstract encoder, a category encoder and a subcategory encoder. In the user encoder, we learn representations of users from their browsed news. Besides, we apply additive attention to learn more informative news and user representations by selecting important words and news.
[Paper](https://arxiv.org/abs/1907.05576) Chuhan Wu, Fangzhao Wu, Mingxiao An, Jianqiang Huang, Yongfeng Huang and Xing Xie: Neural News Recommendation with Attentive Multi-View Learning, IJCAI 2019
# [Dataset](#contents)
Dataset used: [MIND](https://msnews.github.io/)
MIND contains about 160k English news articles and more than 15 million impression logs generated by 1 million users.
You can download the dataset and put the directory in structure as follows:
```path
└─MINDlarge
├─MINDlarge_train
├─MINDlarge_dev
└─MINDlarge_utils
```
# [Environment Requirements](#contents)
- HardwareAscend/GPU
- Prepare hardware environment with Ascend, GPU processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Script description](#contents)
## [Script and sample code](#contents)
```path
├── naml
├── README.md # descriptions about NAML
├── scripts
│ ├──run_train.sh # shell script for training
│ ├──run_eval.sh # shell script for evaluation
├── src
│ ├──option.py # parse args
│ ├──callback.py # callback file
│ ├──dataset.py # creating dataset
│ ├──naml.py # NAML architecture
│ ├──config.py # config file
│ ├──utils.py # utils to load ckpt_file for fine tune or incremental learn
├── train.py # training script
├── eval.py # evaluation script
├── export.py # export mindir script
```
## [Training process](#contents)
### Usage
You can start training using python or shell scripts. The usage of shell scripts as follows:
```shell
bash run_train.sh [PLATFORM] [DEVICE_ID] [DATASET] [DATASET_PATH]
bash run_eval.sh [PLATFORM] [DEVICE_ID] [DATASET] [DATASET_PATH] [CHECKPOINT_PATH]
```
- `PLATFORM` should be Ascend.
- `DEVICE_ID` is the device id you want to run the network.
- `DATASET` MIND dataset, support large, small and demo.
- `DATASET_PATH` is the dataset path, the structure as [Dataset](#dataset).
- `CHECKPOINT_PATH` is a pre-trained checkpoint path.
## [Model Export](#contents)
```shell
python export.py --platform [PLATFORM] --checkpoint_path [CHECKPOINT_PATH] --file_format [EXPORT_FORMAT] --batch_size [BATCH_SIZE]
```
- `EXPORT_FORMAT` should be in ["AIR", "ONNX", "MINDIR"]
# [Model Description](#contents)
## [Performance](#contents)
### Evaluation Performance
| Parameters | Ascend |
| -------------------------- | ------------------------------------------------------------ |
| Model Version | NAML |
| Resource | Ascend 910 CPU 2.60GHz56coresMemory314G |
| uploaded Date | 02/23/2021 (month/day/year) |
| MindSpore Version | 1.2.0 |
| Dataset | MINDlarge |
| Training Parameters | epoch=1, steps=52869, batch_size=64, lr=0.001 |
| Optimizer | Adam |
| Loss Function | Softmax Cross Entropy |
| outputs | probability |
| Speed | 1pc: 62 ms/step |
| Total time | 1pc: 54 mins |
### Inference Performance
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Model Version | NAML |
| Resource | Ascend 910 |
| Uploaded Date | 02/23/2021 (month/day/year) |
| MindSpore Version | 1.2.0 |
| Dataset | MINDlarge |
| batch_size | 64 |
| outputs | probability |
| Accuracy | AUC: 0.66 |
# [Description of Random Situation](#contents)
<!-- In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py. -->
In train.py, we set the seed which is used by numpy.random, mindspore.common.Initializer, mindspore.ops.composite.random_ops and mindspore.nn.probability.distribution.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,34 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation NAML."""
from mindspore.common import set_seed
from mindspore.train.serialization import load_checkpoint
from src.naml import NAML, NAMLWithLossCell
from src.option import get_args
from src.dataset import MINDPreprocess
from src.utils import NAMLMetric, get_metric
if __name__ == '__main__':
args = get_args("eval")
set_seed(args.seed)
net = NAML(args)
net.set_train(False)
net_with_loss = NAMLWithLossCell(net)
load_checkpoint(args.checkpoint_path, net_with_loss)
news_encoder = net.news_encoder
user_encoder = net.user_encoder
metric = NAMLMetric()
mindpreprocess = MINDPreprocess(vars(args), dataset_path=args.eval_dataset_path)
get_metric(args, mindpreprocess, news_encoder, user_encoder, metric)

View File

@ -0,0 +1,41 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""NAML export."""
import numpy as np
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, export
from src.naml import NAML, NAMLWithLossCell
from src.option import get_args
if __name__ == '__main__':
args = get_args("export")
net = NAML(args)
net.set_train(False)
net_with_loss = NAMLWithLossCell(net)
load_checkpoint(args.checkpoint_path, net_with_loss)
news_encoder = net.news_encoder
user_encoder = net.user_encoder
bs = args.batch_size
category = Tensor(np.zeros([bs, 1], np.int32))
subcategory = Tensor(np.zeros([bs, 1], np.int32))
title = Tensor(np.zeros([bs, args.n_words_title], np.int32))
abstract = Tensor(np.zeros([bs, args.n_words_abstract], np.int32))
news_input_data = [category, subcategory, title, abstract]
export(news_encoder, *news_input_data, file_name=f"naml_news_encoder_bs_{bs}", file_format=args.file_format)
browsed_news = Tensor(np.zeros([bs, args.n_browsed_news, args.n_filters], np.float32))
export(user_encoder, browsed_news, file_name=f"naml_user_encoder_bs_{bs}", file_format=args.file_format)

View File

@ -0,0 +1,35 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_eval.sh [PLATFORM] [DEVICE_ID] [DATASET] [DATASET_PATH] [CHECKPOINT_PATH]"
echo "for example: bash run_eval.sh Ascend 0 large /path/MINDlarge ./checkpoint/naml_last.ckpt"
echo "It is better to use absolute path."
echo "=============================================================================================================="
PLATFORM=$1
DEVICE_ID=$2
DATASET=$3
DATASET_PATH=$4
CHECKPOINT_PATH=$5
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
python ${PROJECT_DIR}/../eval.py \
--platform=${PLATFORM} \
--device_id=${DEVICE_ID} \
--dataset=${DATASET} \
--dataset_path=${DATASET_PATH} \
--checkpoint_path=${CHECKPOINT_PATH}

View File

@ -0,0 +1,44 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_train.sh [PLATFORM] [DEVICE_ID] [DATASET] [DATASET_PATH]"
echo "for example: bash run_train.sh Ascend 0 large /path/MINDlarge"
echo "It is better to use absolute path."
echo "=============================================================================================================="
PLATFORM=$1
DEVICE_ID=$2
DATASET=$3
DATASET_PATH=$4
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
CHECKPOINT_PATH="./checkpoint"
python ${PROJECT_DIR}/../train.py \
--platform=${PLATFORM} \
--device_id=${DEVICE_ID} \
--dataset=${DATASET} \
--dataset_path=${DATASET_PATH} \
--save_checkpoint_path=${CHECKPOINT_PATH} \
--weight_decay=False \
--sink_mode=True
python ${PROJECT_DIR}/../eval.py \
--platform=${PLATFORM} \
--device_id=${DEVICE_ID} \
--dataset=${DATASET} \
--dataset_path=${DATASET_PATH} \
--checkpoint_path=${CHECKPOINT_PATH}/naml_last.ckpt

View File

@ -0,0 +1,111 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""NAML Callback"""
import os
import time
import numpy as np
from mindspore import Tensor, save_checkpoint
from mindspore.train.callback import Callback
class Monitor(Callback):
"""
Monitor loss and time.
Args:
lr_init (numpy array): train lr
Returns:
None
Examples:
>>> Monitor(args)
"""
def __init__(self, args):
super(Monitor, self).__init__()
self.cur_step = 1
self.cur_epoch = 1
self.epochs = args.epochs
self.sink_size = args.print_times
self.sink_mode = args.sink_mode
self.dataset_size = args.dataset_size
self.save_checkpoint_path = args.save_checkpoint_path
self.save_checkpoint = args.save_checkpoint
self.losses = []
if args.sink_mode:
self.epoch_steps = self.sink_size
else:
self.epoch_steps = args.dataset_size
if self.save_checkpoint and not os.path.isdir(self.save_checkpoint_path):
os.makedirs(self.save_checkpoint_path)
def epoch_begin(self, run_context):
self.epoch_time = time.time()
def epoch_end(self, run_context):
"""Callback when epoch end."""
epoch_end_f = True
if self.sink_mode:
self.cur_step += self.epoch_steps
epoch_end_f = False
if self.cur_step >= self.dataset_size:
epoch_end_f = True
self.cur_step = self.cur_step % self.dataset_size
cb_params = run_context.original_args()
epoch_mseconds = (time.time() - self.epoch_time) * 1000
per_step_mseconds = epoch_mseconds / cb_params.batch_num
step_loss = cb_params.net_outputs
if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
step_loss = step_loss[0]
if isinstance(step_loss, Tensor):
step_loss = np.mean(step_loss.asnumpy())
self.losses.append(step_loss)
if epoch_end_f:
print("epoch: {:3d}/{:3d}, avg loss:{:5.3f}".format(
self.cur_epoch, self.epochs, np.mean(self.losses)), flush=True)
self.losses = []
self.cur_epoch += 1
if self.sink_mode:
print("epoch: {:3d}/{:3d}, step:{:5d}/{:5d}, loss:{:5.3f}, per step time:{:5.3f} ms".format(
self.cur_epoch, self.epochs, self.cur_step, self.dataset_size, step_loss, per_step_mseconds),
flush=True)
if epoch_end_f and self.save_checkpoint:
save_checkpoint(cb_params.train_network,
os.path.join(self.save_checkpoint_path, f"naml_{self.cur_epoch-1}.ckpt"))
def step_begin(self, run_context):
self.step_time = time.time()
def step_end(self, run_context):
"""Callback when step end."""
if not self.sink_mode:
cb_params = run_context.original_args()
self.cur_step += 1
self.cur_step = self.cur_step % self.dataset_size
step_loss = cb_params.net_outputs
if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
step_loss = step_loss[0]
if isinstance(step_loss, Tensor):
step_loss = np.mean(step_loss.asnumpy())
step_mseconds = (time.time() - self.step_time) * 1000
print("epoch: {:3d}/{:3d}, step:{:5d}/{:5d}, loss:{:5.3f}, per step time:{:5.3f} ms".format(
self.cur_epoch, self.epochs, self.cur_step, self.dataset_size, step_loss, step_mseconds), flush=True)
def end(self, run_context):
cb_params = run_context.original_args()
print("epoch: {:3d}/{:3d}, avg loss:{:5.3f}".format(
self.epochs, self.epochs, np.mean(self.losses)), flush=True)
if self.save_checkpoint:
save_checkpoint(cb_params.train_network, os.path.join(self.save_checkpoint_path, f"naml_last.ckpt"))

View File

@ -0,0 +1,75 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#===================================================
"""
network config setting, will be used in train.py and eval.py
The parameter is usually a multiple of 16 in order to adapt to Ascend.
"""
class MINDlarge:
"""MIND large config."""
n_categories = 19
n_sub_categories = 286
n_words = 74308
epochs = 1
lr = 0.001
print_times = 1000
embedding_file = "{}/MINDlarge_utils/embedding_all.npy"
word_dict_path = "{}/MINDlarge_utils/word_dict_all.pkl"
category_dict_path = "{}/MINDlarge_utils/vert_dict.pkl"
subcategory_dict_path = "{}/MINDlarge_utils/subvert_dict.pkl"
uid2index_path = "{}/MINDlarge_utils/uid2index.pkl"
train_dataset_path = "{}/MINDlarge_train"
eval_dataset_path = "{}/MINDlarge_dev"
class MINDsmall:
"""MIND small config."""
n_categories = 19
n_sub_categories = 271
n_words = 60993
epochs = 3
lr = 0.0005
print_times = 500
embedding_file = "{}/MINDsmall_utils/embedding_all.npy"
word_dict_path = "{}/MINDsmall_utils/word_dict_all.pkl"
category_dict_path = "{}/MINDsmall_utils/vert_dict.pkl"
subcategory_dict_path = "{}/MINDsmall_utils/subvert_dict.pkl"
uid2index_path = "{}/MINDsmall_utils/uid2index.pkl"
train_dataset_path = "{}/MINDsmall_train"
eval_dataset_path = "{}/MINDsmall_dev"
class MINDdemo:
"""MIND small config."""
n_categories = 18
n_sub_categories = 237
n_words = 41059
epochs = 10
lr = 0.0005
print_times = 100
embedding_file = "{}/MINDdemo_utils/embedding_all.npy"
word_dict_path = "{}/MINDdemo_utils/word_dict_all.pkl"
category_dict_path = "{}/MINDdemo_utils/vert_dict.pkl"
subcategory_dict_path = "{}/MINDdemo_utils/subvert_dict.pkl"
uid2index_path = "{}/MINDdemo_utils/uid2index.pkl"
train_dataset_path = "{}/MINDdemo_train"
eval_dataset_path = "{}/MINDdemo_dev"
def get_dataset_config(dataset):
if dataset == "large":
return MINDlarge
if dataset == "small":
return MINDsmall
if dataset == "demo":
return MINDdemo
raise ValueError(f"Only support MINDlarge, MINDsmall and MINDdemo")

View File

@ -0,0 +1,296 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset loading, creation and processing"""
import re
import os
import random
from collections import namedtuple
import pickle
import numpy as np
import mindspore.dataset as ds
ds.config.set_prefetch_size(8)
NEWS_ITEMS = ['category', 'subcategory', 'title', 'abstract']
News = namedtuple('News', NEWS_ITEMS)
class MINDPreprocess:
"""
MIND dataset Preprocess class.
when training, neg_sample=4, when test, neg_sample=-1
"""
def __init__(self, config, dataset_path=""):
self.config = config
self.dataset_dir = dataset_path
self.word_dict_path = config['word_dict_path']
self.category_dict_path = config['category_dict_path']
self.subcategory_dict_path = config['subcategory_dict_path']
self.uid2index_path = config['uid2index_path']
# behaviros config
self.neg_sample = config['neg_sample']
self.n_words_title = config['n_words_title']
self.n_words_abstract = config['n_words_abstract']
self.n_browsed_news = config['n_browsed_news']
# news config
self.n_words_title = config['n_words_title']
self.n_words_abstract = config['n_words_abstract']
self._tokenize = 're'
self._is_init_data = False
self._init_data()
self._index = 0
self._sample_store = []
self._diff = 0
def _load_pickle(self, file_path):
with open(file_path, 'rb') as fp:
return pickle.load(fp)
def _init_news(self, news_path):
"""News info initialization."""
print(f"Start to init news, news path: {news_path}")
category_dict = self._load_pickle(file_path=self.category_dict_path)
word_dict = self._load_pickle(file_path=self.word_dict_path)
subcategory_dict = self._load_pickle(
file_path=self.subcategory_dict_path)
self.nid_map_index = {}
title_list = []
category_list = []
subcategory_list = []
abstract_list = []
with open(news_path) as file_handler:
for line in file_handler:
nid, category, subcategory, title, abstract, _ = line.strip("\n").split('\t')[:6]
if nid in self.nid_map_index:
continue
self.nid_map_index[nid] = len(self.nid_map_index)
title_list.append(self._word_tokenize(title))
category_list.append(category)
subcategory_list.append(subcategory)
abstract_list.append(self._word_tokenize(abstract))
news_len = len(title_list)
self.news_title_index = np.zeros((news_len, self.n_words_title), dtype=np.int32)
self.news_abstract_index = np.zeros((news_len, self.n_words_abstract), dtype=np.int32)
self.news_category_index = np.zeros((news_len, 1), dtype=np.int32)
self.news_subcategory_index = np.zeros((news_len, 1), dtype=np.int32)
self.news_ids = np.zeros((news_len, 1), dtype=np.int32)
for news_index in range(news_len):
title = title_list[news_index]
title_index_list = [word_dict.get(word.lower(), 0) for word in title[:self.n_words_title]]
self.news_title_index[news_index, list(range(len(title_index_list)))] = title_index_list
abstract = abstract_list[news_index]
abstract_index_list = [word_dict.get(word.lower(), 0) for word in abstract[:self.n_words_abstract]]
self.news_abstract_index[news_index, list(range(len(abstract_index_list)))] = abstract_index_list
category = category_list[news_index]
self.news_category_index[news_index, 0] = category_dict.get(category, 0)
subcategory = subcategory_list[news_index]
self.news_subcategory_index[news_index, 0] = subcategory_dict.get(subcategory, 0)
self.news_ids[news_index, 0] = news_index
def _init_behaviors(self, behaviors_path):
"""Behaviors info initialization."""
print(f"Start to init behaviors, path: {behaviors_path}")
self.history_list = []
self.impression_list = []
self.label_list = []
self.impression_index_list = []
self.uid_list = []
self.poss = []
self.negs = []
self.index_map = {}
self.total_count = 0
uid2index = self._load_pickle(self.uid2index_path)
with open(behaviors_path) as file_handler:
for index, line in enumerate(file_handler):
uid, _, history, impressions = line.strip("\n").split('\t')[-4:]
negs = []
history = [self.nid_map_index[i] for i in history.split()]
random.shuffle(history)
history = [0] * (self.n_browsed_news - len(history)) + history[:self.n_browsed_news]
user_id = uid2index.get(uid, 0)
if self.neg_sample > 0:
for item in impressions.split():
nid, label = item.split('-')
nid = self.nid_map_index[nid]
if label == '1':
self.poss.append(nid)
self.index_map[self.total_count] = index
self.total_count += 1
else:
negs.append(nid)
else:
nids = []
labels = []
for item in impressions.split():
nid, label = item.split('-')
nids.append(self.nid_map_index[nid])
labels.append(int(label))
self.impression_list.append((np.array(nids, dtype=np.int32), np.array(labels, dtype=np.int32)))
self.total_count += 1
self.history_list.append(history)
self.negs.append(negs)
self.uid_list.append(user_id)
def _init_data(self):
news_path = os.path.join(self.dataset_dir, 'news.tsv')
behavior_path = os.path.join(self.dataset_dir, 'behaviors.tsv')
if not self._is_init_data:
self._init_news(news_path)
self._init_behaviors(behavior_path)
self._is_init_data = True
print(f'init data end, count: {self.total_count}')
def _word_tokenize(self, sent):
"""
Split sentence into word list using regex.
Args:
sent (str): Input sentence
Return:
list: word list
"""
pat = re.compile(r"[\w]+|[.,!?;|]")
if isinstance(sent, str):
return pat.findall(sent.lower())
return []
def __getitem__(self, index):
uid_index = self.index_map[index]
if self.neg_sample >= 0:
negs = self.negs[uid_index]
nid = self.poss[index]
random.shuffle(negs)
neg_samples = (negs + [0] * (self.neg_sample - len(negs))) if self.neg_sample > len(negs) \
else random.sample(negs, self.neg_sample)
candidate_samples = [nid] + neg_samples
labels = [1] + [0] * self.neg_sample
else:
candidate_samples, labels = self.preprocess.impression_list[index]
browsed_samples = self.history_list[uid_index]
browsed_category = np.array(self.news_category_index[browsed_samples], dtype=np.int32)
browsed_subcategory = np.array(self.news_subcategory_index[browsed_samples], dtype=np.int32)
browsed_title = np.array(self.news_title_index[browsed_samples], dtype=np.int32)
browsed_abstract = np.array(self.news_abstract_index[browsed_samples], dtype=np.int32)
candidate_category = np.array(self.news_category_index[candidate_samples], dtype=np.int32)
candidate_subcategory = np.array(self.news_subcategory_index[candidate_samples], dtype=np.int32)
candidate_title = np.array(self.news_title_index[candidate_samples], dtype=np.int32)
candidate_abstract = np.array(self.news_abstract_index[candidate_samples], dtype=np.int32)
labels = np.array(labels, dtype=np.int32)
return browsed_category, browsed_subcategory, browsed_title, browsed_abstract,\
candidate_category, candidate_subcategory, candidate_title, candidate_abstract, labels
@property
def column_names(self):
news_column_names = ['category', 'subcategory', 'title', 'abstract']
column_names = ['browsed_' + item for item in news_column_names]
column_names += ['candidate_' + item for item in news_column_names]
column_names += ['labels']
return column_names
def __len__(self):
return self.total_count
class EvalDatasetBase:
"""Base evaluation Datase class."""
def __init__(self, preprocess: MINDPreprocess):
self.preprocess = preprocess
class EvalNews(EvalDatasetBase):
"""Generator dataset for all news."""
def __len__(self):
return len(self.preprocess.news_title_index)
def __getitem__(self, index):
news_id = self.preprocess.news_ids[index]
title = self.preprocess.news_title_index[index]
category = self.preprocess.news_category_index[index]
subcategory = self.preprocess.news_subcategory_index[index]
abstract = self.preprocess.news_abstract_index[index]
return news_id.reshape(-1), category.reshape(-1), subcategory.reshape(-1), title.reshape(-1),\
abstract.reshape(-1)
@property
def column_names(self):
return ['news_id', 'category', 'subcategory', 'title', 'abstract']
class EvalUsers(EvalDatasetBase):
"""Generator dataset for all user."""
def __len__(self):
return len(self.preprocess.uid_list)
def __getitem__(self, index):
uid = np.array(self.preprocess.uid_list[index], dtype=np.int32)
history = np.array(self.preprocess.history_list[index], dtype=np.int32)
return uid, history.reshape(50, 1)
@property
def column_names(self):
return ['uid', 'history']
class EvalCandidateNews(EvalDatasetBase):
"""Generator dataset for all candidate news."""
@property
def column_names(self):
return ['uid', 'candidate_nid', 'labels']
def __len__(self):
return self.preprocess.total_count
def __getitem__(self, index):
uid = np.array(self.preprocess.uid_list[index], dtype=np.int32)
nid, label = self.preprocess.impression_list[index]
return uid, nid, label
def create_dataset(mindpreprocess, batch_size=64):
"""Get generator dataset when training."""
dataset = ds.GeneratorDataset(mindpreprocess, mindpreprocess.column_names, shuffle=True)
dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset
def create_eval_dataset(mindpreprocess, eval_cls, batch_size=64):
"""Get generator dataset when evaluation."""
eval_instance = eval_cls(mindpreprocess)
dataset = ds.GeneratorDataset(eval_instance, eval_instance.column_names, shuffle=False)
if not isinstance(eval_instance, EvalCandidateNews):
dataset = dataset.batch(batch_size)
return dataset

View File

@ -0,0 +1,246 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""NAML network."""
import mindspore.nn as nn
from mindspore.common import initializer as init
import mindspore.ops as ops
class Attention(nn.Cell):
"""
Softmax attention implement.
Args:
query_vector_dim (int): dimension of the query vector in attention.
input_vector_dim (int): dimension of the input vector in attention.
Input:
input (Tensor): input tensor, shape is (batch_size, n_input_vector, input_vector_dim)
Returns:
Tensor, output tensor, shape is (batch_size, n_input_vector).
Examples:
>>> Attention(query_vector_dim, input_vector_dim)
"""
def __init__(self, query_vector_dim, input_vector_dim):
super(Attention, self).__init__()
self.dense1 = nn.Dense(input_vector_dim, query_vector_dim, has_bias=True, activation='tanh')
self.dense2 = nn.Dense(query_vector_dim, 1, has_bias=False)
self.softmax = nn.Softmax()
self.sum_keep_dims = ops.ReduceSum(keep_dims=True)
self.sum = ops.ReduceSum(keep_dims=False)
def construct(self, x):
dtype = ops.dtype(x)
batch_size, n_input_vector, input_vector_dim = ops.shape(x)
feature = ops.reshape(x, (-1, input_vector_dim))
attention = ops.reshape(self.dense2(self.dense1(feature)), (batch_size, n_input_vector))
attention_weight = ops.cast(self.softmax(attention), dtype)
weighted_input = x * ops.expand_dims(attention_weight, 2)
return self.sum(weighted_input, 1)
class NewsEncoder(nn.Cell):
"""
The main function to create news encoder of NAML.
Args:
args (class): global hyper-parameters.
word_embedding (Tensor): parameter of word embedding.
Returns:
Tensor, output tensor.
Examples:
>>> NewsEncoder(args, embedding_table)
"""
def __init__(self, args, embedding_table=None):
super(NewsEncoder, self).__init__()
# categories
self.category_embedding = nn.Embedding(args.n_categories, args.category_embedding_dim)
self.category_dense = nn.Dense(args.category_embedding_dim, args.n_filters, has_bias=True, activation="relu")
self.sub_category_embedding = nn.Embedding(args.n_sub_categories, args.category_embedding_dim)
self.subcategory_dense = nn.Dense(args.category_embedding_dim, args.n_filters, has_bias=True, activation="relu")
# title and abstract
if embedding_table is None:
word_embedding = [nn.Embedding(args.n_words, args.word_embedding_dim)]
else:
word_embedding = [nn.Embedding(args.n_words, args.word_embedding_dim, embedding_table=embedding_table)]
title_CNN = [
nn.Conv1d(args.word_embedding_dim, args.n_filters, kernel_size=args.window_size, pad_mode='same',
has_bias=True),
nn.ReLU()
]
abstract_CNN = [
nn.Conv1d(args.word_embedding_dim, args.n_filters, kernel_size=args.window_size, pad_mode='same',
has_bias=True),
nn.ReLU()
]
if args.phase == "train":
word_embedding.append(nn.Dropout(keep_prob=(1-args.dropout_ratio)))
title_CNN.append(nn.Dropout(keep_prob=(1-args.dropout_ratio)))
abstract_CNN.append(nn.Dropout(keep_prob=(1-args.dropout_ratio)))
self.word_embedding = nn.SequentialCell(word_embedding)
self.title_CNN = nn.SequentialCell(title_CNN)
self.abstract_CNN = nn.SequentialCell(abstract_CNN)
self.title_attention = Attention(args.query_vector_dim, args.n_filters)
self.abstract_attention = Attention(args.query_vector_dim, args.n_filters)
self.total_attention = Attention(args.query_vector_dim, args.n_filters)
self.pack = ops.Stack(axis=1)
self.title_shape = (-1, args.n_words_title)
self.abstract_shape = (-1, args.n_words_abstract)
def construct(self, category, subcategory, title, abstract):
"""
The news encoder is composed of title encoder, abstract encoder, category encoder and subcategory encoder.
"""
# Categories
category_embedded = self.category_embedding(ops.reshape(category, (-1,)))
category_vector = self.category_dense(category_embedded)
subcategory_embedded = self.sub_category_embedding(ops.reshape(subcategory, (-1,)))
subcategory_vector = self.subcategory_dense(subcategory_embedded)
# title
title_embedded = self.word_embedding(ops.reshape(title, self.title_shape))
title_feature = self.title_CNN(ops.Transpose()(title_embedded, (0, 2, 1)))
title_vector = self.title_attention(ops.Transpose()(title_feature, (0, 2, 1)))
# abstract
abstract_embedded = self.word_embedding(ops.reshape(abstract, self.abstract_shape))
abstract_feature = self.abstract_CNN(ops.Transpose()(abstract_embedded, (0, 2, 1)))
abstract_vector = self.abstract_attention(ops.Transpose()(abstract_feature, (0, 2, 1)))
# total
news_vector = self.total_attention(
self.pack((category_vector, subcategory_vector, title_vector, abstract_vector)))
return news_vector
class UserEncoder(nn.Cell):
"""
The main function to create user encoder of NAML.
Args:
args (class): global hyper-parameters.
Returns:
Tensor, output tensor.
Examples:
>>> UserEncoder(args)
"""
def __init__(self, args):
super(UserEncoder, self).__init__()
self.news_attention = Attention(args.query_vector_dim, args.n_filters)
def construct(self, news_vectors):
user_vector = self.news_attention(news_vectors)
return user_vector
class ClickPredictor(nn.Cell):
"""
Click predictor by user encoder and news encoder.
Returns:
Tensor, output tensor.
Examples:
>>> ClickPredictor()
"""
def __init__(self):
super(ClickPredictor, self).__init__()
self.matmul = ops.BatchMatMul()
def construct(self, news_vector, user_vector):
predict = ops.Flatten()(self.matmul(news_vector, ops.expand_dims(user_vector, 2)))
return predict
class NAML(nn.Cell):
"""
NAML model(Neural News Recommendation with Attentive Multi-View Learning).
Args:
args (class): global hyper-parameters.
word_embedding (Tensor): parameter of word embedding.
Returns:
Tensor, output tensor.
Examples:
>>> NAML(rgs, embedding_table)
"""
def __init__(self, args, embedding_table=None):
super(NAML, self).__init__()
self.args = args
self.news_encoder = NewsEncoder(args, embedding_table)
self.user_encoder = UserEncoder(args)
self.click_predictor = ClickPredictor()
self.browsed_vector_shape = (args.batch_size, args.n_browsed_news, args.n_filters)
self.candidate_vector_shape = (args.batch_size, args.neg_sample + 1, args.n_filters)
if not embedding_table is None:
self.word_embedding_shape = embedding_table.shape
else:
self.word_embedding_shape = ()
self._initialize_weights()
def construct(self, category_b, subcategory_b, title_b, abstract_b, category_c, subcategory_c, title_c, abstract_c):
browsed_news_vectors = ops.reshape(self.news_encoder(category_b, subcategory_b, title_b, abstract_b),
self.browsed_vector_shape)
user_vector = self.user_encoder(browsed_news_vectors)
candidate_news_vector = ops.reshape(self.news_encoder(category_c, subcategory_c, title_c, abstract_c),
self.candidate_vector_shape)
predict = self.click_predictor(candidate_news_vector, user_vector)
return predict
def _initialize_weights(self):
"""Weights initialize."""
self.init_parameters_data()
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Conv1d):
cell.weight.set_data(init.initializer("XavierUniform",
cell.weight.shape,
cell.weight.dtype))
elif isinstance(cell, nn.Dense):
cell.weight.set_data(init.initializer("XavierUniform",
cell.weight.shape,
cell.weight.dtype))
elif isinstance(cell, nn.Embedding) and cell.embedding_table.shape != self.word_embedding_shape:
cell.embedding_table.set_data(init.initializer("uniform",
cell.embedding_table.shape,
cell.embedding_table.dtype))
class NAMLWithLossCell(nn.Cell):
"""
NAML add loss Cell.
Args:
network (Cell): naml network.
Returns:
Tensor, output tensor.
Examples:
>>> NAMLWithLossCell(NAML(rgs, word_embedding))
"""
def __init__(self, network):
super(NAMLWithLossCell, self).__init__()
self.network = network
self.loss = nn.SoftmaxCrossEntropyWithLogits(sparse=False, reduction='mean')
def construct(self, category_b, subcategory_b, title_b, abstract_b, category_c, subcategory_c, title_c, abstract_c,
label):
predict = self.network(category_b, subcategory_b, title_b, abstract_b, category_c, subcategory_c, title_c,
abstract_c)
dtype = ops.dtype(predict)
shp = ops.shape(predict)
loss = self.loss(predict, ops.reshape(ops.cast(label, dtype), shp))
return loss

View File

@ -0,0 +1,116 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""parse args"""
import argparse
import ast
import os
from mindspore.context import ParallelMode
from mindspore import context
from mindspore.communication.management import init, get_rank
from .config import get_dataset_config
def get_args(phase):
"""Define the common options that are used in both training and test."""
parser = argparse.ArgumentParser(description='Configuration')
# Hardware specifications
parser.add_argument('--seed', type=int, default=1, help='random seed')
parser.add_argument("--device_id", type=int, default=0, help="device id, default is 0.")
parser.add_argument('--device_num', type=int, default=1, help='device num, default is 1.')
parser.add_argument('--platform', type=str, default="Ascend", \
help='run platform, only support Ascend')
parser.add_argument('--save_graphs', type=ast.literal_eval, default=False, \
help='whether save graphs, default is False.')
parser.add_argument('--dataset', type=str, default="large", choices=("large", "small", "demo"), \
help='MIND dataset, support large, small and demo.')
parser.add_argument('--dataset_path', type=str, default=None, help='MIND dataset path.')
# Model specifications
parser.add_argument('--n_browsed_news', type=int, default=50, help='number of browsed news per user')
parser.add_argument('--n_words_title', type=int, default=16, help='number of words per title')
parser.add_argument('--n_words_abstract', type=int, default=48, help='number of words per abstract')
parser.add_argument('--word_embedding_dim', type=int, default=304, help='dimension of word embedding vector')
parser.add_argument('--category_embedding_dim', type=int, default=112, \
help='dimension of category embedding vector')
parser.add_argument('--query_vector_dim', type=int, default=208, help='dimension of the query vector in attention')
parser.add_argument('--n_filters', type=int, default=400, help='number of filters in CNN')
parser.add_argument('--window_size', type=int, default=3, help='size of filter in CNN')
parser.add_argument("--checkpoint_path", type=str, default=None, \
help="Pre trained checkpoint path, default is None.")
parser.add_argument('--batch_size', type=int, default=64, help='size of each batch')
# Training specifications
if phase == "train":
parser.add_argument('--train_dataset_path', type=str, default=None, help='training set directory')
parser.add_argument('--epochs', type=int, default=None, help='number of epochs for training')
parser.add_argument('--lr', type=float, default=None, help='learning rate')
parser.add_argument('--beta1', type=float, default=0.9, help='ADAM beta1')
parser.add_argument('--beta2', type=float, default=0.999, help='ADAM beta2')
parser.add_argument('--epsilon', type=float, default=1e-8, help='ADAM epsilon for numerical stability')
parser.add_argument('--neg_sample', type=int, default=4, help='number of negative samples in negative sampling')
parser.add_argument("--mixed", type=ast.literal_eval, default=True, \
help="whether use mixed precision, default is True.")
parser.add_argument("--sink_mode", type=ast.literal_eval, default=True, \
help="whether use dataset sink, default is True.")
parser.add_argument('--print_times', type=int, default=None, help='number of print times, default is None')
parser.add_argument("--weight_decay", type=ast.literal_eval, default=True, \
help="whether use weight decay, default is True.")
parser.add_argument('--save_checkpoint', type=ast.literal_eval, default=True, \
help='whether save checkpoint, default is True.')
parser.add_argument("--save_checkpoint_path", type=str, default="./checkpoint", \
help="Save checkpoint path, default is checkpoint.")
parser.add_argument('--dropout_ratio', type=float, default=0.2, help='ratio of dropout')
if phase == "eval":
parser.add_argument('--eval_dataset_path', type=str, default=None)
parser.add_argument('--neg_sample', type=int, default=-1, \
help='number of negative samples in negative sampling')
if phase == "export":
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', \
help='file format')
parser.add_argument('--neg_sample', type=int, default=-1, \
help='number of negative samples in negative sampling')
args = parser.parse_args()
if args.device_num > 1:
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, save_graphs=args.save_graphs)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=args.device_num)
init()
args.rank = get_rank()
args.save_checkpoint_path = os.path.join(args.save_checkpoint_path, "ckpt_" + str(args.rank))
else:
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, device_id=args.device_id,
save_graphs=args.save_graphs, save_graphs_path="naml_ir")
args.rank = 0
args.device_num = 1
args.phase = phase
cfg = get_dataset_config(args.dataset)
args.n_categories = cfg.n_categories
args.n_sub_categories = cfg.n_sub_categories
args.n_words = cfg.n_words
if phase == "train":
args.epochs = cfg.epochs if args.epochs is None else args.epochs
args.lr = cfg.lr if args.lr is None else args.lr
args.print_times = cfg.print_times if args.print_times is None else args.print_times
args.embedding_file = cfg.embedding_file.format(args.dataset_path)
args.word_dict_path = cfg.word_dict_path.format(args.dataset_path)
args.category_dict_path = cfg.category_dict_path.format(args.dataset_path)
args.subcategory_dict_path = cfg.subcategory_dict_path.format(args.dataset_path)
args.uid2index_path = cfg.uid2index_path.format(args.dataset_path)
args.train_dataset_path = cfg.train_dataset_path.format(args.dataset_path)
args.eval_dataset_path = cfg.eval_dataset_path.format(args.dataset_path)
args_dict = vars(args)
for key in args_dict.keys():
print('--> {}:{}'.format(key, args_dict[key]), flush=True)
return args

View File

@ -0,0 +1,132 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils for NAML."""
import time
import numpy as np
from sklearn.metrics import roc_auc_score
from mindspore import Tensor
from .dataset import create_eval_dataset, EvalNews, EvalUsers, EvalCandidateNews
def get_metric(args, mindpreprocess, news_encoder, user_encoder, metric):
"""Calculate metrics."""
start = time.time()
news_dict = {}
user_dict = {}
dataset = create_eval_dataset(mindpreprocess, EvalNews, batch_size=args.batch_size)
dataset_size = dataset.get_dataset_size()
iterator = dataset.create_dict_iterator(output_numpy=True)
for count, data in enumerate(iterator):
news_vector = news_encoder(Tensor(data["category"]), Tensor(data["subcategory"]),
Tensor(data["title"]), Tensor(data["abstract"])).asnumpy()
for i, nid in enumerate(data["news_id"]):
news_dict[str(nid[0])] = news_vector[i]
print(f"===Generate News vector==== [ {count} / {dataset_size} ]", end='\r')
print(f"===Generate News vector==== [ {dataset_size} / {dataset_size} ]")
dataset = create_eval_dataset(mindpreprocess, EvalUsers, batch_size=args.batch_size)
dataset_size = dataset.get_dataset_size()
iterator = dataset.create_dict_iterator(output_numpy=True)
for count, data in enumerate(iterator):
browsed_news = []
for newses in data["history"]:
news_list = []
for nid in newses:
news_list.append(news_dict[str(nid[0])])
browsed_news.append(np.array(news_list))
browsed_news = np.array(browsed_news)
user_vector = user_encoder(Tensor(browsed_news)).asnumpy()
for i, uid in enumerate(data["uid"]):
user_dict[str(uid)] = user_vector[i]
print(f"===Generate Users vector==== [ {count} / {dataset_size} ]", end='\r')
print(f"===Generate Users vector==== [ {dataset_size} / {dataset_size} ]")
dataset = create_eval_dataset(mindpreprocess, EvalCandidateNews, batch_size=args.batch_size)
dataset_size = dataset.get_dataset_size()
iterator = dataset.create_dict_iterator(output_numpy=True)
for count, data in enumerate(iterator):
pred = np.dot(
np.stack([news_dict[str(nid)] for nid in data["candidate_nid"]], axis=0),
user_dict[str(data["uid"])]
)
metric.update(pred, data["labels"])
print(f"===Click Prediction==== [ {count} / {dataset_size} ]", end='\r')
print(f"===Click Prediction==== [ {dataset_size} / {dataset_size} ]")
auc = metric.eval()
total_cost = time.time() - start
print(f"Eval total cost: {total_cost} s")
return auc
def process_data(args):
word_embedding = np.load(args.embedding_file)
_, h = word_embedding.shape
if h < args.word_embedding_dim:
word_embedding = np.pad(word_embedding, ((0, 0), (0, args.word_embedding_dim - 300)), 'constant',
constant_values=0)
elif h > args.word_embedding_dim:
word_embedding = word_embedding[:, :args.word_embedding_dim]
print("Load word_embedding", word_embedding.shape)
return Tensor(word_embedding.astype(np.float32))
def AUC(y_true, y_pred):
return roc_auc_score(y_true, y_pred)
def MRR(y_true, y_pred):
index = np.argsort(y_pred)[::-1]
y_true = np.take(y_true, index)
score = y_true / (np.arange(len(y_true)) + 1)
return np.sum(score) / np.sum(y_true)
def DCG(y_true, y_pred, n):
index = np.argsort(y_pred)[::-1]
y_true = np.take(y_true, index[:n])
score = (2 ** y_true - 1) / np.log2(np.arange(len(y_true)) + 2)
return np.sum(score)
def nDCG(y_true, y_pred, n):
return DCG(y_true, y_pred, n) / DCG(y_true, y_true, n)
class NAMLMetric:
"""
Metric method
"""
def __init__(self):
super(NAMLMetric, self).__init__()
self.AUC_list = []
self.MRR_list = []
self.nDCG5_list = []
self.nDCG10_list = []
def clear(self):
"""Clear the internal evaluation result."""
self.AUC_list = []
self.MRR_list = []
self.nDCG5_list = []
self.nDCG10_list = []
def update(self, predict, y_true):
predict = predict.flatten()
y_true = y_true.flatten()
# predict = np.interp(predict, (predict.min(), predict.max()), (0, 1))
self.AUC_list.append(AUC(y_true, predict))
self.MRR_list.append(MRR(y_true, predict))
self.nDCG5_list.append(nDCG(y_true, predict, 5))
self.nDCG10_list.append(nDCG(y_true, predict, 10))
def eval(self):
auc = np.mean(self.AUC_list)
print('AUC:', auc)
print('MRR:', np.mean(self.MRR_list))
print('nDCG@5:', np.mean(self.nDCG5_list))
print('nDCG@10:', np.mean(self.nDCG10_list))
return auc

View File

@ -0,0 +1,71 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Train NAML."""
import time
from mindspore import nn, load_checkpoint
import mindspore.common.dtype as mstype
from mindspore.common import set_seed
from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from src.naml import NAML, NAMLWithLossCell
from src.option import get_args
from src.dataset import create_dataset, MINDPreprocess
from src.utils import process_data
from src.callback import Monitor
if __name__ == '__main__':
args = get_args("train")
set_seed(args.seed)
word_embedding = process_data(args)
net = NAML(args, word_embedding)
net_with_loss = NAMLWithLossCell(net)
if args.checkpoint_path is not None:
load_checkpoint(args.pretrain_checkpoint, net_with_loss)
mindpreprocess_train = MINDPreprocess(vars(args), dataset_path=args.train_dataset_path)
dataset = create_dataset(mindpreprocess_train, batch_size=args.batch_size)
args.dataset_size = dataset.get_dataset_size()
args.print_times = min(args.dataset_size, args.print_times)
if args.weight_decay:
weight_params = list(filter(lambda x: 'weight' in x.name, net.trainable_params()))
other_params = list(filter(lambda x: 'weight' not in x.name, net.trainable_params()))
group_params = [{'params': weight_params, 'weight_decay': 1e-3},
{'params': other_params, 'weight_decay': 0.0},
{'order_params': net.trainable_params()}]
opt = nn.AdamWeightDecay(group_params, args.lr, beta1=args.beta1, beta2=args.beta2, eps=args.epsilon)
else:
opt = nn.Adam(net.trainable_params(), args.lr, beta1=args.beta1, beta2=args.beta2, eps=args.epsilon)
if args.mixed:
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=128.0, scale_factor=2, scale_window=10000)
net_with_loss.to_float(mstype.float16)
for _, cell in net_with_loss.cells_and_names():
if isinstance(cell, (nn.Embedding, nn.Softmax, nn.SoftmaxCrossEntropyWithLogits)):
cell.to_float(mstype.float32)
model = Model(net_with_loss, optimizer=opt, loss_scale_manager=loss_scale_manager)
else:
model = Model(net_with_loss, optimizer=opt)
cb = [Monitor(args)]
epochs = args.epochs
if args.sink_mode:
epochs = int(args.epochs * args.dataset_size / args.print_times)
start_time = time.time()
print("======================= Start Train ==========================", flush=True)
model.train(epochs, dataset, callbacks=cb, dataset_sink_mode=args.sink_mode, sink_size=args.print_times)
end_time = time.time()
print("==============================================================")
print("processor_name: {}".format(args.platform))
print("test_name: NAML")
print(f"model_name: NAML MIND{args.dataset}")
print("batch_size: {}".format(args.batch_size))
print("latency: {} s".format(end_time - start_time))