graceful exit for embedding cache mode

This commit is contained in:
lizhenyu 2022-06-13 21:49:26 +08:00
parent ce05ec064a
commit ea0fe63c04
25 changed files with 546 additions and 182 deletions

View File

@ -30,6 +30,9 @@
"mindspore/mindspore/ccsrc/frontend/optimizer/ad/kpynative.cc" "containerOutOfBounds"
"mindspore/mindspore/ccsrc/pipeline/jit/action.cc" "unreadVariable"
"mindspore/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.c" "unreadVariable"
"mindspore/mindspore/ccsrc/frontend/parallel/graph_util/graph_splitter.cc" "knownConditionTrueFalse"
"mindspore/mindspore/ccsrc/backend/graph_compiler/backend.cc" "knownConditionTrueFalse"
"mindspore/mindspore/ccsrc/backend/graph_compiler/backend.cc" "variableScope"
# MindData
"mindspore/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc" "useStlAlgorithm"

View File

@ -933,10 +933,16 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &graph_output, const Vec
void MindRTBackend::ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, const FuncGraphPtr &root_graph) {
bool need_contruct_output = !(distributed::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
distributed::recovery::RecoveryContext::GetInstance()->need_reset());
bool is_embedding_cache_server = false;
#ifdef WITH_BACKEND
is_embedding_cache_server = ps::PSContext::instance()->cache_enable() && ps::PSContext::instance()->is_server();
#endif
if (need_contruct_output) {
// Update device address for output node of graph.
// Summary processing will use the output device address, so must be after the summary processing.
actor_set->output_actor_->UpdateOutputDeviceAddress();
if (!is_embedding_cache_server) {
actor_set->output_actor_->UpdateOutputDeviceAddress();
}
// Fetch outputs.
MS_EXCEPTION_IF_NULL(actor_set->output_actor_);

View File

@ -28,10 +28,10 @@ namespace mindspore {
namespace distributed {
namespace cluster {
// The timeout in milliseconds for one lookup.
constexpr uint32_t kDefaultLookupTimeout = 60000;
constexpr uint32_t kDefaultLookupTimeout = 300000;
// The time in milliseconds between two lookup operations.
constexpr uint32_t kLookupInterval = 100;
constexpr uint32_t kLookupInterval = 3000;
// Actor route table proxy for nodes like workers and server. This class helps update actor route table in scheduler
// across the network.

View File

@ -49,7 +49,7 @@ const std::vector<std::string> kEmbeddingCacheOps = {kLookupEmbeddingCache, kUpd
constexpr char kFinalizeMuxRecvActor[] = "FINALIZE_MUX_RECV_ACTOR";
// The distributed execution mode enum.
enum class DistExecutionMode { kPSMode = 0, kInvalidMode };
enum class DistExecutionMode { kPSMode = 0, kEmbeddingCacheMode, kInvalidMode };
// The operator's label in distributed execution.
constexpr char kOpLabelRankId[] = "rank_id";

View File

@ -24,6 +24,7 @@
#include "kernel/kernel.h"
#include "distributed/embedding_cache/embedding_hash_map.h"
#include "runtime/hardware/device_context.h"
#include "include/backend/visible.h"
namespace mindspore {
namespace runtime {
@ -114,7 +115,7 @@ struct EmbeddingCacheStatisticsInfo {
// The EmbeddingCacheTableManager class is used to save all Parameter information for enabling cache, such as device
// cache size, host cache size, etc., and can allocate memory for the embedding cache table.
class EmbeddingCacheTableManager {
class BACKEND_EXPORT EmbeddingCacheTableManager {
public:
static EmbeddingCacheTableManager &GetInstance();

View File

@ -33,14 +33,71 @@ const ShapeVector kOneDimShape = {1};
// Two dimensional shape placeholder.
const ShapeVector kTwoDimsShape = {1, 1};
// One dimensional shape placeholder.
const ShapeVector kOneDimDynamicShape = {-1};
// Two dimensional shape placeholder.
const ShapeVector kTwoDimsDynamicShape = {-1, -1};
// The output tensor number of recv node.
const size_t kRecvNodeOutputNum = 3;
// The input index of offset of EmbeddingLookup kernel.
constexpr size_t kEmbeddingLookupOffsetIdx = 2;
// The dims of embedding table.
constexpr size_t kEmbeddingTableDims = 2;
constexpr char kEmbeddingRemoteCacheNode[] = "EmbeddingRemoteCacheNode";
constexpr char kEmbeddingLocalCacheNode[] = "EmbeddingLocalCacheNode";
namespace {
ValueNodePtr CreateFakeValueNode(const AnfNodePtr &origin_node) {
MS_EXCEPTION_IF_NULL(origin_node);
abstract::AbstractTensorPtr origin_abstract = origin_node->abstract()->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(origin_abstract);
tensor::TensorPtr fake_tensor = std::make_shared<tensor::Tensor>(origin_abstract->element()->BuildType()->type_id(),
origin_abstract->shape()->shape());
MS_EXCEPTION_IF_NULL(fake_tensor);
fake_tensor->set_base_shape(origin_abstract->shape()->Clone());
auto fake_value = NewValueNode(fake_tensor);
MS_EXCEPTION_IF_NULL(fake_value);
fake_value->set_abstract(fake_tensor->ToAbstract());
return fake_value;
}
AnfNodePtr CreateOutputNode(const FuncGraphPtr &func_graph, const AnfNodePtr &origin_output) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(origin_output);
MS_EXCEPTION_IF_NULL(origin_output->abstract());
if (origin_output->abstract()->isa<abstract::AbstractTuple>()) {
abstract::AbstractBasePtrList new_elements_abs;
std::vector<ValuePtr> new_elements_values;
auto tuple_elements = origin_output->abstract()->cast<abstract::AbstractTuplePtr>()->elements();
for (const auto &element : tuple_elements) {
MS_EXCEPTION_IF_NULL(element);
auto tensor_abstract = element->cast<abstract::AbstractTensorPtr>();
if (!tensor_abstract) {
MS_LOG(EXCEPTION) << "Only support to replace tuple with all tensor elements.";
}
auto fake_tensor = std::make_shared<tensor::Tensor>(tensor_abstract->element()->BuildType()->type_id(),
tensor_abstract->shape()->shape());
MS_EXCEPTION_IF_NULL(fake_tensor);
new_elements_abs.push_back(fake_tensor->ToAbstract());
new_elements_values.push_back(fake_tensor);
}
ValueTuplePtr value_tuple = std::make_shared<ValueTuple>(new_elements_values);
auto value_tuple_abs = std::make_shared<abstract::AbstractTuple>(new_elements_abs);
auto value_tuple_node = NewValueNode(value_tuple);
MS_EXCEPTION_IF_NULL(value_tuple_node);
value_tuple_node->set_abstract(value_tuple_abs);
return value_tuple_node;
} else {
return CreateFakeValueNode(origin_output);
}
}
} // namespace
void PsEmbeddingCacheInserter::GetEmbeddingLookupNodes() {
MS_EXCEPTION_IF_NULL(root_graph_);
@ -214,15 +271,21 @@ FuncGraphPtr PsEmbeddingCacheInserter::ConstructEmbeddingLookupSubGraph(const An
input_param->set_abstract(param->abstract()->Clone());
ParameterPtr input_indices = graph->add_parameter();
MS_EXCEPTION_IF_NULL(input_indices);
input_indices->set_abstract(std::make_shared<abstract::AbstractTensor>(kInt32, kOneDimShape));
input_indices->set_abstract(std::make_shared<abstract::AbstractTensor>(
kInt32, std::make_shared<abstract::Shape>(kOneDimDynamicShape, kOneDimShape, kOneDimShape)));
// 2. Create EmbeddingLookup node.
auto offset_node = common::AnfAlgo::GetInputNode(node->cast<CNodePtr>(), kEmbeddingLookupOffsetIdx);
MS_EXCEPTION_IF_NULL(offset_node);
auto offset_value_node = offset_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(offset_value_node);
int64_t offset = GetValue<int64_t>(offset_value_node->value());
PrimitivePtr emb_lookup_primitive = std::make_shared<Primitive>(kEmbeddingLookupOpName);
std::vector<AnfNodePtr> emb_lookup_inputs{
NewValueNode(emb_lookup_primitive), input_param, input_indices,
common::AnfAlgo::GetInputNode(node->cast<CNodePtr>(), kEmbeddingLookupOffsetIdx)};
std::vector<AnfNodePtr> emb_lookup_inputs{NewValueNode(emb_lookup_primitive), input_param, input_indices};
auto embedding_cache_lookup_node = graph->NewCNode(emb_lookup_inputs);
MS_EXCEPTION_IF_NULL(embedding_cache_lookup_node);
common::AnfAlgo::SetNodeAttr(kAttrOffset, MakeValue(offset), embedding_cache_lookup_node);
common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), embedding_cache_lookup_node);
common::AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), embedding_cache_lookup_node);
@ -258,35 +321,36 @@ FuncGraphPtr PsEmbeddingCacheInserter::ConstructUpdateEmbeddingSubGraph(const Pa
MS_EXCEPTION_IF_NULL(input_param);
MS_EXCEPTION_IF_NULL(param->abstract());
input_param->set_abstract(param->abstract()->Clone());
ParameterPtr input_indices = graph->add_parameter();
MS_EXCEPTION_IF_NULL(input_indices);
input_indices->set_abstract(std::make_shared<abstract::AbstractTensor>(kInt32, kOneDimShape));
input_indices->set_abstract(std::make_shared<abstract::AbstractTensor>(
kInt32, std::make_shared<abstract::Shape>(kOneDimDynamicShape, kOneDimShape, kOneDimShape)));
ParameterPtr update_values = graph->add_parameter();
MS_EXCEPTION_IF_NULL(update_values);
update_values->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, kTwoDimsShape));
std::vector<size_t> emb_shape = common::AnfAlgo::GetOutputInferShape(param, 0);
if (emb_shape.size() != kEmbeddingTableDims) {
MS_LOG(EXCEPTION) << "Embedding table should be 2 dims for embedding cache mode, but got: " << emb_shape.size()
<< " dims";
}
int64_t emb_dim = SizeToLong(emb_shape.back());
ShapeVector update_values_shape = {-1, emb_dim};
ShapeVector update_values_min_shape = {1, emb_dim};
ShapeVector update_values_max_shape = {1, emb_dim};
update_values->set_abstract(std::make_shared<abstract::AbstractTensor>(
kFloat32,
std::make_shared<abstract::Shape>(update_values_shape, update_values_min_shape, update_values_max_shape)));
// 2. Create Sub node.
PrimitivePtr sub_primitive = std::make_shared<Primitive>(kSubOpName);
auto sub_value_node = common::AnfAlgo::GetInputNode(node->cast<CNodePtr>(), kEmbeddingLookupOffsetIdx);
int64_t sub_value = GetValue<int64_t>(GetValueNode(sub_value_node));
// The input of Sub must be tensor type.
auto sub_value_tensor = std::make_shared<tensor::Tensor>(sub_value, kInt32);
std::vector<AnfNodePtr> sub_inputs{NewValueNode(sub_primitive), input_indices, NewValueNode(sub_value_tensor)};
auto sub_node = graph->NewCNode(sub_inputs);
MS_EXCEPTION_IF_NULL(sub_node);
sub_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kInt32, kOneDimShape));
common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), sub_node);
common::AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), sub_node);
// 3. Create ScatterUpdate node.
// 2. Create ScatterUpdate node.
PrimitivePtr embedding_cache_update_primitive = std::make_shared<Primitive>(kScatterUpdateOpName);
std::vector<AnfNodePtr> embedding_cache_update_inputs{NewValueNode(embedding_cache_update_primitive), input_param,
sub_node, update_values};
input_indices, update_values};
auto embedding_cache_update_node = graph->NewCNode(embedding_cache_update_inputs);
MS_EXCEPTION_IF_NULL(embedding_cache_update_node);
common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), embedding_cache_update_node);
// 4. Create return node.
// 3. Create return node.
CNodePtr return_node = CreateReturnNode(graph, embedding_cache_update_node);
MS_EXCEPTION_IF_NULL(return_node);
graph->set_return(return_node);
@ -304,17 +368,25 @@ CNodePtr PsEmbeddingCacheInserter::CreateRecvNode() const {
MS_EXCEPTION_IF_NULL(root_graph_);
ParameterPtr input_indices = root_graph_->add_parameter();
MS_EXCEPTION_IF_NULL(input_indices);
input_indices->set_abstract(std::make_shared<abstract::AbstractTensor>(kInt32, kOneDimShape));
input_indices->set_abstract(std::make_shared<abstract::AbstractTensor>(
kInt32, std::make_shared<abstract::Shape>(kOneDimDynamicShape, kOneDimShape, kOneDimShape)));
auto fake_input_indices_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt32, kOneDimShape);
input_indices->set_default_param(fake_input_indices_tensor);
// The update values input.
ParameterPtr update_values = root_graph_->add_parameter();
MS_EXCEPTION_IF_NULL(update_values);
update_values->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, kTwoDimsShape));
update_values->set_abstract(std::make_shared<abstract::AbstractTensor>(
kFloat32, std::make_shared<abstract::Shape>(kTwoDimsDynamicShape, kTwoDimsShape, kTwoDimsShape)));
auto fake_update_values_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, kTwoDimsShape);
update_values->set_default_param(fake_update_values_tensor);
// The service id input, used to choose service to execute.
ParameterPtr service_id = root_graph_->add_parameter();
MS_EXCEPTION_IF_NULL(service_id);
service_id->set_abstract(std::make_shared<abstract::AbstractTensor>(kInt32, kOneDimShape));
auto fake_id_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt32, kOneDimDynamicShape);
service_id->set_default_param(fake_id_tensor);
// 2. Create a RpcRecv node.
std::vector<AnfNodePtr> recv_inputs = {NewValueNode(std::make_shared<Primitive>(kRpcRecvOpName))};
@ -431,12 +503,15 @@ bool PsEmbeddingCacheInserter::ConstructEmbeddingCacheGraph() const {
CNodePtr call_node = root_graph_->NewCNode({switch_layer_node});
MS_EXCEPTION_IF_NULL(call_node);
// 4. Replace useless nodes of origin function graph.
// 4. Replace origin output and useless nodes of origin function graph.
AnfNodePtr old_output = root_graph_->output();
AnfNodePtr new_output = CreateOutputNode(root_graph_, old_output);
auto final_output_node = root_graph_->NewCNode({NewValueNode(prim::kPrimDepend), new_output, call_node});
MS_EXCEPTION_IF_NULL(final_output_node);
auto graph_manager = root_graph_->manager();
MS_EXCEPTION_IF_NULL(graph_manager);
graph_manager->Replace(root_graph_->output(), call_node);
auto return_node = root_graph_->get_return();
MS_EXCEPTION_IF_NULL(return_node);
graph_manager->Replace(root_graph_->output(), final_output_node);
return true;
}
@ -452,6 +527,10 @@ bool PsEmbeddingCacheInserter::Run() {
// Set attr(device target attr and graph split label) for all CNodes.
SetAttrForAllNodes();
MS_EXCEPTION_IF_NULL(root_graph_);
// Need renormalize to infer shape and set abstract.
root_graph_->set_flag(kFlagNeedRenormalize, true);
return true;
}
} // namespace parallel

