!27900 Add forcement Cast For Float16

Merge pull request !27900 from huangxinjing/second_version_add_cast
This commit is contained in:
i-robot 2021-12-27 03:11:50 +00:00 committed by Gitee
commit 41bf56ab6c
3 changed files with 138 additions and 1 deletions

View File

@ -48,6 +48,7 @@
#include "utils/ms_context.h"
#include "utils/symbolic.h"
#include "mindspore/core/utils/parallel_node_check.h"
#include "mindspore/ccsrc/pybind_api/ir/primitive_py.h"
#if ((defined ENABLE_CPU) && (!defined _WIN32))
#include "ps/util.h"
#include "ps/ps_context.h"
@ -203,12 +204,20 @@ void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const An
const FuncGraphPtr &root = nullptr) {
// insert new node before the node
FuncGraphManagerPtr manager = func_graph->manager();
auto node_user_map = manager->node_users();
MS_EXCEPTION_IF_NULL(manager);
ScopePtr scope = node->scope();
MS_EXCEPTION_IF_NULL(scope);
std::vector<AnfNodePtr> node_input;
AnfNodePtr pre_node_ = pre_node;
if (root && !param_name.empty()) {
node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name);
TypePtr next_node_dtype = FindChildCastWithFP32ToFP16(node, node_user_map);
if (next_node_dtype) {
MS_LOG(INFO) << "Inserting Cast from float32 to float16 for node " << node->fullname_with_scope() << " for saving"
<< " communication.";
pre_node_ = CreateFP16Cast(node, pre_node, node_user_map, next_node_dtype);
}
node_input = CreateMirrorInput(root, op, pre_node_, instance_name, param_name);
} else {
node_input = CreateInput(op, pre_node, instance_name);
}

View File

