forked from mindspore-Ecosystem/mindspore
!2639 [ad] improve the performance of ad
Merge pull request !2639 from Kang/master
This commit is contained in:
commit
ef8cdbd82e
|
@ -424,7 +424,6 @@ AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
|
|||
}
|
||||
auto k_prim = g_k_prims.KPrimitive(value_node, resources_);
|
||||
if (k_prim != nullptr) {
|
||||
k_prim = BasicClone(k_prim);
|
||||
return NewValueNode(k_prim);
|
||||
}
|
||||
// When failed to find k_prim, try k_meta.
|
||||
|
|
|
@ -47,12 +47,12 @@ struct PrimitiveTotalEqual {
|
|||
return false;
|
||||
}
|
||||
|
||||
for (auto &attr : attrs1) {
|
||||
if (!t2->HasAttr(attr.first)) {
|
||||
for (auto &attr1 : attrs1) {
|
||||
if (!t2->HasAttr(attr1.first)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!(*(attr.second) == *(t2->GetAttr(attr.first)))) {
|
||||
if (!(*(attr1.second) == *(t2->GetAttr(attr1.first)))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -61,7 +61,7 @@ struct PrimitiveTotalEqual {
|
|||
}
|
||||
};
|
||||
|
||||
using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr, PrimitiveHasher>;
|
||||
using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr, PrimitiveHasher, PrimitiveTotalEqual>;
|
||||
class KPrim;
|
||||
extern KPrim g_k_prims;
|
||||
class DFunctor;
|
||||
|
|
|
@ -96,34 +96,37 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
|
|||
MS_LOG(EXCEPTION) << "Primitive node is not valid.";
|
||||
}
|
||||
|
||||
auto prim = value_node->value()->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
|
||||
auto iter = bprop_registry_.find(prim);
|
||||
if (iter != bprop_registry_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == "switch_layer") {
|
||||
auto prim = GetValueNode<PrimitivePtr>(value_node);
|
||||
if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == prim::kPrimSwitchLayer->name()) {
|
||||
auto fprop = GetFprop(prim);
|
||||
fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer));
|
||||
bprop_registry_[prim::kPrimSwitchLayer] = fprop;
|
||||
return fprop;
|
||||
}
|
||||
|
||||
if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
|
||||
} else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool is_faked_bprop = false;
|
||||
FuncGraphPtr bprop_fg = nullptr;
|
||||
if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == "HookBackward") {
|
||||
bprop_fg = BpropCut(value_node, resources);
|
||||
} else {
|
||||
bprop_fg = GetBprop(prim);
|
||||
if (bprop_fg == nullptr) {
|
||||
bprop_fg = FakeBprop(value_node, resources);
|
||||
is_faked_bprop = true;
|
||||
auto iter = bprop_registry_.find(prim);
|
||||
if (iter != bprop_registry_.end()) {
|
||||
bprop_fg = iter->second;
|
||||
}
|
||||
|
||||
if (bprop_fg == nullptr) {
|
||||
bool is_faked_bprop = false;
|
||||
if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) {
|
||||
bprop_fg = BpropCut(value_node, resources);
|
||||
} else {
|
||||
bprop_fg = GetBprop(prim);
|
||||
if (bprop_fg == nullptr) {
|
||||
bprop_fg = FakeBprop(value_node, resources);
|
||||
is_faked_bprop = true;
|
||||
}
|
||||
}
|
||||
|
||||
// To support primitives with variable params, do not cache faked bprop
|
||||
if (!is_faked_bprop) {
|
||||
// Set bprop_g graph cache
|
||||
bprop_registry_[prim] = bprop_fg;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -134,11 +137,6 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
|
|||
<< trace::GetDebugInfo(bprop_fg->debug_info());
|
||||
}
|
||||
|
||||
// To support primitives with variable params, do not cache faked bprop
|
||||
if (!is_faked_bprop) {
|
||||
// Set bprop_g graph cache
|
||||
bprop_registry_[prim] = expanded_fg;
|
||||
}
|
||||
return expanded_fg;
|
||||
}
|
||||
|
||||
|
|
|
@ -38,7 +38,10 @@ TEST_F(TestGradImplementations, TestGetAugmentedGraph) {
|
|||
draw::Draw("gradImpl_TestGetAugmentedFuncGraph.dot", fg);
|
||||
|
||||
auto fg1 = ad::g_k_prims.KPrimitive(NewValueNode(kPrimScalarMul), nullptr);
|
||||
ASSERT_TRUE(fg == fg1);
|
||||
|
||||
FuncGraphPairMapEquiv equiv_graph;
|
||||
NodeMapEquiv equiv_node;
|
||||
ASSERT_TRUE(Isomorphic(fg, fg1, &equiv_graph, &equiv_node));
|
||||
}
|
||||
|
||||
} // namespace prim
|
||||
|
|
Loading…
Reference in New Issue