forked from mindspore-Ecosystem/mindspore
!5309 wide&deep field slice mode
Merge pull request !5309 from yao_yf/wide_and_deep_field_slice
This commit is contained in:
commit
816ea95d5d
|
@ -41,6 +41,8 @@ do
|
|||
cd ${execute_path}/device_$RANK_ID || exit
|
||||
if [ $MODE == "host_device_mix" ]; then
|
||||
python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=1 >train_deep$i.log 2>&1 &
|
||||
elif [ $MODE == "field_slice_host_device_mix" ]; then
|
||||
python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=1 --full_batch=1 --field_slice=1 >train_deep$i.log 2>&1 &
|
||||
else
|
||||
python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=0 >train_deep$i.log 2>&1 &
|
||||
fi
|
||||
|
|
|
@ -38,7 +38,7 @@ do
|
|||
user=$(get_node_user ${cluster_config_path} ${node})
|
||||
passwd=$(get_node_passwd ${cluster_config_path} ${node})
|
||||
echo "------------------${user}@${node}---------------------"
|
||||
if [ $MODE == "host_device_mix" ]; then
|
||||
if [ $MODE == "host_device_mix" ] || [ $MODE == "field_slice_host_device_mix" ]; then
|
||||
ssh_pass ${node} ${user} ${passwd} "mkdir -p ${execute_path}; cd ${execute_path}; bash ${SCRIPTPATH}/run_auto_parallel_train_cluster.sh ${RANK_SIZE} ${RANK_START} ${EPOCH_SIZE} ${VOCAB_SIZE} ${EMB_DIM} ${DATASET} ${ENV_SH} ${MODE} ${RANK_TABLE_FILE}"
|
||||
else
|
||||
echo "[ERROR] mode is wrong"
|
||||
|
|
|
@ -25,7 +25,7 @@ def argparse_init():
|
|||
parser.add_argument("--data_path", type=str, default="./test_raw_data/",
|
||||
help="This should be set to the same directory given to the data_download's data_dir argument")
|
||||
parser.add_argument("--epochs", type=int, default=15, help="Total train epochs")
|
||||
parser.add_argument("--full_batch", type=bool, default=False, help="Enable loading the full batch ")
|
||||
parser.add_argument("--full_batch", type=int, default=0, help="Enable loading the full batch ")
|
||||
parser.add_argument("--batch_size", type=int, default=16000, help="Training batch size.")
|
||||
parser.add_argument("--eval_batch_size", type=int, default=16000, help="Eval batch size.")
|
||||
parser.add_argument("--field_size", type=int, default=39, help="The number of features.")
|
||||
|
@ -46,6 +46,7 @@ def argparse_init():
|
|||
parser.add_argument("--host_device_mix", type=int, default=0, help="Enable host device mode or not")
|
||||
parser.add_argument("--dataset_type", type=str, default="tfrecord", help="tfrecord/mindrecord/hd5")
|
||||
parser.add_argument("--parameter_server", type=int, default=0, help="Open parameter server of not")
|
||||
parser.add_argument("--field_slice", type=int, default=0, help="Enable split field mode or not")
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -81,6 +82,8 @@ class WideDeepConfig():
|
|||
self.host_device_mix = 0
|
||||
self.dataset_type = "tfrecord"
|
||||
self.parameter_server = 0
|
||||
self.field_slice = False
|
||||
self.manual_shape = None
|
||||
|
||||
def argparse_init(self):
|
||||
"""
|
||||
|
@ -91,7 +94,7 @@ class WideDeepConfig():
|
|||
self.device_target = args.device_target
|
||||
self.data_path = args.data_path
|
||||
self.epochs = args.epochs
|
||||
self.full_batch = args.full_batch
|
||||
self.full_batch = bool(args.full_batch)
|
||||
self.batch_size = args.batch_size
|
||||
self.eval_batch_size = args.eval_batch_size
|
||||
self.field_size = args.field_size
|
||||
|
@ -114,3 +117,4 @@ class WideDeepConfig():
|
|||
self.host_device_mix = args.host_device_mix
|
||||
self.dataset_type = args.dataset_type
|
||||
self.parameter_server = args.parameter_server
|
||||
self.field_slice = bool(args.field_slice)
|
||||
|
|
|
@ -23,6 +23,7 @@ import pandas as pd
|
|||
import mindspore.dataset.engine as de
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
|
||||
class DataType(Enum):
|
||||
"""
|
||||
Enumerate supported dataset format.
|
||||
|
@ -83,9 +84,9 @@ class H5Dataset():
|
|||
yield os.path.join(self._hdf_data_dir,
|
||||
self._file_prefix + '_input_part_' + str(
|
||||
p) + '.h5'), \
|
||||
os.path.join(self._hdf_data_dir,
|
||||
self._file_prefix + '_output_part_' + str(
|
||||
p) + '.h5'), i + 1 == len(parts)
|
||||
os.path.join(self._hdf_data_dir,
|
||||
self._file_prefix + '_output_part_' + str(
|
||||
p) + '.h5'), i + 1 == len(parts)
|
||||
|
||||
def _generator(self, X, y, batch_size, shuffle=True):
|
||||
"""
|
||||
|
@ -169,8 +170,41 @@ def _get_h5_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000):
|
|||
return ds
|
||||
|
||||
|
||||
def _padding_func(batch_size, manual_shape, target_column, field_size=39):
|
||||
"""
|
||||
get padding_func
|
||||
"""
|
||||
if manual_shape:
|
||||
generate_concat_offset = [item[0]+item[1] for item in manual_shape]
|
||||
part_size = int(target_column / len(generate_concat_offset))
|
||||
filled_value = []
|
||||
for i in range(field_size, target_column):
|
||||
filled_value.append(generate_concat_offset[i//part_size]-1)
|
||||
print("Filed Value:", filled_value)
|
||||
|
||||
def padding_func(x, y, z):
|
||||
x = np.array(x).flatten().reshape(batch_size, field_size)
|
||||
y = np.array(y).flatten().reshape(batch_size, field_size)
|
||||
z = np.array(z).flatten().reshape(batch_size, 1)
|
||||
|
||||
x_id = np.ones((batch_size, target_column - field_size),
|
||||
dtype=np.int32) * filled_value
|
||||
x_id = np.concatenate([x, x_id.astype(dtype=np.int32)], axis=1)
|
||||
mask = np.concatenate(
|
||||
[y, np.zeros((batch_size, target_column-39), dtype=np.float32)], axis=1)
|
||||
return (x_id, mask, z)
|
||||
else:
|
||||
def padding_func(x, y, z):
|
||||
x = np.array(x).flatten().reshape(batch_size, field_size)
|
||||
y = np.array(y).flatten().reshape(batch_size, field_size)
|
||||
z = np.array(z).flatten().reshape(batch_size, 1)
|
||||
return (x, y, z)
|
||||
return padding_func
|
||||
|
||||
|
||||
def _get_tf_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000,
|
||||
line_per_sample=1000, rank_size=None, rank_id=None):
|
||||
line_per_sample=1000, rank_size=None, rank_id=None,
|
||||
manual_shape=None, target_column=40):
|
||||
"""
|
||||
get_tf_dataset
|
||||
"""
|
||||
|
@ -189,21 +223,22 @@ def _get_tf_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000,
|
|||
ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle, schema=schema, num_parallel_workers=8,
|
||||
num_shards=rank_size, shard_id=rank_id, shard_equal_rows=True)
|
||||
else:
|
||||
ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle, schema=schema, num_parallel_workers=8)
|
||||
ds = de.TFRecordDataset(dataset_files=dataset_files,
|
||||
shuffle=shuffle, schema=schema, num_parallel_workers=8)
|
||||
ds = ds.batch(int(batch_size / line_per_sample),
|
||||
drop_remainder=True)
|
||||
ds = ds.map(operations=(lambda x, y, z: (
|
||||
np.array(x).flatten().reshape(batch_size, 39),
|
||||
np.array(y).flatten().reshape(batch_size, 39),
|
||||
np.array(z).flatten().reshape(batch_size, 1))),
|
||||
|
||||
ds = ds.map(operations=_padding_func(batch_size, manual_shape, target_column),
|
||||
input_columns=['feat_ids', 'feat_vals', 'label'],
|
||||
columns_order=['feat_ids', 'feat_vals', 'label'], num_parallel_workers=8)
|
||||
#if train_mode:
|
||||
# if train_mode:
|
||||
ds = ds.repeat(epochs)
|
||||
return ds
|
||||
|
||||
|
||||
def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=1000,
|
||||
line_per_sample=1000, rank_size=None, rank_id=None):
|
||||
line_per_sample=1000, rank_size=None, rank_id=None,
|
||||
manual_shape=None, target_column=40):
|
||||
"""
|
||||
Get dataset with mindrecord format.
|
||||
|
||||
|
@ -233,9 +268,7 @@ def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=100
|
|||
columns_list=['feat_ids', 'feat_vals', 'label'],
|
||||
shuffle=shuffle, num_parallel_workers=8)
|
||||
ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True)
|
||||
ds = ds.map(operations=(lambda x, y, z: (np.array(x).flatten().reshape(batch_size, 39),
|
||||
np.array(y).flatten().reshape(batch_size, 39),
|
||||
np.array(z).flatten().reshape(batch_size, 1))),
|
||||
ds = ds.map(_padding_func(batch_size, manual_shape, target_column),
|
||||
input_columns=['feat_ids', 'feat_vals', 'label'],
|
||||
columns_order=['feat_ids', 'feat_vals', 'label'],
|
||||
num_parallel_workers=8)
|
||||
|
@ -243,18 +276,84 @@ def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=100
|
|||
return ds
|
||||
|
||||
|
||||
def _get_vocab_size(target_column_number, worker_size, total_vocab_size, multiply=False, per_vocab_size=None):
|
||||
"""
|
||||
get_vocab_size
|
||||
"""
|
||||
# Only 39
|
||||
inidival_vocabs = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 691, 540, 20855, 23639, 182, 15,
|
||||
10091, 347, 4, 16366, 4494, 21293, 3103, 27, 6944, 22366, 11, 3267, 1610,
|
||||
5, 21762, 14, 15, 15030, 61, 12220]
|
||||
|
||||
new_vocabs = inidival_vocabs + [1] * \
|
||||
(target_column_number - len(inidival_vocabs))
|
||||
part_size = int(target_column_number / worker_size)
|
||||
|
||||
# According to the workers, we merge some fields into the same part
|
||||
new_vocab_size = []
|
||||
for i in range(0, target_column_number, part_size):
|
||||
new_vocab_size.append(sum(new_vocabs[i: i + part_size]))
|
||||
|
||||
index_offsets = [0]
|
||||
|
||||
# The gold feature numbers ared used to caculate the offset
|
||||
features = [item for item in new_vocab_size]
|
||||
|
||||
# According to the per_vocab_size, maxize the vocab size
|
||||
if per_vocab_size is not None:
|
||||
new_vocab_size = [per_vocab_size] * worker_size
|
||||
else:
|
||||
# Expands the vocabulary of each field by the multiplier
|
||||
if multiply is True:
|
||||
cur_sum = sum(new_vocab_size)
|
||||
k = total_vocab_size/cur_sum
|
||||
new_vocab_size = [
|
||||
math.ceil(int(item*k)/worker_size)*worker_size for item in new_vocab_size]
|
||||
new_vocab_size = [(item // 8 + 1)*8 for item in new_vocab_size]
|
||||
|
||||
else:
|
||||
if total_vocab_size > sum(new_vocab_size):
|
||||
new_vocab_size[-1] = total_vocab_size - \
|
||||
sum(new_vocab_size[:-1])
|
||||
new_vocab_size = [item for item in new_vocab_size]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Please providede the correct vocab size, now is {}".format(total_vocab_size))
|
||||
|
||||
for i in range(worker_size-1):
|
||||
off = index_offsets[i] + features[i]
|
||||
index_offsets.append(off)
|
||||
|
||||
print("the offset: ", index_offsets)
|
||||
manual_shape = tuple(
|
||||
((new_vocab_size[i], index_offsets[i]) for i in range(worker_size)))
|
||||
vocab_total = sum(new_vocab_size)
|
||||
return manual_shape, vocab_total
|
||||
|
||||
|
||||
def compute_manual_shape(config, worker_size):
|
||||
target_column = (config.field_size // worker_size + 1) * worker_size
|
||||
config.field_size = target_column
|
||||
manual_shape, vocab_total = _get_vocab_size(target_column, worker_size, total_vocab_size=config.vocab_size,
|
||||
per_vocab_size=None, multiply=False)
|
||||
config.manual_shape = manual_shape
|
||||
config.vocab_size = int(vocab_total)
|
||||
|
||||
|
||||
def create_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000,
|
||||
data_type=DataType.TFRECORD, line_per_sample=1000, rank_size=None, rank_id=None):
|
||||
data_type=DataType.TFRECORD, line_per_sample=1000,
|
||||
rank_size=None, rank_id=None, manual_shape=None, target_column=40):
|
||||
"""
|
||||
create_dataset
|
||||
"""
|
||||
if data_type == DataType.TFRECORD:
|
||||
return _get_tf_dataset(data_dir, train_mode, epochs, batch_size,
|
||||
line_per_sample, rank_size=rank_size, rank_id=rank_id)
|
||||
line_per_sample, rank_size=rank_size, rank_id=rank_id,
|
||||
manual_shape=manual_shape, target_column=target_column)
|
||||
if data_type == DataType.MINDRECORD:
|
||||
return _get_mindrecord_dataset(data_dir, train_mode, epochs,
|
||||
batch_size, line_per_sample,
|
||||
rank_size, rank_id)
|
||||
return _get_mindrecord_dataset(data_dir, train_mode, epochs, batch_size,
|
||||
line_per_sample, rank_size=rank_size, rank_id=rank_id,
|
||||
manual_shape=manual_shape, target_column=target_column)
|
||||
|
||||
if rank_size > 1:
|
||||
raise RuntimeError("please use tfrecord dataset.")
|
||||
|
|
|
@ -143,6 +143,7 @@ class WideDeepModel(nn.Cell):
|
|||
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||
if is_auto_parallel:
|
||||
self.batch_size = self.batch_size * get_group_size()
|
||||
is_field_slice = config.field_slice
|
||||
self.field_size = config.field_size
|
||||
self.vocab_size = config.vocab_size
|
||||
self.emb_dim = config.emb_dim
|
||||
|
@ -196,11 +197,10 @@ class WideDeepModel(nn.Cell):
|
|||
self.tile = P.Tile()
|
||||
self.concat = P.Concat(axis=1)
|
||||
self.cast = P.Cast()
|
||||
if is_auto_parallel and host_device_mix:
|
||||
if is_auto_parallel and host_device_mix and not is_field_slice:
|
||||
self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),))
|
||||
self.dense_layer_1.dropout.dropout.set_strategy(((1, get_group_size()),))
|
||||
self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1)))
|
||||
self.dense_layer_1.matmul.add_prim_attr("field_size", config.field_size)
|
||||
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim,
|
||||
slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE)
|
||||
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1,
|
||||
|
@ -209,9 +209,20 @@ class WideDeepModel(nn.Cell):
|
|||
self.deep_reshape.add_prim_attr("skip_redistribution", True)
|
||||
self.reduce_sum.add_prim_attr("cross_batch", True)
|
||||
self.embedding_table = self.deep_embeddinglookup.embedding_table
|
||||
elif host_device_mix:
|
||||
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim)
|
||||
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1)
|
||||
elif is_auto_parallel and host_device_mix and is_field_slice and config.full_batch and config.manual_shape:
|
||||
manual_shapes = tuple((s[0] for s in config.manual_shape))
|
||||
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim,
|
||||
slice_mode=nn.EmbeddingLookUpSplitMode.FIELD_SLICE,
|
||||
manual_shapes=manual_shapes)
|
||||
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1,
|
||||
slice_mode=nn.EmbeddingLookUpSplitMode.FIELD_SLICE,
|
||||
manual_shapes=manual_shapes)
|
||||
self.deep_mul.set_strategy(((1, get_group_size(), 1), (1, get_group_size(), 1)))
|
||||
self.wide_mul.set_strategy(((1, get_group_size(), 1), (1, get_group_size(), 1)))
|
||||
self.reduce_sum.set_strategy(((1, get_group_size(), 1),))
|
||||
self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),))
|
||||
self.dense_layer_1.dropout.dropout.set_strategy(((1, get_group_size()),))
|
||||
self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1)))
|
||||
self.embedding_table = self.deep_embeddinglookup.embedding_table
|
||||
elif parameter_server:
|
||||
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim)
|
||||
|
@ -263,7 +274,7 @@ class NetWithLossClass(nn.Cell):
|
|||
parameter_server = bool(config.parameter_server)
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||
self.no_l2loss = (is_auto_parallel if host_device_mix else parameter_server)
|
||||
self.no_l2loss = (is_auto_parallel if (host_device_mix or config.field_slice) else parameter_server)
|
||||
self.network = network
|
||||
self.l2_coef = config.l2_coef
|
||||
self.loss = P.SigmoidCrossEntropyWithLogits()
|
||||
|
|
|
@ -27,12 +27,13 @@ from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
|
|||
|
||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||
from src.callbacks import LossCallBack, EvalCallBack
|
||||
from src.datasets import create_dataset, DataType
|
||||
from src.datasets import create_dataset, DataType, compute_manual_shape
|
||||
from src.metrics import AUCMetric
|
||||
from src.config import WideDeepConfig
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
def get_WideDeep_net(config):
|
||||
"""
|
||||
Get network of wide&deep model.
|
||||
|
@ -40,7 +41,8 @@ def get_WideDeep_net(config):
|
|||
WideDeep_net = WideDeepModel(config)
|
||||
loss_net = NetWithLossClass(WideDeep_net, config)
|
||||
loss_net = VirtualDatasetCellTriple(loss_net)
|
||||
train_net = TrainStepWrap(loss_net, host_device_mix=bool(config.host_device_mix))
|
||||
train_net = TrainStepWrap(
|
||||
loss_net, host_device_mix=bool(config.host_device_mix))
|
||||
eval_net = PredictWithSigmoid(WideDeep_net)
|
||||
eval_net = VirtualDatasetCellTriple(eval_net)
|
||||
return train_net, eval_net
|
||||
|
@ -50,6 +52,7 @@ class ModelBuilder():
|
|||
"""
|
||||
ModelBuilder
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
@ -86,10 +89,19 @@ def train_and_eval(config):
|
|||
if config.full_batch:
|
||||
context.set_auto_parallel_context(full_batch=True)
|
||||
de.config.set_seed(1)
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
|
||||
batch_size=batch_size*get_group_size(), data_type=dataset_type)
|
||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
|
||||
batch_size=batch_size*get_group_size(), data_type=dataset_type)
|
||||
if config.field_slice:
|
||||
compute_manual_shape(config, get_group_size())
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
|
||||
batch_size=batch_size*get_group_size(), data_type=dataset_type,
|
||||
manual_shape=config.manual_shape, target_column=config.field_size)
|
||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
|
||||
batch_size=batch_size*get_group_size(), data_type=dataset_type,
|
||||
manual_shape=config.manual_shape, target_column=config.field_size)
|
||||
else:
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
|
||||
batch_size=batch_size*get_group_size(), data_type=dataset_type)
|
||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
|
||||
batch_size=batch_size*get_group_size(), data_type=dataset_type)
|
||||
else:
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
|
||||
batch_size=batch_size, rank_id=get_rank(),
|
||||
|
@ -106,9 +118,11 @@ def train_and_eval(config):
|
|||
train_net.set_train()
|
||||
auc_metric = AUCMetric()
|
||||
|
||||
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
|
||||
model = Model(train_net, eval_network=eval_net,
|
||||
metrics={"auc": auc_metric})
|
||||
|
||||
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config, host_device_mix=host_device_mix)
|
||||
eval_callback = EvalCallBack(
|
||||
model, ds_eval, auc_metric, config, host_device_mix=host_device_mix)
|
||||
|
||||
callback = LossCallBack(config=config, per_print_times=20)
|
||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs,
|
||||
|
@ -116,16 +130,19 @@ def train_and_eval(config):
|
|||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
||||
directory=config.ckpt_path, config=ckptconfig)
|
||||
context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt)
|
||||
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
|
||||
callback_list = [TimeMonitor(
|
||||
ds_train.get_dataset_size()), eval_callback, callback]
|
||||
if not host_device_mix:
|
||||
callback_list.append(ckpoint_cb)
|
||||
model.train(epochs, ds_train, callbacks=callback_list, dataset_sink_mode=(not host_device_mix))
|
||||
model.train(epochs, ds_train, callbacks=callback_list,
|
||||
dataset_sink_mode=(not host_device_mix))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
wide_deep_config = WideDeepConfig()
|
||||
wide_deep_config.argparse_init()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target, save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target=wide_deep_config.device_target, save_graphs=True)
|
||||
context.set_context(variable_memory_max_size="24GB")
|
||||
context.set_context(enable_sparse=True)
|
||||
set_multi_subgraphs()
|
||||
|
@ -134,7 +151,9 @@ if __name__ == "__main__":
|
|||
elif wide_deep_config.device_target == "GPU":
|
||||
init("nccl")
|
||||
if wide_deep_config.host_device_mix == 1:
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True)
|
||||
context.set_auto_parallel_context(
|
||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True)
|
||||
else:
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True)
|
||||
context.set_auto_parallel_context(
|
||||
parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True)
|
||||
train_and_eval(wide_deep_config)
|
||||
|
|
|
@ -101,12 +101,8 @@ def train_and_eval(config):
|
|||
|
||||
callback = LossCallBack(config=config)
|
||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
|
||||
if config.device_target == "Ascend":
|
||||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
||||
directory=config.ckpt_path, config=ckptconfig)
|
||||
elif config.device_target == "GPU":
|
||||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train_' + str(get_rank()),
|
||||
directory=config.ckpt_path, config=ckptconfig)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
||||
directory=config.ckpt_path, config=ckptconfig)
|
||||
out = model.eval(ds_eval)
|
||||
print("=====" * 5 + "model.eval() initialized: {}".format(out))
|
||||
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
|
||||
|
|
|
@ -103,14 +103,13 @@ def train_and_eval(config):
|
|||
|
||||
callback = LossCallBack(config=config)
|
||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
|
||||
if config.device_target == "Ascend":
|
||||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
||||
directory=config.ckpt_path, config=ckptconfig)
|
||||
elif config.device_target == "GPU":
|
||||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train_' + str(get_rank()),
|
||||
directory=config.ckpt_path, config=ckptconfig)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
||||
directory=config.ckpt_path, config=ckptconfig)
|
||||
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
|
||||
if get_rank() == 0:
|
||||
callback_list.append(ckpoint_cb)
|
||||
model.train(epochs, ds_train,
|
||||
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb],
|
||||
callbacks=callback_list,
|
||||
dataset_sink_mode=(not parameter_server))
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue