forked from mindspore-Ecosystem/mindspore
auto parallel support adafactor opt
This commit is contained in:
parent
62a780f20f
commit
7ca64d2235
|
@ -23,6 +23,7 @@
|
|||
#include "frontend/parallel/allreduce_fusion/allreduce_graph.h"
|
||||
#include "frontend/parallel/status.h"
|
||||
#include "frontend/parallel/ops_info/ops_utils.h"
|
||||
#include "frontend/parallel/step_parallel_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "pipeline/jit/parse/python_adapter.h"
|
||||
#include "frontend/parallel/ops_info/ops_utils.h"
|
||||
#include "frontend/parallel/step_parallel.h"
|
||||
#include "frontend/parallel/step_parallel_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
|
|
@ -0,0 +1,619 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/parallel/parameter_manager.h"
|
||||
|
||||
#include <inttypes.h>
|
||||
#include <sys/time.h>
|
||||
#include <algorithm>
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "base/core_ops.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "frontend/parallel/device_manager.h"
|
||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||
#include "frontend/parallel/graph_util/graph_info.h"
|
||||
#include "frontend/parallel/graph_util/node_info.h"
|
||||
#include "frontend/parallel/graph_util/pipeline_split_utils.h"
|
||||
#include "frontend/parallel/node_check.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "utils/trace_base.h"
|
||||
#include "utils/comm_manager.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "mindspore/core/utils/parallel_node_check.h"
|
||||
#include "frontend/parallel/step_parallel_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
static ParameterUsersInfo FindRefKeyNodeUsers(const RefKeyPair &ref_key_pair, bool (*IsCareNode)(const CNodePtr &)) {
|
||||
// Dealing with the RefKey case
|
||||
ParameterUsersInfo parameter_user_info;
|
||||
auto refkeys = ref_key_pair.second;
|
||||
auto cnode = ref_key_pair.first;
|
||||
|
||||
auto cnode_ptr = cnode->cast<CNodePtr>();
|
||||
if ((cnode_ptr == nullptr) || !IsValueNode<Primitive>(cnode_ptr->input(0)) || !IsCareNode(cnode_ptr)) {
|
||||
return parameter_user_info;
|
||||
}
|
||||
|
||||
if (refkeys.size() > 1) {
|
||||
MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << "'s inputs have more than 1 RefKeys";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(cnode->func_graph());
|
||||
auto cnode_func_graph = cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager());
|
||||
|
||||
// Find the RefKey being used
|
||||
auto candidate_set_by_refkey = cnode_func_graph->manager()->node_users()[refkeys[0]];
|
||||
for (auto &candidate : candidate_set_by_refkey) {
|
||||
auto candidate_node = candidate.first;
|
||||
auto c = candidate_node->cast<CNodePtr>();
|
||||
if ((c == nullptr) || !IsValueNode<Primitive>(c->input(0)) || !IsCareNode(c)) {
|
||||
continue;
|
||||
}
|
||||
parameter_user_info.second.second.add(candidate);
|
||||
}
|
||||
|
||||
// Find the corresponding Parameter being used
|
||||
std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(refkeys[0], cnode_func_graph);
|
||||
if (parameters.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
|
||||
}
|
||||
parameter_user_info.first = parameters[0]->cast<ParameterPtr>()->name();
|
||||
parameter_user_info.second.first = parameters[0];
|
||||
auto candidate_set_by_para = cnode_func_graph->manager()->node_users()[parameters[0]];
|
||||
for (auto &candidate : candidate_set_by_para) {
|
||||
auto candidate_node = candidate.first;
|
||||
auto c = candidate_node->cast<CNodePtr>();
|
||||
if ((c == nullptr) || !IsValueNode<Primitive>(c->input(0)) || !IsCareNode(c)) {
|
||||
continue;
|
||||
}
|
||||
(void)parameter_user_info.second.second.insert(candidate);
|
||||
}
|
||||
return parameter_user_info;
|
||||
}
|
||||
|
||||
static ParameterUsersInfo FindParameterNodeUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &)) {
|
||||
// In this case, node is a Parameter
|
||||
ParameterUsersInfo parameter_user_info;
|
||||
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||
MS_EXCEPTION_IF_NULL(node->func_graph()->manager());
|
||||
auto candidate_set = node->func_graph()->manager()->node_users()[node];
|
||||
for (auto &candidate : candidate_set) {
|
||||
auto candidate_node = candidate.first;
|
||||
if (IsPrimitiveCNode(candidate_node, prim::kPrimLoad)) {
|
||||
if (candidate.second != 1) {
|
||||
continue;
|
||||
}
|
||||
auto load_node_users = node->func_graph()->manager()->node_users()[candidate_node];
|
||||
for (auto &node_user : load_node_users) {
|
||||
auto cnode = node_user.first->cast<CNodePtr>();
|
||||
if (cnode == nullptr || !cnode->has_user_data<OperatorInfo>() || IsSomePrimitive(cnode, RECEIVE)) {
|
||||
continue;
|
||||
}
|
||||
(void)parameter_user_info.second.second.insert(node_user);
|
||||
}
|
||||
} else {
|
||||
auto c = candidate_node->cast<CNodePtr>();
|
||||
if (c == nullptr || !c->has_user_data<OperatorInfo>() || IsSomePrimitive(c, RECEIVE)) {
|
||||
continue;
|
||||
}
|
||||
(void)parameter_user_info.second.second.insert(candidate);
|
||||
}
|
||||
}
|
||||
parameter_user_info.first = node->cast<ParameterPtr>()->name();
|
||||
parameter_user_info.second.first = node;
|
||||
return parameter_user_info;
|
||||
}
|
||||
|
||||
static RefKeyPair CNodeWithRefKeys(const AnfNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::vector<AnfNodePtr> refkeys;
|
||||
if (cnode->isa<CNode>()) {
|
||||
auto cnode_ptr = cnode->cast<CNodePtr>();
|
||||
auto inputs = cnode_ptr->inputs();
|
||||
for (auto &one_input : inputs) {
|
||||
if (IsValueNode<RefKey>(one_input)) {
|
||||
refkeys.push_back(one_input);
|
||||
}
|
||||
}
|
||||
if (refkeys.size() >= 1) {
|
||||
return std::make_pair(cnode, refkeys);
|
||||
}
|
||||
}
|
||||
return {nullptr, refkeys};
|
||||
}
|
||||
|
||||
ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &)) {
|
||||
ParameterUsersInfo parameter_users_info;
|
||||
|
||||
auto cnode_with_refkeys = CNodeWithRefKeys(node);
|
||||
if (cnode_with_refkeys.first != nullptr) {
|
||||
// the node is a ref key node
|
||||
return FindRefKeyNodeUsers(cnode_with_refkeys, IsCareNode);
|
||||
} else if (node->isa<Parameter>()) {
|
||||
// the node is a parameter node
|
||||
return FindParameterNodeUsers(node, IsCareNode);
|
||||
}
|
||||
|
||||
return parameter_users_info;
|
||||
}
|
||||
|
||||
static bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr ¶meter) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
auto manager = graph->manager();
|
||||
auto node_users = manager->node_users()[parameter];
|
||||
if (node_users.empty()) {
|
||||
return false;
|
||||
}
|
||||
for (auto node_user : node_users) {
|
||||
auto use_node = node_user.first->cast<CNodePtr>();
|
||||
if (IsValueNode<FuncGraph>(use_node->input(0))) {
|
||||
auto graph_sub = GetValueNode<FuncGraphPtr>(use_node->input(0));
|
||||
auto parameters = graph_sub->parameters();
|
||||
auto parameter_sub = parameters[node_user.second - 1];
|
||||
return IsUsedParameter(graph_sub, parameter_sub);
|
||||
}
|
||||
if (use_node->input(0)->isa<CNode>()) {
|
||||
auto cnode = use_node->input(0)->cast<CNodePtr>();
|
||||
if (!IsSomePrimitive(cnode, J) || !IsValueNode<FuncGraph>(cnode->input(1))) {
|
||||
return true;
|
||||
}
|
||||
auto graph_sub = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
auto parameters = graph_sub->parameters();
|
||||
auto parameter_sub = parameters[node_user.second - 1];
|
||||
return IsUsedParameter(graph_sub, parameter_sub);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static RankList GetGroupByTensorInfo(const TensorInfo &tensor_info) {
|
||||
CheckGlobalDeviceManager();
|
||||
int64_t rank = g_device_manager->global_rank();
|
||||
RankList stage_device_list = g_device_manager->GetDeviceListInThisStage();
|
||||
Shape dev_matrix_shape = tensor_info.tensor_layout().device_arrangement().array();
|
||||
Shape tensor_map = tensor_info.tensor_layout().tensor_map().array();
|
||||
|
||||
DeviceMatrix dev_matrix(rank, stage_device_list, dev_matrix_shape);
|
||||
RankList group_devices;
|
||||
if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Get devices by tensor map failed";
|
||||
}
|
||||
|
||||
std::sort(group_devices.begin(), group_devices.end());
|
||||
return group_devices;
|
||||
}
|
||||
|
||||
static ParameterSliceInfo GetParameterSliceInfo(const std::pair<AnfNodePtr, int64_t> ¶m_info) {
|
||||
auto user_cnode = param_info.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(user_cnode);
|
||||
auto user_input_index = param_info.second;
|
||||
OperatorInfoPtr op_info = user_cnode->user_data<OperatorInfo>();
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
|
||||
TensorInfo tensor_info;
|
||||
if (IsPrimitiveCNode(user_cnode, prim::kPrimSend)) {
|
||||
auto param_index = IntToSize(GetValue<int>(user_cnode->GetPrimalAttr(PARAM_INDEX)));
|
||||
tensor_info = op_info->inputs_tensor_info()[param_index];
|
||||
} else {
|
||||
size_t input_tensor_info_size = op_info->inputs_tensor_info().size();
|
||||
if (SizeToLong(input_tensor_info_size) <= user_input_index - 1) {
|
||||
MS_LOG(EXCEPTION) << op_info->name() << ": the size of inputs tensor info is " << input_tensor_info_size
|
||||
<< ", but the index is " << user_input_index - 1;
|
||||
}
|
||||
tensor_info = op_info->inputs_tensor_info()[user_input_index - 1];
|
||||
}
|
||||
|
||||
ParameterSliceInfo parameter_slice_info;
|
||||
parameter_slice_info.slice_shape = tensor_info.slice_shape();
|
||||
parameter_slice_info.group_ranks = GetGroupByTensorInfo(tensor_info);
|
||||
MS_LOG(DEBUG) << "The op name is " << op_info->name() << ", the parameter index is " << user_input_index - 1
|
||||
<< ", the slice shape is " << tensor_info.slice_shape() << ", the origin shape is "
|
||||
<< tensor_info.shape() << ", the group rank list is " << parameter_slice_info.group_ranks;
|
||||
return parameter_slice_info;
|
||||
}
|
||||
|
||||
void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) {
|
||||
for (auto &node : all_nodes) {
|
||||
ParameterUsersInfo parameter_users_info = FindParameterUsers(node, IsParallelCareNode);
|
||||
auto users_set = parameter_users_info.second.second;
|
||||
if (users_set.size() <= 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto parameter_name = parameter_users_info.first;
|
||||
MS_LOG(INFO) << "The parameter: " << parameter_name << " has " << users_set.size() << " users";
|
||||
auto first_user = users_set.pop();
|
||||
ParameterSliceInfo parameter_slice_info = GetParameterSliceInfo(first_user);
|
||||
Shape first_user_slice_shape = parameter_slice_info.slice_shape;
|
||||
RankList first_user_group_list = parameter_slice_info.group_ranks;
|
||||
|
||||
for (auto &user : users_set) {
|
||||
ParameterSliceInfo user_slice_info = GetParameterSliceInfo(user);
|
||||
Shape user_slice_shape = user_slice_info.slice_shape;
|
||||
RankList user_group_list = user_slice_info.group_ranks;
|
||||
if (first_user_slice_shape != user_slice_shape) {
|
||||
MS_LOG(EXCEPTION) << "The parameter: " << parameter_name
|
||||
<< " has multiple users, but the slice shapes are different";
|
||||
}
|
||||
|
||||
if (ParallelContext::GetInstance()->pipeline_stage_split_num() == 1 && first_user_group_list != user_group_list) {
|
||||
MS_LOG(EXCEPTION) << "The parameter: " << parameter_name
|
||||
<< " has multiple users, but the group rank list are different, "
|
||||
<< "the group rank list for first user is " << first_user_group_list
|
||||
<< ", and the group rank list for this user is " << user_group_list;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
void RevertSymbolicKeyInstance(const FuncGraphPtr &root, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto symbolic_key = GetValueNode<SymbolicKeyInstancePtr>(node);
|
||||
MS_EXCEPTION_IF_NULL(symbolic_key);
|
||||
auto all_upstream_node = root->manager()->node_users()[node];
|
||||
for (auto &upstream_node : all_upstream_node) {
|
||||
FuncGraphPtr fg = upstream_node.first->func_graph();
|
||||
if (symbolic_key->node()->isa<Parameter>()) {
|
||||
for (auto ¶m : root->parameters()) {
|
||||
if (*param == *symbolic_key->node()) {
|
||||
AnfNodePtr reverted_node = root->NewCNode({NewValueNode(prim::kPrimEmbed), param});
|
||||
MS_EXCEPTION_IF_NULL(reverted_node);
|
||||
MS_LOG(DEBUG) << "before replace " << node->ToString() << " to node " << reverted_node->DebugString();
|
||||
(void)fg->manager()->Replace(node, reverted_node);
|
||||
MS_LOG(DEBUG) << "revert node " << node->ToString() << " to node " << reverted_node->DebugString();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
for (auto &node : all_nodes) {
|
||||
// revert back SymbolicKeyInstance to embed() primitive
|
||||
if (IsValueNode<SymbolicKeyInstance>(node)) {
|
||||
RevertSymbolicKeyInstance(root, node);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool ParameterIsCloned(const AnfNodePtr ¶meter_node) {
|
||||
MS_EXCEPTION_IF_NULL(parameter_node);
|
||||
auto cloned_parameter = parameter_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(cloned_parameter);
|
||||
|
||||
// find the clone parameter
|
||||
if (!cloned_parameter->has_default()) {
|
||||
return false;
|
||||
}
|
||||
auto param_value = cloned_parameter->param_info();
|
||||
if (param_value == nullptr) {
|
||||
return false;
|
||||
}
|
||||
bool cloned = param_value->cloned();
|
||||
if (!cloned) {
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned";
|
||||
return true;
|
||||
}
|
||||
|
||||
void HandleNoUsedParameter(const FuncGraphPtr &root) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
bool full_batch = ParallelContext::GetInstance()->full_batch();
|
||||
if (full_batch) {
|
||||
return;
|
||||
}
|
||||
|
||||
// in grad accumulation mode, if use dynamic lr, it has some parameters in optimizer which no used for first graph,
|
||||
// but used for second graph(such as global_step), so can not change their shapes
|
||||
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
||||
if (grad_accumulation_step > 1) {
|
||||
MS_LOG(INFO) << "In grad accumulation mode, do not handle no used parameters";
|
||||
return;
|
||||
}
|
||||
|
||||
auto dev_num = g_device_manager->stage_device_num();
|
||||
auto parameters = root->parameters();
|
||||
for (auto ¶meter : parameters) {
|
||||
if (IsUsedParameter(root, parameter)) {
|
||||
continue;
|
||||
}
|
||||
auto parameter_shape = GetNodeShape(parameter);
|
||||
if (parameter_shape.empty()) {
|
||||
continue;
|
||||
}
|
||||
Shape slice_shape = parameter_shape[0];
|
||||
if (slice_shape.empty()) {
|
||||
continue;
|
||||
}
|
||||
slice_shape[0] = slice_shape[0] / dev_num;
|
||||
auto slice_shape_ptr = std::make_shared<abstract::Shape>(slice_shape);
|
||||
auto abstract = parameter->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
auto abstract_cloned = abstract->Clone();
|
||||
MS_EXCEPTION_IF_NULL(abstract_cloned);
|
||||
abstract_cloned->set_shape(slice_shape_ptr);
|
||||
parameter->set_abstract(abstract_cloned);
|
||||
}
|
||||
}
|
||||
|
||||
static bool IsFullySplitParameter(const ParameterPtr ¶m_ptr) {
|
||||
auto tensor_layout = param_ptr->user_data<parallel::TensorLayout>();
|
||||
if (tensor_layout == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto dev_mat_shape = tensor_layout->device_arrangement().array();
|
||||
auto tensor_map = tensor_layout->tensor_map().array();
|
||||
int64_t rank = g_device_manager->global_rank();
|
||||
RankList rank_list = g_device_manager->GetDeviceListInThisStage();
|
||||
DeviceMatrix dev_matrix(rank, rank_list, dev_mat_shape);
|
||||
RankList group_devices;
|
||||
if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
|
||||
MS_LOG(WARNING) << "Get devices by tensor map failed, invalid tensor layout";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (group_devices.size() == 1) {
|
||||
MS_LOG(INFO) << "The parameter: " << param_ptr->name() << " is fully split";
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static void InsertFullySplitParamGradAccu(const std::pair<AnfNodePtr, int> &node_user,
|
||||
const FuncGraphManagerPtr &manager, const AnfNodePtr &accu_parameter) {
|
||||
auto cnode = node_user.first->cast<CNodePtr>();
|
||||
auto prim = GetCNodePrimitive(cnode);
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(WARNING) << cnode->DebugString() << " can not insert fully split param grad accumulation node";
|
||||
return;
|
||||
}
|
||||
OperatorAttrs attrs;
|
||||
auto py_instance = CreatOpInstance(attrs, "_VirtualAdd", "grad_accu");
|
||||
auto value_node = NewValueNode(py_instance);
|
||||
std::vector<AnfNodePtr> virtual_node_input = {value_node, cnode->input(node_user.second), accu_parameter};
|
||||
auto graph = cnode->func_graph();
|
||||
auto virtual_node = graph->NewCNode(virtual_node_input);
|
||||
manager->SetEdge(cnode, node_user.second, virtual_node);
|
||||
}
|
||||
|
||||
void HandleFullySplitParameters(const FuncGraphPtr &root) {
|
||||
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
||||
if ((grad_accumulation_step <= 1) || root->has_flag(ACCUMULATION)) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto parameters = root->parameters();
|
||||
auto node_users_map = root->manager()->node_users();
|
||||
for (auto ¶meter : parameters) {
|
||||
auto param_ptr = parameter->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_ptr);
|
||||
|
||||
if (!IsFullySplitParameter(param_ptr)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto accu_parameter = FindGradAccuParameter(parameters, param_ptr->name());
|
||||
if (!accu_parameter) {
|
||||
continue; // some parameters no need to handle, such as itself or lr
|
||||
}
|
||||
|
||||
auto node_users = node_users_map[parameter];
|
||||
for (auto &user : node_users) {
|
||||
auto node = user.first;
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!cnode->in_forward_flag()) {
|
||||
continue;
|
||||
}
|
||||
InsertFullySplitParamGradAccu(user, root->manager(), accu_parameter);
|
||||
MS_LOG(INFO) << "Insert full split assign add node for " << param_ptr->name();
|
||||
break; // only need to insert once, if the parameter has many users
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
for (auto &cloned_parameter_node : root->parameters()) {
|
||||
MS_EXCEPTION_IF_NULL(cloned_parameter_node);
|
||||
auto cloned_parameter = cloned_parameter_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(cloned_parameter);
|
||||
|
||||
if (!ParameterIsCloned(cloned_parameter_node)) {
|
||||
continue;
|
||||
}
|
||||
auto param_value = cloned_parameter->param_info();
|
||||
if (param_value == nullptr) {
|
||||
continue;
|
||||
}
|
||||
// get the cloned index
|
||||
int64_t cloned_index = param_value->cloned_index();
|
||||
|
||||
// find the be cloned parameter
|
||||
bool found_be_cloned_parameter = false;
|
||||
ParameterPtr cloned_from_parameter = nullptr;
|
||||
AnfNodePtr cloned_from_node = nullptr;
|
||||
for (auto &be_cloned_parameter_node : root->parameters()) {
|
||||
MS_EXCEPTION_IF_NULL(be_cloned_parameter_node);
|
||||
auto be_cloned_parameter = be_cloned_parameter_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(be_cloned_parameter);
|
||||
if (!be_cloned_parameter->has_default()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto param_value_in = be_cloned_parameter->param_info();
|
||||
if (param_value_in == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (!param_value_in->be_cloned()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// get the be cloned index
|
||||
auto &be_cloned_index = param_value_in->be_cloned_index();
|
||||
if (std::find(be_cloned_index.begin(), be_cloned_index.end(), cloned_index) != be_cloned_index.end()) {
|
||||
found_be_cloned_parameter = true;
|
||||
cloned_from_parameter = be_cloned_parameter;
|
||||
cloned_from_node = be_cloned_parameter_node;
|
||||
}
|
||||
}
|
||||
|
||||
if (found_be_cloned_parameter) {
|
||||
// set the shape and tensor layout for cloned parameter
|
||||
std::string param_name = cloned_parameter_node->cast<ParameterPtr>()->name();
|
||||
if (cloned_from_parameter->user_data<TensorLayout>() == nullptr) {
|
||||
MS_LOG(WARNING) << "The parameter " << param_name << " has not tensor layout, skip it";
|
||||
continue;
|
||||
}
|
||||
auto tensor_layout = cloned_from_parameter->user_data<TensorLayout>();
|
||||
MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
|
||||
MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
|
||||
auto cloned_abstract = cloned_parameter_node->abstract()->Clone();
|
||||
MS_EXCEPTION_IF_NULL(cloned_abstract);
|
||||
// from pipeline or grad accumulation
|
||||
if (param_name.find(ACCU_GRADS) != std::string::npos) {
|
||||
auto slice_shape = cloned_from_parameter->user_data<TensorLayout>()->slice_shape().array();
|
||||
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
|
||||
MS_EXCEPTION_IF_NULL(parallel_shape);
|
||||
cloned_abstract->set_shape(parallel_shape);
|
||||
// in opt shard, accu_grad's shape is different from the original param's shape
|
||||
if (ParallelContext::GetInstance()->enable_parallel_optimizer()) {
|
||||
TensorLayout new_layout = *tensor_layout;
|
||||
new_layout.set_opt_shard_group("");
|
||||
tensor_layout = std::make_shared<TensorLayout>(new_layout);
|
||||
}
|
||||
} else {
|
||||
cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack());
|
||||
}
|
||||
cloned_parameter->set_user_data<TensorLayout>(tensor_layout);
|
||||
cloned_parameter_node->set_abstract(cloned_abstract);
|
||||
MS_LOG(INFO) << "The parameter: " << cloned_parameter->name()
|
||||
<< " is cloned, the be cloned parameter is: " << cloned_from_parameter->name()
|
||||
<< ", clone index is: " << cloned_index;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The parameter: " << cloned_parameter->name() << " is cloned, cloned index is "
|
||||
<< cloned_index << ", but not found the be cloned parameter";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void HandleAdaFactorOpt(const FuncGraphPtr &root) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
for (auto ¶m_node : root->parameters()) {
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
auto param = param_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
std::string param_name = param->name();
|
||||
|
||||
if (param_name.find(EXP_AVG) != std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto tensor_layout = param->user_data<TensorLayout>();
|
||||
if (tensor_layout == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int64_t row_col_count = 0;
|
||||
int64_t exp_avg_sq_count = 0;
|
||||
for (auto &row_col_node : root->parameters()) {
|
||||
MS_EXCEPTION_IF_NULL(row_col_node);
|
||||
auto row_col_param = row_col_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(row_col_param);
|
||||
std::string row_col_param_name = row_col_param->name();
|
||||
std::string exp_row_name = EXP_AVG_SQ_ROW + param_name;
|
||||
std::string exp_col_name = EXP_AVG_SQ_COL + param_name;
|
||||
std::string exp_avg_name = EXP_AVG_SQ + param_name;
|
||||
|
||||
if ((row_col_param_name != exp_row_name) && (row_col_param_name != exp_col_name) &&
|
||||
(row_col_param_name != exp_avg_name)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto slice_shape = tensor_layout->slice_shape().array();
|
||||
auto shape_size = slice_shape.size();
|
||||
bool is_row_or_col_param = (row_col_param_name == exp_row_name) || (row_col_param_name == exp_col_name);
|
||||
if (is_row_or_col_param && shape_size <= 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (row_col_param_name == exp_avg_name && shape_size != 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto origin_shape = tensor_layout->tensor_shape().array();
|
||||
auto dev_mat = tensor_layout->device_arrangement().array();
|
||||
auto tensor_map = tensor_layout->tensor_map().array();
|
||||
|
||||
if (row_col_param_name == exp_row_name) {
|
||||
slice_shape.pop_back();
|
||||
origin_shape.pop_back();
|
||||
tensor_map.pop_back();
|
||||
row_col_count++;
|
||||
} else if (row_col_param_name == exp_col_name) {
|
||||
(void)slice_shape.erase(slice_shape.begin() + static_cast<different_type>(SECOND_FROM_END(shape_size)));
|
||||
(void)origin_shape.erase(origin_shape.begin() + static_cast<different_type>(SECOND_FROM_END(shape_size)));
|
||||
(void)tensor_map.erase(tensor_map.begin() + static_cast<different_type>(SECOND_FROM_END(shape_size)));
|
||||
row_col_count++;
|
||||
} else {
|
||||
exp_avg_sq_count++;
|
||||
}
|
||||
|
||||
TensorLayout new_tensor_layout;
|
||||
if (new_tensor_layout.InitFromVector(dev_mat, tensor_map, origin_shape) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Init tensor layout failed";
|
||||
}
|
||||
|
||||
auto cloned_abstract = row_col_node->abstract()->Clone();
|
||||
MS_EXCEPTION_IF_NULL(cloned_abstract);
|
||||
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
|
||||
MS_EXCEPTION_IF_NULL(parallel_shape);
|
||||
cloned_abstract->set_shape(parallel_shape);
|
||||
row_col_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(new_tensor_layout));
|
||||
row_col_node->set_abstract(cloned_abstract);
|
||||
MS_LOG(INFO) << "Set the slice shape for " << row_col_param_name << ", origin shape is " << origin_shape
|
||||
<< ", new slice shape is " << slice_shape;
|
||||
|
||||
if (row_col_count == 2 || exp_avg_sq_count == 1) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PARAMETER_MANAGER_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PARAMETER_MANAGER_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "base/base.h"
|
||||
#include "frontend/parallel/device_manager.h"
|
||||
#include "frontend/parallel/step_parallel_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
constexpr char EXP_AVG[] = "exp_avg_";
|
||||
constexpr char EXP_AVG_SQ_ROW[] = "exp_avg_sq_row_";
|
||||
constexpr char EXP_AVG_SQ_COL[] = "exp_avg_sq_col_";
|
||||
constexpr char EXP_AVG_SQ[] = "exp_avg_sq_";
|
||||
using RefKeyPair = std::pair<AnfNodePtr, std::vector<AnfNodePtr>>;
|
||||
using ParameterUsersInfo = std::pair<std::string, std::pair<AnfNodePtr, AnfNodeIndexSet>>;
|
||||
struct ParameterSliceInfo {
|
||||
Shape slice_shape;
|
||||
RankList group_ranks;
|
||||
};
|
||||
|
||||
ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &));
|
||||
void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes);
|
||||
void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes);
|
||||
void HandleNoUsedParameter(const FuncGraphPtr &root);
|
||||
void HandleFullySplitParameters(const FuncGraphPtr &root);
|
||||
void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root);
|
||||
void HandleAdaFactorOpt(const FuncGraphPtr &root);
|
||||
bool ParameterIsCloned(const AnfNodePtr ¶meter_node);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PARAMETER_MANAGER_H_
|
|
@ -30,6 +30,7 @@
|
|||
#include "frontend/parallel/node_check.h"
|
||||
#include "frontend/parallel/graph_util/node_info.h"
|
||||
#include "frontend/parallel/graph_util/pipeline_split_utils.h"
|
||||
#include "frontend/parallel/step_parallel_utils.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "base/core_ops.h"
|
||||
|
|
|
@ -43,6 +43,7 @@
|
|||
#include "frontend/parallel/ops_info/reshape_info.h"
|
||||
#include "frontend/parallel/ops_info/tmp_identity_info.h"
|
||||
#include "frontend/parallel/step_parallel.h"
|
||||
#include "frontend/parallel/parameter_manager.h"
|
||||
#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/param_info.h"
|
||||
|
|
|
@ -39,6 +39,7 @@
|
|||
#include "frontend/parallel/graph_util/node_info.h"
|
||||
#include "frontend/parallel/graph_util/pipeline_split_utils.h"
|
||||
#include "frontend/parallel/node_check.h"
|
||||
#include "frontend/parallel/parameter_manager.h"
|
||||
#include "frontend/parallel/ops_info/matmul_info.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "ir/tensor.h"
|
||||
|
@ -141,28 +142,6 @@ std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node,
|
|||
return new_node_input;
|
||||
}
|
||||
|
||||
bool ParameterIsCloned(const AnfNodePtr ¶meter_node) {
|
||||
MS_EXCEPTION_IF_NULL(parameter_node);
|
||||
auto cloned_parameter = parameter_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(cloned_parameter);
|
||||
|
||||
// find the clone parameter
|
||||
if (!cloned_parameter->has_default()) {
|
||||
return false;
|
||||
}
|
||||
auto param_value = cloned_parameter->param_info();
|
||||
if (param_value == nullptr) {
|
||||
return false;
|
||||
}
|
||||
bool cloned = param_value->cloned();
|
||||
if (!cloned) {
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned";
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operator &op, const AnfNodePtr &node,
|
||||
const std::string &instance_name, const std::string &weight_name) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
|
@ -579,31 +558,6 @@ bool FindCommunicationOp(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool IsParallelCareNode(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
ValueNodePtr prim_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
if (prim_node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
PrimitivePtr prim = prim_node->value()->cast<PrimitivePtr>();
|
||||
if (prim == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (IsInParallelBlackList(prim)) {
|
||||
MS_LOG(DEBUG) << "Parallel don't care node: " << prim->name();
|
||||
return false;
|
||||
}
|
||||
// get_next is not in the forward graph, we need mark the get_next as the forward node
|
||||
if (prim->name() == GET_NEXT || prim->name() == VIRTUAL_OUTPUT) {
|
||||
return true;
|
||||
}
|
||||
if ((prim->name() == CAST) && !cnode->has_user_data<OperatorInfo>()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return cnode->in_forward_flag();
|
||||
}
|
||||
|
||||
void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_operator, const CNodePtr &insert_node,
|
||||
const TensorRedistribution &tensor_redistribution, const CNodePtr &pre_node) {
|
||||
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||
|
@ -898,16 +852,6 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
|
|||
MS_LOG(INFO) << "Insert ReplaceOp success for " << distribute_operator->name();
|
||||
}
|
||||
|
||||
bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) {
|
||||
if (!cnode) {
|
||||
return false;
|
||||
}
|
||||
ValueNodePtr anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
PrimitivePtr prim = anf_node->value()->cast<PrimitivePtr>();
|
||||
return (prim->name() == name);
|
||||
}
|
||||
|
||||
void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(replace_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -1468,72 +1412,6 @@ StrategyPtr ExtractStrategy(const ValuePtr &stra) {
|
|||
return strategyPtr;
|
||||
}
|
||||
|
||||
Shapes GetValueListShape(const AnfNodePtr &node) {
|
||||
Shapes shapes;
|
||||
std::vector<ValuePtr> inputs_seq;
|
||||
if (IsValueNode<ValueList>(node)) {
|
||||
inputs_seq = node->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value();
|
||||
} else if (IsValueNode<ValueTuple>(node)) {
|
||||
inputs_seq = node->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "node is eigther ValueList or ValueTuple";
|
||||
}
|
||||
for (auto &ele : inputs_seq) {
|
||||
auto tensor = ele->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
auto one_shape = tensor->shape();
|
||||
shapes.push_back(one_shape);
|
||||
}
|
||||
return shapes;
|
||||
}
|
||||
|
||||
Shapes GetNodeShape(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
Shapes shapes;
|
||||
if (IsValueNode<ValueList>(node) || IsValueNode<ValueTuple>(node)) {
|
||||
return GetValueListShape(node);
|
||||
}
|
||||
BaseShapePtr base_shape_ptr = node->Shape();
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (IsValueNode<Primitive>(cnode->input(0))) {
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (prim->name() == MAKEREF) {
|
||||
AnfNodePtr ref_node = cnode->input(1);
|
||||
auto func_graph = cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(ref_node);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
return GetRefKeyNodeShape(ref_node, func_graph);
|
||||
}
|
||||
}
|
||||
if (cnode->input(0)->isa<CNode>()) {
|
||||
if (cnode->inputs().size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is smaller than 2";
|
||||
}
|
||||
base_shape_ptr = cnode->input(1)->Shape();
|
||||
}
|
||||
}
|
||||
if (base_shape_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " shape_ptr is nullptr, full name is "
|
||||
<< node->fullname_with_scope();
|
||||
}
|
||||
auto tuple_shape_ptr = dyn_cast<abstract::SequeueShape>(base_shape_ptr);
|
||||
if (tuple_shape_ptr != nullptr) {
|
||||
auto tuple_shape = tuple_shape_ptr->shape();
|
||||
for (auto &shape : tuple_shape) {
|
||||
auto each_shape = dyn_cast<abstract::Shape>(shape);
|
||||
MS_EXCEPTION_IF_NULL(each_shape);
|
||||
shapes.push_back(each_shape->shape());
|
||||
}
|
||||
} else {
|
||||
auto shape_ptr = dyn_cast<abstract::Shape>(base_shape_ptr);
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
shapes.push_back(shape_ptr->shape());
|
||||
}
|
||||
return shapes;
|
||||
}
|
||||
|
||||
Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -1918,91 +1796,6 @@ void CoverSliceShape(const FuncGraphPtr &root) {
|
|||
g_RefMap.clear();
|
||||
}
|
||||
|
||||
void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
for (auto &cloned_parameter_node : root->parameters()) {
|
||||
MS_EXCEPTION_IF_NULL(cloned_parameter_node);
|
||||
auto cloned_parameter = cloned_parameter_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(cloned_parameter);
|
||||
|
||||
if (!ParameterIsCloned(cloned_parameter_node)) {
|
||||
continue;
|
||||
}
|
||||
auto param_value = cloned_parameter->param_info();
|
||||
if (param_value == nullptr) {
|
||||
continue;
|
||||
}
|
||||
// get the cloned index
|
||||
int64_t cloned_index = param_value->cloned_index();
|
||||
|
||||
// find the be cloned parameter
|
||||
bool found_be_cloned_parameter = false;
|
||||
ParameterPtr cloned_from_parameter = nullptr;
|
||||
AnfNodePtr cloned_from_node = nullptr;
|
||||
for (auto &be_cloned_parameter_node : root->parameters()) {
|
||||
MS_EXCEPTION_IF_NULL(be_cloned_parameter_node);
|
||||
auto be_cloned_parameter = be_cloned_parameter_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(be_cloned_parameter);
|
||||
if (!be_cloned_parameter->has_default()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto param_value_in = be_cloned_parameter->param_info();
|
||||
if (param_value_in == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (!param_value_in->be_cloned()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// get the be cloned index
|
||||
auto &be_cloned_index = param_value_in->be_cloned_index();
|
||||
if (std::find(be_cloned_index.begin(), be_cloned_index.end(), cloned_index) != be_cloned_index.end()) {
|
||||
found_be_cloned_parameter = true;
|
||||
cloned_from_parameter = be_cloned_parameter;
|
||||
cloned_from_node = be_cloned_parameter_node;
|
||||
}
|
||||
}
|
||||
|
||||
if (found_be_cloned_parameter) {
|
||||
// set the shape and tensor layout for cloned parameter
|
||||
std::string param_name = cloned_parameter_node->cast<ParameterPtr>()->name();
|
||||
if (cloned_from_parameter->user_data<TensorLayout>() == nullptr) {
|
||||
MS_LOG(WARNING) << "The parameter " << param_name << " has not tensor layout, skip it";
|
||||
continue;
|
||||
}
|
||||
auto tensor_layout = cloned_from_parameter->user_data<TensorLayout>();
|
||||
MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
|
||||
MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
|
||||
auto cloned_abstract = cloned_parameter_node->abstract()->Clone();
|
||||
MS_EXCEPTION_IF_NULL(cloned_abstract);
|
||||
// from pipeline or grad accumulation
|
||||
if (param_name.find(ACCU_GRADS) != std::string::npos) {
|
||||
auto slice_shape = cloned_from_parameter->user_data<TensorLayout>()->slice_shape().array();
|
||||
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
|
||||
MS_EXCEPTION_IF_NULL(parallel_shape);
|
||||
cloned_abstract->set_shape(parallel_shape);
|
||||
// in opt shard, accu_grad's shape is different from the original param's shape
|
||||
if (ParallelContext::GetInstance()->enable_parallel_optimizer()) {
|
||||
TensorLayout new_layout = *tensor_layout;
|
||||
new_layout.set_opt_shard_group("");
|
||||
tensor_layout = std::make_shared<TensorLayout>(new_layout);
|
||||
}
|
||||
} else {
|
||||
cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack());
|
||||
}
|
||||
cloned_parameter->set_user_data<TensorLayout>(tensor_layout);
|
||||
cloned_parameter_node->set_abstract(cloned_abstract);
|
||||
MS_LOG(INFO) << "The parameter: " << cloned_parameter->name()
|
||||
<< " is cloned, the be cloned parameter is: " << cloned_from_parameter->name()
|
||||
<< ", clone index is: " << cloned_index;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The parameter: " << cloned_parameter->name() << " is cloned, cloned index is "
|
||||
<< cloned_index << ", but not found the be cloned parameter";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SetVirtualDatasetStrategy(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||
|
@ -2931,41 +2724,6 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
void RevertSymbolicKeyInstance(const FuncGraphPtr &root, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto symbolic_key = GetValueNode<SymbolicKeyInstancePtr>(node);
|
||||
MS_EXCEPTION_IF_NULL(symbolic_key);
|
||||
auto all_upstream_node = root->manager()->node_users()[node];
|
||||
for (auto &upstream_node : all_upstream_node) {
|
||||
FuncGraphPtr fg = upstream_node.first->func_graph();
|
||||
if (symbolic_key->node()->isa<Parameter>()) {
|
||||
for (auto ¶m : root->parameters()) {
|
||||
if (*param == *symbolic_key->node()) {
|
||||
AnfNodePtr reverted_node = root->NewCNode({NewValueNode(prim::kPrimEmbed), param});
|
||||
MS_EXCEPTION_IF_NULL(reverted_node);
|
||||
MS_LOG(DEBUG) << "before replace " << node->ToString() << " to node " << reverted_node->DebugString();
|
||||
(void)fg->manager()->Replace(node, reverted_node);
|
||||
MS_LOG(DEBUG) << "revert node " << node->ToString() << " to node " << reverted_node->DebugString();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
for (auto &node : all_nodes) {
|
||||
// revert back SymbolicKeyInstance to embed() primitive
|
||||
if (IsValueNode<SymbolicKeyInstance>(node)) {
|
||||
RevertSymbolicKeyInstance(root, node);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool IsCohesiveNode(const CNodePtr &cnode) {
|
||||
return IsPrimitiveCNode(cnode, prim::kPrimCast) || IsPrimitiveCNode(cnode, prim::kPrimLoad) ||
|
||||
IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather) ||
|
||||
|
@ -3365,200 +3123,6 @@ void HandleForwardMakeTupleAndMakeList(const std::vector<AnfNodePtr> &all_nodes)
|
|||
}
|
||||
}
|
||||
|
||||
RefKeyPair CNodeWithRefKeys(const AnfNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::vector<AnfNodePtr> refkeys;
|
||||
if (cnode->isa<CNode>()) {
|
||||
auto cnode_ptr = cnode->cast<CNodePtr>();
|
||||
auto inputs = cnode_ptr->inputs();
|
||||
for (auto &one_input : inputs) {
|
||||
if (IsValueNode<RefKey>(one_input)) {
|
||||
refkeys.push_back(one_input);
|
||||
}
|
||||
}
|
||||
if (refkeys.size() >= 1) {
|
||||
return std::make_pair(cnode, refkeys);
|
||||
}
|
||||
}
|
||||
return {nullptr, refkeys};
|
||||
}
|
||||
|
||||
ParameterUsersInfo FindParameterNodeUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &)) {
|
||||
// In this case, node is a Parameter
|
||||
ParameterUsersInfo parameter_user_info;
|
||||
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||
MS_EXCEPTION_IF_NULL(node->func_graph()->manager());
|
||||
auto candidate_set = node->func_graph()->manager()->node_users()[node];
|
||||
for (auto &candidate : candidate_set) {
|
||||
auto candidate_node = candidate.first;
|
||||
if (IsPrimitiveCNode(candidate_node, prim::kPrimLoad)) {
|
||||
if (candidate.second != 1) {
|
||||
continue;
|
||||
}
|
||||
auto load_node_users = node->func_graph()->manager()->node_users()[candidate_node];
|
||||
for (auto &node_user : load_node_users) {
|
||||
auto cnode = node_user.first->cast<CNodePtr>();
|
||||
if (cnode == nullptr || !cnode->has_user_data<OperatorInfo>() || IsSomePrimitive(cnode, RECEIVE)) {
|
||||
continue;
|
||||
}
|
||||
(void)parameter_user_info.second.second.insert(node_user);
|
||||
}
|
||||
} else {
|
||||
auto c = candidate_node->cast<CNodePtr>();
|
||||
if (c == nullptr || !c->has_user_data<OperatorInfo>() || IsSomePrimitive(c, RECEIVE)) {
|
||||
continue;
|
||||
}
|
||||
(void)parameter_user_info.second.second.insert(candidate);
|
||||
}
|
||||
}
|
||||
parameter_user_info.first = node->cast<ParameterPtr>()->name();
|
||||
parameter_user_info.second.first = node;
|
||||
return parameter_user_info;
|
||||
}
|
||||
|
||||
ParameterUsersInfo FindRefKeyNodeUsers(const RefKeyPair &ref_key_pair, bool (*IsCareNode)(const CNodePtr &)) {
|
||||
// Dealing with the RefKey case
|
||||
ParameterUsersInfo parameter_user_info;
|
||||
auto refkeys = ref_key_pair.second;
|
||||
auto cnode = ref_key_pair.first;
|
||||
|
||||
auto cnode_ptr = cnode->cast<CNodePtr>();
|
||||
if ((cnode_ptr == nullptr) || !IsValueNode<Primitive>(cnode_ptr->input(0)) || !IsCareNode(cnode_ptr)) {
|
||||
return parameter_user_info;
|
||||
}
|
||||
|
||||
if (refkeys.size() > 1) {
|
||||
MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << "'s inputs have more than 1 RefKeys";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(cnode->func_graph());
|
||||
auto cnode_func_graph = cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager());
|
||||
|
||||
// Find the RefKey being used
|
||||
auto candidate_set_by_refkey = cnode_func_graph->manager()->node_users()[refkeys[0]];
|
||||
for (auto &candidate : candidate_set_by_refkey) {
|
||||
auto candidate_node = candidate.first;
|
||||
auto c = candidate_node->cast<CNodePtr>();
|
||||
if ((c == nullptr) || !IsValueNode<Primitive>(c->input(0)) || !IsCareNode(c)) {
|
||||
continue;
|
||||
}
|
||||
parameter_user_info.second.second.add(candidate);
|
||||
}
|
||||
|
||||
// Find the corresponding Parameter being used
|
||||
std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(refkeys[0], cnode_func_graph);
|
||||
if (parameters.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
|
||||
}
|
||||
parameter_user_info.first = parameters[0]->cast<ParameterPtr>()->name();
|
||||
parameter_user_info.second.first = parameters[0];
|
||||
auto candidate_set_by_para = cnode_func_graph->manager()->node_users()[parameters[0]];
|
||||
for (auto &candidate : candidate_set_by_para) {
|
||||
auto candidate_node = candidate.first;
|
||||
auto c = candidate_node->cast<CNodePtr>();
|
||||
if ((c == nullptr) || !IsValueNode<Primitive>(c->input(0)) || !IsCareNode(c)) {
|
||||
continue;
|
||||
}
|
||||
(void)parameter_user_info.second.second.insert(candidate);
|
||||
}
|
||||
return parameter_user_info;
|
||||
}
|
||||
|
||||
ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &)) {
|
||||
ParameterUsersInfo parameter_users_info;
|
||||
|
||||
auto cnode_with_refkeys = CNodeWithRefKeys(node);
|
||||
if (cnode_with_refkeys.first != nullptr) {
|
||||
// the node is a ref key node
|
||||
return FindRefKeyNodeUsers(cnode_with_refkeys, IsCareNode);
|
||||
} else if (node->isa<Parameter>()) {
|
||||
// the node is a parameter node
|
||||
return FindParameterNodeUsers(node, IsCareNode);
|
||||
}
|
||||
|
||||
return parameter_users_info;
|
||||
}
|
||||
|
||||
RankList GetGroupByTensorInfo(const TensorInfo &tensor_info) {
|
||||
CheckGlobalDeviceManager();
|
||||
int64_t rank = g_device_manager->global_rank();
|
||||
RankList stage_device_list = g_device_manager->GetDeviceListInThisStage();
|
||||
Shape dev_matrix_shape = tensor_info.tensor_layout().device_arrangement().array();
|
||||
Shape tensor_map = tensor_info.tensor_layout().tensor_map().array();
|
||||
|
||||
DeviceMatrix dev_matrix(rank, stage_device_list, dev_matrix_shape);
|
||||
RankList group_devices;
|
||||
if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Get devices by tensor map failed";
|
||||
}
|
||||
|
||||
std::sort(group_devices.begin(), group_devices.end());
|
||||
return group_devices;
|
||||
}
|
||||
|
||||
ParameterSliceInfo GetParameterSliceInfo(const std::pair<AnfNodePtr, int64_t> ¶m_info) {
|
||||
auto user_cnode = param_info.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(user_cnode);
|
||||
auto user_input_index = param_info.second;
|
||||
OperatorInfoPtr op_info = user_cnode->user_data<OperatorInfo>();
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
|
||||
TensorInfo tensor_info;
|
||||
if (IsPrimitiveCNode(user_cnode, prim::kPrimSend)) {
|
||||
auto param_index = IntToSize(GetValue<int>(user_cnode->GetPrimalAttr(PARAM_INDEX)));
|
||||
tensor_info = op_info->inputs_tensor_info()[param_index];
|
||||
} else {
|
||||
size_t input_tensor_info_size = op_info->inputs_tensor_info().size();
|
||||
if (SizeToLong(input_tensor_info_size) <= user_input_index - 1) {
|
||||
MS_LOG(EXCEPTION) << op_info->name() << ": the size of inputs tensor info is " << input_tensor_info_size
|
||||
<< ", but the index is " << user_input_index - 1;
|
||||
}
|
||||
tensor_info = op_info->inputs_tensor_info()[user_input_index - 1];
|
||||
}
|
||||
|
||||
ParameterSliceInfo parameter_slice_info;
|
||||
parameter_slice_info.slice_shape = tensor_info.slice_shape();
|
||||
parameter_slice_info.group_ranks = GetGroupByTensorInfo(tensor_info);
|
||||
MS_LOG(DEBUG) << "The op name is " << op_info->name() << ", the parameter index is " << user_input_index - 1
|
||||
<< ", the slice shape is " << tensor_info.slice_shape() << ", the origin shape is "
|
||||
<< tensor_info.shape() << ", the group rank list is " << parameter_slice_info.group_ranks;
|
||||
return parameter_slice_info;
|
||||
}
|
||||
|
||||
void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) {
|
||||
for (auto &node : all_nodes) {
|
||||
ParameterUsersInfo parameter_users_info = FindParameterUsers(node, IsParallelCareNode);
|
||||
auto users_set = parameter_users_info.second.second;
|
||||
if (users_set.size() <= 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto parameter_name = parameter_users_info.first;
|
||||
MS_LOG(INFO) << "The parameter: " << parameter_name << " has " << users_set.size() << " users";
|
||||
auto first_user = users_set.pop();
|
||||
ParameterSliceInfo parameter_slice_info = GetParameterSliceInfo(first_user);
|
||||
Shape first_user_slice_shape = parameter_slice_info.slice_shape;
|
||||
RankList first_user_group_list = parameter_slice_info.group_ranks;
|
||||
|
||||
for (auto &user : users_set) {
|
||||
ParameterSliceInfo user_slice_info = GetParameterSliceInfo(user);
|
||||
Shape user_slice_shape = user_slice_info.slice_shape;
|
||||
RankList user_group_list = user_slice_info.group_ranks;
|
||||
if (first_user_slice_shape != user_slice_shape) {
|
||||
MS_LOG(EXCEPTION) << "The parameter: " << parameter_name
|
||||
<< " has multiple users, but the slice shapes are different";
|
||||
}
|
||||
|
||||
if (ParallelContext::GetInstance()->pipeline_stage_split_num() == 1 && first_user_group_list != user_group_list) {
|
||||
MS_LOG(EXCEPTION) << "The parameter: " << parameter_name
|
||||
<< " has multiple users, but the group rank list are different, "
|
||||
<< "the group rank list for first user is " << first_user_group_list
|
||||
<< ", and the group rank list for this user is " << user_group_list;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool CreateGroupsByCkptFile(const std::string &file) {
|
||||
GroupInfoMap group_info_map;
|
||||
if (StrategyCheckpoint::GetInstance().LoadGroupInfo(file, &group_info_map) != SUCCESS) {
|
||||
|
@ -3572,154 +3136,6 @@ bool CreateGroupsByCkptFile(const std::string &file) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr ¶meter) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
auto manager = graph->manager();
|
||||
auto node_users = manager->node_users()[parameter];
|
||||
if (node_users.empty()) {
|
||||
return false;
|
||||
}
|
||||
for (auto node_user : node_users) {
|
||||
auto use_node = node_user.first->cast<CNodePtr>();
|
||||
if (IsValueNode<FuncGraph>(use_node->input(0))) {
|
||||
auto graph_sub = GetValueNode<FuncGraphPtr>(use_node->input(0));
|
||||
auto parameters = graph_sub->parameters();
|
||||
auto parameter_sub = parameters[node_user.second - 1];
|
||||
return IsUsedParameter(graph_sub, parameter_sub);
|
||||
}
|
||||
if (use_node->input(0)->isa<CNode>()) {
|
||||
auto cnode = use_node->input(0)->cast<CNodePtr>();
|
||||
if (!IsSomePrimitive(cnode, J) || !IsValueNode<FuncGraph>(cnode->input(1))) {
|
||||
return true;
|
||||
}
|
||||
auto graph_sub = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
auto parameters = graph_sub->parameters();
|
||||
auto parameter_sub = parameters[node_user.second - 1];
|
||||
return IsUsedParameter(graph_sub, parameter_sub);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static void HandleNoUsedParameter(const FuncGraphPtr &root) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
bool full_batch = ParallelContext::GetInstance()->full_batch();
|
||||
if (full_batch) {
|
||||
return;
|
||||
}
|
||||
|
||||
// in grad accumulation mode, if use dynamic lr, it has some parameters in optimizer which no used for first graph,
|
||||
// but used for second graph(such as global_step), so can not change their shapes
|
||||
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
||||
if (grad_accumulation_step > 1) {
|
||||
MS_LOG(INFO) << "In grad accumulation mode, do not handle no used parameters";
|
||||
return;
|
||||
}
|
||||
|
||||
auto dev_num = g_device_manager->stage_device_num();
|
||||
auto parameters = root->parameters();
|
||||
for (auto ¶meter : parameters) {
|
||||
if (IsUsedParameter(root, parameter)) {
|
||||
continue;
|
||||
}
|
||||
auto parameter_shape = GetNodeShape(parameter);
|
||||
if (parameter_shape.empty()) {
|
||||
continue;
|
||||
}
|
||||
Shape slice_shape = parameter_shape[0];
|
||||
if (slice_shape.empty()) {
|
||||
continue;
|
||||
}
|
||||
slice_shape[0] = slice_shape[0] / dev_num;
|
||||
auto slice_shape_ptr = std::make_shared<abstract::Shape>(slice_shape);
|
||||
auto abstract = parameter->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
auto abstract_cloned = abstract->Clone();
|
||||
MS_EXCEPTION_IF_NULL(abstract_cloned);
|
||||
abstract_cloned->set_shape(slice_shape_ptr);
|
||||
parameter->set_abstract(abstract_cloned);
|
||||
}
|
||||
}
|
||||
|
||||
static bool IsFullySplitParameter(const ParameterPtr ¶m_ptr) {
|
||||
auto tensor_layout = param_ptr->user_data<parallel::TensorLayout>();
|
||||
if (tensor_layout == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto dev_mat_shape = tensor_layout->device_arrangement().array();
|
||||
auto tensor_map = tensor_layout->tensor_map().array();
|
||||
int64_t rank = g_device_manager->global_rank();
|
||||
RankList rank_list = g_device_manager->GetDeviceListInThisStage();
|
||||
DeviceMatrix dev_matrix(rank, rank_list, dev_mat_shape);
|
||||
RankList group_devices;
|
||||
if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
|
||||
MS_LOG(WARNING) << "Get devices by tensor map failed, invalid tensor layout";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (group_devices.size() == 1) {
|
||||
MS_LOG(INFO) << "The parameter: " << param_ptr->name() << " is fully split";
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static void InsertFullySplitParamGradAccu(const std::pair<AnfNodePtr, int> &node_user,
|
||||
const FuncGraphManagerPtr &manager, const AnfNodePtr &accu_parameter) {
|
||||
auto cnode = node_user.first->cast<CNodePtr>();
|
||||
auto prim = GetCNodePrimitive(cnode);
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(WARNING) << cnode->DebugString() << " can not insert fully split param grad accumulation node";
|
||||
return;
|
||||
}
|
||||
OperatorAttrs attrs;
|
||||
auto py_instance = CreatOpInstance(attrs, "_VirtualAdd", "grad_accu");
|
||||
auto value_node = NewValueNode(py_instance);
|
||||
std::vector<AnfNodePtr> virtual_node_input = {value_node, cnode->input(node_user.second), accu_parameter};
|
||||
auto graph = cnode->func_graph();
|
||||
auto virtual_node = graph->NewCNode(virtual_node_input);
|
||||
manager->SetEdge(cnode, node_user.second, virtual_node);
|
||||
}
|
||||
|
||||
static void HandleFullySplitParameters(const FuncGraphPtr &root) {
|
||||
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
||||
if ((grad_accumulation_step <= 1) || root->has_flag(ACCUMULATION)) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto parameters = root->parameters();
|
||||
auto node_users_map = root->manager()->node_users();
|
||||
for (auto ¶meter : parameters) {
|
||||
auto param_ptr = parameter->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_ptr);
|
||||
|
||||
if (!IsFullySplitParameter(param_ptr)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto accu_parameter = FindGradAccuParameter(parameters, param_ptr->name());
|
||||
if (!accu_parameter) {
|
||||
continue; // some parameters no need to handle, such as itself or lr
|
||||
}
|
||||
|
||||
auto node_users = node_users_map[parameter];
|
||||
for (auto &user : node_users) {
|
||||
auto node = user.first;
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!cnode->in_forward_flag()) {
|
||||
continue;
|
||||
}
|
||||
InsertFullySplitParamGradAccu(user, root->manager(), accu_parameter);
|
||||
MS_LOG(INFO) << "Insert full split assign add node for " << param_ptr->name();
|
||||
break; // only need to insert once, if the parameter has many users
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ReorderForPipelineSplit(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager, int64_t pipeline_stages) {
|
||||
if (!root->has_flag(BACKWARD) && pipeline_stages > 1) {
|
||||
root->set_flag(BACKWARD, true);
|
||||
|
@ -3833,6 +3249,8 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
|||
// set the shape for optimizer's clone tensor
|
||||
SetClonedTensorShapeForOptimizer(root);
|
||||
|
||||
HandleAdaFactorOpt(root);
|
||||
|
||||
// save strategy as checkpoint for multi-train
|
||||
if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) {
|
||||
CheckpointStrategy(all_nodes, root);
|
||||
|
|
|
@ -54,11 +54,6 @@ struct CommInfo {
|
|||
std::string communication_backend;
|
||||
};
|
||||
|
||||
struct ParameterSliceInfo {
|
||||
Shape slice_shape;
|
||||
RankList group_ranks;
|
||||
};
|
||||
|
||||
std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name);
|
||||
std::string CreateInstanceName(const CNodePtr &node, size_t index);
|
||||
void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node);
|
||||
|
@ -77,8 +72,6 @@ void Redistribution(const std::pair<AnfNodePtr, int64_t> &node_pair, const Opera
|
|||
|
||||
bool StrategyFound(std::unordered_map<std::string, ValuePtr> attrs);
|
||||
|
||||
bool IsParallelCareNode(const CNodePtr &cnode);
|
||||
|
||||
void MarkForwardCNode(const FuncGraphPtr &root);
|
||||
|
||||
bool FindCommunicationOp(const std::vector<AnfNodePtr> &all_nodes);
|
||||
|
@ -108,8 +101,6 @@ OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAtt
|
|||
// Extract strategy from attr
|
||||
StrategyPtr ExtractStrategy(const ValuePtr &strategy);
|
||||
|
||||
Shapes GetNodeShape(const AnfNodePtr &node);
|
||||
|
||||
// Extract shape from anfnode
|
||||
std::vector<Shapes> ExtractShape(const CNodePtr &node);
|
||||
|
||||
|
@ -160,15 +151,8 @@ std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root);
|
|||
|
||||
std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node);
|
||||
|
||||
using RefKeyPair = std::pair<AnfNodePtr, std::vector<AnfNodePtr>>;
|
||||
using ParameterUsersInfo = std::pair<std::string, std::pair<AnfNodePtr, AnfNodeIndexSet>>;
|
||||
|
||||
RefKeyPair CNodeWithRefKeys(const AnfNodePtr &cnode);
|
||||
|
||||
std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node);
|
||||
|
||||
ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &));
|
||||
|
||||
bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr ¶meter);
|
||||
|
||||
void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr &distribute_operator,
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/parallel/step_parallel_utils.h"
|
||||
|
||||
#include <inttypes.h>
|
||||
#include <sys/time.h>
|
||||
#include <algorithm>
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "base/core_ops.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "frontend/parallel/device_manager.h"
|
||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||
#include "frontend/parallel/graph_util/graph_info.h"
|
||||
#include "frontend/parallel/graph_util/node_info.h"
|
||||
#include "frontend/parallel/node_check.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "utils/trace_base.h"
|
||||
#include "utils/comm_manager.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "mindspore/core/utils/parallel_node_check.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) {
|
||||
if (!cnode) {
|
||||
return false;
|
||||
}
|
||||
ValueNodePtr anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
PrimitivePtr prim = anf_node->value()->cast<PrimitivePtr>();
|
||||
return (prim->name() == name);
|
||||
}
|
||||
|
||||
bool IsParallelCareNode(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
ValueNodePtr prim_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
if (prim_node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
PrimitivePtr prim = prim_node->value()->cast<PrimitivePtr>();
|
||||
if (prim == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (IsInParallelBlackList(prim)) {
|
||||
MS_LOG(DEBUG) << "Parallel don't care node: " << prim->name();
|
||||
return false;
|
||||
}
|
||||
// get_next is not in the forward graph, we need mark the get_next as the forward node
|
||||
if (prim->name() == GET_NEXT || prim->name() == VIRTUAL_OUTPUT) {
|
||||
return true;
|
||||
}
|
||||
if ((prim->name() == CAST) && !cnode->has_user_data<OperatorInfo>()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return cnode->in_forward_flag();
|
||||
}
|
||||
|
||||
Shapes GetValueListShape(const AnfNodePtr &node) {
|
||||
Shapes shapes;
|
||||
std::vector<ValuePtr> inputs_seq;
|
||||
if (IsValueNode<ValueList>(node)) {
|
||||
inputs_seq = node->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value();
|
||||
} else if (IsValueNode<ValueTuple>(node)) {
|
||||
inputs_seq = node->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "node is eigther ValueList or ValueTuple";
|
||||
}
|
||||
for (auto &ele : inputs_seq) {
|
||||
auto tensor = ele->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
auto one_shape = tensor->shape();
|
||||
shapes.push_back(one_shape);
|
||||
}
|
||||
return shapes;
|
||||
}
|
||||
|
||||
Shapes GetNodeShape(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
Shapes shapes;
|
||||
if (IsValueNode<ValueList>(node) || IsValueNode<ValueTuple>(node)) {
|
||||
return GetValueListShape(node);
|
||||
}
|
||||
BaseShapePtr base_shape_ptr = node->Shape();
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (IsValueNode<Primitive>(cnode->input(0))) {
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (prim->name() == MAKEREF) {
|
||||
AnfNodePtr ref_node = cnode->input(1);
|
||||
auto func_graph = cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(ref_node);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
return GetRefKeyNodeShape(ref_node, func_graph);
|
||||
}
|
||||
}
|
||||
if (cnode->input(0)->isa<CNode>()) {
|
||||
if (cnode->inputs().size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is smaller than 2";
|
||||
}
|
||||
base_shape_ptr = cnode->input(1)->Shape();
|
||||
}
|
||||
}
|
||||
if (base_shape_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " shape_ptr is nullptr, full name is "
|
||||
<< node->fullname_with_scope();
|
||||
}
|
||||
auto tuple_shape_ptr = dyn_cast<abstract::SequeueShape>(base_shape_ptr);
|
||||
if (tuple_shape_ptr != nullptr) {
|
||||
auto tuple_shape = tuple_shape_ptr->shape();
|
||||
for (auto &shape : tuple_shape) {
|
||||
auto each_shape = dyn_cast<abstract::Shape>(shape);
|
||||
MS_EXCEPTION_IF_NULL(each_shape);
|
||||
shapes.push_back(each_shape->shape());
|
||||
}
|
||||
} else {
|
||||
auto shape_ptr = dyn_cast<abstract::Shape>(base_shape_ptr);
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
shapes.push_back(shape_ptr->shape());
|
||||
}
|
||||
return shapes;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_STEP_PARALLEL_UTILS_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_STEP_PARALLEL_UTILS_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "base/base.h"
|
||||
#include "frontend/parallel/device_manager.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name);
|
||||
bool IsParallelCareNode(const CNodePtr &cnode);
|
||||
Shapes GetNodeShape(const AnfNodePtr &node);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_STEP_PARALLEL_UTILS_H_
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.nn import Cell, TrainOneStepCell
|
||||
from mindspore.nn.optim.adafactor import AdaFactor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, matmul_weight, add_weight, strategy1=None, strategy2=None):
|
||||
super().__init__()
|
||||
self.matmul = P.MatMul().shard(strategy1)
|
||||
self.add = P.BiasAdd().shard(strategy2)
|
||||
self.mul_weight = Parameter(matmul_weight, "w1")
|
||||
self.bias = Parameter(add_weight, "bias")
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.matmul(x, self.mul_weight)
|
||||
out = self.add(out, self.bias)
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
_w2 = Tensor(np.ones([32]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile_net(net):
|
||||
scale_parameter = False
|
||||
relative_step = True
|
||||
warmup_init = True
|
||||
compression = True
|
||||
optimizer = AdaFactor(net.trainable_params(), learning_rate=None, weight_decay=0.9,
|
||||
scale_parameter=scale_parameter, relative_step=relative_step,
|
||||
warmup_init=warmup_init, compression=compression)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
train_net.set_train()
|
||||
_executor.compile(train_net, _x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_opt_data_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((16, 1), (1, 1))
|
||||
strategy2 = ((16, 1), (1,))
|
||||
net = Net(_w1, _w2, strategy1, strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_opt_model_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((4, 2), (2, 2))
|
||||
strategy2 = ((4, 2), (2,))
|
||||
net = Net(_w1, _w2, strategy1, strategy2)
|
||||
compile_net(net)
|
Loading…
Reference in New Issue