support load full dataset on each device

This commit is contained in:
Yi Huaijie 2020-05-29 16:29:11 +08:00
parent 8de8289cfd
commit e5c351690b
14 changed files with 229 additions and 79 deletions

View File

@ -48,6 +48,7 @@ ParallelContext::ParallelContext() { Reset(); }
void ParallelContext::Reset() {
mirror_mean_ = false;
full_batch_ = false;
cast_before_mirror_ = true;
loss_repeated_mean_ = true;
device_num_ = 1;
@ -75,6 +76,8 @@ void ParallelContext::set_global_rank(int32_t global_rank) {
void ParallelContext::set_mirror_mean(bool mirror_mean) { mirror_mean_ = mirror_mean; }
void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; }
void ParallelContext::set_cast_before_mirror(bool cast_before_mirror) { cast_before_mirror_ = cast_before_mirror; }
void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; }

View File

@ -55,6 +55,9 @@ class ParallelContext {
void set_mirror_mean(bool mirror_mean);
bool mirror_mean() const { return mirror_mean_; }
void set_full_batch(bool full_batch);
bool full_batch() const { return full_batch_; }
void set_cast_before_mirror(bool cast_before_mirror);
bool cast_before_mirror() const { return cast_before_mirror_; }
@ -103,6 +106,7 @@ class ParallelContext {
ParallelContext();
static std::shared_ptr<ParallelContext> inst_context_;
bool mirror_mean_;
bool full_batch_;
bool cast_before_mirror_;
bool loss_repeated_mean_;
int32_t device_num_;

View File

@ -24,15 +24,23 @@
#include "ir/value.h"
#include "parallel/device_matrix.h"
#include "parallel/strategy.h"
#include "parallel/context.h"
#include "parallel/tensor_layout/tensor_redistribution.h"
namespace mindspore {
namespace parallel {
Status GetNextInfo::InferTensorMap() {
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
bool full_batch = ParallelContext::GetInstance()->full_batch();
for (auto shp : shapes_) {
TensorMap out_tensor_map;
for (size_t i = 0; i < shp.size(); ++i) {
out_tensor_map.push_back(SizeToInt(dev_matrix_shape_.size() - i - 1));
if (full_batch) {
out_tensor_map.push_back(MAP_NONE);
} else {
out_tensor_map.push_back(SizeToInt(dev_matrix_shape_.size() - i - 1));
}
}
outputs_tensor_map_.push_back(out_tensor_map);
}
@ -190,6 +198,9 @@ Status GetNextInfo::GetAttrs() {
}
Status GetNextInfo::InferReplaceOps(const StrategyPtr &) {
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
bool full_batch = ParallelContext::GetInstance()->full_batch();
Shapes out_shapes = outputs_shape_;
for (size_t i = 0; i < out_shapes.size(); ++i) {
if (dev_num_ <= 0) {
@ -200,7 +211,9 @@ Status GetNextInfo::InferReplaceOps(const StrategyPtr &) {
MS_LOG(ERROR) << name_ << " : batch num cannot floor div dev num.";
return FAILED;
}
out_shapes[i][0] = out_shapes[i][0] / dev_num_;
if (!full_batch) {
out_shapes[i][0] = out_shapes[i][0] / dev_num_;
}
}
ValuePtr new_shapes = MakeValue(out_shapes);
Attr attr_types = std::make_pair(TYPES, attrs_[TYPES]);

View File

@ -23,6 +23,7 @@
#include "parallel/device_manager.h"
#include "parallel/device_matrix.h"
#include "parallel/step_parallel.h"
#include "parallel/context.h"
#include "utils/log_adapter.h"
namespace mindspore {
@ -93,59 +94,21 @@ Status VirtualDatasetInfo::InferDevMatrixShape() {
return SUCCESS;
}
Status VirtualDatasetInfo::InferMirrorOps() {
mirror_ops_.clear();
int32_t stage = strategy_->GetInputStage();
CheckGlobalDeviceManager();
RankList dev_list = g_device_manager->GetDeviceListByStageId(stage);
if (dev_list.empty()) {
MS_LOG(ERROR) << name_ << ": The current stage is empty!";
return Status::FAILED;
}
if (dev_list.size() == 1) {
MS_LOG(INFO) << name_ << ": No need mirror ops.";
return Status::SUCCESS;
}
OperatorName operator_name = BROADCAST;
ValuePtr attr0_value = MakeValue(dev_list.front());
std::vector<Group> group_list;
if (CreateGroupByDim(dev_matrix_shape_.size() - 1, &group_list) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer mirror ops, create group failed.";
return FAILED;
} else if (group_list.empty()) {
MS_LOG(INFO) << name_ << ": No need mirror ops.";
return SUCCESS;
}
std::string group = group_list[0].name();
ValuePtr attr1_value = MakeValue(group);
Attr attr0 = std::make_pair(SRC, attr0_value);
Attr attr1 = std::make_pair(GROUP, attr1_value);
OperatorAttrs operator_attrs = {attr0, attr1};
OperatorParams operator_param;
OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param);
Operator op = std::make_pair(operator_name, operator_args);
OperatorVector op_vector = {op};
size_t size = inputs_shape_.size();
for (size_t i = 0; i < size; ++i) {
mirror_ops_.push_back(op_vector);
}
mirror_ops_.clear();
return SUCCESS;
}
Status VirtualDatasetInfo::InferMirrorOps() { return SUCCESS; }
Status VirtualDatasetInfo::InferForwardCommunication() { return SUCCESS; }
Status VirtualDatasetInfo::InferTensorMap() {
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
bool full_batch = ParallelContext::GetInstance()->full_batch();
for (size_t i = 0; i < strategy_->GetInputNumber(); i++) {
std::vector<int32_t> tensor_map_index;
tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(dev_matrix_shape_.size()))));
if (full_batch) {
tensor_map_index.push_back(MAP_NONE);
} else {
tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(dev_matrix_shape_.size()))));
}
for (size_t j = 1; j < strategy_->GetInputDim()[i].size(); ++j) {
tensor_map_index.push_back(MAP_NONE);
}
@ -213,6 +176,10 @@ Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
}
Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) {
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
bool full_batch = ParallelContext::GetInstance()->full_batch();
size_t total_dev_num;
if (GetAttrs() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": GetAttrs failed";
return FAILED;
@ -220,7 +187,11 @@ Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) {
CheckGlobalDeviceManager();
is_auto_parallel_ = true;
size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
if (full_batch) {
total_dev_num = 1;
} else {
total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
}
StrategyPtr sp;
std::vector<Dimensions> strategy;
for (auto &shape : inputs_shape_) {
@ -232,10 +203,18 @@ Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) {
sp = std::make_shared<Strategy>(stage_id, strategy);
if (SetCostUnderStrategy(sp) == SUCCESS) {
MS_LOG(INFO) << name_ << ": Successfully generated batch-parallel-strategy.";
if (full_batch) {
MS_LOG(INFO) << name_ << ": Successfully generated full-batch-parallel-strategy.";
} else {
MS_LOG(INFO) << name_ << ": Successfully generated batch-parallel-strategy.";
}
PrintStrategy(sp);
} else {
MS_LOG(ERROR) << name_ << ": Generating batch-parallel-strategy failed.";
if (full_batch) {
MS_LOG(ERROR) << name_ << ": Generating full-batch-parallel-strategy failed.";
} else {
MS_LOG(ERROR) << name_ << ": Generating batch-parallel-strategy failed.";
}
return FAILED;
}
return SUCCESS;

View File

@ -1375,11 +1375,19 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
void SetVirtualDatasetStrategy(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
bool full_batch = ParallelContext::GetInstance()->full_batch();
PrimitivePtr prim = GetValueNode<PrimitivePtr>(node->input(0));
MS_EXCEPTION_IF_NULL(prim);
if (prim->name() == VIRTUAL_DATA_SET) {
CheckGlobalDeviceManager();
int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size());
int32_t dev_num;
if (full_batch) {
dev_num = 1;
} else {
dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size());
}
auto attrs_temp = prim->attrs();
std::vector<Shapes> shape_list = ExtractShape(node);
if (shape_list.empty()) {

View File

@ -187,6 +187,8 @@ PYBIND11_MODULE(_c_expression, m) {
"Set strategy checkpoint save file.")
.def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.")
.def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.")
.def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.")
.def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.")
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")

View File

@ -367,7 +367,8 @@ def _context():
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str,
parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str)
parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str,
full_batch=bool)
def set_auto_parallel_context(**kwargs):
"""
Set auto parallel context.
@ -404,6 +405,7 @@ def set_auto_parallel_context(**kwargs):
broadcast. Default: False.
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
full_batch (bool): Whether to load the whole batch on each device. Default: False.
Raises:
ValueError: If input key is not attribute in auto parallel context.

View File

@ -225,6 +225,21 @@ class _AutoParallelContext:
self.check_context_handle()
return self._context_handle.get_strategy_ckpt_load_file()
def set_full_batch(self, full_batch):
"""
Set whether load full batch on each device.
Args:
full_batch (bool): True if load full batch on each device.
"""
self.check_context_handle()
self._context_handle.set_full_batch(full_batch)
def get_full_batch(self):
"""Get whether load full batch on each device."""
self.check_context_handle()
return self._context_handle.get_full_batch()
def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file):
"""
Set strategy checkpoint save path.
@ -415,7 +430,8 @@ _set_auto_parallel_context_func_map = {
"parallel_mode": auto_parallel_context().set_parallel_mode,
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file}
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
"full_batch": auto_parallel_context().set_full_batch}
_get_auto_parallel_context_func_map = {
@ -427,12 +443,13 @@ _get_auto_parallel_context_func_map = {
"parallel_mode": auto_parallel_context().get_parallel_mode,
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file}
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
"full_batch": auto_parallel_context().get_full_batch}
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool,
loss_repeated_mean=bool, parallel_mode=str, parameter_broadcast=bool,
strategy_ckpt_load_file=str, strategy_ckpt_save_file=str)
strategy_ckpt_load_file=str, strategy_ckpt_save_file=str, full_batch=bool)
def _set_auto_parallel_context(**kwargs):
"""
Set auto parallel context.
@ -465,6 +482,7 @@ def _set_auto_parallel_context(**kwargs):
broadcast. Default: False.
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
full_batch (bool): Whether to load the whole batch on each device. Default: False.
Raises:
ValueError: If input key is not attribute in auto parallel context.

View File

@ -20,10 +20,26 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context
def _get_parallel_mode():
"""Get parallel mode."""
return auto_parallel_context().get_parallel_mode()
def _get_full_batch():
"""Get whether to use full_batch."""
return auto_parallel_context().get_full_batch()
def _need_to_full():
"""Check whether to convert input to full shape or tensor."""
parallel_mode = _get_parallel_mode()
full_batch = _get_full_batch()
need = ((parallel_mode in ("semi_auto_parallel", "auto_parallel"))
and (not full_batch))
return need
def _get_mirror_mean():
"""Get if using mirror_mean."""
return auto_parallel_context().get_mirror_mean()

View File

@ -17,11 +17,10 @@ import math
from mindspore._checkparam import check_bool
from .. import context
from .parallel_utils import ParallelMode
from ._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \
_construct_tensor_list, _to_full_shapes, _to_full_tensor
from ..nn.wrap import GetNextSingleOp
from ..parallel._utils import _get_device_num, _get_global_rank, _get_parallel_mode
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full
class DatasetHelper:
@ -118,10 +117,10 @@ class _DatasetIterMSLoopSink(_DatasetIter):
def __init__(self, dataset):
super(_DatasetIterMSLoopSink, self).__init__(dataset)
self.loop_count = self.get_loop_count(dataset)
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to
# compile, and slice tensor to run. The batch dimension of tensors for compile is device_number
# times the batch dimension of tensors for run. Now only support LoopSink.
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch,
# use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for
# compile is device_number times the batch dimension of tensors for run. Now only support LoopSink.
if _need_to_full():
device_num = _get_device_num()
self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num)
@ -146,10 +145,8 @@ class _DatasetIterGE(_DatasetIter):
def __init__(self, dataset):
super(_DatasetIterGE, self).__init__(dataset)
self.loop_count = self.get_loop_count(dataset)
parallel_mode = _get_parallel_mode()
self.need_to_full = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
batch_expand_num = 1
if self.need_to_full:
if _need_to_full():
batch_expand_num = _get_device_num()
tensor_list_run = _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num)
@ -170,9 +167,6 @@ class _DatasetIterFeed:
self.loop_count = dataset.get_dataset_size()
self.ind = 0
parallel_mode = context.get_auto_parallel_context("parallel_mode")
self.need_to_full = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
def __iter__(self):
if self.repeat_ind % self.repeat_count == 0:
self.iter = self.dataset.__iter__()
@ -186,6 +180,6 @@ class _DatasetIterFeed:
raise StopIteration()
self.ind += 1
data = self.iter.__next__()
if self.need_to_full:
if _need_to_full():
return _to_full_tensor(data, self.device_num, self.global_rank)
return _to_tensor(data)

View File

@ -22,6 +22,7 @@ def argparse_init():
parser = argparse.ArgumentParser(description='WideDeep')
parser.add_argument("--data_path", type=str, default="./test_raw_data/")
parser.add_argument("--epochs", type=int, default=15)
parser.add_argument("--full_batch", type=bool, default=False)
parser.add_argument("--batch_size", type=int, default=16000)
parser.add_argument("--eval_batch_size", type=int, default=16000)
parser.add_argument("--field_size", type=int, default=39)
@ -44,6 +45,7 @@ class WideDeepConfig():
"""
def __init__(self):
self.data_path = "./test_raw_data/"
self.full_batch = False
self.epochs = 15
self.batch_size = 16000
self.eval_batch_size = 16000
@ -72,6 +74,7 @@ class WideDeepConfig():
args, _ = parser.parse_known_args()
self.data_path = args.data_path
self.epochs = args.epochs
self.full_batch = args.full_batch
self.batch_size = args.batch_size
self.eval_batch_size = args.eval_batch_size
self.field_size = args.field_size

