Add the check of the return value of Assign, AssignAdd, AssignSub.

This commit is contained in:
Margaret_wangrui 2022-09-17 10:36:03 +08:00
parent 61d2bb4b1e
commit dd2c45a829
41 changed files with 482 additions and 284 deletions

View File

@ -155,19 +155,59 @@ void CheckValueTuple(const AnfNodePtr &node) {
auto value_tuple = value->cast_ptr<ValueTuple>();
MS_EXCEPTION_IF_NULL(value_tuple);
const auto &tuple_values = value_tuple->value();
for (size_t i = 0; i < tuple_values.size(); ++i) {
auto input_node = NewValueNode(tuple_values[i]);
for (const auto &tuple_value : tuple_values) {
auto input_node = NewValueNode(tuple_value);
ValidateOperation(input_node);
ValidateValueNode(input_node);
}
}
void CheckAssignReturnValue(const AnfNodePtr &node) {
static const PrimitiveSet assign_prims = {prim::kPrimAssign, prim::kPrimAssignAdd, prim::kPrimAssignSub};
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
auto real_input = node->cast_ptr<CNode>()->input(1);
while (IsPrimitiveCNode(real_input, prim::kPrimDepend)) {
real_input = real_input->cast_ptr<CNode>()->input(1);
}
if (!IsOneOfPrimitiveCNode(real_input, assign_prims)) {
return;
}
} else if (!IsOneOfPrimitiveCNode(node, assign_prims)) {
return;
}
auto fg = node->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto mgr = fg->manager();
MS_EXCEPTION_IF_NULL(mgr);
auto &node_users = mgr->node_users();
auto iter = node_users.find(node);
if (iter == node_users.end()) {
return;
}
static const PrimitiveSet virtual_prims = {
prim::kPrimImageSummary, prim::kPrimScalarSummary, prim::kPrimTensorSummary, prim::kPrimHistogramSummary,
prim::kPrimMakeTuple, prim::kPrimStateSetItem, prim::kPrimTupleGetItem, prim::kPrimLoad,
prim::kPrimPartial, prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimDynamicLossScale};
auto users = iter->second;
for (const auto &user : users) {
auto user_node = user.first;
if (!IsOneOfPrimitiveCNode(user_node, virtual_prims)) {
MS_LOG(WARNING) << "Deprecated: the return value of Assign/AssignAdd/AssignSub operator will be removed "
<< "in subsequent releases.\n"
<< "You can modify the code from:\na = P.Assign()(param, value)\nb = a * 2\nto: \n"
<< "P.Assign()(param, value)\nb = param * 2\n"
<< "Please check your code:" << trace::GetDebugInfo(node->debug_info());
}
}
}
void Validate(const FuncGraphPtr &func_graph) {
FuncGraphManagerPtr mgr = Manage(func_graph, false);
MS_EXCEPTION_IF_NULL(mgr);
const AnfNodeSet &all_nodes = mgr->all_nodes();
for (auto node : all_nodes) {
TraceGuard guard(std::make_shared<TraceCopy>(node->debug_info()));
CheckAssignReturnValue(node);
while (IsPrimitiveCNode(node, prim::kPrimReturn) || IsPrimitiveCNode(node, prim::kPrimDepend)) {
node = node->cast_ptr<CNode>()->input(1);
}

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -37,7 +37,8 @@ def _update_parameters_after_broadcast(delta_weight, update_delta_weight, parame
shape = F.shape(delta_weight)
update_delta_weight = P.Reshape()(update_delta_weight, shape)
new_parameter = old_parameter - update_delta_weight
return P.Assign()(parameter, new_parameter)
P.Assign()(parameter, new_parameter)
return parameter
def _send_before_receive(send_part, send, recv):

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -53,7 +53,8 @@ _save_weight = C.MultitypeFuncGraph("_save_weight")
@_save_weight.register("Tensor", "Tensor")
def _save_weight_process(new_parameter, old_parameter):
return P.Assign()(new_parameter, old_parameter)
P.Assign()(new_parameter, old_parameter)
return new_parameter
_grad_scale = C.MultitypeFuncGraph("grad_scale")

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -44,7 +44,8 @@ _save_weight = C.MultitypeFuncGraph("_save_weight")
@_save_weight.register("Tensor", "Tensor")
def _save_weight_process(parameter, new_parameter):
return P.Assign()(parameter, new_parameter)
P.Assign()(parameter, new_parameter)
return parameter
_pca_projection = C.MultitypeFuncGraph("_pca_projection")

View File

@ -33,7 +33,8 @@ gradient_accumulation_op = C.MultitypeFuncGraph("gradient_accumulation_op")
@gradient_accumulation_op.register("Int64", "Tensor", "Tensor")
def cumulative_grad_process(accumulation_step, cumulative_grad, grad):
"""Apply gradient accumulation to cumulative grad."""
return P.AssignAdd()(cumulative_grad, grad / accumulation_step)
P.AssignAdd()(cumulative_grad, grad / accumulation_step)
return cumulative_grad
gradient_clear_op = C.MultitypeFuncGraph("gradient_clear_op")

View File

@ -82,13 +82,15 @@ def _run_opt_with_one_number(eps, clip_threshold, beta1, beta2t, weight_decay, s
exp_avg_sq_row_update = P.Mul()(exp_avg_sq_row_update, beta2t)
update_mean = reduce_mean_keep_alive(update, -1) * (1.0 - beta2t)
exp_avg_sq_row_update = P.Add()(exp_avg_sq_row_update, update_mean)
exp_avg_sq_row_update = F.assign(exp_avg_sq_row, F.cast(exp_avg_sq_row_update, F.dtype(exp_avg_sq_row)))
F.assign(exp_avg_sq_row, F.cast(exp_avg_sq_row_update, F.dtype(exp_avg_sq_row)))
exp_avg_sq_row_update = exp_avg_sq_row
exp_avg_sq_col_update = F.cast(exp_avg_sq_col, grad_dtype)
exp_avg_sq_col_update = P.Mul()(exp_avg_sq_col_update, beta2t)
update_mean = reduce_mean_keep_alive(update, -2) * (1.0 - beta2t)
exp_avg_sq_col_update = P.Add()(exp_avg_sq_col_update, update_mean)
exp_avg_sq_col_update = F.assign(exp_avg_sq_col, F.cast(exp_avg_sq_col_update, F.dtype(exp_avg_sq_col)))
F.assign(exp_avg_sq_col, F.cast(exp_avg_sq_col_update, F.dtype(exp_avg_sq_col)))
exp_avg_sq_col_update = exp_avg_sq_col
update = _approx_sq_grad(exp_avg_sq_row_update, exp_avg_sq_col_update)
update = P.Mul()(update, grad)
@ -96,7 +98,8 @@ def _run_opt_with_one_number(eps, clip_threshold, beta1, beta2t, weight_decay, s
exp_avg_sq_update = F.cast(exp_avg_sq, grad_dtype)
update = update * (1.0 - beta2t)
exp_avg_sq_update = P.Add()(P.Mul()(exp_avg_sq_update, beta2t), update)
exp_avg_sq_update = F.assign(exp_avg_sq, F.cast(exp_avg_sq_update, F.dtype(exp_avg_sq)))
F.assign(exp_avg_sq, F.cast(exp_avg_sq_update, F.dtype(exp_avg_sq)))
exp_avg_sq_update = exp_avg_sq
exp_avg_sq_update = 1.0 / P.Sqrt()(exp_avg_sq_update)
update = P.Mul()(exp_avg_sq_update, grad)
@ -109,7 +112,8 @@ def _run_opt_with_one_number(eps, clip_threshold, beta1, beta2t, weight_decay, s
if compression:
exp_avg_update = F.cast(exp_avg, grad_dtype)
exp_avg_update = P.Add()(P.Mul()(exp_avg_update, beta1), update * (1 - beta1))
update = F.assign(exp_avg, F.cast(exp_avg_update, F.dtype(exp_avg)))
F.assign(exp_avg, F.cast(exp_avg_update, F.dtype(exp_avg)))
update = exp_avg
if weight_decay_flag:
p_data_fp32_coff = p_data_fp32 * -weight_decay * learning_rate_update
@ -404,7 +408,8 @@ class AdaFactor(Optimizer):
def construct(self, gradients):
gradients = self.flatten_gradients(gradients)
lr = self.get_lr()
step = F.assign_add(self.step, 1)
F.assign_add(self.step, 1)
step = self.step
if self.scale_lr and self.relative_step:
if self.warmup_init:
min_step = 1e-6 * step

View File

@ -47,7 +47,8 @@ def _update_parameters_adasum(delta_weight, update_delta_weight, parameter, old_
shape = F.shape(delta_weight)
update_delta_weight = reshape(update_delta_weight, shape)
new_parameter = old_parameter - update_delta_weight
return P.Assign()(parameter, new_parameter)
P.Assign()(parameter, new_parameter)
return parameter
@_reshape_grads.register("Tensor", "Tensor", "Function")
@ -377,7 +378,8 @@ def _get_delta_weight_process(new_parameter, old_parameter):
@_save_weight.register("Tensor", "Tensor")
def _save_weight_process(new_parameter, old_parameter):
return P.Assign()(new_parameter, old_parameter)
P.Assign()(new_parameter, old_parameter)
return new_parameter
@_clone_weight.register("Tensor", "Tensor")

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -1678,8 +1678,10 @@ class TransformerEncoderLayer(Cell):
if self.use_past:
# reset states, init_reset True for reuse and False for reset
key_reset = self.assign(self.key_past, self.mul(self.key_past, F.cast(init_reset, self.dtype)))
value_reset = self.assign(self.value_past, self.mul(self.value_past, F.cast(init_reset, self.dtype)))
self.assign(self.key_past, self.mul(self.key_past, F.cast(init_reset, self.dtype)))
key_reset = self.key_past
self.assign(self.value_past, self.mul(self.value_past, F.cast(init_reset, self.dtype)))
value_reset = self.value_past
# add dependency for desired execution order
input_x = F.depend(input_x, key_reset)
input_x = F.depend(input_x, value_reset)
@ -1707,8 +1709,10 @@ class TransformerEncoderLayer(Cell):
# current key and value
key_present, value_present = layer_present
# update key and value calculated this step
key_update = self.assign(self.key_past, key_present)
value_update = self.assign(self.value_past, value_present)
self.assign(self.key_past, key_present)
key_update = self.key_past
self.assign(self.value_past, value_present)
value_update = self.value_past
# add dependency for desired execution order
key_update = F.depend(key_update, key_reset)
value_update = F.depend(value_update, value_reset)
@ -2115,8 +2119,10 @@ class TransformerDecoderLayer(Cell):
value_reset = None
if self.use_past:
# reset states, init_reset True for reuse and False for reset
key_reset = self.assign(self.key_past, self.mul(self.key_past, F.cast(init_reset, self.dtype)))
value_reset = self.assign(self.value_past, self.mul(self.value_past, F.cast(init_reset, self.dtype)))
self.assign(self.key_past, self.mul(self.key_past, F.cast(init_reset, self.dtype)))
key_reset = self.key_past
self.assign(self.value_past, self.mul(self.value_past, F.cast(init_reset, self.dtype)))
value_reset = self.value_past
# add dependency for desired execution order
input_x = F.depend(input_x, key_reset)
input_x = F.depend(input_x, value_reset)
@ -2159,8 +2165,10 @@ class TransformerDecoderLayer(Cell):
# current key and value
key_present, value_present = layer_present
# update key and value calculated this step
key_update = self.assign(self.key_past, key_present)
value_update = self.assign(self.value_past, value_present)
self.assign(self.key_past, key_present)
key_update = self.key_past
self.assign(self.value_past, value_present)
value_update = self.value_past
# add dependency for desired execution order
key_update = F.depend(key_update, key_reset)
value_update = F.depend(value_update, value_reset)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -151,8 +151,7 @@ def get_bprop_virtual_assign_add(self):
def bprop(x, y, out, dout):
if reduce_scatter:
dout = reduce_scatter(dout)
temp = assign_add(y, dout)
return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(y))), temp)
return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(y))), assign_add(y, dout))
return bprop
@ -208,12 +207,14 @@ def get_bprop_mirror_micro_step_operator(self):
z = F.depend(z, dout)
real_grad = all_reduce(z)
real_grad = F.tensor_mul(real_grad, scale)
assign_out = assign(z, real_grad)
assign(z, real_grad)
assign_out = z
else:
if F.issubclass_(F.typeof(dout), mstype.tensor):
z = F.depend(z, dout)
real_grad = all_reduce(z)
assign_out = assign(z, real_grad)
assign(z, real_grad)
assign_out = z
if opt_shard:
return (real_grad, cast(out_tensor, dtype(z)))
return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign_out)
@ -284,7 +285,7 @@ def get_bprop_mini_step_all_gather(self):
if mean_flag:
dx = F.tensor_mul(dx, scale)
if add_accu:
z = assign_add(z, dx)
assign_add(z, dx)
dx = F.depend(dx, z)
else:
dx = dout

