forked from mindspore-Ecosystem/mindspore
save print data to file
This commit is contained in:
parent
60f4795f40
commit
ac23a179bf
|
@ -80,6 +80,7 @@ if (ENABLE_DUMP_PROTO)
|
||||||
"utils/anf_ir.proto"
|
"utils/anf_ir.proto"
|
||||||
"utils/summary.proto"
|
"utils/summary.proto"
|
||||||
"utils/checkpoint.proto"
|
"utils/checkpoint.proto"
|
||||||
|
"utils/print.proto"
|
||||||
)
|
)
|
||||||
ms_protobuf_generate_py(PY_SRCS PY_HDRS PY_PYS ${PROTO_PY})
|
ms_protobuf_generate_py(PY_SRCS PY_HDRS PY_PYS ${PROTO_PY})
|
||||||
|
|
||||||
|
|
|
@ -146,7 +146,8 @@ PYBIND11_MODULE(_c_expression, m) {
|
||||||
.def("get_check_bprop_flag", &mindspore::MsContext::check_bprop_flag, "Get whether to check bprop.")
|
.def("get_check_bprop_flag", &mindspore::MsContext::check_bprop_flag, "Get whether to check bprop.")
|
||||||
.def("set_check_bprop_flag", &mindspore::MsContext::set_check_bprop_flag, "Set whether to check bprop.")
|
.def("set_check_bprop_flag", &mindspore::MsContext::set_check_bprop_flag, "Set whether to check bprop.")
|
||||||
.def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.")
|
.def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.")
|
||||||
.def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size.");
|
.def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size.")
|
||||||
|
.def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print.");
|
||||||
|
|
||||||
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
|
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
|
||||||
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")
|
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")
|
||||||
|
|
|
@ -83,6 +83,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
||||||
profiling_options_ = "training_trace";
|
profiling_options_ = "training_trace";
|
||||||
check_bprop_flag_ = false;
|
check_bprop_flag_ = false;
|
||||||
max_device_memory_ = kDefaultMaxDeviceMemory;
|
max_device_memory_ = kDefaultMaxDeviceMemory;
|
||||||
|
print_file_path_ = "";
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<MsContext> MsContext::GetInstance() {
|
std::shared_ptr<MsContext> MsContext::GetInstance() {
|
||||||
|
|
|
@ -147,6 +147,8 @@ class MsContext {
|
||||||
std::string profiling_options() const { return profiling_options_; }
|
std::string profiling_options() const { return profiling_options_; }
|
||||||
bool check_bprop_flag() const { return check_bprop_flag_; }
|
bool check_bprop_flag() const { return check_bprop_flag_; }
|
||||||
void set_check_bprop_flag(bool check_bprop_flag) { check_bprop_flag_ = check_bprop_flag; }
|
void set_check_bprop_flag(bool check_bprop_flag) { check_bprop_flag_ = check_bprop_flag; }
|
||||||
|
void set_print_file_path(const std::string &file) { print_file_path_ = file; }
|
||||||
|
const std::string &print_file_path() const { return print_file_path_; }
|
||||||
|
|
||||||
float max_device_memory() const { return max_device_memory_; }
|
float max_device_memory() const { return max_device_memory_; }
|
||||||
void set_max_device_memory(float max_device_memory) { max_device_memory_ = max_device_memory; }
|
void set_max_device_memory(float max_device_memory) { max_device_memory_ = max_device_memory; }
|
||||||
|
@ -192,6 +194,7 @@ class MsContext {
|
||||||
std::string profiling_options_;
|
std::string profiling_options_;
|
||||||
bool check_bprop_flag_;
|
bool check_bprop_flag_;
|
||||||
float max_device_memory_;
|
float max_device_memory_;
|
||||||
|
std::string print_file_path_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
syntax = "proto2";
|
||||||
|
|
||||||
|
package mindspore.prntpb;
|
||||||
|
|
||||||
|
message TensorProto {
|
||||||
|
// The shape of the tensor.
|
||||||
|
repeated int64 dims = 1;
|
||||||
|
// The type of the tensor.
|
||||||
|
required string tensor_type = 2;
|
||||||
|
// The data of the tensor.
|
||||||
|
required bytes tensor_content = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message Print {
|
||||||
|
message Value {
|
||||||
|
oneof value {
|
||||||
|
string desc = 1;
|
||||||
|
TensorProto tensor = 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
repeated Value value = 1;
|
||||||
|
}
|
|
@ -47,6 +47,18 @@ static std::map<std::string, size_t> type_size_map = {
|
||||||
{"int64_t", sizeof(int64_t)}, {"uint64_t", sizeof(uint64_t)}, {"float16", sizeof(float) / 2},
|
{"int64_t", sizeof(int64_t)}, {"uint64_t", sizeof(uint64_t)}, {"float16", sizeof(float) / 2},
|
||||||
{"float", sizeof(float)}, {"double", sizeof(double)}, {"bool", sizeof(bool)}};
|
{"float", sizeof(float)}, {"double", sizeof(double)}, {"bool", sizeof(bool)}};
|
||||||
|
|
||||||
|
std::string GetParseType(const std::string &tensorType_) {
|
||||||
|
static const std::map<std::string, std::string> print_parse_map = {
|
||||||
|
{"int8_t", "Int8"}, {"uint8_t", "Uint8"}, {"int16_t", "Int16"}, {"uint16_t", "Uint16"},
|
||||||
|
{"int32_t", "Int32"}, {"uint32_t", "Uint32"}, {"int64_t", "Int64"}, {"uint64_t", "Uint64"},
|
||||||
|
{"float16", "Float16"}, {"float", "Float32"}, {"double", "Float64"}, {"bool", "Bool"}};
|
||||||
|
auto type_iter = print_parse_map.find(tensorType_);
|
||||||
|
if (type_iter == print_parse_map.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "type of tensor need to print is not support " << tensorType_;
|
||||||
|
}
|
||||||
|
return type_iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
bool ParseTensorShape(const std::string &input_shape_str, std::vector<int> *const tensor_shape, size_t *dims) {
|
bool ParseTensorShape(const std::string &input_shape_str, std::vector<int> *const tensor_shape, size_t *dims) {
|
||||||
if (tensor_shape == nullptr) {
|
if (tensor_shape == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -141,7 +153,7 @@ void convertDataItem2Scalar(const char *str_data_ptr, const string &tensor_type,
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "Cannot print scalar because of unsupport data type: " << tensor_type << ".";
|
MS_LOG(EXCEPTION) << "Cannot print scalar because of unsupport data type: " << tensor_type << ".";
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
}
|
||||||
|
|
||||||
bool judgeLengthValid(const size_t str_len, const string &tensor_type) {
|
bool judgeLengthValid(const size_t str_len, const string &tensor_type) {
|
||||||
auto type_iter = type_size_map.find(tensor_type);
|
auto type_iter = type_size_map.find(tensor_type);
|
||||||
|
@ -200,14 +212,84 @@ bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) {
|
||||||
return ret_end_sequence;
|
return ret_end_sequence;
|
||||||
}
|
}
|
||||||
|
|
||||||
void TensorPrint::operator()() {
|
bool SaveDataItem2File(const std::vector<tdt::DataItem> &items, const std::string &print_file_path, prntpb::Print print,
|
||||||
while (true) {
|
std::fstream *output) {
|
||||||
std::vector<tdt::DataItem> bundle;
|
bool ret_end_sequence = false;
|
||||||
if (tdt::TdtHostPopData("_npu_log", bundle) != 0) {
|
for (auto &item : items) {
|
||||||
|
if (item.dataType_ == tdt::TDT_END_OF_SEQUENCE) {
|
||||||
|
ret_end_sequence = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (ConvertDataItem2Tensor(bundle)) {
|
prntpb::Print_Value *value = print.add_value();
|
||||||
break;
|
std::shared_ptr<std::string> str_data_ptr = std::static_pointer_cast<std::string>(item.dataPtr_);
|
||||||
|
MS_EXCEPTION_IF_NULL(str_data_ptr);
|
||||||
|
if (item.tensorShape_ == kShapeScalar || item.tensorShape_ == kShapeNone) {
|
||||||
|
if (!judgeLengthValid(str_data_ptr->size(), item.tensorType_)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Print op receive data length is invalid.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> tensor_shape;
|
||||||
|
size_t totaldims = 1;
|
||||||
|
if (!ParseTensorShape(item.tensorShape_, &tensor_shape, &totaldims)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Tensor print can not parse tensor shape, receive info" << item.tensorShape_;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (item.tensorType_ == "string") {
|
||||||
|
std::string data(reinterpret_cast<const char *>(str_data_ptr->c_str()), item.dataLen_);
|
||||||
|
value->set_desc(data);
|
||||||
|
} else {
|
||||||
|
auto parse_type = GetParseType(item.tensorType_);
|
||||||
|
prntpb::TensorProto *tensor = value->mutable_tensor();
|
||||||
|
if (!(item.tensorShape_ == kShapeScalar) && !(item.tensorShape_ == kShapeNone)) {
|
||||||
|
for (const auto &dim : tensor_shape) {
|
||||||
|
tensor->add_dims(static_cast<::google::protobuf::int64>(dim));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tensor->set_tensor_type(parse_type);
|
||||||
|
std::string data(reinterpret_cast<const char *>(str_data_ptr->c_str()), item.dataLen_);
|
||||||
|
tensor->set_tensor_content(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!print.SerializeToOstream(output)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Save print file:" << print_file_path << " fail.";
|
||||||
|
}
|
||||||
|
print.Clear();
|
||||||
|
}
|
||||||
|
return ret_end_sequence;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TensorPrint::operator()() {
|
||||||
|
prntpb::Print print;
|
||||||
|
auto ms_context = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
|
std::string print_file_path = ms_context->print_file_path();
|
||||||
|
if (print_file_path == "") {
|
||||||
|
while (true) {
|
||||||
|
std::vector<tdt::DataItem> bundle;
|
||||||
|
if (tdt::TdtHostPopData("_npu_log", bundle) != 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (ConvertDataItem2Tensor(bundle)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
std::fstream output(print_file_path, std::ios::out | std::ios::trunc | std::ios::binary);
|
||||||
|
while (true) {
|
||||||
|
std::vector<tdt::DataItem> bundle;
|
||||||
|
if (tdt::TdtHostPopData("_npu_log", bundle) != 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (SaveDataItem2File(bundle, print_file_path, print, &output)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output.close();
|
||||||
|
std::string path_string = print_file_path;
|
||||||
|
if (chmod(common::SafeCStr(path_string), S_IRUSR) == -1) {
|
||||||
|
MS_LOG(ERROR) << "Modify file:" << print_file_path << " to r fail.";
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,8 @@
|
||||||
#include "tdt/tsd_client.h"
|
#include "tdt/tsd_client.h"
|
||||||
#include "tdt/tdt_host_interface.h"
|
#include "tdt/tdt_host_interface.h"
|
||||||
#include "tdt/data_common.h"
|
#include "tdt/data_common.h"
|
||||||
|
#include "proto/print.pb.h"
|
||||||
|
#include "utils/context/ms_context.h"
|
||||||
#endif
|
#endif
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
class TensorPrint {
|
class TensorPrint {
|
||||||
|
|
|
@ -346,6 +346,15 @@ class _Context:
|
||||||
raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"")
|
raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"")
|
||||||
self._context_handle.set_max_device_memory(max_device_memory_value)
|
self._context_handle.set_max_device_memory(max_device_memory_value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def print_file_path(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
@print_file_path.setter
|
||||||
|
def print_file_path(self, file):
|
||||||
|
self._context_handle.set_print_file_path(file)
|
||||||
|
|
||||||
|
|
||||||
def check_input_format(x):
|
def check_input_format(x):
|
||||||
import re
|
import re
|
||||||
pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
|
pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
|
||||||
|
@ -473,7 +482,7 @@ def reset_auto_parallel_context():
|
||||||
save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool,
|
save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool,
|
||||||
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
|
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
|
||||||
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
|
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
|
||||||
check_bprop=bool, max_device_memory=str)
|
check_bprop=bool, max_device_memory=str, print_file_path=str)
|
||||||
def set_context(**kwargs):
|
def set_context(**kwargs):
|
||||||
"""
|
"""
|
||||||
Sets context for running environment.
|
Sets context for running environment.
|
||||||
|
|
|
@ -21,6 +21,7 @@ import mindspore.nn as nn
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore.train.checkpoint_pb2 import Checkpoint
|
from mindspore.train.checkpoint_pb2 import Checkpoint
|
||||||
|
from mindspore.train.print_pb2 import Print
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
|
@ -30,11 +31,15 @@ from mindspore._checkparam import check_input_data
|
||||||
|
|
||||||
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export"]
|
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export"]
|
||||||
|
|
||||||
tensor_to_ms_type = {"Int8": mstype.int8, "Int16": mstype.int16, "Int32": mstype.int32, "Int64": mstype.int64,
|
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
|
||||||
"Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64}
|
"Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64,
|
||||||
|
"Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64,
|
||||||
|
"Bool": mstype.bool_}
|
||||||
|
|
||||||
|
tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uint16": np.uint16,
|
||||||
|
"Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64,
|
||||||
|
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_}
|
||||||
|
|
||||||
tensor_to_np_type = {"Int8": np.int8, "Int16": np.int16, "Int32": np.int32, "Int64": np.int64,
|
|
||||||
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64}
|
|
||||||
|
|
||||||
def _special_process_par(par, new_par):
|
def _special_process_par(par, new_par):
|
||||||
"""
|
"""
|
||||||
|
@ -442,3 +447,64 @@ def export(net, *inputs, file_name, file_format='GEIR'):
|
||||||
# restore network training mode
|
# restore network training mode
|
||||||
if is_training:
|
if is_training:
|
||||||
net.set_train(mode=True)
|
net.set_train(mode=True)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_print(print_file_name):
|
||||||
|
"""
|
||||||
|
Loads Print data from a specified file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
print_file_name (str): The file name of save print data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List, element of list is Tensor.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Print file is incorrect.
|
||||||
|
"""
|
||||||
|
if not os.path.realpath(print_file_name):
|
||||||
|
raise ValueError("Please input the correct print file name.")
|
||||||
|
|
||||||
|
if os.path.getsize(print_file_name) == 0:
|
||||||
|
raise ValueError("The print file may be empty, please make sure enter the correct file name.")
|
||||||
|
|
||||||
|
logger.info("Execute load print process.")
|
||||||
|
print_list = Print()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(print_file_name, "rb") as f:
|
||||||
|
pb_content = f.read()
|
||||||
|
print_list.ParseFromString(pb_content)
|
||||||
|
except BaseException as e:
|
||||||
|
logger.error("Failed to read the print file %s, please check the correct of the file.", print_file_name)
|
||||||
|
raise ValueError(e.__str__())
|
||||||
|
|
||||||
|
tensor_list = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
for print_ in print_list.value:
|
||||||
|
# String type
|
||||||
|
if print_.HasField("desc"):
|
||||||
|
tensor_list.append(print_.desc)
|
||||||
|
elif print_.HasField("tensor"):
|
||||||
|
dims = print_.tensor.dims
|
||||||
|
data_type = print_.tensor.tensor_type
|
||||||
|
data = print_.tensor.tensor_content
|
||||||
|
np_type = tensor_to_np_type[data_type]
|
||||||
|
param_data = np.fromstring(data, np_type)
|
||||||
|
ms_type = tensor_to_ms_type[data_type]
|
||||||
|
param_dim = []
|
||||||
|
for dim in dims:
|
||||||
|
param_dim.append(dim)
|
||||||
|
if param_dim:
|
||||||
|
param_value = param_data.reshape(param_dim)
|
||||||
|
tensor_list.append(Tensor(param_value, ms_type))
|
||||||
|
# Scale type
|
||||||
|
else:
|
||||||
|
tensor_list.append(Tensor(param_data, ms_type))
|
||||||
|
|
||||||
|
except BaseException as e:
|
||||||
|
logger.error("Failed to load the print file %s.", print_list)
|
||||||
|
raise RuntimeError(e.__str__())
|
||||||
|
|
||||||
|
return tensor_list
|
||||||
|
|
|
@ -16,8 +16,9 @@
|
||||||
import os
|
import os
|
||||||
import stat
|
import stat
|
||||||
import time
|
import time
|
||||||
import pytest
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
|
@ -33,7 +34,7 @@ from mindspore.train.serialization import save_checkpoint, load_checkpoint, load
|
||||||
_exec_save_checkpoint, export, _save_graph
|
_exec_save_checkpoint, export, _save_graph
|
||||||
from ..ut_filter import non_graph_engine
|
from ..ut_filter import non_graph_engine
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE, print_file_path="print.pb")
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
|
@ -327,8 +328,52 @@ def test_binary_export():
|
||||||
export(net, input_data, file_name="./me_binary_export.pb", file_format="BINARY")
|
export(net, input_data, file_name="./me_binary_export.pb", file_format="BINARY")
|
||||||
|
|
||||||
|
|
||||||
|
class PrintNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(PrintNet, self).__init__()
|
||||||
|
self.print = P.Print()
|
||||||
|
|
||||||
|
def construct(self, int8, uint8, int16, uint16, int32, uint32, int64, uint64, flt16, flt32, flt64, bool_,
|
||||||
|
scale1, scale2):
|
||||||
|
self.print('============tensor int8:==============', int8)
|
||||||
|
self.print('============tensor uint8:==============', uint8)
|
||||||
|
self.print('============tensor int16:==============', int16)
|
||||||
|
self.print('============tensor uint16:==============', uint16)
|
||||||
|
self.print('============tensor int32:==============', int32)
|
||||||
|
self.print('============tensor uint32:==============', uint32)
|
||||||
|
self.print('============tensor int64:==============', int64)
|
||||||
|
self.print('============tensor uint64:==============', uint64)
|
||||||
|
self.print('============tensor float16:==============', flt16)
|
||||||
|
self.print('============tensor float32:==============', flt32)
|
||||||
|
self.print('============tensor float64:==============', flt64)
|
||||||
|
self.print('============tensor bool:==============', bool_)
|
||||||
|
self.print('============tensor scale1:==============', scale1)
|
||||||
|
self.print('============tensor scale2:==============', scale2)
|
||||||
|
return int8, uint8, int16, uint16, int32, uint32, int64, uint64, flt16, flt32, flt64, bool_, scale1, scale2
|
||||||
|
|
||||||
|
|
||||||
|
def test_print():
|
||||||
|
print_net = PrintNet()
|
||||||
|
int8 = Tensor(np.random.randint(100, size=(10, 10), dtype="int8"))
|
||||||
|
uint8 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint8"))
|
||||||
|
int16 = Tensor(np.random.randint(100, size=(10, 10), dtype="int16"))
|
||||||
|
uint16 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint16"))
|
||||||
|
int32 = Tensor(np.random.randint(100, size=(10, 10), dtype="int32"))
|
||||||
|
uint32 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint32"))
|
||||||
|
int64 = Tensor(np.random.randint(100, size=(10, 10), dtype="int64"))
|
||||||
|
uint64 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint64"))
|
||||||
|
float16 = Tensor(np.random.rand(224, 224).astype(np.float16))
|
||||||
|
float32 = Tensor(np.random.rand(224, 224).astype(np.float32))
|
||||||
|
float64 = Tensor(np.random.rand(224, 224).astype(np.float64))
|
||||||
|
bool_ = Tensor(np.arange(-10, 10, 2).astype(np.bool_))
|
||||||
|
scale1 = Tensor(np.array(1))
|
||||||
|
scale2 = Tensor(np.array(0.1))
|
||||||
|
print_net(int8, uint8, int16, uint16, int32, uint32, int64, uint64, float16, float32, float64, bool_, scale1,
|
||||||
|
scale2)
|
||||||
|
|
||||||
|
|
||||||
def teardown_module():
|
def teardown_module():
|
||||||
files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt']
|
files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt', 'print.pb']
|
||||||
for item in files:
|
for item in files:
|
||||||
file_name = './' + item
|
file_name = './' + item
|
||||||
if not os.path.exists(file_name):
|
if not os.path.exists(file_name):
|
||||||
|
|
Loading…
Reference in New Issue