View File

@ -17,8 +17,10 @@
Area under cure metric
"""
from mindspore.nn.metrics import Metric
from sklearn.metrics import roc_auc_score
from mindspore import context
from mindspore.nn.metrics import Metric
from mindspore.communication.management import get_rank, get_group_size
class AUCMetric(Metric):
"""
@ -28,6 +30,7 @@ class AUCMetric(Metric):
def __init__(self):
super(AUCMetric, self).__init__()
self.clear()
self.full_batch = context.get_auto_parallel_context("full_batch")
def clear(self):
"""Clear the internal evaluation result."""
@ -35,10 +38,17 @@ class AUCMetric(Metric):
self.pred_probs = []
def update(self, *inputs): # inputs
all_predict = inputs[1].asnumpy() # predict
all_label = inputs[2].asnumpy() # label
self.true_labels.extend(all_label.flatten().tolist())
self.pred_probs.extend(all_predict.flatten().tolist())
"""Update list of predicts and labels."""
all_predict = inputs[1].asnumpy().flatten().tolist() # predict
all_label = inputs[2].asnumpy().flatten().tolist() # label
self.pred_probs.extend(all_predict)
if self.full_batch:
rank_id = get_rank()
group_size = get_group_size()
gap = len(all_label) // group_size
self.true_labels.extend(all_label[rank_id*gap: (rank_id+1)*gap])
else:
self.true_labels.extend(all_label)
def eval(self):
if len(self.true_labels) != len(self.pred_probs):

