forked from mindspore-Ecosystem/mindspore
!4848 Combine duplicate embedding lookup IDs
Merge pull request !4848 from chengang/unique_id_5
This commit is contained in:
commit
78f4923e83
|
@ -314,7 +314,7 @@ template <typename T>
|
|||
void ParameterServer<T>::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta &req_meta,
|
||||
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
|
||||
const Key &key = req_data.keys[0];
|
||||
for (size_t i = 0; i < req_data.keys.size(); i++) {
|
||||
for (size_t i = 1; i < req_data.keys.size(); i++) {
|
||||
res->keys.push_back(req_data.keys[i]);
|
||||
}
|
||||
ps_->DoEmbeddingLookup(key, req_data.keys.segment(1, req_data.keys.size()), res);
|
||||
|
|
|
@ -259,13 +259,29 @@ int WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps:
|
|||
auto &kvs = lookup_results_[ts];
|
||||
mutex_.unlock();
|
||||
|
||||
::ps::SArray<T> result(kvs[0].vals.size(), 0);
|
||||
for (auto k : kvs) {
|
||||
for (size_t i = 0; i < k.vals.size(); i++) {
|
||||
result[i] += k.vals[i];
|
||||
std::unordered_map<Key, std::shared_ptr<std::pair<T *, int>>> id_addr_map;
|
||||
for (const auto &s : kvs) {
|
||||
int offset = 0;
|
||||
int len = s.vals.size() / s.keys.size();
|
||||
for (size_t i = 0; i < s.keys.size(); i++) {
|
||||
const Key &key = s.keys[i];
|
||||
T *addr = s.vals.data() + offset;
|
||||
offset += len;
|
||||
id_addr_map[key] = std::make_shared<std::pair<T *, int>>(std::make_pair(addr, len));
|
||||
}
|
||||
}
|
||||
*lookup_result = result;
|
||||
|
||||
T *result_addr = lookup_result->data();
|
||||
int offset = 0;
|
||||
for (size_t i = 0; i < lookup_ids.size(); i++) {
|
||||
auto &pair = id_addr_map[static_cast<Key>(lookup_ids[i])];
|
||||
int size = pair->second * sizeof(T);
|
||||
auto ret = memcpy_s(result_addr + offset, size, pair->first, size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
offset += pair->second;
|
||||
}
|
||||
|
||||
mutex_.lock();
|
||||
lookup_results_.erase(ts);
|
||||
|
@ -312,12 +328,23 @@ void WorkerProxy<T>::LookupIdSlicer(int timestamp, const ::ps::KVPairs<T> &send,
|
|||
sliced->resize(ranges.size());
|
||||
|
||||
for (size_t i = 0; i < ranges.size(); i++) {
|
||||
const ::ps::Range &range = ranges[i];
|
||||
const auto &begin = range.begin();
|
||||
const auto &end = range.end();
|
||||
std::unordered_set<int> unique_ids;
|
||||
auto &kvs = sliced->at(i).second;
|
||||
|
||||
kvs.keys.push_back(key);
|
||||
kvs.vals.push_back(0.0f);
|
||||
|
||||
for (size_t j = 0; j < id_size; j++) {
|
||||
kvs.keys.push_back(lookup_ids[j]);
|
||||
auto lookup_id = static_cast<uint64_t>(lookup_ids[j]);
|
||||
if (lookup_id >= begin && lookup_id <= end) {
|
||||
unique_ids.insert(lookup_id);
|
||||
}
|
||||
}
|
||||
for (const auto &lookup_id : unique_ids) {
|
||||
kvs.keys.push_back(lookup_id);
|
||||
kvs.vals.push_back(0.0f);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue