forked from mindspore-Ecosystem/mindspore
bugfix* fix bug in output tuple of tuple.* check kRWWrite input no-variable* input x of ScatterNdUpdate should be a parameter node
This commit is contained in:
parent
64c87170fd
commit
157710ca0f
|
@ -14,7 +14,6 @@
|
|||
# ============================================================================
|
||||
"""train_imagenet."""
|
||||
import os
|
||||
import math
|
||||
import argparse
|
||||
import random
|
||||
import numpy as np
|
||||
|
|
|
@ -195,6 +195,8 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
|||
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefKey), param});
|
||||
}
|
||||
// If sig is SignatureEnumRW::kRWRef, not do anything.
|
||||
} else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) {
|
||||
MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter.";
|
||||
}
|
||||
// add cast op here
|
||||
if (assign_source != nullptr && sig != SignatureEnumRW::kRWWrite) {
|
||||
|
|
|
@ -70,12 +70,11 @@ def _wrap_func(fn):
|
|||
def _convert_data(data):
|
||||
if isinstance(data, Tensor) and not isinstance(data, MsTensor):
|
||||
return MsTensor(data)
|
||||
if isinstance(data, tuple):
|
||||
return tuple(_convert_data(x) for x in data)
|
||||
if isinstance(data, list):
|
||||
return list(_convert_data(x) for x in data)
|
||||
return data
|
||||
|
||||
if isinstance(results, tuple):
|
||||
return tuple(_convert_data(x) for x in results)
|
||||
if isinstance(results, list):
|
||||
return list(_convert_data(x) for x in results)
|
||||
return _convert_data(results)
|
||||
|
||||
return wrapper
|
||||
|
|
|
@ -57,21 +57,22 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
|
|||
op_reshape = P.Reshape()
|
||||
op_shape = P.Shape()
|
||||
|
||||
param = op_cast(param, mstype.float32)
|
||||
m = op_cast(m, mstype.float32)
|
||||
v = op_cast(v, mstype.float32)
|
||||
gradient = op_cast(gradient, mstype.float32)
|
||||
param_fp32 = op_cast(param, mstype.float32)
|
||||
m_fp32 = op_cast(m, mstype.float32)
|
||||
v_fp32 = op_cast(v, mstype.float32)
|
||||
gradient_fp32 = op_cast(gradient, mstype.float32)
|
||||
|
||||
next_m = op_mul(beta1, m) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient)
|
||||
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)
|
||||
|
||||
next_v = op_mul(beta2, v) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient))
|
||||
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
|
||||
- beta2, op_square(gradient_fp32))
|
||||
|
||||
update = next_m / (op_sqrt(next_v) + eps)
|
||||
if decay_flag:
|
||||
update = update + op_mul(weight_decay_tensor, param)
|
||||
update = update + op_mul(weight_decay_tensor, param_fp32)
|
||||
|
||||
update_with_lr = op_mul(lr, update)
|
||||
next_param = param - op_reshape(update_with_lr, op_shape(param))
|
||||
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
|
||||
|
||||
next_v = F.depend(next_v, F.assign(param, next_param))
|
||||
next_v = F.depend(next_v, F.assign(m, next_m))
|
||||
|
|
|
@ -67,23 +67,23 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
|
|||
op_fill = P.Fill()
|
||||
op_dtype = P.DType()
|
||||
|
||||
param = op_cast(param, mstype.float32)
|
||||
m = op_cast(m, mstype.float32)
|
||||
v = op_cast(v, mstype.float32)
|
||||
gradient = op_cast(gradient, mstype.float32)
|
||||
param_fp32 = op_cast(param, mstype.float32)
|
||||
m_fp32 = op_cast(m, mstype.float32)
|
||||
v_fp32 = op_cast(v, mstype.float32)
|
||||
gradient_fp32 = op_cast(gradient, mstype.float32)
|
||||
|
||||
next_m = op_mul(beta1, m) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient)
|
||||
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient_fp32)
|
||||
|
||||
next_v = op_mul(beta2, v) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient))
|
||||
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32))
|
||||
|
||||
next_mm = next_m / (op_cast(num_one, mstype.float32)
|
||||
- op_pow(beta1, op_cast(global_step + num_one, mstype.float32)))
|
||||
next_vv = next_v / (op_cast(num_one, mstype.float32) -
|
||||
op_pow(beta2, op_cast(global_step + num_one, mstype.float32)))
|
||||
w_norm = op_norm(param)
|
||||
g_norm = op_norm(gradient)
|
||||
w_norm = op_norm(param_fp32)
|
||||
g_norm = op_norm(gradient_fp32)
|
||||
|
||||
g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay_tensor * param)
|
||||
g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay_tensor * param_fp32)
|
||||
zeros = F.zeros_like_tensor(w_norm)
|
||||
ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
|
||||
trust_ratio = op_select(
|
||||
|
@ -95,11 +95,11 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
|
|||
update = next_mm / (op_sqrt(next_vv) + eps)
|
||||
|
||||
if decay_flag:
|
||||
update = update + op_mul(weight_decay_tensor, param)
|
||||
update = update + op_mul(weight_decay_tensor, param_fp32)
|
||||
|
||||
update_with_lr = op_mul(op_mul(trust_ratio, lr), update)
|
||||
|
||||
next_param = param - op_reshape(update_with_lr, op_shape(param))
|
||||
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
|
||||
|
||||
next_v = F.depend(next_v, F.assign(param, next_param))
|
||||
next_v = F.depend(next_v, F.assign(m, next_m))
|
||||
|
|
|
@ -24,6 +24,8 @@ import itertools
|
|||
import numbers
|
||||
import numpy as np
|
||||
|
||||
from ..._c_expression import signature_rw as sig_rw
|
||||
from ..._c_expression import signature_kind as sig_kind
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ...common import dtype as mstype
|
||||
|
@ -1965,6 +1967,11 @@ class ScatterNdUpdate(PrimitiveWithInfer):
|
|||
>>> op = P.ScatterNdUpdate()
|
||||
>>> output = op(input_x, indices, update)
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
('input_x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD),
|
||||
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=True):
|
||||
|
|
|
@ -14,19 +14,20 @@
|
|||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
from mindspore import Tensor
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.nn as nn
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
|
||||
class AssignAdd(nn.Cell):
|
||||
def __init__( self):
|
||||
def __init__(self, value):
|
||||
super(AssignAdd, self).__init__()
|
||||
self.var = Parameter(value, name="var")
|
||||
self.add = P.AssignAdd()
|
||||
|
||||
def construct(self, x, y):
|
||||
res = self.add(x, y)
|
||||
def construct(self, y):
|
||||
res = self.add(self.var, y)
|
||||
return res
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -58,15 +59,17 @@ def test_assign_add():
|
|||
y2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
add = AssignAdd()
|
||||
output1 = add(x1, y1)
|
||||
add = AssignAdd(x1)
|
||||
output1 = add(y1)
|
||||
assert (output1.asnumpy() == expect1).all()
|
||||
output2 = add(output1, y1)
|
||||
add = AssignAdd(output1)
|
||||
output2 = add(y1)
|
||||
assert (output2.asnumpy() == expect2).all()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
add = AssignAdd()
|
||||
output1 = add(x2, y2)
|
||||
add = AssignAdd(x2)
|
||||
output1 = add(y2)
|
||||
assert (output1.asnumpy() == expect1).all()
|
||||
output2 = add(output1, y2)
|
||||
add = AssignAdd(output1)
|
||||
output2 = add(y2)
|
||||
assert (output2.asnumpy() == expect2).all()
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
from mindspore import Tensor
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.nn as nn
|
||||
import numpy as np
|
||||
|
@ -22,12 +22,13 @@ import mindspore.context as context
|
|||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
def __init__(self, value):
|
||||
super(Net, self).__init__()
|
||||
self.var = Parameter(value, name="var")
|
||||
self.assign = P.Assign()
|
||||
|
||||
def construct(self, var, value):
|
||||
return self.assign(var, value)
|
||||
def construct(self, value):
|
||||
return self.assign(self.var, value)
|
||||
|
||||
x = np.array([[1.2, 1], [1, 0]]).astype(np.float32)
|
||||
value = np.array([[1, 2], [3, 4.0]]).astype(np.float32)
|
||||
|
@ -37,13 +38,13 @@ value = np.array([[1, 2], [3, 4.0]]).astype(np.float32)
|
|||
@pytest.mark.env_onecard
|
||||
def test_assign():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
assign = Net()
|
||||
var = Tensor(x)
|
||||
output = assign(var, Tensor(value))
|
||||
assign = Net(var)
|
||||
output = assign(Tensor(value))
|
||||
|
||||
error = np.ones(shape=[2, 2]) * 1.0e-6
|
||||
diff1 = output.asnumpy() - value
|
||||
diff2 = var.asnumpy() - value
|
||||
diff2 = assign.var.default_input.asnumpy() - value
|
||||
assert np.all(diff1 < error)
|
||||
assert np.all(-diff1 < error)
|
||||
assert np.all(diff2 < error)
|
||||
|
|
|
@ -341,6 +341,15 @@ class SignNet(nn.Cell):
|
|||
def construct(self, x):
|
||||
return self.sign(x)
|
||||
|
||||
class AssignAdd(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.op = P.AssignAdd()
|
||||
self.inputdata = Parameter(initializer(1, [1], ms.float32), name="global_step")
|
||||
|
||||
def construct(self, input_):
|
||||
self.inputdata = input_
|
||||
return self.op(self.inputdata, input_)
|
||||
|
||||
test_case_math_ops = [
|
||||
('MatMulGrad', {
|
||||
|
@ -413,6 +422,9 @@ raise_set = [
|
|||
('StridedSlice_4_Error', {
|
||||
'block': (lambda x: P.StridedSlice(new_axis_mask="1.1"), {'exception': TypeError}),
|
||||
'desc_inputs': [0]}),
|
||||
('AssignAdd_Error', {
|
||||
'block': (P.AssignAdd(), {'exception': TypeError}),
|
||||
'desc_inputs': [[1]]}),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -38,8 +38,7 @@ def tensor_run_opt(opt, iters, learning_rate, momentum,
|
|||
gradient, variable, moment):
|
||||
""" tensor_run_opt """
|
||||
success = True
|
||||
new_weight = opt(gradient, moment, variable,
|
||||
learning_rate, momentum)
|
||||
new_weight = opt(variable, moment, learning_rate, gradient, momentum)
|
||||
success = F.depend(success, F.assign(variable, new_weight))
|
||||
return success
|
||||
|
||||
|
|
|
@ -446,12 +446,6 @@ test_cases = [
|
|||
'desc_inputs': [[128, 32, 32, 64]],
|
||||
'desc_bprop': [[128, 32, 32, 64]],
|
||||
}),
|
||||
('ApplyMomentum', {
|
||||
'block': P.ApplyMomentum(),
|
||||
'desc_inputs': [[2], [128, 32, 32, 64], [128, 32, 32, 64], [128, 32, 32, 64], [128, 32, 32, 64]],
|
||||
'desc_bprop': [[128, 32, 32, 64]],
|
||||
'skip': ['backward']
|
||||
}),
|
||||
('ScalarSummary', {
|
||||
'block': ScalarSummaryNet(),
|
||||
'desc_inputs': [2.2],
|
||||
|
@ -515,6 +509,12 @@ test_cases = [
|
|||
]
|
||||
|
||||
test_cases_for_verify_exception = [
|
||||
('ApplyMomentum_Error', {
|
||||
'block': (P.ApplyMomentum(), {'exception': TypeError}),
|
||||
'desc_inputs': [[2], [128, 32, 32, 64], [128, 32, 32, 64], [128, 32, 32, 64], [128, 32, 32, 64]],
|
||||
'desc_bprop': [[128, 32, 32, 64]],
|
||||
'skip': ['backward']
|
||||
}),
|
||||
('Conv2d_ValueError_1', {
|
||||
'block': (lambda _: P.Conv2D(3, 4, mode=-2.0), {'exception': TypeError}),
|
||||
'desc_inputs': [0],
|
||||
|
|
|
@ -674,12 +674,6 @@ test_case_nn_ops = [
|
|||
'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64], [64]],
|
||||
'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]],
|
||||
'skip': ['backward']}),
|
||||
('ApplyMomentum', {
|
||||
'block': P.ApplyMomentum(),
|
||||
'desc_inputs': [[128, 32, 32, 64], [128, 32, 32, 64],
|
||||
[32, 32, 64], [32, 32, 64], [32, 32, 64]],
|
||||
'desc_bprop': [[128, 32, 32, 64]],
|
||||
'skip': ['backward']}),
|
||||
('TopK', {
|
||||
'block': P.TopK(),
|
||||
'desc_const': [5],
|
||||
|
@ -1113,12 +1107,6 @@ test_case_other_ops = [
|
|||
'desc_inputs': (Tensor(np.ones((1, 3, 6, 6), np.float32)),
|
||||
Tensor(np.ones((2, 4), np.int32))),
|
||||
'desc_bprop': [[2]]}),
|
||||
('ScatterNdUpdate', {
|
||||
'block': P.ScatterNdUpdate(),
|
||||
'desc_inputs': (Tensor(np.ones((2, 3), np.float32)),
|
||||
Tensor(np.ones((2, 2), np.int32)),
|
||||
Tensor(np.ones((2,), np.float32))),
|
||||
'desc_bprop': [[2, 3]]}),
|
||||
('ScatterNd', {
|
||||
'block': P.ScatterNd(),
|
||||
'desc_const': [(3, 3)],
|
||||
|
@ -1178,7 +1166,7 @@ import mindspore.context as context
|
|||
@non_graph_engine
|
||||
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
|
||||
def test_exec():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
return test_exec_case
|
||||
|
||||
|
||||
|
@ -1207,6 +1195,12 @@ raise_set = [
|
|||
'block': (NetForFlatten0D(), {'exception': ValueError}),
|
||||
'desc_inputs': [Tensor(np.array(0).astype(np.int32))],
|
||||
'desc_bprop': [Tensor(np.array(0).astype(np.int32))]}),
|
||||
('ScatterNdUpdate', {
|
||||
'block': (P.ScatterNdUpdate(), {'exception': TypeError}),
|
||||
'desc_inputs': (Tensor(np.ones((2, 3), np.float32)),
|
||||
Tensor(np.ones((2, 2), np.int32)),
|
||||
Tensor(np.ones((2,), np.float32))),
|
||||
'desc_bprop': [[2, 3]]}),
|
||||
]
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue