mindspore/tests/ut/python/pynative_mode/test_parse_method.py

461 lines
9.5 KiB
Python

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
@File : test_parse_method.py
@Author:
@Date : 2019-06-27
@Desc : test parse the object's method
"""
import logging
from dataclasses import dataclass
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import context
from mindspore._extends.parse.standard_method import ms_len
from mindspore.common.api import ms_function
from mindspore.common.tensor import Tensor
from mindspore.ops.composite import core
from mindspore.ops.primitive import constexpr
from mindspore.ops import functional as F
from ..ut_filter import non_graph_engine
def setup_module(module):
context.set_context(mode=context.PYNATIVE_MODE)
log = logging.getLogger("test")
log.setLevel(level=logging.ERROR)
@ms_function
def default_parameter_f(x, y=3):
""" default_parameter_f """
z = x + y
return z
# Test case: test parse fn that use default parameter
def test_parse_defalut_parameter_case1():
""" Test default parameter function call """
log.debug("begin test_parse_defalut_parameter_case1")
ret = default_parameter_f(2)
log.debug("finished test_parse_defalut_parameter_case1, ret = %r", ret)
def get_val_fn(x):
""" get_val_fn """
ret = x + 3
return ret
# Test case: test bool not
@ms_function
def bool_exp(x, y):
""" bool_exp """
return not x > y
def test_bool_exp():
""" test_bool_exp """
bool_exp(1, 2)
# Test case: use the variable parameter for @mindspore
@ms_function
def var_parameter_f(x, *args):
""" var_parameter_f """
z = x + args[0] + args[1] + args[2]
return z
def test_var_parameter_case1():
""" test_var_parameter_case1 """
log.debug("start test_var_parameter_case1")
var_parameter_f(1, 2, 3, 4, 5)
log.debug("end test_var_parameter_case1")
class Net(nn.Cell):
""" Net definition """
def __init__(self, value1):
super(Net, self).__init__()
self.relu = nn.ReLU()
self.softmax = nn.Softmax(0)
self.axis = 0
self.TC = ClassTest("test_class", 1.2)
self.value = value1
@ms_function
def construct(self, x):
x = self.get_test_value(x)
return x
def get_test_value(self, x):
ret = x + self.value
return ret
class ClassTest:
""" ClassTest definition """
def __init__(self, name, value1):
self.name = name
self.value = value1
def get_name(self):
return self.name
def get_value(self, inc):
ret = self.value + inc
return ret
def __call__(self, *args, **kwargs):
pass
# Test: call method on parse graph code
@non_graph_engine
def test_call_method_on_construct():
""" test_call_method_on_construct """
log.debug("begin test_call_method_on_construct")
x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
z = np.array([[3, 5, 7], [2, 3, 5]]).astype(np.int32)
net = Net(y)
output = net.construct(x)
result = output.asnumpy()
print(result)
assert np.all(result == z)
log.debug("finished test_call_method_on_construct")
# Test: call method on parse graph code
class Net1(nn.Cell):
""" Net1 definition """
def __init__(self, v1, v2):
super(Net1, self).__init__()
self.relu = nn.ReLU()
self.softmax = nn.Softmax(0)
self.axis = 0
self.TC = ClassTest("test_class", v1)
self.value = v2
@ms_function
def construct(self, x):
x = x + self.TC.get_value(self.value)
return x
@non_graph_engine
def test_call_other_object_method():
""" test_call_other_object_method """
log.debug("begin test_call_other_object_method")
x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
y1 = Tensor(np.array([[5, 4, 5], [1, 1, 2]]).astype(np.int32))
z = np.array([[8, 9, 12], [3, 4, 7]]).astype(np.int32)
net = Net1(y, y1)
with pytest.raises(TypeError):
output = net.construct(x)
result = output.asnumpy()
print(result)
assert np.all(result == z)
log.debug("finished test_call_other_object_method")
# Test: call global object method(not self) on parse graph code
value = Tensor(np.array([[3, 4, 5], [1, 1, 2]]).astype(np.int32))
TC = ClassTest("test_class", value)
class Net2(nn.Cell):
""" Net2 definition """
def __init__(self, value1):
super(Net2, self).__init__()
self.value = value1
@ms_function
def construct(self, x):
x = x + TC.get_value(self.value)
return x
@ms_function
def construct1(self, x):
x = x + TC.value
x = x + self.value
return x
@non_graph_engine
def test_call_no_self_other_object_method():
""" test_call_no_self_other_object_method """
log.debug("begin test_call_other_object_method")
x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
z = np.array([[6, 9, 12], [3, 4, 7]]).astype(np.int32)
net = Net2(y)
with pytest.raises(TypeError):
output = net.construct(x)
result = output.asnumpy()
print(result)
assert np.all(result == z)
log.debug("finished test_call_other_object_method")
def test_call_no_self_other_object_attr_value():
""" test_call_no_self_other_object_attr_value """
# do not support tensor as init input.
return
# Test case: use the * to unlock the varargs for @mindspore
def vararg1(x, y):
""" vararg1 """
z = x + y
return z
def varargs_main(fn):
""" varargs_main """
@ms_function
def t1(*args):
return fn(*args)
return t1
def test_var_parameter_case3():
""" test_var_parameter_case3 """
log.debug("start test_var_parameter_case3")
ret = varargs_main(vararg1)(1, 2)
log.debug("ret = %r", ret)
log.debug("end test_var_parameter_case3")
# Test case: test the flag set
@core(tg=True)
def set_flag(x):
""" set_flag """
return x + 1
@ms_function
def set_test_flag_main(x, y):
""" set_test_flag_main """
z = set_flag(x)
z = z + y
return z
def test_set_flag():
""" Test default parameter function call """
log.debug("begin test_set_flag")
ret = set_test_flag_main(2, 3)
log.debug("finished test_set_flag, ret = %r", ret)
@dataclass
class Access:
a: int
b: int
def max(self):
if self.a > self.b:
return self.a
return self.b
@ms_function
def invoke_dataclass(x, y):
""" invoke_dataclass """
acs = Access(x, y)
return acs.max()
def test_access():
""" test_access """
invoke_dataclass(1, 2)
@dataclass
class Access2:
a: int
b: int
def max(self):
if self.a > self.b:
return self.c
return self.b
@ms_function
def invoke_dataclass2(x, y):
""" invoke_dataclass """
acs = Access2(x, y)
return acs.max()
def test_access_attr_error():
""" test_access """
with pytest.raises(AttributeError):
invoke_dataclass2(1, 2)
def myfunc(x):
""" myfunc """
return x * x
@ms_function
def ms_infer_for():
""" ms_infer_for """
a = 0.0
for x in [1.1, 2.3, 3.3]:
a = a + x
return a
def test_infer_for():
""" test_infer_for """
ms_infer_for()
@ms_function
def ms_infer_for_func(y):
""" ms_infer_for_func """
for x in [1.0, 2.0, 3.0]:
y = myfunc(x) + y
return y
def test_ms_infer_for_func():
""" test_ms_infer_for_func """
ms_infer_for_func(1.0)
@ms_function
def add(x, y):
""" add """
return x + y
def test_add():
""" test_add """
res = add(1, 2.0)
return res
@ms_function
def add_list():
""" add_list """
a = [1, 2, 3]
b = a[1] + a[2]
return b
def test_list():
""" test_list """
return add_list()
@ms_function
def compare_list_len():
""" compare_list_len """
a = [1, 2, 3]
return ms_len(a)
def test_list_len():
""" test_list_len """
compare_list_len()
@ms_function
def add_tuple():
""" add_tuple """
a = (1, 2, 3)
b = a[1] + a[2]
return b
def test_tuple():
""" test_tuple """
return add_tuple()
def invoke_func(x):
""" invoke_func """
return x * x
@ms_function
def tuple_of_node(x, y):
""" tuple_of_node """
a = invoke_func(x)
b = invoke_func(y)
c = (a, b)
d = c[1] * x
return d
def test_tuple_node():
""" test_tuple_node """
res = tuple_of_node(1, 2)
return res
@ms_function
def range_spec(x, y):
""" range_spec """
for _ in range(1, 10, 3):
x = x + 1
return x + y
def test_range():
""" test_range """
res = range_spec(10, 10)
return res
def test_expr():
""" test const expr """
a = (1, 2)
@constexpr
def tuple_len(x):
assert len(x) == 2
tuple_len(a)
def test_tuple_to_array():
""" test range tuple to array """
range_x = range(10)
res = F.tuple_to_array(range_x)
print(res)