forked from mindspore-Ecosystem/mindspore
!11530 support call inner net attr
From: @zhangbuxue Reviewed-by: Signed-off-by:
This commit is contained in:
commit
b9e1c3f045
|
@ -202,10 +202,11 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
}
|
||||
|
||||
ResolveIRPassLib::ResolveIRPassLib() {
|
||||
resolver_resolve_attr_ =
|
||||
MakeSubstitution(std::make_shared<ResolveAttr>(), "resolver_resolve_attr", prim::kPrimGetAttr);
|
||||
resolver_resolve_and_getattr_ =
|
||||
MakeSubstitution(std::make_shared<ResolverResolveAndGetAttr>(), "resolver_resolve_and_getattr",
|
||||
{prim::kPrimGetAttr, prim::kPrimResolve});
|
||||
resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve);
|
||||
resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetattr>(), "resolver_getattr", prim::kPrimGetAttr);
|
||||
resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetAttr>(), "resolver_getattr", prim::kPrimGetAttr);
|
||||
}
|
||||
|
||||
InferenceOptPrepareLib::InferenceOptPrepareLib() {
|
||||
|
|
|
@ -141,7 +141,7 @@ class ResolveIRPassLib {
|
|||
ResolveIRPassLib();
|
||||
~ResolveIRPassLib() = default;
|
||||
|
||||
SubstitutionPtr resolver_resolve_attr_;
|
||||
SubstitutionPtr resolver_resolve_and_getattr_;
|
||||
SubstitutionPtr resolver_resolve_;
|
||||
SubstitutionPtr resolver_getattr_;
|
||||
};
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/optimizer/optimizer_caller.h"
|
||||
|
@ -66,7 +67,7 @@ class ResolverResolve : public AnfVisitor {
|
|||
};
|
||||
|
||||
// {prim::kPrimGetAttr, Ns, Str}
|
||||
class ResolverGetattr : public AnfVisitor {
|
||||
class ResolverGetAttr : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
Reset();
|
||||
|
@ -97,7 +98,7 @@ class ResolverGetattr : public AnfVisitor {
|
|||
};
|
||||
|
||||
// {prim::kPrimGetAttr, {prim::kPrimResolve, ns_node, sym_node}, attr_node}
|
||||
class ResolveAttr : public OptimizerCaller {
|
||||
class ResolverGetAttrResolve : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
PatternNode<AnfNodePtr> ns_node, sym_node, attr_node;
|
||||
|
@ -122,6 +123,29 @@ class ResolveAttr : public OptimizerCaller {
|
|||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
class ResolverResolveAndGetAttr : public OptimizerCaller {
|
||||
public:
|
||||
ResolverResolveAndGetAttr() {
|
||||
resolver_optimizers_ = {std::make_shared<ResolverGetAttrResolve>(), std::make_shared<ResolverResolve>(),
|
||||
std::make_shared<ResolverGetAttr>()};
|
||||
}
|
||||
~ResolverResolveAndGetAttr() = default;
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
AnfNodePtr new_node;
|
||||
for (const auto &resolver_opt : resolver_optimizers_) {
|
||||
new_node = (*resolver_opt)(optimizer, node);
|
||||
if (new_node != nullptr) {
|
||||
return new_node;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<OptimizerCallerPtr> resolver_optimizers_{};
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -276,8 +276,15 @@ AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpa
|
|||
if (!data_converter::IsCellInstance(obj)) {
|
||||
return nullptr;
|
||||
}
|
||||
py::object obj_attr = obj.attr(attr.c_str());
|
||||
AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj_attr, node);
|
||||
|
||||
const std::string fn = PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL;
|
||||
const std::string module = "mindspore._extends.parse.parser";
|
||||
py::object namespace_obj = parse::python_adapter::GetPyFn(module, fn)(obj);
|
||||
auto new_namespace = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_obj);
|
||||
auto new_symbol = std::make_shared<Symbol>(attr);
|
||||
|
||||
AnfNodePtrList inputs = {NewValueNode(prim::kPrimResolve), NewValueNode(new_namespace), NewValueNode(new_symbol)};
|
||||
AnfNodePtr resolved_node = node->func_graph()->NewCNode(inputs);
|
||||
TraceManager::ClearParseOrResolveDebugInfo();
|
||||
return resolved_node;
|
||||
}
|
||||
|
@ -285,16 +292,10 @@ AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpa
|
|||
namespace {
|
||||
opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) {
|
||||
opt::OptPassGroupMap map({
|
||||
{"resolve_attr",
|
||||
{
|
||||
// for resolve primitive;
|
||||
irpass.resolver_resolve_attr_,
|
||||
}},
|
||||
{"resolve",
|
||||
{
|
||||
// for resolve and getattr primitive;
|
||||
irpass.resolver_resolve_,
|
||||
irpass.resolver_getattr_,
|
||||
irpass.resolver_resolve_and_getattr_,
|
||||
}},
|
||||
});
|
||||
return map;
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test call inner net attr"""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import context
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
|
||||
|
||||
class InnerInNet(nn.Cell):
|
||||
def __init__(self, init_data, const):
|
||||
super(InnerInNet, self).__init__()
|
||||
self.weight = Parameter(init_data, name="weight_s")
|
||||
self.t = init_data
|
||||
self.const = const
|
||||
|
||||
def construct(self, input_x):
|
||||
if self.const:
|
||||
return input_x * self.t
|
||||
return input_x * self.weight
|
||||
|
||||
|
||||
class InnerNet(nn.Cell):
|
||||
def __init__(self, init_data, const):
|
||||
super(InnerNet, self).__init__()
|
||||
self.inner_in_net = InnerInNet(init_data, const)
|
||||
self.t = init_data
|
||||
self.const = const
|
||||
|
||||
def construct(self, input_x):
|
||||
if self.const:
|
||||
return self.inner_in_net.t / self.inner_in_net(input_x)
|
||||
return self.inner_in_net.weight / self.inner_in_net(input_x)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, init_data, const):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = InnerNet(init_data, const)
|
||||
self.x = Tensor(np.ones((2, 3)) * 5)
|
||||
self.y = Tensor(np.ones((2, 3)) * 6)
|
||||
self.const = const
|
||||
self.weight = Parameter(init_data, name="weight_s")
|
||||
|
||||
def construct(self, input_x, input_y):
|
||||
if self.const:
|
||||
return self.inner_net.t + self.inner_net(self.x) - self.y
|
||||
return self.inner_net.t + self.inner_net(input_x) - input_y
|
||||
|
||||
|
||||
class OuterMostNet(nn.Cell):
|
||||
def __init__(self, init_data, const):
|
||||
super(OuterMostNet, self).__init__()
|
||||
self.net = Net(init_data, const)
|
||||
|
||||
def construct(self, input_x, input_y):
|
||||
return self.net.inner_net.inner_in_net.t
|
||||
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.forward_net = net
|
||||
self.sens = Tensor(np.ones((2, 2), np.float32) * 5)
|
||||
self.grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
def construct(self, input_x, input_y):
|
||||
return self.grad_all(self.forward_net)(input_x, input_y)
|
||||
|
||||
|
||||
def test_inner_net_attr():
|
||||
input_x = Tensor(np.ones((2, 3)) * 2)
|
||||
input_y = Tensor(np.ones((2, 3)) * 3)
|
||||
init_data = Tensor(np.ones((2, 3)) * 4)
|
||||
|
||||
test_var_net = Net(init_data, False)
|
||||
test_var_net(input_x, input_y)
|
||||
|
||||
grad_net = GradNet(test_var_net)
|
||||
grad_net(input_x, input_y)
|
||||
|
||||
test_const_net = Net(init_data, True)
|
||||
ret = test_const_net(input_x, input_y)
|
||||
expect = -1.8 * np.ones((2, 3))
|
||||
assert np.allclose(ret.asnumpy(), expect)
|
||||
|
||||
test_outer_net = OuterMostNet(init_data, True)
|
||||
ret = test_outer_net(input_x, input_y)
|
||||
assert np.allclose(ret.asnumpy(), init_data.asnumpy())
|
Loading…
Reference in New Issue