View File

@ -27,6 +27,9 @@
#include "mindspore/core/utils/ms_context.h"
#include "include/common/utils/anfalgo.h"
#include "include/common/debug/draw.h"
#ifdef WITH_BACKEND
#include "ps/ps_context.h"
#endif
namespace mindspore {
namespace parallel {
@ -835,14 +838,87 @@ GraphSplitter::GraphSplitter(const FuncGraphPtr &func_graph, uint32_t rank_id, c
: func_graph_(func_graph),
rank_id_(rank_id),
role_(role),
mode_(distributed::DistExecutionMode::kPSMode),
exec_mode_(nullptr),
this_process_label_({rank_id, role}),
node_labels_{},
need_fuse_rpc_nodes_(true) {
bool enable_embedding_cache = false;
#ifdef WITH_BACKEND
enable_embedding_cache = ps::PSContext::instance()->cache_enable();
#endif
mode_ = enable_embedding_cache ? distributed::DistExecutionMode::kEmbeddingCacheMode
: distributed::DistExecutionMode::kPSMode;
default_label_ = {0, distributed::kEnvRoleOfWorker};
}
void EmbeddingCacheMode::PreBuildDistributedGraph() {
// Only need add embedding cache ops of remote cache.
if (role_ != distributed::kEnvRoleOfPServer) {
return;
}
// 1. Add embedding cache ops of remote cache, and build service-side graph.
AddEmbeddingCacheOps();
// 2. Get node labels.
MS_EXCEPTION_IF_NULL(node_labels_);
node_labels_->clear();
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
(void)std::for_each(all_nodes.begin(), all_nodes.end(), [this](const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
OperatorLabel label = GetNodeLabel(cnode);
node_labels_->emplace(node, label);
}
});
}
void EmbeddingCacheMode::AddEmbeddingCacheOps() const {
uint32_t worker_num = ClusterContext::instance()->node_num(distributed::kEnvRoleOfWorker);
if (worker_num == 0) {
MS_LOG(EXCEPTION) << "In embedding cache mode, worker number should be greater than 0.";
}
// Build service-side graph.
std::shared_ptr<parallel::PsEmbeddingCacheInserter> embedding_cache_inserter =
std::make_shared<parallel::PsEmbeddingCacheInserter>(func_graph_, static_cast<int64_t>(rank_id_), role_,
worker_num);
if (!embedding_cache_inserter->Run()) {
MS_LOG(EXCEPTION) << "Insert ps embedding cache failed.";
}
}
OperatorLabel EmbeddingCacheMode::GetNodeLabel(const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Only CNode has distributed split label.";
}
CNodePtr cnode = node->cast<CNodePtr>();
auto prim_node = cnode->input(0);
if (IsValueNode<Primitive>(prim_node)) {
auto prim = GetValueNode<PrimitivePtr>(prim_node);
MS_EXCEPTION_IF_NULL(prim);
if (prim->HasAttr(distributed::kOpLabelRankId) && prim->HasAttr(distributed::kOpLabelRole)) {
MS_LOG(INFO) << "CNode which has distributed split label: " << cnode->fullname_with_scope();
uint32_t rank_id = static_cast<uint32_t>(GetValue<int64_t>(prim->GetAttr(distributed::kOpLabelRankId)));
std::string ms_role = GetValue<std::string>(prim->GetAttr(distributed::kOpLabelRole));
return {rank_id, ms_role};
}
} else {
// Get label for call node, 'call' node hasn't primitive to save attrs, so get attrs of 'call' from cnode.
if (cnode->HasAttr(distributed::kOpLabelRankId) && cnode->HasAttr(distributed::kOpLabelRole)) {
uint32_t rank_id = static_cast<uint32_t>(GetValue<int64_t>(cnode->GetAttr(distributed::kOpLabelRankId)));
std::string ms_role = GetValue<std::string>(cnode->GetAttr(distributed::kOpLabelRole));
return {rank_id, ms_role};
}
}
return {rank_id_, role_};
}
GraphSplitter::~GraphSplitter() { node_labels_.clear(); }
void GraphSplitter::Run() {
@ -852,9 +928,7 @@ void GraphSplitter::Run() {
// Step 1: Dye all the nodes of the whole func_graph_.
DyeGraph();
// If all nodes are all on this process, no need to split the graph. So return.
if (std::find_if(node_labels_.begin(), node_labels_.end(), [&](const auto &node_to_label) {
return node_to_label.second != this_process_label_;
}) == node_labels_.end()) {
if (!NeedSplitGraph()) {
MS_LOG(INFO) << "No need to build and split distributed graph.";
return;
}
@ -865,6 +939,11 @@ void GraphSplitter::Run() {
// Step 3: Prebuild the distributed graph before it gets split.
exec_mode_->PreBuildDistributedGraph();
if (!NeedSplitGraph()) {
MS_LOG(INFO) << "No need to build and split distributed graph.";
return;
}
// Step 4: Create inter-process operators for segments with different labels.
InterProcessOpEdgesInfo comm_edges = GenerateInterProcessOperators();
@ -924,6 +1003,8 @@ void GraphSplitter::CreateExecutionMode() {
}
if (mode_ == distributed::DistExecutionMode::kPSMode) {
exec_mode_ = std::make_unique<ParameterServerMode>(func_graph_, &node_labels_, rank_id_, role_);
} else if (mode_ == distributed::DistExecutionMode::kEmbeddingCacheMode) {
exec_mode_ = std::make_unique<EmbeddingCacheMode>(func_graph_, &node_labels_, rank_id_, role_);
}
MS_EXCEPTION_IF_NULL(exec_mode_);
}
@ -1350,5 +1431,11 @@ bool GraphSplitter::IsNodesWithSameLabel(const AnfNodePtr &node1, const AnfNodeP
}
return node_labels_[node1] == node_labels_[node2];
}
bool GraphSplitter::NeedSplitGraph() const {
return std::find_if(node_labels_.begin(), node_labels_.end(), [&](const auto &node_to_label) {
return node_to_label.second != this_process_label_;
}) != node_labels_.end();
}
} // namespace parallel
} // namespace mindspore

View File

@ -34,6 +34,7 @@
#else
#include "distributed/cluster/dummy_cluster_context.h"
#endif
#include "frontend/parallel/cache_embedding/ps_embedding_cache_inserter.h"
namespace mindspore {
namespace parallel {
@ -314,6 +315,21 @@ class ParameterServerMode : public DistributedExecutionMode {
std::vector<size_t> ps_optimizer_fusion_segments_;
};
class EmbeddingCacheMode : public DistributedExecutionMode {
public:
explicit EmbeddingCacheMode(const FuncGraphPtr &func_graph, NodeLabels *node_labels, uint32_t rank_id,
const std::string &role)
: DistributedExecutionMode(func_graph, node_labels, rank_id, role) {}
~EmbeddingCacheMode() = default;
void PreBuildDistributedGraph() override;
private:
void AddEmbeddingCacheOps() const;
OperatorLabel GetNodeLabel(const AnfNodePtr &node) const;
};
// The class is used as an action in pipeline. It will process the graph and split the nodes to each process in the
// cluster.
class GraphSplitter {
@ -388,6 +404,9 @@ class GraphSplitter {
// Judge whether two nodes have the same distributed label.
bool IsNodesWithSameLabel(const AnfNodePtr &node1, const AnfNodePtr &node2);
// Check whether need split distributed graph.
bool NeedSplitGraph() const;
FuncGraphPtr func_graph_;
// Rank id and node role of this process. They are used to dye graph with different labels, help build split graph,

View File

@ -28,6 +28,8 @@
#ifdef WITH_BACKEND
#include "ps/ps_cache/ps_cache_manager.h"
#include "utils/ms_context.h"
#include "ps/ps_context.h"
#include "distributed/embedding_cache/embedding_cache_utils.h"
#endif
namespace mindspore {
@ -715,7 +717,11 @@ Status GatherInfo::InferBias() {
}
#ifdef WITH_BACKEND
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
bias_ = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound());
if (ps::PSContext::instance()->enable_distributed_mindrt()) {
bias_ = static_cast<int64_t>(embedding_cache_table_manager.cache_indices_lower_bound());
} else {
bias_ = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound());
}
return SUCCESS;
}
#endif

View File

@ -30,6 +30,7 @@
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#ifdef WITH_BACKEND
#include "ps/ps_cache/ps_cache_manager.h"
#include "distributed/embedding_cache/embedding_cache_utils.h"
#endif
namespace mindspore {
@ -106,7 +107,13 @@ Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
MS_LOG(ERROR) << "GenerateGraph Init failed";
return FAILED;
}
auto bias = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound());
int64_t bias = 0;
if (ps::PSContext::instance()->enable_distributed_mindrt()) {
bias = static_cast<int64_t>(embedding_cache_table_manager.cache_indices_lower_bound());
} else {
bias = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound());
}
auto slice_size = SizeToLong(ps::PsCacheManager::GetInstance().vocab_cache_size());
auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias)});

