forked from mindspore-Ecosystem/mindspore
add dynamic ops
This commit is contained in:
parent
8aa78c2c8e
commit
b7d8e87647
|
@ -35,7 +35,6 @@ void AssignCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
}
|
||||
}
|
||||
input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
||||
|
||||
if (input_x_dtype_ == kNumberTypeFloat32 || input_x_dtype_ == kNumberTypeInt32) {
|
||||
input_x_dtype_size_ = 4;
|
||||
} else if (input_x_dtype_ == kNumberTypeFloat64 || input_x_dtype_ == kNumberTypeInt64) {
|
||||
|
@ -75,6 +74,5 @@ void AssignCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -60,7 +60,6 @@ MS_REG_CPU_KERNEL(
|
|||
Assign,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
AssignCPUKernel);
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
template <typename T>
|
||||
void Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) {
|
||||
T i = (entry + 1) % length, off = 1;
|
||||
|
@ -107,6 +106,5 @@ void CacheSwapHashmapCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inpu
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
class CacheSwapHashmapCPUKernel : public CPUKernel {
|
||||
public:
|
||||
CacheSwapHashmapCPUKernel() = default;
|
||||
|
@ -82,7 +81,6 @@ MS_REG_CPU_KERNEL(CacheSwapHashmap,
|
|||
.AddOutputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
CacheSwapHashmapCPUKernel);
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
template <typename T>
|
||||
struct HashmapEntry {
|
||||
T key;
|
||||
|
@ -60,8 +59,9 @@ T HashFunc(const T &key, const size_t &m) {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) {
|
||||
int Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) {
|
||||
T i = (entry + 1) % length, off = 1;
|
||||
int compress_count = 0;
|
||||
for (; !entry_p[i].IsEmpty(); i = (i + 1) % length, off++) {
|
||||
if (entry_p[i].tag > off) {
|
||||
entry_p[entry].key = entry_p[i].key;
|
||||
|
@ -72,21 +72,20 @@ void Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) {
|
|||
off = 0;
|
||||
entry = i;
|
||||
}
|
||||
compress_count++;
|
||||
}
|
||||
return compress_count;
|
||||
}
|
||||
|
||||
void MapCacheIdxCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
node_ = kernel_node;
|
||||
auto hashmap_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto emb_idx_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
|
||||
if (hashmap_shape.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "Dimension of HashMap must be 2, (n, 4)";
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < emb_idx_shape.size(); ++i) {
|
||||
batch_size_ *= emb_idx_shape[i];
|
||||
}
|
||||
|
||||
hashmap_length_ = hashmap_shape[0];
|
||||
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
||||
}
|
||||
|
@ -108,100 +107,124 @@ bool MapCacheIdxCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
template <typename T>
|
||||
void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto emb_idx_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1);
|
||||
batch_size_ = 1;
|
||||
for (size_t i = 0; i < emb_idx_shape.size(); ++i) {
|
||||
batch_size_ *= emb_idx_shape[i];
|
||||
}
|
||||
HashmapEntry<T> *hashmap = reinterpret_cast<HashmapEntry<T> *>(inputs[0]->addr);
|
||||
auto input_indices = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
T *step_ = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
T emb_max_num = *reinterpret_cast<T *>(inputs[3]->addr);
|
||||
T cache_max_num = *reinterpret_cast<T *>(inputs[4]->addr);
|
||||
T offset = *reinterpret_cast<T *>(inputs[4]->addr);
|
||||
auto output_cache_idx = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
auto output_old_emb_idx = reinterpret_cast<T *>(outputs[1]->addr);
|
||||
auto output_miss_emb_idx = reinterpret_cast<T *>(outputs[2]->addr);
|
||||
auto output_swap_cache_idx = reinterpret_cast<T *>(outputs[3]->addr);
|
||||
|
||||
std::vector<T> output_miss_idx(batch_size_, -1);
|
||||
|
||||
std::vector<T> miss_idx;
|
||||
size_t miss_count = 0;
|
||||
float total_count = 0;
|
||||
int count_size = 0;
|
||||
float hit_count = 0;
|
||||
|
||||
// search_cache_idx
|
||||
for (size_t i = 0; i < batch_size_; ++i) {
|
||||
if (input_indices[i] == emb_max_num) {
|
||||
output_miss_idx[i] = -1;
|
||||
output_cache_idx[i] = cache_max_num;
|
||||
output_miss_emb_idx[i] = -1;
|
||||
T key = input_indices[i] - offset;
|
||||
if (key >= emb_max_num || key < 0) {
|
||||
output_cache_idx[i] = -1;
|
||||
continue;
|
||||
}
|
||||
|
||||
T key = input_indices[i];
|
||||
T tmp_entry = HashFunc(key, hashmap_length_);
|
||||
|
||||
int count = 1;
|
||||
size_t count = 1;
|
||||
count_size += 1;
|
||||
while ((!hashmap[tmp_entry].IsEmpty() && !hashmap[tmp_entry].IsKey(key))) {
|
||||
tmp_entry = (tmp_entry + 1) % hashmap_length_;
|
||||
if (count > hashmap_length_) {
|
||||
MS_LOG(ERROR) << "Hashmap is full, search cache idx failed!";
|
||||
break;
|
||||
}
|
||||
count += 1;
|
||||
}
|
||||
|
||||
total_count += count;
|
||||
if (hashmap[tmp_entry].IsEmpty()) {
|
||||
output_miss_idx[i] = i;
|
||||
output_miss_emb_idx[i] = key;
|
||||
miss_idx.emplace_back(i);
|
||||
output_miss_emb_idx[miss_count] = key;
|
||||
output_cache_idx[i] = -1;
|
||||
miss_count++;
|
||||
} else {
|
||||
hit_count += 1;
|
||||
output_miss_idx[i] = -1;
|
||||
output_cache_idx[i] = hashmap[tmp_entry].value;
|
||||
hashmap[tmp_entry].step = step_[0];
|
||||
output_miss_emb_idx[i] = -1;
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "avg search count: " << total_count / count_size;
|
||||
MS_LOG(INFO) << "cache hit rate: " << hit_count / count_size;
|
||||
MS_LOG(INFO) << "Miss count: " << miss_count;
|
||||
MS_LOG(INFO) << "Avg search count: " << total_count / count_size;
|
||||
MS_LOG(INFO) << "Cache hit rate: " << hit_count / count_size;
|
||||
|
||||
float total_insert_count = 0;
|
||||
float total_delete_count = 0;
|
||||
|
||||
// swap hash map
|
||||
for (size_t i = 0; i < batch_size_; ++i) {
|
||||
if (output_miss_emb_idx[i] < 0) {
|
||||
output_swap_cache_idx[i] = -1;
|
||||
output_old_emb_idx[i] = -1;
|
||||
} else {
|
||||
T emb_idx = output_miss_emb_idx[i];
|
||||
T entry = HashFunc(emb_idx, hashmap_length_);
|
||||
T tag_count = 1;
|
||||
while (!hashmap[entry].IsEmpty()) {
|
||||
entry = (entry + 1) % hashmap_length_;
|
||||
tag_count++;
|
||||
for (size_t i = 0; i < miss_count; ++i) {
|
||||
T emb_idx = output_miss_emb_idx[i];
|
||||
T entry = HashFunc(emb_idx, hashmap_length_);
|
||||
size_t tag_count = 1;
|
||||
while (!hashmap[entry].IsEmpty()) {
|
||||
entry = (entry + 1) % hashmap_length_;
|
||||
if (tag_count > hashmap_length_) {
|
||||
MS_LOG(ERROR) << "Hashmap is full, insert new key failed!";
|
||||
break;
|
||||
}
|
||||
|
||||
hashmap[entry].key = emb_idx;
|
||||
hashmap[entry].step = step_[0];
|
||||
hashmap[entry].tag = tag_count;
|
||||
|
||||
T tmp_entry = (entry + 1) % hashmap_length_;
|
||||
|
||||
while (hashmap[tmp_entry].IsEmpty() || hashmap[tmp_entry].IsUsing(step_[0])) {
|
||||
tmp_entry = (tmp_entry + 1) % hashmap_length_;
|
||||
}
|
||||
|
||||
output_swap_cache_idx[i] = hashmap[tmp_entry].value;
|
||||
output_old_emb_idx[i] = hashmap[tmp_entry].key;
|
||||
hashmap[entry].value = output_swap_cache_idx[i];
|
||||
hashmap[tmp_entry].SetEmpty();
|
||||
Compress(hashmap, hashmap_length_, tmp_entry);
|
||||
tag_count++;
|
||||
}
|
||||
|
||||
hashmap[entry].key = emb_idx;
|
||||
hashmap[entry].step = step_[0];
|
||||
hashmap[entry].tag = tag_count;
|
||||
|
||||
T tmp_entry = (entry + 1) % hashmap_length_;
|
||||
size_t delete_count = 1;
|
||||
while (hashmap[tmp_entry].IsEmpty() || hashmap[tmp_entry].IsUsing(step_[0])) {
|
||||
tmp_entry = (tmp_entry + 1) % hashmap_length_;
|
||||
if (delete_count > hashmap_length_) {
|
||||
MS_LOG(ERROR) << "Hashmap is full, delete old key failed!";
|
||||
break;
|
||||
}
|
||||
delete_count++;
|
||||
}
|
||||
|
||||
output_swap_cache_idx[i] = hashmap[tmp_entry].value;
|
||||
output_old_emb_idx[i] = hashmap[tmp_entry].key;
|
||||
hashmap[entry].value = output_swap_cache_idx[i];
|
||||
hashmap[tmp_entry].SetEmpty();
|
||||
int compress_count = Compress(hashmap, hashmap_length_, tmp_entry);
|
||||
total_delete_count += (compress_count + delete_count);
|
||||
total_insert_count += tag_count;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Insert count: " << total_insert_count / miss_count;
|
||||
MS_LOG(INFO) << "Delete count: " << total_delete_count / miss_count;
|
||||
|
||||
// update step
|
||||
step_[0] += 1;
|
||||
|
||||
// update cache idx
|
||||
for (size_t i = 0; i < batch_size_; ++i) {
|
||||
if (output_miss_idx[i] < 0 || output_miss_idx[i] >= cache_max_num) {
|
||||
continue;
|
||||
}
|
||||
output_cache_idx[i] = output_swap_cache_idx[i];
|
||||
for (size_t i = 0; i < miss_count; ++i) {
|
||||
int idx = miss_idx[i];
|
||||
output_cache_idx[idx] = output_swap_cache_idx[i];
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<size_t> out_shape;
|
||||
out_shape.emplace_back(miss_count);
|
||||
std::vector<TypeId> dtypes;
|
||||
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node_); i++) {
|
||||
dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i));
|
||||
}
|
||||
AnfAlgo::SetOutputInferTypeAndShape(dtypes, {AnfAlgo::GetOutputInferShape(node_, 0), out_shape, out_shape, out_shape},
|
||||
node_.get());
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,7 +27,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
class MapCacheIdxCPUKernel : public CPUKernel {
|
||||
public:
|
||||
MapCacheIdxCPUKernel() = default;
|
||||
|
@ -45,6 +44,7 @@ class MapCacheIdxCPUKernel : public CPUKernel {
|
|||
size_t batch_size_{1};
|
||||
size_t hashmap_length_{1};
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
CNodePtr node_ = nullptr;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(MapCacheIdx,
|
||||
|
@ -98,7 +98,6 @@ MS_REG_CPU_KERNEL(MapCacheIdx,
|
|||
.AddOutputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
MapCacheIdxCPUKernel);
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -99,6 +99,5 @@ void SearchCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs
|
|||
MS_LOG(INFO) << "avg search count: " << total_count / count_size;
|
||||
MS_LOG(INFO) << "cache hit rate: " << hit_count / count_size;
|
||||
}
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,7 +27,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
template <typename T>
|
||||
struct HashmapEntry {
|
||||
T key;
|
||||
|
@ -133,7 +132,6 @@ MS_REG_CPU_KERNEL(SearchCacheIdx,
|
|||
.AddOutputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
SearchCacheIdxCPUKernel);
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -21,20 +21,9 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void UpdateCacheCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
auto update_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
|
||||
if (indices_shape.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "indices shape less than 2";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
node_ = kernel_node;
|
||||
|
||||
for (size_t i = 0; i < indices_shape.size(); ++i) {
|
||||
batch_size_ *= indices_shape[i];
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < update_shape.size(); ++i) {
|
||||
update_size_ *= update_shape[i];
|
||||
}
|
||||
update_length_ = update_size_ / batch_size_;
|
||||
input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
||||
indices_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1);
|
||||
|
||||
|
@ -64,6 +53,19 @@ bool UpdateCacheCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
template <typename T>
|
||||
void UpdateCacheCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1);
|
||||
auto update_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 2);
|
||||
|
||||
batch_size_ = 1;
|
||||
for (size_t i = 0; i < indices_shape.size(); ++i) {
|
||||
batch_size_ *= indices_shape[i];
|
||||
}
|
||||
MS_LOG(INFO) << "UpdateCache batch_size:" << batch_size_;
|
||||
update_size_ = 1;
|
||||
for (size_t i = 0; i < update_shape.size(); ++i) {
|
||||
update_size_ *= update_shape[i];
|
||||
}
|
||||
update_length_ = update_shape[1];
|
||||
char *input_x = reinterpret_cast<char *>(inputs[0]->addr);
|
||||
T *indices = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
char *update = reinterpret_cast<char *>(inputs[2]->addr);
|
||||
|
@ -80,6 +82,5 @@ void UpdateCacheCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -46,6 +46,7 @@ class UpdateCacheCPUKernel : public CPUKernel {
|
|||
TypeId input_x_dtype_{kTypeUnknown};
|
||||
TypeId indices_dtype_{kTypeUnknown};
|
||||
size_t input_x_dtype_size_ = 4;
|
||||
CNodePtr node_ = nullptr;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(UpdateCache,
|
||||
|
@ -101,7 +102,6 @@ MS_REG_CPU_KERNEL(UpdateCache,
|
|||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
UpdateCacheCPUKernel);
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -201,7 +201,12 @@ AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &prim
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplRealDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplUpdateCache(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -254,6 +254,99 @@ AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const Primitiv
|
|||
return std::make_shared<AbstractTensor>(x->element(), x->shape());
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 5);
|
||||
auto hash_map = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(hash_map);
|
||||
MS_EXCEPTION_IF_NULL(hash_map->shape());
|
||||
|
||||
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
auto indices_shp = indices->shape();
|
||||
MS_EXCEPTION_IF_NULL(indices);
|
||||
MS_EXCEPTION_IF_NULL(indices_shp);
|
||||
|
||||
ShapeVector shape;
|
||||
ShapeVector min_shape;
|
||||
ShapeVector max_shape;
|
||||
if (!indices_shp->max_shape().empty()) {
|
||||
max_shape = indices_shp->max_shape();
|
||||
} else {
|
||||
max_shape = indices_shp->shape();
|
||||
}
|
||||
for (size_t i = 0; i < max_shape.size(); i++) {
|
||||
shape.emplace_back(Shape::SHP_ANY);
|
||||
min_shape.emplace_back(1);
|
||||
}
|
||||
|
||||
auto cache_idx = std::make_shared<AbstractTensor>(hash_map->element(), indices->shape());
|
||||
auto old_emb_idx =
|
||||
std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
|
||||
auto miss_emb_idx =
|
||||
std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
|
||||
auto swap_emb_idx =
|
||||
std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
|
||||
|
||||
AbstractBasePtrList elements = {cache_idx, old_emb_idx, miss_emb_idx, swap_emb_idx};
|
||||
return std::make_shared<AbstractTuple>(elements);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 3);
|
||||
auto cache_table = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
auto cache_table_shp = cache_table->shape();
|
||||
MS_EXCEPTION_IF_NULL(cache_table);
|
||||
MS_EXCEPTION_IF_NULL(cache_table_shp);
|
||||
|
||||
auto swap_cache_idx = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
auto swap_cache_idx_shp = swap_cache_idx->shape();
|
||||
MS_EXCEPTION_IF_NULL(swap_cache_idx);
|
||||
MS_EXCEPTION_IF_NULL(swap_cache_idx_shp);
|
||||
|
||||
auto cache_table_shape = cache_table_shp->shape();
|
||||
auto swap_cache_idx_shape = swap_cache_idx_shp->shape();
|
||||
ShapeVector shape;
|
||||
shape.emplace_back(swap_cache_idx_shape[0]);
|
||||
shape.emplace_back(cache_table_shape[1]);
|
||||
auto swap_cache_idx_max_shape = swap_cache_idx_shp->max_shape();
|
||||
ShapeVector max_shape;
|
||||
ShapeVector min_shape;
|
||||
if (!swap_cache_idx_max_shape.empty()) {
|
||||
max_shape.emplace_back(swap_cache_idx_max_shape[0]);
|
||||
max_shape.emplace_back(cache_table_shape[1]);
|
||||
} else {
|
||||
max_shape = shape;
|
||||
}
|
||||
for (size_t i = 0; i < max_shape.size(); ++i) {
|
||||
min_shape.emplace_back(1);
|
||||
}
|
||||
|
||||
AbstractTensorPtr ret =
|
||||
std::make_shared<AbstractTensor>(cache_table->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplUpdateCache(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(input_x);
|
||||
MS_EXCEPTION_IF_NULL(input_x->shape());
|
||||
|
||||
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
MS_EXCEPTION_IF_NULL(indices);
|
||||
MS_EXCEPTION_IF_NULL(indices->shape());
|
||||
|
||||
ShapeVector shape;
|
||||
shape.emplace_back(1);
|
||||
|
||||
AbstractTensorPtr ret = std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(shape));
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
|
|
|
@ -56,6 +56,9 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}},
|
||||
{prim::kPrimScatterAdd, {InferImplScatterAdd, true}},
|
||||
{prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}},
|
||||
{prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, true}},
|
||||
{prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, true}},
|
||||
{prim::kPrimUpdateCache, {InferImplUpdateCache, true}},
|
||||
{prim::kPrimDiv, {InferImplDiv, true}},
|
||||
{prim::kPrimRealDiv, {InferImplRealDiv, true}},
|
||||
{prim::kPrimShape, {InferImplShape, false}},
|
||||
|
|
|
@ -98,6 +98,9 @@ inline const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>(
|
|||
inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin");
|
||||
inline const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset");
|
||||
inline const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape");
|
||||
inline const PrimitivePtr kPrimMapCacheIdx = std::make_shared<Primitive>("MapCacheIdx");
|
||||
inline const PrimitivePtr kPrimUpdateCache = std::make_shared<Primitive>("UpdateCache");
|
||||
inline const PrimitivePtr kPrimCacheSwapTable = std::make_shared<Primitive>("CacheSwapTable");
|
||||
inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile");
|
||||
inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN");
|
||||
inline const PrimitivePtr kPrimAccumulateNV2 = std::make_shared<Primitive>("AccumulateNV2");
|
||||
|
|
|
@ -15,11 +15,11 @@
|
|||
"""cache_ops"""
|
||||
from ..._checkparam import Validator as validator
|
||||
from ...common import dtype as mstype
|
||||
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
||||
from ..primitive import PrimitiveWithInfer, prim_attr_register, PrimitiveWithCheck
|
||||
from .. import signature as sig
|
||||
|
||||
|
||||
class UpdateCache(PrimitiveWithInfer):
|
||||
class UpdateCache(PrimitiveWithCheck):
|
||||
"""
|
||||
Update the value fo input_x, similar to ScatterNdUpdate.
|
||||
The diffirent is that UpdateCache will not update when indices < 0 or indices >= max_num.
|
||||
|
@ -47,15 +47,12 @@ class UpdateCache(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=['input_x', 'indices', 'update', 'max_num'],
|
||||
outputs=['out'])
|
||||
|
||||
def infer_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape):
|
||||
|
||||
if len(indices_shape) < 2:
|
||||
raise ValueError("The dimension of 'indices' in UpdateCache must >= 2, "
|
||||
"but got %d." % len(indices_shape))
|
||||
def check_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape):
|
||||
return [1]
|
||||
|
||||
def infer_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype):
|
||||
validator.check_tensor_dtype_valid("indices", indices_dtype, mstype.int_type, self.name)
|
||||
def check_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype):
|
||||
validator.check_tensor_dtype_valid(
|
||||
"indices", indices_dtype, mstype.int_type, self.name)
|
||||
return input_x_dtype
|
||||
|
||||
|
||||
|
@ -139,7 +136,8 @@ class SearchCacheIdx(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype):
|
||||
args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name)
|
||||
validator.check_tensors_dtypes_same_and_valid(
|
||||
args, mstype.int_type, self.name)
|
||||
out_dtype = (hashmap_dtype, hashmap_dtype, hashmap_dtype)
|
||||
return out_dtype
|
||||
|
||||
|
@ -172,7 +170,6 @@ class CacheSwapHashmap(PrimitiveWithInfer):
|
|||
outputs=['swap_cache_idx', 'old_emb_idx'])
|
||||
|
||||
def infer_shape(self, hashmap_shape, miss_emb_idx_shape, step_shape):
|
||||
|
||||
if len(hashmap_shape) != 2:
|
||||
raise ValueError("The dimension of 'hashmap' in CacheSwapHashmap must be 2, "
|
||||
"but got %d." % len(hashmap_shape))
|
||||
|
@ -181,12 +178,13 @@ class CacheSwapHashmap(PrimitiveWithInfer):
|
|||
return out_shape
|
||||
|
||||
def infer_dtype(self, hashmap_dtype, miss_emb_idx_dtype, step_dtype):
|
||||
validator.check_tensor_dtype_valid("miss_emb_idx", miss_emb_idx_dtype, mstype.int_type, self.name)
|
||||
validator.check_tensor_dtype_valid(
|
||||
"miss_emb_idx", miss_emb_idx_dtype, mstype.int_type, self.name)
|
||||
out_dtype = (miss_emb_idx_dtype, miss_emb_idx_dtype)
|
||||
return out_dtype
|
||||
|
||||
|
||||
class CacheSwapTable(PrimitiveWithInfer):
|
||||
class CacheSwapTable(PrimitiveWithCheck):
|
||||
"""
|
||||
Delete a hashmap entry,and insert a new key to hashmap, return the key and value of delete entry.
|
||||
|
||||
|
@ -212,21 +210,20 @@ class CacheSwapTable(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=['cache_table', 'swap_cache_idx', 'miss_value'],
|
||||
outputs=['old_value'])
|
||||
|
||||
def infer_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape):
|
||||
def check_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape):
|
||||
if len(cache_table_shape) != 2:
|
||||
raise ValueError(
|
||||
"cache table shape must be 2, but got %d" % len(cache_table_shape))
|
||||
if swap_cache_idx_shape + cache_table_shape[1:] != miss_value_shape:
|
||||
raise ValueError(
|
||||
"swap_cache_idx_shape + cache_table_shape[1:] must equal to miss_value_shape")
|
||||
|
||||
return miss_value_shape
|
||||
|
||||
def infer_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype):
|
||||
validator.check_tensor_dtype_valid("swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name)
|
||||
def check_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype):
|
||||
validator.check_tensor_dtype_valid(
|
||||
"swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name)
|
||||
return miss_value_dtype
|
||||
|
||||
|
||||
class MapCacheIdx(PrimitiveWithInfer):
|
||||
class MapCacheIdx(PrimitiveWithCheck):
|
||||
"""
|
||||
MapCacheIdx merge SearchCacheIdx, CacheSwapHashmap, UpdateCache together.
|
||||
When input an indices tensor, it will output the cache indices which search in hashmap.
|
||||
|
@ -244,21 +241,34 @@ class MapCacheIdx(PrimitiveWithInfer):
|
|||
def __init__(self):
|
||||
"""init MapCacheIdx"""
|
||||
|
||||
self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'cache_max_num'],
|
||||
self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'offset'],
|
||||
outputs=['cache_idx', 'old_emb_idx', 'miss_emb_idx', 'swap_cache_idx'])
|
||||
|
||||
def infer_shape(self, hashmap_shape, indices_shape, step_shape, emb_max_num_shape, cache_max_num_shape):
|
||||
|
||||
def __check__(self, hashmap, indices, step, emb_max_num, offset):
|
||||
hashmap_shape = hashmap['shape']
|
||||
if len(hashmap_shape) != 2:
|
||||
raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, "
|
||||
"but got %d." % len(hashmap_shape))
|
||||
out_shape = (indices_shape, indices_shape,
|
||||
indices_shape, indices_shape)
|
||||
return out_shape
|
||||
out_shape = (indices['shape'], -1, -1, -1)
|
||||
|
||||
def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype):
|
||||
hashmap_dtype = hashmap['dtype']
|
||||
indices_dtype = indices['dtype']
|
||||
args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name)
|
||||
validator.check_tensor_type_same(args, mstype.int_type, self.name)
|
||||
out_dtype = (hashmap_dtype, hashmap_dtype,
|
||||
hashmap_dtype, hashmap_dtype)
|
||||
return out_dtype
|
||||
|
||||
out = {'shape': out_shape,
|
||||
'dtype': out_dtype,
|
||||
'value': None}
|
||||
if 'max_shape' in indices:
|
||||
out['max_shape'] = (indices['max_shape'], indices['max_shape'],
|
||||
indices['max_shape'], indices['max_shape'])
|
||||
else:
|
||||
out['max_shape'] = (indices['shape'], indices['shape'],
|
||||
indices['shape'], indices['shape'])
|
||||
if 'min_shape' in indices:
|
||||
out['min_shape'] = (indices['min_shape'], 0, 0, 0)
|
||||
else:
|
||||
out['min_shape'] = (0, 0, 0, 0)
|
||||
return out
|
||||
|
|
|
@ -75,19 +75,6 @@ class CacheSwapHashmapNet(nn.Cell):
|
|||
return self.ops(self.net.hashmap, miss_emb_idx, self.step)
|
||||
|
||||
|
||||
class MapCacheIdxNet(nn.Cell):
|
||||
def __init__(self, hashmap_np):
|
||||
super().__init__()
|
||||
self.ops = P.MapCacheIdx()
|
||||
self.hashmap = Parameter(Tensor(hashmap_np), name="hashmap")
|
||||
self.emb_max = 25
|
||||
self.cache_max = 10
|
||||
self.step = 0
|
||||
|
||||
def construct(self, indices):
|
||||
return self.ops(self.hashmap, indices, self.step, self.emb_max, self.cache_max)
|
||||
|
||||
|
||||
class UpdateCacheNet(nn.Cell):
|
||||
def __init__(self, x):
|
||||
super().__init__()
|
||||
|
@ -165,45 +152,6 @@ def test_cache_swap_hashmap():
|
|||
np.array(hashmap_np_after_ops, np.int32))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_map_cache_idx():
|
||||
hashmap_np = init_hashmap(10)
|
||||
indices_np = np.array([10, 2, 20, 5, 3], np.int32)
|
||||
map_cache_idx = MapCacheIdxNet(hashmap_np)
|
||||
indices = Tensor(indices_np)
|
||||
cache_idx, old_emb_idx, miss_emb_idx, swap_cache_idx = map_cache_idx(
|
||||
indices)
|
||||
|
||||
expect_cache_idx = [5, 1, 9, 7, 3]
|
||||
expect_old_emb_idx = [-1, -1, 21, 15, -1]
|
||||
expect_miss_emb_idx = [-1, -1, 20, 5, -1]
|
||||
expect_swap_cache_idx = [-1, -1, 9, 7, -1]
|
||||
|
||||
hashmap_np_after_ops = [[5, 7, 0, 1],
|
||||
[10, 5, 0, 1],
|
||||
[2, 1, 0, 1],
|
||||
[20, 9, 0, 1],
|
||||
[20, 9, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[3, 3, 0, 1],
|
||||
[21, 9, -5, 0]]
|
||||
|
||||
assert np.allclose(cache_idx.asnumpy(),
|
||||
np.array(expect_cache_idx, np.int32))
|
||||
assert np.allclose(old_emb_idx.asnumpy(),
|
||||
np.array(expect_old_emb_idx, np.int32))
|
||||
assert np.allclose(miss_emb_idx.asnumpy(),
|
||||
np.array(expect_miss_emb_idx, np.int32))
|
||||
assert np.allclose(swap_cache_idx.asnumpy(),
|
||||
np.array(expect_swap_cache_idx, np.int32))
|
||||
assert np.allclose(map_cache_idx.hashmap.data.asnumpy(),
|
||||
np.array(hashmap_np_after_ops, np.int32))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
|
Loading…
Reference in New Issue