!11530 support call inner net attr

From: @zhangbuxue
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-01-27 16:50:43 +08:00 committed by Gitee
commit b9e1c3f045
5 changed files with 146 additions and 15 deletions

View File

@ -202,10 +202,11 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
}
ResolveIRPassLib::ResolveIRPassLib() {
resolver_resolve_attr_ =
MakeSubstitution(std::make_shared<ResolveAttr>(), "resolver_resolve_attr", prim::kPrimGetAttr);
resolver_resolve_and_getattr_ =
MakeSubstitution(std::make_shared<ResolverResolveAndGetAttr>(), "resolver_resolve_and_getattr",
{prim::kPrimGetAttr, prim::kPrimResolve});
resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve);
resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetattr>(), "resolver_getattr", prim::kPrimGetAttr);
resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetAttr>(), "resolver_getattr", prim::kPrimGetAttr);
}
InferenceOptPrepareLib::InferenceOptPrepareLib() {

View File

@ -141,7 +141,7 @@ class ResolveIRPassLib {
ResolveIRPassLib();
~ResolveIRPassLib() = default;
SubstitutionPtr resolver_resolve_attr_;
SubstitutionPtr resolver_resolve_and_getattr_;
SubstitutionPtr resolver_resolve_;
SubstitutionPtr resolver_getattr_;
};

View File

@ -19,6 +19,7 @@
#include <string>
#include <memory>
#include <vector>
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/optimizer_caller.h"
@ -66,7 +67,7 @@ class ResolverResolve : public AnfVisitor {
};
// {prim::kPrimGetAttr, Ns, Str}
class ResolverGetattr : public AnfVisitor {
class ResolverGetAttr : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
Reset();
@ -97,7 +98,7 @@ class ResolverGetattr : public AnfVisitor {
};
// {prim::kPrimGetAttr, {prim::kPrimResolve, ns_node, sym_node}, attr_node}
class ResolveAttr : public OptimizerCaller {
class ResolverGetAttrResolve : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> ns_node, sym_node, attr_node;
@ -122,6 +123,29 @@ class ResolveAttr : public OptimizerCaller {
return nullptr;
}
};
class ResolverResolveAndGetAttr : public OptimizerCaller {
public:
ResolverResolveAndGetAttr() {
resolver_optimizers_ = {std::make_shared<ResolverGetAttrResolve>(), std::make_shared<ResolverResolve>(),
std::make_shared<ResolverGetAttr>()};
}
~ResolverResolveAndGetAttr() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (const auto &resolver_opt : resolver_optimizers_) {
new_node = (*resolver_opt)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
}
return nullptr;
}
private:
std::vector<OptimizerCallerPtr> resolver_optimizers_{};
};
} // namespace irpass
} // namespace opt
} // namespace mindspore

View File

@ -276,8 +276,15 @@ AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpa
if (!data_converter::IsCellInstance(obj)) {
return nullptr;
}
py::object obj_attr = obj.attr(attr.c_str());
AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj_attr, node);
const std::string fn = PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL;
const std::string module = "mindspore._extends.parse.parser";
py::object namespace_obj = parse::python_adapter::GetPyFn(module, fn)(obj);
auto new_namespace = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_obj);
auto new_symbol = std::make_shared<Symbol>(attr);
AnfNodePtrList inputs = {NewValueNode(prim::kPrimResolve), NewValueNode(new_namespace), NewValueNode(new_symbol)};
AnfNodePtr resolved_node = node->func_graph()->NewCNode(inputs);
TraceManager::ClearParseOrResolveDebugInfo();
return resolved_node;
}
@ -285,16 +292,10 @@ AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpa
namespace {
opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) {
opt::OptPassGroupMap map({
{"resolve_attr",
{
// for resolve primitive;
irpass.resolver_resolve_attr_,
}},
{"resolve",
{
// for resolve and getattr primitive;
irpass.resolver_resolve_,
irpass.resolver_getattr_,
irpass.resolver_resolve_and_getattr_,
}},
});
return map;

View File

@ -0,0 +1,105 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test call inner net attr"""
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore import context
from mindspore.ops import composite as C
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
class InnerInNet(nn.Cell):
def __init__(self, init_data, const):
super(InnerInNet, self).__init__()
self.weight = Parameter(init_data, name="weight_s")
self.t = init_data
self.const = const
def construct(self, input_x):
if self.const:
return input_x * self.t
return input_x * self.weight
class InnerNet(nn.Cell):
def __init__(self, init_data, const):
super(InnerNet, self).__init__()
self.inner_in_net = InnerInNet(init_data, const)
self.t = init_data
self.const = const
def construct(self, input_x):
if self.const:
return self.inner_in_net.t / self.inner_in_net(input_x)
return self.inner_in_net.weight / self.inner_in_net(input_x)
class Net(nn.Cell):
def __init__(self, init_data, const):
super(Net, self).__init__()
self.inner_net = InnerNet(init_data, const)
self.x = Tensor(np.ones((2, 3)) * 5)
self.y = Tensor(np.ones((2, 3)) * 6)
self.const = const
self.weight = Parameter(init_data, name="weight_s")
def construct(self, input_x, input_y):
if self.const:
return self.inner_net.t + self.inner_net(self.x) - self.y
return self.inner_net.t + self.inner_net(input_x) - input_y
class OuterMostNet(nn.Cell):
def __init__(self, init_data, const):
super(OuterMostNet, self).__init__()
self.net = Net(init_data, const)
def construct(self, input_x, input_y):
return self.net.inner_net.inner_in_net.t
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.forward_net = net
self.sens = Tensor(np.ones((2, 2), np.float32) * 5)
self.grad_all = C.GradOperation(get_all=True)
def construct(self, input_x, input_y):
return self.grad_all(self.forward_net)(input_x, input_y)
def test_inner_net_attr():
input_x = Tensor(np.ones((2, 3)) * 2)
input_y = Tensor(np.ones((2, 3)) * 3)
init_data = Tensor(np.ones((2, 3)) * 4)
test_var_net = Net(init_data, False)
test_var_net(input_x, input_y)
grad_net = GradNet(test_var_net)
grad_net(input_x, input_y)
test_const_net = Net(init_data, True)
ret = test_const_net(input_x, input_y)
expect = -1.8 * np.ones((2, 3))
assert np.allclose(ret.asnumpy(), expect)
test_outer_net = OuterMostNet(init_data, True)
ret = test_outer_net(input_x, input_y)
assert np.allclose(ret.asnumpy(), init_data.asnumpy())