From d457e81134c6314b01eb5265646ce38dc57f4c04 Mon Sep 17 00:00:00 2001 From: huangxinjing Date: Tue, 16 Nov 2021 14:42:12 +0800 Subject: [PATCH] Add cast for mirror Add detail cast Fix weight cast empty Format the code Add code and name optimizer --- .../ccsrc/frontend/parallel/step_parallel.cc | 11 +- .../frontend/parallel/step_parallel_utils.cc | 122 ++++++++++++++++++ .../frontend/parallel/step_parallel_utils.h | 6 + 3 files changed, 138 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index adb6427629b..91ed78b1fb9 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -48,6 +48,7 @@ #include "utils/ms_context.h" #include "utils/symbolic.h" #include "mindspore/core/utils/parallel_node_check.h" +#include "mindspore/ccsrc/pybind_api/ir/primitive_py.h" #if ((defined ENABLE_CPU) && (!defined _WIN32)) #include "ps/util.h" #include "ps/ps_context.h" @@ -203,12 +204,20 @@ void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const An const FuncGraphPtr &root = nullptr) { // insert new node before the node FuncGraphManagerPtr manager = func_graph->manager(); + auto node_user_map = manager->node_users(); MS_EXCEPTION_IF_NULL(manager); ScopePtr scope = node->scope(); MS_EXCEPTION_IF_NULL(scope); std::vector node_input; + AnfNodePtr pre_node_ = pre_node; if (root && !param_name.empty()) { - node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name); + TypePtr next_node_dtype = FindChildCastWithFP32ToFP16(node, node_user_map); + if (next_node_dtype) { + MS_LOG(INFO) << "Inserting Cast from float32 to float16 for node " << node->fullname_with_scope() << " for saving" + << " communication."; + pre_node_ = CreateFP16Cast(node, pre_node, node_user_map, next_node_dtype); + } + node_input = CreateMirrorInput(root, op, pre_node_, instance_name, param_name); } else { node_input = CreateInput(op, pre_node, instance_name); } diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel_utils.cc b/mindspore/ccsrc/frontend/parallel/step_parallel_utils.cc index d8dc4965c72..bcfe4809892 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel_utils.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel_utils.cc @@ -24,6 +24,8 @@ #include #include #include +#include +#include #include "utils/hash_map.h" #include "base/core_ops.h" @@ -57,6 +59,10 @@ bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) { return (prim->name() == name); } +bool IsInNodeList(const CNodePtr &cnode, const std::set &check_list) { + return std::any_of(check_list.begin(), check_list.end(), [cnode](string in) { return IsSomePrimitive(cnode, in); }); +} + bool IsParallelCareNode(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); ValueNodePtr prim_node = cnode->input(0)->cast(); @@ -301,5 +307,121 @@ void SetStridedSliceSplitStrategy(const std::vector &all_nodes) { } } } + +// Check the given tensor, return nullptr if the given type is not an TensorType +bool CheckTensorType(const TypePtr &node_type) { + MS_EXCEPTION_IF_NULL(node_type); + if (!node_type->isa()) { + return false; + } + return true; +} + +// For the weight used by cast and matmul at the same time, like the followings +// weight1->mirror->cast1-> matmul1; +// weight1->add +// we will not insert the cast(FP32->FP16), as it will cause the input of the operator add to be changed to fp16. +AnfNodePtr GetChildCastNode(const AnfNodePtr &node_ptr, const NodeUsersMap &node_users_map) { + std::queue visited; + AnfNodePtr queue_node = nullptr; + CNodePtr cnode = nullptr; + AnfNodePtr node = nullptr; + if (!node_ptr) { + return nullptr; + } + auto users = node_users_map.at(node_ptr); + for (auto &node_user : users) { + if (node_user.first) { + visited.push(node_user.first); + } + } + while (!visited.empty()) { + queue_node = visited.front(); + visited.pop(); + cnode = queue_node->cast(); + if (!cnode || !cnode->in_forward_flag()) { + continue; + } + if (IsInAllGatherNodeList(cnode) || IsInNodeList(cnode, {LOAD, RESHAPE, DEPEND, UPDATESTATE, MAKE_TUPLE})) { + auto node_set = node_users_map.at(queue_node); + for (auto &node_user : node_set) { + visited.push(node_user.first); + } + } else if (!IsSomePrimitive(cnode, CAST)) { + MS_LOG(INFO) << "The weight's users including the non cast node So " + << "will not insert cast for this parameter " << node_ptr->DebugString(); + return nullptr; + } else if (!node) { + node = queue_node; + } + } + return node; +} +// Given the cnode ptr, find its users until we find the computation node, then return the type of the +// computation node. This function is used to find the target type for CreateFP16Cast. Only returns the target type if +// it is float16, and the source node is float32. If the situation is not matched, then return the nullptr. +TypePtr FindChildCastWithFP32ToFP16(const CNodePtr &cnode_ptr, const NodeUsersMap &node_users_map) { + auto node_ptr = cnode_ptr->cast(); + if (!node_ptr) { + return nullptr; + } + auto cnode_inputs = cnode_ptr->inputs(); + if (cnode_inputs.size() < TWO_INPUT_SIZE) { + return nullptr; + } + // As we execute the function IsWeightValidUsed when we start to insert the mirror, so the second parameter + // is always the parameter. + auto weight = cnode_inputs[1]; + if (!weight->isa()) { + return nullptr; + } + MS_LOG(INFO) << "Start to search the weight params:" << weight->DebugString(); + + AnfNodePtr node = GetChildCastNode(weight, node_users_map); + + if (!node) { + return nullptr; + } + // get the output dtype of the operator + auto node_type = node->Type(); + if (!CheckTensorType(node_type)) { + return nullptr; + } + auto input_element_type = node_type->cast()->element(); + MS_EXCEPTION_IF_NULL(input_element_type); + auto source_node_type = node_ptr->Type(); + if (!CheckTensorType(source_node_type)) { + return nullptr; + } + auto source_element_type = source_node_type->cast()->element(); + MS_EXCEPTION_IF_NULL(input_element_type); + // We only add cast operation when the source is fp32 type, and the users is fp16 type. + if (source_element_type->type_id() == kNumberTypeFloat32 && input_element_type->type_id() == kNumberTypeFloat16) { + return input_element_type; + } + return nullptr; +} + +// Create a cast node given the current node and the previous node. The target type of the the cast is from the +// compute_node_type. +// Return the new cast node with pre_node as the inputs. +AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, const NodeUsersMap &node_user_map, + const TypePtr &compute_node_type) { + const char kOpsFunctionModelName[] = "mindspore.ops.functional"; + static py::object cast_prim = parse::python_adapter::GetPyFn(kOpsFunctionModelName, "cast"); + const auto &adapter = py::cast(cast_prim); + MS_EXCEPTION_IF_NULL(adapter); + MS_EXCEPTION_IF_NULL(compute_node_type); + auto prim = adapter->attached_primitive(); + if (prim == nullptr) { + prim = std::make_shared(cast_prim, adapter); + } + // Insert cast. + auto type_node = NewValueNode(compute_node_type); + type_node->set_abstract(compute_node_type->ToAbstract()); + auto new_node = node->func_graph()->NewCNode({NewValueNode(prim), pre_node, type_node}); + new_node->set_abstract(node->abstract()); + return new_node; +} // namespace parallel } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel_utils.h b/mindspore/ccsrc/frontend/parallel/step_parallel_utils.h index 3dfa5b759e5..101e1277d83 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel_utils.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel_utils.h @@ -25,6 +25,8 @@ namespace mindspore { namespace parallel { +const int64_t TWO_INPUT_SIZE = 2; + bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name); bool IsParallelCareNode(const CNodePtr &cnode); Shapes GetNodeShape(const AnfNodePtr &node); @@ -34,6 +36,10 @@ void SetCommunicationOpGroupLabel(std::vector new_node_input); std::vector ReplaceOpInput(const Operator &replace_op, const std::string &instance_name, const CNodePtr &node); void SetStridedSliceSplitStrategy(const std::vector &all_nodes); +AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, const NodeUsersMap &node_user_map, + const TypePtr &compute_node_type); +AnfNodePtr GetChildCastNode(const CNodePtr &cnode_ptr, const NodeUsersMap &node_users_map); +TypePtr FindChildCastWithFP32ToFP16(const CNodePtr &cnode_ptr, const NodeUsersMap &node_users_map); } // namespace parallel } // namespace mindspore