View File

@ -601,6 +601,7 @@ constexpr auto kFlagsIsCutGraph = "is_cut_graph";
constexpr auto kFlagIsDynamicStructure = "is_dynamic_structure";
constexpr auto kFlagIsPynativeBpropGraph = "is_pynative_bprop_graph";
constexpr auto kFlagPyNativeRunInGraph = "pynative_run_in_graph";
constexpr auto kFlagNeedRenormalize = "need_renormalize";
// TODO(dsj): for ms_function running in graph_mode. should be delete later
constexpr auto kAttrMSFunction = "ms_function_graph";

View File

@ -110,6 +110,12 @@ void DisableMindRT(const ResourcePtr &resource) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT) == false) {
return;
}
#ifdef WITH_BACKEND
if (ps::PSContext::instance()->cache_enable()) {
return;
}
#endif
auto func_graph = resource->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
auto parallel_context = parallel::ParallelContext::GetInstance();
@ -915,6 +921,13 @@ void SetRunMode(const FuncGraphPtr &func_graph, compile::Backend *backend_ptr) {
return;
}
#ifdef WITH_BACKEND
if (ps::PSContext::instance()->cache_enable()) {
set_ctx(true, false, false);
return;
}
#endif
// GRAPH | normal network and if/for/switch scenario etc : MultiGraph path in MindRT.
MS_LOG(INFO) << "Run graph mode with multigraph sink.";
set_ctx(true, true, true);
@ -1202,6 +1215,17 @@ bool DistributedSplitAction(const ResourcePtr &resource) {
std::make_shared<parallel::GraphSplitter>(func_graph, node->rank_id(), node_role);
MS_EXCEPTION_IF_NULL(splitter);
splitter->Run();
// Renomalize: Infer shape and Set abstract for all nodes in graph.
if (func_graph->has_flag(kFlagNeedRenormalize)) {
abstract::AbstractBasePtrList args_abs;
auto parameters = func_graph->parameters();
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_abs),
[](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
FuncGraphPtr new_fg = Renormalize(resource, func_graph, args_abs);
resource->set_func_graph(new_fg);
resource->set_args_abs(args_abs);
}
return true;
}
#endif

