forked from mindspore-Ecosystem/mindspore
ps cache support save checkpoint
This commit is contained in:
parent
5832bf0c3d
commit
4269dcece5
|
@ -1085,10 +1085,10 @@ void ClearResAtexit() {
|
||||||
session::ClearPythonParasMap();
|
session::ClearPythonParasMap();
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||||
if (ps::Util::IsParamServerMode() && ps::Util::IsRoleOfWorker()) {
|
if (ps::Util::IsParamServerMode() && ps::Util::IsRoleOfWorker()) {
|
||||||
ps::worker.Finalize();
|
|
||||||
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||||
ps::ps_cache_instance.Finalize();
|
ps::ps_cache_instance.Finalize();
|
||||||
}
|
}
|
||||||
|
ps::worker.Finalize();
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
ad::g_k_prims.clear();
|
ad::g_k_prims.clear();
|
||||||
|
|
|
@ -552,7 +552,6 @@ template <typename T>
|
||||||
void ParameterServer<T>::Finalize() {
|
void ParameterServer<T>::Finalize() {
|
||||||
running_ = false;
|
running_ = false;
|
||||||
apply_grads_cv_.notify_one();
|
apply_grads_cv_.notify_one();
|
||||||
SyncEmbeddingTables();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -774,7 +773,7 @@ void ParameterServer<T>::GetEmbeddingTableParamPtr() {
|
||||||
for (auto cnode : cnodes) {
|
for (auto cnode : cnodes) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
|
std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
|
||||||
if (cnode_name == kEmbeddingLookupOpName) {
|
if (cnode_name == kEmbeddingLookupOpName || cnode_name == kGatherV2OpName) {
|
||||||
auto embedding_table = AnfAlgo::GetInputNode(cnode, 0);
|
auto embedding_table = AnfAlgo::GetInputNode(cnode, 0);
|
||||||
MS_EXCEPTION_IF_NULL(embedding_table);
|
MS_EXCEPTION_IF_NULL(embedding_table);
|
||||||
MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count;
|
MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count;
|
||||||
|
@ -832,6 +831,7 @@ void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
|
||||||
Init(func_graph);
|
Init(func_graph);
|
||||||
PSContext::instance()->SetPSRankId(rank_id_);
|
PSContext::instance()->SetPSRankId(rank_id_);
|
||||||
thread_->join();
|
thread_->join();
|
||||||
|
SyncEmbeddingTables();
|
||||||
MS_LOG(INFO) << "PServer finished updating models, starts finalizing...";
|
MS_LOG(INFO) << "PServer finished updating models, starts finalizing...";
|
||||||
::ps::Finalize(0, true);
|
::ps::Finalize(0, true);
|
||||||
MS_LOG(INFO) << "PServer finalized successfully.";
|
MS_LOG(INFO) << "PServer finalized successfully.";
|
||||||
|
|
|
@ -30,21 +30,21 @@ int EmbeddingHashMap::ParseData(const int id, int *swap_out_index, int *swap_out
|
||||||
if (loop++ == hash_capacity_) {
|
if (loop++ == hash_capacity_) {
|
||||||
return INVALID_INDEX_VALUE;
|
return INVALID_INDEX_VALUE;
|
||||||
}
|
}
|
||||||
if (hash_map_unit_[hash_index].IsEmpty()) {
|
if (hash_map_elements_[hash_index].IsEmpty()) {
|
||||||
hash_count_++;
|
hash_count_++;
|
||||||
(void)hash_id_to_index_.emplace(id, hash_index);
|
(void)hash_id_to_index_.emplace(id, hash_index);
|
||||||
hash_map_unit_[hash_index].set_id(id);
|
hash_map_elements_[hash_index].set_id(id);
|
||||||
hash_map_unit_[hash_index].set_step(data_step);
|
hash_map_elements_[hash_index].set_step(data_step);
|
||||||
return hash_index;
|
return hash_index;
|
||||||
} else if (need_swap && hash_map_unit_[hash_index].IsExpired(graph_running_step)) {
|
} else if (need_swap && hash_map_elements_[hash_index].IsExpired(graph_running_step)) {
|
||||||
// Need swap out from the hash table.
|
// Need swap out from the hash table.
|
||||||
swap_out_index[*swap_out_size] = hash_index;
|
swap_out_index[*swap_out_size] = hash_index;
|
||||||
swap_out_ids[*swap_out_size] = hash_map_unit_[hash_index].id_;
|
swap_out_ids[*swap_out_size] = hash_map_elements_[hash_index].id_;
|
||||||
(*swap_out_size)++;
|
(*swap_out_size)++;
|
||||||
(void)hash_id_to_index_.erase(hash_map_unit_[hash_index].id_);
|
(void)hash_id_to_index_.erase(hash_map_elements_[hash_index].id_);
|
||||||
(void)hash_id_to_index_.emplace(id, hash_index);
|
(void)hash_id_to_index_.emplace(id, hash_index);
|
||||||
hash_map_unit_[hash_index].set_id(id);
|
hash_map_elements_[hash_index].set_id(id);
|
||||||
hash_map_unit_[hash_index].set_step(data_step);
|
hash_map_elements_[hash_index].set_step(data_step);
|
||||||
return hash_index;
|
return hash_index;
|
||||||
}
|
}
|
||||||
hash_index = (hash_index + 1) % hash_capacity_;
|
hash_index = (hash_index + 1) % hash_capacity_;
|
||||||
|
@ -58,9 +58,10 @@ void EmbeddingHashMap::DumpHashMap() {
|
||||||
MS_LOG(INFO) << " id: " << iter->first << " index: " << iter->second;
|
MS_LOG(INFO) << " id: " << iter->first << " index: " << iter->second;
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Dump hash_map_unit: ";
|
MS_LOG(INFO) << "Dump hash_map_unit: ";
|
||||||
for (size_t i = 0; i < hash_map_unit_.size(); i++) {
|
for (size_t i = 0; i < hash_map_elements_.size(); i++) {
|
||||||
if (!hash_map_unit_[i].IsEmpty()) {
|
if (!hash_map_elements_[i].IsEmpty()) {
|
||||||
MS_LOG(INFO) << " index: " << i << " id: " << hash_map_unit_[i].id_ << " step: " << hash_map_unit_[i].step_;
|
MS_LOG(INFO) << " index: " << i << " id: " << hash_map_elements_[i].id_
|
||||||
|
<< " step: " << hash_map_elements_[i].step_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Dump hash map info end.";
|
MS_LOG(INFO) << "Dump hash map info end.";
|
||||||
|
|
|
@ -30,8 +30,8 @@ static const size_t INVALID_STEP_VALUE = 0;
|
||||||
static const int INVALID_INDEX_VALUE = -1;
|
static const int INVALID_INDEX_VALUE = -1;
|
||||||
|
|
||||||
struct HashMapElement {
|
struct HashMapElement {
|
||||||
int id_;
|
int id_{INVALID_INDEX_VALUE};
|
||||||
size_t step_;
|
size_t step_{INVALID_STEP_VALUE};
|
||||||
bool IsEmpty() const { return step_ == INVALID_STEP_VALUE; }
|
bool IsEmpty() const { return step_ == INVALID_STEP_VALUE; }
|
||||||
bool IsExpired(size_t graph_running_step) const { return graph_running_step > step_; }
|
bool IsExpired(size_t graph_running_step) const { return graph_running_step > step_; }
|
||||||
void set_id(int id) { id_ = id; }
|
void set_id(int id) { id_ = id; }
|
||||||
|
@ -42,7 +42,7 @@ struct HashMapElement {
|
||||||
class EmbeddingHashMap {
|
class EmbeddingHashMap {
|
||||||
public:
|
public:
|
||||||
EmbeddingHashMap(size_t hash_count, size_t hash_capacity) : hash_count_(hash_count), hash_capacity_(hash_capacity) {
|
EmbeddingHashMap(size_t hash_count, size_t hash_capacity) : hash_count_(hash_count), hash_capacity_(hash_capacity) {
|
||||||
hash_map_unit_.resize(hash_capacity);
|
hash_map_elements_.resize(hash_capacity);
|
||||||
}
|
}
|
||||||
virtual ~EmbeddingHashMap() = default;
|
virtual ~EmbeddingHashMap() = default;
|
||||||
int ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step,
|
int ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step,
|
||||||
|
@ -51,8 +51,10 @@ class EmbeddingHashMap {
|
||||||
bool IsIdExist(const std::unordered_map<int, int>::const_iterator iter) const {
|
bool IsIdExist(const std::unordered_map<int, int>::const_iterator iter) const {
|
||||||
return iter != hash_id_to_index_.end();
|
return iter != hash_id_to_index_.end();
|
||||||
}
|
}
|
||||||
size_t hash_step(const int hash_index) const { return hash_map_unit_[hash_index].step_; }
|
size_t hash_step(const int hash_index) const { return hash_map_elements_[hash_index].step_; }
|
||||||
void set_hash_step(const int hash_index, const size_t step) { hash_map_unit_[hash_index].set_step(step); }
|
void set_hash_step(const int hash_index, const size_t step) { hash_map_elements_[hash_index].set_step(step); }
|
||||||
|
const std::unordered_map<int, int> &hash_id_to_index() const { return hash_id_to_index_; }
|
||||||
|
size_t hash_capacity() const { return hash_capacity_; }
|
||||||
void DumpHashMap();
|
void DumpHashMap();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -60,7 +62,7 @@ class EmbeddingHashMap {
|
||||||
bool NeedSwap() const { return hash_count_ > FloatToSize(hash_capacity_ * 0.9); }
|
bool NeedSwap() const { return hash_count_ > FloatToSize(hash_capacity_ * 0.9); }
|
||||||
size_t hash_count_;
|
size_t hash_count_;
|
||||||
size_t hash_capacity_;
|
size_t hash_capacity_;
|
||||||
std::vector<HashMapElement> hash_map_unit_;
|
std::vector<HashMapElement> hash_map_elements_;
|
||||||
std::unordered_map<int, int> hash_id_to_index_;
|
std::unordered_map<int, int> hash_id_to_index_;
|
||||||
};
|
};
|
||||||
} // namespace ps
|
} // namespace ps
|
||||||
|
|
|
@ -226,9 +226,9 @@ void PsCacheManager::AllocMemForHashTable() {
|
||||||
device_address.addr = addr;
|
device_address.addr = addr;
|
||||||
|
|
||||||
auto &host_address = item.second.host_address;
|
auto &host_address = item.second.host_address;
|
||||||
auto host_address_ptr = new int[host_cache_vocab_size_ * embedding_size];
|
auto host_address_ptr = new float[host_cache_vocab_size_ * embedding_size];
|
||||||
MS_EXCEPTION_IF_NULL(host_address_ptr);
|
MS_EXCEPTION_IF_NULL(host_address_ptr);
|
||||||
host_address = std::shared_ptr<int[]>(host_address_ptr, std::default_delete<int[]>());
|
host_address = std::shared_ptr<float[]>(host_address_ptr, std::default_delete<float[]>());
|
||||||
MS_EXCEPTION_IF_NULL(host_address);
|
MS_EXCEPTION_IF_NULL(host_address);
|
||||||
|
|
||||||
max_embedding_size = (embedding_size > max_embedding_size) ? embedding_size : max_embedding_size;
|
max_embedding_size = (embedding_size > max_embedding_size) ? embedding_size : max_embedding_size;
|
||||||
|
@ -330,6 +330,14 @@ void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void PsCacheManager::Finalize() {
|
void PsCacheManager::Finalize() {
|
||||||
|
if (running_) {
|
||||||
|
if (!SyncHostEmbeddingTable()) {
|
||||||
|
MS_LOG(ERROR) << "SyncHostEmbeddingTable failed.";
|
||||||
|
}
|
||||||
|
if (!SyncDeviceEmbeddingTable()) {
|
||||||
|
MS_LOG(ERROR) << "SyncDeviceEmbeddingTable failed.";
|
||||||
|
}
|
||||||
|
}
|
||||||
running_ = false;
|
running_ = false;
|
||||||
PsDataPrefetch::GetInstance().NotifyFinalize();
|
PsDataPrefetch::GetInstance().NotifyFinalize();
|
||||||
insert_init_info_.notify_all();
|
insert_init_info_.notify_all();
|
||||||
|
@ -838,6 +846,99 @@ bool PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_da
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool PsCacheManager::SyncHostEmbeddingTable() {
|
||||||
|
MS_ERROR_IF_NULL(embedding_host_cache_);
|
||||||
|
const auto &hash_id_to_index = embedding_host_cache_->host_hash_map_->hash_id_to_index();
|
||||||
|
size_t swap_indices_lens = hash_id_to_index.size();
|
||||||
|
if (swap_indices_lens == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
std::unique_ptr<int[]> host_to_server_ids_ptr = std::make_unique<int[]>(swap_indices_lens);
|
||||||
|
MS_ERROR_IF_NULL(host_to_server_ids_ptr);
|
||||||
|
std::unique_ptr<int[]> host_to_server_indices_ptr = std::make_unique<int[]>(swap_indices_lens);
|
||||||
|
MS_ERROR_IF_NULL(host_to_server_indices_ptr);
|
||||||
|
size_t idx = 0;
|
||||||
|
for (const auto &item : hash_id_to_index) {
|
||||||
|
host_to_server_ids_ptr[idx] = item.first;
|
||||||
|
host_to_server_indices_ptr[idx++] = item.second;
|
||||||
|
}
|
||||||
|
for (const auto &item : hash_tables_) {
|
||||||
|
const auto &hash_info = item.second;
|
||||||
|
if (hash_info.param_init_info_.param_type_ != kWeight) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto key = worker.GetParamKey(item.first);
|
||||||
|
::ps::SArray<int> lookup_ids(swap_indices_lens, 0);
|
||||||
|
::ps::SArray<float> swap_out_data;
|
||||||
|
auto embedding_size = hash_info.embedding_size;
|
||||||
|
swap_out_data.resize(swap_indices_lens * embedding_size);
|
||||||
|
auto host_hash_table_addr = hash_info.host_address.get();
|
||||||
|
MS_ERROR_IF_NULL(host_hash_table_addr);
|
||||||
|
RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_lens, host_hash_table_addr,
|
||||||
|
host_to_server_indices_ptr.get(), swap_out_data.data()));
|
||||||
|
|
||||||
|
auto copy_len = swap_indices_lens * sizeof(int);
|
||||||
|
auto ret = memcpy_s(lookup_ids.data(), copy_len, host_to_server_ids_ptr.get(), copy_len);
|
||||||
|
if (ret != EOK) {
|
||||||
|
MS_LOG(ERROR) << "Lookup id memcpy failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PsCacheManager::SyncDeviceEmbeddingTable() {
|
||||||
|
MS_ERROR_IF_NULL(embedding_device_cache_);
|
||||||
|
const auto &device_hash_map = embedding_device_cache_->device_hash_map_;
|
||||||
|
const auto &hash_id_to_index = device_hash_map->hash_id_to_index();
|
||||||
|
size_t swap_indices_lens = hash_id_to_index.size();
|
||||||
|
if (swap_indices_lens == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
std::unique_ptr<int[]> device_to_server_ids_ptr = std::make_unique<int[]>(swap_indices_lens);
|
||||||
|
MS_ERROR_IF_NULL(device_to_server_ids_ptr);
|
||||||
|
std::unique_ptr<int[]> device_to_server_indices_ptr = std::make_unique<int[]>(swap_indices_lens);
|
||||||
|
MS_ERROR_IF_NULL(device_to_server_indices_ptr);
|
||||||
|
size_t idx = 0;
|
||||||
|
for (const auto &item : hash_id_to_index) {
|
||||||
|
device_to_server_ids_ptr[idx] = item.first;
|
||||||
|
device_to_server_indices_ptr[idx++] = item.second;
|
||||||
|
}
|
||||||
|
for (const auto &item : hash_tables_) {
|
||||||
|
const auto &hash_info = item.second;
|
||||||
|
if (hash_info.param_init_info_.param_type_ != kWeight) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto key = worker.GetParamKey(item.first);
|
||||||
|
::ps::SArray<int> lookup_ids(swap_indices_lens, 0);
|
||||||
|
::ps::SArray<float> swap_out_data;
|
||||||
|
auto embedding_size = hash_info.embedding_size;
|
||||||
|
swap_out_data.resize(swap_indices_lens * embedding_size);
|
||||||
|
std::unique_ptr<float[]> device_hash_table_addr_tmp =
|
||||||
|
std::make_unique<float[]>(device_hash_map->hash_capacity() * embedding_size);
|
||||||
|
MS_ERROR_IF_NULL(device_hash_table_addr_tmp);
|
||||||
|
|
||||||
|
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
|
||||||
|
MS_ERROR_IF_NULL(hash_table_addr);
|
||||||
|
auto hash_table_size = hash_info.device_address.size;
|
||||||
|
RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost(device_hash_table_addr_tmp.get(),
|
||||||
|
hash_table_addr, hash_table_size));
|
||||||
|
RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream());
|
||||||
|
RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_lens, device_hash_table_addr_tmp.get(),
|
||||||
|
device_to_server_indices_ptr.get(), swap_out_data.data()));
|
||||||
|
|
||||||
|
auto copy_len = swap_indices_lens * sizeof(int);
|
||||||
|
auto ret = memcpy_s(lookup_ids.data(), copy_len, device_to_server_ids_ptr.get(), copy_len);
|
||||||
|
if (ret != EOK) {
|
||||||
|
MS_LOG(ERROR) << "Lookup id memcpy failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
void PsCacheManager::DumpHashTables(bool dump_device_tables) const {
|
void PsCacheManager::DumpHashTables(bool dump_device_tables) const {
|
||||||
for (const auto &item : hash_tables_) {
|
for (const auto &item : hash_tables_) {
|
||||||
const auto ¶m_name = item.first;
|
const auto ¶m_name = item.first;
|
||||||
|
|
|
@ -48,7 +48,7 @@ struct HashTableInfo {
|
||||||
size_t embedding_size{0};
|
size_t embedding_size{0};
|
||||||
size_t vocab_size{0};
|
size_t vocab_size{0};
|
||||||
Address device_address{nullptr, 0};
|
Address device_address{nullptr, 0};
|
||||||
std::shared_ptr<int[]> host_address{nullptr};
|
std::shared_ptr<float[]> host_address{nullptr};
|
||||||
ParamInitInfo param_init_info_;
|
ParamInitInfo param_init_info_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -166,6 +166,8 @@ class PsCacheManager {
|
||||||
bool CheckFinishInsertInitInfo() const;
|
bool CheckFinishInsertInitInfo() const;
|
||||||
void AddEmbeddingTable() const;
|
void AddEmbeddingTable() const;
|
||||||
void DumpStatisticsInfo(size_t each_print_step = 1000);
|
void DumpStatisticsInfo(size_t each_print_step = 1000);
|
||||||
|
bool SyncHostEmbeddingTable();
|
||||||
|
bool SyncDeviceEmbeddingTable();
|
||||||
|
|
||||||
bool initialized_ps_cache_{false};
|
bool initialized_ps_cache_{false};
|
||||||
std::string channel_name_;
|
std::string channel_name_;
|
||||||
|
|
|
@ -205,6 +205,7 @@ constexpr auto kPushOpName = "Push";
|
||||||
constexpr auto kPullOpName = "Pull";
|
constexpr auto kPullOpName = "Pull";
|
||||||
constexpr auto kEmbeddingLookupOpName = "EmbeddingLookup";
|
constexpr auto kEmbeddingLookupOpName = "EmbeddingLookup";
|
||||||
constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy";
|
constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy";
|
||||||
|
constexpr auto kGatherV2OpName = "GatherV2";
|
||||||
constexpr auto kPaddingOpName = "Padding";
|
constexpr auto kPaddingOpName = "Padding";
|
||||||
constexpr auto kAvgPoolOpName = "AvgPool";
|
constexpr auto kAvgPoolOpName = "AvgPool";
|
||||||
constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu";
|
constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu";
|
||||||
|
|
|
@ -292,7 +292,8 @@ class EmbeddingLookup(Cell):
|
||||||
"in 'full_batch' and 'table_row_slice' parallel strategy.")
|
"in 'full_batch' and 'table_row_slice' parallel strategy.")
|
||||||
self.vocab_cache_size = self.vocab_cache_size * device_num
|
self.vocab_cache_size = self.vocab_cache_size * device_num
|
||||||
self.cache_enable = True
|
self.cache_enable = True
|
||||||
self.vocab_size = self.vocab_cache_size
|
if _is_role_worker():
|
||||||
|
self.vocab_size = self.vocab_cache_size
|
||||||
|
|
||||||
def _set_voacb_cache_enable(self, vocab_cache_size, embedding_size, vocab_size):
|
def _set_voacb_cache_enable(self, vocab_cache_size, embedding_size, vocab_size):
|
||||||
"""PS embeddingLookup cache enable set."""
|
"""PS embeddingLookup cache enable set."""
|
||||||
|
|
|
@ -24,6 +24,7 @@ from mindspore.context import ParallelMode
|
||||||
from mindspore.communication.management import get_rank, get_group_size, init
|
from mindspore.communication.management import get_rank, get_group_size, init
|
||||||
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
|
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
|
||||||
from mindspore.common import set_seed
|
from mindspore.common import set_seed
|
||||||
|
from mindspore.parallel._ps_context import _is_role_worker
|
||||||
|
|
||||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||||
from src.callbacks import LossCallBack, EvalCallBack
|
from src.callbacks import LossCallBack, EvalCallBack
|
||||||
|
@ -117,11 +118,14 @@ def train_and_eval(config):
|
||||||
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
|
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
|
||||||
|
|
||||||
callback = LossCallBack(config=config)
|
callback = LossCallBack(config=config)
|
||||||
if cache_enable:
|
if _is_role_worker():
|
||||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs,
|
if cache_enable:
|
||||||
keep_checkpoint_max=5, integrated_save=False)
|
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs,
|
||||||
|
keep_checkpoint_max=1, integrated_save=False)
|
||||||
|
else:
|
||||||
|
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
|
||||||
else:
|
else:
|
||||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
|
ckptconfig = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=1)
|
||||||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
||||||
directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/',
|
directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/',
|
||||||
config=ckptconfig)
|
config=ckptconfig)
|
||||||
|
|
|
@ -20,6 +20,7 @@ import sys
|
||||||
from mindspore import Model, context
|
from mindspore import Model, context
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||||
from mindspore.common import set_seed
|
from mindspore.common import set_seed
|
||||||
|
from mindspore.parallel._ps_context import _is_role_worker
|
||||||
|
|
||||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||||
from src.callbacks import LossCallBack, EvalCallBack
|
from src.callbacks import LossCallBack, EvalCallBack
|
||||||
|
@ -99,7 +100,14 @@ def train_and_eval(config):
|
||||||
|
|
||||||
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
|
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
|
||||||
callback = LossCallBack(config=config)
|
callback = LossCallBack(config=config)
|
||||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
|
if _is_role_worker():
|
||||||
|
if cache_enable:
|
||||||
|
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size() * epochs,
|
||||||
|
keep_checkpoint_max=1)
|
||||||
|
else:
|
||||||
|
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
|
||||||
|
else:
|
||||||
|
ckptconfig = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=1)
|
||||||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig)
|
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig)
|
||||||
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb]
|
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue