diff --git a/mindspore/ccsrc/ps/core/abstract_node.cc b/mindspore/ccsrc/ps/core/abstract_node.cc
index 293a1fe1490..daac34850e0 100644
--- a/mindspore/ccsrc/ps/core/abstract_node.cc
+++ b/mindspore/ccsrc/ps/core/abstract_node.cc
@@ -32,6 +32,7 @@ void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
   CommMessage comm_message;
   *comm_message.mutable_pb_meta() = {message_meta};
   comm_message.set_data(register_message.SerializeAsString());
+  comm_message.set_user_cmd("");
   if (!SendMessageSync(client, comm_message)) {
     MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
                       << " the node id:" << node_info_.node_id_ << " register timeout!";
@@ -54,11 +55,12 @@ void AbstractNode::ProcessRegisterResp(const CommMessage &message) {
   MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_;
 }
 
-bool AbstractNode::Broadcast(const enum NodeRole &node_role, const std::string &message, const uint32_t &timeout) {
+bool AbstractNode::Broadcast(const enum NodeRole &node_role, const CommMessage &message, const uint32_t &timeout) {
   if (node_role != NodeRole::SERVER) {
     MS_LOG(EXCEPTION) << "Currently only supports broadcast to server nodes";
   }
 
+  CommMessage &comm_message = const_cast<CommMessage &>(message);
   uint64_t request_id = ++next_request_id_;
   message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0);
 
@@ -69,9 +71,7 @@ bool AbstractNode::Broadcast(const enum NodeRole &node_role, const std::string &
     message_meta.set_rank_id(node_info_.rank_id_);
     message_meta.set_role(node_info_.node_role_);
 
-    CommMessage comm_message;
     *comm_message.mutable_pb_meta() = {message_meta};
-    comm_message.set_data(message);
     auto client = GetOrCreateTcpClient((*it).first.second);
     client->SendMessage(comm_message);
   }
@@ -84,26 +84,26 @@ void AbstractNode::set_event_callback(const OnNodeEventMessage &on_node_event_me
   on_node_event_message_ = on_node_event_message;
 }
 
-bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
+bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message,
                         const uint32_t &timeout) {
   if (!CommUtil::ValidateRankId(node_role, rank_id)) {
     MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
   }
 
+  CommMessage &comm_message = const_cast<CommMessage &>(message);
+
   MessageMeta message_meta;
   message_meta.set_cmd(NodeCommand::SEND_DATA);
   message_meta.set_rank_id(node_info_.rank_id_);
   message_meta.set_role(node_info_.node_role_);
 
-  CommMessage comm_message;
   *comm_message.mutable_pb_meta() = {message_meta};
-  comm_message.set_data(message);
   auto client = GetOrCreateTcpClient(rank_id);
   return SendMessageSync(client, comm_message, timeout);
 }
 
 bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
-                        const std::vector<std::string> &data, const uint32_t &timeout) {
+                        const std::vector<CommMessage> &data, const uint32_t &timeout) {
   uint64_t request_id = ++next_request_id_;
   message_tracker_[request_id] = std::make_pair(data.size(), 0);
 
@@ -121,9 +121,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
     message_meta.set_rank_id(node_info_.rank_id_);
     message_meta.set_role(node_info_.node_role_);
 
-    CommMessage comm_message;
+    CommMessage &comm_message = const_cast<CommMessage &>(data.at(it));
     *comm_message.mutable_pb_meta() = {message_meta};
-    comm_message.set_data(data.at(it));
 
     auto client = GetOrCreateTcpClient(rank_ids.at(it));
     client->SendMessage(comm_message);
@@ -133,19 +132,21 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
   return Wait(request_id, timeout);
 }
 
-bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
-                        std::string *output, const uint32_t &timeout) {
+bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message,
+                        CommMessage *output, const uint32_t &timeout) {
   MS_EXCEPTION_IF_NULL(output);
   if (!CommUtil::ValidateRankId(node_role, rank_id)) {
     MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
   }
 
+  CommMessage &comm_message = const_cast<CommMessage &>(message);
+
   uint64_t request_id = ++next_request_id_;
   message_tracker_[request_id] = std::make_pair(1, 0);
   set_message_callback(request_id, [&]() {
     receive_messages_mutex_.lock();
     auto res = receive_messages_[request_id];
-    *output = res[rank_id].data();
+    *output = res[rank_id];
     receive_messages_.erase(request_id);
     receive_messages_mutex_.unlock();
   });
@@ -156,9 +157,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id,
   message_meta.set_rank_id(node_info_.rank_id_);
   message_meta.set_role(node_info_.node_role_);
 
-  CommMessage comm_message;
   *comm_message.mutable_pb_meta() = {message_meta};
-  comm_message.set_data(message);
   auto client = GetOrCreateTcpClient(rank_id);
   client->SendMessage(comm_message);
   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
@@ -167,7 +166,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id,
 }
 
 bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
-                        const std::vector<std::string> &data, std::vector<std::string> *output,
+                        const std::vector<CommMessage> &data, std::vector<CommMessage> *output,
                         const uint32_t &timeout) {
   MS_EXCEPTION_IF_NULL(output);
   uint64_t request_id = ++next_request_id_;
@@ -183,7 +182,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
     receive_messages_mutex_.lock();
     auto res = receive_messages_[request_id];
     for (size_t it = 0; it < len; ++it) {
-      (*output).push_back(res[rank_ids.at(it)].data());
+      (*output).push_back(res[rank_ids.at(it)]);
     }
     receive_messages_.erase(request_id);
     receive_messages_mutex_.unlock();
@@ -200,9 +199,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
     message_meta.set_rank_id(node_info_.rank_id_);
     message_meta.set_role(node_info_.node_role_);
 
-    CommMessage comm_message;
+    CommMessage &comm_message = const_cast<CommMessage &>(data.at(it));
     *comm_message.mutable_pb_meta() = {message_meta};
-    comm_message.set_data(data.at(it));
 
     auto client = GetOrCreateTcpClient(rank_ids.at(it));
     client->SendMessage(comm_message);
@@ -223,37 +221,37 @@ bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) {
 }
 
 uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
-                                           const std::string &message) {
+                                           const CommMessage &message) {
   if (!CommUtil::ValidateRankId(node_role, rank_id)) {
     MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
   }
 
+  CommMessage &comm_message = const_cast<CommMessage &>(message);
+
   MessageMeta message_meta;
   message_meta.set_cmd(NodeCommand::COLLECTIVE_SEND_DATA);
   message_meta.set_rank_id(node_info_.rank_id_);
   message_meta.set_role(node_info_.node_role_);
 
-  CommMessage comm_message;
   *comm_message.mutable_pb_meta() = {message_meta};
-  comm_message.set_data(message);
   auto client = GetOrCreateTcpClient(rank_id);
   return SendMessageAsync(client, comm_message);
 }
 
 std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role,
-                                                                   const uint32_t &rank_id, std::string *output) {
+                                                                   const uint32_t &rank_id, CommMessage *output) {
   if (!CommUtil::ValidateRankId(node_role, rank_id)) {
     MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
   }
 
   uint64_t rank_request_id = NextExpectedRankRequestId(rank_id);
   if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) {
-    *output = received_data_[std::make_pair(rank_id, rank_request_id)].data();
+    *output = received_data_[std::make_pair(rank_id, rank_request_id)];
     received_data_.erase(std::make_pair(rank_id, rank_request_id));
   } else {
     set_receive_callback(rank_id, rank_request_id, [=]() {
       receive_callbacks_mutex_.lock();
-      *output = received_data_[std::make_pair(rank_id, 1)].data();
+      *output = received_data_[std::make_pair(rank_id, rank_request_id)];
       received_data_.erase(std::make_pair(rank_id, rank_request_id));
       receive_callbacks_mutex_.unlock();
     });
@@ -415,21 +413,12 @@ bool AbstractNode::InitClientToScheduler() {
   uint16_t scheduler_port = ClusterConfig::scheduler_port();
   client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port);
   client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) {
-    switch (message.pb_meta().cmd()) {
-      case NodeCommand::HEARTBEAT:
-        ProcessHeartbeatResp(message);
-        break;
-      case NodeCommand::REGISTER:
-        ProcessRegisterResp(message);
-        break;
-      case NodeCommand::FETCH_SERVER:
-        ProcessFetchServersResp(message);
-        break;
-      case NodeCommand::FINISH:
-        MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a finish message response!";
-        break;
-      default:
-        MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
+    if (handlers_.count(message.pb_meta().cmd()) == 0) {
+      MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
+    }
+    if (handlers_[message.pb_meta().cmd()] != nullptr) {
+      const auto &handler_ptr = handlers_[message.pb_meta().cmd()];
+      (this->*handler_ptr)(message);
     }
     NotifyMessageArrival(message);
   });
