diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index c0242ccacbb..5d8f4185417 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -168,6 +168,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { } ResolveIRPassLib::ResolveIRPassLib() { + resolver_resolve_attr_ = + MakeSubstitution(std::make_shared(), "resolver_resolve_attr", prim::kPrimGetAttr); resolver_resolve_ = MakeSubstitution(std::make_shared(), "resolver_resolve", prim::kPrimResolve); resolver_getattr_ = MakeSubstitution(std::make_shared(), "resolver_getattr", prim::kPrimGetAttr); } diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index 31aaeac7816..0d19a9daa43 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -118,6 +118,7 @@ class ResolveIRPassLib { ResolveIRPassLib(); ~ResolveIRPassLib() = default; + SubstitutionPtr resolver_resolve_attr_; SubstitutionPtr resolver_resolve_; SubstitutionPtr resolver_getattr_; }; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.h b/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.h index 73dc395c0bf..e529c7ce046 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.h @@ -21,15 +21,21 @@ #include #include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/optimizer_caller.h" #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/anf_visitor.h" #include "frontend/operator/ops.h" +#include "ir/pattern_matcher.h" #include "pipeline/jit/parse/data_converter.h" #include "pipeline/jit/parse/python_adapter.h" +#include "pipeline/jit/parse/parse_base.h" namespace mindspore { namespace opt { namespace irpass { + +const char PARSE_SUPER_NAME[] = "namespace"; + // {prim::kPrimResolve, Ns, Sym} class ResolverResolve : public AnfVisitor { public: @@ -90,6 +96,34 @@ class ResolverGetattr : public AnfVisitor { parse::NameSpacePtr ns_{nullptr}; parse::SymbolPtr sym_{nullptr}; }; + +// {prim::kPrimGetAttr, {prim::kPrimResolve, ns_node, sym_node}, attr_node} +class ResolveAttr : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + PatternNode ns_node, sym_node, attr_node; + auto ResolveAttrLambda = [&node, &ns_node, &sym_node, &attr_node, &optimizer]() -> AnfNodePtr { + auto node_to_getattr = node->cast()->input(1); + std::string attr_as_string = GetValueNode(attr_node.GetNode(node))->value(); + + auto ns_ = GetValueNode(ns_node.GetNode(node)); + auto sym_ = GetValueNode(sym_node.GetNode(node)); + + if (ns_->module() == parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER && sym_->symbol() != PARSE_SUPER_NAME) { + // deal with the case of getting attr from a class member + // and avoid the case of getting attr from self (the result of ParseSuper) + auto result = parse::ResolveCellwithAttr(optimizer->manager(), ns_, sym_, node_to_getattr, attr_as_string); + return result; + } + return nullptr; + }; + MATCH_REPLACE_LAMBDA_IF( + node, PPrimitive(prim::kPrimGetAttr, PPrimitive(prim::kPrimResolve, ns_node, sym_node), attr_node), + ResolveAttrLambda, attr_node.CheckFunc(IsValueNode, node)); + + return nullptr; + } +}; } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc index 8d4c4026391..228c9ae1845 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc @@ -228,19 +228,10 @@ bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const Func return true; } -} // namespace -AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, - const AnfNodePtr &node) { - if (node->func_graph() == nullptr || manager == nullptr) { - MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr"; - } - SymbolResolver symbol_resolver(name_space, symbol, node); - if (!symbol_resolver.Resolve()) { - MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info()); - } - - py::object obj = symbol_resolver.result(); +// resolve the python obj, and if the resovled node is valuenode with graphs, add the graphs to manager +AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, const py::object &obj, + const AnfNodePtr &node) { ScopeGuard scope_guard(node->scope()); AnfNodePtr resolved_node = nullptr; TraceManager::DebugTrace(std::make_shared(node->debug_info())); @@ -262,10 +253,54 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr TraceManager::EndTrace(); return resolved_node; } +} // namespace + +AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, + const AnfNodePtr &node) { + if (node->func_graph() == nullptr || manager == nullptr) { + MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr"; + } + SymbolResolver symbol_resolver(name_space, symbol, node); + if (!symbol_resolver.Resolve()) { + MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info()); + } + + py::object obj = symbol_resolver.result(); + + AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node); + + return resolved_node; +} + +AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, + const SymbolPtr &symbol, const AnfNodePtr &node, const std::string &attr) { + if (node->func_graph() == nullptr || manager == nullptr) { + MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr"; + } + SymbolResolver symbol_resolver(name_space, symbol, node); + if (!symbol_resolver.Resolve()) { + MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info()); + } + + py::object obj = symbol_resolver.result(); + if (!data_converter::IsCellInstance(obj)) { + return nullptr; + } + py::object obj_attr = obj.attr(attr.c_str()); + + AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj_attr, node); + + return resolved_node; +} namespace { opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) { opt::OptPassGroupMap map({ + {"resolve_attr", + { + // for resolve primitive; + irpass.resolver_resolve_attr_, + }}, {"resolve", { // for resolve and getattr primitive; diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.h b/mindspore/ccsrc/pipeline/jit/parse/resolve.h index 1024012d46b..db937daebf1 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.h +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.h @@ -145,6 +145,10 @@ using SymbolResolverPtr = std::shared_ptr; AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node); +// Resolve Cell with attr name. +AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, + const SymbolPtr &symbol, const AnfNodePtr &node, const std::string &attr); + // Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager(). bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true); diff --git a/tests/st/ops/ascend/test_distribution/test_normal_new_api.py b/tests/st/ops/ascend/test_distribution/test_normal_new_api.py new file mode 100644 index 00000000000..eabd5624e89 --- /dev/null +++ b/tests/st/ops/ascend/test_distribution/test_normal_new_api.py @@ -0,0 +1,61 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for new api of normal distribution""" +import numpy as np +from scipy import stats +import mindspore.nn as nn +from mindspore import dtype +from mindspore import Tensor +import mindspore.context as context + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(nn.Cell): + """ + Test class: new api of normal distribution. + """ + + def __init__(self): + super(Net, self).__init__() + self.normal = nn.Normal(0., 1., dtype=dtype.float32) + + def construct(self, x_, y_): + kl = self.normal.kl_loss('kl_loss', 'Normal', x_, y_) + prob = self.normal.prob('prob', kl) + return prob + + +def test_new_api(): + """ + Test new api of normal distribution. + """ + prob = Net() + mean_a = np.array([0.0]).astype(np.float32) + sd_a = np.array([1.0]).astype(np.float32) + mean_b = np.array([1.0]).astype(np.float32) + sd_b = np.array([1.0]).astype(np.float32) + ans = prob(Tensor(mean_b), Tensor(sd_b)) + + diff_log_scale = np.log(sd_a) - np.log(sd_b) + squared_diff = np.square(mean_a / sd_b - mean_b / sd_b) + expect_kl_loss = 0.5 * squared_diff + 0.5 * \ + np.expm1(2 * diff_log_scale) - diff_log_scale + + norm_benchmark = stats.norm(np.array([0.0]), np.array([1.0])) + expect_prob = norm_benchmark.pdf(expect_kl_loss).astype(np.float32) + + tol = 1e-6 + assert (np.abs(ans.asnumpy() - expect_prob) < tol).all()