data_parallel_grad_reducer

This commit is contained in:
yao_yf 2020-06-04 09:27:42 +08:00
parent fcdc88cca9
commit ce03ce5af2
16 changed files with 1086 additions and 16 deletions

View File

@ -25,6 +25,8 @@ The entire code structure is as following:
WideDeep.py "Model structure" WideDeep.py "Model structure"
callbacks.py "Callback class for training and evaluation" callbacks.py "Callback class for training and evaluation"
metrics.py "Metric class" metrics.py "Metric class"
|--- script/ "run shell dir"
run_multinpu_train.sh "run data parallel"
``` ```
### Train and evaluate model ### Train and evaluate model

View File

@ -17,16 +17,18 @@
# bash run_multinpu_train.sh # bash run_multinpu_train.sh
execute_path=$(pwd) execute_path=$(pwd)
# export RANK_TABLE_FILE=${execute_path}/rank_table_8p.json export RANK_SIZE=$1
# export RANK_SIZE=8 export EPOCH_SIZE=$2
# export MINDSPORE_HCCL_CONFIG_PATH=${execute_path}/rank_table_8p.json export DATASET=$3
export RANK_TABLE_FILE=$4
export MINDSPORE_HCCL_CONFIG_PATH=$4
for((i=0;i<=7;i++)); for((i=0;i<=$RANK_SIZE;i++));
do do
rm -rf ${execute_path}/device_$i/ rm -rf ${execute_path}/device_$i/
mkdir ${execute_path}/device_$i/ mkdir ${execute_path}/device_$i/
cd ${execute_path}/device_$i/ || exit cd ${execute_path}/device_$i/ || exit
export RANK_ID=$i export RANK_ID=$i
export DEVICE_ID=$i export DEVICE_ID=$i
pytest -s ${execute_path}/train_and_test_multinpu.py >train_deep$i.log 2>&1 & pytest -s ${execute_path}/train_and_test_multinpu.py --data_path=$DATASET --epochs=$EPOCH_SIZE >train_deep$i.log 2>&1 &
done done

View File

@ -17,6 +17,7 @@ callbacks
import time import time
from mindspore.train.callback import Callback from mindspore.train.callback import Callback
from mindspore import context from mindspore import context
from mindspore.train import ParallelMode
def add_write(file_path, out_str): def add_write(file_path, out_str):
""" """
@ -85,14 +86,17 @@ class EvalCallBack(Callback):
self.aucMetric = auc_metric self.aucMetric = auc_metric
self.aucMetric.clear() self.aucMetric.clear()
self.eval_file_name = config.eval_file_name self.eval_file_name = config.eval_file_name
self.eval_values = []
def epoch_name(self, run_context): def epoch_end(self, run_context):
""" """
epoch name epoch end
""" """
self.aucMetric.clear() self.aucMetric.clear()
context.set_auto_parallel_context(strategy_ckpt_save_file="", parallel_mode = context.get_auto_parallel_context("parallel_mode")
strategy_ckpt_load_file="./strategy_train.ckpt") if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
context.set_auto_parallel_context(strategy_ckpt_save_file="",
strategy_ckpt_load_file="./strategy_train.ckpt")
start_time = time.time() start_time = time.time()
out = self.model.eval(self.eval_dataset) out = self.model.eval(self.eval_dataset)
end_time = time.time() end_time = time.time()
@ -101,4 +105,5 @@ class EvalCallBack(Callback):
time_str = time.strftime("%Y-%m-%d %H:%M%S", time.localtime()) time_str = time.strftime("%Y-%m-%d %H:%M%S", time.localtime())
out_str = "{}==== EvalCallBack model.eval(): {}; eval_time: {}s".format(time_str, out.values(), eval_time) out_str = "{}==== EvalCallBack model.eval(): {}; eval_time: {}s".format(time_str, out.values(), eval_time)
print(out_str) print(out_str)
self.eval_values = out.values()
add_write(self.eval_file_name, out_str) add_write(self.eval_file_name, out_str)

View File

