forked from mindspore-Ecosystem/mindspore
!26280 fix mindir export's error when using _encrypt
Merge pull request !26280 from lianliguang/master
This commit is contained in:
commit
9a5fd32bd2
|
@ -1018,7 +1018,9 @@ bool SetMindIRGraphAction(const ResourcePtr &res) {
|
|||
});
|
||||
if (!AbstractBasePtrListDeepEqual(func_args, broaded_args)) {
|
||||
MS_LOG(EXCEPTION) << "The input arguments is not compatible with the function graph which has been exported before."
|
||||
<< " Please check the args is same with export.\n"
|
||||
<< "Please check the args is same with export.\n"
|
||||
<< "The export input argument size : " << func_args.size() << "\n"
|
||||
<< "The load input argument size : " << broaded_args.size() << "\n"
|
||||
<< "Export input args info:" << abstract::ArgsToString(func_args) << "\n"
|
||||
<< "The input args info:" << abstract::ArgsToString(broaded_args);
|
||||
}
|
||||
|
|
|
@ -69,11 +69,7 @@ class MS_CORE_API AbstractFuncUnion final : public AbstractFunction {
|
|||
|
||||
std::string ToString() const override;
|
||||
|
||||
AbstractFunctionPtr GetUnique() override {
|
||||
MS_LOG(EXCEPTION) << "Cannot get unique from AbstractFuncUnion";
|
||||
AbstractFunctionPtr result;
|
||||
return result;
|
||||
}
|
||||
AbstractFunctionPtr GetUnique() override { MS_LOG(EXCEPTION) << "Cannot get unique from AbstractFuncUnion"; }
|
||||
|
||||
/// \brief Check whether the input AbstractFunction is in AbstractFuncUnion.
|
||||
///
|
||||
|
@ -90,11 +86,7 @@ class MS_CORE_API AbstractFuncUnion final : public AbstractFunction {
|
|||
|
||||
std::size_t hash() const override;
|
||||
|
||||
AbstractFunctionPtr Copy() const override {
|
||||
MS_LOG(EXCEPTION) << "Cannot Copy from AbstractFuncUnion";
|
||||
AbstractFunctionPtr result;
|
||||
return result;
|
||||
}
|
||||
AbstractFunctionPtr Copy() const override { MS_LOG(EXCEPTION) << "Cannot Copy from AbstractFuncUnion"; }
|
||||
|
||||
private:
|
||||
AbstractFuncAtomPtrList func_list_;
|
||||
|
|
|
@ -64,8 +64,6 @@ AbstractBasePtr InferImplDropout(const AnalysisEnginePtr &, const PrimitivePtr &
|
|||
|
||||
AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSqrtGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
@ -202,8 +200,7 @@ AbstractBasePtr InferImplScatterElements(const AnalysisEnginePtr &, const Primit
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplRealDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
AbstractBasePtr InferImplSubAndFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
@ -246,8 +243,6 @@ AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitiveP
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplReduceFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -56,16 +56,16 @@ void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const Standard
|
|||
class RegisterStandardPrimitiveEvalHelper {
|
||||
public:
|
||||
RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const InferShapeImpl &infer_impl,
|
||||
const InferValueImpl &infer_value_impl, const bool is_wight_list = true) {
|
||||
const StandardPrimitiveImplReg impl_reg{infer_impl, infer_value_impl, is_wight_list};
|
||||
const InferValueImpl &infer_value_impl, const bool is_white_list = true) {
|
||||
const StandardPrimitiveImplReg impl_reg{infer_impl, infer_value_impl, is_white_list};
|
||||
RegisterStandardPrimitiveImpl(primitive, impl_reg);
|
||||
}
|
||||
~RegisterStandardPrimitiveEvalHelper() = default;
|
||||
};
|
||||
|
||||
#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, infer_impl, infer_value_impl, is_wight_list) \
|
||||
#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, infer_impl, infer_value_impl, is_white_list) \
|
||||
static auto helper_##name = \
|
||||
abstract::RegisterStandardPrimitiveEvalHelper(primitive, infer_impl, infer_value_impl, is_wight_list); \
|
||||
abstract::RegisterStandardPrimitiveEvalHelper(primitive, infer_impl, infer_value_impl, is_white_list); \
|
||||
std::shared_ptr<ops::PrimitiveC> GetDefaultPrimC##name() { \
|
||||
auto out = std::make_shared<name>(); \
|
||||
return out; \
|
||||
|
|
|
@ -22,10 +22,11 @@ import os
|
|||
import shutil
|
||||
import stat
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
import threading
|
||||
from threading import Thread, Lock
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
import numpy as np
|
||||
from mindspore.train.checkpoint_pb2 import Checkpoint
|
||||
|
@ -67,6 +68,7 @@ _ckpt_mutex = Lock()
|
|||
SLICE_SIZE = 512 * 1024
|
||||
PROTO_LIMIT_SIZE = 1024 * 1024 * 2
|
||||
TOTAL_SAVE = 1024 * 1024
|
||||
PARAMETER_SPLIT_SIZE = 1024 * 1024 * 1024
|
||||
|
||||
|
||||
def _special_process_par(par, new_par):
|
||||
|
@ -837,6 +839,108 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
|
|||
net.set_train(mode=True)
|
||||
|
||||
|
||||
def _generate_front_info_for_param_data_file(is_encrypt, kwargs):
|
||||
front_info = bytes()
|
||||
check_code = sys.byteorder == "little"
|
||||
front_info += check_code.to_bytes(1, byteorder=sys.byteorder)
|
||||
front_info += bytes(63)
|
||||
if is_encrypt():
|
||||
front_info = _encrypt(front_info, len(front_info), kwargs['enc_key'],
|
||||
len(kwargs['enc_key']), kwargs['enc_mode'])
|
||||
return front_info
|
||||
|
||||
|
||||
def _change_file(ori_data_file_name, dirname, external_local):
|
||||
# The parameter has been not written in the file
|
||||
if os.path.getsize(ori_data_file_name) == 64:
|
||||
raise RuntimeError("The parameter size is exceed 1T,cannot export to the file")
|
||||
data_file_name = os.path.join(dirname, external_local)
|
||||
if os.path.exists(data_file_name):
|
||||
os.chmod(data_file_name, stat.S_IWUSR)
|
||||
return data_file_name
|
||||
|
||||
|
||||
def _spilt_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
||||
'''
|
||||
The function to save parameter data
|
||||
'''
|
||||
logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.")
|
||||
# save parameter
|
||||
file_prefix = file_name.split("/")[-1]
|
||||
if file_prefix.endswith(".mindir"):
|
||||
file_prefix = file_prefix[:-7]
|
||||
current_path = os.path.abspath(file_name)
|
||||
dirname = os.path.dirname(current_path)
|
||||
data_path = os.path.join(dirname, file_prefix + "_variables")
|
||||
if os.path.exists(data_path):
|
||||
shutil.rmtree(data_path)
|
||||
os.makedirs(data_path, exist_ok=True)
|
||||
os.chmod(data_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
|
||||
# Reserves 4096 bytes as spare information such as check data
|
||||
offset = 64
|
||||
index = 0
|
||||
parameter_size = (offset / 1024)
|
||||
external_local = os.path.join(file_prefix + "_variables", "data_" + str(index))
|
||||
data_file_name = os.path.join(dirname, external_local)
|
||||
if os.path.exists(data_file_name):
|
||||
os.chmod(data_file_name, stat.S_IWUSR)
|
||||
f = open(data_file_name, "wb")
|
||||
f.write(bytes(offset))
|
||||
try:
|
||||
for param_proto in model.graph.parameter:
|
||||
name = param_proto.name[param_proto.name.find(":") + 1:]
|
||||
param = net_dict[name]
|
||||
raw_data = param.data.asnumpy().tobytes()
|
||||
data_length = len(raw_data)
|
||||
append_size = 0
|
||||
if data_length % 64 != 0:
|
||||
append_size = 64 - (data_length % 64)
|
||||
parameter_size += ((append_size + data_length) / 1024)
|
||||
if parameter_size > PARAMETER_SPLIT_SIZE:
|
||||
front_info = _generate_front_info_for_param_data_file(is_encrypt, kwargs)
|
||||
f.seek(0, 0)
|
||||
f.write(front_info)
|
||||
f.close()
|
||||
os.chmod(data_file_name, stat.S_IRUSR)
|
||||
offset = 64
|
||||
index += 1
|
||||
parameter_size = (offset + append_size + data_length) / 1024
|
||||
external_local = os.path.join(file_prefix + "_variables", "data_" + str(index))
|
||||
data_file_name = _change_file(data_file_name, dirname, external_local)
|
||||
f = open(data_file_name, "wb")
|
||||
f.write(bytes(offset))
|
||||
param_proto.external_data.location = external_local
|
||||
param_proto.external_data.length = data_length
|
||||
param_proto.external_data.offset = offset
|
||||
write_data = raw_data + bytes(append_size)
|
||||
offset += (data_length + append_size)
|
||||
if is_encrypt():
|
||||
write_data = _encrypt(write_data, len(write_data), kwargs['enc_key'],
|
||||
len(kwargs['enc_key']), kwargs['enc_mode'])
|
||||
f.write(write_data)
|
||||
|
||||
# save graph
|
||||
graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir")
|
||||
if os.path.exists(graph_file_name):
|
||||
os.chmod(graph_file_name, stat.S_IWUSR)
|
||||
with open(graph_file_name, 'wb') as model_file:
|
||||
os.chmod(graph_file_name, stat.S_IRUSR | stat.S_IWUSR)
|
||||
model_string = model.SerializeToString()
|
||||
if is_encrypt():
|
||||
model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'],
|
||||
len(kwargs['enc_key']),
|
||||
kwargs['enc_mode'])
|
||||
model_file.write(model_string)
|
||||
os.chmod(graph_file_name, stat.S_IRUSR)
|
||||
|
||||
front_info = _generate_front_info_for_param_data_file(is_encrypt, kwargs)
|
||||
f.seek(0, 0)
|
||||
f.write(front_info)
|
||||
finally:
|
||||
f.close()
|
||||
os.chmod(data_file_name, stat.S_IRUSR)
|
||||
|
||||
|
||||
def _save_mindir(net, file_name, *inputs, **kwargs):
|
||||
"""Save MindIR format file."""
|
||||
model = mindir_model()
|
||||
|
@ -859,67 +963,7 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
|
|||
if save_together:
|
||||
_save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs)
|
||||
else:
|
||||
logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.")
|
||||
# save parameter
|
||||
file_prefix = file_name.split("/")[-1]
|
||||
if file_prefix.endswith(".mindir"):
|
||||
file_prefix = file_prefix[:-7]
|
||||
current_path = os.path.abspath(file_name)
|
||||
dirname = os.path.dirname(current_path)
|
||||
data_path = os.path.join(dirname, file_prefix + "_variables")
|
||||
if os.path.exists(data_path):
|
||||
shutil.rmtree(data_path)
|
||||
os.makedirs(data_path, exist_ok=True)
|
||||
os.chmod(data_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
|
||||
# Reserves 4096 bytes as spare information such as check data
|
||||
offset = 64
|
||||
data_file_name = os.path.join(data_path, "veriables.data")
|
||||
if os.path.exists(data_file_name):
|
||||
os.chmod(data_file_name, stat.S_IWUSR)
|
||||
with open(data_file_name, "wb") as f:
|
||||
f.write(bytes(offset))
|
||||
for name, param in net_dict.items():
|
||||
for param_proto in model.graph.parameter:
|
||||
if name == param_proto.name[param_proto.name.find(":") + 1:]:
|
||||
data_file = os.path.join(file_prefix + "_variables", "veriables.data")
|
||||
param_proto.external_data.location = data_file
|
||||
raw_data = param.data.asnumpy().tobytes()
|
||||
data_length = len(raw_data)
|
||||
param_proto.external_data.length = data_length
|
||||
param_proto.external_data.offset = offset
|
||||
write_data = raw_data
|
||||
offset += data_length
|
||||
if data_length % 64 != 0:
|
||||
append_size = 64 - (data_length % 64)
|
||||
write_data += (bytes(append_size))
|
||||
offset += append_size
|
||||
if is_encrypt():
|
||||
write_data = _encrypt(write_data, len(write_data), kwargs['enc_key'],
|
||||
len(kwargs['enc_key']), kwargs['enc_mode'])
|
||||
f.write(write_data)
|
||||
|
||||
# save graph
|
||||
graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir")
|
||||
if os.path.exists(graph_file_name):
|
||||
os.chmod(graph_file_name, stat.S_IWUSR)
|
||||
with open(graph_file_name, 'wb') as model_file:
|
||||
os.chmod(graph_file_name, stat.S_IRUSR | stat.S_IWUSR)
|
||||
model_string = model.SerializeToString()
|
||||
if is_encrypt():
|
||||
model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'],
|
||||
len(kwargs['enc_key']),
|
||||
kwargs['enc_mode'])
|
||||
model_file.write(model_string)
|
||||
os.chmod(graph_file_name, stat.S_IRUSR)
|
||||
|
||||
front_info = bytearray()
|
||||
check_code = sys.byteorder == "little"
|
||||
front_info += check_code.to_bytes(1, byteorder=sys.byteorder)
|
||||
f.seek(0, 0)
|
||||
if is_encrypt():
|
||||
front_info = _encrypt(front_info, len(front_info), kwargs['enc_key'],
|
||||
len(kwargs['enc_key']), kwargs['enc_mode'])
|
||||
f.write(front_info)
|
||||
_spilt_save(net_dict, model, file_name, is_encrypt, **kwargs)
|
||||
|
||||
|
||||
def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
|
||||
|
@ -1242,6 +1286,7 @@ def restore_group_info_list(group_info_file_name):
|
|||
restore_rank_list = [rank for rank in restore_list.dim]
|
||||
return restore_rank_list
|
||||
|
||||
|
||||
def build_searched_strategy(strategy_filename):
|
||||
"""
|
||||
Build strategy of every parameter in network. Used in the case of distributed inference.
|
||||
|
@ -1528,7 +1573,6 @@ def async_ckpt_thread_status():
|
|||
def _check_predict_strategy(predict_strategy):
|
||||
"""Check predict strategy."""
|
||||
|
||||
|
||||
def _check_int_list(arg):
|
||||
if not isinstance(arg, list):
|
||||
return False
|
||||
|
|
|
@ -0,0 +1,146 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test mindir export larger than 1G """
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.serialization import export, load
|
||||
|
||||
def get_front_info():
|
||||
correct_data = bytes()
|
||||
check_code = sys.byteorder == "little"
|
||||
correct_data += check_code.to_bytes(1, byteorder=sys.byteorder)
|
||||
correct_data += bytes(63)
|
||||
return correct_data
|
||||
|
||||
def get_correct_data(parameter):
|
||||
correct_data = bytes()
|
||||
data = parameter.data.asnumpy().tobytes()
|
||||
data_size = len(data)
|
||||
if data_size % 64 != 0:
|
||||
data += bytes((64 - data_size % 64))
|
||||
correct_data += data
|
||||
return correct_data
|
||||
|
||||
|
||||
def get_data(mindir_name):
|
||||
data_path = mindir_name + "_variables"
|
||||
data = bytes()
|
||||
for dirpath, _, filenames in os.walk(data_path):
|
||||
for filename in filenames:
|
||||
with open(os.path.join(dirpath, filename), "rb") as f:
|
||||
data += f.readline()
|
||||
return data
|
||||
|
||||
|
||||
def test_mindir_export_split():
|
||||
"""
|
||||
Feature: MindIR Export model is exceed TOTAL_SAVE(1G but mocked as 0)
|
||||
Description: MindIR Export model is exceed TOTAL_SAVE should be split save as model file and data file
|
||||
Expectation: No exception.
|
||||
"""
|
||||
ms.train.serialization.TOTAL_SAVE = 0
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.addn = ops.AddN()
|
||||
self.y = Parameter(Tensor(np.array([2, 3, 3, 4]).astype(np.float32)), name="w")
|
||||
self.z = Parameter(Tensor(np.array([2, 3, 3, 4])).astype(np.float32), name="z")
|
||||
|
||||
def construct(self, x):
|
||||
return self.addn((x, self.y, self.z))
|
||||
|
||||
x = Tensor(np.array([2, 3, 3, 4]).astype(np.float32))
|
||||
add_net = Net()
|
||||
export(add_net, x, file_name="mindir_export_split", file_format="MINDIR")
|
||||
graph = load("mindir_export_split_graph.mindir")
|
||||
assert graph is not None
|
||||
correct_data = get_front_info()
|
||||
correct_data += get_correct_data(add_net.y)
|
||||
correct_data += get_correct_data(add_net.z)
|
||||
export_data = get_data("mindir_export_split")
|
||||
assert export_data == correct_data
|
||||
assert oct(os.stat(os.path.join("mindir_export_split_variables", "data_0")).st_mode)[-3:] == "400"
|
||||
assert oct(os.stat("mindir_export_split_graph.mindir").st_mode)[-3:] == "400"
|
||||
|
||||
|
||||
def test_mindir_export_larger_error():
|
||||
"""
|
||||
Feature: MindIR Export model is exceed TOTAL_SAVE(1G but mocked as 0)
|
||||
Description: MindIR Export model is exceed TOTAL_SAVE(1G but mocked as 0) should be split save as model file
|
||||
and data file if the model has a parameter which exceed PARAMETER_SPLIT_SIZE(1T but mocked as 0)
|
||||
the exception should be reported.
|
||||
Expectation: Parameter is exceed PARAMETER_SPLIT_SIZE
|
||||
"""
|
||||
ms.train.serialization.TOTAL_SAVE = 0
|
||||
ms.train.serialization.PARAMETER_SPLIT_SIZE = 0
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.add = ops.Add()
|
||||
self.y = Parameter(Tensor(np.array([2, 3, 3, 4]).astype(np.float32)), name="w")
|
||||
|
||||
def construct(self, x):
|
||||
return self.add(x, self.y)
|
||||
|
||||
x = Tensor(np.array([2, 3, 3, 4]).astype(np.float32))
|
||||
add = Net()
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
export(add, x, file_name="net", file_format="MINDIR")
|
||||
assert e.message == "The parameter size is exceed 1T,cannot export to the file"
|
||||
|
||||
|
||||
def test_mindir_export_larger_parameter_exceed_1t_mock():
|
||||
"""
|
||||
Feature: MindIR Export model is exceed TOTAL_SAVE(1G but mocked as 0)
|
||||
Description: MindIR Export model is exceed TOTAL_SAVE(1G but mocked as 0) should be split save as model file
|
||||
and data file if the parameter data file exceed PARAMETER_SPLIT_SIZE(1T but mocked as 129Bytes) limit,
|
||||
it will be split to another file named data_0,data_1,data_2...
|
||||
Expectation: No exception.
|
||||
"""
|
||||
ms.train.serialization.TOTAL_SAVE = 0
|
||||
ms.train.serialization.PARAMETER_SPLIT_SIZE = 129 / 1024
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.addn = ops.AddN()
|
||||
self.y = Parameter(Tensor(np.array([2, 3, 3, 4]).astype(np.float32)), name="w")
|
||||
self.z = Parameter(Tensor(np.array([2, 3, 3, 4])).astype(np.float32), name="z")
|
||||
|
||||
def construct(self, x):
|
||||
return self.addn((x, self.y, self.z))
|
||||
|
||||
x = Tensor(np.array([2, 3, 3, 4]).astype(np.float32))
|
||||
add_net = Net()
|
||||
export(add_net, x, file_name="larger_parameter_exceed_1T_mock", file_format="MINDIR")
|
||||
graph = load("larger_parameter_exceed_1T_mock_graph.mindir")
|
||||
assert graph is not None
|
||||
correct_data = get_front_info()
|
||||
correct_data += get_correct_data(add_net.y)
|
||||
correct_data += get_front_info()
|
||||
correct_data += get_correct_data(add_net.z)
|
||||
export_data = get_data("larger_parameter_exceed_1T_mock")
|
||||
assert export_data == correct_data
|
Loading…
Reference in New Issue