graceful exit for embedding cache mode
This commit is contained in:
parent
ce05ec064a
commit
ea0fe63c04
|
@ -30,6 +30,9 @@
|
|||
"mindspore/mindspore/ccsrc/frontend/optimizer/ad/kpynative.cc" "containerOutOfBounds"
|
||||
"mindspore/mindspore/ccsrc/pipeline/jit/action.cc" "unreadVariable"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.c" "unreadVariable"
|
||||
"mindspore/mindspore/ccsrc/frontend/parallel/graph_util/graph_splitter.cc" "knownConditionTrueFalse"
|
||||
"mindspore/mindspore/ccsrc/backend/graph_compiler/backend.cc" "knownConditionTrueFalse"
|
||||
"mindspore/mindspore/ccsrc/backend/graph_compiler/backend.cc" "variableScope"
|
||||
|
||||
# MindData
|
||||
"mindspore/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc" "useStlAlgorithm"
|
||||
|
|
|
@ -933,10 +933,16 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &graph_output, const Vec
|
|||
void MindRTBackend::ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, const FuncGraphPtr &root_graph) {
|
||||
bool need_contruct_output = !(distributed::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
|
||||
distributed::recovery::RecoveryContext::GetInstance()->need_reset());
|
||||
bool is_embedding_cache_server = false;
|
||||
#ifdef WITH_BACKEND
|
||||
is_embedding_cache_server = ps::PSContext::instance()->cache_enable() && ps::PSContext::instance()->is_server();
|
||||
#endif
|
||||
if (need_contruct_output) {
|
||||
// Update device address for output node of graph.
|
||||
// Summary processing will use the output device address, so must be after the summary processing.
|
||||
actor_set->output_actor_->UpdateOutputDeviceAddress();
|
||||
if (!is_embedding_cache_server) {
|
||||
actor_set->output_actor_->UpdateOutputDeviceAddress();
|
||||
}
|
||||
|
||||
// Fetch outputs.
|
||||
MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
|
||||
|
|
|
@ -28,10 +28,10 @@ namespace mindspore {
|
|||
namespace distributed {
|
||||
namespace cluster {
|
||||
// The timeout in milliseconds for one lookup.
|
||||
constexpr uint32_t kDefaultLookupTimeout = 60000;
|
||||
constexpr uint32_t kDefaultLookupTimeout = 300000;
|
||||
|
||||
// The time in milliseconds between two lookup operations.
|
||||
constexpr uint32_t kLookupInterval = 100;
|
||||
constexpr uint32_t kLookupInterval = 3000;
|
||||
|
||||
// Actor route table proxy for nodes like workers and server. This class helps update actor route table in scheduler
|
||||
// across the network.
|
||||
|
|
|
@ -49,7 +49,7 @@ const std::vector<std::string> kEmbeddingCacheOps = {kLookupEmbeddingCache, kUpd
|
|||
constexpr char kFinalizeMuxRecvActor[] = "FINALIZE_MUX_RECV_ACTOR";
|
||||
|
||||
// The distributed execution mode enum.
|
||||
enum class DistExecutionMode { kPSMode = 0, kInvalidMode };
|
||||
enum class DistExecutionMode { kPSMode = 0, kEmbeddingCacheMode, kInvalidMode };
|
||||
|
||||
// The operator's label in distributed execution.
|
||||
constexpr char kOpLabelRankId[] = "rank_id";
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "kernel/kernel.h"
|
||||
#include "distributed/embedding_cache/embedding_hash_map.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
#include "include/backend/visible.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
|
@ -114,7 +115,7 @@ struct EmbeddingCacheStatisticsInfo {
|
|||
|
||||
// The EmbeddingCacheTableManager class is used to save all Parameter information for enabling cache, such as device
|
||||
// cache size, host cache size, etc., and can allocate memory for the embedding cache table.
|
||||
class EmbeddingCacheTableManager {
|
||||
class BACKEND_EXPORT EmbeddingCacheTableManager {
|
||||
public:
|
||||
static EmbeddingCacheTableManager &GetInstance();
|
||||
|
||||
|
|
|
@ -33,14 +33,71 @@ const ShapeVector kOneDimShape = {1};
|
|||
// Two dimensional shape placeholder.
|
||||
const ShapeVector kTwoDimsShape = {1, 1};
|
||||
|
||||
// One dimensional shape placeholder.
|
||||
const ShapeVector kOneDimDynamicShape = {-1};
|
||||
// Two dimensional shape placeholder.
|
||||
const ShapeVector kTwoDimsDynamicShape = {-1, -1};
|
||||
|
||||
// The output tensor number of recv node.
|
||||
const size_t kRecvNodeOutputNum = 3;
|
||||
|
||||
// The input index of offset of EmbeddingLookup kernel.
|
||||
constexpr size_t kEmbeddingLookupOffsetIdx = 2;
|
||||
|
||||
// The dims of embedding table.
|
||||
constexpr size_t kEmbeddingTableDims = 2;
|
||||
|
||||
constexpr char kEmbeddingRemoteCacheNode[] = "EmbeddingRemoteCacheNode";
|
||||
constexpr char kEmbeddingLocalCacheNode[] = "EmbeddingLocalCacheNode";
|
||||
namespace {
|
||||
ValueNodePtr CreateFakeValueNode(const AnfNodePtr &origin_node) {
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
abstract::AbstractTensorPtr origin_abstract = origin_node->abstract()->cast<abstract::AbstractTensorPtr>();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(origin_abstract);
|
||||
tensor::TensorPtr fake_tensor = std::make_shared<tensor::Tensor>(origin_abstract->element()->BuildType()->type_id(),
|
||||
origin_abstract->shape()->shape());
|
||||
MS_EXCEPTION_IF_NULL(fake_tensor);
|
||||
fake_tensor->set_base_shape(origin_abstract->shape()->Clone());
|
||||
|
||||
auto fake_value = NewValueNode(fake_tensor);
|
||||
MS_EXCEPTION_IF_NULL(fake_value);
|
||||
fake_value->set_abstract(fake_tensor->ToAbstract());
|
||||
return fake_value;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateOutputNode(const FuncGraphPtr &func_graph, const AnfNodePtr &origin_output) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(origin_output);
|
||||
MS_EXCEPTION_IF_NULL(origin_output->abstract());
|
||||
if (origin_output->abstract()->isa<abstract::AbstractTuple>()) {
|
||||
abstract::AbstractBasePtrList new_elements_abs;
|
||||
std::vector<ValuePtr> new_elements_values;
|
||||
|
||||
auto tuple_elements = origin_output->abstract()->cast<abstract::AbstractTuplePtr>()->elements();
|
||||
for (const auto &element : tuple_elements) {
|
||||
MS_EXCEPTION_IF_NULL(element);
|
||||
auto tensor_abstract = element->cast<abstract::AbstractTensorPtr>();
|
||||
if (!tensor_abstract) {
|
||||
MS_LOG(EXCEPTION) << "Only support to replace tuple with all tensor elements.";
|
||||
}
|
||||
auto fake_tensor = std::make_shared<tensor::Tensor>(tensor_abstract->element()->BuildType()->type_id(),
|
||||
tensor_abstract->shape()->shape());
|
||||
MS_EXCEPTION_IF_NULL(fake_tensor);
|
||||
new_elements_abs.push_back(fake_tensor->ToAbstract());
|
||||
new_elements_values.push_back(fake_tensor);
|
||||
}
|
||||
ValueTuplePtr value_tuple = std::make_shared<ValueTuple>(new_elements_values);
|
||||
auto value_tuple_abs = std::make_shared<abstract::AbstractTuple>(new_elements_abs);
|
||||
auto value_tuple_node = NewValueNode(value_tuple);
|
||||
MS_EXCEPTION_IF_NULL(value_tuple_node);
|
||||
value_tuple_node->set_abstract(value_tuple_abs);
|
||||
return value_tuple_node;
|
||||
} else {
|
||||
return CreateFakeValueNode(origin_output);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void PsEmbeddingCacheInserter::GetEmbeddingLookupNodes() {
|
||||
MS_EXCEPTION_IF_NULL(root_graph_);
|
||||
|
@ -214,15 +271,21 @@ FuncGraphPtr PsEmbeddingCacheInserter::ConstructEmbeddingLookupSubGraph(const An
|
|||
input_param->set_abstract(param->abstract()->Clone());
|
||||
ParameterPtr input_indices = graph->add_parameter();
|
||||
MS_EXCEPTION_IF_NULL(input_indices);
|
||||
input_indices->set_abstract(std::make_shared<abstract::AbstractTensor>(kInt32, kOneDimShape));
|
||||
input_indices->set_abstract(std::make_shared<abstract::AbstractTensor>(
|
||||
kInt32, std::make_shared<abstract::Shape>(kOneDimDynamicShape, kOneDimShape, kOneDimShape)));
|
||||
|
||||
// 2. Create EmbeddingLookup node.
|
||||
auto offset_node = common::AnfAlgo::GetInputNode(node->cast<CNodePtr>(), kEmbeddingLookupOffsetIdx);
|
||||
MS_EXCEPTION_IF_NULL(offset_node);
|
||||
auto offset_value_node = offset_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(offset_value_node);
|
||||
int64_t offset = GetValue<int64_t>(offset_value_node->value());
|
||||
|
||||
PrimitivePtr emb_lookup_primitive = std::make_shared<Primitive>(kEmbeddingLookupOpName);
|
||||
std::vector<AnfNodePtr> emb_lookup_inputs{
|
||||
NewValueNode(emb_lookup_primitive), input_param, input_indices,
|
||||
common::AnfAlgo::GetInputNode(node->cast<CNodePtr>(), kEmbeddingLookupOffsetIdx)};
|
||||
std::vector<AnfNodePtr> emb_lookup_inputs{NewValueNode(emb_lookup_primitive), input_param, input_indices};
|
||||
auto embedding_cache_lookup_node = graph->NewCNode(emb_lookup_inputs);
|
||||
MS_EXCEPTION_IF_NULL(embedding_cache_lookup_node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrOffset, MakeValue(offset), embedding_cache_lookup_node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), embedding_cache_lookup_node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), embedding_cache_lookup_node);
|
||||
|
||||
|
@ -258,35 +321,36 @@ FuncGraphPtr PsEmbeddingCacheInserter::ConstructUpdateEmbeddingSubGraph(const Pa
|
|||
MS_EXCEPTION_IF_NULL(input_param);
|
||||
MS_EXCEPTION_IF_NULL(param->abstract());
|
||||
input_param->set_abstract(param->abstract()->Clone());
|
||||
|
||||
ParameterPtr input_indices = graph->add_parameter();
|
||||
MS_EXCEPTION_IF_NULL(input_indices);
|
||||
input_indices->set_abstract(std::make_shared<abstract::AbstractTensor>(kInt32, kOneDimShape));
|
||||
input_indices->set_abstract(std::make_shared<abstract::AbstractTensor>(
|
||||
kInt32, std::make_shared<abstract::Shape>(kOneDimDynamicShape, kOneDimShape, kOneDimShape)));
|
||||
|
||||
ParameterPtr update_values = graph->add_parameter();
|
||||
MS_EXCEPTION_IF_NULL(update_values);
|
||||
update_values->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, kTwoDimsShape));
|
||||
std::vector<size_t> emb_shape = common::AnfAlgo::GetOutputInferShape(param, 0);
|
||||
if (emb_shape.size() != kEmbeddingTableDims) {
|
||||
MS_LOG(EXCEPTION) << "Embedding table should be 2 dims for embedding cache mode, but got: " << emb_shape.size()
|
||||
<< " dims";
|
||||
}
|
||||
int64_t emb_dim = SizeToLong(emb_shape.back());
|
||||
ShapeVector update_values_shape = {-1, emb_dim};
|
||||
ShapeVector update_values_min_shape = {1, emb_dim};
|
||||
ShapeVector update_values_max_shape = {1, emb_dim};
|
||||
update_values->set_abstract(std::make_shared<abstract::AbstractTensor>(
|
||||
kFloat32,
|
||||
std::make_shared<abstract::Shape>(update_values_shape, update_values_min_shape, update_values_max_shape)));
|
||||
|
||||
// 2. Create Sub node.
|
||||
PrimitivePtr sub_primitive = std::make_shared<Primitive>(kSubOpName);
|
||||
auto sub_value_node = common::AnfAlgo::GetInputNode(node->cast<CNodePtr>(), kEmbeddingLookupOffsetIdx);
|
||||
int64_t sub_value = GetValue<int64_t>(GetValueNode(sub_value_node));
|
||||
// The input of Sub must be tensor type.
|
||||
auto sub_value_tensor = std::make_shared<tensor::Tensor>(sub_value, kInt32);
|
||||
std::vector<AnfNodePtr> sub_inputs{NewValueNode(sub_primitive), input_indices, NewValueNode(sub_value_tensor)};
|
||||
auto sub_node = graph->NewCNode(sub_inputs);
|
||||
MS_EXCEPTION_IF_NULL(sub_node);
|
||||
sub_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kInt32, kOneDimShape));
|
||||
common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), sub_node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), sub_node);
|
||||
|
||||
// 3. Create ScatterUpdate node.
|
||||
// 2. Create ScatterUpdate node.
|
||||
PrimitivePtr embedding_cache_update_primitive = std::make_shared<Primitive>(kScatterUpdateOpName);
|
||||
std::vector<AnfNodePtr> embedding_cache_update_inputs{NewValueNode(embedding_cache_update_primitive), input_param,
|
||||
sub_node, update_values};
|
||||
input_indices, update_values};
|
||||
auto embedding_cache_update_node = graph->NewCNode(embedding_cache_update_inputs);
|
||||
MS_EXCEPTION_IF_NULL(embedding_cache_update_node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), embedding_cache_update_node);
|
||||
|
||||
// 4. Create return node.
|
||||
// 3. Create return node.
|
||||
CNodePtr return_node = CreateReturnNode(graph, embedding_cache_update_node);
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
graph->set_return(return_node);
|
||||
|
@ -304,17 +368,25 @@ CNodePtr PsEmbeddingCacheInserter::CreateRecvNode() const {
|
|||
MS_EXCEPTION_IF_NULL(root_graph_);
|
||||
ParameterPtr input_indices = root_graph_->add_parameter();
|
||||
MS_EXCEPTION_IF_NULL(input_indices);
|
||||
input_indices->set_abstract(std::make_shared<abstract::AbstractTensor>(kInt32, kOneDimShape));
|
||||
input_indices->set_abstract(std::make_shared<abstract::AbstractTensor>(
|
||||
kInt32, std::make_shared<abstract::Shape>(kOneDimDynamicShape, kOneDimShape, kOneDimShape)));
|
||||
auto fake_input_indices_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt32, kOneDimShape);
|
||||
input_indices->set_default_param(fake_input_indices_tensor);
|
||||
|
||||
// The update values input.
|
||||
ParameterPtr update_values = root_graph_->add_parameter();
|
||||
MS_EXCEPTION_IF_NULL(update_values);
|
||||
update_values->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, kTwoDimsShape));
|
||||
update_values->set_abstract(std::make_shared<abstract::AbstractTensor>(
|
||||
kFloat32, std::make_shared<abstract::Shape>(kTwoDimsDynamicShape, kTwoDimsShape, kTwoDimsShape)));
|
||||
auto fake_update_values_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, kTwoDimsShape);
|
||||
update_values->set_default_param(fake_update_values_tensor);
|
||||
|
||||
// The service id input, used to choose service to execute.
|
||||
ParameterPtr service_id = root_graph_->add_parameter();
|
||||
MS_EXCEPTION_IF_NULL(service_id);
|
||||
service_id->set_abstract(std::make_shared<abstract::AbstractTensor>(kInt32, kOneDimShape));
|
||||
auto fake_id_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt32, kOneDimDynamicShape);
|
||||
service_id->set_default_param(fake_id_tensor);
|
||||
|
||||
// 2. Create a RpcRecv node.
|
||||
std::vector<AnfNodePtr> recv_inputs = {NewValueNode(std::make_shared<Primitive>(kRpcRecvOpName))};
|
||||
|
@ -431,12 +503,15 @@ bool PsEmbeddingCacheInserter::ConstructEmbeddingCacheGraph() const {
|
|||
CNodePtr call_node = root_graph_->NewCNode({switch_layer_node});
|
||||
MS_EXCEPTION_IF_NULL(call_node);
|
||||
|
||||
// 4. Replace useless nodes of origin function graph.
|
||||
// 4. Replace origin output and useless nodes of origin function graph.
|
||||
AnfNodePtr old_output = root_graph_->output();
|
||||
AnfNodePtr new_output = CreateOutputNode(root_graph_, old_output);
|
||||
auto final_output_node = root_graph_->NewCNode({NewValueNode(prim::kPrimDepend), new_output, call_node});
|
||||
MS_EXCEPTION_IF_NULL(final_output_node);
|
||||
|
||||
auto graph_manager = root_graph_->manager();
|
||||
MS_EXCEPTION_IF_NULL(graph_manager);
|
||||
graph_manager->Replace(root_graph_->output(), call_node);
|
||||
auto return_node = root_graph_->get_return();
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
graph_manager->Replace(root_graph_->output(), final_output_node);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -452,6 +527,10 @@ bool PsEmbeddingCacheInserter::Run() {
|
|||
|
||||
// Set attr(device target attr and graph split label) for all CNodes.
|
||||
SetAttrForAllNodes();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(root_graph_);
|
||||
// Need renormalize to infer shape and set abstract.
|
||||
root_graph_->set_flag(kFlagNeedRenormalize, true);
|
||||
return true;
|
||||
}
|
||||
} // namespace parallel
|
||||
|
|
|
@ -27,6 +27,9 @@
|
|||
#include "mindspore/core/utils/ms_context.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "include/common/debug/draw.h"
|
||||
#ifdef WITH_BACKEND
|
||||
#include "ps/ps_context.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
@ -835,14 +838,87 @@ GraphSplitter::GraphSplitter(const FuncGraphPtr &func_graph, uint32_t rank_id, c
|
|||
: func_graph_(func_graph),
|
||||
rank_id_(rank_id),
|
||||
role_(role),
|
||||
mode_(distributed::DistExecutionMode::kPSMode),
|
||||
exec_mode_(nullptr),
|
||||
this_process_label_({rank_id, role}),
|
||||
node_labels_{},
|
||||
need_fuse_rpc_nodes_(true) {
|
||||
bool enable_embedding_cache = false;
|
||||
#ifdef WITH_BACKEND
|
||||
enable_embedding_cache = ps::PSContext::instance()->cache_enable();
|
||||
#endif
|
||||
mode_ = enable_embedding_cache ? distributed::DistExecutionMode::kEmbeddingCacheMode
|
||||
: distributed::DistExecutionMode::kPSMode;
|
||||
default_label_ = {0, distributed::kEnvRoleOfWorker};
|
||||
}
|
||||
|
||||
void EmbeddingCacheMode::PreBuildDistributedGraph() {
|
||||
// Only need add embedding cache ops of remote cache.
|
||||
if (role_ != distributed::kEnvRoleOfPServer) {
|
||||
return;
|
||||
}
|
||||
|
||||
// 1. Add embedding cache ops of remote cache, and build service-side graph.
|
||||
AddEmbeddingCacheOps();
|
||||
|
||||
// 2. Get node labels.
|
||||
MS_EXCEPTION_IF_NULL(node_labels_);
|
||||
node_labels_->clear();
|
||||
|
||||
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
|
||||
(void)std::for_each(all_nodes.begin(), all_nodes.end(), [this](const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<CNode>()) {
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
OperatorLabel label = GetNodeLabel(cnode);
|
||||
node_labels_->emplace(node, label);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void EmbeddingCacheMode::AddEmbeddingCacheOps() const {
|
||||
uint32_t worker_num = ClusterContext::instance()->node_num(distributed::kEnvRoleOfWorker);
|
||||
if (worker_num == 0) {
|
||||
MS_LOG(EXCEPTION) << "In embedding cache mode, worker number should be greater than 0.";
|
||||
}
|
||||
|
||||
// Build service-side graph.
|
||||
std::shared_ptr<parallel::PsEmbeddingCacheInserter> embedding_cache_inserter =
|
||||
std::make_shared<parallel::PsEmbeddingCacheInserter>(func_graph_, static_cast<int64_t>(rank_id_), role_,
|
||||
worker_num);
|
||||
if (!embedding_cache_inserter->Run()) {
|
||||
MS_LOG(EXCEPTION) << "Insert ps embedding cache failed.";
|
||||
}
|
||||
}
|
||||
|
||||
OperatorLabel EmbeddingCacheMode::GetNodeLabel(const AnfNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Only CNode has distributed split label.";
|
||||
}
|
||||
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
auto prim_node = cnode->input(0);
|
||||
if (IsValueNode<Primitive>(prim_node)) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(prim_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (prim->HasAttr(distributed::kOpLabelRankId) && prim->HasAttr(distributed::kOpLabelRole)) {
|
||||
MS_LOG(INFO) << "CNode which has distributed split label: " << cnode->fullname_with_scope();
|
||||
uint32_t rank_id = static_cast<uint32_t>(GetValue<int64_t>(prim->GetAttr(distributed::kOpLabelRankId)));
|
||||
std::string ms_role = GetValue<std::string>(prim->GetAttr(distributed::kOpLabelRole));
|
||||
return {rank_id, ms_role};
|
||||
}
|
||||
} else {
|
||||
// Get label for call node, 'call' node hasn't primitive to save attrs, so get attrs of 'call' from cnode.
|
||||
if (cnode->HasAttr(distributed::kOpLabelRankId) && cnode->HasAttr(distributed::kOpLabelRole)) {
|
||||
uint32_t rank_id = static_cast<uint32_t>(GetValue<int64_t>(cnode->GetAttr(distributed::kOpLabelRankId)));
|
||||
std::string ms_role = GetValue<std::string>(cnode->GetAttr(distributed::kOpLabelRole));
|
||||
return {rank_id, ms_role};
|
||||
}
|
||||
}
|
||||
return {rank_id_, role_};
|
||||
}
|
||||
|
||||
GraphSplitter::~GraphSplitter() { node_labels_.clear(); }
|
||||
|
||||
void GraphSplitter::Run() {
|
||||
|
@ -852,9 +928,7 @@ void GraphSplitter::Run() {
|
|||
// Step 1: Dye all the nodes of the whole func_graph_.
|
||||
DyeGraph();
|
||||
// If all nodes are all on this process, no need to split the graph. So return.
|
||||
if (std::find_if(node_labels_.begin(), node_labels_.end(), [&](const auto &node_to_label) {
|
||||
return node_to_label.second != this_process_label_;
|
||||
}) == node_labels_.end()) {
|
||||
if (!NeedSplitGraph()) {
|
||||
MS_LOG(INFO) << "No need to build and split distributed graph.";
|
||||
return;
|
||||
}
|
||||
|
@ -865,6 +939,11 @@ void GraphSplitter::Run() {
|
|||
// Step 3: Prebuild the distributed graph before it gets split.
|
||||
exec_mode_->PreBuildDistributedGraph();
|
||||
|
||||
if (!NeedSplitGraph()) {
|
||||
MS_LOG(INFO) << "No need to build and split distributed graph.";
|
||||
return;
|
||||
}
|
||||
|
||||
// Step 4: Create inter-process operators for segments with different labels.
|
||||
InterProcessOpEdgesInfo comm_edges = GenerateInterProcessOperators();
|
||||
|
||||
|
@ -924,6 +1003,8 @@ void GraphSplitter::CreateExecutionMode() {
|
|||
}
|
||||
if (mode_ == distributed::DistExecutionMode::kPSMode) {
|
||||
exec_mode_ = std::make_unique<ParameterServerMode>(func_graph_, &node_labels_, rank_id_, role_);
|
||||
} else if (mode_ == distributed::DistExecutionMode::kEmbeddingCacheMode) {
|
||||
exec_mode_ = std::make_unique<EmbeddingCacheMode>(func_graph_, &node_labels_, rank_id_, role_);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(exec_mode_);
|
||||
}
|
||||
|
@ -1350,5 +1431,11 @@ bool GraphSplitter::IsNodesWithSameLabel(const AnfNodePtr &node1, const AnfNodeP
|
|||
}
|
||||
return node_labels_[node1] == node_labels_[node2];
|
||||
}
|
||||
|
||||
bool GraphSplitter::NeedSplitGraph() const {
|
||||
return std::find_if(node_labels_.begin(), node_labels_.end(), [&](const auto &node_to_label) {
|
||||
return node_to_label.second != this_process_label_;
|
||||
}) != node_labels_.end();
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
#else
|
||||
#include "distributed/cluster/dummy_cluster_context.h"
|
||||
#endif
|
||||
#include "frontend/parallel/cache_embedding/ps_embedding_cache_inserter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
@ -314,6 +315,21 @@ class ParameterServerMode : public DistributedExecutionMode {
|
|||
std::vector<size_t> ps_optimizer_fusion_segments_;
|
||||
};
|
||||
|
||||
class EmbeddingCacheMode : public DistributedExecutionMode {
|
||||
public:
|
||||
explicit EmbeddingCacheMode(const FuncGraphPtr &func_graph, NodeLabels *node_labels, uint32_t rank_id,
|
||||
const std::string &role)
|
||||
: DistributedExecutionMode(func_graph, node_labels, rank_id, role) {}
|
||||
~EmbeddingCacheMode() = default;
|
||||
|
||||
void PreBuildDistributedGraph() override;
|
||||
|
||||
private:
|
||||
void AddEmbeddingCacheOps() const;
|
||||
|
||||
OperatorLabel GetNodeLabel(const AnfNodePtr &node) const;
|
||||
};
|
||||
|
||||
// The class is used as an action in pipeline. It will process the graph and split the nodes to each process in the
|
||||
// cluster.
|
||||
class GraphSplitter {
|
||||
|
@ -388,6 +404,9 @@ class GraphSplitter {
|
|||
// Judge whether two nodes have the same distributed label.
|
||||
bool IsNodesWithSameLabel(const AnfNodePtr &node1, const AnfNodePtr &node2);
|
||||
|
||||
// Check whether need split distributed graph.
|
||||
bool NeedSplitGraph() const;
|
||||
|
||||
FuncGraphPtr func_graph_;
|
||||
|
||||
// Rank id and node role of this process. They are used to dye graph with different labels, help build split graph,
|
||||
|
|
|
@ -28,6 +28,8 @@
|
|||
#ifdef WITH_BACKEND
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "distributed/embedding_cache/embedding_cache_utils.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -715,7 +717,11 @@ Status GatherInfo::InferBias() {
|
|||
}
|
||||
#ifdef WITH_BACKEND
|
||||
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
bias_ = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound());
|
||||
if (ps::PSContext::instance()->enable_distributed_mindrt()) {
|
||||
bias_ = static_cast<int64_t>(embedding_cache_table_manager.cache_indices_lower_bound());
|
||||
} else {
|
||||
bias_ = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound());
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||
#ifdef WITH_BACKEND
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#include "distributed/embedding_cache/embedding_cache_utils.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -106,7 +107,13 @@ Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|||
MS_LOG(ERROR) << "GenerateGraph Init failed";
|
||||
return FAILED;
|
||||
}
|
||||
auto bias = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound());
|
||||
|
||||
int64_t bias = 0;
|
||||
if (ps::PSContext::instance()->enable_distributed_mindrt()) {
|
||||
bias = static_cast<int64_t>(embedding_cache_table_manager.cache_indices_lower_bound());
|
||||
} else {
|
||||
bias = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound());
|
||||
}
|
||||
auto slice_size = SizeToLong(ps::PsCacheManager::GetInstance().vocab_cache_size());
|
||||
|
||||
auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias)});
|
||||
|
|
|
@ -601,6 +601,7 @@ constexpr auto kFlagsIsCutGraph = "is_cut_graph";
|
|||
constexpr auto kFlagIsDynamicStructure = "is_dynamic_structure";
|
||||
constexpr auto kFlagIsPynativeBpropGraph = "is_pynative_bprop_graph";
|
||||
constexpr auto kFlagPyNativeRunInGraph = "pynative_run_in_graph";
|
||||
constexpr auto kFlagNeedRenormalize = "need_renormalize";
|
||||
|
||||
// TODO(dsj): for ms_function running in graph_mode. should be delete later
|
||||
constexpr auto kAttrMSFunction = "ms_function_graph";
|
||||
|
|
|
@ -110,6 +110,12 @@ void DisableMindRT(const ResourcePtr &resource) {
|
|||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT) == false) {
|
||||
return;
|
||||
}
|
||||
#ifdef WITH_BACKEND
|
||||
if (ps::PSContext::instance()->cache_enable()) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
auto func_graph = resource->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto parallel_context = parallel::ParallelContext::GetInstance();
|
||||
|
@ -915,6 +921,13 @@ void SetRunMode(const FuncGraphPtr &func_graph, compile::Backend *backend_ptr) {
|
|||
return;
|
||||
}
|
||||
|
||||
#ifdef WITH_BACKEND
|
||||
if (ps::PSContext::instance()->cache_enable()) {
|
||||
set_ctx(true, false, false);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
// GRAPH | normal network and if/for/switch scenario etc : MultiGraph path in MindRT.
|
||||
MS_LOG(INFO) << "Run graph mode with multigraph sink.";
|
||||
set_ctx(true, true, true);
|
||||
|
@ -1202,6 +1215,17 @@ bool DistributedSplitAction(const ResourcePtr &resource) {
|
|||
std::make_shared<parallel::GraphSplitter>(func_graph, node->rank_id(), node_role);
|
||||
MS_EXCEPTION_IF_NULL(splitter);
|
||||
splitter->Run();
|
||||
|
||||
// Renomalize: Infer shape and Set abstract for all nodes in graph.
|
||||
if (func_graph->has_flag(kFlagNeedRenormalize)) {
|
||||
abstract::AbstractBasePtrList args_abs;
|
||||
auto parameters = func_graph->parameters();
|
||||
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_abs),
|
||||
[](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
|
||||
FuncGraphPtr new_fg = Renormalize(resource, func_graph, args_abs);
|
||||
resource->set_func_graph(new_fg);
|
||||
resource->set_args_abs(args_abs);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -889,7 +889,6 @@ bool AddEmbeddingCachePass(const ResourcePtr &resource) {
|
|||
}
|
||||
|
||||
std::vector<PassItem> kVmPasses = {
|
||||
{"add_embedding_cache", AddEmbeddingCachePass},
|
||||
{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||
{"opt_a", OptPassAGroup},
|
||||
{"clean_after_opta", CleanAfterOptAPass},
|
||||
|
|
|
@ -79,6 +79,7 @@
|
|||
#include "fl/server/server.h"
|
||||
#include "fl/worker/fl_worker.h"
|
||||
#include "distributed/cluster/cluster_context.h"
|
||||
#include "runtime/graph_scheduler/embedding_cache_scheduler.h"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_D
|
||||
|
@ -1655,6 +1656,7 @@ void ClearResAtexit() {
|
|||
RecordExitStatus();
|
||||
#ifdef WITH_BACKEND
|
||||
if (distributed::cluster::ClusterContext::instance()->initialized()) {
|
||||
runtime::EmbeddingCacheScheduler::GetInstance().Finalize();
|
||||
(void)distributed::cluster::ClusterContext::instance()->Finalize(UINT32_MAX);
|
||||
} else if (ps::PSContext::instance()->is_ps_mode() && ps::PSContext::instance()->is_worker()) {
|
||||
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
|
|
|
@ -394,11 +394,14 @@ bool AscendKernelExecutor::LaunchKernel(const CNodePtr &kernel, const vector<Add
|
|||
|
||||
bool is_dynamic_shape = common::AnfAlgo::IsDynamicShape(kernel);
|
||||
if (!is_dynamic_shape || !(common::AnfAlgo::GetBooleanAttr(kernel, kAttrMSFunction))) {
|
||||
std::lock_guard<std::mutex> locker(launch_mutex_);
|
||||
// launch atomic clean
|
||||
if (!LaunchAtomicClean(kernel, workspace, outputs)) {
|
||||
MS_LOG(ERROR) << "Launch AtomicClean failed, pre kernel full name: " << kernel->fullname_with_scope();
|
||||
return false;
|
||||
auto iter = node_atomics_persistent_cache_.find(kernel);
|
||||
if (iter != node_atomics_persistent_cache_.end()) {
|
||||
std::lock_guard<std::mutex> locker(launch_mutex_);
|
||||
// launch atomic clean
|
||||
if (!LaunchAtomicClean(kernel, workspace, outputs)) {
|
||||
MS_LOG(ERROR) << "Launch AtomicClean failed, pre kernel full name: " << kernel->fullname_with_scope();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -18,145 +18,152 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
GatherV2FwdGpuKernelMod, double, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherV2FwdGpuKernelMod, double, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
GatherV2FwdGpuKernelMod, double, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherV2FwdGpuKernelMod, double, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2FwdGpuKernelMod, float, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherV2FwdGpuKernelMod, float, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2FwdGpuKernelMod, float, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherV2FwdGpuKernelMod, float, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2FwdGpuKernelMod, half, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherV2FwdGpuKernelMod, half, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2FwdGpuKernelMod, half, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherV2FwdGpuKernelMod, half, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
GatherV2FwdGpuKernelMod, int, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherV2FwdGpuKernelMod, int, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
GatherV2FwdGpuKernelMod, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherV2FwdGpuKernelMod, int, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
GatherV2FwdGpuKernelMod, int16_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherV2FwdGpuKernelMod, int16_t, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
GatherV2FwdGpuKernelMod, int16_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherV2FwdGpuKernelMod, int16_t, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
GatherV2FwdGpuKernelMod, int8_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherV2FwdGpuKernelMod, int8_t, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
GatherV2FwdGpuKernelMod, int8_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherV2FwdGpuKernelMod, int8_t, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
GatherV2FwdGpuKernelMod, uint, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherV2FwdGpuKernelMod, uint, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
GatherV2FwdGpuKernelMod, uint, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherV2FwdGpuKernelMod, uint, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
GatherV2FwdGpuKernelMod, uint8_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherV2FwdGpuKernelMod, uint8_t, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
GatherV2FwdGpuKernelMod, uint8_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherV2FwdGpuKernelMod, uint8_t, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
GatherV2FwdGpuKernelMod, bool, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherV2FwdGpuKernelMod, bool, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
GatherV2FwdGpuKernelMod, bool, int64_t)
|
||||
GatherV2FwdGpuKernelMod, bool, int64_t, int64_t)
|
||||
// dynamic shape
|
||||
MS_REG_GPU_KERNEL_TWO(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
GatherV2FwdGpuKernelMod, double, int)
|
||||
MS_REG_GPU_KERNEL_TWO(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
GatherV2FwdGpuKernelMod, double, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2FwdGpuKernelMod, float, int)
|
||||
MS_REG_GPU_KERNEL_TWO(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2FwdGpuKernelMod, float, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2FwdGpuKernelMod, half, int)
|
||||
MS_REG_GPU_KERNEL_TWO(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2FwdGpuKernelMod, half, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
GatherV2FwdGpuKernelMod, bool, int)
|
||||
MS_REG_GPU_KERNEL_TWO(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
GatherV2FwdGpuKernelMod, bool, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
GatherV2FwdGpuKernelMod, double, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
GatherV2FwdGpuKernelMod, double, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2FwdGpuKernelMod, float, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2FwdGpuKernelMod, float, int, int)
|
||||
MS_REG_GPU_KERNEL_THREE(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2FwdGpuKernelMod, float, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2FwdGpuKernelMod, half, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2FwdGpuKernelMod, half, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
GatherV2FwdGpuKernelMod, bool, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
GatherV2FwdGpuKernelMod, bool, int64_t, int64_t)
|
||||
// dynamic shape ends
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
SparseGatherV2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2FwdGpuKernelMod, float, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherV2FwdGpuKernelMod, float, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(
|
||||
SparseGatherV2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2FwdGpuKernelMod, half, int)
|
||||
MS_REG_GPU_KERNEL_TWO(SparseGatherV2,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2FwdGpuKernelMod, float, int)
|
||||
MS_REG_GPU_KERNEL_TWO(SparseGatherV2,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2FwdGpuKernelMod, half, int)
|
||||
GatherV2FwdGpuKernelMod, half, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(SparseGatherV2,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2FwdGpuKernelMod, float, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_THREE(SparseGatherV2,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2FwdGpuKernelMod, half, int, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename S>
|
||||
template <typename T, typename S, typename G>
|
||||
class GatherV2FwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
||||
public:
|
||||
GatherV2FwdGpuKernelMod() { ResetResource(); }
|
||||
|
@ -43,16 +43,16 @@ class GatherV2FwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
S *indices_addr = GetDeviceAddress<S>(inputs, 1);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
if (is_dynamic_shape_) {
|
||||
int64_t *axis_device_address = GetDeviceAddress<int64_t>(inputs, 2); // only get this if in dynamic mode
|
||||
G *axis_device_address = GetDeviceAddress<G>(inputs, 2); // only get this if in dynamic mode
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(&axis_, axis_device_address, sizeof(int64_t), cudaMemcpyDeviceToHost,
|
||||
cudaMemcpyAsync(&axis_, axis_device_address, sizeof(G), cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync axis_ failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(),
|
||||
"cudaDeviceSyncFailed - GatherV2 - in dynamic mode");
|
||||
Reshape();
|
||||
}
|
||||
auto input_dim1 = input_shapes_[IntToSize(axis_)];
|
||||
auto input_dim1 = input_shapes_[axis_];
|
||||
|
||||
MS_EXCEPTION_IF_NULL(input_addr);
|
||||
MS_EXCEPTION_IF_NULL(indices_addr);
|
||||
|
@ -85,7 +85,7 @@ class GatherV2FwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
}
|
||||
if (!is_dynamic_shape_) {
|
||||
int dims = SizeToInt(input_shapes_.size());
|
||||
axis_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "axis"));
|
||||
axis_ = static_cast<G>(GetAttr<int64_t>(kernel_node, "axis"));
|
||||
if (axis_ < -dims || axis_ >= dims) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-" << dims << "," << dims
|
||||
<< "), but got " << axis_;
|
||||
|
@ -115,7 +115,7 @@ class GatherV2FwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
size = common::AnfAlgo::TensorSizeInByte<T>(indices_shapes_);
|
||||
input_size_list_.push_back(size);
|
||||
if (is_dynamic_shape_) {
|
||||
input_size_list_.push_back(sizeof(int64_t));
|
||||
input_size_list_.push_back(sizeof(G));
|
||||
}
|
||||
size = common::AnfAlgo::TensorSizeInByte<T>(output_shapes_);
|
||||
output_size_list_.push_back(size);
|
||||
|
@ -148,7 +148,7 @@ class GatherV2FwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
std::vector<size_t> indices_shapes_;
|
||||
std::vector<size_t> output_shapes_;
|
||||
size_t dims_[3] = {};
|
||||
int64_t axis_;
|
||||
G axis_;
|
||||
bool is_dynamic_shape_;
|
||||
bool is_null_input_;
|
||||
};
|
||||
|
|
|
@ -618,7 +618,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph &graph) {
|
|||
DeviceAddressPtr device_address = GetInternalDeviceAddress(graph, item);
|
||||
#ifdef WITH_BACKEND
|
||||
const std::string ¶m_name = item->fullname_with_scope();
|
||||
if (ps::ps_cache_instance.IsHashTable(param_name)) {
|
||||
if (ps::ps_cache_instance.IsHashTable(param_name) && !ps::PSContext::instance()->enable_distributed_mindrt()) {
|
||||
MS_LOG(INFO) << "Parameter(" << param_name << ")"
|
||||
<< " enables the embeddingLookup cache in parameter server training mode.";
|
||||
// PS embeddingLookup cache check.
|
||||
|
|
|
@ -15,10 +15,12 @@
|
|||
*/
|
||||
|
||||
#include "runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h"
|
||||
#include <limits>
|
||||
#include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "runtime/graph_scheduler/actor/rpc/rpc_actor.h"
|
||||
#include "proto/topology.pb.h"
|
||||
#include "distributed/constants.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
|
@ -44,6 +46,13 @@ ParameterPtr NewParameter(const KernelGraphPtr &graph, TypePtr type, const Shape
|
|||
auto abstract = std::make_shared<abstract::AbstractTensor>(type, shape);
|
||||
param->set_abstract(abstract);
|
||||
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
std::vector<std::string> formats = {kOpFormat_DEFAULT};
|
||||
std::vector<TypeId> types = {type->type_id()};
|
||||
kernel_build_info_builder->SetOutputsFormat(formats);
|
||||
kernel_build_info_builder->SetOutputsDeviceType(types);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
|
||||
|
||||
auto mutable_inputs = graph->MutableInputs();
|
||||
MS_EXCEPTION_IF_NULL(mutable_inputs);
|
||||
mutable_inputs->push_back(param);
|
||||
|
@ -51,6 +60,25 @@ ParameterPtr NewParameter(const KernelGraphPtr &graph, TypePtr type, const Shape
|
|||
return param;
|
||||
}
|
||||
|
||||
ValueNodePtr NewValueNode(int64_t value) {
|
||||
auto tensor = std::make_shared<tensor::Tensor>(static_cast<int64_t>(0), kInt32);
|
||||
auto value_node = NewValueNode(tensor);
|
||||
value_node->set_abstract(tensor->ToAbstract());
|
||||
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
std::vector<std::string> formats = {kOpFormat_DEFAULT};
|
||||
std::vector<TypeId> types = {kInt32->type_id()};
|
||||
kernel_build_info_builder->SetOutputsFormat(formats);
|
||||
kernel_build_info_builder->SetOutputsDeviceType(types);
|
||||
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
value_node->set_kernel_info(kernel_info);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), value_node.get());
|
||||
|
||||
return value_node;
|
||||
}
|
||||
|
||||
bool InferOpShape(const CNodePtr &kernel) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
opt::dynamic_shape::InferOp(kernel);
|
||||
|
@ -171,7 +199,11 @@ bool MemcpyDeviceToHostAsync(void *dst, const void *src, size_t size, const Devi
|
|||
} // namespace
|
||||
|
||||
void EmbeddingCachePrefetchActor::Initialize() {
|
||||
if (initialized_) {
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(device_context_);
|
||||
MS_EXCEPTION_IF_NULL(device_context_->device_res_manager_);
|
||||
if (!device_context_->device_res_manager_->CreateStream(&stream_id_)) {
|
||||
MS_LOG(EXCEPTION) << "Create stream failed.";
|
||||
}
|
||||
|
@ -196,30 +228,49 @@ void EmbeddingCachePrefetchActor::Initialize() {
|
|||
// Build and link rpc operators.
|
||||
BuildRpcOperators();
|
||||
LinkRpcOperators();
|
||||
|
||||
initialized_ = true;
|
||||
}
|
||||
|
||||
void EmbeddingCachePrefetchActor::Finalize() {
|
||||
if (!initialized_ || finalized_) {
|
||||
return;
|
||||
}
|
||||
SyncEmbeddingTable();
|
||||
|
||||
running_ = false;
|
||||
FinalizeRemote();
|
||||
|
||||
PsDataPrefetch::GetInstance().NotifyFinalize();
|
||||
data_parser_.notify_all();
|
||||
|
||||
embedding_cache_lookup_node_ = nullptr;
|
||||
embedding_cache_update_node_ = nullptr;
|
||||
|
||||
rpc_operators_.clear();
|
||||
finalized_ = true;
|
||||
initialized_ = false;
|
||||
}
|
||||
|
||||
void EmbeddingCachePrefetchActor::BuildEmbeddingCacheLookupKernel() {
|
||||
auto graph = std::make_shared<KernelGraph>();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
graph->set_graph_id((std::numeric_limits<uint32_t>::max)());
|
||||
embedding_cache_graphs_.push_back(graph);
|
||||
|
||||
// 1. Create parameter nodes which are inputs of embedding cache look up kernel(operator name: 'Gather').
|
||||
ParameterPtr input_param = NewParameter(graph, kFloat32, kTwoDimensionalShape);
|
||||
ParameterPtr input_indices = NewParameter(graph, kInt32, kOneDimensionalShape);
|
||||
ValueNodePtr axis_value_node = NewValueNode(0);
|
||||
|
||||
// 2. Create a CNode for operator Gather.
|
||||
PrimitivePtr emb_lookup_primitive = std::make_shared<Primitive>(kGatherV2OpName);
|
||||
emb_lookup_primitive->set_attr(kAttrAxis, MakeValue<int64_t>(0));
|
||||
emb_lookup_primitive->set_attr(kAttrInputIsDynamicShape, MakeValue(true));
|
||||
emb_lookup_primitive->set_attr(kAttrOutputIsDynamicShape, MakeValue(true));
|
||||
emb_lookup_primitive->set_attr(kAttrStream, MakeValue(stream_id_));
|
||||
|
||||
std::vector<AnfNodePtr> emb_lookup_input_nodes{NewValueNode(emb_lookup_primitive), input_param, input_indices};
|
||||
std::vector<AnfNodePtr> emb_lookup_input_nodes{NewValueNode(emb_lookup_primitive), input_param, input_indices,
|
||||
axis_value_node};
|
||||
embedding_cache_lookup_node_ = graph->NewCNode(emb_lookup_input_nodes);
|
||||
MS_EXCEPTION_IF_NULL(embedding_cache_lookup_node_);
|
||||
auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, kTwoDimensionalShape);
|
||||
|
@ -227,11 +278,15 @@ void EmbeddingCachePrefetchActor::BuildEmbeddingCacheLookupKernel() {
|
|||
|
||||
// 3. Kernel build process.
|
||||
MS_EXCEPTION_IF_NULL(device_context_);
|
||||
MS_EXCEPTION_IF_NULL(device_context_->kernel_executor_);
|
||||
device_context_->kernel_executor_->CreateKernel({embedding_cache_lookup_node_});
|
||||
}
|
||||
|
||||
void EmbeddingCachePrefetchActor::BuildEmbeddingCacheUpdateKernel() {
|
||||
auto graph = std::make_shared<KernelGraph>();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
graph->set_graph_id((std::numeric_limits<uint32_t>::max)());
|
||||
embedding_cache_graphs_.push_back(graph);
|
||||
|
||||
// 1. Create parameter nodes which are inputs of embedding cache update kernel(operator name: 'ScatterUpdate').
|
||||
ParameterPtr input_param = NewParameter(graph, kFloat32, kTwoDimensionalShape);
|
||||
|
@ -252,6 +307,7 @@ void EmbeddingCachePrefetchActor::BuildEmbeddingCacheUpdateKernel() {
|
|||
|
||||
// 3. Kernel build process.
|
||||
MS_EXCEPTION_IF_NULL(device_context_);
|
||||
MS_EXCEPTION_IF_NULL(device_context_->kernel_executor_);
|
||||
device_context_->kernel_executor_->CreateKernel({embedding_cache_update_node_});
|
||||
}
|
||||
|
||||
|
@ -288,6 +344,7 @@ bool EmbeddingCachePrefetchActor::LookupDeviceCache(void *indices, void *embeddi
|
|||
AddressPtrList kernel_outputs = {std::make_shared<Address>(outputs, indices_num * embedding_size * sizeof(float))};
|
||||
|
||||
MS_ERROR_IF_NULL(device_context_);
|
||||
MS_ERROR_IF_NULL(device_context_->kernel_executor_);
|
||||
auto ret =
|
||||
device_context_->kernel_executor_->LaunchKernel(embedding_cache_lookup_node_, kernel_inputs, {}, kernel_outputs);
|
||||
if (!ret) {
|
||||
|
@ -338,6 +395,7 @@ bool EmbeddingCachePrefetchActor::UpdateDeviceCache(void *indices, void *update_
|
|||
std::make_shared<Address>(embedding_cache, cache_size * embedding_size * sizeof(float))};
|
||||
|
||||
MS_ERROR_IF_NULL(device_context_);
|
||||
MS_ERROR_IF_NULL(device_context_->kernel_executor_);
|
||||
auto ret =
|
||||
device_context_->kernel_executor_->LaunchKernel(embedding_cache_update_node_, kernel_inputs, {}, kernel_outputs);
|
||||
if (!ret) {
|
||||
|
@ -809,6 +867,7 @@ bool EmbeddingCachePrefetchActor::PushCacheFromDeviceToLocalHost(const HashTable
|
|||
"Memcpy device to host asynchronously failed.");
|
||||
|
||||
MS_ERROR_IF_NULL(device_context_);
|
||||
MS_ERROR_IF_NULL(device_context_->device_res_manager_);
|
||||
RETURN_IF_FALSE_WITH_LOG(device_context_->device_res_manager_->SyncStream(stream_id_), "Synchronize stream failed.");
|
||||
RETURN_IF_FALSE_WITH_LOG(
|
||||
InsertLocalHostCache(embedding_size, IntToSize(swap_indices_size), host_cache_device_to_host_index,
|
||||
|
@ -882,6 +941,7 @@ bool EmbeddingCachePrefetchActor::PullCacheFromLocalHostToDevice(const HashTable
|
|||
swap_indices_size, cache_vocab_size, embedding_size, hash_table_addr),
|
||||
"Update device embedding cache failed.");
|
||||
MS_ERROR_IF_NULL(device_context_);
|
||||
MS_ERROR_IF_NULL(device_context_->device_res_manager_);
|
||||
RETURN_IF_FALSE_WITH_LOG(device_context_->device_res_manager_->SyncStream(stream_id_), "Synchronize stream failed.");
|
||||
return true;
|
||||
}
|
||||
|
@ -1162,9 +1222,11 @@ bool EmbeddingCachePrefetchActor::PartitionIdsAndEmbeddings(const int *ids, size
|
|||
|
||||
std::vector<int> &slice_ids = slice_ids_list->at(i);
|
||||
std::vector<float> &slice_embeddings = slice_embeddings_list->at(i);
|
||||
// Ids range offset for multi server.
|
||||
int offset = SizeToInt(remote_embedding_slice_bounds_.at(i).first);
|
||||
for (size_t j = 0; j < ids_num; j++) {
|
||||
if (ids[j] >= begin && ids[j] <= end) {
|
||||
slice_ids.push_back(ids[j]);
|
||||
slice_ids.push_back(ids[j] - offset);
|
||||
slice_embeddings.insert(slice_embeddings.end(), embeddings + (j * embedding_dim),
|
||||
embeddings + (j * embedding_dim) + embedding_dim);
|
||||
}
|
||||
|
@ -1175,7 +1237,8 @@ bool EmbeddingCachePrefetchActor::PartitionIdsAndEmbeddings(const int *ids, size
|
|||
|
||||
bool EmbeddingCachePrefetchActor::SendToRemote(const std::string &cache_operation, int32_t param_key,
|
||||
size_t server_rank_id, size_t embedding_dim, const void *keys,
|
||||
size_t keys_len, const void *values, size_t values_len) {
|
||||
size_t keys_len, const void *values, size_t values_len,
|
||||
bool finalize_remote) {
|
||||
MS_ERROR_IF_NULL(keys);
|
||||
// Find sender corresponding to cache operation and parameter key.
|
||||
auto iter = rpc_operators_.find(cache_operation);
|
||||
|
@ -1215,7 +1278,7 @@ bool EmbeddingCachePrefetchActor::SendToRemote(const std::string &cache_operatio
|
|||
std::make_shared<Address>(&service_id, sizeof(int32_t))};
|
||||
|
||||
// Send data.
|
||||
return sender->Send(shapes, data_types, data_list);
|
||||
return sender->Send(shapes, data_types, data_list, finalize_remote);
|
||||
}
|
||||
|
||||
std::unique_ptr<std::vector<char>> EmbeddingCachePrefetchActor::ReceiveFromRemote(const std::string &cache_operation,
|
||||
|
@ -1394,6 +1457,19 @@ bool EmbeddingCachePrefetchActor::SyncDeviceEmbeddingTable() {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool EmbeddingCachePrefetchActor::FinalizeRemote() {
|
||||
for (size_t i = 0; i < server_num_; i++) {
|
||||
size_t embedding_dim = 1;
|
||||
int id = 0;
|
||||
float value = 0.0;
|
||||
RETURN_IF_FALSE_WITH_LOG(SendToRemote(distributed::kLookupEmbeddingCache, 0, i, embedding_dim, &id, sizeof(int),
|
||||
&value, sizeof(float), true),
|
||||
"Send finalize request to remote failed.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string EmbeddingCachePrefetchActor::channel_name() {
|
||||
std::lock_guard<std::mutex> locker(channel_mutex_);
|
||||
return channel_name_;
|
||||
|
@ -1490,9 +1566,9 @@ void EmbeddingCachePrefetchActor::LinkRpcOperators() {
|
|||
}
|
||||
|
||||
bool Sender::Send(const std::vector<ShapeVector> &shapes, const std::vector<TypeId> data_types,
|
||||
const AddressPtrList &data_list) const {
|
||||
const AddressPtrList &data_list, bool finalize_remote) const {
|
||||
MS_ERROR_IF_NULL(receiver_);
|
||||
auto message = BuildRpcMessage(shapes, data_types, data_list, receiver_->get_url(), server_url_);
|
||||
auto message = BuildRpcMessage(shapes, data_types, data_list, receiver_->get_url(), server_url_, finalize_remote);
|
||||
MS_ERROR_IF_NULL(message);
|
||||
MS_ERROR_IF_NULL(client_);
|
||||
client_->SendAsync(std::move(message));
|
||||
|
@ -1533,7 +1609,7 @@ bool Sender::ConnectServer() {
|
|||
std::unique_ptr<MessageBase> Sender::BuildRpcMessage(const std::vector<ShapeVector> &shapes,
|
||||
const std::vector<TypeId> data_types,
|
||||
const AddressPtrList &data_list, const std::string &from_url,
|
||||
const std::string &to_url) const {
|
||||
const std::string &to_url, bool finalize_remote) const {
|
||||
std::unique_ptr<MessageBase> message = std::make_unique<MessageBase>();
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(message, nullptr);
|
||||
message->from = AID("", from_url);
|
||||
|
@ -1571,6 +1647,13 @@ std::unique_ptr<MessageBase> Sender::BuildRpcMessage(const std::vector<ShapeVect
|
|||
// 4. The real data buffer need to be sent.
|
||||
message->body.append(static_cast<char *>(data->addr), data->size);
|
||||
}
|
||||
|
||||
// 5. Finalize remote command.
|
||||
if (finalize_remote) {
|
||||
message->body.append(distributed::kFinalizeMuxRecvActor);
|
||||
message->body.append(reinterpret_cast<char *>(&finalize_remote), sizeof(finalize_remote));
|
||||
}
|
||||
|
||||
return message;
|
||||
}
|
||||
|
||||
|
|
|
@ -176,7 +176,8 @@ class EmbeddingCachePrefetchActor : public ActorBase {
|
|||
// Send content to remote, such as ids or embeddings.
|
||||
// The parameter 'cache_operation' is cache operation name such as LookupEmbeddingCache and UpdateEmbeddingCache.
|
||||
bool SendToRemote(const std::string &cache_operation, int32_t param_key, size_t server_rank_id, size_t embedding_dim,
|
||||
const void *keys, size_t keys_len, const void *values = nullptr, size_t values_len = 0);
|
||||
const void *keys, size_t keys_len, const void *values = nullptr, size_t values_len = 0,
|
||||
bool finalize_remote = false);
|
||||
// Wait response of remote and get return result.
|
||||
// The parameter 'cache_operation' is cache operation name such as LookupEmbeddingCache and UpdateEmbeddingCache.
|
||||
std::unique_ptr<std::vector<char>> ReceiveFromRemote(const std::string &cache_operation, int32_t param_key,
|
||||
|
@ -251,6 +252,9 @@ class EmbeddingCachePrefetchActor : public ActorBase {
|
|||
// The embedding cache update kernel node(operator name: 'ScatterUpdate').
|
||||
CNodePtr embedding_cache_update_node_;
|
||||
|
||||
// Cache embeding cache ops kernel graphs.
|
||||
std::vector<KernelGraphPtr> embedding_cache_graphs_;
|
||||
|
||||
// Full Embedding table row num, not less than the total number of feature ids.
|
||||
size_t vocab_size_{0};
|
||||
|
||||
|
@ -291,6 +295,8 @@ class EmbeddingCachePrefetchActor : public ActorBase {
|
|||
|
||||
// The flag which indicates whether this actor is initialized.
|
||||
bool initialized_{false};
|
||||
// The flag which indicates whether this actor is finalized.
|
||||
bool finalized_{false};
|
||||
|
||||
// The flag which indicates whether finish sync embedding table.
|
||||
bool finish_sync_embedding_table_{false};
|
||||
|
@ -357,7 +363,7 @@ class Sender : public RpcOperator {
|
|||
|
||||
// Send buffer to peer.
|
||||
bool Send(const std::vector<ShapeVector> &shapes, const std::vector<TypeId> data_types,
|
||||
const AddressPtrList &data_list) const;
|
||||
const AddressPtrList &data_list, bool finalize_remote = false) const;
|
||||
|
||||
// Set the receiver paired with the sender to get the 'from url' from the receiver.
|
||||
void set_receiver(const ReceiverPtr &receiver) { receiver_ = receiver; }
|
||||
|
@ -373,7 +379,8 @@ class Sender : public RpcOperator {
|
|||
// The message.from (from url) must be set.
|
||||
std::unique_ptr<MessageBase> BuildRpcMessage(const std::vector<ShapeVector> &shapes,
|
||||
const std::vector<TypeId> data_types, const AddressPtrList &data_list,
|
||||
const std::string &from_url, const std::string &to_url) const;
|
||||
const std::string &from_url, const std::string &to_url,
|
||||
bool finalize_remote) const;
|
||||
|
||||
// The url of the peer receiver's tcp server.
|
||||
std::string server_url_;
|
||||
|
|
|
@ -72,9 +72,18 @@ void MuxRecvActor::ParseFinalizeReqData(size_t data_len, const MessageBase *cons
|
|||
}
|
||||
|
||||
const void *need_finalize_actor_data = msg_body.c_str() + data_len + finalize_header_size;
|
||||
*need_finalize = *(reinterpret_cast<const bool *>(need_finalize_actor_data));
|
||||
if (*need_finalize) {
|
||||
MS_EXCEPTION_IF_NULL(need_finalize_actor_data);
|
||||
bool finalize_in_msg = *(reinterpret_cast<const bool *>(need_finalize_actor_data));
|
||||
MS_LOG(INFO) << "Received a message which contains finalize command: " << finalize_in_msg;
|
||||
if (!finalize_in_msg) {
|
||||
return;
|
||||
}
|
||||
|
||||
recv_finalize_msg_cnt_++;
|
||||
if (recv_finalize_msg_cnt_ == ClusterContext::instance()->node_num(distributed::kEnvRoleOfWorker)) {
|
||||
*need_finalize = true;
|
||||
// Finalize loop of runtime.
|
||||
MS_EXCEPTION_IF_NULL(op_context_);
|
||||
SET_OPCONTEXT_SUCCESS_RET((*op_context_));
|
||||
}
|
||||
}
|
||||
|
@ -88,6 +97,8 @@ void MuxRecvActor::UpdateStatus() {
|
|||
void MuxRecvActor::Finalize() {
|
||||
std::unique_lock<std::mutex> lock(context_mtx_);
|
||||
finalized_ = true;
|
||||
is_ready_ = true;
|
||||
is_context_valid_ = true;
|
||||
|
||||
op_context_ = nullptr;
|
||||
context_cv_.notify_all();
|
||||
|
|
|
@ -73,7 +73,9 @@ class MuxRecvActor : public RecvActor {
|
|||
std::condition_variable is_ready_cv_;
|
||||
|
||||
// Whether the actor is finalized_
|
||||
bool finalized_{false};
|
||||
std::atomic_bool finalized_{false};
|
||||
|
||||
uint32_t recv_finalize_msg_cnt_{0};
|
||||
};
|
||||
|
||||
using MuxRecvActorPtr = std::shared_ptr<MuxRecvActor>;
|
||||
|
|
|
@ -67,7 +67,9 @@ void GetFirstEmbeddingCacheTableInfo(const KernelGraph &graph, AnfNodePtr *const
|
|||
continue;
|
||||
}
|
||||
auto size = embedding_cache_table_manager.QueryHashTableSize(param_name);
|
||||
while (input_index.first->isa<CNode>() && (common::AnfAlgo::GetCNodeName(input_index.first) == kCastOpName)) {
|
||||
while (input_index.first->isa<CNode>() &&
|
||||
((common::AnfAlgo::GetCNodeName(input_index.first) == kCastOpName) ||
|
||||
(common::AnfAlgo::GetCNodeName(input_index.first) == kTensorMoveOpName))) {
|
||||
input_index = common::AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true);
|
||||
MS_EXCEPTION_IF_NULL(input_index.first);
|
||||
}
|
||||
|
@ -114,7 +116,8 @@ void CheckSparseModeForEmbeddingCache(const CNodePtr &node) {
|
|||
|
||||
pre_node = common::AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
|
||||
MS_EXCEPTION_IF_NULL(pre_node.first);
|
||||
while (pre_node.first->isa<CNode>() && (common::AnfAlgo::GetCNodeName(pre_node.first) == kCastOpName)) {
|
||||
while (pre_node.first->isa<CNode>() && ((common::AnfAlgo::GetCNodeName(pre_node.first) == kCastOpName) ||
|
||||
(common::AnfAlgo::GetCNodeName(pre_node.first) == kTensorMoveOpName))) {
|
||||
pre_node = common::AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
|
||||
MS_EXCEPTION_IF_NULL(pre_node.first);
|
||||
}
|
||||
|
@ -147,7 +150,9 @@ void CheckGraphValidForEmbeddingCache(const KernelGraph &graph) {
|
|||
if (embedding_cache_table_manager.IsEmbeddingCacheTable(param_name) && (kernel_name == kSparseGatherV2OpName)) {
|
||||
CheckSparseModeForEmbeddingCache(kernel);
|
||||
}
|
||||
while (input_index.first->isa<CNode>() && (common::AnfAlgo::GetCNodeName(input_index.first) == kCastOpName)) {
|
||||
while (input_index.first->isa<CNode>() &&
|
||||
((common::AnfAlgo::GetCNodeName(input_index.first) == kCastOpName) ||
|
||||
(common::AnfAlgo::GetCNodeName(input_index.first) == kTensorMoveOpName))) {
|
||||
input_index = common::AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true);
|
||||
MS_EXCEPTION_IF_NULL(input_index.first);
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "utils/ms_utils.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
#include "include/backend/visible.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
|
@ -33,7 +34,7 @@ class EmbeddingCachePrefetchActor;
|
|||
// to cache large embedding table of a large recommendation network model. The cache level is:
|
||||
// Device Cache->Local Host Cache->Remote Cache. The embedding cache prefetch actor is used to perform Local
|
||||
// and Device Cache hit analysis and cache prefetching.
|
||||
class EmbeddingCacheScheduler {
|
||||
class BACKEND_EXPORT EmbeddingCacheScheduler {
|
||||
public:
|
||||
static EmbeddingCacheScheduler &GetInstance();
|
||||
|
||||
|
|
|
@ -553,6 +553,9 @@ class Model:
|
|||
cb_params.network = self._network
|
||||
if (_is_role_pserver() and not _enable_distributed_mindrt()) or _is_role_sched():
|
||||
epoch = 1
|
||||
# Embedding cache server only run one step.
|
||||
if (_is_role_pserver() or _is_role_sched()) and _cache_enable():
|
||||
epoch = 1
|
||||
cb_params.last_save_ckpt_step = None
|
||||
cb_params.latest_ckpt_file = None
|
||||
|
||||
|
@ -613,6 +616,8 @@ class Model:
|
|||
self._check_enable_recovery()
|
||||
# Used to check whether need perform recovery for process which is restarted.
|
||||
self._check_need_load_ckpt(cb_params, train_dataset.get_dataset_size(), sink_size)
|
||||
# Check whether this process is embedding cache server.
|
||||
is_embedding_cache_server = _is_role_pserver() and _cache_enable()
|
||||
|
||||
while self.epoch_iter < (epoch - initial_epoch):
|
||||
cb_params.cur_epoch_num = self.epoch_iter + 1 + initial_epoch
|
||||
|
@ -651,6 +656,9 @@ class Model:
|
|||
|
||||
if (_is_role_pserver() and not _enable_distributed_mindrt()) or _is_role_sched():
|
||||
os._exit(0)
|
||||
# Embedding cache server only run one step.
|
||||
if is_embedding_cache_server:
|
||||
break
|
||||
|
||||
dataset_helper.continue_send()
|
||||
|
||||
|
@ -681,7 +689,10 @@ class Model:
|
|||
cb_params.net_outputs = train_net_outputs
|
||||
|
||||
# In disaster recovery scenarios, need not to execute callbacks if this epoch executes failed.
|
||||
need_exec_callback_epoch_end = not (self.enable_recovery and _get_recovery_context("need_reset"))
|
||||
# Embedding cache server need not do epoch end callback, this process only run one step.
|
||||
need_exec_callback_epoch_end = not ((self.enable_recovery and _get_recovery_context("need_reset"))
|
||||
or is_embedding_cache_server)
|
||||
|
||||
if need_exec_callback_epoch_end:
|
||||
list_callback.on_train_epoch_end(run_context)
|
||||
if "metrics" in cb_params or "eval_results" in cb_params:
|
||||
|
|
Loading…
Reference in New Issue