forked from mindspore-Ecosystem/mindspore
!2223 Make those AdamXX and LambXX fusion pass not work for unexpect data type
Merge pull request !2223 from huanghui/TMP
This commit is contained in:
commit
3c1b8308cf
|
@ -109,6 +109,9 @@ const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, con
|
|||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto new_node = CreateAdamApplyOneNode(func_graph, equiv);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
new_node->set_scope(node->scope());
|
||||
|
|
|
@ -146,7 +146,9 @@ const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, c
|
|||
if (graph == nullptr || node == nullptr || equiv == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<AnfNodePtr> inputs = GetFusionNodeInputs(equiv);
|
||||
auto fusion_node = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(fusion_node);
|
||||
|
|
|
@ -108,6 +108,9 @@ bool LambNextMVRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2
|
|||
|
||||
const AnfNodePtr LambNextMVRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<AnfNodePtr> old_pattern_outputs;
|
||||
if (!IsRuleMatched(func_graph, node, equiv, &old_pattern_outputs)) {
|
||||
return nullptr;
|
||||
|
|
|
@ -88,6 +88,9 @@ const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph
|
|||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
|
||||
return nullptr;
|
||||
}
|
||||
AnfNodePtr mul4 = GetAnfNodeByVar(equiv, mul4_var_);
|
||||
MS_EXCEPTION_IF_NULL(mul4);
|
||||
// Get add3 and match the add3 pattern
|
||||
|
|
|
@ -153,6 +153,9 @@ const AnfNodePtr LambNextMVWithDecayV1Rule::Process(const FuncGraphPtr &func_gra
|
|||
if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
|
||||
return nullptr;
|
||||
}
|
||||
AnfNodePtr mul4 = nullptr;
|
||||
AnfNodePtr real_div0 = nullptr;
|
||||
AnfNodePtr real_div1 = nullptr;
|
||||
|
|
|
@ -61,6 +61,9 @@ const AnfNodePtr LambNextRightRule::Process(const FuncGraphPtr &func_graph, cons
|
|||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto new_node = CreateLambNextRightNode(func_graph, equiv);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
// Set abstract of new node
|
||||
|
|
|
@ -50,6 +50,9 @@ const AnfNodePtr LambUpdateWithLRRuleFusion::Process(const FuncGraphPtr &graph,
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto input0 = utils::cast<AnfNodePtr>((*equiv)[input0_]);
|
||||
auto input1 = utils::cast<AnfNodePtr>((*equiv)[input1_]);
|
||||
auto input2 = utils::cast<AnfNodePtr>((*equiv)[input2_]);
|
||||
|
|
|
@ -42,6 +42,9 @@ const AnfNodePtr LambUpdateWithLrV2::Process(const FuncGraphPtr &func_graph, con
|
|||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto prim = std::make_shared<Primitive>(kLambUpdateWithLrV2OpName);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
|
||||
(void)std::transform(input_varptr_.begin(), input_varptr_.end(), std::back_inserter(inputs),
|
||||
|
|
|
@ -765,5 +765,15 @@ bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name) {
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
return AnfAlgo::HasNodeAttr(attr_name, cnode) && AnfAlgo::GetNodeAttr<bool>(node, attr_name);
|
||||
}
|
||||
|
||||
bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &supported_data_type_set) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
TypeId data_type = AnfAlgo::GetOutputInferDataType(node, 0);
|
||||
if (supported_data_type_set.find(data_type) != supported_data_type_set.end()) {
|
||||
return true;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Not supported data type. Node:" << node->DebugString();
|
||||
return false;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <unordered_set>
|
||||
#include "ir/func_graph.h"
|
||||
#include "session/kernel_graph.h"
|
||||
|
@ -189,6 +190,9 @@ bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2);
|
|||
|
||||
// Get attr which is bool from cnode
|
||||
bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name);
|
||||
|
||||
// Check node's data type is in supported data type set
|
||||
bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &supported_data_type_set);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include <set>
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
#include "ir/dtype/type.h"
|
||||
|
||||
namespace mindspore {
|
||||
// op name. Op which not exists in operator/ops.h, so define it's name here
|
||||
|
@ -270,6 +271,8 @@ const std::set<std::string> kHWSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFo
|
|||
kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04,
|
||||
kOpFormat_FRACTAL_Z_C04};
|
||||
|
||||
const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32};
|
||||
|
||||
static inline void ChangeFileMode(const std::string &file_name, mode_t mode) {
|
||||
try {
|
||||
if (chmod(file_name.c_str(), mode) != 0) {
|
||||
|
|
Loading…
Reference in New Issue