forked from mindspore-Ecosystem/mindspore
461 lines
9.5 KiB
Python
461 lines
9.5 KiB
Python
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""
|
|
@File : test_parse_method.py
|
|
@Author:
|
|
@Date : 2019-06-27
|
|
@Desc : test parse the object's method
|
|
"""
|
|
import logging
|
|
from dataclasses import dataclass
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
import mindspore.nn as nn
|
|
from mindspore import context
|
|
from mindspore._extends.parse.standard_method import ms_len
|
|
from mindspore.common.api import ms_function
|
|
from mindspore.common.tensor import Tensor
|
|
from mindspore.ops.composite import core
|
|
from mindspore.ops.primitive import constexpr
|
|
from mindspore.ops import functional as F
|
|
from ..ut_filter import non_graph_engine
|
|
|
|
|
|
def setup_module(module):
|
|
context.set_context(mode=context.PYNATIVE_MODE)
|
|
|
|
|
|
log = logging.getLogger("test")
|
|
log.setLevel(level=logging.ERROR)
|
|
|
|
|
|
@ms_function
|
|
def default_parameter_f(x, y=3):
|
|
""" default_parameter_f """
|
|
z = x + y
|
|
return z
|
|
|
|
|
|
# Test case: test parse fn that use default parameter
|
|
def test_parse_defalut_parameter_case1():
|
|
""" Test default parameter function call """
|
|
log.debug("begin test_parse_defalut_parameter_case1")
|
|
ret = default_parameter_f(2)
|
|
log.debug("finished test_parse_defalut_parameter_case1, ret = %r", ret)
|
|
|
|
|
|
def get_val_fn(x):
|
|
""" get_val_fn """
|
|
ret = x + 3
|
|
return ret
|
|
|
|
|
|
# Test case: test bool not
|
|
@ms_function
|
|
def bool_exp(x, y):
|
|
""" bool_exp """
|
|
return not x > y
|
|
|
|
|
|
def test_bool_exp():
|
|
""" test_bool_exp """
|
|
bool_exp(1, 2)
|
|
|
|
|
|
# Test case: use the variable parameter for @mindspore
|
|
@ms_function
|
|
def var_parameter_f(x, *args):
|
|
""" var_parameter_f """
|
|
z = x + args[0] + args[1] + args[2]
|
|
return z
|
|
|
|
|
|
def test_var_parameter_case1():
|
|
""" test_var_parameter_case1 """
|
|
log.debug("start test_var_parameter_case1")
|
|
var_parameter_f(1, 2, 3, 4, 5)
|
|
log.debug("end test_var_parameter_case1")
|
|
|
|
|
|
class Net(nn.Cell):
|
|
""" Net definition """
|
|
|
|
def __init__(self, value1):
|
|
super(Net, self).__init__()
|
|
self.relu = nn.ReLU()
|
|
self.softmax = nn.Softmax(0)
|
|
self.axis = 0
|
|
self.TC = ClassTest("test_class", 1.2)
|
|
self.value = value1
|
|
|
|
@ms_function
|
|
def construct(self, x):
|
|
x = self.get_test_value(x)
|
|
return x
|
|
|
|
def get_test_value(self, x):
|
|
ret = x + self.value
|
|
return ret
|
|
|
|
|
|
class ClassTest:
|
|
""" ClassTest definition """
|
|
|
|
def __init__(self, name, value1):
|
|
self.name = name
|
|
self.value = value1
|
|
|
|
def get_name(self):
|
|
return self.name
|
|
|
|
def get_value(self, inc):
|
|
ret = self.value + inc
|
|
return ret
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
pass
|
|
|
|
|
|
# Test: call method on parse graph code
|
|
@non_graph_engine
|
|
def test_call_method_on_construct():
|
|
""" test_call_method_on_construct """
|
|
log.debug("begin test_call_method_on_construct")
|
|
|
|
x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
|
|
y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
|
|
z = np.array([[3, 5, 7], [2, 3, 5]]).astype(np.int32)
|
|
|
|
net = Net(y)
|
|
output = net.construct(x)
|
|
result = output.asnumpy()
|
|
print(result)
|
|
assert np.all(result == z)
|
|
|
|
log.debug("finished test_call_method_on_construct")
|
|
|
|
|
|
# Test: call method on parse graph code
|
|
class Net1(nn.Cell):
|
|
""" Net1 definition """
|
|
|
|
def __init__(self, v1, v2):
|
|
super(Net1, self).__init__()
|
|
self.relu = nn.ReLU()
|
|
self.softmax = nn.Softmax(0)
|
|
self.axis = 0
|
|
self.TC = ClassTest("test_class", v1)
|
|
self.value = v2
|
|
|
|
@ms_function
|
|
def construct(self, x):
|
|
x = x + self.TC.get_value(self.value)
|
|
return x
|
|
|
|
|
|
@non_graph_engine
|
|
def test_call_other_object_method():
|
|
""" test_call_other_object_method """
|
|
log.debug("begin test_call_other_object_method")
|
|
|
|
x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
|
|
y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
|
|
y1 = Tensor(np.array([[5, 4, 5], [1, 1, 2]]).astype(np.int32))
|
|
z = np.array([[8, 9, 12], [3, 4, 7]]).astype(np.int32)
|
|
|
|
net = Net1(y, y1)
|
|
with pytest.raises(TypeError):
|
|
output = net.construct(x)
|
|
result = output.asnumpy()
|
|
print(result)
|
|
assert np.all(result == z)
|
|
|
|
log.debug("finished test_call_other_object_method")
|
|
|
|
|
|
# Test: call global object method(not self) on parse graph code
|
|
value = Tensor(np.array([[3, 4, 5], [1, 1, 2]]).astype(np.int32))
|
|
TC = ClassTest("test_class", value)
|
|
|
|
|
|
class Net2(nn.Cell):
|
|
""" Net2 definition """
|
|
|
|
def __init__(self, value1):
|
|
super(Net2, self).__init__()
|
|
self.value = value1
|
|
|
|
@ms_function
|
|
def construct(self, x):
|
|
x = x + TC.get_value(self.value)
|
|
return x
|
|
|
|
@ms_function
|
|
def construct1(self, x):
|
|
x = x + TC.value
|
|
x = x + self.value
|
|
return x
|
|
|
|
|
|
@non_graph_engine
|
|
def test_call_no_self_other_object_method():
|
|
""" test_call_no_self_other_object_method """
|
|
log.debug("begin test_call_other_object_method")
|
|
x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
|
|
y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
|
|
z = np.array([[6, 9, 12], [3, 4, 7]]).astype(np.int32)
|
|
|
|
net = Net2(y)
|
|
with pytest.raises(TypeError):
|
|
output = net.construct(x)
|
|
result = output.asnumpy()
|
|
print(result)
|
|
assert np.all(result == z)
|
|
|
|
log.debug("finished test_call_other_object_method")
|
|
|
|
|
|
def test_call_no_self_other_object_attr_value():
|
|
""" test_call_no_self_other_object_attr_value """
|
|
# do not support tensor as init input.
|
|
return
|
|
|
|
|
|
# Test case: use the * to unlock the varargs for @mindspore
|
|
def vararg1(x, y):
|
|
""" vararg1 """
|
|
z = x + y
|
|
return z
|
|
|
|
|
|
def varargs_main(fn):
|
|
""" varargs_main """
|
|
|
|
@ms_function
|
|
def t1(*args):
|
|
return fn(*args)
|
|
|
|
return t1
|
|
|
|
|
|
def test_var_parameter_case3():
|
|
""" test_var_parameter_case3 """
|
|
log.debug("start test_var_parameter_case3")
|
|
ret = varargs_main(vararg1)(1, 2)
|
|
log.debug("ret = %r", ret)
|
|
log.debug("end test_var_parameter_case3")
|
|
|
|
|
|
# Test case: test the flag set
|
|
@core(tg=True)
|
|
def set_flag(x):
|
|
""" set_flag """
|
|
return x + 1
|
|
|
|
|
|
@ms_function
|
|
def set_test_flag_main(x, y):
|
|
""" set_test_flag_main """
|
|
z = set_flag(x)
|
|
z = z + y
|
|
return z
|
|
|
|
|
|
def test_set_flag():
|
|
""" Test default parameter function call """
|
|
log.debug("begin test_set_flag")
|
|
ret = set_test_flag_main(2, 3)
|
|
log.debug("finished test_set_flag, ret = %r", ret)
|
|
|
|
|
|
@dataclass
|
|
class Access:
|
|
a: int
|
|
b: int
|
|
|
|
def max(self):
|
|
if self.a > self.b:
|
|
return self.a
|
|
return self.b
|
|
|
|
|
|
@ms_function
|
|
def invoke_dataclass(x, y):
|
|
""" invoke_dataclass """
|
|
acs = Access(x, y)
|
|
return acs.max()
|
|
|
|
|
|
def test_access():
|
|
""" test_access """
|
|
invoke_dataclass(1, 2)
|
|
|
|
@dataclass
|
|
class Access2:
|
|
a: int
|
|
b: int
|
|
|
|
def max(self):
|
|
if self.a > self.b:
|
|
return self.c
|
|
return self.b
|
|
|
|
|
|
@ms_function
|
|
def invoke_dataclass2(x, y):
|
|
""" invoke_dataclass """
|
|
acs = Access2(x, y)
|
|
return acs.max()
|
|
|
|
|
|
def test_access_attr_error():
|
|
""" test_access """
|
|
with pytest.raises(AttributeError):
|
|
invoke_dataclass2(1, 2)
|
|
|
|
|
|
def myfunc(x):
|
|
""" myfunc """
|
|
return x * x
|
|
|
|
|
|
@ms_function
|
|
def ms_infer_for():
|
|
""" ms_infer_for """
|
|
a = 0.0
|
|
for x in [1.1, 2.3, 3.3]:
|
|
a = a + x
|
|
return a
|
|
|
|
|
|
def test_infer_for():
|
|
""" test_infer_for """
|
|
ms_infer_for()
|
|
|
|
|
|
@ms_function
|
|
def ms_infer_for_func(y):
|
|
""" ms_infer_for_func """
|
|
for x in [1.0, 2.0, 3.0]:
|
|
y = myfunc(x) + y
|
|
return y
|
|
|
|
|
|
def test_ms_infer_for_func():
|
|
""" test_ms_infer_for_func """
|
|
ms_infer_for_func(1.0)
|
|
|
|
|
|
@ms_function
|
|
def add(x, y):
|
|
""" add """
|
|
return x + y
|
|
|
|
|
|
def test_add():
|
|
""" test_add """
|
|
res = add(1, 2.0)
|
|
return res
|
|
|
|
|
|
@ms_function
|
|
def add_list():
|
|
""" add_list """
|
|
a = [1, 2, 3]
|
|
b = a[1] + a[2]
|
|
return b
|
|
|
|
|
|
def test_list():
|
|
""" test_list """
|
|
return add_list()
|
|
|
|
|
|
@ms_function
|
|
def compare_list_len():
|
|
""" compare_list_len """
|
|
a = [1, 2, 3]
|
|
return ms_len(a)
|
|
|
|
|
|
def test_list_len():
|
|
""" test_list_len """
|
|
compare_list_len()
|
|
|
|
|
|
@ms_function
|
|
def add_tuple():
|
|
""" add_tuple """
|
|
a = (1, 2, 3)
|
|
b = a[1] + a[2]
|
|
return b
|
|
|
|
|
|
def test_tuple():
|
|
""" test_tuple """
|
|
return add_tuple()
|
|
|
|
|
|
def invoke_func(x):
|
|
""" invoke_func """
|
|
return x * x
|
|
|
|
|
|
@ms_function
|
|
def tuple_of_node(x, y):
|
|
""" tuple_of_node """
|
|
a = invoke_func(x)
|
|
b = invoke_func(y)
|
|
c = (a, b)
|
|
d = c[1] * x
|
|
return d
|
|
|
|
|
|
def test_tuple_node():
|
|
""" test_tuple_node """
|
|
res = tuple_of_node(1, 2)
|
|
return res
|
|
|
|
|
|
@ms_function
|
|
def range_spec(x, y):
|
|
""" range_spec """
|
|
for _ in range(1, 10, 3):
|
|
x = x + 1
|
|
return x + y
|
|
|
|
|
|
def test_range():
|
|
""" test_range """
|
|
res = range_spec(10, 10)
|
|
return res
|
|
|
|
def test_expr():
|
|
""" test const expr """
|
|
a = (1, 2)
|
|
@constexpr
|
|
def tuple_len(x):
|
|
assert len(x) == 2
|
|
tuple_len(a)
|
|
|
|
|
|
def test_tuple_to_array():
|
|
""" test range tuple to array """
|
|
range_x = range(10)
|
|
res = F.tuple_to_array(range_x)
|
|
print(res)
|