forked from mindspore-Ecosystem/mindspore
fix cpu codex and optimize controldepend
This commit is contained in:
parent
8239407bf3
commit
89ad244882
|
@ -23,7 +23,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
template <typename T>
|
||||
void AdamCPUKernel::LaunchAdam(T *var, T *m, T *v, float lr, float beta1, float beta2, float epsilon, const T *gradient,
|
||||
size_t start, size_t end) {
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
template <typename S, typename T>
|
||||
void Cast(const S *in, T *out, size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
|
|
|
@ -23,7 +23,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
class CastCPUKernel : public CPUKernel {
|
||||
public:
|
||||
CastCPUKernel() = default;
|
||||
|
|
|
@ -95,6 +95,5 @@ void EqualCPUKernel::CheckParam(const CNodePtr &kernel_node) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -81,11 +81,9 @@ void MapCacheIdxCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
node_ = kernel_node;
|
||||
auto hashmap_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
|
||||
if (hashmap_shape.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "Dimension of HashMap must be 2, (n, 4)";
|
||||
}
|
||||
|
||||
hashmap_length_ = hashmap_shape[0];
|
||||
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
||||
}
|
||||
|
@ -121,7 +119,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
auto output_old_emb_idx = reinterpret_cast<T *>(outputs[1]->addr);
|
||||
auto output_miss_emb_idx = reinterpret_cast<T *>(outputs[2]->addr);
|
||||
auto output_swap_cache_idx = reinterpret_cast<T *>(outputs[3]->addr);
|
||||
|
||||
std::vector<T> miss_idx;
|
||||
size_t miss_count = 0;
|
||||
float total_count = 0;
|
||||
|
@ -134,9 +131,7 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
output_cache_idx[i] = -1;
|
||||
continue;
|
||||
}
|
||||
|
||||
T tmp_entry = HashFunc(key, hashmap_length_);
|
||||
|
||||
size_t count = 1;
|
||||
count_size += 1;
|
||||
while ((!hashmap[tmp_entry].IsEmpty() && !hashmap[tmp_entry].IsKey(key))) {
|
||||
|
@ -147,7 +142,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
}
|
||||
count += 1;
|
||||
}
|
||||
|
||||
total_count += count;
|
||||
if (hashmap[tmp_entry].IsEmpty()) {
|
||||
miss_idx.emplace_back(i);
|
||||
|
@ -163,10 +157,8 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
MS_LOG(INFO) << "Miss count: " << miss_count;
|
||||
MS_LOG(INFO) << "Avg search count: " << total_count / count_size;
|
||||
MS_LOG(INFO) << "Cache hit rate: " << hit_count / count_size;
|
||||
|
||||
float total_insert_count = 0;
|
||||
float total_delete_count = 0;
|
||||
|
||||
// swap hash map
|
||||
for (size_t i = 0; i < miss_count; ++i) {
|
||||
T emb_idx = output_miss_emb_idx[i];
|
||||
|
@ -180,11 +172,9 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
}
|
||||
tag_count++;
|
||||
}
|
||||
|
||||
hashmap[entry].key = emb_idx;
|
||||
hashmap[entry].step = step_[0];
|
||||
hashmap[entry].tag = tag_count;
|
||||
|
||||
T tmp_entry = (entry + 1) % hashmap_length_;
|
||||
size_t delete_count = 1;
|
||||
while (hashmap[tmp_entry].IsEmpty() || hashmap[tmp_entry].IsUsing(step_[0])) {
|
||||
|
@ -195,7 +185,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
}
|
||||
delete_count++;
|
||||
}
|
||||
|
||||
output_swap_cache_idx[i] = hashmap[tmp_entry].value;
|
||||
output_old_emb_idx[i] = hashmap[tmp_entry].key;
|
||||
hashmap[entry].value = output_swap_cache_idx[i];
|
||||
|
@ -204,19 +193,15 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
total_delete_count += (compress_count + delete_count);
|
||||
total_insert_count += tag_count;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Insert count: " << total_insert_count / miss_count;
|
||||
MS_LOG(INFO) << "Delete count: " << total_delete_count / miss_count;
|
||||
|
||||
// update step
|
||||
step_[0] += 1;
|
||||
|
||||
// update cache idx
|
||||
for (size_t i = 0; i < miss_count; ++i) {
|
||||
int idx = miss_idx[i];
|
||||
output_cache_idx[idx] = output_swap_cache_idx[i];
|
||||
}
|
||||
|
||||
std::vector<size_t> out_shape;
|
||||
out_shape.emplace_back(miss_count);
|
||||
std::vector<TypeId> dtypes;
|
||||
|
|
|
@ -54,7 +54,6 @@ MS_REG_CPU_KERNEL(FusedBatchNormGradCPU,
|
|||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedBatchNormGradCPUKernel)
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -63,6 +63,9 @@ class UniqueCPUKernel : public CPUKernel {
|
|||
|
||||
template <typename DataType>
|
||||
static size_t BucketId(DataType data, size_t bucket_num) {
|
||||
if (bucket_num < 1) {
|
||||
return static_cast<size_t>(data);
|
||||
}
|
||||
return static_cast<size_t>(data) % bucket_num;
|
||||
}
|
||||
|
||||
|
@ -73,6 +76,9 @@ class UniqueCPUKernel : public CPUKernel {
|
|||
MS_EXCEPTION_IF_NULL(params->input_);
|
||||
MS_EXCEPTION_IF_NULL(each_bucket_size);
|
||||
size_t bucket_num = each_bucket_size->size();
|
||||
if (params->input_size_ < 1) {
|
||||
return;
|
||||
}
|
||||
for (IndexType i = 0; i < params->input_size_; ++i) {
|
||||
auto bucket_id = BucketId(params->input_[i], bucket_num);
|
||||
each_bucket_size->at(bucket_id)++;
|
||||
|
@ -131,6 +137,9 @@ class UniqueCPUKernel : public CPUKernel {
|
|||
MS_EXCEPTION_IF_NULL(segment->input_);
|
||||
std::vector<IndexType> bucket_data_num(segment->thread_num_, 0);
|
||||
auto bucket_size = buckets.size();
|
||||
if (segment->input_size_ < 1) {
|
||||
return;
|
||||
}
|
||||
for (IndexType i = 0; i < segment->input_size_; ++i) {
|
||||
DataType data = segment->input_[i];
|
||||
auto bucket_id = BucketId(data, segment->thread_num_);
|
||||
|
@ -233,6 +242,9 @@ class UniqueCPUKernel : public CPUKernel {
|
|||
MS_EXCEPTION_IF_NULL(output);
|
||||
MS_EXCEPTION_IF_NULL(inverse_idx);
|
||||
IndexType j = 0;
|
||||
if (params->input_size_ < 1) {
|
||||
return;
|
||||
}
|
||||
if (params->need_sort_) {
|
||||
for (IndexType i = 0; i < params->input_size_; ++i) {
|
||||
input_idx[i] = i;
|
||||
|
@ -296,6 +308,9 @@ class UniqueCPUKernel : public CPUKernel {
|
|||
MS_EXCEPTION_IF_NULL(bucket->workspace_idx_);
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
MS_EXCEPTION_IF_NULL(result->inverse_idx_);
|
||||
if (bucket->input_size_ < 1) {
|
||||
return;
|
||||
}
|
||||
for (IndexType i = 0; i < bucket->input_size_; ++i) {
|
||||
auto origin_idx = bucket->workspace_idx_[i];
|
||||
if (origin_idx >= 0 && origin_idx < result->input_size_) {
|
||||
|
|
|
@ -155,8 +155,10 @@ BaseRef CreateNodeOutputTensors(const AnfNodePtr &anf, const KernelGraphPtr &gra
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
VectorRef ret;
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
auto out = CreateNodeOutputTensors(cnode->input(i), graph, input_tensors, tensor_to_node);
|
||||
ret.push_back(out);
|
||||
if (!AnfAlgo::CheckPrimitiveType(cnode->input(i), prim::kPrimControlDepend)) {
|
||||
auto out = CreateNodeOutputTensors(cnode->input(i), graph, input_tensors, tensor_to_node);
|
||||
ret.push_back(out);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -798,7 +798,6 @@ bool AscendKernelRuntime::DestroyHccl() {
|
|||
MS_LOG(ERROR) << "Dynamic Shape Hccl Finalize Failed";
|
||||
}
|
||||
HcclResult res = hcom_destroy();
|
||||
|
||||
if (res != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Hccl destroy failed";
|
||||
return false;
|
||||
|
|
|
@ -1832,7 +1832,6 @@ void AscendStreamAssign::AdjustAtomicAddrCleanOrder(const NotNull<KernelGraphPtr
|
|||
|
||||
graph_ptr->set_execution_order(update_orders);
|
||||
}
|
||||
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
|
|
|
@ -268,9 +268,11 @@ void AddSegmentDependency(const FuncGraphPtr &graph, const std::string &default_
|
|||
node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end());
|
||||
}
|
||||
GraphSegmentPtr node_segment{nullptr};
|
||||
auto node_iter = node_to_segment.find(node);
|
||||
if (node_iter != node_to_segment.end()) {
|
||||
node_segment = node_iter->second;
|
||||
if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
|
||||
auto node_iter = node_to_segment.find(node);
|
||||
if (node_iter != node_to_segment.end()) {
|
||||
node_segment = node_iter->second;
|
||||
}
|
||||
}
|
||||
for (auto &input : node_inputs) {
|
||||
if (node_segment != nullptr && !node_segment->is_cut_ && input->isa<CNode>()) {
|
||||
|
|
Loading…
Reference in New Issue