forked from mindspore-Ecosystem/mindspore
ps cache support sparse
This commit is contained in:
parent
424e68a803
commit
f17534af08
|
@ -53,7 +53,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
|
|||
Register(prim::kPrimReduceAny->name(), {1});
|
||||
Register(prim::kPrimUnsortedSegmentMin->name(), {2});
|
||||
Register(prim::kPrimUnsortedSegmentMax->name(), {2});
|
||||
Register(kSparseGatherV2, {2});
|
||||
Register(kSparseGatherV2OpName, {2});
|
||||
Register(kUnsortedSegmentProdOpName, {2});
|
||||
Register(kSimpleMeanGradOpName, {1});
|
||||
Register(kMeanGradOpName, {1});
|
||||
|
@ -109,7 +109,7 @@ bool ConstInputToAttrInfoRegistry::GetRegisterByOpName(const std::string &op_nam
|
|||
ConstInputToAttrInfoRegister *reg) const {
|
||||
if (op_input_to_attr_map_.find(op_name) != op_input_to_attr_map_.end()) {
|
||||
*reg = op_input_to_attr_map_.at(op_name);
|
||||
MS_LOG(DEBUG) << op_name << " const2attr find in registery.";
|
||||
MS_LOG(DEBUG) << op_name << " const2attr find in registry.";
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
|
|
@ -31,15 +31,22 @@ std::string GetOpPythonPath(const OperatorName &op_name) {
|
|||
// almost all ops are defined in two main paths
|
||||
const std::string ops_module = OP_PATH;
|
||||
const std::string inner_ops_module = INNER_OP_PATH;
|
||||
const std::string functional_op_module = FUNCTIONAL_OP_PATH;
|
||||
py::module mod = py::module::import(common::SafeCStr(ops_module));
|
||||
py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module));
|
||||
if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) {
|
||||
if (!py::hasattr(mod, common::SafeCStr(op_name))) {
|
||||
MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name;
|
||||
}
|
||||
py::module functional_mod = py::module::import(common::SafeCStr(functional_op_module));
|
||||
|
||||
if (py::hasattr(inner_mod, common::SafeCStr(op_name))) {
|
||||
return inner_ops_module;
|
||||
}
|
||||
if (py::hasattr(mod, common::SafeCStr(op_name))) {
|
||||
return ops_module;
|
||||
}
|
||||
return inner_ops_module;
|
||||
if (!py::hasattr(functional_mod, common::SafeCStr(op_name))) {
|
||||
MS_LOG(EXCEPTION) << ops_module << " and " << inner_ops_module << " and " << functional_op_module
|
||||
<< " don't have op:" << op_name;
|
||||
}
|
||||
return functional_op_module;
|
||||
}
|
||||
|
||||
ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) {
|
||||
|
@ -141,7 +148,7 @@ Status GenerateGraph::Init(const CNodePtr &cnode) {
|
|||
}
|
||||
|
||||
AnfNodePtr GenerateGraph::PushBack(const std::vector<AnfNodePtr> &inputs) {
|
||||
CNodePtr cnode = func_graph_->NewCNode(inputs); // using NewCNode to creat anfnode
|
||||
CNodePtr cnode = func_graph_->NewCNode(inputs); // using NewCNode to create anfnode
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
cnode->set_scope(scope_);
|
||||
if (inputs.size() < 2) {
|
||||
|
|
|
@ -24,8 +24,10 @@
|
|||
|
||||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#include "utils/ms_context.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -158,6 +160,15 @@ Status GatherV2PInfo::GetAttrs() {
|
|||
if (std::find(inputs_shape_[1].begin(), inputs_shape_[1].end(), -1) != inputs_shape_[1].end()) {
|
||||
dynamic_shape_indices_ = true;
|
||||
}
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
|
||||
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
||||
bool enable_sparse = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_SPARSE);
|
||||
if (ps::PsDataPrefetch::GetInstance().cache_enable() && enable_sparse) {
|
||||
dynamic_shape_indices_ = true;
|
||||
}
|
||||
#endif
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -531,7 +542,7 @@ Status GatherV2PInfo::InferBias() {
|
|||
}
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
bias_ = 0;
|
||||
bias_ = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound());
|
||||
return SUCCESS;
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -68,6 +68,7 @@ constexpr char REDUCE_OP_MAX[] = "max";
|
|||
constexpr char REDUCE_OP_MIN[] = "min";
|
||||
constexpr char OP_PATH[] = "mindspore.ops.operations";
|
||||
constexpr char INNER_OP_PATH[] = "mindspore.ops.operations._inner_ops";
|
||||
constexpr char FUNCTIONAL_OP_PATH[] = "mindspore.ops.functional";
|
||||
constexpr char GET_OP_FUNCTION_PATH[] = "mindspore.parallel._utils";
|
||||
constexpr char GET_OP_FUNCTION[] = "_get_python_op";
|
||||
constexpr char KEEP_DIMS[] = "keep_dims";
|
||||
|
|
|
@ -23,9 +23,13 @@
|
|||
|
||||
#include "ir/value.h"
|
||||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
@ -186,5 +190,63 @@ Status UniqueInfo::GenerateStrategies(int64_t stage_id) {
|
|||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||
GenerateGraph gen_g = GenerateGraph();
|
||||
if (gen_g.Init(cnode) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "GenerateGraph Init failed";
|
||||
return FAILED;
|
||||
}
|
||||
auto bias = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound());
|
||||
auto slice_size = SizeToLong(ps::PsCacheManager::GetInstance().vocab_cache_size());
|
||||
|
||||
auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias)});
|
||||
auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub});
|
||||
auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size - 1)});
|
||||
auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum});
|
||||
auto unique = gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node()});
|
||||
auto tuple_getitem_0 = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), unique, CreatInt64Imm(0)});
|
||||
auto tuple_getitem_1 = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), unique, CreatInt64Imm(1)});
|
||||
auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), tuple_getitem_1});
|
||||
auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype});
|
||||
auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), tuple_getitem_1, cast});
|
||||
|
||||
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
|
||||
OperatorAttrs attrs = {attr_op};
|
||||
AnfNodePtr reduce_op;
|
||||
reduce_op = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), mul});
|
||||
auto make_tuple = gen_g.PushBack({gen_g.NewOpInst(MAKE_TUPLE), tuple_getitem_0, reduce_op});
|
||||
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(sub, 1), std::make_pair(unique, 1)};
|
||||
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
|
||||
std::make_pair(input_nodes, make_tuple));
|
||||
return SUCCESS;
|
||||
}
|
||||
#endif
|
||||
|
||||
ReplaceGraphPtr UniqueInfo::replace_graph(const CNodePtr &cnode) {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
auto inputs = cnode->inputs();
|
||||
if (inputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid inputs";
|
||||
}
|
||||
const auto &primitive = GetValueNode<PrimitivePtr>(inputs[0]);
|
||||
const auto &attr = primitive->GetAttr("cache_enable");
|
||||
if (attr == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto need_mask = GetValue<bool>(attr);
|
||||
if (!need_mask) {
|
||||
return nullptr;
|
||||
}
|
||||
if (ComputeReplaceGraph(cnode) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
|
||||
}
|
||||
return replace_graph_;
|
||||
}
|
||||
#endif
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -39,6 +39,7 @@ class UniqueInfo : public OperatorInfo {
|
|||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||
Status GenerateStrategies(int64_t stage_id) override;
|
||||
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
||||
|
||||
protected:
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
|
@ -50,8 +51,12 @@ class UniqueInfo : public OperatorInfo {
|
|||
Status InferMirrorOps() override;
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferAsLossDivisor() override { return SUCCESS; }
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
||||
#endif
|
||||
|
||||
private:
|
||||
std::string replace_op_name_ = UNIQUE;
|
||||
int64_t dev_num_ = 1;
|
||||
};
|
||||
} // namespace parallel
|
||||
|
|
|
@ -321,7 +321,8 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("insert_weight_init_info", &PSContext::InsertWeightInitInfo, "Insert embedding table initialization seed.")
|
||||
.def("insert_accumu_init_info", &PSContext::InsertAccumuInitInfo, "Insert accumulation initialization value.")
|
||||
.def("clone_hash_table", &PSContext::CloneHashTable, "Clone a hash table.")
|
||||
.def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not.");
|
||||
.def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not.")
|
||||
.def("set_rank_id", &PSContext::set_rank_id, "Set rank id for worker on ps mode.");
|
||||
|
||||
(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")
|
||||
.def(py::init())
|
||||
|
|
|
@ -773,12 +773,14 @@ void ParameterServer<T>::GetEmbeddingTableParamPtr() {
|
|||
for (auto cnode : cnodes) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
|
||||
if (cnode_name == kEmbeddingLookupOpName || cnode_name == kGatherV2OpName) {
|
||||
if (cnode_name == kEmbeddingLookupOpName || cnode_name == kGatherV2OpName || cnode_name == kSparseGatherV2OpName) {
|
||||
auto embedding_table = AnfAlgo::GetInputNode(cnode, 0);
|
||||
MS_EXCEPTION_IF_NULL(embedding_table);
|
||||
MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count;
|
||||
embedding_tables_.insert(std::make_pair(count, embedding_table->cast<ParameterPtr>()));
|
||||
count++;
|
||||
if (embedding_table->isa<Parameter>()) {
|
||||
MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count;
|
||||
embedding_tables_.insert(std::make_pair(count, embedding_table->cast<ParameterPtr>()));
|
||||
count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,11 +35,11 @@ void PsCacheManager::InsertHashTableSize(const std::string ¶m_name, size_t c
|
|||
if (vocab_size_ == 0) {
|
||||
vocab_size_ = vocab_size;
|
||||
}
|
||||
if (cache_vocab_size_ == 0) {
|
||||
cache_vocab_size_ = cache_vocab_size;
|
||||
if (vocab_cache_size_ == 0) {
|
||||
vocab_cache_size_ = cache_vocab_size;
|
||||
}
|
||||
if (host_cache_vocab_size_ == 0) {
|
||||
host_cache_vocab_size_ = cache_vocab_size * kHostCacheScaleFactor;
|
||||
if (host_vocab_cache_size_ == 0) {
|
||||
host_vocab_cache_size_ = cache_vocab_size * kHostCacheScaleFactor;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -148,8 +148,8 @@ void PsCacheManager::Initialize() {
|
|||
Util::SetInternalEnvVar();
|
||||
worker.Run();
|
||||
}
|
||||
embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_elements_, cache_vocab_size_);
|
||||
embedding_host_cache_ = std::make_shared<EmbeddingHostCache>(batch_elements_, host_cache_vocab_size_);
|
||||
embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_elements_, vocab_cache_size_);
|
||||
embedding_host_cache_ = std::make_shared<EmbeddingHostCache>(batch_elements_, host_vocab_cache_size_);
|
||||
AddEmbeddingTable();
|
||||
AllocMemForHashTable();
|
||||
SetLocalIdRank();
|
||||
|
@ -220,13 +220,13 @@ void PsCacheManager::AllocMemForHashTable() {
|
|||
for (auto &item : hash_tables_) {
|
||||
size_t embedding_size = item.second.embedding_size;
|
||||
auto &device_address = item.second.device_address;
|
||||
device_address.size = cache_vocab_size_ * embedding_size * sizeof(float);
|
||||
device_address.size = vocab_cache_size_ * embedding_size * sizeof(float);
|
||||
auto addr = embedding_device_cache_->cache_->MallocMemory(device_address.size);
|
||||
MS_EXCEPTION_IF_NULL(addr);
|
||||
device_address.addr = addr;
|
||||
|
||||
auto &host_address = item.second.host_address;
|
||||
auto host_address_ptr = new float[host_cache_vocab_size_ * embedding_size];
|
||||
auto host_address_ptr = new float[host_vocab_cache_size_ * embedding_size];
|
||||
MS_EXCEPTION_IF_NULL(host_address_ptr);
|
||||
host_address = std::shared_ptr<float[]>(host_address_ptr, std::default_delete<float[]>());
|
||||
MS_EXCEPTION_IF_NULL(host_address);
|
||||
|
@ -239,21 +239,28 @@ void PsCacheManager::AllocMemForHashTable() {
|
|||
embedding_device_cache_->hash_swap_value_addr_ = reinterpret_cast<float *>(
|
||||
embedding_device_cache_->cache_->MallocMemory(max_embedding_size * batch_elements_ * sizeof(float)));
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_value_addr_);
|
||||
if (!(embedding_device_cache_->cache_->MallocConstantMemory(cache_vocab_size_))) {
|
||||
if (!(embedding_device_cache_->cache_->MallocConstantMemory(vocab_cache_size_))) {
|
||||
MS_LOG(EXCEPTION) << "MallocConstantMemory failed.";
|
||||
}
|
||||
}
|
||||
|
||||
void PsCacheManager::SetLocalIdRank() {
|
||||
auto worker_num = ::ps::NumWorkers();
|
||||
auto worker_id = ::ps::MyRank();
|
||||
auto local_shard_size = FloatToSize(std::ceil(SizeToFloat(vocab_size_) / worker_num));
|
||||
range_bound_.first = local_shard_size * worker_id;
|
||||
range_bound_.second = std::min(range_bound_.first + local_shard_size, vocab_size_);
|
||||
MS_LOG(INFO) << "Worker num:" << worker_num << ", worker id:" << worker_id << ", rank id begin:" << range_bound_.first
|
||||
<< ", rank id end:" << range_bound_.second;
|
||||
auto local_shard_size = FloatToInt(std::ceil(SizeToFloat(vocab_size_) / worker_num));
|
||||
vocab_cache_size_diff_ = local_shard_size - SizeToInt(vocab_cache_size_);
|
||||
emb_table_slice_bounds_.first = local_shard_size * rank_id_;
|
||||
emb_table_slice_bounds_.second = std::min(emb_table_slice_bounds_.first + local_shard_size, SizeToInt(vocab_size_));
|
||||
cache_indices_bounds_.first = SizeToInt(vocab_cache_size_) * rank_id_;
|
||||
cache_indices_bounds_.second = cache_indices_bounds_.first + SizeToInt(vocab_cache_size_);
|
||||
MS_LOG(INFO) << "Worker num:" << worker_num << ", rank id:" << rank_id_
|
||||
<< ", id begin:" << emb_table_slice_bounds_.first << ", id end:" << emb_table_slice_bounds_.second
|
||||
<< ", cache indices begin: " << cache_indices_bounds_.first
|
||||
<< ", cache indices end: " << cache_indices_bounds_.second
|
||||
<< ", vocab_cache_size_diff: " << vocab_cache_size_diff_;
|
||||
}
|
||||
|
||||
int PsCacheManager::cache_indices_lower_bound() const { return cache_indices_bounds_.first; }
|
||||
|
||||
std::string PsCacheManager::channel_name() {
|
||||
std::lock_guard<std::mutex> locker(channel_mutex_);
|
||||
return channel_name_;
|
||||
|
@ -398,8 +405,8 @@ bool PsCacheManager::ProcessData() {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool PsCacheManager::CheckIDInDeviceTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index,
|
||||
bool *in_device, size_t *hash_hit_count) {
|
||||
bool PsCacheManager::CheckCacheHitOrOutRangeTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index,
|
||||
bool *in_device, bool *out_range, size_t *hash_hit_count) {
|
||||
MS_ERROR_IF_NULL(batch_ids);
|
||||
MS_ERROR_IF_NULL(hash_index);
|
||||
MS_ERROR_IF_NULL(in_device);
|
||||
|
@ -410,9 +417,19 @@ bool PsCacheManager::CheckIDInDeviceTask(const int *batch_ids, const size_t batc
|
|||
const auto &hash_id_to_index = device_hash_map->hash_id_to_index();
|
||||
|
||||
for (size_t i = 0; i < batch_ids_len; ++i) {
|
||||
if (batch_ids[i] < emb_table_slice_bounds_.first) {
|
||||
hash_index[i] = batch_ids[i] - vocab_cache_size_diff_;
|
||||
out_range[i] = true;
|
||||
continue;
|
||||
}
|
||||
if (batch_ids[i] >= emb_table_slice_bounds_.second) {
|
||||
hash_index[i] = batch_ids[i] + cache_indices_bounds_.second;
|
||||
out_range[i] = true;
|
||||
continue;
|
||||
}
|
||||
auto iter = hash_id_to_index.find(batch_ids[i]);
|
||||
if (iter != hash_id_to_index.end()) {
|
||||
hash_index[i] = iter->second;
|
||||
hash_index[i] = iter->second + cache_indices_bounds_.first;
|
||||
if (device_hash_map->hash_step(iter->second) != data_step_) {
|
||||
++(*hash_hit_count);
|
||||
device_hash_map->set_hash_step(iter->second, data_step_);
|
||||
|
@ -423,11 +440,12 @@ bool PsCacheManager::CheckIDInDeviceTask(const int *batch_ids, const size_t batc
|
|||
return true;
|
||||
}
|
||||
|
||||
bool PsCacheManager::CheckIDInDevice(const int *batch_ids, const size_t batch_ids_len, int *hash_index,
|
||||
bool *in_device) {
|
||||
bool PsCacheManager::CheckCacheHitOrOutRange(const int *batch_ids, const size_t batch_ids_len, int *hash_index,
|
||||
bool *in_device, bool *out_range) {
|
||||
MS_ERROR_IF_NULL(batch_ids);
|
||||
MS_ERROR_IF_NULL(hash_index);
|
||||
MS_ERROR_IF_NULL(in_device);
|
||||
MS_ERROR_IF_NULL(out_range);
|
||||
|
||||
size_t thread_num = batch_ids_len / kMinIdsPerThread + 1;
|
||||
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
|
||||
|
@ -441,8 +459,9 @@ bool PsCacheManager::CheckIDInDevice(const int *batch_ids, const size_t batch_id
|
|||
break;
|
||||
}
|
||||
size_t task_proc_lens = batch_ids_len / thread_num + (i < (batch_ids_len % thread_num) ? 1 : 0);
|
||||
threads[i] = std::thread(&PsCacheManager::CheckIDInDeviceTask, this, batch_ids + task_offset, task_proc_lens,
|
||||
hash_index + task_offset, in_device + task_offset, hash_hit_count + i);
|
||||
threads[i] =
|
||||
std::thread(&PsCacheManager::CheckCacheHitOrOutRangeTask, this, batch_ids + task_offset, task_proc_lens,
|
||||
hash_index + task_offset, in_device + task_offset, out_range + task_offset, hash_hit_count + i);
|
||||
task_offset += task_proc_lens;
|
||||
}
|
||||
if (task_offset != batch_ids_len) {
|
||||
|
@ -477,27 +496,26 @@ bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len,
|
|||
MS_ERROR_IF_NULL(hash_index);
|
||||
statistics_info_.batch_id_count_ = batch_ids_len;
|
||||
std::unique_ptr<bool[]> in_device(new bool[batch_ids_len]);
|
||||
std::unique_ptr<bool[]> out_range(new bool[batch_ids_len]);
|
||||
if (memset_s(in_device.get(), batch_ids_len * sizeof(bool), 0, batch_ids_len * sizeof(bool))) {
|
||||
MS_LOG(EXCEPTION) << "Data in device memset failed.";
|
||||
MS_LOG(EXCEPTION) << "Initialize in_device array failed.";
|
||||
}
|
||||
CheckIDInDevice(batch_ids, batch_ids_len, hash_index, in_device.get());
|
||||
if (memset_s(out_range.get(), batch_ids_len * sizeof(bool), 0, batch_ids_len * sizeof(bool))) {
|
||||
MS_LOG(EXCEPTION) << "Initialize out_range array failed.";
|
||||
}
|
||||
RETURN_IF_FALSE(CheckCacheHitOrOutRange(batch_ids, batch_ids_len, hash_index, in_device.get(), out_range.get()));
|
||||
RETURN_IF_FALSE(ResetEmbeddingHashMap());
|
||||
for (size_t i = 0; i < batch_ids_len; i++) {
|
||||
if (in_device[i]) {
|
||||
if (in_device[i] || out_range[i]) {
|
||||
continue;
|
||||
}
|
||||
bool need_swap_host_to_device = true;
|
||||
bool need_swap_device_to_host = true;
|
||||
auto id = batch_ids[i];
|
||||
if ((id < SizeToInt(range_bound_.first)) || (id >= SizeToInt(range_bound_.second))) {
|
||||
hash_index[i] = -1;
|
||||
continue;
|
||||
}
|
||||
int index = INVALID_INDEX_VALUE;
|
||||
RETURN_IF_FALSE(ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device, &index));
|
||||
hash_index[i] = index;
|
||||
RETURN_IF_FALSE(ParseDeviceData(batch_ids[i], &need_swap_device_to_host, &need_swap_host_to_device, &index));
|
||||
hash_index[i] = index + cache_indices_bounds_.first;
|
||||
if (need_swap_host_to_device) {
|
||||
RETURN_IF_FALSE(ParseHostDataHostToDevice(id));
|
||||
RETURN_IF_FALSE(ParseHostDataHostToDevice(batch_ids[i]));
|
||||
}
|
||||
if (need_swap_device_to_host) {
|
||||
RETURN_IF_FALSE(ParseHostDataDeviceToHost());
|
||||
|
@ -667,7 +685,7 @@ void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size,
|
|||
|
||||
bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr,
|
||||
const int *indices_addr, float *output_addr) {
|
||||
size_t first_dim_size = host_cache_vocab_size_;
|
||||
size_t first_dim_size = host_vocab_cache_size_;
|
||||
size_t outer_dim_size = embedding_size;
|
||||
|
||||
size_t thread_num = indices_lens / 10000 + 1;
|
||||
|
@ -697,7 +715,7 @@ bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_l
|
|||
|
||||
bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices,
|
||||
float *insert_data, float *hash_table_addr) {
|
||||
size_t first_dim_size = host_cache_vocab_size_;
|
||||
size_t first_dim_size = host_vocab_cache_size_;
|
||||
size_t thread_num = insert_indices_size / 10000 + 1;
|
||||
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
|
||||
std::thread threads[kMaxThreadNum];
|
||||
|
|
|
@ -125,7 +125,10 @@ class PsCacheManager {
|
|||
const size_t &QueryHashTableSize(const std::string ¶m_name) const;
|
||||
bool IsHashTable(const std::string ¶m_name) { return hash_tables_.count(param_name) != 0; }
|
||||
void set_batch_elements(size_t batch_elements) { batch_elements_ = batch_elements; }
|
||||
void set_rank_id(int rank_id) { rank_id_ = rank_id; }
|
||||
bool initialized_ps_cache() const { return initialized_ps_cache_; }
|
||||
size_t vocab_cache_size() const { return vocab_cache_size_; }
|
||||
int cache_indices_lower_bound() const;
|
||||
void DoProcessData(uint32_t device_id, void *context);
|
||||
void IncreaseGraphStep(const std::string &channel_name);
|
||||
void SyncEmbeddingTable();
|
||||
|
@ -170,10 +173,12 @@ class PsCacheManager {
|
|||
void DumpStatisticsInfo(size_t each_print_step = 1000);
|
||||
bool SyncHostEmbeddingTable();
|
||||
bool SyncDeviceEmbeddingTable();
|
||||
bool CheckIDInDeviceTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device,
|
||||
size_t *hash_hit_count);
|
||||
bool CheckIDInDevice(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device);
|
||||
bool CheckCacheHitOrOutRangeTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device,
|
||||
bool *out_range, size_t *hash_hit_count);
|
||||
bool CheckCacheHitOrOutRange(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device,
|
||||
bool *out_range);
|
||||
bool ResetEmbeddingHashMap();
|
||||
|
||||
bool initialized_ps_cache_{false};
|
||||
std::string channel_name_;
|
||||
std::mutex channel_mutex_;
|
||||
|
@ -190,11 +195,14 @@ class PsCacheManager {
|
|||
std::shared_ptr<EmbeddingHostCache> embedding_host_cache_;
|
||||
|
||||
size_t vocab_size_{0};
|
||||
size_t cache_vocab_size_{0};
|
||||
size_t host_cache_vocab_size_{0};
|
||||
size_t vocab_cache_size_{0};
|
||||
size_t host_vocab_cache_size_{0};
|
||||
size_t batch_elements_{0};
|
||||
PsCacheStatisticsInfo statistics_info_;
|
||||
std::pair<size_t, size_t> range_bound_;
|
||||
std::pair<int, int> emb_table_slice_bounds_;
|
||||
std::pair<int, int> cache_indices_bounds_;
|
||||
int vocab_cache_size_diff_{0};
|
||||
int rank_id_{0};
|
||||
std::atomic_bool finish_insert_init_info_{false};
|
||||
std::atomic_bool finish_init_parameter_server_{false};
|
||||
std::atomic_bool running_{false};
|
||||
|
|
|
@ -129,5 +129,11 @@ void PSContext::set_cache_enable(bool cache_enable) const {
|
|||
PsDataPrefetch::GetInstance().set_cache_enable(cache_enable);
|
||||
#endif
|
||||
}
|
||||
|
||||
void PSContext::set_rank_id(int rank_id) const {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
ps_cache_instance.set_rank_id(rank_id);
|
||||
#endif
|
||||
}
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -52,6 +52,7 @@ class PSContext {
|
|||
void InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const;
|
||||
void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const;
|
||||
void set_cache_enable(bool cache_enable) const;
|
||||
void set_rank_id(int rank_id) const;
|
||||
|
||||
private:
|
||||
PSContext() : ps_enabled_(false), is_worker_(false), is_pserver_(false), is_sched_(false), rank_id_(-1) {}
|
||||
|
|
|
@ -391,7 +391,7 @@ bool AscendKernelRuntime::GenDynamicKernel(const session::KernelGraph *graph) {
|
|||
bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
|
||||
InnerSetContext();
|
||||
if (graph->is_dynamic_shape()) {
|
||||
if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
|
||||
if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE && (ConfigManager::GetInstance().iter_num() > 1)) {
|
||||
MS_LOG(EXCEPTION) << "Dynamic shape is not supported with sink mode.";
|
||||
}
|
||||
if (DumpJsonParser::GetInstance().async_dump_enabled()) {
|
||||
|
|
|
@ -851,7 +851,7 @@ void GPUKernelRuntime::UpdateHostSwapInQueue(const DeviceAddressPtr device_addre
|
|||
MS_LOG(WARNING) << "Unexpected device address status: " << status;
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Invaild device address status: " << status;
|
||||
MS_LOG(EXCEPTION) << "Invalid device address status: " << status;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1092,6 +1092,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel)
|
|||
MS_EXCEPTION_IF_NULL(mem_reuse_util_);
|
||||
auto cnode = kernel->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// Can not free the input addr of communication op when enable multi stream
|
||||
if (AnfAlgo::IsCommunicationOp(kernel)) {
|
||||
return;
|
||||
}
|
||||
|
@ -1106,7 +1107,9 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel)
|
|||
}
|
||||
|
||||
auto kernel_with_index = GetPrevNodeOutput(kernel, i);
|
||||
if (AnfAlgo::IsCommunicationOp(kernel_with_index.first)) {
|
||||
// Maintain output addr of fused communication op to improve training performance
|
||||
if (AnfAlgo::IsCommunicationOp(kernel_with_index.first) &&
|
||||
AnfAlgo::GetInputTensorNum(kernel_with_index.first) > 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
@ -1049,7 +1049,8 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph,
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
for (const auto &kernel : graph->execution_order()) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") {
|
||||
auto kernel_name = AnfAlgo::GetCNodeName(kernel);
|
||||
if (kernel_name != kGatherV2OpName && kernel_name != kSparseGatherV2OpName) {
|
||||
continue;
|
||||
}
|
||||
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true);
|
||||
|
@ -1061,13 +1062,15 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph,
|
|||
continue;
|
||||
}
|
||||
auto size = ps::ps_cache_instance.QueryHashTableSize(param_name);
|
||||
while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == "Cast")) {
|
||||
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second, true);
|
||||
while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == kCastOpName)) {
|
||||
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true);
|
||||
MS_EXCEPTION_IF_NULL(input_index.first);
|
||||
}
|
||||
if (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) != "GetNext")) {
|
||||
auto input_index_node_name = AnfAlgo::GetCNodeName(input_index.first);
|
||||
if (input_index.first->isa<CNode>() && (input_index_node_name != kGetNextOpName)) {
|
||||
bool full_batch = parallel::ParallelContext::GetInstance()->full_batch();
|
||||
if ((!full_batch) || (AnfAlgo::GetCNodeName(input_index.first) != "Minimum")) {
|
||||
if ((!full_batch && (input_index_node_name != kUniqueOpName)) ||
|
||||
(full_batch && (input_index_node_name != kMinimumOpName))) {
|
||||
MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope()
|
||||
<< ") cache is from " << input_index.first->fullname_with_scope();
|
||||
MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in "
|
||||
|
@ -1082,6 +1085,28 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph,
|
|||
}
|
||||
}
|
||||
|
||||
void KernelRuntime::CheckSparsePSEmbeddingCache(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto pre_node = AnfAlgo::GetPrevNodeOutput(node, 1, true);
|
||||
while (pre_node.first->isa<CNode>() && (AnfAlgo::GetCNodeName(pre_node.first) != kUniqueOpName)) {
|
||||
pre_node = AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
|
||||
MS_EXCEPTION_IF_NULL(pre_node.first);
|
||||
}
|
||||
if (!(pre_node.first->isa<CNode>()) || (AnfAlgo::GetCNodeName(pre_node.first) != kUniqueOpName)) {
|
||||
MS_LOG(EXCEPTION) << "The input_indices of kernel[SparseGatherV2] must be unique in parameter server cache mode";
|
||||
}
|
||||
|
||||
pre_node = AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
|
||||
while (pre_node.first->isa<CNode>() && (AnfAlgo::GetCNodeName(pre_node.first) == kCastOpName)) {
|
||||
pre_node = AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
|
||||
MS_EXCEPTION_IF_NULL(pre_node.first);
|
||||
}
|
||||
if (!(pre_node.first->isa<CNode>()) || (AnfAlgo::GetCNodeName(pre_node.first) != kGetNextOpName)) {
|
||||
MS_LOG(EXCEPTION) << "The input indices of kernel[Unique] must be produced from dataset directly and the indices "
|
||||
"value can not be changed before delivering to kernel[Unique] in parameter server cache mode.";
|
||||
}
|
||||
}
|
||||
|
||||
void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
AnfNodePtr first_cache_input_index = nullptr;
|
||||
|
@ -1090,16 +1115,23 @@ void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *g
|
|||
MS_EXCEPTION_IF_NULL(first_cache_input_index);
|
||||
for (const auto &kernel : graph->execution_order()) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") {
|
||||
auto kernel_name = AnfAlgo::GetCNodeName(kernel);
|
||||
if (kernel_name != kGatherV2OpName && kernel_name != kSparseGatherV2OpName) {
|
||||
continue;
|
||||
}
|
||||
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true);
|
||||
auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1, true);
|
||||
MS_EXCEPTION_IF_NULL(input_param.first);
|
||||
MS_EXCEPTION_IF_NULL(input_index.first);
|
||||
if (!input_param.first->isa<Parameter>()) {
|
||||
continue;
|
||||
}
|
||||
auto param_name = input_param.first->fullname_with_scope();
|
||||
while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == "Cast")) {
|
||||
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second, true);
|
||||
if (ps::ps_cache_instance.IsHashTable(param_name) && (kernel_name == kSparseGatherV2OpName)) {
|
||||
CheckSparsePSEmbeddingCache(kernel);
|
||||
}
|
||||
while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == kCastOpName)) {
|
||||
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true);
|
||||
MS_EXCEPTION_IF_NULL(input_index.first);
|
||||
}
|
||||
if (input_index.first == first_cache_input_index) {
|
||||
|
|
|
@ -138,6 +138,7 @@ class KernelRuntime {
|
|||
void GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *first_cache_input_index,
|
||||
size_t *first_cache_size);
|
||||
void CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph);
|
||||
void CheckSparsePSEmbeddingCache(const CNodePtr &node);
|
||||
#endif
|
||||
|
||||
protected:
|
||||
|
|
|
@ -83,7 +83,7 @@ constexpr auto kScatterNdOpName = "ScatterNd";
|
|||
constexpr auto kStridedSliceAssignOpName = "StridedSliceAssign";
|
||||
constexpr auto kStridedSliceOpName = "StridedSlice";
|
||||
constexpr auto kStridedSliceGradOpName = "StridedSliceGrad";
|
||||
constexpr auto kSparseGatherV2 = "SparseGatherV2";
|
||||
constexpr auto kSparseGatherV2OpName = "SparseGatherV2";
|
||||
constexpr auto kUnsortedSegmentProdOpName = "UnsortedSegmentProd";
|
||||
constexpr auto kUnsortedSegmentMinOpName = "UnsortedSegmentMin";
|
||||
constexpr auto kFlattenGradOpName = "FlattenGrad";
|
||||
|
|
|
@ -73,6 +73,13 @@ inline size_t FloatToSize(float u) {
|
|||
}
|
||||
inline float IntToFloat(int32_t v) { return static_cast<float>(v); }
|
||||
|
||||
inline int FloatToInt(float u) {
|
||||
if (u > static_cast<float>((std::numeric_limits<int>::max)())) {
|
||||
MS_LOG(EXCEPTION) << "The float value(" << u << ") exceeds the maximum value of int.";
|
||||
}
|
||||
return static_cast<int>(u);
|
||||
}
|
||||
|
||||
inline float SizeToFloat(size_t v) { return static_cast<float>(v); }
|
||||
|
||||
inline double LongToDouble(int64_t v) { return static_cast<double>(v); }
|
||||
|
|
|
@ -20,10 +20,12 @@ from mindspore.ops import operations as P
|
|||
from mindspore.ops import functional as F
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.context import ParallelMode, get_context
|
||||
from mindspore.communication.management import get_group_size, get_rank
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch
|
||||
from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker, _get_ps_context
|
||||
from mindspore.parallel._ps_context import _is_role_worker, _get_ps_context
|
||||
from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _set_rank_id
|
||||
from mindspore import context
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.ops.primitive import constexpr
|
||||
|
@ -227,8 +229,6 @@ class EmbeddingLookup(Cell):
|
|||
self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size')
|
||||
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
|
||||
name='embedding_table')
|
||||
if self.cache_enable and enable_ps:
|
||||
self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size)
|
||||
parallel_mode = _get_parallel_mode()
|
||||
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||
self.gather_revert = P.GatherV2()
|
||||
|
@ -238,6 +238,10 @@ class EmbeddingLookup(Cell):
|
|||
self.shape = P.Shape()
|
||||
if is_auto_parallel:
|
||||
self.unique = P.Unique().shard(((1,),))
|
||||
if self.cache_enable and enable_ps:
|
||||
self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size)
|
||||
if is_auto_parallel:
|
||||
self.unique.add_prim_attr('cache_enable', True)
|
||||
indices_shape_size = 2
|
||||
if slice_mode == "field_slice" and is_auto_parallel:
|
||||
if not manual_shapes:
|
||||
|
@ -252,7 +256,7 @@ class EmbeddingLookup(Cell):
|
|||
self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size())))
|
||||
elif slice_mode == "table_row_slice" and is_auto_parallel:
|
||||
full_batch = _get_full_batch()
|
||||
if target == 'DEVICE' and not full_batch:
|
||||
if (target == 'DEVICE' and not full_batch) or (self.cache_enable and enable_ps and sparse):
|
||||
indices_shape_size = 1
|
||||
self.gather_revert.shard(((1, 1), (get_group_size(),)))
|
||||
self.forward_unique = True
|
||||
|
@ -293,7 +297,7 @@ class EmbeddingLookup(Cell):
|
|||
raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target.")
|
||||
if not self.sparse:
|
||||
raise ValueError("The configuration of 'vocab_cache_size' is valid only 'sparse' is true.")
|
||||
if get_context("device_target") != 'Ascend':
|
||||
if context.get_context("device_target") != 'Ascend':
|
||||
raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'ascend'.")
|
||||
|
||||
logger.info("EmbeddingLookup cache enable takes effect.")
|
||||
|
@ -320,21 +324,29 @@ class EmbeddingLookup(Cell):
|
|||
parallel_mode = _get_parallel_mode()
|
||||
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||
if is_auto_parallel:
|
||||
device_num = get_group_size()
|
||||
rank_size = get_group_size()
|
||||
rank_id = get_rank()
|
||||
full_batch = _get_full_batch()
|
||||
if device_num > 1 and not (full_batch and slice_mode == "table_row_slice"):
|
||||
if rank_size > 1 and not (full_batch and slice_mode == "table_row_slice"):
|
||||
raise ValueError("The embeddingLookup cache of parameter server parallel only be used "
|
||||
"in 'full_batch' and 'table_row_slice' parallel strategy.")
|
||||
self.vocab_cache_size = self.vocab_cache_size * device_num
|
||||
self.vocab_cache_size = self.vocab_cache_size * rank_size
|
||||
_set_rank_id(rank_id)
|
||||
self.cache_enable = True
|
||||
if _is_role_worker():
|
||||
self.vocab_size = self.vocab_cache_size
|
||||
if context.get_context("enable_sparse") != self.sparse:
|
||||
raise ValueError("The value of parameter 'sparse' must be same for all EmbeddingLookup "
|
||||
"kernels and equal the value of 'enable_sparse' in context setting in "
|
||||
"parameter server cache mode")
|
||||
|
||||
def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size):
|
||||
"""PS embeddingLookup cache enable set."""
|
||||
self.embedding_table.cache_enable = True
|
||||
self.embedding_table.is_param_ps = True
|
||||
_set_cache_enable(True)
|
||||
if self.sparse:
|
||||
self.forward_unique = True
|
||||
if _is_role_worker():
|
||||
_insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size)
|
||||
|
||||
|
|
|
@ -28,14 +28,15 @@ _lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt")
|
|||
|
||||
|
||||
@_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool")
|
||||
"Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool",
|
||||
"Bool")
|
||||
def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, beta2_power,
|
||||
beta1, beta2, eps, lr, gradient, params, m, v, ps_parameter):
|
||||
beta1, beta2, eps, lr, gradient, params, m, v, ps_parameter, cache_enable):
|
||||
"""Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse."""
|
||||
success = True
|
||||
indices = gradient.indices
|
||||
values = gradient.values
|
||||
if ps_parameter:
|
||||
if ps_parameter and not cache_enable:
|
||||
op_shape = P.Shape()
|
||||
shapes = (op_shape(params), op_shape(m), op_shape(v),
|
||||
op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
|
||||
|
@ -75,12 +76,12 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov,
|
|||
|
||||
|
||||
@_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
|
||||
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
|
||||
def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
|
||||
beta2_power, beta1, beta2, eps, lr, gradient, params, moment1, moment2, ps_parameter):
|
||||
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
|
||||
def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, beta2_power,
|
||||
beta1, beta2, eps, lr, gradient, params, moment1, moment2, ps_parameter, cache_enable):
|
||||
"""Apply lazy adam optimizer to the weight parameter using Tensor."""
|
||||
success = True
|
||||
if ps_parameter:
|
||||
if ps_parameter and not cache_enable:
|
||||
op_shape = P.Shape()
|
||||
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
|
||||
(op_shape(params), op_shape(moment1), op_shape(moment2))), params))
|
||||
|
@ -245,12 +246,14 @@ class LazyAdam(Optimizer):
|
|||
success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
|
||||
self.use_locking, self.use_nesterov, self._is_device,
|
||||
self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps),
|
||||
lr, gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters)
|
||||
lr, gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters,
|
||||
self.cache_enable)
|
||||
else:
|
||||
success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
|
||||
self.use_locking, self.use_nesterov, self._is_device,
|
||||
self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps, lr),
|
||||
gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters)
|
||||
gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters,
|
||||
self.cache_enable)
|
||||
return success
|
||||
|
||||
@Optimizer.target.setter
|
||||
|
|
|
@ -142,3 +142,6 @@ def _set_cache_enable(cache_enable):
|
|||
os.environ['GOTO_NUM_THREADS'] = '2'
|
||||
os.environ['OMP_NUM_THREADS'] = '2'
|
||||
ps_context().set_cache_enable(cache_enable)
|
||||
|
||||
def _set_rank_id(rank_id):
|
||||
ps_context().set_rank_id(rank_id)
|
||||
|
|
|
@ -190,7 +190,10 @@ def _get_python_op(op_name, op_path, instance_name, arglist):
|
|||
"""Get python operator."""
|
||||
module = __import__(op_path, fromlist=["None"])
|
||||
cls = getattr(module, op_name)
|
||||
op = cls(*arglist)
|
||||
if op_path != "mindspore.ops.functional":
|
||||
op = cls(*arglist)
|
||||
else:
|
||||
op = cls
|
||||
op.set_prim_instance_name(instance_name)
|
||||
return op
|
||||
|
||||
|
|
|
@ -17,7 +17,8 @@
|
|||
|
||||
#bash run_parameter_server_train_cluster.sh RANK_SIZE EPOCHS DEVICE_TARGET DATASET
|
||||
# LOCAL_WORKER_NUM LOCAL_SERVER_NUM SERVER_NUM
|
||||
# SCHED_HOST SCHED_PORT ROLE RANK_TABLE_FILE VOCAB_CACHE_SIZE
|
||||
# SCHED_HOST SCHED_PORT ROLE RANK_TABLE_FILE
|
||||
# VOCAB_CACHE_SIZE SPARSE
|
||||
execute_path=$(pwd)
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
|
@ -37,11 +38,16 @@ export MS_SCHED_PORT=$9
|
|||
export MS_ROLE=${10}
|
||||
export RANK_TABLE_FILE=${11}
|
||||
export VOCAB_CACHE_SIZE=${12}
|
||||
export SPARSE=${13}
|
||||
|
||||
if [[ ! -n "${12}" ]]; then
|
||||
export VOCAB_CACHE_SIZE=0
|
||||
fi
|
||||
|
||||
if [[ ! -n "${13}" ]]; then
|
||||
export SPARSE=0
|
||||
fi
|
||||
|
||||
echo "=====Role is $MS_ROLE======"
|
||||
|
||||
if [[ "$MS_ROLE" == "MS_SCHED" ]]; then
|
||||
|
@ -73,7 +79,7 @@ if [[ "$MS_ROLE" == "MS_WORKER" ]]; then
|
|||
mpirun --allow-run-as-root -n $LOCAL_WORKER_NUM --output-filename log_output --merge-stderr-to-stdout \
|
||||
python -s ${self_path}/../train_and_eval_parameter_server_distribute.py \
|
||||
--device_target=$DEVICE --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \
|
||||
--vocab_cache_size=$VOCAB_CACHE_SIZE --dropout_flag=1 >worker.log 2>&1 &
|
||||
--vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=1 >worker.log 2>&1 &
|
||||
else
|
||||
for((i=0;i<$LOCAL_WORKER_NUM;i++));
|
||||
do
|
||||
|
@ -84,7 +90,7 @@ if [[ "$MS_ROLE" == "MS_WORKER" ]]; then
|
|||
export DEVICE_ID=$i
|
||||
python -s ${self_path}/../train_and_eval_parameter_server_distribute.py \
|
||||
--device_target=$DEVICE_TARGET --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \
|
||||
--vocab_cache_size=$VOCAB_CACHE_SIZE --dropout_flag=1 >worker_$i.log 2>&1 &
|
||||
--vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=1 >worker_$i.log 2>&1 &
|
||||
done
|
||||
fi
|
||||
fi
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
#bash run_parameter_server_train_distribute.sh RANK_SIZE EPOCHS DEVICE_TARGET DATASET
|
||||
# SERVER_NUM SCHED_HOST SCHED_PORT RANK_TABLE_FILE
|
||||
# VOCAB_CACHE_SIZE
|
||||
# VOCAB_CACHE_SIZE SPARSE
|
||||
execute_path=$(pwd)
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
|
@ -33,11 +33,16 @@ export MS_SCHED_HOST=$6
|
|||
export MS_SCHED_PORT=$7
|
||||
export RANK_TABLE_FILE=$8
|
||||
export VOCAB_CACHE_SIZE=$9
|
||||
export SPARSE=${10}
|
||||
|
||||
if [[ ! -n "$9" ]]; then
|
||||
export VOCAB_CACHE_SIZE=0
|
||||
fi
|
||||
|
||||
if [[ ! -n "${10}" ]]; then
|
||||
export SPARSE=0
|
||||
fi
|
||||
|
||||
export MS_ROLE=MS_SCHED
|
||||
rm -rf ${execute_path}/sched/
|
||||
mkdir ${execute_path}/sched/
|
||||
|
@ -65,7 +70,7 @@ if [[ "X$DEVICE_TARGET" == "XGPU" ]]; then
|
|||
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
|
||||
python -s ${self_path}/../train_and_eval_parameter_server_distribute.py \
|
||||
--device_target=$DEVICE_TARGET --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \
|
||||
--vocab_cache_size=$VOCAB_CACHE_SIZE --dropout_flag=1 >worker.log 2>&1 &
|
||||
--vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=1 >worker.log 2>&1 &
|
||||
else
|
||||
for((i=0;i<$MS_WORKER_NUM;i++));
|
||||
do
|
||||
|
@ -76,7 +81,7 @@ else
|
|||
export DEVICE_ID=$i
|
||||
python -s ${self_path}/../train_and_eval_parameter_server_distribute.py \
|
||||
--device_target=$DEVICE_TARGET --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \
|
||||
--vocab_cache_size=$VOCAB_CACHE_SIZE --dropout_flag=1 >worker_$i.log 2>&1 &
|
||||
--vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=1 >worker_$i.log 2>&1 &
|
||||
done
|
||||
fi
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
|
||||
#bash run_parameter_server_train_standalone.sh EPOCHS DEVICE_TARGET DATASET SERVER_NUM SCHED_HOST
|
||||
# SCHED_PORT DEVICE_ID VOCAB_CACHE_SIZE
|
||||
# SCHED_PORT DEVICE_ID VOCAB_CACHE_SIZE SPARSE
|
||||
execute_path=$(pwd)
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
|
@ -31,11 +31,16 @@ export MS_SCHED_HOST=$5
|
|||
export MS_SCHED_PORT=$6
|
||||
DEVICE_ID=$7
|
||||
export VOCAB_CACHE_SIZE=$8
|
||||
export SPARSE=$9
|
||||
|
||||
if [[ ! -n "$8" ]]; then
|
||||
export VOCAB_CACHE_SIZE=0
|
||||
fi
|
||||
|
||||
if [[ ! -n "$9" ]]; then
|
||||
export SPARSE=0
|
||||
fi
|
||||
|
||||
# Set device id
|
||||
if [[ "X$DEVICE_TARGET" == "XGPU" ]]; then
|
||||
if [[ ! -n "$DEVICE_ID" ]]; then
|
||||
|
@ -76,4 +81,4 @@ mkdir ${execute_path}/worker/
|
|||
cd ${execute_path}/worker/ || exit
|
||||
python -s ${self_path}/../train_and_eval_parameter_server_standalone.py --device_target=$DEVICE_TARGET \
|
||||
--epochs=$EPOCH_SIZE --data_path=$DATASET --parameter_server=1 \
|
||||
--vocab_cache_size=$VOCAB_CACHE_SIZE --dropout_flag=1 >worker.log 2>&1 &
|
||||
--vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=1 >worker.log 2>&1 &
|
||||
|
|
|
@ -115,8 +115,11 @@ class EvalCallBack(Callback):
|
|||
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL,
|
||||
ParallelMode.DATA_PARALLEL):
|
||||
rank_id = get_rank()
|
||||
enable_data_sink = not self.sparse
|
||||
if bool(self.config.parameter_server):
|
||||
enable_data_sink = True
|
||||
start_time = time.time()
|
||||
out = self.model.eval(self.eval_dataset, dataset_sink_mode=(not self.sparse))
|
||||
out = self.model.eval(self.eval_dataset, dataset_sink_mode=enable_data_sink)
|
||||
end_time = time.time()
|
||||
eval_time = int(end_time - start_time)
|
||||
|
||||
|
|
|
@ -202,7 +202,7 @@ class WideDeepModel(nn.Cell):
|
|||
self.unique = P.Unique().shard(((1,),))
|
||||
self.wide_gatherv2 = P.GatherV2()
|
||||
self.deep_gatherv2 = P.GatherV2()
|
||||
if is_auto_parallel and sparse and not is_field_slice:
|
||||
if is_auto_parallel and sparse and not is_field_slice and not parameter_server:
|
||||
target = 'DEVICE'
|
||||
if host_device_mix:
|
||||
target = 'CPU'
|
||||
|
@ -376,12 +376,12 @@ class TrainStepWrap(nn.Cell):
|
|||
self.weights_w = ParameterTuple(weights_w)
|
||||
self.weights_d = ParameterTuple(weights_d)
|
||||
|
||||
if (sparse and is_auto_parallel) or (parameter_server and not cache_enable):
|
||||
if (sparse and is_auto_parallel) or (sparse and parameter_server):
|
||||
self.optimizer_d = LazyAdam(
|
||||
self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens)
|
||||
self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w,
|
||||
l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens)
|
||||
if host_device_mix or parameter_server:
|
||||
if host_device_mix or (parameter_server and not cache_enable):
|
||||
self.optimizer_w.target = "CPU"
|
||||
self.optimizer_d.target = "CPU"
|
||||
else:
|
||||
|
|
|
@ -43,7 +43,7 @@ def get_wide_deep_net(config):
|
|||
if cache_enable:
|
||||
loss_net = VirtualDatasetCellTriple(loss_net)
|
||||
train_net = TrainStepWrap(loss_net, parameter_server=bool(config.parameter_server),
|
||||
cache_enable=(config.vocab_cache_size > 0))
|
||||
sparse=config.sparse, cache_enable=(config.vocab_cache_size > 0))
|
||||
eval_net = PredictWithSigmoid(wide_deep_net)
|
||||
if cache_enable:
|
||||
eval_net = VirtualDatasetCellTriple(eval_net)
|
||||
|
@ -138,7 +138,7 @@ def train_and_eval(config):
|
|||
callback_list.append(ckpoint_cb)
|
||||
model.train(epochs, ds_train,
|
||||
callbacks=callback_list,
|
||||
dataset_sink_mode=bool(parameter_server and cache_enable))
|
||||
dataset_sink_mode=(parameter_server and cache_enable))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -148,7 +148,6 @@ if __name__ == "__main__":
|
|||
cache_enable = wide_deep_config.vocab_cache_size > 0
|
||||
if cache_enable and wide_deep_config.device_target != "GPU":
|
||||
context.set_context(variable_memory_max_size="24GB")
|
||||
context.set_context(enable_sparse=True)
|
||||
context.set_ps_context(enable_ps=True)
|
||||
init()
|
||||
context.set_context(save_graphs_path='./graphs_of_device_id_'+str(get_rank()))
|
||||
|
@ -159,5 +158,8 @@ if __name__ == "__main__":
|
|||
else:
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||
device_num=get_group_size())
|
||||
wide_deep_config.sparse = True
|
||||
|
||||
if wide_deep_config.sparse:
|
||||
context.set_context(enable_sparse=True)
|
||||
train_and_eval(wide_deep_config)
|
||||
|
|
|
@ -29,7 +29,6 @@ from src.metrics import AUCMetric
|
|||
from src.config import WideDeepConfig
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
context.set_context(enable_sparse=True)
|
||||
|
||||
|
||||
def get_wide_deep_net(config):
|
||||
|
@ -39,7 +38,7 @@ def get_wide_deep_net(config):
|
|||
wide_deep_net = WideDeepModel(config)
|
||||
loss_net = NetWithLossClass(wide_deep_net, config)
|
||||
train_net = TrainStepWrap(loss_net, parameter_server=bool(config.parameter_server),
|
||||
cache_enable=(config.vocab_cache_size > 0))
|
||||
sparse=config.sparse, cache_enable=(config.vocab_cache_size > 0))
|
||||
eval_net = PredictWithSigmoid(wide_deep_net)
|
||||
return train_net, eval_net
|
||||
|
||||
|
@ -81,7 +80,6 @@ def train_and_eval(config):
|
|||
else:
|
||||
dataset_type = DataType.H5
|
||||
parameter_server = bool(config.parameter_server)
|
||||
cache_enable = config.vocab_cache_size > 0
|
||||
print("epochs is {}".format(epochs))
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
|
||||
batch_size=batch_size, data_type=dataset_type)
|
||||
|
@ -121,6 +119,11 @@ if __name__ == "__main__":
|
|||
wide_deep_config.argparse_init()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target, save_graphs=True)
|
||||
cache_enable = wide_deep_config.vocab_cache_size > 0
|
||||
if not cache_enable:
|
||||
wide_deep_config.sparse = True
|
||||
if wide_deep_config.sparse:
|
||||
context.set_context(enable_sparse=True)
|
||||
context.set_ps_context(enable_ps=True)
|
||||
|
||||
train_and_eval(wide_deep_config)
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
callbacks
|
||||
"""
|
||||
import time
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore import context
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import get_rank
|
||||
|
||||
def add_write(file_path, out_str):
|
||||
"""
|
||||
add lines to the file
|
||||
"""
|
||||
with open(file_path, 'a+', encoding="utf-8") as file_out:
|
||||
file_out.write(out_str + "\n")
|
||||
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
|
||||
If the loss is NAN or INF, terminate the training.
|
||||
|
||||
Note:
|
||||
If per_print_times is 0, do NOT print loss.
|
||||
If this process is MS_PSERVER role, do not run callbacks.
|
||||
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
def __init__(self, config=None, per_print_times=1):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("per_print_times must be in and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
self.config = config
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""Monitor the loss in training."""
|
||||
cb_params = run_context.original_args()
|
||||
if cb_params.net_outputs is None:
|
||||
return
|
||||
wide_loss, deep_loss = cb_params.net_outputs[0].asnumpy(), cb_params.net_outputs[1].asnumpy()
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
cur_num = cb_params.cur_step_num
|
||||
rank_id = 0
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL,
|
||||
ParallelMode.DATA_PARALLEL):
|
||||
rank_id = get_rank()
|
||||
|
||||
print("===loss===", rank_id, cb_params.cur_epoch_num, cur_step_in_epoch,
|
||||
wide_loss, deep_loss, flush=True)
|
||||
|
||||
# raise ValueError
|
||||
if self._per_print_times != 0 and cur_num % self._per_print_times == 0 and self.config is not None:
|
||||
loss_file = open(self.config.loss_file_name, "a+")
|
||||
loss_file.write("epoch: %s, step: %s, wide_loss: %s, deep_loss: %s" %
|
||||
(cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss))
|
||||
loss_file.write("\n")
|
||||
loss_file.close()
|
||||
print("epoch: %s, step: %s, wide_loss: %s, deep_loss: %s" %
|
||||
(cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss))
|
||||
|
||||
|
||||
class EvalCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in evaluating.
|
||||
|
||||
If the loss is NAN or INF, terminate evaluating.
|
||||
|
||||
Note:
|
||||
If per_print_times is 0, do NOT print loss.
|
||||
|
||||
Args:
|
||||
print_per_step (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1):
|
||||
super(EvalCallBack, self).__init__()
|
||||
if not isinstance(print_per_step, int) or print_per_step < 0:
|
||||
raise ValueError("print_per_step must be int and >= 0.")
|
||||
self.print_per_step = print_per_step
|
||||
self.model = model
|
||||
self.eval_dataset = eval_dataset
|
||||
self.aucMetric = auc_metric
|
||||
self.aucMetric.clear()
|
||||
self.eval_file_name = config.eval_file_name
|
||||
self.eval_values = []
|
||||
self.sparse = config.sparse
|
||||
self.config = config
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""
|
||||
epoch end
|
||||
"""
|
||||
self.aucMetric.clear()
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
context.set_auto_parallel_context(strategy_ckpt_save_file="",
|
||||
strategy_ckpt_load_file=self.config.stra_ckpt)
|
||||
rank_id = 0
|
||||
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL,
|
||||
ParallelMode.DATA_PARALLEL):
|
||||
rank_id = get_rank()
|
||||
start_time = time.time()
|
||||
out = self.model.eval(self.eval_dataset, dataset_sink_mode=(not self.sparse))
|
||||
end_time = time.time()
|
||||
eval_time = int(end_time - start_time)
|
||||
|
||||
time_str = time.strftime("%Y-%m-%d %H:%M%S", time.localtime())
|
||||
out_str = "{} == Rank: {} == EvalCallBack model.eval(): {}; eval_time: {}s".\
|
||||
format(time_str, rank_id, out.values(), eval_time)
|
||||
print(out_str)
|
||||
self.eval_values = out.values()
|
||||
add_write(self.eval_file_name, out_str)
|
|
@ -34,6 +34,7 @@ cp -r ${CODE_DIR} ${BASE_PATH}/wide_and_deep
|
|||
cp -f ${BASE_PATH}/python_file_for_ci/train_and_test_multinpu_ci.py ${BASE_PATH}/wide_and_deep/train_and_test_multinpu_ci.py
|
||||
cp -f ${BASE_PATH}/python_file_for_ci/__init__.py ${BASE_PATH}/wide_and_deep/__init__.py
|
||||
cp -f ${BASE_PATH}/python_file_for_ci/config.py ${BASE_PATH}/wide_and_deep/src/config.py
|
||||
cp -f ${BASE_PATH}/python_file_for_ci/callbacks.py ${BASE_PATH}/wide_and_deep/src/callbacks.py
|
||||
cp -f ${BASE_PATH}/python_file_for_ci/datasets.py ${BASE_PATH}/wide_and_deep/src/datasets.py
|
||||
cp -f ${BASE_PATH}/python_file_for_ci/wide_and_deep.py ${BASE_PATH}/wide_and_deep/src/wide_and_deep.py
|
||||
source ${BASE_PATH}/env.sh
|
||||
|
@ -55,7 +56,7 @@ for((i=0; i<${DEVICE_NUM}; i++)); do
|
|||
wait ${process_pid[i]}
|
||||
status=`echo $?`
|
||||
if [ "${status}" != "0" ]; then
|
||||
echo "[ERROR] test wide_and_deep semi auto parallel failed. status: ${status}"
|
||||
echo "[ERROR] test wide_and_deep semi auto parallel failed. status: ${status}"
|
||||
exit 1
|
||||
else
|
||||
echo "[INFO] test wide_and_deep semi auto parallel success."
|
||||
|
|
Loading…
Reference in New Issue