@ -17,11 +17,20 @@
import os import os
import math import math
from enum import Enum
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import mindspore.dataset.engine as de import mindspore.dataset.engine as de
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
class DataType(Enum):
"""
Enumerate supported dataset format.
"""
MINDRECORD = 1
TFRECORD = 2
H5 = 3
class H5Dataset(): class H5Dataset():
""" """
@ -193,15 +202,60 @@ def _get_tf_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000,
ds = ds.repeat(epochs) ds = ds.repeat(epochs)
return ds return ds
def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=1000,
line_per_sample=1000, rank_size=None, rank_id=None):
"""
Get dataset with mindrecord format.
Args:
directory (str): Dataset directory.
train_mode (bool): Whether dataset is use for train or eval (default=True).
epochs (int): Dataset epoch size (default=1).
batch_size (int): Dataset batch size (default=1000).
line_per_sample (int): The number of sample per line (default=1000).
rank_size (int): The number of device, not necessary for single device (default=None).
rank_id (int): Id of device, not necessary for single device (default=None).
Returns:
Dataset.
"""
file_prefix_name = 'train_input_part.mindrecord' if train_mode else 'test_input_part.mindrecord'
file_suffix_name = '00' if train_mode else '0'
shuffle = train_mode
if rank_size is not None and rank_id is not None:
ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name),
columns_list=['feat_ids', 'feat_vals', 'label'],
num_shards=rank_size, shard_id=rank_id, shuffle=shuffle,
num_parallel_workers=8)
else:
ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name),
columns_list=['feat_ids', 'feat_vals', 'label'],
shuffle=shuffle, num_parallel_workers=8)
ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True)
ds = ds.map(operations=(lambda x, y, z: (np.array(x).flatten().reshape(batch_size, 39),
np.array(y).flatten().reshape(batch_size, 39),
np.array(z).flatten().reshape(batch_size, 1))),
input_columns=['feat_ids', 'feat_vals', 'label'],
columns_order=['feat_ids', 'feat_vals', 'label'],
num_parallel_workers=8)
ds = ds.repeat(epochs)
return ds
def create_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000, def create_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000,
is_tf_dataset=True, line_per_sample=1000, rank_size=None, rank_id=None): data_type=DataType.TFRECORD, line_per_sample=1000, rank_size=None, rank_id=None):
""" """
create_dataset create_dataset
""" """
if is_tf_dataset: if data_type == DataType.TFRECORD:
return _get_tf_dataset(data_dir, train_mode, epochs, batch_size, return _get_tf_dataset(data_dir, train_mode, epochs, batch_size,
line_per_sample, rank_size=rank_size, rank_id=rank_id) line_per_sample, rank_size=rank_size, rank_id=rank_id)
if data_type == DataType.MINDRECORD:
return _get_mindrecord_dataset(data_dir, train_mode, epochs,
batch_size, line_per_sample,
rank_size, rank_id)
if rank_size > 1: if rank_size > 1:
raise RuntimeError("please use tfrecord dataset.") raise RuntimeError("please use tfrecord dataset.")
return _get_h5_dataset(data_dir, train_mode, epochs, batch_size) return _get_h5_dataset(data_dir, train_mode, epochs, batch_size)

View File

@ -14,7 +14,7 @@
# ============================================================================ # ============================================================================
"""wide and deep model""" """wide and deep model"""
from mindspore import nn from mindspore import nn
from mindspore import Tensor, Parameter, ParameterTuple from mindspore import Parameter, ParameterTuple
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import composite as C from mindspore.ops import composite as C
@ -24,6 +24,10 @@ from mindspore.nn.optim import Adam, FTRL
# from mindspore.nn.metrics import Metric # from mindspore.nn.metrics import Metric
from mindspore.common.initializer import Uniform, initializer from mindspore.common.initializer import Uniform, initializer
# from mindspore.train.callback import ModelCheckpoint, CheckpointConfig # from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from mindspore.train.parallel_utils import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.communication.management import get_group_size
import numpy as np import numpy as np
np_type = np.float32 np_type = np.float32
@ -42,8 +46,7 @@ def init_method(method, shape, name, max_val=1.0):
elif method == 'zero': elif method == 'zero':
params = Parameter(initializer("zeros", shape, ms_type), name=name) params = Parameter(initializer("zeros", shape, ms_type), name=name)
elif method == "normal": elif method == "normal":
params = Parameter(Tensor(np.random.normal( params = Parameter(initializer("normal", shape, ms_type), name=name)
loc=0.0, scale=0.01, size=shape).astype(dtype=np_type)), name=name)
return params return params
@ -66,8 +69,8 @@ def init_var_dict(init_args, in_vars):
var_map[key] = Parameter(initializer( var_map[key] = Parameter(initializer(
"zeros", shape, ms_type), name=key) "zeros", shape, ms_type), name=key)
elif method == 'normal': elif method == 'normal':
var_map[key] = Parameter(Tensor(np.random.normal( var_map[key] = Parameter(initializer(
loc=0.0, scale=0.01, size=shape).astype(dtype=np_type)), name=key) "normal", shape, ms_type), name=key)
return var_map return var_map
@ -132,6 +135,9 @@ class WideDeepModel(nn.Cell):
def __init__(self, config): def __init__(self, config):
super(WideDeepModel, self).__init__() super(WideDeepModel, self).__init__()
self.batch_size = config.batch_size self.batch_size = config.batch_size
parallel_mode = _get_parallel_mode()
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
self.batch_size = self.batch_size * get_group_size()
self.field_size = config.field_size self.field_size = config.field_size
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.emb_dim = config.emb_dim self.emb_dim = config.emb_dim
@ -285,6 +291,18 @@ class TrainStepWrap(nn.Cell):
self.loss_net_w = IthOutputCell(network, output_index=0) self.loss_net_w = IthOutputCell(network, output_index=0)
self.loss_net_d = IthOutputCell(network, output_index=1) self.loss_net_d = IthOutputCell(network, output_index=1)
self.reducer_flag = False
self.grad_reducer_w = None
self.grad_reducer_d = None
parallel_mode = _get_parallel_mode()
self.reducer_flag = parallel_mode in (ParallelMode.DATA_PARALLEL,
ParallelMode.HYBRID_PARALLEL)
if self.reducer_flag:
mean = _get_mirror_mean()
degree = _get_device_num()
self.grad_reducer_w = DistributedGradReducer(self.optimizer_w.parameters, mean, degree)
self.grad_reducer_d = DistributedGradReducer(self.optimizer_d.parameters, mean, degree)
def construct(self, batch_ids, batch_wts, label): def construct(self, batch_ids, batch_wts, label):
weights_w = self.weights_w weights_w = self.weights_w
weights_d = self.weights_d weights_d = self.weights_d
@ -295,6 +313,9 @@ class TrainStepWrap(nn.Cell):
label, sens_w) label, sens_w)
grads_d = self.grad_d(self.loss_net_d, weights_d)(batch_ids, batch_wts, grads_d = self.grad_d(self.loss_net_d, weights_d)(batch_ids, batch_wts,
label, sens_d) label, sens_d)
if self.reducer_flag:
grads_w = self.grad_reducer_w(grads_w)
grads_d = self.grad_reducer_d(grads_d)
return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend(loss_d, return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend(loss_d,
self.optimizer_d(grads_d)) self.optimizer_d(grads_d))

