add support for pynative multicases
clear df builder add testcases
This commit is contained in:
parent
6a946bddab
commit
341b8468eb
|
@ -1031,7 +1031,8 @@ PynativeExecutor::PynativeExecutor() {
|
|||
|
||||
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
|
||||
auto cell_id = GetCellId(cell, args);
|
||||
if (cell_graph_map_.count(cell_id) != 0) {
|
||||
// judge graph_context_.empty() to create sperate graphs except for the top
|
||||
if (cell_graph_map_.count(cell_id) != 0 && graph_context_.empty()) {
|
||||
if (cell_resource_map_.find(cell_id) != cell_resource_map_.end()) {
|
||||
resource_ = cell_resource_map_[cell_id];
|
||||
}
|
||||
|
@ -1040,21 +1041,24 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
|
|||
}
|
||||
|
||||
auto g = std::make_shared<FuncGraph>();
|
||||
|
||||
if (top_g_ == nullptr) {
|
||||
if (graph_context_.empty()) {
|
||||
// a df builder is built for every top function graph
|
||||
df_builder_ = std::make_shared<FuncGraph>();
|
||||
df_builder_map_[cell_id] = df_builder_;
|
||||
top_g_ = curr_g_ = g;
|
||||
resource_ = std::make_shared<pipeline::Resource>();
|
||||
resource_->results()[pipeline::kPynativeGraphId] = graph_id_++;
|
||||
cell_resource_map_[cell_id] = resource_;
|
||||
df_builder_ = std::make_shared<FuncGraph>();
|
||||
MS_LOG(DEBUG) << "First new graph" << top_g_.get();
|
||||
first_grad_step_ = true;
|
||||
top_graph_cells_.insert(cell_id);
|
||||
Pushp();
|
||||
} else {
|
||||
Pushp();
|
||||
if (df_builder_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "In NewGraphInner, got df builder is nullptr";
|
||||
}
|
||||
curr_g_ = g;
|
||||
}
|
||||
Pushp();
|
||||
if (graph_info_map_.count(g) == 0) {
|
||||
graph_info_map_[g] = GraphInfo();
|
||||
}
|
||||
|
@ -1171,22 +1175,25 @@ void PynativeExecutor::SetTupleParam(const py::object &obj, const AnfNodePtr &pa
|
|||
}
|
||||
}
|
||||
|
||||
void PynativeExecutor::Pushp() { graph_p_.push(curr_g_); }
|
||||
void PynativeExecutor::Pushp() { graph_context_.push(curr_g_); }
|
||||
|
||||
void PynativeExecutor::Popp() {
|
||||
if (graph_p_.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Stack graph_p_ is empty";
|
||||
if (graph_context_.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Stack graph_context_ is empty";
|
||||
}
|
||||
graph_context_.pop();
|
||||
if (!graph_context_.empty()) {
|
||||
curr_g_ = graph_context_.top();
|
||||
}
|
||||
curr_g_ = graph_p_.top();
|
||||
graph_p_.pop();
|
||||
}
|
||||
|
||||
void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) {
|
||||
auto cell_id = GetCellId(cell, args);
|
||||
if (cell_graph_map_.count(cell_id) != 0) {
|
||||
if (cell_graph_map_.count(cell_id) != 0 && graph_context_.empty()) {
|
||||
MS_LOG(DEBUG) << "Endgraph already compiled";
|
||||
return;
|
||||
}
|
||||
|
||||
cell_graph_map_[cell_id] = curr_g_;
|
||||
auto out_id = GetId(out);
|
||||
if (!graph_info_map_[curr_g_].obj_node_map.count(out_id) && !graph_info_map_[curr_g_].param_map.count(out_id)) {
|
||||
|
@ -1246,7 +1253,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
|
|||
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(curr_g_)));
|
||||
}
|
||||
}
|
||||
auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_);
|
||||
auto newfg = ad::Grad(curr_g_, resource_, graph_context_.size() == 1);
|
||||
if (need_replace_param) {
|
||||
auto params = newfg->parameters();
|
||||
auto manager = Manage({newfg}, false);
|
||||
|
@ -1257,26 +1264,29 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
|
|||
}
|
||||
}
|
||||
graph_info_map_.erase(curr_g_);
|
||||
if (curr_g_ != top_g_) {
|
||||
if (graph_context_.size() > 1) {
|
||||
Popp();
|
||||
// connect the previous graph to the inside graph
|
||||
auto graph_prev = graph_context_.top();
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
auto input = GetInput(args[i], false);
|
||||
inputs.push_back(input);
|
||||
}
|
||||
auto out_cnode = curr_g_->NewCNode(inputs);
|
||||
set_pyobj(curr_g_, GetCellId(cell, args));
|
||||
auto out_cnode = graph_prev->NewCNode(inputs);
|
||||
set_pyobj(graph_prev, GetCellId(cell, args));
|
||||
if (py::isinstance<py::tuple>(out)) {
|
||||
auto out_list = py::cast<py::tuple>(out);
|
||||
auto out_size = static_cast<int>(out_list.size());
|
||||
for (int i = 0; i < out_size; i++) {
|
||||
set_obj_node_map(curr_g_, GetId(out_list[i]), out_cnode, i);
|
||||
set_obj_node_map(graph_prev, GetId(out_list[i]), out_cnode, i);
|
||||
SetTupleOutput(out_list[i], out_cnode, std::vector<int>{i});
|
||||
}
|
||||
}
|
||||
set_obj_node_map(curr_g_, GetId(out), out_cnode);
|
||||
set_obj_node_map(graph_prev, GetId(out), out_cnode);
|
||||
} else {
|
||||
parse::ResolveFuncGraph(newfg, resource_);
|
||||
resource_->set_func_graph(newfg);
|
||||
Popp();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1348,14 +1358,36 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args
|
|||
void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
||||
const py::args &args) {
|
||||
MS_LOG(INFO) << "GradNet start" << args.size();
|
||||
|
||||
std::size_t size = args.size();
|
||||
std::string cell_id = GetCellId(cell, args);
|
||||
if (graph_map_.count(cell_id) != 0) {
|
||||
MS_LOG(DEBUG) << "GradNet already compiled";
|
||||
return;
|
||||
}
|
||||
size_t forward_args_count = args.size();
|
||||
if (grad->sens_param()) {
|
||||
forward_args_count = forward_args_count - 1;
|
||||
}
|
||||
py::tuple forward_args(forward_args_count);
|
||||
for (size_t i = 0; i < forward_args_count; i++) {
|
||||
forward_args[i] = args[i];
|
||||
}
|
||||
std::string forward_cell_id = GetCellId(cell, forward_args);
|
||||
MS_LOG(DEBUG) << "Forward cell_id:" << forward_cell_id;
|
||||
if (df_builder_map_.find(forward_cell_id) == df_builder_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find df builder";
|
||||
}
|
||||
df_builder_ = df_builder_map_[forward_cell_id];
|
||||
if (df_builder_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Got unexpected null df builder";
|
||||
}
|
||||
|
||||
if (cell_resource_map_.find(forward_cell_id) == cell_resource_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find resource for " << forward_cell_id;
|
||||
}
|
||||
MS_LOG(DEBUG) << "GradNet first compiled";
|
||||
resource_ = cell_resource_map_[forward_cell_id];
|
||||
|
||||
std::vector<AnfNodePtr> new_params;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
ParameterPtr p = std::make_shared<Parameter>(df_builder_);
|
||||
|
@ -1368,6 +1400,10 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
|
|||
|
||||
std::vector<AnfNodePtr> w_args = GetWeightsArgs(weights);
|
||||
MS_EXCEPTION_IF_NULL(resource_->func_graph());
|
||||
if (cell_graph_map_.find(forward_cell_id) == cell_graph_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Could not find top graph by cellid: " << forward_cell_id;
|
||||
}
|
||||
top_g_ = cell_graph_map_[forward_cell_id];
|
||||
auto g = GradGraph(resource_->func_graph(), grad, w_args, size);
|
||||
resource_->set_func_graph(g);
|
||||
resource_->manager()->KeepRoots({g});
|
||||
|
@ -1409,6 +1445,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
|
|||
MapClear<std::unordered_map<std::string, FuncGraphPtr>>(&graph_map_, flag);
|
||||
MapClear<std::unordered_map<std::string, FuncGraphPtr>>(&cell_graph_map_, flag);
|
||||
MapClear<std::unordered_map<std::string, ResourcePtr>>(&cell_resource_map_, flag);
|
||||
MapClear<std::unordered_map<std::string, FuncGraphPtr>>(&df_builder_map_, flag);
|
||||
Clean();
|
||||
// Maybe exit in the pynative runing op, so need reset pynative flag.
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
|
@ -1431,7 +1468,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
|
|||
graph_info_map_.clear();
|
||||
op_id_map_.clear();
|
||||
obj_to_forward_id_.clear();
|
||||
std::stack<FuncGraphPtr>().swap(graph_p_);
|
||||
std::stack<FuncGraphPtr>().swap(graph_context_);
|
||||
ConfigManager::GetInstance().ResetIterNum();
|
||||
}
|
||||
|
||||
|
@ -1509,7 +1546,6 @@ py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase)
|
|||
}
|
||||
|
||||
std::string backend = MsContext::GetInstance()->backend_policy();
|
||||
|
||||
MS_LOG(DEBUG) << "Eval run" << backend;
|
||||
BaseRef value = (*run)(arg_list);
|
||||
MS_LOG(DEBUG) << "Run end" << value.ToString();
|
||||
|
|
|
@ -155,7 +155,9 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
std::unordered_map<std::string, size_t> op_id_map_;
|
||||
std::unordered_map<std::string, std::string> obj_to_forward_id_;
|
||||
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
|
||||
std::stack<FuncGraphPtr> graph_p_;
|
||||
std::unordered_map<std::string, FuncGraphPtr> df_builder_map_;
|
||||
// the stack that records the context of graph created, the bottom is the top graph
|
||||
std::stack<FuncGraphPtr> graph_context_;
|
||||
FuncGraphPtr top_g_;
|
||||
FuncGraphPtr df_builder_;
|
||||
FuncGraphPtr curr_g_;
|
||||
|
|
|
@ -21,7 +21,7 @@ from types import FunctionType
|
|||
|
||||
from mindspore import context
|
||||
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, \
|
||||
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_
|
||||
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_
|
||||
from ...common import dtype as mstype
|
||||
from ...common.api import ms_function, _pynative_exec, _wrap_func
|
||||
from .. import functional as F
|
||||
|
@ -475,6 +475,7 @@ class _ListAppend(ListAppend_):
|
|||
Args:
|
||||
name (str): The name of the metafuncgraph object.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
ListAppend_.__init__(self, name)
|
||||
|
||||
|
|
|
@ -0,0 +1,207 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
from mindspore import context, nn, Tensor, Parameter, ParameterTuple
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
|
||||
def setup_module():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False)
|
||||
|
||||
|
||||
class _Grad(nn.Cell):
|
||||
def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
|
||||
super().__init__()
|
||||
self.network = network
|
||||
self.grad = grad
|
||||
self.sens_param = self.grad.sens_param
|
||||
self.wrt_params = wrt_params
|
||||
self.real_inputs_count = real_inputs_count
|
||||
if self.wrt_params:
|
||||
self.params = ParameterTuple(self.network.trainable_params())
|
||||
|
||||
def construct(self, *inputs):
|
||||
if self.wrt_params:
|
||||
if self.real_inputs_count is None or self.sens_param is False:
|
||||
return self.grad(self.network, self.params)(*inputs)
|
||||
real_inputs = inputs[:self.real_inputs_count]
|
||||
sense_param_inputs = inputs[self.real_inputs_count:]
|
||||
return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
|
||||
|
||||
if self.real_inputs_count is None or self.sens_param is False:
|
||||
return self.grad(self.network)(*inputs)
|
||||
real_inputs = inputs[:self.real_inputs_count]
|
||||
sense_param_inputs = inputs[self.real_inputs_count:]
|
||||
return self.grad(self.network)(*real_inputs, sense_param_inputs)
|
||||
|
||||
|
||||
class GradOfFirstInput(_Grad):
|
||||
"""
|
||||
get grad of first input
|
||||
"""
|
||||
|
||||
def __init__(self, network, sens_param=True, real_inputs_count=None):
|
||||
super().__init__(grad=C.GradOperation(sens_param=sens_param),
|
||||
network=network, real_inputs_count=real_inputs_count)
|
||||
|
||||
|
||||
class GradOfAllInputs(_Grad):
|
||||
"""
|
||||
get grad of first input
|
||||
"""
|
||||
|
||||
def __init__(self, network, sens_param=True, real_inputs_count=None):
|
||||
super().__init__(grad=C.GradOperation(get_all=True, sens_param=sens_param),
|
||||
network=network, real_inputs_count=real_inputs_count)
|
||||
|
||||
|
||||
def test_multi_grad():
|
||||
class ForwardNetMul(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x, y):
|
||||
a = x * x
|
||||
b = y * y
|
||||
return a * b
|
||||
|
||||
class ForwardNetAdd(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x, y):
|
||||
a = x + x + x
|
||||
b = y + y
|
||||
return a * b
|
||||
mulnet = ForwardNetMul()
|
||||
addnet = ForwardNetAdd()
|
||||
x = Tensor(np.ones([32]), dtype=mstype.float32)
|
||||
y = Tensor(np.ones([32])*2, dtype=mstype.float32)
|
||||
sens = Tensor(np.ones([32]), dtype=mstype.float32)
|
||||
mulnet.set_grad()
|
||||
addnet.set_grad()
|
||||
out1 = mulnet(x, y)
|
||||
out2 = addnet(x, y)
|
||||
grad_mul = GradOfAllInputs(mulnet)
|
||||
grad_add = GradOfAllInputs(addnet)
|
||||
grad_mul(x, y, sens)
|
||||
grad_add(x, y, sens)
|
||||
|
||||
|
||||
def test_multi_same_grad():
|
||||
class ForwardNetMul(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x, y):
|
||||
a = x * x
|
||||
b = y * y
|
||||
return a * b
|
||||
|
||||
class ForwardNetAdd(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x, y):
|
||||
a = x*3
|
||||
b = y*2
|
||||
return a + b
|
||||
mulnet = ForwardNetMul()
|
||||
addnet = ForwardNetAdd()
|
||||
x = Tensor(np.ones([32]), dtype=mstype.float32)
|
||||
y = Tensor(np.ones([32]), dtype=mstype.float32)
|
||||
sens = Tensor(np.ones([32]), dtype=mstype.float32)
|
||||
mulnet.set_grad()
|
||||
addnet.set_grad()
|
||||
out1 = mulnet(x, y)
|
||||
out2 = addnet(x, y)
|
||||
grad_mul = GradOfAllInputs(mulnet)
|
||||
grad_add = GradOfFirstInput(mulnet)
|
||||
grad_mul(x, y, sens)
|
||||
grad_add(x, y, sens)
|
||||
|
||||
|
||||
def test_net_inner_grad():
|
||||
class ForwardNetMul(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x, y):
|
||||
a = x * x
|
||||
b = y * y
|
||||
return a * b
|
||||
|
||||
class ForwardNetAdd(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super().__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, x, y):
|
||||
a = x + x
|
||||
b = y + y
|
||||
res = self.net(a, b)
|
||||
return res
|
||||
mulnet = ForwardNetMul()
|
||||
addnet = ForwardNetAdd(mulnet)
|
||||
x = Tensor(np.ones([32]), dtype=mstype.float32)
|
||||
y = Tensor(np.ones([32]), dtype=mstype.float32)
|
||||
sens = Tensor(np.ones([32]), dtype=mstype.float32)
|
||||
mulnet.set_grad()
|
||||
addnet.set_grad()
|
||||
out1 = mulnet(x, y)
|
||||
out2 = addnet(x, y)
|
||||
grad_mul = GradOfAllInputs(addnet)
|
||||
grad_add = GradOfAllInputs(mulnet)
|
||||
grad_mul(x, y, sens)
|
||||
grad_add(x, y, sens)
|
||||
|
||||
|
||||
def test_net_inner_first_run_grad():
|
||||
class ForwardNetMul(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.z1 = Parameter(Tensor(np.ones([32])*2, dtype=mstype.float32), name='z1')
|
||||
|
||||
def construct(self, x, y):
|
||||
a = x * self.z1
|
||||
b = y * y
|
||||
return a * b
|
||||
|
||||
class ForwardNetAdd(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super().__init__()
|
||||
self.net = net
|
||||
self.z2 = Parameter(Tensor(np.ones([32]), dtype=mstype.float32), name='z2')
|
||||
self.z3 = Parameter(Tensor(np.ones([32]), dtype=mstype.float32), name='z2')
|
||||
|
||||
def construct(self, x, y):
|
||||
a = x + x*self.z3
|
||||
b = y + y*self.z2
|
||||
res = self.net(a, b)
|
||||
return res
|
||||
mulnet = ForwardNetMul()
|
||||
addnet = ForwardNetAdd(mulnet)
|
||||
x = Tensor(np.ones([32]), dtype=mstype.float32)
|
||||
y = Tensor(np.ones([32]), dtype=mstype.float32)
|
||||
sens = Tensor(np.ones([32]), dtype=mstype.float32)
|
||||
mulnet.set_grad()
|
||||
addnet.set_grad()
|
||||
out1 = mulnet(x, y)
|
||||
out2 = addnet(x, y)
|
||||
grad_mul = GradOfAllInputs(addnet)
|
||||
grad_add = GradOfFirstInput(mulnet)
|
||||
grad_mul(x, y, sens)
|
||||
grad_add(x, y, sens)
|
Loading…
Reference in New Issue