forked from mindspore-Ecosystem/mindspore
Add test case and fix two bugs
1. add case to guard precision 2. fix a shape bug 3. fix a funcGraph bug
This commit is contained in:
parent
cf6dd99ed7
commit
1e43c609e0
|
@ -348,16 +348,13 @@ void AscendStreamAssign::GetProcessedStream(const NotNull<KernelGraphPtr> &graph
|
|||
uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
|
||||
|
||||
if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) {
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto true_stream_id = GetValue<uint32_t>(primitive->GetAttr(kAttrTrueBranchStream));
|
||||
auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrTrueBranchStream);
|
||||
processed_streams_.emplace(true_stream_id);
|
||||
|
||||
auto value_ptr = primitive->GetAttr(kStreamNeedActivedFirst);
|
||||
if (value_ptr == nullptr) {
|
||||
if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, cur_cnode_ptr)) {
|
||||
continue;
|
||||
}
|
||||
auto need_active = GetValue<bool>(value_ptr);
|
||||
auto need_active = AnfAlgo::GetNodeAttr<bool>(cur_cnode_ptr, kStreamNeedActivedFirst);
|
||||
if (need_active) {
|
||||
processed_streams_.emplace(cur_stream_id);
|
||||
}
|
||||
|
@ -371,20 +368,17 @@ void AscendStreamAssign::GetProcessedStream(const NotNull<KernelGraphPtr> &graph
|
|||
void AscendStreamAssign::UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &switch_ptr,
|
||||
vector<CNodePtr> *orders) {
|
||||
orders->emplace_back(switch_ptr);
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(switch_ptr);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto value_ptr = primitive->GetAttr(kStreamNeedActivedFirst);
|
||||
if (value_ptr == nullptr) {
|
||||
if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, switch_ptr)) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto need_active = GetValue<bool>(value_ptr);
|
||||
auto need_active = AnfAlgo::GetNodeAttr<bool>(switch_ptr, kStreamNeedActivedFirst);
|
||||
if (!need_active) {
|
||||
return;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(switch_ptr);
|
||||
auto true_stream_id = GetValue<uint32_t>(primitive->GetAttr(kAttrTrueBranchStream));
|
||||
auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(switch_ptr, kAttrTrueBranchStream);
|
||||
MS_LOG(INFO) << "Streamswtich stream id:" << AnfAlgo::GetStreamId(switch_ptr)
|
||||
<< "; active stream id:" << true_stream_id;
|
||||
|
||||
|
@ -677,14 +671,11 @@ void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &gra
|
|||
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
||||
cur_cnode_ptr = cnode_ptr_list[i];
|
||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto value_ptr = primitive->GetAttr(kStreamNeedActivedFirst);
|
||||
if (value_ptr == nullptr) {
|
||||
if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, cur_cnode_ptr)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto need_active = GetValue<bool>(value_ptr);
|
||||
auto need_active = AnfAlgo::GetNodeAttr<bool>(cur_cnode_ptr, kStreamNeedActivedFirst);
|
||||
if (need_active) {
|
||||
auto stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
|
||||
MS_LOG(INFO) << "Stream id:" << stream_id << " is need actived at first";
|
||||
|
|
|
@ -276,7 +276,8 @@ bool AkgKernelBuild::CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::j
|
|||
input_desc_json[kName] = op_input_name;
|
||||
input_desc_json[kTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index));
|
||||
auto input_shape = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index);
|
||||
if (GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) {
|
||||
if (anf_node->func_graph() != nullptr && anf_node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) &&
|
||||
GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) {
|
||||
MS_LOG(WARNING) << "we take input[" << real_input_index << "] of [" << anf_node->DebugString(2)
|
||||
<< "] as const tensor, shape: [" << Vector2Str(input_shape)
|
||||
<< "], value: " << input_desc_json[kValue];
|
||||
|
|
|
@ -291,7 +291,7 @@ bool AnfRuntimeAlgorithm::HasNodeAttr(const std::string &key, const CNodePtr &no
|
|||
// graph kernel cnode.
|
||||
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
return fg->has_flag(key);
|
||||
return fg->has_attr(key);
|
||||
}
|
||||
|
||||
size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) {
|
||||
|
|
|
@ -68,9 +68,14 @@ class AnfRuntimeAlgorithm {
|
|||
std::string node_debug_log = node->DebugString();
|
||||
MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node_debug_log.c_str();
|
||||
}
|
||||
auto primitive = GetCNodePrimitive(node);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
return GetValue<T>(primitive->GetAttr(key));
|
||||
// single op cnode.
|
||||
if (auto primitive = GetCNodePrimitive(node); primitive != nullptr) {
|
||||
return GetValue<T>(primitive->GetAttr(key));
|
||||
}
|
||||
// graph kernel cnode.
|
||||
auto fg = GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
return GetValue<T>(fg->get_attr(key));
|
||||
}
|
||||
static bool IsTupleOutput(const AnfNodePtr &anf);
|
||||
// set attr of anf node
|
||||
|
|
|
@ -0,0 +1,193 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""train bert network without lossscale"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine.datasets as de
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.optim import Lamb
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
||||
from mindspore.train.model import Model
|
||||
from src.bert_for_pre_training import BertNetworkWithLoss, BertTrainOneStepWithLossScaleCell
|
||||
from src.bert_model import BertConfig
|
||||
|
||||
DATA_DIR = ["/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"]
|
||||
SCHEMA_DIR = "/home/workspace/mindspore_dataset/bert/example/datasetSchema.json"
|
||||
|
||||
def get_config(version='base', batch_size=1):
|
||||
"""get config"""
|
||||
if version == 'base':
|
||||
bert_config = BertConfig(
|
||||
batch_size=batch_size,
|
||||
seq_length=128,
|
||||
vocab_size=21136,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=True,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float32)
|
||||
elif version == 'large':
|
||||
bert_config = BertConfig(
|
||||
batch_size=batch_size,
|
||||
seq_length=128,
|
||||
vocab_size=30522,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=4096,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=True,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
enable_fused_layernorm=True)
|
||||
else:
|
||||
bert_config = BertConfig(batch_size=batch_size)
|
||||
return bert_config
|
||||
|
||||
|
||||
def me_de_train_dataset():
|
||||
"""test me de train dataset"""
|
||||
# apply repeat operations
|
||||
repeat_count = 1
|
||||
ds = de.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids",
|
||||
"next_sentence_labels", "masked_lm_positions",
|
||||
"masked_lm_ids", "masked_lm_weights"], shuffle=False)
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
||||
# apply batch operations
|
||||
batch_size = int(os.getenv('BATCH_SIZE', '16'))
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_count)
|
||||
return ds
|
||||
|
||||
|
||||
def weight_variable(shape):
|
||||
"""weight variable"""
|
||||
np.random.seed(1)
|
||||
ones = np.random.uniform(-0.1, 0.1, size=shape).astype(np.float32)
|
||||
return Tensor(ones)
|
||||
|
||||
|
||||
class ModelCallback(Callback):
|
||||
def __init__(self):
|
||||
super(ModelCallback, self).__init__()
|
||||
self.loss_list = []
|
||||
self.overflow_list = []
|
||||
self.lossscale_list = []
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
self.loss_list.append(cb_params.net_outputs[0].asnumpy()[0])
|
||||
self.overflow_list.append(cb_params.net_outputs[1].asnumpy())
|
||||
self.lossscale_list.append(cb_params.net_outputs[2].asnumpy())
|
||||
print("epoch: {}, outputs are: {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs)))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_bert_tdt():
|
||||
"""test bert tdt"""
|
||||
np.random.seed(0)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False)
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
ds = me_de_train_dataset()
|
||||
config = get_config(version='large', batch_size=16)
|
||||
netwithloss = BertNetworkWithLoss(config, True)
|
||||
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size()*ds.get_repeat_count(),
|
||||
start_learning_rate=5e-5, end_learning_rate=1e-9,
|
||||
power=10.0, warmup_steps=0, weight_decay=0.01)
|
||||
scale_window = 3
|
||||
scale_manager = DynamicLossScaleManager(262144, 2, scale_window)
|
||||
netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
|
||||
scale_update_cell=scale_manager.get_update_cell())
|
||||
netwithgrads.set_train(True)
|
||||
model = Model(netwithgrads)
|
||||
callback = ModelCallback()
|
||||
params = netwithloss.trainable_params()
|
||||
for param in params:
|
||||
param.init_data()
|
||||
value = param.default_input
|
||||
name = param.name
|
||||
if isinstance(value, Tensor):
|
||||
if name.split('.')[-1] in ['weight']:
|
||||
if name.split('.')[-3] in ['cls2']:
|
||||
logger.info("***************** BERT param name is 1 {}".format(name))
|
||||
param.default_input = weight_variable(value.asnumpy().shape)
|
||||
else:
|
||||
logger.info("***************** BERT param name is 2 {}".format(name))
|
||||
tempshape = value.asnumpy().shape
|
||||
shape = (tempshape[1], tempshape[0])
|
||||
weight_value = weight_variable(shape).asnumpy()
|
||||
param.default_input = Tensor(np.transpose(weight_value, [1, 0]))
|
||||
else:
|
||||
logger.info("***************** BERT param name is 3 {}".format(name))
|
||||
param.default_input = weight_variable(value.asnumpy().shape)
|
||||
model.train(1, ds, callbacks=callback, dataset_sink_mode=False)
|
||||
|
||||
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
|
||||
loss_value = np.array(callback.loss_list)
|
||||
expect_loss_value = [12.559319, 12.333815, 12.339806, 12.350235, 12.343947, 12.830965, 12.375336, 12.973715,
|
||||
12.57929, 12.7766905]
|
||||
error = loss_value - expect_loss_value
|
||||
print("loss value: {}".format(loss_value))
|
||||
print("error value: {}".format(error))
|
||||
assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)
|
||||
|
||||
overflow = np.array(callback.overflow_list)
|
||||
expect_overflow = [True, True, True, True, False, False, False, True, False, False]
|
||||
print("overflow: {}".format(overflow))
|
||||
assert (overflow == expect_overflow).all()
|
||||
|
||||
loss_scale = np.array(callback.lossscale_list)
|
||||
expect_loss_scale = [131072.0, 65536.0, 32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0, 16384.0, 16384.0]
|
||||
print("loss scale: {}".format(loss_scale))
|
||||
assert np.allclose(loss_scale, expect_loss_scale, 0, 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_bert_tdt()
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.nn.graph_kernels import LambUpdateWithLR, LambNextMV
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
class LambNet(Cell):
|
||||
def __init__(self, i2, i5, x6):
|
||||
super(LambNet, self).__init__()
|
||||
self.i2 = Parameter(i2, name='i2')
|
||||
self.i5 = Parameter(i5, name='i5')
|
||||
self.x6 = Parameter(x6, name='x6')
|
||||
self.lamb_next = LambNextMV()
|
||||
self.lamb_update = LambUpdateWithLR()
|
||||
|
||||
def construct(self, i1, i3, i4, i6, i7, i8, i9, ix0, ix1, ix2, ix3,
|
||||
x1, x2, x3, x4, x5, gy, se, my):
|
||||
return self.lamb_next(i1, self.i2, i3, i4, self.i5, i6, i7, i8, i9, ix0,
|
||||
ix1, ix2, ix3), \
|
||||
self.lamb_update(x1, x2, x3, x4, x5, self.x6, gy, se, my)
|
||||
|
||||
def LambUpdateNumpy(x1, x2, x3, x4, x5, x6, gy, se, my):
|
||||
trust_ratio = np.where(np.greater(x2, gy),
|
||||
np.where(np.greater(x1, gy), np.divide(x2, x3), se),
|
||||
se)
|
||||
trust_ratio = np.maximum(np.minimum(trust_ratio, my), gy)
|
||||
update_with_lr = trust_ratio * x4 * x5
|
||||
next_param = x6 - np.reshape(update_with_lr, x6.shape)
|
||||
return next_param
|
||||
|
||||
def LambNextMVNumpy(i1, i2, i3, i4, i5, i6, i7, i8, i9, x0, x1, x2, x3):
|
||||
m_fp32 = i5.astype(np.float32)
|
||||
v_fp32 = i2.astype(np.float32)
|
||||
next_m = i8 * m_fp32 + i9 * i4
|
||||
next_v = x0 * v_fp32 + x1 * i1
|
||||
next_mm = next_m / i6
|
||||
next_vv = next_v / i3
|
||||
update = next_mm / (np.sqrt(next_vv) + x3)
|
||||
add3 = next_mm / np.sqrt(next_vv + x3) + x2 * i7
|
||||
return add3, next_m, next_v, update
|
||||
|
||||
|
||||
def tensor_all(*args):
|
||||
res = [Tensor(a) for a in args]
|
||||
return res
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_graph_kernel_lamb():
|
||||
shape = [1, 16]
|
||||
oshape = [1]
|
||||
np.random.seed(0)
|
||||
x1 = np.random.normal(0, 1, oshape).astype(np.float32)
|
||||
x2 = np.random.normal(0, 1, oshape).astype(np.float32)
|
||||
x3 = np.random.normal(0, 1, oshape).astype(np.float32)
|
||||
x4 = np.random.normal(0, 1, oshape).astype(np.float32)
|
||||
x5 = np.random.normal(0, 1, shape).astype(np.float32)
|
||||
x6 = np.random.normal(0, 1, shape).astype(np.float32)
|
||||
gy = np.random.normal(0, 1, oshape).astype(np.float32)
|
||||
se = np.random.normal(0, 1, oshape).astype(np.float32)
|
||||
my = np.random.normal(0, 1, oshape).astype(np.float32)
|
||||
|
||||
tx1, tx2, tx3, tx4, tx5, tx6, tgy, tse, tmy = tensor_all(
|
||||
x1, x2, x3, x4, x5, x6, gy, se, my)
|
||||
|
||||
np.random.seed(1)
|
||||
i1 = np.abs(np.random.normal(0, 1, shape)).astype(np.float32)
|
||||
i2 = np.abs(np.random.normal(0, 1, shape)).astype(np.float32)
|
||||
i3 = np.abs(np.random.normal(0, 1, shape)).astype(np.float32)
|
||||
i4 = np.random.normal(0, 1, shape).astype(np.float32)
|
||||
i5 = np.random.normal(0, 1, shape).astype(np.float32)
|
||||
i6 = np.abs(np.random.normal(0, 1, shape)).astype(np.float32)
|
||||
i7 = np.random.normal(0, 1, shape).astype(np.float32)
|
||||
i8 = np.random.normal(0, 1, shape).astype(np.float32)
|
||||
i9 = np.random.normal(0, 1, shape).astype(np.float32)
|
||||
ix0 = np.abs(np.random.normal(0, 1, shape)).astype(np.float32)
|
||||
ix1 = np.abs(np.random.normal(0, 1, shape)).astype(np.float32)
|
||||
ix2 = np.random.normal(0, 1, shape).astype(np.float32)
|
||||
ix3 = np.ones(shape).astype(np.float32) * 1e-6
|
||||
|
||||
ti1, ti2, ti3, ti4, ti5, ti6, ti7, ti8, ti9, tix0, tix1, tix2, tix3 = \
|
||||
tensor_all(i1, i2, i3, i4, i5, i6, i7, i8, i9, ix0, ix1, ix2, ix3)
|
||||
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
|
||||
net = LambNet(ti2, ti5, tx6)
|
||||
(wa3, wup), _ = net(ti1, ti3, ti4, ti6, ti7, ti8, ti9, tix0, tix1, tix2, tix3,
|
||||
tx1, tx2, tx3, tx4, tx5, tgy, tse, tmy)
|
||||
|
||||
wi2 = net.i2.data.asnumpy().copy()
|
||||
wi5 = net.i5.data.asnumpy().copy()
|
||||
ares = net.x6.data.asnumpy().copy()
|
||||
|
||||
context.set_context(enable_graph_kernel=False)
|
||||
|
||||
a3, a0, a1, up = LambNextMVNumpy(i1, i2, i3, i4, i5, i6, i7, i8, i9, ix0,
|
||||
ix1, ix2, ix3)
|
||||
|
||||
np_res = LambUpdateNumpy(x1, x2, x3, x4, x5, x6, gy, se, my)
|
||||
|
||||
rtol = 0.0001
|
||||
atol = 0.0001
|
||||
|
||||
wres = (wa3.asnumpy().copy(), wi5, wi2, wup.asnumpy().copy())
|
||||
bres = (a3, a0, a1, up)
|
||||
|
||||
cmp_res = list(map(lambda x, y: np.allclose(x, y, rtol, atol),
|
||||
wres, bres))
|
||||
|
||||
assert all(cmp_res) and np.allclose(ares, np_res, rtol, atol)
|
Loading…
Reference in New Issue