fix InsertGradientOf with class method
This commit is contained in:
parent
7a367af9c6
commit
113c0d8cd2
|
@ -103,6 +103,14 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr& func_graph, const py::object&
|
|||
if (para_node == nullptr) {
|
||||
ParameterPtr node = top_graph->AddWeightParameter(param_name);
|
||||
node->set_default_param(obj);
|
||||
|
||||
// set_abstract for parameter
|
||||
auto to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input"));
|
||||
ValuePtr converted = nullptr;
|
||||
(void)ConvertData(to_convert, &converted);
|
||||
bool broaden = true;
|
||||
node->set_abstract(abstract::FromValue(converted, broaden));
|
||||
|
||||
para_node = node;
|
||||
}
|
||||
auto iter = func_graph->make_ref_params().find(para_node);
|
||||
|
|
|
@ -112,6 +112,13 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) {
|
|||
});
|
||||
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
|
||||
opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true);
|
||||
opt::irpass::ResolveIRPassLib resolve_irpass;
|
||||
|
||||
opt::OptPassConfig resolve_pass = opt::OptPassConfig({
|
||||
resolve_irpass.resolver_resolve_,
|
||||
resolve_irpass.resolver_getattr_,
|
||||
irpass.get_make_ref_eliminate_,
|
||||
});
|
||||
|
||||
OptPassGroupMap map_a({{"a_1", a_1},
|
||||
{"a_2", a_2},
|
||||
|
@ -120,6 +127,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) {
|
|||
{"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)},
|
||||
{"virtual_dataset", virtual_dataset},
|
||||
{"grad", grad},
|
||||
{"resolve", resolve_pass},
|
||||
{"renormalize", opt::OptPassConfig::Renormalize()},
|
||||
{"cse", opt::OptPassConfig(opt::CSE(false))},
|
||||
{"a_3", a_3}});
|
||||
|
|
|
@ -554,24 +554,6 @@ AbstractBasePtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &dat
|
|||
return eng->ForwardConfig(old_conf, fn_conf);
|
||||
}
|
||||
|
||||
AbstractBasePtr GenerateResolveAbstract(const AnfNodeConfigPtr &out_conf, const py::object &obj,
|
||||
const ValuePtr &converted_ret) {
|
||||
if (py::hasattr(obj, PYTHON_DATACLASS_FIELDS)) {
|
||||
TypePtr cls_ptr = parse::ParseDataClass(converted_ret->cast<std::shared_ptr<parse::PyObjectWrapper>>()->obj());
|
||||
|
||||
std::vector<AnfNodePtr> input = {NewValueNode(prim::kPrimPartial), NewValueNode(prim::kPrimMakeRecord),
|
||||
NewValueNode(cls_ptr)};
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
FuncGraphPtr func_graph = out_conf->node()->func_graph();
|
||||
CNodePtr new_cnode = func_graph->NewCNode(input);
|
||||
AnalysisEnginePtr eng = out_conf->engine();
|
||||
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, out_conf->context());
|
||||
return eng->ForwardConfig(out_conf, fn_conf);
|
||||
} else {
|
||||
return ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf);
|
||||
}
|
||||
}
|
||||
|
||||
AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine,
|
||||
const AbstractBasePtrList &args_spec_list,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
|
@ -602,23 +584,16 @@ AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &eng
|
|||
// item_name to func addr from obj_map
|
||||
parse::SymbolPtr symbol = item_v->cast<parse::SymbolPtr>();
|
||||
parse::NameSpacePtr name_space = data_v->cast<parse::NameSpacePtr>();
|
||||
FuncGraphPtr func_graph = out_conf->node()->func_graph();
|
||||
|
||||
parse::SymbolResolverPtr symbol_resolver =
|
||||
std::make_shared<parse::SymbolResolver>(name_space, symbol, out_conf->node());
|
||||
if (!symbol_resolver->Resolve()) {
|
||||
auto new_node = parse::ResolveSymbol(func_graph->manager(), name_space, symbol, out_conf->node());
|
||||
if (new_node == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Resolve node failed";
|
||||
}
|
||||
|
||||
py::object obj = symbol_resolver->result();
|
||||
ValuePtr converted_ret = nullptr;
|
||||
bool converted = parse::ConvertData(obj, &converted_ret, true);
|
||||
if (!converted) {
|
||||
MS_LOG(EXCEPTION) << "Convert data failed";
|
||||
}
|
||||
if (converted_ret->isa<FuncGraph>()) {
|
||||
AddToManager(engine, converted_ret->cast<FuncGraphPtr>());
|
||||
}
|
||||
return GenerateResolveAbstract(out_conf, obj, converted_ret);
|
||||
AnalysisEnginePtr eng = out_conf->engine();
|
||||
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context());
|
||||
return eng->ForwardConfig(out_conf, fn_conf);
|
||||
}
|
||||
|
||||
AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine,
|
||||
|
|
|
@ -17,13 +17,14 @@ import numpy as np
|
|||
import mindspore.nn as nn
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.api import ms_function
|
||||
from ....mindspore_test_framework.utils.bprop_util import bprop
|
||||
from ....mindspore_test_framework.utils.debug_util import PrintShapeTypeCell, PrintGradShapeTypeCell
|
||||
from mindspore import Tensor
|
||||
|
||||
from mindspore import context
|
||||
|
||||
import mindspore
|
||||
|
||||
def setup_module(module):
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
@ -107,3 +108,36 @@ def test_print_shape_type():
|
|||
return z
|
||||
bprop(Mul(), Tensor(np.ones([2, 2]).astype(np.float32)),
|
||||
Tensor(np.ones([2, 2]).astype(np.float32)))
|
||||
|
||||
def test_cell_assign():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
class GradNetWrap(nn.Cell):
|
||||
""" GradNetWrap definition """
|
||||
def __init__(self, net):
|
||||
super(GradNetWrap, self).__init__()
|
||||
self.net = net
|
||||
self.weights = mindspore.ParameterTuple(net.get_parameters())
|
||||
|
||||
def construct(self, x, y):
|
||||
return C.grad_by_list(self.net, self.weights)(x, y)
|
||||
|
||||
class Mul(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Mul, self).__init__()
|
||||
self.get_g = P.InsertGradientOf(self.save_gradient)
|
||||
self.matrix_w = mindspore.Parameter(Tensor(np.ones([2, 2], np.float32)), name="matrix_w")
|
||||
self.matrix_g = mindspore.Parameter(Tensor(np.ones([2, 2], np.float32)), name="matrix_g")
|
||||
|
||||
def save_gradient(self, dout):
|
||||
self.matrix_g = dout
|
||||
return dout
|
||||
|
||||
def construct(self, x, y):
|
||||
z = x * self.matrix_w
|
||||
z = self.get_g(z)
|
||||
z = z * y
|
||||
return z
|
||||
|
||||
input_x = Tensor(np.ones([2, 2], np.float32))
|
||||
input_y = Tensor(np.ones([2, 2], np.float32))
|
||||
GradNetWrap(Mul())(input_x, input_y)
|
||||
|
|
Loading…
Reference in New Issue