From ac23a179bf50c629d807f708f40bf869dac75e9a Mon Sep 17 00:00:00 2001 From: jinyaohui Date: Fri, 19 Jun 2020 12:46:10 +0800 Subject: [PATCH] save print data to file --- mindspore/ccsrc/CMakeLists.txt | 1 + mindspore/ccsrc/pipeline/init.cc | 3 +- mindspore/ccsrc/utils/context/ms_context.cc | 1 + mindspore/ccsrc/utils/context/ms_context.h | 3 + mindspore/ccsrc/utils/print.proto | 39 +++++++++ mindspore/ccsrc/utils/tensorprint_utils.cc | 96 +++++++++++++++++++-- mindspore/ccsrc/utils/tensorprint_utils.h | 2 + mindspore/context.py | 11 ++- mindspore/train/serialization.py | 74 +++++++++++++++- tests/ut/python/utils/test_serialize.py | 51 ++++++++++- 10 files changed, 265 insertions(+), 16 deletions(-) create mode 100644 mindspore/ccsrc/utils/print.proto diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 88f88d49e93..627105e88c8 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -80,6 +80,7 @@ if (ENABLE_DUMP_PROTO) "utils/anf_ir.proto" "utils/summary.proto" "utils/checkpoint.proto" + "utils/print.proto" ) ms_protobuf_generate_py(PY_SRCS PY_HDRS PY_PYS ${PROTO_PY}) diff --git a/mindspore/ccsrc/pipeline/init.cc b/mindspore/ccsrc/pipeline/init.cc index 998c530cf8c..d75f0c903c4 100644 --- a/mindspore/ccsrc/pipeline/init.cc +++ b/mindspore/ccsrc/pipeline/init.cc @@ -146,7 +146,8 @@ PYBIND11_MODULE(_c_expression, m) { .def("get_check_bprop_flag", &mindspore::MsContext::check_bprop_flag, "Get whether to check bprop.") .def("set_check_bprop_flag", &mindspore::MsContext::set_check_bprop_flag, "Set whether to check bprop.") .def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.") - .def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size."); + .def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size.") + .def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print."); (void)py::class_>(m, "MpiConfig") .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") diff --git a/mindspore/ccsrc/utils/context/ms_context.cc b/mindspore/ccsrc/utils/context/ms_context.cc index 37d11264b6a..f9f5fa1ef12 100644 --- a/mindspore/ccsrc/utils/context/ms_context.cc +++ b/mindspore/ccsrc/utils/context/ms_context.cc @@ -83,6 +83,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { profiling_options_ = "training_trace"; check_bprop_flag_ = false; max_device_memory_ = kDefaultMaxDeviceMemory; + print_file_path_ = ""; } std::shared_ptr MsContext::GetInstance() { diff --git a/mindspore/ccsrc/utils/context/ms_context.h b/mindspore/ccsrc/utils/context/ms_context.h index cfedefe3d5f..ec3d2f40a73 100644 --- a/mindspore/ccsrc/utils/context/ms_context.h +++ b/mindspore/ccsrc/utils/context/ms_context.h @@ -147,6 +147,8 @@ class MsContext { std::string profiling_options() const { return profiling_options_; } bool check_bprop_flag() const { return check_bprop_flag_; } void set_check_bprop_flag(bool check_bprop_flag) { check_bprop_flag_ = check_bprop_flag; } + void set_print_file_path(const std::string &file) { print_file_path_ = file; } + const std::string &print_file_path() const { return print_file_path_; } float max_device_memory() const { return max_device_memory_; } void set_max_device_memory(float max_device_memory) { max_device_memory_ = max_device_memory; } @@ -192,6 +194,7 @@ class MsContext { std::string profiling_options_; bool check_bprop_flag_; float max_device_memory_; + std::string print_file_path_; }; } // namespace mindspore diff --git a/mindspore/ccsrc/utils/print.proto b/mindspore/ccsrc/utils/print.proto new file mode 100644 index 00000000000..a82791bccfa --- /dev/null +++ b/mindspore/ccsrc/utils/print.proto @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto2"; + +package mindspore.prntpb; + +message TensorProto { + // The shape of the tensor. + repeated int64 dims = 1; + // The type of the tensor. + required string tensor_type = 2; + // The data of the tensor. + required bytes tensor_content = 3; +} + + +message Print { + message Value { + oneof value { + string desc = 1; + TensorProto tensor = 2; + } + } + repeated Value value = 1; +} diff --git a/mindspore/ccsrc/utils/tensorprint_utils.cc b/mindspore/ccsrc/utils/tensorprint_utils.cc index f4715b22a86..0d464e88a89 100644 --- a/mindspore/ccsrc/utils/tensorprint_utils.cc +++ b/mindspore/ccsrc/utils/tensorprint_utils.cc @@ -47,6 +47,18 @@ static std::map type_size_map = { {"int64_t", sizeof(int64_t)}, {"uint64_t", sizeof(uint64_t)}, {"float16", sizeof(float) / 2}, {"float", sizeof(float)}, {"double", sizeof(double)}, {"bool", sizeof(bool)}}; +std::string GetParseType(const std::string &tensorType_) { + static const std::map print_parse_map = { + {"int8_t", "Int8"}, {"uint8_t", "Uint8"}, {"int16_t", "Int16"}, {"uint16_t", "Uint16"}, + {"int32_t", "Int32"}, {"uint32_t", "Uint32"}, {"int64_t", "Int64"}, {"uint64_t", "Uint64"}, + {"float16", "Float16"}, {"float", "Float32"}, {"double", "Float64"}, {"bool", "Bool"}}; + auto type_iter = print_parse_map.find(tensorType_); + if (type_iter == print_parse_map.end()) { + MS_LOG(EXCEPTION) << "type of tensor need to print is not support " << tensorType_; + } + return type_iter->second; +} + bool ParseTensorShape(const std::string &input_shape_str, std::vector *const tensor_shape, size_t *dims) { if (tensor_shape == nullptr) { return false; @@ -141,7 +153,7 @@ void convertDataItem2Scalar(const char *str_data_ptr, const string &tensor_type, } else { MS_LOG(EXCEPTION) << "Cannot print scalar because of unsupport data type: " << tensor_type << "."; } -} // namespace mindspore +} bool judgeLengthValid(const size_t str_len, const string &tensor_type) { auto type_iter = type_size_map.find(tensor_type); @@ -200,14 +212,84 @@ bool ConvertDataItem2Tensor(const std::vector &items) { return ret_end_sequence; } -void TensorPrint::operator()() { - while (true) { - std::vector bundle; - if (tdt::TdtHostPopData("_npu_log", bundle) != 0) { +bool SaveDataItem2File(const std::vector &items, const std::string &print_file_path, prntpb::Print print, + std::fstream *output) { + bool ret_end_sequence = false; + for (auto &item : items) { + if (item.dataType_ == tdt::TDT_END_OF_SEQUENCE) { + ret_end_sequence = true; break; } - if (ConvertDataItem2Tensor(bundle)) { - break; + prntpb::Print_Value *value = print.add_value(); + std::shared_ptr str_data_ptr = std::static_pointer_cast(item.dataPtr_); + MS_EXCEPTION_IF_NULL(str_data_ptr); + if (item.tensorShape_ == kShapeScalar || item.tensorShape_ == kShapeNone) { + if (!judgeLengthValid(str_data_ptr->size(), item.tensorType_)) { + MS_LOG(EXCEPTION) << "Print op receive data length is invalid."; + } + } + + std::vector tensor_shape; + size_t totaldims = 1; + if (!ParseTensorShape(item.tensorShape_, &tensor_shape, &totaldims)) { + MS_LOG(EXCEPTION) << "Tensor print can not parse tensor shape, receive info" << item.tensorShape_; + } + + if (item.tensorType_ == "string") { + std::string data(reinterpret_cast(str_data_ptr->c_str()), item.dataLen_); + value->set_desc(data); + } else { + auto parse_type = GetParseType(item.tensorType_); + prntpb::TensorProto *tensor = value->mutable_tensor(); + if (!(item.tensorShape_ == kShapeScalar) && !(item.tensorShape_ == kShapeNone)) { + for (const auto &dim : tensor_shape) { + tensor->add_dims(static_cast<::google::protobuf::int64>(dim)); + } + } + tensor->set_tensor_type(parse_type); + std::string data(reinterpret_cast(str_data_ptr->c_str()), item.dataLen_); + tensor->set_tensor_content(data); + } + + if (!print.SerializeToOstream(output)) { + MS_LOG(EXCEPTION) << "Save print file:" << print_file_path << " fail."; + } + print.Clear(); + } + return ret_end_sequence; +} + +void TensorPrint::operator()() { + prntpb::Print print; + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + std::string print_file_path = ms_context->print_file_path(); + if (print_file_path == "") { + while (true) { + std::vector bundle; + if (tdt::TdtHostPopData("_npu_log", bundle) != 0) { + break; + } + if (ConvertDataItem2Tensor(bundle)) { + break; + } + } + } else { + std::fstream output(print_file_path, std::ios::out | std::ios::trunc | std::ios::binary); + while (true) { + std::vector bundle; + if (tdt::TdtHostPopData("_npu_log", bundle) != 0) { + break; + } + if (SaveDataItem2File(bundle, print_file_path, print, &output)) { + break; + } + } + output.close(); + std::string path_string = print_file_path; + if (chmod(common::SafeCStr(path_string), S_IRUSR) == -1) { + MS_LOG(ERROR) << "Modify file:" << print_file_path << " to r fail."; + return; } } } diff --git a/mindspore/ccsrc/utils/tensorprint_utils.h b/mindspore/ccsrc/utils/tensorprint_utils.h index c8442e62913..4a40862ea35 100644 --- a/mindspore/ccsrc/utils/tensorprint_utils.h +++ b/mindspore/ccsrc/utils/tensorprint_utils.h @@ -23,6 +23,8 @@ #include "tdt/tsd_client.h" #include "tdt/tdt_host_interface.h" #include "tdt/data_common.h" +#include "proto/print.pb.h" +#include "utils/context/ms_context.h" #endif namespace mindspore { class TensorPrint { diff --git a/mindspore/context.py b/mindspore/context.py index 1e7ba3b28b7..d3d00e8cfd1 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -346,6 +346,15 @@ class _Context: raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"") self._context_handle.set_max_device_memory(max_device_memory_value) + @property + def print_file_path(self): + return None + + @print_file_path.setter + def print_file_path(self, file): + self._context_handle.set_print_file_path(file) + + def check_input_format(x): import re pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB' @@ -473,7 +482,7 @@ def reset_auto_parallel_context(): save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool, save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool, - check_bprop=bool, max_device_memory=str) + check_bprop=bool, max_device_memory=str, print_file_path=str) def set_context(**kwargs): """ Sets context for running environment. diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 8048eedeecd..c39104c6ffb 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -21,6 +21,7 @@ import mindspore.nn as nn import mindspore.context as context from mindspore import log as logger from mindspore.train.checkpoint_pb2 import Checkpoint +from mindspore.train.print_pb2 import Print from mindspore.common.tensor import Tensor from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter @@ -30,11 +31,15 @@ from mindspore._checkparam import check_input_data __all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export"] -tensor_to_ms_type = {"Int8": mstype.int8, "Int16": mstype.int16, "Int32": mstype.int32, "Int64": mstype.int64, - "Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64} +tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, + "Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64, + "Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64, + "Bool": mstype.bool_} + +tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uint16": np.uint16, + "Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64, + "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_} -tensor_to_np_type = {"Int8": np.int8, "Int16": np.int16, "Int32": np.int32, "Int64": np.int64, - "Float16": np.float16, "Float32": np.float32, "Float64": np.float64} def _special_process_par(par, new_par): """ @@ -442,3 +447,64 @@ def export(net, *inputs, file_name, file_format='GEIR'): # restore network training mode if is_training: net.set_train(mode=True) + + +def parse_print(print_file_name): + """ + Loads Print data from a specified file. + + Args: + print_file_name (str): The file name of save print data. + + Returns: + List, element of list is Tensor. + + Raises: + ValueError: Print file is incorrect. + """ + if not os.path.realpath(print_file_name): + raise ValueError("Please input the correct print file name.") + + if os.path.getsize(print_file_name) == 0: + raise ValueError("The print file may be empty, please make sure enter the correct file name.") + + logger.info("Execute load print process.") + print_list = Print() + + try: + with open(print_file_name, "rb") as f: + pb_content = f.read() + print_list.ParseFromString(pb_content) + except BaseException as e: + logger.error("Failed to read the print file %s, please check the correct of the file.", print_file_name) + raise ValueError(e.__str__()) + + tensor_list = [] + + try: + for print_ in print_list.value: + # String type + if print_.HasField("desc"): + tensor_list.append(print_.desc) + elif print_.HasField("tensor"): + dims = print_.tensor.dims + data_type = print_.tensor.tensor_type + data = print_.tensor.tensor_content + np_type = tensor_to_np_type[data_type] + param_data = np.fromstring(data, np_type) + ms_type = tensor_to_ms_type[data_type] + param_dim = [] + for dim in dims: + param_dim.append(dim) + if param_dim: + param_value = param_data.reshape(param_dim) + tensor_list.append(Tensor(param_value, ms_type)) + # Scale type + else: + tensor_list.append(Tensor(param_data, ms_type)) + + except BaseException as e: + logger.error("Failed to load the print file %s.", print_list) + raise RuntimeError(e.__str__()) + + return tensor_list diff --git a/tests/ut/python/utils/test_serialize.py b/tests/ut/python/utils/test_serialize.py index 7cf3e88b2f2..19e9bd72e69 100644 --- a/tests/ut/python/utils/test_serialize.py +++ b/tests/ut/python/utils/test_serialize.py @@ -16,8 +16,9 @@ import os import stat import time -import pytest + import numpy as np +import pytest import mindspore.common.dtype as mstype import mindspore.nn as nn @@ -33,7 +34,7 @@ from mindspore.train.serialization import save_checkpoint, load_checkpoint, load _exec_save_checkpoint, export, _save_graph from ..ut_filter import non_graph_engine -context.set_context(mode=context.GRAPH_MODE) +context.set_context(mode=context.GRAPH_MODE, print_file_path="print.pb") class Net(nn.Cell): @@ -327,8 +328,52 @@ def test_binary_export(): export(net, input_data, file_name="./me_binary_export.pb", file_format="BINARY") +class PrintNet(nn.Cell): + def __init__(self): + super(PrintNet, self).__init__() + self.print = P.Print() + + def construct(self, int8, uint8, int16, uint16, int32, uint32, int64, uint64, flt16, flt32, flt64, bool_, + scale1, scale2): + self.print('============tensor int8:==============', int8) + self.print('============tensor uint8:==============', uint8) + self.print('============tensor int16:==============', int16) + self.print('============tensor uint16:==============', uint16) + self.print('============tensor int32:==============', int32) + self.print('============tensor uint32:==============', uint32) + self.print('============tensor int64:==============', int64) + self.print('============tensor uint64:==============', uint64) + self.print('============tensor float16:==============', flt16) + self.print('============tensor float32:==============', flt32) + self.print('============tensor float64:==============', flt64) + self.print('============tensor bool:==============', bool_) + self.print('============tensor scale1:==============', scale1) + self.print('============tensor scale2:==============', scale2) + return int8, uint8, int16, uint16, int32, uint32, int64, uint64, flt16, flt32, flt64, bool_, scale1, scale2 + + +def test_print(): + print_net = PrintNet() + int8 = Tensor(np.random.randint(100, size=(10, 10), dtype="int8")) + uint8 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint8")) + int16 = Tensor(np.random.randint(100, size=(10, 10), dtype="int16")) + uint16 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint16")) + int32 = Tensor(np.random.randint(100, size=(10, 10), dtype="int32")) + uint32 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint32")) + int64 = Tensor(np.random.randint(100, size=(10, 10), dtype="int64")) + uint64 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint64")) + float16 = Tensor(np.random.rand(224, 224).astype(np.float16)) + float32 = Tensor(np.random.rand(224, 224).astype(np.float32)) + float64 = Tensor(np.random.rand(224, 224).astype(np.float64)) + bool_ = Tensor(np.arange(-10, 10, 2).astype(np.bool_)) + scale1 = Tensor(np.array(1)) + scale2 = Tensor(np.array(0.1)) + print_net(int8, uint8, int16, uint16, int32, uint32, int64, uint64, float16, float32, float64, bool_, scale1, + scale2) + + def teardown_module(): - files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt'] + files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt', 'print.pb'] for item in files: file_name = './' + item if not os.path.exists(file_name):