forked from mindspore-Ecosystem/mindspore
!2639 [ad] improve the performance of ad
Merge pull request !2639 from Kang/master
This commit is contained in:
commit
ef8cdbd82e
|
@ -424,7 +424,6 @@ AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
|
||||||
}
|
}
|
||||||
auto k_prim = g_k_prims.KPrimitive(value_node, resources_);
|
auto k_prim = g_k_prims.KPrimitive(value_node, resources_);
|
||||||
if (k_prim != nullptr) {
|
if (k_prim != nullptr) {
|
||||||
k_prim = BasicClone(k_prim);
|
|
||||||
return NewValueNode(k_prim);
|
return NewValueNode(k_prim);
|
||||||
}
|
}
|
||||||
// When failed to find k_prim, try k_meta.
|
// When failed to find k_prim, try k_meta.
|
||||||
|
|
|
@ -47,12 +47,12 @@ struct PrimitiveTotalEqual {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto &attr : attrs1) {
|
for (auto &attr1 : attrs1) {
|
||||||
if (!t2->HasAttr(attr.first)) {
|
if (!t2->HasAttr(attr1.first)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!(*(attr.second) == *(t2->GetAttr(attr.first)))) {
|
if (!(*(attr1.second) == *(t2->GetAttr(attr1.first)))) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -61,7 +61,7 @@ struct PrimitiveTotalEqual {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr, PrimitiveHasher>;
|
using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr, PrimitiveHasher, PrimitiveTotalEqual>;
|
||||||
class KPrim;
|
class KPrim;
|
||||||
extern KPrim g_k_prims;
|
extern KPrim g_k_prims;
|
||||||
class DFunctor;
|
class DFunctor;
|
||||||
|
|
|
@ -96,34 +96,37 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
|
||||||
MS_LOG(EXCEPTION) << "Primitive node is not valid.";
|
MS_LOG(EXCEPTION) << "Primitive node is not valid.";
|
||||||
}
|
}
|
||||||
|
|
||||||
auto prim = value_node->value()->cast<PrimitivePtr>();
|
auto prim = GetValueNode<PrimitivePtr>(value_node);
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == prim::kPrimSwitchLayer->name()) {
|
||||||
|
|
||||||
auto iter = bprop_registry_.find(prim);
|
|
||||||
if (iter != bprop_registry_.end()) {
|
|
||||||
return iter->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == "switch_layer") {
|
|
||||||
auto fprop = GetFprop(prim);
|
auto fprop = GetFprop(prim);
|
||||||
fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer));
|
fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer));
|
||||||
bprop_registry_[prim::kPrimSwitchLayer] = fprop;
|
|
||||||
return fprop;
|
return fprop;
|
||||||
}
|
} else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
|
||||||
|
|
||||||
if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_faked_bprop = false;
|
|
||||||
FuncGraphPtr bprop_fg = nullptr;
|
FuncGraphPtr bprop_fg = nullptr;
|
||||||
if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == "HookBackward") {
|
auto iter = bprop_registry_.find(prim);
|
||||||
bprop_fg = BpropCut(value_node, resources);
|
if (iter != bprop_registry_.end()) {
|
||||||
} else {
|
bprop_fg = iter->second;
|
||||||
bprop_fg = GetBprop(prim);
|
}
|
||||||
if (bprop_fg == nullptr) {
|
|
||||||
bprop_fg = FakeBprop(value_node, resources);
|
if (bprop_fg == nullptr) {
|
||||||
is_faked_bprop = true;
|
bool is_faked_bprop = false;
|
||||||
|
if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) {
|
||||||
|
bprop_fg = BpropCut(value_node, resources);
|
||||||
|
} else {
|
||||||
|
bprop_fg = GetBprop(prim);
|
||||||
|
if (bprop_fg == nullptr) {
|
||||||
|
bprop_fg = FakeBprop(value_node, resources);
|
||||||
|
is_faked_bprop = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// To support primitives with variable params, do not cache faked bprop
|
||||||
|
if (!is_faked_bprop) {
|
||||||
|
// Set bprop_g graph cache
|
||||||
|
bprop_registry_[prim] = bprop_fg;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -134,11 +137,6 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
|
||||||
<< trace::GetDebugInfo(bprop_fg->debug_info());
|
<< trace::GetDebugInfo(bprop_fg->debug_info());
|
||||||
}
|
}
|
||||||
|
|
||||||
// To support primitives with variable params, do not cache faked bprop
|
|
||||||
if (!is_faked_bprop) {
|
|
||||||
// Set bprop_g graph cache
|
|
||||||
bprop_registry_[prim] = expanded_fg;
|
|
||||||
}
|
|
||||||
return expanded_fg;
|
return expanded_fg;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,10 @@ TEST_F(TestGradImplementations, TestGetAugmentedGraph) {
|
||||||
draw::Draw("gradImpl_TestGetAugmentedFuncGraph.dot", fg);
|
draw::Draw("gradImpl_TestGetAugmentedFuncGraph.dot", fg);
|
||||||
|
|
||||||
auto fg1 = ad::g_k_prims.KPrimitive(NewValueNode(kPrimScalarMul), nullptr);
|
auto fg1 = ad::g_k_prims.KPrimitive(NewValueNode(kPrimScalarMul), nullptr);
|
||||||
ASSERT_TRUE(fg == fg1);
|
|
||||||
|
FuncGraphPairMapEquiv equiv_graph;
|
||||||
|
NodeMapEquiv equiv_node;
|
||||||
|
ASSERT_TRUE(Isomorphic(fg, fg1, &equiv_graph, &equiv_node));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace prim
|
} // namespace prim
|
||||||
|
|
Loading…
Reference in New Issue