ascend support untuple

This commit is contained in:
ttudu 2023-01-30 16:38:05 +08:00
parent 73e19297db
commit 37586dbf9b
59 changed files with 1049 additions and 298 deletions

View File

@ -76,6 +76,8 @@ if(ENABLE_ASAN)
endif()
endif()
add_compile_definitions(ENABLE_TUPLE_UNFOLD)
if(DEBUG_MODE)
set(CMAKE_BUILD_TYPE "Debug")
add_compile_definitions(MEM_REUSE_DEBUG)

View File

@ -1177,5 +1177,96 @@ size_t GetInputNodeIndex(const AnfNodePtr &input, const CNodePtr &user_node) {
// The first input is Primitive and needs to be skipped.
return std::distance(input_list.begin() + kSizeOne, pos);
}
int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
std::vector<AnfNodePtr> *plant_inputs) {
if (!common::AnfAlgo::IsTupleOutput(tuple_input)) {
auto abs = tuple_input->abstract();
MS_EXCEPTION_IF_NULL(abs);
MS_LOG(WARNING) << "The Function only split the output type is tuple type but got" << abs->ToString();
return -1;
}
MS_EXCEPTION_IF_NULL(plant_inputs);
auto input_size = AnfAlgo::GetOutputElementNum(tuple_input);
if (tuple_input->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(tuple_input, prim::kPrimMakeTuple)) {
auto make_tuple = tuple_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
size_t tuple_input_num = common::AnfAlgo::GetInputTensorNum(make_tuple);
for (size_t j = 0; j < tuple_input_num; ++j) {
// using for graph kernel
auto dyn_input_node = common::AnfAlgo::GetInputNode(make_tuple, j);
MS_EXCEPTION_IF_NULL(dyn_input_node);
// Handle tuple nested scenes.
if (dyn_input_node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(dyn_input_node, prim::kPrimMakeTuple)) {
input_size += LongToSize(SplitTupleInputs(graph, dyn_input_node, plant_inputs));
continue;
}
(void)plant_inputs->emplace_back(dyn_input_node);
}
return input_size;
}
for (size_t index = 0; index < input_size; ++index) {
auto dynamic_input_node = CreatTupleGetItemNode(graph, tuple_input, index);
(void)plant_inputs->emplace_back(dynamic_input_node);
}
return input_size;
}
AnfNodePtr ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) {
MS_EXCEPTION_IF_NULL(cnode_ptr);
MS_EXCEPTION_IF_NULL(graph);
if (common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimCall) ||
common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimPartial)) {
return nullptr;
}
if (common::AnfAlgo::HasDynamicTupleInput(cnode_ptr)) {
MS_LOG(INFO) << "Node " << cnode_ptr->fullname_with_scope()
<< " has dynamic tuple input, can't convert. Node debug string:" << cnode_ptr->DebugString();
return nullptr;
}
bool is_bprop_cut = common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimBpropCut);
bool cnode_is_print = common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimPrint);
std::vector<AnfNodePtr> plant_inputs;
std::vector<int64_t> dyn_input_sizes;
plant_inputs.push_back(common::AnfAlgo::GetCNodePrimitiveNode(cnode_ptr));
size_t input_num = cnode_ptr->inputs().size() - 1;
for (size_t i = 0; i < input_num; ++i) {
auto input_node = common::AnfAlgo::GetInputNode(cnode_ptr, i);
MS_EXCEPTION_IF_NULL(input_node);
bool output_is_tuple = common::AnfAlgo::IsTupleOutput(input_node);
bool skip = (is_bprop_cut && input_node->abstract()->isa<abstract::AbstractSparseTensor>());
if (output_is_tuple && cnode_is_print) {
(void)dyn_input_sizes.emplace_back(SplitTupleInputs(graph, input_node, &plant_inputs));
} else if (output_is_tuple && !skip) {
auto dyn_input_size = SplitTupleInputs(graph, input_node, &plant_inputs);
if (dyn_input_size == 0) {
dyn_input_sizes.push_back(-1);
plant_inputs.push_back(input_node);
} else {
(void)dyn_input_sizes.emplace_back(dyn_input_size);
}
} else {
dyn_input_sizes.push_back(-1);
plant_inputs.push_back(input_node);
}
}
// If there is dynamic input, set the dyn_input_sizes as an attribute and update the inputs.
if (std::any_of(dyn_input_sizes.begin(), dyn_input_sizes.end(), [](int64_t s) { return s >= 0; })) {
auto new_cnode = NewCNode(plant_inputs, graph, {cnode_ptr});
new_cnode->set_abstract(cnode_ptr->abstract());
new_cnode->set_scope(cnode_ptr->scope());
new_cnode->set_primal_attrs(cnode_ptr->primal_attrs());
new_cnode->set_attrs(cnode_ptr->attrs());
common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_cnode);
auto kernel_graph = graph->cast<KernelGraphPtr>();
if (kernel_graph != nullptr) {
kernel_graph->FrontBackendlMapUpdate(cnode_ptr, new_cnode);
}
return new_cnode;
}
return nullptr;
}
} // namespace opt
} // namespace mindspore

View File

@ -257,6 +257,11 @@ BACKEND_EXPORT int64_t GetNodeOutputTotalUsedNum(const session::KernelGraph &ker
BACKEND_EXPORT void GetCustomOpAttrIndex(const PrimitivePtr &primitive, mindspore::HashSet<size_t> *indexes);
BACKEND_EXPORT size_t GetInputNodeIndex(const AnfNodePtr &input, const CNodePtr &user_node);
BACKEND_EXPORT int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
std::vector<AnfNodePtr> *plant_inputs);
BACKEND_EXPORT AnfNodePtr ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_

View File