@ -24,6 +24,8 @@
#include <set>
#include <string>
#include <utility>
#include <queue>
#include <memory>
#include "utils/hash_map.h"
#include "base/core_ops.h"
@ -57,6 +59,10 @@ bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) {
return (prim->name() == name);
}
bool IsInNodeList(const CNodePtr &cnode, const std::set<string> &check_list) {
return std::any_of(check_list.begin(), check_list.end(), [cnode](string in) { return IsSomePrimitive(cnode, in); });
}
bool IsParallelCareNode(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
ValueNodePtr prim_node = cnode->input(0)->cast<ValueNodePtr>();
@ -301,5 +307,121 @@ void SetStridedSliceSplitStrategy(const std::vector<AnfNodePtr> &all_nodes) {
}
}
}
// Check the given tensor, return nullptr if the given type is not an TensorType
bool CheckTensorType(const TypePtr &node_type) {
MS_EXCEPTION_IF_NULL(node_type);
if (!node_type->isa<mindspore::TensorType>()) {
return false;
}
return true;
}
// For the weight used by cast and matmul at the same time, like the followings
// weight1->mirror->cast1-> matmul1;
// weight1->add
// we will not insert the cast(FP32->FP16), as it will cause the input of the operator add to be changed to fp16.
AnfNodePtr GetChildCastNode(const AnfNodePtr &node_ptr, const NodeUsersMap &node_users_map) {
std::queue<AnfNodePtr> visited;
AnfNodePtr queue_node = nullptr;
CNodePtr cnode = nullptr;
AnfNodePtr node = nullptr;
if (!node_ptr) {
return nullptr;
}
auto users = node_users_map.at(node_ptr);
for (auto &node_user : users) {
if (node_user.first) {
visited.push(node_user.first);
}
}
while (!visited.empty()) {
queue_node = visited.front();
visited.pop();
cnode = queue_node->cast<CNodePtr>();
if (!cnode || !cnode->in_forward_flag()) {
continue;
}
if (IsInAllGatherNodeList(cnode) || IsInNodeList(cnode, {LOAD, RESHAPE, DEPEND, UPDATESTATE, MAKE_TUPLE})) {
auto node_set = node_users_map.at(queue_node);
for (auto &node_user : node_set) {
visited.push(node_user.first);
}
} else if (!IsSomePrimitive(cnode, CAST)) {
MS_LOG(INFO) << "The weight's users including the non cast node So "
<< "will not insert cast for this parameter " << node_ptr->DebugString();
return nullptr;
} else if (!node) {
node = queue_node;
}
}
return node;
}
// Given the cnode ptr, find its users until we find the computation node, then return the type of the
// computation node. This function is used to find the target type for CreateFP16Cast. Only returns the target type if
// it is float16, and the source node is float32. If the situation is not matched, then return the nullptr.
TypePtr FindChildCastWithFP32ToFP16(const CNodePtr &cnode_ptr, const NodeUsersMap &node_users_map) {
auto node_ptr = cnode_ptr->cast<AnfNodePtr>();
if (!node_ptr) {
return nullptr;
}
auto cnode_inputs = cnode_ptr->inputs();
if (cnode_inputs.size() < TWO_INPUT_SIZE) {
return nullptr;
}
// As we execute the function IsWeightValidUsed when we start to insert the mirror, so the second parameter
// is always the parameter.
auto weight = cnode_inputs[1];
if (!weight->isa<Parameter>()) {
return nullptr;
}
MS_LOG(INFO) << "Start to search the weight params:" << weight->DebugString();
AnfNodePtr node = GetChildCastNode(weight, node_users_map);
if (!node) {
return nullptr;
}
// get the output dtype of the operator
auto node_type = node->Type();
if (!CheckTensorType(node_type)) {
return nullptr;
}
auto input_element_type = node_type->cast<mindspore::TensorTypePtr>()->element();
MS_EXCEPTION_IF_NULL(input_element_type);
auto source_node_type = node_ptr->Type();
if (!CheckTensorType(source_node_type)) {
return nullptr;
}
auto source_element_type = source_node_type->cast<mindspore::TensorTypePtr>()->element();
MS_EXCEPTION_IF_NULL(input_element_type);
// We only add cast operation when the source is fp32 type, and the users is fp16 type.
if (source_element_type->type_id() == kNumberTypeFloat32 && input_element_type->type_id() == kNumberTypeFloat16) {
return input_element_type;
}
return nullptr;
}
// Create a cast node given the current node and the previous node. The target type of the the cast is from the
// compute_node_type.
// Return the new cast node with pre_node as the inputs.
AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, const NodeUsersMap &node_user_map,
const TypePtr &compute_node_type) {
const char kOpsFunctionModelName[] = "mindspore.ops.functional";
static py::object cast_prim = parse::python_adapter::GetPyFn(kOpsFunctionModelName, "cast");
const auto &adapter = py::cast<PrimitivePyAdapterPtr>(cast_prim);
MS_EXCEPTION_IF_NULL(adapter);
MS_EXCEPTION_IF_NULL(compute_node_type);
auto prim = adapter->attached_primitive();
if (prim == nullptr) {
prim = std::make_shared<PrimitivePy>(cast_prim, adapter);
}
// Insert cast.
auto type_node = NewValueNode(compute_node_type);
type_node->set_abstract(compute_node_type->ToAbstract());
auto new_node = node->func_graph()->NewCNode({NewValueNode(prim), pre_node, type_node});
new_node->set_abstract(node->abstract());
return new_node;
} // namespace parallel
} // namespace parallel
} // namespace mindspore

View File

@ -25,6 +25,8 @@
namespace mindspore {
namespace parallel {
const int64_t TWO_INPUT_SIZE = 2;
bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name);
bool IsParallelCareNode(const CNodePtr &cnode);
Shapes GetNodeShape(const AnfNodePtr &node);
@ -34,6 +36,10 @@ void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input);
std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
const CNodePtr &node);
void SetStridedSliceSplitStrategy(const std::vector<AnfNodePtr> &all_nodes);
AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, const NodeUsersMap &node_user_map,
const TypePtr &compute_node_type);
AnfNodePtr GetChildCastNode(const CNodePtr &cnode_ptr, const NodeUsersMap &node_users_map);
TypePtr FindChildCastWithFP32ToFP16(const CNodePtr &cnode_ptr, const NodeUsersMap &node_users_map);
} // namespace parallel
} // namespace mindspore