!2639 [ad] improve the performance of ad

Merge pull request !2639 from Kang/master
This commit is contained in:
mindspore-ci-bot 2020-06-29 10:04:51 +08:00 committed by Gitee
commit ef8cdbd82e
4 changed files with 32 additions and 32 deletions

View File

@ -424,7 +424,6 @@ AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
} }
auto k_prim = g_k_prims.KPrimitive(value_node, resources_); auto k_prim = g_k_prims.KPrimitive(value_node, resources_);
if (k_prim != nullptr) { if (k_prim != nullptr) {
k_prim = BasicClone(k_prim);
return NewValueNode(k_prim); return NewValueNode(k_prim);
} }
// When failed to find k_prim, try k_meta. // When failed to find k_prim, try k_meta.

View File

@ -47,12 +47,12 @@ struct PrimitiveTotalEqual {
return false; return false;
} }
for (auto &attr : attrs1) { for (auto &attr1 : attrs1) {
if (!t2->HasAttr(attr.first)) { if (!t2->HasAttr(attr1.first)) {
return false; return false;
} }
if (!(*(attr.second) == *(t2->GetAttr(attr.first)))) { if (!(*(attr1.second) == *(t2->GetAttr(attr1.first)))) {
return false; return false;
} }
} }
@ -61,7 +61,7 @@ struct PrimitiveTotalEqual {
} }
}; };
using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr, PrimitiveHasher>; using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr, PrimitiveHasher, PrimitiveTotalEqual>;
class KPrim; class KPrim;
extern KPrim g_k_prims; extern KPrim g_k_prims;
class DFunctor; class DFunctor;

View File

@ -96,34 +96,37 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
MS_LOG(EXCEPTION) << "Primitive node is not valid."; MS_LOG(EXCEPTION) << "Primitive node is not valid.";
} }
auto prim = value_node->value()->cast<PrimitivePtr>(); auto prim = GetValueNode<PrimitivePtr>(value_node);
MS_EXCEPTION_IF_NULL(prim); if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == prim::kPrimSwitchLayer->name()) {
auto iter = bprop_registry_.find(prim);
if (iter != bprop_registry_.end()) {
return iter->second;
}
if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == "switch_layer") {
auto fprop = GetFprop(prim); auto fprop = GetFprop(prim);
fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer)); fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer));
bprop_registry_[prim::kPrimSwitchLayer] = fprop;
return fprop; return fprop;
} } else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
return nullptr; return nullptr;
} }
bool is_faked_bprop = false;
FuncGraphPtr bprop_fg = nullptr; FuncGraphPtr bprop_fg = nullptr;
if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == "HookBackward") { auto iter = bprop_registry_.find(prim);
bprop_fg = BpropCut(value_node, resources); if (iter != bprop_registry_.end()) {
} else { bprop_fg = iter->second;
bprop_fg = GetBprop(prim); }
if (bprop_fg == nullptr) {
bprop_fg = FakeBprop(value_node, resources); if (bprop_fg == nullptr) {
is_faked_bprop = true; bool is_faked_bprop = false;
if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) {
bprop_fg = BpropCut(value_node, resources);
} else {
bprop_fg = GetBprop(prim);
if (bprop_fg == nullptr) {
bprop_fg = FakeBprop(value_node, resources);
is_faked_bprop = true;
}
}
// To support primitives with variable params, do not cache faked bprop
if (!is_faked_bprop) {
// Set bprop_g graph cache
bprop_registry_[prim] = bprop_fg;
} }
} }
@ -134,11 +137,6 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
<< trace::GetDebugInfo(bprop_fg->debug_info()); << trace::GetDebugInfo(bprop_fg->debug_info());
} }
// To support primitives with variable params, do not cache faked bprop
if (!is_faked_bprop) {
// Set bprop_g graph cache
bprop_registry_[prim] = expanded_fg;
}
return expanded_fg; return expanded_fg;
} }

View File

@ -38,7 +38,10 @@ TEST_F(TestGradImplementations, TestGetAugmentedGraph) {
draw::Draw("gradImpl_TestGetAugmentedFuncGraph.dot", fg); draw::Draw("gradImpl_TestGetAugmentedFuncGraph.dot", fg);
auto fg1 = ad::g_k_prims.KPrimitive(NewValueNode(kPrimScalarMul), nullptr); auto fg1 = ad::g_k_prims.KPrimitive(NewValueNode(kPrimScalarMul), nullptr);
ASSERT_TRUE(fg == fg1);
FuncGraphPairMapEquiv equiv_graph;
NodeMapEquiv equiv_node;
ASSERT_TRUE(Isomorphic(fg, fg1, &equiv_graph, &equiv_node));
} }
} // namespace prim } // namespace prim