@ -24,92 +24,6 @@
namespace mindspore {
namespace opt {
namespace {
int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
std::vector<AnfNodePtr> *plant_inputs) {
if (!common::AnfAlgo::IsTupleOutput(tuple_input)) {
auto abs = tuple_input->abstract();
MS_EXCEPTION_IF_NULL(abs);
MS_LOG(WARNING) << "The Function only split the output type is tuple type but got" << abs->ToString();
return -1;
}
MS_EXCEPTION_IF_NULL(plant_inputs);
auto input_size = AnfAlgo::GetOutputTensorNum(tuple_input);
if (tuple_input->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(tuple_input, prim::kPrimMakeTuple)) {
auto make_tuple = tuple_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
size_t tuple_input_num = common::AnfAlgo::GetInputTensorNum(make_tuple);
for (size_t j = 0; j < tuple_input_num; ++j) {
// using for graph kernel
auto dyn_input_node = common::AnfAlgo::GetInputNode(make_tuple, j);
MS_EXCEPTION_IF_NULL(dyn_input_node);
// Handle tuple nested scenes.
if (dyn_input_node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(dyn_input_node, prim::kPrimMakeTuple)) {
input_size += LongToSize(SplitTupleInputs(graph, dyn_input_node, plant_inputs));
continue;
}
(void)plant_inputs->emplace_back(dyn_input_node);
}
return input_size;
}
for (size_t index = 0; index < input_size; ++index) {
auto dynamic_input_node = CreatTupleGetItemNode(graph, tuple_input, index);
(void)plant_inputs->emplace_back(dynamic_input_node);
}
return input_size;
}
AnfNodePtr ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) {
MS_EXCEPTION_IF_NULL(cnode_ptr);
MS_EXCEPTION_IF_NULL(graph);
if (common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimCall) ||
common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimPartial)) {
return nullptr;
}
bool is_bprop_cut = common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimBpropCut);
bool cnode_is_print = common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimPrint);
std::vector<AnfNodePtr> plant_inputs;
std::vector<int64_t> dyn_input_sizes;
plant_inputs.push_back(common::AnfAlgo::GetCNodePrimitiveNode(cnode_ptr));
size_t input_num = cnode_ptr->inputs().size() - 1;
for (size_t i = 0; i < input_num; ++i) {
auto input_node = common::AnfAlgo::GetInputNode(cnode_ptr, i);
MS_EXCEPTION_IF_NULL(input_node);
bool output_is_tuple = common::AnfAlgo::IsTupleOutput(input_node);
bool skip = (is_bprop_cut && input_node->abstract()->isa<abstract::AbstractSparseTensor>());
if (output_is_tuple && cnode_is_print) {
(void)dyn_input_sizes.emplace_back(SplitTupleInputs(graph, input_node, &plant_inputs));
} else if (output_is_tuple && !skip) {
auto dyn_input_size = SplitTupleInputs(graph, input_node, &plant_inputs);
if (dyn_input_size == 0) {
dyn_input_sizes.push_back(-1);
plant_inputs.push_back(input_node);
} else {
(void)dyn_input_sizes.emplace_back(dyn_input_size);
}
} else {
dyn_input_sizes.push_back(-1);
plant_inputs.push_back(input_node);
}
}
// If there is dynamic input, set the dyn_input_sizes as an attribute and update the inputs.
if (std::any_of(dyn_input_sizes.begin(), dyn_input_sizes.end(), [](int64_t s) { return s >= 0; })) {
auto new_cnode = NewCNode(plant_inputs, graph, {cnode_ptr});
new_cnode->set_abstract(cnode_ptr->abstract());
new_cnode->set_scope(cnode_ptr->scope());
new_cnode->set_primal_attrs(cnode_ptr->primal_attrs());
new_cnode->set_attrs(cnode_ptr->attrs());
common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_cnode);
auto kernel_graph = graph->cast<KernelGraphPtr>();
if (kernel_graph != nullptr) {
kernel_graph->FrontBackendlMapUpdate(cnode_ptr, new_cnode);
}
return new_cnode;
}
return nullptr;
}
} // namespace
const BaseRef ConvertTupleInputToDynamicInput::DefinePattern() const {
VarPtr V = std::make_shared<Var>();
VarPtr Xs = std::make_shared<SeqVar>();

View File

@ -106,40 +106,6 @@ size_t GetFusionSize(const AnfNodePtr &node) {
return 0;
}
int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
std::vector<AnfNodePtr> *plant_inputs) {
if (!common::AnfAlgo::IsTupleOutput(tuple_input)) {
auto abs = tuple_input->abstract();
MS_EXCEPTION_IF_NULL(abs);
MS_LOG(WARNING) << "The Function only split the output type is tuple type but got" << abs->ToString();
return -1;
}
MS_EXCEPTION_IF_NULL(plant_inputs);
auto input_size = AnfAlgo::GetOutputTensorNum(tuple_input);
if (tuple_input->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(tuple_input, prim::kPrimMakeTuple)) {
auto make_tuple = tuple_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
size_t tuple_input_num = common::AnfAlgo::GetInputTensorNum(make_tuple);
for (size_t j = 0; j < tuple_input_num; ++j) {
// using for graph kernel
auto dyn_input_node = common::AnfAlgo::GetInputNode(make_tuple, j);
MS_EXCEPTION_IF_NULL(dyn_input_node);
// Handle tuple nested scenes.
if (dyn_input_node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(dyn_input_node, prim::kPrimMakeTuple)) {
input_size += SplitTupleInputs(graph, dyn_input_node, plant_inputs);
continue;
}
(void)plant_inputs->emplace_back(dyn_input_node);
}
return input_size;
}
for (size_t index = 0; index < input_size; ++index) {
auto dynamic_input_node = CreatTupleGetItemNode(graph, tuple_input, index);
(void)plant_inputs->emplace_back(dynamic_input_node);
}
return input_size;
}
void ExpandFlattenConcatTupleInput(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) {
MS_EXCEPTION_IF_NULL(cnode_ptr);
MS_EXCEPTION_IF_NULL(graph);

View File

@ -26,8 +26,8 @@
namespace mindspore {
namespace opt {
int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
std::vector<AnfNodePtr> *plant_inputs) {
int64_t SplitTupleInputsForInsertType(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
std::vector<AnfNodePtr> *plant_inputs) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(tuple_input);
MS_EXCEPTION_IF_NULL(plant_inputs);
@ -50,7 +50,7 @@ int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_inpu
MS_EXCEPTION_IF_NULL(dyn_input_node);
// Handle tuple nested scenes.
if (dyn_input_node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(dyn_input_node, prim::kPrimMakeTuple)) {
int64_t dyn_input_size = SplitTupleInputs(graph, dyn_input_node, plant_inputs);
int64_t dyn_input_size = SplitTupleInputsForInsertType(graph, dyn_input_node, plant_inputs);
input_size += LongToSize(dyn_input_size);
continue;
}
@ -169,7 +169,7 @@ void SetKernelInfoForNewCNodeByOrigNode(const CNodePtr &new_cnode, const CNodePt
MS_EXCEPTION_IF_NULL(kernel_info);
new_cnode->set_kernel_info(kernel_info);
// The node may not be supported in the current device.
new_kernel_builder->SetValid(false);
new_kernel_builder->SetValid(true);
AnfAlgo::SetSelectKernelBuildInfo(new_kernel_builder->Build(), new_cnode.get());
auto new_prim = GetValueNode<PrimitivePtr>(new_cnode->input(kIndex0));
@ -198,6 +198,7 @@ void SetKernelInfoForNewCNode(const CNodePtr &cnode, bool set_format_type) {
std::vector<KernelObjectType> input_obj_type;
std::vector<KernelObjectType> output_obj_type;
GenerateKernelObjectTypeForNewCNode(cnode, &input_obj_type, &output_obj_type);
builder->SetKernelType(CPU_KERNEL);
builder->SetInputsKernelObjectType(input_obj_type);
builder->SetOutputsKernelObjectType(output_obj_type);
@ -231,7 +232,7 @@ void SetKernelInfoForNewCNode(const CNodePtr &cnode, bool set_format_type) {
}
// The node may not be supported in the current device.
builder->SetValid(false);
builder->SetValid(true);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), cnode.get());
}
@ -568,7 +569,7 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleUnfoldToTupleUnfold(const Func
}
AnfNodePtrList plant_inputs;
int64_t unfold_num = SplitTupleInputs(func_graph, input, &plant_inputs);
int64_t unfold_num = SplitTupleInputsForInsertType(func_graph, input, &plant_inputs);
MS_LOG(DEBUG) << "Transform tuple unfold input: " << input->fullname_with_scope() << " to " << unfold_num
<< " inputs.";
return plant_inputs;

View File

@ -76,8 +76,8 @@ using ProcessTypeTransformFunc = std::function<AnfNodePtrList(const FuncGraphPtr
// SplitTupleInputs methods refer to the pass ConvertTupleInputToDynamicInput. It unfolds tuple inputs and returns the
// unfolded inputs nodes.
int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
std::vector<AnfNodePtr> *plant_inputs);
int64_t SplitTupleInputsForInsertType(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
std::vector<AnfNodePtr> *plant_inputs);
// Create the new cnode which will replace the original cnode.
// This method is called at the last step of this pass specifically.

View File

@ -397,9 +397,9 @@ std::string AnfRuntimeAlgorithm::GetOriginDataFormat(const AnfNodePtr &node) {
std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
if (output_idx > AnfAlgo::GetOutputTensorNum(node)) {
if (output_idx > AnfAlgo::GetOutputElementNum(node)) {
MS_LOG(EXCEPTION) << "Output index:" << output_idx
<< " is out of the node output range :" << AnfAlgo::GetOutputTensorNum(node) << " #node ["
<< " is out of the node output range :" << AnfAlgo::GetOutputElementNum(node) << " #node ["
<< node->DebugString() << "]" << trace::DumpSourceLines(node);
}
if (common::AnfAlgo::CheckAbsSparseTensor(node)) {
@ -635,9 +635,9 @@ std::string AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, siz
std::string AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
if (output_idx > AnfAlgo::GetOutputTensorNum(node)) {
if (output_idx > AnfAlgo::GetOutputElementNum(node)) {
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
<< AnfAlgo::GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]"
<< AnfAlgo::GetOutputElementNum(node) << "#node[ " << node->DebugString() << "]"
<< trace::DumpSourceLines(node);
}
if (!AnfUtils::IsRealKernel(node)) {
@ -1728,7 +1728,7 @@ std::vector<TypeId> AnfRuntimeAlgorithm::GetAllOutputObjectType(const AnfNodePtr
std::vector<TypeId> AnfAlgo::GetAllOutputInferDataTypes(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
std::vector<TypeId> outputs;
auto out_nums = AnfAlgo::GetOutputTensorNum(node);
auto out_nums = AnfAlgo::GetOutputElementNum(node);
for (size_t i = 0; i < out_nums; i++) {
auto type = common::AnfAlgo::GetOutputInferDataType(node, i);
outputs.push_back(type);
@ -1736,20 +1736,31 @@ std::vector<TypeId> AnfAlgo::GetAllOutputInferDataTypes(const AnfNodePtr &node)
return outputs;
}
// if input node is MakeTuple, find the PrevNodeNum recursively;
// The monad node in the end is not included in the num;
size_t AnfAlgo::GetInputElementNum(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
size_t element_num = 0;
size_t input_num = cnode->inputs().size() - 1;
for (size_t i = 0; i < input_num; ++i) {
auto input_node = common::AnfAlgo::GetInputNode(cnode, i);
if (common::AnfAlgo::IsTupleOutput(input_node)) {
element_num += AnfUtils::GetOutputTensorNum(input_node);
bool cal_monad_flag = false;
for (size_t i = input_num; i > 0; --i) {
auto input_node = common::AnfAlgo::GetInputNode(cnode, i - 1);
if (!cal_monad_flag && HasAbstractMonad(input_node)) {
continue;
} else if (common::AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) {
element_num += GetInputElementNum(input_node);
cal_monad_flag = true;
} else if (common::AnfAlgo::IsTupleOutput(input_node)) {
element_num += AnfAlgo::GetOutputElementNum(input_node);
cal_monad_flag = true;
} else {
++element_num;
cal_monad_flag = true;
}
}
return element_num;
}

View File

@ -221,6 +221,7 @@ class BACKEND_EXPORT AnfRuntimeAlgorithm {
static std::vector<TypeId> GetAllOutputObjectType(const AnfNodePtr &node);
// Get all output infer data type.
static std::vector<TypeId> GetAllOutputInferDataTypes(const AnfNodePtr &node);
// Get unfold input num
static size_t GetInputElementNum(const AnfNodePtr &node);
static bool IsRealSquenceOutput(const AnfNodePtr &node);
static void SetDynamicAttrToPrim(const PrimitivePtr &prim);

View File

@ -115,6 +115,10 @@ class COMMON_EXPORT AnfAlgo {
static size_t GetInputTensorNum(const AnfNodePtr &node);
// get prev node output width output index
static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx, bool skip_nop_node = false);
// get all the untuple real prev_nodes output
static std::vector<KernelWithIndex> GetRealPrevNodesOutput(const AnfNodePtr &anf_node, size_t input_idx,
bool skip_nop_node = false);
// get output shapes inferred by ME from input nodes.
static ShapeVector GetOutputInferShape(const AnfNodePtr &node, size_t output_idx);
static ShapeVector GetOutputInferShape(const AnfNodePtr &node, const abstract::BaseShapePtr &base_shape,
@ -126,6 +130,8 @@ class COMMON_EXPORT AnfAlgo {
static TypeId GetOutputInferDataType(const TypePtr &type, size_t output_idx);
// get output original data type from prev node,input_index is the input index of current node related to prev node
static TypeId GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx);
// for tuple condition
static std::vector<TypeId> GetRealPrevNodesOutputInferDataType(const AnfNodePtr &node, size_t input_idx);
// set infer shapes and types of anf node
static void SetOutputInferTypeAndShape(const std::vector<TypeId> &types, const std::vector<ShapeVector> &shapes,
AnfNode *node, bool disable_dynamic_len = false);

View File

@ -505,8 +505,8 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpIn
std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list);
size_t real_input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
size_t real_output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
size_t real_input_num = AnfAlgo::GetInputElementNum(kernel_node);
size_t real_output_num = AnfAlgo::GetOutputElementNum(kernel_node);
std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info_ptr->inputs_ptr();
std::vector<std::shared_ptr<OpIOInfo>> outputs = op_info_ptr->outputs_ptr();
std::vector<int64_t> dyn_input_sizes;
@ -729,7 +729,7 @@ bool IsWeightBoundary(const AnfNodePtr &node) {
}
std::vector<int64_t> GetReduceAttrAxis(const CNodePtr &cnode) {
if (common::AnfAlgo::GetInputTensorNum(cnode) != 1 || AnfAlgo::GetOutputTensorNum(cnode) != 1) {
if (common::AnfAlgo::GetInputTensorNum(cnode) != 1 || AnfAlgo::GetOutputElementNum(cnode) != 1) {
MS_LOG(EXCEPTION) << "The reduce node [" << cnode->DebugString() << "] is not single input or single output."
<< trace::DumpSourceLines(cnode);
}
@ -1278,7 +1278,10 @@ std::vector<KernelObjectType> CalKernelObjectTypes(const std::vector<TypeId> &ob
for (size_t i = 0; i < selected_object_types.size(); ++i) {
// Allsame/skip_check doesn't support the backoff.
bool not_backoff = ((all_same || skip_check) && (selected_object_types[i] != object_types[i]));
if (not_backoff) {
// Ops which support tensor also support scalar.
bool scalar_compact =
((selected_object_types[i] == kObjectTypeTensorType) && (object_types[i] == kObjectTypeNumber));
if (not_backoff || scalar_compact) {
(void)ret.emplace_back(TypeIdToKernelObjectTypeForTupleUnfold(object_types[i]));
} else {
(void)ret.emplace_back(TypeIdToKernelObjectType(selected_object_types[i]));
@ -1724,6 +1727,22 @@ std::pair<bool, size_t> MatchKernelAttrStrict(const KernelAttr &kernel_attr,
return std::make_pair(false, 0);
}
bool IsFoldKernelBuildInfo(const KernelBuildInfoPtr &kernel_build_info) {
auto inputs_object_type = kernel_build_info->GetAllInputKernelObjectTypes();
if (std::find(inputs_object_type.begin(), inputs_object_type.end(), KernelObjectType::TUPLE) !=
inputs_object_type.end()) {
return true;
}
auto outputs_object_type = kernel_build_info->GetAllOutputKernelObjectTypes();
if (std::find(outputs_object_type.begin(), outputs_object_type.end(), KernelObjectType::TUPLE) !=
outputs_object_type.end()) {
return true;
}
return false;
}
KernelAttr GetKernelAttrFromBuildInfo(const KernelBuildInfoPtr &build_info) {
MS_EXCEPTION_IF_NULL(build_info);
KernelAttr kernel_attr;

View File

@ -400,6 +400,7 @@ BACKEND_EXPORT std::pair<bool, size_t> MatchKernelAttrStrict(const KernelAttr &k
const std::vector<KernelAttr> &kernel_attr_list);
BACKEND_EXPORT KernelAttr GetKernelAttrFromBuildInfo(const KernelBuildInfoPtr &build_info);
BACKEND_EXPORT KernelAttr GetKernelAttrFromNode(const AnfNodePtr &kernel_node);
BACKEND_EXPORT bool IsFoldKernelBuildInfo(const KernelBuildInfoPtr &kernel_build_info);
struct KernelArgs {
BaseOperatorPtr op;
@ -455,6 +456,7 @@ BACKEND_EXPORT std::vector<KernelObjectType> TypeIdToKernelObjectTypeForTupleUnf
BACKEND_EXPORT TypeId KernelObjectTypeToTypeId(const KernelObjectType &object_type);
KernelObjectType StringToKernelObjectType(const std::string &object_type);
BACKEND_EXPORT void UnfoldKernelBuildInfo(const CNodePtr &kernel_node);
BACKEND_EXPORT int64_t CalOutputTupleSize(const AnfNodePtr &node);
BACKEND_EXPORT void SetDynamicInputSizeAttr(const CNodePtr &cnode);
BACKEND_EXPORT bool IsDynamicParamKernel(const std::string &op_name);

View File

@ -70,7 +70,7 @@ TypeId KernelBuildInfo::GetOutputDeviceType(size_t output_index) const {
KernelObjectType KernelBuildInfo::GetInputKernelObjectType(size_t input_index) const {
if (input_index >= inputs_kernel_object_type_.size()) {
#ifdef ENABLE_TUPLE_UNFOLD
MS_LOG(ERROR) << "The input index [" << input_index
MS_LOG(DEBUG) << "The input index [" << input_index
<< "] is exceed the number of input:" << inputs_kernel_object_type_.size();
#endif
return KernelObjectType::UNKNOWN_TYPE;
@ -81,7 +81,7 @@ KernelObjectType KernelBuildInfo::GetInputKernelObjectType(size_t input_index) c
KernelObjectType KernelBuildInfo::GetOutputKernelObjectType(size_t output_index) const {
if (output_index >= outputs_kernel_object_type_.size()) {
#ifdef ENABLE_TUPLE_UNFOLD
MS_LOG(ERROR) << "The output index [" << output_index
MS_LOG(DEBUG) << "The output index [" << output_index
<< "] is exceed the number of output:" << outputs_kernel_object_type_.size();
#endif
return KernelObjectType::UNKNOWN_TYPE;
@ -182,18 +182,33 @@ std::string KernelBuildInfo::ToString() const {
if (index != 0) {
output_buffer << ", ";
}
output_buffer << "<" << TypeIdLabel(GetInputDeviceType(index)) << "x" << GetInputFormat(index) << "x"
<< KernelObjectTypeLabel(GetInputKernelObjectType(index)) << ">";
output_buffer << "<" << TypeIdLabel(GetInputDeviceType(index)) << "x" << GetInputFormat(index) << ">";
}
output_buffer << ") -> (";
output_buffer << ", object_type: [";
auto input_object_types = GetAllInputKernelObjectTypes();
for (size_t index = 0; index < input_object_types.size(); ++index) {
if (index != 0) {
output_buffer << ",";
}
output_buffer << KernelObjectTypeLabel(input_object_types[index]);
}
output_buffer << "]) -> (";
for (size_t index = 0; index < GetOutputNum(); ++index) {
if (index != 0) {
output_buffer << ",";
}
output_buffer << "<" << TypeIdLabel(GetOutputDeviceType(index)) << "x" << GetOutputFormat(index) << ">";
}
output_buffer << ", object_type: [";
auto output_object_types = GetAllOutputKernelObjectTypes();
for (size_t index = 0; index < output_object_types.size(); ++index) {
if (index != 0) {
output_buffer << ", ";
}
output_buffer << "<" << TypeIdLabel(GetOutputDeviceType(index)) << "x" << GetOutputFormat(index) << "x"
<< KernelObjectTypeLabel(GetOutputKernelObjectType(index)) << ">";
output_buffer << KernelObjectTypeLabel(output_object_types[index]);
}
output_buffer << ")";
output_buffer << "])";
return output_buffer.str();
}

View File

@ -174,13 +174,18 @@ class BACKEND_EXPORT KernelBuildInfo::KernelBuildInfoBuilder {
(void)kernel_build_info_->inputs_device_type_.emplace_back(kernel_build_info->GetInputDeviceType(index));
(void)kernel_build_info_->inputs_format_.emplace_back(kernel_build_info->GetInputFormat(index));
(void)kernel_build_info_->input_reshape_type_.emplace_back(kernel_build_info->GetInputReshapeType(index));
}
for (size_t index = 0; index < kernel_build_info->GetAllInputKernelObjectTypes().size(); ++index) {
(void)kernel_build_info_->inputs_kernel_object_type_.emplace_back(
kernel_build_info->GetInputKernelObjectType(index));
}
for (size_t index = 0; index < kernel_build_info->GetOutputNum(); ++index) {
(void)kernel_build_info_->outputs_device_type_.emplace_back(kernel_build_info->GetOutputDeviceType(index));
(void)kernel_build_info_->outputs_format_.emplace_back(kernel_build_info->GetOutputFormat(index));
(void)kernel_build_info_->output_reshape_type_.emplace_back(kernel_build_info->GetOutputReshapeType(index));
}
for (size_t index = 0; index < kernel_build_info->GetAllOutputKernelObjectTypes().size(); ++index) {
(void)kernel_build_info_->outputs_kernel_object_type_.emplace_back(
kernel_build_info->GetOutputKernelObjectType(index));
}

View File

@ -36,6 +36,8 @@
#include "frontend/operator/ops.h"
#include "utils/trace_base.h"
#include "mindspore/core/ops/op_name.h"
#include "kernel/common_utils.h"
#include "kernel/kernel_build_info.h"
namespace mindspore {
namespace device {
@ -79,18 +81,27 @@ mindspore::HashSet<std::string> kHighPrecisionOp = {kConv2DOpName,
kBiasAddGradOpName,
kSigmoidCrossEntropyWithLogitsV2OpName};
bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) {
bool MatchUnfoldInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfoPtr &kernel_build_info) {
MS_EXCEPTION_IF_NULL(cnode);
// Check input data type
for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) {
TypeId input_origin_type = common::AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) {
return false;
size_t kernel_input_index = 0;
size_t fold_input_tensor_num = common::AnfAlgo::GetInputTensorNum(cnode);
for (size_t input_index = 0; input_index < fold_input_tensor_num; ++input_index) {
std::vector<TypeId> inputs_type = common::AnfAlgo::GetRealPrevNodesOutputInferDataType(cnode, input_index);
for (size_t i = 0; i < inputs_type.size(); ++i) {
if (kernel_input_index >= kernel_build_info->GetInputNum()) {
return false;
}
if (kernel_build_info->GetInputDeviceType(kernel_input_index) != inputs_type[i]) {
return false;
}
++kernel_input_index;
}
}
// Check output data type
for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) {
if (kernel_build_info.GetOutputDeviceType(output_index) !=
for (size_t output_index = 0; output_index < kernel_build_info->GetOutputNum(); ++output_index) {
if (kernel_build_info->GetOutputDeviceType(output_index) !=
common::AnfAlgo::GetOutputInferDataType(cnode, output_index)) {
return false;
}
@ -98,6 +109,52 @@ bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildIn
return true;
}
bool MatchFoldInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfoPtr &kernel_build_info) {
MS_EXCEPTION_IF_NULL(cnode);
// Check input data type
size_t fold_input_tensor_num = common::AnfAlgo::GetInputTensorNum(cnode);
size_t kernel_index = 0;
for (size_t input_index = 0; input_index < fold_input_tensor_num; ++input_index) {
if (kernel_build_info->GetInputKernelObjectType(kernel_index) == kernel::KernelObjectType::TUPLE) {
auto input_node = cnode->inputs()[input_index + 1];
TypeId input_origin_type = common::AnfAlgo::GetOutputInferDataType(input_node, 0);
if (kernel_build_info->GetInputDeviceType(kernel_index) != input_origin_type) {
return false;
}
++kernel_index;
} else {
std::vector<TypeId> inputs_type = common::AnfAlgo::GetRealPrevNodesOutputInferDataType(cnode, input_index);
for (size_t i = 0; i < inputs_type.size(); ++i) {
if (kernel_index >= kernel_build_info->GetInputNum()) {
return false;
}
if (kernel_build_info->GetInputDeviceType(kernel_index) != inputs_type[i]) {
return false;
}
++kernel_index;
}
}
}
// Check output data type
for (size_t output_index = 0; output_index < kernel_build_info->GetOutputNum(); ++output_index) {
if (kernel_build_info->GetOutputDeviceType(output_index) !=
common::AnfAlgo::GetOutputInferDataType(cnode, output_index)) {
return false;
}
}
return true;
}
bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfoPtr &kernel_build_info) {
MS_EXCEPTION_IF_NULL(cnode);
bool is_fold = kernel::IsFoldKernelBuildInfo(kernel_build_info);
if (is_fold) {
return MatchFoldInferOutputDataType(cnode, kernel_build_info);
} else {
return MatchUnfoldInferOutputDataType(cnode, kernel_build_info);
}
}
string GetPriorityMatchFormat(const CNodePtr &cnode) {
constexpr size_t k5dSize = 5;
constexpr size_t k4dSize = 4;
@ -195,7 +252,7 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
}
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
size_t output_num = AnfAlgo::GetOutputElementNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
// cal count of same output dtype between abstract and kernel info
if (kernel_build_info.GetOutputDeviceType(output_index) ==
@ -210,14 +267,14 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
std::string PrintRaiseOrReducePrecisionSelectedInfo(
const CNodePtr &cnode, const std::shared_ptr<kernel::KernelBuildInfo> &selected_kernel_build_info,
bool precision_reduce) {
KernelSelectStatus KernelSelectStatus) {
MS_EXCEPTION_IF_NULL(selected_kernel_build_info);
MS_EXCEPTION_IF_NULL(cnode);
std::ostringstream buffer;
buffer << cnode->DebugString();
if (precision_reduce) {
if (KernelSelectStatus == kStatusReducePrecision) {
buffer << " Reduce precision, node datatype: \n";
} else {
} else if (KernelSelectStatus == kStatusRaisePrecision) {
buffer << " Raise precision, node datatype: \n";
}
GatherInputAndOutputInferType(buffer, cnode);
@ -250,7 +307,7 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilteredKernelInfoByDtype(
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> result;
for (const auto &kernel_build_info : kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_build_info);
if (!MatchInferOutputDataType(cnode, *kernel_build_info)) {
if (!MatchInferOutputDataType(cnode, kernel_build_info)) {
continue;
}
result.push_back(kernel_build_info);
@ -258,6 +315,131 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilteredKernelInfoByDtype(
return result;
}
bool MatchObjectType(const kernel::KernelObjectType &node_object, const kernel::KernelObjectType &kernel_object) {
if (node_object == kernel_object) {
return true;
}
if ((node_object == kernel::TUPLE || node_object == kernel::TUPLE_UNFOLD || node_object == kernel::SCALAR) &&
(kernel_object == kernel::TENSOR)) {
return true;
}
// for monad output op such as labelset labelswitch labelgoto ...
if (node_object == kernel::UNKNOWN_TYPE && kernel_object == kernel::TENSOR) {
return true;
}
MS_LOG(INFO) << "Object mismatch. node object type : " << node_object << ", kernel object type: " << kernel_object;
return false;
}
// kernel:tuple, node:tuple -> compare objecttype
// kernel:tuple, node:tensor -> compare objecttype
// kernel:tensor, node:tensor -> compare objecttype
// kernel:tensor, node:tuple -> unfold node, then compare object type
bool MatchObjectType(const CNodePtr &cnode, const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info) {
MS_EXCEPTION_IF_NULL(cnode);
// Check input object type
auto kernel_inputs_object_type = kernel_build_info->GetAllInputKernelObjectTypes();
auto node_inputs_object_type = kernel::TypeIdToKernelObjectType(AnfAlgo::GetAllInputObjectType(cnode));
size_t kernel_input_index = 0;
std::vector<kernel::KernelObjectType> new_input_object_types = {};
for (size_t input_index = 0; input_index < node_inputs_object_type.size(); ++input_index) {
if (kernel_inputs_object_type[kernel_input_index] != kernel::KernelObjectType::TUPLE &&
node_inputs_object_type[input_index] == kernel::KernelObjectType::TUPLE) {
// tuple_unfold condition
std::vector<KernelWithIndex> index_inputs = common::AnfAlgo::GetRealPrevNodesOutput(cnode, input_index);
for (size_t i = 0; i < index_inputs.size(); ++i) {
auto real_input_node = index_inputs[i].first;
MS_EXCEPTION_IF_NULL(real_input_node);
if (kernel_input_index >= kernel_inputs_object_type.size()) {
MS_LOG(DEBUG) << "index is large equal than list size: " << kernel_input_index << " vs "
<< kernel_inputs_object_type.size();
return false;
}
if (!MatchObjectType(
kernel::TypeIdToKernelObjectType(AnfAlgo::GetAbstractObjectType(real_input_node->abstract())),
kernel_inputs_object_type[kernel_input_index])) {
return false;
}
++kernel_input_index;
}
new_input_object_types.push_back(kernel::KernelObjectType::TUPLE_UNFOLD);
} else {
auto node_object = node_inputs_object_type[input_index];
auto kernel_object = kernel_inputs_object_type[kernel_input_index];
if (!MatchObjectType(node_object, kernel_object)) {
return false;
}
if (node_object == kernel::KernelObjectType::SCALAR && kernel_object == kernel::KernelObjectType::TENSOR) {
new_input_object_types.push_back(kernel::KernelObjectType::SCALAR);
} else {
new_input_object_types.push_back(kernel_inputs_object_type[kernel_input_index]);
}
++kernel_input_index;
}
}
if (kernel_input_index != kernel_inputs_object_type.size()) {
MS_LOG(DEBUG) << "index is not equal to list size: " << kernel_input_index << " vs "
<< kernel_inputs_object_type.size();
return false;
}
// Check output object type
auto kernel_outputs_object_type = kernel_build_info->GetAllOutputKernelObjectTypes();
auto node_output_object_type = AnfAlgo::GetAbstractObjectType(cnode->abstract());
std::vector<kernel::KernelObjectType> new_output_object_types = {};
if (node_output_object_type == kObjectTypeTuple) {
auto tuple_abs = cnode->abstract()->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abs);
auto items = tuple_abs->elements();
size_t output_index = 0;
for (auto item : items) {
if (output_index >= kernel_outputs_object_type.size()) {
MS_LOG(DEBUG) << "index is large equal than list size: " << output_index << " vs "
<< kernel_outputs_object_type.size();
return false;
}
if (!MatchObjectType(kernel::TypeIdToKernelObjectType(AnfAlgo::GetAbstractObjectType(item)),
kernel_outputs_object_type[output_index])) {
return false;
}
++output_index;
}
new_output_object_types = {kernel::KernelObjectType::TUPLE_UNFOLD};
} else {
auto output_num = AnfAlgo::GetOutputElementNum(cnode);
if (output_num > 0) {
if (!MatchObjectType(kernel::TypeIdToKernelObjectType(node_output_object_type), kernel_outputs_object_type[0])) {
return false;
}
new_output_object_types.push_back(kernel_outputs_object_type[0]);
}
}
kernel_build_info->SetInputsKernelObjectType(new_input_object_types);
kernel_build_info->SetOutputsKernelObjectType(new_output_object_types);
return true;
}
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilteredKernelInfoByObjectType(
const CNodePtr &cnode, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> result;
for (const auto &kernel_build_info : kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_build_info);
auto new_kernel_build_info = std::make_shared<kernel::KernelBuildInfo>(*kernel_build_info);
if (!MatchObjectType(cnode, new_kernel_build_info)) {
continue;
}
result.push_back(new_kernel_build_info);
}
return result;
}
void SetCastAndWeightFormat(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
if (!common::AnfAlgo::HasNodeAttr(kAttrPynativeNextIndex, kernel_node) ||
@ -386,14 +568,17 @@ TypeId GetInputDeviceType(const CNodePtr &kernel_node, size_t input_idx) {
return type;
}
void GetInputsDeviceType(const CNodePtr &kernel_node, std::vector<TypeId> *input_types) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(input_types);
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t i = 0; i < input_num; ++i) {
auto type = GetInputDeviceType(kernel_node, i);
input_types->emplace_back(type);
TypeId GetInputDeviceType(const KernelWithIndex &input_node_with_index) {
TypeId type = kTypeUnknown;
auto input_node = input_node_with_index.first;
MS_EXCEPTION_IF_NULL(input_node);
auto kernel_info = dynamic_cast<device::KernelInfo *>(input_node->kernel_info());
if (kernel_info != nullptr && kernel_info->select_kernel_build_info() != nullptr) {
type = AnfAlgo::GetOutputDeviceDataType(input_node_with_index.first, input_node_with_index.second);
} else {
type = common::AnfAlgo::GetOutputInferDataType(input_node_with_index.first, input_node_with_index.second);
}
return type;
}
string InferOutputFormat(const CNodePtr &kernel_node, const std::vector<std::string> &inputs_format) {
@ -464,32 +649,54 @@ KernelSelectStatus SelectCustomKernelInfo(const CNodePtr &kernel_node, KernelTyp
// set inputs info
std::vector<TypeId> inputs_device_type;
std::vector<std::string> inputs_format;
GetInputsDeviceType(kernel_node, &inputs_device_type);
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
std::vector<kernel::KernelObjectType> inputs_kernel_object_type;
std::unordered_set<string> all_input_formats;
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t i = 0; i < input_num; ++i) {
auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i);
inputs_format.emplace_back(format);
all_input_formats.insert(format);
auto input_node = common::AnfAlgo::GetInputNode(kernel_node, i);
MS_EXCEPTION_IF_NULL(input_node);
if (common::AnfAlgo::IsTupleOutput(input_node)) {
std::vector<KernelWithIndex> inputs_with_index = common::AnfAlgo::GetRealPrevNodesOutput(kernel_node, i);
for (size_t j = 0; j < inputs_with_index.size(); ++j) {
auto type = GetInputDeviceType(inputs_with_index[j]);
inputs_device_type.emplace_back(type);
auto format = AnfAlgo::GetOutputFormat(inputs_with_index[j].first, inputs_with_index[j].second);
inputs_format.emplace_back(format);
all_input_formats.insert(format);
}
inputs_kernel_object_type.emplace_back(kernel::KernelObjectType::TUPLE_UNFOLD);
} else {
auto type = GetInputDeviceType(kernel_node, i);
inputs_device_type.emplace_back(type);
auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i);
inputs_format.emplace_back(format);
all_input_formats.insert(format);
inputs_kernel_object_type.emplace_back(kernel::KernelObjectType::TENSOR);
}
}
if (all_input_formats.size() > 1) {
MS_LOG(WARNING) << op_name << " has different input formats, the number of input formats is "
<< all_input_formats.size();
}
builder->SetInputsDeviceType(inputs_device_type);
builder->SetInputsFormat(inputs_format);
builder->SetInputsKernelObjectType(inputs_kernel_object_type);
// set outputs info
std::vector<TypeId> outputs_device_type;
std::vector<std::string> outputs_format;
auto output_infer_format = InferOutputFormat(kernel_node, inputs_format);
MS_LOG(INFO) << "Outputs of " << op_name << " will use same inferred format: " << output_infer_format;
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
size_t output_num = AnfAlgo::GetOutputElementNum(kernel_node);
for (size_t i = 0; i < output_num; ++i) {
outputs_device_type.push_back(common::AnfAlgo::GetOutputInferDataType(kernel_node, i));
outputs_format.push_back(output_infer_format);
}
builder->SetOutputsDeviceType(outputs_device_type);
builder->SetOutputsFormat(outputs_format);
builder->SetOutputsKernelObjectType(
std::vector<kernel::KernelObjectType>(outputs_format.size(), kernel::KernelObjectType::TENSOR));
// Set kernel build info to node
auto build_info = builder->Build();
MS_LOG(INFO) << "Current node: " << kernel_node->fullname_with_scope() << " selected: " << build_info;
@ -565,28 +772,47 @@ void ResetPreFixedFormat(const CNodePtr &kernel_node, kernel::KernelBuildInfoPtr
}
} // namespace
void RefreshInputParameter(const CNodePtr &kernel_node, const AnfNodePtr &input_kernel_node,
const std::string &input_format, size_t input_index) {
auto input_with_index = common::AnfAlgo::VisitKernelWithReturnType(input_kernel_node, 0);
MS_EXCEPTION_IF_NULL(input_with_index.first);
auto real_input_node = input_with_index.first;
MS_EXCEPTION_IF_NULL(real_input_node);
if (RefreshCastAndParamWeightFormat(real_input_node, input_format)) {
return;
}
if (real_input_node->isa<Parameter>() && !common::AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
return;
}
std::vector<std::string> output_format = {input_format};
SetWeightFormat(real_input_node, output_format, kernel_node, input_index);
return;
}
void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
auto selected_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
MS_EXCEPTION_IF_NULL(selected_kernel_info);
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
size_t real_input_num = 0;
for (size_t input_index = 0; input_index < input_num; ++input_index) {
auto input_kernel_node = common::AnfAlgo::GetInputNode(kernel_node, input_index);
MS_EXCEPTION_IF_NULL(input_kernel_node);
auto input_with_index = common::AnfAlgo::VisitKernelWithReturnType(input_kernel_node, 0);
MS_EXCEPTION_IF_NULL(input_with_index.first);
auto real_input_node = input_with_index.first;
MS_EXCEPTION_IF_NULL(real_input_node);
if (RefreshCastAndParamWeightFormat(real_input_node, selected_kernel_info->GetInputFormat(input_index))) {
continue;
auto input_object_type = selected_kernel_info->GetInputKernelObjectType(input_index);
if (input_object_type == kernel::KernelObjectType::TUPLE_UNFOLD) {
std::vector<KernelWithIndex> kernels_with_index =
common::AnfAlgo::GetRealPrevNodesOutput(kernel_node, input_index);
for (size_t i = 0; i < kernels_with_index.size(); ++i) {
RefreshInputParameter(kernel_node, kernels_with_index[i].first,
selected_kernel_info->GetInputFormat(real_input_num), real_input_num);
++real_input_num;
}
} else {
auto input_kernel_node = common::AnfAlgo::GetInputNode(kernel_node, input_index);
MS_EXCEPTION_IF_NULL(input_kernel_node);
RefreshInputParameter(kernel_node, input_kernel_node, selected_kernel_info->GetInputFormat(real_input_num),
real_input_num);
++real_input_num;
}
if (real_input_node->isa<Parameter>() &&
!common::AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
continue;
}
auto refresh_format = selected_kernel_info->GetInputFormat(input_index);
std::vector<std::string> output_format = {refresh_format};
SetWeightFormat(real_input_node, output_format, kernel_node, input_index);
}
}
@ -595,28 +821,40 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
MS_EXCEPTION_IF_NULL(kernel_node);
KernelSelectStatus select_status = kNoMatched;
if (kernel_info_list.empty()) {
return select_status;
return kNoMatched;
}
bool precision_reduce = false;
kernel::KernelBuildInfoPtr selected_kernel_info = nullptr;
// Matched kernel info
// Filter kernel info matched with me inferred type
auto filtered_kernel_info_list = FilteredKernelInfoByDtype(kernel_node, kernel_info_list);
if (!filtered_kernel_info_list.empty()) {
selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
select_status = kStatusAllMatched;
} else {
if (filtered_kernel_info_list.empty()) {
// selected kernel info using raised precision or reduce precision
filtered_kernel_info_list =
FilterRaisedOrReducePrecisionMatchedKernelInfo(kernel_node, kernel_info_list, &precision_reduce);
selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
if (selected_kernel_info == nullptr) {
return select_status;
} else {
MS_LOG(INFO) << PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce);
select_status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision;
if (filtered_kernel_info_list.empty()) {
return kNoMatched;
}
select_status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision;
} else {
select_status = kStatusAllMatched;
}
// filter object_type and adjust tuple_unfold condition
MS_LOG(DEBUG) << "Node " << kernel_node->fullname_with_scope() << "'s kernel info list size is "
<< filtered_kernel_info_list.size() << " before object type matching";
filtered_kernel_info_list = FilteredKernelInfoByObjectType(kernel_node, filtered_kernel_info_list);
MS_LOG(DEBUG) << "Node " << kernel_node->fullname_with_scope() << "'s kernel info list size is "
<< filtered_kernel_info_list.size() << " after object type matching";
if (filtered_kernel_info_list.empty()) {
return kNoMatched;
}
selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
if (select_status == kStatusReducePrecision || kStatusReducePrecision == kStatusRaisePrecision) {
MS_LOG(INFO) << PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, select_status);
}
// Set kernel build info to node
MS_LOG(DEBUG) << "Current node: " << kernel_node->fullname_with_scope()
<< " selected: " << selected_kernel_info->ToString();
@ -861,6 +1099,47 @@ void SetAclKernelInfo(const CNodePtr &kernel_node) {
AnfAlgo::SetSelectKernelBuildInfo(new_builder->Build(), kernel_node.get());
}
void SetDynamicInputSizeAttrBeforeKernelSelect(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) ||
common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPartial)) {
return;
}
if (common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, cnode)) {
return;
}
std::vector<int64_t> dyn_input_sizes;
size_t input_num = cnode->inputs().size() - 1;
for (size_t i = 0; i < input_num; ++i) {
auto input_node = common::AnfAlgo::GetInputNode(cnode, i);
dyn_input_sizes.push_back(kernel::CalOutputTupleSize(input_node));
}
if (std::any_of(dyn_input_sizes.begin(), dyn_input_sizes.end(), [](int64_t s) { return s >= 0; })) {
common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), cnode);
}
}
void RefreshDynamicInputSizeAttr(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
if (!common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, cnode)) {
MS_LOG(INFO) << "Node has not set kAttrDynInputSizes yet, node: " << cnode->fullname_with_scope();
return;
}
std::vector<int64_t> dyn_input_sizes = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, kAttrDynInputSizes);
auto input_obj_types = AnfAlgo::GetInputKernelObjectTypes(cnode);
size_t input_num = cnode->inputs().size() - 1;
for (size_t i = 0; i < input_num; ++i) {
if (input_obj_types[i] == kernel::KernelObjectType::TUPLE) {
dyn_input_sizes[i] = -1;
}
}
if (std::any_of(dyn_input_sizes.begin(), dyn_input_sizes.end(), [](int64_t s) { return s >= 0; })) {
common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), cnode);
} else {
common::AnfAlgo::EraseNodeAttr(kAttrDynInputSizes, cnode);
}
}
std::tuple<KernelSelectStatus, std::string, ExceptionType> SelectKernelInfoWithMsg(const CNodePtr &kernel_node,
KernelType kernel_type) {
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
@ -870,15 +1149,18 @@ std::tuple<KernelSelectStatus, std::string, ExceptionType> SelectKernelInfoWithM
std::tuple<KernelSelectStatus, std::string, ExceptionType> result =
std::make_tuple(kStatusAllMatched, "", NoExceptionType);
MS_EXCEPTION_IF_NULL(kernel_node);
SetDynamicInputSizeAttrBeforeKernelSelect(kernel_node);
if (common::AnfAlgo::IsGraphKernel(kernel_node)) {
auto func_graph = GetValueNode<FuncGraphPtr>(kernel_node->input(kAnfPrimitiveIndex));
MS_EXCEPTION_IF_NULL(func_graph);
SelectGraphKernelInfo(kernel_node, func_graph);
RefreshDynamicInputSizeAttr(kernel_node);
return result;
}
if (IsPrimitiveCNode(kernel_node, prim::kPrimCallInline)) {
opt::SelectCallInlineKernelInfo(kernel_node);
SetTensorDeviceInfo(kernel_node);
RefreshDynamicInputSizeAttr(kernel_node);
return result;
}
if (common::AnfAlgo::HasNodeAttr(ops::kBatchRank, kernel_node)) {
@ -892,6 +1174,7 @@ std::tuple<KernelSelectStatus, std::string, ExceptionType> SelectKernelInfoWithM
if (IsPrimitiveCNode(kernel_node, prim::kPrimCustom)) {
auto select_status = SelectCustomKernelInfo(kernel_node, &kernel_type);
if (select_status == kStatusAllMatched) {
RefreshDynamicInputSizeAttr(kernel_node);
return result;
}
}
@ -942,6 +1225,7 @@ std::tuple<KernelSelectStatus, std::string, ExceptionType> SelectKernelInfoWithM
select_status = SetMatchedKernelInfo(kernel_node, aicpu_kernel_info_list);
common::AnfAlgo::SetNodeAttr(kAttrIsAiCpuKernel, MakeValue(true), kernel_node);
}
// The kernel info can not find in ai_cpu kernel lists and ai_core kernel lists
if (select_status == kNoMatched) {
GatherInputAndOutputInferType(aicpu_in_out_info, kernel_node);
@ -954,6 +1238,7 @@ std::tuple<KernelSelectStatus, std::string, ExceptionType> SelectKernelInfoWithM
std::get<two>(result) = etype;
return result;
}
RefreshDynamicInputSizeAttr(kernel_node);
SetRaiseOrReduceFlag(kernel_node, select_status);
std::get<0>(result) = select_status;
return result;

View File

@ -72,6 +72,7 @@ void ResetKernelBuildInfo(const CNodePtr &kernel_node) {
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
builder.SetOutputsDeviceType(std::vector<TypeId>{kTypeUnknown});
builder.SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_kernel_node.get());
}
}
@ -283,6 +284,7 @@ void UpdateInputsKernelInfo(const CNodePtr &kernel_node, const std::vector<AnfNo
std::vector<TypeId> outputs_device_type = {(*graph_input_type)[i]};
builder.SetOutputsFormat(outputs_format);
builder.SetOutputsDeviceType(outputs_device_type);
builder.SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get());
}
}
@ -403,6 +405,7 @@ void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNo
std::vector<TypeId> outputs_device_type = {graph_input_type[i]};
builder.SetOutputsFormat(outputs_format);
builder.SetOutputsDeviceType(outputs_device_type);
builder.SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get());
}
@ -436,6 +439,7 @@ void SetGraphKernelInfo(const CNodePtr &kernel_node, const std::vector<std::pair
MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<std::string> graph_output_format;
std::vector<TypeId> graph_output_type;
std::vector<kernel::KernelObjectType> graph_output_object_type;
for (size_t i = 0; i < output_index.size(); ++i) {
auto const &output = output_index[i];
graph_output_format.push_back(AnfAlgo::GetOutputFormat(output.first, output.second));
@ -447,13 +451,21 @@ void SetGraphKernelInfo(const CNodePtr &kernel_node, const std::vector<std::pair
output_type = AnfAlgo::GetOutputDeviceDataType(output.first, output.second);
}
graph_output_type.push_back(output_type);
graph_output_object_type.push_back(kernel::KernelObjectType::TENSOR);
}
std::vector<kernel::KernelObjectType> graph_input_object_type;
for (size_t i = 0; i < graph_input_type.size(); ++i) {
graph_input_object_type.push_back(kernel::KernelObjectType::TENSOR);
}
kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
graph_info_builder.SetInputsFormat(graph_input_format);
graph_info_builder.SetInputsDeviceType(graph_input_type);
graph_info_builder.SetInputsKernelObjectType(graph_input_object_type);
graph_info_builder.SetOutputsFormat(graph_output_format);
graph_info_builder.SetOutputsDeviceType(graph_output_type);
graph_info_builder.SetOutputsKernelObjectType(graph_output_object_type);
graph_info_builder.SetProcessor(kernel::Processor::AICORE);
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
graph_info_builder.SetFusionType(kernel::kPatternOpaque);

View File

@ -35,12 +35,12 @@ namespace mindspore {
namespace device {
namespace ascend {
void AscendDeviceContext::Initialize() {
MS_LOG(INFO) << "Start Initialize...";
if (initialized_) {
MS_EXCEPTION_IF_NULL(runtime_instance_);
runtime_instance_->SetContext();
return;
} else {
MS_LOG(INFO) << "Start Initialize...";
#ifndef ENABLE_SECURITY
AscendProfiler::GetInstance()->MsprofInitProfiler();
#endif

View File

@ -268,6 +268,8 @@ void AscendGraphOptimization::OptimizeGraphWithoutDeviceInfo(const KernelGraphPt
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
CommOpReuse(graph);
opt::AscendUnfoldInputsForSpecialNodes(graph);
if (context_ptr->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {
HandleControlFlow(NOT_NULL(graph));
}
@ -448,6 +450,7 @@ void AscendGraphOptimization::RecurseSelectKernelInfo(const KernelGraphPtr &grap
}
#endif
MS_LOG(INFO) << "Status record: start select kernel info. graph id: " << graph->graph_id();
graph->SetKernelObjectTypesForUnrealNodes();
SetOperatorInfo(graph);
MS_LOG(INFO) << "Status record: end select kernel info. graph id: " << graph->graph_id();
#ifdef ENABLE_DUMP_IR

View File

@ -238,7 +238,7 @@ void UpdateOutputNodeShape(const AnfNodePtr &node, size_t index, TypeId output_t
if (node->isa<CNode>()) {
name = common::AnfAlgo::GetCNodeName(node);
}
size_t total_output_num = AnfAlgo::GetOutputTensorNum(node);
size_t total_output_num = AnfAlgo::GetOutputElementNum(node);
if (index >= total_output_num) {
MS_LOG(EXCEPTION) << "Invalid output index " << index << ", node " << node->fullname_with_scope() << " has "
<< total_output_num << " outputs.";

View File

@ -64,20 +64,27 @@ void AicpuMetadataInfoForSpecialNodes(const CNodePtr &kernel_node,
if (kDynamicInputOps.find(op_name) != kDynamicInputOps.end()) {
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_format.emplace_back(kOpFormat_DEFAULT);
(void)inputs_type.emplace_back(common::AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index));
inputs_object_type.emplace_back(KernelObjectType::TENSOR);
auto kernels_with_index = common::AnfAlgo::GetRealPrevNodesOutput(kernel_node, input_index);
for (size_t i = 0; i < kernels_with_index.size(); ++i) {
inputs_format.emplace_back(kOpFormat_DEFAULT);
(void)inputs_type.emplace_back(
common::AnfAlgo::GetOutputInferDataType(kernels_with_index[i].first, kernels_with_index[i].second));
inputs_object_type.emplace_back(kernel::TypeIdToKernelObjectType(
AnfAlgo::GetOutputObjectType(kernels_with_index[i].first, kernels_with_index[i].second)));
}
}
}
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
std::vector<KernelObjectType> outputs_object_type{};
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
size_t output_num = AnfAlgo::GetOutputElementNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
outputs_format.emplace_back(kOpFormat_DEFAULT);
(void)outputs_type.emplace_back(common::AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
outputs_object_type.emplace_back(KernelObjectType::TENSOR);
outputs_object_type.emplace_back(
kernel::TypeIdToKernelObjectType(AnfAlgo::GetOutputObjectType(kernel_node, output_index)));
}
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
builder.SetInputsFormat(inputs_format);
builder.SetInputsDeviceType(inputs_type);

View File

@ -98,7 +98,7 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
std::vector<KernelObjectType> output_object_type{};
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
size_t output_num = AnfAlgo::GetOutputElementNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
(void)outputs_format.emplace_back(GetKernelFormat(kernel_node, output_index));
if (op_name == kReceiveOpName) {

View File

@ -37,7 +37,7 @@ void HostMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
std::vector<std::string> inputs_format{};
std::vector<TypeId> inputs_type{};
std::vector<KernelObjectType> inputs_object_type{};
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
size_t input_num = AnfAlgo::GetInputElementNum(kernel_node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_format.emplace_back(kOpFormat_DEFAULT);
inputs_type.push_back(common::AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index));
@ -46,7 +46,7 @@ void HostMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
std::vector<KernelObjectType> outputs_object_type{};
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
size_t output_num = AnfAlgo::GetOutputElementNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
outputs_format.emplace_back(kOpFormat_DEFAULT);
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(kernel_node, output_index));

View File

@ -25,8 +25,10 @@
#include "plugin/device/ascend/kernel/akg/akg_kernel_metadata.h"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
#include "backend/common/optimizer/helper.h"
#include "utils/ms_context.h"
#include "utils/trace_base.h"
#include "kernel/common_utils.h"
namespace mindspore {
namespace kernel {
@ -38,36 +40,91 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
return;
}
MS_EXCEPTION_IF_NULL(kernel_node);
size_t output_tensor_num = AnfAlgo::GetOutputTensorNum(kernel_node);
size_t input_tensor_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
size_t unfold_output_tensor_num = AnfAlgo::GetOutputElementNum(kernel_node);
size_t unfold_input_tensor_num = AnfAlgo::GetInputElementNum(kernel_node);
size_t fold_output_tensor_num = 1;
size_t fold_input_tensor_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_list;
(void)std::copy_if(
kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list),
[output_tensor_num, input_tensor_num](const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info) {
MS_EXCEPTION_IF_NULL(kernel_build_info);
return kernel_build_info->GetOutputNum() == output_tensor_num &&
kernel_build_info->GetInputNum() == input_tensor_num;
});
std::ostringstream buffer;
size_t info_index = 0;
for (const auto &kernel_info : *kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_info);
bool is_fold = kernel::IsFoldKernelBuildInfo(kernel_info);
if (is_fold) {
bool is_match = true;
if (!common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, kernel_node)) {
is_match = false;
} else {
// compare input num
std::vector<int64_t> dyn_input_sizes =
common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, kAttrDynInputSizes);
size_t real_input_num = 0;
for (size_t i = 0; i < fold_input_tensor_num; ++i) {
if (kernel_info->GetInputKernelObjectType(i) == kernel::KernelObjectType::TUPLE || dyn_input_sizes[i] == -1) {
++real_input_num;
} else {
real_input_num += dyn_input_sizes[i];
}
}
if (kernel_info->GetInputNum() != real_input_num) {
is_match = false;
}
}
if (is_match) {
// compare output num
size_t real_output_num = unfold_output_tensor_num;
if (kernel_info->GetOutputKernelObjectType(0) == kernel::KernelObjectType::TUPLE) {
real_output_num = 1;
}
if (kernel_info->GetOutputNum() != real_output_num) {
is_match = false;
}
}
if (is_match) {
(void)filtered_list.emplace_back(kernel_info);
} else {
buffer << "Kernel [ " << info_index << " ] [Fold]:";
if (kernel_info->GetOutputNum() != fold_output_tensor_num) {
buffer << "Kernel build info's output size [" << kernel_info->GetOutputNum() << "]"
<< " cannot match the node's output size [" << fold_output_tensor_num << "]\n";
} else {
buffer << "Kernel build info's input size [" << kernel_info->GetInputNum() << "]"
<< " cannot match the node's input size [" << fold_input_tensor_num << "]\n";
}
buffer << "\n kernel info:" << kernel_info->ToString();
}
} else {
if ((kernel_info->GetInputNum() == unfold_input_tensor_num) &&
(kernel_info->GetOutputNum() == unfold_output_tensor_num)) {
(void)filtered_list.emplace_back(kernel_info);
} else {
buffer << "Kernel [ " << info_index << " ] [Unfold]:";
if (kernel_info->GetOutputNum() != unfold_output_tensor_num) {
buffer << "Kernel build info's output size [" << kernel_info->GetOutputNum() << "]"
<< " cannot match the node's output size [" << unfold_output_tensor_num << "]\n";
} else {
buffer << "Kernel build info's input size [" << kernel_info->GetInputNum() << "]"
<< " cannot match the node's input size [" << unfold_input_tensor_num << "]\n";
}
buffer << "\n kernel info:" << kernel_info->ToString();
}
}
info_index++;
}
if (!filtered_list.empty()) {
kernel_info_list->clear();
(void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list));
} else {
for (size_t index = 0; index < kernel_info_list->size(); ++index) {
std::ostringstream buffer;
auto &kernel_info = kernel_info_list->at(index);
MS_EXCEPTION_IF_NULL(kernel_info);
if (kernel_info->GetOutputNum() != output_tensor_num) {
buffer << "Kernel node's output size [" << output_tensor_num << "]"
<< " cannot match the kernel's output size [" << kernel_info->GetOutputNum() << "]";
} else {
buffer << "Kernel node's input size [" << input_tensor_num << "]"
<< " cannot match the kernel's input size [" << kernel_info->GetInputNum() << "]";
}
MS_LOG(INFO) << "Kernel [ " << index << " ] :" << kernel_info->ToString() << buffer.str();
}
MS_LOG(INFO) << buffer.str();
kernel_info_list->clear();
MS_LOG(INFO) << "Node: " << kernel_node->DebugString() << "'s output size : [" << output_tensor_num << "]"
<< "input size : [" << input_tensor_num << "] can not match any kernelInfo !";
MS_LOG(INFO) << "Node: " << kernel_node->DebugString() << "'s fold output size : [" << fold_output_tensor_num << "]"
<< ", fold input size : [" << fold_input_tensor_num << "], unfold output size : ["
<< unfold_output_tensor_num << "]"
<< ", unfold input size : [" << unfold_input_tensor_num << "] can not match any kernelInfo !";
}
}
@ -99,22 +156,37 @@ void KernelQueryAll(const CNodePtr &kernel_node,
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list);
TbeMetadataInfo(kernel_node, kernel_info_list);
auto select_cnode = kernel_node;
auto tuple_unfold_node = opt::ConvertMakeTupleInputToPlantInputs(kernel_node->func_graph(), kernel_node);
if (tuple_unfold_node != nullptr) {
auto tuple_unfold_cnode = tuple_unfold_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_unfold_cnode);
select_cnode = tuple_unfold_cnode;
select_cnode->set_fullname_with_scope(kernel_node->fullname_with_scope());
MS_LOG(INFO) << "Create tuple unfold node " << tuple_unfold_node->fullname_with_scope() << ", debug string ["
<< tuple_unfold_node->DebugString() << "] from " << kernel_node->fullname_with_scope()
<< ", debug string [" << kernel_node->DebugString() << "].";
}
TbeMetadataInfo(select_cnode, kernel_info_list);
if (kernel_info_list->empty()) {
GetRtKelInfo(kernel_node, kernel_info_list);
GetRtKelInfo(select_cnode, kernel_info_list);
CheckKernelInfoListEmpty(kernel_info_list, "RT_Kernel");
}
if (kernel_info_list->empty()) {
HcclMetadataInfo(kernel_node, kernel_info_list);
HcclMetadataInfo(select_cnode, kernel_info_list);
CheckKernelInfoListEmpty(kernel_info_list, "HCCL_Kernel");
}
if (SelectAicpuReshapeInTaskSink(kernel_node)) {
if (SelectAicpuReshapeInTaskSink(select_cnode)) {
return;
}
if (kernel_info_list->empty()) {
HostMetadataInfo(kernel_node, kernel_info_list);
HostMetadataInfo(select_cnode, kernel_info_list);
CheckKernelInfoListEmpty(kernel_info_list, "HOST_Kernel");
}
if (!kernel_info_list->empty()) {
common::AnfAlgo::CopyNodeAttrs(select_cnode, kernel_node);
}
}
void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list,

View File

@ -69,7 +69,7 @@ void GetRtKelInfo(const CNodePtr &kernel_node,
if (IsDefaultKernelInfo(node_name)) {
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// set input infos
auto input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
auto input_num = AnfAlgo::GetInputElementNum(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_build_info_builder);
kernel_build_info_builder->SetInputsFormat(std::vector<std::string>(input_num, kOpFormat_DEFAULT));
kernel_build_info_builder->SetInputsKernelObjectType(

View File

@ -124,7 +124,8 @@ RangePair TbeDynamicShapeUtil::GetOutputDynamicRange(const AnfNodePtr &anf_node,
kernel_info->select_kernel_build_info() == nullptr ? def_format : AnfAlgo::GetOutputFormat(anf_node, index);
auto data_type =
kernel_info->select_kernel_build_info() == nullptr ? type : AnfAlgo::GetOutputDeviceDataType(anf_node, index);
std::string reshape_type = AnfAlgo::GetOutputReshapeType(anf_node, index);
std::string reshape_type =
kernel_info->select_kernel_build_info() == nullptr ? "" : AnfAlgo::GetOutputReshapeType(anf_node, index);
trans::ShapeRangeTransfer shapeRangeTransfer;
RangePair ret;

View File

@ -257,7 +257,7 @@ bool SingleTbeJsonCreator::GenOutputsJson(const AnfNodePtr &anf_node, nlohmann::
size_t sum_outputs_num =
std::accumulate(outputs_tensor_num.begin(), outputs_tensor_num.end(), static_cast<size_t>(0));
size_t real_output_num = AnfAlgo::GetOutputTensorNum(anf_node);
size_t real_output_num = AnfAlgo::GetOutputElementNum(anf_node);
std::vector<nlohmann::json> outputs_desc;
for (size_t i = 0; i < real_output_num; i++) {
nlohmann::json output_desc;

View File

@ -54,7 +54,7 @@ bool TbeJsonUtils::GetInputsRealNum(const AnfNodePtr &anf_node, const std::vecto
bool TbeJsonUtils::GetOutputsRealNum(const AnfNodePtr &anf_node, const std::vector<OpIOInfoPtr> &outputs_ptr,
std::vector<size_t> *outputs_num) {
MS_EXCEPTION_IF_NULL(anf_node);
size_t real_output_num = AnfAlgo::GetOutputTensorNum(anf_node);
size_t real_output_num = AnfAlgo::GetOutputElementNum(anf_node);
for (const auto &output_ptr : outputs_ptr) {
if (output_ptr->param_type() == kJParamDynamic) {
if (outputs_ptr.size() > 1) {

View File

@ -24,8 +24,8 @@ namespace mindspore::kernel {
void TbeKernelAgnosticSelector::GetSupportedFormatDType(SupportFormatDType *support_format_dtype) {
MS_EXCEPTION_IF_NULL(cnode_ptr_);
SupportFormat support_format;
auto input_num = common::AnfAlgo::GetInputTensorNum(cnode_ptr_);
auto output_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_);
auto input_num = AnfAlgo::GetInputElementNum(cnode_ptr_);
auto output_num = AnfAlgo::GetOutputElementNum(cnode_ptr_);
if (input_num != 1 || output_num != 1) {
MS_LOG(EXCEPTION) << "Agnostic only support one input. input_num: " << input_num << ", output num: " << output_num
<< ", full_name:" << cnode_ptr_->fullname_with_scope();

View File

@ -59,7 +59,7 @@ void TbeKernelBroadcastSelector::GetBroadCastNodeInfo() {
(void)input_shapes_.emplace_back(dynamic_input_shape0_);
input_num_ = 1;
} else {
input_num_ = common::AnfAlgo::GetInputTensorNum(cnode_ptr_);
input_num_ = AnfAlgo::GetInputElementNum(cnode_ptr_);
for (size_t i = 0; i < input_num_; ++i) {
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i);
PadScalarShape(&input_shape);
@ -67,7 +67,7 @@ void TbeKernelBroadcastSelector::GetBroadCastNodeInfo() {
}
}
output_num_ = AnfAlgo::GetOutputTensorNum(cnode_ptr_);
output_num_ = AnfAlgo::GetOutputElementNum(cnode_ptr_);
for (size_t i = 0; i < output_num_; ++i) {
auto output = common::AnfAlgo::GetOutputInferShape(cnode_ptr_, i);
PadScalarShape(&output);

View File

@ -51,7 +51,7 @@ void TbeKernelReduceSelector::GetSupportedFormatDType(SupportFormatDType *suppor
void TbeKernelReduceSelector::GetReduceNodeInfo() {
auto input_num = common::AnfAlgo::GetInputTensorNum(cnode_ptr_);
auto output_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_);
auto output_num = AnfAlgo::GetOutputElementNum(cnode_ptr_);
if (input_num != 1 || output_num != 1) {
MS_LOG(INFO) << "Reduce operator input/output is not 1, input num: " << input_num << ", output num: " << output_num
<< ", node info: " << cnode_ptr_->DebugString();

View File

@ -48,6 +48,11 @@ void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<Ke
if (op_info_ptr == nullptr) {
return;
}
if (common::AnfAlgo::HasDynamicTupleInput(kernel_node)) {
return;
}
if (IsKernelDynamicImpl(kernel_node)) {
common::AnfAlgo::SetNodeAttr(kAttrIsKernelDynamicImpl, MakeValue(true), kernel_node);
if (tbe_selector.CheckOpSupported()) {
@ -486,7 +491,7 @@ bool TbeKernelSelect::GetKernelBuildInfoFromCache() {
void TbeKernelSelect::GenerateKernelBuildInfo(const SupportFormatDType &support_format_dtype) {
auto dyn_input_sizes = GetNodeDynamicInputs();
// get real input/output num
size_t real_input_num = common::AnfAlgo::GetInputTensorNum(cnode_ptr_);
size_t real_input_num = AnfAlgo::GetInputElementNum(cnode_ptr_);
size_t real_output_num = AnfAlgo::GetOutputElementNum(cnode_ptr_);
auto op_info_input_num = support_format_dtype.input_dtypes.size();
auto op_info_output_num = support_format_dtype.output_dtypes.size();

View File

@ -27,6 +27,7 @@
#include "include/common/utils/utils.h"
#include "plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.h"
#include "utils/ms_context.h"
#include "kernel/common_utils.h"
namespace mindspore::kernel {
namespace {
@ -56,7 +57,7 @@ bool HostCheck::CheckValidDeviceShape(const AnfNodePtr &node) {
}
}
size_t real_output_num = AnfAlgo::GetOutputTensorNum(node);
size_t real_output_num = AnfAlgo::GetOutputElementNum(node);
for (size_t i = 0; i < real_output_num; i++) {
auto format = AnfAlgo::GetOutputFormat(node, i);
if (!CheckValidInOutDeviceShape(node, i, true, format)) {
@ -165,8 +166,8 @@ bool IsKernelDynamicImpl(const AnfNodePtr &node) {
void GetSupportOriFormat(const CNodePtr &cnode, SupportFormat *support_format) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(support_format);
auto input_num = common::AnfAlgo::GetInputTensorNum(cnode);
auto output_num = AnfAlgo::GetOutputTensorNum(cnode);
auto input_num = AnfAlgo::GetInputElementNum(cnode);
auto output_num = AnfAlgo::GetOutputElementNum(cnode);
auto op_name = common::AnfAlgo::GetCNodeName(cnode);
auto op_info = tbe::TbeDynamicShapeUtil::FindOp(op_name, cnode);
MS_EXCEPTION_IF_NULL(op_info);
@ -266,30 +267,39 @@ bool CheckHitTargetDtype(const std::map<TypeId, TypeId> &type_map, const TypeId
return true;
}
bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info, const CNodePtr &cnode,
const std::map<TypeId, TypeId> &type_map) {
// filte kernel info that unsupported raise or reduce datatype
bool TagUnfoldRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info, const CNodePtr &cnode,
const std::map<TypeId, TypeId> &type_map) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(kernel_build_info);
for (size_t input_index = 0; input_index < kernel_build_info->GetInputNum(); ++input_index) {
auto in_dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
auto device_dtype = kernel_build_info->GetInputDeviceType(input_index);
if (device_dtype == kNumberTypeFloat) {
device_dtype = kNumberTypeFloat32;
}
if (!CheckHitTargetDtype(type_map, in_dtype, device_dtype)) {
return false;
// Check input data type
size_t kernel_input_index = 0;
size_t fold_input_tensor_num = common::AnfAlgo::GetInputTensorNum(cnode);
for (size_t input_index = 0; input_index < fold_input_tensor_num; ++input_index) {
std::vector<TypeId> inputs_type = common::AnfAlgo::GetRealPrevNodesOutputInferDataType(cnode, input_index);
for (size_t i = 0; i < inputs_type.size(); ++i) {
if (kernel_input_index >= kernel_build_info->GetInputNum()) {
return false;
}
auto device_dtype = kernel_build_info->GetInputDeviceType(kernel_input_index);
if (device_dtype == kNumberTypeFloat) {
device_dtype = kNumberTypeFloat32;
}
if (!CheckHitTargetDtype(type_map, inputs_type[i], device_dtype)) {
return false;
}
++kernel_input_index;
}
}
// Check output data type
for (size_t output_index = 0; output_index < kernel_build_info->GetOutputNum(); ++output_index) {
auto in_dtype = common::AnfAlgo::GetOutputInferDataType(cnode, output_index);
auto device_dtype = kernel_build_info->GetOutputDeviceType(output_index);
if (device_dtype == kNumberTypeFloat) {
device_dtype = kNumberTypeFloat32;
}
if (!CheckHitTargetDtype(type_map, in_dtype, device_dtype)) {
return false;
}
@ -297,6 +307,71 @@ bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build
return true;
}
bool TagFoldRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info, const CNodePtr &cnode,
const std::map<TypeId, TypeId> &type_map) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(kernel_build_info);
// Check input data type
size_t kernel_input_index = 0;
for (size_t input_index = 0; input_index < common::AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
if (kernel_build_info->GetInputKernelObjectType(kernel_input_index) == kernel::KernelObjectType::TUPLE) {
auto input_node = cnode->inputs()[input_index + 1];
TypeId in_dtype = common::AnfAlgo::GetOutputInferDataType(input_node, 0);
auto device_dtype = kernel_build_info->GetInputDeviceType(kernel_input_index);
if (device_dtype == kNumberTypeFloat) {
device_dtype = kNumberTypeFloat32;
}
if (!CheckHitTargetDtype(type_map, in_dtype, device_dtype)) {
return false;
}
++kernel_input_index;
} else {
std::vector<TypeId> inputs_type = common::AnfAlgo::GetRealPrevNodesOutputInferDataType(cnode, input_index);
for (size_t i = 0; i < inputs_type.size(); ++i) {
if (kernel_input_index >= kernel_build_info->GetInputNum()) {
return false;
}
auto device_dtype = kernel_build_info->GetInputDeviceType(kernel_input_index);
if (device_dtype == kNumberTypeFloat) {
device_dtype = kNumberTypeFloat32;
}
if (!CheckHitTargetDtype(type_map, inputs_type[i], device_dtype)) {
return false;
}
++kernel_input_index;
}
}
}
// Check output data type
for (size_t output_index = 0; output_index < kernel_build_info->GetOutputNum(); ++output_index) {
auto in_dtype = common::AnfAlgo::GetOutputInferDataType(cnode, output_index);
auto device_dtype = kernel_build_info->GetOutputDeviceType(output_index);
if (device_dtype == kNumberTypeFloat) {
device_dtype = kNumberTypeFloat32;
}
if (!CheckHitTargetDtype(type_map, in_dtype, device_dtype)) {
return false;
}
}
return true;
}
bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info, const CNodePtr &cnode,
const std::map<TypeId, TypeId> &type_map) {
// filte kernel info that unsupported raise or reduce datatype
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(kernel_build_info);
bool is_fold = kernel::IsFoldKernelBuildInfo(kernel_build_info);
if (is_fold) {
return TagFoldRaiseReduce(kernel_build_info, cnode, type_map);
} else {
return TagUnfoldRaiseReduce(kernel_build_info, cnode, type_map);
}
}
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecisionMatchedKernelInfo(
const CNodePtr &cnode, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list,
bool *precision_reduce) {

View File

@ -184,6 +184,8 @@
#include "include/common/debug/draw.h"
#include "plugin/device/ascend/optimizer/optimizer_factory.h"
#include "plugin/device/ascend/hal/common/ascend_utils.h"
#include "backend/common/pass/insert_type_transform_op.h"
#include "plugin/device/ascend/optimizer/ir_fission/ascend_convert_tuple_input_to_dynamic_input.h"
namespace mindspore {
namespace opt {
@ -221,6 +223,7 @@ void AddAscendIRFusionRulesPass(PassManager *ir_fusion_pm) {
void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
MS_EXCEPTION_IF_NULL(ir_fusion_pm);
ir_fusion_pm->AddPass(std::make_shared<AscendConvertTupleInputToDynamicInput>());
ir_fusion_pm->AddPass(std::make_shared<UnsortedSegmentSumReplace>());
ir_fusion_pm->AddPass(std::make_shared<SingleBatchNormFission>());
ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>());
@ -289,6 +292,7 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph)
MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<GraphOptimizer>();
auto data_layout_pm = std::make_shared<PassManager>("transop_pm");
data_layout_pm->AddPass(std::make_shared<opt::InsertTypeTransformOp>());
data_layout_pm->AddPass(std::make_shared<ReselectCallInlineFormat>());
data_layout_pm->AddPass(std::make_shared<RectifyDoMaskKernelInfo>());
data_layout_pm->AddPass(std::make_shared<DynamicRNNGradReformat>());
@ -739,5 +743,34 @@ void AscendOpAdaptation(const std::shared_ptr<session::KernelGraph> &kernel_grap
}
#endif
}
void AscendUnfoldInputsForSpecialNodes(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
#ifdef ENABLE_DUMP_IR
if (context_ptr->CanDump(kIntroductory)) {
std::string file_name =
"hwopt_d_before_unfold_inputs_for_special_nodes_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_name, kernel_graph, true, kWholeStack);
DumpIRProto(kernel_graph,
"before_unfold_inputs_for_special_nodes_hwopt_" + std::to_string(kernel_graph->graph_id()));
}
#endif
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto unfold_inputs_pm = std::make_shared<opt::PassManager>("unfold_inputs_for_special_nodes_pm");
unfold_inputs_pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
optimizer->AddPassManager(unfold_inputs_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
#ifdef ENABLE_DUMP_IR
if (context_ptr->CanDump(kIntroductory)) {
std::string file_name =
"hwopt_d_after_unfold_inputs_for_special_nodes_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_name, kernel_graph, true, kWholeStack);
}
#endif
}
} // namespace opt
} // namespace mindspore

View File

@ -30,6 +30,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendUnifyMindIR(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendOpAdaptation(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendUnfoldInputsForSpecialNodes(const std::shared_ptr<session::KernelGraph> &kernel_graph);
} // namespace opt
} // namespace mindspore

View File

@ -638,8 +638,10 @@ void SelectCallInlineKernelInfo(const CNodePtr &node) {
auto sub_ret = sub_graph->output();
std::vector<std::string> input_formats;
std::vector<TypeId> input_types;
std::vector<kernel::KernelObjectType> input_object_types;
std::vector<std::string> output_formats;
std::vector<TypeId> output_types;
std::vector<kernel::KernelObjectType> output_object_types;
for (auto &param : sub_graph->inputs()) {
TypeId type_id = AnfAlgo::GetOutputDeviceDataType(param, 0);
if (type_id == kTypeUnknown) {
@ -650,17 +652,25 @@ void SelectCallInlineKernelInfo(const CNodePtr &node) {
}
input_types.push_back(type_id);
input_formats.push_back(AnfAlgo::GetOutputFormat(param, 0));
input_object_types.push_back(kernel::KernelObjectType::TENSOR);
}
for (size_t i = 0; i < AnfUtils::GetOutputTensorNum(node); ++i) {
output_formats.push_back(AnfAlgo::GetOutputFormat(sub_ret, i));
output_types.push_back(common::AnfAlgo::GetOutputInferDataType(sub_ret, i));
if (AnfAlgo::GetOutputObjectType(node, i) == TypeId::kObjectTypeTuple) {
output_object_types.push_back(kernel::KernelObjectType::TUPLE_UNFOLD);
} else {
output_object_types.push_back(kernel::KernelObjectType::TENSOR);
}
}
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
MS_EXCEPTION_IF_NULL(builder);
builder->SetInputsFormat(input_formats);
builder->SetInputsDeviceType(input_types);
builder->SetInputsKernelObjectType(input_object_types);
builder->SetOutputsFormat(output_formats);
builder->SetOutputsDeviceType(output_types);
builder->SetOutputsKernelObjectType(output_object_types);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
}

View File

@ -28,7 +28,7 @@ AnfNodePtr InsertTensorMoveForGetNextOutputs(const FuncGraphPtr &func_graph, con
return nullptr;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
size_t output_num = AnfAlgo::GetOutputElementNum(node);
if (output_num == 0) {
MS_LOG(DEBUG) << "Output number is zero, no need to insert tensor_move!";
return node;

View File

@ -101,7 +101,7 @@ AnfNodePtr InsertForOutput(const FuncGraphPtr &func_graph, const CNodePtr &orig_
}
std::vector<AnfNodePtr> tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
auto out_num = AnfAlgo::GetOutputTensorNum(node);
auto out_num = AnfAlgo::GetOutputElementNum(node);
for (size_t output_idx = 0; output_idx < out_num; output_idx++) {
auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx);

View File

@ -68,7 +68,7 @@ CNodePtr Insert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std
}
} else if (op_name == kBasicLSTMCellWeightGradOpName) {
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
size_t out_num = AnfAlgo::GetOutputTensorNum(cnode);
size_t out_num = AnfAlgo::GetOutputElementNum(cnode);
for (size_t output_idx = 0; output_idx < out_num; output_idx++) {
auto tuple_getitem = CreatTupleGetItemNode(func_graph, cnode, output_idx);
auto origin_shape = common::AnfAlgo::GetOutputInferShape(cnode, output_idx);

View File

@ -0,0 +1,57 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/optimizer/ir_fission/ascend_convert_tuple_input_to_dynamic_input.h"
#include <algorithm>
#include <memory>
#include <vector>
#include "backend/common/optimizer/helper.h"
#include "include/common/utils/anfalgo.h"
#include "backend/common/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace opt {
const BaseRef AscendConvertTupleInputToDynamicInput::DefinePattern() const {
VarPtr V = std::make_shared<Var>();
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({V, Xs});
}
const AnfNodePtr AscendConvertTupleInputToDynamicInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>() || !AnfUtils::IsRealKernel(node)) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
// this pass should be in front of concat_fission, pack_fission, addn_fission, since the input should be unfold before
// this passes.
// the auto_monad pass should before this pass
bool is_communication_op = common::AnfAlgo::IsCommunicationOp(node);
static const PrimitiveSet need_unfold_node = {prim::kPrimAddN, prim::kPrimConcatD, prim::kPrimPack,
prim::kPrimStack, prim::kPrimCallInline, prim::kPrimPrint,
prim::kPrimSwitchLayer, prim::kPrimCall, prim::kPrimSwitch};
PrimitivePtr prim = common::AnfAlgo::GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(prim);
if (!is_communication_op && need_unfold_node.find(prim) == need_unfold_node.end()) {
return nullptr;
}
return ConvertMakeTupleInputToPlantInputs(func_graph, node->cast<CNodePtr>());
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPTIMIZER_ASCEND_IR_ASCEND_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_
#define MINDSPORE_CCSRC_OPTIMIZER_ASCEND_IR_ASCEND_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_
#include <string>
#include "ir/anf.h"
#include "backend/common/optimizer/optimizer.h"
namespace mindspore {
namespace opt {
class AscendConvertTupleInputToDynamicInput : public PatternProcessPass {
public:
explicit AscendConvertTupleInputToDynamicInput(bool multigraph = true)
: PatternProcessPass("ascend_convert_tuple_input_to_dynamic_input", multigraph) {}
~AscendConvertTupleInputToDynamicInput() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_ASCEND_IR_ASCEND_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_

View File

@ -91,6 +91,7 @@ ValueNodePtr CreateValueNode(const AnfNodePtr &node) {
kernel::KernelBuildInfo::KernelBuildInfoBuilder op_builder;
op_builder.SetOutputsFormat({kOpFormat_NDC1HWC0});
op_builder.SetOutputsDeviceType({kNumberTypeFloat16});
op_builder.SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
AnfAlgo::SetSelectKernelBuildInfo(op_builder.Build(), assist_const.get());
return assist_const;
}

View File

@ -70,6 +70,7 @@ ValueNodePtr CreateValueNode(T seed) {
} else {
builder.SetOutputsDeviceType({kNumberTypeUInt64});
}
builder.SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), value_node.get());
return value_node;
}

View File

@ -86,6 +86,7 @@ ValueNodePtr CreateValueNode(const AnfNodePtr &node) {
kernel::KernelBuildInfo::KernelBuildInfoBuilder op_builder;
op_builder.SetOutputsFormat({kOpFormat_NC1HWC0});
op_builder.SetOutputsDeviceType({kNumberTypeFloat16});
op_builder.SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
AnfAlgo::SetSelectKernelBuildInfo(op_builder.Build(), assist_const.get());
return assist_const;
}

View File

@ -117,6 +117,7 @@ ValueNodePtr CreateAssistNode(const std::vector<int64_t> &input_shape, int32_t k
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1;
builder1.SetOutputsFormat({kOpFormat_DEFAULT});
builder1.SetOutputsDeviceType({common::AnfAlgo::GetOutputInferDataType(assist_const, 0)});
builder1.SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), assist_const.get());
return assist_const;
}
@ -130,6 +131,8 @@ kernel::KernelBuildInfoPtr CreateKernelBuildInfo() {
builder.SetOutputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});
builder.SetInputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16});
builder.SetOutputsDeviceType({kNumberTypeFloat16, kNumberTypeInt32});
builder.SetInputsKernelObjectType({kernel::KernelObjectType::TENSOR});
builder.SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
return builder.Build();
}

View File

@ -123,7 +123,7 @@ const AnfNodePtr AdaptiveMaxPool2DFusion::Process(const FuncGraphPtr &func_graph
std::vector<int64_t> new_output_size{output_h, output_w};
common::AnfAlgo::SetNodeAttr(kAttrOutputSize, MakeValue(new_output_size), adaptive_max_pool2d);
if (AnfAlgo::GetOutputTensorNum(adaptive_max_pool2d) > 1) {
if (AnfAlgo::GetOutputElementNum(adaptive_max_pool2d) > 1) {
return nullptr;
}

View File

@ -512,6 +512,7 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph,
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{kNumberTypeUInt8});
kernel_build_info_builder->SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), mask_input.get());
}

View File

@ -59,6 +59,7 @@ ValueNodePtr CreateValueNode(const ValuePtr &value_ptr, TypeId output_type) {
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1;
builder1.SetOutputsFormat({kOpFormat_DEFAULT});
builder1.SetOutputsDeviceType({output_type});
builder1.SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), new_node.get());
return new_node;
}

View File

@ -589,18 +589,18 @@ bool GetSelectKernelResult(const CNodePtr &kernel_node,
}
#ifdef ENABLE_TUPLE_UNFOLD
bool GetSelectKernelObjectTypeResult(const CNodePtr &kernel_node) {
bool GetSelectKernelObjectTypeResult(const CNodePtr &kernel_node, KernelType kernel_type) {
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
// Only the kernel nodes that register kernel attr can support the backoff.
bool backoff_support_condition =
((kernel_type == UNKNOWN_KERNEL_TYPE) && !IsPrimitiveCNode(kernel_node, prim::kPrimCustom) &&
!common::AnfAlgo::IsGraphKernel(kernel_node));
std::vector<kernel::KernelAttr> kernel_attrs;
// Kernel that is not supported can try to backed off on CPU and use the CPU kernel attrs to set object type.
if (!kernel::NativeGpuKernelModFactory::GetInstance().IsRegistered(kernel_name)) {
if (kernel::NativeGpuKernelModFactory::GetInstance().IsRegistered(kernel_name)) {
kernel_attrs = kernel::NativeGpuKernelMod::GetGpuSupportedList(kernel_name);
} else if (backoff_support_condition) {
// Kernel that is not supported can try to backed off on CPU and use the CPU kernel attrs to set object type.
kernel_attrs = kernel::NativeCpuKernelMod::GetCpuSupportedList(kernel_name);
// CPU also doesn't support the kernel.
if (kernel_attrs.empty()) {
return false;
}
} else {
kernel_attrs = kernel::NativeGpuKernelModFactory::GetInstance().GetGpuSupportedList(kernel_name);
}
// Some dynamic kernels may not set the kernel attrs on GPU. Skip check only supports the tuple fold.
@ -635,7 +635,7 @@ std::pair<std::string, ExceptionType> SetKernelInfoWithMsg(const CNodePtr &kerne
auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get());
#ifdef ENABLE_TUPLE_UNFOLD
bool selected = GetSelectKernelObjectTypeResult(kernel_node);
bool selected = GetSelectKernelObjectTypeResult(kernel_node, kernel_type);
if (!selected) {
std::stringstream ss;
ss << "kernel object types are not supported for " << common::AnfAlgo::GetCNodeName(kernel_node)

View File

@ -84,7 +84,10 @@ std::vector<KernelAttr> NativeGpuKernelModFactory::GetGpuSupportedList(const std
for (size_t attr_index = 0; attr_index < (iter->second).size(); ++attr_index) {
auto attr = (iter->second)[attr_index].first;
kernel_attr_list.push_back(attr);
// Skip the invalid attr.
if (attr.GetInputSize() > 0 || attr.GetOutputSize() > 0) {
kernel_attr_list.push_back(attr);
}
}
return kernel_attr_list;

View File

@ -574,6 +574,38 @@ KernelWithIndex AnfAlgo::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t in
return res;
}
// if the prev_node is MakeTuple, get all the input_nodes recursively, else use the ori GetPrevNodeOutput function
std::vector<KernelWithIndex> AnfAlgo::GetRealPrevNodesOutput(const AnfNodePtr &anf_node, size_t input_idx,
bool skip_nop_node) {
MS_EXCEPTION_IF_NULL(anf_node);
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
std::vector<KernelWithIndex> res;
auto input_node = AnfAlgo::GetInputNode(cnode, input_idx);
MS_EXCEPTION_IF_NULL(input_node);
if (CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) {
auto maketuple_input_num = GetInputTensorNum(input_node);
for (size_t i = 0; i < maketuple_input_num; ++i) {
auto inputs_i = GetRealPrevNodesOutput(input_node, i, skip_nop_node);
res.insert(res.end(), inputs_i.begin(), inputs_i.end());
}
} else {
res.emplace_back(GetPrevNodeOutput(cnode, input_idx, skip_nop_node));
}
return res;
}
std::vector<TypeId> AnfAlgo::GetRealPrevNodesOutputInferDataType(const AnfNodePtr &node, size_t input_idx) {
std::vector<KernelWithIndex> kernels_with_index = AnfAlgo::GetRealPrevNodesOutput(node, input_idx);
std::vector<TypeId> res;
(void)std::transform(kernels_with_index.begin(), kernels_with_index.end(), std::back_inserter(res),
[](auto kernel_with_index) {
return AnfAlgo::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second);
});
return res;
}
inline ShapeVector GetShape(const abstract::BaseShapePtr &base_shape) {
auto shape_ptr = base_shape->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_ptr);
@ -763,8 +795,8 @@ void AnfAlgo::SetOutputTypeAndDetailShape(const std::vector<TypeId> &types,
node_name = GetCNodeName(node_ptr);
}
if (types.size() != shapes.size()) {
MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size() << "."
<< trace::DumpSourceLines(node);
MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size()
<< " for node " << node->fullname_with_scope() << "." << trace::DumpSourceLines(node);
}
auto tuple_node = kNodeTupleOutSet.find(node_name);

View File

@ -483,7 +483,8 @@ def test_call_no_self_other_object_method_runtime():
assert np.all(result == z)
@pytest.mark.level0
@pytest.mark.skip(reason="Not supported by now")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@ -505,7 +506,8 @@ def test_getattr_tensor_with_wrong_attr():
assert "object has no attribute" in str(err.value)
@pytest.mark.level0
@pytest.mark.skip(reason="Not supported by now")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training

View File

@ -28,6 +28,7 @@
#define private public
#define protected public
#include "plugin/device/ascend/optimizer/enhancer/insert_tensor_move_for_hccl_op.h"
#include "plugin/device/ascend/optimizer/ir_fission/ascend_convert_tuple_input_to_dynamic_input.h"
#undef private
#undef protected
namespace mindspore {
@ -168,6 +169,8 @@ TEST_F(TestHWInsertTensorMoveForHccl, test_cond5) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
// This pass run before hccl_pass to unfold inputs of hccl node
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
auto pass = std::make_shared<opt::InsertTensorMoveForHcclOp>();
pass->kernel_query_ = std::make_shared<MockInsertTensorMoveForHcclKernelQuery>();
pm->AddPass(pass);

View File

@ -19,6 +19,7 @@
#define private public
#define protected public
#include "plugin/device/ascend/optimizer/ir_fission/addn_fission.h"
#include "plugin/device/ascend/optimizer/ir_fission/ascend_convert_tuple_input_to_dynamic_input.h"
#undef private
#undef protected
@ -45,6 +46,7 @@ TEST_F(TestHWAddnFission, test_addn_fission_divided_by_2) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
auto addn_fission = std::make_shared<opt::AddnFission>();
addn_fission->inputs_divisor_ = 2;
pm->AddPass(addn_fission);
@ -54,7 +56,13 @@ TEST_F(TestHWAddnFission, test_addn_fission_divided_by_2) {
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_addn_fission", "after_divided_by_2");
EXPECT_NE(g_after, nullptr);
auto kg_after = GetKernelGraph(g_after, args_spec_list);
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
auto optimizer2 = std::make_shared<opt::GraphOptimizer>();
auto pm2 = std::make_shared<opt::PassManager>();
pm2->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
optimizer2->AddPassManager(pm2);
auto kg_after2 = optimizer2->Optimize(kg_after);
EXPECT_TRUE(CheckEqualGraph(kg_after2, new_graph));
}
TEST_F(TestHWAddnFission, test_addn_fission_divided_by_3) {
@ -70,6 +78,7 @@ TEST_F(TestHWAddnFission, test_addn_fission_divided_by_3) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
auto addn_fission = std::make_shared<opt::AddnFission>();
addn_fission->inputs_divisor_ = 3;
pm->AddPass(addn_fission);
@ -79,7 +88,13 @@ TEST_F(TestHWAddnFission, test_addn_fission_divided_by_3) {
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_addn_fission", "after_divided_by_3");
EXPECT_NE(g_after, nullptr);
auto kg_after = GetKernelGraph(g_after, args_spec_list);
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
auto optimizer2 = std::make_shared<opt::GraphOptimizer>();
auto pm2 = std::make_shared<opt::PassManager>();
pm2->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
optimizer2->AddPassManager(pm2);
auto kg_after2 = optimizer2->Optimize(kg_after);
EXPECT_TRUE(CheckEqualGraph(kg_after2, new_graph));
}
TEST_F(TestHWAddnFission, test_addn_fission_divided_by_4) {
@ -95,6 +110,7 @@ TEST_F(TestHWAddnFission, test_addn_fission_divided_by_4) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
auto addn_fission = std::make_shared<opt::AddnFission>();
addn_fission->inputs_divisor_ = 4;
pm->AddPass(addn_fission);
@ -104,7 +120,13 @@ TEST_F(TestHWAddnFission, test_addn_fission_divided_by_4) {
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_addn_fission", "after_divided_by_4");
EXPECT_NE(g_after, nullptr);
auto kg_after = GetKernelGraph(g_after, args_spec_list);
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
auto optimizer2 = std::make_shared<opt::GraphOptimizer>();
auto pm2 = std::make_shared<opt::PassManager>();
pm2->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
optimizer2->AddPassManager(pm2);
auto kg_after2 = optimizer2->Optimize(kg_after);
EXPECT_TRUE(CheckEqualGraph(kg_after2, new_graph));
}
TEST_F(TestHWAddnFission, test_addn_fission_divided_by_8) {
@ -120,6 +142,7 @@ TEST_F(TestHWAddnFission, test_addn_fission_divided_by_8) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
auto addn_fission = std::make_shared<opt::AddnFission>();
addn_fission->inputs_divisor_ = 8;
pm->AddPass(addn_fission);
@ -129,7 +152,13 @@ TEST_F(TestHWAddnFission, test_addn_fission_divided_by_8) {
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_addn_fission", "after_divided_by_8");
EXPECT_NE(g_after, nullptr);
auto kg_after = GetKernelGraph(g_after, args_spec_list);
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
auto optimizer2 = std::make_shared<opt::GraphOptimizer>();
auto pm2 = std::make_shared<opt::PassManager>();
pm2->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
optimizer2->AddPassManager(pm2);
auto kg_after2 = optimizer2->Optimize(kg_after);
EXPECT_TRUE(CheckEqualGraph(kg_after2, new_graph));
}
TEST_F(TestHWAddnFission, test_addn_fission_divided_by_9) {

View File

@ -19,6 +19,7 @@
#define private public
#define protected public
#include "plugin/device/ascend/optimizer/ir_fission/concat_fission.h"
#include "plugin/device/ascend/optimizer/ir_fission/ascend_convert_tuple_input_to_dynamic_input.h"
#undef private
#undef protected
@ -45,6 +46,7 @@ TEST_F(TestHWConcatFission, test_concat_fission_divided_by_2) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
auto concat_fission = std::make_shared<opt::ConcatFission>();
concat_fission->inputs_divisor_ = 2;
pm->AddPass(concat_fission);
@ -54,7 +56,13 @@ TEST_F(TestHWConcatFission, test_concat_fission_divided_by_2) {
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_2");
EXPECT_NE(g_after, nullptr);
auto kg_after = GetKernelGraph(g_after, args_spec_list);
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
auto optimizer2 = std::make_shared<opt::GraphOptimizer>();
auto pm2 = std::make_shared<opt::PassManager>();
pm2->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
optimizer2->AddPassManager(pm2);
auto kg_after2 = optimizer2->Optimize(kg_after);
EXPECT_TRUE(CheckEqualGraph(kg_after2, new_graph));
}
TEST_F(TestHWConcatFission, test_concat_fission_divided_by_3) {
@ -70,6 +78,7 @@ TEST_F(TestHWConcatFission, test_concat_fission_divided_by_3) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
auto concat_fission = std::make_shared<opt::ConcatFission>();
concat_fission->inputs_divisor_ = 3;
pm->AddPass(concat_fission);
@ -79,7 +88,13 @@ TEST_F(TestHWConcatFission, test_concat_fission_divided_by_3) {
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_3");
EXPECT_NE(g_after, nullptr);
auto kg_after = GetKernelGraph(g_after, args_spec_list);
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
auto optimizer2 = std::make_shared<opt::GraphOptimizer>();
auto pm2 = std::make_shared<opt::PassManager>();
pm2->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
optimizer2->AddPassManager(pm2);
auto kg_after2 = optimizer2->Optimize(kg_after);
EXPECT_TRUE(CheckEqualGraph(kg_after2, new_graph));
}
TEST_F(TestHWConcatFission, test_concat_fission_divided_by_4) {
@ -95,6 +110,7 @@ TEST_F(TestHWConcatFission, test_concat_fission_divided_by_4) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
auto concat_fission = std::make_shared<opt::ConcatFission>();
concat_fission->inputs_divisor_ = 4;
pm->AddPass(concat_fission);
@ -104,7 +120,13 @@ TEST_F(TestHWConcatFission, test_concat_fission_divided_by_4) {
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_4");
EXPECT_NE(g_after, nullptr);
auto kg_after = GetKernelGraph(g_after, args_spec_list);
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
auto optimizer2 = std::make_shared<opt::GraphOptimizer>();
auto pm2 = std::make_shared<opt::PassManager>();
pm2->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
optimizer2->AddPassManager(pm2);
auto kg_after2 = optimizer2->Optimize(kg_after);
EXPECT_TRUE(CheckEqualGraph(kg_after2, new_graph));
}
TEST_F(TestHWConcatFission, test_concat_fission_divided_by_8) {
@ -120,6 +142,7 @@ TEST_F(TestHWConcatFission, test_concat_fission_divided_by_8) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
auto concat_fission = std::make_shared<opt::ConcatFission>();
concat_fission->inputs_divisor_ = 8;
pm->AddPass(concat_fission);
@ -129,7 +152,13 @@ TEST_F(TestHWConcatFission, test_concat_fission_divided_by_8) {
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_8");
EXPECT_NE(g_after, nullptr);
auto kg_after = GetKernelGraph(g_after, args_spec_list);
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
auto optimizer2 = std::make_shared<opt::GraphOptimizer>();
auto pm2 = std::make_shared<opt::PassManager>();
pm2->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
optimizer2->AddPassManager(pm2);
auto kg_after2 = optimizer2->Optimize(kg_after);
EXPECT_TRUE(CheckEqualGraph(kg_after2, new_graph));
}
TEST_F(TestHWConcatFission, test_concat_fission_divided_by_9) {

View File

@ -19,6 +19,7 @@
#define private public
#define protected public
#include "plugin/device/ascend/optimizer/ir_fission/pack_fission.h"
#include "plugin/device/ascend/optimizer/ir_fission/ascend_convert_tuple_input_to_dynamic_input.h"
#undef private
#undef protected
@ -45,6 +46,7 @@ TEST_F(TestHWPackFission, test_stack_fission_divided_by_3) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
auto pack_fission = std::make_shared<opt::PackFission>();
pack_fission->inputs_divisor_ = 3;
pm->AddPass(pack_fission);
@ -69,6 +71,7 @@ TEST_F(TestHWPackFission, test_stack_fission_divided_by_4) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
auto pack_fission = std::make_shared<opt::PackFission>();
pack_fission->inputs_divisor_ = 4;
pm->AddPass(pack_fission);

View File

@ -17,6 +17,7 @@
#include "common/py_func_graph_fetcher.h"
#include "backend/common/optimizer/optimizer.h"
#include "plugin/device/ascend/optimizer/ir_fusion/confusion_mul_grad_fusion.h"
#include "plugin/device/ascend/optimizer/ir_fission/ascend_convert_tuple_input_to_dynamic_input.h"
#include "include/common/debug/anf_ir_dump.h"
namespace mindspore {
@ -42,6 +43,7 @@ TEST_F(TestHWOptimizeConfusionMulGradFusion, test_fusion) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
pm->AddPass(std::make_shared<opt::ConfusionMulGradFusion>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);

View File

@ -16,6 +16,7 @@
#include "common/backend_common_test.h"
#include "common/py_func_graph_fetcher.h"
#include "plugin/device/ascend/optimizer/ir_fusion/mul_addn_fusion.h"
#include "plugin/device/ascend/optimizer/ir_fission/ascend_convert_tuple_input_to_dynamic_input.h"
#include "include/common/debug/anf_ir_dump.h"
namespace mindspore {
@ -37,6 +38,7 @@ TEST_F(TestHWMulAddNFusion, test_mul_addn_fusion) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
pm->AddPass(std::make_shared<opt::MulAddNFusion>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
@ -55,6 +57,7 @@ TEST_F(TestHWMulAddNFusion, test_unmatch) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
pm->AddPass(std::make_shared<opt::MulAddNFusion>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);

View File

@ -138,9 +138,7 @@ def test_insert_tensor_move_for_hccl_op_cond5(tag):
m1 = tensor_move(b)
m2 = tensor_move(c)
y = broadcast(m1, m2)
y0 = tuple_getitem(y, 0)
y1 = tuple_getitem(y, 1)
res = depend(x, make_tuple(y0, y1))
res = depend(x, y)
return make_tuple(res)
return fns[tag]

View File

@ -48,7 +48,7 @@ TEST_F(TestMemUsageAnalyzer, test_mem_usage_analyzer) {
auto tensor_infos = analyzer->GetMemUsageTensorInfos();
ASSERT_EQ(5, kernel_infos.size());
ASSERT_EQ(15, tensor_infos.size());
ASSERT_EQ(16, tensor_infos.size());
for (size_t i = 0; i < kernel_infos.size(); ++i) {
ASSERT_NE(nullptr, analyzer->GetMemUsageKernelInfo(i));
}
@ -57,6 +57,6 @@ TEST_F(TestMemUsageAnalyzer, test_mem_usage_analyzer) {
ASSERT_NE(nullptr, analyzer->GetMemUsageTensorInfo(i));
}
ASSERT_EQ(132, analyzer->LeastMemNeeded());
ASSERT_EQ(100, analyzer->LeastMemNeeded());
}
} // namespace mindspore::device