forked from mindspore-Ecosystem/mindspore
add more pynative st test_cases
This commit is contained in:
parent
f8fa043f37
commit
41f520c9fa
|
@ -0,0 +1,76 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test_pynative_embeddinglookup """
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.ops.operations as op
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.nn import Cell
|
||||
|
||||
def setup_module():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
|
||||
class MetaFactory:
|
||||
def __init__(self):
|
||||
self.device_target = context.get_context('device_target')
|
||||
self.rank_size = None
|
||||
self.device_id = None
|
||||
self.global_rank_id = None
|
||||
|
||||
class OpsFactory(MetaFactory):
|
||||
def __init__(self, dtype=np.float16):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
if self.dtype == np.float16:
|
||||
self.loss = 1e-3
|
||||
elif self.dtype == np.float32:
|
||||
self.loss = 1e-4
|
||||
elif self.dtype == np.float64:
|
||||
self.loss = 1e-5
|
||||
else:
|
||||
self.loss = 0
|
||||
|
||||
class EmbeddingLookup(Cell):
|
||||
def __init__(self, offset):
|
||||
super().__init__()
|
||||
self.op = op.EmbeddingLookup()
|
||||
self.offset = offset
|
||||
|
||||
def construct(self, params, indices):
|
||||
x = self.op(params, indices, self.offset)
|
||||
return x
|
||||
|
||||
class EmbeddingLookupFactory(OpsFactory):
|
||||
def __init__(self, params_shape, indices_shape, offset=0, low=0, high=2, dtype=np.float32, ids_type=np.int32):
|
||||
super().__init__(dtype=dtype)
|
||||
self.input_np = np.random.randn(*params_shape).astype(dtype)
|
||||
self.indices_np = np.random.randint(low, high, size=indices_shape).astype(ids_type)
|
||||
self.offset = offset
|
||||
self.output_grad_np = None
|
||||
|
||||
def forward_mindspore_impl(self):
|
||||
net = EmbeddingLookup(self.offset)
|
||||
out = net(Tensor(self.input_np), Tensor(self.indices_np))
|
||||
return out.asnumpy()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_embeddinglookup_indices_outrange():
|
||||
fact = EmbeddingLookupFactory(params_shape=(2, 4), indices_shape=(2, 3), low=1, high=3, offset=10, dtype=np.int8)
|
||||
out = fact.forward_mindspore_impl()
|
||||
out_expect = np.zeros((2, 3, 4))
|
||||
np.allclose(out_expect, out)
|
|
@ -232,7 +232,7 @@ def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_hook_if_net_register_diff_hook_at_each_hook():
|
||||
def test_pynative_hook_diff_hook():
|
||||
input_np = np.ones([1, 1, 224, 224]).astype(np.float32)
|
||||
ms_net = FinalNet()
|
||||
ms_net.set_grad()
|
||||
|
@ -248,7 +248,7 @@ def test_pynative_hook_if_net_register_diff_hook_at_each_hook():
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_hook_one_input_network_register_hook_at_outermost_cell_not_change_grad():
|
||||
def test_pynative_hook_outermost_cell_not_change_grad():
|
||||
input_np = np.ones([2, 2]).astype(np.float32)
|
||||
|
||||
ms_net = MsOneInputNet()
|
||||
|
@ -273,7 +273,7 @@ def test_pynative_hook_one_input_network_register_hook_at_outermost_cell_not_cha
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_hook_one_input_network_register_hook_to_all_cell_record_grad():
|
||||
def test_pynative_hook_all_cell_record_grad():
|
||||
input_np = np.ones([2, 2]).astype(np.float32)
|
||||
|
||||
ms_net = MsOneInputNet()
|
||||
|
@ -305,7 +305,7 @@ def test_pynative_hook_one_input_network_register_hook_to_all_cell_record_grad()
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_hook_one_input_network_register_hook_to_mul_change_input_grad():
|
||||
def test_pynative_hook_mul_change_input_grad():
|
||||
input_np = np.ones([2, 2]).astype(np.float32)
|
||||
|
||||
ms_net = MsOneInputNet()
|
||||
|
@ -325,7 +325,7 @@ def test_pynative_hook_one_input_network_register_hook_to_mul_change_input_grad(
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_hook_multi_input_network_register_hook_to_mul2_change_input_grad():
|
||||
def test_pynative_hook_mul2_change_input_grad():
|
||||
input1_np = np.array([2.0, 3.0, 4.0]).astype(np.float32)
|
||||
input2_np = np.array([2.0, 3.0, 4.0]).astype(np.float32)
|
||||
|
||||
|
@ -349,7 +349,7 @@ def test_pynative_hook_multi_input_network_register_hook_to_mul2_change_input_gr
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_hook_network_with_cell_in_cell_register_hook_at_outermost_cell_change_grad():
|
||||
def test_pynative_hook_outermost_cell_change_grad():
|
||||
input_np = np.ones([2, 2]).astype(np.float32)
|
||||
|
||||
ms_net = MsNetWithCellinCell()
|
||||
|
@ -371,7 +371,7 @@ def test_pynative_hook_network_with_cell_in_cell_register_hook_at_outermost_cell
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_hook_network_with_bprop_register_hook_at_outermost_cell_record_grad():
|
||||
def test_pynative_hook_outermost_cell_record_grad():
|
||||
input_np = np.ones([2, 2]).astype(np.float32)
|
||||
|
||||
ms_net = MsSingleOpNetWithBprop()
|
||||
|
@ -397,7 +397,7 @@ def test_pynative_hook_network_with_bprop_register_hook_at_outermost_cell_record
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_hook_network_with_bprop_in_child_register_hook_at_outermost_cell_record_grad():
|
||||
def test_pynative_hook_bprop_outermost_cell_record_grad():
|
||||
input_np = np.ones([2, 2]).astype(np.float32)
|
||||
|
||||
ms_net = MsNetHasBpropInChild()
|
||||
|
@ -428,7 +428,7 @@ def test_pynative_hook_network_with_bprop_in_child_register_hook_at_outermost_ce
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_hook_multi_op_network_with_bprop_register_hook_at_child_cell_record_grad():
|
||||
def test_pynative_hook_child_cell_record_grad():
|
||||
input_np = np.ones([2, 2]).astype(np.float32)
|
||||
|
||||
ms_net = MsMultiOpNetWithBprop()
|
||||
|
|
|
@ -0,0 +1,226 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test_pynative_layernorm_input_and_argmaxwithvalue """
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
import mindspore.ops.operations as op
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.nn import LayerNorm, Cell
|
||||
from mindspore.common import ParameterTuple
|
||||
from mindspore.ops.composite import GradOperation
|
||||
from mindspore.train.model import Model
|
||||
|
||||
class _Grad(Cell):
|
||||
def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
|
||||
super().__init__()
|
||||
self.network = network
|
||||
self.grad = grad
|
||||
self.sens_param = self.grad.sens_param
|
||||
self.wrt_params = wrt_params
|
||||
self.real_inputs_count = real_inputs_count
|
||||
if self.wrt_params:
|
||||
self.params = ParameterTuple(self.network.trainable_params())
|
||||
|
||||
def construct(self, *inputs):
|
||||
if self.wrt_params:
|
||||
if self.real_inputs_count is None or self.sens_param is False:
|
||||
return self.grad(self.network, self.params)(*inputs)
|
||||
real_inputs = inputs[:self.real_inputs_count]
|
||||
sense_param_inputs = inputs[self.real_inputs_count:]
|
||||
return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
|
||||
if self.real_inputs_count is None or self.sens_param is False:
|
||||
return self.grad(self.network)(*inputs)
|
||||
real_inputs = inputs[:self.real_inputs_count]
|
||||
sense_param_inputs = inputs[self.real_inputs_count:]
|
||||
return self.grad(self.network)(*real_inputs, sense_param_inputs)
|
||||
|
||||
class GradOfAllInputsAndParams(_Grad):
|
||||
def __init__(self, network, sens_param=True, real_inputs_count=None):
|
||||
super().__init__(grad=GradOperation(get_all=True, get_by_list=True, sens_param=sens_param),
|
||||
network=network, wrt_params=True, real_inputs_count=real_inputs_count)
|
||||
|
||||
class MetaFactory:
|
||||
def __init__(self):
|
||||
self.device_target = context.get_context('device_target')
|
||||
self.rank_size = None
|
||||
self.device_id = None
|
||||
self.global_rank_id = None
|
||||
|
||||
class OpsFactory(MetaFactory):
|
||||
def __init__(self, dtype=np.float16):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
if self.dtype == np.float16:
|
||||
self.loss = 1e-3
|
||||
elif self.dtype == np.float32:
|
||||
self.loss = 1e-4
|
||||
elif self.dtype == np.float64:
|
||||
self.loss = 1e-5
|
||||
else:
|
||||
self.loss = 0
|
||||
|
||||
def _count_unequal_element(data_expected, data_me, rtol, atol):
|
||||
assert data_expected.shape == data_me.shape
|
||||
total_count = len(data_expected.flatten())
|
||||
error = np.abs(data_expected - data_me)
|
||||
greater = np.greater(error, atol + np.abs(data_me)*rtol)
|
||||
loss_count = np.count_nonzero(greater)
|
||||
assert (loss_count/total_count) < rtol, \
|
||||
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
|
||||
format(data_expected[greater], data_me[greater], error[greater])
|
||||
|
||||
def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
|
||||
if np.any(np.isnan(data_expected)):
|
||||
assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan)
|
||||
elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan):
|
||||
_count_unequal_element(data_expected, data_me, rtol, atol)
|
||||
else:
|
||||
assert True
|
||||
|
||||
class LayerNormFactory(OpsFactory):
|
||||
def __init__(self, input_shape, norm_shape, gamma_shape, beta_shape, gamma_init=None, beta_init=None,
|
||||
norm_axis=-1, params_axis=-1, dtype=np.float32):
|
||||
super().__init__(dtype=dtype)
|
||||
np.random.seed(1)
|
||||
self.input_np = np.random.randn(*input_shape).astype(dtype=dtype)
|
||||
self.gamma_np = np.ones(shape=gamma_shape, dtype=dtype)
|
||||
self.gamma_init = gamma_init
|
||||
self.beta_np = np.zeros(shape=beta_shape, dtype=dtype)
|
||||
self.beta_init = beta_init
|
||||
self.output_grad_np = np.random.randn(*input_shape).astype(dtype=dtype)
|
||||
self.begin_norm_axis = norm_axis
|
||||
self.begin_params_axis = params_axis
|
||||
self.input_shape = norm_shape
|
||||
|
||||
def forward_mindspore_impl(self):
|
||||
input_ms = Tensor(self.input_np)
|
||||
gamma = Tensor(self.gamma_np)
|
||||
beta = Tensor(self.beta_np)
|
||||
net = LayerNorm(self.input_shape, self.begin_norm_axis, self.begin_params_axis, gamma, beta)
|
||||
net.set_train()
|
||||
model = Model(net)
|
||||
out_me = model.predict(Tensor(input_ms))
|
||||
return out_me.asnumpy()
|
||||
|
||||
def grad_mindspore_impl(self):
|
||||
input_nn = Tensor(self.input_np)
|
||||
output_grad = Tensor(self.output_grad_np)
|
||||
net = LayerNorm(self.input_shape, self.begin_norm_axis, self.begin_params_axis,
|
||||
Tensor(self.gamma_np), Tensor(self.beta_np))
|
||||
grad_net = GradOfAllInputsAndParams(net)
|
||||
grad_net.set_train()
|
||||
input_grad = grad_net(input_nn, output_grad)
|
||||
return input_grad[0][0].asnumpy(), input_grad[1][1].asnumpy(), input_grad[1][0].asnumpy()
|
||||
|
||||
def forward_cmp(self):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
graph_out = self.forward_mindspore_impl()
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
pynative_out = self.forward_mindspore_impl()
|
||||
|
||||
allclose_nparray(graph_out[0], pynative_out[0], self.loss, self.loss)
|
||||
|
||||
def grad_cmp(self):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
graph_grad1, graph_grad2, graph_grad3 = self.grad_mindspore_impl()
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
pynative_grad1, pynative_grad2, pynative_grad3 = self.grad_mindspore_impl()
|
||||
|
||||
allclose_nparray(graph_grad1, pynative_grad1, self.loss, self.loss)
|
||||
allclose_nparray(graph_grad2, pynative_grad2, self.loss, self.loss)
|
||||
allclose_nparray(graph_grad3, pynative_grad3, self.loss, self.loss)
|
||||
|
||||
class ArgMaxWithValue(Cell):
|
||||
def __init__(self, axis, keep_dims):
|
||||
super().__init__()
|
||||
self.op = op.ArgMaxWithValue(axis=axis, keep_dims=keep_dims)
|
||||
|
||||
def construct(self, input_value):
|
||||
return self.op(input_value)
|
||||
|
||||
class GradOfFirstInput(_Grad):
|
||||
def __init__(self, network, sens_param=True, real_inputs_count=None):
|
||||
super().__init__(grad=GradOperation(sens_param=sens_param),
|
||||
network=network, real_inputs_count=real_inputs_count)
|
||||
|
||||
class ArgMaxWithValueFactory(OpsFactory):
|
||||
def __init__(self, input_shape, axis, keep_dims, dtype=np.float32):
|
||||
super().__init__(dtype=dtype)
|
||||
np.random.seed(1)
|
||||
self.input_np = np.random.rand(*input_shape).astype(dtype)
|
||||
self.output_grad_np = None
|
||||
self.axis = axis
|
||||
self.keep_dims = keep_dims
|
||||
|
||||
def forward_mindspore_impl(self):
|
||||
input_forward = Tensor(self.input_np)
|
||||
net = ArgMaxWithValue(axis=self.axis, keep_dims=self.keep_dims)
|
||||
index, value = net(input_forward)
|
||||
return index.asnumpy().reshape(1, -1), value.asnumpy()
|
||||
|
||||
def forward_numpy_impl(self):
|
||||
index = np.argmax(self.input_np, axis=self.axis)
|
||||
value = np.amax(self.input_np, axis=self.axis, keepdims=self.keep_dims)
|
||||
return index.reshape(1, -1), value.astype(self.dtype)
|
||||
|
||||
def grad_mindspore_impl(self):
|
||||
input_back = Tensor(self.input_np)
|
||||
np.random.seed(1)
|
||||
self.output_grad_np = np.random.randn(*input_back[0].shape).astype(self.dtype)
|
||||
output_grad = Tensor(self.output_grad_np, ms.int32)
|
||||
output_grad_2 = Tensor(self.output_grad_np)
|
||||
net = ArgMaxWithValue(axis=self.axis, keep_dims=self.keep_dims)
|
||||
grad_net = GradOfFirstInput(net, real_inputs_count=1)
|
||||
grad_net.set_train()
|
||||
input_grad = grad_net(input_back, output_grad, output_grad_2)
|
||||
return input_grad.asnumpy()
|
||||
|
||||
def forward_cmp(self):
|
||||
out_numpy = self.forward_numpy_impl()
|
||||
out_mindspore = self.forward_mindspore_impl()
|
||||
allclose_nparray(out_numpy[0], out_mindspore[0], self.loss, self.loss)
|
||||
allclose_nparray(out_numpy[1], out_mindspore[1], self.loss, self.loss)
|
||||
|
||||
def grad_cmp(self):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
graph_grad = self.grad_mindspore_impl()
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
pynative_grad = self.grad_mindspore_impl()
|
||||
|
||||
allclose_nparray(graph_grad, pynative_grad, self.loss, self.loss)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layernorm_input():
|
||||
fact = LayerNormFactory(input_shape=(1, 128, 1024), norm_shape=(1024,), gamma_shape=(1024,), beta_shape=(1024,),
|
||||
norm_axis=2, params_axis=2, dtype=np.float16)
|
||||
fact.forward_cmp()
|
||||
fact.loss = 5e-3
|
||||
fact.grad_cmp()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_argmaxwithvalue_input():
|
||||
fact = ArgMaxWithValueFactory(input_shape=[1024, 1024], axis=-1, keep_dims=False)
|
||||
fact.forward_cmp()
|
||||
fact.grad_cmp()
|
|
@ -0,0 +1,159 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test_pynative_mixed_precision_cells """
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore import context
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.nn import ReLU
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
class MetaFactory:
|
||||
def __init__(self):
|
||||
self.device_target = context.get_context('device_target')
|
||||
self.rank_size = None
|
||||
self.device_id = None
|
||||
self.global_rank_id = None
|
||||
|
||||
class ReluTanhSoftmax(Cell, MetaFactory):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
MetaFactory.__init__(self)
|
||||
self.relu = ReLU()
|
||||
self.tanh = nn.Tanh()
|
||||
self.softmax = nn.Softmax()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu(x)
|
||||
y = self.tanh(x)
|
||||
z = self.softmax(x)
|
||||
return x, y, z
|
||||
|
||||
class Add(Cell, MetaFactory):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
MetaFactory.__init__(self)
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.add(x, y)
|
||||
|
||||
class ReluTanhAdd(Cell, MetaFactory):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
MetaFactory.__init__(self)
|
||||
self.relu = ReLU()
|
||||
self.tanh = nn.Tanh()
|
||||
self.add = Add()
|
||||
|
||||
def construct(self, x):
|
||||
x_1 = self.relu(x)
|
||||
y = self.tanh(x)
|
||||
x = self.add(x_1, y)
|
||||
return x
|
||||
|
||||
def _count_unequal_element(data_expected, data_me, rtol, atol):
|
||||
assert data_expected.shape == data_me.shape
|
||||
total_count = len(data_expected.flatten())
|
||||
error = np.abs(data_expected - data_me)
|
||||
greater = np.greater(error, atol + np.abs(data_me)*rtol)
|
||||
loss_count = np.count_nonzero(greater)
|
||||
assert (loss_count/total_count) < rtol, \
|
||||
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
|
||||
format(data_expected[greater], data_me[greater], error[greater])
|
||||
|
||||
def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
|
||||
if np.any(np.isnan(data_expected)):
|
||||
assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan)
|
||||
elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan):
|
||||
_count_unequal_element(data_expected, data_me, rtol, atol)
|
||||
else:
|
||||
assert True
|
||||
|
||||
def mixed_precision_multiple_cells_01():
|
||||
np.random.seed(1)
|
||||
x = np.random.randn(1, 3, 28, 28).astype(np.float32)
|
||||
net = ReluTanhSoftmax()
|
||||
net.to_float(ms.float16)
|
||||
net.relu.to_float(ms.float32)
|
||||
net.softmax.to_float(ms.float16)
|
||||
out_me_relu_01, out_me_tanh_01, out_me_softmax_01 = net(Tensor(x))
|
||||
return out_me_relu_01, out_me_tanh_01, out_me_softmax_01
|
||||
|
||||
def mixed_precision_multiple_cells_02():
|
||||
np.random.seed(1)
|
||||
x = np.random.randn(1, 3, 28, 28).astype(np.float32)
|
||||
net = ReluTanhSoftmax()
|
||||
net.relu.to_float(ms.float32)
|
||||
net.softmax.to_float(ms.float16)
|
||||
net.to_float(ms.float16)
|
||||
out_me_relu_02, out_me_tanh_02, out_me_softmax_02 = net(Tensor(x))
|
||||
return out_me_relu_02, out_me_tanh_02, out_me_softmax_02
|
||||
|
||||
def mixed_precision_multiple_cells_03():
|
||||
np.random.seed(1)
|
||||
x = np.random.randn(1, 3, 28, 28).astype(np.float32)
|
||||
net = ReluTanhAdd()
|
||||
net.to_float(ms.float16)
|
||||
net.relu.to_float(ms.float32)
|
||||
net.add.to_float(ms.float32)
|
||||
out_me = net(Tensor(x))
|
||||
return out_me
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_mixed_precision_multiples_cell_01():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
graph_relu_01, graph_tanh_01, graph_softmax_01 = mixed_precision_multiple_cells_01()
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
pynative_relu_01, pynative_tanh_01, pynative_softmax_01 = mixed_precision_multiple_cells_01()
|
||||
|
||||
allclose_nparray(graph_relu_01.asnumpy(), pynative_relu_01.asnumpy(), 0.001, 0.001)
|
||||
allclose_nparray(graph_tanh_01.asnumpy(), pynative_tanh_01.asnumpy(), 0.001, 0.001)
|
||||
allclose_nparray(graph_softmax_01.asnumpy(), pynative_softmax_01.asnumpy(), 0.001, 0.001)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_mixed_precision_multiples_cell_02():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
graph_relu_02, graph_tanh_02, graph_softmax_02 = mixed_precision_multiple_cells_02()
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
pynative_relu_02, pynative_tanh_02, pynative_softmax_02 = mixed_precision_multiple_cells_02()
|
||||
|
||||
allclose_nparray(graph_relu_02.asnumpy(), pynative_relu_02.asnumpy(), 0.001, 0.001)
|
||||
allclose_nparray(graph_tanh_02.asnumpy(), pynative_tanh_02.asnumpy(), 0.001, 0.001)
|
||||
allclose_nparray(graph_softmax_02.asnumpy(), pynative_softmax_02.asnumpy(), 0.001, 0.001)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_mixed_precision_multiples_cell_03():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
graph_output_03 = mixed_precision_multiple_cells_03()
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
pynative_output_03 = mixed_precision_multiple_cells_03()
|
||||
|
||||
allclose_nparray(graph_output_03.asnumpy(), pynative_output_03.asnumpy(), 0.001, 0.001)
|
Loading…
Reference in New Issue