forked from mindspore-Ecosystem/mindspore
add ps cache manager
This commit is contained in:
parent
1033166d8a
commit
e3f7ae61db
|
@ -194,6 +194,14 @@ if (ENABLE_GPU)
|
|||
)
|
||||
endif ()
|
||||
|
||||
if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
|
||||
install(
|
||||
TARGETS ps_cache
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
endif()
|
||||
|
||||
if (ENABLE_SERVING OR ENABLE_TESTCASES)
|
||||
file(GLOB_RECURSE LIBEVENT_LIB_LIST
|
||||
${libevent_LIBPATH}/libevent*
|
||||
|
|
|
@ -308,7 +308,7 @@ elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
|||
else ()
|
||||
if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
|
||||
target_link_libraries(mindspore mindspore::pslite proto_input mindspore::protobuf mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a)
|
||||
target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core)
|
||||
target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core ps_cache)
|
||||
if (${ENABLE_IBVERBS} STREQUAL "ON")
|
||||
target_link_libraries(mindspore ibverbs rdmacm)
|
||||
endif()
|
||||
|
|
|
@ -75,6 +75,9 @@ void AicpuOpKernelMod::CreateCpuKernelInfo(const std::vector<AddressPtr> &inputs
|
|||
if (kCustAiCpuKernelOps.find(node_name_) != kCustAiCpuKernelOps.end()) {
|
||||
node_so_ = CUST_AICPU_OPS_SO_NAME;
|
||||
node_name_ = kCustRunApi;
|
||||
} else if (kCacheKernelOps.find(node_name_) != kCacheKernelOps.end()) {
|
||||
node_so_ = AICPU_OPS_SO_NAME;
|
||||
node_name_ = kCustRunApi;
|
||||
} else {
|
||||
node_so_ = AICPU_OPS_SO_NAME;
|
||||
}
|
||||
|
@ -161,6 +164,9 @@ std::vector<TaskInfoPtr> AicpuOpKernelMod::GenTask(const std::vector<AddressPtr>
|
|||
if (kCustAiCpuKernelOps.find(node_name_) != kCustAiCpuKernelOps.end()) {
|
||||
node_so_ = CUST_AICPU_OPS_SO_NAME;
|
||||
node_name_ = kCustRunApi;
|
||||
} else if (kCacheKernelOps.find(node_name_) != kCacheKernelOps.end()) {
|
||||
node_so_ = AICPU_OPS_SO_NAME;
|
||||
node_name_ = kCustRunApi;
|
||||
} else {
|
||||
node_so_ = AICPU_OPS_SO_NAME;
|
||||
}
|
||||
|
|
|
@ -49,6 +49,7 @@ constexpr auto kIdentity = "Identity";
|
|||
constexpr auto kUpdateCache = "UpdateCache";
|
||||
constexpr auto kCustRunApi = "RunCpuKernel";
|
||||
const std::set<std::string> kCustAiCpuKernelOps{kEditDistance, kIdentity};
|
||||
const std::set<std::string> kCacheKernelOps{kUpdateCache};
|
||||
|
||||
struct AicpuParamHead {
|
||||
uint32_t length; // Total length: include cunstom message
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
#include "backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.h"
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "ps/worker.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -38,10 +39,13 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
key_ = AnfAlgo::GetNodeAttr<size_t>(kernel_node, kAttrPsKey);
|
||||
}
|
||||
std::vector<size_t> keys{key_, key_, key_};
|
||||
std::vector<size_t> values;
|
||||
values.insert(values.end(), input_shape.begin(), input_shape.end());
|
||||
values.insert(values.end(), indices_shape.begin(), indices_shape.end());
|
||||
values.insert(values.end(), output_shape.begin(), output_shape.end());
|
||||
std::vector<float> values;
|
||||
std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(values),
|
||||
[](size_t dim) -> float { return SizeToFloat(dim); });
|
||||
std::transform(indices_shape.begin(), indices_shape.end(), std::back_inserter(values),
|
||||
[](size_t dim) -> float { return SizeToFloat(dim); });
|
||||
std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(values),
|
||||
[](size_t dim) -> float { return SizeToFloat(dim); });
|
||||
MS_LOG(INFO) << "Init embedding lookup proxy kernel, input shape:" << input_shape
|
||||
<< ", indices_shape:" << indices_shape << ", output_shape:" << output_shape;
|
||||
std::vector<int64_t> lens{SizeToLong(input_shape.size()), SizeToLong(indices_shape.size()),
|
||||
|
|
|
@ -72,6 +72,23 @@ bool EmbeddingLookUpPSKernel::Execute(const std::vector<AddressPtr> &inputs, con
|
|||
return Launch(inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
void EmbeddingLookUpPSKernel::UpdateEmbeddings(float *embedding_table, const size_t *lookup_ids,
|
||||
const float *update_vals, size_t ids_size) {
|
||||
size_t copy_lens = outer_dim_size_ * sizeof(float);
|
||||
for (size_t i = 0; i < ids_size; ++i) {
|
||||
int index = lookup_ids[i] - offset_;
|
||||
if (index >= 0 && index < SizeToInt(first_dim_size_)) {
|
||||
auto ret =
|
||||
memcpy_s(embedding_table + index * outer_dim_size_, copy_lens, update_vals + i * outer_dim_size_, copy_lens);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed.";
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "UpdateEmbeddings index invalid.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<size_t> &EmbeddingLookUpPSKernel::input_sizes() const { return input_shape_; }
|
||||
|
||||
const std::vector<size_t> &EmbeddingLookUpPSKernel::output_sizes() const { return GetOutputSizeList(); }
|
||||
|
|
|
@ -35,7 +35,8 @@ class EmbeddingLookUpPSKernel : public EmbeddingLookUpCPUKernel, public PServerK
|
|||
|
||||
bool Execute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
void UpdateEmbeddings(float *embedding_table, const size_t *lookup_ids, const float *update_vals,
|
||||
size_t ids_size) override;
|
||||
const std::vector<size_t> &input_sizes() const override;
|
||||
const std::vector<size_t> &output_sizes() const override;
|
||||
const std::vector<size_t> &workspace_sizes() const override;
|
||||
|
|
|
@ -38,7 +38,8 @@ class PServerKernel {
|
|||
virtual void ReInit(const std::vector<std::vector<size_t>> &) {}
|
||||
virtual bool Execute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) = 0;
|
||||
|
||||
virtual void UpdateEmbeddings(float *embedding_table, const size_t *lookup_ids, const float *update_vals,
|
||||
size_t ids_size) {}
|
||||
virtual const std::vector<size_t> &input_sizes() const = 0;
|
||||
virtual const std::vector<size_t> &output_sizes() const = 0;
|
||||
virtual const std::vector<size_t> &workspace_sizes() const = 0;
|
||||
|
|
|
@ -56,6 +56,7 @@
|
|||
#include "toolchain/adx_datadump_server.h"
|
||||
#if ENABLE_CPU && ENABLE_D
|
||||
#include "ps/util.h"
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -487,11 +488,7 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
|
|||
// adjust kernel
|
||||
AdjustKernel(root_graph);
|
||||
#if ENABLE_CPU && ENABLE_D
|
||||
if (ps::Util::IsParamServerMode()) {
|
||||
CheckPSModeConsistence(root_graph);
|
||||
// Assign parameter keys.
|
||||
AssignParamKey(root_graph);
|
||||
}
|
||||
InitPsWorker(root_graph);
|
||||
#endif
|
||||
// assign stream
|
||||
AssignStream(NOT_NULL(root_graph));
|
||||
|
@ -568,6 +565,9 @@ void AscendSession::BuildGraphImpl(GraphId graph_id) {
|
|||
}
|
||||
// adjust execution order because merge child graph and other special operations
|
||||
AdjustKernel(graph);
|
||||
#if ENABLE_CPU && ENABLE_D
|
||||
InitPsWorker(graph);
|
||||
#endif
|
||||
// Reorder optimizer order
|
||||
auto execution_order = graph->execution_order();
|
||||
Reorder(&execution_order);
|
||||
|
@ -644,6 +644,10 @@ void AscendSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tens
|
|||
#if ENABLE_CPU && ENABLE_D
|
||||
// Initialize parameter server
|
||||
InitPSParamAndOptim(kernel_graph, inputs);
|
||||
std::string channel_name;
|
||||
if (ps::PsDataPrefetch::GetInstance().cache_enable() && IsGetNextGraph(graph_id, &channel_name)) {
|
||||
ps::ps_cache_instance.IncreaseGraphStep(channel_name);
|
||||
}
|
||||
#endif
|
||||
{
|
||||
// run task on device
|
||||
|
|
|
@ -21,6 +21,9 @@
|
|||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
#include "utils/comm_manager.h"
|
||||
#include "utils/scoped_long_running.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
|
|
|
@ -64,6 +64,7 @@
|
|||
#include "utils/ms_context.h"
|
||||
#if ENABLE_CPU && ENABLE_GPU
|
||||
#include "ps/util.h"
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -237,6 +238,12 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
|||
auto input_node = input_nodes[i];
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
|
||||
#if ENABLE_CPU && ENABLE_GPU
|
||||
const std::string ¶m_name = input_node->fullname_with_scope();
|
||||
if (ps::ps_cache_instance.IsHashTable(param_name)) {
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
auto pk_node = input_node->cast<ParameterPtr>();
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
|
||||
auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
|
||||
|
@ -300,16 +307,11 @@ GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr
|
|||
HardwareOptimize(graph);
|
||||
// Graph kernel fusion optimization
|
||||
GraphKernelOptimize(graph);
|
||||
|
||||
#if ENABLE_CPU && ENABLE_GPU
|
||||
if (ps::Util::IsParamServerMode()) {
|
||||
CheckPSModeConsistence(graph);
|
||||
// Assign parameter keys.
|
||||
AssignParamKey(graph);
|
||||
}
|
||||
#endif
|
||||
// Start gpu kernel runtime
|
||||
StartKernelRT();
|
||||
#if ENABLE_CPU && ENABLE_GPU
|
||||
InitPsWorker(graph);
|
||||
#endif
|
||||
// Assign CUDA streams
|
||||
AssignStream(graph);
|
||||
// Dump .pb graph before remove nop nodes
|
||||
|
@ -374,6 +376,12 @@ void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor:
|
|||
int kernel_num = kernel_graph->execution_order().size();
|
||||
int64_t loopsize = (kernel_num > 1) ? ConfigManager::GetInstance().gpu_loopsink_size() : 1;
|
||||
for (int64_t i = 0; i < loopsize; i++) {
|
||||
#if ENABLE_CPU && ENABLE_GPU
|
||||
std::string channel_name;
|
||||
if (ps::PsDataPrefetch::GetInstance().cache_enable() && IsGetNextGraph(graph_id, &channel_name)) {
|
||||
ps::ps_cache_instance.IncreaseGraphStep(channel_name);
|
||||
}
|
||||
#endif
|
||||
Execute(kernel_graph);
|
||||
}
|
||||
// In pynative mode, device addresses of tensors in value nodes need be clean.
|
||||
|
|
|
@ -41,8 +41,10 @@
|
|||
#include "utils/trace_base.h"
|
||||
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#include "ps/worker.h"
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#include "ps/common.h"
|
||||
#include "ps/util.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -1125,6 +1127,12 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
|
|||
size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type());
|
||||
}
|
||||
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
const std::string ¶m_name = input_node->fullname_with_scope();
|
||||
if (ps::ps_cache_instance.IsHashTable(param_name)) {
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if (size != 0 && !device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), size,
|
||||
|
@ -1715,8 +1723,64 @@ void SessionBasic::CleanUselessTensorsImpl(const std::shared_ptr<std::vector<ten
|
|||
}
|
||||
}
|
||||
|
||||
bool SessionBasic::IsGetNextGraph(const GraphId &graph_id, std::string *channel_name) {
|
||||
auto kernel_graph = graphs_[graph_id];
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
for (const auto &kernel_node : kernel_graph->execution_order()) {
|
||||
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (kernel_name == kGetNextOpName) {
|
||||
auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
*channel_name = GetValue<std::string>(prim->GetAttr("shared_name"));
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) {
|
||||
void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) {
|
||||
if (!ps::Util::IsRoleOfWorker()) {
|
||||
return;
|
||||
}
|
||||
CheckPSModeConsistence(kernel_graph);
|
||||
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
if (!ps::ps_cache_instance.initialized_ps_cache()) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
auto devcie_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(devcie_target, device_id_);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
auto context = runtime_instance->context();
|
||||
const auto &kernels = kernel_graph->execution_order();
|
||||
if (kernels.size() > 0 && AnfAlgo::GetCNodeName(kernels[0]) == "InitDataSetQueue") {
|
||||
GetBatchElements(kernels[0]);
|
||||
ps::ps_cache_instance.Initialize();
|
||||
}
|
||||
ps::ps_cache_instance.DoProcessData(device_id_, context);
|
||||
}
|
||||
} else {
|
||||
// Assign parameter keys.
|
||||
AssignParamKey(kernel_graph);
|
||||
}
|
||||
}
|
||||
|
||||
void SessionBasic::GetBatchElements(const AnfNodePtr &kernel_node) const {
|
||||
auto shapes = AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(kernel_node, "shapes");
|
||||
auto types = AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "types");
|
||||
if (shapes.size() != types.size() || shapes.size() == 0 || types.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid shapes of op[InitDataSetQueue]: shapes size " << shapes.size() << ", types size "
|
||||
<< types;
|
||||
}
|
||||
size_t batch_elements = 1;
|
||||
const auto &shape = shapes[0];
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
batch_elements *= shape[i];
|
||||
}
|
||||
ps::ps_cache_instance.set_batch_elements(batch_elements);
|
||||
}
|
||||
|
||||
void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const {
|
||||
auto input_nodes = kernel_graph->inputs();
|
||||
for (const auto &input_node : input_nodes) {
|
||||
if (!input_node->isa<Parameter>()) {
|
||||
|
@ -1725,8 +1789,9 @@ void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) {
|
|||
auto pk_node = input_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(pk_node);
|
||||
auto param_info_ptr = pk_node->param_info();
|
||||
if (param_info_ptr != nullptr && param_info_ptr->init_in_server()) {
|
||||
const std::string ¶m_name = pk_node->fullname_with_scope();
|
||||
const std::string ¶m_name = pk_node->fullname_with_scope();
|
||||
if (param_info_ptr != nullptr && param_info_ptr->init_in_server() &&
|
||||
!ps::ps_cache_instance.IsHashTable(param_name)) {
|
||||
MS_LOG(EXCEPTION) << "Can not initialize the parameter[" << param_name
|
||||
<< "] in server, this parameter is used by kernel which executes in device";
|
||||
}
|
||||
|
@ -1734,10 +1799,6 @@ void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) {
|
|||
}
|
||||
|
||||
void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
|
||||
if (!ps::Util::IsRoleOfWorker()) {
|
||||
MS_LOG(INFO) << "Not parameter server mode.";
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
|
@ -1775,16 +1836,8 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph,
|
|||
return;
|
||||
}
|
||||
std::vector<tensor::TensorPtr> inputs(inputs_const);
|
||||
size_t input_ctrl_size = 1;
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
if (kernel_graph->input_ctrl_tensors()) {
|
||||
input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs);
|
||||
}
|
||||
auto input_nodes = kernel_graph->inputs();
|
||||
if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) {
|
||||
MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
|
||||
<< ", input_ctrl_size:" << input_ctrl_size;
|
||||
}
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
|
|
|
@ -99,9 +99,9 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
// get graph id in child graphs by ME front anf node pointer
|
||||
virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const;
|
||||
virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; }
|
||||
void CheckPSModeConsistence(const KernelGraphPtr &Kernel_graph);
|
||||
void AssignParamKey(const KernelGraphPtr &kernel_graph);
|
||||
void InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &inputs_const);
|
||||
bool IsGetNextGraph(const GraphId &graph_id, std::string *channel_name);
|
||||
virtual bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||
std::string *error_msg) const {
|
||||
return true;
|
||||
|
@ -195,6 +195,11 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list);
|
||||
void UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &root_graph);
|
||||
void UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> &all_graphs);
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
void CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const;
|
||||
void GetBatchElements(const AnfNodePtr &kernel_node) const;
|
||||
void InitPsWorker(const KernelGraphPtr &kernel_graph);
|
||||
#endif
|
||||
|
||||
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
|
||||
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
|
||||
|
@ -207,6 +212,9 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
std::shared_ptr<Debugger> debugger_;
|
||||
#endif
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
bool initialized_ps_cache_{false};
|
||||
#endif
|
||||
};
|
||||
|
||||
using SessionPtr = std::shared_ptr<session::SessionBasic>;
|
||||
|
|
|
@ -24,6 +24,9 @@
|
|||
|
||||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
@ -514,6 +517,12 @@ Status GatherV2PInfo::InferBias() {
|
|||
if (repeated_calc_num_ > 1) {
|
||||
rank = rank / repeated_calc_num_;
|
||||
}
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
bias_ = 0;
|
||||
return SUCCESS;
|
||||
}
|
||||
#endif
|
||||
bias_ = rank / params_strategy.at(1) * slice_size_;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
|
|
@ -46,10 +46,18 @@
|
|||
#include "ir/anf.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "ir/tensor.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#include "ps/util.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) {
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
|
||||
|
|
|
@ -44,6 +44,9 @@
|
|||
#include "utils/comm_manager.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/symbolic.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#include "ps/util.h"
|
||||
#endif
|
||||
|
||||
using mindspore::tensor::Tensor;
|
||||
|
||||
|
@ -3036,6 +3039,11 @@ static void HandleNoUsedParameter(const FuncGraphPtr &root) {
|
|||
}
|
||||
|
||||
bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) {
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
MS_EXCEPTION_IF_NULL(optimizer);
|
||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||
|
|
|
@ -201,6 +201,7 @@ else ()
|
|||
if (${ENABLE_IBVERBS} STREQUAL "ON")
|
||||
target_link_libraries(_c_dataengine PRIVATE ibverbs rdmacm)
|
||||
endif ()
|
||||
target_link_libraries(_c_dataengine PRIVATE ps_cache)
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
|
|
|
@ -322,6 +322,7 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector<size_t> &data_size, con
|
|||
bool profiling, int32_t *push_time) {
|
||||
std::vector<device::DataItemGpu> items;
|
||||
double start_time;
|
||||
bool ps_data_prefetch = false;
|
||||
for (int i = 0; i < data_size.size(); i++) {
|
||||
device::DataItemGpu data_item;
|
||||
data_item.data_len_ = data_size[i];
|
||||
|
@ -334,6 +335,11 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector<size_t> &data_size, con
|
|||
if (profiling) {
|
||||
start_time = ProfilingTime::GetCurMilliSecond();
|
||||
}
|
||||
// Data prefetch only when PS mode enables cache.
|
||||
if ((!ps_data_prefetch) && (items.size() > 0)) {
|
||||
ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_);
|
||||
ps_data_prefetch = true;
|
||||
}
|
||||
BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME);
|
||||
if (profiling) {
|
||||
double end_time = ProfilingTime::GetCurMilliSecond();
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "minddata/dataset/engine/datasetops/pipeline_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/repeat_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||
|
||||
#ifdef ENABLE_TDTQUE
|
||||
#include "minddata/dataset/util/queue.h"
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
#include "utils/ms_utils.h"
|
||||
#include "minddata/dataset/engine/perf/profiling.h"
|
||||
#include "minddata/dataset/util/log_adapter.h"
|
||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
static std::shared_ptr<TdtPlugin> instance_ptr_ = nullptr;
|
||||
|
@ -48,6 +50,10 @@ TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channe
|
|||
if (profiling) {
|
||||
start_time = ProfilingTime::GetCurMilliSecond();
|
||||
}
|
||||
// Data prefetch only when PS mode enables cache.
|
||||
if (items.size() > 0) {
|
||||
ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_);
|
||||
}
|
||||
if (tdt::TdtHostPushData(channel_name, items) != 0) {
|
||||
return FAILED;
|
||||
}
|
||||
|
|
|
@ -308,7 +308,14 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("is_role_worker", &PSContext::is_role_worker, "Get whether the role of this process is Worker.")
|
||||
.def("is_role_pserver", &PSContext::is_role_pserver, "Get whether the role of this process is PServer.")
|
||||
.def("is_role_sched", &PSContext::is_role_sched, "Get whether the role of this process is Scheduler.")
|
||||
.def("ps_rank_id", &PSContext::ps_rank_id, "Get Worker and PServer rank id.");
|
||||
.def("ps_rank_id", &PSContext::ps_rank_id, "Get Worker and PServer rank id.")
|
||||
.def("insert_hash_table_size", &PSContext::InsertHashTableSize, "Insert hash table size.")
|
||||
.def("reinsert_hash_table_size", &PSContext::ReInsertHashTableSize,
|
||||
"Insert hash table size with new parameter name.")
|
||||
.def("insert_weight_init_info", &PSContext::InsertWeightInitInfo, "Insert embedding table initialization seed.")
|
||||
.def("insert_accumu_init_info", &PSContext::InsertAccumuInitInfo, "Insert accumulation initialization value.")
|
||||
.def("clone_hash_table", &PSContext::CloneHashTable, "Clone a hash table.")
|
||||
.def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not.");
|
||||
|
||||
(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")
|
||||
.def(py::init())
|
||||
|
|
|
@ -52,6 +52,7 @@
|
|||
#include "ps/common.h"
|
||||
#include "ps/util.h"
|
||||
#include "ps/worker.h"
|
||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||
#endif
|
||||
|
||||
#if (ENABLE_GE || ENABLE_D)
|
||||
|
@ -921,6 +922,11 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba
|
|||
bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size,
|
||||
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
|
||||
const std::vector<int64_t> &input_indexes, bool need_run) {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
if ((ps::Util::IsParamServerMode()) && (!ps::Util::IsRoleOfWorker())) {
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
MS_LOG(INFO) << "Start InitDataSet Entry";
|
||||
ShapeVector int_input_indexes;
|
||||
(void)std::transform(input_indexes.begin(), input_indexes.end(), std::back_inserter(int_input_indexes),
|
||||
|
@ -966,7 +972,17 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
|
|||
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
|
||||
backend->Link(runner.graph_id);
|
||||
}
|
||||
ConfigManager::GetInstance().set_iter_num(size);
|
||||
// PS mode does not support loop sink.
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
if (ps::Util::IsRoleOfWorker()) {
|
||||
ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size));
|
||||
ConfigManager::GetInstance().set_iter_num(1);
|
||||
} else {
|
||||
#endif
|
||||
ConfigManager::GetInstance().set_iter_num(size);
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
}
|
||||
#endif
|
||||
|
||||
if (!(*runner.run)) {
|
||||
// empty function
|
||||
|
@ -981,7 +997,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
|
|||
}
|
||||
MS_LOG(DEBUG) << "InitDataSetVm End.";
|
||||
return true;
|
||||
}
|
||||
} // namespace pipeline
|
||||
|
||||
void ResetOpId() { mindspore::id_generator::reset_id(); }
|
||||
|
||||
|
|
|
@ -14,22 +14,20 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
|
|||
list(REMOVE_ITEM _PS_SRC_FILES "core/cluster_config.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc")
|
||||
endif ()
|
||||
|
||||
if (NOT ENABLE_D)
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ascend/ascend_ps_cache.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc")
|
||||
endif()
|
||||
|
||||
if (NOT ENABLE_GPU)
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc")
|
||||
endif()
|
||||
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_channel.cc")
|
||||
add_subdirectory(ps_cache)
|
||||
|
||||
|
||||
set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS)
|
||||
add_library(_mindspore_ps_obj OBJECT ${_PS_SRC_FILES})
|
||||
|
|
|
@ -64,6 +64,7 @@ constexpr int64_t kInitWeightToOptimIdCmd = 11;
|
|||
constexpr int64_t kInitOptimInputsShapeCmd = 12;
|
||||
constexpr int64_t kInitKeyToPushNodeIdCmd = 13;
|
||||
constexpr int64_t kInitEmbeddingsCmd = 20;
|
||||
constexpr int64_t kUpdateEmbeddingsCmd = 21;
|
||||
constexpr int64_t kCheckReadyForPushCmd = 25;
|
||||
constexpr int64_t kCheckReadyForPullCmd = 26;
|
||||
constexpr int64_t kEmbeddingLookupCmd = 30;
|
||||
|
|
|
@ -51,6 +51,8 @@
|
|||
#include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h"
|
||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||
#include "ps/random_normal/random_normal.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -100,6 +102,7 @@ class ParameterServer {
|
|||
void HandleCheckReadyForPush(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
||||
void HandleCheckReadyForPull(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
||||
void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
||||
void HandleUpdateEmbeddings(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
||||
void HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
|
||||
|
||||
ParameterServer *ps_;
|
||||
|
@ -118,13 +121,15 @@ class ParameterServer {
|
|||
void InitWeight(const Key &key, const WeightPtr &weight);
|
||||
void InitGrad(const Key &key, const GradPtr &grad);
|
||||
void InitEmbeddingTable(const Key &key,
|
||||
const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes);
|
||||
const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes,
|
||||
const ParamInitInfo ¶m_init_info);
|
||||
bool HasWeight(const Key &key);
|
||||
void Finalize();
|
||||
void UpdateWeights();
|
||||
void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths);
|
||||
WeightPtr weight(const Key &key);
|
||||
void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs<T> *res);
|
||||
void UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals);
|
||||
bool ReadyForUpdateWeights();
|
||||
bool ReadyForPush(const Key &key);
|
||||
bool ReadyForPull(const Key &key);
|
||||
|
@ -193,6 +198,7 @@ void ParameterServer<T>::ServerHandler::Init() {
|
|||
handlers_[kCheckReadyForPushCmd] = &ServerHandler::HandleCheckReadyForPush;
|
||||
handlers_[kCheckReadyForPullCmd] = &ServerHandler::HandleCheckReadyForPull;
|
||||
handlers_[kEmbeddingLookupCmd] = &ServerHandler::HandleEmbeddingLookup;
|
||||
handlers_[kUpdateEmbeddingsCmd] = &ServerHandler::HandleUpdateEmbeddings;
|
||||
handlers_[kFinalizeCmd] = &ServerHandler::HandleFinalize;
|
||||
}
|
||||
|
||||
|
@ -302,7 +308,17 @@ void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta
|
|||
for (int64_t k = 0; k < lens[2]; k++) {
|
||||
output_shape->push_back(static_cast<size_t>(req_data.vals[index++]));
|
||||
}
|
||||
ps_->InitEmbeddingTable(key, shapes);
|
||||
ParamInitInfo param_init_info;
|
||||
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
param_init_info.param_type_ = static_cast<ParamType>(lens[3]);
|
||||
if (param_init_info.param_type_ == kWeight) {
|
||||
param_init_info.global_seed_ = static_cast<size_t>(lens[4]);
|
||||
param_init_info.op_seed_ = static_cast<size_t>(lens[5]);
|
||||
} else if (param_init_info.param_type_ == kAccumulation) {
|
||||
param_init_info.init_val_ = req_data.vals[index];
|
||||
}
|
||||
}
|
||||
ps_->InitEmbeddingTable(key, shapes, param_init_info);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -338,6 +354,18 @@ void ParameterServer<T>::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta
|
|||
ps_->DoEmbeddingLookup(key, req_data.keys.segment(1, req_data.keys.size()), res);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ParameterServer<T>::ServerHandler::HandleUpdateEmbeddings(const ::ps::KVMeta &req_meta,
|
||||
const ::ps::KVPairs<T> &req_data,
|
||||
::ps::KVPairs<T> *res) {
|
||||
std::unique_lock<std::mutex> lock(ps_->mutex());
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
const Key &key = req_data.keys[0];
|
||||
const LookupIds &lookup_ids = req_data.keys.segment(1, req_data.keys.size());
|
||||
const Values &update_vals = req_data.vals;
|
||||
ps_->UpdateEmbeddings(key, lookup_ids, update_vals);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ParameterServer<T>::ServerHandler::HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
|
||||
::ps::KVPairs<T> *res) {
|
||||
|
@ -476,7 +504,8 @@ void ParameterServer<T>::InitGrad(const Key &key, const GradPtr &grad) {
|
|||
|
||||
template <typename T>
|
||||
void ParameterServer<T>::InitEmbeddingTable(
|
||||
const Key &key, const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) {
|
||||
const Key &key, const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes,
|
||||
const ParamInitInfo ¶m_init_info) {
|
||||
MS_EXCEPTION_IF_NULL(shapes);
|
||||
if (weights_.count(key) == 0) {
|
||||
std::shared_ptr<PServerKernel> lookup =
|
||||
|
@ -493,8 +522,18 @@ void ParameterServer<T>::InitEmbeddingTable(
|
|||
T *embedding_data = embedding->data();
|
||||
std::default_random_engine engine;
|
||||
std::normal_distribution<float> random(0, 0.01);
|
||||
for (size_t i = 0; i < total_dims; i++) {
|
||||
embedding_data[i] = random(engine);
|
||||
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
if (param_init_info.param_type_ == kWeight) {
|
||||
InitRandomNormal(0, 0.01, input_shapes, param_init_info.global_seed_, param_init_info.op_seed_, embedding_data);
|
||||
} else if (param_init_info.param_type_ == kAccumulation) {
|
||||
for (size_t i = 0; i < total_dims; i++) {
|
||||
embedding_data[i] = param_init_info.init_val_;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < total_dims; i++) {
|
||||
embedding_data[i] = random(engine);
|
||||
}
|
||||
}
|
||||
weights_[key] = embedding;
|
||||
tokens_[key] = 0;
|
||||
|
@ -673,6 +712,23 @@ void ParameterServer<T>::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids,
|
|||
res->lens.push_back(res->vals.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ParameterServer<T>::UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals) {
|
||||
if (weights_.count(key) == 0) {
|
||||
MS_LOG(ERROR) << "Invalid embedding table key " << key;
|
||||
return;
|
||||
}
|
||||
if (embedding_lookup_ops_.count(key) == 0) {
|
||||
MS_LOG(ERROR) << "Invalid embedding lookup op key " << key;
|
||||
return;
|
||||
}
|
||||
WeightPtr table_ptr = weights_[key];
|
||||
MS_EXCEPTION_IF_NULL(table_ptr);
|
||||
std::shared_ptr<PServerKernel> table_lookup_op = embedding_lookup_ops_[key];
|
||||
MS_EXCEPTION_IF_NULL(table_lookup_op);
|
||||
table_lookup_op->UpdateEmbeddings(table_ptr->data(), lookup_ids.data(), vals.data(), lookup_ids.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool ParameterServer<T>::ReadyForUpdateWeights() {
|
||||
return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size();
|
||||
|
|
|
@ -70,9 +70,16 @@ void PsCacheManager::InsertWeightInitInfo(const std::string ¶m_name, size_t
|
|||
MS_LOG(EXCEPTION) << "Can not find parameter[" << param_name << "] in hash table.";
|
||||
}
|
||||
auto &hash_table_info = iter->second;
|
||||
if (hash_table_info.param_init_info_.param_type_ != kUnKnown) {
|
||||
return;
|
||||
}
|
||||
hash_table_info.param_init_info_.param_type_ = kWeight;
|
||||
hash_table_info.param_init_info_.global_seed_ = global_seed;
|
||||
hash_table_info.param_init_info_.op_seed_ = op_seed;
|
||||
if (CheckFinishInsertInitInfo()) {
|
||||
finish_insert_init_info_ = true;
|
||||
insert_init_info_.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
void PsCacheManager::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) {
|
||||
|
@ -81,8 +88,26 @@ void PsCacheManager::InsertAccumuInitInfo(const std::string ¶m_name, float i
|
|||
MS_LOG(EXCEPTION) << "Can not find parameter[" << param_name << "] in hash table.";
|
||||
}
|
||||
auto &hash_table_info = iter->second;
|
||||
if (hash_table_info.param_init_info_.param_type_ != kUnKnown) {
|
||||
return;
|
||||
}
|
||||
hash_table_info.param_init_info_.param_type_ = kAccumulation;
|
||||
hash_table_info.param_init_info_.init_val_ = init_val;
|
||||
if (CheckFinishInsertInitInfo()) {
|
||||
finish_insert_init_info_ = true;
|
||||
insert_init_info_.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
bool PsCacheManager::CheckFinishInsertInitInfo() const {
|
||||
for (const auto &item : hash_tables_) {
|
||||
const auto &hash_table_info = item.second;
|
||||
const auto ¶m_init_info = hash_table_info.param_init_info_;
|
||||
if (param_init_info.param_type_ == kUnKnown) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void PsCacheManager::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) {
|
||||
|
@ -113,35 +138,49 @@ void PsCacheManager::Initialize() {
|
|||
}
|
||||
embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_elements_, cache_vocab_size_);
|
||||
embedding_host_cache_ = std::make_shared<EmbeddingHostCache>(batch_elements_, host_cache_vocab_size_);
|
||||
InitParameterServer();
|
||||
AddEmbeddingTable();
|
||||
AllocMemForHashTable();
|
||||
SetLocalIdRank();
|
||||
initialized_ps_cache_ = true;
|
||||
}
|
||||
|
||||
void PsCacheManager::InitParameterServer() {
|
||||
void PsCacheManager::AddEmbeddingTable() const {
|
||||
for (const auto &item : hash_tables_) {
|
||||
const auto ¶m_name = item.first;
|
||||
size_t key = worker.SetParamKey(param_name);
|
||||
size_t row_count = item.second.vocab_size;
|
||||
std::vector<size_t> keys{key, key, key, key};
|
||||
// if worker role
|
||||
worker.AddEmbeddingTable(key, row_count);
|
||||
}
|
||||
}
|
||||
|
||||
void PsCacheManager::InitParameterServer() {
|
||||
std::unique_lock<std::mutex> locker(data_mutex_);
|
||||
insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true; });
|
||||
|
||||
for (const auto &item : hash_tables_) {
|
||||
const auto ¶m_name = item.first;
|
||||
size_t key = worker.SetParamKey(param_name);
|
||||
std::vector<size_t> keys{key, key, key, key, key, key};
|
||||
std::vector<float> values{
|
||||
SizeToFloat(item.second.vocab_size), SizeToFloat(item.second.embedding_size), 1, 1, 1, 1, 1};
|
||||
std::vector<int64_t> lens{2, 2, 3};
|
||||
const auto &hash_table_info = item.second;
|
||||
const auto ¶m_init_info = hash_table_info.param_init_info_;
|
||||
if (param_init_info.param_type_ == kWeight) {
|
||||
lens.push_back(0);
|
||||
values.push_back(SizeToFloat(param_init_info.global_seed_));
|
||||
values.push_back(SizeToFloat(param_init_info.op_seed_));
|
||||
} else if (param_init_info.param_type_ == kAccumulation) {
|
||||
lens.push_back(1);
|
||||
values.push_back(param_init_info.init_val_);
|
||||
} else if (param_init_info.param_type_ == kAccumulation) {
|
||||
lens.push_back(2);
|
||||
}
|
||||
values.push_back(param_init_info.init_val_);
|
||||
lens.push_back(param_init_info.global_seed_);
|
||||
lens.push_back(param_init_info.op_seed_);
|
||||
// if worker role
|
||||
worker.AddEmbeddingTable(key, row_count);
|
||||
worker.InitPSEmbeddingTable(keys, values, lens);
|
||||
}
|
||||
|
||||
finish_init_parameter_server_ = true;
|
||||
data_prase_.notify_one();
|
||||
}
|
||||
|
||||
void PsCacheManager::AllocMemForHashTable() {
|
||||
|
@ -208,10 +247,538 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) {
|
|||
if (graph_step_ >= UINT64_MAX) {
|
||||
MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") << will exceed the maximum value of uint64_t.";
|
||||
}
|
||||
if (graph_step_ == 0) {
|
||||
std::unique_lock<std::mutex> locker(data_mutex_);
|
||||
data_prase_.wait(locker, [this] { return finish_init_parameter_server_ == true; });
|
||||
}
|
||||
graph_step_++;
|
||||
set_channel_name(channel_name);
|
||||
PsDataPrefetch::GetInstance().TryWakeChannel(channel_name);
|
||||
data_prase_.notify_one();
|
||||
}
|
||||
|
||||
void PsCacheManager::DoProcessData(uint32_t device_id, void *context) {
|
||||
if (!initialized_ps_cache_) {
|
||||
MS_LOG(EXCEPTION) << "PS cache does not init.";
|
||||
}
|
||||
auto process_data_thread = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context);
|
||||
process_data_thread.detach();
|
||||
}
|
||||
|
||||
void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) {
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
|
||||
embedding_device_cache_->cache_->InitDevice(device_id, context);
|
||||
InitParameterServer();
|
||||
while (true) {
|
||||
ProcessData();
|
||||
}
|
||||
}
|
||||
|
||||
void PsCacheManager::ProcessData() {
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
|
||||
struct timeval start_time, end_time;
|
||||
const uint64_t kUSecondInSecond = 1000000;
|
||||
(void)gettimeofday(&start_time, nullptr);
|
||||
auto channel = channel_name();
|
||||
if (channel.empty()) {
|
||||
std::unique_lock<std::mutex> locker(data_mutex_);
|
||||
data_prase_.wait(locker, [this] { return !channel_name_.empty(); });
|
||||
}
|
||||
auto data = PsDataPrefetch::GetInstance().data(channel_name_);
|
||||
if (data == nullptr) {
|
||||
MS_LOG(INFO) << "No data process, channel name:" << channel_name_;
|
||||
std::unique_lock<std::mutex> locker(data_mutex_);
|
||||
(void)data_prase_.wait_for(locker, std::chrono::milliseconds(100));
|
||||
return;
|
||||
}
|
||||
IncreaseStep();
|
||||
auto data_size = PsDataPrefetch::GetInstance().data_size(channel_name_);
|
||||
auto batch_ids = reinterpret_cast<int *>(data);
|
||||
auto batch_ids_len = data_size / sizeof(int);
|
||||
std::unique_ptr<int[]> hash_index(new int[batch_ids_len]);
|
||||
if (memset_s(&statistics_info_, sizeof(statistics_info_), 0, sizeof(statistics_info_))) {
|
||||
MS_LOG(EXCEPTION) << "Process data memset failed.";
|
||||
}
|
||||
// Get hash swap in/out index and ids.
|
||||
ParseData(batch_ids, batch_ids_len, hash_index.get());
|
||||
for (const auto &item : hash_tables_) {
|
||||
auto key = worker.GetParamKey(item.first);
|
||||
auto hash_info = item.second;
|
||||
HashSwapHostToServer(key, hash_info);
|
||||
HashSwapDeviceToHost(hash_info);
|
||||
HashSwapServerToHost(key, hash_info);
|
||||
HashSwapHostToDevice(hash_info);
|
||||
}
|
||||
// Replace the batch_ids by hash index for getNext-op getting hash index as input.
|
||||
if (memcpy_s(data, data_size, hash_index.get(), data_size) != EOK) {
|
||||
MS_LOG(EXCEPTION) << "Process data memcpy failed.";
|
||||
}
|
||||
embedding_device_cache_->cache_->SynchronizeStream();
|
||||
// Finish the data process and notify data prefetch.
|
||||
PsDataPrefetch::GetInstance().FinalizeData(channel_name_);
|
||||
(void)gettimeofday(&end_time, nullptr);
|
||||
uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
|
||||
cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
|
||||
MS_LOG(DEBUG) << "Ps cache completes processing data(data step:" << data_step_
|
||||
<< ",graph step:" << graph_running_step_ << " channel name:" << channel_name_
|
||||
<< ", time cost:" << cost / 1000 << "ms).";
|
||||
}
|
||||
|
||||
void PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) {
|
||||
MS_EXCEPTION_IF_NULL(batch_ids);
|
||||
MS_EXCEPTION_IF_NULL(hash_index);
|
||||
for (size_t i = 0; i < batch_ids_len; i++) {
|
||||
bool need_swap_host_to_device = true;
|
||||
bool need_swap_device_to_host = true;
|
||||
auto id = batch_ids[i];
|
||||
if ((id < SizeToInt(range_bound_.first)) || (id >= SizeToInt(range_bound_.second))) {
|
||||
hash_index[i] = -1;
|
||||
continue;
|
||||
}
|
||||
hash_index[i] = ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device);
|
||||
if (need_swap_host_to_device) {
|
||||
ParseHostDataHostToDevice(id);
|
||||
}
|
||||
if (need_swap_device_to_host) {
|
||||
ParseHostDataDeviceToHost(id);
|
||||
}
|
||||
}
|
||||
// Each 1000 step prints ps cache hit rate.
|
||||
if (data_step_ % 1000 == 0) {
|
||||
statistics_info_.batch_id_unique_count_ = statistics_info_.hash_hit_count_ + statistics_info_.host_to_device_size_;
|
||||
auto hit_rate = SizeToFloat(statistics_info_.hash_hit_count_) / statistics_info_.batch_id_unique_count_;
|
||||
MS_LOG(INFO) << "Ps cache hit rate: " << hit_rate * 100 << "%.";
|
||||
}
|
||||
}
|
||||
|
||||
void PsCacheManager::WaitGraphRun() {
|
||||
MS_LOG(INFO) << "Hash table has no space to insert new data and retries within 2 minutes.";
|
||||
std::unique_lock<std::mutex> locker(data_mutex_);
|
||||
if (!data_prase_.wait_for(locker, std::chrono::seconds(120), [this] { return graph_step_ > graph_running_step_; })) {
|
||||
MS_LOG(EXCEPTION) << "Ps cache data parse timeout, suggest to enlarge the cache size(graph step:" << graph_step_
|
||||
<< ", graph running step:" << graph_running_step_ << ").";
|
||||
}
|
||||
set_current_graph_step();
|
||||
}
|
||||
|
||||
int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device) {
|
||||
MS_EXCEPTION_IF_NULL(need_swap_device_to_host);
|
||||
MS_EXCEPTION_IF_NULL(need_swap_host_to_device);
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
|
||||
int *device_to_host_index = embedding_device_cache_->device_to_host_index.get();
|
||||
int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get();
|
||||
int *host_to_device_index = embedding_device_cache_->host_to_device_index.get();
|
||||
int *host_to_device_ids = embedding_device_cache_->host_to_device_ids.get();
|
||||
MS_EXCEPTION_IF_NULL(device_to_host_index);
|
||||
MS_EXCEPTION_IF_NULL(device_to_host_ids);
|
||||
MS_EXCEPTION_IF_NULL(host_to_device_index);
|
||||
MS_EXCEPTION_IF_NULL(host_to_device_ids);
|
||||
|
||||
auto device_hash_map = embedding_device_cache_->device_hash_map_;
|
||||
MS_EXCEPTION_IF_NULL(device_hash_map);
|
||||
int index = 0;
|
||||
auto iter = device_hash_map->id_iter(id);
|
||||
if (device_hash_map->IsIdExist(iter)) {
|
||||
*need_swap_device_to_host = false;
|
||||
*need_swap_host_to_device = false;
|
||||
index = iter->second;
|
||||
if (device_hash_map->hash_step(index) != data_step_) {
|
||||
statistics_info_.hash_hit_count_++;
|
||||
device_hash_map->set_hash_step(index, data_step_);
|
||||
}
|
||||
} else {
|
||||
auto tmp_device_to_host_size = statistics_info_.device_to_host_size_;
|
||||
while (true) {
|
||||
index = device_hash_map->ParseData(id, device_to_host_index, device_to_host_ids, data_step_, graph_running_step_,
|
||||
&(statistics_info_.device_to_host_size_));
|
||||
if (index == INVALID_INDEX_VALUE) {
|
||||
WaitGraphRun();
|
||||
continue;
|
||||
}
|
||||
host_to_device_index[statistics_info_.host_to_device_size_] = index;
|
||||
host_to_device_ids[statistics_info_.host_to_device_size_] = id;
|
||||
statistics_info_.host_to_device_size_++;
|
||||
*need_swap_device_to_host = statistics_info_.device_to_host_size_ > tmp_device_to_host_size;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
void PsCacheManager::ParseHostDataHostToDevice(size_t id) {
|
||||
MS_EXCEPTION_IF_NULL(embedding_host_cache_);
|
||||
int *host_to_server_index = embedding_host_cache_->host_to_server_index.get();
|
||||
int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
|
||||
int *server_to_host_index = embedding_host_cache_->server_to_host_index.get();
|
||||
int *server_to_host_ids = embedding_host_cache_->server_to_host_ids.get();
|
||||
int *host_to_device_index = embedding_host_cache_->host_to_device_index.get();
|
||||
MS_EXCEPTION_IF_NULL(host_to_server_index);
|
||||
MS_EXCEPTION_IF_NULL(host_to_server_ids);
|
||||
MS_EXCEPTION_IF_NULL(server_to_host_index);
|
||||
MS_EXCEPTION_IF_NULL(server_to_host_ids);
|
||||
MS_EXCEPTION_IF_NULL(host_to_device_index);
|
||||
|
||||
auto host_hash_map = embedding_host_cache_->host_hash_map_;
|
||||
MS_EXCEPTION_IF_NULL(host_hash_map);
|
||||
auto iter = host_hash_map->id_iter(id);
|
||||
if (host_hash_map->IsIdExist(iter)) {
|
||||
auto index = iter->second;
|
||||
if (host_hash_map->hash_step(index) != data_step_) {
|
||||
host_hash_map->set_hash_step(index, data_step_);
|
||||
}
|
||||
host_to_device_index[statistics_info_.host_to_device_size_ - 1] = index;
|
||||
} else {
|
||||
while (true) {
|
||||
auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_,
|
||||
graph_running_step_, &statistics_info_.host_to_server_size_);
|
||||
if (index == INVALID_INDEX_VALUE) {
|
||||
WaitGraphRun();
|
||||
continue;
|
||||
}
|
||||
host_to_device_index[statistics_info_.host_to_device_size_ - 1] = index;
|
||||
server_to_host_index[statistics_info_.server_to_host_size_] = index;
|
||||
server_to_host_ids[statistics_info_.server_to_host_size_++] = id;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PsCacheManager::ParseHostDataDeviceToHost(size_t id) {
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
|
||||
MS_EXCEPTION_IF_NULL(embedding_host_cache_);
|
||||
int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get();
|
||||
int *host_to_server_index = embedding_host_cache_->host_to_server_index.get();
|
||||
int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
|
||||
int *device_to_host_index = embedding_host_cache_->device_to_host_index.get();
|
||||
MS_EXCEPTION_IF_NULL(device_to_host_ids);
|
||||
MS_EXCEPTION_IF_NULL(host_to_server_index);
|
||||
MS_EXCEPTION_IF_NULL(host_to_server_ids);
|
||||
MS_EXCEPTION_IF_NULL(device_to_host_index);
|
||||
|
||||
auto host_hash_map = embedding_host_cache_->host_hash_map_;
|
||||
MS_EXCEPTION_IF_NULL(host_hash_map);
|
||||
int swap_device_to_host_id = device_to_host_ids[statistics_info_.device_to_host_size_ - 1];
|
||||
auto iter = host_hash_map->id_iter(swap_device_to_host_id);
|
||||
if (host_hash_map->IsIdExist(iter)) {
|
||||
auto index = iter->second;
|
||||
if (host_hash_map->hash_step(index) != data_step_) {
|
||||
host_hash_map->set_hash_step(index, data_step_);
|
||||
}
|
||||
device_to_host_index[statistics_info_.device_to_host_size_ - 1] = index;
|
||||
} else {
|
||||
while (true) {
|
||||
auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_,
|
||||
graph_running_step_, &statistics_info_.host_to_server_size_);
|
||||
if (index == INVALID_INDEX_VALUE) {
|
||||
WaitGraphRun();
|
||||
continue;
|
||||
}
|
||||
device_to_host_index[statistics_info_.device_to_host_size_ - 1] = index;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size,
|
||||
const float *input_addr, const int *indices_addr, float *output_addr) {
|
||||
auto type_size = sizeof(float);
|
||||
size_t lens = outer_dim_size * type_size;
|
||||
for (size_t i = 0; i < indices_lens; ++i) {
|
||||
int index = indices_addr[i];
|
||||
if (index >= 0 && index < SizeToInt(first_dim_size)) {
|
||||
size_t pos = index * outer_dim_size;
|
||||
auto ret = memcpy_s(output_addr, (indices_lens - i) * lens, input_addr + pos, lens);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed.";
|
||||
}
|
||||
} else {
|
||||
auto ret = memset_s(output_addr, (indices_lens - i) * lens, 0, lens);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "LookUpTable task memset failed.";
|
||||
}
|
||||
}
|
||||
output_addr += outer_dim_size;
|
||||
}
|
||||
}
|
||||
|
||||
void PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr,
|
||||
const int *indices_addr, float *output_addr) {
|
||||
size_t first_dim_size = host_cache_vocab_size_;
|
||||
size_t outer_dim_size = embedding_size;
|
||||
|
||||
size_t thread_num = indices_lens / 10000 + 1;
|
||||
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
|
||||
std::thread threads[kMaxThreadNum];
|
||||
size_t task_proc_lens = (indices_lens + thread_num - 1) / thread_num;
|
||||
size_t i;
|
||||
size_t task_offset = 0;
|
||||
MS_LOG(DEBUG) << "Indices lens: " << indices_lens << ", one task proc lens:" << task_proc_lens;
|
||||
for (i = 0; i < thread_num; i++) {
|
||||
if (task_offset >= indices_lens) {
|
||||
break;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Task offset: " << task_offset << ", task process lens:" << task_proc_lens;
|
||||
threads[i] = std::thread(&PsCacheManager::LookUpTableTask, this, task_proc_lens, outer_dim_size, first_dim_size,
|
||||
hash_table_addr, indices_addr + task_offset, output_addr + task_offset * outer_dim_size);
|
||||
task_offset += task_proc_lens;
|
||||
if (task_offset + task_proc_lens > indices_lens) {
|
||||
task_proc_lens = indices_lens - task_offset;
|
||||
}
|
||||
}
|
||||
for (size_t j = 0; j < i; j++) {
|
||||
threads[j].join();
|
||||
}
|
||||
}
|
||||
|
||||
void PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices,
|
||||
float *insert_data, float *hash_table_addr) {
|
||||
size_t first_dim_size = host_cache_vocab_size_;
|
||||
size_t thread_num = insert_indices_size / 10000 + 1;
|
||||
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
|
||||
std::thread threads[kMaxThreadNum];
|
||||
size_t task_proc_lens = (insert_indices_size + thread_num - 1) / thread_num;
|
||||
size_t i;
|
||||
size_t task_offset = 0;
|
||||
|
||||
auto insert_hash_table_task = [](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size,
|
||||
int *insert_indices, float *insert_data, float *hash_table_addr) {
|
||||
auto type_size = sizeof(float);
|
||||
size_t lens = outer_dim_size * type_size;
|
||||
for (size_t i = 0; i < insert_indices_size; ++i) {
|
||||
int index = insert_indices[i];
|
||||
if (index >= 0 && index < SizeToInt(first_dim_size)) {
|
||||
auto ret = memcpy_s(hash_table_addr + index * outer_dim_size, lens, insert_data + i * outer_dim_size, lens);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "Insert hash table task memcpy failed.";
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (i = 0; i < thread_num; i++) {
|
||||
if (task_offset >= insert_indices_size) {
|
||||
break;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Task offset: " << task_offset << ", task process lens:" << task_proc_lens;
|
||||
threads[i] = std::thread(insert_hash_table_task, task_proc_lens, embedding_size, first_dim_size,
|
||||
insert_indices + task_offset, insert_data + task_offset * embedding_size, hash_table_addr);
|
||||
task_offset += task_proc_lens;
|
||||
if (task_offset + task_proc_lens > insert_indices_size) {
|
||||
task_proc_lens = insert_indices_size - task_offset;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < i; j++) {
|
||||
threads[j].join();
|
||||
}
|
||||
}
|
||||
|
||||
void PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) {
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
|
||||
MS_EXCEPTION_IF_NULL(embedding_host_cache_);
|
||||
auto host_cache_host_to_device_index = embedding_host_cache_->host_to_device_index.get();
|
||||
auto device_cache_host_to_device_index = embedding_device_cache_->host_to_device_index.get();
|
||||
auto swap_indices_size = statistics_info_.host_to_device_size_;
|
||||
if (swap_indices_size == 0) {
|
||||
return;
|
||||
}
|
||||
auto embedding_size = hash_info.embedding_size;
|
||||
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
|
||||
auto hash_table_size = hash_info.device_address.size;
|
||||
auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
|
||||
auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
|
||||
LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_cache_host_to_device_index,
|
||||
swap_out_data.get());
|
||||
embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_value_addr_,
|
||||
swap_out_data.get(),
|
||||
swap_indices_size * embedding_size * sizeof(float));
|
||||
embedding_device_cache_->cache_->CopyHostMemToDevice(
|
||||
embedding_device_cache_->hash_swap_index_addr_, device_cache_host_to_device_index, swap_indices_size * sizeof(int));
|
||||
embedding_device_cache_->cache_->HashSwapIn(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_,
|
||||
embedding_device_cache_->hash_swap_index_addr_, hash_table_size,
|
||||
embedding_size, swap_indices_size);
|
||||
}
|
||||
|
||||
void PsCacheManager::HashSwapDeviceToHost(const HashTableInfo &hash_info) {
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
|
||||
MS_EXCEPTION_IF_NULL(embedding_host_cache_);
|
||||
auto swap_indices_size = statistics_info_.device_to_host_size_;
|
||||
auto device_cache_device_to_host_index = embedding_device_cache_->device_to_host_index.get();
|
||||
auto host_cache_device_to_host_index = embedding_host_cache_->device_to_host_index.get();
|
||||
if (swap_indices_size == 0) {
|
||||
return;
|
||||
}
|
||||
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
|
||||
auto hash_table_size = hash_info.device_address.size;
|
||||
auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
|
||||
auto embedding_size = hash_info.embedding_size;
|
||||
auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
|
||||
embedding_device_cache_->cache_->CopyHostMemToDevice(
|
||||
embedding_device_cache_->hash_swap_index_addr_, device_cache_device_to_host_index, swap_indices_size * sizeof(int));
|
||||
embedding_device_cache_->cache_->HashSwapOut(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_,
|
||||
embedding_device_cache_->hash_swap_index_addr_, hash_table_size,
|
||||
embedding_size, swap_indices_size);
|
||||
embedding_device_cache_->cache_->CopyDeviceMemToHost(swap_out_data.get(),
|
||||
embedding_device_cache_->hash_swap_value_addr_,
|
||||
swap_indices_size * embedding_size * sizeof(float));
|
||||
embedding_device_cache_->cache_->SynchronizeStream();
|
||||
InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), host_cache_device_to_host_index,
|
||||
swap_out_data.get(), host_hash_table_addr);
|
||||
}
|
||||
|
||||
void PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_info) {
|
||||
MS_EXCEPTION_IF_NULL(embedding_host_cache_);
|
||||
auto host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
|
||||
auto host_to_server_index = embedding_host_cache_->host_to_server_index.get();
|
||||
auto swap_indices_size = statistics_info_.host_to_server_size_;
|
||||
if (swap_indices_size == 0) {
|
||||
return;
|
||||
}
|
||||
::ps::SArray<int> lookup_ids(swap_indices_size, 0);
|
||||
::ps::SArray<float> swap_out_data;
|
||||
auto embedding_size = hash_info.embedding_size;
|
||||
swap_out_data.resize(swap_indices_size * embedding_size);
|
||||
auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
|
||||
LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_to_server_index,
|
||||
swap_out_data.data());
|
||||
|
||||
auto copy_len = swap_indices_size * sizeof(int);
|
||||
auto ret = memcpy_s(lookup_ids.data(), copy_len, host_to_server_ids, copy_len);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "Lookup id memcpy failed.";
|
||||
}
|
||||
worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
|
||||
}
|
||||
|
||||
void PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_info) {
|
||||
MS_EXCEPTION_IF_NULL(embedding_host_cache_);
|
||||
auto swap_indices_size = statistics_info_.server_to_host_size_;
|
||||
auto server_to_host_ids = embedding_host_cache_->server_to_host_ids.get();
|
||||
auto server_to_host_index = embedding_host_cache_->server_to_host_index.get();
|
||||
if (swap_indices_size == 0) {
|
||||
return;
|
||||
}
|
||||
auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
|
||||
auto embedding_size = hash_info.embedding_size;
|
||||
::ps::SArray<int> lengths{swap_indices_size};
|
||||
::ps::SArray<float> lookup_result(swap_indices_size * embedding_size, 0);
|
||||
::ps::SArray<int> lookup_ids(swap_indices_size, 0);
|
||||
auto copy_len = swap_indices_size * sizeof(int);
|
||||
auto ret = memcpy_s(lookup_ids.data(), copy_len, server_to_host_ids, copy_len);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "Lookup id memcpy failed.";
|
||||
}
|
||||
worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
|
||||
InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), server_to_host_index, lookup_result.data(),
|
||||
host_hash_table_addr);
|
||||
}
|
||||
|
||||
void PsCacheManager::HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data,
|
||||
const HashTableInfo &hash_info) {
|
||||
MS_EXCEPTION_IF_NULL(swap_out_index);
|
||||
MS_EXCEPTION_IF_NULL(swap_out_data);
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
|
||||
auto swap_out_index_size = statistics_info_.device_to_host_size_;
|
||||
if (swap_out_index_size == 0) {
|
||||
return;
|
||||
}
|
||||
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
|
||||
auto hash_table_size = hash_info.device_address.size;
|
||||
auto embedding_size = hash_info.embedding_size;
|
||||
swap_out_data->resize(swap_out_index_size * embedding_size);
|
||||
embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, swap_out_index,
|
||||
swap_out_index_size * sizeof(int));
|
||||
embedding_device_cache_->cache_->HashSwapOut(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_,
|
||||
embedding_device_cache_->hash_swap_index_addr_, hash_table_size,
|
||||
embedding_size, swap_out_index_size);
|
||||
embedding_device_cache_->cache_->CopyDeviceMemToHost(swap_out_data->data(),
|
||||
embedding_device_cache_->hash_swap_value_addr_,
|
||||
swap_out_index_size * embedding_size * sizeof(float));
|
||||
embedding_device_cache_->cache_->RecordEvent();
|
||||
}
|
||||
|
||||
void PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info,
|
||||
size_t key) {
|
||||
MS_EXCEPTION_IF_NULL(swap_in_ids);
|
||||
MS_EXCEPTION_IF_NULL(swap_in_index);
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
|
||||
auto swap_in_ids_size = statistics_info_.host_to_device_size_;
|
||||
if (swap_in_ids_size == 0) {
|
||||
return;
|
||||
}
|
||||
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
|
||||
auto hash_table_size = hash_info.device_address.size;
|
||||
auto embedding_size = hash_info.embedding_size;
|
||||
// Get id embs by swap_in_ids in host(Pipeline with hash swap-out in device).
|
||||
::ps::SArray<int> lengths{swap_in_ids_size};
|
||||
::ps::SArray<float> lookup_result(swap_in_ids_size * embedding_size, 0);
|
||||
::ps::SArray<int> lookup_ids(swap_in_ids_size, 0);
|
||||
auto copy_len = swap_in_ids_size * sizeof(int);
|
||||
auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_in_ids, copy_len);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "Lookup id memcpy failed.";
|
||||
}
|
||||
worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
|
||||
// Hash swap-in in device.
|
||||
embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_value_addr_,
|
||||
lookup_result.data(),
|
||||
swap_in_ids_size * embedding_size * sizeof(float));
|
||||
embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, swap_in_index,
|
||||
swap_in_ids_size * sizeof(int));
|
||||
embedding_device_cache_->cache_->HashSwapIn(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_,
|
||||
embedding_device_cache_->hash_swap_index_addr_, hash_table_size,
|
||||
embedding_size, swap_in_ids_size);
|
||||
}
|
||||
|
||||
void PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key) {
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
|
||||
MS_EXCEPTION_IF_NULL(swap_out_ids);
|
||||
auto swap_out_ids_size = statistics_info_.device_to_host_size_;
|
||||
if (swap_out_ids_size == 0) {
|
||||
return;
|
||||
}
|
||||
::ps::SArray<int> lookup_ids(swap_out_ids_size, 0);
|
||||
auto copy_len = swap_out_ids_size * sizeof(int);
|
||||
auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_out_ids, copy_len);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "Lookup id memcpy failed.";
|
||||
}
|
||||
// Need synchronize event to ensure that the swap-out in device is completed.
|
||||
embedding_device_cache_->cache_->SynchronizeEvent();
|
||||
worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
|
||||
}
|
||||
|
||||
void PsCacheManager::DumpHashTables() const {
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
|
||||
for (const auto &item : hash_tables_) {
|
||||
const auto ¶m_name = item.first;
|
||||
size_t cache_vocab_size = item.second.cache_vocab_size;
|
||||
size_t embedding_size = item.second.embedding_size;
|
||||
size_t vocab_size = item.second.vocab_size;
|
||||
MS_LOG(INFO) << "Dump hash tables: " << param_name << " || " << cache_vocab_size << " || " << embedding_size
|
||||
<< " || " << vocab_size << " || " << reinterpret_cast<void *>(item.second.device_address.addr)
|
||||
<< " || " << reinterpret_cast<void *>(item.second.host_address.get());
|
||||
float *output = new float[item.second.device_address.size / 4];
|
||||
embedding_device_cache_->cache_->CopyDeviceMemToHost(output, item.second.device_address.addr,
|
||||
item.second.device_address.size);
|
||||
embedding_device_cache_->cache_->SynchronizeStream();
|
||||
for (size_t i = 0; i < cache_vocab_size; i++) {
|
||||
for (size_t j = 0; j < embedding_size; j++) {
|
||||
std::cout << output[i * embedding_size + j] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
delete[] output;
|
||||
}
|
||||
}
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -49,6 +49,7 @@ struct HashTableInfo {
|
|||
size_t vocab_size{0};
|
||||
Address device_address{nullptr, 0};
|
||||
std::shared_ptr<int[]> host_address{nullptr};
|
||||
ParamInitInfo param_init_info_;
|
||||
};
|
||||
|
||||
struct EmbeddingDeviceCache {
|
||||
|
@ -158,6 +159,8 @@ class PsCacheManager {
|
|||
void UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key);
|
||||
void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr,
|
||||
const int *indices_addr, float *output_addr);
|
||||
bool CheckFinishInsertInitInfo() const;
|
||||
void AddEmbeddingTable() const;
|
||||
|
||||
bool initialized_ps_cache_{false};
|
||||
std::string channel_name_;
|
||||
|
@ -167,6 +170,7 @@ class PsCacheManager {
|
|||
size_t data_step_{0};
|
||||
std::mutex data_mutex_;
|
||||
std::condition_variable data_prase_;
|
||||
std::condition_variable insert_init_info_;
|
||||
|
||||
std::map<std::string, HashTableInfo> hash_tables_;
|
||||
std::shared_ptr<EmbeddingDeviceCache> embedding_device_cache_;
|
||||
|
@ -178,6 +182,8 @@ class PsCacheManager {
|
|||
size_t batch_elements_{0};
|
||||
PsCacheStatisticsInfo statistics_info_;
|
||||
std::pair<size_t, size_t> range_bound_;
|
||||
std::atomic_bool finish_insert_init_info_{false};
|
||||
std::atomic_bool finish_init_parameter_server_{false};
|
||||
};
|
||||
|
||||
static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance();
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
class PsDataPrefetch {
|
||||
class EXPORT PsDataPrefetch {
|
||||
public:
|
||||
EXPORT static PsDataPrefetch &GetInstance() {
|
||||
static PsDataPrefetch instance;
|
||||
|
|
|
@ -17,6 +17,11 @@
|
|||
#include "ps/ps_context.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -80,5 +85,43 @@ bool PSContext::is_role_sched() const { return is_sched_; }
|
|||
void PSContext::SetPSRankId(int rank_id) { rank_id_ = rank_id; }
|
||||
|
||||
int PSContext::ps_rank_id() const { return rank_id_; }
|
||||
|
||||
void PSContext::InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size,
|
||||
size_t vocab_size) const {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
ps_cache_instance.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size);
|
||||
#endif
|
||||
}
|
||||
|
||||
void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
|
||||
size_t cache_vocab_size, size_t embedding_size) const {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
ps_cache_instance.ReInsertHashTableSize(new_param_name, cur_param_name, cache_vocab_size, embedding_size);
|
||||
#endif
|
||||
}
|
||||
|
||||
void PSContext::InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed) const {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
ps_cache_instance.InsertWeightInitInfo(param_name, global_seed, op_seed);
|
||||
#endif
|
||||
}
|
||||
|
||||
void PSContext::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
ps_cache_instance.InsertAccumuInitInfo(param_name, init_val);
|
||||
#endif
|
||||
}
|
||||
|
||||
void PSContext::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
ps_cache_instance.CloneHashTable(dest_param_name, src_param_name);
|
||||
#endif
|
||||
}
|
||||
|
||||
void PSContext::set_cache_enable(bool cache_enable) const {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
PsDataPrefetch::GetInstance().set_cache_enable(cache_enable);
|
||||
#endif
|
||||
}
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -44,6 +44,14 @@ class PSContext {
|
|||
bool is_role_sched() const;
|
||||
void SetPSRankId(int rank_id);
|
||||
int ps_rank_id() const;
|
||||
void InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size,
|
||||
size_t vocab_size) const;
|
||||
void ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
|
||||
size_t cache_vocab_size, size_t embedding_size) const;
|
||||
void InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed) const;
|
||||
void InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const;
|
||||
void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const;
|
||||
void set_cache_enable(bool cache_enable) const;
|
||||
|
||||
private:
|
||||
PSContext() : ps_enabled_(false), is_worker_(false), is_pserver_(false), is_sched_(false), rank_id_(-1) {}
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ps/random_normal/random_normal.h"
|
||||
#include <iostream>
|
||||
#include <thread>
|
||||
#include <memory>
|
||||
#include "utils/convert_utils_base.h"
|
||||
#include "pybind_api/random_normal/random_cpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
bool InitRandomNormal(float mean, float stddev, std::vector<size_t> out_shape, size_t global_seed, size_t op_seed,
|
||||
float *output_data) {
|
||||
if (out_shape.size() == 0) {
|
||||
std::cout << "output data shape is error" << std::endl;
|
||||
}
|
||||
int64_t total_count = 1;
|
||||
for (uint32_t i = 0; i < out_shape.size(); i++) {
|
||||
total_count *= SizeToLong(out_shape[i]);
|
||||
}
|
||||
uint32_t thread_num = 16;
|
||||
if (total_count <= thread_num) {
|
||||
thread_num = 1;
|
||||
}
|
||||
float *start_ptr = output_data;
|
||||
if (start_ptr == nullptr) {
|
||||
std::cout << "start_ptr is nullptr" << std::endl;
|
||||
return false;
|
||||
}
|
||||
int64_t batchSize = total_count / thread_num;
|
||||
std::vector<std::thread> threads(thread_num);
|
||||
int64_t seed = SizeToLong(global_seed);
|
||||
int64_t seed2 = SizeToLong(op_seed);
|
||||
seed = (seed == 0 && seed2 == 0) ? clock() : seed;
|
||||
PhiloxGenerator generator = PhiloxGenerator(seed, seed2);
|
||||
if (thread_num != 1) {
|
||||
for (uint32_t i = 0; i < thread_num - 1; i++) {
|
||||
float *offset_ptr = start_ptr + batchSize * i;
|
||||
threads[i] =
|
||||
std::thread(FillRandoms<NormalDistribution<PhiloxGenerator, float>>, generator, offset_ptr, batchSize, i);
|
||||
}
|
||||
float *offset_ptr = start_ptr + batchSize * (thread_num - 1);
|
||||
threads[thread_num - 1] = std::thread(FillRandoms<NormalDistribution<PhiloxGenerator, float>>, generator,
|
||||
offset_ptr, total_count - (thread_num - 1) * batchSize, thread_num - 1);
|
||||
} else {
|
||||
threads[0] =
|
||||
std::thread(FillRandoms<NormalDistribution<PhiloxGenerator, float>>, generator, start_ptr, total_count, 0);
|
||||
}
|
||||
for (uint32_t i = 0; i < thread_num; i++) {
|
||||
threads[i].join();
|
||||
}
|
||||
for (int64_t i = 0; i < total_count; i++) {
|
||||
output_data[i] *= stddev;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PS_RANDOM_NORMAL_RANDOM_NORMAL_H_
|
||||
#define MINDSPORE_CCSRC_PS_RANDOM_NORMAL_RANDOM_NORMAL_H_
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
bool InitRandomNormal(float mean, float stddev, std::vector<size_t> out_shape, size_t global_seed, size_t op_seed,
|
||||
float *output_data);
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PS_RANDOM_NORMAL_RANDOM_NORMAL_H_
|
|
@ -26,6 +26,15 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
enum ParamType { kUnKnown = 0, kWeight = 1, kAccumulation = 2 };
|
||||
|
||||
struct ParamInitInfo {
|
||||
ParamType param_type_{kUnKnown};
|
||||
size_t global_seed_{0};
|
||||
size_t op_seed_{0};
|
||||
float init_val_{0};
|
||||
};
|
||||
|
||||
class Util {
|
||||
public:
|
||||
static bool IsParamServerMode();
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include "ps/common.h"
|
||||
#include "ps/worker_proxy.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -47,15 +48,19 @@ class Worker {
|
|||
void Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes);
|
||||
void Pull(const size_t key, void *dev_addr, const size_t size);
|
||||
size_t SetParamKey(const std::string ¶m_name);
|
||||
size_t GetParamKey(const std::string ¶m_name);
|
||||
void SetParamInitInServer(const std::string ¶m_name, bool init_in_server);
|
||||
bool GetParamInitInServer(const std::string ¶m_name);
|
||||
void SetKeyOptimId(size_t key, const std::string &optimizer_name);
|
||||
void SetOptimInputShapes(size_t key, const ShapeVector &shape);
|
||||
void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count);
|
||||
void InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<size_t> shapes, const ShapeVector &sizes);
|
||||
void InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<T> shapes, const ShapeVector &sizes);
|
||||
void InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor);
|
||||
void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
|
||||
const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int64_t cmd);
|
||||
void UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
|
||||
const ::ps::SArray<T> &vals);
|
||||
bool running() { return running_; }
|
||||
void Finalize();
|
||||
|
||||
private:
|
||||
|
@ -65,7 +70,6 @@ class Worker {
|
|||
Worker &operator=(const Worker &) = delete;
|
||||
|
||||
bool IsKeyInit(const size_t key);
|
||||
size_t GetParamKey(const std::string ¶m_name);
|
||||
void InitPSOptimId(const size_t param_key);
|
||||
void InitPSOptimInputShapes(const size_t key);
|
||||
void InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size);
|
||||
|
@ -187,6 +191,12 @@ void Worker<T>::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const :
|
|||
kv_worker_->EmbeddingLookup(keys, lookup_ids, lens, lookup_result, cmd);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Worker<T>::UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
|
||||
const ::ps::SArray<T> &vals) {
|
||||
kv_worker_->UpdateEmbeddingTable(keys, lookup_ids, vals);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Worker<T>::Finalize() {
|
||||
if (running_) {
|
||||
|
@ -286,7 +296,7 @@ size_t Worker<T>::GetParamKey(const std::string ¶m_name) {
|
|||
size_t key = kInvalidKey;
|
||||
if (param_to_key_.find(param_name) != param_to_key_.end()) {
|
||||
key = param_to_key_[param_name];
|
||||
MS_LOG(INFO) << "Get key of parameter " << param_name << " key is " << key;
|
||||
MS_LOG(DEBUG) << "Get key of parameter " << param_name << " key is " << key;
|
||||
}
|
||||
return key;
|
||||
}
|
||||
|
@ -310,8 +320,7 @@ void Worker<T>::InitPSOptimId(const size_t param_key) {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<size_t> shapes,
|
||||
const ShapeVector &sizes) {
|
||||
void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<T> shapes, const ShapeVector &sizes) {
|
||||
bool has_init = IsKeyInit(keys[0]);
|
||||
if (has_init) {
|
||||
MS_LOG(DEBUG) << "The key embedding table of key " << keys[0] << " is initialized.";
|
||||
|
@ -319,7 +328,7 @@ void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vecto
|
|||
}
|
||||
::ps::SArray<T> shapes_val;
|
||||
for (auto dim : shapes) {
|
||||
shapes_val.push_back(static_cast<T>(dim));
|
||||
shapes_val.push_back(dim);
|
||||
}
|
||||
std::vector<int> sizes_int;
|
||||
(void)std::transform(sizes.begin(), sizes.end(), std::back_inserter(sizes_int),
|
||||
|
@ -337,9 +346,6 @@ void Worker<T>::InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::
|
|||
const std::string ¶m_name = pk_node->fullname_with_scope();
|
||||
void *param_data = tensor->data_c();
|
||||
size_t param_size = LongToSize(tensor->data().nbytes());
|
||||
if (param_size > INT_MAX) {
|
||||
MS_LOG(EXCEPTION) << "PS mode max weight size is " << INT_MAX << ", " << param_name << " size is " << param_size;
|
||||
}
|
||||
|
||||
size_t param_key = GetParamKey(param_name);
|
||||
if (param_key == kInvalidKey) {
|
||||
|
@ -357,11 +363,17 @@ void Worker<T>::InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::
|
|||
MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name
|
||||
<< ", whether init in server: " << init_in_server;
|
||||
kv_worker_->AddKeyToServerId(param_key);
|
||||
if (!init_in_server) {
|
||||
InitPSParamData({param_key}, param_data, param_size);
|
||||
if (!PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
if (!init_in_server) {
|
||||
if (param_size > INT_MAX) {
|
||||
MS_LOG(EXCEPTION) << "PS mode max weight size is " << INT_MAX << ", " << param_name << " size is "
|
||||
<< param_size;
|
||||
}
|
||||
InitPSParamData({param_key}, param_data, param_size);
|
||||
}
|
||||
InitPSOptimId(param_key);
|
||||
InitPSOptimInputShapes(param_key);
|
||||
}
|
||||
InitPSOptimId(param_key);
|
||||
InitPSOptimInputShapes(param_key);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -45,6 +45,7 @@ class WorkerProxy : public ::ps::KVWorker<T> {
|
|||
explicit WorkerProxy(int64_t app_id, int64_t customer_id, int64_t lookup_customer_id, int64_t general_customer_id)
|
||||
: Worker(app_id, customer_id) {
|
||||
server_num_ = ::ps::NumServers();
|
||||
MS_LOG(INFO) << "Server num:" << server_num_;
|
||||
PSContext::instance()->SetPSRankId(::ps::MyRank());
|
||||
using std::placeholders::_1;
|
||||
using std::placeholders::_2;
|
||||
|
@ -60,6 +61,7 @@ class WorkerProxy : public ::ps::KVWorker<T> {
|
|||
broadcast_slicer_ = std::bind(&WorkerProxy<T>::BroadcastSlicer, this, _1, _2, _3, _4, _5);
|
||||
round_robin_slicer_ = std::bind(&WorkerProxy<T>::RoundRobinSlicer, this, _1, _2, _3, _4, _5);
|
||||
worker_init_embedding_slicer_ = std::bind(&WorkerProxy<T>::WorkerInitEmbeddingSlicer, this, _1, _2, _3, _4, _5);
|
||||
update_embedding_slicer_ = std::bind(&WorkerProxy<T>::UpdateEmbeddingSlicer, this, _1, _2, _3, _4, _5);
|
||||
}
|
||||
~WorkerProxy() override = default;
|
||||
|
||||
|
@ -70,6 +72,8 @@ class WorkerProxy : public ::ps::KVWorker<T> {
|
|||
const Callback &cb = nullptr, int64_t priority = 0);
|
||||
int64_t InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals,
|
||||
const ::ps::SArray<int> &lens = {}, const Callback &cb = nullptr, int64_t priority = 0);
|
||||
void UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
|
||||
const ::ps::SArray<T> &vals, const Callback &cb = nullptr, int64_t priority = 0);
|
||||
bool IsReadyForPush(const Key &key);
|
||||
bool IsReadyForPull(const Key &key);
|
||||
void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, const ::ps::SArray<int> &lens = {},
|
||||
|
@ -98,6 +102,9 @@ class WorkerProxy : public ::ps::KVWorker<T> {
|
|||
void WorkerInitEmbeddingSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
|
||||
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
void UpdateEmbeddingSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
|
||||
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
|
||||
const std::map<int64_t, int64_t> &attrs);
|
||||
void ProcessLookupResult(const ::ps::Message &msg);
|
||||
void ProcessResponse(const ::ps::Message &msg);
|
||||
void Send(::ps::Customer *customer, int64_t timestamp, bool push, bool pull, int64_t cmd, const ::ps::KVPairs<T> &kvs,
|
||||
|
@ -122,6 +129,7 @@ class WorkerProxy : public ::ps::KVWorker<T> {
|
|||
Slicer broadcast_slicer_;
|
||||
Slicer round_robin_slicer_;
|
||||
Slicer worker_init_embedding_slicer_;
|
||||
Slicer update_embedding_slicer_;
|
||||
std::unordered_map<int64_t, Callback> lookup_callbacks_;
|
||||
std::unordered_map<int64_t, Callback> general_callbacks_;
|
||||
std::unordered_map<int64_t, int64_t> expected_result_count_;
|
||||
|
@ -195,6 +203,24 @@ int64_t WorkerProxy<T>::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys,
|
|||
return ts;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void WorkerProxy<T>::UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
|
||||
const ::ps::SArray<T> &vals, const Callback &cb, int64_t priority) {
|
||||
int ts = AddGeneralRspCB(keys, nullptr, nullptr, 0, nullptr);
|
||||
::ps::KVPairs<T> kvs;
|
||||
kvs.keys = keys;
|
||||
kvs.lens = lookup_ids;
|
||||
kvs.vals = vals;
|
||||
kvs.priority = priority;
|
||||
expected_result_count_[ts] = 0;
|
||||
Send(general_customer_.get(), ts, true, false, kUpdateEmbeddingsCmd, kvs, update_embedding_slicer_);
|
||||
if (expected_result_count_[ts] < server_num_) {
|
||||
general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]);
|
||||
}
|
||||
general_customer_->WaitRequest(ts);
|
||||
expected_result_count_.erase(ts);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool WorkerProxy<T>::IsReadyForPush(const Key &key) {
|
||||
::ps::SArray<T> result(1, 0);
|
||||
|
@ -724,6 +750,47 @@ void WorkerProxy<T>::WorkerInitEmbeddingSlicer(int64_t timestamp, const ::ps::KV
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void WorkerProxy<T>::UpdateEmbeddingSlicer(int timestamp, const ::ps::KVPairs<T> &send,
|
||||
const std::vector<::ps::Range> &,
|
||||
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
|
||||
const std::map<int64_t, int64_t> &attrs) {
|
||||
MS_EXCEPTION_IF_NULL(sliced);
|
||||
T *embedding_vals = send.vals.data();
|
||||
int *lookup_ids = send.lens.data();
|
||||
size_t val_size = send.vals.size();
|
||||
size_t id_size = send.lens.size();
|
||||
size_t embedding_dim = val_size / id_size;
|
||||
|
||||
const Key &key = send.keys[0];
|
||||
const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]);
|
||||
sliced->resize(ranges.size());
|
||||
|
||||
for (size_t i = 0; i < ranges.size(); i++) {
|
||||
const ::ps::Range &range = ranges[i];
|
||||
const auto &begin = range.begin();
|
||||
const auto &end = range.end();
|
||||
auto &kvs = sliced->at(i).second;
|
||||
kvs.keys.push_back(key);
|
||||
for (size_t j = 0; j < id_size; j++) {
|
||||
auto lookup_id = static_cast<uint64_t>(lookup_ids[j]);
|
||||
if (lookup_id >= begin && lookup_id <= end) {
|
||||
kvs.keys.push_back(lookup_id);
|
||||
for (size_t k = 0; k < embedding_dim; k++) {
|
||||
kvs.vals.push_back(embedding_vals[j * embedding_dim + k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (kvs.keys.size() <= 1) {
|
||||
sliced->at(i).first = false;
|
||||
} else {
|
||||
sliced->at(i).first = true;
|
||||
expected_result_count_[timestamp] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void WorkerProxy<T>::ProcessLookupResult(const ::ps::Message &msg) {
|
||||
int64_t ts = msg.meta.timestamp;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <thread>
|
||||
#include <memory>
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "ir/tensor.h"
|
||||
|
||||
namespace mindspore {
|
||||
bool InitRandomNormal(float mean, float stddev, std::vector<int64_t> out_shape, int64_t seed, int64_t seed2,
|
||||
|
|
|
@ -19,8 +19,7 @@
|
|||
#include "pybind_api/random_normal/philox_generator.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
|
|
@ -55,6 +55,7 @@ class AscendKernelRuntime : public KernelRuntime {
|
|||
bool SyncStream() override;
|
||||
void SetContext() override;
|
||||
void CreateContext() override;
|
||||
void *context() const override { return rt_context_; }
|
||||
|
||||
protected:
|
||||
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "runtime/device/gpu/gpu_memory_allocator.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace gpu {
|
||||
|
@ -38,6 +39,9 @@ void GPUMemoryManager::MallocDeviceMemory() {
|
|||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
// If use the dynamic memory pool, then alloc the first memory block to init.
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL)) {
|
||||
if (ps::ps_cache_instance.initialized_ps_cache()) {
|
||||
return;
|
||||
}
|
||||
auto device_addr = MallocMemFromMemPool(1);
|
||||
if (!device_addr) {
|
||||
MS_LOG(EXCEPTION) << "Dynamic memory pool init error.";
|
||||
|
|
|
@ -30,6 +30,10 @@
|
|||
#include "utils/ms_utils.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#include "utils/utils.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#endif
|
||||
|
||||
using mindspore::kernel::Address;
|
||||
using mindspore::kernel::AddressPtr;
|
||||
|
||||
|
@ -331,15 +335,27 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
|
|||
MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph";
|
||||
continue;
|
||||
}
|
||||
DeviceAddressPtr device_address = nullptr;
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
const std::string ¶m_name = item->fullname_with_scope();
|
||||
if (ps::ps_cache_instance.IsHashTable(param_name)) {
|
||||
const auto &address = ps::ps_cache_instance.QueryHashTableAddr(param_name);
|
||||
MS_EXCEPTION_IF_NULL(address.addr);
|
||||
device_address =
|
||||
CreateDeviceAddress(address.addr, address.size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
|
||||
AnfAlgo::SetOutputAddr(device_address, index, item.get());
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
auto tensor_size = CountNodeDeviceMemorySize(item, index);
|
||||
auto address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
|
||||
device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
|
||||
MS_LOG(DEBUG) << "Malloc static memory for " << item->fullname_with_scope();
|
||||
if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) {
|
||||
if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address) == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
|
||||
}
|
||||
MS_LOG(INFO) << "Malloc Input for graph " << graph->graph_id() << ", node: " << item->fullname_with_scope()
|
||||
<< " index: " << index << " size: " << tensor_size;
|
||||
AnfAlgo::SetOutputAddr(address, index, item.get());
|
||||
AnfAlgo::SetOutputAddr(device_address, index, item.get());
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "AssignStaticMemoryInput end";
|
||||
|
|
|
@ -78,6 +78,7 @@ class KernelRuntime {
|
|||
virtual void ClearGlobalIdleMem() {}
|
||||
virtual void CreateContext() {}
|
||||
virtual void SetContext() {}
|
||||
virtual void *context() const { return nullptr; }
|
||||
uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) {
|
||||
return mem_manager_->MallocMem(type, size, address);
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
"""Parameter for cell."""
|
||||
from copy import copy
|
||||
import numbers
|
||||
from .._c_expression import ParamInfo
|
||||
from .._c_expression import MetaTensor as MetaTensor_
|
||||
from . import dtype as mstype
|
||||
|
@ -23,7 +24,10 @@ from .tensor import Tensor, MetaTensor
|
|||
from .._checkparam import Validator
|
||||
from ..parallel._tensor import _get_slice_index
|
||||
from ..parallel._auto_parallel_context import auto_parallel_context
|
||||
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched
|
||||
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _clone_hash_table
|
||||
from ..parallel._ps_context import _reinsert_hash_table_size
|
||||
from ..parallel._ps_context import _insert_weight_init_info, _insert_accumu_init_info
|
||||
from .seed import _get_global_and_op_seed
|
||||
|
||||
__all__ = ['Parameter', 'ParameterTuple']
|
||||
|
||||
|
@ -35,6 +39,18 @@ def _is_in_parallel_mode():
|
|||
"""Get parallel mode."""
|
||||
return auto_parallel_context().get_parallel_mode() in ["semi_auto_parallel", "auto_parallel"]
|
||||
|
||||
def init_to_value(init):
|
||||
"""Get value of initializer."""
|
||||
if isinstance(init, str):
|
||||
if init == 'zeros':
|
||||
return 0.0
|
||||
if init == 'ones':
|
||||
return 1.0
|
||||
raise ValueError("init should be one of values in 'zeros', 'ones'.")
|
||||
if isinstance(init, numbers.Number):
|
||||
return float(init)
|
||||
raise ValueError("init should be number or string")
|
||||
|
||||
|
||||
class Parameter(MetaTensor_):
|
||||
"""
|
||||
|
@ -118,6 +134,8 @@ class Parameter(MetaTensor_):
|
|||
|
||||
def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False):
|
||||
self._param_info = ParamInfo()
|
||||
self.init_in_server = False
|
||||
self.cache_enable = False
|
||||
self.name = name
|
||||
self.requires_grad = requires_grad
|
||||
self.layerwise_parallel = layerwise_parallel
|
||||
|
@ -129,7 +147,6 @@ class Parameter(MetaTensor_):
|
|||
self._sliced = False
|
||||
self.is_param_ps = False
|
||||
self._cast_type = None
|
||||
self.init_in_server = False
|
||||
self._unique = False
|
||||
self.is_in_parallel = _is_in_parallel_mode()
|
||||
if isinstance(default_input, (MetaTensor, Tensor)):
|
||||
|
@ -155,7 +172,7 @@ class Parameter(MetaTensor_):
|
|||
if isinstance(data, bool):
|
||||
raise ValueError('Parameter data can not be `bool`')
|
||||
if isinstance(data, MetaTensor):
|
||||
if _is_in_parallel_mode() or _is_role_worker():
|
||||
if _is_in_parallel_mode() or _is_role_worker() or _is_role_sched():
|
||||
# do not init data while in auto parallel.
|
||||
return (MetaTensor_, data.dtype, data.shape)
|
||||
data = data.to_tensor()
|
||||
|
@ -189,18 +206,18 @@ class Parameter(MetaTensor_):
|
|||
init_in_server (bool): Whether trainable parameter updated by parameter server is
|
||||
initialized on server. Default: False.
|
||||
"""
|
||||
if _is_role_worker() or _is_role_pserver() or _is_role_sched():
|
||||
if init_in_server and (not self.name.endswith("embedding_table")):
|
||||
raise RuntimeError("Can not initialize parameter '{}' in server, only parameters of "
|
||||
"sparse operator support initialization in server.".format(self.name))
|
||||
self.is_param_ps = True
|
||||
self.init_in_server = init_in_server
|
||||
self._param_info.init_in_server = init_in_server
|
||||
else:
|
||||
if not(_is_role_worker() or _is_role_pserver() or _is_role_sched()):
|
||||
raise RuntimeError("Must complete following two steps before calling set_param_ps: \
|
||||
1. set_ps_context(enable_ps=True) \
|
||||
2. export MS_ROLE environment variable.")
|
||||
|
||||
if init_in_server and (not self.name.endswith("embedding_table")):
|
||||
raise RuntimeError("Can not initialize parameter '{}' in server, only parameters of "
|
||||
"sparse operator support initialization in server.".format(self.name))
|
||||
self.is_param_ps = True
|
||||
self.init_in_server = init_in_server
|
||||
self._param_info.init_in_server = init_in_server
|
||||
|
||||
|
||||
@property
|
||||
def inited_param(self):
|
||||
|
@ -238,6 +255,13 @@ class Parameter(MetaTensor_):
|
|||
format(name_, PARAMETER_NAME_PREFIX_MAX_LEN))
|
||||
else:
|
||||
raise ValueError("The type of the name should be `str` or `None`.")
|
||||
|
||||
if _is_role_worker() and self.cache_enable:
|
||||
if len(self.shape) != 2:
|
||||
raise RuntimeError("The dims of parameter '{}' must be 2, but got {}."
|
||||
.format(self.name, len(self.shape)))
|
||||
_reinsert_hash_table_size(name_, self._param_info.name, self.shape[0], self.shape[1])
|
||||
|
||||
self._param_info.name = name_
|
||||
|
||||
@property
|
||||
|
@ -297,6 +321,7 @@ class Parameter(MetaTensor_):
|
|||
x.is_init = False
|
||||
x.is_param_ps = self.is_param_ps
|
||||
x.init_in_server = self.init_in_server
|
||||
x.cache_enable = self.cache_enable
|
||||
if init != 'same':
|
||||
shape = self.shape
|
||||
dtype = self.dtype
|
||||
|
@ -431,15 +456,18 @@ class Parameter(MetaTensor_):
|
|||
raise ValueError("The length of layout must be larger than 3! layout is {}.".format(layout))
|
||||
slice_index = int(_get_slice_index(layout[0], layout[1]))
|
||||
if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)):
|
||||
if _is_role_worker():
|
||||
if _is_role_worker() or _is_role_sched():
|
||||
data = self.init_mode.to_tensor(0, [1])
|
||||
else:
|
||||
data = self.init_mode.to_tensor(slice_index, layout[2], layout[5])
|
||||
else:
|
||||
data = self.init_mode.to_tensor(slice_index, layout[2], layout[5])
|
||||
else:
|
||||
if _is_role_worker() and self.cache_enable:
|
||||
global_seed, op_seed = _get_global_and_op_seed()
|
||||
_insert_weight_init_info(self.name, global_seed, op_seed)
|
||||
if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)):
|
||||
if _is_role_worker():
|
||||
if _is_role_worker() or _is_role_sched():
|
||||
data = self.init_mode.to_tensor(0, [1])
|
||||
else:
|
||||
data = self.init_mode.to_tensor()
|
||||
|
@ -502,6 +530,16 @@ class ParameterTuple(tuple):
|
|||
x1 = x.clone(init)
|
||||
x1.name = prefix + "." + x1.name
|
||||
new.append(x1)
|
||||
|
||||
if not x1.cache_enable:
|
||||
continue
|
||||
if not x1.name.endswith("embedding_table"):
|
||||
raise RuntimeError("Can not enable cache for parameter '{}', Only parameters of "
|
||||
"sparse operator support enable cache.".format(x1.name))
|
||||
|
||||
if _is_role_worker():
|
||||
_clone_hash_table(x.name, x1.name)
|
||||
_insert_accumu_init_info(x1.name, init_to_value(init))
|
||||
return ParameterTuple(new)
|
||||
|
||||
def __parameter_tuple__(self):
|
||||
|
|
|
@ -195,6 +195,20 @@ def _get_op_seed(op_seed, kernel_name):
|
|||
return _KERNEL_SEED[(kernel_name, op_seed)]
|
||||
|
||||
|
||||
def _get_global_and_op_seed():
|
||||
"""Get global_seed and op_seed."""
|
||||
global_seed = get_seed()
|
||||
op_seed = get_seed()
|
||||
if global_seed == 0:
|
||||
global_seed = DEFAULT_GRAPH_SEED
|
||||
elif global_seed is None:
|
||||
global_seed = 0
|
||||
Validator.check_non_negative_int(op_seed, "seed", "init")
|
||||
temp_seed = _get_op_seed(op_seed, "init")
|
||||
seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed)
|
||||
return seeds
|
||||
|
||||
|
||||
def _get_graph_seed(op_seed, kernel_name):
|
||||
"""
|
||||
Get the graph-level seed.
|
||||
|
|
|
@ -73,6 +73,8 @@ inline size_t FloatToSize(float u) {
|
|||
}
|
||||
inline float IntToFloat(int32_t v) { return static_cast<float>(v); }
|
||||
|
||||
inline float SizeToFloat(size_t v) { return static_cast<float>(v); }
|
||||
|
||||
inline double LongToDouble(int64_t v) { return static_cast<double>(v); }
|
||||
|
||||
inline double FloatToDouble(float v) { return static_cast<double>(v); }
|
||||
|
|
|
@ -22,6 +22,7 @@ from mindspore.common.initializer import initializer
|
|||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.parallel._utils import _get_parallel_mode
|
||||
from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.ops.primitive import constexpr
|
||||
|
@ -156,6 +157,7 @@ class EmbeddingLookup(Cell):
|
|||
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
|
||||
or None. Default: None
|
||||
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
|
||||
vocab_cache_size (int): Cache size of the dictionary of embeddings.
|
||||
|
||||
Inputs:
|
||||
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
||||
|
@ -185,7 +187,7 @@ class EmbeddingLookup(Cell):
|
|||
|
||||
def __init__(self, vocab_size, embedding_size, param_init='normal',
|
||||
target='CPU', slice_mode='batch_slice', manual_shapes=None,
|
||||
max_norm=None, sparse=True):
|
||||
max_norm=None, sparse=True, vocab_cache_size=0):
|
||||
super(EmbeddingLookup, self).__init__()
|
||||
self.target = target
|
||||
if target not in ('CPU', 'DEVICE'):
|
||||
|
@ -199,11 +201,23 @@ class EmbeddingLookup(Cell):
|
|||
self.gatherv2 = P.GatherV2()
|
||||
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
|
||||
self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
|
||||
self.vocab_cache_size = validator.check_value_type('vocab_cache_size', vocab_cache_size, [int], self.cls_name)
|
||||
self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)
|
||||
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
|
||||
name='embedding_table')
|
||||
parallel_mode = _get_parallel_mode()
|
||||
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||
self.cache_enable = self.vocab_cache_size > 0
|
||||
if self.cache_enable:
|
||||
if is_auto_parallel:
|
||||
self.vocab_cache_size = self.vocab_cache_size * get_group_size()
|
||||
self.vocab_size = self.vocab_cache_size
|
||||
|
||||
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
|
||||
name='embedding_table')
|
||||
if self.cache_enable:
|
||||
self.embedding_table.cache_enable = True
|
||||
_set_cache_enable(True)
|
||||
if _is_role_worker():
|
||||
_insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size)
|
||||
self.forward_unique = False
|
||||
self.gather_revert = P.GatherV2()
|
||||
self.unique = P.Unique().shard(((1,),))
|
||||
|
@ -222,7 +236,7 @@ class EmbeddingLookup(Cell):
|
|||
self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
|
||||
self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size())))
|
||||
elif slice_mode == "table_row_slice" and is_auto_parallel:
|
||||
if target == 'DEVICE':
|
||||
if target == 'DEVICE' and not self.cache_enable:
|
||||
indices_shape_size = 1
|
||||
self.gather_revert.shard(((1, 1), (get_group_size(),)))
|
||||
self.forward_unique = True
|
||||
|
|
|
@ -88,14 +88,14 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
|
|||
|
||||
|
||||
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool")
|
||||
"Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
|
||||
def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
|
||||
beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter):
|
||||
beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter, cache_enable):
|
||||
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
|
||||
success = True
|
||||
indices = gradient.indices
|
||||
values = gradient.values
|
||||
if ps_parameter:
|
||||
if ps_parameter and not cache_enable:
|
||||
op_shape = P.Shape()
|
||||
shapes = (op_shape(param), op_shape(m), op_shape(v),
|
||||
op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
|
||||
|
@ -158,12 +158,13 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov,
|
|||
|
||||
|
||||
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
|
||||
def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
|
||||
beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2, ps_parameter):
|
||||
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
|
||||
def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target,
|
||||
beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param,
|
||||
moment1, moment2, ps_parameter, cache_enable):
|
||||
"""Apply adam optimizer to the weight parameter using Tensor."""
|
||||
success = True
|
||||
if ps_parameter:
|
||||
if ps_parameter and not cache_enable:
|
||||
op_shape = P.Shape()
|
||||
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
|
||||
(op_shape(param), op_shape(moment1), op_shape(moment2))), param))
|
||||
|
@ -338,12 +339,12 @@ class Adam(Optimizer):
|
|||
success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
|
||||
self.use_locking, self.use_nesterov, self._is_device,
|
||||
beta1_power, beta2_power, self.beta1, self.beta2, self.eps),
|
||||
lr, gradients, params, moment1, moment2, self.ps_parameters)
|
||||
lr, gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable)
|
||||
else:
|
||||
success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
|
||||
self.use_locking, self.use_nesterov, self._is_device,
|
||||
beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr),
|
||||
gradients, params, moment1, moment2, self.ps_parameters)
|
||||
gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable)
|
||||
return success
|
||||
|
||||
@Optimizer.target.setter
|
||||
|
|
|
@ -24,14 +24,14 @@ _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt")
|
|||
|
||||
|
||||
@_ftrl_opt.register("Function", "Function", "Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor",
|
||||
"RowTensor", "Tensor", "Tensor", "Bool")
|
||||
"RowTensor", "Tensor", "Tensor", "Bool", "Bool")
|
||||
def _tensor_run_opt_with_sparse(opt, spars_opt, push, pull, l1, l2, lr_power, learning_rate, linear,
|
||||
gradient, weight, moment, ps_parameter):
|
||||
gradient, weight, moment, ps_parameter, cache_enable):
|
||||
"""Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse."""
|
||||
success = True
|
||||
indices = gradient.indices
|
||||
values = gradient.values
|
||||
if ps_parameter:
|
||||
if ps_parameter and not cache_enable:
|
||||
op_shape = P.Shape()
|
||||
shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(values), op_shape(indices))
|
||||
success = F.depend(success, pull(push((values, indices), shapes), weight))
|
||||
|
@ -41,12 +41,12 @@ def _tensor_run_opt_with_sparse(opt, spars_opt, push, pull, l1, l2, lr_power, le
|
|||
|
||||
|
||||
@_ftrl_opt.register("Function", "Function", "Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor", "Tensor", "Bool")
|
||||
"Tensor", "Tensor", "Tensor", "Bool", "Bool")
|
||||
def _tensor_run_opt(opt, spars_opt, push, pull, l1, l2, lr_power, learning_rate, linear,
|
||||
gradient, weight, moment, ps_parameter):
|
||||
gradient, weight, moment, ps_parameter, cache_enable):
|
||||
"""Apply ftrl optimizer to the weight parameter."""
|
||||
success = True
|
||||
if ps_parameter:
|
||||
if ps_parameter and not cache_enable:
|
||||
op_shape = P.Shape()
|
||||
success = F.depend(success, pull(push((gradient, learning_rate, l1, l2, lr_power),
|
||||
(op_shape(weight), op_shape(moment), op_shape(linear))), weight))
|
||||
|
@ -185,7 +185,7 @@ class FTRL(Optimizer):
|
|||
|
||||
success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
|
||||
self.l1, self.l2, self.lr_power, lr),
|
||||
linear, grads, params, moments, self.ps_parameters)
|
||||
linear, grads, params, moments, self.ps_parameters, self.cache_enable)
|
||||
return success
|
||||
|
||||
@Optimizer.target.setter
|
||||
|
|
|
@ -156,6 +156,8 @@ class Optimizer(Cell):
|
|||
break
|
||||
ps_filter = lambda x: x.is_param_ps
|
||||
self.ps_parameters = tuple(ps_filter(x) for x in self.parameters)
|
||||
ps_cache_filter = lambda x: x.cache_enable
|
||||
self.cache_enable = tuple(ps_cache_filter(x) for x in self.parameters)
|
||||
self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32)
|
||||
self.need_scale = loss_scale != 1.0
|
||||
self.global_step_increase_tensor = Tensor(1, mstype.int32)
|
||||
|
|
|
@ -117,3 +117,21 @@ def _is_role_pserver():
|
|||
|
||||
def _is_role_sched():
|
||||
return ps_context().is_role_sched()
|
||||
|
||||
def _insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size):
|
||||
ps_context().insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size)
|
||||
|
||||
def _reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size):
|
||||
ps_context().reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size)
|
||||
|
||||
def _insert_weight_init_info(name, global_seed, op_seed):
|
||||
ps_context().insert_weight_init_info(name, global_seed, op_seed)
|
||||
|
||||
def _insert_accumu_init_info(name, init_val):
|
||||
ps_context().insert_accumu_init_info(name, init_val)
|
||||
|
||||
def _clone_hash_table(dest_param_name, src_param_name):
|
||||
ps_context().clone_hash_table(dest_param_name, src_param_name)
|
||||
|
||||
def _set_cache_enable(cache_enable):
|
||||
ps_context().set_cache_enable(cache_enable)
|
||||
|
|
|
@ -92,6 +92,10 @@ def connect_network_with_dataset(network, dataset_helper):
|
|||
if isinstance(dataset_iter, _DatasetIterNormal):
|
||||
raise RuntimeError("Dataset should be connected with network only in sink mode.")
|
||||
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||
return network
|
||||
|
||||
if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \
|
||||
and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \
|
||||
and context.get_context("device_target") == "Ascend" \
|
||||
|
@ -166,14 +170,14 @@ class DatasetHelper:
|
|||
iterclass = _DatasetIterGE
|
||||
else:
|
||||
if context.get_context("mode") == context.GRAPH_MODE:
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||
iterclass = _DatasetIterPSServer
|
||||
elif ms_role == "MS_WORKER":
|
||||
iterclass = _DatasetIterPSWork
|
||||
elif (context.get_context("device_target") == "Ascend") or \
|
||||
(context.get_context("device_target") == "GPU"):
|
||||
iterclass = _DatasetIterMSLoopSink
|
||||
elif context.get_context("device_target") == "GPU":
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||
iterclass = _DatasetIterPSLite
|
||||
else:
|
||||
iterclass = _DatasetIterMSLoopSink
|
||||
elif context.get_context("device_target") == "CPU":
|
||||
raise RuntimeError(
|
||||
"Currently dataset sink mode is not supported when the device target is CPU.")
|
||||
|
@ -218,7 +222,10 @@ class _DatasetIter:
|
|||
|
||||
if not hasattr(dataset, '__transfer_dataset__'):
|
||||
if hasattr(dataset, '__loop_size__'):
|
||||
self.sink_size = dataset.__loop_size__
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
# PS mode does not support loop sink and need get the real sink size.
|
||||
if ms_role != "MS_WORKER":
|
||||
self.sink_size = dataset.__loop_size__
|
||||
create_data_info_queue = (sink_size == 1 and self.sink_count == 1 and context.get_context(
|
||||
"device_target") == "Ascend")
|
||||
dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size,
|
||||
|
@ -260,8 +267,12 @@ class _DatasetIter:
|
|||
def get_sink_size(self):
|
||||
"""get sink_size to device"""
|
||||
sink_size = 1
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
if hasattr(self.dataset, '__loop_size__'):
|
||||
sink_size = self.dataset.__loop_size__
|
||||
elif ms_role == "MS_WORKER":
|
||||
# PS mode does not support loop sink.
|
||||
sink_size = 1
|
||||
else:
|
||||
if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend" \
|
||||
or context.get_context("device_target") == "GPU":
|
||||
|
@ -311,9 +322,6 @@ class _DatasetIterMSLoopSink(_DatasetIter):
|
|||
def __init__(self, dataset, sink_size, epoch_num):
|
||||
super().__init__(dataset, sink_size, epoch_num)
|
||||
self.sink_count = self.get_sink_count(dataset)
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||
self.sink_count = 1
|
||||
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch,
|
||||
# use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for
|
||||
# compile is device_number times the batch dimension of tensors for run. Now only support LoopSink.
|
||||
|
@ -341,8 +349,8 @@ class _DatasetIterMS(_DatasetIter):
|
|||
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)
|
||||
|
||||
|
||||
class _DatasetIterPSLite(_DatasetIter):
|
||||
"""Iter for context (device_target=GPU) on MS_PSERVER or MS_SCHED"""
|
||||
class _DatasetIterPSServer(_DatasetIter):
|
||||
"""Iter for context on MS_PSERVER or MS_SCHED"""
|
||||
|
||||
def __init__(self, dataset, sink_size, epoch_num):
|
||||
super().__init__(dataset, sink_size, epoch_num)
|
||||
|
@ -355,6 +363,20 @@ class _DatasetIterPSLite(_DatasetIter):
|
|||
|
||||
self.op = op
|
||||
|
||||
class _DatasetIterPSWork(_DatasetIter):
|
||||
"""Iter for context on MS_WORKER"""
|
||||
|
||||
def __init__(self, dataset, sink_size, epoch_num):
|
||||
super().__init__(dataset, sink_size, epoch_num)
|
||||
if sink_size > 0:
|
||||
self.sink_count = sink_size
|
||||
else:
|
||||
self.sink_count = dataset.get_dataset_size()
|
||||
|
||||
def op():
|
||||
return tuple()
|
||||
|
||||
self.op = op
|
||||
|
||||
class _DatasetIterNormal:
|
||||
"""Iter for normal(non sink) mode, feed the data from host."""
|
||||
|
|
|
@ -30,6 +30,7 @@ def argparse_init():
|
|||
parser.add_argument("--eval_batch_size", type=int, default=16000, help="Eval batch size.")
|
||||
parser.add_argument("--field_size", type=int, default=39, help="The number of features.")
|
||||
parser.add_argument("--vocab_size", type=int, default=200000, help="The total features of dataset.")
|
||||
parser.add_argument("--vocab_cache_size", type=int, default=0, help="The total features of hash table.")
|
||||
parser.add_argument("--emb_dim", type=int, default=80, help="The dense embedding dimension of sparse feature.")
|
||||
parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128],
|
||||
help="The dimension of all deep layers.")
|
||||
|
@ -66,6 +67,7 @@ class WideDeepConfig():
|
|||
self.eval_batch_size = 16000
|
||||
self.field_size = 39
|
||||
self.vocab_size = 200000
|
||||
self.vocab_cache_size = 100000
|
||||
self.emb_dim = 80
|
||||
self.deep_layer_dim = [1024, 512, 256, 128]
|
||||
self.deep_layer_act = 'relu'
|
||||
|
@ -103,6 +105,7 @@ class WideDeepConfig():
|
|||
self.eval_batch_size = args.eval_batch_size
|
||||
self.field_size = args.field_size
|
||||
self.vocab_size = args.vocab_size
|
||||
self.vocab_cache_size = args.vocab_cache_size
|
||||
self.emb_dim = args.emb_dim
|
||||
self.deep_layer_dim = args.deep_layer_dim
|
||||
self.deep_layer_act = args.deep_layer_act
|
||||
|
|
|
@ -147,6 +147,7 @@ class WideDeepModel(nn.Cell):
|
|||
sparse = config.sparse
|
||||
self.field_size = config.field_size
|
||||
self.vocab_size = config.vocab_size
|
||||
self.vocab_cache_size = config.vocab_cache_size
|
||||
self.emb_dim = config.emb_dim
|
||||
self.deep_layer_dims_list = config.deep_layer_dim
|
||||
self.deep_layer_act = config.deep_layer_act
|
||||
|
@ -237,8 +238,20 @@ class WideDeepModel(nn.Cell):
|
|||
self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1)))
|
||||
self.embedding_table = self.deep_embeddinglookup.embedding_table
|
||||
elif parameter_server:
|
||||
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim)
|
||||
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1)
|
||||
cache_enable = self.vocab_cache_size > 0
|
||||
target = 'DEVICE' if cache_enable else 'CPU'
|
||||
if is_auto_parallel and config.full_batch and cache_enable:
|
||||
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target,
|
||||
slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE,
|
||||
sparse=sparse, vocab_cache_size=self.vocab_cache_size)
|
||||
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target=target,
|
||||
slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE,
|
||||
sparse=sparse, vocab_cache_size=self.vocab_cache_size)
|
||||
else:
|
||||
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target,
|
||||
sparse=sparse, vocab_cache_size=self.vocab_cache_size)
|
||||
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target=target, sparse=sparse,
|
||||
vocab_cache_size=self.vocab_cache_size)
|
||||
self.embedding_table = self.deep_embeddinglookup.embedding_table
|
||||
self.deep_embeddinglookup.embedding_table.set_param_ps()
|
||||
self.wide_embeddinglookup.embedding_table.set_param_ps()
|
||||
|
@ -344,7 +357,7 @@ class TrainStepWrap(nn.Cell):
|
|||
"""
|
||||
|
||||
def __init__(self, network, sens=1024.0, host_device_mix=False, parameter_server=False,
|
||||
sparse=False):
|
||||
sparse=False, cache_enable=False):
|
||||
super(TrainStepWrap, self).__init__()
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||
|
@ -361,7 +374,7 @@ class TrainStepWrap(nn.Cell):
|
|||
self.weights_w = ParameterTuple(weights_w)
|
||||
self.weights_d = ParameterTuple(weights_d)
|
||||
|
||||
if (sparse and is_auto_parallel) or parameter_server:
|
||||
if (sparse and is_auto_parallel) or (parameter_server and not cache_enable):
|
||||
self.optimizer_d = LazyAdam(
|
||||
self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens)
|
||||
self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w,
|
||||
|
@ -417,10 +430,17 @@ class TrainStepWrap(nn.Cell):
|
|||
|
||||
|
||||
class PredictWithSigmoid(nn.Cell):
|
||||
"""
|
||||
Predict definition
|
||||
"""
|
||||
def __init__(self, network):
|
||||
super(PredictWithSigmoid, self).__init__()
|
||||
self.network = network
|
||||
self.sigmoid = P.Sigmoid()
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||
if is_auto_parallel:
|
||||
self.sigmoid.shard(((1, 1),))
|
||||
|
||||
def construct(self, batch_ids, batch_wts, labels):
|
||||
logits, _, = self.network(batch_ids, batch_wts)
|
||||
|
|
|
@ -39,7 +39,8 @@ def get_WideDeep_net(config):
|
|||
"""
|
||||
WideDeep_net = WideDeepModel(config)
|
||||
loss_net = NetWithLossClass(WideDeep_net, config)
|
||||
train_net = TrainStepWrap(loss_net, parameter_server=bool(config.parameter_server))
|
||||
train_net = TrainStepWrap(loss_net, parameter_server=bool(config.parameter_server),
|
||||
cache_enable=bool(config.vocab_cache_size > 0))
|
||||
eval_net = PredictWithSigmoid(WideDeep_net)
|
||||
return train_net, eval_net
|
||||
|
||||
|
@ -81,6 +82,7 @@ def train_and_eval(config):
|
|||
else:
|
||||
dataset_type = DataType.H5
|
||||
parameter_server = bool(config.parameter_server)
|
||||
cache_enable = bool(config.vocab_cache_size > 0)
|
||||
print("epochs is {}".format(epochs))
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
|
||||
batch_size=batch_size, rank_id=get_rank(),
|
||||
|
@ -111,7 +113,7 @@ def train_and_eval(config):
|
|||
callback_list.append(ckpoint_cb)
|
||||
model.train(epochs, ds_train,
|
||||
callbacks=callback_list,
|
||||
dataset_sink_mode=(not parameter_server))
|
||||
dataset_sink_mode=(parameter_server and cache_enable))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue