pynative-fix-bug-of-tuple-set-item-index-wrong

This commit is contained in:
lvliang 2020-09-19 09:58:39 +08:00
parent 473b9614a7
commit 37e59f826a
2 changed files with 1 additions and 7 deletions

View File

@ -255,7 +255,7 @@ class PynativeEliminater : public OptimizerCaller {
MS_LOG(DEBUG) << "Start FillZero";
ValuePtr out = nullptr;
if (value->isa<Int32Imm>()) {
return MakeValue(0);
return value;
}
if (value->isa<tensor::Tensor>()) {

View File

@ -70,7 +70,6 @@ const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "InsertG
namespace mindspore {
namespace pynative {
static std::shared_ptr<session::SessionBasic> session = nullptr;
PynativeExecutorPtr PynativeExecutor::executor_ = nullptr;
std::mutex PynativeExecutor::instance_lock_;
@ -1213,7 +1212,6 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
}
return graph_info_map_[df_builder_].param_map[obj_id].first;
}
// if input is graph output
if (graph_info_map_[curr_g_].param_map.count(obj_id) != 0) {
// op(x, y)
@ -1227,20 +1225,16 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
// out = op((x, y))
// out = cell((x, y))
auto tuple = obj.cast<py::tuple>();
// cell((1,2)): support not mix (scalar, tensor)
if (!tuple.empty() && !py::isinstance<tensor::Tensor>(tuple[0])) {
return MakeValueNode(obj, obj_id);
}
std::vector<AnfNodePtr> args;
args.push_back(NewValueNode(prim::kPrimMakeTuple));
auto tuple_size = static_cast<int>(tuple.size());
for (int i = 0; i < tuple_size; i++) {
args.push_back(GetInput(tuple[i], false));
}
auto cnode = curr_g_->NewCNode(args);
set_obj_node_map(curr_g_, GetId(obj), cnode);
node = cnode;