forked from mindspore-Ecosystem/mindspore
!27900 Add forcement Cast For Float16
Merge pull request !27900 from huangxinjing/second_version_add_cast
This commit is contained in:
commit
41bf56ab6c
|
@ -48,6 +48,7 @@
|
|||
#include "utils/ms_context.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "mindspore/core/utils/parallel_node_check.h"
|
||||
#include "mindspore/ccsrc/pybind_api/ir/primitive_py.h"
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
#include "ps/util.h"
|
||||
#include "ps/ps_context.h"
|
||||
|
@ -203,12 +204,20 @@ void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const An
|
|||
const FuncGraphPtr &root = nullptr) {
|
||||
// insert new node before the node
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
auto node_user_map = manager->node_users();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
ScopePtr scope = node->scope();
|
||||
MS_EXCEPTION_IF_NULL(scope);
|
||||
std::vector<AnfNodePtr> node_input;
|
||||
AnfNodePtr pre_node_ = pre_node;
|
||||
if (root && !param_name.empty()) {
|
||||
node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name);
|
||||
TypePtr next_node_dtype = FindChildCastWithFP32ToFP16(node, node_user_map);
|
||||
if (next_node_dtype) {
|
||||
MS_LOG(INFO) << "Inserting Cast from float32 to float16 for node " << node->fullname_with_scope() << " for saving"
|
||||
<< " communication.";
|
||||
pre_node_ = CreateFP16Cast(node, pre_node, node_user_map, next_node_dtype);
|
||||
}
|
||||
node_input = CreateMirrorInput(root, op, pre_node_, instance_name, param_name);
|
||||
} else {
|
||||
node_input = CreateInput(op, pre_node, instance_name);
|
||||
}
|
||||
|
|
|
@ -24,6 +24,8 @@
|
|||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <queue>
|
||||
#include <memory>
|
||||
|
||||
#include "utils/hash_map.h"
|
||||
#include "base/core_ops.h"
|
||||
|
@ -57,6 +59,10 @@ bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) {
|
|||
return (prim->name() == name);
|
||||
}
|
||||
|
||||
bool IsInNodeList(const CNodePtr &cnode, const std::set<string> &check_list) {
|
||||
return std::any_of(check_list.begin(), check_list.end(), [cnode](string in) { return IsSomePrimitive(cnode, in); });
|
||||
}
|
||||
|
||||
bool IsParallelCareNode(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
ValueNodePtr prim_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
|
@ -301,5 +307,121 @@ void SetStridedSliceSplitStrategy(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check the given tensor, return nullptr if the given type is not an TensorType
|
||||
bool CheckTensorType(const TypePtr &node_type) {
|
||||
MS_EXCEPTION_IF_NULL(node_type);
|
||||
if (!node_type->isa<mindspore::TensorType>()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// For the weight used by cast and matmul at the same time, like the followings
|
||||
// weight1->mirror->cast1-> matmul1;
|
||||
// weight1->add
|
||||
// we will not insert the cast(FP32->FP16), as it will cause the input of the operator add to be changed to fp16.
|
||||
AnfNodePtr GetChildCastNode(const AnfNodePtr &node_ptr, const NodeUsersMap &node_users_map) {
|
||||
std::queue<AnfNodePtr> visited;
|
||||
AnfNodePtr queue_node = nullptr;
|
||||
CNodePtr cnode = nullptr;
|
||||
AnfNodePtr node = nullptr;
|
||||
if (!node_ptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto users = node_users_map.at(node_ptr);
|
||||
for (auto &node_user : users) {
|
||||
if (node_user.first) {
|
||||
visited.push(node_user.first);
|
||||
}
|
||||
}
|
||||
while (!visited.empty()) {
|
||||
queue_node = visited.front();
|
||||
visited.pop();
|
||||
cnode = queue_node->cast<CNodePtr>();
|
||||
if (!cnode || !cnode->in_forward_flag()) {
|
||||
continue;
|
||||
}
|
||||
if (IsInAllGatherNodeList(cnode) || IsInNodeList(cnode, {LOAD, RESHAPE, DEPEND, UPDATESTATE, MAKE_TUPLE})) {
|
||||
auto node_set = node_users_map.at(queue_node);
|
||||
for (auto &node_user : node_set) {
|
||||
visited.push(node_user.first);
|
||||
}
|
||||
} else if (!IsSomePrimitive(cnode, CAST)) {
|
||||
MS_LOG(INFO) << "The weight's users including the non cast node So "
|
||||
<< "will not insert cast for this parameter " << node_ptr->DebugString();
|
||||
return nullptr;
|
||||
} else if (!node) {
|
||||
node = queue_node;
|
||||
}
|
||||
}
|
||||
return node;
|
||||
}
|
||||
// Given the cnode ptr, find its users until we find the computation node, then return the type of the
|
||||
// computation node. This function is used to find the target type for CreateFP16Cast. Only returns the target type if
|
||||
// it is float16, and the source node is float32. If the situation is not matched, then return the nullptr.
|
||||
TypePtr FindChildCastWithFP32ToFP16(const CNodePtr &cnode_ptr, const NodeUsersMap &node_users_map) {
|
||||
auto node_ptr = cnode_ptr->cast<AnfNodePtr>();
|
||||
if (!node_ptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode_inputs = cnode_ptr->inputs();
|
||||
if (cnode_inputs.size() < TWO_INPUT_SIZE) {
|
||||
return nullptr;
|
||||
}
|
||||
// As we execute the function IsWeightValidUsed when we start to insert the mirror, so the second parameter
|
||||
// is always the parameter.
|
||||
auto weight = cnode_inputs[1];
|
||||
if (!weight->isa<Parameter>()) {
|
||||
return nullptr;
|
||||
}
|
||||
MS_LOG(INFO) << "Start to search the weight params:" << weight->DebugString();
|
||||
|
||||
AnfNodePtr node = GetChildCastNode(weight, node_users_map);
|
||||
|
||||
if (!node) {
|
||||
return nullptr;
|
||||
}
|
||||
// get the output dtype of the operator
|
||||
auto node_type = node->Type();
|
||||
if (!CheckTensorType(node_type)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto input_element_type = node_type->cast<mindspore::TensorTypePtr>()->element();
|
||||
MS_EXCEPTION_IF_NULL(input_element_type);
|
||||
auto source_node_type = node_ptr->Type();
|
||||
if (!CheckTensorType(source_node_type)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto source_element_type = source_node_type->cast<mindspore::TensorTypePtr>()->element();
|
||||
MS_EXCEPTION_IF_NULL(input_element_type);
|
||||
// We only add cast operation when the source is fp32 type, and the users is fp16 type.
|
||||
if (source_element_type->type_id() == kNumberTypeFloat32 && input_element_type->type_id() == kNumberTypeFloat16) {
|
||||
return input_element_type;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Create a cast node given the current node and the previous node. The target type of the the cast is from the
|
||||
// compute_node_type.
|
||||
// Return the new cast node with pre_node as the inputs.
|
||||
AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, const NodeUsersMap &node_user_map,
|
||||
const TypePtr &compute_node_type) {
|
||||
const char kOpsFunctionModelName[] = "mindspore.ops.functional";
|
||||
static py::object cast_prim = parse::python_adapter::GetPyFn(kOpsFunctionModelName, "cast");
|
||||
const auto &adapter = py::cast<PrimitivePyAdapterPtr>(cast_prim);
|
||||
MS_EXCEPTION_IF_NULL(adapter);
|
||||
MS_EXCEPTION_IF_NULL(compute_node_type);
|
||||
auto prim = adapter->attached_primitive();
|
||||
if (prim == nullptr) {
|
||||
prim = std::make_shared<PrimitivePy>(cast_prim, adapter);
|
||||
}
|
||||
// Insert cast.
|
||||
auto type_node = NewValueNode(compute_node_type);
|
||||
type_node->set_abstract(compute_node_type->ToAbstract());
|
||||
auto new_node = node->func_graph()->NewCNode({NewValueNode(prim), pre_node, type_node});
|
||||
new_node->set_abstract(node->abstract());
|
||||
return new_node;
|
||||
} // namespace parallel
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -25,6 +25,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
const int64_t TWO_INPUT_SIZE = 2;
|
||||
|
||||
bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name);
|
||||
bool IsParallelCareNode(const CNodePtr &cnode);
|
||||
Shapes GetNodeShape(const AnfNodePtr &node);
|
||||
|
@ -34,6 +36,10 @@ void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input);
|
|||
std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
|
||||
const CNodePtr &node);
|
||||
void SetStridedSliceSplitStrategy(const std::vector<AnfNodePtr> &all_nodes);
|
||||
AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, const NodeUsersMap &node_user_map,
|
||||
const TypePtr &compute_node_type);
|
||||
AnfNodePtr GetChildCastNode(const CNodePtr &cnode_ptr, const NodeUsersMap &node_users_map);
|
||||
TypePtr FindChildCastWithFP32ToFP16(const CNodePtr &cnode_ptr, const NodeUsersMap &node_users_map);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
Loading…
Reference in New Issue