forked from mindspore-Ecosystem/mindspore
!19891 Sending heartbeat periodically to MI
Merge pull request !19891 from sabrinasun_59ee/dev
This commit is contained in:
commit
fe3aa8cfc2
|
@ -27,6 +27,7 @@ service EventListener {
|
|||
rpc SendTensors (stream TensorProto) returns (EventReply) {};
|
||||
rpc SendWatchpointHits (stream WatchpointHit) returns (EventReply) {};
|
||||
rpc SendMultiGraphs (stream Chunk) returns (EventReply) {};
|
||||
rpc SendHeartbeat (Heartbeat) returns (EventReply) {};
|
||||
}
|
||||
|
||||
message Metadata {
|
||||
|
@ -136,3 +137,8 @@ message WatchpointHit {
|
|||
int32 id = 3;
|
||||
int32 error_code = 4;
|
||||
}
|
||||
|
||||
message Heartbeat {
|
||||
string message = 1;
|
||||
int32 period = 2;
|
||||
}
|
||||
|
|
|
@ -59,12 +59,14 @@ using debugger::WatchpointHit;
|
|||
namespace mindspore {
|
||||
|
||||
static constexpr auto g_chunk_size = 1024 * 1024 * 3;
|
||||
static constexpr int32_t heartbeat_period_second = 30;
|
||||
DebuggerPtr Debugger::debugger_ = nullptr;
|
||||
std::mutex Debugger::instance_lock_;
|
||||
|
||||
Debugger::Debugger()
|
||||
: grpc_client_(nullptr),
|
||||
debug_services_(nullptr),
|
||||
heartbeat_thread_(nullptr),
|
||||
device_id_(0),
|
||||
device_target_(""),
|
||||
num_step_(0),
|
||||
|
@ -132,6 +134,7 @@ void Debugger::EnableDebugger() {
|
|||
partial_memory_ = false;
|
||||
grpc_client_ = nullptr;
|
||||
debug_services_ = nullptr;
|
||||
heartbeat_thread_ = nullptr;
|
||||
|
||||
// see if dump using debugger backend is enabled
|
||||
bool dump_enabled = CheckDebuggerDumpEnabled();
|
||||
|
@ -184,6 +187,8 @@ void Debugger::EnableDebugger() {
|
|||
}
|
||||
// initialize grpc client
|
||||
grpc_client_ = std::make_unique<GrpcClient>(host, port);
|
||||
// initialize sending heartbeat
|
||||
heartbeat_thread_ = std::make_unique<std::thread>([&]() { SendHeartbeat(heartbeat_period_second); });
|
||||
}
|
||||
debug_services_ = std::make_unique<DebugServices>();
|
||||
}
|
||||
|
@ -575,6 +580,38 @@ GraphProto Debugger::GetGraphProto(const KernelGraphPtr &graph_ptr) const {
|
|||
ModelProto model = GetDebuggerFuncGraphProto(graph_ptr);
|
||||
return model.graph();
|
||||
}
|
||||
|
||||
void Debugger::SendHeartbeat(int32_t period) {
|
||||
bool heartbeat_enabled_ = true;
|
||||
int num_heartbeat_fail = 0;
|
||||
const int max_num_heartbeat_fail = 5;
|
||||
const int retry_period = 500;
|
||||
|
||||
Heartbeat heartbeat;
|
||||
heartbeat.set_message("Debugger is alive");
|
||||
heartbeat.set_period(heartbeat_period_second);
|
||||
|
||||
bool run_ = CheckDebuggerEnabled() && heartbeat_enabled_;
|
||||
while (run_) {
|
||||
EventReply reply = grpc_client_->SendHeartbeat(heartbeat);
|
||||
|
||||
if (reply.status() != reply.OK) {
|
||||
MS_LOG(ERROR) << "Error: SendHeartbeat failed";
|
||||
num_heartbeat_fail++;
|
||||
if (num_heartbeat_fail >= max_num_heartbeat_fail) {
|
||||
MS_LOG(ERROR) << "Maximum number of failure for SendHeartbeat reached : exiting training session.";
|
||||
Exit();
|
||||
run_ = false;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Number of consecutive SendHeartbeat fail:" << num_heartbeat_fail;
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(retry_period));
|
||||
}
|
||||
} else {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(period * 1000));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Debugger::SendGraphAndSuspend(const GraphProto &graph_proto) {
|
||||
if (SendMetadata(true)) {
|
||||
// send graph to Mindinsight server
|
||||
|
|
|
@ -195,6 +195,9 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
|
|||
// serialize graph and get proto
|
||||
GraphProto GetGraphProto(const KernelGraphPtr &graph_ptr) const;
|
||||
|
||||
// send heartbeat message to UI once per 30 second by default
|
||||
void SendHeartbeat(int32_t period);
|
||||
|
||||
// send graph and enter command wait loop
|
||||
void SendGraphAndSuspend(const GraphProto &graph_proto);
|
||||
|
||||
|
@ -244,6 +247,7 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
|
|||
|
||||
std::unique_ptr<GrpcClient> grpc_client_;
|
||||
std::unique_ptr<DebugServices> debug_services_;
|
||||
std::unique_ptr<std::thread> heartbeat_thread_;
|
||||
KernelGraphPtr graph_ptr_;
|
||||
uint32_t device_id_;
|
||||
std::string device_target_;
|
||||
|
|
|
@ -24,6 +24,7 @@ using debugger::EventListener;
|
|||
using debugger::EventReply;
|
||||
using debugger::EventReply_Status_FAILED;
|
||||
using debugger::GraphProto;
|
||||
using debugger::Heartbeat;
|
||||
using debugger::Metadata;
|
||||
using debugger::TensorProto;
|
||||
using debugger::WatchpointHit;
|
||||
|
@ -185,4 +186,18 @@ EventReply GrpcClient::SendWatchpointHits(const std::list<WatchpointHit> &watchp
|
|||
}
|
||||
return reply;
|
||||
}
|
||||
|
||||
EventReply GrpcClient::SendHeartbeat(const Heartbeat &heartbeat) {
|
||||
EventReply reply;
|
||||
grpc::ClientContext context;
|
||||
|
||||
grpc::Status status = stub_->SendHeartbeat(&context, heartbeat, &reply);
|
||||
|
||||
if (!status.ok()) {
|
||||
MS_LOG(ERROR) << "RPC failed: SendHeartbeat";
|
||||
MS_LOG(ERROR) << status.error_code() << ": " << status.error_message();
|
||||
reply.set_status(EventReply_Status_FAILED);
|
||||
}
|
||||
return reply;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,6 +27,7 @@ using debugger::Chunk;
|
|||
using debugger::EventListener;
|
||||
using debugger::EventReply;
|
||||
using debugger::GraphProto;
|
||||
using debugger::Heartbeat;
|
||||
using debugger::Metadata;
|
||||
using debugger::TensorProto;
|
||||
using debugger::WatchpointHit;
|
||||
|
@ -60,6 +61,8 @@ class GrpcClient {
|
|||
|
||||
std::vector<std::string> ChunkString(std::string str, int graph_size);
|
||||
|
||||
EventReply SendHeartbeat(const Heartbeat &heartbeat);
|
||||
|
||||
private:
|
||||
std::unique_ptr<EventListener::Stub> stub_;
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue