support parameter tuple input in pynative mode
This commit is contained in:
parent
236952ca6c
commit
7765d44b76
|
@ -793,6 +793,20 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
|
|||
return node;
|
||||
}
|
||||
|
||||
AnfNodePtr PynativeExecutor::GetParamNode(const py::object &obj) {
|
||||
auto id = GetId(obj);
|
||||
auto ¶m = graph_info_map_[curr_g_].param_map[id];
|
||||
if (param.second.size() == 1 && param.second[0] == -1) {
|
||||
return param.first;
|
||||
}
|
||||
auto para_node = param.first;
|
||||
for (auto &idx : param.second) {
|
||||
std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), para_node, NewValueNode(idx)};
|
||||
para_node = curr_g_->NewCNode(tuple_get_item_inputs);
|
||||
}
|
||||
return para_node;
|
||||
}
|
||||
|
||||
std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &args) {
|
||||
auto cell_id = GetId(cell);
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
|
@ -995,9 +1009,18 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
|
|||
graph_info_map_[g] = GraphInfo();
|
||||
}
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
auto param = args[i];
|
||||
auto new_param = g->add_parameter();
|
||||
std::string param_obj = GetId(args[i]);
|
||||
graph_info_map_[g].param_map[param_obj] = new_param;
|
||||
std::string param_obj = GetId(param);
|
||||
if (py::isinstance<py::tuple>(param)) {
|
||||
auto tuple = param.cast<py::tuple>();
|
||||
auto tuple_size = static_cast<int>(tuple.size());
|
||||
for (int j = 0; j < tuple_size; j++) {
|
||||
set_param_map(curr_g_, GetId(tuple[j]), new_param, j);
|
||||
SetTupleParam(tuple[j], new_param, std::vector<int>{j});
|
||||
}
|
||||
}
|
||||
set_param_map(curr_g_, param_obj, new_param);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1028,16 +1051,16 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
|
|||
auto value = py::cast<tensor::TensorPtr>(obj);
|
||||
free_param->set_default_param(value);
|
||||
MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id;
|
||||
graph_info_map_[df_builder_].param_map[obj_id] = free_param;
|
||||
set_param_map(df_builder_, obj_id, free_param);
|
||||
return free_param;
|
||||
}
|
||||
return graph_info_map_[df_builder_].param_map[obj_id];
|
||||
return graph_info_map_[df_builder_].param_map[obj_id].first;
|
||||
}
|
||||
|
||||
// if input is graph output
|
||||
if (graph_info_map_[curr_g_].param_map.count(obj_id) != 0) {
|
||||
// op(x, y)
|
||||
node = graph_info_map_[curr_g_].param_map[obj_id];
|
||||
node = GetParamNode(obj);
|
||||
} else if (graph_info_map_[curr_g_].obj_node_map.count(obj_id) != 0) {
|
||||
// out = op(op1(x, y))
|
||||
// out = op(cell1(x, y))
|
||||
|
@ -1085,6 +1108,19 @@ void PynativeExecutor::SetTupleOutput(const py::object &obj, const AnfNodePtr &c
|
|||
}
|
||||
}
|
||||
|
||||
// for param ((a, (b, c)), d) need multi getitem
|
||||
void PynativeExecutor::SetTupleParam(const py::object &obj, const AnfNodePtr ¶_node, std::vector<int> idx) {
|
||||
if (py::isinstance<py::tuple>(obj)) {
|
||||
auto tuple = obj.cast<py::tuple>();
|
||||
for (int i = 0; i < static_cast<int>(tuple.size()); i++) {
|
||||
std::vector<int> tmp = idx;
|
||||
tmp.push_back(i);
|
||||
set_param_map(curr_g_, GetId(tuple[i]), para_node, tmp);
|
||||
SetTupleParam(tuple[i], para_node, tmp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PynativeExecutor::Pushp() { graph_p_.push(curr_g_); }
|
||||
|
||||
void PynativeExecutor::Popp() {
|
||||
|
@ -1132,7 +1168,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
|
|||
const py::args &args) {
|
||||
AnfNodePtr output_node;
|
||||
if (graph_info_map_[curr_g_].param_map.count(out_id)) {
|
||||
output_node = graph_info_map_[curr_g_].param_map[out_id];
|
||||
output_node = GetParamNode(out);
|
||||
} else {
|
||||
output_node = GetObjNode(out);
|
||||
}
|
||||
|
@ -1186,7 +1222,7 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh
|
|||
auto param_id = GetId(param);
|
||||
AnfNodePtr para_node = nullptr;
|
||||
if (graph_info_map_[df_builder_].param_map.count(param_id)) {
|
||||
para_node = graph_info_map_[df_builder_].param_map[param_id];
|
||||
para_node = graph_info_map_[df_builder_].param_map[param_id].first;
|
||||
} else {
|
||||
auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(param, "name");
|
||||
if (py::isinstance<py::none>(name_attr)) {
|
||||
|
|
|
@ -59,7 +59,7 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tupl
|
|||
void ClearPyNativeSession();
|
||||
|
||||
struct GraphInfo {
|
||||
std::unordered_map<std::string, AnfNodePtr> param_map;
|
||||
std::unordered_map<std::string, std::pair<AnfNodePtr, std::vector<int>>> param_map;
|
||||
std::unordered_map<std::string, std::pair<AnfNodePtr, std::vector<int>>> obj_node_map;
|
||||
AnfNodePtr output;
|
||||
std::vector<std::string> objects;
|
||||
|
@ -92,6 +92,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
void set_grad_flag(bool flag) { grad_flag_ = flag; }
|
||||
AnfNodePtr GetInput(const py::object &obj, bool op_mask);
|
||||
AnfNodePtr GetObjNode(const py::object &obj);
|
||||
AnfNodePtr GetParamNode(const py::object &obj);
|
||||
std::string GetCellId(const py::object &obj, const py::args &args);
|
||||
FuncGraphPtr curr_g() { return curr_g_; }
|
||||
void set_pyobj(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); }
|
||||
|
@ -104,6 +105,17 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, std::vector<int> index) {
|
||||
graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index);
|
||||
}
|
||||
|
||||
void set_param_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node) {
|
||||
graph_info_map_[g].param_map[obj] = std::make_pair(node, std::vector<int>{-1});
|
||||
}
|
||||
void set_param_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, int index) {
|
||||
graph_info_map_[g].param_map[obj] = std::make_pair(node, std::vector<int>{index});
|
||||
}
|
||||
void set_param_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, std::vector<int> index) {
|
||||
graph_info_map_[g].param_map[obj] = std::make_pair(node, index);
|
||||
}
|
||||
|
||||
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
|
||||
abstract::AbstractBasePtrList *args_spec_list);
|
||||
void MakeCNode(const OpExecInfoPtr &op_exec_info, const py::object &out, const AnfNodePtr &cnode);
|
||||
|
@ -119,6 +131,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
FuncGraphPtr GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op, const std::vector<AnfNodePtr> &weights,
|
||||
size_t arg_size);
|
||||
void SetTupleOutput(const py::object &obj, const AnfNodePtr &cnode, std::vector<int> idx);
|
||||
void SetTupleParam(const py::object &obj, const AnfNodePtr ¶_node, std::vector<int> idx);
|
||||
AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id);
|
||||
py::tuple RunOpInner(const py::args &args);
|
||||
py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info);
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
|
||||
|
||||
def setup_module(module):
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
|
||||
class Block1(nn.Cell):
|
||||
""" Define Cell with tuple input as paramter."""
|
||||
|
||||
def __init__(self):
|
||||
super(Block1, self).__init__()
|
||||
self.mul = P.Mul()
|
||||
|
||||
def construct(self, tuple_xy):
|
||||
x, y = tuple_xy
|
||||
z = self.mul(x, y)
|
||||
return z
|
||||
|
||||
class Block2(nn.Cell):
|
||||
""" definition with tuple in tuple output in Cell."""
|
||||
|
||||
def __init__(self):
|
||||
super(Block2, self).__init__()
|
||||
self.mul = P.Mul()
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
def construct(self, x, y):
|
||||
z1 = self.mul(x, y)
|
||||
z2 = self.add(z1, x)
|
||||
z3 = self.add(z1, y)
|
||||
return (z1, (z2, z3))
|
||||
|
||||
class Net1(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net1, self).__init__()
|
||||
self.block = Block1()
|
||||
|
||||
def construct(self, x, y):
|
||||
res = self.block((x, y))
|
||||
return res
|
||||
|
||||
|
||||
class Net2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net2, self).__init__()
|
||||
self.add = P.TensorAdd()
|
||||
self.block = Block2()
|
||||
|
||||
def construct(self, x, y):
|
||||
z1, (z2, z3) = self.block(x, y)
|
||||
res = self.add(z1, z2)
|
||||
res = self.add(res, z3)
|
||||
return res
|
||||
|
||||
def test_net():
|
||||
context.set_context(save_graphs=True)
|
||||
x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32) * 2)
|
||||
y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32) * 3)
|
||||
net1 = Net1()
|
||||
grad_op = C.GradOperation(get_all=True)
|
||||
output = grad_op(net1)(x, y)
|
||||
assert np.all(output[0].asnumpy() == y.asnumpy())
|
||||
assert np.all(output[1].asnumpy() == x.asnumpy())
|
||||
|
||||
net2 = Net2()
|
||||
output = grad_op(net2)(x, y)
|
||||
expect_x = np.ones([1, 1, 3, 3]).astype(np.float32) * 10
|
||||
expect_y = np.ones([1, 1, 3, 3]).astype(np.float32) * 7
|
||||
assert np.all(output[0].asnumpy() == expect_x)
|
||||
assert np.all(output[1].asnumpy() == expect_y)
|
Loading…
Reference in New Issue