support parameter tuple input in pynative mode

This commit is contained in:
kingfo 2020-08-28 17:36:22 +08:00
parent 236952ca6c
commit 7765d44b76
3 changed files with 134 additions and 8 deletions

View File

@ -793,6 +793,20 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
return node;
}
AnfNodePtr PynativeExecutor::GetParamNode(const py::object &obj) {
auto id = GetId(obj);
auto &param = graph_info_map_[curr_g_].param_map[id];
if (param.second.size() == 1 && param.second[0] == -1) {
return param.first;
}
auto para_node = param.first;
for (auto &idx : param.second) {
std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), para_node, NewValueNode(idx)};
para_node = curr_g_->NewCNode(tuple_get_item_inputs);
}
return para_node;
}
std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &args) {
auto cell_id = GetId(cell);
for (size_t i = 0; i < args.size(); i++) {
@ -995,9 +1009,18 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
graph_info_map_[g] = GraphInfo();
}
for (size_t i = 0; i < args.size(); i++) {
auto param = args[i];
auto new_param = g->add_parameter();
std::string param_obj = GetId(args[i]);
graph_info_map_[g].param_map[param_obj] = new_param;
std::string param_obj = GetId(param);
if (py::isinstance<py::tuple>(param)) {
auto tuple = param.cast<py::tuple>();
auto tuple_size = static_cast<int>(tuple.size());
for (int j = 0; j < tuple_size; j++) {
set_param_map(curr_g_, GetId(tuple[j]), new_param, j);
SetTupleParam(tuple[j], new_param, std::vector<int>{j});
}
}
set_param_map(curr_g_, param_obj, new_param);
}
}
@ -1028,16 +1051,16 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
auto value = py::cast<tensor::TensorPtr>(obj);
free_param->set_default_param(value);
MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id;
graph_info_map_[df_builder_].param_map[obj_id] = free_param;
set_param_map(df_builder_, obj_id, free_param);
return free_param;
}
return graph_info_map_[df_builder_].param_map[obj_id];
return graph_info_map_[df_builder_].param_map[obj_id].first;
}
// if input is graph output
if (graph_info_map_[curr_g_].param_map.count(obj_id) != 0) {
// op(x, y)
node = graph_info_map_[curr_g_].param_map[obj_id];
node = GetParamNode(obj);
} else if (graph_info_map_[curr_g_].obj_node_map.count(obj_id) != 0) {
// out = op(op1(x, y))
// out = op(cell1(x, y))
@ -1085,6 +1108,19 @@ void PynativeExecutor::SetTupleOutput(const py::object &obj, const AnfNodePtr &c
}
}
// for param ((a, (b, c)), d) need multi getitem
void PynativeExecutor::SetTupleParam(const py::object &obj, const AnfNodePtr &para_node, std::vector<int> idx) {
if (py::isinstance<py::tuple>(obj)) {
auto tuple = obj.cast<py::tuple>();
for (int i = 0; i < static_cast<int>(tuple.size()); i++) {
std::vector<int> tmp = idx;
tmp.push_back(i);
set_param_map(curr_g_, GetId(tuple[i]), para_node, tmp);
SetTupleParam(tuple[i], para_node, tmp);
}
}
}
void PynativeExecutor::Pushp() { graph_p_.push(curr_g_); }
void PynativeExecutor::Popp() {
@ -1132,7 +1168,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
const py::args &args) {
AnfNodePtr output_node;
if (graph_info_map_[curr_g_].param_map.count(out_id)) {
output_node = graph_info_map_[curr_g_].param_map[out_id];
output_node = GetParamNode(out);
} else {
output_node = GetObjNode(out);
}
@ -1186,7 +1222,7 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh
auto param_id = GetId(param);
AnfNodePtr para_node = nullptr;
if (graph_info_map_[df_builder_].param_map.count(param_id)) {
para_node = graph_info_map_[df_builder_].param_map[param_id];
para_node = graph_info_map_[df_builder_].param_map[param_id].first;
} else {
auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(param, "name");
if (py::isinstance<py::none>(name_attr)) {

View File

@ -59,7 +59,7 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tupl
void ClearPyNativeSession();
struct GraphInfo {
std::unordered_map<std::string, AnfNodePtr> param_map;
std::unordered_map<std::string, std::pair<AnfNodePtr, std::vector<int>>> param_map;
std::unordered_map<std::string, std::pair<AnfNodePtr, std::vector<int>>> obj_node_map;
AnfNodePtr output;
std::vector<std::string> objects;
@ -92,6 +92,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void set_grad_flag(bool flag) { grad_flag_ = flag; }
AnfNodePtr GetInput(const py::object &obj, bool op_mask);
AnfNodePtr GetObjNode(const py::object &obj);
AnfNodePtr GetParamNode(const py::object &obj);
std::string GetCellId(const py::object &obj, const py::args &args);
FuncGraphPtr curr_g() { return curr_g_; }
void set_pyobj(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); }
@ -104,6 +105,17 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, std::vector<int> index) {
graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index);
}
void set_param_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node) {
graph_info_map_[g].param_map[obj] = std::make_pair(node, std::vector<int>{-1});
}
void set_param_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, int index) {
graph_info_map_[g].param_map[obj] = std::make_pair(node, std::vector<int>{index});
}
void set_param_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, std::vector<int> index) {
graph_info_map_[g].param_map[obj] = std::make_pair(node, index);
}
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
abstract::AbstractBasePtrList *args_spec_list);
void MakeCNode(const OpExecInfoPtr &op_exec_info, const py::object &out, const AnfNodePtr &cnode);
@ -119,6 +131,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
FuncGraphPtr GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op, const std::vector<AnfNodePtr> &weights,
size_t arg_size);
void SetTupleOutput(const py::object &obj, const AnfNodePtr &cnode, std::vector<int> idx);
void SetTupleParam(const py::object &obj, const AnfNodePtr &para_node, std::vector<int> idx);
AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id);
py::tuple RunOpInner(const py::args &args);
py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info);

