forked from mindspore-Ecosystem/mindspore
!1672 support load full dataset on each device
Merge pull request !1672 from yihuaijie/dev
This commit is contained in:
commit
4df861cb62
|
@ -48,6 +48,7 @@ ParallelContext::ParallelContext() { Reset(); }
|
|||
|
||||
void ParallelContext::Reset() {
|
||||
mirror_mean_ = false;
|
||||
full_batch_ = false;
|
||||
cast_before_mirror_ = true;
|
||||
loss_repeated_mean_ = true;
|
||||
device_num_ = 1;
|
||||
|
@ -75,6 +76,8 @@ void ParallelContext::set_global_rank(int32_t global_rank) {
|
|||
|
||||
void ParallelContext::set_mirror_mean(bool mirror_mean) { mirror_mean_ = mirror_mean; }
|
||||
|
||||
void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; }
|
||||
|
||||
void ParallelContext::set_cast_before_mirror(bool cast_before_mirror) { cast_before_mirror_ = cast_before_mirror; }
|
||||
|
||||
void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; }
|
||||
|
|
|
@ -55,6 +55,9 @@ class ParallelContext {
|
|||
void set_mirror_mean(bool mirror_mean);
|
||||
bool mirror_mean() const { return mirror_mean_; }
|
||||
|
||||
void set_full_batch(bool full_batch);
|
||||
bool full_batch() const { return full_batch_; }
|
||||
|
||||
void set_cast_before_mirror(bool cast_before_mirror);
|
||||
bool cast_before_mirror() const { return cast_before_mirror_; }
|
||||
|
||||
|
@ -103,6 +106,7 @@ class ParallelContext {
|
|||
ParallelContext();
|
||||
static std::shared_ptr<ParallelContext> inst_context_;
|
||||
bool mirror_mean_;
|
||||
bool full_batch_;
|
||||
bool cast_before_mirror_;
|
||||
bool loss_repeated_mean_;
|
||||
int32_t device_num_;
|
||||
|
|
|
@ -24,15 +24,23 @@
|
|||
#include "ir/value.h"
|
||||
#include "parallel/device_matrix.h"
|
||||
#include "parallel/strategy.h"
|
||||
#include "parallel/context.h"
|
||||
#include "parallel/tensor_layout/tensor_redistribution.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status GetNextInfo::InferTensorMap() {
|
||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||
bool full_batch = ParallelContext::GetInstance()->full_batch();
|
||||
|
||||
for (auto shp : shapes_) {
|
||||
TensorMap out_tensor_map;
|
||||
for (size_t i = 0; i < shp.size(); ++i) {
|
||||
out_tensor_map.push_back(SizeToInt(dev_matrix_shape_.size() - i - 1));
|
||||
if (full_batch) {
|
||||
out_tensor_map.push_back(MAP_NONE);
|
||||
} else {
|
||||
out_tensor_map.push_back(SizeToInt(dev_matrix_shape_.size() - i - 1));
|
||||
}
|
||||
}
|
||||
outputs_tensor_map_.push_back(out_tensor_map);
|
||||
}
|
||||
|
@ -190,6 +198,9 @@ Status GetNextInfo::GetAttrs() {
|
|||
}
|
||||
|
||||
Status GetNextInfo::InferReplaceOps(const StrategyPtr &) {
|
||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||
bool full_batch = ParallelContext::GetInstance()->full_batch();
|
||||
|
||||
Shapes out_shapes = outputs_shape_;
|
||||
for (size_t i = 0; i < out_shapes.size(); ++i) {
|
||||
if (dev_num_ <= 0) {
|
||||
|
@ -200,7 +211,9 @@ Status GetNextInfo::InferReplaceOps(const StrategyPtr &) {
|
|||
MS_LOG(ERROR) << name_ << " : batch num cannot floor div dev num.";
|
||||
return FAILED;
|
||||
}
|
||||
out_shapes[i][0] = out_shapes[i][0] / dev_num_;
|
||||
if (!full_batch) {
|
||||
out_shapes[i][0] = out_shapes[i][0] / dev_num_;
|
||||
}
|
||||
}
|
||||
ValuePtr new_shapes = MakeValue(out_shapes);
|
||||
Attr attr_types = std::make_pair(TYPES, attrs_[TYPES]);
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "parallel/device_manager.h"
|
||||
#include "parallel/device_matrix.h"
|
||||
#include "parallel/step_parallel.h"
|
||||
#include "parallel/context.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -93,59 +94,21 @@ Status VirtualDatasetInfo::InferDevMatrixShape() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status VirtualDatasetInfo::InferMirrorOps() {
|
||||
mirror_ops_.clear();
|
||||
|
||||
int32_t stage = strategy_->GetInputStage();
|
||||
CheckGlobalDeviceManager();
|
||||
RankList dev_list = g_device_manager->GetDeviceListByStageId(stage);
|
||||
if (dev_list.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The current stage is empty!";
|
||||
return Status::FAILED;
|
||||
}
|
||||
if (dev_list.size() == 1) {
|
||||
MS_LOG(INFO) << name_ << ": No need mirror ops.";
|
||||
return Status::SUCCESS;
|
||||
}
|
||||
|
||||
OperatorName operator_name = BROADCAST;
|
||||
ValuePtr attr0_value = MakeValue(dev_list.front());
|
||||
std::vector<Group> group_list;
|
||||
if (CreateGroupByDim(dev_matrix_shape_.size() - 1, &group_list) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Infer mirror ops, create group failed.";
|
||||
return FAILED;
|
||||
} else if (group_list.empty()) {
|
||||
MS_LOG(INFO) << name_ << ": No need mirror ops.";
|
||||
return SUCCESS;
|
||||
}
|
||||
std::string group = group_list[0].name();
|
||||
ValuePtr attr1_value = MakeValue(group);
|
||||
|
||||
Attr attr0 = std::make_pair(SRC, attr0_value);
|
||||
Attr attr1 = std::make_pair(GROUP, attr1_value);
|
||||
|
||||
OperatorAttrs operator_attrs = {attr0, attr1};
|
||||
|
||||
OperatorParams operator_param;
|
||||
OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param);
|
||||
|
||||
Operator op = std::make_pair(operator_name, operator_args);
|
||||
OperatorVector op_vector = {op};
|
||||
|
||||
size_t size = inputs_shape_.size();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
mirror_ops_.push_back(op_vector);
|
||||
}
|
||||
mirror_ops_.clear();
|
||||
return SUCCESS;
|
||||
}
|
||||
Status VirtualDatasetInfo::InferMirrorOps() { return SUCCESS; }
|
||||
|
||||
Status VirtualDatasetInfo::InferForwardCommunication() { return SUCCESS; }
|
||||
|
||||
Status VirtualDatasetInfo::InferTensorMap() {
|
||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||
bool full_batch = ParallelContext::GetInstance()->full_batch();
|
||||
|
||||
for (size_t i = 0; i < strategy_->GetInputNumber(); i++) {
|
||||
std::vector<int32_t> tensor_map_index;
|
||||
tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(dev_matrix_shape_.size()))));
|
||||
if (full_batch) {
|
||||
tensor_map_index.push_back(MAP_NONE);
|
||||
} else {
|
||||
tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(dev_matrix_shape_.size()))));
|
||||
}
|
||||
for (size_t j = 1; j < strategy_->GetInputDim()[i].size(); ++j) {
|
||||
tensor_map_index.push_back(MAP_NONE);
|
||||
}
|
||||
|
@ -213,6 +176,10 @@ Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
|||
}
|
||||
|
||||
Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) {
|
||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||
bool full_batch = ParallelContext::GetInstance()->full_batch();
|
||||
size_t total_dev_num;
|
||||
|
||||
if (GetAttrs() != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": GetAttrs failed";
|
||||
return FAILED;
|
||||
|
@ -220,7 +187,11 @@ Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) {
|
|||
|
||||
CheckGlobalDeviceManager();
|
||||
is_auto_parallel_ = true;
|
||||
size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
|
||||
if (full_batch) {
|
||||
total_dev_num = 1;
|
||||
} else {
|
||||
total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
|
||||
}
|
||||
StrategyPtr sp;
|
||||
std::vector<Dimensions> strategy;
|
||||
for (auto &shape : inputs_shape_) {
|
||||
|
@ -232,10 +203,18 @@ Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) {
|
|||
sp = std::make_shared<Strategy>(stage_id, strategy);
|
||||
|
||||
if (SetCostUnderStrategy(sp) == SUCCESS) {
|
||||
MS_LOG(INFO) << name_ << ": Successfully generated batch-parallel-strategy.";
|
||||
if (full_batch) {
|
||||
MS_LOG(INFO) << name_ << ": Successfully generated full-batch-parallel-strategy.";
|
||||
} else {
|
||||
MS_LOG(INFO) << name_ << ": Successfully generated batch-parallel-strategy.";
|
||||
}
|
||||
PrintStrategy(sp);
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << ": Generating batch-parallel-strategy failed.";
|
||||
if (full_batch) {
|
||||
MS_LOG(ERROR) << name_ << ": Generating full-batch-parallel-strategy failed.";
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << ": Generating batch-parallel-strategy failed.";
|
||||
}
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
|
|
|
@ -1375,11 +1375,19 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
|
|||
|
||||
void SetVirtualDatasetStrategy(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||
bool full_batch = ParallelContext::GetInstance()->full_batch();
|
||||
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(node->input(0));
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (prim->name() == VIRTUAL_DATA_SET) {
|
||||
CheckGlobalDeviceManager();
|
||||
int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size());
|
||||
int32_t dev_num;
|
||||
if (full_batch) {
|
||||
dev_num = 1;
|
||||
} else {
|
||||
dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size());
|
||||
}
|
||||
auto attrs_temp = prim->attrs();
|
||||
std::vector<Shapes> shape_list = ExtractShape(node);
|
||||
if (shape_list.empty()) {
|
||||
|
|
|
@ -187,6 +187,8 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
"Set strategy checkpoint save file.")
|
||||
.def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.")
|
||||
.def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.")
|
||||
.def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.")
|
||||
.def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.")
|
||||
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
|
||||
|
||||
(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")
|
||||
|
|
|
@ -367,7 +367,8 @@ def _context():
|
|||
|
||||
|
||||
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str,
|
||||
parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str)
|
||||
parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str,
|
||||
full_batch=bool)
|
||||
def set_auto_parallel_context(**kwargs):
|
||||
"""
|
||||
Set auto parallel context.
|
||||
|
@ -404,6 +405,7 @@ def set_auto_parallel_context(**kwargs):
|
|||
broadcast. Default: False.
|
||||
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
|
||||
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
|
||||
full_batch (bool): Whether to load the whole batch on each device. Default: False.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not attribute in auto parallel context.
|
||||
|
|
|
@ -225,6 +225,21 @@ class _AutoParallelContext:
|
|||
self.check_context_handle()
|
||||
return self._context_handle.get_strategy_ckpt_load_file()
|
||||
|
||||
def set_full_batch(self, full_batch):
|
||||
"""
|
||||
Set whether load full batch on each device.
|
||||
|
||||
Args:
|
||||
full_batch (bool): True if load full batch on each device.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
self._context_handle.set_full_batch(full_batch)
|
||||
|
||||
def get_full_batch(self):
|
||||
"""Get whether load full batch on each device."""
|
||||
self.check_context_handle()
|
||||
return self._context_handle.get_full_batch()
|
||||
|
||||
def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file):
|
||||
"""
|
||||
Set strategy checkpoint save path.
|
||||
|
@ -409,7 +424,8 @@ _set_auto_parallel_context_func_map = {
|
|||
"parallel_mode": auto_parallel_context().set_parallel_mode,
|
||||
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
|
||||
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
|
||||
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file}
|
||||
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
|
||||
"full_batch": auto_parallel_context().set_full_batch}
|
||||
|
||||
|
||||
_get_auto_parallel_context_func_map = {
|
||||
|
@ -421,12 +437,13 @@ _get_auto_parallel_context_func_map = {
|
|||
"parallel_mode": auto_parallel_context().get_parallel_mode,
|
||||
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
|
||||
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
|
||||
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file}
|
||||
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
|
||||
"full_batch": auto_parallel_context().get_full_batch}
|
||||
|
||||
|
||||
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool,
|
||||
loss_repeated_mean=bool, parallel_mode=str, parameter_broadcast=bool,
|
||||
strategy_ckpt_load_file=str, strategy_ckpt_save_file=str)
|
||||
strategy_ckpt_load_file=str, strategy_ckpt_save_file=str, full_batch=bool)
|
||||
def _set_auto_parallel_context(**kwargs):
|
||||
"""
|
||||
Set auto parallel context.
|
||||
|
@ -459,6 +476,7 @@ def _set_auto_parallel_context(**kwargs):
|
|||
broadcast. Default: False.
|
||||
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
|
||||
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
|
||||
full_batch (bool): Whether to load the whole batch on each device. Default: False.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not attribute in auto parallel context.
|
||||
|
|
|
@ -20,10 +20,26 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
|||
|
||||
|
||||
def _get_parallel_mode():
|
||||
"""Get parallel mode."""
|
||||
return auto_parallel_context().get_parallel_mode()
|
||||
|
||||
|
||||
def _get_full_batch():
|
||||
"""Get whether to use full_batch."""
|
||||
return auto_parallel_context().get_full_batch()
|
||||
|
||||
|
||||
def _need_to_full():
|
||||
"""Check whether to convert input to full shape or tensor."""
|
||||
parallel_mode = _get_parallel_mode()
|
||||
full_batch = _get_full_batch()
|
||||
need = ((parallel_mode in ("semi_auto_parallel", "auto_parallel"))
|
||||
and (not full_batch))
|
||||
return need
|
||||
|
||||
|
||||
def _get_mirror_mean():
|
||||
"""Get if using mirror_mean."""
|
||||
return auto_parallel_context().get_mirror_mean()
|
||||
|
||||
|
||||
|
|
|
@ -17,11 +17,10 @@ import math
|
|||
|
||||
from mindspore._checkparam import check_bool
|
||||
from .. import context
|
||||
from .parallel_utils import ParallelMode
|
||||
from ._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \
|
||||
_construct_tensor_list, _to_full_shapes, _to_full_tensor
|
||||
from ..nn.wrap import GetNextSingleOp
|
||||
from ..parallel._utils import _get_device_num, _get_global_rank, _get_parallel_mode
|
||||
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full
|
||||
|
||||
|
||||
class DatasetHelper:
|
||||
|
@ -118,10 +117,10 @@ class _DatasetIterMSLoopSink(_DatasetIter):
|
|||
def __init__(self, dataset):
|
||||
super(_DatasetIterMSLoopSink, self).__init__(dataset)
|
||||
self.loop_count = self.get_loop_count(dataset)
|
||||
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to
|
||||
# compile, and slice tensor to run. The batch dimension of tensors for compile is device_number
|
||||
# times the batch dimension of tensors for run. Now only support LoopSink.
|
||||
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch,
|
||||
# use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for
|
||||
# compile is device_number times the batch dimension of tensors for run. Now only support LoopSink.
|
||||
if _need_to_full():
|
||||
device_num = _get_device_num()
|
||||
self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num)
|
||||
|
||||
|
@ -146,10 +145,8 @@ class _DatasetIterGE(_DatasetIter):
|
|||
def __init__(self, dataset):
|
||||
super(_DatasetIterGE, self).__init__(dataset)
|
||||
self.loop_count = self.get_loop_count(dataset)
|
||||
parallel_mode = _get_parallel_mode()
|
||||
self.need_to_full = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||
batch_expand_num = 1
|
||||
if self.need_to_full:
|
||||
if _need_to_full():
|
||||
batch_expand_num = _get_device_num()
|
||||
tensor_list_run = _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num)
|
||||
|
||||
|
@ -170,9 +167,6 @@ class _DatasetIterFeed:
|
|||
self.loop_count = dataset.get_dataset_size()
|
||||
self.ind = 0
|
||||
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
self.need_to_full = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||
|
||||
def __iter__(self):
|
||||
if self.repeat_ind % self.repeat_count == 0:
|
||||
self.iter = self.dataset.__iter__()
|
||||
|
@ -186,6 +180,6 @@ class _DatasetIterFeed:
|
|||
raise StopIteration()
|
||||
self.ind += 1
|
||||
data = self.iter.__next__()
|
||||
if self.need_to_full:
|
||||
if _need_to_full():
|
||||
return _to_full_tensor(data, self.device_num, self.global_rank)
|
||||
return _to_tensor(data)
|
||||
|
|
|
@ -22,6 +22,7 @@ def argparse_init():
|
|||
parser = argparse.ArgumentParser(description='WideDeep')
|
||||
parser.add_argument("--data_path", type=str, default="./test_raw_data/")
|
||||
parser.add_argument("--epochs", type=int, default=15)
|
||||
parser.add_argument("--full_batch", type=bool, default=False)
|
||||
parser.add_argument("--batch_size", type=int, default=16000)
|
||||
parser.add_argument("--eval_batch_size", type=int, default=16000)
|
||||
parser.add_argument("--field_size", type=int, default=39)
|
||||
|
@ -44,6 +45,7 @@ class WideDeepConfig():
|
|||
"""
|
||||
def __init__(self):
|
||||
self.data_path = "./test_raw_data/"
|
||||
self.full_batch = False
|
||||
self.epochs = 15
|
||||
self.batch_size = 16000
|
||||
self.eval_batch_size = 16000
|
||||
|
@ -72,6 +74,7 @@ class WideDeepConfig():
|
|||
args, _ = parser.parse_known_args()
|
||||
self.data_path = args.data_path
|
||||
self.epochs = args.epochs
|
||||
self.full_batch = args.full_batch
|
||||
self.batch_size = args.batch_size
|
||||
self.eval_batch_size = args.eval_batch_size
|
||||
self.field_size = args.field_size
|
||||
|
|
|
@ -17,8 +17,10 @@
|
|||
Area under cure metric
|
||||
"""
|
||||
|
||||
from mindspore.nn.metrics import Metric
|
||||
from sklearn.metrics import roc_auc_score
|
||||
from mindspore import context
|
||||
from mindspore.nn.metrics import Metric
|
||||
from mindspore.communication.management import get_rank, get_group_size
|
||||
|
||||
class AUCMetric(Metric):
|
||||
"""
|
||||
|
@ -28,6 +30,7 @@ class AUCMetric(Metric):
|
|||
def __init__(self):
|
||||
super(AUCMetric, self).__init__()
|
||||
self.clear()
|
||||
self.full_batch = context.get_auto_parallel_context("full_batch")
|
||||
|
||||
def clear(self):
|
||||
"""Clear the internal evaluation result."""
|
||||
|
@ -35,10 +38,17 @@ class AUCMetric(Metric):
|
|||
self.pred_probs = []
|
||||
|
||||
def update(self, *inputs): # inputs
|
||||
all_predict = inputs[1].asnumpy() # predict
|
||||
all_label = inputs[2].asnumpy() # label
|
||||
self.true_labels.extend(all_label.flatten().tolist())
|
||||
self.pred_probs.extend(all_predict.flatten().tolist())
|
||||
"""Update list of predicts and labels."""
|
||||
all_predict = inputs[1].asnumpy().flatten().tolist() # predict
|
||||
all_label = inputs[2].asnumpy().flatten().tolist() # label
|
||||
self.pred_probs.extend(all_predict)
|
||||
if self.full_batch:
|
||||
rank_id = get_rank()
|
||||
group_size = get_group_size()
|
||||
gap = len(all_label) // group_size
|
||||
self.true_labels.extend(all_label[rank_id*gap: (rank_id+1)*gap])
|
||||
else:
|
||||
self.true_labels.extend(all_label)
|
||||
|
||||
def eval(self):
|
||||
if len(self.true_labels) != len(self.pred_probs):
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
import os
|
||||
import sys
|
||||
import mindspore.dataset.engine as de
|
||||
from mindspore import Model, context
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||
from mindspore.train import ParallelMode
|
||||
|
@ -79,10 +80,18 @@ def test_train_eval():
|
|||
batch_size = config.batch_size
|
||||
epochs = config.epochs
|
||||
print("epochs is {}".format(epochs))
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs,
|
||||
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
|
||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1,
|
||||
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
|
||||
if config.full_batch:
|
||||
context.set_auto_parallel_context(full_batch=True)
|
||||
de.config.set_seed(1)
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs,
|
||||
batch_size=batch_size*get_group_size())
|
||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1,
|
||||
batch_size=batch_size*get_group_size())
|
||||
else:
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs,
|
||||
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
|
||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1,
|
||||
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
|
||||
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
|
||||
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
|
||||
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.parallel._utils import _reset_op_id
|
||||
from mindspore.train import Model, ParallelMode
|
||||
from tests.dataset_mock import MindData
|
||||
|
||||
class Dataset(MindData):
|
||||
def __init__(self, predict, label, length=3):
|
||||
super(Dataset, self).__init__(size=length)
|
||||
self.predict = predict
|
||||
self.label = label
|
||||
self.index = 0
|
||||
self.length = length
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.index >= self.length:
|
||||
raise StopIteration
|
||||
self.index += 1
|
||||
return self.predict, self.label
|
||||
|
||||
def reset(self):
|
||||
self.index = 0
|
||||
|
||||
|
||||
class AllToAllNet(nn.Cell):
|
||||
def __init__(self, strategy1):
|
||||
super(AllToAllNet, self).__init__()
|
||||
self.matmul = P.MatMul().set_strategy(((1, 1), (1, 8)))
|
||||
self.matmul_weight = Parameter(Tensor(np.ones([128, 256]), dtype=ms.float32), name="weight")
|
||||
self.transpose1 = P.Transpose().set_strategy(strategy1)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.matmul(x, self.matmul_weight)
|
||||
x = self.transpose1(x, (1, 0))
|
||||
return x
|
||||
|
||||
def all_to_all_net(strategy1):
|
||||
return AllToAllNet(strategy1=strategy1)
|
||||
|
||||
def all_to_all_common(strategy1):
|
||||
learning_rate = 0.1
|
||||
momentum = 0.9
|
||||
epoch_size = 2
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=8, full_batch=True)
|
||||
predict = Tensor(np.ones([256, 128]), dtype=ms.float32)
|
||||
label = Tensor(np.ones([256]), dtype=ms.int32)
|
||||
dataset = Dataset(predict, label, 2)
|
||||
net = all_to_all_net(strategy1)
|
||||
|
||||
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||
loss.softmax_cross_entropy.set_strategy(((8, 1), (8, 1)))
|
||||
loss.one_hot.set_strategy(((8, 1), (), ()))
|
||||
opt = Momentum(net.trainable_params(), learning_rate, momentum)
|
||||
model = Model(net, loss, opt)
|
||||
|
||||
model.train(epoch_size, dataset, dataset_sink_mode=False)
|
||||
|
||||
def test_all_to_all():
|
||||
strategy1 = ((8, 1),)
|
||||
_reset_op_id()
|
||||
all_to_all_common(strategy1)
|
Loading…
Reference in New Issue