forked from mindspore-Ecosystem/mindspore
!27900 Add forcement Cast For Float16
Merge pull request !27900 from huangxinjing/second_version_add_cast
This commit is contained in:
commit
41bf56ab6c
|
@ -48,6 +48,7 @@
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
#include "utils/symbolic.h"
|
#include "utils/symbolic.h"
|
||||||
#include "mindspore/core/utils/parallel_node_check.h"
|
#include "mindspore/core/utils/parallel_node_check.h"
|
||||||
|
#include "mindspore/ccsrc/pybind_api/ir/primitive_py.h"
|
||||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||||
#include "ps/util.h"
|
#include "ps/util.h"
|
||||||
#include "ps/ps_context.h"
|
#include "ps/ps_context.h"
|
||||||
|
@ -203,12 +204,20 @@ void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const An
|
||||||
const FuncGraphPtr &root = nullptr) {
|
const FuncGraphPtr &root = nullptr) {
|
||||||
// insert new node before the node
|
// insert new node before the node
|
||||||
FuncGraphManagerPtr manager = func_graph->manager();
|
FuncGraphManagerPtr manager = func_graph->manager();
|
||||||
|
auto node_user_map = manager->node_users();
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
ScopePtr scope = node->scope();
|
ScopePtr scope = node->scope();
|
||||||
MS_EXCEPTION_IF_NULL(scope);
|
MS_EXCEPTION_IF_NULL(scope);
|
||||||
std::vector<AnfNodePtr> node_input;
|
std::vector<AnfNodePtr> node_input;
|
||||||
|
AnfNodePtr pre_node_ = pre_node;
|
||||||
if (root && !param_name.empty()) {
|
if (root && !param_name.empty()) {
|
||||||
node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name);
|
TypePtr next_node_dtype = FindChildCastWithFP32ToFP16(node, node_user_map);
|
||||||
|
if (next_node_dtype) {
|
||||||
|
MS_LOG(INFO) << "Inserting Cast from float32 to float16 for node " << node->fullname_with_scope() << " for saving"
|
||||||
|
<< " communication.";
|
||||||
|
pre_node_ = CreateFP16Cast(node, pre_node, node_user_map, next_node_dtype);
|
||||||
|
}
|
||||||
|
node_input = CreateMirrorInput(root, op, pre_node_, instance_name, param_name);
|
||||||
} else {
|
} else {
|
||||||
node_input = CreateInput(op, pre_node, instance_name);
|
node_input = CreateInput(op, pre_node, instance_name);
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,8 @@
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <queue>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
#include "utils/hash_map.h"
|
#include "utils/hash_map.h"
|
||||||
#include "base/core_ops.h"
|
#include "base/core_ops.h"
|
||||||
|
@ -57,6 +59,10 @@ bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) {
|
||||||
return (prim->name() == name);
|
return (prim->name() == name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsInNodeList(const CNodePtr &cnode, const std::set<string> &check_list) {
|
||||||
|
return std::any_of(check_list.begin(), check_list.end(), [cnode](string in) { return IsSomePrimitive(cnode, in); });
|
||||||
|
}
|
||||||
|
|
||||||
bool IsParallelCareNode(const CNodePtr &cnode) {
|
bool IsParallelCareNode(const CNodePtr &cnode) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
ValueNodePtr prim_node = cnode->input(0)->cast<ValueNodePtr>();
|
ValueNodePtr prim_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||||
|
@ -301,5 +307,121 @@ void SetStridedSliceSplitStrategy(const std::vector<AnfNodePtr> &all_nodes) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check the given tensor, return nullptr if the given type is not an TensorType
|
||||||
|
bool CheckTensorType(const TypePtr &node_type) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node_type);
|
||||||
|
if (!node_type->isa<mindspore::TensorType>()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// For the weight used by cast and matmul at the same time, like the followings
|
||||||
|
// weight1->mirror->cast1-> matmul1;
|
||||||
|
// weight1->add
|
||||||
|
// we will not insert the cast(FP32->FP16), as it will cause the input of the operator add to be changed to fp16.
|
||||||
|
AnfNodePtr GetChildCastNode(const AnfNodePtr &node_ptr, const NodeUsersMap &node_users_map) {
|
||||||
|
std::queue<AnfNodePtr> visited;
|
||||||
|
AnfNodePtr queue_node = nullptr;
|
||||||
|
CNodePtr cnode = nullptr;
|
||||||
|
AnfNodePtr node = nullptr;
|
||||||
|
if (!node_ptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto users = node_users_map.at(node_ptr);
|
||||||
|
for (auto &node_user : users) {
|
||||||
|
if (node_user.first) {
|
||||||
|
visited.push(node_user.first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
while (!visited.empty()) {
|
||||||
|
queue_node = visited.front();
|
||||||
|
visited.pop();
|
||||||
|
cnode = queue_node->cast<CNodePtr>();
|
||||||
|
if (!cnode || !cnode->in_forward_flag()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (IsInAllGatherNodeList(cnode) || IsInNodeList(cnode, {LOAD, RESHAPE, DEPEND, UPDATESTATE, MAKE_TUPLE})) {
|
||||||
|
auto node_set = node_users_map.at(queue_node);
|
||||||
|
for (auto &node_user : node_set) {
|
||||||
|
visited.push(node_user.first);
|
||||||
|
}
|
||||||
|
} else if (!IsSomePrimitive(cnode, CAST)) {
|
||||||
|
MS_LOG(INFO) << "The weight's users including the non cast node So "
|
||||||
|
<< "will not insert cast for this parameter " << node_ptr->DebugString();
|
||||||
|
return nullptr;
|
||||||
|
} else if (!node) {
|
||||||
|
node = queue_node;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
// Given the cnode ptr, find its users until we find the computation node, then return the type of the
|
||||||
|
// computation node. This function is used to find the target type for CreateFP16Cast. Only returns the target type if
|
||||||
|
// it is float16, and the source node is float32. If the situation is not matched, then return the nullptr.
|
||||||
|
TypePtr FindChildCastWithFP32ToFP16(const CNodePtr &cnode_ptr, const NodeUsersMap &node_users_map) {
|
||||||
|
auto node_ptr = cnode_ptr->cast<AnfNodePtr>();
|
||||||
|
if (!node_ptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto cnode_inputs = cnode_ptr->inputs();
|
||||||
|
if (cnode_inputs.size() < TWO_INPUT_SIZE) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
// As we execute the function IsWeightValidUsed when we start to insert the mirror, so the second parameter
|
||||||
|
// is always the parameter.
|
||||||
|
auto weight = cnode_inputs[1];
|
||||||
|
if (!weight->isa<Parameter>()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Start to search the weight params:" << weight->DebugString();
|
||||||
|
|
||||||
|
AnfNodePtr node = GetChildCastNode(weight, node_users_map);
|
||||||
|
|
||||||
|
if (!node) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
// get the output dtype of the operator
|
||||||
|
auto node_type = node->Type();
|
||||||
|
if (!CheckTensorType(node_type)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto input_element_type = node_type->cast<mindspore::TensorTypePtr>()->element();
|
||||||
|
MS_EXCEPTION_IF_NULL(input_element_type);
|
||||||
|
auto source_node_type = node_ptr->Type();
|
||||||
|
if (!CheckTensorType(source_node_type)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto source_element_type = source_node_type->cast<mindspore::TensorTypePtr>()->element();
|
||||||
|
MS_EXCEPTION_IF_NULL(input_element_type);
|
||||||
|
// We only add cast operation when the source is fp32 type, and the users is fp16 type.
|
||||||
|
if (source_element_type->type_id() == kNumberTypeFloat32 && input_element_type->type_id() == kNumberTypeFloat16) {
|
||||||
|
return input_element_type;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a cast node given the current node and the previous node. The target type of the the cast is from the
|
||||||
|
// compute_node_type.
|
||||||
|
// Return the new cast node with pre_node as the inputs.
|
||||||
|
AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, const NodeUsersMap &node_user_map,
|
||||||
|
const TypePtr &compute_node_type) {
|
||||||
|
const char kOpsFunctionModelName[] = "mindspore.ops.functional";
|
||||||
|
static py::object cast_prim = parse::python_adapter::GetPyFn(kOpsFunctionModelName, "cast");
|
||||||
|
const auto &adapter = py::cast<PrimitivePyAdapterPtr>(cast_prim);
|
||||||
|
MS_EXCEPTION_IF_NULL(adapter);
|
||||||
|
MS_EXCEPTION_IF_NULL(compute_node_type);
|
||||||
|
auto prim = adapter->attached_primitive();
|
||||||
|
if (prim == nullptr) {
|
||||||
|
prim = std::make_shared<PrimitivePy>(cast_prim, adapter);
|
||||||
|
}
|
||||||
|
// Insert cast.
|
||||||
|
auto type_node = NewValueNode(compute_node_type);
|
||||||
|
type_node->set_abstract(compute_node_type->ToAbstract());
|
||||||
|
auto new_node = node->func_graph()->NewCNode({NewValueNode(prim), pre_node, type_node});
|
||||||
|
new_node->set_abstract(node->abstract());
|
||||||
|
return new_node;
|
||||||
|
} // namespace parallel
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -25,6 +25,8 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
|
const int64_t TWO_INPUT_SIZE = 2;
|
||||||
|
|
||||||
bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name);
|
bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name);
|
||||||
bool IsParallelCareNode(const CNodePtr &cnode);
|
bool IsParallelCareNode(const CNodePtr &cnode);
|
||||||
Shapes GetNodeShape(const AnfNodePtr &node);
|
Shapes GetNodeShape(const AnfNodePtr &node);
|
||||||
|
@ -34,6 +36,10 @@ void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input);
|
||||||
std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
|
std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
|
||||||
const CNodePtr &node);
|
const CNodePtr &node);
|
||||||
void SetStridedSliceSplitStrategy(const std::vector<AnfNodePtr> &all_nodes);
|
void SetStridedSliceSplitStrategy(const std::vector<AnfNodePtr> &all_nodes);
|
||||||
|
AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, const NodeUsersMap &node_user_map,
|
||||||
|
const TypePtr &compute_node_type);
|
||||||
|
AnfNodePtr GetChildCastNode(const CNodePtr &cnode_ptr, const NodeUsersMap &node_users_map);
|
||||||
|
TypePtr FindChildCastWithFP32ToFP16(const CNodePtr &cnode_ptr, const NodeUsersMap &node_users_map);
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue