forked from OSSInnovation/mindspore
!3453 Support resolving an attribute of a Cell class instance
Merge pull request !3453 from zichun_ye/resolve_attr_pr
This commit is contained in:
commit
bca16792be
|
@ -168,6 +168,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
}
|
||||
|
||||
ResolveIRPassLib::ResolveIRPassLib() {
|
||||
resolver_resolve_attr_ =
|
||||
MakeSubstitution(std::make_shared<ResolveAttr>(), "resolver_resolve_attr", prim::kPrimGetAttr);
|
||||
resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve);
|
||||
resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetattr>(), "resolver_getattr", prim::kPrimGetAttr);
|
||||
}
|
||||
|
|
|
@ -118,6 +118,7 @@ class ResolveIRPassLib {
|
|||
ResolveIRPassLib();
|
||||
~ResolveIRPassLib() = default;
|
||||
|
||||
SubstitutionPtr resolver_resolve_attr_;
|
||||
SubstitutionPtr resolver_resolve_;
|
||||
SubstitutionPtr resolver_getattr_;
|
||||
};
|
||||
|
|
|
@ -21,15 +21,21 @@
|
|||
#include <memory>
|
||||
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/optimizer/optimizer_caller.h"
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/anf_visitor.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "ir/pattern_matcher.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "pipeline/jit/parse/python_adapter.h"
|
||||
#include "pipeline/jit/parse/parse_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
|
||||
const char PARSE_SUPER_NAME[] = "namespace";
|
||||
|
||||
// {prim::kPrimResolve, Ns, Sym}
|
||||
class ResolverResolve : public AnfVisitor {
|
||||
public:
|
||||
|
@ -90,6 +96,34 @@ class ResolverGetattr : public AnfVisitor {
|
|||
parse::NameSpacePtr ns_{nullptr};
|
||||
parse::SymbolPtr sym_{nullptr};
|
||||
};
|
||||
|
||||
// {prim::kPrimGetAttr, {prim::kPrimResolve, ns_node, sym_node}, attr_node}
|
||||
class ResolveAttr : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
PatternNode<AnfNodePtr> ns_node, sym_node, attr_node;
|
||||
auto ResolveAttrLambda = [&node, &ns_node, &sym_node, &attr_node, &optimizer]() -> AnfNodePtr {
|
||||
auto node_to_getattr = node->cast<CNodePtr>()->input(1);
|
||||
std::string attr_as_string = GetValueNode<StringImmPtr>(attr_node.GetNode(node))->value();
|
||||
|
||||
auto ns_ = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
|
||||
auto sym_ = GetValueNode<parse::SymbolPtr>(sym_node.GetNode(node));
|
||||
|
||||
if (ns_->module() == parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER && sym_->symbol() != PARSE_SUPER_NAME) {
|
||||
// deal with the case of getting attr from a class member
|
||||
// and avoid the case of getting attr from self (the result of ParseSuper)
|
||||
auto result = parse::ResolveCellwithAttr(optimizer->manager(), ns_, sym_, node_to_getattr, attr_as_string);
|
||||
return result;
|
||||
}
|
||||
return nullptr;
|
||||
};
|
||||
MATCH_REPLACE_LAMBDA_IF(
|
||||
node, PPrimitive(prim::kPrimGetAttr, PPrimitive(prim::kPrimResolve, ns_node, sym_node), attr_node),
|
||||
ResolveAttrLambda, attr_node.CheckFunc(IsValueNode<StringImm>, node));
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -228,19 +228,10 @@ bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const Func
|
|||
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol,
|
||||
const AnfNodePtr &node) {
|
||||
if (node->func_graph() == nullptr || manager == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr";
|
||||
}
|
||||
SymbolResolver symbol_resolver(name_space, symbol, node);
|
||||
if (!symbol_resolver.Resolve()) {
|
||||
MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info());
|
||||
}
|
||||
|
||||
py::object obj = symbol_resolver.result();
|
||||
// resolve the python obj, and if the resovled node is valuenode with graphs, add the graphs to manager
|
||||
AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, const py::object &obj,
|
||||
const AnfNodePtr &node) {
|
||||
ScopeGuard scope_guard(node->scope());
|
||||
AnfNodePtr resolved_node = nullptr;
|
||||
TraceManager::DebugTrace(std::make_shared<TraceResolve>(node->debug_info()));
|
||||
|
@ -262,10 +253,54 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr
|
|||
TraceManager::EndTrace();
|
||||
return resolved_node;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol,
|
||||
const AnfNodePtr &node) {
|
||||
if (node->func_graph() == nullptr || manager == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr";
|
||||
}
|
||||
SymbolResolver symbol_resolver(name_space, symbol, node);
|
||||
if (!symbol_resolver.Resolve()) {
|
||||
MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info());
|
||||
}
|
||||
|
||||
py::object obj = symbol_resolver.result();
|
||||
|
||||
AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node);
|
||||
|
||||
return resolved_node;
|
||||
}
|
||||
|
||||
AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space,
|
||||
const SymbolPtr &symbol, const AnfNodePtr &node, const std::string &attr) {
|
||||
if (node->func_graph() == nullptr || manager == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr";
|
||||
}
|
||||
SymbolResolver symbol_resolver(name_space, symbol, node);
|
||||
if (!symbol_resolver.Resolve()) {
|
||||
MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info());
|
||||
}
|
||||
|
||||
py::object obj = symbol_resolver.result();
|
||||
if (!data_converter::IsCellInstance(obj)) {
|
||||
return nullptr;
|
||||
}
|
||||
py::object obj_attr = obj.attr(attr.c_str());
|
||||
|
||||
AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj_attr, node);
|
||||
|
||||
return resolved_node;
|
||||
}
|
||||
|
||||
namespace {
|
||||
opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) {
|
||||
opt::OptPassGroupMap map({
|
||||
{"resolve_attr",
|
||||
{
|
||||
// for resolve primitive;
|
||||
irpass.resolver_resolve_attr_,
|
||||
}},
|
||||
{"resolve",
|
||||
{
|
||||
// for resolve and getattr primitive;
|
||||
|
|
|
@ -145,6 +145,10 @@ using SymbolResolverPtr = std::shared_ptr<SymbolResolver>;
|
|||
AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol,
|
||||
const AnfNodePtr &node);
|
||||
|
||||
// Resolve Cell with attr name.
|
||||
AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space,
|
||||
const SymbolPtr &symbol, const AnfNodePtr &node, const std::string &attr);
|
||||
|
||||
// Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager().
|
||||
bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true);
|
||||
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test cases for new api of normal distribution"""
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
import mindspore.nn as nn
|
||||
from mindspore import dtype
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""
|
||||
Test class: new api of normal distribution.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.normal = nn.Normal(0., 1., dtype=dtype.float32)
|
||||
|
||||
def construct(self, x_, y_):
|
||||
kl = self.normal.kl_loss('kl_loss', 'Normal', x_, y_)
|
||||
prob = self.normal.prob('prob', kl)
|
||||
return prob
|
||||
|
||||
|
||||
def test_new_api():
|
||||
"""
|
||||
Test new api of normal distribution.
|
||||
"""
|
||||
prob = Net()
|
||||
mean_a = np.array([0.0]).astype(np.float32)
|
||||
sd_a = np.array([1.0]).astype(np.float32)
|
||||
mean_b = np.array([1.0]).astype(np.float32)
|
||||
sd_b = np.array([1.0]).astype(np.float32)
|
||||
ans = prob(Tensor(mean_b), Tensor(sd_b))
|
||||
|
||||
diff_log_scale = np.log(sd_a) - np.log(sd_b)
|
||||
squared_diff = np.square(mean_a / sd_b - mean_b / sd_b)
|
||||
expect_kl_loss = 0.5 * squared_diff + 0.5 * \
|
||||
np.expm1(2 * diff_log_scale) - diff_log_scale
|
||||
|
||||
norm_benchmark = stats.norm(np.array([0.0]), np.array([1.0]))
|
||||
expect_prob = norm_benchmark.pdf(expect_kl_loss).astype(np.float32)
|
||||
|
||||
tol = 1e-6
|
||||
assert (np.abs(ans.asnumpy() - expect_prob) < tol).all()
|
Loading…
Reference in New Issue