forked from mindspore-Ecosystem/mindspore
!43487 实现Embedding Stroe相关接口及周边适配
Merge pull request !43487 from larryzhangyong/emb-store-impl-withut
This commit is contained in:
commit
203eea3362
|
@ -27,12 +27,16 @@ class EmbeddingCache {
|
|||
EmbeddingCache() = default;
|
||||
virtual ~EmbeddingCache() = default;
|
||||
|
||||
virtual bool Initialize() = 0;
|
||||
virtual bool Finalize() = 0;
|
||||
|
||||
// Get values which is indexed by keys at input. Input is a tensor data address from Parameter of embedding.
|
||||
virtual bool Get(void *input, size_t key_num, const void *keys, void *values) = 0;
|
||||
virtual bool Get(const void *input, size_t key_num, const void *keys, void *values, size_t *miss_num, void *miss_keys,
|
||||
size_t *miss_indices) = 0;
|
||||
|
||||
// Put values which is indexed by keys to input. Input is a tensor data address from Parameter of embedding.
|
||||
// When input is full, save the evicted values and keys.
|
||||
virtual bool Put(void *input, size_t key_num, const void *keys, const void *values, size_t evicted_num,
|
||||
virtual bool Put(void *input, size_t key_num, const void *keys, const void *values, size_t *evicted_num,
|
||||
void *evicted_keys, void *evicted_values) = 0;
|
||||
|
||||
// Check if cache is full.
|
||||
|
|
|
@ -45,8 +45,8 @@ static constexpr size_t kNumberBase = 10;
|
|||
|
||||
static constexpr size_t kOneGBBitNum = 30;
|
||||
|
||||
// The default cache size of one embedding parameter on role of server: 100GB.
|
||||
static constexpr size_t kDefaultEmbeddingRemoteCacheMemorySize = size_t(100) << 30;
|
||||
// The default cache size of one embedding parameter on role of server: 1TB.
|
||||
static constexpr size_t kDefaultEmbeddingRemoteCacheMemorySize = size_t(1) << 40;
|
||||
|
||||
// The default cache size of one embedding parameter on role of worker: 10GB.
|
||||
static constexpr size_t kDefaultEmbeddingLocalCacheMemorySize = size_t(10) << 30;
|
||||
|
@ -276,10 +276,10 @@ class BACKEND_EXPORT EmbeddingStoreManager {
|
|||
public:
|
||||
static EmbeddingStoreManager &GetInstance();
|
||||
|
||||
void Add(const std::string &name, std::shared_ptr<EmbeddingStore<size_t, float_t>> emb_store) {
|
||||
void Add(const std::string &name, std::shared_ptr<EmbeddingStore<int32_t, float>> emb_store) {
|
||||
embedding_stores_[name] = emb_store;
|
||||
}
|
||||
std::shared_ptr<EmbeddingStore<size_t, float_t>> Get(const std::string &name) { return embedding_stores_[name]; }
|
||||
std::shared_ptr<EmbeddingStore<int32_t, float>> Get(const std::string &name) { return embedding_stores_[name]; }
|
||||
|
||||
bool IsExists(const std::string &name) const { return embedding_stores_.find(name) != embedding_stores_.end(); }
|
||||
|
||||
|
@ -288,7 +288,7 @@ class BACKEND_EXPORT EmbeddingStoreManager {
|
|||
~EmbeddingStoreManager() = default;
|
||||
DISABLE_COPY_AND_ASSIGN(EmbeddingStoreManager);
|
||||
|
||||
mindspore::HashMap<std::string, std::shared_ptr<EmbeddingStore<size_t, float_t>>> embedding_stores_;
|
||||
mindspore::HashMap<std::string, std::shared_ptr<EmbeddingStore<int32_t, float>>> embedding_stores_;
|
||||
};
|
||||
|
||||
size_t GetEmbeddingRemoteCacheSize();
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "distributed/embedding_cache/embedding_lru_cache.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace distributed {
|
||||
|
@ -24,9 +25,110 @@ bool EmbeddingLRUCache<K, V>::Initialize() {
|
|||
return true;
|
||||
}
|
||||
|
||||
template class EmbeddingLRUCache<size_t, float>;
|
||||
template class EmbeddingLRUCache<size_t, double>;
|
||||
template class EmbeddingLRUCache<size_t, int64_t>;
|
||||
template class EmbeddingLRUCache<size_t, size_t>;
|
||||
template <typename K, typename V>
|
||||
bool EmbeddingLRUCache<K, V>::Get(const void *input, size_t key_num, const void *keys, void *values, size_t *miss_num,
|
||||
void *miss_keys, size_t *miss_indices) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
MS_EXCEPTION_IF_NULL(keys);
|
||||
MS_EXCEPTION_IF_NULL(values);
|
||||
MS_EXCEPTION_IF_NULL(miss_keys);
|
||||
MS_EXCEPTION_IF_NULL(miss_indices);
|
||||
|
||||
size_t miss_count = 0;
|
||||
auto *miss_keys_list = static_cast<K *>(miss_keys);
|
||||
for (size_t i = 0; i < key_num; i++) {
|
||||
const K key = static_cast<const K *>(keys)[i];
|
||||
if (!keys_lru_cache_->Exists(key)) {
|
||||
miss_keys_list[miss_count] = key;
|
||||
miss_indices[miss_count] = i;
|
||||
miss_count += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t index = keys_lru_cache_->Get(key);
|
||||
auto ret = memcpy_s(AddressOffset(values, i * value_size_), value_size_,
|
||||
AddressOffset(const_cast<void *>(input), index * value_size_), value_size_);
|
||||
if (ret != 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
*miss_num = miss_count;
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename K, typename V>
|
||||
bool EmbeddingLRUCache<K, V>::Put(void *input, size_t key_num, const void *keys, const void *values,
|
||||
size_t *evicted_num, void *evicted_keys, void *evicted_values) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
MS_EXCEPTION_IF_NULL(keys);
|
||||
MS_EXCEPTION_IF_NULL(values);
|
||||
MS_EXCEPTION_IF_NULL(evicted_keys);
|
||||
MS_EXCEPTION_IF_NULL(evicted_values);
|
||||
|
||||
auto *evicted_keys_list = static_cast<K *>(evicted_keys);
|
||||
size_t evicted_count = 0;
|
||||
size_t hit_count = 0;
|
||||
for (size_t i = 0; i < key_num; i++) {
|
||||
hit_count++;
|
||||
const K key = static_cast<const K *>(keys)[i];
|
||||
|
||||
if (keys_lru_cache_->Exists(key)) {
|
||||
size_t idx = static_cast<size_t>(keys_lru_cache_->Get(key));
|
||||
auto ret = memcpy_s(AddressOffset(input, idx * value_size_), value_size_,
|
||||
AddressOffset(const_cast<void *>(values), i * value_size_), value_size_);
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "Failed to update exist key: " << key;
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!keys_lru_cache_->IsFull()) {
|
||||
keys_lru_cache_->Put(key, curr_index_);
|
||||
|
||||
auto ret = memcpy_s(AddressOffset(input, curr_index_ * value_size_), value_size_,
|
||||
AddressOffset(const_cast<void *>(values), i * value_size_), value_size_);
|
||||
if (ret != 0) {
|
||||
return false;
|
||||
}
|
||||
curr_index_++;
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::pair<K, V> last = keys_lru_cache_->Back();
|
||||
evicted_keys_list[evicted_count] = last.first;
|
||||
// Save evicted values
|
||||
auto ret = memcpy_s(AddressOffset(evicted_values, evicted_count * value_size_), value_size_,
|
||||
AddressOffset(input, last.second * value_size_), value_size_);
|
||||
if (ret != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Update input use new values
|
||||
ret = memcpy_s(AddressOffset(input, last.second * value_size_), value_size_,
|
||||
AddressOffset(const_cast<void *>(values), i * value_size_), value_size_);
|
||||
if (ret != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Update key&index cache
|
||||
keys_lru_cache_->Put(key, last.second);
|
||||
evicted_count++;
|
||||
}
|
||||
*evicted_num = evicted_count;
|
||||
MS_LOG(INFO) << "Embedding lru cache size after put: " << keys_lru_cache_->Size() << ", hit count: " << hit_count;
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename K, typename V>
|
||||
bool EmbeddingLRUCache<K, V>::IsFull() {
|
||||
return curr_index_ >= capacity_;
|
||||
}
|
||||
|
||||
template class EmbeddingLRUCache<int32_t, float>;
|
||||
template class EmbeddingLRUCache<int32_t, double>;
|
||||
template class EmbeddingLRUCache<int32_t, int64_t>;
|
||||
template class EmbeddingLRUCache<int32_t, size_t>;
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -46,7 +46,7 @@ class LRUCache {
|
|||
}
|
||||
cache_items_map_[key] = cache_items_list_.begin();
|
||||
|
||||
if (IsFull()) {
|
||||
if (cache_items_map_.size() > capacity_) {
|
||||
auto last = cache_items_list_.end();
|
||||
last--;
|
||||
cache_items_map_.erase(last->first);
|
||||
|
@ -64,11 +64,19 @@ class LRUCache {
|
|||
return it->second->second;
|
||||
}
|
||||
|
||||
const Item Back() const {
|
||||
if (cache_items_list_.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Cache is empty.";
|
||||
}
|
||||
|
||||
return cache_items_list_.back();
|
||||
}
|
||||
|
||||
bool Exists(const K &key) const { return cache_items_map_.find(key) != cache_items_map_.end(); }
|
||||
|
||||
size_t Size() const { return cache_items_map_.size(); }
|
||||
|
||||
bool IsFull() const { return Size() > capacity_; }
|
||||
bool IsFull() const { return Size() >= capacity_; }
|
||||
|
||||
private:
|
||||
std::list<Item> cache_items_list_;
|
||||
|
@ -80,22 +88,25 @@ class LRUCache {
|
|||
template <typename K, typename V>
|
||||
class BACKEND_EXPORT EmbeddingLRUCache : public EmbeddingCache {
|
||||
public:
|
||||
explicit EmbeddingLRUCache(size_t capacity) : capacity_(capacity) {}
|
||||
explicit EmbeddingLRUCache(size_t capacity, size_t value_size) : capacity_(capacity), value_size_(value_size) {}
|
||||
~EmbeddingLRUCache() = default;
|
||||
|
||||
bool Initialize();
|
||||
bool Finalize() { return true; }
|
||||
bool Initialize() override;
|
||||
bool Finalize() override { return true; }
|
||||
|
||||
bool Get(void *input, size_t key_num, const void *keys, void *values) override { return true; }
|
||||
bool Put(void *input, size_t key_num, const void *keys, const void *values, size_t evicted_num, void *evicted_keys,
|
||||
void *evicted_values) override {
|
||||
return true;
|
||||
}
|
||||
bool IsFull() override { return true; }
|
||||
bool Get(const void *input, size_t key_num, const void *keys, void *values, size_t *miss_num, void *miss_keys,
|
||||
size_t *miss_indices) override;
|
||||
bool Put(void *input, size_t key_num, const void *keys, const void *values, size_t *evicted_num, void *evicted_keys,
|
||||
void *evicted_values) override;
|
||||
bool IsFull() override;
|
||||
|
||||
private:
|
||||
size_t capacity_;
|
||||
|
||||
size_t value_size_;
|
||||
|
||||
size_t curr_index_{0};
|
||||
|
||||
// Cache the index of saved value in the Parameter of embedding.
|
||||
std::unique_ptr<LRUCache<K, size_t>> keys_lru_cache_;
|
||||
};
|
||||
|
|
|
@ -28,11 +28,16 @@ template <typename K, typename V>
|
|||
bool EmbeddingStore<K, V>::Initialize() {
|
||||
value_size_ = emb_dim_ * sizeof(V);
|
||||
key_size_ = sizeof(K);
|
||||
cache_ = std::make_unique<EmbeddingLRUCache<K, V>>(cache_capacity_);
|
||||
cache_ = std::make_unique<EmbeddingLRUCache<K, V>>(cache_capacity_, value_size_);
|
||||
if (!cache_->Initialize()) {
|
||||
MS_LOG(ERROR) << "Cannot initialize cache";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string storage_file_path = GetEmbeddingRemoteStoragePath();
|
||||
std::string storage_file_root_path = GetEmbeddingRemoteStoragePath();
|
||||
std::string storage_file_path = storage_file_root_path + "/" + name_;
|
||||
if (!distributed::storage::FileIOUtils::IsFileOrDirExist(storage_file_path)) {
|
||||
distributed::storage::FileIOUtils::CreateDir(storage_file_path);
|
||||
distributed::storage::FileIOUtils::CreateDirRecursive(storage_file_path);
|
||||
}
|
||||
auto ret = FileUtils::GetRealPath(storage_file_path.c_str());
|
||||
if (!ret.has_value()) {
|
||||
|
@ -44,6 +49,95 @@ bool EmbeddingStore<K, V>::Initialize() {
|
|||
std::map<std::string, std::string> config_map;
|
||||
config_map[distributed::storage::kFileStoragePath] = real_storage_file_path;
|
||||
storage_ = std::make_unique<storage::LocalFile>(config_map, key_size_, value_size_);
|
||||
storage_->Initialize();
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename K, typename V>
|
||||
bool EmbeddingStore<K, V>::Get(const void *input, size_t key_num, const void *keys, void *values) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
MS_EXCEPTION_IF_NULL(keys);
|
||||
MS_EXCEPTION_IF_NULL(values);
|
||||
|
||||
// 1. Get data from cache, save miss keys.
|
||||
size_t cache_miss_num = 0;
|
||||
cache_miss_keys_.resize(key_num);
|
||||
cache_miss_indices_.resize(key_num);
|
||||
|
||||
if (!cache_->Get(input, key_num, keys, values, &cache_miss_num, cache_miss_keys_.data(),
|
||||
cache_miss_indices_.data())) {
|
||||
MS_LOG(ERROR) << "Cannot get data from cache.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (cache_miss_num == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// 2. Get data of miss keys from storage
|
||||
MS_LOG(INFO) << "Embedding store read miss data from storage, num: " << cache_miss_num;
|
||||
size_t storage_miss_num = 0;
|
||||
storage_miss_indices_.resize(cache_miss_num);
|
||||
storage_output_buf_.resize(cache_miss_num * value_size_);
|
||||
storage_->Read(cache_miss_num, static_cast<const int32_t *>(cache_miss_keys_.data()), storage_output_buf_.data(),
|
||||
&storage_miss_num, storage_miss_indices_.data());
|
||||
if (storage_miss_num > 0) {
|
||||
MS_LOG(ERROR) << "Miss some key from storage. num: " << storage_miss_num;
|
||||
return false;
|
||||
}
|
||||
|
||||
// 3. Copy data of miss keys to values
|
||||
for (size_t i = 0; i < cache_miss_num; i++) {
|
||||
auto ret = memcpy_s(AddressOffset(values, cache_miss_indices_[i] * value_size_), value_size_,
|
||||
AddressOffset(storage_output_buf_.data(), i * value_size_), value_size_);
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "Failed to copy storage data to return values.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename K, typename V>
|
||||
bool EmbeddingStore<K, V>::Get(size_t key_num, const void *keys, void *values) {
|
||||
MS_EXCEPTION_IF_NULL(keys);
|
||||
MS_EXCEPTION_IF_NULL(values);
|
||||
|
||||
size_t storage_miss_num = 0;
|
||||
storage_miss_indices_.resize(key_num);
|
||||
|
||||
storage_->Read(key_num, static_cast<const int32_t *>(keys), values, &storage_miss_num, storage_miss_indices_.data());
|
||||
if (storage_miss_num > 0) {
|
||||
MS_LOG(INFO) << "Miss some key from storage. num: " << storage_miss_num;
|
||||
// After impl flush interface, all data will be in storage, it can not miss data from here.
|
||||
return true;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename K, typename V>
|
||||
bool EmbeddingStore<K, V>::Put(void *input, size_t key_num, const void *keys, const void *values) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
MS_EXCEPTION_IF_NULL(keys);
|
||||
MS_EXCEPTION_IF_NULL(values);
|
||||
|
||||
size_t evicted_num = 0;
|
||||
evicted_keys_.resize(key_num);
|
||||
evicted_values_buf_.resize(key_num * value_size_);
|
||||
if (!cache_->Put(input, key_num, keys, values, &evicted_num, evicted_keys_.data(), evicted_values_buf_.data())) {
|
||||
MS_LOG(ERROR) << "Cannot put data to cache.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (evicted_num == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// put evicted data to storage
|
||||
MS_LOG(INFO) << "Embedding store Write evicted data to storage, num: " << evicted_num;
|
||||
storage_->Write(evicted_values_buf_.data(), evicted_num, static_cast<const int32_t *>(evicted_keys_.data()));
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -56,9 +150,9 @@ std::string GetEmbeddingRemoteStoragePath() {
|
|||
return stoage_path;
|
||||
}
|
||||
|
||||
template class EmbeddingStore<size_t, float>;
|
||||
template class EmbeddingStore<size_t, double>;
|
||||
template class EmbeddingStore<size_t, int64_t>;
|
||||
template class EmbeddingStore<size_t, size_t>;
|
||||
template class EmbeddingStore<int32_t, float>;
|
||||
template class EmbeddingStore<int32_t, double>;
|
||||
template class EmbeddingStore<int32_t, int64_t>;
|
||||
template class EmbeddingStore<int32_t, size_t>;
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <string>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "distributed/persistent/storage/storage.h"
|
||||
#include "distributed/embedding_cache/embedding_lru_cache.h"
|
||||
|
@ -39,7 +40,8 @@ std::string GetEmbeddingRemoteStoragePath();
|
|||
template <typename K, typename V>
|
||||
class BACKEND_EXPORT EmbeddingStore {
|
||||
public:
|
||||
EmbeddingStore(size_t cache_capacity, size_t emb_dim) : cache_capacity_(cache_capacity), emb_dim_(emb_dim) {}
|
||||
EmbeddingStore(std::string name, size_t cache_capacity, size_t emb_dim)
|
||||
: name_(name), cache_capacity_(cache_capacity), emb_dim_(emb_dim) {}
|
||||
~EmbeddingStore() = default;
|
||||
|
||||
bool Initialize();
|
||||
|
@ -48,20 +50,23 @@ class BACKEND_EXPORT EmbeddingStore {
|
|||
// Get values which is indexed by keys at input. Input is a tensor data address from Parameter of embedding.
|
||||
// Values save the get result. Keys is lookup indices.
|
||||
// When keys not exist in input, will get values from persistent storage.
|
||||
bool Get(const void *input, size_t key_num, const void *keys, void *values) { return true; }
|
||||
bool Get(const void *input, size_t key_num, const void *keys, void *values);
|
||||
|
||||
// Get values which is indexed by keys at persistent storage.
|
||||
bool Get(size_t key_num, const void *keys, void *values) { return true; }
|
||||
bool Get(size_t key_num, const void *keys, void *values);
|
||||
|
||||
// Put values which is indexed by keys to input. Input is a tensor data address from Parameter of embedding.
|
||||
// Values is data to be update to input. Keys is update indices.
|
||||
// When input is full, save evicted values to persistent storage.
|
||||
bool Put(void *input, size_t key_num, const void *keys, const void *values) { return true; }
|
||||
bool Put(void *input, size_t key_num, const void *keys, const void *values);
|
||||
|
||||
// Flush input to persistent storage.
|
||||
bool Flush(void *input);
|
||||
|
||||
private:
|
||||
// A unique name for this embedding store.
|
||||
std::string name_;
|
||||
|
||||
// Cache the keys of Parameter of embedding.
|
||||
std::unique_ptr<EmbeddingCache> cache_;
|
||||
|
||||
|
@ -78,6 +83,19 @@ class BACKEND_EXPORT EmbeddingStore {
|
|||
|
||||
// Total size of bytes of key.
|
||||
size_t key_size_;
|
||||
|
||||
// Vector to save miss keys when get from cache.
|
||||
std::vector<K> cache_miss_keys_;
|
||||
// Vector to save miss indices when get from cache.
|
||||
std::vector<size_t> cache_miss_indices_;
|
||||
// Vector to save miss indices when get from storage.
|
||||
std::vector<size_t> storage_miss_indices_;
|
||||
// Buffer to save data read from storage.
|
||||
std::vector<uint8_t> storage_output_buf_;
|
||||
// Buffer to save evicted keys when put from cache.
|
||||
std::vector<K> evicted_keys_;
|
||||
// Buffer to save evicted values when put from cache.
|
||||
std::vector<uint8_t> evicted_values_buf_;
|
||||
};
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -54,7 +54,6 @@ void LocalFile::Initialize() {
|
|||
MS_EXCEPTION_IF_ZERO("feature_size_", feature_size_);
|
||||
MS_EXCEPTION_IF_ZERO("page_size_", page_size_);
|
||||
|
||||
page_size_ = DEFAULT_PAGE_SIZE;
|
||||
num_features_per_page_ = page_size_ / feature_size_;
|
||||
num_pages_per_block_file_ = DEFAULT_BLOCK_FILE_SIZE / page_size_;
|
||||
num_features_per_block_file_ = num_features_per_page_ * num_pages_per_block_file_;
|
||||
|
@ -308,19 +307,20 @@ bool LocalFile::LoadBlocksInfo() {
|
|||
return true;
|
||||
}
|
||||
|
||||
void LocalFile::Read(const std::vector<int> &ids, void *output, std::vector<int> *missing) {
|
||||
MS_EXCEPTION_IF_NULL(missing);
|
||||
void LocalFile::Read(size_t ids_num, const int32_t *ids, void *output, size_t *miss_num, size_t *miss_indices) {
|
||||
MS_EXCEPTION_IF_NULL(miss_indices);
|
||||
|
||||
size_t num_ids = ids.size();
|
||||
offsets_buf_.resize(num_ids);
|
||||
pages_buf_.resize(num_ids * page_size_);
|
||||
offsets_buf_.resize(ids_num);
|
||||
pages_buf_.resize(ids_num * page_size_);
|
||||
void *pages_ptr = pages_buf_.data();
|
||||
|
||||
ReadPages(num_ids, ids, pages_ptr, offsets_buf_.data());
|
||||
ReadPages(ids_num, ids, pages_ptr, offsets_buf_.data());
|
||||
|
||||
for (uint32_t i = 0; i < num_ids; ++i) {
|
||||
size_t miss_count = 0;
|
||||
for (uint32_t i = 0; i < ids_num; ++i) {
|
||||
if (offsets_buf_.at(i) == page_size_) {
|
||||
missing->emplace_back(i);
|
||||
miss_indices[miss_count] = i;
|
||||
miss_count++;
|
||||
continue;
|
||||
}
|
||||
auto ret = memcpy_s(AddressOffset(output, i * feature_size_), feature_size_,
|
||||
|
@ -329,16 +329,18 @@ void LocalFile::Read(const std::vector<int> &ids, void *output, std::vector<int>
|
|||
MS_LOG(EXCEPTION) << "Failed to copy output when read block, ret = " << ret;
|
||||
}
|
||||
}
|
||||
|
||||
*miss_num = miss_count;
|
||||
}
|
||||
|
||||
void LocalFile::ReadPages(size_t num_ids, const std::vector<int> &ids, void *pages_ptr, size_t *offsets) {
|
||||
void LocalFile::ReadPages(size_t ids_num, const int32_t *ids, void *pages_ptr, size_t *offsets) {
|
||||
MS_EXCEPTION_IF_NULL(pages_ptr);
|
||||
MS_EXCEPTION_IF_NULL(offsets);
|
||||
MS_EXCEPTION_IF_ZERO("num_features_per_page_", num_features_per_page_);
|
||||
MS_EXCEPTION_IF_ZERO("num_pages_per_block_file_", num_pages_per_block_file_);
|
||||
|
||||
for (size_t i = 0; i < num_ids; ++i) {
|
||||
const int id = ids[i];
|
||||
for (size_t i = 0; i < ids_num; ++i) {
|
||||
const int32_t id = ids[i];
|
||||
auto it = id_to_page_loc_.find(id);
|
||||
if (it == id_to_page_loc_.end()) {
|
||||
offsets[i] = page_size_;
|
||||
|
@ -357,35 +359,34 @@ void LocalFile::ReadPages(size_t num_ids, const std::vector<int> &ids, void *pag
|
|||
}
|
||||
}
|
||||
|
||||
void LocalFile::Write(const void *input, const std::vector<int> &ids) {
|
||||
void LocalFile::Write(const void *input, size_t ids_num, const int32_t *ids) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
|
||||
size_t num_ids = ids.size();
|
||||
pages_buf_.resize(num_ids * page_size_);
|
||||
pages_buf_.resize(ids_num * page_size_);
|
||||
|
||||
// Copy data at input to pages buf, page by page.
|
||||
for (size_t i = 0; i < num_ids; i += num_features_per_page_) {
|
||||
for (size_t i = 0; i < ids_num; i += num_features_per_page_) {
|
||||
const size_t page_id = i / num_features_per_page_;
|
||||
const size_t copy_size = (num_ids - i) < num_features_per_page_ ? (num_ids - i) * feature_size_ : page_size_;
|
||||
const size_t copy_size = (ids_num - i) < num_features_per_page_ ? (ids_num - i) * feature_size_ : page_size_;
|
||||
auto ret = memcpy_s(AddressOffset(pages_buf_.data(), page_id * page_size_), copy_size,
|
||||
AddressOffset(const_cast<void *>(input), i * feature_size_), copy_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Failed to copy input when write block, ret = " << ret;
|
||||
}
|
||||
}
|
||||
WritePages(num_ids, ids);
|
||||
WritePages(ids_num, ids);
|
||||
}
|
||||
|
||||
// This function write page buf to a block file, a block have many pages, a page have many feature, a feature is a
|
||||
// one-dim tensor. Block file is append only, so need use curr_storage_size_ to record the write pos at block file.
|
||||
void LocalFile::WritePages(size_t num_ids, const std::vector<int> &ids) {
|
||||
void LocalFile::WritePages(size_t ids_num, const int32_t *ids) {
|
||||
MS_EXCEPTION_IF_ZERO("num_features_per_page_", num_features_per_page_);
|
||||
MS_EXCEPTION_IF_ZERO("num_pages_per_block_file_", num_pages_per_block_file_);
|
||||
|
||||
const void *pages_ptr = pages_buf_.data();
|
||||
|
||||
// 1. Calculate how many pages will be written, we need aligned by features per page.
|
||||
const size_t num_pages = RoundUp(num_ids, num_features_per_page_) / num_features_per_page_;
|
||||
const size_t num_pages = RoundUp(ids_num, num_features_per_page_) / num_features_per_page_;
|
||||
const size_t num_padded_ids = num_pages * num_features_per_page_;
|
||||
const size_t start_index = curr_storage_size_;
|
||||
curr_storage_size_ += num_padded_ids;
|
||||
|
@ -427,7 +428,7 @@ void LocalFile::WritePages(size_t num_ids, const std::vector<int> &ids) {
|
|||
}
|
||||
|
||||
// 4. Record the index of page in block file of a feature.
|
||||
for (size_t i = 0; i < num_ids; ++i) {
|
||||
for (size_t i = 0; i < ids_num; ++i) {
|
||||
id_to_page_loc_[ids[i]] = start_index + i;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -79,7 +79,7 @@ class LocalFile : public StorageBase {
|
|||
// Write the entire blob data composed of multiple tensors to the block files on disk:
|
||||
void Write(const std::vector<InputData> &inputs, const DirtyInfo &dirty_info) override;
|
||||
// Write data of ids to block files.
|
||||
void Write(const void *input, const std::vector<int> &ids) override;
|
||||
void Write(const void *input, size_t ids_num, const int32_t *ids) override;
|
||||
|
||||
// The following two methods are override version function for Read:
|
||||
// 1.Tamper proof check.
|
||||
|
@ -89,7 +89,7 @@ class LocalFile : public StorageBase {
|
|||
// Read data from all block files in file_path_(dir) for multiple tensors.
|
||||
void Read(const std::vector<OutputData> &outputs) override;
|
||||
// Read ids from block files.
|
||||
void Read(const std::vector<int> &ids, void *output, std::vector<int> *missing) override;
|
||||
void Read(size_t ids_num, const int32_t *ids, void *output, size_t *miss_num, size_t *miss_indices) override;
|
||||
|
||||
private:
|
||||
// Create blocks and block metas and write input data to block files.
|
||||
|
@ -105,9 +105,9 @@ class LocalFile : public StorageBase {
|
|||
// Load file list info of block files and block meta files in the 'file_path_' to block list and block meta list.
|
||||
bool LoadBlocksInfo();
|
||||
|
||||
void ReadPages(size_t num_ids, const std::vector<int> &ids, void *pages_ptr, size_t *offsets);
|
||||
void ReadPages(size_t ids_num, const int32_t *ids, void *pages_ptr, size_t *offsets);
|
||||
|
||||
void WritePages(size_t num_ids, const std::vector<int> &ids);
|
||||
void WritePages(size_t ids_num, const int32_t *ids);
|
||||
|
||||
std::string BlockFilePath(size_t block_id) const;
|
||||
|
||||
|
@ -134,7 +134,7 @@ class LocalFile : public StorageBase {
|
|||
bool finish_create_block_files_{false};
|
||||
|
||||
// Size of a read/write page in block.
|
||||
size_t page_size_;
|
||||
size_t page_size_{DEFAULT_PAGE_SIZE};
|
||||
|
||||
size_t num_features_per_page_;
|
||||
size_t num_pages_per_block_file_;
|
||||
|
@ -152,7 +152,7 @@ class LocalFile : public StorageBase {
|
|||
std::vector<size_t> offsets_buf_;
|
||||
|
||||
// Map to record id in which page.
|
||||
mindspore::HashMap<int, size_t> id_to_page_loc_;
|
||||
mindspore::HashMap<int32_t, size_t> id_to_page_loc_;
|
||||
|
||||
// File System Ops handle.
|
||||
std::shared_ptr<system::FileSystem> fs_;
|
||||
|
|
|
@ -56,7 +56,7 @@ class StorageBase {
|
|||
virtual void Write(const std::vector<InputData> &input, const DirtyInfo &dirty_info) {}
|
||||
|
||||
// Write data of ids to block files.
|
||||
virtual void Write(const void *input, const std::vector<int> &ids) {}
|
||||
virtual void Write(const void *input, size_t ids_num, const int32_t *ids) {}
|
||||
|
||||
// Read data from the storage medium or memory buffer and merge them into contiguous memory.
|
||||
virtual void Read(const OutputData &output) {}
|
||||
|
@ -65,7 +65,7 @@ class StorageBase {
|
|||
virtual void Read(const std::vector<OutputData> &outputs) {}
|
||||
|
||||
// Read ids from block files.
|
||||
virtual void Read(const std::vector<int> &ids, void *output, std::vector<int> *missing) {}
|
||||
virtual void Read(size_t ids_num, const int32_t *ids, void *output, size_t *miss_num, size_t *miss_indices) {}
|
||||
};
|
||||
} // namespace storage
|
||||
} // namespace distributed
|
||||
|
|
|
@ -546,9 +546,15 @@ void PsEmbeddingCacheInserter::BuildEmbeddingStores() {
|
|||
|
||||
auto param_info = param->param_info();
|
||||
MS_EXCEPTION_IF_NULL(param_info);
|
||||
const std::vector<int64_t> &origin_shape = param_info->parameter_shape();
|
||||
size_t origin_capacity = LongToSize(origin_shape.front());
|
||||
size_t origin_emb_dim = LongToSize(origin_shape.back());
|
||||
MS_LOG(INFO) << "Get a parameter for embedding store: " << param->name() << ", origin emb_dim: " << origin_emb_dim
|
||||
<< ", origin capacity: " << origin_capacity;
|
||||
|
||||
if (!param_info->use_persistent_storage()) {
|
||||
MS_LOG(INFO) << "No need to use embedding store for this parameter(key): " << key;
|
||||
return;
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::vector<int64_t> &slice_shape = param_info->parameter_persistent_slice_shape();
|
||||
|
@ -559,13 +565,17 @@ void PsEmbeddingCacheInserter::BuildEmbeddingStores() {
|
|||
}
|
||||
size_t capacity = LongToSize(slice_shape.front());
|
||||
size_t emb_dim = LongToSize(slice_shape.back());
|
||||
std::string name = std::to_string(key);
|
||||
|
||||
auto emb_store = std::make_shared<distributed::EmbeddingStore<size_t, float_t>>(capacity, emb_dim);
|
||||
auto emb_store = std::make_shared<distributed::EmbeddingStore<int32_t, float>>(name, capacity, emb_dim);
|
||||
MS_EXCEPTION_IF_NULL(emb_store);
|
||||
if (!emb_store->Initialize()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to Initialize for parameter(key): " << key;
|
||||
}
|
||||
embedding_store_manager.Add(std::to_string(key), emb_store);
|
||||
embedding_store_manager.Add(name, emb_store);
|
||||
|
||||
MS_LOG(INFO) << "Add a new embedding store: " << name << ", emb_dim: " << emb_dim << ", capacity: " << capacity
|
||||
<< ", origin emb_dim:" << origin_emb_dim << ", origin capacity: " << origin_capacity;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -188,6 +188,10 @@ int EmbeddingLookUpCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
|||
}
|
||||
use_embedding_cache_ = GetValue<bool>(base_operator->GetAttr(kAttrUseEmbeddingStore));
|
||||
parameter_key_ = GetValue<int64_t>(base_operator->GetAttr(kAttrParameterKey));
|
||||
if (use_embedding_cache_) {
|
||||
MS_LOG(INFO) << "For embedding cache kernel: " << kernel_name_ << ", param key: " << parameter_key_
|
||||
<< ", vocab size: " << first_dim_size_ << ", emb dim:" << outer_dim_size_ << ", offset: " << offset_;
|
||||
}
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -65,6 +65,10 @@ int ScatterArithmeticCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
|||
|
||||
use_embedding_cache_ = GetValue<bool>(base_operator->GetAttr(kAttrUseEmbeddingStore));
|
||||
parameter_key_ = GetValue<int64_t>(base_operator->GetAttr(kAttrParameterKey));
|
||||
if (use_embedding_cache_) {
|
||||
MS_LOG(INFO) << "For embedding cache kernel: " << kernel_name_ << ", param key: " << parameter_key_
|
||||
<< ", vocab size: " << first_dim_size_ << ", emb dim:" << inner_size_;
|
||||
}
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -231,7 +231,7 @@ class Parameter(Tensor_):
|
|||
# And save out range data to persistent storage to support TB-Level size parameter.
|
||||
slice_num_of_persistent_data = split_to_slice_if_need(default_input.dtype, default_input.shape)
|
||||
if slice_num_of_persistent_data > 1:
|
||||
data_shape = default_input.shape
|
||||
data_shape = list(default_input.shape)
|
||||
slice_first_dim = math.ceil(data_shape[0] / slice_num_of_persistent_data)
|
||||
data_shape[0] = slice_first_dim
|
||||
self.param_info.parameter_persistent_slice_shape = data_shape
|
||||
|
@ -279,7 +279,7 @@ class Parameter(Tensor_):
|
|||
|
||||
@staticmethod
|
||||
def _not_init_data():
|
||||
is_worker_or_server = (_is_role_worker() or _is_role_pserver()) and not _enable_distributed_mindrt()
|
||||
is_worker_or_server = (_is_role_worker() or _is_role_pserver())
|
||||
if is_worker_or_server or _is_role_sched() or _is_in_parallel_mode():
|
||||
return True
|
||||
return False
|
||||
|
@ -298,7 +298,7 @@ class Parameter(Tensor_):
|
|||
# make a copy of Tensor to init the parameter.
|
||||
return (Tensor, data.asnumpy())
|
||||
|
||||
is_worker_or_server = (_is_role_worker() or _is_role_pserver()) and not _enable_distributed_mindrt()
|
||||
is_worker_or_server = (_is_role_worker() or _is_role_pserver())
|
||||
not_init_data = is_worker_or_server or _is_role_sched() or _is_in_parallel_mode()
|
||||
if not_init_data:
|
||||
# do not init data while in auto parallel.
|
||||
|
|
|
@ -3844,7 +3844,7 @@ class Tensor(Tensor_):
|
|||
shape = self.shape
|
||||
# At embedding cache scenes, we need limit the size of memory for tensor.
|
||||
# And save out of range data to persistent storage to support TB-Level size of tensor.
|
||||
data_shape = shape
|
||||
data_shape = list(shape)
|
||||
slice_num_of_persistent_data = split_to_slice_if_need(self.dtype, shape)
|
||||
if slice_num_of_persistent_data > 1:
|
||||
slice_first_dim = math.ceil(shape[0] / slice_num_of_persistent_data)
|
||||
|
|
|
@ -362,7 +362,7 @@ class EmbeddingLookup(Cell):
|
|||
if _enable_distributed_mindrt():
|
||||
self.rank_id = get_rank()
|
||||
if self.is_ps_server:
|
||||
self._slice_pserver_embeddings(param_init)
|
||||
self._slice_pserver_embeddings("zeros")
|
||||
self._set_cache_enable_and_key_for_pserver(param_key)
|
||||
|
||||
def _slice_pserver_embeddings(self, param_init):
|
||||
|
|
|
@ -0,0 +1,162 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/common_test.h"
|
||||
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <random>
|
||||
|
||||
#include "distributed/embedding_cache/embedding_lru_cache.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace distributed {
|
||||
namespace persistent {
|
||||
class TestEmbeddingLRUCache : public UT::Common {
|
||||
public:
|
||||
TestEmbeddingLRUCache() = default;
|
||||
virtual ~TestEmbeddingLRUCache() = default;
|
||||
|
||||
void SetUp() override {}
|
||||
void TearDown() override {}
|
||||
};
|
||||
|
||||
/// Feature: test lru cache.
|
||||
/// Description: test embedding lru cache data structure and interface.
|
||||
/// Expectation: all interface work normally and can not throw exception.
|
||||
TEST_F(TestEmbeddingLRUCache, test_lru_cache_base_api) {
|
||||
distributed::LRUCache<int, int> cache(1);
|
||||
cache.Put(8, 987);
|
||||
EXPECT_TRUE(cache.Exists(8));
|
||||
EXPECT_EQ(987, cache.Get(8));
|
||||
EXPECT_EQ(1, cache.Size());
|
||||
EXPECT_TRUE(cache.IsFull());
|
||||
}
|
||||
|
||||
/// Feature: test lru cache.
|
||||
/// Description: test embedding lru cache data structure and interface of simple use case.
|
||||
/// Expectation: all interface work normally and can not throw exception.
|
||||
TEST_F(TestEmbeddingLRUCache, test_lru_cache_simple_case) {
|
||||
int all_num = 100;
|
||||
int capacity = 50;
|
||||
|
||||
distributed::LRUCache<int, int> cache(capacity);
|
||||
|
||||
for (int i = 0; i < all_num; ++i) {
|
||||
cache.Put(i, i);
|
||||
}
|
||||
|
||||
for (int i = 0; i < all_num - capacity; ++i) {
|
||||
EXPECT_FALSE(cache.Exists(i));
|
||||
}
|
||||
|
||||
for (int i = all_num - capacity; i < all_num; ++i) {
|
||||
EXPECT_TRUE(cache.Exists(i));
|
||||
EXPECT_EQ(i, cache.Get(i));
|
||||
}
|
||||
|
||||
size_t size = cache.Size();
|
||||
EXPECT_EQ(capacity, size);
|
||||
}
|
||||
|
||||
/// Feature: test lru cache.
|
||||
/// Description: test embedding lru cache data structure and interface when miss key.
|
||||
/// Expectation: all interface work normally and can not throw exception.
|
||||
TEST_F(TestEmbeddingLRUCache, test_lru_cache_miss_key) {
|
||||
distributed::LRUCache<int, int> cache(1);
|
||||
EXPECT_ANY_THROW(cache.Get(5));
|
||||
}
|
||||
|
||||
/// Feature: test embedding lru cache.
|
||||
/// Description: test embedding lru cache data structure and interface of simple use case.
|
||||
/// Expectation: all interface work normally and can not throw exception.
|
||||
TEST_F(TestEmbeddingLRUCache, test_emb_lru_cache_simple_case) {
|
||||
size_t emb_dim = 256;
|
||||
size_t vocab_size = 1;
|
||||
size_t value_size = emb_dim * sizeof(float);
|
||||
size_t shape_size = vocab_size * emb_dim;
|
||||
size_t cache_capacity = 1;
|
||||
|
||||
auto cache = std::make_unique<EmbeddingLRUCache<int32_t, float>>(cache_capacity, value_size);
|
||||
EXPECT_NO_THROW(cache->Initialize());
|
||||
|
||||
std::vector<float> input;
|
||||
for (int i = 0; i < shape_size; i++) {
|
||||
input.emplace_back(1.0 * i);
|
||||
}
|
||||
size_t key_num = 1;
|
||||
std::vector<int32_t> keys{0};
|
||||
std::vector<float> values(shape_size);
|
||||
size_t miss_num = 0;
|
||||
std::vector<int32_t> miss_keys(1);
|
||||
std::vector<size_t> miss_indices(1);
|
||||
|
||||
// Get not exists key
|
||||
EXPECT_TRUE(
|
||||
cache->Get(input.data(), key_num, keys.data(), values.data(), &miss_num, miss_keys.data(), miss_indices.data()));
|
||||
EXPECT_EQ(1, miss_num);
|
||||
EXPECT_EQ(0, miss_keys[0]);
|
||||
EXPECT_EQ(1, miss_indices.size());
|
||||
EXPECT_EQ(0, miss_indices[0]);
|
||||
|
||||
// Put key&value to input
|
||||
for (int i = 0; i < shape_size; i++) {
|
||||
values[i] = 3.0 * i;
|
||||
}
|
||||
size_t evicted_num = 0;
|
||||
std::vector<int32_t> evicted_keys(1);
|
||||
std::vector<float> evicted_values(shape_size);
|
||||
EXPECT_TRUE(cache->Put(input.data(), key_num, keys.data(), values.data(), &evicted_num, evicted_keys.data(),
|
||||
evicted_values.data()));
|
||||
EXPECT_EQ(0, evicted_num);
|
||||
for (int i = 0; i < shape_size; i++) {
|
||||
EXPECT_FLOAT_EQ(3.0 * i, input[i]);
|
||||
}
|
||||
|
||||
// Put new key&value to input
|
||||
keys[0] = 42;
|
||||
for (int i = 0; i < shape_size; i++) {
|
||||
values[i] = 5.0 * i;
|
||||
}
|
||||
EXPECT_TRUE(cache->Put(input.data(), key_num, keys.data(), values.data(), &evicted_num, evicted_keys.data(),
|
||||
evicted_values.data()));
|
||||
EXPECT_EQ(1, evicted_num);
|
||||
EXPECT_EQ(1, evicted_keys.size());
|
||||
for (int i = 0; i < shape_size; i++) {
|
||||
EXPECT_FLOAT_EQ(5.0 * i, input[i]);
|
||||
}
|
||||
for (int i = 0; i < shape_size; i++) {
|
||||
EXPECT_FLOAT_EQ(3.0 * i, evicted_values[i]);
|
||||
}
|
||||
|
||||
// Get old key will miss
|
||||
keys[0] = 0;
|
||||
EXPECT_TRUE(
|
||||
cache->Get(input.data(), key_num, keys.data(), values.data(), &miss_num, miss_keys.data(), miss_indices.data()));
|
||||
EXPECT_EQ(1, miss_num);
|
||||
EXPECT_EQ(0, miss_keys[0]);
|
||||
EXPECT_EQ(1, miss_indices.size());
|
||||
EXPECT_EQ(0, miss_indices[0]);
|
||||
// value not change after get
|
||||
for (int i = 0; i < shape_size; i++) {
|
||||
EXPECT_FLOAT_EQ(5.0 * i, input[i]);
|
||||
}
|
||||
}
|
||||
} // namespace persistent
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,82 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/common_test.h"
|
||||
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <random>
|
||||
|
||||
#include "include/common/random.h"
|
||||
#include "distributed/embedding_cache/embedding_store.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace distributed {
|
||||
namespace persistent {
|
||||
class TestEmbeddingStore : public UT::Common {
|
||||
public:
|
||||
TestEmbeddingStore() = default;
|
||||
virtual ~TestEmbeddingStore() = default;
|
||||
|
||||
void SetUp() override {}
|
||||
void TearDown() override {}
|
||||
};
|
||||
|
||||
/// Feature: test embedding store.
|
||||
/// Description: test embedding store data structure and interface.
|
||||
/// Expectation: all interface work normally and can not throw exception.
|
||||
TEST_F(TestEmbeddingStore, test_embedding_store_simple_case) {
|
||||
size_t emb_dim = 3;
|
||||
size_t vocab_size = 3;
|
||||
size_t vocab_cache_size = 1;
|
||||
size_t shape_size = vocab_size * emb_dim;
|
||||
size_t cache_shape_size = vocab_cache_size * emb_dim;
|
||||
std::string name = "fake";
|
||||
|
||||
auto emb_store = std::make_shared<distributed::EmbeddingStore<int32_t, float>>(name, vocab_cache_size, emb_dim);
|
||||
EXPECT_NO_THROW(emb_store->Initialize());
|
||||
|
||||
std::vector<float> input(cache_shape_size);
|
||||
std::vector<float> values(shape_size);
|
||||
size_t key_num = 3;
|
||||
std::vector<int32_t> keys{0, 1, 2};
|
||||
|
||||
// Get keys not exists
|
||||
EXPECT_FALSE(emb_store->Get(input.data(), key_num, keys.data(), values.data()));
|
||||
|
||||
// Put key&value out of cache range
|
||||
for (int i = 0; i < shape_size; i++) {
|
||||
values[i] = 1.0 * i;
|
||||
}
|
||||
EXPECT_TRUE(emb_store->Put(input.data(), key_num, keys.data(), values.data()));
|
||||
for (int i = 2 * emb_dim; i < shape_size; i++) {
|
||||
EXPECT_FLOAT_EQ(1.0 * i, input[i - 2 * emb_dim]);
|
||||
}
|
||||
|
||||
// Get all key&value
|
||||
for (int i = 0; i < shape_size; i++) {
|
||||
values[i] = 0;
|
||||
}
|
||||
EXPECT_TRUE(emb_store->Get(input.data(), key_num, keys.data(), values.data()));
|
||||
for (int i = 0; i < shape_size; i++) {
|
||||
EXPECT_FLOAT_EQ(1.0 * i, values[i]);
|
||||
}
|
||||
}
|
||||
} // namespace persistent
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,78 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/common_test.h"
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "distributed/persistent/storage/local_file.h"
|
||||
#include "utils/file_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace distributed {
|
||||
namespace persistent {
|
||||
class TestLocalFile : public UT::Common {
|
||||
public:
|
||||
TestLocalFile() = default;
|
||||
virtual ~TestLocalFile() = default;
|
||||
void SetUp() override {}
|
||||
void TearDown() override {}
|
||||
};
|
||||
|
||||
/// Feature: test parameter persistent storage and resotre.
|
||||
/// Description: Modify part of the Embedding table content, persist it to the file, and read it from the file again.
|
||||
/// Expectation: The content after persistent recovery is consistent with expectations.
|
||||
TEST_F(TestLocalFile, test_read_write_by_ids_normal_size) {
|
||||
std::string storage_file_path = "./storage";
|
||||
if (!distributed::storage::FileIOUtils::IsFileOrDirExist(storage_file_path)) {
|
||||
distributed::storage::FileIOUtils::CreateDir(storage_file_path);
|
||||
}
|
||||
|
||||
std::map<std::string, std::string> config_map;
|
||||
config_map[distributed::storage::kFileStoragePath] = storage_file_path;
|
||||
std::shared_ptr<storage::StorageBase> storage_ =
|
||||
std::make_shared<storage::LocalFile>(config_map, sizeof(int32_t), 10 * sizeof(int));
|
||||
EXPECT_NO_THROW(storage_->Initialize());
|
||||
|
||||
size_t ids_num = 10000;
|
||||
size_t table_size = ids_num * 10;
|
||||
size_t miss_num = 0;
|
||||
|
||||
std::vector<int32_t> ids;
|
||||
for (int i = 1; i <= ids_num; i++) {
|
||||
ids.emplace_back(i);
|
||||
}
|
||||
|
||||
std::vector<int> write_data;
|
||||
for (int i = 0; i < table_size; i++) {
|
||||
write_data.emplace_back(i);
|
||||
}
|
||||
|
||||
std::vector<int> read_data(table_size);
|
||||
std::vector<size_t> missing(ids_num);
|
||||
|
||||
EXPECT_NO_THROW(storage_->Write(write_data.data(), ids_num, ids.data()));
|
||||
EXPECT_NO_THROW(storage_->Read(ids_num, ids.data(), read_data.data(), &miss_num, missing.data()));
|
||||
|
||||
EXPECT_EQ(miss_num, 0);
|
||||
for (int i = 0; i < table_size; i++) {
|
||||
EXPECT_EQ(i, read_data[i]);
|
||||
}
|
||||
}
|
||||
} // namespace persistent
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue