!1945 [bug]fix bug in '=', use signature to support auto cast in assign.

Merge pull request !1945 from vlne-v1/I1JXUP-resnet50-thor-assign-fail
This commit is contained in:
mindspore-ci-bot 2020-06-09 22:57:01 +08:00 committed by Gitee
commit 0e7839826e
4 changed files with 26 additions and 9 deletions

View File

@ -27,7 +27,8 @@ namespace mindspore {
// namespace to support primitive operators
namespace prim {
ValuePtr GetPythonOps(const std::string &op_name,
const std::string &module_name = "mindspore._extends.parse.standard_method");
const std::string &module_name = "mindspore._extends.parse.standard_method",
bool use_signature = false);
// Arithmetic
extern const PrimitivePtr kPrimScalarAdd;

View File

@ -23,10 +23,10 @@
namespace mindspore {
// namespace to support primitive operators
namespace prim {
ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name) {
ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name, bool use_signature) {
py::object obj = parse::python_adapter::GetPyFn(module_name, op_name);
ValuePtr node = nullptr;
bool succ = parse::ConvertData(obj, &node);
bool succ = parse::ConvertData(obj, &node, use_signature);
if (!succ) {
MS_LOG(EXCEPTION) << "get Python op " << op_name << " from " << module_name << " fail";
}

View File

@ -322,12 +322,10 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple);
ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend);
ValueNodePtr get_ref_origin_op = NewValueNode(prim::kPrimGetRefOrigin);
ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient);
const std::string primitive_name("assign");
const std::string module_name("mindspore.ops.functional");
ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name));
ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true));
if (state_assign_.size() == 0 && auto_depends_.size() == 0) {
return;
}
@ -336,8 +334,7 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
vec_states.emplace_back(make_tuple_op);
for (auto &item : state_assign_) {
auto source = ReadVariable(item.second);
auto origin = func_graph()->NewCNode({get_ref_origin_op, item.first});
auto assign = func_graph()->NewCNode({assign_op, origin, source});
auto assign = func_graph()->NewCNode({assign_op, item.first, source});
MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second;
vec_states.emplace_back(assign);
}

View File

@ -47,7 +47,7 @@ class Net(nn.Cell):
def test_assign_through_cell():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
context.set_context(mode=context.GRAPH_MODE)
net = Net()
net.to_float(ms.float16)
net.add_flags_recursive(fp16=False)
@ -57,6 +57,25 @@ def test_assign_through_cell():
net(None)
class AssignOp(nn.Cell):
def __init__(self):
super(AssignOp, self).__init__()
self.b = Parameter(initializer('ones', [5]), name='b')
def construct(self, w):
self.b = w
return w
def test_assign_by_operator():
context.set_context(mode=context.GRAPH_MODE)
net = AssignOp()
net.to_float(ms.float16)
input_data = Tensor(np.ones([5]).astype(np.float32))
net(input_data)
class NetScatterNdUpdate(nn.Cell):
def __init__(self):
super(NetScatterNdUpdate, self).__init__()