From e5c351690b9e9295c088a384a82292af5d2dbbd7 Mon Sep 17 00:00:00 2001 From: Yi Huaijie Date: Fri, 29 May 2020 16:29:11 +0800 Subject: [PATCH] support load full dataset on each device --- mindspore/ccsrc/parallel/context.cc | 3 + mindspore/ccsrc/parallel/context.h | 4 + .../ccsrc/parallel/ops_info/get_next_info.cc | 17 +++- .../parallel/ops_info/virtual_dataset_info.cc | 79 ++++++---------- mindspore/ccsrc/parallel/step_parallel.cc | 10 ++- mindspore/ccsrc/pipeline/init.cc | 2 + mindspore/context.py | 4 +- mindspore/parallel/_auto_parallel_context.py | 24 ++++- mindspore/parallel/_utils.py | 16 ++++ mindspore/train/dataset_helper.py | 20 ++--- model_zoo/wide_and_deep/src/config.py | 3 + model_zoo/wide_and_deep/src/metrics.py | 20 +++-- .../train_and_test_multinpu_auto_parallel.py | 17 +++- tests/ut/python/parallel/test_full_batch.py | 89 +++++++++++++++++++ 14 files changed, 229 insertions(+), 79 deletions(-) create mode 100644 tests/ut/python/parallel/test_full_batch.py diff --git a/mindspore/ccsrc/parallel/context.cc b/mindspore/ccsrc/parallel/context.cc index de92bba507c..6802292cb46 100644 --- a/mindspore/ccsrc/parallel/context.cc +++ b/mindspore/ccsrc/parallel/context.cc @@ -48,6 +48,7 @@ ParallelContext::ParallelContext() { Reset(); } void ParallelContext::Reset() { mirror_mean_ = false; + full_batch_ = false; cast_before_mirror_ = true; loss_repeated_mean_ = true; device_num_ = 1; @@ -75,6 +76,8 @@ void ParallelContext::set_global_rank(int32_t global_rank) { void ParallelContext::set_mirror_mean(bool mirror_mean) { mirror_mean_ = mirror_mean; } +void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; } + void ParallelContext::set_cast_before_mirror(bool cast_before_mirror) { cast_before_mirror_ = cast_before_mirror; } void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; } diff --git a/mindspore/ccsrc/parallel/context.h b/mindspore/ccsrc/parallel/context.h index 32f9838d6c6..efa528d1793 100644 --- a/mindspore/ccsrc/parallel/context.h +++ b/mindspore/ccsrc/parallel/context.h @@ -55,6 +55,9 @@ class ParallelContext { void set_mirror_mean(bool mirror_mean); bool mirror_mean() const { return mirror_mean_; } + void set_full_batch(bool full_batch); + bool full_batch() const { return full_batch_; } + void set_cast_before_mirror(bool cast_before_mirror); bool cast_before_mirror() const { return cast_before_mirror_; } @@ -103,6 +106,7 @@ class ParallelContext { ParallelContext(); static std::shared_ptr inst_context_; bool mirror_mean_; + bool full_batch_; bool cast_before_mirror_; bool loss_repeated_mean_; int32_t device_num_; diff --git a/mindspore/ccsrc/parallel/ops_info/get_next_info.cc b/mindspore/ccsrc/parallel/ops_info/get_next_info.cc index 29d519fda8a..0fb49364f0a 100644 --- a/mindspore/ccsrc/parallel/ops_info/get_next_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/get_next_info.cc @@ -24,15 +24,23 @@ #include "ir/value.h" #include "parallel/device_matrix.h" #include "parallel/strategy.h" +#include "parallel/context.h" #include "parallel/tensor_layout/tensor_redistribution.h" namespace mindspore { namespace parallel { Status GetNextInfo::InferTensorMap() { + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + bool full_batch = ParallelContext::GetInstance()->full_batch(); + for (auto shp : shapes_) { TensorMap out_tensor_map; for (size_t i = 0; i < shp.size(); ++i) { - out_tensor_map.push_back(SizeToInt(dev_matrix_shape_.size() - i - 1)); + if (full_batch) { + out_tensor_map.push_back(MAP_NONE); + } else { + out_tensor_map.push_back(SizeToInt(dev_matrix_shape_.size() - i - 1)); + } } outputs_tensor_map_.push_back(out_tensor_map); } @@ -190,6 +198,9 @@ Status GetNextInfo::GetAttrs() { } Status GetNextInfo::InferReplaceOps(const StrategyPtr &) { + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + bool full_batch = ParallelContext::GetInstance()->full_batch(); + Shapes out_shapes = outputs_shape_; for (size_t i = 0; i < out_shapes.size(); ++i) { if (dev_num_ <= 0) { @@ -200,7 +211,9 @@ Status GetNextInfo::InferReplaceOps(const StrategyPtr &) { MS_LOG(ERROR) << name_ << " : batch num cannot floor div dev num."; return FAILED; } - out_shapes[i][0] = out_shapes[i][0] / dev_num_; + if (!full_batch) { + out_shapes[i][0] = out_shapes[i][0] / dev_num_; + } } ValuePtr new_shapes = MakeValue(out_shapes); Attr attr_types = std::make_pair(TYPES, attrs_[TYPES]); diff --git a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc b/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc index 4b695ba62d3..ce8b04d8028 100644 --- a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc @@ -23,6 +23,7 @@ #include "parallel/device_manager.h" #include "parallel/device_matrix.h" #include "parallel/step_parallel.h" +#include "parallel/context.h" #include "utils/log_adapter.h" namespace mindspore { @@ -93,59 +94,21 @@ Status VirtualDatasetInfo::InferDevMatrixShape() { return SUCCESS; } -Status VirtualDatasetInfo::InferMirrorOps() { - mirror_ops_.clear(); - - int32_t stage = strategy_->GetInputStage(); - CheckGlobalDeviceManager(); - RankList dev_list = g_device_manager->GetDeviceListByStageId(stage); - if (dev_list.empty()) { - MS_LOG(ERROR) << name_ << ": The current stage is empty!"; - return Status::FAILED; - } - if (dev_list.size() == 1) { - MS_LOG(INFO) << name_ << ": No need mirror ops."; - return Status::SUCCESS; - } - - OperatorName operator_name = BROADCAST; - ValuePtr attr0_value = MakeValue(dev_list.front()); - std::vector group_list; - if (CreateGroupByDim(dev_matrix_shape_.size() - 1, &group_list) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer mirror ops, create group failed."; - return FAILED; - } else if (group_list.empty()) { - MS_LOG(INFO) << name_ << ": No need mirror ops."; - return SUCCESS; - } - std::string group = group_list[0].name(); - ValuePtr attr1_value = MakeValue(group); - - Attr attr0 = std::make_pair(SRC, attr0_value); - Attr attr1 = std::make_pair(GROUP, attr1_value); - - OperatorAttrs operator_attrs = {attr0, attr1}; - - OperatorParams operator_param; - OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param); - - Operator op = std::make_pair(operator_name, operator_args); - OperatorVector op_vector = {op}; - - size_t size = inputs_shape_.size(); - for (size_t i = 0; i < size; ++i) { - mirror_ops_.push_back(op_vector); - } - mirror_ops_.clear(); - return SUCCESS; -} +Status VirtualDatasetInfo::InferMirrorOps() { return SUCCESS; } Status VirtualDatasetInfo::InferForwardCommunication() { return SUCCESS; } Status VirtualDatasetInfo::InferTensorMap() { + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + bool full_batch = ParallelContext::GetInstance()->full_batch(); + for (size_t i = 0; i < strategy_->GetInputNumber(); i++) { std::vector tensor_map_index; - tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(dev_matrix_shape_.size())))); + if (full_batch) { + tensor_map_index.push_back(MAP_NONE); + } else { + tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(dev_matrix_shape_.size())))); + } for (size_t j = 1; j < strategy_->GetInputDim()[i].size(); ++j) { tensor_map_index.push_back(MAP_NONE); } @@ -213,6 +176,10 @@ Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { } Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) { + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + bool full_batch = ParallelContext::GetInstance()->full_batch(); + size_t total_dev_num; + if (GetAttrs() != SUCCESS) { MS_LOG(ERROR) << name_ << ": GetAttrs failed"; return FAILED; @@ -220,7 +187,11 @@ Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) { CheckGlobalDeviceManager(); is_auto_parallel_ = true; - size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + if (full_batch) { + total_dev_num = 1; + } else { + total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + } StrategyPtr sp; std::vector strategy; for (auto &shape : inputs_shape_) { @@ -232,10 +203,18 @@ Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) { sp = std::make_shared(stage_id, strategy); if (SetCostUnderStrategy(sp) == SUCCESS) { - MS_LOG(INFO) << name_ << ": Successfully generated batch-parallel-strategy."; + if (full_batch) { + MS_LOG(INFO) << name_ << ": Successfully generated full-batch-parallel-strategy."; + } else { + MS_LOG(INFO) << name_ << ": Successfully generated batch-parallel-strategy."; + } PrintStrategy(sp); } else { - MS_LOG(ERROR) << name_ << ": Generating batch-parallel-strategy failed."; + if (full_batch) { + MS_LOG(ERROR) << name_ << ": Generating full-batch-parallel-strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Generating batch-parallel-strategy failed."; + } return FAILED; } return SUCCESS; diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index d11d78d9bd4..166ce6b0382 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -1375,11 +1375,19 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { void SetVirtualDatasetStrategy(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + bool full_batch = ParallelContext::GetInstance()->full_batch(); + PrimitivePtr prim = GetValueNode(node->input(0)); MS_EXCEPTION_IF_NULL(prim); if (prim->name() == VIRTUAL_DATA_SET) { CheckGlobalDeviceManager(); - int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size()); + int32_t dev_num; + if (full_batch) { + dev_num = 1; + } else { + dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size()); + } auto attrs_temp = prim->attrs(); std::vector shape_list = ExtractShape(node); if (shape_list.empty()) { diff --git a/mindspore/ccsrc/pipeline/init.cc b/mindspore/ccsrc/pipeline/init.cc index 1b9666a4005..37faf7decc6 100644 --- a/mindspore/ccsrc/pipeline/init.cc +++ b/mindspore/ccsrc/pipeline/init.cc @@ -187,6 +187,8 @@ PYBIND11_MODULE(_c_expression, m) { "Set strategy checkpoint save file.") .def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.") .def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.") + .def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.") + .def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.") .def("reset", &ParallelContext::Reset, "Reset auto parallel context."); (void)py::class_>(m, "CostModelContext") diff --git a/mindspore/context.py b/mindspore/context.py index bf6439a7d59..1887363d5a3 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -367,7 +367,8 @@ def _context(): @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str, - parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str) + parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str, + full_batch=bool) def set_auto_parallel_context(**kwargs): """ Set auto parallel context. @@ -404,6 +405,7 @@ def set_auto_parallel_context(**kwargs): broadcast. Default: False. strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' + full_batch (bool): Whether to load the whole batch on each device. Default: False. Raises: ValueError: If input key is not attribute in auto parallel context. diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index 02190290379..bfdc4f5c7c3 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -225,6 +225,21 @@ class _AutoParallelContext: self.check_context_handle() return self._context_handle.get_strategy_ckpt_load_file() + def set_full_batch(self, full_batch): + """ + Set whether load full batch on each device. + + Args: + full_batch (bool): True if load full batch on each device. + """ + self.check_context_handle() + self._context_handle.set_full_batch(full_batch) + + def get_full_batch(self): + """Get whether load full batch on each device.""" + self.check_context_handle() + return self._context_handle.get_full_batch() + def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file): """ Set strategy checkpoint save path. @@ -415,7 +430,8 @@ _set_auto_parallel_context_func_map = { "parallel_mode": auto_parallel_context().set_parallel_mode, "parameter_broadcast": auto_parallel_context().set_parameter_broadcast, "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file, - "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file} + "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file, + "full_batch": auto_parallel_context().set_full_batch} _get_auto_parallel_context_func_map = { @@ -427,12 +443,13 @@ _get_auto_parallel_context_func_map = { "parallel_mode": auto_parallel_context().get_parallel_mode, "parameter_broadcast": auto_parallel_context().get_parameter_broadcast, "strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file, - "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file} + "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file, + "full_batch": auto_parallel_context().get_full_batch} @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, loss_repeated_mean=bool, parallel_mode=str, parameter_broadcast=bool, - strategy_ckpt_load_file=str, strategy_ckpt_save_file=str) + strategy_ckpt_load_file=str, strategy_ckpt_save_file=str, full_batch=bool) def _set_auto_parallel_context(**kwargs): """ Set auto parallel context. @@ -465,6 +482,7 @@ def _set_auto_parallel_context(**kwargs): broadcast. Default: False. strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' + full_batch (bool): Whether to load the whole batch on each device. Default: False. Raises: ValueError: If input key is not attribute in auto parallel context. diff --git a/mindspore/parallel/_utils.py b/mindspore/parallel/_utils.py index 3301c3c9707..c5b4d57702d 100644 --- a/mindspore/parallel/_utils.py +++ b/mindspore/parallel/_utils.py @@ -20,10 +20,26 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context def _get_parallel_mode(): + """Get parallel mode.""" return auto_parallel_context().get_parallel_mode() +def _get_full_batch(): + """Get whether to use full_batch.""" + return auto_parallel_context().get_full_batch() + + +def _need_to_full(): + """Check whether to convert input to full shape or tensor.""" + parallel_mode = _get_parallel_mode() + full_batch = _get_full_batch() + need = ((parallel_mode in ("semi_auto_parallel", "auto_parallel")) + and (not full_batch)) + return need + + def _get_mirror_mean(): + """Get if using mirror_mean.""" return auto_parallel_context().get_mirror_mean() diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index 083349e5a1c..6cee80cabbf 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -17,11 +17,10 @@ import math from mindspore._checkparam import check_bool from .. import context -from .parallel_utils import ParallelMode from ._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \ _construct_tensor_list, _to_full_shapes, _to_full_tensor from ..nn.wrap import GetNextSingleOp -from ..parallel._utils import _get_device_num, _get_global_rank, _get_parallel_mode +from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full class DatasetHelper: @@ -118,10 +117,10 @@ class _DatasetIterMSLoopSink(_DatasetIter): def __init__(self, dataset): super(_DatasetIterMSLoopSink, self).__init__(dataset) self.loop_count = self.get_loop_count(dataset) - # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to - # compile, and slice tensor to run. The batch dimension of tensors for compile is device_number - # times the batch dimension of tensors for run. Now only support LoopSink. - if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch, + # use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for + # compile is device_number times the batch dimension of tensors for run. Now only support LoopSink. + if _need_to_full(): device_num = _get_device_num() self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num) @@ -146,10 +145,8 @@ class _DatasetIterGE(_DatasetIter): def __init__(self, dataset): super(_DatasetIterGE, self).__init__(dataset) self.loop_count = self.get_loop_count(dataset) - parallel_mode = _get_parallel_mode() - self.need_to_full = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) batch_expand_num = 1 - if self.need_to_full: + if _need_to_full(): batch_expand_num = _get_device_num() tensor_list_run = _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num) @@ -170,9 +167,6 @@ class _DatasetIterFeed: self.loop_count = dataset.get_dataset_size() self.ind = 0 - parallel_mode = context.get_auto_parallel_context("parallel_mode") - self.need_to_full = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) - def __iter__(self): if self.repeat_ind % self.repeat_count == 0: self.iter = self.dataset.__iter__() @@ -186,6 +180,6 @@ class _DatasetIterFeed: raise StopIteration() self.ind += 1 data = self.iter.__next__() - if self.need_to_full: + if _need_to_full(): return _to_full_tensor(data, self.device_num, self.global_rank) return _to_tensor(data) diff --git a/model_zoo/wide_and_deep/src/config.py b/model_zoo/wide_and_deep/src/config.py index 3559e8bf23a..71031b7b95a 100644 --- a/model_zoo/wide_and_deep/src/config.py +++ b/model_zoo/wide_and_deep/src/config.py @@ -22,6 +22,7 @@ def argparse_init(): parser = argparse.ArgumentParser(description='WideDeep') parser.add_argument("--data_path", type=str, default="./test_raw_data/") parser.add_argument("--epochs", type=int, default=15) + parser.add_argument("--full_batch", type=bool, default=False) parser.add_argument("--batch_size", type=int, default=16000) parser.add_argument("--eval_batch_size", type=int, default=16000) parser.add_argument("--field_size", type=int, default=39) @@ -44,6 +45,7 @@ class WideDeepConfig(): """ def __init__(self): self.data_path = "./test_raw_data/" + self.full_batch = False self.epochs = 15 self.batch_size = 16000 self.eval_batch_size = 16000 @@ -72,6 +74,7 @@ class WideDeepConfig(): args, _ = parser.parse_known_args() self.data_path = args.data_path self.epochs = args.epochs + self.full_batch = args.full_batch self.batch_size = args.batch_size self.eval_batch_size = args.eval_batch_size self.field_size = args.field_size diff --git a/model_zoo/wide_and_deep/src/metrics.py b/model_zoo/wide_and_deep/src/metrics.py index 277d6744dc9..c89e9484053 100644 --- a/model_zoo/wide_and_deep/src/metrics.py +++ b/model_zoo/wide_and_deep/src/metrics.py @@ -17,8 +17,10 @@ Area under cure metric """ -from mindspore.nn.metrics import Metric from sklearn.metrics import roc_auc_score +from mindspore import context +from mindspore.nn.metrics import Metric +from mindspore.communication.management import get_rank, get_group_size class AUCMetric(Metric): """ @@ -28,6 +30,7 @@ class AUCMetric(Metric): def __init__(self): super(AUCMetric, self).__init__() self.clear() + self.full_batch = context.get_auto_parallel_context("full_batch") def clear(self): """Clear the internal evaluation result.""" @@ -35,10 +38,17 @@ class AUCMetric(Metric): self.pred_probs = [] def update(self, *inputs): # inputs - all_predict = inputs[1].asnumpy() # predict - all_label = inputs[2].asnumpy() # label - self.true_labels.extend(all_label.flatten().tolist()) - self.pred_probs.extend(all_predict.flatten().tolist()) + """Update list of predicts and labels.""" + all_predict = inputs[1].asnumpy().flatten().tolist() # predict + all_label = inputs[2].asnumpy().flatten().tolist() # label + self.pred_probs.extend(all_predict) + if self.full_batch: + rank_id = get_rank() + group_size = get_group_size() + gap = len(all_label) // group_size + self.true_labels.extend(all_label[rank_id*gap: (rank_id+1)*gap]) + else: + self.true_labels.extend(all_label) def eval(self): if len(self.true_labels) != len(self.pred_probs): diff --git a/model_zoo/wide_and_deep/train_and_test_multinpu_auto_parallel.py b/model_zoo/wide_and_deep/train_and_test_multinpu_auto_parallel.py index 6b6e1e33d12..9659d172231 100644 --- a/model_zoo/wide_and_deep/train_and_test_multinpu_auto_parallel.py +++ b/model_zoo/wide_and_deep/train_and_test_multinpu_auto_parallel.py @@ -17,6 +17,7 @@ import os import sys +import mindspore.dataset.engine as de from mindspore import Model, context from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor from mindspore.train import ParallelMode @@ -79,10 +80,18 @@ def test_train_eval(): batch_size = config.batch_size epochs = config.epochs print("epochs is {}".format(epochs)) - ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, - batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) - ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, - batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) + if config.full_batch: + context.set_auto_parallel_context(full_batch=True) + de.config.set_seed(1) + ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, + batch_size=batch_size*get_group_size()) + ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, + batch_size=batch_size*get_group_size()) + else: + ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, + batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) + ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, + batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) print("ds_train.size: {}".format(ds_train.get_dataset_size())) print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) diff --git a/tests/ut/python/parallel/test_full_batch.py b/tests/ut/python/parallel/test_full_batch.py new file mode 100644 index 00000000000..70a68a5b00c --- /dev/null +++ b/tests/ut/python/parallel/test_full_batch.py @@ -0,0 +1,89 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.common.parameter import Parameter +from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits +from mindspore.nn.optim.momentum import Momentum +from mindspore.ops import operations as P +from mindspore.parallel._utils import _reset_op_id +from mindspore.train import Model, ParallelMode +from tests.dataset_mock import MindData + +class Dataset(MindData): + def __init__(self, predict, label, length=3): + super(Dataset, self).__init__(size=length) + self.predict = predict + self.label = label + self.index = 0 + self.length = length + + def __iter__(self): + return self + + def __next__(self): + if self.index >= self.length: + raise StopIteration + self.index += 1 + return self.predict, self.label + + def reset(self): + self.index = 0 + + +class AllToAllNet(nn.Cell): + def __init__(self, strategy1): + super(AllToAllNet, self).__init__() + self.matmul = P.MatMul().set_strategy(((1, 1), (1, 8))) + self.matmul_weight = Parameter(Tensor(np.ones([128, 256]), dtype=ms.float32), name="weight") + self.transpose1 = P.Transpose().set_strategy(strategy1) + + def construct(self, x): + x = self.matmul(x, self.matmul_weight) + x = self.transpose1(x, (1, 0)) + return x + +def all_to_all_net(strategy1): + return AllToAllNet(strategy1=strategy1) + +def all_to_all_common(strategy1): + learning_rate = 0.1 + momentum = 0.9 + epoch_size = 2 + + context.set_context(mode=context.GRAPH_MODE, save_graphs=False) + context.reset_auto_parallel_context() + context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=8, full_batch=True) + predict = Tensor(np.ones([256, 128]), dtype=ms.float32) + label = Tensor(np.ones([256]), dtype=ms.int32) + dataset = Dataset(predict, label, 2) + net = all_to_all_net(strategy1) + + loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + loss.softmax_cross_entropy.set_strategy(((8, 1), (8, 1))) + loss.one_hot.set_strategy(((8, 1), (), ())) + opt = Momentum(net.trainable_params(), learning_rate, momentum) + model = Model(net, loss, opt) + + model.train(epoch_size, dataset, dataset_sink_mode=False) + +def test_all_to_all(): + strategy1 = ((8, 1),) + _reset_op_id() + all_to_all_common(strategy1)