add cache pass
This commit is contained in:
parent
dfa6daaa57
commit
f97e19f23f
|
@ -19,55 +19,20 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "utils/cache_embedding_hashmap_struct.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
struct HashmapEntry {
|
||||
T key;
|
||||
T value;
|
||||
T step;
|
||||
T tag;
|
||||
|
||||
bool IsEmpty() {
|
||||
if (this->tag == NULLTAG)
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsUsing(const T &train_step) {
|
||||
if (this->step >= (train_step - 1))
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsKey(const T &emb_idx) {
|
||||
if (this->key == emb_idx)
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
void SetEmpty() { this->tag = NULLTAG; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
T HashFunc(const T &key, const size_t &m) {
|
||||
return (T)(((0.6180339 * key) - floor(0.6180339 * key)) * m);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) {
|
||||
T i = (entry + 1) % length, off = 1;
|
||||
int compress_count = 0;
|
||||
for (; !entry_p[i].IsEmpty(); i = (i + 1) % length, off++) {
|
||||
if (entry_p[i].tag > off) {
|
||||
entry_p[entry].key = entry_p[i].key;
|
||||
entry_p[entry].value = entry_p[i].value;
|
||||
entry_p[entry].step = entry_p[i].step;
|
||||
entry_p[entry].tag = entry_p[i].tag - off;
|
||||
if (entry_p[i].tag_ > off) {
|
||||
entry_p[entry].key_ = entry_p[i].key_;
|
||||
entry_p[entry].value_ = entry_p[i].value_;
|
||||
entry_p[entry].step_ = entry_p[i].step_;
|
||||
entry_p[entry].tag_ = entry_p[i].tag_ - off;
|
||||
entry_p[i].SetEmpty();
|
||||
off = 0;
|
||||
entry = i;
|
||||
|
@ -127,6 +92,7 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
float total_count = 0;
|
||||
int count_size = 0;
|
||||
float hit_count = 0;
|
||||
|
||||
// search_cache_idx
|
||||
for (size_t i = 0; i < batch_size_; ++i) {
|
||||
T key = input_indices[i] - offset;
|
||||
|
@ -140,7 +106,7 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
while ((!hashmap[tmp_entry].IsEmpty() && !hashmap[tmp_entry].IsKey(key))) {
|
||||
tmp_entry = (tmp_entry + 1) % hashmap_length_;
|
||||
if (count > hashmap_length_) {
|
||||
MS_LOG(ERROR) << "Hashmap is full, search cache idx failed!";
|
||||
MS_LOG(EXCEPTION) << "Hashmap is full, search cache idx failed, please set a larger vocab_cache_size!";
|
||||
break;
|
||||
}
|
||||
count += 1;
|
||||
|
@ -153,8 +119,8 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
miss_count++;
|
||||
} else {
|
||||
hit_count += 1;
|
||||
output_cache_idx[i] = hashmap[tmp_entry].value;
|
||||
hashmap[tmp_entry].step = step_[0];
|
||||
output_cache_idx[i] = hashmap[tmp_entry].value_;
|
||||
hashmap[tmp_entry].step_ = step_[0];
|
||||
}
|
||||
}
|
||||
if (miss_count != 0) {
|
||||
|
@ -175,27 +141,27 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
while (!hashmap[entry].IsEmpty()) {
|
||||
entry = (entry + 1) % hashmap_length_;
|
||||
if (tag_count > hashmap_length_) {
|
||||
MS_LOG(ERROR) << "Hashmap is full, insert new key failed!";
|
||||
MS_LOG(EXCEPTION) << "Hashmap is full, insert new key failed, please set a larger vocab_cache_size!";
|
||||
break;
|
||||
}
|
||||
tag_count++;
|
||||
}
|
||||
hashmap[entry].key = emb_idx;
|
||||
hashmap[entry].step = step_[0];
|
||||
hashmap[entry].tag = tag_count;
|
||||
hashmap[entry].key_ = emb_idx;
|
||||
hashmap[entry].step_ = step_[0];
|
||||
hashmap[entry].tag_ = tag_count;
|
||||
T tmp_entry = (entry + 1) % hashmap_length_;
|
||||
size_t delete_count = 1;
|
||||
while (hashmap[tmp_entry].IsEmpty() || hashmap[tmp_entry].IsUsing(step_[0])) {
|
||||
tmp_entry = (tmp_entry + 1) % hashmap_length_;
|
||||
if (delete_count > hashmap_length_) {
|
||||
MS_LOG(ERROR) << "Hashmap is full, delete old key failed!";
|
||||
MS_LOG(EXCEPTION) << "Hashmap is full, delete old key failed, please set a larger vocab_cache_size!";
|
||||
break;
|
||||
}
|
||||
delete_count++;
|
||||
}
|
||||
output_swap_cache_idx[i] = hashmap[tmp_entry].value;
|
||||
output_old_emb_idx[i] = hashmap[tmp_entry].key;
|
||||
hashmap[entry].value = output_swap_cache_idx[i];
|
||||
output_swap_cache_idx[i] = hashmap[tmp_entry].value_;
|
||||
output_old_emb_idx[i] = hashmap[tmp_entry].key_;
|
||||
hashmap[entry].value_ = output_swap_cache_idx[i];
|
||||
hashmap[tmp_entry].SetEmpty();
|
||||
int compress_count = Compress(hashmap, hashmap_length_, tmp_entry);
|
||||
total_delete_count += (compress_count + delete_count);
|
||||
|
|
|
@ -23,8 +23,6 @@
|
|||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
#define NULLTAG 0
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class MapCacheIdxCPUKernel : public CPUKernel {
|
||||
|
|
|
@ -188,12 +188,18 @@ void ReplaceOldNode(std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusi
|
|||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id];
|
||||
if (buffer_fusion_info.outputs_list.size() == 1) { // single output
|
||||
if (kernel_graph != nullptr) {
|
||||
kernel_graph->FrontBackendlMapUpdate(buffer_fusion_info.outputs_list[0], buffer_fusion_kernel);
|
||||
}
|
||||
(void)manager->Replace(buffer_fusion_info.outputs_list[0], buffer_fusion_kernel);
|
||||
ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[0],
|
||||
buffer_fusion_kernel);
|
||||
} else { // multiple output
|
||||
for (size_t index = 0; index < buffer_fusion_info.outputs_list.size(); ++index) {
|
||||
auto tuple_item = CreateTupleGetItem(buffer_fusion_kernel, kernel_graph, index);
|
||||
if (kernel_graph != nullptr) {
|
||||
kernel_graph->FrontBackendlMapUpdate(buffer_fusion_info.outputs_list[index], tuple_item);
|
||||
}
|
||||
(void)manager->Replace(buffer_fusion_info.outputs_list[index], tuple_item);
|
||||
ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[index],
|
||||
tuple_item);
|
||||
|
|
|
@ -274,6 +274,10 @@ tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
|
|||
bool IsNopNode(const AnfNodePtr &node) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
auto target = GetCNodeTarget(node);
|
||||
if (target == kCPUDevice) {
|
||||
return false;
|
||||
}
|
||||
if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice &&
|
||||
context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
|
||||
return false;
|
||||
|
|
|
@ -0,0 +1,535 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/parallel/cache_embedding/cache_embedding.h"
|
||||
#include <random>
|
||||
#include <vector>
|
||||
#include <list>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "utils/cache_embedding_hashmap_struct.h"
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
using ParamMap = std::unordered_map<ParameterPtr, ParameterPtr>;
|
||||
using ParamSet = std::unordered_set<ParameterPtr>;
|
||||
using NodePairList = std::vector<std::pair<AnfNodePtr, AnfNodePtr>>;
|
||||
|
||||
ParamMap AddCacheParameters(const FuncGraphPtr &graph, const ParamSet ¶meter_cache_enable_set) {
|
||||
ParamMap cache_host_params_map;
|
||||
for (auto ¶m : parameter_cache_enable_set) {
|
||||
auto param_info = param->param_info();
|
||||
if (param_info && param_info->cache_enable()) {
|
||||
auto data_type = param->Type();
|
||||
auto data_element_type = data_type->cast<mindspore::TensorTypePtr>()->element();
|
||||
auto type_id = data_element_type->type_id();
|
||||
auto cache_shape = param_info->cache_shape();
|
||||
auto ori_param_name = param->name();
|
||||
auto new_tensor = std::make_shared<tensor::Tensor>(type_id, cache_shape);
|
||||
ParamInfoPtr new_param_info = std::make_shared<ParamInfo>();
|
||||
auto cache_name = ori_param_name + "_cache";
|
||||
new_param_info->set_name(cache_name);
|
||||
new_tensor->set_param_info(new_param_info);
|
||||
auto cache_param = graph->AddWeightParameter(cache_name);
|
||||
cache_param->set_default_param(MakeValue(new_tensor));
|
||||
cache_param->set_abstract(new_tensor->ToAbstract());
|
||||
cache_host_params_map[cache_param] = param;
|
||||
}
|
||||
}
|
||||
return cache_host_params_map;
|
||||
}
|
||||
|
||||
bool CheckHostCacheParamSize(const ParamSet ¶meter_cache_enable_set) {
|
||||
int64_t host_size = 0;
|
||||
int64_t cache_size = 0;
|
||||
for (auto &host_param : parameter_cache_enable_set) {
|
||||
auto tmp_host_size = host_param->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>()->shape()[0];
|
||||
auto host_param_info = host_param->param_info();
|
||||
auto cache_shape = host_param_info->cache_shape();
|
||||
if (cache_shape.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The value of cache_shape is empty.";
|
||||
}
|
||||
auto tmp_cache_size = cache_shape[0];
|
||||
if ((host_size != 0 && tmp_host_size != host_size) || (cache_size != 0 && tmp_cache_size != cache_size)) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "If EmbeddingLookup are cache enable, vocab_size and vocab_cache_size of different cells must be the same.";
|
||||
}
|
||||
cache_size = tmp_cache_size;
|
||||
host_size = tmp_host_size;
|
||||
}
|
||||
if (cache_size >= host_size) {
|
||||
MS_LOG(WARNING) << "vocab_cache_size >= vocab_size, there is no need use cache.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void ReplaceCacheParams(const FuncGraphPtr &graph, const ParamMap &map) {
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
for (auto &ele : map) {
|
||||
if (!manager->Replace(ele.second, ele.first)) {
|
||||
MS_LOG(EXCEPTION) << "host param: " << ele.second->name() << ", replace node failed.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ParamSet MapKeysToSet(const ParamMap &map) {
|
||||
ParamSet set;
|
||||
for (auto &ele : map) {
|
||||
set.insert(ele.first);
|
||||
}
|
||||
return set;
|
||||
}
|
||||
|
||||
ParamSet FindParamCacheEnable(const FuncGraphPtr &graph) {
|
||||
ParamSet parameter_cache_enable_set;
|
||||
auto parameters = graph->parameters();
|
||||
auto params_size = parameters.size();
|
||||
for (size_t i = 0; i < params_size; ++i) {
|
||||
auto param = parameters[i]->cast<ParameterPtr>();
|
||||
auto param_info = param->param_info();
|
||||
if (param_info && param_info->cache_enable()) {
|
||||
parameter_cache_enable_set.insert(param);
|
||||
}
|
||||
}
|
||||
return parameter_cache_enable_set;
|
||||
}
|
||||
|
||||
CNodePtrList FindUniqueCacheEnable(const CNodePtrList &cnodes) {
|
||||
size_t cnodes_size = cnodes.size();
|
||||
CNodePtrList unique_cache_enable;
|
||||
for (size_t i = 0; i < cnodes_size; ++i) {
|
||||
if (IsPrimitiveCNode(cnodes[i], prim::kPrimUnique)) {
|
||||
auto unique_node = cnodes[i];
|
||||
auto unique_prim = GetCNodePrimitive(unique_node);
|
||||
MS_EXCEPTION_IF_NULL(unique_prim);
|
||||
auto attr_value = unique_prim->GetAttr(kAttrCacheEnable);
|
||||
if (attr_value != nullptr && GetValue<bool>(attr_value)) {
|
||||
unique_cache_enable.emplace_back(unique_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (unique_cache_enable.size() > 1) {
|
||||
MS_LOG(EXCEPTION) << "Support only one of Unique op cache enable, but got " << unique_cache_enable.size();
|
||||
}
|
||||
return unique_cache_enable;
|
||||
}
|
||||
|
||||
void BindAndInitCacheTensor(const ParamMap ¶m_pair_list, const ParameterPtr &hashmap) {
|
||||
auto hashmap_tensor_value = hashmap->default_param();
|
||||
auto hashmap_tensor = hashmap_tensor_value->cast<std::shared_ptr<tensor::Tensor>>();
|
||||
for (auto &ele : param_pair_list) {
|
||||
auto host_tensor_value = ele.second->default_param();
|
||||
auto host_tensor = host_tensor_value->cast<std::shared_ptr<tensor::Tensor>>();
|
||||
auto cache_tensor_value = ele.first->default_param();
|
||||
auto cache_tensor = cache_tensor_value->cast<std::shared_ptr<tensor::Tensor>>();
|
||||
|
||||
// bind host, cache, hashmap
|
||||
host_tensor->set_cache_enable(true);
|
||||
host_tensor->set_hashmap_tensor_ptr(hashmap_tensor);
|
||||
host_tensor->set_cache_tensor_ptr(cache_tensor);
|
||||
// init cache tensor data
|
||||
auto cache_byte_size = cache_tensor->Size();
|
||||
int ret = memcpy_s(cache_tensor->data_c(), cache_byte_size, host_tensor->data_c(), cache_byte_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Memcpy failed.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void InitHashMapData(void *data, const int64_t host_size, const int64_t cache_size, const size_t hashmap_size,
|
||||
const size_t byte_size) {
|
||||
MS_LOG(INFO) << "Start init hashmap data.";
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
HashmapEntry<T> *hashmap_data = static_cast<HashmapEntry<T> *>(data);
|
||||
MS_EXCEPTION_IF_NULL(hashmap_data);
|
||||
int ret = memset_s(hashmap_data, byte_size, 0, byte_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Memset failed.";
|
||||
}
|
||||
std::vector<T> host_range;
|
||||
host_range.reserve(host_size);
|
||||
for (int64_t i = 0; i < host_size; ++i) {
|
||||
host_range.emplace_back(i);
|
||||
}
|
||||
std::random_shuffle(host_range.begin(), host_range.end());
|
||||
size_t size = cache_size;
|
||||
size_t hashmap_count = 0;
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
auto random_key = host_range[i];
|
||||
auto entry = HashFunc(random_key, hashmap_size);
|
||||
size_t count = 1;
|
||||
while (!hashmap_data[entry].IsEmpty() && !hashmap_data[entry].IsKey(random_key)) {
|
||||
count += 1;
|
||||
entry = (entry + 1) % hashmap_size;
|
||||
}
|
||||
if (hashmap_data[entry].IsEmpty()) {
|
||||
hashmap_count++;
|
||||
hashmap_data[entry].key_ = random_key;
|
||||
hashmap_data[entry].value_ = i;
|
||||
hashmap_data[entry].step_ = kInitStep;
|
||||
hashmap_data[entry].tag_ = count;
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Hashmap init success, with " << hashmap_count << " / " << hashmap_size;
|
||||
}
|
||||
|
||||
AnfNodePtr InitHashMap(const FuncGraphPtr &func_graph, const int64_t host_size, const int64_t cache_size,
|
||||
TypeId type_id) {
|
||||
// init new tensor
|
||||
size_t hashmap_size = cache_size * kEmptyRate;
|
||||
std::vector<int64_t> host_shape{static_cast<int64_t>(hashmap_size), 4};
|
||||
auto new_tensor = std::make_shared<tensor::Tensor>(type_id, host_shape);
|
||||
size_t byte_size = new_tensor->Size();
|
||||
if (type_id == TypeId::kNumberTypeInt64) {
|
||||
InitHashMapData<int64_t>(new_tensor->data_c(), host_size, cache_size, hashmap_size, byte_size);
|
||||
} else {
|
||||
InitHashMapData<int32_t>(new_tensor->data_c(), host_size, cache_size, hashmap_size, byte_size);
|
||||
}
|
||||
ParamInfoPtr new_param_info = std::make_shared<ParamInfo>();
|
||||
std::string hashmap_name = "cache_hashmap";
|
||||
new_param_info->set_name(hashmap_name);
|
||||
new_tensor->set_param_info(new_param_info);
|
||||
auto hashmap = func_graph->AddWeightParameter(hashmap_name);
|
||||
hashmap->set_default_param(MakeValue(new_tensor));
|
||||
hashmap->set_abstract(new_tensor->ToAbstract());
|
||||
return hashmap;
|
||||
}
|
||||
|
||||
AnfNodePtr InitStep(const FuncGraphPtr &func_graph, TypeId type_id) {
|
||||
std::vector<int64_t> host_shape{1};
|
||||
auto new_tensor = std::make_shared<tensor::Tensor>(type_id, host_shape);
|
||||
auto step_data = static_cast<int64_t *>(new_tensor->data_c());
|
||||
step_data[0] = 0;
|
||||
ParamInfoPtr new_param_info = std::make_shared<ParamInfo>();
|
||||
std::string step_name = "cache_step";
|
||||
new_param_info->set_name(step_name);
|
||||
new_tensor->set_param_info(new_param_info);
|
||||
auto step = func_graph->AddWeightParameter(step_name);
|
||||
step->set_default_param(MakeValue(new_tensor));
|
||||
step->set_abstract(new_tensor->ToAbstract());
|
||||
return step;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateMapCacheIdx(const FuncGraphPtr &func_graph, const AnfNodePtr &indices,
|
||||
const ParamMap &cache_host_params_map) {
|
||||
auto iter = cache_host_params_map.begin();
|
||||
int64_t cache_size = iter->first->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>()->shape()[0];
|
||||
int64_t host_size = iter->second->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>()->shape()[0];
|
||||
auto indices_type = indices->Type();
|
||||
auto indices_element_type = indices_type->cast<mindspore::TensorTypePtr>()->element();
|
||||
auto indices_type_id = indices_element_type->type_id();
|
||||
auto hashmap = InitHashMap(func_graph, host_size, cache_size, indices_type_id);
|
||||
auto step = InitStep(func_graph, indices_type_id);
|
||||
auto max_num = NewValueNode(MakeValue(host_size));
|
||||
auto hashmap_param = hashmap->cast<ParameterPtr>();
|
||||
BindAndInitCacheTensor(cache_host_params_map, hashmap_param);
|
||||
// add rank_id
|
||||
int64_t offset_value = 0;
|
||||
std::string rank_id_str = common::GetEnv("RANK_ID");
|
||||
if (!rank_id_str.empty()) {
|
||||
int64_t rank_id = atoi(rank_id_str.c_str());
|
||||
offset_value = rank_id * host_size;
|
||||
}
|
||||
auto offset = NewValueNode(MakeValue(offset_value));
|
||||
auto max_num_imm = std::make_shared<Int64Imm>(SizeToLong(host_size));
|
||||
auto max_num_abstract_scalar = std::make_shared<abstract::AbstractScalar>(max_num_imm);
|
||||
max_num->set_abstract(max_num_abstract_scalar);
|
||||
auto offset_imm = std::make_shared<Int64Imm>(SizeToLong(offset_value));
|
||||
auto offset_abstract_scalar = std::make_shared<abstract::AbstractScalar>(offset_imm);
|
||||
offset->set_abstract(offset_abstract_scalar);
|
||||
|
||||
PrimitivePtr map_cache_primitive = prim::kPrimMapCacheIdx;
|
||||
map_cache_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
|
||||
std::vector<AnfNodePtr> map_cache_nodes{NewValueNode(map_cache_primitive), hashmap, indices, step, max_num, offset};
|
||||
auto map_cache_idx = func_graph->NewCNode(map_cache_nodes);
|
||||
|
||||
auto indices_ori_shp = indices->Shape();
|
||||
auto indices_shp = indices_ori_shp->cast<abstract::ShapePtr>();
|
||||
ShapeVector shape;
|
||||
ShapeVector min_shape;
|
||||
ShapeVector max_shape;
|
||||
if (!indices_shp->max_shape().empty()) {
|
||||
max_shape = indices_shp->max_shape();
|
||||
} else {
|
||||
max_shape = indices_shp->shape();
|
||||
}
|
||||
for (size_t i = 0; i < max_shape.size(); i++) {
|
||||
shape.emplace_back(-1);
|
||||
min_shape.emplace_back(1);
|
||||
}
|
||||
|
||||
auto cache_idx = std::make_shared<abstract::AbstractTensor>(indices_element_type, indices_shp);
|
||||
auto old_emb_idx = std::make_shared<abstract::AbstractTensor>(
|
||||
indices_element_type, std::make_shared<abstract::Shape>(shape, min_shape, max_shape));
|
||||
auto miss_emb_idx = std::make_shared<abstract::AbstractTensor>(
|
||||
indices_element_type, std::make_shared<abstract::Shape>(shape, min_shape, max_shape));
|
||||
auto swap_emb_idx = std::make_shared<abstract::AbstractTensor>(
|
||||
indices_element_type, std::make_shared<abstract::Shape>(shape, min_shape, max_shape));
|
||||
|
||||
std::vector<std::shared_ptr<abstract::AbstractBase>> elements = {cache_idx, old_emb_idx, miss_emb_idx, swap_emb_idx};
|
||||
auto abstract = std::make_shared<abstract::AbstractTuple>(elements);
|
||||
map_cache_idx->set_abstract(abstract);
|
||||
return map_cache_idx;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateTupleGetItem(const FuncGraphPtr &func_graph, const AnfNodePtr &input, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto idx = NewValueNode(SizeToLong(index));
|
||||
MS_EXCEPTION_IF_NULL(idx);
|
||||
auto imm = std::make_shared<Int64Imm>(SizeToLong(index));
|
||||
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
|
||||
idx->set_abstract(abstract_scalar);
|
||||
auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, idx});
|
||||
auto input_abstract_tuple = dyn_cast<abstract::AbstractTuple>(input->abstract());
|
||||
auto tuple_getitem_abstract = input_abstract_tuple->elements()[index];
|
||||
tuple_getitem->set_abstract(tuple_getitem_abstract);
|
||||
return tuple_getitem;
|
||||
}
|
||||
|
||||
void CreateTupleGetItems(const FuncGraphPtr &func_graph, const AnfNodePtr &input, std::vector<AnfNodePtr> *outputs) {
|
||||
auto input_abstract_tuple = dyn_cast<abstract::AbstractTuple>(input->abstract());
|
||||
auto size = input_abstract_tuple->elements().size();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
(*outputs).emplace_back(CreateTupleGetItem(func_graph, input, i));
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
}
|
||||
|
||||
AnfNodePtr CreateEmbeddingLookup(const FuncGraphPtr &graph, ParameterPtr params, AnfNodePtr indices) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
PrimitivePtr emb_lookup_primitive = prim::kPrimEmbeddingLookup;
|
||||
emb_lookup_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
|
||||
emb_lookup_primitive->set_attr(kAttrOffset, MakeValue<int64_t>(0));
|
||||
std::vector<AnfNodePtr> emb_lookup_nodes{NewValueNode(emb_lookup_primitive), params, indices};
|
||||
auto emb_lookup = graph->NewCNode(emb_lookup_nodes);
|
||||
return emb_lookup;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateCacheSwapTable(const FuncGraphPtr &graph, ParameterPtr cache_table, AnfNodePtr swap_cache_idx,
|
||||
AnfNodePtr miss_value) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
PrimitivePtr cache_swap_table_primitive = prim::kPrimCacheSwapTable;
|
||||
std::vector<AnfNodePtr> cache_swap_table_nodes{NewValueNode(cache_swap_table_primitive), cache_table, swap_cache_idx,
|
||||
miss_value};
|
||||
auto cache_swap_table = graph->NewCNode(cache_swap_table_nodes);
|
||||
return cache_swap_table;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateUpdateCache(const FuncGraphPtr &graph, ParameterPtr params, AnfNodePtr old_emb_idx,
|
||||
AnfNodePtr old_value) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
PrimitivePtr update_cache_primitive = prim::kPrimUpdateCache;
|
||||
update_cache_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
|
||||
|
||||
auto params_ori_shp = params->Shape();
|
||||
MS_EXCEPTION_IF_NULL(params_ori_shp);
|
||||
auto params_shp = params_ori_shp->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(params_shp);
|
||||
auto params_shape = params_shp->shape();
|
||||
auto max_size = params_shape[0];
|
||||
auto max_size_node = NewValueNode(MakeValue(max_size));
|
||||
auto max_num_imm = std::make_shared<Int64Imm>(SizeToLong(max_size));
|
||||
auto max_num_abstract_scalar = std::make_shared<abstract::AbstractScalar>(max_num_imm);
|
||||
max_size_node->set_abstract(max_num_abstract_scalar);
|
||||
|
||||
std::vector<AnfNodePtr> update_cache_nodes{NewValueNode(update_cache_primitive), params, old_emb_idx, old_value,
|
||||
max_size_node};
|
||||
auto update_cache = graph->NewCNode(update_cache_nodes);
|
||||
return update_cache;
|
||||
}
|
||||
|
||||
NodePairList CreateEmbSwapUpdate(const FuncGraphPtr &graph, ParamMap param_pair_list,
|
||||
const AnfNodePtrList &map_cache_idx_node_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
NodePairList node_pair_list;
|
||||
for (auto &ele : param_pair_list) {
|
||||
auto emb_lookup = CreateEmbeddingLookup(graph, ele.second, map_cache_idx_node_outputs[2]);
|
||||
auto cache_swap_table = CreateCacheSwapTable(graph, ele.first, map_cache_idx_node_outputs[3], emb_lookup);
|
||||
auto update_cache = CreateUpdateCache(graph, ele.second, map_cache_idx_node_outputs[1], cache_swap_table);
|
||||
node_pair_list.emplace_back(std::make_pair(cache_swap_table, update_cache));
|
||||
}
|
||||
return node_pair_list;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node,
|
||||
const AnfNodePtr &behind_node) {
|
||||
// Create control depend
|
||||
MS_EXCEPTION_IF_NULL(main_graph);
|
||||
AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), prior_node, behind_node};
|
||||
auto control_depend_cnode = main_graph->NewCNode(cd_inputs);
|
||||
return control_depend_cnode;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateDepend(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &invalid_nodes,
|
||||
const AnfNodePtr &patron_node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> make_tuple_list{NewValueNode(prim::kPrimMakeTuple)};
|
||||
std::copy(invalid_nodes.begin(), invalid_nodes.end(), std::back_inserter(make_tuple_list));
|
||||
auto make_tuple = graph->NewCNode(make_tuple_list);
|
||||
std::vector<AnfNodePtr> depend_list{NewValueNode(prim::kPrimDepend), patron_node, make_tuple};
|
||||
auto depend_cnode = graph->NewCNode(depend_list);
|
||||
depend_cnode->set_abstract(patron_node->abstract());
|
||||
return depend_cnode;
|
||||
}
|
||||
|
||||
CNodePtrList FindSparseGatherV2WithCache(const CNodePtrList &cnodes, const ParamSet ¶m_set) {
|
||||
size_t cnodes_size = cnodes.size();
|
||||
CNodePtrList sparse_gather_v2_with_cache;
|
||||
for (size_t i = 0; i < cnodes_size; ++i) {
|
||||
if (IsPrimitiveCNode(cnodes[i], prim::kPrimSparseGatherV2)) {
|
||||
auto param_node = cnodes[i]->input(1)->cast<ParameterPtr>();
|
||||
if (param_set.find(param_node) != param_set.end()) {
|
||||
sparse_gather_v2_with_cache.push_back(cnodes[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (sparse_gather_v2_with_cache.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find SparseGatherV2 with cache param.";
|
||||
}
|
||||
auto indices = sparse_gather_v2_with_cache[0]->input(2);
|
||||
for (auto &ele : sparse_gather_v2_with_cache) {
|
||||
if (ele->input(2) != indices) {
|
||||
MS_LOG(EXCEPTION) << "SparseGatherV2 which with cache param have different indices!.";
|
||||
}
|
||||
}
|
||||
return sparse_gather_v2_with_cache;
|
||||
}
|
||||
|
||||
AnfNodePtr FindGatherV2FromSparseGatherV2(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
AnfNodePtrList gatherv2_nodes;
|
||||
auto user_set = graph->manager()->node_users()[node];
|
||||
for (auto &ele : user_set) {
|
||||
if (IsPrimitiveCNode(ele.first, prim::kPrimGatherV2)) {
|
||||
gatherv2_nodes.emplace_back(ele.first);
|
||||
}
|
||||
}
|
||||
if (gatherv2_nodes.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "SparseGatherV2 with cache can only used by one of gatherv2, but got "
|
||||
<< gatherv2_nodes.size();
|
||||
}
|
||||
return gatherv2_nodes[0];
|
||||
}
|
||||
|
||||
void AddCacheEmbedding(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::list<CNodePtr> orders = graph->GetOrderedCnodes();
|
||||
CNodePtrList cnodes(orders.begin(), orders.end());
|
||||
size_t cnodes_size = cnodes.size();
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
bool training = graph->has_flag("training");
|
||||
auto param_cache_enable_set = FindParamCacheEnable(graph);
|
||||
if (param_cache_enable_set.empty()) {
|
||||
MS_LOG(INFO) << "Parameters are all not cache enable.";
|
||||
return;
|
||||
} else {
|
||||
MS_LOG(INFO) << "Parameters have cache enable.";
|
||||
}
|
||||
if (!CheckHostCacheParamSize(param_cache_enable_set)) {
|
||||
return;
|
||||
}
|
||||
if (training) {
|
||||
// If training, create cache parameters corresponding to the host params with is cache_enable.
|
||||
// Replace the host params. Create hashmap then insert MapCacheIdx op after Unique with has 'cache_enable' attr.
|
||||
// Bind hashmap tensor ptr and cache tensor ptr to host tensor, so that we can flush values
|
||||
// from cache to host in each epoch end.
|
||||
// Create EmbeddingLookup(CPU), CacheSwapTable(Ascend), UpdateCache(CPU) for each pair of params, in order to
|
||||
// flush miss values to cache params and write back old values to host params.
|
||||
// If no use pipe in training, EmbeddingLookup and CacheSwapTable must execute before SparseGatherV2, so add
|
||||
// ControlDepend between them. And add Depend for UpdateCache op and ControlDepnd op to add nodes into graph.
|
||||
auto unique_cache_enable = FindUniqueCacheEnable(cnodes);
|
||||
if (unique_cache_enable.empty()) {
|
||||
MS_LOG(WARNING) << "Parameters have cache enable, but not find Unique op cache enable.";
|
||||
return;
|
||||
}
|
||||
auto cache_host_params_map = AddCacheParameters(graph, param_cache_enable_set);
|
||||
auto param_set = MapKeysToSet(cache_host_params_map);
|
||||
ReplaceCacheParams(graph, cache_host_params_map);
|
||||
graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true);
|
||||
auto unique_node = unique_cache_enable[0];
|
||||
|
||||
CNodePtrList sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_set);
|
||||
auto unique_node_output_0 = CreateTupleGetItem(graph, unique_node, 0);
|
||||
auto map_cache_idx = CreateMapCacheIdx(graph, unique_node_output_0, cache_host_params_map);
|
||||
|
||||
AnfNodePtrList map_cache_idx_node_outputs;
|
||||
CreateTupleGetItems(graph, map_cache_idx, &map_cache_idx_node_outputs);
|
||||
|
||||
if (!manager->Replace(sparse_gatherv2_with_cache[0]->input(2), map_cache_idx_node_outputs[0])) {
|
||||
MS_LOG(EXCEPTION) << "MapCacheIdx output[0] replace node failed";
|
||||
}
|
||||
|
||||
auto node_pair_list = CreateEmbSwapUpdate(graph, cache_host_params_map, map_cache_idx_node_outputs);
|
||||
|
||||
AnfNodePtr last_node = cnodes[cnodes_size - 1];
|
||||
CNodePtr return_node;
|
||||
if (last_node->isa<CNode>()) {
|
||||
return_node = last_node->cast<CNodePtr>();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
if (!IsPrimitiveCNode(return_node, prim::kPrimReturn)) {
|
||||
MS_LOG(EXCEPTION) << "The last cnode after sorting, not return cnode.";
|
||||
}
|
||||
if (return_node->inputs().size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Number of return node inputs should be great than or equal to 2.";
|
||||
}
|
||||
AnfNodePtrList invalid_nodes;
|
||||
for (auto &ele : node_pair_list) {
|
||||
std::transform(sparse_gatherv2_with_cache.begin(), sparse_gatherv2_with_cache.end(),
|
||||
std::back_inserter(invalid_nodes), [&graph, &ele](const AnfNodePtr &sparse_gatherv2) {
|
||||
return CreateControlDepend(graph, ele.first, sparse_gatherv2);
|
||||
});
|
||||
invalid_nodes.emplace_back(ele.second);
|
||||
}
|
||||
auto depend_node = CreateDepend(graph, invalid_nodes, return_node->input(1));
|
||||
if (!manager->Replace(return_node->input(1), depend_node)) {
|
||||
MS_LOG(EXCEPTION) << "Depend replace node failed";
|
||||
}
|
||||
} else {
|
||||
// If eval, Use EmbeddingLookup(CPU) op to replace GatherV2.
|
||||
// The network is the same as Host-Device mode.
|
||||
auto unique_cache_enable = FindUniqueCacheEnable(cnodes);
|
||||
if (unique_cache_enable.empty()) {
|
||||
MS_LOG(WARNING) << "Parameters have cache enable, but not find Unique op cache enable.";
|
||||
return;
|
||||
}
|
||||
graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true);
|
||||
// replace GatherV2 to EmbeddingLookupCPU
|
||||
auto indices = unique_cache_enable[0]->input(1);
|
||||
auto sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_cache_enable_set);
|
||||
for (auto &ele : sparse_gatherv2_with_cache) {
|
||||
auto anf_ele = ele->cast<AnfNodePtr>();
|
||||
auto gatherv2 = FindGatherV2FromSparseGatherV2(graph, anf_ele);
|
||||
auto param = ele->input(1)->cast<ParameterPtr>();
|
||||
auto embedding_lookup = CreateEmbeddingLookup(graph, param, indices);
|
||||
if (!manager->Replace(gatherv2, embedding_lookup)) {
|
||||
MS_LOG(EXCEPTION) << "Depend replace node failed";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_CACHE_EMBEDDING_CACHE_EMBEDDING_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_CACHE_EMBEDDING_CACHE_EMBEDDING_H_
|
||||
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
// Automatically adding control depend based on effect order and side effect analysis.
|
||||
void AddCacheEmbedding(const FuncGraphPtr &graph);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_CACHE_EMBEDDING_CACHE_EMBEDDING_H_
|
|
@ -36,11 +36,14 @@
|
|||
#include "frontend/optimizer/graph_transform.h"
|
||||
#include "frontend/parallel/step_parallel.h"
|
||||
#include "frontend/parallel/step_auto_parallel.h"
|
||||
#include "frontend/parallel/cache_embedding/cache_embedding.h"
|
||||
#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
|
||||
#include "frontend/optimizer/recompute.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "pipeline/jit/pipeline_split.h"
|
||||
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#include "ps/util.h"
|
||||
#endif
|
||||
namespace mindspore {
|
||||
namespace pipeline {
|
||||
using OptPassGroupMap = opt::OptPassGroupMap;
|
||||
|
@ -391,6 +394,26 @@ bool AddRecomputationPass(const ResourcePtr &res) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool AddCacheEmbeddingPass(const ResourcePtr &res) {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
if (ps::Util::IsParamServerMode()) {
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
||||
parallel::AddCacheEmbedding(func_graph);
|
||||
if (func_graph->has_flag(GRAPH_FLAG_CACHE_ENABLE)) {
|
||||
auto params = func_graph->parameters();
|
||||
AbstractBasePtrList args_spec_list;
|
||||
std::for_each(params.begin(), params.end(),
|
||||
[&args_spec_list](const AnfNodePtr &node) { args_spec_list.push_back(node->abstract()); });
|
||||
func_graph = pipeline::Renormalize(res, func_graph, args_spec_list);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MergeDupGraphPass(const ResourcePtr &res) {
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -500,6 +523,7 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
|
|||
{"tuple_transform", OptPassTransformGraphGroup},
|
||||
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
|
||||
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
|
||||
{"add_cache_embedding", AddCacheEmbeddingPass},
|
||||
{"add_control_depend", AddControlDependPass},
|
||||
{"add_recomputation", AddRecomputationPass}};
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@ bool PipelineSplitPass(const ResourcePtr &res);
|
|||
bool ValidatePass(const ResourcePtr &res);
|
||||
bool ConvertPrepareAdapt(const ResourcePtr &res);
|
||||
bool AddControlDependPass(const ResourcePtr &res);
|
||||
bool AddCacheEmbeddingPass(const ResourcePtr &res);
|
||||
bool InferenceOptPreparePass(const ResourcePtr &res);
|
||||
void ReclaimOptimizer();
|
||||
bool PynativeOptPass(const ResourcePtr &res);
|
||||
|
|
|
@ -32,6 +32,8 @@ REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) {
|
|||
.def_property("parallel_optimizer", &ParamInfo::parallel_optimizer,
|
||||
&ParamInfo::set_parallel_optimizer)
|
||||
.def_property("comm_fusion", &ParamInfo::comm_fusion, &ParamInfo::set_comm_fusion)
|
||||
.def_property("cache_enable", &ParamInfo::cache_enable, &ParamInfo::set_cache_enable)
|
||||
.def_property("cache_shape", &ParamInfo::cache_shape, &ParamInfo::set_cache_shape)
|
||||
.def(py::pickle(
|
||||
[](const ParamInfo &p) { // __getstate__
|
||||
return py::make_tuple(p.name(), p.requires_grad(), p.layerwise_parallel());
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "pybind_api/api_register.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#include "utils/cache_embedding_hashmap_struct.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace tensor {
|
||||
|
@ -272,6 +273,68 @@ py::int_ TensorPy::GetPyItemSize(const Tensor &tensor) { return tensor.data().it
|
|||
|
||||
py::int_ TensorPy::GetPyNBytes(const Tensor &tensor) { return tensor.data().nbytes(); }
|
||||
|
||||
template <typename T>
|
||||
void MemCopyFromCacheToHost(void *hashmap_addr, void *host_addr, void *cache_addr, size_t host_max, size_t cache_max,
|
||||
size_t hashmap_size, size_t col_size) {
|
||||
auto host_data = static_cast<char *>(host_addr);
|
||||
auto cache_data = static_cast<char *>(cache_addr);
|
||||
auto hashmap_data = static_cast<HashmapEntry<T> *>(hashmap_addr);
|
||||
// default param type float
|
||||
size_t param_type_size = 4;
|
||||
size_t single_col_bytes = param_type_size * col_size;
|
||||
for (size_t i = 0; i < hashmap_size; ++i) {
|
||||
if (!hashmap_data[i].IsEmpty()) {
|
||||
size_t host_offset = single_col_bytes * hashmap_data[i].key_;
|
||||
size_t cache_offset = single_col_bytes * hashmap_data[i].value_;
|
||||
if (cache_offset + single_col_bytes <= cache_max) {
|
||||
auto ret =
|
||||
memcpy_s(host_data + host_offset, host_max - host_offset, cache_data + cache_offset, single_col_bytes);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Memcpy failed.";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Memcpy from cache to host success!";
|
||||
}
|
||||
|
||||
void TensorPy::FlushFromCache(const Tensor &tensor) {
|
||||
py::gil_scoped_release gil_release;
|
||||
if (tensor.NeedWait()) {
|
||||
tensor.Wait();
|
||||
}
|
||||
tensor.data_sync();
|
||||
|
||||
if (tensor.cache_enable()) {
|
||||
MS_LOG(INFO) << tensor.ToString() << " is cache enable.";
|
||||
auto hashmap_tensor_ptr = tensor.hashmap_tensor_ptr();
|
||||
auto cache_tensor_ptr = tensor.cache_tensor_ptr();
|
||||
if (hashmap_tensor_ptr != nullptr && cache_tensor_ptr != nullptr) {
|
||||
hashmap_tensor_ptr->data_sync();
|
||||
cache_tensor_ptr->data_sync();
|
||||
auto hashmap_size = hashmap_tensor_ptr->shape_c()[0];
|
||||
auto host_shape = tensor.shape_c();
|
||||
auto cache_shape = cache_tensor_ptr->shape_c();
|
||||
if (host_shape.size() != 2 && host_shape.size() != 2 && host_shape[1] != cache_shape[1]) {
|
||||
MS_LOG(EXCEPTION) << "Got host shape and cache shape invalid."
|
||||
<< "host shape:" << host_shape << ", cache shape:" << cache_shape;
|
||||
}
|
||||
auto host_data_max_size = tensor.Size();
|
||||
auto cache_data_max_size = cache_tensor_ptr->Size();
|
||||
auto hashmap_data_type = hashmap_tensor_ptr->data_type();
|
||||
if (hashmap_data_type == TypeId::kNumberTypeInt32) {
|
||||
MemCopyFromCacheToHost<int32_t>(hashmap_tensor_ptr->data_c(), tensor.data_c(), cache_tensor_ptr->data_c(),
|
||||
host_data_max_size, cache_data_max_size, hashmap_size, host_shape[1]);
|
||||
} else if (hashmap_data_type == TypeId::kNumberTypeInt64) {
|
||||
MemCopyFromCacheToHost<int32_t>(hashmap_tensor_ptr->data_c(), tensor.data_c(), cache_tensor_ptr->data_c(),
|
||||
host_data_max_size, cache_data_max_size, hashmap_size, host_shape[1]);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Hashmap dtype only suppotr int32, in64.";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
py::array TensorPy::SyncAsNumpy(const Tensor &tensor) {
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
|
@ -457,6 +520,16 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
|
|||
array([[1., 1., 1.],
|
||||
[1., 1., 1.]])
|
||||
)mydelimiter")
|
||||
.def("_flush_from_cache", TensorPy::FlushFromCache, R"mydelimiter(
|
||||
Flush Cache data to Host if tensor is cache enable.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
|
||||
Examples:
|
||||
>>> data = mindspore.Tensor(np.ones((2, 3)))
|
||||
>>> data._flush_from_cache()
|
||||
)mydelimiter")
|
||||
.def("is_init", &Tensor::is_init, R"mydelimiter(
|
||||
Get tensor init_flag.
|
||||
|
||||
|
|
|
@ -115,6 +115,8 @@ class TensorPy {
|
|||
static py::int_ GetPyItemSize(const Tensor &tensor);
|
||||
|
||||
static py::int_ GetPyNBytes(const Tensor &tensor);
|
||||
|
||||
static void FlushFromCache(const Tensor &tensor);
|
||||
};
|
||||
} // namespace tensor
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -268,7 +268,7 @@ void CPUKernelRuntime::CreateOutputTensors(session::KernelGraph *kernel_graph,
|
|||
bound_addresses_.clear();
|
||||
auto output_nodes = kernel_graph->outputs();
|
||||
for (const auto &item : output_nodes) {
|
||||
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, true);
|
||||
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, false);
|
||||
auto out = CreatTensorForOutput(kernel_graph, item_with_index, tensor_to_node);
|
||||
outputs->push_back(std::move(out));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_UTIL_CACHE_EMBBEDDING_HASHMAP_STRUCT_H_
|
||||
#define MINDSPORE_CCSRC_UTIL_CACHE_EMBBEDDING_HASHMAP_STRUCT_H_
|
||||
|
||||
#include <math.h>
|
||||
|
||||
namespace mindspore {
|
||||
const int64_t kNullTag = 0;
|
||||
const int64_t kInitStep = -5;
|
||||
const int64_t kEmptyRate = 4;
|
||||
const double kGoldenRatio = 0.6180339;
|
||||
template <typename T>
|
||||
struct HashmapEntry {
|
||||
T key_;
|
||||
T value_;
|
||||
T step_;
|
||||
T tag_;
|
||||
|
||||
bool IsEmpty() { return tag_ == kNullTag; }
|
||||
|
||||
bool IsUsing(const T train_step) { return step_ >= (train_step - 1); }
|
||||
|
||||
bool IsKey(const T emb_idx) { return key_ == emb_idx; }
|
||||
|
||||
void SetEmpty() { tag_ = kNullTag; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
T HashFunc(const T key, const size_t m) {
|
||||
return (T)(((kGoldenRatio * key) - floor(kGoldenRatio * key)) * m);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_UTIL_CACHE_EMBBEDDING_HASHMAP_STRUCT_H_
|
|
@ -350,6 +350,7 @@ constexpr auto kAttrPrimitiveTarget = "primitive_target";
|
|||
constexpr auto kAttrUseLocking = "use_locking";
|
||||
constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag";
|
||||
constexpr auto kAttrOffset = "offset";
|
||||
constexpr auto kAttrCacheEnable = "cache_enable";
|
||||
constexpr auto kAttrPsKey = "ps_key";
|
||||
constexpr auto kAttrOptimizerType = "optim_type";
|
||||
constexpr auto kAttrChildGraph = "child_graph";
|
||||
|
|
|
@ -131,7 +131,7 @@ class Parameter(Tensor_):
|
|||
if self.init_mode is not None:
|
||||
data = self.init_mode
|
||||
else:
|
||||
# cast to break deep infinit loop while deepcopy
|
||||
# cast to break deep infinite loop while deepcopy
|
||||
data = Tensor(self)
|
||||
return (
|
||||
Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel))
|
||||
|
@ -348,6 +348,8 @@ class Parameter(Tensor_):
|
|||
x.is_param_ps = self.is_param_ps
|
||||
x.init_in_server = self.init_in_server
|
||||
x.cache_enable = self.cache_enable
|
||||
if self.cache_shape:
|
||||
x.cache_shape = self.cache_shape
|
||||
if init != 'same':
|
||||
shape = self.shape
|
||||
dtype = self.dtype
|
||||
|
@ -375,6 +377,28 @@ class Parameter(Tensor_):
|
|||
raise TypeError("`parallel_optimizer` parameter must be bool type")
|
||||
self._param_info.parallel_optimizer = value
|
||||
|
||||
@property
|
||||
def cache_enable(self):
|
||||
"""Return whether the parameter is cache enable."""
|
||||
return self._param_info.cache_enable
|
||||
|
||||
@cache_enable.setter
|
||||
def cache_enable(self, value=True):
|
||||
if not isinstance(value, bool):
|
||||
raise TypeError("`cache_enable` parameter must be bool type")
|
||||
self._param_info.cache_enable = value
|
||||
|
||||
@property
|
||||
def cache_shape(self):
|
||||
"""Return the cache shape corresponding to the parameter if use cache."""
|
||||
return self._param_info.cache_shape
|
||||
|
||||
@cache_shape.setter
|
||||
def cache_shape(self, value):
|
||||
if not isinstance(value, (tuple, list)):
|
||||
raise TypeError("`cache_shape` parameter must be tuple or list type")
|
||||
self._param_info.cache_shape = value
|
||||
|
||||
@property
|
||||
def requires_grad(self):
|
||||
"""Return whether the parameter requires gradient."""
|
||||
|
|
|
@ -308,6 +308,10 @@ class Tensor(Tensor_):
|
|||
"""Convert tensor to numpy array."""
|
||||
return Tensor_.asnumpy(self)
|
||||
|
||||
def _flush_from_cache(self):
|
||||
"""Flush cache data to host if tensor is cache enable."""
|
||||
Tensor_._flush_from_cache(self)
|
||||
|
||||
def all(self, axis=(), keep_dims=False):
|
||||
"""
|
||||
Check all array elements along a given axis evaluate to True.
|
||||
|
|
|
@ -60,6 +60,7 @@ using ValueNodePtr = std::shared_ptr<ValueNode>;
|
|||
|
||||
class CNode;
|
||||
using CNodePtr = std::shared_ptr<CNode>;
|
||||
using CNodePtrList = std::vector<CNodePtr>;
|
||||
|
||||
class FuncGraph;
|
||||
using FuncGraphSet = OrderedSet<FuncGraphPtr>;
|
||||
|
@ -88,7 +89,7 @@ using ParamInfoPtr = std::shared_ptr<ParamInfo>;
|
|||
// intermediate_abstract: return the cached inferring abstract value.
|
||||
// Type/Shape: return the related info of this AnfNode. When this AnfNode is an
|
||||
// input of other CNodes, you can get the related info by this method.
|
||||
// debug_info: return the information retrived from parser. Set it using set_debug_info.
|
||||
// debug_info: return the information retrieved from parser. Set it using set_debug_info.
|
||||
// fullname_with_scope: return the detailed debug info.
|
||||
class AnfNode : public Base {
|
||||
public:
|
||||
|
|
|
@ -167,7 +167,6 @@ class MetaTensor : public Value {
|
|||
// Get tensor's param_info info.
|
||||
ParamInfoPtr param_info() const { return param_info_; }
|
||||
bool is_parameter() const { return is_parameter_; }
|
||||
|
||||
// Set tensor's param_info info.
|
||||
void set_param_info(const ParamInfoPtr ¶m_info) {
|
||||
is_parameter_ = true;
|
||||
|
|
|
@ -81,6 +81,12 @@ class ParamInfo {
|
|||
bool parallel_optimizer() const { return parallel_optimizer_; }
|
||||
void set_parallel_optimizer(bool parallel_optimizer) { parallel_optimizer_ = parallel_optimizer; }
|
||||
|
||||
bool cache_enable() const { return cache_enable_; }
|
||||
void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; }
|
||||
|
||||
std::vector<int64_t> cache_shape() const { return cache_shape_; }
|
||||
void set_cache_shape(const std::vector<int64_t> &cache_shape) { cache_shape_ = cache_shape; }
|
||||
|
||||
private:
|
||||
std::string name_{"Parameter"};
|
||||
bool requires_grad_{true};
|
||||
|
@ -92,6 +98,8 @@ class ParamInfo {
|
|||
int32_t cloned_index_{0};
|
||||
int32_t fusion_type_{1};
|
||||
bool parallel_optimizer_{true};
|
||||
bool cache_enable_{false};
|
||||
std::vector<int64_t> cache_shape_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_IR_PARAM_INFO_H_
|
||||
|
|
|
@ -449,6 +449,9 @@ Tensor::Tensor(const Tensor &tensor)
|
|||
event_(tensor.event_),
|
||||
sync_status_(tensor.sync_status_),
|
||||
device_sync_(tensor.device_sync_),
|
||||
cache_enable_(tensor.cache_enable_),
|
||||
cache_tensor_ptr_(tensor.cache_tensor_ptr_),
|
||||
hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_),
|
||||
padding_type_(tensor.padding_type()) {}
|
||||
|
||||
Tensor::Tensor(const Tensor &tensor, TypeId data_type)
|
||||
|
@ -459,6 +462,9 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type)
|
|||
event_(tensor.event_),
|
||||
sync_status_(tensor.sync_status_),
|
||||
device_sync_(tensor.device_sync_),
|
||||
cache_enable_(tensor.cache_enable_),
|
||||
cache_tensor_ptr_(tensor.cache_tensor_ptr_),
|
||||
hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_),
|
||||
padding_type_(tensor.padding_type()) {}
|
||||
|
||||
Tensor::Tensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data)
|
||||
|
@ -511,7 +517,7 @@ bool Tensor::ValueEqual(const Tensor &tensor) const {
|
|||
return (&tensor == this || (MetaTensor::operator==(tensor) && data_->equals(*tensor.data_)));
|
||||
}
|
||||
|
||||
// assgin value to this tensor
|
||||
// assign value to this tensor
|
||||
Tensor &Tensor::AssignValue(const Tensor &tensor) {
|
||||
if (this != &tensor) {
|
||||
MetaTensor::operator=(tensor);
|
||||
|
|
|
@ -206,7 +206,7 @@ class Tensor : public MetaTensor {
|
|||
// it do real value comparison.
|
||||
bool ValueEqual(const Tensor &tensor) const;
|
||||
|
||||
// assgin value to this tensor
|
||||
// assign value to this tensor
|
||||
Tensor &AssignValue(const Tensor &tensor);
|
||||
|
||||
bool operator==(const Value &other) const override {
|
||||
|
@ -291,6 +291,18 @@ class Tensor : public MetaTensor {
|
|||
TypePtr cast_dtype() { return cast_dtype_; }
|
||||
void set_cast_dtype(TypePtr dtype = nullptr) { cast_dtype_ = dtype; }
|
||||
|
||||
// used if cache_enable, in order to update tensor from cache to host
|
||||
bool cache_enable() const { return cache_enable_; }
|
||||
void set_cache_enable(bool cache_enable = true) { cache_enable_ = cache_enable; }
|
||||
std::shared_ptr<Tensor> hashmap_tensor_ptr() const { return hashmap_tensor_ptr_; }
|
||||
void set_hashmap_tensor_ptr(std::shared_ptr<Tensor> hashmap_tensor_ptr = nullptr) {
|
||||
hashmap_tensor_ptr_ = hashmap_tensor_ptr;
|
||||
}
|
||||
std::shared_ptr<Tensor> cache_tensor_ptr() const { return cache_tensor_ptr_; }
|
||||
void set_cache_tensor_ptr(std::shared_ptr<Tensor> cache_tensor_ptr = nullptr) {
|
||||
cache_tensor_ptr_ = cache_tensor_ptr;
|
||||
}
|
||||
|
||||
void SetNeedWait(bool need_wait) {
|
||||
if (event_ != nullptr) {
|
||||
event_->set_need_wait(need_wait);
|
||||
|
@ -335,6 +347,9 @@ class Tensor : public MetaTensor {
|
|||
mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice};
|
||||
bool graph_output_{false};
|
||||
DeviceSyncPtr device_sync_{nullptr};
|
||||
bool cache_enable_{false};
|
||||
std::shared_ptr<Tensor> cache_tensor_ptr_{nullptr};
|
||||
std::shared_ptr<Tensor> hashmap_tensor_ptr_{nullptr};
|
||||
std::vector<Axis> padding_type_;
|
||||
TypePtr cast_dtype_{nullptr};
|
||||
};
|
||||
|
|
|
@ -21,6 +21,7 @@ const char GRAPH_FLAG_MIX_PRECISION_FP16[] = "fp16";
|
|||
const char GRAPH_FLAG_MIX_PRECISION_FP32[] = "fp32";
|
||||
const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect";
|
||||
const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order";
|
||||
const char GRAPH_FLAG_CACHE_ENABLE[] = "cache_enable";
|
||||
const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect";
|
||||
const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect";
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ extern const char GRAPH_FLAG_MIX_PRECISION_FP16[];
|
|||
extern const char GRAPH_FLAG_MIX_PRECISION_FP32[];
|
||||
extern const char GRAPH_FLAG_HAS_EFFECT[];
|
||||
extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[];
|
||||
extern const char GRAPH_FLAG_CACHE_ENABLE[];
|
||||
extern const char GRAPH_FLAG_RANDOM_EFFECT[];
|
||||
extern const char GRAPH_FLAG_SIDE_EFFECT[];
|
||||
|
||||
|
|
|
@ -172,8 +172,8 @@ class EmbeddingLookup(Cell):
|
|||
or None. Default: None
|
||||
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
|
||||
vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: 0. It is valid only in
|
||||
parameter server trainning mode and 'DEVICE' target. And the moment parameter of corresponding
|
||||
optimizer will also be set to the cache size. In addition, it should be noted that it will cost the 'DEVICE'
|
||||
'DEVICE' target. And the moment parameter of corresponding optimizer will also be set to the cache size.
|
||||
In addition, it should be noted that it will cost the 'DEVICE'
|
||||
memory, so suggests setting a reasonable value to avoid insufficient memory.
|
||||
|
||||
Inputs:
|
||||
|
@ -205,7 +205,12 @@ class EmbeddingLookup(Cell):
|
|||
max_norm=None, sparse=True, vocab_cache_size=0):
|
||||
super(EmbeddingLookup, self).__init__()
|
||||
validator.check_value_type('sparse', sparse, [bool], self.cls_name)
|
||||
self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size')
|
||||
self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size')
|
||||
self.target = target
|
||||
self.sparse = sparse
|
||||
self.cache_enable = self.vocab_cache_size > 0
|
||||
self.forward_unique = False
|
||||
if target not in ('CPU', 'DEVICE'):
|
||||
raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed '
|
||||
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
|
||||
|
@ -216,21 +221,23 @@ class EmbeddingLookup(Cell):
|
|||
else:
|
||||
self.gatherv2 = P.GatherV2()
|
||||
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
|
||||
self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size')
|
||||
self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size')
|
||||
self._process_vocab_cache(slice_mode)
|
||||
enable_ps = _get_ps_context("enable_ps")
|
||||
if enable_ps:
|
||||
self._process_vocab_cache(slice_mode)
|
||||
self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size')
|
||||
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
|
||||
name='embedding_table')
|
||||
if self.cache_enable:
|
||||
self._set_voacb_cache_enable(vocab_cache_size, embedding_size, vocab_size)
|
||||
if self.cache_enable and enable_ps:
|
||||
self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size)
|
||||
parallel_mode = _get_parallel_mode()
|
||||
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||
self.forward_unique = False
|
||||
self.gather_revert = P.GatherV2()
|
||||
self.unique = P.Unique().shard(((1,),))
|
||||
self.reshape_first = P.Reshape()
|
||||
self.reshape = P.Reshape()
|
||||
self.unique = P.Unique()
|
||||
self.shape = P.Shape()
|
||||
if is_auto_parallel:
|
||||
self.unique = P.Unique().shard(((1,),))
|
||||
indices_shape_size = 2
|
||||
if slice_mode == "field_slice" and is_auto_parallel:
|
||||
if not manual_shapes:
|
||||
|
@ -270,12 +277,34 @@ class EmbeddingLookup(Cell):
|
|||
if is_auto_parallel:
|
||||
raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get "
|
||||
+ str(slice_mode))
|
||||
if self.cache_enable and not enable_ps:
|
||||
if is_auto_parallel:
|
||||
raise ValueError("parallel mode haven't supported cache enable yet.")
|
||||
self._set_cache_enable()
|
||||
self.embedding_table.unique = self.forward_unique
|
||||
self.max_norm = max_norm
|
||||
if self.max_norm is not None:
|
||||
self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name)
|
||||
self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
|
||||
|
||||
def _set_cache_enable(self):
|
||||
"""EmbeddingLookup cache check for not ps env."""
|
||||
if self.target != 'DEVICE':
|
||||
logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, "
|
||||
"so it will be ignored.")
|
||||
return
|
||||
if not self.sparse:
|
||||
logger.warning("The configuration of 'vocab_cache_size' is valid only 'sparse' is true, "
|
||||
"so it will be ignored.")
|
||||
return
|
||||
logger.info("EmbeddingLookup cache enable takes effect.")
|
||||
self.forward_unique = True
|
||||
self.unique = P.Unique().add_prim_attr('primitive_target', 'CPU')
|
||||
self.unique.add_prim_attr('cache_enable', True)
|
||||
self.embedding_table.cache_enable = self.cache_enable
|
||||
self.embedding_table.cache_shape = (self.vocab_cache_size, self.embedding_size)
|
||||
self.reshape_first = P.Reshape().add_prim_attr('primitive_target', 'CPU')
|
||||
|
||||
def _process_vocab_cache(self, slice_mode):
|
||||
"""PS embeddingLookup cache check and process."""
|
||||
self.cache_enable = False
|
||||
|
@ -302,7 +331,7 @@ class EmbeddingLookup(Cell):
|
|||
if _is_role_worker():
|
||||
self.vocab_size = self.vocab_cache_size
|
||||
|
||||
def _set_voacb_cache_enable(self, vocab_cache_size, embedding_size, vocab_size):
|
||||
def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size):
|
||||
"""PS embeddingLookup cache enable set."""
|
||||
self.embedding_table.cache_enable = True
|
||||
self.embedding_table.is_param_ps = True
|
||||
|
@ -316,7 +345,7 @@ class EmbeddingLookup(Cell):
|
|||
else:
|
||||
if self.forward_unique:
|
||||
shp = self.shape(indices) + (self.embedding_size,)
|
||||
indices_flatten = self.reshape(indices, (-1,))
|
||||
indices_flatten = self.reshape_first(indices, (-1,))
|
||||
unique_id, unique_idx = self.unique(indices_flatten)
|
||||
weight_unique = self.gatherv2(self.embedding_table, unique_id, 0)
|
||||
weight_flatten = self.gather_revert(weight_unique, unique_idx, 0)
|
||||
|
|
|
@ -156,8 +156,8 @@ class Optimizer(Cell):
|
|||
break
|
||||
ps_filter = lambda x: x.is_param_ps
|
||||
self.ps_parameters = tuple(ps_filter(x) for x in self.parameters)
|
||||
ps_cache_filter = lambda x: x.cache_enable
|
||||
self.cache_enable = tuple(ps_cache_filter(x) for x in self.parameters)
|
||||
cache_filter = lambda x: x.cache_enable
|
||||
self.cache_enable = tuple(cache_filter(x) for x in self.parameters)
|
||||
self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32)
|
||||
self.need_scale = loss_scale != 1.0
|
||||
self.global_step_increase_tensor = Tensor(1, mstype.int32)
|
||||
|
|
|
@ -526,6 +526,9 @@ class Model:
|
|||
|
||||
train_dataset.reset()
|
||||
|
||||
# if param is cache enable, flush data from cache to host before epoch end
|
||||
self._flush_from_cache(cb_params)
|
||||
|
||||
list_callback.epoch_end(run_context)
|
||||
should_stop = should_stop or run_context.get_stop_requested()
|
||||
if should_stop:
|
||||
|
@ -784,5 +787,11 @@ class Model:
|
|||
predict_net.compile(*predict_data)
|
||||
return predict_net.parameter_layout_dict
|
||||
|
||||
def _flush_from_cache(self, cb_params):
|
||||
"""Flush cache data to host if tensor is cache enable."""
|
||||
params = cb_params.train_network.get_parameters()
|
||||
for param in params:
|
||||
if param.cache_enable:
|
||||
Tensor(param)._flush_from_cache()
|
||||
|
||||
__all__ = ["Model"]
|
||||
|
|
|
@ -53,8 +53,8 @@ def init_var_dict(init_args, in_vars):
|
|||
'''
|
||||
var_map = {}
|
||||
_, _max_val = init_args
|
||||
for _, iterm in enumerate(in_vars):
|
||||
key, shape, method = iterm
|
||||
for _, item in enumerate(in_vars):
|
||||
key, shape, method = item
|
||||
if key not in var_map.keys():
|
||||
if method in ['random', 'uniform']:
|
||||
var_map[key] = Parameter(initializer(
|
||||
|
@ -257,9 +257,11 @@ class WideDeepModel(nn.Cell):
|
|||
self.wide_embeddinglookup.embedding_table.set_param_ps()
|
||||
else:
|
||||
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim,
|
||||
target='DEVICE', sparse=sparse)
|
||||
target='DEVICE', sparse=sparse,
|
||||
vocab_cache_size=self.vocab_cache_size)
|
||||
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1,
|
||||
target='DEVICE', sparse=sparse)
|
||||
target='DEVICE', sparse=sparse,
|
||||
vocab_cache_size=self.vocab_cache_size)
|
||||
self.embedding_table = self.deep_embeddinglookup.embedding_table
|
||||
|
||||
def construct(self, id_hldr, wt_hldr):
|
||||
|
|
|
@ -57,8 +57,8 @@ def init_var_dict(init_args, in_vars):
|
|||
"""
|
||||
var_map = {}
|
||||
_, _max_val = init_args
|
||||
for _, iterm in enumerate(in_vars):
|
||||
key, shape, method = iterm
|
||||
for _, item in enumerate(in_vars):
|
||||
key, shape, method = item
|
||||
if key not in var_map.keys():
|
||||
if method in ['random', 'uniform']:
|
||||
var_map[key] = Parameter(initializer(Uniform(_max_val), shape,
|
||||
|
|
Loading…
Reference in New Issue