!3453 Support resolving an attribute of a Cell class instance

Merge pull request !3453 from zichun_ye/resolve_attr_pr
This commit is contained in:
mindspore-ci-bot 2020-07-28 14:32:31 +08:00 committed by Gitee
commit bca16792be
6 changed files with 149 additions and 12 deletions

View File

@ -168,6 +168,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
}
ResolveIRPassLib::ResolveIRPassLib() {
resolver_resolve_attr_ =
MakeSubstitution(std::make_shared<ResolveAttr>(), "resolver_resolve_attr", prim::kPrimGetAttr);
resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve);
resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetattr>(), "resolver_getattr", prim::kPrimGetAttr);
}

View File

@ -118,6 +118,7 @@ class ResolveIRPassLib {
ResolveIRPassLib();
~ResolveIRPassLib() = default;
SubstitutionPtr resolver_resolve_attr_;
SubstitutionPtr resolver_resolve_;
SubstitutionPtr resolver_getattr_;
};

View File

@ -21,15 +21,21 @@
#include <memory>
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/optimizer_caller.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
#include "ir/pattern_matcher.h"
#include "pipeline/jit/parse/data_converter.h"
#include "pipeline/jit/parse/python_adapter.h"
#include "pipeline/jit/parse/parse_base.h"
namespace mindspore {
namespace opt {
namespace irpass {
const char PARSE_SUPER_NAME[] = "namespace";
// {prim::kPrimResolve, Ns, Sym}
class ResolverResolve : public AnfVisitor {
public:
@ -90,6 +96,34 @@ class ResolverGetattr : public AnfVisitor {
parse::NameSpacePtr ns_{nullptr};
parse::SymbolPtr sym_{nullptr};
};
// {prim::kPrimGetAttr, {prim::kPrimResolve, ns_node, sym_node}, attr_node}
class ResolveAttr : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> ns_node, sym_node, attr_node;
auto ResolveAttrLambda = [&node, &ns_node, &sym_node, &attr_node, &optimizer]() -> AnfNodePtr {
auto node_to_getattr = node->cast<CNodePtr>()->input(1);
std::string attr_as_string = GetValueNode<StringImmPtr>(attr_node.GetNode(node))->value();
auto ns_ = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
auto sym_ = GetValueNode<parse::SymbolPtr>(sym_node.GetNode(node));
if (ns_->module() == parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER && sym_->symbol() != PARSE_SUPER_NAME) {
// deal with the case of getting attr from a class member
// and avoid the case of getting attr from self (the result of ParseSuper)
auto result = parse::ResolveCellwithAttr(optimizer->manager(), ns_, sym_, node_to_getattr, attr_as_string);
return result;
}
return nullptr;
};
MATCH_REPLACE_LAMBDA_IF(
node, PPrimitive(prim::kPrimGetAttr, PPrimitive(prim::kPrimResolve, ns_node, sym_node), attr_node),
ResolveAttrLambda, attr_node.CheckFunc(IsValueNode<StringImm>, node));
return nullptr;
}
};
} // namespace irpass
} // namespace opt
} // namespace mindspore

View File

@ -228,19 +228,10 @@ bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const Func
return true;
}
} // namespace
AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol,
const AnfNodePtr &node) {
if (node->func_graph() == nullptr || manager == nullptr) {
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr";
}
SymbolResolver symbol_resolver(name_space, symbol, node);
if (!symbol_resolver.Resolve()) {
MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info());
}
py::object obj = symbol_resolver.result();
// resolve the python obj, and if the resovled node is valuenode with graphs, add the graphs to manager
AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, const py::object &obj,
const AnfNodePtr &node) {
ScopeGuard scope_guard(node->scope());
AnfNodePtr resolved_node = nullptr;
TraceManager::DebugTrace(std::make_shared<TraceResolve>(node->debug_info()));
@ -262,10 +253,54 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr
TraceManager::EndTrace();
return resolved_node;
}
} // namespace
AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol,
const AnfNodePtr &node) {
if (node->func_graph() == nullptr || manager == nullptr) {
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr";
}
SymbolResolver symbol_resolver(name_space, symbol, node);
if (!symbol_resolver.Resolve()) {
MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info());
}
py::object obj = symbol_resolver.result();
AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node);
return resolved_node;
}
AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space,
const SymbolPtr &symbol, const AnfNodePtr &node, const std::string &attr) {
if (node->func_graph() == nullptr || manager == nullptr) {
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr";
}
SymbolResolver symbol_resolver(name_space, symbol, node);
if (!symbol_resolver.Resolve()) {
MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info());
}
py::object obj = symbol_resolver.result();
if (!data_converter::IsCellInstance(obj)) {
return nullptr;
}
py::object obj_attr = obj.attr(attr.c_str());
AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj_attr, node);
return resolved_node;
}
namespace {
opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) {
opt::OptPassGroupMap map({
{"resolve_attr",
{
// for resolve primitive;
irpass.resolver_resolve_attr_,
}},
{"resolve",
{
// for resolve and getattr primitive;

View File

@ -145,6 +145,10 @@ using SymbolResolverPtr = std::shared_ptr<SymbolResolver>;
AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol,
const AnfNodePtr &node);
// Resolve Cell with attr name.
AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space,
const SymbolPtr &symbol, const AnfNodePtr &node, const std::string &attr);
// Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager().
bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true);

View File

@ -0,0 +1,61 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for new api of normal distribution"""
import numpy as np
from scipy import stats
import mindspore.nn as nn
from mindspore import dtype
from mindspore import Tensor
import mindspore.context as context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
"""
Test class: new api of normal distribution.
"""
def __init__(self):
super(Net, self).__init__()
self.normal = nn.Normal(0., 1., dtype=dtype.float32)
def construct(self, x_, y_):
kl = self.normal.kl_loss('kl_loss', 'Normal', x_, y_)
prob = self.normal.prob('prob', kl)
return prob
def test_new_api():
"""
Test new api of normal distribution.
"""
prob = Net()
mean_a = np.array([0.0]).astype(np.float32)
sd_a = np.array([1.0]).astype(np.float32)
mean_b = np.array([1.0]).astype(np.float32)
sd_b = np.array([1.0]).astype(np.float32)
ans = prob(Tensor(mean_b), Tensor(sd_b))
diff_log_scale = np.log(sd_a) - np.log(sd_b)
squared_diff = np.square(mean_a / sd_b - mean_b / sd_b)
expect_kl_loss = 0.5 * squared_diff + 0.5 * \
np.expm1(2 * diff_log_scale) - diff_log_scale
norm_benchmark = stats.norm(np.array([0.0]), np.array([1.0]))
expect_prob = norm_benchmark.pdf(expect_kl_loss).astype(np.float32)
tol = 1e-6
assert (np.abs(ans.asnumpy() - expect_prob) < tol).all()