diff --git a/mindspore/ccsrc/distributed/constants.h b/mindspore/ccsrc/distributed/constants.h index ebdaecf1e11..0be9ecbb69a 100644 --- a/mindspore/ccsrc/distributed/constants.h +++ b/mindspore/ccsrc/distributed/constants.h @@ -21,6 +21,7 @@ #include #include #include +#include namespace mindspore { namespace distributed { @@ -37,6 +38,14 @@ constexpr char kEnvRoleOfScheduler[] = "MS_SCHED"; const std::set kValidRoleName = {kEnvRoleOfServer, kEnvRoleOfPServer, kEnvRoleOfWorker, kEnvRoleOfScheduler}; +// Used in parameter server embedding cache scenarios to identify the same Parameter between Worker and Server. +constexpr char kParameterKey[] = "parameter_key"; +// Embedding cache lookup operation. +constexpr char kLookupEmbeddingCache[] = "LookupEmbeddingCache"; +// Embedding cache update operation. +constexpr char kUpdateEmbeddingCache[] = "UpdateEmbeddingCache"; +const std::vector kEmbeddingCacheOps = {kLookupEmbeddingCache, kUpdateEmbeddingCache}; + // The distributed execution mode enum. enum class DistExecutionMode { kPSMode = 0, kInvalidMode }; diff --git a/mindspore/ccsrc/frontend/parallel/cache_embedding/ps_embedding_cache_inserter.cc b/mindspore/ccsrc/frontend/parallel/cache_embedding/ps_embedding_cache_inserter.cc new file mode 100644 index 00000000000..c3036ef8f12 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/cache_embedding/ps_embedding_cache_inserter.cc @@ -0,0 +1,402 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/cache_embedding/ps_embedding_cache_inserter.h" + +#include +#include +#include + +#include "ir/func_graph.h" +#include "abstract/abstract_function.h" +#include "include/common/utils/anfalgo.h" +#include "include/common/utils/utils.h" +#include "utils/ms_context.h" + +namespace mindspore { +namespace parallel { +// One dimensional shape placeholder. +const ShapeVector kOneDimDynamicShape = {-1}; +// Two dimensional shape placeholder. +const ShapeVector kTwoDimsDynamicShape = {-1, -1}; +// The output tensor number of recv node. +const size_t kRecvNodeOutputNum = 3; + +void PsEmbeddingCacheInserter::GetEmbeddingLookupNodes() { + MS_EXCEPTION_IF_NULL(root_graph_); + std::vector all_nodes = DeepScopedGraphSearch(root_graph_->get_return()); + (void)std::for_each(all_nodes.begin(), all_nodes.end(), [this](const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!(node->isa() && common::AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName)) { + return; + } + const PrimitivePtr &prim = common::AnfAlgo::GetCNodePrimitive(node); + MS_EXCEPTION_IF_NULL(prim); + if (!(prim->HasAttr(distributed::kOpLabelRankId) && prim->HasAttr(distributed::kOpLabelRole))) { + return; + } + + int64_t rank_id_attr = GetValue(prim->GetAttr(distributed::kOpLabelRankId)); + std::string node_role_attr = GetValue(prim->GetAttr(distributed::kOpLabelRole)); + if (rank_id_attr == rank_id_ && node_role_attr == node_role_) { + std::vector shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, 0); + shapes_to_nodes_[shape] = node; + } + }); +} + +void PsEmbeddingCacheInserter::SetNodeAttr(const CNodePtr &node, const std::string &node_role) const { + MS_EXCEPTION_IF_NULL(node); + + // Set attr for call node, call node hasn't primitive to save attrs, so save attrs into CNode. + if (common::AnfAlgo::IsCallNode(node)) { + node->AddAttr(kAttrPrimitiveTarget, MakeValue(kCPUDevice)); + node->AddAttr(distributed::kOpLabelRankId, MakeValue(rank_id_)); + node->AddAttr(distributed::kOpLabelRole, MakeValue(node_role)); + } else { + common::AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(kCPUDevice), node); + common::AnfAlgo::SetNodeAttr(distributed::kOpLabelRankId, MakeValue(rank_id_), node); + common::AnfAlgo::SetNodeAttr(distributed::kOpLabelRole, MakeValue(node_role), node); + } +} + +void PsEmbeddingCacheInserter::SetSendNodeAttr(const CNodePtr &send_node, int32_t param_key, + const std::string &embedding_cache_op, + const std::string &dst_role) const { + MS_EXCEPTION_IF_NULL(send_node); + SetNodeAttr(send_node); + + std::vector dst_ranks; + std::vector dst_roles = {dst_role}; + std::vector inter_process_edges; + + // Set inter process edges, send dst ranks, send dst roles. + for (uint32_t i = 0; i < worker_num_; i++) { + dst_ranks.push_back(i); + dst_roles.push_back(dst_role); + // Unique edge name: src role + src rank id -> dst role + dst rank id +embedding cache operation + parameter key. + inter_process_edges.push_back(distributed::kEnvRoleOfServer + std::to_string(rank_id_) + "->" + dst_role + + std::to_string(i) + "_" + embedding_cache_op + "_" + distributed::kParameterKey + + std::to_string(param_key)); + } + + common::AnfAlgo::SetNodeAttr(kAttrSendDstRanks, MakeValue(dst_ranks), send_node); + common::AnfAlgo::SetNodeAttr(kAttrSendDstRoles, MakeValue(dst_roles), send_node); + common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeNames, MakeValue(inter_process_edges), send_node); +} + +void PsEmbeddingCacheInserter::SetRecvNodeAttr(const CNodePtr &recv_node, const std::string &src_role) const { + MS_EXCEPTION_IF_NULL(recv_node); + SetNodeAttr(recv_node); + + std::vector src_ranks; + std::vector src_roles; + std::vector inter_process_edges; + + // Set inter process edges, recv src ranks, recv src roles. + // Each server has only one Recv node, which needs to receive all requests from each worker. For example, different + // parameters on each worker have two operations: look up embedding and update embedding. Each operation will be + // performed by an independent Send node, so the Recv node on the server side will have multiple edges. + for (uint32_t i = 0; i < worker_num_; i++) { + for (const auto &item : keys_to_params_) { + int32_t param_key = item.first; + for (uint32_t k = 0; k < distributed::kEmbeddingCacheOps.size(); k++) { + src_ranks.push_back(i); + src_roles.push_back(src_role); + // Unique edge name: src role + src rank id -> dst role + dst rank id + embedding cache operation + parameter + // key. + inter_process_edges.push_back(src_role + std::to_string(i) + "->" + distributed::kEnvRoleOfServer + + std::to_string(rank_id_) + "_" + distributed::kEmbeddingCacheOps[k] + "_" + + distributed::kParameterKey + std::to_string(param_key)); + } + } + } + + common::AnfAlgo::SetNodeAttr(kAttrRecvSrcRanks, MakeValue(src_ranks), recv_node); + common::AnfAlgo::SetNodeAttr(kAttrRecvSrcRoles, MakeValue(src_roles), recv_node); + common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeNames, MakeValue(inter_process_edges), recv_node); +} + +CNodePtr PsEmbeddingCacheInserter::CreateReturnNode(const FuncGraphPtr graph, const AnfNodePtr &output_node) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(output_node); + + // Create fake output value node to make sure the output abstract is the same for each subgraph. + auto fake_output_tensor = std::make_shared(1.0); + auto fake_output_value = NewValueNode(fake_output_tensor); + MS_EXCEPTION_IF_NULL(fake_output_value); + fake_output_value->set_abstract(fake_output_tensor->ToAbstract()); + + // Create depend node. + auto depend_node = graph->NewCNode({NewValueNode(prim::kPrimDepend), fake_output_value, output_node}); + MS_EXCEPTION_IF_NULL(depend_node); + + // Create return node. + std::vector return_inputs; + return_inputs.push_back(NewValueNode(prim::kPrimReturn)); + return_inputs.push_back(depend_node); + auto return_node = graph->NewCNode(return_inputs); + MS_EXCEPTION_IF_NULL(return_node); + + return return_node; +} + +FuncGraphPtr PsEmbeddingCacheInserter::ConstructEmbeddingLookupSubGraph(const AnfNodePtr &node, + const ParameterPtr ¶m, + int32_t param_key) const { + MS_EXCEPTION_IF_NULL(param); + MS_EXCEPTION_IF_NULL(node); + + // 1. Create subgraph and parameters. + auto graph = std::make_shared(); + ParameterPtr input_param = graph->add_parameter(); + MS_EXCEPTION_IF_NULL(input_param); + MS_EXCEPTION_IF_NULL(param->abstract()); + input_param->set_abstract(param->abstract()->Clone()); + ParameterPtr input_indices = graph->add_parameter(); + MS_EXCEPTION_IF_NULL(input_indices); + input_indices->set_abstract(std::make_shared(kInt32, kOneDimDynamicShape)); + + // 2. Create EmbeddingLookup node. + PrimitivePtr emb_lookup_primitive = std::make_shared(kEmbeddingLookupOpName); + std::vector emb_lookup_inputs{NewValueNode(emb_lookup_primitive), input_param, input_indices}; + auto embedding_cache_lookup_node = graph->NewCNode(emb_lookup_inputs); + MS_EXCEPTION_IF_NULL(embedding_cache_lookup_node); + common::AnfAlgo::CopyNodeAttrs(node, embedding_cache_lookup_node); + common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), embedding_cache_lookup_node); + common::AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), embedding_cache_lookup_node); + SetNodeAttr(embedding_cache_lookup_node); + + // 3. Create RpcSend node. + std::vector send_inputs = {NewValueNode(std::make_shared(kRpcSendOpName))}; + send_inputs.push_back(embedding_cache_lookup_node); + CNodePtr send_node = graph->NewCNode(send_inputs); + MS_EXCEPTION_IF_NULL(send_node); + SetSendNodeAttr(send_node, param_key, distributed::kLookupEmbeddingCache); + + // 4. Create return node. + CNodePtr return_node = CreateReturnNode(graph, send_node); + MS_EXCEPTION_IF_NULL(return_node); + graph->set_return(return_node); + + MS_EXCEPTION_IF_NULL(root_graph_); + auto manager = root_graph_->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(graph); + return graph; +} + +FuncGraphPtr PsEmbeddingCacheInserter::ConstructUpdateEmbeddingSubGraph(const ParameterPtr ¶m, + const AnfNodePtr &node) const { + MS_EXCEPTION_IF_NULL(param); + MS_EXCEPTION_IF_NULL(node); + + // 1. Create subgraph and parameters. + auto graph = std::make_shared(); + ParameterPtr input_param = graph->add_parameter(); + MS_EXCEPTION_IF_NULL(input_param); + MS_EXCEPTION_IF_NULL(param->abstract()); + input_param->set_abstract(param->abstract()->Clone()); + ParameterPtr input_indices = graph->add_parameter(); + MS_EXCEPTION_IF_NULL(input_indices); + input_indices->set_abstract(std::make_shared(kInt32, kOneDimDynamicShape)); + ParameterPtr update_values = graph->add_parameter(); + MS_EXCEPTION_IF_NULL(update_values); + update_values->set_abstract(std::make_shared(kFloat32, kTwoDimsDynamicShape)); + + // 2. Create Sub node. + int32_t offset = LongToInt(common::AnfAlgo::GetNodeAttr(node, kAttrOffset)); + PrimitivePtr sub_primitive = std::make_shared(kSubOpName); + std::vector sub_inputs{NewValueNode(sub_primitive), input_indices, NewValueNode(MakeValue(offset))}; + auto sub_node = graph->NewCNode(sub_inputs); + MS_EXCEPTION_IF_NULL(sub_node); + SetNodeAttr(sub_node); + + // 3. Create ScatterUpdate node. + PrimitivePtr embedding_cache_update_primitive = std::make_shared(kScatterUpdateOpName); + std::vector embedding_cache_update_inputs{NewValueNode(embedding_cache_update_primitive), input_param, + sub_node, update_values}; + auto embedding_cache_update_node = graph->NewCNode(embedding_cache_update_inputs); + MS_EXCEPTION_IF_NULL(embedding_cache_update_node); + common::AnfAlgo::CopyNodeAttrs(node, embedding_cache_update_node); + common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), embedding_cache_update_node); + SetNodeAttr(embedding_cache_update_node); + + // 4. Create return node. + CNodePtr return_node = CreateReturnNode(graph, embedding_cache_update_node); + MS_EXCEPTION_IF_NULL(return_node); + graph->set_return(return_node); + + MS_EXCEPTION_IF_NULL(root_graph_); + auto manager = root_graph_->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(graph); + return graph; +} + +CNodePtr PsEmbeddingCacheInserter::CreateRecvNode() const { + // 1. Create abstract for RpcRecv node. + auto indices_abstract = std::make_shared(kInt32, kOneDimDynamicShape); + auto update_values_abstract = std::make_shared(kFloat32, kTwoDimsDynamicShape); + auto fake_id_tensor = std::make_shared(static_cast(0)); + + // 2. Create fake input nodes for RpcRecv node. + // The indices input. + MS_EXCEPTION_IF_NULL(indices_abstract->element()); + MS_EXCEPTION_IF_NULL(indices_abstract->element()->BuildType()); + MS_EXCEPTION_IF_NULL(indices_abstract->shape()); + auto fake_indices_tensor = std::make_shared(indices_abstract->element()->BuildType()->type_id(), + indices_abstract->shape()->shape()); + auto fake_indices_value = NewValueNode(fake_indices_tensor); + MS_EXCEPTION_IF_NULL(fake_indices_value); + fake_indices_value->set_abstract(fake_indices_tensor->ToAbstract()); + + // The update values input. + MS_EXCEPTION_IF_NULL(update_values_abstract->element()); + MS_EXCEPTION_IF_NULL(update_values_abstract->element()->BuildType()); + MS_EXCEPTION_IF_NULL(update_values_abstract->shape()); + auto fake_update_values_tensor = std::make_shared( + update_values_abstract->element()->BuildType()->type_id(), update_values_abstract->shape()->shape()); + auto fake_update_values_value = NewValueNode(fake_update_values_tensor); + MS_EXCEPTION_IF_NULL(fake_update_values_value); + fake_update_values_value->set_abstract(fake_update_values_tensor->ToAbstract()); + + // The id input, id is used to choose service. + auto fake_id_value = NewValueNode(fake_id_tensor); + MS_EXCEPTION_IF_NULL(fake_id_value); + fake_id_value->set_abstract(fake_id_tensor->ToAbstract()); + + // 3. Create a RpcRecv node. + std::vector recv_inputs = {NewValueNode(std::make_shared(kRpcRecvOpName))}; + recv_inputs.push_back(fake_indices_value); + recv_inputs.push_back(fake_update_values_value); + recv_inputs.push_back(fake_id_value); + MS_EXCEPTION_IF_NULL(root_graph_); + CNodePtr recv_node = root_graph_->NewCNode(recv_inputs); + MS_EXCEPTION_IF_NULL(recv_node); + SetRecvNodeAttr(recv_node); + + return recv_node; +} + +bool PsEmbeddingCacheInserter::ConstructEmbeddingCacheServicesSubGraphs( + const std::vector &recv_outputs, std::vector *make_tuple_inputs) const { + MS_EXCEPTION_IF_NULL(root_graph_); + MS_EXCEPTION_IF_NULL(make_tuple_inputs); + if (recv_outputs.size() != kRecvNodeOutputNum) { + MS_LOG(ERROR) << "The output tensor number of recv node is not equal to " << kRecvNodeOutputNum; + return false; + } + + for (const auto &item : keys_to_params_) { + int32_t key = item.first; + ParameterPtr param = item.second; + MS_EXCEPTION_IF_NULL(param); + auto shape = common::AnfAlgo::GetOutputInferShape(param, 0); + auto iter = shapes_to_nodes_.find(shape); + if (iter == shapes_to_nodes_.end()) { + MS_LOG(ERROR) << "Can not find cnode for parameter(key[" << key << "]) with shape: " << shape; + return false; + } + AnfNodePtr node = iter->second; + + // 1. Construct embedding lookup service sub graph. + auto emb_lookup_sub_graph = ConstructEmbeddingLookupSubGraph(node, param, key); + MS_EXCEPTION_IF_NULL(emb_lookup_sub_graph); + auto emb_lookup_graph_value = NewValueNode(emb_lookup_sub_graph); + MS_EXCEPTION_IF_NULL(emb_lookup_graph_value); + auto emb_lookup_graph_value_abstract = std::make_shared( + emb_lookup_sub_graph, abstract::AnalysisContext::DummyContext()); + emb_lookup_graph_value->set_abstract(emb_lookup_graph_value_abstract); + + CNodePtr emb_lookup_partial_node = + root_graph_->NewCNode({NewValueNode(prim::kPrimPartial), emb_lookup_graph_value, param, recv_outputs[0]}); + MS_EXCEPTION_IF_NULL(emb_lookup_partial_node); + make_tuple_inputs->push_back(emb_lookup_partial_node); + + // 2. Construct updating embedding service sub graph. + auto update_emb_sub_graph = ConstructUpdateEmbeddingSubGraph(param, node); + MS_EXCEPTION_IF_NULL(update_emb_sub_graph); + auto update_emb_graph_value = NewValueNode(update_emb_sub_graph); + MS_EXCEPTION_IF_NULL(update_emb_graph_value); + auto update_emb_graph_value_abstract = std::make_shared( + update_emb_sub_graph, abstract::AnalysisContext::DummyContext()); + update_emb_graph_value->set_abstract(update_emb_graph_value_abstract); + + CNodePtr update_emb_partial_node = root_graph_->NewCNode( + {NewValueNode(prim::kPrimPartial), update_emb_graph_value, param, recv_outputs[0], recv_outputs[1]}); + MS_EXCEPTION_IF_NULL(update_emb_partial_node); + make_tuple_inputs->push_back(update_emb_partial_node); + } + + return true; +} + +bool PsEmbeddingCacheInserter::ConstructEmbeddingCacheGraph() const { + // 1. Create recv node for server. + CNodePtr recv_node = CreateRecvNode(); + MS_EXCEPTION_IF_NULL(recv_node); + auto value_node_0 = NewValueNode(static_cast(0)); + auto value_node_1 = NewValueNode(static_cast(1)); + auto value_node_2 = NewValueNode(static_cast(2)); + std::vector getitem_input0{NewValueNode(prim::kPrimTupleGetItem), recv_node, value_node_0}; + std::vector getitem_input1{NewValueNode(prim::kPrimTupleGetItem), recv_node, value_node_1}; + std::vector getitem_input2{NewValueNode(prim::kPrimTupleGetItem), recv_node, value_node_2}; + + MS_EXCEPTION_IF_NULL(root_graph_); + auto getitem_0 = root_graph_->NewCNode(getitem_input0); + auto getitem_1 = root_graph_->NewCNode(getitem_input1); + auto getitem_2 = root_graph_->NewCNode(getitem_input2); + // The tuple_getitem nodes used to get the outputs of recv node. + std::vector getitems = {getitem_0, getitem_1, getitem_2}; + + std::vector make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)}; + + // 2. Construct the embedding cache services subgraphs, including embedding lookup and update operations, and + // package the subgraphs corresponding to the related operations into the partial. + RETURN_IF_FALSE_WITH_LOG(ConstructEmbeddingCacheServicesSubGraphs(getitems, &make_tuple_inputs), + "Construct embedding cache services sub graphs failed."); + + auto make_tuple_node = root_graph_->NewCNode(make_tuple_inputs); + MS_EXCEPTION_IF_NULL(make_tuple_node); + + // 3. Create switch layer and call node, used to select and execute the subgraph corresponding to the service + // requested. + std::vector switch_layer_inputs = {NewValueNode(prim::kPrimSwitchLayer), getitem_2, make_tuple_node}; + auto switch_layer_node = root_graph_->NewCNode(switch_layer_inputs); + + CNodePtr call_node = root_graph_->NewCNode({switch_layer_node}); + MS_EXCEPTION_IF_NULL(call_node); + + // 4. Replace useless nodes of origin function graph. + auto graph_manager = root_graph_->manager(); + MS_EXCEPTION_IF_NULL(graph_manager); + graph_manager->Replace(root_graph_->output(), call_node); + auto return_node = root_graph_->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + return true; +} + +bool PsEmbeddingCacheInserter::Run() { + // Get EmbeddingLookup nodes which are executed on server from origin function graph. + GetEmbeddingLookupNodes(); + + // Construct the embedding cache graph of server. + RETURN_IF_FALSE_WITH_LOG(ConstructEmbeddingCacheGraph(), "Construct embedding cache graph failed."); + return true; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/cache_embedding/ps_embedding_cache_inserter.h b/mindspore/ccsrc/frontend/parallel/cache_embedding/ps_embedding_cache_inserter.h new file mode 100644 index 00000000000..1d92260b38e --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/cache_embedding/ps_embedding_cache_inserter.h @@ -0,0 +1,112 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_EMBEDDING_CACHE_PS_EMBEDDING_CACHE_INSERTER_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_EMBEDDING_CACHE_PS_EMBEDDING_CACHE_INSERTER_H_ + +#include +#include +#include + +#include "utils/hash_map.h" +#include "ir/anf.h" +#include "distributed/constants.h" + +namespace mindspore { +namespace parallel { +// Build service-side graph for embedding distributed cache based on Parameter Server, +// and remove all nodes of origin func graph. +class PsEmbeddingCacheInserter { + public: + PsEmbeddingCacheInserter(const FuncGraphPtr &root_graph, int64_t rank_id, const std::string &node_role, + uint32_t worker_num) + : root_graph_(root_graph), rank_id_(rank_id), node_role_(node_role), worker_num_(worker_num) {} + + ~PsEmbeddingCacheInserter() { + root_graph_ = nullptr; + keys_to_params_.clear(); + shapes_to_nodes_.clear(); + } + + // Insert embedding cache sub graphs to replace all nodes of origin func graph. + bool Run(); + + private: + // Construct the embedding cache graph of server: + // Recv --> SwitchLayer --> Call --> Return + // the SwitchLayer is used to select the subgraph corresponding to the service requested to be executed. + bool ConstructEmbeddingCacheGraph() const; + + // Create RpcRecv node for server to receive request. + CNodePtr CreateRecvNode() const; + + // Construct the embedding cache services subgraphs, including embedding lookup and update operations, and package the + // subgraphs corresponding to the related operations into the partial. + bool ConstructEmbeddingCacheServicesSubGraphs(const std::vector &recv_outputs, + std::vector *make_tuple_inputs) const; + + // Construct embedding lookup service sub graph: + // Input(param, indices) --> EmbeddingLookup --> RpcSend --> Return + // RpcSend is used to send the embeddings to the service caller. + FuncGraphPtr ConstructEmbeddingLookupSubGraph(const AnfNodePtr &node, const ParameterPtr ¶m, + int32_t param_key) const; + + // Construct updating embedding service sub graph: + // Input(param, indices, update_values) --> Sub --> ScatterUpdate --> Return + // The Sub is used to rectify the id via offset for embedding slice. + FuncGraphPtr ConstructUpdateEmbeddingSubGraph(const ParameterPtr ¶m, const AnfNodePtr &node) const; + + // Create return node for subgraph, using depend node to return a fake value node to ensure that the output abstract + // of each subgraph is the same. + CNodePtr CreateReturnNode(const FuncGraphPtr graph, const AnfNodePtr &output_node) const; + + // Set device target attr to cpu, set graph split label(rank id and node role, such as (0, "MS_SERVER")). + void SetNodeAttr(const CNodePtr &node, const std::string &node_role = distributed::kEnvRoleOfServer) const; + + // Set attrs for send node, such as:inter process edges, send dst ranks, send dst roles. + void SetSendNodeAttr(const CNodePtr &send_node, int32_t param_key, const std::string &embedding_cache_op, + const std::string &dst_role = distributed::kEnvRoleOfWorker) const; + + // Set attrs for recv node, such as:inter process edges, recv src ranks, recv src roles. + void SetRecvNodeAttr(const CNodePtr &recv_node, const std::string &src_role = distributed::kEnvRoleOfWorker) const; + + // Get EmbeddingLookup nodes which are executed on server from origin function graph. + void GetEmbeddingLookupNodes(); + + // Get parameters enabled embedding cache of origin function graph. + void GetCachedParameters(); + + // Origin root function graph. + FuncGraphPtr root_graph_; + + // The rank id of this process. + int64_t rank_id_; + // The node role of this process. + std::string node_role_; + // The worker number of in cluster. + uint32_t worker_num_; + + // Record parameters enabled embedding cache of origin function graph. + // Key: parameter key, Value: ParameterPtr + mindspore::HashMap keys_to_params_; + + // Record EmbeddingLookup nodes which are executed on server from origin function graph. + // Key: shape of EmbeddingLookup node, Value: EmbeddingLookup AnfNodePtr. + std::map, AnfNodePtr> shapes_to_nodes_; +}; +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_EMBEDDING_CACHE_PS_EMBEDDING_CACHE_INSERTER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/graph_splitter.cc b/mindspore/ccsrc/frontend/parallel/graph_util/graph_splitter.cc index be48627dcb1..b0d1123c10e 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/graph_splitter.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/graph_splitter.cc @@ -1045,6 +1045,13 @@ OperatorLabel GraphSplitter::GetSplitLabel(const AnfNodePtr &node) { std::string ms_role = GetValue(prim->GetAttr(distributed::kOpLabelRole)); return {rank_id, ms_role}; } + } else { + // Get label for call node, 'call' node hasn't primitive to save attrs, so get attrs of 'call' from cnode. + if (cnode->HasAttr(distributed::kOpLabelRankId) && cnode->HasAttr(distributed::kOpLabelRole)) { + uint32_t rank_id = static_cast(GetValue(cnode->GetAttr(distributed::kOpLabelRankId))); + std::string ms_role = GetValue(cnode->GetAttr(distributed::kOpLabelRole)); + return {rank_id, ms_role}; + } } return default_label_; } diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index ba767e7ab22..750ef636e66 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -38,6 +38,7 @@ #include "frontend/parallel/step_parallel.h" #include "frontend/parallel/step_auto_parallel.h" #include "frontend/parallel/cache_embedding/cache_embedding.h" +#include "frontend/parallel/cache_embedding/ps_embedding_cache_inserter.h" #include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h" #include "frontend/optimizer/recompute.h" #include "frontend/optimizer/slice_activation_in_recompute.h" @@ -844,7 +845,46 @@ bool EnvironConversionPass(const ResourcePtr &resource) { return true; } +// Build service-side graph for embedding distributed cache based on Parameter Server. +bool AddEmbeddingCachePass(const ResourcePtr &resource) { + MS_EXCEPTION_IF_NULL(resource); +#if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__)) + if (!ps::PSContext::instance()->cache_enable() || !distributed::cluster::ClusterContext::instance()->initialized() || + !ps::PSContext::instance()->is_server()) { + return true; + } + + FuncGraphPtr func_graph = resource->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + auto node = distributed::cluster::ClusterContext::instance()->node(); + MS_EXCEPTION_IF_NULL(node); + + // 1. Build service-size graph. + auto node_role = distributed::cluster::ClusterContext::instance()->node_role(); + uint32_t worker_num = ps::PSContext::instance()->worker_num(); + std::shared_ptr embedding_cache_inserter = + std::make_shared(func_graph, static_cast(node->rank_id()), node_role, + worker_num); + if (!embedding_cache_inserter->Run()) { + MS_LOG(ERROR) << "Insert ps embedding cache failed."; + return false; + } + + // 2. Renomalize: Infer shape and Set abstract for all nodes in graph. + abstract::AbstractBasePtrList args_abs; + auto parameters = func_graph->parameters(); + (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_abs), + [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); + FuncGraphPtr new_fg = Renormalize(resource, func_graph, args_abs); + resource->set_func_graph(new_fg); + resource->set_args_abs(args_abs); +#endif + + return true; +} + std::vector kVmPasses = { + {"add_embedding_cache", AddEmbeddingCachePass}, {"simplify_data_structures", SimplifyDataStructuresPass}, {"opt_a", OptPassAGroup}, {"clean_after_opta", CleanAfterOptAPass},