@@ -607,6 +596,13 @@ uint64_t AbstractNode::NextActualRankRequestId(const uint32_t &rank_id) {
   }
   return rank_request_id;
 }
+
+void AbstractNode::InitCommandHandler() {
+  handlers_[NodeCommand::HEARTBEAT] = &AbstractNode::ProcessHeartbeatResp;
+  handlers_[NodeCommand::REGISTER] = &AbstractNode::ProcessRegisterResp;
+  handlers_[NodeCommand::FETCH_SERVER] = &AbstractNode::ProcessFetchServersResp;
+  handlers_[NodeCommand::FINISH] = nullptr;
+}
 }  // namespace core
 }  // namespace ps
 }  // namespace mindspore
diff --git a/mindspore/ccsrc/ps/core/abstract_node.h b/mindspore/ccsrc/ps/core/abstract_node.h
index eea8eb773da..448e36489ac 100644
--- a/mindspore/ccsrc/ps/core/abstract_node.h
+++ b/mindspore/ccsrc/ps/core/abstract_node.h
@@ -34,23 +34,25 @@ class AbstractNode : public Node {
   AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {}
   ~AbstractNode() override = default;
 
-  bool Broadcast(const enum NodeRole &node_role, const std::string &message,
+  typedef void (AbstractNode::*ResponseHandler)(const CommMessage &message);
+
+  bool Broadcast(const enum NodeRole &node_role, const CommMessage &message,
                  const uint32_t &timeout = kCommTimeoutInSeconds);
   void set_event_callback(const OnNodeEventMessage &on_node_event_message);
 
-  bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
+  bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message,
             const uint32_t &timeout = kCommTimeoutInSeconds);
-  bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
+  bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<CommMessage> &data,
             const uint32_t &timeout = kCommTimeoutInSeconds);
-  bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, std::string *output,
+  bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, CommMessage *output,
             const uint32_t &timeout = kCommTimeoutInSeconds);
-  bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
-            std::vector<std::string> *output, const uint32_t &timeout = kCommTimeoutInSeconds);
+  bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<CommMessage> &data,
+            std::vector<CommMessage> *output, const uint32_t &timeout = kCommTimeoutInSeconds);
   bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
 
-  uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message);
+  uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message);
   std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
-                                                       std::string *output);
+                                                       CommMessage *output);
   bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
 
  protected:
@@ -78,6 +80,7 @@ class AbstractNode : public Node {
   void RunReceiveCallback(const CommMessage &message);
   uint64_t NextExpectedRankRequestId(const uint32_t &rank_id);
   uint64_t NextActualRankRequestId(const uint32_t &rank_id);
+  void InitCommandHandler();
 
   std::unique_ptr<std::thread> heart_beat_thread_;
   std::unique_ptr<std::thread> client_to_scheduler_thread_;
@@ -115,6 +118,7 @@ class AbstractNode : public Node {
   std::unordered_map<uint32_t, uint64_t> actual_rank_request_ids_;
   std::mutex rank_request_ids_mutex;
   timeval scheduler_time_;
+  std::unordered_map<NodeCommand, ResponseHandler> handlers_;
 };
 }  // namespace core
 }  // namespace ps
diff --git a/mindspore/ccsrc/ps/core/protos/comm.proto b/mindspore/ccsrc/ps/core/protos/comm.proto
index 4e24de8c580..81d10137120 100644
--- a/mindspore/ccsrc/ps/core/protos/comm.proto
+++ b/mindspore/ccsrc/ps/core/protos/comm.proto
@@ -95,5 +95,6 @@ message FinishMessage {
 message CommMessage {
   MessageMeta pb_meta = 1;
   bytes data = 2;
+  // User-defined commands
+  bytes user_cmd = 3;
 }
-
diff --git a/mindspore/ccsrc/ps/core/protos/ps.proto b/mindspore/ccsrc/ps/core/protos/ps.proto
index 9ae31a94c13..7f293663a12 100644
--- a/mindspore/ccsrc/ps/core/protos/ps.proto
+++ b/mindspore/ccsrc/ps/core/protos/ps.proto
@@ -14,17 +14,42 @@
  * limitations under the License.
  */
 syntax = "proto3";
-package mindspore.ps.core;
+package mindspore.ps;
 option optimize_for = LITE_RUNTIME;
 
-enum PSCommand {
+message Command {
+  CommandCode cmd = 1;
+}
+
+enum CommandCode {
   PUSH = 0;
   PULL = 1;
   INIT_EMBEDDING_TABLE = 2;
+  INIT_WEIGHT = 3;
+  INIT_WEIGHT_TO_OPTIM_ID = 4;
+  INIT_INPUTS_SHAPE = 5;
+  CHECK_READY_FOR_PUSH = 6;
+  CHECK_READY_FOR_PULL = 7;
+  EMBEDDING_LOOKUP = 8;
+  UPDATE_EMBEDDING = 9;
+  FINALIZE = 10;
 }
 
 message KVMessage {
-  PSCommand command = 1;
   repeated int32 keys = 2;
   repeated float values = 3;
+  repeated int32 len = 4;
+}
+
+message EmbeddingTableMeta {
+  uint64 key = 1;
+  repeated uint64 input_shape = 2;
+  repeated uint64 indices_shape = 3;
+  repeated uint64 output_shape = 4;
+}
+
+message EmbeddingTableLookup {
+  uint64 key = 2;
+  repeated int32 keys = 3;
+  repeated float values = 4;
 }
\ No newline at end of file
diff --git a/mindspore/ccsrc/ps/core/scheduler_node.cc b/mindspore/ccsrc/ps/core/scheduler_node.cc
index d84fc77dc47..a3a38519fbd 100644
--- a/mindspore/ccsrc/ps/core/scheduler_node.cc
+++ b/mindspore/ccsrc/ps/core/scheduler_node.cc
@@ -67,6 +67,7 @@ void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::sha
 }
 
 void SchedulerNode::Initialize() {
+  InitCommandHandler();
   CreateTcpServer();
   is_already_stopped_ = false;
   node_info_.node_id_ = CommUtil::GenerateUUID();
@@ -75,6 +76,13 @@ void SchedulerNode::Initialize() {
                << ", the node id is:" << node_info_.node_id_;
 }
 
+void SchedulerNode::InitCommandHandler() {
+  handlers_[NodeCommand::HEARTBEAT] = &SchedulerNode::ProcessHeartbeat;
+  handlers_[NodeCommand::REGISTER] = &SchedulerNode::ProcessRegister;
+  handlers_[NodeCommand::FINISH] = &SchedulerNode::ProcessFinish;
+  handlers_[NodeCommand::FETCH_SERVER] = &SchedulerNode::ProcessFetchServers;
+}
+
 void SchedulerNode::CreateTcpServer() {
   node_manager_.InitNodeNum();
 
@@ -82,22 +90,11 @@ void SchedulerNode::CreateTcpServer() {
   uint32_t scheduler_port = ClusterConfig::scheduler_port();
   server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port);
   server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
-    switch (message->pb_meta().cmd()) {
-      case NodeCommand::HEARTBEAT:
-        ProcessHeartbeat(server_, conn, message);
-        break;
-      case NodeCommand::REGISTER:
-        ProcessRegister(server_, conn, message);
-        break;
-      case NodeCommand::FINISH:
-        ProcessFinish(server_, conn, message);
-        break;
-      case NodeCommand::FETCH_SERVER:
-        ProcessFetchServers(server_, conn, message);
-        break;
-      default:
-        MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!";
+    if (handlers_.count(message->pb_meta().cmd()) == 0) {
+      MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!";
     }
+    const auto &handler_ptr = handlers_[message->pb_meta().cmd()];
+    (this->*handler_ptr)(server_, conn, message);
   });
 
   server_->Init();
diff --git a/mindspore/ccsrc/ps/core/scheduler_node.h b/mindspore/ccsrc/ps/core/scheduler_node.h
index a476caae53e..1c89d2398dd 100644
--- a/mindspore/ccsrc/ps/core/scheduler_node.h
+++ b/mindspore/ccsrc/ps/core/scheduler_node.h
@@ -25,29 +25,32 @@
 #include <vector>
 #include <thread>
 #include <mutex>
+#include <unordered_map>
 
 #include "ps/core/cluster_config.h"
 #include "ps/core/tcp_client.h"
 #include "ps/core/tcp_server.h"
 #include "ps/core/node_manager.h"
 #include "ps/core/node.h"
-#include "utils/log_adapter.h"
 
 namespace mindspore {
 namespace ps {
 namespace core {
-
 class SchedulerNode : public Node {
  public:
   SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {}
   ~SchedulerNode() override;
 
+  typedef void (SchedulerNode::*ResponseHandler)(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
+                                                 std::shared_ptr<CommMessage> message);
+
   bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override;
   bool Stop() override;
   bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;
 
  private:
   void Initialize();
+  void InitCommandHandler();
   void CreateTcpServer();
   void ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
                         std::shared_ptr<CommMessage> message);
@@ -62,6 +65,7 @@ class SchedulerNode : public Node {
   std::shared_ptr<TcpServer> server_;
   std::unique_ptr<std::thread> scheduler_thread_;
   std::unique_ptr<std::thread> update_state_thread_;
+  std::unordered_map<NodeCommand, ResponseHandler> handlers_;
 
   NodeManager node_manager_;
 };
diff --git a/mindspore/ccsrc/ps/core/server_node.cc b/mindspore/ccsrc/ps/core/server_node.cc
index 08d0b280b80..28d09570678 100644
--- a/mindspore/ccsrc/ps/core/server_node.cc
+++ b/mindspore/ccsrc/ps/core/server_node.cc
@@ -92,6 +92,7 @@ void ServerNode::Initialize() {
   node_info_.port_ = server_->BoundPort();
   MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
                << " is generate uuid is:" << node_info_.node_id_;
+  InitCommandHandler();
   if (!InitClientToScheduler()) {
     MS_LOG(EXCEPTION) << "Server node init client timeout!";
   }
diff --git a/mindspore/ccsrc/ps/core/server_node.h b/mindspore/ccsrc/ps/core/server_node.h
index 2a0d70e82b6..086358f56e5 100644
--- a/mindspore/ccsrc/ps/core/server_node.h
+++ b/mindspore/ccsrc/ps/core/server_node.h
@@ -24,13 +24,10 @@
 #include <thread>
 #include <utility>
 
-#include "proto/comm.pb.h"
-#include "proto/ps.pb.h"
 #include "ps/core/cluster_config.h"
 #include "ps/core/tcp_client.h"
 #include "ps/core/tcp_server.h"
 #include "ps/core/abstract_node.h"
-#include "utils/log_adapter.h"
 
 namespace mindspore {
 namespace ps {
diff --git a/mindspore/ccsrc/ps/core/worker_node.cc b/mindspore/ccsrc/ps/core/worker_node.cc
index ee162e070b4..1870a499241 100644
--- a/mindspore/ccsrc/ps/core/worker_node.cc
+++ b/mindspore/ccsrc/ps/core/worker_node.cc
@@ -50,6 +50,7 @@ void WorkerNode::Initialize() {
   node_info_.node_role_ = NodeRole::WORKER;
   MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
                << ", the node id is:" << node_info_.node_id_;
+  InitCommandHandler();
   if (!InitClientToScheduler()) {
     MS_LOG(EXCEPTION) << "Worker node init client timeout!";
   }
diff --git a/mindspore/ccsrc/ps/core/worker_node.h b/mindspore/ccsrc/ps/core/worker_node.h
index a1343aa3623..8608ae430a9 100644
--- a/mindspore/ccsrc/ps/core/worker_node.h
+++ b/mindspore/ccsrc/ps/core/worker_node.h
@@ -28,7 +28,6 @@
 #include "ps/core/tcp_client.h"
 #include "ps/core/tcp_server.h"
 #include "ps/core/abstract_node.h"
-#include "utils/log_adapter.h"
 
 namespace mindspore {
 namespace ps {