!34765 Add embedding cache pass for func graph

Merge pull request !34765 from zyli2020/master
This commit is contained in:
i-robot 2022-05-30 03:09:23 +00:00 committed by Gitee
commit 378e36f26b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 570 additions and 0 deletions

View File

@ -21,6 +21,7 @@
#include <map>
#include <chrono>
#include <string>
#include <vector>
namespace mindspore {
namespace distributed {
@ -37,6 +38,14 @@ constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";
const std::set<std::string> kValidRoleName = {kEnvRoleOfServer, kEnvRoleOfPServer, kEnvRoleOfWorker,
kEnvRoleOfScheduler};
// Used in parameter server embedding cache scenarios to identify the same Parameter between Worker and Server.
constexpr char kParameterKey[] = "parameter_key";
// Embedding cache lookup operation.
constexpr char kLookupEmbeddingCache[] = "LookupEmbeddingCache";
// Embedding cache update operation.
constexpr char kUpdateEmbeddingCache[] = "UpdateEmbeddingCache";
const std::vector<std::string> kEmbeddingCacheOps = {kLookupEmbeddingCache, kUpdateEmbeddingCache};
// The distributed execution mode enum.
enum class DistExecutionMode { kPSMode = 0, kInvalidMode };

View File

@ -0,0 +1,402 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/parallel/cache_embedding/ps_embedding_cache_inserter.h"
#include <memory>
#include <string>
#include <algorithm>
#include "ir/func_graph.h"
#include "abstract/abstract_function.h"
#include "include/common/utils/anfalgo.h"
#include "include/common/utils/utils.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace parallel {
// One dimensional shape placeholder.
const ShapeVector kOneDimDynamicShape = {-1};
// Two dimensional shape placeholder.
const ShapeVector kTwoDimsDynamicShape = {-1, -1};
// The output tensor number of recv node.
const size_t kRecvNodeOutputNum = 3;
void PsEmbeddingCacheInserter::GetEmbeddingLookupNodes() {
MS_EXCEPTION_IF_NULL(root_graph_);
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(root_graph_->get_return());
(void)std::for_each(all_nodes.begin(), all_nodes.end(), [this](const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!(node->isa<CNode>() && common::AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName)) {
return;
}
const PrimitivePtr &prim = common::AnfAlgo::GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(prim);
if (!(prim->HasAttr(distributed::kOpLabelRankId) && prim->HasAttr(distributed::kOpLabelRole))) {
return;
}
int64_t rank_id_attr = GetValue<int64_t>(prim->GetAttr(distributed::kOpLabelRankId));
std::string node_role_attr = GetValue<std::string>(prim->GetAttr(distributed::kOpLabelRole));
if (rank_id_attr == rank_id_ && node_role_attr == node_role_) {
std::vector<size_t> shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
shapes_to_nodes_[shape] = node;
}
});
}
void PsEmbeddingCacheInserter::SetNodeAttr(const CNodePtr &node, const std::string &node_role) const {
MS_EXCEPTION_IF_NULL(node);
// Set attr for call node, call node hasn't primitive to save attrs, so save attrs into CNode.
if (common::AnfAlgo::IsCallNode(node)) {
node->AddAttr(kAttrPrimitiveTarget, MakeValue(kCPUDevice));
node->AddAttr(distributed::kOpLabelRankId, MakeValue(rank_id_));
node->AddAttr(distributed::kOpLabelRole, MakeValue(node_role));
} else {
common::AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(kCPUDevice), node);
common::AnfAlgo::SetNodeAttr(distributed::kOpLabelRankId, MakeValue(rank_id_), node);
common::AnfAlgo::SetNodeAttr(distributed::kOpLabelRole, MakeValue(node_role), node);
}
}
void PsEmbeddingCacheInserter::SetSendNodeAttr(const CNodePtr &send_node, int32_t param_key,
const std::string &embedding_cache_op,
const std::string &dst_role) const {
MS_EXCEPTION_IF_NULL(send_node);
SetNodeAttr(send_node);
std::vector<uint32_t> dst_ranks;
std::vector<std::string> dst_roles = {dst_role};
std::vector<std::string> inter_process_edges;
// Set inter process edges, send dst ranks, send dst roles.
for (uint32_t i = 0; i < worker_num_; i++) {
dst_ranks.push_back(i);
dst_roles.push_back(dst_role);
// Unique edge name: src role + src rank id -> dst role + dst rank id +embedding cache operation + parameter key.
inter_process_edges.push_back(distributed::kEnvRoleOfServer + std::to_string(rank_id_) + "->" + dst_role +
std::to_string(i) + "_" + embedding_cache_op + "_" + distributed::kParameterKey +
std::to_string(param_key));
}
common::AnfAlgo::SetNodeAttr(kAttrSendDstRanks, MakeValue(dst_ranks), send_node);
common::AnfAlgo::SetNodeAttr(kAttrSendDstRoles, MakeValue(dst_roles), send_node);
common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeNames, MakeValue(inter_process_edges), send_node);
}
void PsEmbeddingCacheInserter::SetRecvNodeAttr(const CNodePtr &recv_node, const std::string &src_role) const {
MS_EXCEPTION_IF_NULL(recv_node);
SetNodeAttr(recv_node);
std::vector<uint32_t> src_ranks;
std::vector<std::string> src_roles;
std::vector<std::string> inter_process_edges;
// Set inter process edges, recv src ranks, recv src roles.
// Each server has only one Recv node, which needs to receive all requests from each worker. For example, different
// parameters on each worker have two operations: look up embedding and update embedding. Each operation will be
// performed by an independent Send node, so the Recv node on the server side will have multiple edges.
for (uint32_t i = 0; i < worker_num_; i++) {
for (const auto &item : keys_to_params_) {
int32_t param_key = item.first;
for (uint32_t k = 0; k < distributed::kEmbeddingCacheOps.size(); k++) {
src_ranks.push_back(i);
src_roles.push_back(src_role);
// Unique edge name: src role + src rank id -> dst role + dst rank id + embedding cache operation + parameter
// key.
inter_process_edges.push_back(src_role + std::to_string(i) + "->" + distributed::kEnvRoleOfServer +
std::to_string(rank_id_) + "_" + distributed::kEmbeddingCacheOps[k] + "_" +
distributed::kParameterKey + std::to_string(param_key));
}
}
}
common::AnfAlgo::SetNodeAttr(kAttrRecvSrcRanks, MakeValue(src_ranks), recv_node);
common::AnfAlgo::SetNodeAttr(kAttrRecvSrcRoles, MakeValue(src_roles), recv_node);
common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeNames, MakeValue(inter_process_edges), recv_node);
}
CNodePtr PsEmbeddingCacheInserter::CreateReturnNode(const FuncGraphPtr graph, const AnfNodePtr &output_node) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(output_node);
// Create fake output value node to make sure the output abstract is the same for each subgraph.
auto fake_output_tensor = std::make_shared<tensor::Tensor>(1.0);
auto fake_output_value = NewValueNode(fake_output_tensor);
MS_EXCEPTION_IF_NULL(fake_output_value);
fake_output_value->set_abstract(fake_output_tensor->ToAbstract());
// Create depend node.
auto depend_node = graph->NewCNode({NewValueNode(prim::kPrimDepend), fake_output_value, output_node});
MS_EXCEPTION_IF_NULL(depend_node);
// Create return node.
std::vector<AnfNodePtr> return_inputs;
return_inputs.push_back(NewValueNode(prim::kPrimReturn));
return_inputs.push_back(depend_node);
auto return_node = graph->NewCNode(return_inputs);
MS_EXCEPTION_IF_NULL(return_node);
return return_node;
}
FuncGraphPtr PsEmbeddingCacheInserter::ConstructEmbeddingLookupSubGraph(const AnfNodePtr &node,
const ParameterPtr &param,
int32_t param_key) const {
MS_EXCEPTION_IF_NULL(param);
MS_EXCEPTION_IF_NULL(node);
// 1. Create subgraph and parameters.
auto graph = std::make_shared<FuncGraph>();
ParameterPtr input_param = graph->add_parameter();
MS_EXCEPTION_IF_NULL(input_param);
MS_EXCEPTION_IF_NULL(param->abstract());
input_param->set_abstract(param->abstract()->Clone());
ParameterPtr input_indices = graph->add_parameter();
MS_EXCEPTION_IF_NULL(input_indices);
input_indices->set_abstract(std::make_shared<abstract::AbstractTensor>(kInt32, kOneDimDynamicShape));
// 2. Create EmbeddingLookup node.
PrimitivePtr emb_lookup_primitive = std::make_shared<Primitive>(kEmbeddingLookupOpName);
std::vector<AnfNodePtr> emb_lookup_inputs{NewValueNode(emb_lookup_primitive), input_param, input_indices};
auto embedding_cache_lookup_node = graph->NewCNode(emb_lookup_inputs);
MS_EXCEPTION_IF_NULL(embedding_cache_lookup_node);
common::AnfAlgo::CopyNodeAttrs(node, embedding_cache_lookup_node);
common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), embedding_cache_lookup_node);
common::AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), embedding_cache_lookup_node);
SetNodeAttr(embedding_cache_lookup_node);
// 3. Create RpcSend node.
std::vector<AnfNodePtr> send_inputs = {NewValueNode(std::make_shared<Primitive>(kRpcSendOpName))};
send_inputs.push_back(embedding_cache_lookup_node);
CNodePtr send_node = graph->NewCNode(send_inputs);
MS_EXCEPTION_IF_NULL(send_node);
SetSendNodeAttr(send_node, param_key, distributed::kLookupEmbeddingCache);
// 4. Create return node.
CNodePtr return_node = CreateReturnNode(graph, send_node);
MS_EXCEPTION_IF_NULL(return_node);
graph->set_return(return_node);
MS_EXCEPTION_IF_NULL(root_graph_);
auto manager = root_graph_->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(graph);
return graph;
}
FuncGraphPtr PsEmbeddingCacheInserter::ConstructUpdateEmbeddingSubGraph(const ParameterPtr &param,
const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(param);
MS_EXCEPTION_IF_NULL(node);
// 1. Create subgraph and parameters.
auto graph = std::make_shared<FuncGraph>();
ParameterPtr input_param = graph->add_parameter();
MS_EXCEPTION_IF_NULL(input_param);
MS_EXCEPTION_IF_NULL(param->abstract());
input_param->set_abstract(param->abstract()->Clone());
ParameterPtr input_indices = graph->add_parameter();
MS_EXCEPTION_IF_NULL(input_indices);
input_indices->set_abstract(std::make_shared<abstract::AbstractTensor>(kInt32, kOneDimDynamicShape));
ParameterPtr update_values = graph->add_parameter();
MS_EXCEPTION_IF_NULL(update_values);
update_values->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, kTwoDimsDynamicShape));
// 2. Create Sub node.
int32_t offset = LongToInt(common::AnfAlgo::GetNodeAttr<int64_t>(node, kAttrOffset));
PrimitivePtr sub_primitive = std::make_shared<Primitive>(kSubOpName);
std::vector<AnfNodePtr> sub_inputs{NewValueNode(sub_primitive), input_indices, NewValueNode(MakeValue(offset))};
auto sub_node = graph->NewCNode(sub_inputs);
MS_EXCEPTION_IF_NULL(sub_node);
SetNodeAttr(sub_node);
// 3. Create ScatterUpdate node.
PrimitivePtr embedding_cache_update_primitive = std::make_shared<Primitive>(kScatterUpdateOpName);
std::vector<AnfNodePtr> embedding_cache_update_inputs{NewValueNode(embedding_cache_update_primitive), input_param,
sub_node, update_values};
auto embedding_cache_update_node = graph->NewCNode(embedding_cache_update_inputs);
MS_EXCEPTION_IF_NULL(embedding_cache_update_node);
common::AnfAlgo::CopyNodeAttrs(node, embedding_cache_update_node);
common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), embedding_cache_update_node);
SetNodeAttr(embedding_cache_update_node);
// 4. Create return node.
CNodePtr return_node = CreateReturnNode(graph, embedding_cache_update_node);
MS_EXCEPTION_IF_NULL(return_node);
graph->set_return(return_node);
MS_EXCEPTION_IF_NULL(root_graph_);
auto manager = root_graph_->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(graph);
return graph;
}
CNodePtr PsEmbeddingCacheInserter::CreateRecvNode() const {
// 1. Create abstract for RpcRecv node.
auto indices_abstract = std::make_shared<abstract::AbstractTensor>(kInt32, kOneDimDynamicShape);
auto update_values_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, kTwoDimsDynamicShape);
auto fake_id_tensor = std::make_shared<tensor::Tensor>(static_cast<int64_t>(0));
// 2. Create fake input nodes for RpcRecv node.
// The indices input.
MS_EXCEPTION_IF_NULL(indices_abstract->element());
MS_EXCEPTION_IF_NULL(indices_abstract->element()->BuildType());
MS_EXCEPTION_IF_NULL(indices_abstract->shape());
auto fake_indices_tensor = std::make_shared<tensor::Tensor>(indices_abstract->element()->BuildType()->type_id(),
indices_abstract->shape()->shape());
auto fake_indices_value = NewValueNode(fake_indices_tensor);
MS_EXCEPTION_IF_NULL(fake_indices_value);
fake_indices_value->set_abstract(fake_indices_tensor->ToAbstract());
// The update values input.
MS_EXCEPTION_IF_NULL(update_values_abstract->element());
MS_EXCEPTION_IF_NULL(update_values_abstract->element()->BuildType());
MS_EXCEPTION_IF_NULL(update_values_abstract->shape());
auto fake_update_values_tensor = std::make_shared<tensor::Tensor>(
update_values_abstract->element()->BuildType()->type_id(), update_values_abstract->shape()->shape());
auto fake_update_values_value = NewValueNode(fake_update_values_tensor);
MS_EXCEPTION_IF_NULL(fake_update_values_value);
fake_update_values_value->set_abstract(fake_update_values_tensor->ToAbstract());
// The id input, id is used to choose service.
auto fake_id_value = NewValueNode(fake_id_tensor);
MS_EXCEPTION_IF_NULL(fake_id_value);
fake_id_value->set_abstract(fake_id_tensor->ToAbstract());
// 3. Create a RpcRecv node.
std::vector<AnfNodePtr> recv_inputs = {NewValueNode(std::make_shared<Primitive>(kRpcRecvOpName))};
recv_inputs.push_back(fake_indices_value);
recv_inputs.push_back(fake_update_values_value);
recv_inputs.push_back(fake_id_value);
MS_EXCEPTION_IF_NULL(root_graph_);
CNodePtr recv_node = root_graph_->NewCNode(recv_inputs);
MS_EXCEPTION_IF_NULL(recv_node);
SetRecvNodeAttr(recv_node);
return recv_node;
}
bool PsEmbeddingCacheInserter::ConstructEmbeddingCacheServicesSubGraphs(
const std::vector<CNodePtr> &recv_outputs, std::vector<AnfNodePtr> *make_tuple_inputs) const {
MS_EXCEPTION_IF_NULL(root_graph_);
MS_EXCEPTION_IF_NULL(make_tuple_inputs);
if (recv_outputs.size() != kRecvNodeOutputNum) {
MS_LOG(ERROR) << "The output tensor number of recv node is not equal to " << kRecvNodeOutputNum;
return false;
}
for (const auto &item : keys_to_params_) {
int32_t key = item.first;
ParameterPtr param = item.second;
MS_EXCEPTION_IF_NULL(param);
auto shape = common::AnfAlgo::GetOutputInferShape(param, 0);
auto iter = shapes_to_nodes_.find(shape);
if (iter == shapes_to_nodes_.end()) {
MS_LOG(ERROR) << "Can not find cnode for parameter(key[" << key << "]) with shape: " << shape;
return false;
}
AnfNodePtr node = iter->second;
// 1. Construct embedding lookup service sub graph.
auto emb_lookup_sub_graph = ConstructEmbeddingLookupSubGraph(node, param, key);
MS_EXCEPTION_IF_NULL(emb_lookup_sub_graph);
auto emb_lookup_graph_value = NewValueNode(emb_lookup_sub_graph);
MS_EXCEPTION_IF_NULL(emb_lookup_graph_value);
auto emb_lookup_graph_value_abstract = std::make_shared<abstract::FuncGraphAbstractClosure>(
emb_lookup_sub_graph, abstract::AnalysisContext::DummyContext());
emb_lookup_graph_value->set_abstract(emb_lookup_graph_value_abstract);
CNodePtr emb_lookup_partial_node =
root_graph_->NewCNode({NewValueNode(prim::kPrimPartial), emb_lookup_graph_value, param, recv_outputs[0]});
MS_EXCEPTION_IF_NULL(emb_lookup_partial_node);
make_tuple_inputs->push_back(emb_lookup_partial_node);
// 2. Construct updating embedding service sub graph.
auto update_emb_sub_graph = ConstructUpdateEmbeddingSubGraph(param, node);
MS_EXCEPTION_IF_NULL(update_emb_sub_graph);
auto update_emb_graph_value = NewValueNode(update_emb_sub_graph);
MS_EXCEPTION_IF_NULL(update_emb_graph_value);
auto update_emb_graph_value_abstract = std::make_shared<abstract::FuncGraphAbstractClosure>(
update_emb_sub_graph, abstract::AnalysisContext::DummyContext());
update_emb_graph_value->set_abstract(update_emb_graph_value_abstract);
CNodePtr update_emb_partial_node = root_graph_->NewCNode(
{NewValueNode(prim::kPrimPartial), update_emb_graph_value, param, recv_outputs[0], recv_outputs[1]});
MS_EXCEPTION_IF_NULL(update_emb_partial_node);
make_tuple_inputs->push_back(update_emb_partial_node);
}
return true;
}
bool PsEmbeddingCacheInserter::ConstructEmbeddingCacheGraph() const {
// 1. Create recv node for server.
CNodePtr recv_node = CreateRecvNode();
MS_EXCEPTION_IF_NULL(recv_node);
auto value_node_0 = NewValueNode(static_cast<int64_t>(0));
auto value_node_1 = NewValueNode(static_cast<int64_t>(1));
auto value_node_2 = NewValueNode(static_cast<int64_t>(2));
std::vector<AnfNodePtr> getitem_input0{NewValueNode(prim::kPrimTupleGetItem), recv_node, value_node_0};
std::vector<AnfNodePtr> getitem_input1{NewValueNode(prim::kPrimTupleGetItem), recv_node, value_node_1};
std::vector<AnfNodePtr> getitem_input2{NewValueNode(prim::kPrimTupleGetItem), recv_node, value_node_2};
MS_EXCEPTION_IF_NULL(root_graph_);
auto getitem_0 = root_graph_->NewCNode(getitem_input0);
auto getitem_1 = root_graph_->NewCNode(getitem_input1);
auto getitem_2 = root_graph_->NewCNode(getitem_input2);
// The tuple_getitem nodes used to get the outputs of recv node.
std::vector<CNodePtr> getitems = {getitem_0, getitem_1, getitem_2};
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)};
// 2. Construct the embedding cache services subgraphs, including embedding lookup and update operations, and
// package the subgraphs corresponding to the related operations into the partial.
RETURN_IF_FALSE_WITH_LOG(ConstructEmbeddingCacheServicesSubGraphs(getitems, &make_tuple_inputs),
"Construct embedding cache services sub graphs failed.");
auto make_tuple_node = root_graph_->NewCNode(make_tuple_inputs);
MS_EXCEPTION_IF_NULL(make_tuple_node);
// 3. Create switch layer and call node, used to select and execute the subgraph corresponding to the service
// requested.
std::vector<AnfNodePtr> switch_layer_inputs = {NewValueNode(prim::kPrimSwitchLayer), getitem_2, make_tuple_node};
auto switch_layer_node = root_graph_->NewCNode(switch_layer_inputs);
CNodePtr call_node = root_graph_->NewCNode({switch_layer_node});
MS_EXCEPTION_IF_NULL(call_node);
// 4. Replace useless nodes of origin function graph.
auto graph_manager = root_graph_->manager();
MS_EXCEPTION_IF_NULL(graph_manager);
graph_manager->Replace(root_graph_->output(), call_node);
auto return_node = root_graph_->get_return();
MS_EXCEPTION_IF_NULL(return_node);
return true;
}
bool PsEmbeddingCacheInserter::Run() {
// Get EmbeddingLookup nodes which are executed on server from origin function graph.
GetEmbeddingLookupNodes();
// Construct the embedding cache graph of server.
RETURN_IF_FALSE_WITH_LOG(ConstructEmbeddingCacheGraph(), "Construct embedding cache graph failed.");
return true;
}
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,112 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_EMBEDDING_CACHE_PS_EMBEDDING_CACHE_INSERTER_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_EMBEDDING_CACHE_PS_EMBEDDING_CACHE_INSERTER_H_
#include <string>
#include <map>
#include <vector>
#include "utils/hash_map.h"
#include "ir/anf.h"
#include "distributed/constants.h"
namespace mindspore {
namespace parallel {
// Build service-side graph for embedding distributed cache based on Parameter Server,
// and remove all nodes of origin func graph.
class PsEmbeddingCacheInserter {
public:
PsEmbeddingCacheInserter(const FuncGraphPtr &root_graph, int64_t rank_id, const std::string &node_role,
uint32_t worker_num)
: root_graph_(root_graph), rank_id_(rank_id), node_role_(node_role), worker_num_(worker_num) {}
~PsEmbeddingCacheInserter() {
root_graph_ = nullptr;
keys_to_params_.clear();
shapes_to_nodes_.clear();
}
// Insert embedding cache sub graphs to replace all nodes of origin func graph.
bool Run();
private:
// Construct the embedding cache graph of server:
// Recv --> SwitchLayer --> Call --> Return
// the SwitchLayer is used to select the subgraph corresponding to the service requested to be executed.
bool ConstructEmbeddingCacheGraph() const;
// Create RpcRecv node for server to receive request.
CNodePtr CreateRecvNode() const;
// Construct the embedding cache services subgraphs, including embedding lookup and update operations, and package the
// subgraphs corresponding to the related operations into the partial.
bool ConstructEmbeddingCacheServicesSubGraphs(const std::vector<CNodePtr> &recv_outputs,
std::vector<AnfNodePtr> *make_tuple_inputs) const;
// Construct embedding lookup service sub graph:
// Input(param, indices) --> EmbeddingLookup --> RpcSend --> Return
// RpcSend is used to send the embeddings to the service caller.
FuncGraphPtr ConstructEmbeddingLookupSubGraph(const AnfNodePtr &node, const ParameterPtr &param,
int32_t param_key) const;
// Construct updating embedding service sub graph:
// Input(param, indices, update_values) --> Sub --> ScatterUpdate --> Return
// The Sub is used to rectify the id via offset for embedding slice.
FuncGraphPtr ConstructUpdateEmbeddingSubGraph(const ParameterPtr &param, const AnfNodePtr &node) const;
// Create return node for subgraph, using depend node to return a fake value node to ensure that the output abstract
// of each subgraph is the same.
CNodePtr CreateReturnNode(const FuncGraphPtr graph, const AnfNodePtr &output_node) const;
// Set device target attr to cpu, set graph split label(rank id and node role, such as (0, "MS_SERVER")).
void SetNodeAttr(const CNodePtr &node, const std::string &node_role = distributed::kEnvRoleOfServer) const;
// Set attrs for send node, such as:inter process edges, send dst ranks, send dst roles.
void SetSendNodeAttr(const CNodePtr &send_node, int32_t param_key, const std::string &embedding_cache_op,
const std::string &dst_role = distributed::kEnvRoleOfWorker) const;
// Set attrs for recv node, such as:inter process edges, recv src ranks, recv src roles.
void SetRecvNodeAttr(const CNodePtr &recv_node, const std::string &src_role = distributed::kEnvRoleOfWorker) const;
// Get EmbeddingLookup nodes which are executed on server from origin function graph.
void GetEmbeddingLookupNodes();
// Get parameters enabled embedding cache of origin function graph.
void GetCachedParameters();
// Origin root function graph.
FuncGraphPtr root_graph_;
// The rank id of this process.
int64_t rank_id_;
// The node role of this process.
std::string node_role_;
// The worker number of in cluster.
uint32_t worker_num_;
// Record parameters enabled embedding cache of origin function graph.
// Key: parameter key, Value: ParameterPtr
mindspore::HashMap<int32_t, ParameterPtr> keys_to_params_;
// Record EmbeddingLookup nodes which are executed on server from origin function graph.
// Key: shape of EmbeddingLookup node, Value: EmbeddingLookup AnfNodePtr.
std::map<std::vector<size_t>, AnfNodePtr> shapes_to_nodes_;
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_EMBEDDING_CACHE_PS_EMBEDDING_CACHE_INSERTER_H_

View File

@ -1045,6 +1045,13 @@ OperatorLabel GraphSplitter::GetSplitLabel(const AnfNodePtr &node) {
std::string ms_role = GetValue<std::string>(prim->GetAttr(distributed::kOpLabelRole));
return {rank_id, ms_role};
}
} else {
// Get label for call node, 'call' node hasn't primitive to save attrs, so get attrs of 'call' from cnode.
if (cnode->HasAttr(distributed::kOpLabelRankId) && cnode->HasAttr(distributed::kOpLabelRole)) {
uint32_t rank_id = static_cast<uint32_t>(GetValue<int64_t>(cnode->GetAttr(distributed::kOpLabelRankId)));
std::string ms_role = GetValue<std::string>(cnode->GetAttr(distributed::kOpLabelRole));
return {rank_id, ms_role};
}
}
return default_label_;
}

View File

@ -38,6 +38,7 @@
#include "frontend/parallel/step_parallel.h"
#include "frontend/parallel/step_auto_parallel.h"
#include "frontend/parallel/cache_embedding/cache_embedding.h"
#include "frontend/parallel/cache_embedding/ps_embedding_cache_inserter.h"
#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
#include "frontend/optimizer/recompute.h"
#include "frontend/optimizer/slice_activation_in_recompute.h"
@ -844,7 +845,46 @@ bool EnvironConversionPass(const ResourcePtr &resource) {
return true;
}
// Build service-side graph for embedding distributed cache based on Parameter Server.
bool AddEmbeddingCachePass(const ResourcePtr &resource) {
MS_EXCEPTION_IF_NULL(resource);
#if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
if (!ps::PSContext::instance()->cache_enable() || !distributed::cluster::ClusterContext::instance()->initialized() ||
!ps::PSContext::instance()->is_server()) {
return true;
}
FuncGraphPtr func_graph = resource->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
auto node = distributed::cluster::ClusterContext::instance()->node();
MS_EXCEPTION_IF_NULL(node);
// 1. Build service-size graph.
auto node_role = distributed::cluster::ClusterContext::instance()->node_role();
uint32_t worker_num = ps::PSContext::instance()->worker_num();
std::shared_ptr<parallel::PsEmbeddingCacheInserter> embedding_cache_inserter =
std::make_shared<parallel::PsEmbeddingCacheInserter>(func_graph, static_cast<int64_t>(node->rank_id()), node_role,
worker_num);
if (!embedding_cache_inserter->Run()) {
MS_LOG(ERROR) << "Insert ps embedding cache failed.";
return false;
}
// 2. Renomalize: Infer shape and Set abstract for all nodes in graph.
abstract::AbstractBasePtrList args_abs;
auto parameters = func_graph->parameters();
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_abs),
[](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
FuncGraphPtr new_fg = Renormalize(resource, func_graph, args_abs);
resource->set_func_graph(new_fg);
resource->set_args_abs(args_abs);
#endif
return true;
}
std::vector<PassItem> kVmPasses = {
{"add_embedding_cache", AddEmbeddingCachePass},
{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_a", OptPassAGroup},
{"clean_after_opta", CleanAfterOptAPass},