forked from mindspore-Ecosystem/mindspore
update
This commit is contained in:
parent
e32d539b5f
commit
563622874a
|
@ -28,9 +28,14 @@ namespace parallel {
|
|||
std::string GetOpPythonPath(const OperatorName &op_name) {
|
||||
// almost all ops are defined in two main paths
|
||||
const std::string ops_module = OP_PATH;
|
||||
const std::string inner_ops_module = INNER_OP_PATH;
|
||||
py::module mod = py::module::import(common::SafeCStr(ops_module));
|
||||
py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module));
|
||||
if (!py::hasattr(mod, common::SafeCStr(op_name))) {
|
||||
MS_LOG(EXCEPTION) << ops_module << " don't have op:" << op_name;
|
||||
if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) {
|
||||
MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name;
|
||||
}
|
||||
return inner_ops_module;
|
||||
}
|
||||
return ops_module;
|
||||
}
|
||||
|
|
|
@ -56,6 +56,12 @@ Status GatherV2PInfo::GetAttrs() {
|
|||
}
|
||||
}
|
||||
|
||||
// target=CPU, axis must be 0
|
||||
if (target_ == "CPU" && axis_ != 0) {
|
||||
MS_LOG(ERROR) << name_ << ": target is CPU, axis must be 0, but got " << axis_;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -279,6 +285,11 @@ Status GatherV2PInfo::InferBias() {
|
|||
int32_t rank = g_device_manager->global_rank();
|
||||
auto input_shape = inputs_shape_.at(0);
|
||||
auto params_strategy = strategy_->GetInputDim().at(0);
|
||||
// axis don't split
|
||||
if (params_strategy.at(axis_) == 1) {
|
||||
bias_ = 0;
|
||||
return SUCCESS;
|
||||
}
|
||||
// params_size=1, axis=0
|
||||
if ((input_shape.size() == 1) && (axis_ == 0)) {
|
||||
slice_size_ = input_shape.at(0) / params_strategy.at(0);
|
||||
|
@ -353,26 +364,35 @@ Status GatherV2PInfo::InferForwardCommunication() {
|
|||
}
|
||||
auto group_size = group_.GetDevNum();
|
||||
Attr attr_group;
|
||||
// group size <= 8
|
||||
std::vector<int32_t> rank_list;
|
||||
if (group_size <= 8) {
|
||||
reduce_scatter_flag_ = false;
|
||||
operator_name = HOST_REDUCE_SCATTER;
|
||||
rank_list = GetRankFromGroup(group_);
|
||||
attr_group = std::make_pair(GROUP, MakeValue(rank_list));
|
||||
} else {
|
||||
// group size > 8
|
||||
reduce_scatter_flag_ = true;
|
||||
split_num_ = SizeToInt(group_size / 8);
|
||||
CheckGlobalDeviceManager();
|
||||
operator_name = REDUCE_SCATTER;
|
||||
int32_t rank = g_device_manager->global_rank();
|
||||
size_t repeat = group_size / 8;
|
||||
for (size_t i = 0; i < repeat; ++i) {
|
||||
rank_list.push_back(rank + SizeToInt(i * 8));
|
||||
if (host_reduce_scatter_) {
|
||||
// group size <= 8
|
||||
std::vector<int32_t> rank_list;
|
||||
if (group_size <= 8) {
|
||||
reduce_scatter_flag_ = false;
|
||||
operator_name = HOST_REDUCE_SCATTER;
|
||||
rank_list = GetRankFromGroup(group_);
|
||||
attr_group = std::make_pair(GROUP, MakeValue(rank_list));
|
||||
} else {
|
||||
// group size > 8, don't support host reduce_scatter
|
||||
reduce_scatter_flag_ = true;
|
||||
split_num_ = SizeToInt(group_size / 8);
|
||||
CheckGlobalDeviceManager();
|
||||
operator_name = REDUCE_SCATTER;
|
||||
int32_t rank = g_device_manager->global_rank();
|
||||
size_t repeat = group_size / 8;
|
||||
for (size_t i = 0; i < repeat; ++i) {
|
||||
rank_list.push_back(rank + SizeToInt(i * 8));
|
||||
}
|
||||
Group g = g_device_manager->CreateGroup(rank_list);
|
||||
attr_group = std::make_pair(GROUP, MakeValue(g.name()));
|
||||
}
|
||||
Group g = g_device_manager->CreateGroup(rank_list);
|
||||
attr_group = std::make_pair(GROUP, MakeValue(g.name()));
|
||||
} else {
|
||||
operator_name = REDUCE_SCATTER;
|
||||
if (InferGroup() != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Infer Group failed.";
|
||||
return FAILED;
|
||||
}
|
||||
attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
|
||||
}
|
||||
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
|
||||
OperatorAttrs attrs = {attr_op, attr_group};
|
||||
|
@ -446,8 +466,8 @@ Status GatherV2PInfo::ComputeReplaceOp() {
|
|||
Attr param_offset = std::make_pair("offset", MakeValue(bias_));
|
||||
Attr param_flag = std::make_pair("reduce_scatter_flag", MakeValue(reduce_scatter_flag_));
|
||||
Attr param_split_num = std::make_pair("split_num", MakeValue(split_num_));
|
||||
OperatorParams params = {std::make_pair(param_offset, 4), std::make_pair(param_flag, 5),
|
||||
std::make_pair(param_split_num, 6)};
|
||||
OperatorParams params = {std::make_pair(param_offset, 3), std::make_pair(param_flag, 4),
|
||||
std::make_pair(param_split_num, 5)};
|
||||
OperatorArgs args = std::make_pair(attrs, params);
|
||||
Operator op = std::make_pair(op_name, args);
|
||||
replace_op_.push_back(op);
|
||||
|
|
|
@ -70,6 +70,7 @@ class GatherV2PInfo : public OperatorInfo {
|
|||
Group group_;
|
||||
bool reduce_scatter_flag_ = false;
|
||||
int32_t split_num_ = 1;
|
||||
bool host_reduce_scatter_ = false;
|
||||
};
|
||||
|
||||
class SparseGatherV2Info : public GatherV2PInfo {
|
||||
|
|
|
@ -55,6 +55,7 @@ constexpr char REDUCE_OP_SUM[] = "sum";
|
|||
constexpr char REDUCE_OP_MAX[] = "max";
|
||||
constexpr char REDUCE_OP_MIN[] = "min";
|
||||
constexpr char OP_PATH[] = "mindspore.ops.operations";
|
||||
constexpr char INNER_OP_PATH[] = "mindspore.ops.operations._inner_ops";
|
||||
constexpr char GET_OP_FUNCTION_PATH[] = "mindspore.parallel._utils";
|
||||
constexpr char GET_OP_FUNCTION[] = "_get_python_op";
|
||||
constexpr char KEEP_DIMS[] = "keep_dims";
|
||||
|
|
|
@ -536,7 +536,7 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st
|
|||
std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)};
|
||||
auto prim = GetValueNode<PrimitivePtr>(node->input(0));
|
||||
if (prim->name() == GATHERV2 || prim->name() == SPARSE_GATHERV2) {
|
||||
replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2), node->input(3)};
|
||||
replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)};
|
||||
}
|
||||
if (!params.empty()) {
|
||||
Param param_first = *(params.begin());
|
||||
|
|
|
@ -184,7 +184,7 @@ def test_gatherv2_auto1():
|
|||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def need_fix_test_gatherv2_cpu0():
|
||||
def test_gatherv2_cpu0():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((8, 1), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
|
@ -196,7 +196,7 @@ def need_fix_test_gatherv2_cpu0():
|
|||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def need_fix_test_gatherv2_cpu1():
|
||||
def test_gatherv2_cpu1():
|
||||
context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((16, 1), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
|
@ -208,7 +208,7 @@ def need_fix_test_gatherv2_cpu1():
|
|||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def need_fix_test_gatherv2_cpu2():
|
||||
def test_gatherv2_cpu2():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((1, 8), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
|
|
|
@ -184,7 +184,7 @@ def test_gatherv2_auto1():
|
|||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def need_fix_test_gatherv2_cpu0():
|
||||
def test_gatherv2_cpu0():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((8, 1), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
|
@ -196,7 +196,7 @@ def need_fix_test_gatherv2_cpu0():
|
|||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def need_fix_test_gatherv2_cpu1():
|
||||
def test_gatherv2_cpu1():
|
||||
context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((16, 1), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
|
@ -208,7 +208,7 @@ def need_fix_test_gatherv2_cpu1():
|
|||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def need_fix_test_gatherv2_cpu2():
|
||||
def test_gatherv2_cpu2():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((1, 8), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
|
|
Loading…
Reference in New Issue