!10426 Add ldp_linucb into model zoo

From: @sergeywong
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-26 20:10:51 +08:00 committed by Gitee
commit 5653a0a970
7 changed files with 480 additions and 0 deletions

View File

@ -0,0 +1,132 @@
# Contents
- [LDP LinUCB Description](#ldp-linucb-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Launch](#launch)
- [Model Description](#model-description)
- [Performance](#performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [LDP LinUCB](#contents)
Locally Differentially Private (LDP) LinUCB is a variant of LinUCB bandit algorithm with local differential privacy guarantee, which can preserve users' personal data with theoretical guarantee.
[Paper](https://arxiv.org/abs/2006.00701): Kai Zheng, Tianle Cai, Weiran Huang, Zhenguo Li, Liwei Wang. "Locally Differentially Private (Contextual) Bandits Learning." *Advances in Neural Information Processing Systems*. 2020.
# [Model Architecture](#contents)
The server interacts with users in rounds. For a coming user, the server first transfers the current model parameters to the user. In the user side, the model chooses an action based on the user feature to play (e.g., choose a movie to recommend), and observes a reward (or loss) value from the user (e.g., rating of the movie). Then we perturb the data to be transfered by adding Gaussian noise. Finally, the server receives the perturbed data and updates the model. Details can be found in the [original paper](https://arxiv.org/abs/2006.00701).
# [Dataset](#contents)
Note that you can run the scripts based on the dataset mentioned in original paper. In the following sections, we will introduce how to run the scripts using the related dataset below.
Dataset used: [MovieLens 100K](https://grouplens.org/datasets/movielens/100k/)
- Dataset size5MB, 100,000 ratings (1-5) from 943 users on 1682 movies.
- Data formatcsv/txt files
# [Environment Requirements](#contents)
- Hardware (Ascend/GPU)
- Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend, please send the[application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```console
├── model_zoo
├── README.md // descriptions about all the models
├── research
├── rl
├── ldp_linucb
├── README.md // descriptions about LDP LinUCB
├── scripts
│ ├── run_train_eval.sh // shell script for runing on Ascend
├── src
│ ├── dataset.py // dataset for movielens
│ ├── linucb.py // model
├── train_eval.py // training script
├── result1.png // experimental result
├── result2.png // experimental result
```
## [Script Parameters](#contents)
- Parameters for preparing MovieLens 100K dataset
```python
'num_actions': 20 # number of candidate movies to be recommended
'rank_k': 20 # rank of rating matrix completion
```
- Parameters for LDP LinUCB, MovieLens 100K dataset
```python
'epsilon': 8e5 # privacy parameter
'delta': 0.1 # privacy parameter
'alpha': 0.1 # failure probability
'iter_num': 1e6 # number of iterations
```
## [Launch](#contents)
- running on Ascend
```shell
python train_eval.py > result.log 2>&1 &
```
The python command above will run in the background, you can view the results through the file `result.log`.
The regret value will be achieved as follows:
```console
--> Step: 0, diff: 348.662, current_regret: 0.000, cumulative regret: 0.000
--> Step: 1, diff: 338.457, current_regret: 0.000, cumulative regret: 0.000
--> Step: 2, diff: 336.465, current_regret: 2.000, cumulative regret: 2.000
--> Step: 3, diff: 327.337, current_regret: 0.000, cumulative regret: 2.000
--> Step: 4, diff: 325.039, current_regret: 2.000, cumulative regret: 4.000
...
```
# [Model Description](#contents)
The [original paper](https://arxiv.org/abs/2006.00701) assumes that the norm of user features is bounded by 1 and the norm of rating scores is bounded by 2. For the MovieLens dataset, we normalize rating scores to [-1,1]. Thus, we set `sigma` in Algorithm 5 to be $$4/epsilon \* sqrt(2 \* ln(1.25/delta))$$.
## [Performance](#contents)
The performance for different privacy parameters:
- x: number of iterations
- y: cumulative regret
![Result1](result1.png)
The performance compared with optimal non-private regret O(sqrt(T)):
- x: number of iterations
- y: cumulative regret divided by sqrt(T)
![Result2](result2.png)
# [Description of Random Situation](#contents)
In `train_eval.py`, we randomly sample a user at each round. We also add Gaussian noise to the date being transfered.
# [ModelZoo Homepage](#contents)
Please check the official
[homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

View File

@ -0,0 +1,17 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
python3 train_eval.py > result.log 2>&1 &

View File

@ -0,0 +1,99 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
MovieLens Environment.
"""
import random
import numpy as np
from mindspore import Tensor
_MAX_NUM_ACTIONS = 1682
_NUM_USERS = 943
def load_movielens_data(data_file):
"""Loads the movielens data and returns the ratings matrix."""
ratings_matrix = np.zeros([_NUM_USERS, _MAX_NUM_ACTIONS])
with open(data_file, 'r') as f:
for line in f.readlines():
row_infos = line.strip().split()
user_id = int(row_infos[0])
item_id = int(row_infos[1])
rating = float(row_infos[2])
ratings_matrix[user_id - 1, item_id - 1] = rating
return ratings_matrix
class MovieLensEnv:
"""
MovieLens dataset environment for bandit algorithms.
Args:
data_file(str): path of movielens file, e.g. 'ua.base'.
num_movies(int): number of movies for choices.
rank_k(int): the dim of feature.
Returns:
Environment for bandit algorithms.
"""
def __init__(self, data_file, num_movies, rank_k):
# Initialization
self._num_actions = num_movies
self._context_dim = rank_k
# Load Movielens dataset
self._data_matrix = load_movielens_data(data_file)
# Keep only the first items
self._data_matrix = self._data_matrix[:, :num_movies]
# Filter the users with at least one rating score
nonzero_users = list(
np.nonzero(
np.sum(
self._data_matrix,
axis=1) > 0.0)[0])
self._data_matrix = self._data_matrix[nonzero_users, :]
# Normalize the data_matrix into -1~1
self._data_matrix = 0.4 * (self._data_matrix - 2.5)
# Compute the SVD # Only keep the largest rank_k singular values
u, s, vh = np.linalg.svd(self._data_matrix, full_matrices=False)
u_hat = u[:, :rank_k] * np.sqrt(s[:rank_k])
v_hat = np.transpose(np.transpose(
vh[:rank_k, :]) * np.sqrt(s[:rank_k]))
self._approx_ratings_matrix = np.matmul(
u_hat, v_hat).astype(np.float32)
# Prepare feature for user i and item j: u[i,:] * vh[:,j]
# (elementwise product of user feature and item feature)
self._ground_truth = s
self._current_user = 0
self._feature = np.expand_dims(u[:, :rank_k], axis=1) * \
np.expand_dims(np.transpose(vh[:rank_k, :]), axis=0)
self._feature = self._feature.astype(np.float32)
@property
def ground_truth(self):
return self._ground_truth
def observation(self):
"""random select a user and return its feature."""
sampled_user = random.randint(0, self._data_matrix.shape[0] - 1)
self._current_user = sampled_user
return Tensor(self._feature[sampled_user])
def current_rewards(self):
"""rewards for current user."""
return Tensor(self._approx_ratings_matrix[self._current_user])

View File

@ -0,0 +1,143 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Linear UCB with locally differentially private.
"""
import math
import numpy as np
import mindspore
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
class LinUCB(nn.Cell):
"""
Linear UCB with locally differentially private bandits learning.
Args:
context_dim(int): dim of input feature.
epsilon(float): epsilon for private parameter.
delta(float): delta for private parameter.
alpha(float): failure probability.
T(float/int): number of iterations.
Returns:
Tuple of Tensors: gradients to update parameters and optimal action.
"""
def __init__(self, context_dim, epsilon=100, delta=0.1, alpha=0.1, T=1e5):
super(LinUCB, self).__init__()
self.matmul = P.MatMul()
self.expand_dims = P.ExpandDims()
self.transpose = P.Transpose()
self.reduce_sum = P.ReduceSum()
self.squeeze = P.Squeeze(1)
self.argmax = P.Argmax()
self.reduce_max = P.ReduceMax()
# Basic variables
self._context_dim = context_dim
self._epsilon = epsilon
self._delta = delta
self._alpha = alpha
self._T = int(T)
# Parameters
self._V = Tensor(
np.zeros(
(context_dim,
context_dim),
dtype=np.float32))
self._u = Tensor(np.zeros((context_dim,), dtype=np.float32))
self._theta = Tensor(np.zeros((context_dim,), dtype=np.float32))
# \sigma = 4*\sqrt{2*\ln{\farc{1.25}{\delta}}}/\epsilon
self._sigma = 4 * \
math.sqrt(math.log(1.25 / self._delta)) / self._epsilon
self._c = 0.1
self._step = 1
self._regret = 0
self._current_regret = 0
self.inverse_matrix()
@property
def theta(self):
return self._theta
@property
def regret(self):
return self._regret
@property
def current_regret(self):
return self._current_regret
def inverse_matrix(self):
"""compute the inverse matrix of parameter matrix."""
Vc = self._V + Tensor(np.eye(self._context_dim,
dtype=np.float32)) * self._c
self._Vc_inv = Tensor(np.linalg.inv(Vc.asnumpy()), mindspore.float32)
def update_status(self, step):
"""update status variables."""
t = max(step, 1)
T = self._T
d = self._context_dim
alpha = self._alpha
sigma = self._sigma
gamma = sigma * \
math.sqrt(t) * (4 * math.sqrt(d) + 2 * math.log(2 * T / alpha))
self._c = 2 * gamma
self._beta = 2 * sigma * math.sqrt(d * math.log(T)) + (math.sqrt(
3 * gamma) + sigma * math.sqrt(d * t / gamma)) * d * math.log(T)
def construct(self, x, rewards):
"""compute the perturbed gradients for parameters."""
# Choose optimal action
x_transpose = self.transpose(x, (1, 0))
scores_a = self.squeeze(self.matmul(x, self.expand_dims(self._theta, 1)))
scores_b = x_transpose * self.matmul(self._Vc_inv, x_transpose)
scores_b = self.reduce_sum(scores_b, 0)
scores = scores_a + self._beta * scores_b
max_a = self.argmax(scores)
xa = x[max_a]
xaxat = self.matmul(self.expand_dims(xa, -1), self.expand_dims(xa, 0))
y = rewards[max_a]
y_max = self.reduce_max(rewards)
y_diff = y_max - y
self._current_regret = float(y_diff.asnumpy())
self._regret += self._current_regret
# Prepare noise
B = np.random.normal(0, self._sigma, size=xaxat.shape)
B = np.triu(B)
B += B.transpose() - np.diag(B.diagonal())
B = Tensor(B.astype(np.float32))
Xi = np.random.normal(0, self._sigma, size=xa.shape)
Xi = Tensor(Xi.astype(np.float32))
# Add noise and update parameters
return xaxat + B, xa * y + Xi, max_a
def server_update(self, xaxat, xay):
"""update parameters with perturbed gradients."""
self._V += xaxat
self._u += xay
self.inverse_matrix()
theta = self.matmul(self._Vc_inv, self.expand_dims(self._u, 1))
self._theta = self.squeeze(theta)

View File

@ -0,0 +1,89 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
train/eval.
"""
import argparse
import time
import numpy as np
import matplotlib.pyplot as plt
from src.dataset import MovieLensEnv
from src.linucb import LinUCB
def parse_args():
"""parse args"""
parser = argparse.ArgumentParser()
parser.add_argument('--data_file', type=str, default='ua.base',
help='data file for movielens')
parser.add_argument('--rank_k', type=int, default=20,
help='rank for data matrix')
parser.add_argument('--num_actions', type=int, default=20,
help='movie number for choices')
parser.add_argument('--epsilon', type=float, default=8e5,
help='epsilon for differentially private')
parser.add_argument('--delta', type=float, default=1e-1,
help='delta for differentially private')
parser.add_argument('--alpha', type=float, default=1e-1,
help='failure probability')
parser.add_argument('--iter_num', type=float, default=1e6,
help='iteration number for training')
args_opt = parser.parse_args()
return args_opt
if __name__ == '__main__':
# build environment
args = parse_args()
env = MovieLensEnv(args.data_file, args.num_actions, args.rank_k)
# Linear UCB
lin_ucb = LinUCB(
args.rank_k,
epsilon=args.epsilon,
delta=args.delta,
alpha=args.alpha,
T=args.iter_num)
print('start')
start_time = time.time()
cumulative_regrets = []
for i in range(int(args.iter_num)):
x = env.observation()
rewards = env.current_rewards()
lin_ucb.update_status(i + 1)
xaxat, xay, max_a = lin_ucb(x, rewards)
cumulative_regrets.append(float(lin_ucb.regret))
lin_ucb.server_update(xaxat, xay)
diff = np.abs(lin_ucb.theta.asnumpy() - env.ground_truth).sum()
print(
f'--> Step: {i}, diff: {diff:.3f},'
f'current_regret: {lin_ucb.current_regret:.3f},'
f'cumulative regret: {lin_ucb.regret:.3f}')
end_time = time.time()
print(f'Regret: {lin_ucb.regret}, cost time: {end_time-start_time:.3f}s')
print(f'theta: {lin_ucb.theta.asnumpy()}')
print(f' gt: {env.ground_truth}')
np.save(f'e_{args.epsilon:.1e}.npy', cumulative_regrets)
plt.plot(
range(len(cumulative_regrets)),
cumulative_regrets,
label=f'epsilon={args.epsilon:.1e}')
plt.legend()
plt.savefig(f'regret_{args.epsilon:.1e}.png')