forked from mindspore-Ecosystem/mindspore
!49503 AtomicAddrClean uses list_int64 attr
Merge pull request !49503 from xulei/atomic_clean
This commit is contained in:
commit
f2fe294c1b
|
@ -450,7 +450,8 @@
|
|||
"ScatterNdD ": "Accuracy issues",
|
||||
"Trace": "Hadn't adapted tbe implementation",
|
||||
"AssignAdd": "Frac_nz in pangu not support",
|
||||
"Range": "not support dynamic shape with tiling failed"
|
||||
"Range": "not support dynamic shape with tiling failed",
|
||||
"AtomicAddrClean": "need to clean addr larger than 2G, int32 is not enough"
|
||||
},
|
||||
"SkipNodes": [
|
||||
"Im2col",
|
||||
|
@ -481,7 +482,8 @@
|
|||
"TransData",
|
||||
"ScatterNdD",
|
||||
"AssignAdd",
|
||||
"Range"
|
||||
"Range",
|
||||
"AtomicAddrClean"
|
||||
],
|
||||
"FallbackOps": {
|
||||
"DeformableOffsets": [
|
||||
|
|
|
@ -239,11 +239,11 @@ bool IfAtomicOpNeedFusion(const size_t clean_total_num, const CNodePtr &first_no
|
|||
return false;
|
||||
}
|
||||
|
||||
std::vector<int32_t> GetClearSize(const CNodePtr &pre_node) {
|
||||
std::vector<int64_t> GetClearSize(const CNodePtr &pre_node) {
|
||||
MS_EXCEPTION_IF_NULL(pre_node);
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(pre_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
std::vector<int32_t> clean_size_list;
|
||||
std::vector<int64_t> clean_size_list;
|
||||
constexpr size_t kAlignBytes = 32 - 1;
|
||||
// clean output
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
|
||||
|
@ -251,7 +251,7 @@ std::vector<int32_t> GetClearSize(const CNodePtr &pre_node) {
|
|||
auto output_men_size = kernel_mod->GetOutputSizeList();
|
||||
for (auto index : output_indexes) {
|
||||
auto clean_item =
|
||||
SizeToInt((output_men_size.at(index) + kMemAlignSize + kAlignBytes) / kMemAlignSize * kMemAlignSize);
|
||||
SizeToLong((output_men_size.at(index) + kMemAlignSize + kAlignBytes) / kMemAlignSize * kMemAlignSize);
|
||||
(void)clean_size_list.emplace_back(clean_item);
|
||||
}
|
||||
}
|
||||
|
@ -261,7 +261,7 @@ std::vector<int32_t> GetClearSize(const CNodePtr &pre_node) {
|
|||
auto workspace_men_sizes = kernel_mod->GetWorkspaceSizeList();
|
||||
for (const auto &index : workspace_indexes) {
|
||||
auto clean_item =
|
||||
SizeToInt((workspace_men_sizes.at(index) + kMemAlignSize + kAlignBytes) / kMemAlignSize * kMemAlignSize);
|
||||
SizeToLong((workspace_men_sizes.at(index) + kMemAlignSize + kAlignBytes) / kMemAlignSize * kMemAlignSize);
|
||||
(void)clean_size_list.emplace_back(clean_item);
|
||||
}
|
||||
}
|
||||
|
@ -303,7 +303,7 @@ CNodePtr NewAtomicOp(const CNodePtr &pre_node, const std::vector<AnfNodePtr> &fu
|
|||
}
|
||||
|
||||
void InsertFusionAtomicOp(const CNodePtr &first_clear_node, const std::vector<AnfNodePtr> &fusion_clear_inputs,
|
||||
const std::vector<int32_t> &clean_size_list, CleanOpsMap *clean_ops) {
|
||||
const std::vector<int64_t> &clean_size_list, CleanOpsMap *clean_ops) {
|
||||
MS_EXCEPTION_IF_NULL(first_clear_node);
|
||||
MS_EXCEPTION_IF_NULL(clean_ops);
|
||||
auto clear_zero = NewAtomicOp(first_clear_node, fusion_clear_inputs);
|
||||
|
@ -355,7 +355,7 @@ void SpecialAkgOps(const std::string &op_name, const CNodePtr &node, CleanOpsMap
|
|||
|
||||
void ProcessAtomicFusion(const std::vector<CNodePtr> &kernels, CleanOpsMap *clean_ops) {
|
||||
MS_EXCEPTION_IF_NULL(clean_ops);
|
||||
std::vector<int32_t> clean_size_list;
|
||||
std::vector<int64_t> clean_size_list;
|
||||
std::vector<AnfNodePtr> fusion_clear_inputs;
|
||||
CNodePtr first_node = nullptr;
|
||||
for (const auto &anf_node : kernels) {
|
||||
|
|
|
@ -153,7 +153,7 @@ void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressP
|
|||
MS_LOG(DEBUG) << "AtomicAddClean clean workspace size:" << clean_workspace_indexs.size();
|
||||
}
|
||||
}
|
||||
auto clear_mems = common::AnfAlgo::GetNodeAttr<std::vector<int32_t>>(anf_node_ptr, kAttrAtomicAddMemSize);
|
||||
auto clear_mems = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(anf_node_ptr, kAttrAtomicAddMemSize);
|
||||
if (kernel_inputs->size() != clear_mems.size()) {
|
||||
MS_LOG(EXCEPTION) << "AtomicAddClean kernel inputs size not equal clear memory size, kernel inputs size:"
|
||||
<< kernel_inputs->size() << ",clean mem size" << clear_mems.size();
|
||||
|
|
|
@ -48,6 +48,7 @@ static std::unordered_map<std::string, ATTR_DTYPE> type_attr_dtype_map = {
|
|||
{kVTypeFloat, ATTR_DTYPE::ATTR_FLOAT32},
|
||||
{kVTypeListInt, ATTR_DTYPE::ATTR_LIST_INT32},
|
||||
{kVTypeListFloat, ATTR_DTYPE::ATTR_LIST_FLOAT32},
|
||||
{kVTypeListInt64, ATTR_DTYPE::ATTR_LIST_INT64},
|
||||
{kVTypeListUInt64, ATTR_DTYPE::ATTR_LIST_UINT64},
|
||||
{kVTypeListListInt, ATTR_DTYPE::ATTR_LIST_LIST_INT64}};
|
||||
|
||||
|
@ -181,6 +182,7 @@ bool ParseAttrValue(const std::string &type, const mindspore::ValuePtr &value, n
|
|||
case ATTR_DTYPE::ATTR_FLOAT32:
|
||||
return ParseAttrFloat(value, attr_obj);
|
||||
case ATTR_DTYPE::ATTR_LIST_INT32:
|
||||
case ATTR_DTYPE::ATTR_LIST_INT64:
|
||||
return ParseAttrListInt(value, attr_obj);
|
||||
case ATTR_DTYPE::ATTR_LIST_FLOAT32:
|
||||
return ParseAttrListFloat(value, attr_obj);
|
||||
|
@ -232,7 +234,8 @@ bool ParseAttrDefaultValue(const std::string &type, const std::string &value, nl
|
|||
case ATTR_DTYPE::ATTR_FLOAT32:
|
||||
(*attr_obj)[kJValue] = std::stof(value);
|
||||
break;
|
||||
case ATTR_DTYPE::ATTR_LIST_INT32: {
|
||||
case ATTR_DTYPE::ATTR_LIST_INT32:
|
||||
case ATTR_DTYPE::ATTR_LIST_INT64: {
|
||||
std::stringstream string_value(value);
|
||||
std::string list_elem;
|
||||
std::vector<int64_t> attrs_value;
|
||||
|
|
|
@ -60,6 +60,7 @@ constexpr auto kVTypeFloat32 = "float32";
|
|||
constexpr auto kVTypeListInt = "listInt";
|
||||
constexpr auto kVTypeInt32 = "Int32";
|
||||
constexpr auto kVTypeInt64 = "Int64";
|
||||
constexpr auto kVTypeListInt64 = "listInt64";
|
||||
constexpr auto kVTypeListUInt64 = "listUInt64";
|
||||
constexpr auto kVTypeListFloat = "listFloat";
|
||||
constexpr auto kVTypeListListInt = "listListInt";
|
||||
|
|
|
@ -35,3 +35,4 @@ from .acos import _acos_tbe # Accuracy issues(task error in parallel)
|
|||
from .trans_data_ds import _trans_data_ds_tbe # support bool
|
||||
from .scatter_nd_d import _scatter_nd_d_tbe # in python no check supported
|
||||
from .assign_add_ds import _assign_add_ds_tbe # "Frac_nz in pangu not support"
|
||||
from .atomic_addr_clean import _atomic_addr_clean_tbe # need to clean addr larger than 2G, int32 is not enough
|
||||
|
|
|
@ -23,7 +23,7 @@ atomic_addr_clean_op_info = TBERegOp("AtomicAddrClean") \
|
|||
.compute_cost(10) \
|
||||
.kernel_name("atomic_addr_clean") \
|
||||
.partial_flag(True) \
|
||||
.attr("automic_add_mem_size", "required", "listInt", "all") \
|
||||
.attr("automic_add_mem_size", "required", "listInt64", "all") \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue