forked from mindspore-Ecosystem/mindspore
Add partial memory reuse support to debugger
move pre-execution of debugger from rungraph to build/compile graph support partial mem reuse for a scope of nodes set default mem reuse to be true for debugger remove some redundant lines remove redundant code and fix a bug for supporting partial no mem reuse a scope of nodes resolve CI errors Solve CI errors solve cpplint errors solve CI build error manually fix the CI compile UT error Optimize code for mem reuse support Debug optimization of debugger memory reuse debug code for debugger memory reuse part2 address clang-format errors Switch memory reuse on and off based on environment variable Fix typo Fix typo Load watchpoint value only fix bugs Addressed comments from lupengcheng fix typo Fix typo fix CI errors refactor some code fix typo addressed comments from canadian teamates remove locking from TensorLoader fix CI errors add lock to tensor_loader fix rebase-to-master conflict fix rebase conflicts fix rebase conflicts part 2 fix rebase conflicts part 3
This commit is contained in:
parent
bed93a9ead
commit
6bb2182134
|
@ -13,13 +13,16 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/mem_reuse/mem_reuse_allocator.h"
|
||||
#include "backend/optimizer/mem_reuse/mem_reuse.h"
|
||||
#include "backend/optimizer/mem_reuse/mem_reuse_checker.h"
|
||||
#ifdef ENABLE_D
|
||||
#include "runtime/device/ascend/ascend_stream_assign.h"
|
||||
#endif
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
#include "debug/debugger/debugger.h"
|
||||
#include "debug/debug_services.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace memreuse {
|
||||
|
@ -75,6 +78,15 @@ bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr
|
|||
MS_EXCEPTION_IF_NULL(mem_buf);
|
||||
auto kernel_prev = mem_buf->used_kernel_;
|
||||
MS_EXCEPTION_IF_NULL(kernel_prev);
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
auto debugger_ = mindspore::Debugger::GetInstance();
|
||||
DebugServices *debug_services = debugger_->debug_services();
|
||||
auto watchpoint_table = debug_services->GetWatchpointTable();
|
||||
std::string current_kernel_name = kernel_curr->scope_full_name();
|
||||
if (debug_services->IsWatchPoint(current_kernel_name, watchpoint_table)) {
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
auto curr_stream_id = kernel_curr->stream_id();
|
||||
auto prev_stream_id = kernel_prev->stream_id();
|
||||
if (curr_stream_id == prev_stream_id) {
|
||||
|
|
|
@ -331,6 +331,11 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
|||
device::KernelAdjust::GetInstance().Profiling(NOT_NULL(root_graph.get()));
|
||||
// build kernel
|
||||
BuildKernel(root_graph);
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
if (debugger_) {
|
||||
debugger_->PreExecute(root_graph);
|
||||
}
|
||||
#endif
|
||||
// alloc mem
|
||||
MemoryAlloc(root_graph.get());
|
||||
// task generate
|
||||
|
@ -407,6 +412,11 @@ void AscendSession::BuildGraph(GraphId graph_id) {
|
|||
BuildKernel(graph);
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
if (debugger_) {
|
||||
debugger_->PreExecute(graph);
|
||||
}
|
||||
#endif
|
||||
if (ms_context->precompile_only()) {
|
||||
MS_LOG(INFO) << "Precompile only, stop in build kernel step";
|
||||
} else {
|
||||
|
@ -475,12 +485,6 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::
|
|||
LoadInputData(kernel_graph, inputs);
|
||||
// convert inputs to model
|
||||
predictmodel::StepConvertWeight(inputs);
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
// debugger pre-execution processing
|
||||
if (debugger_) {
|
||||
debugger_->PreExecute(kernel_graph);
|
||||
}
|
||||
#endif
|
||||
{
|
||||
py::gil_scoped_release release;
|
||||
// run task on device
|
||||
|
@ -791,7 +795,8 @@ void AscendSession::LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph)
|
|||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
DebugServices *debug_services = debugger_->debug_services();
|
||||
TensorLoader *tensor_loader = debug_services->get_tensor_loader();
|
||||
TensorLoader *tensor_loader = debug_services->tensor_loader();
|
||||
// TensorData will be freed up here
|
||||
tensor_loader->EmptyTensor();
|
||||
uint32_t iter_num = tensor_loader->GetIterNum();
|
||||
tensor_loader->set_iter_num(++iter_num);
|
||||
|
|
|
@ -37,8 +37,8 @@ DebugServices &DebugServices::operator=(const DebugServices &other) {
|
|||
|
||||
DebugServices::~DebugServices() { delete tensor_loader_; }
|
||||
|
||||
void DebugServices::add_watchpoint(unsigned int id, unsigned int watch_condition,
|
||||
const std::vector<std::tuple<std::string, bool>> &check_node_list) {
|
||||
void DebugServices::AddWatchpoint(unsigned int id, unsigned int watch_condition,
|
||||
const std::vector<std::tuple<std::string, bool>> &check_node_list) {
|
||||
std::lock_guard<std::mutex> lg(lock_);
|
||||
|
||||
watchpoint_t watchpoint_item;
|
||||
|
@ -57,14 +57,14 @@ void DebugServices::add_watchpoint(unsigned int id, unsigned int watch_condition
|
|||
watchpoint_table[id] = watchpoint_item;
|
||||
}
|
||||
|
||||
void DebugServices::remove_watchpoint(unsigned int id) {
|
||||
void DebugServices::RemoveWatchpoint(unsigned int id) {
|
||||
std::lock_guard<std::mutex> lg(lock_);
|
||||
watchpoint_table.erase(id);
|
||||
}
|
||||
|
||||
void DebugServices::check_watchpoints(std::vector<std::string> *name, std::vector<std::string> *slot,
|
||||
std::vector<char *> *data_ptr, std::vector<unsigned int> *data_size,
|
||||
std::vector<int> *condition, std::vector<unsigned int> *wacthpoint_id) {
|
||||
void DebugServices::CheckWatchpoints(std::vector<std::string> *name, std::vector<std::string> *slot,
|
||||
std::vector<char *> *data_ptr, std::vector<unsigned int> *data_size,
|
||||
std::vector<int> *condition, std::vector<unsigned int> *wacthpoint_id) {
|
||||
std::lock_guard<std::mutex> lg(lock_);
|
||||
|
||||
std::vector<std::shared_ptr<TensorData>> tensor_list = tensor_loader_->GetTensor();
|
||||
|
@ -171,9 +171,9 @@ void DebugServices::check_watchpoints(std::vector<std::string> *name, std::vecto
|
|||
}
|
||||
}
|
||||
|
||||
void DebugServices::read_nodes_tensors(std::vector<std::string> name, std::vector<std::string> *ret_name,
|
||||
std::vector<char *> *data_ptr, std::vector<unsigned int> *data_size,
|
||||
std::vector<TypePtr> *dtype, std::vector<std::vector<int>> *shape) {
|
||||
void DebugServices::ReadNodesTensors(std::vector<std::string> name, std::vector<std::string> *ret_name,
|
||||
std::vector<char *> *data_ptr, std::vector<unsigned int> *data_size,
|
||||
std::vector<TypePtr> *dtype, std::vector<std::vector<int>> *shape) {
|
||||
std::vector<std::tuple<std::string, std::shared_ptr<TensorData>>> result_list;
|
||||
tensor_loader_->SearchTensors(name, &result_list);
|
||||
|
||||
|
@ -189,6 +189,28 @@ void DebugServices::read_nodes_tensors(std::vector<std::string> name, std::vecto
|
|||
}
|
||||
}
|
||||
|
||||
TensorLoader *DebugServices::get_tensor_loader() const { return tensor_loader_; }
|
||||
bool DebugServices::IsWatchPoint(std::string kernel_name,
|
||||
std::unordered_map<unsigned int, watchpoint_t> watchpoint_table) {
|
||||
bool ret = false;
|
||||
for (auto w_table_item : watchpoint_table) {
|
||||
auto check_node_list = std::get<1>(w_table_item).check_node_list;
|
||||
for (auto check_node : check_node_list) {
|
||||
std::string w_name = std::get<0>(check_node);
|
||||
bool w_type = std::get<1>(check_node);
|
||||
if ((w_type == true &&
|
||||
((kernel_name.find(w_name) != string::npos && kernel_name.rfind(w_name, 0) == 0) || w_name == "*")) ||
|
||||
(w_type == false && kernel_name == w_name)) {
|
||||
ret = true;
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
TensorLoader *DebugServices::tensor_loader() const { return tensor_loader_; }
|
||||
std::unordered_map<unsigned int, DebugServices::watchpoint_t> DebugServices::GetWatchpointTable() {
|
||||
return watchpoint_table;
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -37,22 +37,6 @@ class DebugServices {
|
|||
|
||||
~DebugServices();
|
||||
|
||||
void add_watchpoint(unsigned int id, unsigned int watch_condition,
|
||||
const std::vector<std::tuple<std::string, bool>> &check_node_list);
|
||||
|
||||
void remove_watchpoint(unsigned int id);
|
||||
|
||||
void check_watchpoints(std::vector<std::string> *name, std::vector<std::string> *slot, std::vector<char *> *data_ptr,
|
||||
std::vector<unsigned int> *data_size, std::vector<int> *condition,
|
||||
std::vector<unsigned int> *wacthpoint_id);
|
||||
|
||||
void read_nodes_tensors(std::vector<std::string> name, std::vector<std::string> *ret_name,
|
||||
std::vector<char *> *data_ptr, std::vector<unsigned int> *data_size,
|
||||
std::vector<TypePtr> *dtype, std::vector<std::vector<int>> *shape);
|
||||
|
||||
TensorLoader *get_tensor_loader() const;
|
||||
|
||||
private:
|
||||
typedef struct condition_no_param {
|
||||
bool enabled = false;
|
||||
} condition_no_param_t;
|
||||
|
@ -84,6 +68,26 @@ class DebugServices {
|
|||
std::vector<std::tuple<std::string, bool>> check_node_list;
|
||||
} watchpoint_t;
|
||||
|
||||
void AddWatchpoint(unsigned int id, unsigned int watch_condition,
|
||||
const std::vector<std::tuple<std::string, bool>> &check_node_list);
|
||||
|
||||
void RemoveWatchpoint(unsigned int id);
|
||||
|
||||
void CheckWatchpoints(std::vector<std::string> *name, std::vector<std::string> *slot, std::vector<char *> *data_ptr,
|
||||
std::vector<unsigned int> *data_size, std::vector<int> *condition,
|
||||
std::vector<unsigned int> *wacthpoint_id);
|
||||
|
||||
void ReadNodesTensors(std::vector<std::string> name, std::vector<std::string> *ret_name,
|
||||
std::vector<char *> *data_ptr, std::vector<unsigned int> *data_size,
|
||||
std::vector<TypePtr> *dtype, std::vector<std::vector<int>> *shape);
|
||||
|
||||
bool IsWatchPoint(std::string kernel_name, std::unordered_map<unsigned int, watchpoint_t> watchpoint_table);
|
||||
|
||||
TensorLoader *tensor_loader() const;
|
||||
|
||||
std::unordered_map<unsigned int, watchpoint_t> GetWatchpointTable();
|
||||
|
||||
private:
|
||||
std::mutex lock_;
|
||||
|
||||
std::unordered_map<unsigned int, watchpoint_t> watchpoint_table;
|
||||
|
|
|
@ -43,7 +43,8 @@ Debugger::Debugger()
|
|||
device_id_(0),
|
||||
num_step_(0),
|
||||
debugger_enabled_(false),
|
||||
is_dataset_graph_(false) {}
|
||||
is_dataset_graph_(false),
|
||||
partial_memory_(false) {}
|
||||
|
||||
void Debugger::Init(const uint32_t device_id) {
|
||||
// access lock for public method
|
||||
|
@ -57,6 +58,7 @@ void Debugger::EnableDebugger() {
|
|||
// reset some of the class members
|
||||
num_step_ = 0;
|
||||
debugger_enabled_ = false;
|
||||
partial_memory_ = false;
|
||||
grpc_client_ = nullptr;
|
||||
debug_services_ = nullptr;
|
||||
|
||||
|
@ -72,7 +74,8 @@ void Debugger::EnableDebugger() {
|
|||
MS_LOG(WARNING) << "Not enabling debugger. Set environment variable ENABLE_MS_DEBUGGER=1 to enable debugger.";
|
||||
return;
|
||||
}
|
||||
// configure host
|
||||
|
||||
// configure grpc host
|
||||
const char *env_host_str = std::getenv("MS_DEBUGGER_HOST");
|
||||
std::string host;
|
||||
if (env_host_str != nullptr) {
|
||||
|
@ -82,7 +85,7 @@ void Debugger::EnableDebugger() {
|
|||
MS_LOG(WARNING) << "Environment variable MS_DEBUGGER_HOST doesn't exist. Using default debugger host: localhost";
|
||||
host = "localhost";
|
||||
}
|
||||
// configure port
|
||||
// configure grpc port
|
||||
const char *env_port_str = std::getenv("MS_DEBUGGER_PORT");
|
||||
std::string port;
|
||||
if (env_port_str != nullptr) {
|
||||
|
@ -93,6 +96,27 @@ void Debugger::EnableDebugger() {
|
|||
port = "50051";
|
||||
}
|
||||
|
||||
// configure partial memory reuse
|
||||
const char *env_partial_mem_str = std::getenv("MS_DEBUGGER_PARTIAL_MEM");
|
||||
if (env_partial_mem_str != nullptr) {
|
||||
MS_LOG(INFO) << "Getenv MS_DEBUGGER_PARTIAL_MEM: " << env_partial_mem_str;
|
||||
if (std::strcmp(env_partial_mem_str, "1") == 0) {
|
||||
partial_memory_ = true;
|
||||
}
|
||||
}
|
||||
// switch memory reuse on or off
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
context_ptr->set_enable_mem_reuse(partial_memory_);
|
||||
// print some message about memory reuse to user
|
||||
if (partial_memory_) {
|
||||
MS_LOG(WARNING) << "Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first "
|
||||
"step. 2. Tensor values are only available for nodes that are watched by any watchpoint.";
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Memory Reuse is disabled. Set environment variable MS_DEBUGGER_PARTIAL_MEM=1 to reduce memory "
|
||||
"usage for large models.";
|
||||
}
|
||||
|
||||
// initialize grpc client
|
||||
grpc_client_ = std::make_unique<GrpcClient>(host, port);
|
||||
debug_services_ = std::make_unique<DebugServices>();
|
||||
|
@ -106,6 +130,7 @@ void Debugger::Reset() {
|
|||
num_step_ = 0;
|
||||
debugger_enabled_ = false;
|
||||
is_dataset_graph_ = false;
|
||||
partial_memory_ = false;
|
||||
graph_ptr_ = nullptr;
|
||||
grpc_client_ = nullptr;
|
||||
debug_services_ = nullptr;
|
||||
|
@ -317,11 +342,10 @@ void Debugger::SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCon
|
|||
[](WatchNode node) -> std::tuple<std::string, bool> {
|
||||
return make_tuple(node.node_name(), node.node_type() == "scope");
|
||||
});
|
||||
|
||||
debug_services_->add_watchpoint(id, condition.condition(), check_node_list);
|
||||
debug_services_->AddWatchpoint(id, condition.condition(), check_node_list);
|
||||
}
|
||||
|
||||
void Debugger::RemoveWatchpoint(const int32_t id) { debug_services_->remove_watchpoint(id); }
|
||||
void Debugger::RemoveWatchpoint(const int32_t id) { debug_services_->RemoveWatchpoint(id); }
|
||||
|
||||
std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &tensors) const {
|
||||
std::vector<std::string> name;
|
||||
|
@ -335,7 +359,7 @@ std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &ten
|
|||
|
||||
// ret_name will contain tensor names that are found in TensorLoader
|
||||
// items in ret_name will be in the same order with tensors if found
|
||||
debug_services_->read_nodes_tensors(name, &ret_name, &data_ptr, &data_size, &dtype, &shape);
|
||||
debug_services_->ReadNodesTensors(name, &ret_name, &data_ptr, &data_size, &dtype, &shape);
|
||||
|
||||
std::list<TensorProto> tensor_list;
|
||||
unsigned int result_index = 0;
|
||||
|
@ -384,8 +408,7 @@ std::list<WatchpointHit> Debugger::CheckWatchpoints() const {
|
|||
std::vector<int> condition;
|
||||
std::vector<unsigned int> watchpoint_id;
|
||||
|
||||
debug_services_->check_watchpoints(&name, &slot, &data_ptr, &data_size, &condition, &watchpoint_id);
|
||||
|
||||
debug_services_->CheckWatchpoints(&name, &slot, &data_ptr, &data_size, &condition, &watchpoint_id);
|
||||
std::list<WatchpointHit> hits;
|
||||
for (unsigned int i = 0; i < name.size(); i++) {
|
||||
WatchpointHit hit;
|
||||
|
@ -494,4 +517,6 @@ std::string GetTensorFullName(const TensorProto &tensor) {
|
|||
return node_name + ":" + tensor.slot() + (tensor.iter() == "" ? "" : ":" + tensor.iter());
|
||||
}
|
||||
|
||||
bool Debugger::partial_memory() { return partial_memory_; }
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -76,6 +76,8 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
|
|||
|
||||
bool debugger_enabled() const;
|
||||
|
||||
bool partial_memory();
|
||||
|
||||
private:
|
||||
// private constructor for singleton
|
||||
Debugger();
|
||||
|
@ -129,6 +131,7 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
|
|||
int32_t num_step_;
|
||||
bool debugger_enabled_;
|
||||
bool is_dataset_graph_;
|
||||
bool partial_memory_;
|
||||
std::mutex access_lock_;
|
||||
|
||||
// singleton
|
||||
|
|
|
@ -51,25 +51,13 @@ class TensorData {
|
|||
|
||||
int GetExecutionOrder() { return this->execution_order; }
|
||||
|
||||
int SetExecutionOrder(int execution_order) {
|
||||
this->execution_order = execution_order;
|
||||
return true;
|
||||
}
|
||||
void SetExecutionOrder(int execution_order) { this->execution_order = execution_order; }
|
||||
|
||||
int SetName(const std::string &name) {
|
||||
this->name = name;
|
||||
return true;
|
||||
}
|
||||
void SetName(const std::string &name) { this->name = name; }
|
||||
|
||||
bool SetTensor(mindspore::tensor::TensorPtr out_tensor) {
|
||||
this->tensor_ptr = out_tensor;
|
||||
return true;
|
||||
}
|
||||
void SetTensor(mindspore::tensor::TensorPtr out_tensor) { this->tensor_ptr = out_tensor; }
|
||||
|
||||
bool SetSlot(size_t slot) {
|
||||
this->slot = slot;
|
||||
return true;
|
||||
}
|
||||
void SetSlot(size_t slot) { this->slot = slot; }
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_DEBUG_TENSOR_DATA_H_
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
@ -28,9 +29,10 @@ class TensorLoader {
|
|||
public:
|
||||
TensorLoader() : iter_num(-1) {}
|
||||
|
||||
~TensorLoader() {}
|
||||
~TensorLoader() { EmptyTensor(); }
|
||||
|
||||
bool LoadNewTensor(std::shared_ptr<TensorData> tensor, bool keep_prev) {
|
||||
std::lock_guard<std::mutex> lg(lock_);
|
||||
if (keep_prev) {
|
||||
// add prev step tensor into current step map with ":prev" suffix
|
||||
auto handle = prev_tensor_list_map.extract(tensor->GetName());
|
||||
|
@ -61,11 +63,11 @@ class TensorLoader {
|
|||
}
|
||||
}
|
||||
|
||||
bool EmptyTensor() {
|
||||
void EmptyTensor() {
|
||||
std::lock_guard<std::mutex> lg(lock_);
|
||||
prev_tensor_list_map.clear();
|
||||
tensor_list_map.swap(prev_tensor_list_map);
|
||||
tensor_list.clear();
|
||||
return true;
|
||||
}
|
||||
|
||||
void EmptyPrevTensor() { prev_tensor_list_map.clear(); }
|
||||
|
@ -77,6 +79,7 @@ class TensorLoader {
|
|||
std::map<std::string, std::shared_ptr<TensorData>> tensor_list_map;
|
||||
std::map<std::string, std::shared_ptr<TensorData>> prev_tensor_list_map;
|
||||
uint32_t iter_num;
|
||||
std::mutex lock_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_DEBUG_TENSOR_LOAD_H_
|
||||
|
|
|
@ -372,10 +372,13 @@ bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tens
|
|||
const std::string &host_fmt, const std::vector<int> &host_shape,
|
||||
TypeId host_type, size_t slot, Debugger *debugger, bool keep_prev) const {
|
||||
bool ret = false;
|
||||
|
||||
DebugServices *debug_services = debugger->debug_services();
|
||||
TensorLoader *tensor_loader = debug_services->get_tensor_loader();
|
||||
|
||||
TensorLoader *tensor_loader = debug_services->tensor_loader();
|
||||
// TensorData is freed up in AscendSession class
|
||||
auto tensor_data = std::make_shared<mindspore::TensorData>();
|
||||
tensor_data->SetName(tensor_name);
|
||||
tensor_data->SetExecutionOrder(execution_order);
|
||||
tensor_data->SetSlot(slot);
|
||||
if (trans_flag) {
|
||||
MS_LOG(INFO) << "E2E tensor name is " << tensor_name;
|
||||
mindspore::tensor::TensorPtr out_tensor = std::make_shared<tensor::Tensor>(host_type, host_shape);
|
||||
|
@ -385,28 +388,18 @@ bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tens
|
|||
MS_LOG(ERROR) << "Copy device mem to host failed";
|
||||
return ret;
|
||||
}
|
||||
auto tensor_data = std::make_shared<mindspore::TensorData>();
|
||||
tensor_data->SetName(tensor_name);
|
||||
tensor_data->SetExecutionOrder(execution_order);
|
||||
tensor_data->SetTensor(out_tensor);
|
||||
tensor_data->SetSlot(slot);
|
||||
ret = tensor_loader->LoadNewTensor(tensor_data, keep_prev);
|
||||
} else {
|
||||
mindspore::tensor::TensorPtr out_tensor = std::make_shared<tensor::Tensor>(type_id_, host_shape);
|
||||
size_t host_size = out_tensor->data().nbytes();
|
||||
auto ret_rt_memcpy = rtMemcpy(out_tensor->data_c(), host_size, ptr_, host_size, RT_MEMCPY_DEVICE_TO_HOST);
|
||||
|
||||
auto tensor_data = std::make_shared<mindspore::TensorData>();
|
||||
tensor_data->SetName(tensor_name);
|
||||
tensor_data->SetExecutionOrder(execution_order);
|
||||
tensor_data->SetTensor(out_tensor);
|
||||
tensor_data->SetSlot(slot);
|
||||
ret = tensor_loader->LoadNewTensor(tensor_data, keep_prev);
|
||||
if (ret_rt_memcpy != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "SyncDeviceToHost: rtMemcpy mem size[" << size_ << "] fail, ret[" << ret_rt_memcpy << "]";
|
||||
}
|
||||
MS_LOG(INFO) << "E2E tensor name is " << tensor_name;
|
||||
tensor_data->SetTensor(out_tensor);
|
||||
}
|
||||
ret = tensor_loader->LoadNewTensor(tensor_data, keep_prev);
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -311,15 +311,24 @@ bool AscendKernelRuntime::DumpData(mindspore::session::KernelGraph *graph) {
|
|||
namespace {
|
||||
void LoadOutput(mindspore::session::KernelGraph *graph, Debugger *debugger) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// trans_flag: "true" means tensor values will be transfered to host format, otherwise not.
|
||||
bool trans_flag = false;
|
||||
const auto &apply_kernels = graph->execution_order();
|
||||
// for kernels, execution order starts from 1
|
||||
int exec_order = 1;
|
||||
auto debugger_ = mindspore::Debugger::GetInstance();
|
||||
DebugServices *debug_services = debugger_->debug_services();
|
||||
auto watchpoint_table = debug_services->GetWatchpointTable();
|
||||
for (const auto &node : apply_kernels) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto node_name = AnfAlgo::GetCNodeName(node);
|
||||
std::string kernel_name = node->fullname_with_scope();
|
||||
auto output_size = AnfAlgo::GetOutputTensorNum(node);
|
||||
if (debugger_->partial_memory()) {
|
||||
if (!debug_services->IsWatchPoint(kernel_name, watchpoint_table)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
for (size_t j = 0; j < output_size; ++j) {
|
||||
auto addr = AnfAlgo::GetOutputAddr(node, j);
|
||||
auto type = AnfAlgo::GetOutputInferDataType(node, j);
|
||||
|
@ -347,6 +356,7 @@ void LoadOutput(mindspore::session::KernelGraph *graph, Debugger *debugger) {
|
|||
|
||||
void LoadParameters(mindspore::session::KernelGraph *graph, Debugger *debugger) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// trans_flag: "true" means tensor values will be transfered to host format, otherwise not.
|
||||
bool trans_flag = false;
|
||||
const auto ¶meters = graph->inputs();
|
||||
// for parameters, set its execution order to be 0;
|
||||
|
|
Loading…
Reference in New Issue