forked from mindspore-Ecosystem/mindspore
wide_and_deep gpu host_device
This commit is contained in:
parent
ddd9121968
commit
da7c8cbae1
|
@ -0,0 +1,34 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
# bash run_multigpu_train.sh RANK_SIZE EPOCH_SIZE DATASET
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
RANK_SIZE=$1
|
||||
EPOCH_SIZE=$2
|
||||
DATASET=$3
|
||||
VOCAB_SIZE=$4
|
||||
EMB_DIM=$5
|
||||
|
||||
mpirun --allow-run-as-root -n $RANK_SIZE \
|
||||
python -s ${self_path}/../train_and_eval_auto_parallel.py \
|
||||
--device_target="GPU" \
|
||||
--data_path=$DATASET \
|
||||
--epochs=$EPOCH_SIZE \
|
||||
--vocab_size=$VOCAB_SIZE \
|
||||
--emb_dim=$EMB_DIM \
|
||||
--dropout_flag=1 \
|
||||
--host_device_mix=1 > log.txt 2>&1 &
|
|
@ -18,6 +18,7 @@ import time
|
|||
from mindspore.train.callback import Callback
|
||||
from mindspore import context
|
||||
from mindspore.train import ParallelMode
|
||||
from mindspore.communication.management import get_rank
|
||||
|
||||
def add_write(file_path, out_str):
|
||||
"""
|
||||
|
@ -52,7 +53,14 @@ class LossCallBack(Callback):
|
|||
wide_loss, deep_loss = cb_params.net_outputs[0].asnumpy(), cb_params.net_outputs[1].asnumpy()
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
cur_num = cb_params.cur_step_num
|
||||
print("===loss===", cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss, flush=True)
|
||||
rank_id = 0
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL,
|
||||
ParallelMode.DATA_PARALLEL):
|
||||
rank_id = get_rank()
|
||||
|
||||
print("===loss===", rank_id, cb_params.cur_epoch_num, cur_step_in_epoch,
|
||||
wide_loss, deep_loss, flush=True)
|
||||
|
||||
# raise ValueError
|
||||
if self._per_print_times != 0 and cur_num % self._per_print_times == 0 and self.config is not None:
|
||||
|
@ -99,13 +107,18 @@ class EvalCallBack(Callback):
|
|||
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
context.set_auto_parallel_context(strategy_ckpt_save_file="",
|
||||
strategy_ckpt_load_file="./strategy_train.ckpt")
|
||||
rank_id = 0
|
||||
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL,
|
||||
ParallelMode.DATA_PARALLEL):
|
||||
rank_id = get_rank()
|
||||
start_time = time.time()
|
||||
out = self.model.eval(self.eval_dataset, dataset_sink_mode=(not self.host_device_mix))
|
||||
end_time = time.time()
|
||||
eval_time = int(end_time - start_time)
|
||||
|
||||
time_str = time.strftime("%Y-%m-%d %H:%M%S", time.localtime())
|
||||
out_str = "{}==== EvalCallBack model.eval(): {}; eval_time: {}s".format(time_str, out.values(), eval_time)
|
||||
out_str = "{} == Rank: {} == EvalCallBack model.eval(): {}; eval_time: {}s".\
|
||||
format(time_str, rank_id, out.values(), eval_time)
|
||||
print(out_str)
|
||||
self.eval_values = out.values()
|
||||
add_write(self.eval_file_name, out_str)
|
||||
|
|
|
@ -201,6 +201,7 @@ class WideDeepModel(nn.Cell):
|
|||
self.cast = P.Cast()
|
||||
if is_auto_parallel and host_device_mix:
|
||||
self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),))
|
||||
self.dense_layer_1.dropout.dropout.set_strategy(((1, get_group_size()),))
|
||||
self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1)))
|
||||
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim,
|
||||
slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE)
|
||||
|
|
|
@ -32,13 +32,6 @@ from src.metrics import AUCMetric
|
|||
from src.config import WideDeepConfig
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(variable_memory_max_size="24GB")
|
||||
context.set_context(enable_sparse=True)
|
||||
cost_model_context.set_cost_model_context(multi_subgraphs=True)
|
||||
init()
|
||||
|
||||
|
||||
|
||||
def get_WideDeep_net(config):
|
||||
"""
|
||||
|
@ -131,6 +124,14 @@ def train_and_eval(config):
|
|||
if __name__ == "__main__":
|
||||
wide_deep_config = WideDeepConfig()
|
||||
wide_deep_config.argparse_init()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target, save_graphs=True)
|
||||
context.set_context(variable_memory_max_size="24GB")
|
||||
context.set_context(enable_sparse=True)
|
||||
cost_model_context.set_cost_model_context(multi_subgraphs=True)
|
||||
if wide_deep_config.device_target == "Ascend":
|
||||
init("hccl")
|
||||
elif wide_deep_config.device_target == "GPU":
|
||||
init("nccl")
|
||||
if wide_deep_config.host_device_mix == 1:
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True)
|
||||
else:
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
from mindspore import Model, context
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train.callback import TimeMonitor
|
||||
|
@ -68,6 +69,7 @@ def train_and_eval(config):
|
|||
"""
|
||||
train_and_eval
|
||||
"""
|
||||
np.random.seed(1000)
|
||||
data_path = config.data_path
|
||||
epochs = config.epochs
|
||||
print("epochs is {}".format(epochs))
|
||||
|
|
Loading…
Reference in New Issue