data_parallel_grad_reducer

This commit is contained in:
yao_yf 2020-06-04 09:27:42 +08:00
parent fcdc88cca9
commit ce03ce5af2
16 changed files with 1086 additions and 16 deletions

View File

@ -25,6 +25,8 @@ The entire code structure is as following:
WideDeep.py "Model structure"
callbacks.py "Callback class for training and evaluation"
metrics.py "Metric class"
|--- script/ "run shell dir"
run_multinpu_train.sh "run data parallel"
```
### Train and evaluate model

View File

@ -17,16 +17,18 @@
# bash run_multinpu_train.sh
execute_path=$(pwd)
# export RANK_TABLE_FILE=${execute_path}/rank_table_8p.json
# export RANK_SIZE=8
# export MINDSPORE_HCCL_CONFIG_PATH=${execute_path}/rank_table_8p.json
export RANK_SIZE=$1
export EPOCH_SIZE=$2
export DATASET=$3
export RANK_TABLE_FILE=$4
export MINDSPORE_HCCL_CONFIG_PATH=$4
for((i=0;i<=7;i++));
for((i=0;i<=$RANK_SIZE;i++));
do
rm -rf ${execute_path}/device_$i/
mkdir ${execute_path}/device_$i/
cd ${execute_path}/device_$i/ || exit
export RANK_ID=$i
export DEVICE_ID=$i
pytest -s ${execute_path}/train_and_test_multinpu.py >train_deep$i.log 2>&1 &
pytest -s ${execute_path}/train_and_test_multinpu.py --data_path=$DATASET --epochs=$EPOCH_SIZE >train_deep$i.log 2>&1 &
done

View File

@ -17,6 +17,7 @@ callbacks
import time
from mindspore.train.callback import Callback
from mindspore import context
from mindspore.train import ParallelMode
def add_write(file_path, out_str):
"""
@ -85,14 +86,17 @@ class EvalCallBack(Callback):
self.aucMetric = auc_metric
self.aucMetric.clear()
self.eval_file_name = config.eval_file_name
self.eval_values = []
def epoch_name(self, run_context):
def epoch_end(self, run_context):
"""
epoch name
epoch end
"""
self.aucMetric.clear()
context.set_auto_parallel_context(strategy_ckpt_save_file="",
strategy_ckpt_load_file="./strategy_train.ckpt")
parallel_mode = context.get_auto_parallel_context("parallel_mode")
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
context.set_auto_parallel_context(strategy_ckpt_save_file="",
strategy_ckpt_load_file="./strategy_train.ckpt")
start_time = time.time()
out = self.model.eval(self.eval_dataset)
end_time = time.time()
@ -101,4 +105,5 @@ class EvalCallBack(Callback):
time_str = time.strftime("%Y-%m-%d %H:%M%S", time.localtime())
out_str = "{}==== EvalCallBack model.eval(): {}; eval_time: {}s".format(time_str, out.values(), eval_time)
print(out_str)
self.eval_values = out.values()
add_write(self.eval_file_name, out_str)

View File

@ -17,11 +17,20 @@
import os
import math
from enum import Enum
import numpy as np
import pandas as pd
import mindspore.dataset.engine as de
import mindspore.common.dtype as mstype
class DataType(Enum):
"""
Enumerate supported dataset format.
"""
MINDRECORD = 1
TFRECORD = 2
H5 = 3
class H5Dataset():
"""
@ -193,15 +202,60 @@ def _get_tf_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000,
ds = ds.repeat(epochs)
return ds
def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=1000,
line_per_sample=1000, rank_size=None, rank_id=None):
"""
Get dataset with mindrecord format.
Args:
directory (str): Dataset directory.
train_mode (bool): Whether dataset is use for train or eval (default=True).
epochs (int): Dataset epoch size (default=1).
batch_size (int): Dataset batch size (default=1000).
line_per_sample (int): The number of sample per line (default=1000).
rank_size (int): The number of device, not necessary for single device (default=None).
rank_id (int): Id of device, not necessary for single device (default=None).
Returns:
Dataset.
"""
file_prefix_name = 'train_input_part.mindrecord' if train_mode else 'test_input_part.mindrecord'
file_suffix_name = '00' if train_mode else '0'
shuffle = train_mode
if rank_size is not None and rank_id is not None:
ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name),
columns_list=['feat_ids', 'feat_vals', 'label'],
num_shards=rank_size, shard_id=rank_id, shuffle=shuffle,
num_parallel_workers=8)
else:
ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name),
columns_list=['feat_ids', 'feat_vals', 'label'],
shuffle=shuffle, num_parallel_workers=8)
ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True)
ds = ds.map(operations=(lambda x, y, z: (np.array(x).flatten().reshape(batch_size, 39),
np.array(y).flatten().reshape(batch_size, 39),
np.array(z).flatten().reshape(batch_size, 1))),
input_columns=['feat_ids', 'feat_vals', 'label'],
columns_order=['feat_ids', 'feat_vals', 'label'],
num_parallel_workers=8)
ds = ds.repeat(epochs)
return ds
def create_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000,
is_tf_dataset=True, line_per_sample=1000, rank_size=None, rank_id=None):
data_type=DataType.TFRECORD, line_per_sample=1000, rank_size=None, rank_id=None):
"""
create_dataset
"""
if is_tf_dataset:
if data_type == DataType.TFRECORD:
return _get_tf_dataset(data_dir, train_mode, epochs, batch_size,
line_per_sample, rank_size=rank_size, rank_id=rank_id)
if data_type == DataType.MINDRECORD:
return _get_mindrecord_dataset(data_dir, train_mode, epochs,
batch_size, line_per_sample,
rank_size, rank_id)
if rank_size > 1:
raise RuntimeError("please use tfrecord dataset.")
return _get_h5_dataset(data_dir, train_mode, epochs, batch_size)

View File

@ -14,7 +14,7 @@
# ============================================================================
"""wide and deep model"""
from mindspore import nn
from mindspore import Tensor, Parameter, ParameterTuple
from mindspore import Parameter, ParameterTuple
import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
from mindspore.ops import composite as C
@ -24,6 +24,10 @@ from mindspore.nn.optim import Adam, FTRL
# from mindspore.nn.metrics import Metric
from mindspore.common.initializer import Uniform, initializer
# from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from mindspore.train.parallel_utils import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.communication.management import get_group_size
import numpy as np
np_type = np.float32
@ -42,8 +46,7 @@ def init_method(method, shape, name, max_val=1.0):
elif method == 'zero':
params = Parameter(initializer("zeros", shape, ms_type), name=name)
elif method == "normal":
params = Parameter(Tensor(np.random.normal(
loc=0.0, scale=0.01, size=shape).astype(dtype=np_type)), name=name)
params = Parameter(initializer("normal", shape, ms_type), name=name)
return params
@ -66,8 +69,8 @@ def init_var_dict(init_args, in_vars):
var_map[key] = Parameter(initializer(
"zeros", shape, ms_type), name=key)
elif method == 'normal':
var_map[key] = Parameter(Tensor(np.random.normal(
loc=0.0, scale=0.01, size=shape).astype(dtype=np_type)), name=key)
var_map[key] = Parameter(initializer(
"normal", shape, ms_type), name=key)
return var_map
@ -132,6 +135,9 @@ class WideDeepModel(nn.Cell):
def __init__(self, config):
super(WideDeepModel, self).__init__()
self.batch_size = config.batch_size
parallel_mode = _get_parallel_mode()
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
self.batch_size = self.batch_size * get_group_size()
self.field_size = config.field_size
self.vocab_size = config.vocab_size
self.emb_dim = config.emb_dim
@ -285,6 +291,18 @@ class TrainStepWrap(nn.Cell):
self.loss_net_w = IthOutputCell(network, output_index=0)
self.loss_net_d = IthOutputCell(network, output_index=1)
self.reducer_flag = False
self.grad_reducer_w = None
self.grad_reducer_d = None
parallel_mode = _get_parallel_mode()
self.reducer_flag = parallel_mode in (ParallelMode.DATA_PARALLEL,
ParallelMode.HYBRID_PARALLEL)
if self.reducer_flag:
mean = _get_mirror_mean()
degree = _get_device_num()
self.grad_reducer_w = DistributedGradReducer(self.optimizer_w.parameters, mean, degree)
self.grad_reducer_d = DistributedGradReducer(self.optimizer_d.parameters, mean, degree)
def construct(self, batch_ids, batch_wts, label):
weights_w = self.weights_w
weights_d = self.weights_d
@ -295,6 +313,9 @@ class TrainStepWrap(nn.Cell):
label, sens_w)
grads_d = self.grad_d(self.loss_net_d, weights_d)(batch_ids, batch_wts,
label, sens_d)
if self.reducer_flag:
grads_w = self.grad_reducer_w(grads_w)
grads_d = self.grad_reducer_d(grads_d)
return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend(loss_d,
self.optimizer_d(grads_d))

View File

@ -0,0 +1,108 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_multinpu."""
import os
import sys
from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init
from mindspore.parallel import _cost_model_context as cost_model_context
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
from src.callbacks import LossCallBack, EvalCallBack
from src.datasets import create_dataset
from src.metrics import AUCMetric
from src.config import WideDeepConfig
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True)
cost_model_context.set_cost_model_context(multi_subgraphs=True)
init()
def get_WideDeep_net(config):
WideDeep_net = WideDeepModel(config)
loss_net = NetWithLossClass(WideDeep_net, config)
loss_net = VirtualDatasetCellTriple(loss_net)
train_net = TrainStepWrap(loss_net)
eval_net = PredictWithSigmoid(WideDeep_net)
eval_net = VirtualDatasetCellTriple(eval_net)
return train_net, eval_net
class ModelBuilder():
"""
ModelBuilder
"""
def __init__(self):
pass
def get_hook(self):
pass
def get_train_hook(self):
hooks = []
callback = LossCallBack()
hooks.append(callback)
if int(os.getenv('DEVICE_ID')) == 0:
pass
return hooks
def get_net(self, config):
return get_WideDeep_net(config)
def test_train_eval():
"""
test_train_eval
"""
config = WideDeepConfig()
data_path = config.data_path
batch_size = config.batch_size
epochs = config.epochs
print("epochs is {}".format(epochs))
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
net_builder = ModelBuilder()
train_net, eval_net = net_builder.get_net(config)
train_net.set_train()
auc_metric = AUCMetric()
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
callback = LossCallBack(config=config)
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
directory=config.ckpt_path, config=ckptconfig)
model.train(epochs, ds_train,
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb])
if __name__ == "__main__":
test_train_eval()

View File

@ -11,3 +11,4 @@ decorator >= 4.4.0
setuptools >= 40.8.0
matplotlib >= 3.1.3 # for ut test
opencv-python >= 4.2.0.32 # for ut test
sklearn >= 0.0 # for st test

View File

@ -0,0 +1,22 @@
#!/bin/bash
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
LOCAL_HIAI=/usr/local/Ascend
export TBE_IMPL_PATH=${LOCAL_HIAI}/runtime/ops/op_impl/built-in/ai_core/tbe/impl/:${TBE_IMPL_PATH}
export LD_LIBRARY_PATH=${LOCAL_HIAI}/runtime/lib64/:${LOCAL_HIAI}/add-ons/:${LD_LIBRARY_PATH}
export PATH=${LOCAL_HIAI}/runtime/ccec_compiler/bin/:${PATH}
export PYTHONPATH=${LOCAL_HIAI}/runtime/ops/op_impl/built-in/ai_core/tbe/:${PYTHONPATH}
export DEVICE_MEMORY_CAPACITY=1073741824000
export NOT_FULLY_USE_DEVICES=off

View File

@ -0,0 +1,92 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" config. """
import argparse
def argparse_init():
"""
argparse_init
"""
parser = argparse.ArgumentParser(description='WideDeep')
parser.add_argument("--data_path", type=str, default="./test_raw_data/")
parser.add_argument("--epochs", type=int, default=15)
parser.add_argument("--batch_size", type=int, default=16000)
parser.add_argument("--eval_batch_size", type=int, default=16000)
parser.add_argument("--field_size", type=int, default=39)
parser.add_argument("--vocab_size", type=int, default=184965)
parser.add_argument("--emb_dim", type=int, default=80)
parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128])
parser.add_argument("--deep_layer_act", type=str, default='relu')
parser.add_argument("--keep_prob", type=float, default=1.0)
parser.add_argument("--output_path", type=str, default="./output/")
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/")
parser.add_argument("--eval_file_name", type=str, default="eval.log")
parser.add_argument("--loss_file_name", type=str, default="loss.log")
return parser
class WideDeepConfig():
"""
WideDeepConfig
"""
def __init__(self):
self.data_path = "/home/workspace/mindspore_dataset/criteo_data/mindrecord"
self.epochs = 1
self.batch_size = 16000
self.eval_batch_size = 16000
self.field_size = 39
self.vocab_size = 184968
self.emb_dim = 64
self.deep_layer_dim = [1024, 512, 256, 128]
self.deep_layer_act = 'relu'
self.weight_bias_init = ['normal', 'normal']
self.emb_init = 'normal'
self.init_args = [-0.01, 0.01]
self.dropout_flag = False
self.keep_prob = 1.0
self.l2_coef = 8e-5
self.output_path = "./output"
self.eval_file_name = "eval.log"
self.loss_file_name = "loss.log"
self.ckpt_path = "./checkpoints/"
def argparse_init(self):
"""
argparse_init
"""
parser = argparse_init()
args, _ = parser.parse_known_args()
self.data_path = args.data_path
self.epochs = args.epochs
self.batch_size = args.batch_size
self.eval_batch_size = args.eval_batch_size
self.field_size = args.field_size
self.vocab_size = args.vocab_size
self.emb_dim = args.emb_dim
self.deep_layer_dim = args.deep_layer_dim
self.deep_layer_act = args.deep_layer_act
self.keep_prob = args.keep_prob
self.weight_bias_init = ['normal', 'normal']
self.emb_init = 'normal'
self.init_args = [-0.01, 0.01]
self.dropout_flag = False
self.l2_coef = 8e-5
self.output_path = args.output_path
self.eval_file_name = args.eval_file_name
self.loss_file_name = args.loss_file_name
self.ckpt_path = args.ckpt_path

View File

@ -0,0 +1,116 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_imagenet."""
import os
from enum import Enum
import numpy as np
import mindspore.dataset.engine as de
import mindspore.common.dtype as mstype
class DataType(Enum):
"""
Enumerate supported dataset format.
"""
MINDRECORD = 1
TFRECORD = 2
H5 = 3
def _get_tf_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000,
line_per_sample=1000, rank_size=None, rank_id=None):
"""
get_tf_dataset
"""
dataset_files = []
file_prefix_name = 'train' if train_mode else 'test'
shuffle = train_mode
for (dirpath, _, filenames) in os.walk(data_dir):
for filename in filenames:
if file_prefix_name in filename and "tfrecord" in filename:
dataset_files.append(os.path.join(dirpath, filename))
schema = de.Schema()
schema.add_column('feat_ids', de_type=mstype.int32)
schema.add_column('feat_vals', de_type=mstype.float32)
schema.add_column('label', de_type=mstype.float32)
if rank_size is not None and rank_id is not None:
ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle, schema=schema, num_parallel_workers=8,
num_shards=rank_size, shard_id=rank_id, shard_equal_rows=True)
else:
ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle, schema=schema, num_parallel_workers=8)
ds = ds.batch(int(batch_size / line_per_sample),
drop_remainder=True)
ds = ds.map(operations=(lambda x, y, z: (
np.array(x).flatten().reshape(batch_size, 39),
np.array(y).flatten().reshape(batch_size, 39),
np.array(z).flatten().reshape(batch_size, 1))),
input_columns=['feat_ids', 'feat_vals', 'label'],
columns_order=['feat_ids', 'feat_vals', 'label'], num_parallel_workers=8)
#if train_mode:
ds = ds.repeat(epochs)
return ds
def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=1000,
line_per_sample=1000, rank_size=None, rank_id=None):
"""
Get dataset with mindrecord format.
Args:
directory (str): Dataset directory.
train_mode (bool): Whether dataset is use for train or eval (default=True).
epochs (int): Dataset epoch size (default=1).
batch_size (int): Dataset batch size (default=1000).
line_per_sample (int): The number of sample per line (default=1000).
rank_size (int): The number of device, not necessary for single device (default=None).
rank_id (int): Id of device, not necessary for single device (default=None).
Returns:
Dataset.
"""
file_prefix_name = 'train_input_part.mindrecord' if train_mode else 'test_input_part.mindrecord'
file_suffix_name = '00' if train_mode else '0'
shuffle = train_mode
if rank_size is not None and rank_id is not None:
ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name),
columns_list=['feat_ids', 'feat_vals', 'label'],
num_shards=rank_size, shard_id=rank_id, shuffle=shuffle,
num_parallel_workers=8)
else:
ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name),
columns_list=['feat_ids', 'feat_vals', 'label'],
shuffle=shuffle, num_parallel_workers=8)
ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True)
ds = ds.map(operations=(lambda x, y, z: (np.array(x).flatten().reshape(batch_size, 39),
np.array(y).flatten().reshape(batch_size, 39),
np.array(z).flatten().reshape(batch_size, 1))),
input_columns=['feat_ids', 'feat_vals', 'label'],
columns_order=['feat_ids', 'feat_vals', 'label'],
num_parallel_workers=8)
ds = ds.repeat(epochs)
return ds
def create_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000,
data_type=DataType.TFRECORD, line_per_sample=1000, rank_size=None, rank_id=None):
"""
create_dataset
"""
if data_type == DataType.TFRECORD:
return _get_tf_dataset(data_dir, train_mode, epochs, batch_size,
line_per_sample, rank_size=rank_size, rank_id=rank_id)
return _get_mindrecord_dataset(data_dir, train_mode, epochs,
batch_size, line_per_sample,
rank_size, rank_id)

View File

@ -0,0 +1,108 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_multinpu."""
import os
import sys
from mindspore import Model, context
from mindspore.train.callback import TimeMonitor
from mindspore.train import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init
from mindspore.parallel import _cost_model_context as cost_model_context
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
from src.callbacks import LossCallBack, EvalCallBack
from src.datasets import create_dataset, DataType
from src.metrics import AUCMetric
from src.config import WideDeepConfig
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True)
cost_model_context.set_cost_model_context(multi_subgraphs=True)
init()
def get_WideDeep_net(config):
WideDeep_net = WideDeepModel(config)
loss_net = NetWithLossClass(WideDeep_net, config)
loss_net = VirtualDatasetCellTriple(loss_net)
train_net = TrainStepWrap(loss_net)
eval_net = PredictWithSigmoid(WideDeep_net)
eval_net = VirtualDatasetCellTriple(eval_net)
return train_net, eval_net
class ModelBuilder():
"""
ModelBuilder
"""
def __init__(self):
pass
def get_hook(self):
pass
def get_train_hook(self):
hooks = []
callback = LossCallBack()
hooks.append(callback)
if int(os.getenv('DEVICE_ID')) == 0:
pass
return hooks
def get_net(self, config):
return get_WideDeep_net(config)
def test_train_eval():
"""
test_train_eval
"""
config = WideDeepConfig()
data_path = config.data_path
batch_size = config.batch_size
epochs = config.epochs
print("epochs is {}".format(epochs))
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size,
data_type=DataType.MINDRECORD, rank_id=get_rank(), rank_size=get_group_size())
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, batch_size=batch_size,
data_type=DataType.MINDRECORD, rank_id=get_rank(), rank_size=get_group_size())
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
net_builder = ModelBuilder()
train_net, eval_net = net_builder.get_net(config)
train_net.set_train()
auc_metric = AUCMetric()
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
callback = LossCallBack(config=config)
context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_train.ckpt")
model.train(epochs, ds_train,
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback])
eval_values = list(eval_callback.eval_values)
assert eval_values[0] > 0.78
if __name__ == "__main__":
test_train_eval()

View File

@ -0,0 +1,333 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""wide and deep model"""
from mindspore import nn
from mindspore import Parameter, ParameterTuple
import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.ops import operations as P
# from mindspore.nn import Dropout
from mindspore.nn.optim import Adam, FTRL
# from mindspore.nn.metrics import Metric
from mindspore.common.initializer import Uniform, initializer
# from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from mindspore.train.parallel_utils import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.communication.management import get_group_size
import numpy as np
np_type = np.float32
ms_type = mstype.float32
def init_method(method, shape, name, max_val=1.0):
'''
parameter init method
'''
if method in ['uniform']:
params = Parameter(initializer(
Uniform(max_val), shape, ms_type), name=name)
elif method == "one":
params = Parameter(initializer("ones", shape, ms_type), name=name)
elif method == 'zero':
params = Parameter(initializer("zeros", shape, ms_type), name=name)
elif method == "normal":
params = Parameter(initializer("normal", shape, ms_type), name=name)
return params
def init_var_dict(init_args, in_vars):
'''
var init function
'''
var_map = {}
_, _max_val = init_args
for _, iterm in enumerate(in_vars):
key, shape, method = iterm
if key not in var_map.keys():
if method in ['random', 'uniform']:
var_map[key] = Parameter(initializer(
Uniform(_max_val), shape, ms_type), name=key)
elif method == "one":
var_map[key] = Parameter(initializer(
"ones", shape, ms_type), name=key)
elif method == "zero":
var_map[key] = Parameter(initializer(
"zeros", shape, ms_type), name=key)
elif method == 'normal':
var_map[key] = Parameter(initializer(
"normal", shape, ms_type), name=key)
return var_map
class DenseLayer(nn.Cell):
"""
Dense Layer for Deep Layer of WideDeep Model;
Containing: activation, matmul, bias_add;
Args:
"""
def __init__(self, input_dim, output_dim, weight_bias_init, act_str,
keep_prob=0.7, scale_coef=1.0, convert_dtype=True):
super(DenseLayer, self).__init__()
weight_init, bias_init = weight_bias_init
self.weight = init_method(
weight_init, [input_dim, output_dim], name="weight")
self.bias = init_method(bias_init, [output_dim], name="bias")
self.act_func = self._init_activation(act_str)
self.matmul = P.MatMul(transpose_b=False)
self.bias_add = P.BiasAdd()
self.cast = P.Cast()
#self.dropout = Dropout(keep_prob=keep_prob)
self.mul = P.Mul()
self.realDiv = P.RealDiv()
self.scale_coef = scale_coef
self.convert_dtype = convert_dtype
def _init_activation(self, act_str):
act_str = act_str.lower()
if act_str == "relu":
act_func = P.ReLU()
elif act_str == "sigmoid":
act_func = P.Sigmoid()
elif act_str == "tanh":
act_func = P.Tanh()
return act_func
def construct(self, x):
x = self.act_func(x)
# if self.training:
# x = self.dropout(x)
x = self.mul(x, self.scale_coef)
if self.convert_dtype:
x = self.cast(x, mstype.float16)
weight = self.cast(self.weight, mstype.float16)
wx = self.matmul(x, weight)
wx = self.cast(wx, mstype.float32)
else:
wx = self.matmul(x, self.weight)
wx = self.realDiv(wx, self.scale_coef)
output = self.bias_add(wx, self.bias)
return output
class WideDeepModel(nn.Cell):
"""
From paper: " Wide & Deep Learning for Recommender Systems"
Args:
config (Class): The default config of Wide&Deep
"""
def __init__(self, config):
super(WideDeepModel, self).__init__()
self.batch_size = config.batch_size
parallel_mode = _get_parallel_mode()
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
self.batch_size = self.batch_size * get_group_size()
self.field_size = config.field_size
self.vocab_size = config.vocab_size
self.emb_dim = config.emb_dim
self.deep_layer_dims_list = config.deep_layer_dim
self.deep_layer_act = config.deep_layer_act
self.init_args = config.init_args
self.weight_init, self.bias_init = config.weight_bias_init
self.weight_bias_init = config.weight_bias_init
self.emb_init = config.emb_init
self.drop_out = config.dropout_flag
self.keep_prob = config.keep_prob
self.deep_input_dims = self.field_size * self.emb_dim
self.layer_dims = self.deep_layer_dims_list + [1]
self.all_dim_list = [self.deep_input_dims] + self.layer_dims
init_acts = [('Wide_w', [self.vocab_size, 1], self.emb_init),
('V_l2', [self.vocab_size, self.emb_dim], self.emb_init),
('Wide_b', [1], self.emb_init)]
var_map = init_var_dict(self.init_args, init_acts)
self.wide_w = var_map["Wide_w"]
self.wide_b = var_map["Wide_b"]
self.embedding_table = var_map["V_l2"]
self.dense_layer_1 = DenseLayer(self.all_dim_list[0],
self.all_dim_list[1],
self.weight_bias_init,
self.deep_layer_act, convert_dtype=True)
self.dense_layer_2 = DenseLayer(self.all_dim_list[1],
self.all_dim_list[2],
self.weight_bias_init,
self.deep_layer_act, convert_dtype=True)
self.dense_layer_3 = DenseLayer(self.all_dim_list[2],
self.all_dim_list[3],
self.weight_bias_init,
self.deep_layer_act, convert_dtype=True)
self.dense_layer_4 = DenseLayer(self.all_dim_list[3],
self.all_dim_list[4],
self.weight_bias_init,
self.deep_layer_act, convert_dtype=True)
self.dense_layer_5 = DenseLayer(self.all_dim_list[4],
self.all_dim_list[5],
self.weight_bias_init,
self.deep_layer_act, convert_dtype=True)
self.gather_v2 = P.GatherV2().set_strategy(((1, 8), (1, 1)))
self.gather_v2_1 = P.GatherV2()
self.mul = P.Mul()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.reshape = P.Reshape()
self.square = P.Square()
self.shape = P.Shape()
self.tile = P.Tile()
self.concat = P.Concat(axis=1)
self.cast = P.Cast()
def construct(self, id_hldr, wt_hldr):
"""
Args:
id_hldr: batch ids;
wt_hldr: batch weights;
"""
mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1))
# Wide layer
wide_id_weight = self.gather_v2_1(self.wide_w, id_hldr, 0)
wx = self.mul(wide_id_weight, mask)
wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1))
# Deep layer
deep_id_embs = self.gather_v2(self.embedding_table, id_hldr, 0)
vx = self.mul(deep_id_embs, mask)
deep_in = self.reshape(vx, (-1, self.field_size * self.emb_dim))
deep_in = self.dense_layer_1(deep_in)
deep_in = self.dense_layer_2(deep_in)
deep_in = self.dense_layer_3(deep_in)
deep_in = self.dense_layer_4(deep_in)
deep_out = self.dense_layer_5(deep_in)
out = wide_out + deep_out
return out, self.embedding_table
class NetWithLossClass(nn.Cell):
""""
Provide WideDeep training loss through network.
Args:
network (Cell): The training network
config (Class): WideDeep config
"""
def __init__(self, network, config):
super(NetWithLossClass, self).__init__(auto_prefix=False)
self.network = network
self.l2_coef = config.l2_coef
self.loss = P.SigmoidCrossEntropyWithLogits()
self.square = P.Square().set_strategy(((1, get_group_size()),))
self.reduceMean_false = P.ReduceMean(keep_dims=False)
self.reduceSum_false = P.ReduceSum(keep_dims=False)
def construct(self, batch_ids, batch_wts, label):
predict, embedding_table = self.network(batch_ids, batch_wts)
log_loss = self.loss(predict, label)
wide_loss = self.reduceMean_false(log_loss)
l2_loss_v = self.reduceSum_false(self.square(embedding_table)) / 2
deep_loss = self.reduceMean_false(log_loss) + self.l2_coef * l2_loss_v
return wide_loss, deep_loss
class IthOutputCell(nn.Cell):
def __init__(self, network, output_index):
super(IthOutputCell, self).__init__()
self.network = network
self.output_index = output_index
def construct(self, x1, x2, x3):
predict = self.network(x1, x2, x3)[self.output_index]
return predict
class TrainStepWrap(nn.Cell):
"""
Encapsulation class of WideDeep network training.
Append Adam and FTRL optimizers to the training network after that construct
function can be called to create the backward graph.
Args:
network (Cell): the training network. Note that loss function should have been added.
sens (Number): The adjust parameter. Default: 1000.0
"""
def __init__(self, network, sens=1000.0):
super(TrainStepWrap, self).__init__()
self.network = network
self.network.set_train()
self.trainable_params = network.trainable_params()
weights_w = []
weights_d = []
for params in self.trainable_params:
if 'wide' in params.name:
weights_w.append(params)
else:
weights_d.append(params)
self.weights_w = ParameterTuple(weights_w)
self.weights_d = ParameterTuple(weights_d)
self.optimizer_w = FTRL(learning_rate=1e-2, params=self.weights_w,
l1=1e-8, l2=1e-8, initial_accum=1.0)
self.optimizer_d = Adam(
self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens)
self.hyper_map = C.HyperMap()
self.grad_w = C.GradOperation('grad_w', get_by_list=True,
sens_param=True)
self.grad_d = C.GradOperation('grad_d', get_by_list=True,
sens_param=True)
self.sens = sens
self.loss_net_w = IthOutputCell(network, output_index=0)
self.loss_net_d = IthOutputCell(network, output_index=1)
self.reducer_flag = False
self.grad_reducer_w = None
self.grad_reducer_d = None
parallel_mode = _get_parallel_mode()
self.reducer_flag = parallel_mode in (ParallelMode.DATA_PARALLEL,
ParallelMode.HYBRID_PARALLEL)
if self.reducer_flag:
mean = _get_mirror_mean()
degree = _get_device_num()
self.grad_reducer_w = DistributedGradReducer(self.optimizer_w.parameters, mean, degree)
self.grad_reducer_d = DistributedGradReducer(self.optimizer_d.parameters, mean, degree)
def construct(self, batch_ids, batch_wts, label):
weights_w = self.weights_w
weights_d = self.weights_d
loss_w, loss_d = self.network(batch_ids, batch_wts, label)
sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens)
sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens)
grads_w = self.grad_w(self.loss_net_w, weights_w)(batch_ids, batch_wts,
label, sens_w)
grads_d = self.grad_d(self.loss_net_d, weights_d)(batch_ids, batch_wts,
label, sens_d)
if self.reducer_flag:
grads_w = self.grad_reducer_w(grads_w)
grads_d = self.grad_reducer_d(grads_d)
return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend(loss_d,
self.optimizer_d(grads_d))
class PredictWithSigmoid(nn.Cell):
def __init__(self, network):
super(PredictWithSigmoid, self).__init__()
self.network = network
self.sigmoid = P.Sigmoid()
def construct(self, batch_ids, batch_wts, labels):
logits, _, _, = self.network(batch_ids, batch_wts)
pred_probs = self.sigmoid(logits)
return logits, pred_probs, labels

View File

@ -0,0 +1,65 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
set -e
BASE_PATH=$(cd "$(dirname $0)"; pwd)
CONFIG_PATH=/home/workspace/mindspore_config
export DEVICE_NUM=8
export RANK_SIZE=$DEVICE_NUM
unset SLOG_PRINT_TO_STDOUT
export MINDSPORE_HCCL_CONFIG_PATH=$CONFIG_PATH/hccl/rank_table_${DEVICE_NUM}p.json
CODE_DIR="./"
if [ -d ${BASE_PATH}/../../../../model_zoo/wide_and_deep ]; then
CODE_DIR=${BASE_PATH}/../../../../model_zoo/wide_and_deep
elif [ -d ${BASE_PATH}/../../model_zoo/wide_and_deep ]; then
CODE_DIR=${BASE_PATH}/../../model_zoo/wide_and_deep
else
echo "[ERROR] code dir is not found"
fi
echo $CODE_DIR
rm -rf ${BASE_PATH}/wide_and_deep
cp -r ${CODE_DIR} ${BASE_PATH}/wide_and_deep
cp -f ${BASE_PATH}/python_file_for_ci/train_and_test_multinpu_ci.py ${BASE_PATH}/wide_and_deep/train_and_test_multinpu_ci.py
cp -f ${BASE_PATH}/python_file_for_ci/__init__.py ${BASE_PATH}/wide_and_deep/__init__.py
cp -f ${BASE_PATH}/python_file_for_ci/config.py ${BASE_PATH}/wide_and_deep/src/config.py
cp -f ${BASE_PATH}/python_file_for_ci/datasets.py ${BASE_PATH}/wide_and_deep/src/datasets.py
cp -f ${BASE_PATH}/python_file_for_ci/wide_and_deep.py ${BASE_PATH}/wide_and_deep/src/wide_and_deep.py
source ${BASE_PATH}/env.sh
export PYTHONPATH=${BASE_PATH}/wide_and_deep/:$PYTHONPATH
process_pid=()
for((i=0; i<$DEVICE_NUM; i++)); do
rm -rf ${BASE_PATH}/wide_and_deep_auto_parallel${i}
mkdir ${BASE_PATH}/wide_and_deep_auto_parallel${i}
cd ${BASE_PATH}/wide_and_deep_auto_parallel${i}
export RANK_ID=${i}
export DEVICE_ID=${i}
echo "start training for device $i"
env > env$i.log
pytest -s -v ../wide_and_deep/train_and_test_multinpu_ci.py > train_and_test_multinpu_ci$i.log 2>&1 &
process_pid[${i}]=`echo $!`
done
for((i=0; i<${DEVICE_NUM}; i++)); do
wait ${process_pid[i]}
status=`echo $?`
if [ "${status}" != "0" ]; then
echo "[ERROR] test wide_and_deep semi auto parallel failed. status: ${status}"
exit 1
else
echo "[INFO] test wide_and_deep semi auto parallel success."
fi
done
exit 0

View File

@ -0,0 +1,27 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import pytest
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_single
def test_wide_and_deep():
sh_path = os.path.split(os.path.realpath(__file__))[0]
ret = os.system(f"sh {sh_path}/run_wide_and_deep_auto_parallel.sh")
os.system(f"grep -E 'ERROR|error' {sh_path}/wide_and_deep_auto_parallel*/train*log -C 3")
assert ret == 0

View File

@ -0,0 +1,114 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_multinpu."""
import os
import sys
import numpy as np
from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
from src.callbacks import LossCallBack, EvalCallBack
from src.datasets import create_dataset
from src.metrics import AUCMetric
from src.config import WideDeepConfig
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True)
init()
def get_WideDeep_net(config):
WideDeep_net = WideDeepModel(config)
loss_net = NetWithLossClass(WideDeep_net, config)
train_net = TrainStepWrap(loss_net)
eval_net = PredictWithSigmoid(WideDeep_net)
return train_net, eval_net
class ModelBuilder():
"""
ModelBuilder
"""
def __init__(self):
pass
def get_hook(self):
pass
def get_train_hook(self):
hooks = []
callback = LossCallBack()
hooks.append(callback)
if int(os.getenv('DEVICE_ID')) == 0:
pass
return hooks
def get_net(self, config):
return get_WideDeep_net(config)
def test_train_eval():
"""
test_train_eval
"""
np.random.seed(1000)
config = WideDeepConfig()
data_path = config.data_path
batch_size = config.batch_size
epochs = config.epochs
print("epochs is {}".format(epochs))
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
net_builder = ModelBuilder()
train_net, eval_net = net_builder.get_net(config)
train_net.set_train()
auc_metric = AUCMetric()
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
callback = LossCallBack(config=config)
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
directory=config.ckpt_path, config=ckptconfig)
out = model.eval(ds_eval)
print("=====" * 5 + "model.eval() initialized: {}".format(out))
model.train(epochs, ds_train,
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb])
expect_out0 = [0.792634,0.799862,0.803324]
expect_out6 = [0.796580,0.803908,0.807262]
if get_rank() == 0:
assert np.allclose(eval_callback.eval_values, expect_out0)
if get_rank() == 6:
assert np.allclose(eval_callback.eval_values, expect_out6)
if __name__ == "__main__":
test_train_eval()