Add test case for nested calling ms_function
This commit is contained in:
parent
b1be8dfd31
commit
dc14e2baef
|
@ -0,0 +1,97 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import ms_function
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
input_x = Tensor(np.ones([1, 1, 120, 640]), dtype=mstype.float32)
|
||||
input_y = Tensor(np.full((1, 1, 120, 640), 4), dtype=mstype.float32)
|
||||
ret_output_2 = Tensor(np.full((1, 1, 120, 640), 3.125), dtype=mstype.float32)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.timeout(60)
|
||||
@pytest.mark.env_Ascend_1p
|
||||
@pytest.mark.env_Gpu_1p
|
||||
@pytest.mark.env_CPU
|
||||
@pytest.mark.Function
|
||||
def test_ms_function_nested_local():
|
||||
@ms_function
|
||||
def function1(x, y):
|
||||
x = x ** y
|
||||
x /= y
|
||||
x += y
|
||||
x -= 1
|
||||
x %= 2
|
||||
return x
|
||||
|
||||
@ms_function
|
||||
def function11(x, y):
|
||||
r = function1(x, y)
|
||||
out = r + r
|
||||
return out
|
||||
|
||||
@ms_function
|
||||
def function2(x, y):
|
||||
r1 = function1(x, y)
|
||||
r2 = function11(x, y)
|
||||
z = r1 * r2
|
||||
return z
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
output2 = function2(input_x, input_y)
|
||||
print(output2)
|
||||
assert "Not support nested calling of local ms_function, please delete decorator of 'function11'." in str(
|
||||
info.value)
|
||||
|
||||
|
||||
@ms_function
|
||||
def function1_g(x, y):
|
||||
x = x ** y
|
||||
x /= y
|
||||
x += y
|
||||
x -= 1
|
||||
x %= 2
|
||||
return x
|
||||
|
||||
@ms_function
|
||||
def function11_g(x, y):
|
||||
r = function1_g(x, y)
|
||||
out = r + r
|
||||
return out
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.timeout(60)
|
||||
@pytest.mark.env_Ascend_1p
|
||||
@pytest.mark.env_Gpu_1p
|
||||
@pytest.mark.env_CPU
|
||||
@pytest.mark.Function
|
||||
def test_ms_function_nested_global():
|
||||
@ms_function
|
||||
def function2_g(x, y):
|
||||
r1 = function1_g(x, y)
|
||||
r2 = function11_g(x, y)
|
||||
z = r1 * r2
|
||||
return z
|
||||
|
||||
output2 = function2_g(input_x, input_y)
|
||||
assert np.allclose(output2.asnumpy(), ret_output_2.asnumpy(), 0.0001, 0.0001)
|
Loading…
Reference in New Issue