forked from mindspore-Ecosystem/mindspore
!1945 [bug]fix bug in '=', use signature to support auto cast in assign.
Merge pull request !1945 from vlne-v1/I1JXUP-resnet50-thor-assign-fail
This commit is contained in:
commit
0e7839826e
|
@ -27,7 +27,8 @@ namespace mindspore {
|
|||
// namespace to support primitive operators
|
||||
namespace prim {
|
||||
ValuePtr GetPythonOps(const std::string &op_name,
|
||||
const std::string &module_name = "mindspore._extends.parse.standard_method");
|
||||
const std::string &module_name = "mindspore._extends.parse.standard_method",
|
||||
bool use_signature = false);
|
||||
|
||||
// Arithmetic
|
||||
extern const PrimitivePtr kPrimScalarAdd;
|
||||
|
|
|
@ -23,10 +23,10 @@
|
|||
namespace mindspore {
|
||||
// namespace to support primitive operators
|
||||
namespace prim {
|
||||
ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name) {
|
||||
ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name, bool use_signature) {
|
||||
py::object obj = parse::python_adapter::GetPyFn(module_name, op_name);
|
||||
ValuePtr node = nullptr;
|
||||
bool succ = parse::ConvertData(obj, &node);
|
||||
bool succ = parse::ConvertData(obj, &node, use_signature);
|
||||
if (!succ) {
|
||||
MS_LOG(EXCEPTION) << "get Python op " << op_name << " from " << module_name << " fail";
|
||||
}
|
||||
|
|
|
@ -322,12 +322,10 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
|
|||
|
||||
ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple);
|
||||
ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend);
|
||||
ValueNodePtr get_ref_origin_op = NewValueNode(prim::kPrimGetRefOrigin);
|
||||
ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient);
|
||||
const std::string primitive_name("assign");
|
||||
const std::string module_name("mindspore.ops.functional");
|
||||
ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name));
|
||||
|
||||
ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true));
|
||||
if (state_assign_.size() == 0 && auto_depends_.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
@ -336,8 +334,7 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
|
|||
vec_states.emplace_back(make_tuple_op);
|
||||
for (auto &item : state_assign_) {
|
||||
auto source = ReadVariable(item.second);
|
||||
auto origin = func_graph()->NewCNode({get_ref_origin_op, item.first});
|
||||
auto assign = func_graph()->NewCNode({assign_op, origin, source});
|
||||
auto assign = func_graph()->NewCNode({assign_op, item.first, source});
|
||||
MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second;
|
||||
vec_states.emplace_back(assign);
|
||||
}
|
||||
|
|
|
@ -47,7 +47,7 @@ class Net(nn.Cell):
|
|||
|
||||
|
||||
def test_assign_through_cell():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net()
|
||||
net.to_float(ms.float16)
|
||||
net.add_flags_recursive(fp16=False)
|
||||
|
@ -57,6 +57,25 @@ def test_assign_through_cell():
|
|||
net(None)
|
||||
|
||||
|
||||
class AssignOp(nn.Cell):
|
||||
def __init__(self):
|
||||
super(AssignOp, self).__init__()
|
||||
self.b = Parameter(initializer('ones', [5]), name='b')
|
||||
|
||||
|
||||
def construct(self, w):
|
||||
self.b = w
|
||||
return w
|
||||
|
||||
|
||||
def test_assign_by_operator():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = AssignOp()
|
||||
net.to_float(ms.float16)
|
||||
input_data = Tensor(np.ones([5]).astype(np.float32))
|
||||
net(input_data)
|
||||
|
||||
|
||||
class NetScatterNdUpdate(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetScatterNdUpdate, self).__init__()
|
||||
|
|
Loading…
Reference in New Issue