forked from mindspore-Ecosystem/mindspore
enable parallel optimizer in auto parallel
This commit is contained in:
parent
c1b9efe8e6
commit
ddc0113058
|
@ -22,6 +22,7 @@
|
|||
#include "ir/anf.h"
|
||||
#include "frontend/parallel/allreduce_fusion/allreduce_graph.h"
|
||||
#include "frontend/parallel/status.h"
|
||||
#include "frontend/parallel/ops_info/ops_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
@ -35,7 +36,6 @@ constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_INHERENT_TIME = 0
|
|||
constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_BANDWIDTH = 0.1;
|
||||
constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_COMPUTATION_TIME_PARAMETER = 0.1;
|
||||
|
||||
constexpr char FUSION[] = "fusion";
|
||||
constexpr char PARAMETER[] = "parameter";
|
||||
const uint32_t MAX_RECURSIVE_CALL_TIMES = 100;
|
||||
class AllreduceFusion {
|
||||
|
|
|
@ -42,15 +42,11 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) {
|
|||
auto device_arrangement = tensor_layout->device_arrangement().array();
|
||||
auto tensor_map = tensor_layout->tensor_map().array();
|
||||
auto slice_shape = tensor_layout->slice_shape().array();
|
||||
Shape field_size = {tensor_layout->get_field_size()};
|
||||
Shape uniform_split;
|
||||
if (tensor_layout->uniform_split()) {
|
||||
uniform_split.push_back(1);
|
||||
} else {
|
||||
uniform_split.push_back(0);
|
||||
}
|
||||
|
||||
std::vector<Shape> layout = {device_arrangement, tensor_map, slice_shape, field_size, uniform_split};
|
||||
int32_t field_size = tensor_layout->get_field_size();
|
||||
bool uniform_split = tensor_layout->uniform_split();
|
||||
std::string opt_shard_group = tensor_layout->opt_shard_group();
|
||||
py::tuple layout =
|
||||
py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group);
|
||||
dict[py::str(name)] = layout;
|
||||
MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString();
|
||||
}
|
||||
|
|
|
@ -226,6 +226,21 @@ Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator CreateAllGatherOp(const std::string &group) {
|
||||
OperatorName operator_name = ALL_GATHER;
|
||||
ValuePtr attr0_value = MakeValue(group); // group
|
||||
Attr attr0 = std::make_pair(GROUP, attr0_value);
|
||||
OperatorAttrs operator_attrs;
|
||||
operator_attrs.push_back(attr0);
|
||||
|
||||
OperatorParams operator_param;
|
||||
OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
|
||||
|
||||
Operator op = std::make_pair(operator_name, operator_arg);
|
||||
MS_LOG(INFO) << "Create allgather op success, the group is " << group;
|
||||
return op;
|
||||
}
|
||||
|
||||
// use for get tensor slice
|
||||
Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) {
|
||||
Shape tensor_map = tensor_layout.tensor_map().array();
|
||||
|
|
|
@ -164,6 +164,10 @@ class OperatorInfo {
|
|||
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
|
||||
void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; }
|
||||
int32_t stage_id() const { return stage_id_; }
|
||||
void set_opt_shard_flag(bool flag) { opt_shard_flag_ = flag; }
|
||||
bool opt_shard_flag() { return opt_shard_flag_; }
|
||||
Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group);
|
||||
|
||||
// Key for user data.
|
||||
constexpr static char key[] = "OpInfo";
|
||||
|
||||
|
@ -180,7 +184,6 @@ class OperatorInfo {
|
|||
Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape);
|
||||
void SetDeviceListByStrategy();
|
||||
void SetRepeatedCalcDevMatrix();
|
||||
Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group);
|
||||
Status CreateGroupByDim(size_t axis, std::vector<Group> *group);
|
||||
Status InferAttrs();
|
||||
void ResetQueueMember();
|
||||
|
@ -263,6 +266,7 @@ class OperatorInfo {
|
|||
private:
|
||||
OperatorCostPtr operator_cost_;
|
||||
std::vector<TypePtr> outputs_type_;
|
||||
bool opt_shard_flag_ = false;
|
||||
};
|
||||
|
||||
Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy);
|
||||
|
@ -270,6 +274,7 @@ Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shap
|
|||
Operator CreateVirtualDivOp(int32_t div_num);
|
||||
Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group);
|
||||
Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group);
|
||||
Operator CreateAllGatherOp(const std::string &group);
|
||||
Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout);
|
||||
OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num);
|
||||
int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map);
|
||||
|
|
|
@ -98,6 +98,7 @@ constexpr char BEGIN[] = "begin";
|
|||
constexpr char END[] = "end";
|
||||
constexpr char STRIDES[] = "strides";
|
||||
constexpr char GROUP[] = "group";
|
||||
constexpr char FUSION[] = "fusion";
|
||||
constexpr char AXIS[] = "axis";
|
||||
constexpr char OUTPUT_NUM[] = "output_num";
|
||||
constexpr char SPLIT_COUNT[] = "split_count";
|
||||
|
@ -140,6 +141,7 @@ constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter";
|
|||
constexpr char FIELD_SIZE[] = "field_size";
|
||||
constexpr char OPTIMIZER_SUB_STRING[] = "optimizer";
|
||||
constexpr char DEVICE[] = "Device";
|
||||
constexpr char PARALLEL_OPTIMIZER_ALLGATHER[] = "parallel_optimizer_allgather";
|
||||
|
||||
// Operator
|
||||
constexpr char VIRTUAL_DIV[] = "_VirtualDiv";
|
||||
|
|
|
@ -121,6 +121,7 @@ void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const An
|
|||
new_node->set_scope(scope);
|
||||
node_input[0]->set_scope(scope);
|
||||
manager->SetEdge(node, SizeToInt(index), new_node);
|
||||
MS_LOG(INFO) << "Insert " << instance_name << " success";
|
||||
}
|
||||
|
||||
std::string CreateInstanceName(const CNodePtr &node, size_t index) {
|
||||
|
@ -924,7 +925,7 @@ void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNo
|
|||
MirrorOps mirror_ops = distribute_operator->mirror_ops();
|
||||
VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op();
|
||||
// insert mirror op
|
||||
if (!mirror_ops.empty()) {
|
||||
if (!mirror_ops.empty() && !distribute_operator->opt_shard_flag()) {
|
||||
MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name();
|
||||
InsertMirrorOps(mirror_ops, node);
|
||||
}
|
||||
|
@ -1263,6 +1264,37 @@ std::pair<AnfNodePtr, int> FindSubGraph(const FuncGraphPtr &graph, const AnfNode
|
|||
return std::make_pair(nullptr, 0);
|
||||
}
|
||||
|
||||
void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr &distribute_operator,
|
||||
const CNodePtr &cnode, const AnfNodePtr ¶meter, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(distribute_operator);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
std::vector<Group> dev_group;
|
||||
// create communication group for allgather operator
|
||||
if (distribute_operator->CreateGroupByTensorMap(tensor_layout->origin_tensor_map().array(), &dev_group) ==
|
||||
Status::SUCCESS &&
|
||||
!dev_group.empty()) {
|
||||
// set optimizer shard split flag to avoid inserting mirror_ops
|
||||
distribute_operator->set_opt_shard_flag(true);
|
||||
// insert allgather operator between shard parameter and cnode
|
||||
Operator op = CreateAllGatherOp(dev_group[0].name());
|
||||
auto graph = cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
InsertNode(op, cnode, index, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER);
|
||||
// set communication group in tensor layout for checkpoint saving
|
||||
tensor_layout->set_opt_shard_group(dev_group[0].name());
|
||||
// add fusion flag
|
||||
auto allgather = cnode->input(index)->cast<CNodePtr>();
|
||||
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0));
|
||||
auto attrs = prim->attrs();
|
||||
attrs["fusion"] = MakeValue(1);
|
||||
prim->SetAttrs(attrs);
|
||||
MS_LOG(INFO) << "Parallel optimizer is applied on " << parameter->ToString();
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Parallel optimizer applied on " << parameter->ToString() << "failed!";
|
||||
}
|
||||
}
|
||||
|
||||
void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, int> &res) {
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
AbstractBasePtr abstract = parameter->abstract();
|
||||
|
@ -1280,7 +1312,22 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i
|
|||
<< distribute_operator->inputs_tensor_info().size();
|
||||
}
|
||||
TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[IntToSize(res.second - 1)];
|
||||
Shape slice_shape = tensorinfo_in.slice_shape();
|
||||
TensorLayout tensor_layout = tensorinfo_in.tensor_layout();
|
||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||
bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer();
|
||||
Shape slice_shape = tensor_layout.slice_shape().array();
|
||||
if (enable_parallel_optimizer) {
|
||||
if (!ParameterRequireGrad(parameter)) {
|
||||
// only trainable parameters need parallel optimizer
|
||||
MS_LOG(INFO) << "Parallel optimizer is no need for " << parameter->ToString();
|
||||
} else if (tensor_layout.GenerateOptShardSliceShape() == Status::SUCCESS) {
|
||||
// get a totally shard tensor slice shape if the weight is repeated on devices
|
||||
// and the shape of the first dimension could be divided
|
||||
// apply parallel optimizer on parameters
|
||||
ApplyParallelOptOnParam(&tensor_layout, distribute_operator, cnode, parameter, IntToSize(res.second));
|
||||
slice_shape = tensor_layout.opt_shard_slice_shape();
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "SetParallelShape slice_shape " << parameter->ToString() << " shape "
|
||||
<< MakeValue(slice_shape)->ToString() << ", op name is " << distribute_operator->name();
|
||||
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
|
||||
|
@ -1290,7 +1337,6 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i
|
|||
MS_EXCEPTION_IF_NULL(cloned_abstract);
|
||||
cloned_abstract->set_shape(parallel_shape);
|
||||
parameter->set_abstract(cloned_abstract);
|
||||
TensorLayout tensor_layout = tensorinfo_in.tensor_layout();
|
||||
ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(parameter_ptr);
|
||||
parameter_ptr->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
|
||||
|
|
|
@ -160,6 +160,9 @@ RefKeyPair CNodeWithRefKeys(const AnfNodePtr &cnode);
|
|||
std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node);
|
||||
|
||||
ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &));
|
||||
|
||||
void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr &distribute_operator,
|
||||
const CNodePtr &cnode, const AnfNodePtr ¶meter, size_t index);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -389,5 +389,39 @@ TensorLayout TensorLayout::SqueezeShape() const {
|
|||
(void)out.Init(device_arrangement_, out_map, out_shape);
|
||||
return out;
|
||||
}
|
||||
|
||||
// Generate a totally shard tensor slice shape for parallel optimizer
|
||||
Status TensorLayout::GenerateOptShardSliceShape() {
|
||||
MS_LOG(INFO) << "layout for GetOptShardSliceShape is " << StandardToString();
|
||||
Shape dev_max = device_arrangement_.array();
|
||||
Shape tensor_map = tensor_map_.array();
|
||||
Shape repeated_dev;
|
||||
for (size_t i = 0; i < dev_max.size(); i++) {
|
||||
if (tensor_map_.GetIndexByValue(i) == MAP_NONE) {
|
||||
repeated_dev.push_back(dev_max[dev_max.size() - 1 - i]);
|
||||
dev_max[dev_max.size() - 1 - i] = 1;
|
||||
}
|
||||
}
|
||||
if (repeated_dev.empty()) {
|
||||
MS_LOG(INFO) << "Tensor is totally shard already.";
|
||||
return Status::FAILED;
|
||||
}
|
||||
int64_t repeated_num =
|
||||
std::accumulate(repeated_dev.begin(), repeated_dev.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
|
||||
int64_t split_num;
|
||||
if (tensor_map[0] == MAP_NONE) {
|
||||
split_num = repeated_num;
|
||||
} else {
|
||||
split_num = dev_max[dev_max.size() - 1 - tensor_map[0]] * repeated_num;
|
||||
}
|
||||
if (tensor_shape_.array()[0] % split_num != 0) {
|
||||
MS_LOG(INFO) << "Tensor could not be shard on the first dimension.";
|
||||
return Status::FAILED;
|
||||
}
|
||||
Shape origin_slice_shape = slice_shape().array();
|
||||
origin_slice_shape[0] = tensor_shape_.array()[0] / split_num;
|
||||
opt_shard_slice_shape_ = origin_slice_shape;
|
||||
return Status::SUCCESS;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,7 +21,9 @@
|
|||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include "frontend/parallel/device_manager.h"
|
||||
#include "frontend/parallel/status.h"
|
||||
#include "frontend/parallel/tensor_layout/arrangement.h"
|
||||
|
@ -86,6 +88,14 @@ class TensorLayout {
|
|||
|
||||
TensorLayout SqueezeShape() const;
|
||||
|
||||
Status GenerateOptShardSliceShape();
|
||||
|
||||
Shape opt_shard_slice_shape() { return opt_shard_slice_shape_; }
|
||||
|
||||
void set_opt_shard_group(std::string name) { opt_shard_group_ = std::move(name); }
|
||||
|
||||
std::string opt_shard_group() { return opt_shard_group_; }
|
||||
|
||||
// Key for user data.
|
||||
constexpr static char key[] = "TLayout";
|
||||
|
||||
|
@ -109,6 +119,8 @@ class TensorLayout {
|
|||
bool skip_redistribution_ = false;
|
||||
int32_t field_size_ = 0;
|
||||
bool uniform_split_ = true;
|
||||
Shape opt_shard_slice_shape_;
|
||||
std::string opt_shard_group_ = "";
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -396,8 +396,8 @@ class Parameter(MetaTensor):
|
|||
if self.inited_param is not None:
|
||||
return self.inited_param
|
||||
if layout is not None:
|
||||
if not isinstance(layout, list):
|
||||
raise TypeError("The layout should be list! layout is {}.".format(layout))
|
||||
if not isinstance(layout, tuple):
|
||||
raise TypeError("The layout should be tuple! layout is {}.".format(layout))
|
||||
if len(layout) < 3:
|
||||
raise ValueError("The length of layout must be larger than 3! layout is {}.".format(layout))
|
||||
slice_index = int(_get_slice_index(layout[0], layout[1]))
|
||||
|
|
|
@ -334,7 +334,7 @@ def _context():
|
|||
all_reduce_fusion_config=list, pipeline_stages=int)
|
||||
def set_auto_parallel_context(**kwargs):
|
||||
r"""
|
||||
Set auto parallel context.
|
||||
Set auto parallel context, which is valid only for Ascend and GPU target.
|
||||
|
||||
Auto parallel context should be configured before the initialization of your network.
|
||||
|
||||
|
@ -348,17 +348,17 @@ def set_auto_parallel_context(**kwargs):
|
|||
|
||||
Some configurations are parallel mode specific, see the below table for details:
|
||||
|
||||
=========================== =========================== =================
|
||||
Common AUTO_PARALLEL DATA_PARALLEL
|
||||
=========================== =========================== =================
|
||||
device_num gradient_fp32_sync enable_parallel_optimizer
|
||||
=========================== ===========================
|
||||
Common AUTO_PARALLEL
|
||||
=========================== ===========================
|
||||
device_num gradient_fp32_sync
|
||||
global_rank loss_repeated_mean
|
||||
gradients_mean auto_parallel_search_mode
|
||||
parallel_mode strategy_ckpt_load_file
|
||||
all_reduce_fusion_config strategy_ckpt_save_file
|
||||
\ full_batch
|
||||
\ pipeline_stages
|
||||
=========================== =========================== =================
|
||||
enable_parallel_optimizer full_batch
|
||||
\ pipeline_stages
|
||||
=========================== ===========================
|
||||
|
||||
Args:
|
||||
device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
|
||||
|
@ -387,7 +387,7 @@ def set_auto_parallel_context(**kwargs):
|
|||
- recursive_programming: Recursive programming search mode.
|
||||
|
||||
- dynamic_programming: Dynamic programming search mode.
|
||||
parameter_broadcast (bool): Whether to broadcast parameters before training.
|
||||
parameter_broadcast (bool): A developing feature. Whether to broadcast parameters before training.
|
||||
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
|
||||
broadcast. Default: False.
|
||||
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
|
||||
|
@ -395,9 +395,9 @@ def set_auto_parallel_context(**kwargs):
|
|||
full_batch (bool): If you load whole batch datasets in auto_parallel mode, this parameter
|
||||
should be set with True. Default: False.
|
||||
enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for
|
||||
data parallel training in the benefit of time and memory saving. For now,
|
||||
`Lamb` and `AdamWeightDecay` are supported in data parallel mode. No Default, if it is not set,
|
||||
the fusion is closed.
|
||||
data parallel training in the benefit of time and memory saving. For now, auto parallel mode
|
||||
supports all optimizers. Data parallel mode only supports `Lamb` and `AdamWeightDecay`.
|
||||
Default: False.
|
||||
all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM
|
||||
and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. No Default, if it is not set, the fusion is closed.
|
||||
pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how
|
||||
|
|
|
@ -148,15 +148,18 @@ class Optimizer(Cell):
|
|||
self.reciprocal_scale = 1.0 / loss_scale
|
||||
self.param_length = len(self.parameters)
|
||||
self.map_ = C.Map()
|
||||
|
||||
use_parallel = context.get_auto_parallel_context("enable_parallel_optimizer")
|
||||
self.use_parallel = use_parallel
|
||||
if use_parallel:
|
||||
if context.get_auto_parallel_context("enable_parallel_optimizer"):
|
||||
if _get_parallel_mode() == ParallelMode.DATA_PARALLEL:
|
||||
self.use_parallel = True
|
||||
elif _get_parallel_mode() == ParallelMode.STAND_ALONE:
|
||||
raise RuntimeError("Parallel optimizer is not supported in stand alone mode.")
|
||||
else:
|
||||
self.use_parallel = False
|
||||
else:
|
||||
self.use_parallel = False
|
||||
if self.use_parallel:
|
||||
if self.cls_name not in ["Lamb", "AdamWeightDecay"]:
|
||||
raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name))
|
||||
if _get_parallel_mode() != ParallelMode.DATA_PARALLEL:
|
||||
raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format
|
||||
(_get_parallel_mode()))
|
||||
self.dev_num = _get_device_num()
|
||||
if self.dev_num > self.param_length:
|
||||
raise RuntimeError("Optimizer segmentation can not be applied when the number of parameters {} is"
|
||||
|
|
|
@ -83,8 +83,10 @@ def get_bprop_broad_cast(self):
|
|||
def get_bprop_all_gather(self):
|
||||
"""Generate bprop for AllGather"""
|
||||
all_gather_grad = ReduceScatter(ReduceOp.SUM, self.group)
|
||||
fusion = self.get_attr_dict()["fusion"]
|
||||
all_gather_grad.add_prim_attr("fusion", fusion)
|
||||
if self.instance_name:
|
||||
instance_name = "grad" + self.instance_name
|
||||
instance_name = "grad_" + self.instance_name
|
||||
all_gather_grad.set_prim_instance_name(instance_name)
|
||||
|
||||
def bprop(x, out, dout):
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
"""comm_ops"""
|
||||
|
||||
from mindspore.common import Tensor
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ...communication.management import get_rank, get_group_size, GlobalComm, _get_group
|
||||
|
@ -158,6 +159,7 @@ class AllGather(PrimitiveWithInfer):
|
|||
validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name)
|
||||
self.add_prim_attr('rank_size', self.rank_size)
|
||||
self.add_prim_attr('group', _get_group(group))
|
||||
self.add_prim_attr('fusion', 0)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
validator.check_integer("x shape", len(x_shape), 0, Rel.GT, self.name)
|
||||
|
@ -268,6 +270,7 @@ class ReduceScatter(PrimitiveWithInfer):
|
|||
self.rank_size = get_group_size(_get_group(group))
|
||||
self.add_prim_attr('rank_size', self.rank_size)
|
||||
self.add_prim_attr('group', _get_group(group))
|
||||
self.add_prim_attr('fusion', 0)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
if x_shape[0] % self.rank_size != 0:
|
||||
|
@ -526,4 +529,4 @@ class _GetTensorSlice(PrimitiveWithInfer):
|
|||
from mindspore.parallel._tensor import _load_tensor
|
||||
validator.check_value_type("dev_mat", dev_mat, [tuple], self.name)
|
||||
validator.check_value_type("tensor_map", tensor_map, [tuple], self.name)
|
||||
return _load_tensor(x, dev_mat, tensor_map)
|
||||
return Tensor(_load_tensor(x, dev_mat, tensor_map))
|
||||
|
|
|
@ -37,12 +37,34 @@ class AllGatherCell(Cell):
|
|||
return x
|
||||
|
||||
|
||||
def get_allgather_cell():
|
||||
class SaveOptShardCkptCell(Cell):
|
||||
"""
|
||||
Allgather cell, used in optimizer parallel scenario.
|
||||
Firstly gather the tensor to original layout in the specified device group.
|
||||
Then gather the whole parameter slices from all devices.
|
||||
|
||||
Note:
|
||||
This could be optimized later with less communication consumption.
|
||||
"""
|
||||
def __init__(self, group):
|
||||
super(SaveOptShardCkptCell, self).__init__(auto_prefix=False)
|
||||
self.allgather1 = AllGather(group)
|
||||
self.allgather2 = AllGather()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.allgather1(x)
|
||||
x = self.allgather2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def get_allgather_cell(group):
|
||||
"""Get AllGatherCell object."""
|
||||
global _allgather_cell
|
||||
if not _allgather_cell:
|
||||
if group:
|
||||
_allgather_cell = SaveOptShardCkptCell(group)
|
||||
else:
|
||||
_allgather_cell = AllGatherCell()
|
||||
|
||||
return _allgather_cell
|
||||
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import numpy as np
|
||||
|
||||
from mindspore.common.tensor import Tensor
|
||||
from ..communication.management import get_rank
|
||||
from ..communication.management import get_rank, get_group_size
|
||||
|
||||
|
||||
def _get_tensor_strategy(dev_mat, tensor_map):
|
||||
|
@ -168,6 +168,7 @@ def _chunk_tensor_by_strategy(np_tensor, strategy):
|
|||
raise ValueError("The length of np_tensor does not match the length of strategy!")
|
||||
return _chunk_tensor(np_tensor, strategy, len(strategy))
|
||||
|
||||
|
||||
def _get_slice_index(dev_mat, tensor_map):
|
||||
"""
|
||||
Get the slice index for current slice.
|
||||
|
@ -184,6 +185,7 @@ def _get_slice_index(dev_mat, tensor_map):
|
|||
tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
|
||||
return tensor_slice_index
|
||||
|
||||
|
||||
def _load_tensor(tensor, dev_mat, tensor_map):
|
||||
"""
|
||||
Get the tensor slice of the local device by the device matrix and the tensor map
|
||||
|
@ -194,7 +196,7 @@ def _load_tensor(tensor, dev_mat, tensor_map):
|
|||
tensor_map (list): The split strategy of tensor.
|
||||
|
||||
Returns:
|
||||
Tensor, the sliced tensor.
|
||||
numpy.array, the sliced array.
|
||||
|
||||
Examples:
|
||||
>>> tensor = Tensor(np.ones([32, 32]))
|
||||
|
@ -208,8 +210,7 @@ def _load_tensor(tensor, dev_mat, tensor_map):
|
|||
np_tensor = tensor.asnumpy()
|
||||
np_tensor_list = _chunk_tensor_by_strategy(np_tensor, tensor_strategy)
|
||||
np_tensor_slice = np_tensor_list[int(tensor_slice_index)]
|
||||
tensor_slice = Tensor(np_tensor_slice)
|
||||
return tensor_slice
|
||||
return np_tensor_slice
|
||||
|
||||
|
||||
def _load_tensor_by_layout(tensor, layout):
|
||||
|
@ -227,18 +228,25 @@ def _load_tensor_by_layout(tensor, layout):
|
|||
TypeError: If layout is not list.
|
||||
ValueError: If the length of layout is not 3.
|
||||
"""
|
||||
if not isinstance(layout, list):
|
||||
raise TypeError("The layout should be list! layout is {}".format(layout))
|
||||
if len(layout) < 5:
|
||||
if not isinstance(layout, tuple):
|
||||
raise TypeError("The layout should be tuple! layout is {}".format(layout))
|
||||
if len(layout) < 6:
|
||||
raise ValueError("The length of layout must be larger than 5! layout is {}".format(layout))
|
||||
dev_mat = layout[0]
|
||||
tensor_map = layout[1]
|
||||
uniform_split = layout[4]
|
||||
if uniform_split[0] == 0:
|
||||
group = layout[5]
|
||||
if uniform_split == 0:
|
||||
raise RuntimeError("The load tensor only support uniform split now")
|
||||
if tensor.size() == 1:
|
||||
return tensor
|
||||
return _load_tensor(tensor, dev_mat, tensor_map)
|
||||
tensor_slice = _load_tensor(tensor, dev_mat, tensor_map)
|
||||
if group:
|
||||
# get a totally shard tensor slice for parallel optimizer
|
||||
rank = get_rank(group)
|
||||
size = get_group_size(group)
|
||||
tensor_slice = np.split(tensor_slice, size)[rank]
|
||||
return Tensor(tensor_slice)
|
||||
|
||||
|
||||
def _reshape_param_data(param_data, dev_mat, tensor_map):
|
||||
|
@ -294,6 +302,7 @@ def _reshape_param_data(param_data, dev_mat, tensor_map):
|
|||
|
||||
return Tensor(tensor_slices_new[0])
|
||||
|
||||
|
||||
def _reshape_param_data_with_weight(param_data, dev_mat, field_size):
|
||||
"""
|
||||
Combine param slice by the device matrix, used in model parallel scenario.
|
||||
|
@ -318,10 +327,10 @@ def _reshape_param_data_with_weight(param_data, dev_mat, field_size):
|
|||
tensor_slices = np.split(param_data.asnumpy(), device_count, axis=0)
|
||||
tensor_slices_col = []
|
||||
for i in range(len(tensor_slices[0][0])):
|
||||
tensor_slices_new = np.array(tensor_slices[0][:, i]).reshape(field_size[0], -1)
|
||||
tensor_slices_new = np.array(tensor_slices[0][:, i]).reshape(field_size, -1)
|
||||
for j in range(1, device_count):
|
||||
tensor_slices_new = np.concatenate((tensor_slices_new,\
|
||||
np.array(tensor_slices[j][:, i]).reshape(field_size[0], -1)), axis=1)
|
||||
np.array(tensor_slices[j][:, i]).reshape(field_size, -1)), axis=1)
|
||||
tensor_slices_col.append(tensor_slices_new)
|
||||
new_tensor = np.array(tensor_slices_col[0]).reshape(-1, 1)
|
||||
for i in range(1, len(tensor_slices_col)):
|
||||
|
|
|
@ -398,7 +398,7 @@ def _get_merged_param_data(net, param_name, param_data):
|
|||
Tensor, the combined tensor which with the whole data value.
|
||||
"""
|
||||
layout = net.parameter_layout_dict[param_name]
|
||||
if len(layout) < 5:
|
||||
if len(layout) < 6:
|
||||
logger.info("layout dict does not contain the key %s", param_name)
|
||||
return param_data
|
||||
|
||||
|
@ -406,17 +406,19 @@ def _get_merged_param_data(net, param_name, param_data):
|
|||
tensor_map = layout[1]
|
||||
field_size = layout[3]
|
||||
uniform_split = layout[4]
|
||||
if uniform_split[0] == 0:
|
||||
opt_shard_group = layout[5]
|
||||
if uniform_split == 0:
|
||||
raise RuntimeError("Save checkpoint only support uniform split tensor now.")
|
||||
|
||||
from mindspore.parallel._cell_wrapper import get_allgather_cell
|
||||
from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
|
||||
# while any dim is not equal to -1, means param is splited and needs to be merged
|
||||
# while any dim is not equal to -1, means param is split and needs to be merged
|
||||
# pipeline parallel need to be supported here later
|
||||
for dim in tensor_map:
|
||||
if dim != -1:
|
||||
allgather_net = get_allgather_cell()
|
||||
if dim != -1 or opt_shard_group:
|
||||
allgather_net = get_allgather_cell(opt_shard_group)
|
||||
param_data = allgather_net(param_data)
|
||||
if field_size[0]:
|
||||
if field_size:
|
||||
return _reshape_param_data_with_weight(param_data, dev_mat, field_size)
|
||||
return _reshape_param_data(param_data, dev_mat, tensor_map)
|
||||
|
||||
|
|
|
@ -35,11 +35,11 @@ class TestRedistributionLayoutTransfer : public UT::Common {
|
|||
};
|
||||
|
||||
void RedistributionLayoutTransferTestFunction(
|
||||
const DeviceArrangement& in_device_arrangement_shape, const TensorMap& in_tensor_map_shape,
|
||||
const TensorShape& tensor_shape_shape, const DeviceArrangement& out_device_arrangement_shape,
|
||||
const TensorMap& out_tensor_map_shape, DeviceArrangement* unified_device_arrangement_shape,
|
||||
TensorMap* unified_in_tensor_map_shape, TensorMap* unified_out_tensor_map_shape,
|
||||
TensorMap* unified_tensor_shape_shape) {
|
||||
const DeviceArrangement &in_device_arrangement_shape, const TensorMap &in_tensor_map_shape,
|
||||
const TensorShape &tensor_shape_shape, const DeviceArrangement &out_device_arrangement_shape,
|
||||
const TensorMap &out_tensor_map_shape, DeviceArrangement *unified_device_arrangement_shape,
|
||||
TensorMap *unified_in_tensor_map_shape, TensorMap *unified_out_tensor_map_shape,
|
||||
TensorMap *unified_tensor_shape_shape) {
|
||||
Arrangement in_device_arrangement;
|
||||
Status status = in_device_arrangement.Init(in_device_arrangement_shape);
|
||||
ASSERT_EQ(Status::SUCCESS, status);
|
||||
|
@ -86,13 +86,13 @@ void RedistributionLayoutTransferTestFunction(
|
|||
*unified_tensor_shape_shape = unified_in_tensor_shape.array();
|
||||
}
|
||||
|
||||
void RedistributionLayoutCheck(const DeviceArrangement& in_device_arrangement, const TensorMap& in_tensor_map,
|
||||
const TensorShape& tensor_shape, const DeviceArrangement& out_device_arrangement,
|
||||
const TensorMap& out_tensor_map,
|
||||
const DeviceArrangement& unified_device_arrangement_expect,
|
||||
const TensorMap& unified_in_tensor_map_expect,
|
||||
const TensorMap& unified_out_tensor_map_expect,
|
||||
const TensorMap& unified_tensor_shape_expect) {
|
||||
void RedistributionLayoutCheck(const DeviceArrangement &in_device_arrangement, const TensorMap &in_tensor_map,
|
||||
const TensorShape &tensor_shape, const DeviceArrangement &out_device_arrangement,
|
||||
const TensorMap &out_tensor_map,
|
||||
const DeviceArrangement &unified_device_arrangement_expect,
|
||||
const TensorMap &unified_in_tensor_map_expect,
|
||||
const TensorMap &unified_out_tensor_map_expect,
|
||||
const TensorMap &unified_tensor_shape_expect) {
|
||||
DeviceArrangement unified_device_arrangement;
|
||||
TensorMap unified_in_tensor_map;
|
||||
TensorMap unified_out_tensor_map;
|
||||
|
@ -224,9 +224,9 @@ TEST_F(TestRedistributionLayoutTransfer, RedistributionLayoutTransfer5) {
|
|||
unified_out_tensor_map_expect, unified_tensor_shape_expect);
|
||||
}
|
||||
|
||||
void ValidRedistributionLayoutCheck(const DeviceArrangement& in_device_arrangement, const TensorMap& in_tensor_map,
|
||||
const TensorShape& tensor_shape, const DeviceArrangement& out_device_arrangement,
|
||||
const TensorMap& out_tensor_map) {
|
||||
void ValidRedistributionLayoutCheck(const DeviceArrangement &in_device_arrangement, const TensorMap &in_tensor_map,
|
||||
const TensorShape &tensor_shape, const DeviceArrangement &out_device_arrangement,
|
||||
const TensorMap &out_tensor_map) {
|
||||
DeviceArrangement unified_device_arrangement;
|
||||
TensorMap unified_in_tensor_map;
|
||||
TensorMap unified_out_tensor_map;
|
||||
|
@ -242,8 +242,8 @@ void ValidRedistributionLayoutCheck(const DeviceArrangement& in_device_arrangeme
|
|||
unified_out_tensor_map, unified_tensor_shape);
|
||||
}
|
||||
|
||||
void ValidRedistributionLayoutCheckAll(int64_t device_pow_size, int64_t tensor_pow_size,
|
||||
int64_t max_device_dim, int64_t max_shape_dim) {
|
||||
void ValidRedistributionLayoutCheckAll(int64_t device_pow_size, int64_t tensor_pow_size, int64_t max_device_dim,
|
||||
int64_t max_shape_dim) {
|
||||
std::vector<std::tuple<DeviceArrangement, TensorMap, TensorShape>> layout_list;
|
||||
GenerateValidLayoutByDeviceSizeAndTensorSize(device_pow_size, tensor_pow_size, max_device_dim, max_shape_dim,
|
||||
&layout_list);
|
||||
|
|
|
@ -49,7 +49,7 @@ class TestRedistributionOperatorInfer : public UT::Common {
|
|||
};
|
||||
|
||||
// check if in_tensor_map could be changed to out_tensor_map with operator_list
|
||||
void InferOperatorCheck(Shape in_tensor_map, const Shape& out_tensor_map, const OperatorList& operator_list) {
|
||||
void InferOperatorCheck(Shape in_tensor_map, const Shape &out_tensor_map, const OperatorList &operator_list) {
|
||||
for (auto op_cost : operator_list) {
|
||||
OperatorR op = op_cost.first;
|
||||
Args args = op.second;
|
||||
|
|
|
@ -35,11 +35,11 @@ class TestReshapeLayoutTransfer : public UT::Common {
|
|||
virtual void TearDown() {}
|
||||
};
|
||||
|
||||
void InferUnifiedLayout(const DeviceArrangement& device_arrangement_shape, const TensorMap& in_tensor_map_shape,
|
||||
const TensorShape& in_tensor_shape_shape, const TensorMap& out_tensor_map_shape,
|
||||
const TensorShape& out_tensor_shape_shape, DeviceArrangement* unified_device_arrangement_shape,
|
||||
TensorMap* unified_in_tensor_map_shape, TensorMap* unified_out_tensor_map_shape,
|
||||
TensorMap* unified_tensor_shape_shape) {
|
||||
void InferUnifiedLayout(const DeviceArrangement &device_arrangement_shape, const TensorMap &in_tensor_map_shape,
|
||||
const TensorShape &in_tensor_shape_shape, const TensorMap &out_tensor_map_shape,
|
||||
const TensorShape &out_tensor_shape_shape, DeviceArrangement *unified_device_arrangement_shape,
|
||||
TensorMap *unified_in_tensor_map_shape, TensorMap *unified_out_tensor_map_shape,
|
||||
TensorMap *unified_tensor_shape_shape) {
|
||||
Arrangement device_arrangement;
|
||||
Status status = device_arrangement.Init(device_arrangement_shape);
|
||||
ASSERT_EQ(Status::SUCCESS, status);
|
||||
|
@ -85,13 +85,13 @@ void InferUnifiedLayout(const DeviceArrangement& device_arrangement_shape, const
|
|||
*unified_out_tensor_map_shape = unified_out_tensor_map.array();
|
||||
}
|
||||
|
||||
void InferUnifiedLayoutCheck(const DeviceArrangement& device_arrangement, const TensorMap& in_tensor_map,
|
||||
const TensorShape& in_tensor_shape, const TensorMap& out_tensor_map,
|
||||
const TensorShape& out_tensor_shape,
|
||||
const DeviceArrangement& unified_device_arrangement_expect,
|
||||
const TensorMap& unified_in_tensor_map_expect,
|
||||
const TensorMap& unified_out_tensor_map_expect,
|
||||
const TensorMap& unified_tensor_shape_expect) {
|
||||
void InferUnifiedLayoutCheck(const DeviceArrangement &device_arrangement, const TensorMap &in_tensor_map,
|
||||
const TensorShape &in_tensor_shape, const TensorMap &out_tensor_map,
|
||||
const TensorShape &out_tensor_shape,
|
||||
const DeviceArrangement &unified_device_arrangement_expect,
|
||||
const TensorMap &unified_in_tensor_map_expect,
|
||||
const TensorMap &unified_out_tensor_map_expect,
|
||||
const TensorMap &unified_tensor_shape_expect) {
|
||||
DeviceArrangement unified_device_arrangement;
|
||||
TensorMap unified_in_tensor_map;
|
||||
TensorMap unified_out_tensor_map;
|
||||
|
@ -109,9 +109,9 @@ void InferUnifiedLayoutCheck(const DeviceArrangement& device_arrangement, const
|
|||
ASSERT_EQ(unified_tensor_shape_expect, unified_tensor_shape);
|
||||
}
|
||||
|
||||
void ValidUnifiedLayoutCheck(const DeviceArrangement& device_arrangement, const TensorMap& in_tensor_map,
|
||||
const TensorShape& in_tensor_shape, const TensorMap& out_tensor_map,
|
||||
const TensorShape& out_tensor_shape) {
|
||||
void ValidUnifiedLayoutCheck(const DeviceArrangement &device_arrangement, const TensorMap &in_tensor_map,
|
||||
const TensorShape &in_tensor_shape, const TensorMap &out_tensor_map,
|
||||
const TensorShape &out_tensor_shape) {
|
||||
DeviceArrangement unified_device_arrangement;
|
||||
TensorMap unified_in_tensor_map;
|
||||
TensorMap unified_out_tensor_map;
|
||||
|
@ -257,8 +257,8 @@ TEST_F(TestReshapeLayoutTransfer, ValidInferUnifiedLayoutCheck11) {
|
|||
ValidUnifiedLayoutCheck(device_arrangement, in_tensor_map, in_tensor_shape, out_tensor_map, out_tensor_shape);
|
||||
}
|
||||
|
||||
void ValidInferUnifiedLayoutCheckAll(int64_t device_pow_size, int64_t tensor_pow_size,
|
||||
int64_t max_device_dim, int64_t max_shape_dim) {
|
||||
void ValidInferUnifiedLayoutCheckAll(int64_t device_pow_size, int64_t tensor_pow_size, int64_t max_device_dim,
|
||||
int64_t max_shape_dim) {
|
||||
std::vector<std::tuple<DeviceArrangement, TensorMap, TensorShape>> layout_list;
|
||||
GenerateValidLayoutByDeviceSizeAndTensorSize(device_pow_size, tensor_pow_size, max_device_dim, max_shape_dim,
|
||||
&layout_list);
|
||||
|
@ -297,7 +297,7 @@ TEST_F(TestReshapeLayoutTransfer, ValidInferUnifiedLayoutCheckAll) {
|
|||
ValidInferUnifiedLayoutCheckAll(device_pow_size, tensor_pow_size, max_device_dim, max_shape_dim);
|
||||
tensor_pow_size++;
|
||||
}
|
||||
device_pow_size++;
|
||||
device_pow_size++;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -32,12 +32,12 @@ class TestTensorLayout : public UT::Common {
|
|||
virtual void TearDown() {}
|
||||
};
|
||||
|
||||
void ReshapeExpandDeviceArrangementTestFunction(const DeviceArrangement& in_device_arrangement_shape,
|
||||
const TensorMap& in_tensor_map_shape,
|
||||
const TensorShape& in_tensor_shape_shape,
|
||||
const DeviceArrangement& out_device_arrangement_shape,
|
||||
const TensorMap& out_tensor_map_shape,
|
||||
const TensorShape& out_tensor_shape_shape) {
|
||||
void ReshapeExpandDeviceArrangementTestFunction(const DeviceArrangement &in_device_arrangement_shape,
|
||||
const TensorMap &in_tensor_map_shape,
|
||||
const TensorShape &in_tensor_shape_shape,
|
||||
const DeviceArrangement &out_device_arrangement_shape,
|
||||
const TensorMap &out_tensor_map_shape,
|
||||
const TensorShape &out_tensor_shape_shape) {
|
||||
Arrangement device_arrangement;
|
||||
Status status = device_arrangement.Init(in_device_arrangement_shape);
|
||||
ASSERT_EQ(Status::SUCCESS, status);
|
||||
|
@ -154,12 +154,10 @@ TEST_F(TestTensorLayout, ReshapeExpandDeviceArrangement5) {
|
|||
tensor_map_expect, tensor_shape_expect);
|
||||
}
|
||||
|
||||
void ExpandTensorShapeTestFunction(const DeviceArrangement& in_device_arrangement_shape,
|
||||
const TensorMap& in_tensor_map_shape,
|
||||
const TensorShape& in_tensor_shape_shape,
|
||||
const DeviceArrangement& out_device_arrangement_shape,
|
||||
const TensorMap& out_tensor_map_shape,
|
||||
const TensorShape& out_tensor_shape_shape) {
|
||||
void ExpandTensorShapeTestFunction(const DeviceArrangement &in_device_arrangement_shape,
|
||||
const TensorMap &in_tensor_map_shape, const TensorShape &in_tensor_shape_shape,
|
||||
const DeviceArrangement &out_device_arrangement_shape,
|
||||
const TensorMap &out_tensor_map_shape, const TensorShape &out_tensor_shape_shape) {
|
||||
Arrangement device_arrangement;
|
||||
Status status = device_arrangement.Init(in_device_arrangement_shape);
|
||||
ASSERT_EQ(Status::SUCCESS, status);
|
||||
|
@ -251,12 +249,12 @@ TEST_F(TestTensorLayout, UpdateTensorMap) {
|
|||
ASSERT_EQ(in_tensor_map, new_tensor_map);
|
||||
}
|
||||
|
||||
void RemoveElementEqualToOneInDeviceArrangementTestFunction(const DeviceArrangement& in_device_arrangement_shape,
|
||||
const TensorMap& in_tensor_map_shape,
|
||||
const TensorShape& in_tensor_shape_shape,
|
||||
const DeviceArrangement& out_device_arrangement_shape,
|
||||
const TensorMap& out_tensor_map_shape,
|
||||
const TensorShape& out_tensor_shape_shape) {
|
||||
void RemoveElementEqualToOneInDeviceArrangementTestFunction(const DeviceArrangement &in_device_arrangement_shape,
|
||||
const TensorMap &in_tensor_map_shape,
|
||||
const TensorShape &in_tensor_shape_shape,
|
||||
const DeviceArrangement &out_device_arrangement_shape,
|
||||
const TensorMap &out_tensor_map_shape,
|
||||
const TensorShape &out_tensor_shape_shape) {
|
||||
Arrangement device_arrangement;
|
||||
Status status = device_arrangement.Init(in_device_arrangement_shape);
|
||||
ASSERT_EQ(Status::SUCCESS, status);
|
||||
|
@ -310,15 +308,82 @@ TEST_F(TestTensorLayout, RemoveElementEqualToOneInDeviceArrangement3) {
|
|||
device_arrangement, tensor_map, tensor_shape, device_arrangement_expect, tensor_map_expect, tensor_shape_new);
|
||||
}
|
||||
|
||||
TEST_F(TestTensorLayout, RemoveElementEqualToOneInDeviceArrangement4) {
|
||||
DeviceArrangement device_arrangement = {1, 1, 1};
|
||||
TensorMap tensor_map = {2, 1};
|
||||
TensorShape tensor_shape = {128, 4096};
|
||||
DeviceArrangement device_arrangement_expect = {};
|
||||
TensorMap tensor_map_expect = {-1, -1};
|
||||
TensorShape tensor_shape_new = {128, 4096};
|
||||
RemoveElementEqualToOneInDeviceArrangementTestFunction(
|
||||
device_arrangement, tensor_map, tensor_shape, device_arrangement_expect, tensor_map_expect, tensor_shape_new);
|
||||
/*
|
||||
* example:
|
||||
* device_arrangement = [8, 4],
|
||||
* tensor_map = [1, 0],
|
||||
* tensor_shape = [512, 1024],
|
||||
*/
|
||||
TEST_F(TestTensorLayout, GenerateOptShardSliceShape1) {
|
||||
Arrangement device_arrangement;
|
||||
device_arrangement.Init({8, 4});
|
||||
Map tensor_map;
|
||||
tensor_map.Init({1, 0});
|
||||
Arrangement tensor_shape;
|
||||
tensor_shape.Init({512, 1024});
|
||||
TensorLayout tensor_layout;
|
||||
tensor_layout.Init(device_arrangement, tensor_map, tensor_shape);
|
||||
ASSERT_EQ(Status::FAILED, tensor_layout.GenerateOptShardSliceShape());
|
||||
}
|
||||
|
||||
/*
|
||||
* example:
|
||||
* device_arrangement = [8, 4],
|
||||
* tensor_map = [-1, 0],
|
||||
* tensor_shape = [512, 1024],
|
||||
*/
|
||||
TEST_F(TestTensorLayout, GenerateOptShardSliceShape2) {
|
||||
Arrangement device_arrangement;
|
||||
device_arrangement.Init({8, 4});
|
||||
Map tensor_map;
|
||||
tensor_map.Init({-1, 0});
|
||||
Arrangement tensor_shape;
|
||||
tensor_shape.Init({512, 1024});
|
||||
TensorLayout tensor_layout;
|
||||
tensor_layout.Init(device_arrangement, tensor_map, tensor_shape);
|
||||
ASSERT_EQ(Status::SUCCESS, tensor_layout.GenerateOptShardSliceShape());
|
||||
|
||||
Shape slice_shape_expect = {64, 256};
|
||||
ASSERT_EQ(tensor_layout.opt_shard_slice_shape(), slice_shape_expect);
|
||||
}
|
||||
|
||||
/*
|
||||
* example:
|
||||
* device_arrangement = [4, 4, 2],
|
||||
* tensor_map = [1, 0],
|
||||
* tensor_shape = [512, 1024],
|
||||
*/
|
||||
TEST_F(TestTensorLayout, GenerateOptShardSliceShape3) {
|
||||
Arrangement device_arrangement;
|
||||
device_arrangement.Init({4, 4, 2});
|
||||
Map tensor_map;
|
||||
tensor_map.Init({1, 0});
|
||||
Arrangement tensor_shape;
|
||||
tensor_shape.Init({512, 1024});
|
||||
TensorLayout tensor_layout;
|
||||
tensor_layout.Init(device_arrangement, tensor_map, tensor_shape);
|
||||
ASSERT_EQ(Status::SUCCESS, tensor_layout.GenerateOptShardSliceShape());
|
||||
|
||||
Shape slice_shape_expect = {32, 512};
|
||||
ASSERT_EQ(tensor_layout.opt_shard_slice_shape(), slice_shape_expect);
|
||||
}
|
||||
|
||||
/*
|
||||
* example:
|
||||
* device_arrangement = [4, 4, 2],
|
||||
* tensor_map = [1, 0],
|
||||
* tensor_shape = [20, 1024],
|
||||
*/
|
||||
TEST_F(TestTensorLayout, GenerateOptShardSliceShape4) {
|
||||
Arrangement device_arrangement;
|
||||
device_arrangement.Init({4, 4, 2});
|
||||
Map tensor_map;
|
||||
tensor_map.Init({1, 0});
|
||||
Arrangement tensor_shape;
|
||||
tensor_shape.Init({20, 1024});
|
||||
TensorLayout tensor_layout;
|
||||
tensor_layout.Init(device_arrangement, tensor_map, tensor_shape);
|
||||
ASSERT_EQ(Status::FAILED, tensor_layout.GenerateOptShardSliceShape());
|
||||
}
|
||||
|
||||
} // namespace parallel
|
||||
|
|
|
@ -28,7 +28,7 @@ using std::pow;
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
std::vector<Shape> combine(const Shape& in, int64_t target) {
|
||||
std::vector<Shape> combine(const Shape &in, int64_t target) {
|
||||
std::vector<Shape> output;
|
||||
for (int64_t i = 0; i < pow(2, in.size()); i++) {
|
||||
size_t temp = 0;
|
||||
|
@ -54,7 +54,7 @@ std::vector<Shape> combine(const Shape& in, int64_t target) {
|
|||
return output;
|
||||
}
|
||||
|
||||
void GenerateValidShapeBySizeAndDim(int64_t pow_size, int64_t dim, std::vector<Shape>* out) {
|
||||
void GenerateValidShapeBySizeAndDim(int64_t pow_size, int64_t dim, std::vector<Shape> *out) {
|
||||
out->clear();
|
||||
Shape in;
|
||||
for (int64_t i = 1; i < pow_size; i++) {
|
||||
|
@ -80,7 +80,7 @@ void GenerateValidShapeBySizeAndDim(int64_t pow_size, int64_t dim, std::vector<S
|
|||
return;
|
||||
}
|
||||
|
||||
void GenerateValidShapeBySize(int64_t pow_size, std::vector<Shape>* out) {
|
||||
void GenerateValidShapeBySize(int64_t pow_size, std::vector<Shape> *out) {
|
||||
out->clear();
|
||||
for (int64_t dim = 1; dim <= pow_size; dim++) {
|
||||
std::vector<Shape> combine_result;
|
||||
|
@ -92,7 +92,7 @@ void GenerateValidShapeBySize(int64_t pow_size, std::vector<Shape>* out) {
|
|||
return;
|
||||
}
|
||||
|
||||
TensorMap GenerateTensorMap(const int64_t& map_size, const Shape& pos_index, const Shape& pos_value) {
|
||||
TensorMap GenerateTensorMap(const int64_t &map_size, const Shape &pos_index, const Shape &pos_value) {
|
||||
TensorMap tensor_map(map_size, -1);
|
||||
for (size_t i = 0; i < pos_index.size() && i < pos_value.size(); i++) {
|
||||
if (pos_index[i] >= map_size) {
|
||||
|
@ -103,8 +103,8 @@ TensorMap GenerateTensorMap(const int64_t& map_size, const Shape& pos_index, con
|
|||
return tensor_map;
|
||||
}
|
||||
|
||||
void GenerateValidTensorMap(const DeviceArrangement& device_arrangement, const TensorShape& tensor_shape,
|
||||
std::vector<TensorMap>* tensor_map_list) {
|
||||
void GenerateValidTensorMap(const DeviceArrangement &device_arrangement, const TensorShape &tensor_shape,
|
||||
std::vector<TensorMap> *tensor_map_list) {
|
||||
tensor_map_list->clear();
|
||||
int64_t device_size = device_arrangement.size();
|
||||
int64_t shape_size = tensor_shape.size();
|
||||
|
@ -149,9 +149,8 @@ void GenerateValidTensorMap(const DeviceArrangement& device_arrangement, const T
|
|||
}
|
||||
|
||||
void GenerateValidLayoutByDeviceSizeAndTensorSize(
|
||||
int64_t device_pow_size, int64_t tensor_pow_size, int64_t max_device_dim,
|
||||
int64_t max_shape_dim,
|
||||
std::vector<std::tuple<DeviceArrangement, TensorMap, TensorShape>>* layout_list) {
|
||||
int64_t device_pow_size, int64_t tensor_pow_size, int64_t max_device_dim, int64_t max_shape_dim,
|
||||
std::vector<std::tuple<DeviceArrangement, TensorMap, TensorShape>> *layout_list) {
|
||||
layout_list->clear();
|
||||
std::vector<DeviceArrangement> device_arrangement_list;
|
||||
GenerateValidShapeBySize(device_pow_size, &device_arrangement_list);
|
||||
|
@ -174,8 +173,8 @@ void GenerateValidLayoutByDeviceSizeAndTensorSize(
|
|||
return;
|
||||
}
|
||||
|
||||
bool CheckLayoutValid(const DeviceArrangement& device_arrangement, const TensorMap& tensor_map,
|
||||
const TensorShape& tensor_shape) {
|
||||
bool CheckLayoutValid(const DeviceArrangement &device_arrangement, const TensorMap &tensor_map,
|
||||
const TensorShape &tensor_shape) {
|
||||
bool flag = false;
|
||||
if ((tensor_map.size() - ComputeNoneNumber(tensor_map)) > device_arrangement.size()) {
|
||||
return flag;
|
||||
|
@ -186,7 +185,7 @@ bool CheckLayoutValid(const DeviceArrangement& device_arrangement, const TensorM
|
|||
return true;
|
||||
}
|
||||
|
||||
size_t ComputeNoneNumber(const TensorMap& tensor_map) {
|
||||
size_t ComputeNoneNumber(const TensorMap &tensor_map) {
|
||||
size_t num = 0;
|
||||
for (size_t i = 0; i < tensor_map.size(); i++) {
|
||||
if (tensor_map[i] == -1) {
|
||||
|
@ -196,8 +195,8 @@ size_t ComputeNoneNumber(const TensorMap& tensor_map) {
|
|||
return num;
|
||||
}
|
||||
|
||||
bool ShapeIsDividedByDevice(const DeviceArrangement& device_arrangement, const TensorMap& tensor_map,
|
||||
const TensorShape& tensor_shape) {
|
||||
bool ShapeIsDividedByDevice(const DeviceArrangement &device_arrangement, const TensorMap &tensor_map,
|
||||
const TensorShape &tensor_shape) {
|
||||
bool flag = false;
|
||||
for (uint32_t i = 0; i < tensor_map.size() && i < tensor_shape.size(); i++) {
|
||||
if (tensor_map[i] == -1) {
|
||||
|
@ -211,7 +210,7 @@ bool ShapeIsDividedByDevice(const DeviceArrangement& device_arrangement, const T
|
|||
return true;
|
||||
}
|
||||
|
||||
bool IsExpended(const Shape& in1, const Shape& in2) {
|
||||
bool IsExpended(const Shape &in1, const Shape &in2) {
|
||||
int64_t size = 1;
|
||||
uint32_t ind = 0;
|
||||
for (uint32_t i = 0; i < in1.size(); i++) {
|
||||
|
@ -234,9 +233,9 @@ bool IsExpended(const Shape& in1, const Shape& in2) {
|
|||
return true;
|
||||
}
|
||||
|
||||
void ComputeAccumDeviceTOAccumShapeMap(const DeviceArrangement& device_arrangement,
|
||||
const TensorMap& tensor_map, const TensorShape& tensor_shape,
|
||||
std::map<int64_t, int64_t>* accum_device_to_accum_shape_map) {
|
||||
void ComputeAccumDeviceTOAccumShapeMap(const DeviceArrangement &device_arrangement, const TensorMap &tensor_map,
|
||||
const TensorShape &tensor_shape,
|
||||
std::map<int64_t, int64_t> *accum_device_to_accum_shape_map) {
|
||||
accum_device_to_accum_shape_map->clear();
|
||||
std::vector<int64_t> shape_accum_reverse;
|
||||
Status status = ShapeToAccumulateProductReverse(tensor_shape, &shape_accum_reverse);
|
||||
|
@ -263,12 +262,10 @@ void IsLinearValue(int64_t small, int64_t big, int64_t small_value, int64_t big_
|
|||
ASSERT_EQ(middle_value, value);
|
||||
}
|
||||
|
||||
void LayoutTransferValidLayoutChangeCheck(const DeviceArrangement& in_device_arrangement,
|
||||
const TensorMap& in_tensor_map,
|
||||
const TensorShape& in_tensor_shape,
|
||||
const DeviceArrangement& out_device_arrangement,
|
||||
const TensorMap& out_tensor_map,
|
||||
const TensorShape& out_tensor_shape) {
|
||||
void LayoutTransferValidLayoutChangeCheck(const DeviceArrangement &in_device_arrangement,
|
||||
const TensorMap &in_tensor_map, const TensorShape &in_tensor_shape,
|
||||
const DeviceArrangement &out_device_arrangement,
|
||||
const TensorMap &out_tensor_map, const TensorShape &out_tensor_shape) {
|
||||
bool is_expended = IsExpended(out_device_arrangement, in_device_arrangement);
|
||||
ASSERT_EQ(true, is_expended);
|
||||
is_expended = IsExpended(out_tensor_shape, in_tensor_shape);
|
||||
|
@ -317,10 +314,9 @@ void LayoutTransferValidLayoutChangeCheck(const DeviceArrangement& in_device_arr
|
|||
}
|
||||
}
|
||||
|
||||
void ValidLayoutChangeCheck(const DeviceArrangement& in_device_arrangement,
|
||||
const TensorMap& in_tensor_map, const TensorShape& in_tensor_shape,
|
||||
const DeviceArrangement& out_device_arrangement,
|
||||
const TensorMap& out_tensor_map, const TensorShape& out_tensor_shape) {
|
||||
void ValidLayoutChangeCheck(const DeviceArrangement &in_device_arrangement, const TensorMap &in_tensor_map,
|
||||
const TensorShape &in_tensor_shape, const DeviceArrangement &out_device_arrangement,
|
||||
const TensorMap &out_tensor_map, const TensorShape &out_tensor_shape) {
|
||||
LayoutTransferValidLayoutChangeCheck(in_device_arrangement, in_tensor_map, in_tensor_shape, out_device_arrangement,
|
||||
out_tensor_map, out_tensor_shape);
|
||||
}
|
||||
|
|
|
@ -26,45 +26,41 @@
|
|||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
||||
std::vector<Shape> combine(const Shape& in, int64_t target);
|
||||
std::vector<Shape> combine(const Shape &in, int64_t target);
|
||||
|
||||
void GenerateValidShapeBySizeAndDim(int64_t pow_size, int64_t dim, std::vector<Shape>* out);
|
||||
void GenerateValidShapeBySizeAndDim(int64_t pow_size, int64_t dim, std::vector<Shape> *out);
|
||||
|
||||
void GenerateValidShapeBySize(int64_t pow_size, std::vector<Shape>* out);
|
||||
void GenerateValidShapeBySize(int64_t pow_size, std::vector<Shape> *out);
|
||||
|
||||
TensorMap GenerateTensorMap(const int64_t& map_size, const Shape& pos_index, const Shape& pos_value);
|
||||
TensorMap GenerateTensorMap(const int64_t &map_size, const Shape &pos_index, const Shape &pos_value);
|
||||
|
||||
void GenerateValidTensorMap(const DeviceArrangement& device_arrangement, const TensorMap& tensor_shape,
|
||||
std::vector<TensorMap>* tensor_map_list);
|
||||
void GenerateValidTensorMap(const DeviceArrangement &device_arrangement, const TensorMap &tensor_shape,
|
||||
std::vector<TensorMap> *tensor_map_list);
|
||||
|
||||
void GenerateValidLayoutByDeviceSizeAndTensorSize(
|
||||
int64_t device_pow_size, int64_t tensor_pow_size, int64_t max_device_dim,
|
||||
int64_t max_shape_dim,
|
||||
std::vector<std::tuple<DeviceArrangement, TensorMap, TensorShape>>* layout_list);
|
||||
int64_t device_pow_size, int64_t tensor_pow_size, int64_t max_device_dim, int64_t max_shape_dim,
|
||||
std::vector<std::tuple<DeviceArrangement, TensorMap, TensorShape>> *layout_list);
|
||||
|
||||
size_t ComputeNoneNumber(const TensorMap& tensor_map);
|
||||
size_t ComputeNoneNumber(const TensorMap &tensor_map);
|
||||
|
||||
bool ShapeIsDividedByDevice(const DeviceArrangement& device_arrangement, const TensorMap& tensor_map,
|
||||
const TensorShape& tensor_shape);
|
||||
bool ShapeIsDividedByDevice(const DeviceArrangement &device_arrangement, const TensorMap &tensor_map,
|
||||
const TensorShape &tensor_shape);
|
||||
|
||||
bool CheckLayoutValid(const DeviceArrangement& device_arrangement, const TensorMap& tensor_map,
|
||||
const TensorShape& tensor_shape);
|
||||
bool CheckLayoutValid(const DeviceArrangement &device_arrangement, const TensorMap &tensor_map,
|
||||
const TensorShape &tensor_shape);
|
||||
|
||||
void ComputeAccumDeviceTOAccumShapeMap(const DeviceArrangement& device_arrangement,
|
||||
const TensorMap& tensor_map, const TensorShape& tensor_shape,
|
||||
std::map<int64_t, int64_t>* accum_device_to_accum_shape_map);
|
||||
void ComputeAccumDeviceTOAccumShapeMap(const DeviceArrangement &device_arrangement, const TensorMap &tensor_map,
|
||||
const TensorShape &tensor_shape,
|
||||
std::map<int64_t, int64_t> *accum_device_to_accum_shape_map);
|
||||
|
||||
void LayoutTransferValidLayoutChangeCheck(const DeviceArrangement& in_device_arrangement,
|
||||
const TensorMap& in_tensor_map,
|
||||
const TensorShape& in_tensor_shape,
|
||||
const DeviceArrangement& out_device_arrangement,
|
||||
const TensorMap& out_tensor_map,
|
||||
const TensorShape& out_tensor_shape);
|
||||
void LayoutTransferValidLayoutChangeCheck(const DeviceArrangement &in_device_arrangement,
|
||||
const TensorMap &in_tensor_map, const TensorShape &in_tensor_shape,
|
||||
const DeviceArrangement &out_device_arrangement,
|
||||
const TensorMap &out_tensor_map, const TensorShape &out_tensor_shape);
|
||||
|
||||
void ValidLayoutChangeCheck(const DeviceArrangement& in_device_arrangement,
|
||||
const TensorMap& in_tensor_map, const TensorShape& in_tensor_shape,
|
||||
const DeviceArrangement& out_device_arrangement,
|
||||
const TensorMap& out_tensor_map, const TensorShape& out_tensor_shape);
|
||||
void ValidLayoutChangeCheck(const DeviceArrangement &in_device_arrangement, const TensorMap &in_tensor_map,
|
||||
const TensorShape &in_tensor_shape, const DeviceArrangement &out_device_arrangement,
|
||||
const TensorMap &out_tensor_map, const TensorShape &out_tensor_shape);
|
||||
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
"""api definition"""
|
||||
import threading
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
|
||||
|
||||
class Hccl():
|
||||
|
@ -62,7 +63,9 @@ def get_rank_id(group=None):
|
|||
def get_rank_size(group=None):
|
||||
hccl = Hccl()
|
||||
if group is None or "nccl_world_group" in group:
|
||||
return hccl.rank_size
|
||||
if auto_parallel_context().get_device_num_is_set() is False:
|
||||
return 1
|
||||
return auto_parallel_context().get_device_num()
|
||||
if isinstance(group, str):
|
||||
return int(group.split("-")[0])
|
||||
raise ValueError
|
||||
|
|
|
@ -49,8 +49,8 @@ def test_get_parameter_layout():
|
|||
net.set_auto_parallel()
|
||||
exe = me._executor
|
||||
exe.compile(net, x, phase='train', auto_parallel_mode=True)
|
||||
x_layout = [[2, 4], [1, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [1, -1]
|
||||
weight_layout = [[2, 4], [0, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [0, -1]
|
||||
x_layout = ([2, 4], [1, -1], [16, 32], 0, True, '') # device_arrangement = [2, 4], tensor_map = [1, -1]
|
||||
weight_layout = ([2, 4], [0, -1], [16, 32], 0, True, '') # device_arrangement = [2, 4], tensor_map = [0, -1]
|
||||
expect_dict = {'x': x_layout, 'w1': weight_layout}
|
||||
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
|
||||
assert net.parameter_layout_dict == expect_dict
|
||||
|
|
|
@ -17,16 +17,17 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb
|
||||
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||
from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb, Momentum
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import context
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""Net definition"""
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.fc1 = nn.Dense(128, 768, activation='relu')
|
||||
|
@ -50,6 +51,56 @@ class Net(nn.Cell):
|
|||
return s
|
||||
|
||||
|
||||
class Net2(nn.Cell):
|
||||
"""Net definition"""
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super(Net2, self).__init__()
|
||||
self.fc1 = P.MatMul().shard(strategy=strategy1)
|
||||
self.fc2 = P.MatMul().shard(strategy=strategy2)
|
||||
self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1")
|
||||
self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(np.float32)), name="weight2")
|
||||
|
||||
def construct(self, x, y):
|
||||
x = self.fc1(x, self.p1)
|
||||
x = self.fc2(x, self.p2)
|
||||
return x - y
|
||||
|
||||
|
||||
def auto_parallel_compile_net(mode, dev_num, strategy1=None, strategy2=None):
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.set_auto_parallel_context(parallel_mode=mode, device_num=dev_num, enable_parallel_optimizer=True)
|
||||
inputs = Tensor(np.ones([32, 48]).astype(np.float32))
|
||||
label = Tensor(np.zeros([32, 16]).astype(np.float32))
|
||||
net = Net2(strategy1, strategy2)
|
||||
net = _VirtualDatasetCell(net)
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_network = TrainOneStepCell(net, optimizer)
|
||||
train_network.set_auto_parallel()
|
||||
_executor.compile(train_network, inputs, label)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_auto_parallel_momentum_1():
|
||||
auto_parallel_compile_net("auto_parallel", 8)
|
||||
|
||||
|
||||
def test_auto_parallel_momentum_2():
|
||||
# data parallel case
|
||||
auto_parallel_compile_net("auto_parallel", 8, ((8, 1), (1, 1)), ((8, 1), (1, 1)))
|
||||
|
||||
|
||||
def test_auto_parallel_momentum_3():
|
||||
# hybrid parallel case
|
||||
# weight1 could not be shard and weight2 is repeated
|
||||
auto_parallel_compile_net("semi_auto_parallel", 32, ((4, 8), (8, 1)), ((4, 4), (4, 2)))
|
||||
|
||||
|
||||
def test_auto_parallel_momentum_4():
|
||||
# hybrid parallel cases
|
||||
# devices are repeatedly used
|
||||
auto_parallel_compile_net("semi_auto_parallel", 32, ((4, 4), (4, 1)), ((4, 4), (4, 2)))
|
||||
|
||||
|
||||
def test_AdamWeightDecay():
|
||||
""" test_AdamWeightDecay """
|
||||
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
|
||||
|
@ -98,6 +149,7 @@ def test_lamb_split_fusion():
|
|||
_executor.compile(train_network, inputs, label)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_edge_case():
|
||||
""" test_edge_case """
|
||||
context.set_auto_parallel_context(enable_parallel_optimizer=True)
|
||||
|
|
|
@ -121,10 +121,10 @@ def test_grad_sens_parameter_type():
|
|||
sens = Tensor(np.ones([128, 64]), dtype=ms.float32)
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, b, sens, phase='train', auto_parallel_mode=True)
|
||||
x_layout = [[8, 8], [1, -1], [16, 32], [0], [1]]
|
||||
y_layout = [[8, 8], [-1, 0], [32, 8], [0], [1]]
|
||||
b_layout = [[8, 8], [0, -1], [8, 64], [0], [1]]
|
||||
sens_layout = [[8, 8], [1, -1], [16, 64], [0], [1]]
|
||||
x_layout = ([8, 8], [1, -1], [16, 32], 0, True, '')
|
||||
y_layout = ([8, 8], [-1, 0], [32, 8], 0, True, '')
|
||||
b_layout = ([8, 8], [0, -1], [8, 64], 0, True, '')
|
||||
sens_layout = ([8, 8], [1, -1], [16, 64], 0, True, '')
|
||||
expect_dict = {'x': x_layout, 'y': y_layout, 'b': b_layout, 'sens': sens_layout}
|
||||
assert net.parameter_layout_dict == expect_dict
|
||||
|
||||
|
|
Loading…
Reference in New Issue