View File

@ -1,10 +1,9 @@
0.1.1 MindSpore*1.9.0:ˆ
s
bprop.22:doutbprop.22:[CNode]23:1bprop.22:[CNode]23:1"REF::S-Prim-MakeTuple:2:Default/S-Prim-MakeTuple-op15bprop.22*
bprop.22:x*
bprop.22:out*
bprop.22:dout2
bprop.22:[CNode]23:1:@d656093b51d0012f58d8fd8c87351fc64fdc19fa3aa49c67d08964bd3af12156Pb&
0.1.1 MindSpore*2.0.0:ü
m
bprop.1:doutbprop.1:[CNode]2:1bprop.1:[CNode]2:1"REF::S-Prim-MakeTuple:2:Default/S-Prim-MakeTuple-op0bprop.1*
bprop.1:x*
bprop.1:out*
bprop.1:dout2
bprop.1:[CNode]2:1:@b347c311cc2bda87367b6fcbdcf96b857d6d2e109f47aed3df832027dc31a978Pb&
S-Prim-MakeTuple:2S-Prim-MakeTupleh

View File

@ -18,8 +18,9 @@
from mindspore.ops import operations as P
from .._primitive_cache import _get_cache_prim
assign_ = P.Assign()
def assign(variable, value):
"""
Assigns `Parameter` with a value.
@ -48,14 +49,16 @@ def assign(variable, value):
Examples:
>>> value = Tensor([2.0], mindspore.float32)
>>> variable = mindspore.Parameter(Tensor([1.0], mindspore.float32), name="variable")
>>> output = ops.assign(variable, value)
>>> print(output)
>>> ops.assign(variable, value)
>>> print(variable)
[2.]
"""
return assign_(variable, value)
assign_sub_ = P.AssignSub()
def assign_sub(variable, value):
"""
Updates a `Parameter` by subtracting a value from it.
@ -93,7 +96,7 @@ def assign_sub(variable, value):
Examples:
>>> variable = mindspore.Parameter(initializer(1, [1], mindspore.int32), name="global_step")
>>> value = Tensor(np.ones([1]).astype(np.int32) * 100)
>>> output = ops.assign_sub(variable, value)
>>> ops.assign_sub(variable, value)
>>> print(variable.asnumpy())
[-99]
"""
@ -101,6 +104,8 @@ def assign_sub(variable, value):
assign_add_ = P.AssignAdd()
def assign_add(variable, value):
"""
Updates a `Parameter` by adding a value to it.
@ -138,7 +143,7 @@ def assign_add(variable, value):
Examples:
>>> variable = mindspore.Parameter(initializer(1, [1], mindspore.int32), name="global_step")
>>> value = Tensor(np.ones([1]).astype(np.int32) * 100)
>>> output = ops.assign_add(variable, value)
>>> ops.assign_add(variable, value)
>>> print(variable.asnumpy())
[101]
"""
@ -186,7 +191,7 @@ def index_add(x, indices, y, axis, use_lock=True, check_index_bound=True):
>>> x = Parameter(Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32), name="name_x")
>>> indices = Tensor(np.array([0, 2]), mindspore.int32)
>>> y = Tensor(np.array([[0.5, 1.0], [1.0, 1.5], [2.0, 2.5]]), mindspore.float32)
>>> output = ops.index_add(x, indices, y, 1)
>>> ops.index_add(x, indices, y, 1)
>>> print(output)
[[ 1.5 2. 4. ]
[ 5. 5. 7.5]

View File

@ -36,8 +36,8 @@ class Assign(Primitive):
>>> value = Tensor([2.0], mindspore.float32)
>>> variable = mindspore.Parameter(Tensor([1.0], mindspore.float32), name="variable")
>>> assign = ops.Assign()
>>> output = assign(variable, value)
>>> print(output)
>>> assign(variable, value)
>>> print(variable)
[2.]
"""
__mindspore_signature__ = (

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -709,10 +709,10 @@ def test_multi_assign():
self.para3 = Parameter(Tensor(3, dtype=ms.int32), name='para3')
def construct(self, x, y, z):
a = self.assign(self.para1, x)
a = self.assign(self.para2, y)
a = self.assign(self.para3, z)
return self.para1 + self.para2 + a
self.assign(self.para1, x)
self.assign(self.para2, y)
self.assign(self.para3, z)
return self.para1 + self.para2 + self.para3
x = Tensor(4, dtype=ms.int32)
y = Tensor(5, dtype=ms.int32)
@ -1275,11 +1275,12 @@ class AssignNet(Cell):
1, [1, 3, 2, 2], ms.float32), name='value')
def construct(self, x):
x = self.assign_sub(self.input_data, x)
x = self.relu(x)
self.assign_sub(self.input_data, x)
x = self.relu(self.input_data)
x = self.mean(x, (2, 3))
return x
@security_off_wrap
def test_auto_mixed_precision_train_1(pynative_save_graphs):
net = AssignNet()
@ -1287,6 +1288,7 @@ def test_auto_mixed_precision_train_1(pynative_save_graphs):
label32 = Tensor(np.zeros([1, 3]).astype(np.float32))
use_build_train_network_check_cast_num(net, "O0", input32, label32, 0)
@security_off_wrap
def test_auto_mixed_precision_train_2(pynative_save_graphs):
net = AssignNet()
@ -1364,6 +1366,7 @@ def use_build_train_network_controlflow_check_cast_num(network, level, input_x,
assert len(castnum) == cast_num
return out_me
@security_off_wrap
def test_auto_mixed_precision_controlflow_auto(pynative_save_graphs):
net = MixControlNet(3, 5)
@ -1527,6 +1530,7 @@ def test_print_assign_print():
Description: Test load eliminate when umonad and iomona both exist.
Expectation: No exception.
"""
class Print(Cell):
def __init__(self):
super().__init__()

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -82,10 +82,10 @@ def _count_unequal_element(data_expected, data_me, rtol, atol):
assert data_expected.shape == data_me.shape
total_count = len(data_expected.flatten())
error = np.abs(data_expected - data_me)
greater = np.greater(error, atol + np.abs(data_me)*rtol)
greater = np.greater(error, atol + np.abs(data_me) * rtol)
loss_count = np.count_nonzero(greater)
assert (loss_count/total_count) < rtol, \
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".\
assert (loss_count / total_count) < rtol, \
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
format(data_expected[greater], data_me[greater], error[greater])
@ -130,6 +130,7 @@ class SideEffectCastAll(Cell):
out_b = self.cast(self.parameter_b, self.dtype)
return out_a, out_b
@security_off_wrap
def test_side_effect_castall():
clear_files()
@ -335,6 +336,7 @@ class InplaceNet(Cell):
output = self.add(tmp_c1, tmp_c2)
return output
@security_off_wrap
def test_ir_fusion_inplace_bn_conv_conv():
clear_files()
@ -393,7 +395,7 @@ class Add(Cell):
class MixControlNet(Cell):
def __init__(self, in_channel, x):
super().__init__()
#self._save_graphs(save_graph_flag=True, save_graph_path=".")
# self._save_graphs(save_graph_flag=True, save_graph_path=".")
self.biasadd = P.BiasAdd()
self.equal = P.Equal()
self.addn = P.AddN()
@ -460,6 +462,7 @@ def use_build_train_network_controlflow_check_cast_num(network, level, input_x,
assert len(castnum) == cast_num
return out_me
@security_off_wrap
def test_auto_mixed_precision_controlflow_auto():
context.set_context(mode=context.PYNATIVE_MODE, save_graphs=True)
@ -474,6 +477,7 @@ def test_auto_mixed_precision_controlflow_auto():
use_build_train_network_controlflow_check_cast_num(net, "auto", input_x,
label, cast_num)
@security_off_wrap
def test_updatestate_between_assigns():
class UpdateState_Assigns(Cell):
@ -499,6 +503,7 @@ def test_updatestate_between_assigns():
updatestate_num = re.findall('UpdateState', content)
assert len(updatestate_num) == 1
@security_off_wrap
def test_updatestate_between_maketuple_assign():
class UpdateState_MakeTuple_Assign(Cell):
@ -526,6 +531,7 @@ def test_updatestate_between_maketuple_assign():
updatestate_num = re.findall('UpdateState', content)
assert len(updatestate_num) == 1
@security_off_wrap
def test_updatestate_between_assign_maketuple():
class UpdateState_Assign_MakeTuple(Cell):
@ -563,6 +569,7 @@ def test_cycle_parameter_binding():
Description: Auto-monad should work properly when cycle parameter binding existed.
Expectation: Normal output, no core dump.
"""
class MyActor(Cell):
def construct(self, inputs):
return inputs

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -266,7 +266,8 @@ class ControlOneIfOneParaOneAddn(Cell):
else:
out = self.addn([input_data, input_data, input_data])
if x > y:
out = self.assign(self.inputdata, input_data)
self.assign(self.inputdata, input_data)
out = self.inputdata
return out
@ -368,7 +369,8 @@ class SideEffectReturnParameterNet(Cell):
self.relu = P.ReLU()
def construct(self, inputs):
p1 = self.assign(self.para, inputs)
self.assign(self.para, inputs)
p1 = self.para
out = self.addn((inputs, inputs, inputs))
out = self.relu(out)
return p1
@ -401,7 +403,8 @@ class SideEffectAssignAddnReluReturnParNet(Cell):
self.relu = P.ReLU()
def construct(self, inputs):
p1 = self.assign(self.parameter1, inputs)
self.assign(self.parameter1, inputs)
p1 = self.parameter1
out = self.addN((inputs, inputs, inputs))
out = self.relu(out)
return p1
@ -426,10 +429,8 @@ def test_side_effect_grad_read_dependency_assign_addn_relu_return_parameter():
try:
context.set_context(mode=context.PYNATIVE_MODE)
out2 = net.grad_mindspore_impl(inputs, grad_ys)
allclose_nparray(out1[0][0].asnumpy(), out2[0]
[0].asnumpy(), 0.001, 0.001)
allclose_nparray(out1[1][0].asnumpy(), out2[1]
[0].asnumpy(), 0.001, 0.001)
allclose_nparray(out1[0][0].asnumpy(), out2[0][0].asnumpy(), 0.001, 0.001)
allclose_nparray(out1[1][0].asnumpy(), out2[1][0].asnumpy(), 0.001, 0.001)
finally:
context.set_context(mode=context.GRAPH_MODE)
@ -462,6 +463,7 @@ class SideEffectPrintInHighOrdeAddnNet(Cell):
grad_out = grad_net(params, grad_ys)
return grad_out
@security_off_wrap
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@ -495,7 +497,8 @@ class SideEffectControlFlowAssignDependTwoIfNet(Cell):
self.assign(self.parameter1, x)
if self.parameter1 > y:
x = self.mul(x, x)
p2 = self.assign(self.parameter1, x)
self.assign(self.parameter1, x)
p2 = self.parameter1
if self.parameter1 > y:
x = self.addn((x, self.parameter1))
p3 = self.assign(self.parameter1, x)
@ -661,10 +664,12 @@ class SideEffectControlFlowAssignDependWhileNet(Cell):
self.depend = P.Depend()
def construct(self, x, y, z):
p1 = self.assign(self.parameter1, x)
self.assign(self.parameter1, x)
p1 = self.parameter1
while self.parameter1 < y:
x = self.addn((x, x))
p2 = self.assignadd(self.parameter1, z)
self.assignadd(self.parameter1, z)
p2 = self.parameter1
self.depend(p2, p1)
return x

View File

@ -39,17 +39,18 @@ def test_monad_vmap():
self.value = Tensor([3, 4], mstype.int32)
def construct(self, x):
return self.assign(x, self.value)
x = self.assign(x, self.value)
return x
vampfunc = F.vmap(AssignNet())
@ms_function
def test_monad(a):
c = Tensor([[1, 2], [3, 4], [5, 6]], mstype.int32)
out = vampfunc(a)
c = a + c
out2 = P.AssignAdd()(a, c)
P.AssignAdd()(a, c)
out2 = a
return out, out2
a = Parameter(Tensor([[1, 2], [3, 4], [5, 6]], mstype.int32), name='param_a')

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -22,6 +22,8 @@ from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
grad_all = C.GradOperation(get_all=True)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@ -75,6 +77,7 @@ def test_for_in_while_01():
assert graph_forward_res == expect_forward_res
assert graph_backward_res == expect_backward_res
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -23,6 +23,7 @@ from mindspore.common import dtype as mstype
grad_all = C.GradOperation(get_all=True)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@ -87,6 +88,7 @@ def test_for_after_for_in_while_01():
assert graph_forward_res == Tensor([1], mstype.int32)
assert graph_backward_res == (Tensor([0], mstype.int32), Tensor([0], mstype.int32))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -58,7 +58,8 @@ class ControlSimpleIfWithAssign(nn.Cell):
if x > y:
out = self.addn([input_data, input_data, input_data])
else:
out = self.assign(self.input_data, input_data)
self.assign(self.input_data, input_data)
out = self.input_data
return out

View File

@ -26,7 +26,8 @@ _clear_op = C.MultitypeFuncGraph("clear_op")
def _cumulative_gard(grad_sum, grad):
"""Apply gard sum to cumulative gradient."""
add = P.AssignAdd()
return add(grad_sum, grad)
add(grad_sum, grad)
return grad_sum
@_clear_op.register("Tensor", "Tensor")

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -29,17 +29,21 @@ class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.AssignAdd = P.AssignAdd()
self.assign_add = P.AssignAdd()
self.inputdata = Parameter(initializer('normal', [1]), name="global_step")
print("inputdata: ", self.inputdata)
def construct(self, x):
out = self.AssignAdd(self.inputdata, x)
return out
self.assign_add(self.inputdata, x)
return self.inputdata
def test_net():
"""test AssignAdd"""
"""
Feature: test AssignAdd.
Description: test AssignAdd.
Expectation: No exception.
"""
net = Net()
x = Tensor(np.ones([1]).astype(np.float32) * 100)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -29,17 +29,21 @@ class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.AssignSub = P.AssignSub()
self.assign_sub = P.AssignSub()
self.inputdata = Parameter(initializer('normal', [1]), name="global_step")
print("inputdata: ", self.inputdata)
def construct(self, x):
out = self.AssignSub(self.inputdata, x)
return out
self.assign_sub(self.inputdata, x)
return self.inputdata
def test_net():
"""test AssignSub"""
"""
Feature: test AssignSub.
Description: test AssignSub.
Expectation: No exception.
"""
net = Net()
x = Tensor(np.ones([1]).astype(np.int32) * 100)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -29,8 +29,8 @@ class AssignAdd(nn.Cell):
self.add = P.AssignAdd()
def construct(self, y):
res = self.add(self.var, y)
return res
self.add(self.var, y)
return self.var
@pytest.mark.level0

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -29,8 +29,8 @@ class AssignSub(nn.Cell):
self.sub = P.AssignSub()
def construct(self, y):
res = self.sub(self.var, y)
return res
self.sub(self.var, y)
return self.var
@pytest.mark.level0
@ -117,13 +117,17 @@ def test_assign_sub_func():
y2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
output1 = F.assign_sub(x1, y1)
F.assign_sub(x1, y1)
output1 = x1
assert (output1.asnumpy() == expect1).all()
output2 = F.assign_sub(output1, y1)
F.assign_sub(output1, y1)
output2 = output1
assert (output2.asnumpy() == expect2).all()
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
output1 = F.assign_sub(x2, y2)
F.assign_sub(x2, y2)
output1 = x2
assert (output1.asnumpy() == expect1).all()
output2 = F.assign_sub(output1, y2)
F.assign_sub(output1, y2)
output2 = output1
assert (output2.asnumpy() == expect2).all()

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -29,8 +29,8 @@ class AssignAdd(nn.Cell):
self.add = P.AssignAdd()
def construct(self, y):
res = self.add(self.var, y)
return res
self.add(self.var, y)
return self.var
def assign_add(nptype):

View File

@ -1,4 +1,4 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -29,7 +29,8 @@ class Net(nn.Cell):
self.assign = P.Assign()
def construct(self, param):
return self.assign(self.var, param)
self.assign(self.var, param)
return self.var
x = np.array([[1.2, 1], [1, 0]]).astype(np.float32)

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -32,6 +32,7 @@ class AssignAdd(nn.Cell):
self.add(self.var, y)
return self.var
def get_output(x2, y2, enable_graph_kernel=False):
context.set_context(enable_graph_kernel=enable_graph_kernel)
add = AssignAdd(x2)
@ -41,6 +42,7 @@ def get_output(x2, y2, enable_graph_kernel=False):
output = [result_gk_on_1, result_gk_on_2]
return output
def assign_add():
x2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
y2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
@ -53,6 +55,7 @@ def assign_add():
assert np.allclose(o1.asnumpy(), e1.asnumpy())
assert np.allclose(o2.asnumpy(), e2.asnumpy())
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -60,6 +63,7 @@ def test_assign_add_gpu():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
assign_add()
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training

View File

@ -98,9 +98,9 @@ class LambGPUOrigin(nn.Cell):
next_param = param_fp32 - self.op_reshape(update_with_lr, self.op_shape(param_fp32))
next_param = F.depend(next_param, F.assign(self.param, self.op_cast(next_param, F.dtype(self.param))))
next_param = F.depend(next_param, F.assign(self.m, self.op_cast(next_m, F.dtype(self.m))))
next_param = F.depend(next_param, F.assign(self.v, self.op_cast(next_v, F.dtype(self.v))))
F.assign(self.param, self.op_cast(next_param, F.dtype(self.param)))
F.assign(self.m, self.op_cast(next_m, F.dtype(self.m)))
F.assign(self.v, self.op_cast(next_v, F.dtype(self.v)))
return self.op_cast(next_param, F.dtype(self.param))

View File

@ -34,6 +34,7 @@ def test_parameter_1_1():
Description: If the name of the input of construct is same as the parameters, add suffix to the name of the input.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
@ -58,6 +59,7 @@ def test_parameter_1_2():
Description: If the name of the input of construct is same as the parameters, add suffix to the name of the input.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
@ -82,6 +84,7 @@ def test_parameter_2_1():
Description: If parameters in init have same name, an exception will be thrown.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
@ -107,6 +110,7 @@ def test_parameter_2_2():
Description: Check the name of parameter in init.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
@ -134,6 +138,7 @@ def test_parameter_3():
Description: Check the name of parameter in init.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
@ -158,6 +163,7 @@ def test_parameter_4():
Description: Check the name of parameter in init.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
@ -183,6 +189,7 @@ def test_parameter_5_1():
Description: Check the name of parameter in init.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
@ -207,6 +214,7 @@ def test_parameter_5_2():
Description: Check the name of parameter in init.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
@ -233,6 +241,7 @@ def test_parameter_list_tuple_no_name():
Description: Check the name of parameter in init.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
@ -257,6 +266,7 @@ def test_parameter_in_tuple():
Description: Check the name of parameter in init.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
@ -282,6 +292,7 @@ def test_parameter_parameter_tuple_1():
Description: Check the name of parameter in init.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
@ -308,6 +319,7 @@ def test_parameter_parameter_tuple_2():
Description: Check the name of parameter in init.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
@ -332,6 +344,7 @@ def test_parameter():
Description: If parameter in list or tuple is not given a name, will give it a unique name.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
@ -369,6 +382,7 @@ def test_parameter_same_name_between_tuple_or_list():
Description: If the same name exists between tuple and list, an exception will be thrown.
Expectation: Get the expected exception report.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
@ -402,6 +416,7 @@ def test_parameter_argument_and_fv():
Expectation: Parameter used as argument should equal to used as FV.
"""
y = Parameter(Tensor([1]))
class Demo(Cell):
def construct(self, x):
ms.ops.Assign()(x, Tensor([0]))
@ -428,6 +443,7 @@ def test_parameter_argument_grad():
Description: Use Parameter as input argmument, and pass it to varargs.
Expectation: Parameter used as argument should equal to used as FV.
"""
class ParameterArgumentCell(Cell):
def __init__(self):
super(ParameterArgumentCell, self).__init__()

View File

@ -56,13 +56,14 @@ class EmaUpdate(nn.Cell):
def ema(self, tau, policy_param, target_param):
new_param = (1 - tau) * target_param + tau * policy_param
out = P.Assign()(target_param, new_param)
return out
P.Assign()(target_param, new_param)
return target_param
def construct(self):
if self.step % self.period == 0:
self.hyper_map(F.partial(self.ema, self.tau), self.policy_param, self.target_param)
return self.assignadd(self.step, 1)
self.assignadd(self.step, 1)
return self.step
@pytest.mark.level1

View File

@ -1,3 +1,17 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
from mindspore import ops, Tensor, context
@ -12,7 +26,9 @@ class AssignNet(Cell):
self.input_data = input_variable
def construct(self, input_x):
return self.op(self.input_data, input_x)
self.op(self.input_data, input_x)
return self.input_data
@pytest.mark.level0
@pytest.mark.platform_x86_cpu

View File

@ -160,7 +160,8 @@ def test_vmap_monad():
def construct(self, assign_add_val, assign_add_var, scatter_ref, indices, updates):
self.assign(self.assign_ref, self.replace_tensor)
F.print(self.assign_ref)
out = self.assign_add(assign_add_var, assign_add_val) + self.scatter_add(scatter_ref, indices, updates)
self.assign_add(assign_add_var, assign_add_val)
out = assign_add_var + self.scatter_add(scatter_ref, indices, updates)
return out
class VmapMonadNet(nn.Cell):
@ -411,8 +412,8 @@ def test_vmap_with_celllist_input():
self.ref_b = Parameter(Tensor([0, 1, 2], mstype.float32), name='ref_b')
def construct(self, replace_tensor):
out = self.assign(self.ref_a, replace_tensor)
out = self.ref_b + out
self.assign(self.ref_a, replace_tensor)
out = self.ref_b + self.ref_a
return out
m1 = AssignNet()

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -255,10 +255,10 @@ def test_tuple_assignadd_tuple():
def test_string_assignadd_string():
result1 = "string111"
result2 = "string111"
input2 = "string111"
input_x = "string222"
result1 += input_x
result2 = AssignAdd(result2, input_x)()
result2 = AssignAdd(input2, input_x)
expect = "string111string222"
assert result1 == expect
assert result2 == expect
@ -279,9 +279,9 @@ class AssignSub(nn.Cell):
def test_number_assignsub_number():
input_x = 2
result1 = 5
result2 = 5
input2 = 5
result1 -= input_x
result2 = AssignSub(result2, input_x)()
result2 = AssignSub(input2, input_x)
expect = 3
assert np.all(result1 == expect)
assert np.all(result2 == expect)
@ -292,7 +292,7 @@ def test_tensor_assignsub_tensor():
result1 = Tensor(np.array([[4, -2], [2, 17]]))
result2 = Tensor(np.array([[4, -2], [2, 17]]))
result1 -= input_x
result2 = AssignSub(result2, input_x)()
result2 = AssignSub(result2, input_x)
expect = Tensor(np.array([[2, -4], [-1, 14]]))
assert np.all(result1.asnumpy() == expect)
assert np.all(result2.asnumpy() == expect)
@ -301,9 +301,9 @@ def test_tensor_assignsub_tensor():
def test_tensor_assignsub_number():
input_x = 3
result1 = Tensor(np.array([[4, -2], [2, 17]])).astype(np.float16)
result2 = Tensor(np.array([[4, -2], [2, 17]])).astype(np.float16)
input2 = Tensor(np.array([[4, -2], [2, 17]])).astype(np.float16)
result1 -= input_x
result2 = AssignSub(result2, input_x)()
result2 = AssignSub(input2, input_x)
expect = Tensor(np.array([[1, -5], [-1, 14]]))
assert np.all(result1.asnumpy() == expect)
assert np.all(result2.asnumpy() == expect)
@ -311,10 +311,10 @@ def test_tensor_assignsub_number():
def test_number_assignsub_tensor():
result1 = 3
result2 = 3
input2 = 3
input_x = Tensor(np.array([[4, -2], [2, 17]])).astype(np.float16)
result1 -= input_x
result2 = AssignSub(result2, input_x)()
result2 = AssignSub(input2, input_x)
expect = Tensor(np.array([[-1, 5], [1, -14]]))
assert np.all(result1.asnumpy() == expect)
assert np.all(result2.asnumpy() == expect)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -33,18 +33,22 @@ class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.AssignAdd = P.AssignAdd()
self.assign_add = P.AssignAdd()
self.inputdata = Parameter(initializer(1, [1], ms.int64), name="global_step")
print("inputdata: ", self.inputdata)
def construct(self, x):
out = self.AssignAdd(self.inputdata, x)
return out
self.assign_add(self.inputdata, x)
return self.inputdata
@non_graph_engine
def test_AssignAdd_1():
"""test AssignAdd 1"""
def test_assign_add_1():
"""
Feature: test AssignAdd.
Description: test AssignAdd.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE)
net = Net()
x = Tensor(np.ones([1]).astype(np.int64) * 100)
@ -62,8 +66,12 @@ def test_AssignAdd_1():
@non_graph_engine
def test_AssignAdd_2():
"""test AssignAdd 2"""
def test_assign_add_2():
"""
Feature: test AssignAdd.
Description: test AssignAdd.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE)
net = Net()
x = Tensor(np.ones([1]).astype(np.int64) * 102)
@ -85,18 +93,22 @@ class AssignAddNet(nn.Cell):
def __init__(self):
super(AssignAddNet, self).__init__()
self.AssignAdd = P.AssignAdd()
self.assign_add = P.AssignAdd()
self.inputdata = Parameter(initializer(1, [1], ms.float16), name="KIND_AUTOCAST_SCALAR_TO_TENSOR")
self.one = 1
def construct(self, ixt):
z1 = self.AssignAdd(self.inputdata, self.one)
return z1
self.assign_add(self.inputdata, self.one)
return self.inputdata
@non_graph_engine
def test_assignadd_scalar_cast():
"""
Feature: test AssignAdd.
Description: test AssignAdd.
Expectation: No exception.
"""
net = AssignAddNet()
x = Tensor(np.ones([1]).astype(np.int64) * 102)
# _executor.compile(net, 1)
_ = net(x)

View File

@ -34,7 +34,8 @@ class Net(nn.Cell):
self.sub = P.AssignSub()
def construct(self, value):
return self.sub(self.b, value)
self.sub(self.b, value)
return self.b
def test_net():

View File

@ -30,6 +30,9 @@ from mindspore.ops.operations._grad_ops import IgammaGradA
from mindspore.ops import prim_attr_register, PrimitiveWithInfer
from mindspore.ops.operations.math_ops import Zeta, Igamma, Igammac
from mindspore.ops.operations.sparse_ops import DenseToDenseSetOperation
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from ..ut_filter import non_graph_engine
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \
@ -368,7 +371,8 @@ class AssignAdd(nn.Cell):
def construct(self, input_):
self.inputdata = input_
return self.op(self.inputdata, input_)
self.op(self.inputdata, input_)
return self.inputdata
class FloorNet(nn.Cell):

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -206,7 +206,8 @@ def test_assign():
self.variable = Parameter(Tensor([1.0], mstype.float32), name="variable")
def construct(self, x):
return self.assign(self.variable, x)
self.assign(self.variable, x)
return self.variable
value = Tensor([2.0], mstype.float32)
assign = AssignNet()
@ -222,7 +223,8 @@ def test_assign_add():
self.variable = Parameter(initializer(1, [1], mstype.int64), name="global_step")
def construct(self, x):
return self.assign_add(self.variable, x)
self.assign_add(self.variable, x)
return self.variable
value = Tensor(np.ones([1]).astype(np.int64) * 100)
assign_add = AssignAddNet()
@ -234,11 +236,12 @@ def test_assign_sub():
class AssignSubNet(nn.Cell):
def __init__(self):
super(AssignSubNet, self).__init__()
self.assign = P.AssignSub()
self.assign_sub = P.AssignSub()
self.variable = Parameter(initializer(1, [1], mstype.int32), name="global_step")
def construct(self, x):
return self.assign(self.variable, x)
self.assign_sub(self.variable, x)
return self.variable
value = Tensor(np.ones([1]).astype(np.int32) * 100)
assign_sub = AssignSubNet()
@ -500,6 +503,7 @@ def test_switch():
Expectation: Load the bprop mindir successfully.
"""
context.set_context(mode=context.PYNATIVE_MODE)
class SwitchNet(nn.Cell):
def construct(self, x, y):
if x > y:
@ -527,7 +531,8 @@ def test_update_state():
self.variable = Parameter(initializer(1, [1], mstype.int64), name="global_step")
def construct(self, x):
return self.assign_add(self.variable, x)
self.assign_add(self.variable, x)
return self.variable
value = Tensor(np.ones([1]).astype(np.int64) * 100)
update_state = UpdateStateNet()

View File

@ -62,6 +62,7 @@ def test_matmul_sub():
Description: matmul-sub net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -91,6 +92,7 @@ def test_matmul_add():
Description: matmul-add net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -120,6 +122,7 @@ def test_matmul_mul():
Description: matmul-mul net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -142,12 +145,14 @@ def test_matmul_mul():
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
compile_net(net, x, y, b)
def test_matmul_mod():
"""
Feature: distribute operator sub in auto parallel.
Description: matmul-mod net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -170,12 +175,14 @@ def test_matmul_mod():
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
compile_net(net, x, y, b)
def test_matmul_floormod():
"""
Feature: distribute operator sub in auto parallel.
Description: matmul-floormod net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -205,6 +212,7 @@ def test_matmul_atan2():
Description: matmul-atan2 net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -234,6 +242,7 @@ def test_matmul_divNoNan():
Description: matmul-divNoNan net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -263,6 +272,7 @@ def test_matmul_logicaland():
Description: matmul-logical_and net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -297,6 +307,7 @@ def test_matmul_logicalor():
Description: matmul-logical_or net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -331,6 +342,7 @@ def test_matmul_div():
Description: matmul-div net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -360,6 +372,7 @@ def test_matmul_add_broadcast():
Description: matmul-add broadcast net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -389,6 +402,7 @@ def test_matmul_add_broadcast2():
Description: matmul-add broadcast net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -418,6 +432,7 @@ def test_matmul_sub_broadcast():
Description: matmul-sub broadcast net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -447,6 +462,7 @@ def test_matmul_sub_broadcast2():
Description: matmul-sub broadcast net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -476,6 +492,7 @@ def test_matmul_mul_broadcast():
Description: matmul-mul broadcast net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -505,6 +522,7 @@ def test_matmul_mul_broadcast2():
Description: matmul-mul broadcast net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -534,6 +552,7 @@ def test_matmul_div_broadcast():
Description: matmul-div broadcast net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -563,6 +582,7 @@ def test_matmul_div_broadcast2():
Description: matmul-div broadcast net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -592,6 +612,7 @@ def test_matmul_greater_broadcast():
Description: matmul-greater broadcast net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -621,6 +642,7 @@ def test_matmul_greater_broadcast2():
Description: matmul-greater broadcast net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -650,6 +672,7 @@ def test_matmul_floordiv():
Description: matmul-floordiv net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -679,6 +702,7 @@ def test_matmul_floordiv_broadcast():
Description: matmul-floordiv broadcast net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -708,6 +732,7 @@ def test_matmul_floordiv_broadcast2():
Description: matmul-floordiv broadcast net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -737,6 +762,7 @@ def test_assign_sub():
Description: mul-assign_sub net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self):
super().__init__()
@ -751,7 +777,7 @@ def test_assign_sub():
def construct(self, x):
out = self.mul(x, self.mul_weight)
out = self.assign_sub(self.assignsub_weight, out)
self.assign_sub(self.assignsub_weight, out)
return out
class SubNetWithLoss(nn.Cell):
@ -786,25 +812,26 @@ def test_assign_sub():
def test_assign_add():
"""
Feature: distribute operator sub in auto parallel.
Feature: distribute operator add in auto parallel.
Description: mul-assign_add net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.assign_sub = P.AssignAdd()
self.assign_add = P.AssignAdd()
self.mul = P.Mul()
self.mul_weight = Parameter(Tensor(np.full([128, 32],
0.5, dtype=np.float32)),
name="mul_weight")
self.assignsub_weight = Parameter(Tensor(np.full([128, 32],
self.assignadd_weight = Parameter(Tensor(np.full([128, 32],
1.1, dtype=np.float32)),
name="assignsub_weight")
name="assignadd_weight")
def construct(self, x):
out = self.mul(x, self.mul_weight)
out = self.assign_sub(self.assignsub_weight, out)
self.assign_add(self.assignadd_weight, out)
return out
class SubNetWithLoss(nn.Cell):
@ -839,25 +866,26 @@ def test_assign_add():
def test_assign():
"""
Feature: distribute operator sub in auto parallel.
Feature: distribute operator assign in auto parallel.
Description: mul-assign_sub net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.assign_sub = P.Assign()
self.assign = P.Assign()
self.mul = P.Mul()
self.mul_weight = Parameter(Tensor(np.full([128, 32],
0.5, dtype=np.float32)),
name="mul_weight")
self.assignsub_weight = Parameter(Tensor(np.full([128, 32],
self.assign_weight = Parameter(Tensor(np.full([128, 32],
1.1, dtype=np.float32)),
name="assignsub_weight")
name="assign_weight")
def construct(self, x):
out = self.mul(x, self.mul_weight)
out = self.assign_sub(self.assignsub_weight, out)
self.assign(self.assign_weight, out)
return out
class SubNetWithLoss(nn.Cell):
@ -896,17 +924,16 @@ def test_matmul_bitwise_and_broadcast():
Description: mul-BitwiseAnd net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.bitwise_and = P.BitwiseAnd().shard(strategy1)
self.matmul = P.MatMul().shard(strategy2)
def construct(self, x, y, z):
out = self.bitwise_and(x, y)
out = self.matmul(out, z)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
@ -926,6 +953,7 @@ def test_matmul_bitwise_or_broadcast():
Description: mul-BitwiseOr net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -954,6 +982,7 @@ def test_matmul_bitwise_xor_broadcast():
Description: mul-BitwiseXor net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -982,6 +1011,7 @@ def test_matmul_mul_no_nan_broadcast():
Description: mul-MulNoNan net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -1010,6 +1040,7 @@ def test_matmul_truncate_div_broadcast():
Description: mul-TruncateDiv net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -1038,6 +1069,7 @@ def test_matmul_truncate_mod_broadcast():
Description: mul-TruncateMod net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -1066,6 +1098,7 @@ def test_matmul_xdivy_broadcast():
Description: mul-Xdivy net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -1094,6 +1127,7 @@ def test_matmul_xlogy_broadcast():
Description: mul-Xlogy net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -1122,6 +1156,7 @@ def test_matmul_squared_difference_broadcast():
Description: mul-SquaredDifference net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -1150,6 +1185,7 @@ def test_matmul_masked_fill_broadcast_with_value_float():
Description: mul-MaskedFill net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
@ -1179,6 +1215,7 @@ def test_matmul_masked_fill_broadcast_with_value_tensor():
Description: mul-MaskedFill net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -21,8 +21,10 @@ from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
import mindspore.ops.operations as op
def test_net_infer():
""" test_net_infer """
class Net(nn.Cell):
""" Net definition """
@ -40,13 +42,14 @@ def test_net_infer():
x = self.flatten(x)
out = self.fc(x)
return out
Tensor(np.random.randint(0, 255, [1, 3, 224, 224]))
Net()
def test_assign_in_while():
context.set_context(device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target="Ascend", mode=context.GRAPH_MODE)
class Net(nn.Cell):
def __init__(self, input_shape):
super().__init__()
@ -58,7 +61,8 @@ def test_assign_in_while():
while x < y:
inputdata = self.inputdata
x = x + 1
out = self.assign(inputdata, z)
self.assign(inputdata, z)
out = inputdata
return out
x = Tensor(np.array(1).astype(np.int32))
@ -76,9 +80,6 @@ def test_dup_context():
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def __init__(self):
super().__init__()
def construct(self, x):
def identity(f):
return f
@ -106,9 +107,6 @@ def test_maybe_poly_func():
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def __init__(self):
super().__init__()
def construct(self, x, y, z):
def identity(f, inp):
return f(inp)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -34,6 +34,7 @@ from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
from mindspore.ops.functional import tensor_add
from ...ut_filter import non_graph_engine
# pylint: disable=W0613,W0612
# W0613: unused-argument
@ -46,11 +47,11 @@ def fixture_enable_check_bprop():
grad_all = C.GradOperation(get_all=True)
log = logging.getLogger("test")
log.setLevel(level=logging.ERROR)
context.set_context(mode=context.GRAPH_MODE)
# Test case: use the parse obj interface use default parameter
class Net(nn.Cell):
""" Net definition """
@ -60,7 +61,7 @@ class Net(nn.Cell):
self.softmax1 = nn.Softmax(dim)
self.softmax2 = nn.Softmax(dim + 1)
def construct(self, input_data, input1=1+2+3+4):
def construct(self, input_data, input1=1 + 2 + 3 + 4):
return self.softmax1(input_data)
@ -84,10 +85,6 @@ def test_parse_defalut_parameter_case2():
# Test case: use the variable parameter for parse object
class Net1(nn.Cell):
""" Net1 definition """
def __init__(self):
super(Net1, self).__init__()
def construct(self, *args):
x = args[0]
return x
@ -177,9 +174,6 @@ def test_bprop_with_wrong_output_num(enable_check_bprop):
return bprop
class BpropWithWrongOutputNumCell(nn.Cell):
def __init__(self):
super(BpropWithWrongOutputNumCell, self).__init__()
def construct(self, x, y):
return BpropWithWrongOutputNum()(x, y)
@ -187,6 +181,7 @@ def test_bprop_with_wrong_output_num(enable_check_bprop):
grad_all(BpropWithWrongOutputNumCell())(Tensor(np.array(1).astype(np.int32)),
Tensor(np.array(2).astype(np.int32)))
def test_bprop_with_wrong_output_type(enable_check_bprop):
class BpropWithWrongOutputType(PrimitiveWithInfer):
@prim_attr_register
@ -212,9 +207,6 @@ def test_bprop_with_wrong_output_type(enable_check_bprop):
return bprop
class BpropWithWrongOutputTypeCell(nn.Cell):
def __init__(self):
super(BpropWithWrongOutputTypeCell, self).__init__()
def construct(self, x):
return BpropWithWrongOutputType()(x)
@ -248,9 +240,6 @@ def test_bprop_with_wrong_output_shape(enable_check_bprop):
return bprop
class BpropWithWrongOutputShapeCell(nn.Cell):
def __init__(self):
super(BpropWithWrongOutputShapeCell, self).__init__()
def construct(self, x):
return BpropWithWrongOutputShape()(x)
@ -259,6 +248,7 @@ def test_bprop_with_wrong_output_shape(enable_check_bprop):
net.set_grad()
grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32)))
class AssignWhenInsertGrad(nn.Cell):
""" NetWithNDarray definition """
@ -280,8 +270,10 @@ class AssignWhenInsertGrad(nn.Cell):
out = self.getG(out)
return out
grad_all = C.GradOperation(get_all=True)
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
@ -291,6 +283,7 @@ class GradNet(nn.Cell):
out = self.net(*inputs)
return out, grad_all(self.net)(*inputs)
def test_assign_in_insert_grad():
context.set_context(mode=context.GRAPH_MODE)
net = AssignWhenInsertGrad().to_float(ms.float16)
@ -298,6 +291,7 @@ def test_assign_in_insert_grad():
net_back = GradNet(net)
net_back(ms.Tensor(input_data))
class Assign(nn.Cell):
""" NetWithNDarray definition """
@ -317,6 +311,7 @@ def test_assign(enable_check_bprop):
net_back = GradNet(net)
net_back(input_data)
class AssignCheck(nn.Cell):
""" NetWithNDarray definition """

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -44,6 +44,7 @@ def test_while_loop_phi():
net = WhileSubGraphParam()
net(x, y, z)
class WhileSubGraphParam2(Cell):
def __init__(self):
super().__init__()
@ -73,15 +74,15 @@ class WhileSubGraphParam3(Cell):
def __init__(self, initial_input_x):
super().__init__()
self.initial_input_x = initial_input_x
self.X = ms.Parameter(initial_input_x, name="parameter_x")
self.Y = ms.Parameter(self.initial_input_x, name="parameter_y")
self.x = ms.Parameter(initial_input_x, name="parameter_x")
self.y = ms.Parameter(self.initial_input_x, name="parameter_y")
def construct(self):
a = 0
while a < 3:
self.X = self.X + self.Y
self.x = self.x + self.y
a += 1
return self.X
return self.x
def test_while_loop_phi_3():
@ -91,6 +92,7 @@ def test_while_loop_phi_3():
net = WhileSubGraphParam3(x)
net()
class ControlMixedWhileIf(nn.Cell):
def __init__(self):
super().__init__()
@ -101,24 +103,28 @@ class ControlMixedWhileIf(nn.Cell):
def construct(self, x, y, z, c2, c4):
out = self.assign(self.var, c4)
while x < c2:
y = self.assign(self.var, c4)
self.assign(self.var, c4)
y = self.var
while y < c2 and x < c2:
if 2 * y < c2:
y = y + 2
else:
y = y + 1
out = out + y
z = self.assign(self.var, c4)
self.assign(self.var, c4)
z = self.var
while z < c2:
z = z + 1
out = out + z
x = x + 1
out = out + x
while x < 2 * c2:
y = self.assign(self.var, c4)
self.assign(self.var, c4)
y = self.var
x = x + 1
while y < c2:
z = self.assign(self.var, c4)
self.assign(self.var, c4)
z = self.var
while z < c2:
z = z + 1
if x < c2:
@ -130,6 +136,7 @@ class ControlMixedWhileIf(nn.Cell):
out = out + x
return out
def test_mixed_while_if():
context.set_context(mode=context.PYNATIVE_MODE)
x = np.array(2).astype(np.int32)