View File

@ -17,6 +17,7 @@
import os
import sys
import mindspore.dataset.engine as de
from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train import ParallelMode
@ -79,10 +80,18 @@ def test_train_eval():
batch_size = config.batch_size
epochs = config.epochs
print("epochs is {}".format(epochs))
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
if config.full_batch:
context.set_auto_parallel_context(full_batch=True)
de.config.set_seed(1)
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs,
batch_size=batch_size*get_group_size())
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1,
batch_size=batch_size*get_group_size())
else:
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

View File

@ -0,0 +1,89 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.common.parameter import Parameter
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P
from mindspore.parallel._utils import _reset_op_id
from mindspore.train import Model, ParallelMode
from tests.dataset_mock import MindData
class Dataset(MindData):
def __init__(self, predict, label, length=3):
super(Dataset, self).__init__(size=length)
self.predict = predict
self.label = label
self.index = 0
self.length = length
def __iter__(self):
return self
def __next__(self):
if self.index >= self.length:
raise StopIteration
self.index += 1
return self.predict, self.label
def reset(self):
self.index = 0
class AllToAllNet(nn.Cell):
def __init__(self, strategy1):
super(AllToAllNet, self).__init__()
self.matmul = P.MatMul().set_strategy(((1, 1), (1, 8)))
self.matmul_weight = Parameter(Tensor(np.ones([128, 256]), dtype=ms.float32), name="weight")
self.transpose1 = P.Transpose().set_strategy(strategy1)
def construct(self, x):
x = self.matmul(x, self.matmul_weight)
x = self.transpose1(x, (1, 0))
return x
def all_to_all_net(strategy1):
return AllToAllNet(strategy1=strategy1)
def all_to_all_common(strategy1):
learning_rate = 0.1
momentum = 0.9
epoch_size = 2
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=8, full_batch=True)
predict = Tensor(np.ones([256, 128]), dtype=ms.float32)
label = Tensor(np.ones([256]), dtype=ms.int32)
dataset = Dataset(predict, label, 2)
net = all_to_all_net(strategy1)
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
loss.softmax_cross_entropy.set_strategy(((8, 1), (8, 1)))
loss.one_hot.set_strategy(((8, 1), (), ()))
opt = Momentum(net.trainable_params(), learning_rate, momentum)
model = Model(net, loss, opt)
model.train(epoch_size, dataset, dataset_sink_mode=False)
def test_all_to_all():
strategy1 = ((8, 1),)
_reset_op_id()
all_to_all_common(strategy1)