View File

@ -0,0 +1,108 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_multinpu."""
import os
import sys
from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init
from mindspore.parallel import _cost_model_context as cost_model_context
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
from src.callbacks import LossCallBack, EvalCallBack
from src.datasets import create_dataset
from src.metrics import AUCMetric
from src.config import WideDeepConfig
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True)
cost_model_context.set_cost_model_context(multi_subgraphs=True)
init()
def get_WideDeep_net(config):
WideDeep_net = WideDeepModel(config)
loss_net = NetWithLossClass(WideDeep_net, config)
loss_net = VirtualDatasetCellTriple(loss_net)
train_net = TrainStepWrap(loss_net)
eval_net = PredictWithSigmoid(WideDeep_net)
eval_net = VirtualDatasetCellTriple(eval_net)
return train_net, eval_net
class ModelBuilder():
"""
ModelBuilder
"""
def __init__(self):
pass
def get_hook(self):
pass
def get_train_hook(self):
hooks = []
callback = LossCallBack()
hooks.append(callback)
if int(os.getenv('DEVICE_ID')) == 0:
pass
return hooks
def get_net(self, config):
return get_WideDeep_net(config)
def test_train_eval():
"""
test_train_eval
"""
config = WideDeepConfig()
data_path = config.data_path
batch_size = config.batch_size
epochs = config.epochs
print("epochs is {}".format(epochs))
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
net_builder = ModelBuilder()
train_net, eval_net = net_builder.get_net(config)
train_net.set_train()
auc_metric = AUCMetric()
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
callback = LossCallBack(config=config)
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
directory=config.ckpt_path, config=ckptconfig)
model.train(epochs, ds_train,
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb])
if __name__ == "__main__":
test_train_eval()

View File

@ -11,3 +11,4 @@ decorator >= 4.4.0
setuptools >= 40.8.0 setuptools >= 40.8.0
matplotlib >= 3.1.3 # for ut test matplotlib >= 3.1.3 # for ut test
opencv-python >= 4.2.0.32 # for ut test opencv-python >= 4.2.0.32 # for ut test
sklearn >= 0.0 # for st test

View File

@ -0,0 +1,22 @@
#!/bin/bash
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
LOCAL_HIAI=/usr/local/Ascend
export TBE_IMPL_PATH=${LOCAL_HIAI}/runtime/ops/op_impl/built-in/ai_core/tbe/impl/:${TBE_IMPL_PATH}
export LD_LIBRARY_PATH=${LOCAL_HIAI}/runtime/lib64/:${LOCAL_HIAI}/add-ons/:${LD_LIBRARY_PATH}
export PATH=${LOCAL_HIAI}/runtime/ccec_compiler/bin/:${PATH}
export PYTHONPATH=${LOCAL_HIAI}/runtime/ops/op_impl/built-in/ai_core/tbe/:${PYTHONPATH}
export DEVICE_MEMORY_CAPACITY=1073741824000
export NOT_FULLY_USE_DEVICES=off

View File

@ -0,0 +1,92 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" config. """
import argparse
def argparse_init():
"""
argparse_init
"""
parser = argparse.ArgumentParser(description='WideDeep')
parser.add_argument("--data_path", type=str, default="./test_raw_data/")
parser.add_argument("--epochs", type=int, default=15)
parser.add_argument("--batch_size", type=int, default=16000)
parser.add_argument("--eval_batch_size", type=int, default=16000)
parser.add_argument("--field_size", type=int, default=39)
parser.add_argument("--vocab_size", type=int, default=184965)
parser.add_argument("--emb_dim", type=int, default=80)
parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128])
parser.add_argument("--deep_layer_act", type=str, default='relu')
parser.add_argument("--keep_prob", type=float, default=1.0)
parser.add_argument("--output_path", type=str, default="./output/")
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/")
parser.add_argument("--eval_file_name", type=str, default="eval.log")
parser.add_argument("--loss_file_name", type=str, default="loss.log")
return parser
class WideDeepConfig():
"""
WideDeepConfig
"""
def __init__(self):
self.data_path = "/home/workspace/mindspore_dataset/criteo_data/mindrecord"
self.epochs = 1
self.batch_size = 16000
self.eval_batch_size = 16000
self.field_size = 39
self.vocab_size = 184968
self.emb_dim = 64
self.deep_layer_dim = [1024, 512, 256, 128]
self.deep_layer_act = 'relu'
self.weight_bias_init = ['normal', 'normal']
self.emb_init = 'normal'
self.init_args = [-0.01, 0.01]
self.dropout_flag = False
self.keep_prob = 1.0
self.l2_coef = 8e-5
self.output_path = "./output"
self.eval_file_name = "eval.log"
self.loss_file_name = "loss.log"
self.ckpt_path = "./checkpoints/"
def argparse_init(self):
"""
argparse_init
"""
parser = argparse_init()
args, _ = parser.parse_known_args()
self.data_path = args.data_path
self.epochs = args.epochs
self.batch_size = args.batch_size
self.eval_batch_size = args.eval_batch_size
self.field_size = args.field_size
self.vocab_size = args.vocab_size
self.emb_dim = args.emb_dim
self.deep_layer_dim = args.deep_layer_dim
self.deep_layer_act = args.deep_layer_act
self.keep_prob = args.keep_prob
self.weight_bias_init = ['normal', 'normal']
self.emb_init = 'normal'
self.init_args = [-0.01, 0.01]
self.dropout_flag = False
self.l2_coef = 8e-5
self.output_path = args.output_path
self.eval_file_name = args.eval_file_name
self.loss_file_name = args.loss_file_name
self.ckpt_path = args.ckpt_path

View File

@ -0,0 +1,116 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_imagenet."""
import os
from enum import Enum
import numpy as np
import mindspore.dataset.engine as de
import mindspore.common.dtype as mstype
class DataType(Enum):
"""
Enumerate supported dataset format.
"""
MINDRECORD = 1
TFRECORD = 2
H5 = 3
def _get_tf_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000,
line_per_sample=1000, rank_size=None, rank_id=None):
"""
get_tf_dataset
"""
dataset_files = []
file_prefix_name = 'train' if train_mode else 'test'
shuffle = train_mode
for (dirpath, _, filenames) in os.walk(data_dir):
for filename in filenames:
if file_prefix_name in filename and "tfrecord" in filename:
dataset_files.append(os.path.join(dirpath, filename))
schema = de.Schema()
schema.add_column('feat_ids', de_type=mstype.int32)
schema.add_column('feat_vals', de_type=mstype.float32)
schema.add_column('label', de_type=mstype.float32)
if rank_size is not None and rank_id is not None:
ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle, schema=schema, num_parallel_workers=8,
num_shards=rank_size, shard_id=rank_id, shard_equal_rows=True)
else:
ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle, schema=schema, num_parallel_workers=8)
ds = ds.batch(int(batch_size / line_per_sample),
drop_remainder=True)
ds = ds.map(operations=(lambda x, y, z: (
np.array(x).flatten().reshape(batch_size, 39),
np.array(y).flatten().reshape(batch_size, 39),
np.array(z).flatten().reshape(batch_size, 1))),
input_columns=['feat_ids', 'feat_vals', 'label'],
columns_order=['feat_ids', 'feat_vals', 'label'], num_parallel_workers=8)
#if train_mode:
ds = ds.repeat(epochs)
return ds
def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=1000,
line_per_sample=1000, rank_size=None, rank_id=None):
"""
Get dataset with mindrecord format.
Args:
directory (str): Dataset directory.
train_mode (bool): Whether dataset is use for train or eval (default=True).
epochs (int): Dataset epoch size (default=1).
batch_size (int): Dataset batch size (default=1000).
line_per_sample (int): The number of sample per line (default=1000).
rank_size (int): The number of device, not necessary for single device (default=None).
rank_id (int): Id of device, not necessary for single device (default=None).
Returns:
Dataset.
"""
file_prefix_name = 'train_input_part.mindrecord' if train_mode else 'test_input_part.mindrecord'
file_suffix_name = '00' if train_mode else '0'
shuffle = train_mode
if rank_size is not None and rank_id is not None:
ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name),
columns_list=['feat_ids', 'feat_vals', 'label'],
num_shards=rank_size, shard_id=rank_id, shuffle=shuffle,
num_parallel_workers=8)
else:
ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name),
columns_list=['feat_ids', 'feat_vals', 'label'],
shuffle=shuffle, num_parallel_workers=8)
ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True)
ds = ds.map(operations=(lambda x, y, z: (np.array(x).flatten().reshape(batch_size, 39),
np.array(y).flatten().reshape(batch_size, 39),
np.array(z).flatten().reshape(batch_size, 1))),
input_columns=['feat_ids', 'feat_vals', 'label'],
columns_order=['feat_ids', 'feat_vals', 'label'],
num_parallel_workers=8)
ds = ds.repeat(epochs)
return ds
def create_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000,
data_type=DataType.TFRECORD, line_per_sample=1000, rank_size=None, rank_id=None):
"""
create_dataset
"""
if data_type == DataType.TFRECORD:
return _get_tf_dataset(data_dir, train_mode, epochs, batch_size,
line_per_sample, rank_size=rank_size, rank_id=rank_id)
return _get_mindrecord_dataset(data_dir, train_mode, epochs,
batch_size, line_per_sample,
rank_size, rank_id)

View File

@ -0,0 +1,108 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_multinpu."""
import os
import sys
from mindspore import Model, context
from mindspore.train.callback import TimeMonitor
from mindspore.train import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init
from mindspore.parallel import _cost_model_context as cost_model_context
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
from src.callbacks import LossCallBack, EvalCallBack
from src.datasets import create_dataset, DataType
from src.metrics import AUCMetric
from src.config import WideDeepConfig
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True)
cost_model_context.set_cost_model_context(multi_subgraphs=True)
init()
def get_WideDeep_net(config):
WideDeep_net = WideDeepModel(config)
loss_net = NetWithLossClass(WideDeep_net, config)
loss_net = VirtualDatasetCellTriple(loss_net)
train_net = TrainStepWrap(loss_net)
eval_net = PredictWithSigmoid(WideDeep_net)
eval_net = VirtualDatasetCellTriple(eval_net)
return train_net, eval_net
class ModelBuilder():
"""
ModelBuilder
"""
def __init__(self):
pass
def get_hook(self):
pass
def get_train_hook(self):
hooks = []
callback = LossCallBack()
hooks.append(callback)
if int(os.getenv('DEVICE_ID')) == 0:
pass
return hooks
def get_net(self, config):
return get_WideDeep_net(config)
def test_train_eval():
"""
test_train_eval
"""
config = WideDeepConfig()
data_path = config.data_path
batch_size = config.batch_size
epochs = config.epochs
print("epochs is {}".format(epochs))
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size,
data_type=DataType.MINDRECORD, rank_id=get_rank(), rank_size=get_group_size())
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, batch_size=batch_size,
data_type=DataType.MINDRECORD, rank_id=get_rank(), rank_size=get_group_size())
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
net_builder = ModelBuilder()
train_net, eval_net = net_builder.get_net(config)
train_net.set_train()
auc_metric = AUCMetric()
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
callback = LossCallBack(config=config)
context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_train.ckpt")
model.train(epochs, ds_train,
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback])
eval_values = list(eval_callback.eval_values)
assert eval_values[0] > 0.78
if __name__ == "__main__":
test_train_eval()

View File

@ -0,0 +1,333 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""wide and deep model"""
from mindspore import nn
from mindspore import Parameter, ParameterTuple
import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.ops import operations as P
# from mindspore.nn import Dropout
from mindspore.nn.optim import Adam, FTRL
# from mindspore.nn.metrics import Metric
from mindspore.common.initializer import Uniform, initializer
# from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from mindspore.train.parallel_utils import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.communication.management import get_group_size
import numpy as np
np_type = np.float32
ms_type = mstype.float32
def init_method(method, shape, name, max_val=1.0):
'''
parameter init method
'''
if method in ['uniform']:
params = Parameter(initializer(
Uniform(max_val), shape, ms_type), name=name)
elif method == "one":
params = Parameter(initializer("ones", shape, ms_type), name=name)
elif method == 'zero':
params = Parameter(initializer("zeros", shape, ms_type), name=name)
elif method == "normal":
params = Parameter(initializer("normal", shape, ms_type), name=name)
return params
def init_var_dict(init_args, in_vars):
'''
var init function
'''
var_map = {}
_, _max_val = init_args
for _, iterm in enumerate(in_vars):
key, shape, method = iterm
if key not in var_map.keys():
if method in ['random', 'uniform']:
var_map[key] = Parameter(initializer(
Uniform(_max_val), shape, ms_type), name=key)
elif method == "one":
var_map[key] = Parameter(initializer(
"ones", shape, ms_type), name=key)
elif method == "zero":
var_map[key] = Parameter(initializer(
"zeros", shape, ms_type), name=key)
elif method == 'normal':
var_map[key] = Parameter(initializer(
"normal", shape, ms_type), name=key)
return var_map
class DenseLayer(nn.Cell):
"""
Dense Layer for Deep Layer of WideDeep Model;
Containing: activation, matmul, bias_add;
Args:
"""
def __init__(self, input_dim, output_dim, weight_bias_init, act_str,
keep_prob=0.7, scale_coef=1.0, convert_dtype=True):
super(DenseLayer, self).__init__()
weight_init, bias_init = weight_bias_init
self.weight = init_method(
weight_init, [input_dim, output_dim], name="weight")
self.bias = init_method(bias_init, [output_dim], name="bias")
self.act_func = self._init_activation(act_str)
self.matmul = P.MatMul(transpose_b=False)
self.bias_add = P.BiasAdd()
self.cast = P.Cast()
#self.dropout = Dropout(keep_prob=keep_prob)
self.mul = P.Mul()
self.realDiv = P.RealDiv()
self.scale_coef = scale_coef
self.convert_dtype = convert_dtype
def _init_activation(self, act_str):
act_str = act_str.lower()
if act_str == "relu":
act_func = P.ReLU()
elif act_str == "sigmoid":
act_func = P.Sigmoid()
elif act_str == "tanh":
act_func = P.Tanh()
return act_func
def construct(self, x):
x = self.act_func(x)
# if self.training:
# x = self.dropout(x)
x = self.mul(x, self.scale_coef)
if self.convert_dtype:
x = self.cast(x, mstype.float16)
weight = self.cast(self.weight, mstype.float16)
wx = self.matmul(x, weight)
wx = self.cast(wx, mstype.float32)
else:
wx = self.matmul(x, self.weight)
wx = self.realDiv(wx, self.scale_coef)
output = self.bias_add(wx, self.bias)
return output
class WideDeepModel(nn.Cell):
"""
From paper: " Wide & Deep Learning for Recommender Systems"
Args:
config (Class): The default config of Wide&Deep
"""
def __init__(self, config):
super(WideDeepModel, self).__init__()
self.batch_size = config.batch_size
parallel_mode = _get_parallel_mode()
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
self.batch_size = self.batch_size * get_group_size()
self.field_size = config.field_size
self.vocab_size = config.vocab_size
self.emb_dim = config.emb_dim
self.deep_layer_dims_list = config.deep_layer_dim
self.deep_layer_act = config.deep_layer_act
self.init_args = config.init_args
self.weight_init, self.bias_init = config.weight_bias_init
self.weight_bias_init = config.weight_bias_init
self.emb_init = config.emb_init
self.drop_out = config.dropout_flag
self.keep_prob = config.keep_prob
self.deep_input_dims = self.field_size * self.emb_dim
self.layer_dims = self.deep_layer_dims_list + [1]
self.all_dim_list = [self.deep_input_dims] + self.layer_dims
init_acts = [('Wide_w', [self.vocab_size, 1], self.emb_init),
('V_l2', [self.vocab_size, self.emb_dim], self.emb_init),
('Wide_b', [1], self.emb_init)]
var_map = init_var_dict(self.init_args, init_acts)
self.wide_w = var_map["Wide_w"]
self.wide_b = var_map["Wide_b"]
self.embedding_table = var_map["V_l2"]
self.dense_layer_1 = DenseLayer(self.all_dim_list[0],
self.all_dim_list[1],
self.weight_bias_init,
self.deep_layer_act, convert_dtype=True)
self.dense_layer_2 = DenseLayer(self.all_dim_list[1],
self.all_dim_list[2],
self.weight_bias_init,
self.deep_layer_act, convert_dtype=True)
self.dense_layer_3 = DenseLayer(self.all_dim_list[2],
self.all_dim_list[3],
self.weight_bias_init,
self.deep_layer_act, convert_dtype=True)
self.dense_layer_4 = DenseLayer(self.all_dim_list[3],
self.all_dim_list[4],
self.weight_bias_init,
self.deep_layer_act, convert_dtype=True)
self.dense_layer_5 = DenseLayer(self.all_dim_list[4],
self.all_dim_list[5],
self.weight_bias_init,
self.deep_layer_act, convert_dtype=True)
self.gather_v2 = P.GatherV2().set_strategy(((1, 8), (1, 1)))
self.gather_v2_1 = P.GatherV2()
self.mul = P.Mul()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.reshape = P.Reshape()
self.square = P.Square()
self.shape = P.Shape()
self.tile = P.Tile()
self.concat = P.Concat(axis=1)
self.cast = P.Cast()
def construct(self, id_hldr, wt_hldr):
"""
Args:
id_hldr: batch ids;
wt_hldr: batch weights;
"""
mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1))
# Wide layer
wide_id_weight = self.gather_v2_1(self.wide_w, id_hldr, 0)
wx = self.mul(wide_id_weight, mask)
wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1))
# Deep layer
deep_id_embs = self.gather_v2(self.embedding_table, id_hldr, 0)
vx = self.mul(deep_id_embs, mask)
deep_in = self.reshape(vx, (-1, self.field_size * self.emb_dim))
deep_in = self.dense_layer_1(deep_in)
deep_in = self.dense_layer_2(deep_in)
deep_in = self.dense_layer_3(deep_in)
deep_in = self.dense_layer_4(deep_in)
deep_out = self.dense_layer_5(deep_in)
out = wide_out + deep_out
return out, self.embedding_table
class NetWithLossClass(nn.Cell):
""""
Provide WideDeep training loss through network.
Args:
network (Cell): The training network
config (Class): WideDeep config
"""
def __init__(self, network, config):
super(NetWithLossClass, self).__init__(auto_prefix=False)
self.network = network
self.l2_coef = config.l2_coef
self.loss = P.SigmoidCrossEntropyWithLogits()
self.square = P.Square().set_strategy(((1, get_group_size()),))
self.reduceMean_false = P.ReduceMean(keep_dims=False)
self.reduceSum_false = P.ReduceSum(keep_dims=False)
def construct(self, batch_ids, batch_wts, label):
predict, embedding_table = self.network(batch_ids, batch_wts)
log_loss = self.loss(predict, label)
wide_loss = self.reduceMean_false(log_loss)
l2_loss_v = self.reduceSum_false(self.square(embedding_table)) / 2
deep_loss = self.reduceMean_false(log_loss) + self.l2_coef * l2_loss_v
return wide_loss, deep_loss
class IthOutputCell(nn.Cell):
def __init__(self, network, output_index):
super(IthOutputCell, self).__init__()
self.network = network
self.output_index = output_index
def construct(self, x1, x2, x3):
predict = self.network(x1, x2, x3)[self.output_index]
return predict
class TrainStepWrap(nn.Cell):
"""
Encapsulation class of WideDeep network training.
Append Adam and FTRL optimizers to the training network after that construct
function can be called to create the backward graph.
Args:
network (Cell): the training network. Note that loss function should have been added.
sens (Number): The adjust parameter. Default: 1000.0
"""
def __init__(self, network, sens=1000.0):
super(TrainStepWrap, self).__init__()
self.network = network
self.network.set_train()
self.trainable_params = network.trainable_params()
weights_w = []
weights_d = []
for params in self.trainable_params:
if 'wide' in params.name:
weights_w.append(params)
else:
weights_d.append(params)
self.weights_w = ParameterTuple(weights_w)
self.weights_d = ParameterTuple(weights_d)
self.optimizer_w = FTRL(learning_rate=1e-2, params=self.weights_w,
l1=1e-8, l2=1e-8, initial_accum=1.0)
self.optimizer_d = Adam(
self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens)
self.hyper_map = C.HyperMap()
self.grad_w = C.GradOperation('grad_w', get_by_list=True,
sens_param=True)
self.grad_d = C.GradOperation('grad_d', get_by_list=True,
sens_param=True)
self.sens = sens
self.loss_net_w = IthOutputCell(network, output_index=0)
self.loss_net_d = IthOutputCell(network, output_index=1)
self.reducer_flag = False
self.grad_reducer_w = None
self.grad_reducer_d = None
parallel_mode = _get_parallel_mode()
self.reducer_flag = parallel_mode in (ParallelMode.DATA_PARALLEL,
ParallelMode.HYBRID_PARALLEL)
if self.reducer_flag:
mean = _get_mirror_mean()
degree = _get_device_num()
self.grad_reducer_w = DistributedGradReducer(self.optimizer_w.parameters, mean, degree)
self.grad_reducer_d = DistributedGradReducer(self.optimizer_d.parameters, mean, degree)
def construct(self, batch_ids, batch_wts, label):
weights_w = self.weights_w
weights_d = self.weights_d
loss_w, loss_d = self.network(batch_ids, batch_wts, label)
sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens)
sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens)
grads_w = self.grad_w(self.loss_net_w, weights_w)(batch_ids, batch_wts,
label, sens_w)
grads_d = self.grad_d(self.loss_net_d, weights_d)(batch_ids, batch_wts,
label, sens_d)
if self.reducer_flag:
grads_w = self.grad_reducer_w(grads_w)
grads_d = self.grad_reducer_d(grads_d)
return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend(loss_d,
self.optimizer_d(grads_d))
class PredictWithSigmoid(nn.Cell):
def __init__(self, network):
super(PredictWithSigmoid, self).__init__()
self.network = network
self.sigmoid = P.Sigmoid()
def construct(self, batch_ids, batch_wts, labels):
logits, _, _, = self.network(batch_ids, batch_wts)
pred_probs = self.sigmoid(logits)
return logits, pred_probs, labels

View File

@ -0,0 +1,65 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
set -e
BASE_PATH=$(cd "$(dirname $0)"; pwd)
CONFIG_PATH=/home/workspace/mindspore_config
export DEVICE_NUM=8
export RANK_SIZE=$DEVICE_NUM
unset SLOG_PRINT_TO_STDOUT
export MINDSPORE_HCCL_CONFIG_PATH=$CONFIG_PATH/hccl/rank_table_${DEVICE_NUM}p.json
CODE_DIR="./"
if [ -d ${BASE_PATH}/../../../../model_zoo/wide_and_deep ]; then
CODE_DIR=${BASE_PATH}/../../../../model_zoo/wide_and_deep
elif [ -d ${BASE_PATH}/../../model_zoo/wide_and_deep ]; then
CODE_DIR=${BASE_PATH}/../../model_zoo/wide_and_deep
else
echo "[ERROR] code dir is not found"
fi
echo $CODE_DIR
rm -rf ${BASE_PATH}/wide_and_deep
cp -r ${CODE_DIR} ${BASE_PATH}/wide_and_deep
cp -f ${BASE_PATH}/python_file_for_ci/train_and_test_multinpu_ci.py ${BASE_PATH}/wide_and_deep/train_and_test_multinpu_ci.py
cp -f ${BASE_PATH}/python_file_for_ci/__init__.py ${BASE_PATH}/wide_and_deep/__init__.py
cp -f ${BASE_PATH}/python_file_for_ci/config.py ${BASE_PATH}/wide_and_deep/src/config.py
cp -f ${BASE_PATH}/python_file_for_ci/datasets.py ${BASE_PATH}/wide_and_deep/src/datasets.py
cp -f ${BASE_PATH}/python_file_for_ci/wide_and_deep.py ${BASE_PATH}/wide_and_deep/src/wide_and_deep.py
source ${BASE_PATH}/env.sh
export PYTHONPATH=${BASE_PATH}/wide_and_deep/:$PYTHONPATH
process_pid=()
for((i=0; i<$DEVICE_NUM; i++)); do
rm -rf ${BASE_PATH}/wide_and_deep_auto_parallel${i}
mkdir ${BASE_PATH}/wide_and_deep_auto_parallel${i}
cd ${BASE_PATH}/wide_and_deep_auto_parallel${i}
export RANK_ID=${i}
export DEVICE_ID=${i}
echo "start training for device $i"
env > env$i.log
pytest -s -v ../wide_and_deep/train_and_test_multinpu_ci.py > train_and_test_multinpu_ci$i.log 2>&1 &
process_pid[${i}]=`echo $!`
done
for((i=0; i<${DEVICE_NUM}; i++)); do
wait ${process_pid[i]}
status=`echo $?`
if [ "${status}" != "0" ]; then
echo "[ERROR] test wide_and_deep semi auto parallel failed. status: ${status}"
exit 1
else
echo "[INFO] test wide_and_deep semi auto parallel success."
fi
done
exit 0

View File

@ -0,0 +1,27 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import pytest
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_single
def test_wide_and_deep():
sh_path = os.path.split(os.path.realpath(__file__))[0]
ret = os.system(f"sh {sh_path}/run_wide_and_deep_auto_parallel.sh")
os.system(f"grep -E 'ERROR|error' {sh_path}/wide_and_deep_auto_parallel*/train*log -C 3")
assert ret == 0

View File

@ -0,0 +1,114 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_multinpu."""
import os
import sys
import numpy as np
from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
from src.callbacks import LossCallBack, EvalCallBack
from src.datasets import create_dataset
from src.metrics import AUCMetric
from src.config import WideDeepConfig
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True)
init()
def get_WideDeep_net(config):
WideDeep_net = WideDeepModel(config)
loss_net = NetWithLossClass(WideDeep_net, config)
train_net = TrainStepWrap(loss_net)
eval_net = PredictWithSigmoid(WideDeep_net)
return train_net, eval_net
class ModelBuilder():
"""
ModelBuilder
"""
def __init__(self):
pass
def get_hook(self):
pass
def get_train_hook(self):
hooks = []
callback = LossCallBack()
hooks.append(callback)
if int(os.getenv('DEVICE_ID')) == 0:
pass
return hooks
def get_net(self, config):
return get_WideDeep_net(config)
def test_train_eval():
"""
test_train_eval
"""
np.random.seed(1000)
config = WideDeepConfig()
data_path = config.data_path
batch_size = config.batch_size
epochs = config.epochs
print("epochs is {}".format(epochs))
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
net_builder = ModelBuilder()
train_net, eval_net = net_builder.get_net(config)
train_net.set_train()
auc_metric = AUCMetric()
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
callback = LossCallBack(config=config)
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
directory=config.ckpt_path, config=ckptconfig)
out = model.eval(ds_eval)
print("=====" * 5 + "model.eval() initialized: {}".format(out))
model.train(epochs, ds_train,
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb])
expect_out0 = [0.792634,0.799862,0.803324]
expect_out6 = [0.796580,0.803908,0.807262]
if get_rank() == 0:
assert np.allclose(eval_callback.eval_values, expect_out0)
if get_rank() == 6:
assert np.allclose(eval_callback.eval_values, expect_out6)
if __name__ == "__main__":
test_train_eval()