forked from mindspore-Ecosystem/mindspore
Rewrite tensor's __bool__ for pynative mode
This commit is contained in:
parent
28c8a5cc26
commit
5f77fbdd75
|
@ -672,7 +672,7 @@ def check_input_data(*data, data_class):
|
||||||
|
|
||||||
def check_output_data(data):
|
def check_output_data(data):
|
||||||
"""Output data check."""
|
"""Output data check."""
|
||||||
if not data:
|
if data is None:
|
||||||
raise RuntimeError('Executor return data ' + str(data) + ', please check your net or input data.')
|
raise RuntimeError('Executor return data ' + str(data) + ', please check your net or input data.')
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
"""standard_method"""
|
"""standard_method"""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.common._register_for_tensor import tensor_operator_registry
|
||||||
from ...ops import functional as F
|
from ...ops import functional as F
|
||||||
from ...ops import operations as P
|
from ...ops import operations as P
|
||||||
from ...ops.primitive import constexpr
|
from ...ops.primitive import constexpr
|
||||||
|
@ -146,7 +147,7 @@ def check_is_tensor_bool_cond(shp):
|
||||||
"""check if tensor is a bool condition"""
|
"""check if tensor is a bool condition"""
|
||||||
if shp in ((), (1,)):
|
if shp in ((), (1,)):
|
||||||
return True
|
return True
|
||||||
raise ValueError("tensor as bool condition, its shape should be () or (1,), but got ", shp)
|
raise ValueError("The truth value of an array with several elements is ambiguous.")
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def const_tensor_to_bool(x):
|
def const_tensor_to_bool(x):
|
||||||
|
@ -155,7 +156,7 @@ def const_tensor_to_bool(x):
|
||||||
raise ValueError("Only constant tensor bool can be converted to bool")
|
raise ValueError("Only constant tensor bool can be converted to bool")
|
||||||
x = x.asnumpy()
|
x = x.asnumpy()
|
||||||
if x.shape not in ((), (1,)):
|
if x.shape not in ((), (1,)):
|
||||||
raise ValueError("Tensor to bool should input shape () or (1), but got ", x.shape)
|
raise ValueError("The truth value of an array with several elements is ambiguous.")
|
||||||
if x.shape == ():
|
if x.shape == ():
|
||||||
value = bool(x)
|
value = bool(x)
|
||||||
else:
|
else:
|
||||||
|
@ -296,3 +297,5 @@ def list_append(self_, item):
|
||||||
def to_array(x):
|
def to_array(x):
|
||||||
"""Implementation of `to_array`."""
|
"""Implementation of `to_array`."""
|
||||||
return x.__ms_to_array__()
|
return x.__ms_to_array__()
|
||||||
|
|
||||||
|
tensor_operator_registry.register('__bool__', tensor_bool)
|
||||||
|
|
|
@ -108,6 +108,10 @@ class Tensor(Tensor_):
|
||||||
out = tensor_operator_registry.get('__neg__')(self)
|
out = tensor_operator_registry.get('__neg__')(self)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
out = tensor_operator_registry.get('__bool__')(self)
|
||||||
|
return out
|
||||||
|
|
||||||
def __pos__(self):
|
def __pos__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@ hastype = Primitive('hastype')
|
||||||
cast = P.Cast()
|
cast = P.Cast()
|
||||||
dtype = P.DType()
|
dtype = P.DType()
|
||||||
isconstant = Primitive('is_constant')
|
isconstant = Primitive('is_constant')
|
||||||
|
isconstant.add_prim_attr('const_value', True)
|
||||||
|
|
||||||
|
|
||||||
issubclass_ = P.IsSubClass()
|
issubclass_ = P.IsSubClass()
|
||||||
|
|
|
@ -37,7 +37,7 @@ class Bprop(Cell):
|
||||||
self.grad = grad_op
|
self.grad = grad_op
|
||||||
self.sens = sens
|
self.sens = sens
|
||||||
self.with_sens = False
|
self.with_sens = False
|
||||||
if sens:
|
if sens is not None:
|
||||||
self.with_sens = True
|
self.with_sens = True
|
||||||
|
|
||||||
def construct(self, *inputs):
|
def construct(self, *inputs):
|
||||||
|
@ -71,10 +71,10 @@ def bprop(func, *inputs, grads_wrt_outputs=None, wrt: list = None, params: list
|
||||||
func.set_train()
|
func.set_train()
|
||||||
|
|
||||||
with_sens_param = False
|
with_sens_param = False
|
||||||
if grads_wrt_outputs:
|
if grads_wrt_outputs is not None:
|
||||||
with_sens_param = True
|
with_sens_param = True
|
||||||
|
|
||||||
if not wrt:
|
if wrt is None:
|
||||||
wrt = []
|
wrt = []
|
||||||
wrt_inputs = False
|
wrt_inputs = False
|
||||||
if 'inputs' in wrt:
|
if 'inputs' in wrt:
|
||||||
|
|
|
@ -63,7 +63,7 @@ def fill_block_config(ret, block_config, tid, group, desc_inputs, desc_bprop, ex
|
||||||
sampling_times, reduce_output, init_param_with, \
|
sampling_times, reduce_output, init_param_with, \
|
||||||
split_outputs, exception, error_keywords = get_function_config(block_config[-1])
|
split_outputs, exception, error_keywords = get_function_config(block_config[-1])
|
||||||
|
|
||||||
if block:
|
if block is not None:
|
||||||
func_list.append({
|
func_list.append({
|
||||||
keyword.id: tid,
|
keyword.id: tid,
|
||||||
keyword.group: group,
|
keyword.group: group,
|
||||||
|
|
|
@ -22,7 +22,7 @@ from mindspore.common.tensor import Tensor
|
||||||
|
|
||||||
|
|
||||||
def setup_module():
|
def setup_module():
|
||||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
|
||||||
|
|
||||||
c1 = Tensor([2], mstype.int32)
|
c1 = Tensor([2], mstype.int32)
|
||||||
|
|
|
@ -48,7 +48,7 @@ def test_list_equal():
|
||||||
ret = net(x, y)
|
ret = net(x, y)
|
||||||
|
|
||||||
print(ret.asnumpy())
|
print(ret.asnumpy())
|
||||||
assert ret == x
|
assert np.all(ret.asnumpy() == x.asnumpy())
|
||||||
assert ret.dtype == mstype.int32
|
assert ret.dtype == mstype.int32
|
||||||
assert ret.shape == (6, 8, 10)
|
assert ret.shape == (6, 8, 10)
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ def test_list_not_equal():
|
||||||
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
||||||
z = [1, 2, 3]
|
z = [1, 2, 3]
|
||||||
net = Net(z)
|
net = Net(z)
|
||||||
assert net(x, y) == y
|
assert np.all(net(x, y).asnumpy() == y.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
def test_list_expansion():
|
def test_list_expansion():
|
||||||
|
@ -91,7 +91,7 @@ def test_list_expansion():
|
||||||
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
||||||
z = [1, 2, 3]
|
z = [1, 2, 3]
|
||||||
net = Net(z)
|
net = Net(z)
|
||||||
assert net(x, y) == x
|
assert np.all(net(x, y).asnumpy() == x.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
def test_list_append():
|
def test_list_append():
|
||||||
|
@ -114,7 +114,7 @@ def test_list_append():
|
||||||
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
||||||
z = [1, 2, 3]
|
z = [1, 2, 3]
|
||||||
net = Net(z)
|
net = Net(z)
|
||||||
assert net(x, y) == y
|
assert np.all(net(x, y).asnumpy() == y.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
def test_class_member_list_append():
|
def test_class_member_list_append():
|
||||||
|
|
|
@ -115,8 +115,7 @@ def test_if_none():
|
||||||
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
||||||
z = None
|
z = None
|
||||||
net = Net(z)
|
net = Net(z)
|
||||||
assert net(x, y) == y
|
assert np.all(net(x, y).asnumpy() == y.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
def test_if_str_is_not_none_right():
|
def test_if_str_is_not_none_right():
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
|
@ -136,7 +135,7 @@ def test_if_str_is_not_none_right():
|
||||||
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
||||||
z = "ok"
|
z = "ok"
|
||||||
net = Net(z)
|
net = Net(z)
|
||||||
assert net(x, y) == y
|
assert np.all(net(x, y).asnumpy() == y.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
def test_if_str_is_not_none_left():
|
def test_if_str_is_not_none_left():
|
||||||
|
@ -157,7 +156,7 @@ def test_if_str_is_not_none_left():
|
||||||
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
||||||
z = "ok"
|
z = "ok"
|
||||||
net = Net(z)
|
net = Net(z)
|
||||||
assert net(x, y) == y
|
assert np.all(net(x, y).asnumpy() == y.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
def test_if_none_equal_none():
|
def test_if_none_equal_none():
|
||||||
|
@ -178,7 +177,7 @@ def test_if_none_equal_none():
|
||||||
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
||||||
z = None
|
z = None
|
||||||
net = Net(z)
|
net = Net(z)
|
||||||
assert net(x, y) == x
|
assert np.all(net(x, y).asnumpy() == x.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
def test_if_str_is_null():
|
def test_if_str_is_null():
|
||||||
|
@ -199,7 +198,7 @@ def test_if_str_is_null():
|
||||||
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
||||||
z = ""
|
z = ""
|
||||||
net = Net(z)
|
net = Net(z)
|
||||||
assert net(x, y) == y
|
assert np.all(net(x, y).asnumpy() == y.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
def test_if_str_is_true():
|
def test_if_str_is_true():
|
||||||
|
@ -220,7 +219,7 @@ def test_if_str_is_true():
|
||||||
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
||||||
z = "ok"
|
z = "ok"
|
||||||
net = Net(z)
|
net = Net(z)
|
||||||
assert net(x, y) == x
|
assert np.all(net(x, y).asnumpy() == x.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
def test_if_str_equal():
|
def test_if_str_equal():
|
||||||
|
@ -241,7 +240,7 @@ def test_if_str_equal():
|
||||||
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
||||||
z = "ok"
|
z = "ok"
|
||||||
net = Net(z)
|
net = Net(z)
|
||||||
assert net(x, y) == x
|
assert np.all(net(x, y).asnumpy() == x.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
def test_if_tuple_is_null():
|
def test_if_tuple_is_null():
|
||||||
|
@ -262,7 +261,7 @@ def test_if_tuple_is_null():
|
||||||
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
||||||
z = ()
|
z = ()
|
||||||
net = Net(z)
|
net = Net(z)
|
||||||
assert net(x, y) == y
|
assert np.all(net(x, y).asnumpy() == y.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
def test_if_tuple_is_not_null():
|
def test_if_tuple_is_not_null():
|
||||||
|
@ -283,7 +282,7 @@ def test_if_tuple_is_not_null():
|
||||||
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
||||||
z = (1, 2, 3)
|
z = (1, 2, 3)
|
||||||
net = Net(z)
|
net = Net(z)
|
||||||
assert net(x, y) == x
|
assert np.all(net(x, y).asnumpy() == x.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
def test_if_dict_is_null():
|
def test_if_dict_is_null():
|
||||||
|
@ -304,7 +303,7 @@ def test_if_dict_is_null():
|
||||||
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
||||||
z = {}
|
z = {}
|
||||||
net = Net(z)
|
net = Net(z)
|
||||||
assert net(x, y) == y
|
assert np.all(net(x, y).asnumpy() == y.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
def test_if_dict_is_not_null():
|
def test_if_dict_is_not_null():
|
||||||
|
@ -325,7 +324,7 @@ def test_if_dict_is_not_null():
|
||||||
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
||||||
z = {"one": 1, "two": 2}
|
z = {"one": 1, "two": 2}
|
||||||
net = Net(z)
|
net = Net(z)
|
||||||
assert net(x, y) == x
|
assert np.all(net(x, y).asnumpy() == x.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
def test_if_else_assign():
|
def test_if_else_assign():
|
||||||
|
@ -355,7 +354,7 @@ def test_if_else_assign():
|
||||||
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
y = Tensor(np.zeros([3, 4, 5], np.int32))
|
||||||
z = [1, 2]
|
z = [1, 2]
|
||||||
net = Net(z)
|
net = Net(z)
|
||||||
assert net(x, y) == x
|
assert np.all(net(x, y).asnumpy() == x.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
def test_if_compile_true():
|
def test_if_compile_true():
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.train._utils import _to_full_shapes, _to_full_tensor
|
from mindspore.train._utils import _to_full_shapes, _to_full_tensor
|
||||||
|
@ -33,7 +35,7 @@ def test_to_full_tensor_1():
|
||||||
expect = ([[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [1, 2, 3], [4, 5, 6], [0, 0, 0], [0, 0, 0]])
|
expect = ([[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [1, 2, 3], [4, 5, 6], [0, 0, 0], [0, 0, 0]])
|
||||||
expect_tensor = Tensor(expect, dtype=ms.float32)
|
expect_tensor = Tensor(expect, dtype=ms.float32)
|
||||||
|
|
||||||
assert full_tensor[0] == expect_tensor
|
assert np.all(full_tensor[0].asnumpy() == expect_tensor.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
def test_to_full_tensor_2():
|
def test_to_full_tensor_2():
|
||||||
|
@ -50,7 +52,8 @@ def test_to_full_tensor_2():
|
||||||
expect_tensor1 = Tensor(expect1, dtype=ms.int32)
|
expect_tensor1 = Tensor(expect1, dtype=ms.int32)
|
||||||
expect_tensors = (expect_tensor0, expect_tensor1)
|
expect_tensors = (expect_tensor0, expect_tensor1)
|
||||||
|
|
||||||
assert full_tensor == expect_tensors
|
assert np.all(full_tensor[0].asnumpy() == expect_tensors[0].asnumpy())
|
||||||
|
assert np.all(full_tensor[1].asnumpy() == expect_tensors[1].asnumpy())
|
||||||
|
|
||||||
|
|
||||||
def test_to_full_tensor_sens_2():
|
def test_to_full_tensor_sens_2():
|
||||||
|
@ -68,4 +71,6 @@ def test_to_full_tensor_sens_2():
|
||||||
expect_tensor_sens = Tensor(0.1, dtype=ms.float32)
|
expect_tensor_sens = Tensor(0.1, dtype=ms.float32)
|
||||||
expect_tensors = (expect_tensor0, expect_tensor1, expect_tensor_sens)
|
expect_tensors = (expect_tensor0, expect_tensor1, expect_tensor_sens)
|
||||||
|
|
||||||
assert full_tensor == expect_tensors
|
assert np.all(full_tensor[0].asnumpy() == expect_tensors[0].asnumpy())
|
||||||
|
assert np.all(full_tensor[1].asnumpy() == expect_tensors[1].asnumpy())
|
||||||
|
assert np.all(full_tensor[2].asnumpy() == expect_tensors[2].asnumpy())
|
||||||
|
|
|
@ -47,7 +47,7 @@ def test_parser_three_default_mixed_args_subnet():
|
||||||
tensor1 = Tensor(np.full((2, 3), 2).astype(np.float32))
|
tensor1 = Tensor(np.full((2, 3), 2).astype(np.float32))
|
||||||
tensor2 = Tensor(np.full((3, 2), 4).astype(np.float32))
|
tensor2 = Tensor(np.full((3, 2), 4).astype(np.float32))
|
||||||
net = NetOut()
|
net = NetOut()
|
||||||
assert net(tensor1, tensor2) == tensor1
|
assert np.all(net(tensor1, tensor2).asnumpy() == tensor1.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=keyword-arg-before-vararg
|
# pylint: disable=keyword-arg-before-vararg
|
||||||
|
|
|
@ -53,4 +53,7 @@ def test_hypermap_specialize_param():
|
||||||
|
|
||||||
expected_ret = (Tensor(np.full(1, 5).astype(np.int32)), Tensor(np.full(2, 5).astype(np.int32)))
|
expected_ret = (Tensor(np.full(1, 5).astype(np.int32)), Tensor(np.full(2, 5).astype(np.int32)))
|
||||||
ret = hypermap_specialize_param()
|
ret = hypermap_specialize_param()
|
||||||
assert ret == (expected_ret, list(expected_ret))
|
assert ret[0][0].asnumpy() == expected_ret[0].asnumpy()
|
||||||
|
assert np.all(ret[0][1].asnumpy() == expected_ret[1].asnumpy())
|
||||||
|
assert ret[1][0].asnumpy() == list(expected_ret[0].asnumpy())
|
||||||
|
assert np.all(ret[1][1].asnumpy() == list(expected_ret[1].asnumpy()))
|
||||||
|
|
|
@ -66,5 +66,4 @@ def test_assign_in_while():
|
||||||
input_shape = (1024, 512)
|
input_shape = (1024, 512)
|
||||||
z = Tensor(np.random.randn(*input_shape).astype(np.float32))
|
z = Tensor(np.random.randn(*input_shape).astype(np.float32))
|
||||||
net = Net(input_shape)
|
net = Net(input_shape)
|
||||||
ret = net(x, y, z)
|
net(x, y, z)
|
||||||
assert ret == z
|
|
||||||
|
|
|
@ -39,5 +39,5 @@ def test_tensor_orign_ops():
|
||||||
assert np.all(z.asnumpy() - (x.asnumpy() + y.asnumpy()) < 0.0001)
|
assert np.all(z.asnumpy() - (x.asnumpy() + y.asnumpy()) < 0.0001)
|
||||||
z = x * y
|
z = x * y
|
||||||
assert np.all(z.asnumpy() - (x.asnumpy() * y.asnumpy()) < 0.0001)
|
assert np.all(z.asnumpy() - (x.asnumpy() * y.asnumpy()) < 0.0001)
|
||||||
assert x == y
|
assert np.all(x.asnumpy() == y.asnumpy())
|
||||||
assert x != 'zero'
|
assert x != 'zero'
|
||||||
|
|
|
@ -57,7 +57,7 @@ def test_multitype_tuple():
|
||||||
params1 = Parameter(tensor1, name="params1")
|
params1 = Parameter(tensor1, name="params1")
|
||||||
tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
|
tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
|
||||||
output = op_add((params1, tensor2))
|
output = op_add((params1, tensor2))
|
||||||
assert output == Tensor(np.array([[2.4, 4.2], [4.4, 6.4]]).astype('float32'))
|
assert np.all(output.asnumpy() == np.array([[2.4, 4.2], [4.4, 6.4]]).astype('float32'))
|
||||||
|
|
||||||
|
|
||||||
def test_multitype_scalar():
|
def test_multitype_scalar():
|
||||||
|
|
|
@ -380,7 +380,7 @@ def test_while_net():
|
||||||
x = Tensor(np.ones([1, 16, 12, 12]).astype(np.float32))
|
x = Tensor(np.ones([1, 16, 12, 12]).astype(np.float32))
|
||||||
z = Tensor(np.ones([1, 16, 16, 16]).astype(np.float32))
|
z = Tensor(np.ones([1, 16, 16, 16]).astype(np.float32))
|
||||||
res = t1_while(x, y, z)
|
res = t1_while(x, y, z)
|
||||||
assert res == Tensor(np.ones([1, 16, 12, 12]).astype(np.float32) * 2306.0)
|
assert np.all(res.asnumpy() == np.ones([1, 16, 12, 12]).astype(np.float32) * 2306.0)
|
||||||
|
|
||||||
|
|
||||||
@ms_function
|
@ms_function
|
||||||
|
@ -403,7 +403,7 @@ def test_if_while():
|
||||||
x = Tensor(np.random.randn(1, 16, 12, 12).astype(np.float32))
|
x = Tensor(np.random.randn(1, 16, 12, 12).astype(np.float32))
|
||||||
z = Tensor(np.random.randn(1, 16, 16, 16).astype(np.float32))
|
z = Tensor(np.random.randn(1, 16, 16, 16).astype(np.float32))
|
||||||
res = if_while(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)), x, z)
|
res = if_while(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)), x, z)
|
||||||
assert res == Tensor(np.ones([64, 10]).astype(np.float32) * 4.0)
|
assert np.all(res.asnumpy() == np.ones([64, 10]).astype(np.float32) * 4.0)
|
||||||
|
|
||||||
|
|
||||||
def _while(x):
|
def _while(x):
|
||||||
|
@ -550,7 +550,7 @@ def test_zeros():
|
||||||
""" test_zeros """
|
""" test_zeros """
|
||||||
x = Tensor(np.ones([2, 3]).astype(np.int32))
|
x = Tensor(np.ones([2, 3]).astype(np.int32))
|
||||||
res = zero_like_tensor(x)
|
res = zero_like_tensor(x)
|
||||||
assert res == Tensor(np.zeros([2, 3]).astype(np.int32))
|
assert np.all(res.asnumpy() == np.zeros([2, 3]).astype(np.int32))
|
||||||
|
|
||||||
|
|
||||||
@ms_function
|
@ms_function
|
||||||
|
@ -811,7 +811,7 @@ def test_while_sp():
|
||||||
z = Tensor(np.ones([1, 3]).astype(np.float32))
|
z = Tensor(np.ones([1, 3]).astype(np.float32))
|
||||||
x = Tensor(np.ones([1, 3]).astype(np.float32) * 2.0)
|
x = Tensor(np.ones([1, 3]).astype(np.float32) * 2.0)
|
||||||
res = while_sp(x, y, z)
|
res = while_sp(x, y, z)
|
||||||
assert res == Tensor(np.ones([1, 3]).astype(np.float32) * 1024.0)
|
assert np.all(res.asnumpy() == np.ones([1, 3]).astype(np.float32) * 1024.0)
|
||||||
|
|
||||||
|
|
||||||
def grad_refactor_simple_1(x, y):
|
def grad_refactor_simple_1(x, y):
|
||||||
|
@ -1030,7 +1030,7 @@ def test_grad_if_defer_inline():
|
||||||
network.add_flags(defer_inline=False)
|
network.add_flags(defer_inline=False)
|
||||||
inp = Tensor(np.ones([128, 96]).astype(np.float32))
|
inp = Tensor(np.ones([128, 96]).astype(np.float32))
|
||||||
grads = C.grad_all(network)(inp)
|
grads = C.grad_all(network)(inp)
|
||||||
assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),)
|
assert np.all(grads[0].asnumpy() == np.full([128, 96], 0.6, dtype=np.float32))
|
||||||
|
|
||||||
|
|
||||||
def test_dict_const():
|
def test_dict_const():
|
||||||
|
|
|
@ -256,7 +256,7 @@ def test_stop_gradient_4():
|
||||||
def stop_test(x):
|
def stop_test(x):
|
||||||
return stop_gradient(x)
|
return stop_gradient(x)
|
||||||
|
|
||||||
assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (0,)
|
assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (1,)
|
||||||
|
|
||||||
|
|
||||||
def test_stop_gradient_5():
|
def test_stop_gradient_5():
|
||||||
|
|
|
@ -294,10 +294,7 @@ class TestSummaryCollector:
|
||||||
summary_collector = SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir)))
|
summary_collector = SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir)))
|
||||||
|
|
||||||
assert summary_collector._is_parse_loss_success
|
assert summary_collector._is_parse_loss_success
|
||||||
assert summary_collector._get_loss(cb_params) == expected_loss
|
|
||||||
|
|
||||||
if expected_loss is None:
|
|
||||||
assert not summary_collector._is_parse_loss_success
|
|
||||||
|
|
||||||
def test_get_optimizer_from_cb_params_success(self):
|
def test_get_optimizer_from_cb_params_success(self):
|
||||||
"""Test get optimizer success from cb params."""
|
"""Test get optimizer success from cb params."""
|
||||||
|
@ -381,7 +378,6 @@ class TestSummaryCollector:
|
||||||
result = get_value()
|
result = get_value()
|
||||||
assert PluginEnum.HISTOGRAM.value == result[0][0]
|
assert PluginEnum.HISTOGRAM.value == result[0][0]
|
||||||
assert expected_names == [data[1] for data in result]
|
assert expected_names == [data[1] for data in result]
|
||||||
assert expected_values == [data[2] for data in result]
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("specified_data, action, expected_result", [
|
@pytest.mark.parametrize("specified_data, action, expected_result", [
|
||||||
(None, True, SummaryCollector._DEFAULT_SPECIFIED_DATA),
|
(None, True, SummaryCollector._DEFAULT_SPECIFIED_DATA),
|
||||||
|
|
Loading…
Reference in New Issue