View File

@ -889,7 +889,6 @@ bool AddEmbeddingCachePass(const ResourcePtr &resource) {
}
std::vector<PassItem> kVmPasses = {
{"add_embedding_cache", AddEmbeddingCachePass},
{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_a", OptPassAGroup},
{"clean_after_opta", CleanAfterOptAPass},

View File

@ -79,6 +79,7 @@
#include "fl/server/server.h"
#include "fl/worker/fl_worker.h"
#include "distributed/cluster/cluster_context.h"
#include "runtime/graph_scheduler/embedding_cache_scheduler.h"
#endif
#ifdef ENABLE_D
@ -1655,6 +1656,7 @@ void ClearResAtexit() {
RecordExitStatus();
#ifdef WITH_BACKEND
if (distributed::cluster::ClusterContext::instance()->initialized()) {
runtime::EmbeddingCacheScheduler::GetInstance().Finalize();
(void)distributed::cluster::ClusterContext::instance()->Finalize(UINT32_MAX);
} else if (ps::PSContext::instance()->is_ps_mode() && ps::PSContext::instance()->is_worker()) {
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {

View File

@ -394,11 +394,14 @@ bool AscendKernelExecutor::LaunchKernel(const CNodePtr &kernel, const vector<Add
bool is_dynamic_shape = common::AnfAlgo::IsDynamicShape(kernel);
if (!is_dynamic_shape || !(common::AnfAlgo::GetBooleanAttr(kernel, kAttrMSFunction))) {
std::lock_guard<std::mutex> locker(launch_mutex_);
// launch atomic clean
if (!LaunchAtomicClean(kernel, workspace, outputs)) {
MS_LOG(ERROR) << "Launch AtomicClean failed, pre kernel full name: " << kernel->fullname_with_scope();
return false;
auto iter = node_atomics_persistent_cache_.find(kernel);
if (iter != node_atomics_persistent_cache_.end()) {
std::lock_guard<std::mutex> locker(launch_mutex_);
// launch atomic clean
if (!LaunchAtomicClean(kernel, workspace, outputs)) {
MS_LOG(ERROR) << "Launch AtomicClean failed, pre kernel full name: " << kernel->fullname_with_scope();
return false;
}
}
}

View File

@ -18,145 +18,152 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
MS_REG_GPU_KERNEL_THREE(
Gather,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
GatherV2FwdGpuKernelMod, double, int)
MS_REG_GPU_KERNEL_TWO(
GatherV2FwdGpuKernelMod, double, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
GatherV2FwdGpuKernelMod, double, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherV2FwdGpuKernelMod, double, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
GatherV2FwdGpuKernelMod, float, int)
MS_REG_GPU_KERNEL_TWO(
GatherV2FwdGpuKernelMod, float, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
GatherV2FwdGpuKernelMod, float, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherV2FwdGpuKernelMod, float, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
GatherV2FwdGpuKernelMod, half, int)
MS_REG_GPU_KERNEL_TWO(
GatherV2FwdGpuKernelMod, half, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
GatherV2FwdGpuKernelMod, half, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherV2FwdGpuKernelMod, half, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
GatherV2FwdGpuKernelMod, int, int)
MS_REG_GPU_KERNEL_TWO(
GatherV2FwdGpuKernelMod, int, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
GatherV2FwdGpuKernelMod, int, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherV2FwdGpuKernelMod, int, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
GatherV2FwdGpuKernelMod, int16_t, int)
MS_REG_GPU_KERNEL_TWO(
GatherV2FwdGpuKernelMod, int16_t, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
GatherV2FwdGpuKernelMod, int16_t, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherV2FwdGpuKernelMod, int16_t, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
GatherV2FwdGpuKernelMod, int8_t, int)
MS_REG_GPU_KERNEL_TWO(
GatherV2FwdGpuKernelMod, int8_t, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
GatherV2FwdGpuKernelMod, int8_t, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherV2FwdGpuKernelMod, int8_t, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
GatherV2FwdGpuKernelMod, uint, int)
MS_REG_GPU_KERNEL_TWO(
GatherV2FwdGpuKernelMod, uint, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
GatherV2FwdGpuKernelMod, uint, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherV2FwdGpuKernelMod, uint, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
GatherV2FwdGpuKernelMod, uint8_t, int)
MS_REG_GPU_KERNEL_TWO(
GatherV2FwdGpuKernelMod, uint8_t, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
GatherV2FwdGpuKernelMod, uint8_t, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherV2FwdGpuKernelMod, uint8_t, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
GatherV2FwdGpuKernelMod, bool, int)
MS_REG_GPU_KERNEL_TWO(
GatherV2FwdGpuKernelMod, bool, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
Gather, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
GatherV2FwdGpuKernelMod, bool, int64_t)
GatherV2FwdGpuKernelMod, bool, int64_t, int64_t)
// dynamic shape
MS_REG_GPU_KERNEL_TWO(Gather,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
GatherV2FwdGpuKernelMod, double, int)
MS_REG_GPU_KERNEL_TWO(Gather,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
GatherV2FwdGpuKernelMod, double, int64_t)
MS_REG_GPU_KERNEL_TWO(Gather,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
GatherV2FwdGpuKernelMod, float, int)
MS_REG_GPU_KERNEL_TWO(Gather,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
GatherV2FwdGpuKernelMod, float, int64_t)
MS_REG_GPU_KERNEL_TWO(Gather,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
GatherV2FwdGpuKernelMod, half, int)
MS_REG_GPU_KERNEL_TWO(Gather,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
GatherV2FwdGpuKernelMod, half, int64_t)
MS_REG_GPU_KERNEL_TWO(Gather,
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeBool),
GatherV2FwdGpuKernelMod, bool, int)
MS_REG_GPU_KERNEL_TWO(Gather,
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeBool),
GatherV2FwdGpuKernelMod, bool, int64_t)
MS_REG_GPU_KERNEL_THREE(Gather,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
GatherV2FwdGpuKernelMod, double, int, int64_t)
MS_REG_GPU_KERNEL_THREE(Gather,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
GatherV2FwdGpuKernelMod, double, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(Gather,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
GatherV2FwdGpuKernelMod, float, int, int64_t)
MS_REG_GPU_KERNEL_THREE(Gather,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
GatherV2FwdGpuKernelMod, float, int, int)
MS_REG_GPU_KERNEL_THREE(Gather,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
GatherV2FwdGpuKernelMod, float, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(Gather,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
GatherV2FwdGpuKernelMod, half, int, int64_t)
MS_REG_GPU_KERNEL_THREE(Gather,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
GatherV2FwdGpuKernelMod, half, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(Gather,
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeBool),
GatherV2FwdGpuKernelMod, bool, int, int64_t)
MS_REG_GPU_KERNEL_THREE(Gather,
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeBool),
GatherV2FwdGpuKernelMod, bool, int64_t, int64_t)
// dynamic shape ends
MS_REG_GPU_KERNEL_TWO(
MS_REG_GPU_KERNEL_THREE(
SparseGatherV2,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
GatherV2FwdGpuKernelMod, float, int)
MS_REG_GPU_KERNEL_TWO(
GatherV2FwdGpuKernelMod, float, int, int64_t)
MS_REG_GPU_KERNEL_THREE(
SparseGatherV2,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
GatherV2FwdGpuKernelMod, half, int)
MS_REG_GPU_KERNEL_TWO(SparseGatherV2,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
GatherV2FwdGpuKernelMod, float, int)
MS_REG_GPU_KERNEL_TWO(SparseGatherV2,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
GatherV2FwdGpuKernelMod, half, int)
GatherV2FwdGpuKernelMod, half, int, int64_t)
MS_REG_GPU_KERNEL_THREE(SparseGatherV2,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
GatherV2FwdGpuKernelMod, float, int, int64_t)
MS_REG_GPU_KERNEL_THREE(SparseGatherV2,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
GatherV2FwdGpuKernelMod, half, int, int64_t)
} // namespace kernel
} // namespace mindspore

View File

@ -27,7 +27,7 @@
namespace mindspore {
namespace kernel {
template <typename T, typename S>
template <typename T, typename S, typename G>
class GatherV2FwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
public:
GatherV2FwdGpuKernelMod() { ResetResource(); }
@ -43,16 +43,16 @@ class GatherV2FwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
S *indices_addr = GetDeviceAddress<S>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
if (is_dynamic_shape_) {
int64_t *axis_device_address = GetDeviceAddress<int64_t>(inputs, 2); // only get this if in dynamic mode
G *axis_device_address = GetDeviceAddress<G>(inputs, 2); // only get this if in dynamic mode
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&axis_, axis_device_address, sizeof(int64_t), cudaMemcpyDeviceToHost,
cudaMemcpyAsync(&axis_, axis_device_address, sizeof(G), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync axis_ failed");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(),
"cudaDeviceSyncFailed - GatherV2 - in dynamic mode");
Reshape();
}
auto input_dim1 = input_shapes_[IntToSize(axis_)];
auto input_dim1 = input_shapes_[axis_];
MS_EXCEPTION_IF_NULL(input_addr);
MS_EXCEPTION_IF_NULL(indices_addr);
@ -85,7 +85,7 @@ class GatherV2FwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
}
if (!is_dynamic_shape_) {
int dims = SizeToInt(input_shapes_.size());
axis_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "axis"));
axis_ = static_cast<G>(GetAttr<int64_t>(kernel_node, "axis"));
if (axis_ < -dims || axis_ >= dims) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-" << dims << "," << dims
<< "), but got " << axis_;
@ -115,7 +115,7 @@ class GatherV2FwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
size = common::AnfAlgo::TensorSizeInByte<T>(indices_shapes_);
input_size_list_.push_back(size);
if (is_dynamic_shape_) {
input_size_list_.push_back(sizeof(int64_t));
input_size_list_.push_back(sizeof(G));
}
size = common::AnfAlgo::TensorSizeInByte<T>(output_shapes_);
output_size_list_.push_back(size);
@ -148,7 +148,7 @@ class GatherV2FwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
std::vector<size_t> indices_shapes_;
std::vector<size_t> output_shapes_;
size_t dims_[3] = {};
int64_t axis_;
G axis_;
bool is_dynamic_shape_;
bool is_null_input_;
};

View File

@ -618,7 +618,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph &graph) {
DeviceAddressPtr device_address = GetInternalDeviceAddress(graph, item);
#ifdef WITH_BACKEND
const std::string &param_name = item->fullname_with_scope();
if (ps::ps_cache_instance.IsHashTable(param_name)) {
if (ps::ps_cache_instance.IsHashTable(param_name) && !ps::PSContext::instance()->enable_distributed_mindrt()) {
MS_LOG(INFO) << "Parameter(" << param_name << ")"
<< " enables the embeddingLookup cache in parameter server training mode.";
// PS embeddingLookup cache check.

View File

@ -15,10 +15,12 @@
*/
#include "runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h"
#include <limits>
#include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
#include "kernel/common_utils.h"
#include "runtime/graph_scheduler/actor/rpc/rpc_actor.h"
#include "proto/topology.pb.h"
#include "distributed/constants.h"
namespace mindspore {
namespace runtime {
@ -44,6 +46,13 @@ ParameterPtr NewParameter(const KernelGraphPtr &graph, TypePtr type, const Shape
auto abstract = std::make_shared<abstract::AbstractTensor>(type, shape);
param->set_abstract(abstract);
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
std::vector<std::string> formats = {kOpFormat_DEFAULT};
std::vector<TypeId> types = {type->type_id()};
kernel_build_info_builder->SetOutputsFormat(formats);
kernel_build_info_builder->SetOutputsDeviceType(types);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
auto mutable_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(mutable_inputs);
mutable_inputs->push_back(param);
@ -51,6 +60,25 @@ ParameterPtr NewParameter(const KernelGraphPtr &graph, TypePtr type, const Shape
return param;
}
ValueNodePtr NewValueNode(int64_t value) {
auto tensor = std::make_shared<tensor::Tensor>(static_cast<int64_t>(0), kInt32);
auto value_node = NewValueNode(tensor);
value_node->set_abstract(tensor->ToAbstract());
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
std::vector<std::string> formats = {kOpFormat_DEFAULT};
std::vector<TypeId> types = {kInt32->type_id()};
kernel_build_info_builder->SetOutputsFormat(formats);
kernel_build_info_builder->SetOutputsDeviceType(types);
auto kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(kernel_info);
value_node->set_kernel_info(kernel_info);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), value_node.get());
return value_node;
}
bool InferOpShape(const CNodePtr &kernel) {
MS_EXCEPTION_IF_NULL(kernel);
opt::dynamic_shape::InferOp(kernel);
@ -171,7 +199,11 @@ bool MemcpyDeviceToHostAsync(void *dst, const void *src, size_t size, const Devi
} // namespace
void EmbeddingCachePrefetchActor::Initialize() {
if (initialized_) {
return;
}
MS_EXCEPTION_IF_NULL(device_context_);
MS_EXCEPTION_IF_NULL(device_context_->device_res_manager_);
if (!device_context_->device_res_manager_->CreateStream(&stream_id_)) {
MS_LOG(EXCEPTION) << "Create stream failed.";
}
@ -196,30 +228,49 @@ void EmbeddingCachePrefetchActor::Initialize() {
// Build and link rpc operators.
BuildRpcOperators();
LinkRpcOperators();
initialized_ = true;
}
void EmbeddingCachePrefetchActor::Finalize() {
if (!initialized_ || finalized_) {
return;
}
SyncEmbeddingTable();
running_ = false;
FinalizeRemote();
PsDataPrefetch::GetInstance().NotifyFinalize();
data_parser_.notify_all();
embedding_cache_lookup_node_ = nullptr;
embedding_cache_update_node_ = nullptr;
rpc_operators_.clear();
finalized_ = true;
initialized_ = false;
}
void EmbeddingCachePrefetchActor::BuildEmbeddingCacheLookupKernel() {
auto graph = std::make_shared<KernelGraph>();
MS_EXCEPTION_IF_NULL(graph);
graph->set_graph_id((std::numeric_limits<uint32_t>::max)());
embedding_cache_graphs_.push_back(graph);
// 1. Create parameter nodes which are inputs of embedding cache look up kernel(operator name: 'Gather').
ParameterPtr input_param = NewParameter(graph, kFloat32, kTwoDimensionalShape);
ParameterPtr input_indices = NewParameter(graph, kInt32, kOneDimensionalShape);
ValueNodePtr axis_value_node = NewValueNode(0);
// 2. Create a CNode for operator Gather.
PrimitivePtr emb_lookup_primitive = std::make_shared<Primitive>(kGatherV2OpName);
emb_lookup_primitive->set_attr(kAttrAxis, MakeValue<int64_t>(0));
emb_lookup_primitive->set_attr(kAttrInputIsDynamicShape, MakeValue(true));
emb_lookup_primitive->set_attr(kAttrOutputIsDynamicShape, MakeValue(true));
emb_lookup_primitive->set_attr(kAttrStream, MakeValue(stream_id_));
std::vector<AnfNodePtr> emb_lookup_input_nodes{NewValueNode(emb_lookup_primitive), input_param, input_indices};
std::vector<AnfNodePtr> emb_lookup_input_nodes{NewValueNode(emb_lookup_primitive), input_param, input_indices,
axis_value_node};
embedding_cache_lookup_node_ = graph->NewCNode(emb_lookup_input_nodes);
MS_EXCEPTION_IF_NULL(embedding_cache_lookup_node_);
auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, kTwoDimensionalShape);
@ -227,11 +278,15 @@ void EmbeddingCachePrefetchActor::BuildEmbeddingCacheLookupKernel() {
// 3. Kernel build process.
MS_EXCEPTION_IF_NULL(device_context_);
MS_EXCEPTION_IF_NULL(device_context_->kernel_executor_);
device_context_->kernel_executor_->CreateKernel({embedding_cache_lookup_node_});
}
void EmbeddingCachePrefetchActor::BuildEmbeddingCacheUpdateKernel() {
auto graph = std::make_shared<KernelGraph>();
MS_EXCEPTION_IF_NULL(graph);
graph->set_graph_id((std::numeric_limits<uint32_t>::max)());
embedding_cache_graphs_.push_back(graph);
// 1. Create parameter nodes which are inputs of embedding cache update kernel(operator name: 'ScatterUpdate').
ParameterPtr input_param = NewParameter(graph, kFloat32, kTwoDimensionalShape);
@ -252,6 +307,7 @@ void EmbeddingCachePrefetchActor::BuildEmbeddingCacheUpdateKernel() {
// 3. Kernel build process.
MS_EXCEPTION_IF_NULL(device_context_);
MS_EXCEPTION_IF_NULL(device_context_->kernel_executor_);
device_context_->kernel_executor_->CreateKernel({embedding_cache_update_node_});
}
@ -288,6 +344,7 @@ bool EmbeddingCachePrefetchActor::LookupDeviceCache(void *indices, void *embeddi
AddressPtrList kernel_outputs = {std::make_shared<Address>(outputs, indices_num * embedding_size * sizeof(float))};
MS_ERROR_IF_NULL(device_context_);
MS_ERROR_IF_NULL(device_context_->kernel_executor_);
auto ret =
device_context_->kernel_executor_->LaunchKernel(embedding_cache_lookup_node_, kernel_inputs, {}, kernel_outputs);
if (!ret) {
@ -338,6 +395,7 @@ bool EmbeddingCachePrefetchActor::UpdateDeviceCache(void *indices, void *update_
std::make_shared<Address>(embedding_cache, cache_size * embedding_size * sizeof(float))};
MS_ERROR_IF_NULL(device_context_);
MS_ERROR_IF_NULL(device_context_->kernel_executor_);
auto ret =
device_context_->kernel_executor_->LaunchKernel(embedding_cache_update_node_, kernel_inputs, {}, kernel_outputs);
if (!ret) {
@ -809,6 +867,7 @@ bool EmbeddingCachePrefetchActor::PushCacheFromDeviceToLocalHost(const HashTable
"Memcpy device to host asynchronously failed.");
MS_ERROR_IF_NULL(device_context_);
MS_ERROR_IF_NULL(device_context_->device_res_manager_);
RETURN_IF_FALSE_WITH_LOG(device_context_->device_res_manager_->SyncStream(stream_id_), "Synchronize stream failed.");
RETURN_IF_FALSE_WITH_LOG(
InsertLocalHostCache(embedding_size, IntToSize(swap_indices_size), host_cache_device_to_host_index,
@ -882,6 +941,7 @@ bool EmbeddingCachePrefetchActor::PullCacheFromLocalHostToDevice(const HashTable
swap_indices_size, cache_vocab_size, embedding_size, hash_table_addr),
"Update device embedding cache failed.");
MS_ERROR_IF_NULL(device_context_);
MS_ERROR_IF_NULL(device_context_->device_res_manager_);
RETURN_IF_FALSE_WITH_LOG(device_context_->device_res_manager_->SyncStream(stream_id_), "Synchronize stream failed.");
return true;
}
@ -1162,9 +1222,11 @@ bool EmbeddingCachePrefetchActor::PartitionIdsAndEmbeddings(const int *ids, size
std::vector<int> &slice_ids = slice_ids_list->at(i);
std::vector<float> &slice_embeddings = slice_embeddings_list->at(i);
// Ids range offset for multi server.
int offset = SizeToInt(remote_embedding_slice_bounds_.at(i).first);
for (size_t j = 0; j < ids_num; j++) {
if (ids[j] >= begin && ids[j] <= end) {
slice_ids.push_back(ids[j]);
slice_ids.push_back(ids[j] - offset);
slice_embeddings.insert(slice_embeddings.end(), embeddings + (j * embedding_dim),
embeddings + (j * embedding_dim) + embedding_dim);
}
@ -1175,7 +1237,8 @@ bool EmbeddingCachePrefetchActor::PartitionIdsAndEmbeddings(const int *ids, size
bool EmbeddingCachePrefetchActor::SendToRemote(const std::string &cache_operation, int32_t param_key,
size_t server_rank_id, size_t embedding_dim, const void *keys,
size_t keys_len, const void *values, size_t values_len) {
size_t keys_len, const void *values, size_t values_len,
bool finalize_remote) {
MS_ERROR_IF_NULL(keys);
// Find sender corresponding to cache operation and parameter key.
auto iter = rpc_operators_.find(cache_operation);
@ -1215,7 +1278,7 @@ bool EmbeddingCachePrefetchActor::SendToRemote(const std::string &cache_operatio
std::make_shared<Address>(&service_id, sizeof(int32_t))};
// Send data.
return sender->Send(shapes, data_types, data_list);
return sender->Send(shapes, data_types, data_list, finalize_remote);
}
std::unique_ptr<std::vector<char>> EmbeddingCachePrefetchActor::ReceiveFromRemote(const std::string &cache_operation,
@ -1394,6 +1457,19 @@ bool EmbeddingCachePrefetchActor::SyncDeviceEmbeddingTable() {
return true;
}
bool EmbeddingCachePrefetchActor::FinalizeRemote() {
for (size_t i = 0; i < server_num_; i++) {
size_t embedding_dim = 1;
int id = 0;
float value = 0.0;
RETURN_IF_FALSE_WITH_LOG(SendToRemote(distributed::kLookupEmbeddingCache, 0, i, embedding_dim, &id, sizeof(int),
&value, sizeof(float), true),
"Send finalize request to remote failed.");
}
return true;
}
std::string EmbeddingCachePrefetchActor::channel_name() {
std::lock_guard<std::mutex> locker(channel_mutex_);
return channel_name_;
@ -1490,9 +1566,9 @@ void EmbeddingCachePrefetchActor::LinkRpcOperators() {
}
bool Sender::Send(const std::vector<ShapeVector> &shapes, const std::vector<TypeId> data_types,
const AddressPtrList &data_list) const {
const AddressPtrList &data_list, bool finalize_remote) const {
MS_ERROR_IF_NULL(receiver_);
auto message = BuildRpcMessage(shapes, data_types, data_list, receiver_->get_url(), server_url_);
auto message = BuildRpcMessage(shapes, data_types, data_list, receiver_->get_url(), server_url_, finalize_remote);
MS_ERROR_IF_NULL(message);
MS_ERROR_IF_NULL(client_);
client_->SendAsync(std::move(message));
@ -1533,7 +1609,7 @@ bool Sender::ConnectServer() {
std::unique_ptr<MessageBase> Sender::BuildRpcMessage(const std::vector<ShapeVector> &shapes,
const std::vector<TypeId> data_types,
const AddressPtrList &data_list, const std::string &from_url,
const std::string &to_url) const {
const std::string &to_url, bool finalize_remote) const {
std::unique_ptr<MessageBase> message = std::make_unique<MessageBase>();
MS_ERROR_IF_NULL_W_RET_VAL(message, nullptr);
message->from = AID("", from_url);
@ -1571,6 +1647,13 @@ std::unique_ptr<MessageBase> Sender::BuildRpcMessage(const std::vector<ShapeVect
// 4. The real data buffer need to be sent.
message->body.append(static_cast<char *>(data->addr), data->size);
}
// 5. Finalize remote command.
if (finalize_remote) {
message->body.append(distributed::kFinalizeMuxRecvActor);
message->body.append(reinterpret_cast<char *>(&finalize_remote), sizeof(finalize_remote));
}
return message;
}

View File

@ -176,7 +176,8 @@ class EmbeddingCachePrefetchActor : public ActorBase {
// Send content to remote, such as ids or embeddings.
// The parameter 'cache_operation' is cache operation name such as LookupEmbeddingCache and UpdateEmbeddingCache.
bool SendToRemote(const std::string &cache_operation, int32_t param_key, size_t server_rank_id, size_t embedding_dim,
const void *keys, size_t keys_len, const void *values = nullptr, size_t values_len = 0);
const void *keys, size_t keys_len, const void *values = nullptr, size_t values_len = 0,
bool finalize_remote = false);
// Wait response of remote and get return result.
// The parameter 'cache_operation' is cache operation name such as LookupEmbeddingCache and UpdateEmbeddingCache.
std::unique_ptr<std::vector<char>> ReceiveFromRemote(const std::string &cache_operation, int32_t param_key,
@ -251,6 +252,9 @@ class EmbeddingCachePrefetchActor : public ActorBase {
// The embedding cache update kernel node(operator name: 'ScatterUpdate').
CNodePtr embedding_cache_update_node_;
// Cache embeding cache ops kernel graphs.
std::vector<KernelGraphPtr> embedding_cache_graphs_;
// Full Embedding table row num, not less than the total number of feature ids.
size_t vocab_size_{0};
@ -291,6 +295,8 @@ class EmbeddingCachePrefetchActor : public ActorBase {
// The flag which indicates whether this actor is initialized.
bool initialized_{false};
// The flag which indicates whether this actor is finalized.
bool finalized_{false};
// The flag which indicates whether finish sync embedding table.
bool finish_sync_embedding_table_{false};
@ -357,7 +363,7 @@ class Sender : public RpcOperator {
// Send buffer to peer.
bool Send(const std::vector<ShapeVector> &shapes, const std::vector<TypeId> data_types,
const AddressPtrList &data_list) const;
const AddressPtrList &data_list, bool finalize_remote = false) const;
// Set the receiver paired with the sender to get the 'from url' from the receiver.
void set_receiver(const ReceiverPtr &receiver) { receiver_ = receiver; }
@ -373,7 +379,8 @@ class Sender : public RpcOperator {
// The message.from (from url) must be set.
std::unique_ptr<MessageBase> BuildRpcMessage(const std::vector<ShapeVector> &shapes,
const std::vector<TypeId> data_types, const AddressPtrList &data_list,
const std::string &from_url, const std::string &to_url) const;
const std::string &from_url, const std::string &to_url,
bool finalize_remote) const;
// The url of the peer receiver's tcp server.
std::string server_url_;

View File

@ -72,9 +72,18 @@ void MuxRecvActor::ParseFinalizeReqData(size_t data_len, const MessageBase *cons
}
const void *need_finalize_actor_data = msg_body.c_str() + data_len + finalize_header_size;
*need_finalize = *(reinterpret_cast<const bool *>(need_finalize_actor_data));
if (*need_finalize) {
MS_EXCEPTION_IF_NULL(need_finalize_actor_data);
bool finalize_in_msg = *(reinterpret_cast<const bool *>(need_finalize_actor_data));
MS_LOG(INFO) << "Received a message which contains finalize command: " << finalize_in_msg;
if (!finalize_in_msg) {
return;
}
recv_finalize_msg_cnt_++;
if (recv_finalize_msg_cnt_ == ClusterContext::instance()->node_num(distributed::kEnvRoleOfWorker)) {
*need_finalize = true;
// Finalize loop of runtime.
MS_EXCEPTION_IF_NULL(op_context_);
SET_OPCONTEXT_SUCCESS_RET((*op_context_));
}
}
@ -88,6 +97,8 @@ void MuxRecvActor::UpdateStatus() {
void MuxRecvActor::Finalize() {
std::unique_lock<std::mutex> lock(context_mtx_);
finalized_ = true;
is_ready_ = true;
is_context_valid_ = true;
op_context_ = nullptr;
context_cv_.notify_all();

View File

@ -73,7 +73,9 @@ class MuxRecvActor : public RecvActor {
std::condition_variable is_ready_cv_;
// Whether the actor is finalized_
bool finalized_{false};
std::atomic_bool finalized_{false};
uint32_t recv_finalize_msg_cnt_{0};
};
using MuxRecvActorPtr = std::shared_ptr<MuxRecvActor>;

View File

@ -67,7 +67,9 @@ void GetFirstEmbeddingCacheTableInfo(const KernelGraph &graph, AnfNodePtr *const
continue;
}
auto size = embedding_cache_table_manager.QueryHashTableSize(param_name);
while (input_index.first->isa<CNode>() && (common::AnfAlgo::GetCNodeName(input_index.first) == kCastOpName)) {
while (input_index.first->isa<CNode>() &&
((common::AnfAlgo::GetCNodeName(input_index.first) == kCastOpName) ||
(common::AnfAlgo::GetCNodeName(input_index.first) == kTensorMoveOpName))) {
input_index = common::AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true);
MS_EXCEPTION_IF_NULL(input_index.first);
}
@ -114,7 +116,8 @@ void CheckSparseModeForEmbeddingCache(const CNodePtr &node) {
pre_node = common::AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
MS_EXCEPTION_IF_NULL(pre_node.first);
while (pre_node.first->isa<CNode>() && (common::AnfAlgo::GetCNodeName(pre_node.first) == kCastOpName)) {
while (pre_node.first->isa<CNode>() && ((common::AnfAlgo::GetCNodeName(pre_node.first) == kCastOpName) ||
(common::AnfAlgo::GetCNodeName(pre_node.first) == kTensorMoveOpName))) {
pre_node = common::AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
MS_EXCEPTION_IF_NULL(pre_node.first);
}
@ -147,7 +150,9 @@ void CheckGraphValidForEmbeddingCache(const KernelGraph &graph) {
if (embedding_cache_table_manager.IsEmbeddingCacheTable(param_name) && (kernel_name == kSparseGatherV2OpName)) {
CheckSparseModeForEmbeddingCache(kernel);
}
while (input_index.first->isa<CNode>() && (common::AnfAlgo::GetCNodeName(input_index.first) == kCastOpName)) {
while (input_index.first->isa<CNode>() &&
((common::AnfAlgo::GetCNodeName(input_index.first) == kCastOpName) ||
(common::AnfAlgo::GetCNodeName(input_index.first) == kTensorMoveOpName))) {
input_index = common::AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true);
MS_EXCEPTION_IF_NULL(input_index.first);
}

View File

@ -23,6 +23,7 @@
#include "utils/ms_utils.h"
#include "backend/common/session/kernel_graph.h"
#include "runtime/hardware/device_context.h"
#include "include/backend/visible.h"
namespace mindspore {
namespace runtime {
@ -33,7 +34,7 @@ class EmbeddingCachePrefetchActor;
// to cache large embedding table of a large recommendation network model. The cache level is:
// Device Cache->Local Host Cache->Remote Cache. The embedding cache prefetch actor is used to perform Local
// and Device Cache hit analysis and cache prefetching.
class EmbeddingCacheScheduler {
class BACKEND_EXPORT EmbeddingCacheScheduler {
public:
static EmbeddingCacheScheduler &GetInstance();

View File

@ -553,6 +553,9 @@ class Model:
cb_params.network = self._network
if (_is_role_pserver() and not _enable_distributed_mindrt()) or _is_role_sched():
epoch = 1
# Embedding cache server only run one step.
if (_is_role_pserver() or _is_role_sched()) and _cache_enable():
epoch = 1
cb_params.last_save_ckpt_step = None
cb_params.latest_ckpt_file = None
@ -613,6 +616,8 @@ class Model:
self._check_enable_recovery()
# Used to check whether need perform recovery for process which is restarted.
self._check_need_load_ckpt(cb_params, train_dataset.get_dataset_size(), sink_size)
# Check whether this process is embedding cache server.
is_embedding_cache_server = _is_role_pserver() and _cache_enable()
while self.epoch_iter < (epoch - initial_epoch):
cb_params.cur_epoch_num = self.epoch_iter + 1 + initial_epoch
@ -651,6 +656,9 @@ class Model:
if (_is_role_pserver() and not _enable_distributed_mindrt()) or _is_role_sched():
os._exit(0)
# Embedding cache server only run one step.
if is_embedding_cache_server:
break
dataset_helper.continue_send()
@ -681,7 +689,10 @@ class Model:
cb_params.net_outputs = train_net_outputs
# In disaster recovery scenarios, need not to execute callbacks if this epoch executes failed.
need_exec_callback_epoch_end = not (self.enable_recovery and _get_recovery_context("need_reset"))
# Embedding cache server need not do epoch end callback, this process only run one step.
need_exec_callback_epoch_end = not ((self.enable_recovery and _get_recovery_context("need_reset"))
or is_embedding_cache_server)
if need_exec_callback_epoch_end:
list_callback.on_train_epoch_end(run_context)
if "metrics" in cb_params or "eval_results" in cb_params: