diff --git a/CMakeLists.txt b/CMakeLists.txt index 37c3288f12b..324eca867b9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,7 +12,7 @@ if (NOT CMAKE_SYSTEM_NAME MATCHES "Windows") endif () if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") - set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Werror -Wno-return-std-move -Wno-unused-private-field -Wno-unused-lambda-capture -Wno-sign-compare -Wno-overloaded-virtual -Wno-unneeded-internal-declaration -Wno-unused-variable -Wno-pessimizing-move -Wno-inconsistent-missing-override -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2") + set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Werror -Wno-return-std-move -Wno-unused-private-field -Wno-unused-lambda-capture -Wno-sign-compare -Wno-overloaded-virtual -Wno-unneeded-internal-declaration -Wno-unused-variable -Wno-pessimizing-move -Wno-inconsistent-missing-override -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2") else() set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Wl,--allow-shlib-undefined -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2") endif() diff --git a/build.sh b/build.sh index 70718bf89b0..608b7234252 100755 --- a/build.sh +++ b/build.sh @@ -25,7 +25,7 @@ usage() echo "Usage:" echo "bash build.sh [-d] [-r] [-v] [-c on|off] [-t on|off] [-g on|off] [-h] [-b ge] [-m infer|train] \\" echo " [-a on|off] [-Q on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|d|cpu] \\" - echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 9.2|10.1] [-I] [-K]" + echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 9.2|10.1] [-I] [-K] [-B on|off]" echo "" echo "Options:" echo " -d Debug mode" @@ -54,6 +54,7 @@ usage() echo " -I Compile predict, default off" echo " -K Compile with AKG, default off" echo " -s Enable serving module, default off" + echo " -B Enable debugger, default off" } # check value of input is 'on' or 'off' @@ -94,8 +95,10 @@ checkopts() PREDICT_PLATFORM="" ENABLE_AKG="on" ENABLE_SERVING="off" + ENABLE_DEBUGGER="off" + # Process the options - while getopts 'drvj:c:t:hsb:a:g:p:ie:m:I:LRP:Q:D:zM:V:K:s' opt + while getopts 'drvj:c:t:hsb:a:g:p:ie:m:I:LRP:Q:D:zM:V:K:sB:' opt do OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]') case "${opt}" in @@ -240,6 +243,11 @@ checkopts() ENABLE_SERVING="on" echo "enable serving" ;; + B) + check_on_off $OPTARG B + ENABLE_DEBUGGER="on" + echo "enable debugger" + ;; *) echo "Unknown option ${opt}!" usage @@ -322,6 +330,9 @@ build_mindspore() if [[ "X$ENABLE_SERVING" = "Xon" ]]; then CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_SERVING=ON" fi + if [[ "X$ENABLE_DEBUGGER" = "Xon" ]]; then + CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_DEBUGGER=ON" + fi echo "${CMAKE_ARGS}" if [[ "X$INC_BUILD" = "Xoff" ]]; then diff --git a/cmake/external_libs/absl.cmake b/cmake/external_libs/absl.cmake new file mode 100644 index 00000000000..6087b651289 --- /dev/null +++ b/cmake/external_libs/absl.cmake @@ -0,0 +1,14 @@ +mindspore_add_pkg(absl + VER 20200225.2 + LIBS absl_strings absl_throw_delegate absl_raw_logging_internal absl_int128 absl_bad_optional_access + URL https://github.com/abseil/abseil-cpp/archive/20200225.2.tar.gz + MD5 73f2b6e72f1599a9139170c29482ddc4 + CMAKE_OPTION -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=TRUE) + +include_directories(${absl_INC}) + +add_library(mindspore::absl_strings ALIAS absl::absl_strings) +add_library(mindspore::absl_throw_delegate ALIAS absl::absl_throw_delegate) +add_library(mindspore::absl_raw_logging_internal ALIAS absl::absl_raw_logging_internal) +add_library(mindspore::absl_int128 ALIAS absl::absl_int128) +add_library(mindspore::absl_bad_optional_access ALIAS absl::absl_bad_optional_access) diff --git a/cmake/external_libs/c-ares.cmake b/cmake/external_libs/c-ares.cmake new file mode 100644 index 00000000000..9bb547f2db8 --- /dev/null +++ b/cmake/external_libs/c-ares.cmake @@ -0,0 +1,12 @@ +mindspore_add_pkg(c-ares + VER 1.15.0 + LIBS cares + URL https://github.com/c-ares/c-ares/releases/download/cares-1_15_0/c-ares-1.15.0.tar.gz + MD5 d2391da274653f7643270623e822dff7 + CMAKE_OPTION -DCMAKE_BUILD_TYPE:STRING=Release + -DCARES_SHARED:BOOL=OFF + -DCARES_STATIC:BOOL=ON + -DCARES_STATIC_PIC:BOOL=ON) + +include_directories(${c-ares_INC}) +add_library(mindspore::cares ALIAS c-ares::cares) diff --git a/cmake/external_libs/grpc.cmake b/cmake/external_libs/grpc.cmake new file mode 100644 index 00000000000..7496cfd88e5 --- /dev/null +++ b/cmake/external_libs/grpc.cmake @@ -0,0 +1,110 @@ +set(grpc_USE_STATIC_LIBS ON) +if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + set(grpc_CXXFLAGS "-fstack-protector-all -Wno-uninitialized -Wno-unused-parameter -fPIC -fvisibility=hidden -D_FORTIFY_SOURCE=2 -O2") +elseif (${CMAKE_SYSTEM_NAME} MATCHES "Windows") + set(grpc_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fvisibility=hidden -D_FORTIFY_SOURCE=2 -O2") +else() + set(grpc_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fvisibility=hidden -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2") +endif() + +set(grpc_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") + + +if (EXISTS ${protobuf_ROOT}/lib64) + set(_FINDPACKAGE_PROTOBUF_CONFIG_DIR "${protobuf_ROOT}/lib64/cmake/protobuf") +else() + set(_FINDPACKAGE_PROTOBUF_CONFIG_DIR "${protobuf_ROOT}/lib/cmake/protobuf") +endif() +message("grpc using Protobuf_DIR : " ${_FINDPACKAGE_PROTOBUF_CONFIG_DIR}) + +if (EXISTS ${absl_ROOT}/lib64) + set(_FINDPACKAGE_ABSL_CONFIG_DIR "${absl_ROOT}/lib64/cmake/absl") +else() + set(_FINDPACKAGE_ABSL_CONFIG_DIR "${absl_ROOT}/lib/cmake/absl") +endif() +message("grpc using absl_DIR : " ${_FINDPACKAGE_ABSL_CONFIG_DIR}) + +set(_CMAKE_ARGS_OPENSSL_ROOT_DIR "") +if (OPENSSL_ROOT_DIR) + set(_CMAKE_ARGS_OPENSSL_ROOT_DIR "-DOPENSSL_ROOT_DIR:PATH=${OPENSSL_ROOT_DIR}") +endif() + +mindspore_add_pkg(grpc + VER 1.27.3 + LIBS grpc++ grpc gpr upb address_sorting + EXE grpc_cpp_plugin + URL https://github.com/grpc/grpc/archive/v1.27.3.tar.gz + MD5 0c6c3fc8682d4262dd0e5e6fabe1a7e2 + CMAKE_OPTION -DCMAKE_BUILD_TYPE:STRING=Release + -DgRPC_INSTALL:BOOL=ON + -DgRPC_BUILD_TESTS:BOOL=OFF + -DgRPC_PROTOBUF_PROVIDER:STRING=package + -DgRPC_PROTOBUF_PACKAGE_TYPE:STRING=CONFIG + -DProtobuf_DIR:PATH=${_FINDPACKAGE_PROTOBUF_CONFIG_DIR} + -DgRPC_ZLIB_PROVIDER:STRING=package + -DZLIB_ROOT:PATH=${zlib_ROOT} + -DgRPC_ABSL_PROVIDER:STRING=package + -Dabsl_DIR:PATH=${_FINDPACKAGE_ABSL_CONFIG_DIR} + -DgRPC_CARES_PROVIDER:STRING=package + -Dc-ares_DIR:PATH=${c-ares_ROOT}/lib/cmake/c-ares + -DgRPC_SSL_PROVIDER:STRING=package + ${_CMAKE_ARGS_OPENSSL_ROOT_DIR} + ) + +include_directories(${grpc_INC}) + +add_library(mindspore::grpc++ ALIAS grpc::grpc++) + +# link other grpc libs +target_link_libraries(grpc::grpc++ INTERFACE grpc::grpc grpc::gpr grpc::upb grpc::address_sorting) + +# link built dependencies +target_link_libraries(grpc::grpc++ INTERFACE mindspore::z) +target_link_libraries(grpc::grpc++ INTERFACE mindspore::cares) +target_link_libraries(grpc::grpc++ INTERFACE mindspore::absl_strings mindspore::absl_throw_delegate + mindspore::absl_raw_logging_internal mindspore::absl_int128 mindspore::absl_bad_optional_access) + +# link system openssl +find_package(OpenSSL REQUIRED) +target_link_libraries(grpc::grpc++ INTERFACE OpenSSL::SSL OpenSSL::Crypto) + + +function(ms_grpc_generate c_var h_var) + if(NOT ARGN) + message(SEND_ERROR "Error: ms_grpc_generate() called without any proto files") + return() + endif() + + set(${c_var}) + set(${h_var}) + + foreach(file ${ARGN}) + get_filename_component(abs_file ${file} ABSOLUTE) + get_filename_component(file_name ${file} NAME_WE) + get_filename_component(file_dir ${abs_file} PATH) + file(RELATIVE_PATH rel_path ${CMAKE_CURRENT_SOURCE_DIR} ${file_dir}) + + list(APPEND ${c_var} "${CMAKE_BINARY_DIR}/proto/${file_name}.pb.cc") + list(APPEND ${h_var} "${CMAKE_BINARY_DIR}/proto/${file_name}.pb.h") + list(APPEND ${c_var} "${CMAKE_BINARY_DIR}/proto/${file_name}.grpc.pb.cc") + list(APPEND ${h_var} "${CMAKE_BINARY_DIR}/proto/${file_name}.grpc.pb.h") + + add_custom_command( + OUTPUT "${CMAKE_BINARY_DIR}/proto/${file_name}.pb.cc" + "${CMAKE_BINARY_DIR}/proto/${file_name}.pb.h" + "${CMAKE_BINARY_DIR}/proto/${file_name}.grpc.pb.cc" + "${CMAKE_BINARY_DIR}/proto/${file_name}.grpc.pb.h" + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} + COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_BINARY_DIR}/proto" + COMMAND protobuf::protoc --version + COMMAND protobuf::protoc -I${file_dir} --cpp_out=${CMAKE_BINARY_DIR}/proto + --grpc_out=${CMAKE_BINARY_DIR}/proto --plugin=protoc-gen-grpc=$ ${abs_file} + DEPENDS protobuf::protoc grpc::grpc_cpp_plugin ${abs_file} + COMMENT "Running C++ gRPC compiler on ${file}" VERBATIM) + endforeach() + + set_source_files_properties(${${c_var}} ${${h_var}} PROPERTIES GENERATED TRUE) + set(${c_var} ${${c_var}} PARENT_SCOPE) + set(${h_var} ${${h_var}} PARENT_SCOPE) + +endfunction() diff --git a/cmake/external_libs/zlib.cmake b/cmake/external_libs/zlib.cmake new file mode 100644 index 00000000000..06532ed8d73 --- /dev/null +++ b/cmake/external_libs/zlib.cmake @@ -0,0 +1,9 @@ +mindspore_add_pkg(zlib + VER 1.2.11 + LIBS z + URL https://github.com/madler/zlib/archive/v1.2.11.tar.gz + MD5 0095d2d2d1f3442ce1318336637b695f + CMAKE_OPTION -DCMAKE_BUILD_TYPE:STRING=Release) + +include_directories(${zlib_INC}) +add_library(mindspore::z ALIAS zlib::z) diff --git a/cmake/mind_expression.cmake b/cmake/mind_expression.cmake index 86337c1dd2d..403316ac478 100644 --- a/cmake/mind_expression.cmake +++ b/cmake/mind_expression.cmake @@ -14,6 +14,16 @@ include(${CMAKE_SOURCE_DIR}/cmake/external_libs/eigen.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/json.cmake) include(${CMAKE_SOURCE_DIR}/cmake/dependency_securec.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake) + +if (ENABLE_DEBUGGER) + # build dependencies of gRPC + include(${CMAKE_SOURCE_DIR}/cmake/external_libs/absl.cmake) + include(${CMAKE_SOURCE_DIR}/cmake/external_libs/c-ares.cmake) + include(${CMAKE_SOURCE_DIR}/cmake/external_libs/zlib.cmake) + # build gRPC + include(${CMAKE_SOURCE_DIR}/cmake/external_libs/grpc.cmake) +endif() + include(${CMAKE_SOURCE_DIR}/cmake/external_libs/pybind11.cmake) MESSAGE("go to link flatbuffers") include(${CMAKE_SOURCE_DIR}/cmake/external_libs/flatbuffers.cmake) diff --git a/cmake/options.cmake b/cmake/options.cmake index 3e03ed33395..33e4b47ef35 100644 --- a/cmake/options.cmake +++ b/cmake/options.cmake @@ -17,6 +17,7 @@ option(ENABLE_DUMP_E2E "Enable dump e2e file, default on" OFF) option(ENABLE_DUMP_IR "Enable dump funciton graph ir, default on" ON) option(ENABLE_MPI "enable mpi" OFF) option(ENABLE_AKG "enable akg" OFF) +option(ENABLE_DEBUGGER "enable debugger" OFF) if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") if (WIN32) @@ -112,3 +113,7 @@ endif() if(ENABLE_DUMP_E2E) add_compile_definitions(ENABLE_DUMP_E2E) endif() + +if(ENABLE_DEBUGGER) + add_compile_definitions(ENABLE_DEBUGGER) +endif() diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 80f82fd7ea8..0dc68783e8e 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -71,6 +71,17 @@ message("onnx proto path is :" ${ONNX_PROTO}) ms_protobuf_generate(ONNX_PROTO_SRCS ONNX_PROTO_HDRS ${ONNX_PROTO}) list(APPEND MINDSPORE_PROTO_LIST ${ONNX_PROTO_SRCS}) +if (ENABLE_DEBUGGER) + # debugger: compile proto files + include_directories("${CMAKE_BINARY_DIR}/debug/debugger") + file(GLOB_RECURSE DEBUGGER_PROTO_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "debug/debugger/debug_graph.proto") + ms_protobuf_generate(DEBUGGER_PROTO_SRCS DEBUGGER_PROTO_HDRS ${DEBUGGER_PROTO_LIST}) + file(GLOB_RECURSE DEBUGGER_GRPC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "debug/debugger/debug_grpc.proto") + ms_grpc_generate(DEBUGGER_GRPC_SRCS DEBUGGER_GRPC_HDRS ${DEBUGGER_GRPC_LIST}) + list(APPEND MINDSPORE_PROTO_LIST ${DEBUGGER_PROTO_SRCS}) + list(APPEND MINDSPORE_PROTO_LIST ${DEBUGGER_GRPC_SRCS}) +endif () + if (ENABLE_DUMP_PROTO) include_directories(${CMAKE_BINARY_DIR}) file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "utils/node_strategy.proto") @@ -125,6 +136,14 @@ endforeach () set_property(SOURCE ${SUB_OBJECTS_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ME) add_library(mindspore STATIC ${SUB_OBJECTS_SRC}) + +target_link_libraries(proto_input mindspore::protobuf) + +if (ENABLE_DEBUGGER) + # debugger: link grpc + target_link_libraries(proto_input mindspore::grpc++) +endif() + target_link_libraries(mindspore proto_input) if (ENABLE_CPU AND ENABLE_MPI) target_link_libraries(mindspore securec mindspore::flatbuffers mindspore::ompi) @@ -217,6 +236,7 @@ if (USE_GLOG) endif () if (ENABLE_DUMP_PROTO) + message("add protobuf lib to c_expression") target_link_libraries(_c_expression PRIVATE mindspore::protobuf) endif () diff --git a/mindspore/ccsrc/debug/CMakeLists.txt b/mindspore/ccsrc/debug/CMakeLists.txt index 30b10a17fd4..ba0c5e07ac2 100644 --- a/mindspore/ccsrc/debug/CMakeLists.txt +++ b/mindspore/ccsrc/debug/CMakeLists.txt @@ -10,6 +10,15 @@ set(_DEBUG_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/trace.cc" ) +if (ENABLE_DEBUGGER) + list(APPEND _DEBUG_SRC_LIST + "${CMAKE_CURRENT_SOURCE_DIR}/debugger/debugger.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/debugger/grpc_client.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/debugger/proto_exporter.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/debug_services.cc" + ) +endif (ENABLE_DEBUGGER) + if (ENABLE_DUMP_E2E) list(APPEND _DEBUG_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/e2e_dump.cc") endif (ENABLE_DUMP_E2E) diff --git a/mindspore/ccsrc/debug/debug_services.cc b/mindspore/ccsrc/debug/debug_services.cc new file mode 100644 index 00000000000..8d46e00f192 --- /dev/null +++ b/mindspore/ccsrc/debug/debug_services.cc @@ -0,0 +1,194 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "debug/debug_services.h" +namespace mindspore { + +DebugServices::DebugServices() { + tensor_loader_ = new TensorLoader(); + uint32_t iter_num = -1; + tensor_loader_->set_iter_num(iter_num); +} + +DebugServices::DebugServices(const DebugServices &other) { + tensor_loader_ = other.tensor_loader_; + watchpoint_table = other.watchpoint_table; +} + +DebugServices &DebugServices::operator=(const DebugServices &other) { + if (this != &other) { + tensor_loader_ = other.tensor_loader_; + watchpoint_table = other.watchpoint_table; + } + return *this; +} + +DebugServices::~DebugServices() { delete tensor_loader_; } + +void DebugServices::add_watchpoint(unsigned int id, unsigned int watch_condition, + const std::vector> &check_node_list) { + std::lock_guard lg(lock_); + + watchpoint_t watchpoint_item; + + watchpoint_item.id = id; + + if (watch_condition == 0) { + watchpoint_item.conditions.nan.enabled = true; + } else if (watch_condition == 1) { + watchpoint_item.conditions.inf.enabled = true; + watchpoint_item.conditions.neg_inf.enabled = true; + } + + watchpoint_item.check_node_list = check_node_list; + + watchpoint_table[id] = watchpoint_item; +} + +void DebugServices::remove_watchpoint(unsigned int id) { + std::lock_guard lg(lock_); + watchpoint_table.erase(id); +} + +void DebugServices::check_watchpoints(std::vector *name, std::vector *slot, + std::vector *data_ptr, std::vector *data_size, + std::vector *condition, std::vector *wacthpoint_id) { + std::lock_guard lg(lock_); + + std::vector> tensor_list = tensor_loader_->GetTensor(); + + std::string current_tensor_name; + std::unordered_map watchpoints_to_check_table; + + for (std::size_t i = 0; i < tensor_list.size(); i++) { + current_tensor_name = tensor_list[i]->GetName(); + mindspore::tensor::TensorPtr tensor_ptr = tensor_list[i]->GetTensor(); + int tensor_data_type = tensor_ptr->data_type_c(); + + // check if we need to analyze this node and for which watchpoints we will check + // create a list of watchpoints to check + watchpoints_to_check_table.clear(); + for (auto w_table_item : watchpoint_table) { + // if the watchpoint is checking for a nan or inf and the current tensor is not of a float type, then + // don't check the watchpoint for this tensor + if (std::get<1>(w_table_item).conditions.inf.enabled || std::get<1>(w_table_item).conditions.neg_inf.enabled || + std::get<1>(w_table_item).conditions.nan.enabled) { + if (tensor_data_type != kNumberTypeFloat16 && tensor_data_type != kNumberTypeFloat && + tensor_data_type != kNumberTypeFloat32 && tensor_data_type != kNumberTypeFloat64) { + continue; + } + } + + auto check_node_list = std::get<1>(w_table_item).check_node_list; + + for (auto check_node : check_node_list) { + std::string w_name = std::get<0>(check_node); + bool w_type = std::get<1>(check_node); + + // check if the current node tensor name is included the watchpoint + std::string current_node_name = current_tensor_name.substr(0, current_tensor_name.find_first_of(":")); + if ((w_type == true && (current_tensor_name.find(w_name) != string::npos || w_name == "*")) || + (w_type == false && current_node_name == w_name)) { + watchpoints_to_check_table[w_table_item.second.id] = w_table_item.second; + break; + } + } + } + + // check if no watchpoints are valid for the current tensor + if (watchpoints_to_check_table.empty()) { + continue; + } + + // need to add support for float16 and float64, and other types when we support conditions beyond inf and nan + if (tensor_data_type != kNumberTypeFloat && tensor_data_type != kNumberTypeFloat32) { + continue; + } + + float *start_addr = reinterpret_cast(tensor_ptr->data_c(false)); + unsigned int num_elements = (tensor_ptr->data().nbytes()) / sizeof(float); + + std::unordered_map::iterator it_w_table_check; + std::vector hit_encountered; + + for (unsigned int index = 0; index < num_elements; index++) { + float x = start_addr[index]; + it_w_table_check = watchpoints_to_check_table.begin(); + + while (it_w_table_check != watchpoints_to_check_table.end()) { + if ((it_w_table_check->second.conditions.inf.enabled || it_w_table_check->second.conditions.neg_inf.enabled) && + isinf(x)) { + hit_encountered.push_back(it_w_table_check->second.id); + } else if (it_w_table_check->second.conditions.nan.enabled && isnan(x)) { + hit_encountered.push_back(it_w_table_check->second.id); + } + + ++it_w_table_check; + } + + if (hit_encountered.size()) { + for (auto it_hit_id = hit_encountered.begin(); it_hit_id != hit_encountered.end(); ++it_hit_id) { + std::string name_no_slot = current_tensor_name.substr(0, current_tensor_name.find_first_of(":")); + name->push_back(name_no_slot); + + slot->push_back(std::to_string(tensor_list[i]->GetSlot())); + data_ptr->push_back(reinterpret_cast(tensor_ptr->data_c(false))); + data_size->push_back(tensor_ptr->data().nbytes()); + + int condition_item = -1; + if (watchpoint_table[*it_hit_id].conditions.nan.enabled) { + condition_item = 0; + } else if (watchpoint_table[*it_hit_id].conditions.inf.enabled || + watchpoint_table[*it_hit_id].conditions.neg_inf.enabled) { + condition_item = 1; + } + condition->push_back(condition_item); + + wacthpoint_id->push_back(*it_hit_id); + + watchpoints_to_check_table.erase(*it_hit_id); + } + + hit_encountered.clear(); + } + + if (watchpoints_to_check_table.empty()) { + break; + } + } + } +} + +void DebugServices::read_nodes_tensors(std::vector name, std::vector *ret_name, + std::vector *data_ptr, std::vector *data_size, + std::vector *dtype, std::vector> *shape) { + std::vector>> result_list; + tensor_loader_->SearchTensors(name, &result_list); + + for (auto result : result_list) { + if (!std::get<1>(result)) { + continue; + } + ret_name->push_back(std::get<0>(result)); + data_ptr->push_back(reinterpret_cast(std::get<1>(result)->GetTensor()->data_c(false))); + data_size->push_back(std::get<1>(result)->GetTensor()->data().nbytes()); + dtype->push_back(std::get<1>(result)->GetTensor()->Dtype()); + shape->push_back(std::get<1>(result)->GetTensor()->shape()); + } +} + +TensorLoader *DebugServices::get_tensor_loader() const { return tensor_loader_; } + +} // namespace mindspore diff --git a/mindspore/ccsrc/debug/debug_services.h b/mindspore/ccsrc/debug/debug_services.h new file mode 100644 index 00000000000..b2fd41cd683 --- /dev/null +++ b/mindspore/ccsrc/debug/debug_services.h @@ -0,0 +1,95 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_DEBUG_DEBUG_SERVICES_H_ +#define MINDSPORE_CCSRC_DEBUG_DEBUG_SERVICES_H_ + +#include +#include +#include +#include +#include +#include +#include "debug/tensor_load.h" +#include "debug/tensor_data.h" +#include "ir/dtype.h" + +namespace mindspore { +class DebugServices { + public: + DebugServices(); + + DebugServices(const DebugServices &other); + + DebugServices &operator=(const DebugServices &other); + + ~DebugServices(); + + void add_watchpoint(unsigned int id, unsigned int watch_condition, + const std::vector> &check_node_list); + + void remove_watchpoint(unsigned int id); + + void check_watchpoints(std::vector *name, std::vector *slot, std::vector *data_ptr, + std::vector *data_size, std::vector *condition, + std::vector *wacthpoint_id); + + void read_nodes_tensors(std::vector name, std::vector *ret_name, + std::vector *data_ptr, std::vector *data_size, + std::vector *dtype, std::vector> *shape); + + TensorLoader *get_tensor_loader() const; + + private: + typedef struct condition_no_param { + bool enabled = false; + } condition_no_param_t; + + typedef struct condition_with_param { + bool enabled = false; + float parameter = 0; + } condition_with_param_t; + + typedef struct conditions { + condition_no_param_t inf; + condition_no_param_t neg_inf; + condition_no_param_t nan; + condition_with_param_t max_below; + condition_with_param_t max_above; + condition_with_param_t min_below; + condition_with_param_t min_above; + condition_with_param_t max_minus_min_below; + condition_with_param_t max_minus_min_above; + condition_with_param_t mean_below; + condition_with_param_t mean_above; + condition_with_param_t std_dev_below; + condition_with_param_t std_dev_above; + } conditions_t; + + typedef struct watchpoint { + unsigned int id; + conditions_t conditions; + std::vector> check_node_list; + } watchpoint_t; + + std::mutex lock_; + + std::unordered_map watchpoint_table; + + TensorLoader *tensor_loader_; +}; +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEBUG_DEBUG_SERVICES_H_ diff --git a/mindspore/ccsrc/debug/debugger/debug_graph.proto b/mindspore/ccsrc/debug/debugger/debug_graph.proto new file mode 100644 index 00000000000..042360fac37 --- /dev/null +++ b/mindspore/ccsrc/debug/debugger/debug_graph.proto @@ -0,0 +1,316 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto2"; + +package debugger; + +// Versioning +enum Version { + // unknown version + UNKNOWWN_VERSION = 0; + + // Initial version (IR VERSION 1), published on Sep 23, 2019 + IR_VERSION = 0x0000000000000001; +} + +// Data type definition +enum DataType { + DT_UNDEFINED = 0; + // Basic types. + DT_BOOL = 1; // bool + + DT_INT8 = 2; // int8_t + DT_INT16 = 3; // int16_t + DT_INT32 = 4; // int32_t + DT_INT64 = 5; // int64_t + + DT_UINT8 = 6; // uint8_t + DT_UINT16 = 7; // uint16_t + DT_UINT32 = 8; // uint32_t + DT_UINT64 = 9; // uint64_t + + DT_FLOAT16 = 10; // float 16 + DT_FLOAT32 = 11; // float 32 + DT_FLOAT64 = 12; // float 64 + + DT_STRING = 13; // string + DT_TENSOR = 14; // tensor + DT_GRAPH = 15; // graph + + // list type + DT_BOOLS = 16; // list of bool + + DT_INTS8 = 17; // list of int8_t + DT_INTS16 = 18; // list of int16_t + DT_INTS32 = 19; // list of int32_t + DT_INTS64 = 20; // list of int64_t + + DT_UINTS8 = 21; // list of uint8_t + DT_UINTS16 = 22; // list of uint16_t + DT_UINTS32 = 23; // list of uint32_t + DT_UINTS64 = 24; // list of uint64_t + + DT_FLOATS16 = 25; // list of float16 + DT_FLOATS32 = 26; // list of float32 + DT_FLOATS64 = 27; // list of float64 + + DT_STRINGS = 28; // list of string + DT_TENSORS = 29; // list of tensor + DT_GRAPHS = 30; // list of graph + + DT_TUPLE = 31; // tuple + DT_LIST = 32; // list + DT_DICT = 33; // dictionary + + // other types + DT_NONE = 34; // None + DT_SYM_INST = 35; // Symbolic Key Instance + + // type related type + DT_BASE_INT = 36; // type generic int + DT_BASE_UINT = 37; // type generate unsigned int + DT_BASE_FLOAT = 38; // type generate float + DT_TYPE = 39; // type type + DT_ANYTHING = 40; // type anything + DT_REFKEY = 41; // type refkey + DT_REF = 42; // type ref +} + +// Value definition for attribute value or parameter default value +message ValueProto { + // data type of value + optional DataType dtype = 1; // discriminator that indicates which field below is in use + + // Exactly ONE of the following fields must be present for this version of the IR + optional bool bool_val = 2; // bool + optional int64 int_val = 3; // int + optional uint64 uint_val = 4; // uint + optional float float_val = 5; // float + optional double double_val = 6; // double + optional string str_val = 7; // string + optional TensorProto tensor_val = 8; // tensor value + optional GraphProto graph = 9; // graph + + repeated bool bool_vals = 10; // list of bool + repeated int64 int_vals = 11; // list of int + repeated uint64 uint_vals = 12; // list of uint + repeated float float_vals = 13; // list of float + repeated double double_vals = 14; // list of double + repeated string str_vals = 15; // list of string + repeated TensorProto tensor_vals = 16; // list of tensor value + repeated GraphProto graphs = 17; // list of graph + + // tuple or list + repeated ValueProto values = 18; // tuple, list of value + + // dictionary + repeated NamedValueProto dict_val = 19; // dictionary info + + // filed for type type + optional TypeProto type_val = 20; // type type info +} + +message AttributeProto { + optional string name = 1; // attribute name + optional ValueProto value = 2; // attribute value +} + +message NamedValueProto { + optional string key = 1; // attribute name + optional ValueProto value = 2; // attribute value +} + +// Defines a tensor shape. +message TensorShapeProto { + // One dimension of the tensor. + message Dimension { + // Size of the tensor in that dimension. + // This value must be >= -1, but values of -1 are reserved for "unknown" + // shapes (values of -1 mean "unknown" dimension). + optional int64 size = 1; + + // Optional name of the tensor dimension. + optional string name = 2; + }; + + repeated Dimension dim = 1; +} + +// Types for graph input(parameter) and output +message TypeProto { + + message Tensor { + // This field MUST have a valid DataType value except DT_TENSOR + optional DataType elem_type = 1; + optional TensorShapeProto shape = 2; // for scalar, this field is not set + } + + // tuple type + message Sequence { + // The type and optional shape of elements of the tuple. + repeated TypeProto elem_types = 1; + }; + + // data type + optional DataType data_type = 1; + + oneof value { + // The type of a tensor. + Tensor tensor_type = 2; + + // The type of a tuple. + Sequence sequence_type = 3; + } +} + +// Defines information on graph parameters, including the name, the type, and +// the default value of parameter if exists. +message ParameterProto { + optional string name = 1; // parameter name + optional TypeProto type = 2; // parameter type + optional ValueProto default_val = 3; // default value of parameter if exists +} + +// Defines graph output information +message OutputProto { + optional string name = 1; // output node name + optional TypeProto type = 2; // output node type +} + +// Define node input information +message InputProto { + enum EdgeType { + DATA_EDGE = 0; // data edge + CONTROL_EDGE = 1; // control edge + } + + optional string name = 1; + optional EdgeType type = 2; +} + +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. +message NodeProto { + repeated InputProto input = 1; // namespace Value + optional string name = 2; // namespace Value + + // The symbolic identifier of the Operator to execute. + optional string op_type = 3; // namespace Operator + // The domain of the OperatorSet that specifies the operator named by op_type. + optional string scope = 4; // namespace Domain + + // Additional named attributes. + repeated AttributeProto attribute = 5; + + // Optional type info of this node + optional TypeProto output_type = 6; + + // other fields for debug + optional uint64 output_i = 7; + + // for debugger, full name with scope + optional string debug_name = 8; +} + +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto. +message ModelProto { + // ir version + optional int64 ir_version = 1; + + // Domain name of the model. + // We use reverse domain names as name space indicators. For example: + // `com.facebook.fair` or `com.microsoft.cognitiveservices` + // + // Together with `model_version` and GraphProto.name, this forms the unique identity of + // the graph. + optional string domain = 2; + + // The version of the graph encoded. See Version enum below. + optional int64 model_version = 3; + + // The parameterized graph that is evaluated to execute the model. + optional GraphProto graph = 4; + + // metadata info of opeartors + optional OperatorSetProto metadata_operators = 5; +}; + +message OperatorProto { + optional string name = 1; // used as key, must be distinct + optional bytes config = 2; // operator config info + optional bytes obj_info = 3; // operator related object info, e.g. content of operator binary or name +}; + +message OperatorSetProto { + repeated OperatorProto operators = 1; + optional string summary = 2; // summary info of operators, e.g. file position of operators file +} + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning +// frameworks. +message GraphProto { + // The nodes in the graph, sorted topologically. + repeated NodeProto node = 1; + + // The name of the graph. + optional string name = 2; // namespace Graph + + // The parameters(inputs) and outputs of the graph. + repeated ParameterProto parameters = 3; + repeated OutputProto outputs = 4; + + // Constants used in this graph + repeated NamedValueProto const_vals = 5; +} + +// Tensors +// +// A serialized tensor value. +message TensorProto { + // The node name of the tensor. + optional string node_name = 1; + + // The slot of the tensor in its node. + optional string slot = 2; + + // The serialized tensor content. + optional bytes tensor_content = 3; + + // The shape of the tensor. + repeated int64 dims = 4; + + // The data type of the tensor. + // This field MUST have a valid DataType value except DT_TENSOR + optional DataType data_type = 5; + + // If the tensor content transferring is finished. + optional bool finished = 6; +} \ No newline at end of file diff --git a/mindspore/ccsrc/debug/debugger/debug_grpc.proto b/mindspore/ccsrc/debug/debugger/debug_grpc.proto new file mode 100644 index 00000000000..f742987a4ed --- /dev/null +++ b/mindspore/ccsrc/debug/debugger/debug_grpc.proto @@ -0,0 +1,81 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package debugger; + +import "debug_graph.proto"; + +service EventListener { + rpc WaitCMD (Metadata) returns (EventReply) {}; + rpc SendMetadata (Metadata) returns (EventReply) {}; + rpc SendGraph (GraphProto) returns (EventReply) {}; + rpc SendTensors (stream TensorProto) returns (EventReply) {}; + rpc SendWatchpointHits (stream WatchpointHit) returns (EventReply) {}; +} + +message Metadata { + string device_name = 1; + int32 cur_step = 2; +} + +message EventReply { + enum Status { + OK = 0; + FAILED = 1; + PENDING = 2; + } + + Status status = 1; + + oneof cmd { + bool exit = 2; + int32 run_cmd = 3; + SetCMD set_cmd = 4; + ViewCMD view_cmd = 5; + } +} + +message SetCMD { + repeated WatchNode watch_nodes = 1; + WatchCondition watch_condition = 2; + bool delete = 3; + int32 id = 4; +} + +message ViewCMD { + repeated TensorProto tensors = 1; +} + +message WatchCondition { + enum Condition { + nan = 0; + inf = 1; + } + Condition condition = 1; +} + +message WatchNode { + string node_name = 1; + string node_type = 2; +} + +message WatchpointHit { + TensorProto tensor = 1; + WatchCondition watch_condition = 2; + int32 id = 3; +} diff --git a/mindspore/ccsrc/debug/debugger/debugger.cc b/mindspore/ccsrc/debug/debugger/debugger.cc new file mode 100644 index 00000000000..ea147a929f5 --- /dev/null +++ b/mindspore/ccsrc/debug/debugger/debugger.cc @@ -0,0 +1,488 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "debug/debugger/debugger.h" +#include "pipeline/pipeline.h" +#include "session/anf_runtime_algorithm.h" + +using debugger::EventReply; +using debugger::GraphProto; +using debugger::ModelProto; +using debugger::TensorProto; +using debugger::WatchCondition; +using debugger::WatchCondition_Condition_inf; +using debugger::WatchCondition_Condition_nan; +using debugger::WatchNode; +using debugger::WatchpointHit; + +namespace mindspore { + +DebuggerPtr Debugger::debugger_ = nullptr; +std::mutex Debugger::instance_lock_; + +Debugger::Debugger() + : grpc_client_(nullptr), + debug_services_(nullptr), + device_id_(0), + num_step_(0), + debugger_enabled_(false), + is_dataset_graph_(false) {} + +void Debugger::Init(const uint32_t device_id) { + // access lock for public method + std::lock_guard a_lock(access_lock_); + // save device_id + MS_LOG(INFO) << "Debugger got device_id: " << device_id; + device_id_ = device_id; +} + +void Debugger::EnableDebugger() { + // reset some of the class members + num_step_ = 0; + debugger_enabled_ = false; + grpc_client_ = nullptr; + debug_services_ = nullptr; + + // get env variables to configure debugger + const char *env_enable_str = std::getenv("ENABLE_MS_DEBUGGER"); + if (env_enable_str != nullptr) { + MS_LOG(INFO) << "Getenv ENABLE_MS_DEBUGGER: " << env_enable_str; + if (std::strcmp(env_enable_str, "1") == 0) { + debugger_enabled_ = true; + } + } + if (!debugger_enabled_) { + MS_LOG(WARNING) << "Not enabling debugger. Set environment variable ENABLE_MS_DEBUGGER=1 to enable debugger."; + return; + } + // configure host + const char *env_host_str = std::getenv("MS_DEBUGGER_HOST"); + std::string host; + if (env_host_str != nullptr) { + MS_LOG(INFO) << "Getenv MS_DEBUGGER_HOST: " << env_host_str; + host = std::string(env_host_str); + } else { + MS_LOG(WARNING) << "Environment variable MS_DEBUGGER_HOST doesn't exist. Using default debugger host: localhost"; + host = "localhost"; + } + // configure port + const char *env_port_str = std::getenv("MS_DEBUGGER_PORT"); + std::string port; + if (env_port_str != nullptr) { + MS_LOG(INFO) << "Getenv MS_DEBUGGER_PORT: " << env_port_str; + port = std::string(env_port_str); + } else { + MS_LOG(WARNING) << "Environment variable MS_DEBUGGER_PORT doesn't exist. Using default debugger port: 50051"; + port = "50051"; + } + + // initialize grpc client + grpc_client_ = std::make_unique(host, port); + debug_services_ = std::make_unique(); +} + +void Debugger::Reset() { + // access lock for public method + std::lock_guard a_lock(access_lock_); + // reset components + device_id_ = 0; + num_step_ = 0; + debugger_enabled_ = false; + is_dataset_graph_ = false; + graph_ptr_ = nullptr; + grpc_client_ = nullptr; + debug_services_ = nullptr; +} + +void Debugger::PreExecute(const KernelGraphPtr &graph_ptr) { + // access lock for public method + std::lock_guard a_lock(access_lock_); + // check and save graph_ptr, suspend if graph is new + CheckGraphPtr(graph_ptr); +} + +void Debugger::PostExecute() { + // access lock for public method + std::lock_guard a_lock(access_lock_); + // analyze tensor data and send the watchpoints been hit + if (debugger_enabled_ && !is_dataset_graph_) { + num_step_++; + MS_LOG(INFO) << "Debugger suspend at end of step; number of steps executed: " << num_step_; + SendWatchpointsAndSuspend(CheckWatchpoints()); + } +} + +void Debugger::PostDebugOp() { + // access lock for public method + std::lock_guard a_lock(access_lock_); + // suspend if debugger is enabled + if (debugger_enabled_ && !is_dataset_graph_) { + MS_LOG(INFO) << "Debugger suspend at debug_op"; + CommandLoop(); + } +} + +void Debugger::CheckGraphPtr(const KernelGraphPtr &graph_ptr) { + if (graph_ptr_ != graph_ptr) { + MS_LOG(INFO) << "Debugger got new graph: " << graph_ptr->graph_id(); + // save new graph_ptr + graph_ptr_ = graph_ptr; + // check if it is dataset graph + CheckDatasetGraph(); + if (!is_dataset_graph_) { + // only try to enable debugger if it is not a dataset graph + EnableDebugger(); + if (debugger_enabled_) { + // get graph proto and send to mindinsight + SendGraphAndSuspend(GetGraphProto()); + } + } + } +} + +void Debugger::CheckDatasetGraph() { + // print parameter node names + const auto ¶ms = graph_ptr_->inputs(); + for (const auto ¶m : params) { + MS_LOG(INFO) << "param: " << param->fullname_with_scope(); + } + // check if there is GetNext or InitDataSetQueue node + const auto &nodes = graph_ptr_->execution_order(); + for (const auto &node : nodes) { + auto node_name = AnfAlgo::GetCNodeName(node); + MS_LOG(INFO) << "node: " << node->fullname_with_scope(); + if (node_name == "GetNext" || node_name == "InitDataSetQueue") { + MS_LOG(WARNING) << "Not enabling debugger for graph " << graph_ptr_->graph_id() << ": found dataset graph node " + << node_name; + is_dataset_graph_ = true; + return; + } + } + is_dataset_graph_ = false; +} + +GraphProto Debugger::GetGraphProto() { + // convert kernel graph to debugger modelproto + ModelProto model = GetDebuggerFuncGraphProto(graph_ptr_); + return model.graph(); +} + +void Debugger::SendGraphAndSuspend(const GraphProto &graph_proto) { + // prepare metadata + std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id()); + Metadata metadata; + metadata.set_device_name(device_name); + metadata.set_cur_step(num_step_); + EventReply reply_metadata = grpc_client_->SendMetadata(metadata); + if (reply_metadata.status() != reply_metadata.OK) { + MS_LOG(ERROR) << "Error: SendMetadata failed"; + } + // send graph to mindinght server + EventReply reply = grpc_client_->SendGraph(graph_proto); + if (reply.status() != reply.OK) { + MS_LOG(ERROR) << "Error: SendGraph failed"; + } + // enter command loop, wait and process commands + CommandLoop(); +} + +void Debugger::CommandLoop() { + // prepare metadata + std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id()); + Metadata metadata; + metadata.set_device_name(device_name); + metadata.set_cur_step(num_step_); + + // loop exit flag + bool run = false; + int num_wait_fail = 0; + const int max_num_wait_fail = 5; + + while (!run) { + // wait for command + EventReply reply = grpc_client_->WaitForCommand(metadata); + if (reply.status() != reply.OK) { + MS_LOG(ERROR) << "Error: WaitForCommand failed"; + num_wait_fail++; + if (num_wait_fail > max_num_wait_fail) { + MS_LOG(ERROR) << "Maximum number of WaitForCommand retry reached: exiting training session"; + Exit(); + } + MS_LOG(ERROR) << "Number of consecutive WaitForCommand fail:" << num_wait_fail << "; Retry after " + << num_wait_fail << "s"; + std::this_thread::sleep_for(std::chrono::milliseconds(1000 * num_wait_fail)); + continue; + } + + // get type of the command in reply + DebuggerCommand cmd = GetCommand(reply); + if (cmd == DebuggerCommand::kUnknownCMD) { + MS_LOG(ERROR) << "Error: debugger recieved unknown command"; + continue; + } + + MS_LOG(INFO) << "recieved command: "; + switch (cmd) { + case DebuggerCommand::kUnknownCMD: + MS_LOG(INFO) << "UnknownCMD"; + break; + case DebuggerCommand::kExitCMD: + MS_LOG(INFO) << "ExitCMD"; + Exit(); + break; + case DebuggerCommand::kRunCMD: + MS_LOG(INFO) << "RunCMD"; + // exit loop + run = true; + break; + case DebuggerCommand::kSetCMD: + MS_LOG(INFO) << "SetCMD"; + { + // print set cmd content + ProtoVector recieved_nodes = GetWatchnodes(reply); + for (auto node : recieved_nodes) { + MS_LOG(INFO) << "node name: " << node.node_name(); + MS_LOG(INFO) << "node type: " << node.node_type(); + } + WatchCondition recieved_condition = GetWatchcondition(reply); + MS_LOG(INFO) << "condition: " << recieved_condition.condition(); + int32_t id = GetWatchpointID(reply); + MS_LOG(INFO) << "id: " << id; + bool delete_ = GetWatchpointDelete(reply); + MS_LOG(INFO) << "delete: " << delete_; + } + MS_LOG(INFO) << "Setting watchpoint"; + if (GetWatchpointDelete(reply)) { + RemoveWatchpoint(GetWatchpointID(reply)); + } else { + SetWatchpoint(GetWatchnodes(reply), GetWatchcondition(reply), GetWatchpointID(reply)); + } + break; + case DebuggerCommand::kViewCMD: + MS_LOG(INFO) << "ViewCMD"; + { + // print view cmd content + ProtoVector received_tensors = GetTensors(reply); + for (auto tensor : received_tensors) { + MS_LOG(INFO) << "tensor node name: " << tensor.node_name(); + MS_LOG(INFO) << "tensor slot: " << tensor.slot(); + MS_LOG(INFO) << "tensor finished: " << std::boolalpha << tensor.finished() << std::noboolalpha; + } + } + MS_LOG(INFO) << "Sending tensors"; + std::list tensors = LoadTensors(GetTensors(reply)); + { + for (auto tensor : tensors) { + MS_LOG(INFO) << "tensor node name: " << tensor.node_name(); + MS_LOG(INFO) << "tensor slot: " << tensor.slot(); + MS_LOG(INFO) << "tensor finished: " << std::boolalpha << tensor.finished() << std::noboolalpha; + MS_LOG(INFO) << "tensor dims: "; + for (auto dim : tensor.dims()) { + MS_LOG(INFO) << dim << ","; + } + MS_LOG(INFO) << "tensor dtype: " << tensor.data_type(); + } + } + EventReply send_tensors_reply = grpc_client_->SendTensors(tensors); + if (send_tensors_reply.status() != send_tensors_reply.OK) { + MS_LOG(ERROR) << "Error: SendTensors failed"; + } + break; + } + } +} + +DebuggerCommand Debugger::GetCommand(const EventReply &reply) { + DebuggerCommand cmd = DebuggerCommand::kUnknownCMD; + switch (reply.cmd_case()) { + case debugger::EventReply::CmdCase::kExit: + cmd = DebuggerCommand::kExitCMD; + break; + case debugger::EventReply::CmdCase::kRunCmd: + cmd = DebuggerCommand::kRunCMD; + break; + case debugger::EventReply::CmdCase::kSetCmd: + cmd = DebuggerCommand::kSetCMD; + break; + case debugger::EventReply::CmdCase::kViewCmd: + cmd = DebuggerCommand::kViewCMD; + break; + default: + MS_LOG(ERROR) << "Error: UnknownCMD"; + break; + } + return cmd; +} + +ProtoVector Debugger::GetWatchnodes(const EventReply &reply) { + if (!reply.has_set_cmd()) { + MS_LOG(ERROR) << "Error: Not SetCMD, can not get WatchNodes. Returning default value: ProtoVector()."; + return ProtoVector(); + } + return reply.set_cmd().watch_nodes(); +} + +WatchCondition Debugger::GetWatchcondition(const EventReply &reply) { + if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) { + MS_LOG(ERROR) << "Error: Can not get WatchCondition from command. Returning default value: WatchCondition()."; + return WatchCondition(); + } + return reply.set_cmd().watch_condition(); +} + +int32_t Debugger::GetWatchpointID(const EventReply &reply) { + if (!reply.has_set_cmd()) { + MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint ID. Returning default value: 0."; + return 0; + } + return reply.set_cmd().id(); +} + +bool Debugger::GetWatchpointDelete(const EventReply &reply) { + if (!reply.has_set_cmd()) { + MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint delete flag. Returning default value: false."; + return false; + } + return reply.set_cmd().delete_(); +} + +ProtoVector Debugger::GetTensors(const EventReply &reply) { + if (!reply.has_view_cmd()) { + MS_LOG(ERROR) << "Error: Not ViewCMD, can not get Tensors. Returning default value: ProtoVector()."; + return ProtoVector(); + } + return reply.view_cmd().tensors(); +} + +void Debugger::SetWatchpoint(const ProtoVector &nodes, const WatchCondition &condition, const int32_t id) { + std::vector> check_node_list; + std::transform(nodes.begin(), nodes.end(), std::back_inserter(check_node_list), + [](WatchNode node) -> std::tuple { + return make_tuple(node.node_name(), node.node_type() == "scope"); + }); + + debug_services_->add_watchpoint(id, condition.condition(), check_node_list); +} + +void Debugger::RemoveWatchpoint(const int32_t id) { debug_services_->remove_watchpoint(id); } + +std::list Debugger::LoadTensors(const ProtoVector &tensors) { + std::vector name; + std::vector ret_name; + std::vector data_ptr; + std::vector data_size; + std::vector dtype; + std::vector> shape; + + std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), + [](TensorProto tensor) -> std::string { return tensor.node_name() + ":" + tensor.slot(); }); + + debug_services_->read_nodes_tensors(name, &ret_name, &data_ptr, &data_size, &dtype, &shape); + + std::list tensor_list; + unsigned int result_index = 0; + TensorProto tensor_item; + + for (auto tensor : tensors) { + tensor_item.set_node_name(tensor.node_name()); + tensor_item.set_slot(tensor.slot()); + tensor_item.set_finished(true); + + // return empty tensor if didn't find the requested tensor + if (result_index >= ret_name.size() || ret_name[result_index] != tensor.node_name() + ":" + tensor.slot()) { + tensor_list.push_back(tensor_item); + continue; + } + + tensor_item.set_tensor_content(data_ptr[result_index], data_size[result_index]); + tensor_item.set_data_type(GetDebuggerNumberDataType(dtype[result_index])); + tensor_item.clear_dims(); + for (auto &elem : shape[result_index]) { + tensor_item.add_dims(elem); + } + + tensor_list.push_back(tensor_item); + + result_index++; + } + + return tensor_list; +} + +void Debugger::Exit() { + // clear resource before exit + pipeline::ClearResAtexit(); + std::exit(EXIT_FAILURE); +} + +std::list Debugger::CheckWatchpoints() { + std::vector name; + std::vector slot; + std::vector data_ptr; + std::vector data_size; + std::vector condition; + std::vector watchpoint_id; + + debug_services_->check_watchpoints(&name, &slot, &data_ptr, &data_size, &condition, &watchpoint_id); + + std::list points; + + for (unsigned int i = 0; i < name.size(); i++) { + TensorProto *tensor_item; + tensor_item = new TensorProto(); + tensor_item->set_node_name(name[i]); + tensor_item->set_slot(slot[i]); + tensor_item->set_tensor_content(data_ptr[i], data_size[i]); + + // finished in TensorProto will always be true before we implement big tensor splitting + tensor_item->set_finished(true); + + WatchCondition *condition_item; + condition_item = new WatchCondition(); + condition_item->set_condition(debugger::WatchCondition_Condition(condition[i])); + + WatchpointHit point; + point.set_allocated_tensor(tensor_item); + point.set_allocated_watch_condition(condition_item); + point.set_id(watchpoint_id[i]); + + points.push_back(point); + } + + return points; +} + +void Debugger::SendWatchpointsAndSuspend(const std::list &points) { + // send info about watchpoint + if (!points.empty()) { + EventReply reply = grpc_client_->SendWatchpointHits(points); + if (reply.status() != reply.OK) { + MS_LOG(ERROR) << "Error: SendWatchpointHits failed"; + } + } + // enter command loop + CommandLoop(); +} + +DebugServices *Debugger::get_debug_services() { return debug_services_.get(); } + +bool Debugger::debugger_enabled() { return debugger_enabled_; } + +} // namespace mindspore diff --git a/mindspore/ccsrc/debug/debugger/debugger.h b/mindspore/ccsrc/debug/debugger/debugger.h new file mode 100644 index 00000000000..6ce7d036257 --- /dev/null +++ b/mindspore/ccsrc/debug/debugger/debugger.h @@ -0,0 +1,159 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_DEBUG_DEBUGGER_DEBUGGER_H_ +#define MINDSPORE_CCSRC_DEBUG_DEBUGGER_DEBUGGER_H_ + +#include +#include +#include +#include "session/kernel_graph.h" +#include "debug/debugger/grpc_client.h" +#include "debug/debug_services.h" + +using debugger::DataType; +using debugger::EventReply; +using debugger::GraphProto; +using debugger::ModelProto; +using debugger::TensorProto; +using debugger::WatchCondition; +using debugger::WatchNode; +using debugger::WatchpointHit; + +template +using ProtoVector = google::protobuf::RepeatedPtrField; + +namespace mindspore { +// different types of command recieved by debugger +// need to keep sync with client-side proto and server-side proto +enum class DebuggerCommand { kExitCMD = 2, kRunCMD = 3, kSetCMD = 4, kViewCMD = 5, kUnknownCMD = -1 }; + +class Debugger : public std::enable_shared_from_this { + public: + static std::shared_ptr GetInstance() { + std::lock_guard i_lock(instance_lock_); + if (debugger_ == nullptr) { + debugger_ = std::shared_ptr(new (std::nothrow) Debugger()); + } + return debugger_; + } + + // deconstructor + ~Debugger() = default; + + // init + // only save device_id + void Init(const uint32_t device_id); + + // reset debugger + void Reset(); + + // enable debugger + // send graph and wait for command + // do nothing if graph is set already + void PreExecute(const KernelGraphPtr &graph_ptr); + + // analyze tensors and wait for command + // don't need a graph_ptr because it is saved during pre_execute + void PostExecute(); + + // suspend the execution after a debug_op + void PostDebugOp(); + + DebugServices *get_debug_services(); + + bool debugger_enabled(); + + private: + // private constructor for singleton + Debugger(); + + // enable debugger + // instantiate class members + // read env variable for grpc client + void EnableDebugger(); + + // check and save graph pointer + void CheckGraphPtr(const KernelGraphPtr &graph_ptr); + + // check if the graph is a dataset graph + void CheckDatasetGraph(); + + // serialize graph and get proto + GraphProto GetGraphProto(); + + // send graph and enter command wait loop + void SendGraphAndSuspend(const GraphProto &graph_proto); + + // wait for command and process command + // send command request and process reply in a loop + // break if RunCMD + void CommandLoop(); + + // process reply and command type + DebuggerCommand GetCommand(const EventReply &reply); + + // parse other data out of EventReply + ProtoVector GetWatchnodes(const EventReply &reply); + WatchCondition GetWatchcondition(const EventReply &reply); + int32_t GetWatchpointID(const EventReply &reply); + bool GetWatchpointDelete(const EventReply &reply); + ProtoVector GetTensors(const EventReply &reply); + + // set what nodes and conditions to watch + void SetWatchpoint(const ProtoVector &nodes, const WatchCondition &condition, const int32_t id); + + // remove watchpoint with id + void RemoveWatchpoint(const int32_t id); + + // load tensor for view command + std::list LoadTensors(const ProtoVector &tensors); + + // terminate training process + void Exit(); + + // analyze tensors and check watchpoint conditions + // return names of tensors and what condition they hit + std::list CheckWatchpoints(); + + // send watchpoints that hit and enter command wait loop + void SendWatchpointsAndSuspend(const std::list &points); + + // class members + std::unique_ptr grpc_client_; + std::unique_ptr debug_services_; + KernelGraphPtr graph_ptr_; + uint32_t device_id_; + int32_t num_step_; + bool debugger_enabled_; + bool is_dataset_graph_; + std::mutex access_lock_; + + // singleton + static std::mutex instance_lock_; + static std::shared_ptr debugger_; +}; + +using DebuggerPtr = std::shared_ptr; + +// get debugger ModelProto +std::string GetDebuggerFuncGraphProtoString(const FuncGraphPtr &func_graph); +ModelProto GetDebuggerFuncGraphProto(const FuncGraphPtr &func_graph); + +// for getting proto DataType from Type of Tensor +DataType GetDebuggerNumberDataType(const TypePtr &type); + +} // namespace mindspore +#endif // MINDSPORE_CCSRC_DEBUG_DEBUGGER_DEBUGGER_H_ diff --git a/mindspore/ccsrc/debug/debugger/grpc_client.cc b/mindspore/ccsrc/debug/debugger/grpc_client.cc new file mode 100644 index 00000000000..7709f4c0d1a --- /dev/null +++ b/mindspore/ccsrc/debug/debugger/grpc_client.cc @@ -0,0 +1,124 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "debug/debugger/grpc_client.h" +#include "utils/log_adapter.h" + +using debugger::EventListener; +using debugger::EventReply; +using debugger::EventReply_Status_FAILED; +using debugger::GraphProto; +using debugger::Metadata; +using debugger::TensorProto; +using debugger::WatchpointHit; + +namespace mindspore { +GrpcClient::GrpcClient(const std::string &host, const std::string &port) : stub_(nullptr) { Init(host, port); } + +void GrpcClient::Init(const std::string &host, const std::string &port) { + std::string target_str = host + ":" + port; + MS_LOG(INFO) << "GrpcClient connecting to: " << target_str; + + std::shared_ptr channel = grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials()); + stub_ = EventListener::NewStub(channel); +} + +void GrpcClient::Reset() { stub_ = nullptr; } + +EventReply GrpcClient::WaitForCommand(const Metadata &metadata) { + EventReply reply; + grpc::ClientContext context; + grpc::Status status = stub_->WaitCMD(&context, metadata, &reply); + + if (!status.ok()) { + MS_LOG(ERROR) << "RPC failed: WaitForCommand"; + MS_LOG(ERROR) << status.error_code() << ": " << status.error_message(); + reply.set_status(EventReply_Status_FAILED); + } + return reply; +} + +EventReply GrpcClient::SendMetadata(const Metadata &metadata) { + EventReply reply; + grpc::ClientContext context; + grpc::Status status = stub_->SendMetadata(&context, metadata, &reply); + + if (!status.ok()) { + MS_LOG(ERROR) << "RPC failed: SendMetadata"; + MS_LOG(ERROR) << status.error_code() << ": " << status.error_message(); + reply.set_status(EventReply_Status_FAILED); + } + return reply; +} + +EventReply GrpcClient::SendGraph(const GraphProto &graph) { + EventReply reply; + grpc::ClientContext context; + grpc::Status status = stub_->SendGraph(&context, graph, &reply); + + if (!status.ok()) { + MS_LOG(ERROR) << "RPC failed: SendGraph"; + MS_LOG(ERROR) << status.error_code() << ": " << status.error_message(); + reply.set_status(EventReply_Status_FAILED); + } + return reply; +} + +EventReply GrpcClient::SendTensors(const std::list &tensors) { + EventReply reply; + grpc::ClientContext context; + + std::unique_ptr > writer(stub_->SendTensors(&context, &reply)); + for (const auto &tensor : tensors) { + if (!writer->Write(tensor)) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + writer->WritesDone(); + grpc::Status status = writer->Finish(); + + if (!status.ok()) { + MS_LOG(ERROR) << "RPC failed: SendTensors"; + MS_LOG(ERROR) << status.error_code() << ": " << status.error_message(); + reply.set_status(EventReply_Status_FAILED); + } + return reply; +} + +EventReply GrpcClient::SendWatchpointHits(const std::list &watchpoints) { + EventReply reply; + grpc::ClientContext context; + + std::unique_ptr > writer(stub_->SendWatchpointHits(&context, &reply)); + for (const auto &watchpoint : watchpoints) { + if (!writer->Write(watchpoint)) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + writer->WritesDone(); + grpc::Status status = writer->Finish(); + + if (!status.ok()) { + MS_LOG(ERROR) << "RPC failed: SendWatchpointHits"; + MS_LOG(ERROR) << status.error_code() << ": " << status.error_message(); + reply.set_status(EventReply_Status_FAILED); + } + return reply; +} +} // namespace mindspore diff --git a/mindspore/ccsrc/debug/debugger/grpc_client.h b/mindspore/ccsrc/debug/debugger/grpc_client.h new file mode 100644 index 00000000000..0b5359e4447 --- /dev/null +++ b/mindspore/ccsrc/debug/debugger/grpc_client.h @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_DEBUG_DEBUGGER_GRPC_CLIENT_H_ +#define MINDSPORE_CCSRC_DEBUG_DEBUGGER_GRPC_CLIENT_H_ + +#include +#include +#include +#include +#include "proto/debug_grpc.grpc.pb.h" + +using debugger::EventListener; +using debugger::EventReply; +using debugger::GraphProto; +using debugger::Metadata; +using debugger::TensorProto; +using debugger::WatchpointHit; + +namespace mindspore { +class GrpcClient { + public: + // constructor + GrpcClient(const std::string &host, const std::string &port); + + // deconstructor + ~GrpcClient() = default; + + // init + void Init(const std::string &host, const std::string &port); + + // reset + void Reset(); + + EventReply WaitForCommand(const Metadata &metadata); + + EventReply SendMetadata(const Metadata &metadata); + + EventReply SendGraph(const GraphProto &graph); + + EventReply SendTensors(const std::list &tensors); + + EventReply SendWatchpointHits(const std::list &watchpoints); + + private: + std::unique_ptr stub_; +}; +} // namespace mindspore +#endif // MINDSPORE_CCSRC_DEBUG_DEBUGGER_GRPC_CLIENT_H_ diff --git a/mindspore/ccsrc/debug/debugger/proto_exporter.cc b/mindspore/ccsrc/debug/debugger/proto_exporter.cc new file mode 100644 index 00000000000..b4b4de9d994 --- /dev/null +++ b/mindspore/ccsrc/debug/debugger/proto_exporter.cc @@ -0,0 +1,542 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "debug/debugger/debugger.h" +#include "proto/debug_graph.pb.h" +#include "utils/graph_utils.h" +#include "utils/symbolic.h" + +namespace mindspore { +class DebuggerProtoExporter { + public: + DebuggerProtoExporter() {} + ~DebuggerProtoExporter() {} + + std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph); + debugger::ModelProto GetFuncGraphProto(const FuncGraphPtr &func_graph); + + private: + void InitModelInfo(); + void GetOpNodeTypeAndAttrs(const FuncGraphPtr &func_graph, const AnfNodePtr &node, debugger::NodeProto *node_proto); + std::string GetOpNodeInputId(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const std::map &apply_map, + std::map *const_map_ptr); + void SetValueToProto(const ValuePtr &attr_value, debugger::ValueProto *value_proto); + void SetScalarToProto(const ScalarPtr &val, debugger::ValueProto *value_proto); + void SetSequenceToProto(const ValueSequeuePtr &val, debugger::ValueProto *value_proto); + void SetDictionaryToProto(const ValueDictionaryPtr &val, debugger::ValueProto *value_proto); + void SetNodeOutputType(const AnfNodePtr &node, debugger::TypeProto *type_proto); + void SetNodeOutputType(const TypePtr &node, const BaseShapePtr &shape, debugger::TypeProto *type_proto); + + void ExportFuncGraph(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto); + void ExportParameters(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto); + void ExportCNodes(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto, + std::map *const_map_ptr); + void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *apply_map_ptr, + std::map *const_map_ptr, debugger::GraphProto *graph_proto); + void ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node, + const std::map &apply_map, std::map *const_map_ptr, + debugger::GraphProto *graph_proto); + void ExportValueNodes(const std::map &const_map, debugger::GraphProto *graph_proto); + + static std::string GetConstNodeId(size_t idx) { return std::string("cst") + std::to_string(idx); } + + debugger::ModelProto model_; +}; + +void DebuggerProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &shape, + debugger::TypeProto *type_proto) { + if (type_proto == nullptr) { + return; + } + + if (type == nullptr) { + type_proto->set_data_type(debugger::DT_UNDEFINED); + } else if (type->isa()) { + type_proto->set_data_type(GetDebuggerNumberDataType(type)); + } else if (type->isa()) { + TypePtr elem_type = dyn_cast(type)->element(); + type_proto->mutable_tensor_type()->set_elem_type(GetDebuggerNumberDataType(elem_type)); + type_proto->set_data_type(debugger::DT_TENSOR); + if (shape != nullptr && shape->isa()) { + abstract::ShapePtr shape_info = dyn_cast(shape); + for (const auto &elem : shape_info->shape()) { + type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem); + } + } + } else if (type->isa()) { + TuplePtr tuple_type = dyn_cast(type); + type_proto->set_data_type(debugger::DT_TUPLE); + for (const auto &elem_type : tuple_type->elements()) { + SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); + } + } else if (type->isa()) { + type_proto->set_data_type(debugger::DT_TYPE); + } else if (type->isa()) { + ListPtr list_type = dyn_cast(type); + type_proto->set_data_type(debugger::DT_LIST); + for (const auto &elem_type : list_type->elements()) { + SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); + } + } else if (type->isa()) { + type_proto->set_data_type(debugger::DT_ANYTHING); + } else if (type->isa()) { + type_proto->set_data_type(debugger::DT_REFKEY); + } else if (type->isa()) { + type_proto->set_data_type(debugger::DT_REF); + } else if (type->isa()) { + type_proto->set_data_type(debugger::DT_GRAPH); + } else if (type->isa()) { + type_proto->set_data_type(debugger::DT_NONE); + } else if (type->isa()) { + type_proto->set_data_type(debugger::DT_STRING); + } else if (type->isa()) { + // Do Nothing. + } else { + MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name(); + } +} + +void DebuggerProtoExporter::SetNodeOutputType(const AnfNodePtr &node, debugger::TypeProto *type_proto) { + if (node == nullptr || type_proto == nullptr) { + return; + } + SetNodeOutputType(node->Type(), node->Shape(), type_proto); +} + +void DebuggerProtoExporter::SetValueToProto(const ValuePtr &val, debugger::ValueProto *value_proto) { + if (val == nullptr || value_proto == nullptr) { + return; + } + + if (val->isa()) { + const StringImmPtr &value = dyn_cast(val); + value_proto->set_dtype(debugger::DT_STRING); + value_proto->set_str_val(value->value()); + } else if (val->isa()) { + SetScalarToProto(dyn_cast(val), value_proto); + } else if (val->isa()) { + value_proto->set_dtype(debugger::DT_TYPE); + value_proto->mutable_type_val()->set_data_type(debugger::DT_BOOL); + } else if (val->isa()) { + value_proto->set_dtype(debugger::DT_TYPE); + value_proto->mutable_type_val()->set_data_type(debugger::DT_BASE_INT); + } else if (val->isa()) { + value_proto->set_dtype(debugger::DT_TYPE); + value_proto->mutable_type_val()->set_data_type(debugger::DT_BASE_FLOAT); + } else if (val->isa()) { + SetSequenceToProto(dyn_cast(val), value_proto); + } else if (val->isa()) { + value_proto->set_dtype(debugger::DT_NONE); + value_proto->set_str_val("None"); + } else if (val->isa()) { + SymbolicKeyInstancePtr sym_inst = dyn_cast(val); + ParameterPtr sym_node = dyn_cast(sym_inst->node()); + value_proto->set_dtype(debugger::DT_SYM_INST); + value_proto->set_str_val(sym_node == nullptr ? std::string("nullptr") : sym_node->ToString()); + } else if (val->isa()) { + SetDictionaryToProto(dyn_cast(val), value_proto); + } else if (val->isa()) { + tensor::TensorPtr tensor_ptr = dyn_cast(val); + value_proto->set_dtype(debugger::DT_TENSOR); + debugger::TensorProto *tensor_proto = value_proto->mutable_tensor_val(); + tensor_proto->set_data_type(GetDebuggerNumberDataType(tensor_ptr->Dtype())); + for (auto &elem : tensor_ptr->shape()) { + tensor_proto->add_dims(elem); + } + tensor_proto->set_tensor_content(tensor_ptr->data_c(), tensor_ptr->data().nbytes()); + } else if (val->isa()) { + value_proto->set_dtype(debugger::DT_TYPE); + + debugger::TypeProto *type_proto = value_proto->mutable_type_val(); + type_proto->set_data_type(debugger::DT_TENSOR); + TypePtr elem_type = dyn_cast(val)->element(); + type_proto->mutable_tensor_type()->set_elem_type(GetDebuggerNumberDataType(elem_type)); + } else { + MS_LOG(WARNING) << "Unsupported type " << val->type_name(); + } +} + +void DebuggerProtoExporter::SetScalarToProto(const ScalarPtr &val, debugger::ValueProto *value_proto) { + if (val == nullptr || value_proto == nullptr) { + return; + } + + if (val->isa()) { + const BoolImmPtr &value = dyn_cast(val); + value_proto->set_dtype(debugger::DT_BOOL); + value_proto->set_bool_val(value->value()); + } else if (val->isa()) { + const Int8ImmPtr &value = dyn_cast(val); + value_proto->set_dtype(debugger::DT_INT8); + value_proto->set_int_val(value->value()); + } else if (val->isa()) { + const Int16ImmPtr &value = dyn_cast(val); + value_proto->set_dtype(debugger::DT_INT16); + value_proto->set_int_val(value->value()); + } else if (val->isa()) { + const Int32ImmPtr &value = dyn_cast(val); + value_proto->set_dtype(debugger::DT_INT32); + value_proto->set_int_val(value->value()); + } else if (val->isa()) { + const Int64ImmPtr &value = dyn_cast(val); + value_proto->set_dtype(debugger::DT_INT64); + value_proto->set_int_val(value->value()); + } else if (val->isa()) { + const UInt8ImmPtr &value = dyn_cast(val); + value_proto->set_dtype(debugger::DT_UINT8); + value_proto->set_uint_val(value->value()); + } else if (val->isa()) { + const UInt16ImmPtr &value = dyn_cast(val); + value_proto->set_dtype(debugger::DT_UINT16); + value_proto->set_uint_val(value->value()); + } else if (val->isa()) { + const UInt32ImmPtr &value = dyn_cast(val); + value_proto->set_dtype(debugger::DT_UINT32); + value_proto->set_uint_val(value->value()); + } else if (val->isa()) { + const UInt64ImmPtr &value = dyn_cast(val); + value_proto->set_dtype(debugger::DT_UINT64); + value_proto->set_uint_val(value->value()); + } else if (val->isa()) { + const FP32ImmPtr &value = dyn_cast(val); + value_proto->set_dtype(debugger::DT_FLOAT32); + value_proto->set_float_val(value->value()); + } else if (val->isa()) { + const FP64ImmPtr &value = dyn_cast(val); + value_proto->set_dtype(debugger::DT_FLOAT64); + value_proto->set_double_val(value->value()); + } else { + MS_LOG(EXCEPTION) << "Unknown scalar type " << val->ToString(); + } +} + +void DebuggerProtoExporter::SetSequenceToProto(const ValueSequeuePtr &val, debugger::ValueProto *value_proto) { + if (val == nullptr || value_proto == nullptr) { + return; + } + + if (val->isa()) { + const ValueTuplePtr &value = dyn_cast(val); + value_proto->set_dtype(debugger::DT_TUPLE); + for (const auto &item : value->value()) { + SetValueToProto(item, value_proto->add_values()); + } + } else if (val->isa()) { + const ValueListPtr &value = dyn_cast(val); + value_proto->set_dtype(debugger::DT_LIST); + for (const auto &item : value->value()) { + SetValueToProto(item, value_proto->add_values()); + } + } +} + +void DebuggerProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr &val, debugger::ValueProto *value_proto) { + if (val == nullptr || value_proto == nullptr) { + return; + } + + value_proto->set_dtype(debugger::DT_DICT); + for (const auto &item : val->value()) { + debugger::NamedValueProto *named_val = value_proto->add_dict_val(); + named_val->set_key(item.first); + SetValueToProto(item.second, named_val->mutable_value()); + } +} + +void DebuggerProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr &, const AnfNodePtr &node, + debugger::NodeProto *node_proto) { + if (node == nullptr || node_proto == nullptr) { + return; + } + + if (node->isa() || node->isa() || IsValueNode(node)) { + MS_LOG(EXCEPTION) << "Op node can not be CNode, Parameter or ValueNode Graph. But got " << node->ToString(); + } + + if (!IsValueNode(node)) { + MS_LOG(EXCEPTION) << "Op node is not primitive: " << node->ToString(); + } + + const PrimitivePtr &prim = GetValueNode(node); + node_proto->set_op_type(prim->name()); + for (const auto &attr : prim->attrs()) { + debugger::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name(attr.first); + SetValueToProto(attr.second, attr_proto->mutable_value()); + } + node_proto->set_scope(node->scope()->name()); +} + +std::string DebuggerProtoExporter::GetOpNodeInputId(const FuncGraphPtr &, const AnfNodePtr &node, + const std::map &apply_map, + std::map *const_map_ptr) { + if (node == nullptr || const_map_ptr == nullptr) { + return ""; + } + + if (node->isa()) { + auto iter = apply_map.find(node); + if (iter == apply_map.end()) { + MS_LOG(EXCEPTION) << "Can not find node '" << node->ToString() << "' in apply_map"; + } + return std::to_string(iter->second); + } + + if (node->isa()) { + return node->ToString(); + } + + if (node->isa()) { + auto iter = const_map_ptr->find(node); + if (iter == const_map_ptr->end()) { + // Start index number from 1 + auto const_idx = const_map_ptr->size() + 1; + (*const_map_ptr)[node] = const_idx; + } + return GetConstNodeId((*const_map_ptr)[node]); + } + + MS_LOG(EXCEPTION) << "Unknown node type. node is '" << node->ToString() << "'"; +} + +std::string DebuggerProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { + if (func_graph == nullptr) { + return ""; + } + + InitModelInfo(); + debugger::GraphProto *graph_proto = model_.mutable_graph(); + ExportFuncGraph(func_graph, graph_proto); + return model_.SerializeAsString(); +} + +debugger::ModelProto DebuggerProtoExporter::GetFuncGraphProto(const FuncGraphPtr &func_graph) { + if (func_graph == nullptr) { + return ModelProto(); + } + + InitModelInfo(); + debugger::GraphProto *graph_proto = model_.mutable_graph(); + ExportFuncGraph(func_graph, graph_proto); + return model_; +} + +void DebuggerProtoExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto) { + if (func_graph == nullptr || graph_proto == nullptr) { + return; + } + + // map for store ValueNodes of this graph + std::map const_map; + + // set graph name + graph_proto->set_name(func_graph->ToString()); + + ExportParameters(func_graph, graph_proto); + + ExportCNodes(func_graph, graph_proto, &const_map); + + ExportValueNodes(const_map, graph_proto); +} + +void DebuggerProtoExporter::ExportParameters(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto) { + if (func_graph == nullptr || graph_proto == nullptr) { + return; + } + + // cast FuncGraph to KernelGraph to access inputs() + std::vector parameters = static_cast(func_graph.get())->inputs(); + + for (auto ¶m : parameters) { + debugger::ParameterProto *param_proto = graph_proto->add_parameters(); + param_proto->set_name(param->ToString()); + + SetNodeOutputType(param, param_proto->mutable_type()); + + const ParameterPtr param_ptr = dyn_cast(param); + if (param_ptr == nullptr) { + MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter."; + } + } +} + +void DebuggerProtoExporter::ExportCNodes(const FuncGraphPtr &func_graph, debugger::GraphProto *graph_proto, + std::map *const_map_ptr) { + if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) { + return; + } + // topo sort nodes + std::vector nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); + std::map apply_map; + for (const AnfNodePtr &node : nodes) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + if (cnode != func_graph->get_return()) { + ExportCNode(func_graph, cnode, &apply_map, const_map_ptr, graph_proto); + } else { + ExportFuncGraphOutput(func_graph, cnode, apply_map, const_map_ptr, graph_proto); + } + } +} + +void DebuggerProtoExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *apply_map_ptr, + std::map *const_map_ptr, + debugger::GraphProto *graph_proto) { + if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr || + graph_proto == nullptr) { + return; + } + + auto apply_idx = apply_map_ptr->size() + 1; + (*apply_map_ptr)[node] = apply_idx; + + auto &inputs = node->inputs(); + if (inputs.size() < 1) { + MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; + } + AnfNodePtr op = inputs[0]; + debugger::NodeProto *node_proto = graph_proto->add_node(); + + // CNode/ConstGraph/Const/Parameter + if (op->isa() || IsValueNode(op) || op->isa()) { + MS_LOG(WARNING) << "Operator must be a primitive"; + } else { + GetOpNodeTypeAndAttrs(func_graph, op, node_proto); + node_proto->set_name(std::to_string(apply_idx)); + node_proto->set_scope(node->scope()->name()); + + // add debug_name for debugger + node_proto->set_debug_name(node->fullname_with_scope()); + + // process OP inputs + for (size_t i = 1; i < inputs.size(); ++i) { + debugger::InputProto *input_proto = node_proto->add_input(); + input_proto->set_type(debugger::InputProto_EdgeType_DATA_EDGE); + std::string id = GetOpNodeInputId(func_graph, inputs[i], *apply_map_ptr, const_map_ptr); + input_proto->set_name(id); + } + + // set node output type + SetNodeOutputType(node, node_proto->mutable_output_type()); + } +} + +void DebuggerProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node, + const std::map &apply_map, + std::map *const_map_ptr, + debugger::GraphProto *graph_proto) { + if (ret_node == nullptr || !ret_node->isa()) { + MS_LOG(EXCEPTION) << "Graph return node is illegal"; + } + AnfNodePtr arg = ret_node->input(1); + if (graph_proto == nullptr) { + MS_LOG(EXCEPTION) << "graph_proto is nullptr"; + } + debugger::OutputProto *output_proto = graph_proto->add_outputs(); + if (output_proto == nullptr) { + MS_LOG(EXCEPTION) << "output_proto is nullptr"; + } + std::string id = GetOpNodeInputId(func_graph, arg, apply_map, const_map_ptr); + output_proto->set_name(id); + SetNodeOutputType(arg, output_proto->mutable_type()); +} + +static bool CompareValue(const std::pair &x, const std::pair &y) { + return x.second < y.second; +} + +void DebuggerProtoExporter::ExportValueNodes(const std::map &const_map, + debugger::GraphProto *graph_proto) { + std::vector> nodes; + (void)std::transform(const_map.cbegin(), const_map.cend(), std::back_inserter(nodes), + [](const std::pair &item) { return item; }); + + sort(nodes.begin(), nodes.end(), CompareValue); + + for (auto &item : nodes) { + if (graph_proto == nullptr) { + MS_LOG(EXCEPTION) << "graph_proto is nullptr"; + } + debugger::NamedValueProto *named_value = graph_proto->add_const_vals(); + MS_EXCEPTION_IF_NULL(named_value); + named_value->set_key(GetConstNodeId(item.second)); + SetValueToProto(GetValueNode(item.first), named_value->mutable_value()); + } +} + +void DebuggerProtoExporter::InitModelInfo() { model_.set_ir_version(debugger::IR_VERSION); } + +std::string GetDebuggerFuncGraphProtoString(const FuncGraphPtr &func_graph) { + DebuggerProtoExporter exporter; + return exporter.GetFuncGraphProtoString(func_graph); +} + +debugger::ModelProto GetDebuggerFuncGraphProto(const FuncGraphPtr &func_graph) { + DebuggerProtoExporter exporter; + return exporter.GetFuncGraphProto(func_graph); +} + +debugger::DataType GetDebuggerNumberDataType(const TypePtr &type) { + switch (type->type_id()) { + case kNumberTypeBool: + return debugger::DT_BOOL; + case kNumberTypeInt8: + return debugger::DT_INT8; + case kNumberTypeInt16: + return debugger::DT_INT16; + case kNumberTypeInt32: + return debugger::DT_INT32; + case kNumberTypeInt64: + return debugger::DT_INT64; + case kNumberTypeUInt8: + return debugger::DT_UINT8; + case kNumberTypeUInt16: + return debugger::DT_UINT16; + case kNumberTypeUInt32: + return debugger::DT_UINT32; + case kNumberTypeUInt64: + return debugger::DT_UINT64; + case kNumberTypeFloat16: + return debugger::DT_FLOAT16; + case kNumberTypeFloat32: + return debugger::DT_FLOAT32; + case kNumberTypeFloat64: + return debugger::DT_FLOAT64; + case kNumberTypeInt: + return debugger::DT_BASE_INT; + case kNumberTypeUInt: + return debugger::DT_BASE_UINT; + case kNumberTypeFloat: + return debugger::DT_BASE_FLOAT; + default: + MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name(); + } +} + +} // namespace mindspore diff --git a/mindspore/ccsrc/debug/tensor_data.h b/mindspore/ccsrc/debug/tensor_data.h new file mode 100644 index 00000000000..9704d69089b --- /dev/null +++ b/mindspore/ccsrc/debug/tensor_data.h @@ -0,0 +1,75 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_DEBUG_TENSOR_DATA_H_ +#define MINDSPORE_CCSRC_DEBUG_TENSOR_DATA_H_ + +#include +#include +#include +#include +#include "ir/tensor.h" + +namespace mindspore { +class TensorData { + private: + mindspore::tensor::TensorPtr tensor_ptr; + std::string name; + size_t slot; + int execution_order; + + public: + TensorData() : slot(0), execution_order(-1) {} + + TensorData(const TensorData &obj) { + std::cout << "Copy Constructor" << std::endl; + this->name = obj.name; + this->execution_order = obj.execution_order; + this->slot = obj.slot; + this->tensor_ptr = obj.tensor_ptr; + } + + ~TensorData() {} + + std::string GetName() { return this->name; } + + mindspore::tensor::TensorPtr GetTensor() { return this->tensor_ptr; } + + size_t GetSlot() { return this->slot; } + + int GetExecutionOrder() { return this->execution_order; } + + int SetExecutionOrder(int execution_order) { + this->execution_order = execution_order; + return true; + } + + int SetName(const std::string &name) { + this->name = name; + return true; + } + + bool SetTensor(mindspore::tensor::TensorPtr out_tensor) { + this->tensor_ptr = out_tensor; + return true; + } + + bool SetSlot(size_t slot) { + this->slot = slot; + return true; + } +}; +} // namespace mindspore +#endif // MINDSPORE_CCSRC_DEBUG_TENSOR_DATA_H_ diff --git a/mindspore/ccsrc/debug/tensor_load.h b/mindspore/ccsrc/debug/tensor_load.h new file mode 100644 index 00000000000..6c3ea67a785 --- /dev/null +++ b/mindspore/ccsrc/debug/tensor_load.h @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_DEBUG_TENSOR_LOAD_H_ +#define MINDSPORE_CCSRC_DEBUG_TENSOR_LOAD_H_ + +#include +#include +#include +#include +#include +#include "debug/tensor_data.h" +namespace mindspore { +class TensorLoader { + public: + TensorLoader() : iter_num(-1) {} + + ~TensorLoader() {} + + bool LoadNewTensor(std::shared_ptr tensor) { + tensor_list.push_back(tensor); + tensor_list_map.insert({tensor->GetName(), tensor}); + return true; + } + std::vector> GetTensor() { return tensor_list; } + + uint32_t GetIterNum() { return iter_num; } + + std::map> GetTensorMap() { return tensor_list_map; } + void SearchTensors(const std::vector &search_list, + std::vector>> *result_list) { + for (auto i : search_list) { + std::map>::iterator iter; + iter = tensor_list_map.find(i); + if (iter != tensor_list_map.end()) { + result_list->push_back(std::make_tuple(i, iter->second)); + } else { + result_list->push_back(std::make_tuple(i, nullptr)); + } + } + } + + bool EmptyTensor() { + tensor_list_map.clear(); + tensor_list.clear(); + return true; + } + + void set_iter_num(uint32_t iter_num) { this->iter_num = iter_num; } + + private: + std::vector> tensor_list; + std::map> tensor_list_map; + uint32_t iter_num; +}; +} // namespace mindspore +#endif // MINDSPORE_CCSRC_DEBUG_TENSOR_LOAD_H_ diff --git a/mindspore/ccsrc/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/device/ascend/ascend_device_address.cc index a47c482c0e0..71a16607ef9 100644 --- a/mindspore/ccsrc/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/device/ascend/ascend_device_address.cc @@ -30,6 +30,10 @@ #ifdef ENABLE_DUMP_E2E #include "debug/e2e_dump.h" #endif +#ifdef ENABLE_DEBUGGER +#include "debug/tensor_load.h" +#endif + namespace mindspore { namespace device { namespace ascend { @@ -346,6 +350,52 @@ bool AscendDeviceAddress::DumpMemToFile(bool trans_flag, const std::string &file return ret; } #endif + +#ifdef ENABLE_DEBUGGER +bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tensor_name, int execution_order, + const std::string &host_fmt, const std::vector &host_shape, + TypeId host_type, size_t slot, Debugger *debugger) const { + bool ret = false; + + DebugServices *debug_services = debugger->get_debug_services(); + TensorLoader *tensor_loader = debug_services->get_tensor_loader(); + + if (trans_flag) { + MS_LOG(INFO) << "E2E tensor name is " << tensor_name; + mindspore::tensor::TensorPtr out_tensor = std::make_shared(host_type, host_shape); + size_t host_size = out_tensor->data().nbytes(); + ret = SyncDeviceToHost(host_shape, host_size, host_type, out_tensor->data_c(true)); + if (!ret) { + MS_LOG(ERROR) << "Copy device mem to host failed"; + return ret; + } + auto tensor_data = std::make_shared(); + tensor_data->SetName(tensor_name); + tensor_data->SetExecutionOrder(execution_order); + tensor_data->SetTensor(out_tensor); + tensor_data->SetSlot(slot); + ret = tensor_loader->LoadNewTensor(tensor_data); + + } else { + mindspore::tensor::TensorPtr out_tensor = std::make_shared(type_id_, host_shape); + size_t host_size = out_tensor->data().nbytes(); + auto ret_rt_memcpy = rtMemcpy(out_tensor->data_c(true), host_size, ptr_, host_size, RT_MEMCPY_DEVICE_TO_HOST); + + auto tensor_data = std::make_shared(); + tensor_data->SetName(tensor_name); + tensor_data->SetExecutionOrder(execution_order); + tensor_data->SetTensor(out_tensor); + tensor_data->SetSlot(slot); + ret = tensor_loader->LoadNewTensor(tensor_data); + if (ret_rt_memcpy != RT_ERROR_NONE) { + MS_LOG(ERROR) << "SyncDeviceToHost: rtMemcpy mem size[" << size_ << "] fail, ret[" << ret_rt_memcpy << "]"; + } + MS_LOG(INFO) << "E2E tensor name is " << tensor_name; + } + return ret; +} +#endif + } // namespace ascend } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/ascend_device_address.h b/mindspore/ccsrc/device/ascend/ascend_device_address.h index 364f9e95fda..6871abfe1b1 100644 --- a/mindspore/ccsrc/device/ascend/ascend_device_address.h +++ b/mindspore/ccsrc/device/ascend/ascend_device_address.h @@ -25,6 +25,9 @@ #include "ir/dtype.h" namespace mindspore { +#ifdef ENABLE_DEBUGGER +class Debugger; +#endif namespace device { namespace ascend { class AscendDeviceAddress : public DeviceAddress { @@ -39,6 +42,10 @@ class AscendDeviceAddress : public DeviceAddress { #ifdef ENABLE_DUMP_E2E bool DumpMemToFile(bool dump_mode, const std::string &filepath, const std::string &host_fmt, const std::vector &host_shape, TypeId host_type) const; +#endif +#ifdef ENABLE_DEBUGGER + bool LoadMemToHost(bool dump_mode, const std::string &tensor_name, int execution_order, const std::string &host_fmt, + const std::vector &host_shape, TypeId host_type, size_t slot, Debugger *debugger) const; #endif private: bool SyncDeviceToHostAndConvertFormat(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const; diff --git a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc index fb2a3f350b0..5ec1e90a61b 100644 --- a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc @@ -41,6 +41,7 @@ #include "kernel/tbe/tbe_python_funcs.h" #include "pre_activate/mem_reuse/mem_reuse_checker.h" #include "device/ascend/ascend_memory_manager.h" +#include "debug/tensor_load.h" using mindspore::device::ascend::ProfilingManager; using mindspore::device::ascend::ProfilingUtils; @@ -293,6 +294,91 @@ bool AscendKernelRuntime::DumpData(mindspore::session::KernelGraph *graph) { return true; } +#ifdef ENABLE_DEBUGGER +namespace { +void LoadOutput(mindspore::session::KernelGraph *graph, Debugger *debugger) { + MS_EXCEPTION_IF_NULL(graph); + bool trans_flag = false; + const auto &apply_kernels = graph->execution_order(); + // for kernels, execution order starts from 1 + int exec_order = 1; + for (const auto &node : apply_kernels) { + MS_EXCEPTION_IF_NULL(node); + auto node_name = AnfAlgo::GetCNodeName(node); + std::string kernel_name = node->fullname_with_scope(); + auto output_size = AnfAlgo::GetOutputTensorNum(node); + for (size_t j = 0; j < output_size; ++j) { + auto addr = AnfAlgo::GetOutputAddr(node, j); + auto type = AnfAlgo::GetOutputInferDataType(node, j); + auto format = kOpFormat_DEFAULT; + string tensor_name = kernel_name + ':' + std::to_string(j); + auto ascend_addr = dynamic_cast(addr); + std::vector int_shapes; + if (trans_flag) { + int_shapes = trans::GetRuntimePaddingShape(node, j); + } else { + auto shape = AnfAlgo::GetOutputDeviceShape(node, j); + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), + [](size_t inner_item) { return SizeToInt(inner_item); }); + } + auto ret = ascend_addr->LoadMemToHost(trans_flag, tensor_name, exec_order, format, int_shapes, type, j, debugger); + if (!ret) { + MS_LOG(ERROR) << "LoadMemToHost: flag:" << trans_flag << ", tensor_name:" << tensor_name + << ", host_format:" << format << ".!"; + } + } + exec_order = exec_order + 1; + } +} + +void LoadParameters(mindspore::session::KernelGraph *graph, Debugger *debugger) { + MS_EXCEPTION_IF_NULL(graph); + bool trans_flag = false; + const auto ¶meters = graph->inputs(); + // for parameters, set its execution order to be 0; + int exec_order = 0; + for (auto &item : parameters) { + if (!item->isa()) { + continue; + } + std::string parameter_name = item->fullname_with_scope(); + auto addr = AnfAlgo::GetOutputAddr(item, PRAMATER_OUTPUT_INDEX); + auto type = AnfAlgo::GetOutputInferDataType(item, PRAMATER_OUTPUT_INDEX); + auto format = kOpFormat_DEFAULT; + string tensor_name = parameter_name + ':' + "0"; + auto ascend_addr = dynamic_cast(addr); + std::vector int_shapes; + if (trans_flag) { + int_shapes = trans::GetRuntimePaddingShape(item, PRAMATER_OUTPUT_INDEX); + } else { + auto shape = AnfAlgo::GetOutputDeviceShape(item, PRAMATER_OUTPUT_INDEX); + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), + [](size_t inner_item) { return SizeToInt(inner_item); }); + } + auto ret = ascend_addr->LoadMemToHost(trans_flag, tensor_name, exec_order, format, int_shapes, type, 0, debugger); + if (!ret) { + MS_LOG(ERROR) << "LoadMemToHost Failed: flag:" << trans_flag << ", path:" << tensor_name + << ", host_format:" << format << ".!"; + } + } +} +} // namespace +#endif + +bool AscendKernelRuntime::LoadData(mindspore::session::KernelGraph *graph, Debugger *debugger) { + MS_EXCEPTION_IF_NULL(graph); +#ifdef ENABLE_DEBUGGER + MS_LOG(INFO) << "start load step"; + uint32_t cur_iter = 0; + MS_LOG(INFO) << "cur iter is " << cur_iter; + // load output + LoadOutput(graph, debugger); + // load parameters + LoadParameters(graph, debugger); +#endif + return true; +} + bool AscendKernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) { if (AnfAlgo::OutputAddrExist(kernel, index)) { auto address = AnfAlgo::GetOutputAddr(kernel, index); diff --git a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h index 28076f95b78..69ba8b295a8 100644 --- a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h +++ b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h @@ -37,6 +37,7 @@ class AscendKernelRuntime : public KernelRuntime { ~AscendKernelRuntime() override; bool Init() override; bool DumpData(session::KernelGraph *graph) override; + bool LoadData(session::KernelGraph *graph, Debugger *debugger) override; bool GenTask(const session::KernelGraph *graph) override; bool RunTask(const session::KernelGraph *graph) override; bool LoadTask(const session::KernelGraph *graph) override; diff --git a/mindspore/ccsrc/device/kernel_runtime.cc b/mindspore/ccsrc/device/kernel_runtime.cc index aae21aac72f..9fb6a88076d 100644 --- a/mindspore/ccsrc/device/kernel_runtime.cc +++ b/mindspore/ccsrc/device/kernel_runtime.cc @@ -79,6 +79,14 @@ bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph) { return false; } +// for D to impl +bool KernelRuntime::LoadData(mindspore::session::KernelGraph *graph, Debugger *debugger) { + if (graph != nullptr) { + return true; + } + return false; +} + // for D to impl bool KernelRuntime::GenTask(const session::KernelGraph *graph) { if (graph != nullptr) { diff --git a/mindspore/ccsrc/device/kernel_runtime.h b/mindspore/ccsrc/device/kernel_runtime.h index bfe857f61b9..8442342e322 100644 --- a/mindspore/ccsrc/device/kernel_runtime.h +++ b/mindspore/ccsrc/device/kernel_runtime.h @@ -27,6 +27,9 @@ #ifdef ENABLE_DUMP_E2E #include "debug/e2e_dump.h" #endif +#ifdef ENABLE_DEBUGGER +#include "debug/debugger/debugger.h" +#endif #include "session/kernel_graph.h" #include "session/anf_runtime_algorithm.h" #include "kernel/kernel.h" @@ -34,11 +37,15 @@ #include "device/memory_manager.h" using mindspore::tensor::Tensor; +using std::vector; using TensorPtr = std::shared_ptr; using mindspore::kernel::AddressPtr; using AddressPtrList = std::vector; namespace mindspore { +#ifndef ENABLE_DEBUGGER +class Debugger; +#endif namespace device { class KernelRuntime { public: @@ -50,6 +57,7 @@ class KernelRuntime { void RunOpClearMemory(session::KernelGraph *graph); virtual bool Run(session::KernelGraph *graph); virtual bool DumpData(session::KernelGraph *graph); + virtual bool LoadData(session::KernelGraph *graph, Debugger *debugger); virtual bool RunTask(const session::KernelGraph *graph); virtual bool GenTask(const session::KernelGraph *graph); bool LaunchKernel(const session::KernelGraph *graph); diff --git a/mindspore/ccsrc/kernel/cpu/debug_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/debug_cpu_kernel.cc new file mode 100644 index 00000000000..a1dcaca3f31 --- /dev/null +++ b/mindspore/ccsrc/kernel/cpu/debug_cpu_kernel.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "kernel/cpu/debug_cpu_kernel.h" +#include "device/cpu/cpu_device_address.h" +#include "common/utils.h" +#ifdef ENABLE_DEBUGGER +#include "debug/debugger/debugger.h" +#endif + +namespace mindspore { +namespace kernel { +void DebugCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); } + +bool DebugCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() < 1 || outputs.empty()) { + MS_LOG(EXCEPTION) << " input or output empty!"; + } + auto val = reinterpret_cast(inputs[0]->addr); + MS_LOG(DEBUG) << " launch DebugCountCPUKernel val " << *val; + + auto output = reinterpret_cast(outputs[0]->addr); + size_t elem_num = inputs[0]->size / sizeof(int); + for (size_t i = 0; i < elem_num; i++) { + output[i] = val[i]; + } + +#ifdef ENABLE_DEBUGGER + // debugger will suspend execution is neccessary + Debugger::GetInstance()->PostDebugOp(); +#endif + + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/debug_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/debug_cpu_kernel.h new file mode 100644 index 00000000000..da9f3286b95 --- /dev/null +++ b/mindspore/ccsrc/kernel/cpu/debug_cpu_kernel.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_DEBUG_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_DEBUG_CPU_KERNEL_H_ + +#include +#include +#include "kernel/cpu/cpu_kernel.h" +#include "kernel/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class DebugCPUKernel : public CPUKernel { + public: + DebugCPUKernel() = default; + ~DebugCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL(Debug, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), DebugCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_DEBUG_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc index f86cbd7fd2a..e6545d311ce 100755 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -275,5 +275,6 @@ const PrimitivePtr kPrimScalarSummary = std::make_shared("ScalarSumma const PrimitivePtr kPrimImageSummary = std::make_shared("ImageSummary"); const PrimitivePtr kPrimTensorSummary = std::make_shared("TensorSummary"); const PrimitivePtr kPrimHistogramSummary = std::make_shared("HistogramSummary"); +const PrimitivePtr kPrimDebug = std::make_shared("Debug"); } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index 65327cf407d..01812a55291 100755 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -276,6 +276,7 @@ extern const PrimitivePtr kPrimNotInDict; extern const PrimitivePtr kPrimMixedPrecisionCast; extern const PrimitivePtr kPrimIsConsant; extern const PrimitivePtr kPrimEquivFormat; +extern const PrimitivePtr kPrimDebug; // Comm ops extern const PrimitivePtr kPrimAllReduce; diff --git a/mindspore/ccsrc/operator/prim_debug.cc b/mindspore/ccsrc/operator/prim_debug.cc index a9962c6d140..5e6cdcc3183 100644 --- a/mindspore/ccsrc/operator/prim_debug.cc +++ b/mindspore/ccsrc/operator/prim_debug.cc @@ -21,5 +21,21 @@ #include "utils/symbolic.h" namespace mindspore { -namespace abstract {} // namespace abstract +namespace abstract { +AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor(value) + const std::string op_name = primitive->name(); + + CheckArgsSize(op_name, args_spec_list, 1); + auto tensor_value = CheckArg(op_name, args_spec_list, 0); + + int tensor_rank = SizeToInt(tensor_value->shape()->shape().size()); + if (tensor_rank == 0) { + MS_LOG(EXCEPTION) << op_name << " summary evaluator second arg should be an tensor, but got a scalar, rank is 0"; + } + + return std::make_shared(AbstractBasePtrList({tensor_value->Broaden()})); +} +} // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/node_check.cc b/mindspore/ccsrc/parallel/node_check.cc index 6f30a8ec1c4..6b920f82ec6 100644 --- a/mindspore/ccsrc/parallel/node_check.cc +++ b/mindspore/ccsrc/parallel/node_check.cc @@ -66,6 +66,7 @@ const std::set BLACK_LIST = {TUPLE_GETITEM, SCALARSUMMARY, IMAGESUMMARY, TENSORSUMMARY, + DEBUG, HISTOGRAMSUMMARY, COL2IMV1, RESOLVE, diff --git a/mindspore/ccsrc/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/parallel/ops_info/ops_utils.h index 4b8f61bb2e4..f9af7e26263 100644 --- a/mindspore/ccsrc/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/parallel/ops_info/ops_utils.h @@ -269,6 +269,7 @@ constexpr char SCALARSUMMARY[] = "ScalarSummary"; constexpr char IMAGESUMMARY[] = "ImageSummary"; constexpr char TENSORSUMMARY[] = "TensorSummary"; constexpr char HISTOGRAMSUMMARY[] = "HistogramSummary"; +constexpr char DEBUG[] = "Debug"; constexpr char BROADCASTGRADIENTARGS[] = "BroadcastGradientArgs"; constexpr char INVERTPERMUTATION[] = "InvertPermutation"; constexpr char CONTROLDEPEND[] = "ControlDepend"; diff --git a/mindspore/ccsrc/pipeline/pipeline.cc b/mindspore/ccsrc/pipeline/pipeline.cc index bb1f693c6bb..d346c980eab 100644 --- a/mindspore/ccsrc/pipeline/pipeline.cc +++ b/mindspore/ccsrc/pipeline/pipeline.cc @@ -445,7 +445,10 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons std::string backend = MsContext::GetInstance()->backend_policy(); if (use_vm && backend != "ge") { // Create backend and session - resource->results()[kBackend] = compile::CreateBackend(); + auto backend_ptr = compile::CreateBackend(); + // Connect session to debugger + backend_ptr->SetDebugger(); + resource->results()[kBackend] = backend_ptr; p_actions = VmPipeline(); } else { p_actions = GePipeline(); diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc index 82b83959331..bf1f319ae28 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.cc @@ -130,6 +130,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimDepend, {InferImplDepend, true}}, {prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}}, {prim::kPrimControlDepend, {InferImplControlDepend, true}}, + // Debug + {prim::kPrimDebug, {InferImplDebug, true}}, }; return prim_eval_implement_map; } diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.h b/mindspore/ccsrc/pipeline/static_analysis/prim.h index 5b910f81942..5b3972088a0 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.h +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.h @@ -346,6 +346,9 @@ AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); + +AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index bae10ed9438..7ef6551f2bb 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include #include "operator/ops.h" #include "ir/tensor.h" @@ -45,6 +46,7 @@ #include "kernel/tbe/tbe_python_funcs.h" #include "utils/config_manager.h" #include "utils/base_ref_extends.h" +#include "debug/tensor_load.h" namespace mindspore { namespace session { @@ -450,6 +452,12 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vectorPreExecute(kernel_graph); + } +#endif { py::gil_scoped_release release; // run task on device @@ -459,8 +467,20 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vectordebugger_enabled()) { + LoadTensor(kernel_graph); + } +#endif // dump used for debug Dump(kernel_graph); +#ifdef ENABLE_DEBUGGER + // debugger post-execution processing + if (debugger_) { + debugger_->PostExecute(); + } +#endif MS_LOG(INFO) << "Finish!"; } @@ -757,6 +777,22 @@ void AscendSession::ExportChildGraphs(const GraphId graph_id) { #endif } +void AscendSession::LoadTensor(const std::shared_ptr &kernel_graph) const { + MS_LOG(INFO) << "Start!"; + MS_EXCEPTION_IF_NULL(kernel_graph); +#ifdef ENABLE_DEBUGGER + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + DebugServices *debug_services = debugger_->get_debug_services(); + TensorLoader *tensor_loader = debug_services->get_tensor_loader(); + tensor_loader->EmptyTensor(); + uint32_t iter_num = tensor_loader->GetIterNum(); + tensor_loader->set_iter_num(++iter_num); + (void)runtime_instance->LoadData(kernel_graph.get(), debugger_.get()); +#endif + MS_LOG(INFO) << "Finish!"; +} + GraphId AscendSession::SetFinalGraphInput(const std::vector &args) { MS_LOG(INFO) << "Start! Args size " << args.size(); auto final_graph = NewKernelGraph(); diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h index 7857330115d..eaa01b8f804 100755 --- a/mindspore/ccsrc/session/ascend_session.h +++ b/mindspore/ccsrc/session/ascend_session.h @@ -87,6 +87,7 @@ class AscendSession : public SessionBasic { void ExecTask(const std::shared_ptr &kernel_graph) const; void Dump(const std::shared_ptr &kernel_graph) const; void ExportChildGraphs(const GraphId graph_id); + void LoadTensor(const std::shared_ptr &kernel_graph) const; // below functions are used for run op void RunOpHardwareOptimize(const std::shared_ptr &kernel_graph) const; void RunOpExecTask(const std::shared_ptr &kernel_graph) const; diff --git a/mindspore/ccsrc/session/cpu_session.cc b/mindspore/ccsrc/session/cpu_session.cc index e70e5510227..49728bc4c27 100644 --- a/mindspore/ccsrc/session/cpu_session.cc +++ b/mindspore/ccsrc/session/cpu_session.cc @@ -25,6 +25,9 @@ #include "predict/predict.h" #include "kernel/cpu/cpu_kernel_factory.h" #include "device/cpu/kernel_select_cpu.h" +#ifdef ENABLE_DEBUGGER +#include "debug/debugger/debugger.h" +#endif namespace mindspore { namespace session { @@ -78,7 +81,12 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vectorsummary_nodes(); runtime_.IncreaseSummaryRefCount(summary_outputs); } - +#ifdef ENABLE_DEBUGGER + // debugger pre-execution processing + if (debugger_) { + debugger_->PreExecute(kernel_graph); + } +#endif bool ret = runtime_.Run(kernel_graph.get()); if (!ret) { MS_LOG(EXCEPTION) << "Run graph failed"; @@ -92,6 +100,12 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vectorPostExecute(); + } +#endif MS_LOG(INFO) << "Run graph end"; } diff --git a/mindspore/ccsrc/session/session_basic.h b/mindspore/ccsrc/session/session_basic.h index 27171b75894..1ea62d8df90 100755 --- a/mindspore/ccsrc/session/session_basic.h +++ b/mindspore/ccsrc/session/session_basic.h @@ -32,6 +32,9 @@ #include "utils/contract.h" #include "pynative/pynative_execute.h" #include "device/kernel_info.h" +#ifdef ENABLE_DEBUGGER +#include "debug/debugger/debugger.h" +#endif namespace mindspore { using GraphId = uint32_t; @@ -48,7 +51,11 @@ using OpRunInfoPtr = std::shared_ptr; class SessionBasic { public: - SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) {} + SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) { +#ifdef ENABLE_DEBUGGER + debugger_ = nullptr; +#endif + } virtual void Init(uint32_t device_id) { device_id_ = device_id; } @@ -92,6 +99,14 @@ class SessionBasic { virtual void SetActive(GraphId, GraphId) {} virtual void GetSummaryNodes(KernelGraph *graph); +#ifdef ENABLE_DEBUGGER + // set debugger + void SetDebugger() { + debugger_ = Debugger::GetInstance(); + debugger_->Init(device_id_); + } +#endif + protected: virtual void LoadInputData(const std::shared_ptr &kernel_graph, const std::vector &inputs_const) const; @@ -123,6 +138,9 @@ class SessionBasic { CallBackFunc summary_callback_; static GraphId graph_sum_; uint32_t device_id_; +#ifdef ENABLE_DEBUGGER + std::shared_ptr debugger_; +#endif }; using SessionPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index a5726b078a6..1d61b0050fe 100644 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -371,6 +371,7 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {prim::kPrimImageSummary->name(), ADPT_DESC(Summary)}, {prim::kPrimTensorSummary->name(), ADPT_DESC(Summary)}, {prim::kPrimHistogramSummary->name(), ADPT_DESC(Summary)}, + {prim::kPrimDebug->name(), ADPT_DESC(Summary)}, {prim::kPrimTensorAdd->name(), std::make_shared(std::make_shared>(ExtraAttr({{"mode", MakeValue(1)}})), std::make_shared>(ExtraAttr({{"mode", MakeValue(1)}})))}, diff --git a/mindspore/ccsrc/utils/context/ms_context.cc b/mindspore/ccsrc/utils/context/ms_context.cc index 9f283319a75..d385ec7a3f6 100644 --- a/mindspore/ccsrc/utils/context/ms_context.cc +++ b/mindspore/ccsrc/utils/context/ms_context.cc @@ -69,7 +69,11 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { enable_task_sink_ = true; ir_fusion_flag_ = true; enable_hccl_ = false; +#ifdef ENABLE_DEBUGGER + enable_mem_reuse_ = false; +#else enable_mem_reuse_ = true; +#endif enable_gpu_summary_ = true; precompile_only_ = false; auto_mixed_precision_flag_ = false; diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index 3fde263c9db..1a27fcb63a1 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -362,5 +362,9 @@ GraphId MsBackend::CompileGraph(NotNull fg) { return target_sess_- VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); } +#ifdef ENABLE_DEBUGGER +void MsBackend::SetDebugger() { target_sess_->SetDebugger(); } +#endif + } // namespace compile } // namespace mindspore diff --git a/mindspore/ccsrc/vm/backend.h b/mindspore/ccsrc/vm/backend.h index 0e0b02c055d..3a93cf930f9 100644 --- a/mindspore/ccsrc/vm/backend.h +++ b/mindspore/ccsrc/vm/backend.h @@ -69,6 +69,8 @@ class Backend { bool is_switch_call() const { return is_switch_call_; } void set_simu_flag(bool simu) { simu_flag_ = simu; } + virtual void SetDebugger() {} + protected: std::string name_; LinkFuncType convert_fn_; @@ -109,6 +111,10 @@ class MsBackend : public Backend { VectorRef RunGraph(GraphId graph_id, const VectorRef &args); void CreateOtherSession(const std::string &target); +#ifdef ENABLE_DEBUGGER + void SetDebugger() override; +#endif + private: session::SessionPtr target_sess_; session::SessionPtr other_sess_; diff --git a/mindspore/ops/_grad/grad_debug_ops.py b/mindspore/ops/_grad/grad_debug_ops.py index 1cb756219a8..6e31556b149 100644 --- a/mindspore/ops/_grad/grad_debug_ops.py +++ b/mindspore/ops/_grad/grad_debug_ops.py @@ -66,3 +66,12 @@ def get_bprop_insert_gradient_of(self): def bprop(x, out, dout): return (f(dout),) return bprop + + +@bprop_getters.register(P.Debug) +def get_bprop_debug(self): + """Generate bprop for Debug""" + + def bprop(x, out, dout): + return dout + return bprop diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 792381a15f3..61932923160 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -37,7 +37,7 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast _VirtualDiv, _GetTensorSlice, HostAllGather, HostReduceScatter) from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, - TensorSummary, HistogramSummary, Print) + TensorSummary, HistogramSummary, Debug, Print) from .control_ops import ControlDepend, GeSwitch, Merge from .inner_ops import ScalarCast @@ -173,6 +173,7 @@ __all__ = [ 'ImageSummary', 'TensorSummary', 'HistogramSummary', + "Debug", "Print", 'InsertGradientOf', 'HookBackward', diff --git a/mindspore/ops/operations/debug_ops.py b/mindspore/ops/operations/debug_ops.py index c6b635a69fb..4bb51f2564b 100644 --- a/mindspore/ops/operations/debug_ops.py +++ b/mindspore/ops/operations/debug_ops.py @@ -17,7 +17,7 @@ from types import FunctionType, MethodType from ..._checkparam import Validator as validator from ...common import dtype as mstype -from ..primitive import prim_attr_register, PrimitiveWithInfer +from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive def _check_summary_param(name, value, class_name): @@ -340,3 +340,29 @@ class Print(PrimitiveWithInfer): for dtype in inputs: validator.check_subclass("input", dtype, (mstype.tensor, mstype.string), self.name) return mstype.int32 + + +class Debug(Primitive): + """ + Print tensor value. + + Inputs: + - **value** (Tensor) - The value of tensor. + + Examples: + >>> class DebugNN(nn.Cell): + >>> def __init__(self,): + >>> self.debug = nn.Debug() + >>> + >>> def construct(self, x, y): + >>> x = self.add(x, y) + >>> self.debug(x) + >>> return x + """ + + @prim_attr_register + def __init__(self): + """init""" + + def __call__(self, *args, **kwargs): + pass diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index 13f961fa246..0ba778f5c5e 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -114,6 +114,12 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/utils/node_strateg list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/utils/load_onnx/anf_converter.cc") +# remove files for debugger +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/debug/debugger/debugger.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/debug/debugger/grpc_client.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/debug/debug_services.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/debug/debugger/proto_exporter.cc") + file(GLOB_RECURSE UT_SUTB_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "stub/aicpu/*.cc" "stub/cce/*.cc"