add cache pass

This commit is contained in:
fangzehua 2021-01-04 16:13:13 +08:00
parent dfa6daaa57
commit f97e19f23f
28 changed files with 870 additions and 80 deletions

View File

@ -19,55 +19,20 @@
#include <memory>
#include <vector>
#include "runtime/device/cpu/cpu_device_address.h"
#include "utils/cache_embedding_hashmap_struct.h"
namespace mindspore {
namespace kernel {
template <typename T>
struct HashmapEntry {
T key;
T value;
T step;
T tag;
bool IsEmpty() {
if (this->tag == NULLTAG)
return true;
else
return false;
}
bool IsUsing(const T &train_step) {
if (this->step >= (train_step - 1))
return true;
else
return false;
}
bool IsKey(const T &emb_idx) {
if (this->key == emb_idx)
return true;
else
return false;
}
void SetEmpty() { this->tag = NULLTAG; }
};
template <typename T>
T HashFunc(const T &key, const size_t &m) {
return (T)(((0.6180339 * key) - floor(0.6180339 * key)) * m);
}
template <typename T>
int Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) {
T i = (entry + 1) % length, off = 1;
int compress_count = 0;
for (; !entry_p[i].IsEmpty(); i = (i + 1) % length, off++) {
if (entry_p[i].tag > off) {
entry_p[entry].key = entry_p[i].key;
entry_p[entry].value = entry_p[i].value;
entry_p[entry].step = entry_p[i].step;
entry_p[entry].tag = entry_p[i].tag - off;
if (entry_p[i].tag_ > off) {
entry_p[entry].key_ = entry_p[i].key_;
entry_p[entry].value_ = entry_p[i].value_;
entry_p[entry].step_ = entry_p[i].step_;
entry_p[entry].tag_ = entry_p[i].tag_ - off;
entry_p[i].SetEmpty();
off = 0;
entry = i;
@ -127,6 +92,7 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
float total_count = 0;
int count_size = 0;
float hit_count = 0;
// search_cache_idx
for (size_t i = 0; i < batch_size_; ++i) {
T key = input_indices[i] - offset;
@ -140,7 +106,7 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
while ((!hashmap[tmp_entry].IsEmpty() && !hashmap[tmp_entry].IsKey(key))) {
tmp_entry = (tmp_entry + 1) % hashmap_length_;
if (count > hashmap_length_) {
MS_LOG(ERROR) << "Hashmap is full, search cache idx failed!";
MS_LOG(EXCEPTION) << "Hashmap is full, search cache idx failed, please set a larger vocab_cache_size!";
break;
}
count += 1;
@ -153,8 +119,8 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
miss_count++;
} else {
hit_count += 1;
output_cache_idx[i] = hashmap[tmp_entry].value;
hashmap[tmp_entry].step = step_[0];
output_cache_idx[i] = hashmap[tmp_entry].value_;
hashmap[tmp_entry].step_ = step_[0];
}
}
if (miss_count != 0) {
@ -175,27 +141,27 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
while (!hashmap[entry].IsEmpty()) {
entry = (entry + 1) % hashmap_length_;
if (tag_count > hashmap_length_) {
MS_LOG(ERROR) << "Hashmap is full, insert new key failed!";
MS_LOG(EXCEPTION) << "Hashmap is full, insert new key failed, please set a larger vocab_cache_size!";
break;
}
tag_count++;
}
hashmap[entry].key = emb_idx;
hashmap[entry].step = step_[0];
hashmap[entry].tag = tag_count;
hashmap[entry].key_ = emb_idx;
hashmap[entry].step_ = step_[0];
hashmap[entry].tag_ = tag_count;
T tmp_entry = (entry + 1) % hashmap_length_;
size_t delete_count = 1;
while (hashmap[tmp_entry].IsEmpty() || hashmap[tmp_entry].IsUsing(step_[0])) {
tmp_entry = (tmp_entry + 1) % hashmap_length_;
if (delete_count > hashmap_length_) {
MS_LOG(ERROR) << "Hashmap is full, delete old key failed!";
MS_LOG(EXCEPTION) << "Hashmap is full, delete old key failed, please set a larger vocab_cache_size!";
break;
}
delete_count++;
}
output_swap_cache_idx[i] = hashmap[tmp_entry].value;
output_old_emb_idx[i] = hashmap[tmp_entry].key;
hashmap[entry].value = output_swap_cache_idx[i];
output_swap_cache_idx[i] = hashmap[tmp_entry].value_;
output_old_emb_idx[i] = hashmap[tmp_entry].key_;
hashmap[entry].value_ = output_swap_cache_idx[i];
hashmap[tmp_entry].SetEmpty();
int compress_count = Compress(hashmap, hashmap_length_, tmp_entry);
total_delete_count += (compress_count + delete_count);

View File

@ -23,8 +23,6 @@
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#define NULLTAG 0
namespace mindspore {
namespace kernel {
class MapCacheIdxCPUKernel : public CPUKernel {

View File

@ -188,12 +188,18 @@ void ReplaceOldNode(std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusi
MS_EXCEPTION_IF_NULL(manager);
auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id];
if (buffer_fusion_info.outputs_list.size() == 1) { // single output
if (kernel_graph != nullptr) {
kernel_graph->FrontBackendlMapUpdate(buffer_fusion_info.outputs_list[0], buffer_fusion_kernel);
}
(void)manager->Replace(buffer_fusion_info.outputs_list[0], buffer_fusion_kernel);
ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[0],
buffer_fusion_kernel);
} else { // multiple output
for (size_t index = 0; index < buffer_fusion_info.outputs_list.size(); ++index) {
auto tuple_item = CreateTupleGetItem(buffer_fusion_kernel, kernel_graph, index);
if (kernel_graph != nullptr) {
kernel_graph->FrontBackendlMapUpdate(buffer_fusion_info.outputs_list[index], tuple_item);
}
(void)manager->Replace(buffer_fusion_info.outputs_list[index], tuple_item);
ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[index],
tuple_item);

View File

@ -274,6 +274,10 @@ tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
bool IsNopNode(const AnfNodePtr &node) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto target = GetCNodeTarget(node);
if (target == kCPUDevice) {
return false;
}
if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice &&
context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
return false;

View File

@ -0,0 +1,535 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/parallel/cache_embedding/cache_embedding.h"
#include <random>
#include <vector>
#include <list>
#include <utility>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <string>
#include <algorithm>
#include "backend/optimizer/common/helper.h"
#include "frontend/optimizer/optimizer.h"
#include "ir/func_graph.h"
#include "utils/cache_embedding_hashmap_struct.h"
namespace mindspore {
namespace parallel {
using ParamMap = std::unordered_map<ParameterPtr, ParameterPtr>;
using ParamSet = std::unordered_set<ParameterPtr>;
using NodePairList = std::vector<std::pair<AnfNodePtr, AnfNodePtr>>;
ParamMap AddCacheParameters(const FuncGraphPtr &graph, const ParamSet &parameter_cache_enable_set) {
ParamMap cache_host_params_map;
for (auto &param : parameter_cache_enable_set) {
auto param_info = param->param_info();
if (param_info && param_info->cache_enable()) {
auto data_type = param->Type();
auto data_element_type = data_type->cast<mindspore::TensorTypePtr>()->element();
auto type_id = data_element_type->type_id();
auto cache_shape = param_info->cache_shape();
auto ori_param_name = param->name();
auto new_tensor = std::make_shared<tensor::Tensor>(type_id, cache_shape);
ParamInfoPtr new_param_info = std::make_shared<ParamInfo>();
auto cache_name = ori_param_name + "_cache";
new_param_info->set_name(cache_name);
new_tensor->set_param_info(new_param_info);
auto cache_param = graph->AddWeightParameter(cache_name);
cache_param->set_default_param(MakeValue(new_tensor));
cache_param->set_abstract(new_tensor->ToAbstract());
cache_host_params_map[cache_param] = param;
}
}
return cache_host_params_map;
}
bool CheckHostCacheParamSize(const ParamSet &parameter_cache_enable_set) {
int64_t host_size = 0;
int64_t cache_size = 0;
for (auto &host_param : parameter_cache_enable_set) {
auto tmp_host_size = host_param->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>()->shape()[0];
auto host_param_info = host_param->param_info();
auto cache_shape = host_param_info->cache_shape();
if (cache_shape.empty()) {
MS_LOG(EXCEPTION) << "The value of cache_shape is empty.";
}
auto tmp_cache_size = cache_shape[0];
if ((host_size != 0 && tmp_host_size != host_size) || (cache_size != 0 && tmp_cache_size != cache_size)) {
MS_LOG(EXCEPTION)
<< "If EmbeddingLookup are cache enable, vocab_size and vocab_cache_size of different cells must be the same.";
}
cache_size = tmp_cache_size;
host_size = tmp_host_size;
}
if (cache_size >= host_size) {
MS_LOG(WARNING) << "vocab_cache_size >= vocab_size, there is no need use cache.";
return false;
}
return true;
}
void ReplaceCacheParams(const FuncGraphPtr &graph, const ParamMap &map) {
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
for (auto &ele : map) {
if (!manager->Replace(ele.second, ele.first)) {
MS_LOG(EXCEPTION) << "host param: " << ele.second->name() << ", replace node failed.";
}
}
}
ParamSet MapKeysToSet(const ParamMap &map) {
ParamSet set;
for (auto &ele : map) {
set.insert(ele.first);
}
return set;
}
ParamSet FindParamCacheEnable(const FuncGraphPtr &graph) {
ParamSet parameter_cache_enable_set;
auto parameters = graph->parameters();
auto params_size = parameters.size();
for (size_t i = 0; i < params_size; ++i) {
auto param = parameters[i]->cast<ParameterPtr>();
auto param_info = param->param_info();
if (param_info && param_info->cache_enable()) {
parameter_cache_enable_set.insert(param);
}
}
return parameter_cache_enable_set;
}
CNodePtrList FindUniqueCacheEnable(const CNodePtrList &cnodes) {
size_t cnodes_size = cnodes.size();
CNodePtrList unique_cache_enable;
for (size_t i = 0; i < cnodes_size; ++i) {
if (IsPrimitiveCNode(cnodes[i], prim::kPrimUnique)) {
auto unique_node = cnodes[i];
auto unique_prim = GetCNodePrimitive(unique_node);
MS_EXCEPTION_IF_NULL(unique_prim);
auto attr_value = unique_prim->GetAttr(kAttrCacheEnable);
if (attr_value != nullptr && GetValue<bool>(attr_value)) {
unique_cache_enable.emplace_back(unique_node);
}
}
}
if (unique_cache_enable.size() > 1) {
MS_LOG(EXCEPTION) << "Support only one of Unique op cache enable, but got " << unique_cache_enable.size();
}
return unique_cache_enable;
}
void BindAndInitCacheTensor(const ParamMap &param_pair_list, const ParameterPtr &hashmap) {
auto hashmap_tensor_value = hashmap->default_param();
auto hashmap_tensor = hashmap_tensor_value->cast<std::shared_ptr<tensor::Tensor>>();
for (auto &ele : param_pair_list) {
auto host_tensor_value = ele.second->default_param();
auto host_tensor = host_tensor_value->cast<std::shared_ptr<tensor::Tensor>>();
auto cache_tensor_value = ele.first->default_param();
auto cache_tensor = cache_tensor_value->cast<std::shared_ptr<tensor::Tensor>>();
// bind host, cache, hashmap
host_tensor->set_cache_enable(true);
host_tensor->set_hashmap_tensor_ptr(hashmap_tensor);
host_tensor->set_cache_tensor_ptr(cache_tensor);
// init cache tensor data
auto cache_byte_size = cache_tensor->Size();
int ret = memcpy_s(cache_tensor->data_c(), cache_byte_size, host_tensor->data_c(), cache_byte_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "Memcpy failed.";
}
}
}
template <typename T>
void InitHashMapData(void *data, const int64_t host_size, const int64_t cache_size, const size_t hashmap_size,
const size_t byte_size) {
MS_LOG(INFO) << "Start init hashmap data.";
MS_EXCEPTION_IF_NULL(data);
HashmapEntry<T> *hashmap_data = static_cast<HashmapEntry<T> *>(data);
MS_EXCEPTION_IF_NULL(hashmap_data);
int ret = memset_s(hashmap_data, byte_size, 0, byte_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "Memset failed.";
}
std::vector<T> host_range;
host_range.reserve(host_size);
for (int64_t i = 0; i < host_size; ++i) {
host_range.emplace_back(i);
}
std::random_shuffle(host_range.begin(), host_range.end());
size_t size = cache_size;
size_t hashmap_count = 0;
for (size_t i = 0; i < size; ++i) {
auto random_key = host_range[i];
auto entry = HashFunc(random_key, hashmap_size);
size_t count = 1;
while (!hashmap_data[entry].IsEmpty() && !hashmap_data[entry].IsKey(random_key)) {
count += 1;
entry = (entry + 1) % hashmap_size;
}
if (hashmap_data[entry].IsEmpty()) {
hashmap_count++;
hashmap_data[entry].key_ = random_key;
hashmap_data[entry].value_ = i;
hashmap_data[entry].step_ = kInitStep;
hashmap_data[entry].tag_ = count;
}
}
MS_LOG(INFO) << "Hashmap init success, with " << hashmap_count << " / " << hashmap_size;
}
AnfNodePtr InitHashMap(const FuncGraphPtr &func_graph, const int64_t host_size, const int64_t cache_size,
TypeId type_id) {
// init new tensor
size_t hashmap_size = cache_size * kEmptyRate;
std::vector<int64_t> host_shape{static_cast<int64_t>(hashmap_size), 4};
auto new_tensor = std::make_shared<tensor::Tensor>(type_id, host_shape);
size_t byte_size = new_tensor->Size();
if (type_id == TypeId::kNumberTypeInt64) {
InitHashMapData<int64_t>(new_tensor->data_c(), host_size, cache_size, hashmap_size, byte_size);
} else {
InitHashMapData<int32_t>(new_tensor->data_c(), host_size, cache_size, hashmap_size, byte_size);
}
ParamInfoPtr new_param_info = std::make_shared<ParamInfo>();
std::string hashmap_name = "cache_hashmap";
new_param_info->set_name(hashmap_name);
new_tensor->set_param_info(new_param_info);
auto hashmap = func_graph->AddWeightParameter(hashmap_name);
hashmap->set_default_param(MakeValue(new_tensor));
hashmap->set_abstract(new_tensor->ToAbstract());
return hashmap;
}
AnfNodePtr InitStep(const FuncGraphPtr &func_graph, TypeId type_id) {
std::vector<int64_t> host_shape{1};
auto new_tensor = std::make_shared<tensor::Tensor>(type_id, host_shape);
auto step_data = static_cast<int64_t *>(new_tensor->data_c());
step_data[0] = 0;
ParamInfoPtr new_param_info = std::make_shared<ParamInfo>();
std::string step_name = "cache_step";
new_param_info->set_name(step_name);
new_tensor->set_param_info(new_param_info);
auto step = func_graph->AddWeightParameter(step_name);
step->set_default_param(MakeValue(new_tensor));
step->set_abstract(new_tensor->ToAbstract());
return step;
}
AnfNodePtr CreateMapCacheIdx(const FuncGraphPtr &func_graph, const AnfNodePtr &indices,
const ParamMap &cache_host_params_map) {
auto iter = cache_host_params_map.begin();
int64_t cache_size = iter->first->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>()->shape()[0];
int64_t host_size = iter->second->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>()->shape()[0];
auto indices_type = indices->Type();
auto indices_element_type = indices_type->cast<mindspore::TensorTypePtr>()->element();
auto indices_type_id = indices_element_type->type_id();
auto hashmap = InitHashMap(func_graph, host_size, cache_size, indices_type_id);
auto step = InitStep(func_graph, indices_type_id);
auto max_num = NewValueNode(MakeValue(host_size));
auto hashmap_param = hashmap->cast<ParameterPtr>();
BindAndInitCacheTensor(cache_host_params_map, hashmap_param);
// add rank_id
int64_t offset_value = 0;
std::string rank_id_str = common::GetEnv("RANK_ID");
if (!rank_id_str.empty()) {
int64_t rank_id = atoi(rank_id_str.c_str());
offset_value = rank_id * host_size;
}
auto offset = NewValueNode(MakeValue(offset_value));
auto max_num_imm = std::make_shared<Int64Imm>(SizeToLong(host_size));
auto max_num_abstract_scalar = std::make_shared<abstract::AbstractScalar>(max_num_imm);
max_num->set_abstract(max_num_abstract_scalar);
auto offset_imm = std::make_shared<Int64Imm>(SizeToLong(offset_value));
auto offset_abstract_scalar = std::make_shared<abstract::AbstractScalar>(offset_imm);
offset->set_abstract(offset_abstract_scalar);
PrimitivePtr map_cache_primitive = prim::kPrimMapCacheIdx;
map_cache_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
std::vector<AnfNodePtr> map_cache_nodes{NewValueNode(map_cache_primitive), hashmap, indices, step, max_num, offset};
auto map_cache_idx = func_graph->NewCNode(map_cache_nodes);
auto indices_ori_shp = indices->Shape();
auto indices_shp = indices_ori_shp->cast<abstract::ShapePtr>();
ShapeVector shape;
ShapeVector min_shape;
ShapeVector max_shape;
if (!indices_shp->max_shape().empty()) {
max_shape = indices_shp->max_shape();
} else {
max_shape = indices_shp->shape();
}
for (size_t i = 0; i < max_shape.size(); i++) {
shape.emplace_back(-1);
min_shape.emplace_back(1);
}
auto cache_idx = std::make_shared<abstract::AbstractTensor>(indices_element_type, indices_shp);
auto old_emb_idx = std::make_shared<abstract::AbstractTensor>(
indices_element_type, std::make_shared<abstract::Shape>(shape, min_shape, max_shape));
auto miss_emb_idx = std::make_shared<abstract::AbstractTensor>(
indices_element_type, std::make_shared<abstract::Shape>(shape, min_shape, max_shape));
auto swap_emb_idx = std::make_shared<abstract::AbstractTensor>(
indices_element_type, std::make_shared<abstract::Shape>(shape, min_shape, max_shape));
std::vector<std::shared_ptr<abstract::AbstractBase>> elements = {cache_idx, old_emb_idx, miss_emb_idx, swap_emb_idx};
auto abstract = std::make_shared<abstract::AbstractTuple>(elements);
map_cache_idx->set_abstract(abstract);
return map_cache_idx;
}
AnfNodePtr CreateTupleGetItem(const FuncGraphPtr &func_graph, const AnfNodePtr &input, size_t index) {
MS_EXCEPTION_IF_NULL(func_graph);
auto idx = NewValueNode(SizeToLong(index));
MS_EXCEPTION_IF_NULL(idx);
auto imm = std::make_shared<Int64Imm>(SizeToLong(index));
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
idx->set_abstract(abstract_scalar);
auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, idx});
auto input_abstract_tuple = dyn_cast<abstract::AbstractTuple>(input->abstract());
auto tuple_getitem_abstract = input_abstract_tuple->elements()[index];
tuple_getitem->set_abstract(tuple_getitem_abstract);
return tuple_getitem;
}
void CreateTupleGetItems(const FuncGraphPtr &func_graph, const AnfNodePtr &input, std::vector<AnfNodePtr> *outputs) {
auto input_abstract_tuple = dyn_cast<abstract::AbstractTuple>(input->abstract());
auto size = input_abstract_tuple->elements().size();
for (size_t i = 0; i < size; ++i) {
(*outputs).emplace_back(CreateTupleGetItem(func_graph, input, i));
}
MS_EXCEPTION_IF_NULL(outputs);
}
AnfNodePtr CreateEmbeddingLookup(const FuncGraphPtr &graph, ParameterPtr params, AnfNodePtr indices) {
MS_EXCEPTION_IF_NULL(graph);
PrimitivePtr emb_lookup_primitive = prim::kPrimEmbeddingLookup;
emb_lookup_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
emb_lookup_primitive->set_attr(kAttrOffset, MakeValue<int64_t>(0));
std::vector<AnfNodePtr> emb_lookup_nodes{NewValueNode(emb_lookup_primitive), params, indices};
auto emb_lookup = graph->NewCNode(emb_lookup_nodes);
return emb_lookup;
}
AnfNodePtr CreateCacheSwapTable(const FuncGraphPtr &graph, ParameterPtr cache_table, AnfNodePtr swap_cache_idx,
AnfNodePtr miss_value) {
MS_EXCEPTION_IF_NULL(graph);
PrimitivePtr cache_swap_table_primitive = prim::kPrimCacheSwapTable;
std::vector<AnfNodePtr> cache_swap_table_nodes{NewValueNode(cache_swap_table_primitive), cache_table, swap_cache_idx,
miss_value};
auto cache_swap_table = graph->NewCNode(cache_swap_table_nodes);
return cache_swap_table;
}
AnfNodePtr CreateUpdateCache(const FuncGraphPtr &graph, ParameterPtr params, AnfNodePtr old_emb_idx,
AnfNodePtr old_value) {
MS_EXCEPTION_IF_NULL(graph);
PrimitivePtr update_cache_primitive = prim::kPrimUpdateCache;
update_cache_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
auto params_ori_shp = params->Shape();
MS_EXCEPTION_IF_NULL(params_ori_shp);
auto params_shp = params_ori_shp->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(params_shp);
auto params_shape = params_shp->shape();
auto max_size = params_shape[0];
auto max_size_node = NewValueNode(MakeValue(max_size));
auto max_num_imm = std::make_shared<Int64Imm>(SizeToLong(max_size));
auto max_num_abstract_scalar = std::make_shared<abstract::AbstractScalar>(max_num_imm);
max_size_node->set_abstract(max_num_abstract_scalar);
std::vector<AnfNodePtr> update_cache_nodes{NewValueNode(update_cache_primitive), params, old_emb_idx, old_value,
max_size_node};
auto update_cache = graph->NewCNode(update_cache_nodes);
return update_cache;
}
NodePairList CreateEmbSwapUpdate(const FuncGraphPtr &graph, ParamMap param_pair_list,
const AnfNodePtrList &map_cache_idx_node_outputs) {
MS_EXCEPTION_IF_NULL(graph);
NodePairList node_pair_list;
for (auto &ele : param_pair_list) {
auto emb_lookup = CreateEmbeddingLookup(graph, ele.second, map_cache_idx_node_outputs[2]);
auto cache_swap_table = CreateCacheSwapTable(graph, ele.first, map_cache_idx_node_outputs[3], emb_lookup);
auto update_cache = CreateUpdateCache(graph, ele.second, map_cache_idx_node_outputs[1], cache_swap_table);
node_pair_list.emplace_back(std::make_pair(cache_swap_table, update_cache));
}
return node_pair_list;
}
AnfNodePtr CreateControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node,
const AnfNodePtr &behind_node) {
// Create control depend
MS_EXCEPTION_IF_NULL(main_graph);
AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), prior_node, behind_node};
auto control_depend_cnode = main_graph->NewCNode(cd_inputs);
return control_depend_cnode;
}
AnfNodePtr CreateDepend(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &invalid_nodes,
const AnfNodePtr &patron_node) {
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> make_tuple_list{NewValueNode(prim::kPrimMakeTuple)};
std::copy(invalid_nodes.begin(), invalid_nodes.end(), std::back_inserter(make_tuple_list));
auto make_tuple = graph->NewCNode(make_tuple_list);
std::vector<AnfNodePtr> depend_list{NewValueNode(prim::kPrimDepend), patron_node, make_tuple};
auto depend_cnode = graph->NewCNode(depend_list);
depend_cnode->set_abstract(patron_node->abstract());
return depend_cnode;
}
CNodePtrList FindSparseGatherV2WithCache(const CNodePtrList &cnodes, const ParamSet &param_set) {
size_t cnodes_size = cnodes.size();
CNodePtrList sparse_gather_v2_with_cache;
for (size_t i = 0; i < cnodes_size; ++i) {
if (IsPrimitiveCNode(cnodes[i], prim::kPrimSparseGatherV2)) {
auto param_node = cnodes[i]->input(1)->cast<ParameterPtr>();
if (param_set.find(param_node) != param_set.end()) {
sparse_gather_v2_with_cache.push_back(cnodes[i]);
}
}
}
if (sparse_gather_v2_with_cache.empty()) {
MS_LOG(EXCEPTION) << "Can not find SparseGatherV2 with cache param.";
}
auto indices = sparse_gather_v2_with_cache[0]->input(2);
for (auto &ele : sparse_gather_v2_with_cache) {
if (ele->input(2) != indices) {
MS_LOG(EXCEPTION) << "SparseGatherV2 which with cache param have different indices!.";
}
}
return sparse_gather_v2_with_cache;
}
AnfNodePtr FindGatherV2FromSparseGatherV2(const FuncGraphPtr &graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(graph);
AnfNodePtrList gatherv2_nodes;
auto user_set = graph->manager()->node_users()[node];
for (auto &ele : user_set) {
if (IsPrimitiveCNode(ele.first, prim::kPrimGatherV2)) {
gatherv2_nodes.emplace_back(ele.first);
}
}
if (gatherv2_nodes.size() != 1) {
MS_LOG(EXCEPTION) << "SparseGatherV2 with cache can only used by one of gatherv2, but got "
<< gatherv2_nodes.size();
}
return gatherv2_nodes[0];
}
void AddCacheEmbedding(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
std::list<CNodePtr> orders = graph->GetOrderedCnodes();
CNodePtrList cnodes(orders.begin(), orders.end());
size_t cnodes_size = cnodes.size();
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
bool training = graph->has_flag("training");
auto param_cache_enable_set = FindParamCacheEnable(graph);
if (param_cache_enable_set.empty()) {
MS_LOG(INFO) << "Parameters are all not cache enable.";
return;
} else {
MS_LOG(INFO) << "Parameters have cache enable.";
}
if (!CheckHostCacheParamSize(param_cache_enable_set)) {
return;
}
if (training) {
// If training, create cache parameters corresponding to the host params with is cache_enable.
// Replace the host params. Create hashmap then insert MapCacheIdx op after Unique with has 'cache_enable' attr.
// Bind hashmap tensor ptr and cache tensor ptr to host tensor, so that we can flush values
// from cache to host in each epoch end.
// Create EmbeddingLookup(CPU), CacheSwapTable(Ascend), UpdateCache(CPU) for each pair of params, in order to
// flush miss values to cache params and write back old values to host params.
// If no use pipe in training, EmbeddingLookup and CacheSwapTable must execute before SparseGatherV2, so add
// ControlDepend between them. And add Depend for UpdateCache op and ControlDepnd op to add nodes into graph.
auto unique_cache_enable = FindUniqueCacheEnable(cnodes);
if (unique_cache_enable.empty()) {
MS_LOG(WARNING) << "Parameters have cache enable, but not find Unique op cache enable.";
return;
}
auto cache_host_params_map = AddCacheParameters(graph, param_cache_enable_set);
auto param_set = MapKeysToSet(cache_host_params_map);
ReplaceCacheParams(graph, cache_host_params_map);
graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true);
auto unique_node = unique_cache_enable[0];
CNodePtrList sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_set);
auto unique_node_output_0 = CreateTupleGetItem(graph, unique_node, 0);
auto map_cache_idx = CreateMapCacheIdx(graph, unique_node_output_0, cache_host_params_map);
AnfNodePtrList map_cache_idx_node_outputs;
CreateTupleGetItems(graph, map_cache_idx, &map_cache_idx_node_outputs);
if (!manager->Replace(sparse_gatherv2_with_cache[0]->input(2), map_cache_idx_node_outputs[0])) {
MS_LOG(EXCEPTION) << "MapCacheIdx output[0] replace node failed";
}
auto node_pair_list = CreateEmbSwapUpdate(graph, cache_host_params_map, map_cache_idx_node_outputs);
AnfNodePtr last_node = cnodes[cnodes_size - 1];
CNodePtr return_node;
if (last_node->isa<CNode>()) {
return_node = last_node->cast<CNodePtr>();
}
MS_EXCEPTION_IF_NULL(return_node);
if (!IsPrimitiveCNode(return_node, prim::kPrimReturn)) {
MS_LOG(EXCEPTION) << "The last cnode after sorting, not return cnode.";
}
if (return_node->inputs().size() < 2) {
MS_LOG(EXCEPTION) << "Number of return node inputs should be great than or equal to 2.";
}
AnfNodePtrList invalid_nodes;
for (auto &ele : node_pair_list) {
std::transform(sparse_gatherv2_with_cache.begin(), sparse_gatherv2_with_cache.end(),
std::back_inserter(invalid_nodes), [&graph, &ele](const AnfNodePtr &sparse_gatherv2) {
return CreateControlDepend(graph, ele.first, sparse_gatherv2);
});
invalid_nodes.emplace_back(ele.second);
}
auto depend_node = CreateDepend(graph, invalid_nodes, return_node->input(1));
if (!manager->Replace(return_node->input(1), depend_node)) {
MS_LOG(EXCEPTION) << "Depend replace node failed";
}
} else {
// If eval, Use EmbeddingLookup(CPU) op to replace GatherV2.
// The network is the same as Host-Device mode.
auto unique_cache_enable = FindUniqueCacheEnable(cnodes);
if (unique_cache_enable.empty()) {
MS_LOG(WARNING) << "Parameters have cache enable, but not find Unique op cache enable.";
return;
}
graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true);
// replace GatherV2 to EmbeddingLookupCPU
auto indices = unique_cache_enable[0]->input(1);
auto sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_cache_enable_set);
for (auto &ele : sparse_gatherv2_with_cache) {
auto anf_ele = ele->cast<AnfNodePtr>();
auto gatherv2 = FindGatherV2FromSparseGatherV2(graph, anf_ele);
auto param = ele->input(1)->cast<ParameterPtr>();
auto embedding_lookup = CreateEmbeddingLookup(graph, param, indices);
if (!manager->Replace(gatherv2, embedding_lookup)) {
MS_LOG(EXCEPTION) << "Depend replace node failed";
}
}
}
}
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,28 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_CACHE_EMBEDDING_CACHE_EMBEDDING_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_CACHE_EMBEDDING_CACHE_EMBEDDING_H_
#include "ir/anf.h"
namespace mindspore {
namespace parallel {
// Automatically adding control depend based on effect order and side effect analysis.
void AddCacheEmbedding(const FuncGraphPtr &graph);
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_CACHE_EMBEDDING_CACHE_EMBEDDING_H_

View File

@ -36,11 +36,14 @@
#include "frontend/optimizer/graph_transform.h"
#include "frontend/parallel/step_parallel.h"
#include "frontend/parallel/step_auto_parallel.h"
#include "frontend/parallel/cache_embedding/cache_embedding.h"
#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
#include "frontend/optimizer/recompute.h"
#include "utils/log_adapter.h"
#include "pipeline/jit/pipeline_split.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/util.h"
#endif
namespace mindspore {
namespace pipeline {
using OptPassGroupMap = opt::OptPassGroupMap;
@ -391,6 +394,26 @@ bool AddRecomputationPass(const ResourcePtr &res) {
return true;
}
bool AddCacheEmbeddingPass(const ResourcePtr &res) {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::Util::IsParamServerMode()) {
return true;
}
#endif
FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
parallel::AddCacheEmbedding(func_graph);
if (func_graph->has_flag(GRAPH_FLAG_CACHE_ENABLE)) {
auto params = func_graph->parameters();
AbstractBasePtrList args_spec_list;
std::for_each(params.begin(), params.end(),
[&args_spec_list](const AnfNodePtr &node) { args_spec_list.push_back(node->abstract()); });
func_graph = pipeline::Renormalize(res, func_graph, args_spec_list);
}
return true;
}
bool MergeDupGraphPass(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
@ -500,6 +523,7 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
{"tuple_transform", OptPassTransformGraphGroup},
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
{"add_cache_embedding", AddCacheEmbeddingPass},
{"add_control_depend", AddControlDependPass},
{"add_recomputation", AddRecomputationPass}};

View File

@ -37,6 +37,7 @@ bool PipelineSplitPass(const ResourcePtr &res);
bool ValidatePass(const ResourcePtr &res);
bool ConvertPrepareAdapt(const ResourcePtr &res);
bool AddControlDependPass(const ResourcePtr &res);
bool AddCacheEmbeddingPass(const ResourcePtr &res);
bool InferenceOptPreparePass(const ResourcePtr &res);
void ReclaimOptimizer();
bool PynativeOptPass(const ResourcePtr &res);

View File

@ -32,6 +32,8 @@ REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) {
.def_property("parallel_optimizer", &ParamInfo::parallel_optimizer,
&ParamInfo::set_parallel_optimizer)
.def_property("comm_fusion", &ParamInfo::comm_fusion, &ParamInfo::set_comm_fusion)
.def_property("cache_enable", &ParamInfo::cache_enable, &ParamInfo::set_cache_enable)
.def_property("cache_shape", &ParamInfo::cache_shape, &ParamInfo::set_cache_shape)
.def(py::pickle(
[](const ParamInfo &p) { // __getstate__
return py::make_tuple(p.name(), p.requires_grad(), p.layerwise_parallel());

View File

@ -24,6 +24,7 @@
#include "pybind_api/api_register.h"
#include "abstract/abstract_value.h"
#include "utils/shape_utils.h"
#include "utils/cache_embedding_hashmap_struct.h"
namespace mindspore {
namespace tensor {
@ -272,6 +273,68 @@ py::int_ TensorPy::GetPyItemSize(const Tensor &tensor) { return tensor.data().it
py::int_ TensorPy::GetPyNBytes(const Tensor &tensor) { return tensor.data().nbytes(); }
template <typename T>
void MemCopyFromCacheToHost(void *hashmap_addr, void *host_addr, void *cache_addr, size_t host_max, size_t cache_max,
size_t hashmap_size, size_t col_size) {
auto host_data = static_cast<char *>(host_addr);
auto cache_data = static_cast<char *>(cache_addr);
auto hashmap_data = static_cast<HashmapEntry<T> *>(hashmap_addr);
// default param type float
size_t param_type_size = 4;
size_t single_col_bytes = param_type_size * col_size;
for (size_t i = 0; i < hashmap_size; ++i) {
if (!hashmap_data[i].IsEmpty()) {
size_t host_offset = single_col_bytes * hashmap_data[i].key_;
size_t cache_offset = single_col_bytes * hashmap_data[i].value_;
if (cache_offset + single_col_bytes <= cache_max) {
auto ret =
memcpy_s(host_data + host_offset, host_max - host_offset, cache_data + cache_offset, single_col_bytes);
if (ret != 0) {
MS_LOG(EXCEPTION) << "Memcpy failed.";
}
}
}
}
MS_LOG(INFO) << "Memcpy from cache to host success!";
}
void TensorPy::FlushFromCache(const Tensor &tensor) {
py::gil_scoped_release gil_release;
if (tensor.NeedWait()) {
tensor.Wait();
}
tensor.data_sync();
if (tensor.cache_enable()) {
MS_LOG(INFO) << tensor.ToString() << " is cache enable.";
auto hashmap_tensor_ptr = tensor.hashmap_tensor_ptr();
auto cache_tensor_ptr = tensor.cache_tensor_ptr();
if (hashmap_tensor_ptr != nullptr && cache_tensor_ptr != nullptr) {
hashmap_tensor_ptr->data_sync();
cache_tensor_ptr->data_sync();
auto hashmap_size = hashmap_tensor_ptr->shape_c()[0];
auto host_shape = tensor.shape_c();
auto cache_shape = cache_tensor_ptr->shape_c();
if (host_shape.size() != 2 && host_shape.size() != 2 && host_shape[1] != cache_shape[1]) {
MS_LOG(EXCEPTION) << "Got host shape and cache shape invalid."
<< "host shape:" << host_shape << ", cache shape:" << cache_shape;
}
auto host_data_max_size = tensor.Size();
auto cache_data_max_size = cache_tensor_ptr->Size();
auto hashmap_data_type = hashmap_tensor_ptr->data_type();
if (hashmap_data_type == TypeId::kNumberTypeInt32) {
MemCopyFromCacheToHost<int32_t>(hashmap_tensor_ptr->data_c(), tensor.data_c(), cache_tensor_ptr->data_c(),
host_data_max_size, cache_data_max_size, hashmap_size, host_shape[1]);
} else if (hashmap_data_type == TypeId::kNumberTypeInt64) {
MemCopyFromCacheToHost<int32_t>(hashmap_tensor_ptr->data_c(), tensor.data_c(), cache_tensor_ptr->data_c(),
host_data_max_size, cache_data_max_size, hashmap_size, host_shape[1]);
} else {
MS_LOG(ERROR) << "Hashmap dtype only suppotr int32, in64.";
}
}
}
}
py::array TensorPy::SyncAsNumpy(const Tensor &tensor) {
{
py::gil_scoped_release gil_release;
@ -457,6 +520,16 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
array([[1., 1., 1.],
[1., 1., 1.]])
)mydelimiter")
.def("_flush_from_cache", TensorPy::FlushFromCache, R"mydelimiter(
Flush Cache data to Host if tensor is cache enable.
Returns:
None.
Examples:
>>> data = mindspore.Tensor(np.ones((2, 3)))
>>> data._flush_from_cache()
)mydelimiter")
.def("is_init", &Tensor::is_init, R"mydelimiter(
Get tensor init_flag.

View File

@ -115,6 +115,8 @@ class TensorPy {
static py::int_ GetPyItemSize(const Tensor &tensor);
static py::int_ GetPyNBytes(const Tensor &tensor);
static void FlushFromCache(const Tensor &tensor);
};
} // namespace tensor
} // namespace mindspore

View File

@ -268,7 +268,7 @@ void CPUKernelRuntime::CreateOutputTensors(session::KernelGraph *kernel_graph,
bound_addresses_.clear();
auto output_nodes = kernel_graph->outputs();
for (const auto &item : output_nodes) {
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, true);
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, false);
auto out = CreatTensorForOutput(kernel_graph, item_with_index, tensor_to_node);
outputs->push_back(std::move(out));
}

View File

@ -0,0 +1,51 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_UTIL_CACHE_EMBBEDDING_HASHMAP_STRUCT_H_
#define MINDSPORE_CCSRC_UTIL_CACHE_EMBBEDDING_HASHMAP_STRUCT_H_
#include <math.h>
namespace mindspore {
const int64_t kNullTag = 0;
const int64_t kInitStep = -5;
const int64_t kEmptyRate = 4;
const double kGoldenRatio = 0.6180339;
template <typename T>
struct HashmapEntry {
T key_;
T value_;
T step_;
T tag_;
bool IsEmpty() { return tag_ == kNullTag; }
bool IsUsing(const T train_step) { return step_ >= (train_step - 1); }
bool IsKey(const T emb_idx) { return key_ == emb_idx; }
void SetEmpty() { tag_ = kNullTag; }
};
template <typename T>
T HashFunc(const T key, const size_t m) {
return (T)(((kGoldenRatio * key) - floor(kGoldenRatio * key)) * m);
}
} // namespace mindspore
#endif // MINDSPORE_CCSRC_UTIL_CACHE_EMBBEDDING_HASHMAP_STRUCT_H_

View File

@ -350,6 +350,7 @@ constexpr auto kAttrPrimitiveTarget = "primitive_target";
constexpr auto kAttrUseLocking = "use_locking";
constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag";
constexpr auto kAttrOffset = "offset";
constexpr auto kAttrCacheEnable = "cache_enable";
constexpr auto kAttrPsKey = "ps_key";
constexpr auto kAttrOptimizerType = "optim_type";
constexpr auto kAttrChildGraph = "child_graph";

View File

@ -131,7 +131,7 @@ class Parameter(Tensor_):
if self.init_mode is not None:
data = self.init_mode
else:
# cast to break deep infinit loop while deepcopy
# cast to break deep infinite loop while deepcopy
data = Tensor(self)
return (
Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel))
@ -348,6 +348,8 @@ class Parameter(Tensor_):
x.is_param_ps = self.is_param_ps
x.init_in_server = self.init_in_server
x.cache_enable = self.cache_enable
if self.cache_shape:
x.cache_shape = self.cache_shape
if init != 'same':
shape = self.shape
dtype = self.dtype
@ -375,6 +377,28 @@ class Parameter(Tensor_):
raise TypeError("`parallel_optimizer` parameter must be bool type")
self._param_info.parallel_optimizer = value
@property
def cache_enable(self):
"""Return whether the parameter is cache enable."""
return self._param_info.cache_enable
@cache_enable.setter
def cache_enable(self, value=True):
if not isinstance(value, bool):
raise TypeError("`cache_enable` parameter must be bool type")
self._param_info.cache_enable = value
@property
def cache_shape(self):
"""Return the cache shape corresponding to the parameter if use cache."""
return self._param_info.cache_shape
@cache_shape.setter
def cache_shape(self, value):
if not isinstance(value, (tuple, list)):
raise TypeError("`cache_shape` parameter must be tuple or list type")
self._param_info.cache_shape = value
@property
def requires_grad(self):
"""Return whether the parameter requires gradient."""

View File

@ -308,6 +308,10 @@ class Tensor(Tensor_):
"""Convert tensor to numpy array."""
return Tensor_.asnumpy(self)
def _flush_from_cache(self):
"""Flush cache data to host if tensor is cache enable."""
Tensor_._flush_from_cache(self)
def all(self, axis=(), keep_dims=False):
"""
Check all array elements along a given axis evaluate to True.

View File

@ -60,6 +60,7 @@ using ValueNodePtr = std::shared_ptr<ValueNode>;
class CNode;
using CNodePtr = std::shared_ptr<CNode>;
using CNodePtrList = std::vector<CNodePtr>;
class FuncGraph;
using FuncGraphSet = OrderedSet<FuncGraphPtr>;
@ -88,7 +89,7 @@ using ParamInfoPtr = std::shared_ptr<ParamInfo>;
// intermediate_abstract: return the cached inferring abstract value.
// Type/Shape: return the related info of this AnfNode. When this AnfNode is an
// input of other CNodes, you can get the related info by this method.
// debug_info: return the information retrived from parser. Set it using set_debug_info.
// debug_info: return the information retrieved from parser. Set it using set_debug_info.
// fullname_with_scope: return the detailed debug info.
class AnfNode : public Base {
public:

View File

@ -167,7 +167,6 @@ class MetaTensor : public Value {
// Get tensor's param_info info.
ParamInfoPtr param_info() const { return param_info_; }
bool is_parameter() const { return is_parameter_; }
// Set tensor's param_info info.
void set_param_info(const ParamInfoPtr &param_info) {
is_parameter_ = true;

View File

@ -81,6 +81,12 @@ class ParamInfo {
bool parallel_optimizer() const { return parallel_optimizer_; }
void set_parallel_optimizer(bool parallel_optimizer) { parallel_optimizer_ = parallel_optimizer; }
bool cache_enable() const { return cache_enable_; }
void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; }
std::vector<int64_t> cache_shape() const { return cache_shape_; }
void set_cache_shape(const std::vector<int64_t> &cache_shape) { cache_shape_ = cache_shape; }
private:
std::string name_{"Parameter"};
bool requires_grad_{true};
@ -92,6 +98,8 @@ class ParamInfo {
int32_t cloned_index_{0};
int32_t fusion_type_{1};
bool parallel_optimizer_{true};
bool cache_enable_{false};
std::vector<int64_t> cache_shape_;
};
} // namespace mindspore
#endif // MINDSPORE_CORE_IR_PARAM_INFO_H_

View File

@ -449,6 +449,9 @@ Tensor::Tensor(const Tensor &tensor)
event_(tensor.event_),
sync_status_(tensor.sync_status_),
device_sync_(tensor.device_sync_),
cache_enable_(tensor.cache_enable_),
cache_tensor_ptr_(tensor.cache_tensor_ptr_),
hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_),
padding_type_(tensor.padding_type()) {}
Tensor::Tensor(const Tensor &tensor, TypeId data_type)
@ -459,6 +462,9 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type)
event_(tensor.event_),
sync_status_(tensor.sync_status_),
device_sync_(tensor.device_sync_),
cache_enable_(tensor.cache_enable_),
cache_tensor_ptr_(tensor.cache_tensor_ptr_),
hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_),
padding_type_(tensor.padding_type()) {}
Tensor::Tensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data)
@ -511,7 +517,7 @@ bool Tensor::ValueEqual(const Tensor &tensor) const {
return (&tensor == this || (MetaTensor::operator==(tensor) && data_->equals(*tensor.data_)));
}
// assgin value to this tensor
// assign value to this tensor
Tensor &Tensor::AssignValue(const Tensor &tensor) {
if (this != &tensor) {
MetaTensor::operator=(tensor);

View File

@ -206,7 +206,7 @@ class Tensor : public MetaTensor {
// it do real value comparison.
bool ValueEqual(const Tensor &tensor) const;
// assgin value to this tensor
// assign value to this tensor
Tensor &AssignValue(const Tensor &tensor);
bool operator==(const Value &other) const override {
@ -291,6 +291,18 @@ class Tensor : public MetaTensor {
TypePtr cast_dtype() { return cast_dtype_; }
void set_cast_dtype(TypePtr dtype = nullptr) { cast_dtype_ = dtype; }
// used if cache_enable, in order to update tensor from cache to host
bool cache_enable() const { return cache_enable_; }
void set_cache_enable(bool cache_enable = true) { cache_enable_ = cache_enable; }
std::shared_ptr<Tensor> hashmap_tensor_ptr() const { return hashmap_tensor_ptr_; }
void set_hashmap_tensor_ptr(std::shared_ptr<Tensor> hashmap_tensor_ptr = nullptr) {
hashmap_tensor_ptr_ = hashmap_tensor_ptr;
}
std::shared_ptr<Tensor> cache_tensor_ptr() const { return cache_tensor_ptr_; }
void set_cache_tensor_ptr(std::shared_ptr<Tensor> cache_tensor_ptr = nullptr) {
cache_tensor_ptr_ = cache_tensor_ptr;
}
void SetNeedWait(bool need_wait) {
if (event_ != nullptr) {
event_->set_need_wait(need_wait);
@ -335,6 +347,9 @@ class Tensor : public MetaTensor {
mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice};
bool graph_output_{false};
DeviceSyncPtr device_sync_{nullptr};
bool cache_enable_{false};
std::shared_ptr<Tensor> cache_tensor_ptr_{nullptr};
std::shared_ptr<Tensor> hashmap_tensor_ptr_{nullptr};
std::vector<Axis> padding_type_;
TypePtr cast_dtype_{nullptr};
};

View File

@ -21,6 +21,7 @@ const char GRAPH_FLAG_MIX_PRECISION_FP16[] = "fp16";
const char GRAPH_FLAG_MIX_PRECISION_FP32[] = "fp32";
const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect";
const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order";
const char GRAPH_FLAG_CACHE_ENABLE[] = "cache_enable";
const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect";
const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect";

View File

@ -21,6 +21,7 @@ extern const char GRAPH_FLAG_MIX_PRECISION_FP16[];
extern const char GRAPH_FLAG_MIX_PRECISION_FP32[];
extern const char GRAPH_FLAG_HAS_EFFECT[];
extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[];
extern const char GRAPH_FLAG_CACHE_ENABLE[];
extern const char GRAPH_FLAG_RANDOM_EFFECT[];
extern const char GRAPH_FLAG_SIDE_EFFECT[];

View File

@ -172,8 +172,8 @@ class EmbeddingLookup(Cell):
or None. Default: None
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: 0. It is valid only in
parameter server trainning mode and 'DEVICE' target. And the moment parameter of corresponding
optimizer will also be set to the cache size. In addition, it should be noted that it will cost the 'DEVICE'
'DEVICE' target. And the moment parameter of corresponding optimizer will also be set to the cache size.
In addition, it should be noted that it will cost the 'DEVICE'
memory, so suggests setting a reasonable value to avoid insufficient memory.
Inputs:
@ -205,7 +205,12 @@ class EmbeddingLookup(Cell):
max_norm=None, sparse=True, vocab_cache_size=0):
super(EmbeddingLookup, self).__init__()
validator.check_value_type('sparse', sparse, [bool], self.cls_name)
self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size')
self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size')
self.target = target
self.sparse = sparse
self.cache_enable = self.vocab_cache_size > 0
self.forward_unique = False
if target not in ('CPU', 'DEVICE'):
raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed '
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
@ -216,21 +221,23 @@ class EmbeddingLookup(Cell):
else:
self.gatherv2 = P.GatherV2()
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size')
self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size')
self._process_vocab_cache(slice_mode)
enable_ps = _get_ps_context("enable_ps")
if enable_ps:
self._process_vocab_cache(slice_mode)
self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size')
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
name='embedding_table')
if self.cache_enable:
self._set_voacb_cache_enable(vocab_cache_size, embedding_size, vocab_size)
if self.cache_enable and enable_ps:
self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size)
parallel_mode = _get_parallel_mode()
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
self.forward_unique = False
self.gather_revert = P.GatherV2()
self.unique = P.Unique().shard(((1,),))
self.reshape_first = P.Reshape()
self.reshape = P.Reshape()
self.unique = P.Unique()
self.shape = P.Shape()
if is_auto_parallel:
self.unique = P.Unique().shard(((1,),))
indices_shape_size = 2
if slice_mode == "field_slice" and is_auto_parallel:
if not manual_shapes:
@ -270,12 +277,34 @@ class EmbeddingLookup(Cell):
if is_auto_parallel:
raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get "
+ str(slice_mode))
if self.cache_enable and not enable_ps:
if is_auto_parallel:
raise ValueError("parallel mode haven't supported cache enable yet.")
self._set_cache_enable()
self.embedding_table.unique = self.forward_unique
self.max_norm = max_norm
if self.max_norm is not None:
self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name)
self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
def _set_cache_enable(self):
"""EmbeddingLookup cache check for not ps env."""
if self.target != 'DEVICE':
logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, "
"so it will be ignored.")
return
if not self.sparse:
logger.warning("The configuration of 'vocab_cache_size' is valid only 'sparse' is true, "
"so it will be ignored.")
return
logger.info("EmbeddingLookup cache enable takes effect.")
self.forward_unique = True
self.unique = P.Unique().add_prim_attr('primitive_target', 'CPU')
self.unique.add_prim_attr('cache_enable', True)
self.embedding_table.cache_enable = self.cache_enable
self.embedding_table.cache_shape = (self.vocab_cache_size, self.embedding_size)
self.reshape_first = P.Reshape().add_prim_attr('primitive_target', 'CPU')
def _process_vocab_cache(self, slice_mode):
"""PS embeddingLookup cache check and process."""
self.cache_enable = False
@ -302,7 +331,7 @@ class EmbeddingLookup(Cell):
if _is_role_worker():
self.vocab_size = self.vocab_cache_size
def _set_voacb_cache_enable(self, vocab_cache_size, embedding_size, vocab_size):
def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size):
"""PS embeddingLookup cache enable set."""
self.embedding_table.cache_enable = True
self.embedding_table.is_param_ps = True
@ -316,7 +345,7 @@ class EmbeddingLookup(Cell):
else:
if self.forward_unique:
shp = self.shape(indices) + (self.embedding_size,)
indices_flatten = self.reshape(indices, (-1,))
indices_flatten = self.reshape_first(indices, (-1,))
unique_id, unique_idx = self.unique(indices_flatten)
weight_unique = self.gatherv2(self.embedding_table, unique_id, 0)
weight_flatten = self.gather_revert(weight_unique, unique_idx, 0)

View File

@ -156,8 +156,8 @@ class Optimizer(Cell):
break
ps_filter = lambda x: x.is_param_ps
self.ps_parameters = tuple(ps_filter(x) for x in self.parameters)
ps_cache_filter = lambda x: x.cache_enable
self.cache_enable = tuple(ps_cache_filter(x) for x in self.parameters)
cache_filter = lambda x: x.cache_enable
self.cache_enable = tuple(cache_filter(x) for x in self.parameters)
self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32)
self.need_scale = loss_scale != 1.0
self.global_step_increase_tensor = Tensor(1, mstype.int32)

View File

@ -526,6 +526,9 @@ class Model:
train_dataset.reset()
# if param is cache enable, flush data from cache to host before epoch end
self._flush_from_cache(cb_params)
list_callback.epoch_end(run_context)
should_stop = should_stop or run_context.get_stop_requested()
if should_stop:
@ -784,5 +787,11 @@ class Model:
predict_net.compile(*predict_data)
return predict_net.parameter_layout_dict
def _flush_from_cache(self, cb_params):
"""Flush cache data to host if tensor is cache enable."""
params = cb_params.train_network.get_parameters()
for param in params:
if param.cache_enable:
Tensor(param)._flush_from_cache()
__all__ = ["Model"]

View File

@ -53,8 +53,8 @@ def init_var_dict(init_args, in_vars):
'''
var_map = {}
_, _max_val = init_args
for _, iterm in enumerate(in_vars):
key, shape, method = iterm
for _, item in enumerate(in_vars):
key, shape, method = item
if key not in var_map.keys():
if method in ['random', 'uniform']:
var_map[key] = Parameter(initializer(
@ -257,9 +257,11 @@ class WideDeepModel(nn.Cell):
self.wide_embeddinglookup.embedding_table.set_param_ps()
else:
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim,
target='DEVICE', sparse=sparse)
target='DEVICE', sparse=sparse,
vocab_cache_size=self.vocab_cache_size)
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1,
target='DEVICE', sparse=sparse)
target='DEVICE', sparse=sparse,
vocab_cache_size=self.vocab_cache_size)
self.embedding_table = self.deep_embeddinglookup.embedding_table
def construct(self, id_hldr, wt_hldr):

View File

@ -57,8 +57,8 @@ def init_var_dict(init_args, in_vars):
"""
var_map = {}
_, _max_val = init_args
for _, iterm in enumerate(in_vars):
key, shape, method = iterm
for _, item in enumerate(in_vars):
key, shape, method = item
if key not in var_map.keys():
if method in ['random', 'uniform']:
var_map[key] = Parameter(initializer(Uniform(_max_val), shape,