View File

@ -0,0 +1,77 @@
import numpy as np
import mindspore.nn as nn
from mindspore import context, Tensor
from mindspore.ops import operations as P
from mindspore.ops import composite as C
def setup_module(module):
context.set_context(mode=context.PYNATIVE_MODE)
class Block1(nn.Cell):
""" Define Cell with tuple input as paramter."""
def __init__(self):
super(Block1, self).__init__()
self.mul = P.Mul()
def construct(self, tuple_xy):
x, y = tuple_xy
z = self.mul(x, y)
return z
class Block2(nn.Cell):
""" definition with tuple in tuple output in Cell."""
def __init__(self):
super(Block2, self).__init__()
self.mul = P.Mul()
self.add = P.TensorAdd()
def construct(self, x, y):
z1 = self.mul(x, y)
z2 = self.add(z1, x)
z3 = self.add(z1, y)
return (z1, (z2, z3))
class Net1(nn.Cell):
def __init__(self):
super(Net1, self).__init__()
self.block = Block1()
def construct(self, x, y):
res = self.block((x, y))
return res
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.add = P.TensorAdd()
self.block = Block2()
def construct(self, x, y):
z1, (z2, z3) = self.block(x, y)
res = self.add(z1, z2)
res = self.add(res, z3)
return res
def test_net():
context.set_context(save_graphs=True)
x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32) * 2)
y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32) * 3)
net1 = Net1()
grad_op = C.GradOperation(get_all=True)
output = grad_op(net1)(x, y)
assert np.all(output[0].asnumpy() == y.asnumpy())
assert np.all(output[1].asnumpy() == x.asnumpy())
net2 = Net2()
output = grad_op(net2)(x, y)
expect_x = np.ones([1, 1, 3, 3]).astype(np.float32) * 10
expect_y = np.ones([1, 1, 3, 3]).astype(np.float32) * 7
assert np.all(output[0].asnumpy() == expect_x)
assert np.all(output[1].asnumpy() == expect_y)