ascend support untuple
This commit is contained in:
parent
73e19297db
commit
37586dbf9b
|
@ -76,6 +76,8 @@ if(ENABLE_ASAN)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
add_compile_definitions(ENABLE_TUPLE_UNFOLD)
|
||||
|
||||
if(DEBUG_MODE)
|
||||
set(CMAKE_BUILD_TYPE "Debug")
|
||||
add_compile_definitions(MEM_REUSE_DEBUG)
|
||||
|
|
|
@ -1177,5 +1177,96 @@ size_t GetInputNodeIndex(const AnfNodePtr &input, const CNodePtr &user_node) {
|
|||
// The first input is Primitive and needs to be skipped.
|
||||
return std::distance(input_list.begin() + kSizeOne, pos);
|
||||
}
|
||||
|
||||
int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
|
||||
std::vector<AnfNodePtr> *plant_inputs) {
|
||||
if (!common::AnfAlgo::IsTupleOutput(tuple_input)) {
|
||||
auto abs = tuple_input->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
MS_LOG(WARNING) << "The Function only split the output type is tuple type but got" << abs->ToString();
|
||||
return -1;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(plant_inputs);
|
||||
auto input_size = AnfAlgo::GetOutputElementNum(tuple_input);
|
||||
if (tuple_input->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(tuple_input, prim::kPrimMakeTuple)) {
|
||||
auto make_tuple = tuple_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
size_t tuple_input_num = common::AnfAlgo::GetInputTensorNum(make_tuple);
|
||||
for (size_t j = 0; j < tuple_input_num; ++j) {
|
||||
// using for graph kernel
|
||||
auto dyn_input_node = common::AnfAlgo::GetInputNode(make_tuple, j);
|
||||
MS_EXCEPTION_IF_NULL(dyn_input_node);
|
||||
// Handle tuple nested scenes.
|
||||
if (dyn_input_node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(dyn_input_node, prim::kPrimMakeTuple)) {
|
||||
input_size += LongToSize(SplitTupleInputs(graph, dyn_input_node, plant_inputs));
|
||||
continue;
|
||||
}
|
||||
(void)plant_inputs->emplace_back(dyn_input_node);
|
||||
}
|
||||
return input_size;
|
||||
}
|
||||
for (size_t index = 0; index < input_size; ++index) {
|
||||
auto dynamic_input_node = CreatTupleGetItemNode(graph, tuple_input, index);
|
||||
(void)plant_inputs->emplace_back(dynamic_input_node);
|
||||
}
|
||||
return input_size;
|
||||
}
|
||||
|
||||
AnfNodePtr ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimCall) ||
|
||||
common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimPartial)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (common::AnfAlgo::HasDynamicTupleInput(cnode_ptr)) {
|
||||
MS_LOG(INFO) << "Node " << cnode_ptr->fullname_with_scope()
|
||||
<< " has dynamic tuple input, can't convert. Node debug string:" << cnode_ptr->DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool is_bprop_cut = common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimBpropCut);
|
||||
bool cnode_is_print = common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimPrint);
|
||||
std::vector<AnfNodePtr> plant_inputs;
|
||||
std::vector<int64_t> dyn_input_sizes;
|
||||
plant_inputs.push_back(common::AnfAlgo::GetCNodePrimitiveNode(cnode_ptr));
|
||||
size_t input_num = cnode_ptr->inputs().size() - 1;
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto input_node = common::AnfAlgo::GetInputNode(cnode_ptr, i);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
bool output_is_tuple = common::AnfAlgo::IsTupleOutput(input_node);
|
||||
bool skip = (is_bprop_cut && input_node->abstract()->isa<abstract::AbstractSparseTensor>());
|
||||
if (output_is_tuple && cnode_is_print) {
|
||||
(void)dyn_input_sizes.emplace_back(SplitTupleInputs(graph, input_node, &plant_inputs));
|
||||
} else if (output_is_tuple && !skip) {
|
||||
auto dyn_input_size = SplitTupleInputs(graph, input_node, &plant_inputs);
|
||||
if (dyn_input_size == 0) {
|
||||
dyn_input_sizes.push_back(-1);
|
||||
plant_inputs.push_back(input_node);
|
||||
} else {
|
||||
(void)dyn_input_sizes.emplace_back(dyn_input_size);
|
||||
}
|
||||
} else {
|
||||
dyn_input_sizes.push_back(-1);
|
||||
plant_inputs.push_back(input_node);
|
||||
}
|
||||
}
|
||||
// If there is dynamic input, set the dyn_input_sizes as an attribute and update the inputs.
|
||||
if (std::any_of(dyn_input_sizes.begin(), dyn_input_sizes.end(), [](int64_t s) { return s >= 0; })) {
|
||||
auto new_cnode = NewCNode(plant_inputs, graph, {cnode_ptr});
|
||||
new_cnode->set_abstract(cnode_ptr->abstract());
|
||||
new_cnode->set_scope(cnode_ptr->scope());
|
||||
new_cnode->set_primal_attrs(cnode_ptr->primal_attrs());
|
||||
new_cnode->set_attrs(cnode_ptr->attrs());
|
||||
common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_cnode);
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
if (kernel_graph != nullptr) {
|
||||
kernel_graph->FrontBackendlMapUpdate(cnode_ptr, new_cnode);
|
||||
}
|
||||
return new_cnode;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -257,6 +257,11 @@ BACKEND_EXPORT int64_t GetNodeOutputTotalUsedNum(const session::KernelGraph &ker
|
|||
BACKEND_EXPORT void GetCustomOpAttrIndex(const PrimitivePtr &primitive, mindspore::HashSet<size_t> *indexes);
|
||||
|
||||
BACKEND_EXPORT size_t GetInputNodeIndex(const AnfNodePtr &input, const CNodePtr &user_node);
|
||||
|
||||
BACKEND_EXPORT int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
|
||||
std::vector<AnfNodePtr> *plant_inputs);
|
||||
|
||||
BACKEND_EXPORT AnfNodePtr ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_
|
||||
|
|
|
@ -24,92 +24,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
|
||||
std::vector<AnfNodePtr> *plant_inputs) {
|
||||
if (!common::AnfAlgo::IsTupleOutput(tuple_input)) {
|
||||
auto abs = tuple_input->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
MS_LOG(WARNING) << "The Function only split the output type is tuple type but got" << abs->ToString();
|
||||
return -1;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(plant_inputs);
|
||||
auto input_size = AnfAlgo::GetOutputTensorNum(tuple_input);
|
||||
if (tuple_input->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(tuple_input, prim::kPrimMakeTuple)) {
|
||||
auto make_tuple = tuple_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
size_t tuple_input_num = common::AnfAlgo::GetInputTensorNum(make_tuple);
|
||||
for (size_t j = 0; j < tuple_input_num; ++j) {
|
||||
// using for graph kernel
|
||||
auto dyn_input_node = common::AnfAlgo::GetInputNode(make_tuple, j);
|
||||
MS_EXCEPTION_IF_NULL(dyn_input_node);
|
||||
// Handle tuple nested scenes.
|
||||
if (dyn_input_node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(dyn_input_node, prim::kPrimMakeTuple)) {
|
||||
input_size += LongToSize(SplitTupleInputs(graph, dyn_input_node, plant_inputs));
|
||||
continue;
|
||||
}
|
||||
(void)plant_inputs->emplace_back(dyn_input_node);
|
||||
}
|
||||
return input_size;
|
||||
}
|
||||
for (size_t index = 0; index < input_size; ++index) {
|
||||
auto dynamic_input_node = CreatTupleGetItemNode(graph, tuple_input, index);
|
||||
(void)plant_inputs->emplace_back(dynamic_input_node);
|
||||
}
|
||||
return input_size;
|
||||
}
|
||||
|
||||
AnfNodePtr ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimCall) ||
|
||||
common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimPartial)) {
|
||||
return nullptr;
|
||||
}
|
||||
bool is_bprop_cut = common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimBpropCut);
|
||||
bool cnode_is_print = common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimPrint);
|
||||
std::vector<AnfNodePtr> plant_inputs;
|
||||
std::vector<int64_t> dyn_input_sizes;
|
||||
plant_inputs.push_back(common::AnfAlgo::GetCNodePrimitiveNode(cnode_ptr));
|
||||
size_t input_num = cnode_ptr->inputs().size() - 1;
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto input_node = common::AnfAlgo::GetInputNode(cnode_ptr, i);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
bool output_is_tuple = common::AnfAlgo::IsTupleOutput(input_node);
|
||||
bool skip = (is_bprop_cut && input_node->abstract()->isa<abstract::AbstractSparseTensor>());
|
||||
if (output_is_tuple && cnode_is_print) {
|
||||
(void)dyn_input_sizes.emplace_back(SplitTupleInputs(graph, input_node, &plant_inputs));
|
||||
} else if (output_is_tuple && !skip) {
|
||||
auto dyn_input_size = SplitTupleInputs(graph, input_node, &plant_inputs);
|
||||
if (dyn_input_size == 0) {
|
||||
dyn_input_sizes.push_back(-1);
|
||||
plant_inputs.push_back(input_node);
|
||||
} else {
|
||||
(void)dyn_input_sizes.emplace_back(dyn_input_size);
|
||||
}
|
||||
} else {
|
||||
dyn_input_sizes.push_back(-1);
|
||||
plant_inputs.push_back(input_node);
|
||||
}
|
||||
}
|
||||
// If there is dynamic input, set the dyn_input_sizes as an attribute and update the inputs.
|
||||
if (std::any_of(dyn_input_sizes.begin(), dyn_input_sizes.end(), [](int64_t s) { return s >= 0; })) {
|
||||
auto new_cnode = NewCNode(plant_inputs, graph, {cnode_ptr});
|
||||
new_cnode->set_abstract(cnode_ptr->abstract());
|
||||
new_cnode->set_scope(cnode_ptr->scope());
|
||||
new_cnode->set_primal_attrs(cnode_ptr->primal_attrs());
|
||||
new_cnode->set_attrs(cnode_ptr->attrs());
|
||||
common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_cnode);
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
if (kernel_graph != nullptr) {
|
||||
kernel_graph->FrontBackendlMapUpdate(cnode_ptr, new_cnode);
|
||||
}
|
||||
return new_cnode;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef ConvertTupleInputToDynamicInput::DefinePattern() const {
|
||||
VarPtr V = std::make_shared<Var>();
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
|
|
|
@ -106,40 +106,6 @@ size_t GetFusionSize(const AnfNodePtr &node) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
|
||||
std::vector<AnfNodePtr> *plant_inputs) {
|
||||
if (!common::AnfAlgo::IsTupleOutput(tuple_input)) {
|
||||
auto abs = tuple_input->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
MS_LOG(WARNING) << "The Function only split the output type is tuple type but got" << abs->ToString();
|
||||
return -1;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(plant_inputs);
|
||||
auto input_size = AnfAlgo::GetOutputTensorNum(tuple_input);
|
||||
if (tuple_input->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(tuple_input, prim::kPrimMakeTuple)) {
|
||||
auto make_tuple = tuple_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
size_t tuple_input_num = common::AnfAlgo::GetInputTensorNum(make_tuple);
|
||||
for (size_t j = 0; j < tuple_input_num; ++j) {
|
||||
// using for graph kernel
|
||||
auto dyn_input_node = common::AnfAlgo::GetInputNode(make_tuple, j);
|
||||
MS_EXCEPTION_IF_NULL(dyn_input_node);
|
||||
// Handle tuple nested scenes.
|
||||
if (dyn_input_node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(dyn_input_node, prim::kPrimMakeTuple)) {
|
||||
input_size += SplitTupleInputs(graph, dyn_input_node, plant_inputs);
|
||||
continue;
|
||||
}
|
||||
(void)plant_inputs->emplace_back(dyn_input_node);
|
||||
}
|
||||
return input_size;
|
||||
}
|
||||
for (size_t index = 0; index < input_size; ++index) {
|
||||
auto dynamic_input_node = CreatTupleGetItemNode(graph, tuple_input, index);
|
||||
(void)plant_inputs->emplace_back(dynamic_input_node);
|
||||
}
|
||||
return input_size;
|
||||
}
|
||||
|
||||
void ExpandFlattenConcatTupleInput(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
|
|
@ -26,8 +26,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
|
||||
std::vector<AnfNodePtr> *plant_inputs) {
|
||||
int64_t SplitTupleInputsForInsertType(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
|
||||
std::vector<AnfNodePtr> *plant_inputs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(tuple_input);
|
||||
MS_EXCEPTION_IF_NULL(plant_inputs);
|
||||
|
@ -50,7 +50,7 @@ int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_inpu
|
|||
MS_EXCEPTION_IF_NULL(dyn_input_node);
|
||||
// Handle tuple nested scenes.
|
||||
if (dyn_input_node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(dyn_input_node, prim::kPrimMakeTuple)) {
|
||||
int64_t dyn_input_size = SplitTupleInputs(graph, dyn_input_node, plant_inputs);
|
||||
int64_t dyn_input_size = SplitTupleInputsForInsertType(graph, dyn_input_node, plant_inputs);
|
||||
input_size += LongToSize(dyn_input_size);
|
||||
continue;
|
||||
}
|
||||
|
@ -169,7 +169,7 @@ void SetKernelInfoForNewCNodeByOrigNode(const CNodePtr &new_cnode, const CNodePt
|
|||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
new_cnode->set_kernel_info(kernel_info);
|
||||
// The node may not be supported in the current device.
|
||||
new_kernel_builder->SetValid(false);
|
||||
new_kernel_builder->SetValid(true);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(new_kernel_builder->Build(), new_cnode.get());
|
||||
|
||||
auto new_prim = GetValueNode<PrimitivePtr>(new_cnode->input(kIndex0));
|
||||
|
@ -198,6 +198,7 @@ void SetKernelInfoForNewCNode(const CNodePtr &cnode, bool set_format_type) {
|
|||
std::vector<KernelObjectType> input_obj_type;
|
||||
std::vector<KernelObjectType> output_obj_type;
|
||||
GenerateKernelObjectTypeForNewCNode(cnode, &input_obj_type, &output_obj_type);
|
||||
builder->SetKernelType(CPU_KERNEL);
|
||||
builder->SetInputsKernelObjectType(input_obj_type);
|
||||
builder->SetOutputsKernelObjectType(output_obj_type);
|
||||
|
||||
|
@ -231,7 +232,7 @@ void SetKernelInfoForNewCNode(const CNodePtr &cnode, bool set_format_type) {
|
|||
}
|
||||
|
||||
// The node may not be supported in the current device.
|
||||
builder->SetValid(false);
|
||||
builder->SetValid(true);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), cnode.get());
|
||||
}
|
||||
|
||||
|
@ -568,7 +569,7 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleUnfoldToTupleUnfold(const Func
|
|||
}
|
||||
|
||||
AnfNodePtrList plant_inputs;
|
||||
int64_t unfold_num = SplitTupleInputs(func_graph, input, &plant_inputs);
|
||||
int64_t unfold_num = SplitTupleInputsForInsertType(func_graph, input, &plant_inputs);
|
||||
MS_LOG(DEBUG) << "Transform tuple unfold input: " << input->fullname_with_scope() << " to " << unfold_num
|
||||
<< " inputs.";
|
||||
return plant_inputs;
|
||||
|
|
|
@ -76,8 +76,8 @@ using ProcessTypeTransformFunc = std::function<AnfNodePtrList(const FuncGraphPtr
|
|||
|
||||
// SplitTupleInputs methods refer to the pass ConvertTupleInputToDynamicInput. It unfolds tuple inputs and returns the
|
||||
// unfolded inputs nodes.
|
||||
int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
|
||||
std::vector<AnfNodePtr> *plant_inputs);
|
||||
int64_t SplitTupleInputsForInsertType(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
|
||||
std::vector<AnfNodePtr> *plant_inputs);
|
||||
|
||||
// Create the new cnode which will replace the original cnode.
|
||||
// This method is called at the last step of this pass specifically.
|
||||
|
|
|
@ -397,9 +397,9 @@ std::string AnfRuntimeAlgorithm::GetOriginDataFormat(const AnfNodePtr &node) {
|
|||
|
||||
std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (output_idx > AnfAlgo::GetOutputTensorNum(node)) {
|
||||
if (output_idx > AnfAlgo::GetOutputElementNum(node)) {
|
||||
MS_LOG(EXCEPTION) << "Output index:" << output_idx
|
||||
<< " is out of the node output range :" << AnfAlgo::GetOutputTensorNum(node) << " #node ["
|
||||
<< " is out of the node output range :" << AnfAlgo::GetOutputElementNum(node) << " #node ["
|
||||
<< node->DebugString() << "]" << trace::DumpSourceLines(node);
|
||||
}
|
||||
if (common::AnfAlgo::CheckAbsSparseTensor(node)) {
|
||||
|
@ -635,9 +635,9 @@ std::string AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, siz
|
|||
|
||||
std::string AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (output_idx > AnfAlgo::GetOutputTensorNum(node)) {
|
||||
if (output_idx > AnfAlgo::GetOutputElementNum(node)) {
|
||||
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
|
||||
<< AnfAlgo::GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]"
|
||||
<< AnfAlgo::GetOutputElementNum(node) << "#node[ " << node->DebugString() << "]"
|
||||
<< trace::DumpSourceLines(node);
|
||||
}
|
||||
if (!AnfUtils::IsRealKernel(node)) {
|
||||
|
@ -1728,7 +1728,7 @@ std::vector<TypeId> AnfRuntimeAlgorithm::GetAllOutputObjectType(const AnfNodePtr
|
|||
std::vector<TypeId> AnfAlgo::GetAllOutputInferDataTypes(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
std::vector<TypeId> outputs;
|
||||
auto out_nums = AnfAlgo::GetOutputTensorNum(node);
|
||||
auto out_nums = AnfAlgo::GetOutputElementNum(node);
|
||||
for (size_t i = 0; i < out_nums; i++) {
|
||||
auto type = common::AnfAlgo::GetOutputInferDataType(node, i);
|
||||
outputs.push_back(type);
|
||||
|
@ -1736,20 +1736,31 @@ std::vector<TypeId> AnfAlgo::GetAllOutputInferDataTypes(const AnfNodePtr &node)
|
|||
return outputs;
|
||||
}
|
||||
|
||||
// if input node is MakeTuple, find the PrevNodeNum recursively;
|
||||
// The monad node in the end is not included in the num;
|
||||
size_t AnfAlgo::GetInputElementNum(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
size_t element_num = 0;
|
||||
size_t input_num = cnode->inputs().size() - 1;
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto input_node = common::AnfAlgo::GetInputNode(cnode, i);
|
||||
if (common::AnfAlgo::IsTupleOutput(input_node)) {
|
||||
element_num += AnfUtils::GetOutputTensorNum(input_node);
|
||||
bool cal_monad_flag = false;
|
||||
for (size_t i = input_num; i > 0; --i) {
|
||||
auto input_node = common::AnfAlgo::GetInputNode(cnode, i - 1);
|
||||
if (!cal_monad_flag && HasAbstractMonad(input_node)) {
|
||||
continue;
|
||||
} else if (common::AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) {
|
||||
element_num += GetInputElementNum(input_node);
|
||||
cal_monad_flag = true;
|
||||
} else if (common::AnfAlgo::IsTupleOutput(input_node)) {
|
||||
element_num += AnfAlgo::GetOutputElementNum(input_node);
|
||||
cal_monad_flag = true;
|
||||
} else {
|
||||
++element_num;
|
||||
cal_monad_flag = true;
|
||||
}
|
||||
}
|
||||
|
||||
return element_num;
|
||||
}
|
||||
|
||||
|
|
|
@ -221,6 +221,7 @@ class BACKEND_EXPORT AnfRuntimeAlgorithm {
|
|||
static std::vector<TypeId> GetAllOutputObjectType(const AnfNodePtr &node);
|
||||
// Get all output infer data type.
|
||||
static std::vector<TypeId> GetAllOutputInferDataTypes(const AnfNodePtr &node);
|
||||
// Get unfold input num
|
||||
static size_t GetInputElementNum(const AnfNodePtr &node);
|
||||
static bool IsRealSquenceOutput(const AnfNodePtr &node);
|
||||
static void SetDynamicAttrToPrim(const PrimitivePtr &prim);
|
||||
|
|
|
@ -115,6 +115,10 @@ class COMMON_EXPORT AnfAlgo {
|
|||
static size_t GetInputTensorNum(const AnfNodePtr &node);
|
||||
// get prev node output width output index
|
||||
static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx, bool skip_nop_node = false);
|
||||
// get all the untuple real prev_nodes output
|
||||
static std::vector<KernelWithIndex> GetRealPrevNodesOutput(const AnfNodePtr &anf_node, size_t input_idx,
|
||||
bool skip_nop_node = false);
|
||||
|
||||
// get output shapes inferred by ME from input nodes.
|
||||
static ShapeVector GetOutputInferShape(const AnfNodePtr &node, size_t output_idx);
|
||||
static ShapeVector GetOutputInferShape(const AnfNodePtr &node, const abstract::BaseShapePtr &base_shape,
|
||||
|
@ -126,6 +130,8 @@ class COMMON_EXPORT AnfAlgo {
|
|||
static TypeId GetOutputInferDataType(const TypePtr &type, size_t output_idx);
|
||||
// get output original data type from prev node,input_index is the input index of current node related to prev node
|
||||
static TypeId GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx);
|
||||
// for tuple condition
|
||||
static std::vector<TypeId> GetRealPrevNodesOutputInferDataType(const AnfNodePtr &node, size_t input_idx);
|
||||
// set infer shapes and types of anf node
|
||||
static void SetOutputInferTypeAndShape(const std::vector<TypeId> &types, const std::vector<ShapeVector> &shapes,
|
||||
AnfNode *node, bool disable_dynamic_len = false);
|
||||
|
|
|
@ -505,8 +505,8 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpIn
|
|||
std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
||||
size_t real_input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
size_t real_output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
size_t real_input_num = AnfAlgo::GetInputElementNum(kernel_node);
|
||||
size_t real_output_num = AnfAlgo::GetOutputElementNum(kernel_node);
|
||||
std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info_ptr->inputs_ptr();
|
||||
std::vector<std::shared_ptr<OpIOInfo>> outputs = op_info_ptr->outputs_ptr();
|
||||
std::vector<int64_t> dyn_input_sizes;
|
||||
|
@ -729,7 +729,7 @@ bool IsWeightBoundary(const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
std::vector<int64_t> GetReduceAttrAxis(const CNodePtr &cnode) {
|
||||
if (common::AnfAlgo::GetInputTensorNum(cnode) != 1 || AnfAlgo::GetOutputTensorNum(cnode) != 1) {
|
||||
if (common::AnfAlgo::GetInputTensorNum(cnode) != 1 || AnfAlgo::GetOutputElementNum(cnode) != 1) {
|
||||
MS_LOG(EXCEPTION) << "The reduce node [" << cnode->DebugString() << "] is not single input or single output."
|
||||
<< trace::DumpSourceLines(cnode);
|
||||
}
|
||||
|
@ -1278,7 +1278,10 @@ std::vector<KernelObjectType> CalKernelObjectTypes(const std::vector<TypeId> &ob
|
|||
for (size_t i = 0; i < selected_object_types.size(); ++i) {
|
||||
// Allsame/skip_check doesn't support the backoff.
|
||||
bool not_backoff = ((all_same || skip_check) && (selected_object_types[i] != object_types[i]));
|
||||
if (not_backoff) {
|
||||
// Ops which support tensor also support scalar.
|
||||
bool scalar_compact =
|
||||
((selected_object_types[i] == kObjectTypeTensorType) && (object_types[i] == kObjectTypeNumber));
|
||||
if (not_backoff || scalar_compact) {
|
||||
(void)ret.emplace_back(TypeIdToKernelObjectTypeForTupleUnfold(object_types[i]));
|
||||
} else {
|
||||
(void)ret.emplace_back(TypeIdToKernelObjectType(selected_object_types[i]));
|
||||
|
@ -1724,6 +1727,22 @@ std::pair<bool, size_t> MatchKernelAttrStrict(const KernelAttr &kernel_attr,
|
|||
return std::make_pair(false, 0);
|
||||
}
|
||||
|
||||
bool IsFoldKernelBuildInfo(const KernelBuildInfoPtr &kernel_build_info) {
|
||||
auto inputs_object_type = kernel_build_info->GetAllInputKernelObjectTypes();
|
||||
if (std::find(inputs_object_type.begin(), inputs_object_type.end(), KernelObjectType::TUPLE) !=
|
||||
inputs_object_type.end()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
auto outputs_object_type = kernel_build_info->GetAllOutputKernelObjectTypes();
|
||||
if (std::find(outputs_object_type.begin(), outputs_object_type.end(), KernelObjectType::TUPLE) !=
|
||||
outputs_object_type.end()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
KernelAttr GetKernelAttrFromBuildInfo(const KernelBuildInfoPtr &build_info) {
|
||||
MS_EXCEPTION_IF_NULL(build_info);
|
||||
KernelAttr kernel_attr;
|
||||
|
|
|
@ -400,6 +400,7 @@ BACKEND_EXPORT std::pair<bool, size_t> MatchKernelAttrStrict(const KernelAttr &k
|
|||
const std::vector<KernelAttr> &kernel_attr_list);
|
||||
BACKEND_EXPORT KernelAttr GetKernelAttrFromBuildInfo(const KernelBuildInfoPtr &build_info);
|
||||
BACKEND_EXPORT KernelAttr GetKernelAttrFromNode(const AnfNodePtr &kernel_node);
|
||||
BACKEND_EXPORT bool IsFoldKernelBuildInfo(const KernelBuildInfoPtr &kernel_build_info);
|
||||
|
||||
struct KernelArgs {
|
||||
BaseOperatorPtr op;
|
||||
|
@ -455,6 +456,7 @@ BACKEND_EXPORT std::vector<KernelObjectType> TypeIdToKernelObjectTypeForTupleUnf
|
|||
BACKEND_EXPORT TypeId KernelObjectTypeToTypeId(const KernelObjectType &object_type);
|
||||
KernelObjectType StringToKernelObjectType(const std::string &object_type);
|
||||
BACKEND_EXPORT void UnfoldKernelBuildInfo(const CNodePtr &kernel_node);
|
||||
BACKEND_EXPORT int64_t CalOutputTupleSize(const AnfNodePtr &node);
|
||||
BACKEND_EXPORT void SetDynamicInputSizeAttr(const CNodePtr &cnode);
|
||||
BACKEND_EXPORT bool IsDynamicParamKernel(const std::string &op_name);
|
||||
|
||||
|
|
|
@ -70,7 +70,7 @@ TypeId KernelBuildInfo::GetOutputDeviceType(size_t output_index) const {
|
|||
KernelObjectType KernelBuildInfo::GetInputKernelObjectType(size_t input_index) const {
|
||||
if (input_index >= inputs_kernel_object_type_.size()) {
|
||||
#ifdef ENABLE_TUPLE_UNFOLD
|
||||
MS_LOG(ERROR) << "The input index [" << input_index
|
||||
MS_LOG(DEBUG) << "The input index [" << input_index
|
||||
<< "] is exceed the number of input:" << inputs_kernel_object_type_.size();
|
||||
#endif
|
||||
return KernelObjectType::UNKNOWN_TYPE;
|
||||
|
@ -81,7 +81,7 @@ KernelObjectType KernelBuildInfo::GetInputKernelObjectType(size_t input_index) c
|
|||
KernelObjectType KernelBuildInfo::GetOutputKernelObjectType(size_t output_index) const {
|
||||
if (output_index >= outputs_kernel_object_type_.size()) {
|
||||
#ifdef ENABLE_TUPLE_UNFOLD
|
||||
MS_LOG(ERROR) << "The output index [" << output_index
|
||||
MS_LOG(DEBUG) << "The output index [" << output_index
|
||||
<< "] is exceed the number of output:" << outputs_kernel_object_type_.size();
|
||||
#endif
|
||||
return KernelObjectType::UNKNOWN_TYPE;
|
||||
|
@ -182,18 +182,33 @@ std::string KernelBuildInfo::ToString() const {
|
|||
if (index != 0) {
|
||||
output_buffer << ", ";
|
||||
}
|
||||
output_buffer << "<" << TypeIdLabel(GetInputDeviceType(index)) << "x" << GetInputFormat(index) << "x"
|
||||
<< KernelObjectTypeLabel(GetInputKernelObjectType(index)) << ">";
|
||||
output_buffer << "<" << TypeIdLabel(GetInputDeviceType(index)) << "x" << GetInputFormat(index) << ">";
|
||||
}
|
||||
output_buffer << ") -> (";
|
||||
output_buffer << ", object_type: [";
|
||||
auto input_object_types = GetAllInputKernelObjectTypes();
|
||||
for (size_t index = 0; index < input_object_types.size(); ++index) {
|
||||
if (index != 0) {
|
||||
output_buffer << ",";
|
||||
}
|
||||
output_buffer << KernelObjectTypeLabel(input_object_types[index]);
|
||||
}
|
||||
|
||||
output_buffer << "]) -> (";
|
||||
for (size_t index = 0; index < GetOutputNum(); ++index) {
|
||||
if (index != 0) {
|
||||
output_buffer << ",";
|
||||
}
|
||||
output_buffer << "<" << TypeIdLabel(GetOutputDeviceType(index)) << "x" << GetOutputFormat(index) << ">";
|
||||
}
|
||||
output_buffer << ", object_type: [";
|
||||
auto output_object_types = GetAllOutputKernelObjectTypes();
|
||||
for (size_t index = 0; index < output_object_types.size(); ++index) {
|
||||
if (index != 0) {
|
||||
output_buffer << ", ";
|
||||
}
|
||||
output_buffer << "<" << TypeIdLabel(GetOutputDeviceType(index)) << "x" << GetOutputFormat(index) << "x"
|
||||
<< KernelObjectTypeLabel(GetOutputKernelObjectType(index)) << ">";
|
||||
output_buffer << KernelObjectTypeLabel(output_object_types[index]);
|
||||
}
|
||||
output_buffer << ")";
|
||||
output_buffer << "])";
|
||||
return output_buffer.str();
|
||||
}
|
||||
|
||||
|
|
|
@ -174,13 +174,18 @@ class BACKEND_EXPORT KernelBuildInfo::KernelBuildInfoBuilder {
|
|||
(void)kernel_build_info_->inputs_device_type_.emplace_back(kernel_build_info->GetInputDeviceType(index));
|
||||
(void)kernel_build_info_->inputs_format_.emplace_back(kernel_build_info->GetInputFormat(index));
|
||||
(void)kernel_build_info_->input_reshape_type_.emplace_back(kernel_build_info->GetInputReshapeType(index));
|
||||
}
|
||||
for (size_t index = 0; index < kernel_build_info->GetAllInputKernelObjectTypes().size(); ++index) {
|
||||
(void)kernel_build_info_->inputs_kernel_object_type_.emplace_back(
|
||||
kernel_build_info->GetInputKernelObjectType(index));
|
||||
}
|
||||
|
||||
for (size_t index = 0; index < kernel_build_info->GetOutputNum(); ++index) {
|
||||
(void)kernel_build_info_->outputs_device_type_.emplace_back(kernel_build_info->GetOutputDeviceType(index));
|
||||
(void)kernel_build_info_->outputs_format_.emplace_back(kernel_build_info->GetOutputFormat(index));
|
||||
(void)kernel_build_info_->output_reshape_type_.emplace_back(kernel_build_info->GetOutputReshapeType(index));
|
||||
}
|
||||
for (size_t index = 0; index < kernel_build_info->GetAllOutputKernelObjectTypes().size(); ++index) {
|
||||
(void)kernel_build_info_->outputs_kernel_object_type_.emplace_back(
|
||||
kernel_build_info->GetOutputKernelObjectType(index));
|
||||
}
|
||||
|
|
|
@ -36,6 +36,8 @@
|
|||
#include "frontend/operator/ops.h"
|
||||
#include "utils/trace_base.h"
|
||||
#include "mindspore/core/ops/op_name.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "kernel/kernel_build_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
|
@ -79,18 +81,27 @@ mindspore::HashSet<std::string> kHighPrecisionOp = {kConv2DOpName,
|
|||
kBiasAddGradOpName,
|
||||
kSigmoidCrossEntropyWithLogitsV2OpName};
|
||||
|
||||
bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) {
|
||||
bool MatchUnfoldInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfoPtr &kernel_build_info) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// Check input data type
|
||||
for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) {
|
||||
TypeId input_origin_type = common::AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
|
||||
if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) {
|
||||
return false;
|
||||
size_t kernel_input_index = 0;
|
||||
size_t fold_input_tensor_num = common::AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t input_index = 0; input_index < fold_input_tensor_num; ++input_index) {
|
||||
std::vector<TypeId> inputs_type = common::AnfAlgo::GetRealPrevNodesOutputInferDataType(cnode, input_index);
|
||||
for (size_t i = 0; i < inputs_type.size(); ++i) {
|
||||
if (kernel_input_index >= kernel_build_info->GetInputNum()) {
|
||||
return false;
|
||||
}
|
||||
if (kernel_build_info->GetInputDeviceType(kernel_input_index) != inputs_type[i]) {
|
||||
return false;
|
||||
}
|
||||
++kernel_input_index;
|
||||
}
|
||||
}
|
||||
|
||||
// Check output data type
|
||||
for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) {
|
||||
if (kernel_build_info.GetOutputDeviceType(output_index) !=
|
||||
for (size_t output_index = 0; output_index < kernel_build_info->GetOutputNum(); ++output_index) {
|
||||
if (kernel_build_info->GetOutputDeviceType(output_index) !=
|
||||
common::AnfAlgo::GetOutputInferDataType(cnode, output_index)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -98,6 +109,52 @@ bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildIn
|
|||
return true;
|
||||
}
|
||||
|
||||
bool MatchFoldInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfoPtr &kernel_build_info) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// Check input data type
|
||||
size_t fold_input_tensor_num = common::AnfAlgo::GetInputTensorNum(cnode);
|
||||
size_t kernel_index = 0;
|
||||
for (size_t input_index = 0; input_index < fold_input_tensor_num; ++input_index) {
|
||||
if (kernel_build_info->GetInputKernelObjectType(kernel_index) == kernel::KernelObjectType::TUPLE) {
|
||||
auto input_node = cnode->inputs()[input_index + 1];
|
||||
TypeId input_origin_type = common::AnfAlgo::GetOutputInferDataType(input_node, 0);
|
||||
if (kernel_build_info->GetInputDeviceType(kernel_index) != input_origin_type) {
|
||||
return false;
|
||||
}
|
||||
++kernel_index;
|
||||
} else {
|
||||
std::vector<TypeId> inputs_type = common::AnfAlgo::GetRealPrevNodesOutputInferDataType(cnode, input_index);
|
||||
for (size_t i = 0; i < inputs_type.size(); ++i) {
|
||||
if (kernel_index >= kernel_build_info->GetInputNum()) {
|
||||
return false;
|
||||
}
|
||||
if (kernel_build_info->GetInputDeviceType(kernel_index) != inputs_type[i]) {
|
||||
return false;
|
||||
}
|
||||
++kernel_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Check output data type
|
||||
for (size_t output_index = 0; output_index < kernel_build_info->GetOutputNum(); ++output_index) {
|
||||
if (kernel_build_info->GetOutputDeviceType(output_index) !=
|
||||
common::AnfAlgo::GetOutputInferDataType(cnode, output_index)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfoPtr &kernel_build_info) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
bool is_fold = kernel::IsFoldKernelBuildInfo(kernel_build_info);
|
||||
if (is_fold) {
|
||||
return MatchFoldInferOutputDataType(cnode, kernel_build_info);
|
||||
} else {
|
||||
return MatchUnfoldInferOutputDataType(cnode, kernel_build_info);
|
||||
}
|
||||
}
|
||||
|
||||
string GetPriorityMatchFormat(const CNodePtr &cnode) {
|
||||
constexpr size_t k5dSize = 5;
|
||||
constexpr size_t k4dSize = 4;
|
||||
|
@ -195,7 +252,7 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
|
|||
}
|
||||
}
|
||||
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
size_t output_num = AnfAlgo::GetOutputElementNum(kernel_node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
// cal count of same output dtype between abstract and kernel info
|
||||
if (kernel_build_info.GetOutputDeviceType(output_index) ==
|
||||
|
@ -210,14 +267,14 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
|
|||
|
||||
std::string PrintRaiseOrReducePrecisionSelectedInfo(
|
||||
const CNodePtr &cnode, const std::shared_ptr<kernel::KernelBuildInfo> &selected_kernel_build_info,
|
||||
bool precision_reduce) {
|
||||
KernelSelectStatus KernelSelectStatus) {
|
||||
MS_EXCEPTION_IF_NULL(selected_kernel_build_info);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::ostringstream buffer;
|
||||
buffer << cnode->DebugString();
|
||||
if (precision_reduce) {
|
||||
if (KernelSelectStatus == kStatusReducePrecision) {
|
||||
buffer << " Reduce precision, node datatype: \n";
|
||||
} else {
|
||||
} else if (KernelSelectStatus == kStatusRaisePrecision) {
|
||||
buffer << " Raise precision, node datatype: \n";
|
||||
}
|
||||
GatherInputAndOutputInferType(buffer, cnode);
|
||||
|
@ -250,7 +307,7 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilteredKernelInfoByDtype(
|
|||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> result;
|
||||
for (const auto &kernel_build_info : kernel_info_list) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_info);
|
||||
if (!MatchInferOutputDataType(cnode, *kernel_build_info)) {
|
||||
if (!MatchInferOutputDataType(cnode, kernel_build_info)) {
|
||||
continue;
|
||||
}
|
||||
result.push_back(kernel_build_info);
|
||||
|
@ -258,6 +315,131 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilteredKernelInfoByDtype(
|
|||
return result;
|
||||
}
|
||||
|
||||
bool MatchObjectType(const kernel::KernelObjectType &node_object, const kernel::KernelObjectType &kernel_object) {
|
||||
if (node_object == kernel_object) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if ((node_object == kernel::TUPLE || node_object == kernel::TUPLE_UNFOLD || node_object == kernel::SCALAR) &&
|
||||
(kernel_object == kernel::TENSOR)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// for monad output op such as labelset labelswitch labelgoto ...
|
||||
if (node_object == kernel::UNKNOWN_TYPE && kernel_object == kernel::TENSOR) {
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Object mismatch. node object type : " << node_object << ", kernel object type: " << kernel_object;
|
||||
return false;
|
||||
}
|
||||
// kernel:tuple, node:tuple -> compare objecttype
|
||||
// kernel:tuple, node:tensor -> compare objecttype
|
||||
// kernel:tensor, node:tensor -> compare objecttype
|
||||
// kernel:tensor, node:tuple -> unfold node, then compare object type
|
||||
bool MatchObjectType(const CNodePtr &cnode, const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// Check input object type
|
||||
auto kernel_inputs_object_type = kernel_build_info->GetAllInputKernelObjectTypes();
|
||||
auto node_inputs_object_type = kernel::TypeIdToKernelObjectType(AnfAlgo::GetAllInputObjectType(cnode));
|
||||
|
||||
size_t kernel_input_index = 0;
|
||||
std::vector<kernel::KernelObjectType> new_input_object_types = {};
|
||||
for (size_t input_index = 0; input_index < node_inputs_object_type.size(); ++input_index) {
|
||||
if (kernel_inputs_object_type[kernel_input_index] != kernel::KernelObjectType::TUPLE &&
|
||||
node_inputs_object_type[input_index] == kernel::KernelObjectType::TUPLE) {
|
||||
// tuple_unfold condition
|
||||
std::vector<KernelWithIndex> index_inputs = common::AnfAlgo::GetRealPrevNodesOutput(cnode, input_index);
|
||||
for (size_t i = 0; i < index_inputs.size(); ++i) {
|
||||
auto real_input_node = index_inputs[i].first;
|
||||
MS_EXCEPTION_IF_NULL(real_input_node);
|
||||
if (kernel_input_index >= kernel_inputs_object_type.size()) {
|
||||
MS_LOG(DEBUG) << "index is large equal than list size: " << kernel_input_index << " vs "
|
||||
<< kernel_inputs_object_type.size();
|
||||
return false;
|
||||
}
|
||||
if (!MatchObjectType(
|
||||
kernel::TypeIdToKernelObjectType(AnfAlgo::GetAbstractObjectType(real_input_node->abstract())),
|
||||
kernel_inputs_object_type[kernel_input_index])) {
|
||||
return false;
|
||||
}
|
||||
++kernel_input_index;
|
||||
}
|
||||
|
||||
new_input_object_types.push_back(kernel::KernelObjectType::TUPLE_UNFOLD);
|
||||
} else {
|
||||
auto node_object = node_inputs_object_type[input_index];
|
||||
auto kernel_object = kernel_inputs_object_type[kernel_input_index];
|
||||
if (!MatchObjectType(node_object, kernel_object)) {
|
||||
return false;
|
||||
}
|
||||
if (node_object == kernel::KernelObjectType::SCALAR && kernel_object == kernel::KernelObjectType::TENSOR) {
|
||||
new_input_object_types.push_back(kernel::KernelObjectType::SCALAR);
|
||||
} else {
|
||||
new_input_object_types.push_back(kernel_inputs_object_type[kernel_input_index]);
|
||||
}
|
||||
++kernel_input_index;
|
||||
}
|
||||
}
|
||||
if (kernel_input_index != kernel_inputs_object_type.size()) {
|
||||
MS_LOG(DEBUG) << "index is not equal to list size: " << kernel_input_index << " vs "
|
||||
<< kernel_inputs_object_type.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check output object type
|
||||
auto kernel_outputs_object_type = kernel_build_info->GetAllOutputKernelObjectTypes();
|
||||
auto node_output_object_type = AnfAlgo::GetAbstractObjectType(cnode->abstract());
|
||||
std::vector<kernel::KernelObjectType> new_output_object_types = {};
|
||||
|
||||
if (node_output_object_type == kObjectTypeTuple) {
|
||||
auto tuple_abs = cnode->abstract()->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abs);
|
||||
auto items = tuple_abs->elements();
|
||||
size_t output_index = 0;
|
||||
for (auto item : items) {
|
||||
if (output_index >= kernel_outputs_object_type.size()) {
|
||||
MS_LOG(DEBUG) << "index is large equal than list size: " << output_index << " vs "
|
||||
<< kernel_outputs_object_type.size();
|
||||
return false;
|
||||
}
|
||||
if (!MatchObjectType(kernel::TypeIdToKernelObjectType(AnfAlgo::GetAbstractObjectType(item)),
|
||||
kernel_outputs_object_type[output_index])) {
|
||||
return false;
|
||||
}
|
||||
++output_index;
|
||||
}
|
||||
new_output_object_types = {kernel::KernelObjectType::TUPLE_UNFOLD};
|
||||
} else {
|
||||
auto output_num = AnfAlgo::GetOutputElementNum(cnode);
|
||||
if (output_num > 0) {
|
||||
if (!MatchObjectType(kernel::TypeIdToKernelObjectType(node_output_object_type), kernel_outputs_object_type[0])) {
|
||||
return false;
|
||||
}
|
||||
new_output_object_types.push_back(kernel_outputs_object_type[0]);
|
||||
}
|
||||
}
|
||||
|
||||
kernel_build_info->SetInputsKernelObjectType(new_input_object_types);
|
||||
kernel_build_info->SetOutputsKernelObjectType(new_output_object_types);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilteredKernelInfoByObjectType(
|
||||
const CNodePtr &cnode, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> result;
|
||||
for (const auto &kernel_build_info : kernel_info_list) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_info);
|
||||
auto new_kernel_build_info = std::make_shared<kernel::KernelBuildInfo>(*kernel_build_info);
|
||||
if (!MatchObjectType(cnode, new_kernel_build_info)) {
|
||||
continue;
|
||||
}
|
||||
result.push_back(new_kernel_build_info);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void SetCastAndWeightFormat(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
if (!common::AnfAlgo::HasNodeAttr(kAttrPynativeNextIndex, kernel_node) ||
|
||||
|
@ -386,14 +568,17 @@ TypeId GetInputDeviceType(const CNodePtr &kernel_node, size_t input_idx) {
|
|||
return type;
|
||||
}
|
||||
|
||||
void GetInputsDeviceType(const CNodePtr &kernel_node, std::vector<TypeId> *input_types) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(input_types);
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto type = GetInputDeviceType(kernel_node, i);
|
||||
input_types->emplace_back(type);
|
||||
TypeId GetInputDeviceType(const KernelWithIndex &input_node_with_index) {
|
||||
TypeId type = kTypeUnknown;
|
||||
auto input_node = input_node_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
auto kernel_info = dynamic_cast<device::KernelInfo *>(input_node->kernel_info());
|
||||
if (kernel_info != nullptr && kernel_info->select_kernel_build_info() != nullptr) {
|
||||
type = AnfAlgo::GetOutputDeviceDataType(input_node_with_index.first, input_node_with_index.second);
|
||||
} else {
|
||||
type = common::AnfAlgo::GetOutputInferDataType(input_node_with_index.first, input_node_with_index.second);
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
string InferOutputFormat(const CNodePtr &kernel_node, const std::vector<std::string> &inputs_format) {
|
||||
|
@ -464,32 +649,54 @@ KernelSelectStatus SelectCustomKernelInfo(const CNodePtr &kernel_node, KernelTyp
|
|||
// set inputs info
|
||||
std::vector<TypeId> inputs_device_type;
|
||||
std::vector<std::string> inputs_format;
|
||||
GetInputsDeviceType(kernel_node, &inputs_device_type);
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
std::vector<kernel::KernelObjectType> inputs_kernel_object_type;
|
||||
std::unordered_set<string> all_input_formats;
|
||||
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i);
|
||||
inputs_format.emplace_back(format);
|
||||
all_input_formats.insert(format);
|
||||
auto input_node = common::AnfAlgo::GetInputNode(kernel_node, i);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (common::AnfAlgo::IsTupleOutput(input_node)) {
|
||||
std::vector<KernelWithIndex> inputs_with_index = common::AnfAlgo::GetRealPrevNodesOutput(kernel_node, i);
|
||||
for (size_t j = 0; j < inputs_with_index.size(); ++j) {
|
||||
auto type = GetInputDeviceType(inputs_with_index[j]);
|
||||
inputs_device_type.emplace_back(type);
|
||||
auto format = AnfAlgo::GetOutputFormat(inputs_with_index[j].first, inputs_with_index[j].second);
|
||||
inputs_format.emplace_back(format);
|
||||
all_input_formats.insert(format);
|
||||
}
|
||||
inputs_kernel_object_type.emplace_back(kernel::KernelObjectType::TUPLE_UNFOLD);
|
||||
} else {
|
||||
auto type = GetInputDeviceType(kernel_node, i);
|
||||
inputs_device_type.emplace_back(type);
|
||||
auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i);
|
||||
inputs_format.emplace_back(format);
|
||||
all_input_formats.insert(format);
|
||||
inputs_kernel_object_type.emplace_back(kernel::KernelObjectType::TENSOR);
|
||||
}
|
||||
}
|
||||
|
||||
if (all_input_formats.size() > 1) {
|
||||
MS_LOG(WARNING) << op_name << " has different input formats, the number of input formats is "
|
||||
<< all_input_formats.size();
|
||||
}
|
||||
builder->SetInputsDeviceType(inputs_device_type);
|
||||
builder->SetInputsFormat(inputs_format);
|
||||
builder->SetInputsKernelObjectType(inputs_kernel_object_type);
|
||||
// set outputs info
|
||||
std::vector<TypeId> outputs_device_type;
|
||||
std::vector<std::string> outputs_format;
|
||||
auto output_infer_format = InferOutputFormat(kernel_node, inputs_format);
|
||||
MS_LOG(INFO) << "Outputs of " << op_name << " will use same inferred format: " << output_infer_format;
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
size_t output_num = AnfAlgo::GetOutputElementNum(kernel_node);
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
outputs_device_type.push_back(common::AnfAlgo::GetOutputInferDataType(kernel_node, i));
|
||||
outputs_format.push_back(output_infer_format);
|
||||
}
|
||||
builder->SetOutputsDeviceType(outputs_device_type);
|
||||
builder->SetOutputsFormat(outputs_format);
|
||||
builder->SetOutputsKernelObjectType(
|
||||
std::vector<kernel::KernelObjectType>(outputs_format.size(), kernel::KernelObjectType::TENSOR));
|
||||
// Set kernel build info to node
|
||||
auto build_info = builder->Build();
|
||||
MS_LOG(INFO) << "Current node: " << kernel_node->fullname_with_scope() << " selected: " << build_info;
|
||||
|
@ -565,28 +772,47 @@ void ResetPreFixedFormat(const CNodePtr &kernel_node, kernel::KernelBuildInfoPtr
|
|||
}
|
||||
} // namespace
|
||||
|
||||
void RefreshInputParameter(const CNodePtr &kernel_node, const AnfNodePtr &input_kernel_node,
|
||||
const std::string &input_format, size_t input_index) {
|
||||
auto input_with_index = common::AnfAlgo::VisitKernelWithReturnType(input_kernel_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(input_with_index.first);
|
||||
auto real_input_node = input_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(real_input_node);
|
||||
if (RefreshCastAndParamWeightFormat(real_input_node, input_format)) {
|
||||
return;
|
||||
}
|
||||
if (real_input_node->isa<Parameter>() && !common::AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<std::string> output_format = {input_format};
|
||||
SetWeightFormat(real_input_node, output_format, kernel_node, input_index);
|
||||
return;
|
||||
}
|
||||
|
||||
void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
auto selected_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(selected_kernel_info);
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
size_t real_input_num = 0;
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
auto input_kernel_node = common::AnfAlgo::GetInputNode(kernel_node, input_index);
|
||||
MS_EXCEPTION_IF_NULL(input_kernel_node);
|
||||
auto input_with_index = common::AnfAlgo::VisitKernelWithReturnType(input_kernel_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(input_with_index.first);
|
||||
auto real_input_node = input_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(real_input_node);
|
||||
if (RefreshCastAndParamWeightFormat(real_input_node, selected_kernel_info->GetInputFormat(input_index))) {
|
||||
continue;
|
||||
auto input_object_type = selected_kernel_info->GetInputKernelObjectType(input_index);
|
||||
if (input_object_type == kernel::KernelObjectType::TUPLE_UNFOLD) {
|
||||
std::vector<KernelWithIndex> kernels_with_index =
|
||||
common::AnfAlgo::GetRealPrevNodesOutput(kernel_node, input_index);
|
||||
for (size_t i = 0; i < kernels_with_index.size(); ++i) {
|
||||
RefreshInputParameter(kernel_node, kernels_with_index[i].first,
|
||||
selected_kernel_info->GetInputFormat(real_input_num), real_input_num);
|
||||
++real_input_num;
|
||||
}
|
||||
} else {
|
||||
auto input_kernel_node = common::AnfAlgo::GetInputNode(kernel_node, input_index);
|
||||
MS_EXCEPTION_IF_NULL(input_kernel_node);
|
||||
RefreshInputParameter(kernel_node, input_kernel_node, selected_kernel_info->GetInputFormat(real_input_num),
|
||||
real_input_num);
|
||||
++real_input_num;
|
||||
}
|
||||
if (real_input_node->isa<Parameter>() &&
|
||||
!common::AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
|
||||
continue;
|
||||
}
|
||||
auto refresh_format = selected_kernel_info->GetInputFormat(input_index);
|
||||
std::vector<std::string> output_format = {refresh_format};
|
||||
SetWeightFormat(real_input_node, output_format, kernel_node, input_index);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -595,28 +821,40 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
|
|||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
KernelSelectStatus select_status = kNoMatched;
|
||||
if (kernel_info_list.empty()) {
|
||||
return select_status;
|
||||
return kNoMatched;
|
||||
}
|
||||
bool precision_reduce = false;
|
||||
kernel::KernelBuildInfoPtr selected_kernel_info = nullptr;
|
||||
// Matched kernel info
|
||||
// Filter kernel info matched with me inferred type
|
||||
auto filtered_kernel_info_list = FilteredKernelInfoByDtype(kernel_node, kernel_info_list);
|
||||
if (!filtered_kernel_info_list.empty()) {
|
||||
selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
|
||||
select_status = kStatusAllMatched;
|
||||
} else {
|
||||
if (filtered_kernel_info_list.empty()) {
|
||||
// selected kernel info using raised precision or reduce precision
|
||||
filtered_kernel_info_list =
|
||||
FilterRaisedOrReducePrecisionMatchedKernelInfo(kernel_node, kernel_info_list, &precision_reduce);
|
||||
selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
|
||||
if (selected_kernel_info == nullptr) {
|
||||
return select_status;
|
||||
} else {
|
||||
MS_LOG(INFO) << PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce);
|
||||
select_status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision;
|
||||
if (filtered_kernel_info_list.empty()) {
|
||||
return kNoMatched;
|
||||
}
|
||||
select_status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision;
|
||||
} else {
|
||||
select_status = kStatusAllMatched;
|
||||
}
|
||||
|
||||
// filter object_type and adjust tuple_unfold condition
|
||||
MS_LOG(DEBUG) << "Node " << kernel_node->fullname_with_scope() << "'s kernel info list size is "
|
||||
<< filtered_kernel_info_list.size() << " before object type matching";
|
||||
filtered_kernel_info_list = FilteredKernelInfoByObjectType(kernel_node, filtered_kernel_info_list);
|
||||
MS_LOG(DEBUG) << "Node " << kernel_node->fullname_with_scope() << "'s kernel info list size is "
|
||||
<< filtered_kernel_info_list.size() << " after object type matching";
|
||||
if (filtered_kernel_info_list.empty()) {
|
||||
return kNoMatched;
|
||||
}
|
||||
|
||||
selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
|
||||
if (select_status == kStatusReducePrecision || kStatusReducePrecision == kStatusRaisePrecision) {
|
||||
MS_LOG(INFO) << PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, select_status);
|
||||
}
|
||||
|
||||
// Set kernel build info to node
|
||||
MS_LOG(DEBUG) << "Current node: " << kernel_node->fullname_with_scope()
|
||||
<< " selected: " << selected_kernel_info->ToString();
|
||||
|
@ -861,6 +1099,47 @@ void SetAclKernelInfo(const CNodePtr &kernel_node) {
|
|||
AnfAlgo::SetSelectKernelBuildInfo(new_builder->Build(), kernel_node.get());
|
||||
}
|
||||
|
||||
void SetDynamicInputSizeAttrBeforeKernelSelect(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) ||
|
||||
common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPartial)) {
|
||||
return;
|
||||
}
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, cnode)) {
|
||||
return;
|
||||
}
|
||||
std::vector<int64_t> dyn_input_sizes;
|
||||
size_t input_num = cnode->inputs().size() - 1;
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto input_node = common::AnfAlgo::GetInputNode(cnode, i);
|
||||
dyn_input_sizes.push_back(kernel::CalOutputTupleSize(input_node));
|
||||
}
|
||||
if (std::any_of(dyn_input_sizes.begin(), dyn_input_sizes.end(), [](int64_t s) { return s >= 0; })) {
|
||||
common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), cnode);
|
||||
}
|
||||
}
|
||||
|
||||
void RefreshDynamicInputSizeAttr(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, cnode)) {
|
||||
MS_LOG(INFO) << "Node has not set kAttrDynInputSizes yet, node: " << cnode->fullname_with_scope();
|
||||
return;
|
||||
}
|
||||
std::vector<int64_t> dyn_input_sizes = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, kAttrDynInputSizes);
|
||||
auto input_obj_types = AnfAlgo::GetInputKernelObjectTypes(cnode);
|
||||
size_t input_num = cnode->inputs().size() - 1;
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
if (input_obj_types[i] == kernel::KernelObjectType::TUPLE) {
|
||||
dyn_input_sizes[i] = -1;
|
||||
}
|
||||
}
|
||||
if (std::any_of(dyn_input_sizes.begin(), dyn_input_sizes.end(), [](int64_t s) { return s >= 0; })) {
|
||||
common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), cnode);
|
||||
} else {
|
||||
common::AnfAlgo::EraseNodeAttr(kAttrDynInputSizes, cnode);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<KernelSelectStatus, std::string, ExceptionType> SelectKernelInfoWithMsg(const CNodePtr &kernel_node,
|
||||
KernelType kernel_type) {
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
|
||||
|
@ -870,15 +1149,18 @@ std::tuple<KernelSelectStatus, std::string, ExceptionType> SelectKernelInfoWithM
|
|||
std::tuple<KernelSelectStatus, std::string, ExceptionType> result =
|
||||
std::make_tuple(kStatusAllMatched, "", NoExceptionType);
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
SetDynamicInputSizeAttrBeforeKernelSelect(kernel_node);
|
||||
if (common::AnfAlgo::IsGraphKernel(kernel_node)) {
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(kernel_node->input(kAnfPrimitiveIndex));
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
SelectGraphKernelInfo(kernel_node, func_graph);
|
||||
RefreshDynamicInputSizeAttr(kernel_node);
|
||||
return result;
|
||||
}
|
||||
if (IsPrimitiveCNode(kernel_node, prim::kPrimCallInline)) {
|
||||
opt::SelectCallInlineKernelInfo(kernel_node);
|
||||
SetTensorDeviceInfo(kernel_node);
|
||||
RefreshDynamicInputSizeAttr(kernel_node);
|
||||
return result;
|
||||
}
|
||||
if (common::AnfAlgo::HasNodeAttr(ops::kBatchRank, kernel_node)) {
|
||||
|
@ -892,6 +1174,7 @@ std::tuple<KernelSelectStatus, std::string, ExceptionType> SelectKernelInfoWithM
|
|||
if (IsPrimitiveCNode(kernel_node, prim::kPrimCustom)) {
|
||||
auto select_status = SelectCustomKernelInfo(kernel_node, &kernel_type);
|
||||
if (select_status == kStatusAllMatched) {
|
||||
RefreshDynamicInputSizeAttr(kernel_node);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
@ -942,6 +1225,7 @@ std::tuple<KernelSelectStatus, std::string, ExceptionType> SelectKernelInfoWithM
|
|||
select_status = SetMatchedKernelInfo(kernel_node, aicpu_kernel_info_list);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrIsAiCpuKernel, MakeValue(true), kernel_node);
|
||||
}
|
||||
|
||||
// The kernel info can not find in ai_cpu kernel lists and ai_core kernel lists
|
||||
if (select_status == kNoMatched) {
|
||||
GatherInputAndOutputInferType(aicpu_in_out_info, kernel_node);
|
||||
|
@ -954,6 +1238,7 @@ std::tuple<KernelSelectStatus, std::string, ExceptionType> SelectKernelInfoWithM
|
|||
std::get<two>(result) = etype;
|
||||
return result;
|
||||
}
|
||||
RefreshDynamicInputSizeAttr(kernel_node);
|
||||
SetRaiseOrReduceFlag(kernel_node, select_status);
|
||||
std::get<0>(result) = select_status;
|
||||
return result;
|
||||
|
|
|
@ -72,6 +72,7 @@ void ResetKernelBuildInfo(const CNodePtr &kernel_node) {
|
|||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
builder.SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
|
||||
builder.SetOutputsDeviceType(std::vector<TypeId>{kTypeUnknown});
|
||||
builder.SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_kernel_node.get());
|
||||
}
|
||||
}
|
||||
|
@ -283,6 +284,7 @@ void UpdateInputsKernelInfo(const CNodePtr &kernel_node, const std::vector<AnfNo
|
|||
std::vector<TypeId> outputs_device_type = {(*graph_input_type)[i]};
|
||||
builder.SetOutputsFormat(outputs_format);
|
||||
builder.SetOutputsDeviceType(outputs_device_type);
|
||||
builder.SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get());
|
||||
}
|
||||
}
|
||||
|
@ -403,6 +405,7 @@ void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNo
|
|||
std::vector<TypeId> outputs_device_type = {graph_input_type[i]};
|
||||
builder.SetOutputsFormat(outputs_format);
|
||||
builder.SetOutputsDeviceType(outputs_device_type);
|
||||
builder.SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get());
|
||||
}
|
||||
|
||||
|
@ -436,6 +439,7 @@ void SetGraphKernelInfo(const CNodePtr &kernel_node, const std::vector<std::pair
|
|||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::vector<std::string> graph_output_format;
|
||||
std::vector<TypeId> graph_output_type;
|
||||
std::vector<kernel::KernelObjectType> graph_output_object_type;
|
||||
for (size_t i = 0; i < output_index.size(); ++i) {
|
||||
auto const &output = output_index[i];
|
||||
graph_output_format.push_back(AnfAlgo::GetOutputFormat(output.first, output.second));
|
||||
|
@ -447,13 +451,21 @@ void SetGraphKernelInfo(const CNodePtr &kernel_node, const std::vector<std::pair
|
|||
output_type = AnfAlgo::GetOutputDeviceDataType(output.first, output.second);
|
||||
}
|
||||
graph_output_type.push_back(output_type);
|
||||
graph_output_object_type.push_back(kernel::KernelObjectType::TENSOR);
|
||||
}
|
||||
|
||||
std::vector<kernel::KernelObjectType> graph_input_object_type;
|
||||
for (size_t i = 0; i < graph_input_type.size(); ++i) {
|
||||
graph_input_object_type.push_back(kernel::KernelObjectType::TENSOR);
|
||||
}
|
||||
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
|
||||
graph_info_builder.SetInputsFormat(graph_input_format);
|
||||
graph_info_builder.SetInputsDeviceType(graph_input_type);
|
||||
graph_info_builder.SetInputsKernelObjectType(graph_input_object_type);
|
||||
graph_info_builder.SetOutputsFormat(graph_output_format);
|
||||
graph_info_builder.SetOutputsDeviceType(graph_output_type);
|
||||
graph_info_builder.SetOutputsKernelObjectType(graph_output_object_type);
|
||||
graph_info_builder.SetProcessor(kernel::Processor::AICORE);
|
||||
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
graph_info_builder.SetFusionType(kernel::kPatternOpaque);
|
||||
|
|
|
@ -35,12 +35,12 @@ namespace mindspore {
|
|||
namespace device {
|
||||
namespace ascend {
|
||||
void AscendDeviceContext::Initialize() {
|
||||
MS_LOG(INFO) << "Start Initialize...";
|
||||
if (initialized_) {
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
runtime_instance_->SetContext();
|
||||
return;
|
||||
} else {
|
||||
MS_LOG(INFO) << "Start Initialize...";
|
||||
#ifndef ENABLE_SECURITY
|
||||
AscendProfiler::GetInstance()->MsprofInitProfiler();
|
||||
#endif
|
||||
|
|
|
@ -268,6 +268,8 @@ void AscendGraphOptimization::OptimizeGraphWithoutDeviceInfo(const KernelGraphPt
|
|||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
CommOpReuse(graph);
|
||||
opt::AscendUnfoldInputsForSpecialNodes(graph);
|
||||
|
||||
if (context_ptr->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {
|
||||
HandleControlFlow(NOT_NULL(graph));
|
||||
}
|
||||
|
@ -448,6 +450,7 @@ void AscendGraphOptimization::RecurseSelectKernelInfo(const KernelGraphPtr &grap
|
|||
}
|
||||
#endif
|
||||
MS_LOG(INFO) << "Status record: start select kernel info. graph id: " << graph->graph_id();
|
||||
graph->SetKernelObjectTypesForUnrealNodes();
|
||||
SetOperatorInfo(graph);
|
||||
MS_LOG(INFO) << "Status record: end select kernel info. graph id: " << graph->graph_id();
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
|
|
|
@ -238,7 +238,7 @@ void UpdateOutputNodeShape(const AnfNodePtr &node, size_t index, TypeId output_t
|
|||
if (node->isa<CNode>()) {
|
||||
name = common::AnfAlgo::GetCNodeName(node);
|
||||
}
|
||||
size_t total_output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
size_t total_output_num = AnfAlgo::GetOutputElementNum(node);
|
||||
if (index >= total_output_num) {
|
||||
MS_LOG(EXCEPTION) << "Invalid output index " << index << ", node " << node->fullname_with_scope() << " has "
|
||||
<< total_output_num << " outputs.";
|
||||
|
|
|
@ -64,20 +64,27 @@ void AicpuMetadataInfoForSpecialNodes(const CNodePtr &kernel_node,
|
|||
if (kDynamicInputOps.find(op_name) != kDynamicInputOps.end()) {
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||
(void)inputs_type.emplace_back(common::AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index));
|
||||
inputs_object_type.emplace_back(KernelObjectType::TENSOR);
|
||||
auto kernels_with_index = common::AnfAlgo::GetRealPrevNodesOutput(kernel_node, input_index);
|
||||
for (size_t i = 0; i < kernels_with_index.size(); ++i) {
|
||||
inputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||
(void)inputs_type.emplace_back(
|
||||
common::AnfAlgo::GetOutputInferDataType(kernels_with_index[i].first, kernels_with_index[i].second));
|
||||
inputs_object_type.emplace_back(kernel::TypeIdToKernelObjectType(
|
||||
AnfAlgo::GetOutputObjectType(kernels_with_index[i].first, kernels_with_index[i].second)));
|
||||
}
|
||||
}
|
||||
}
|
||||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> outputs_type;
|
||||
std::vector<KernelObjectType> outputs_object_type{};
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
size_t output_num = AnfAlgo::GetOutputElementNum(kernel_node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
outputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||
(void)outputs_type.emplace_back(common::AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
|
||||
outputs_object_type.emplace_back(KernelObjectType::TENSOR);
|
||||
outputs_object_type.emplace_back(
|
||||
kernel::TypeIdToKernelObjectType(AnfAlgo::GetOutputObjectType(kernel_node, output_index)));
|
||||
}
|
||||
|
||||
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
|
||||
builder.SetInputsFormat(inputs_format);
|
||||
builder.SetInputsDeviceType(inputs_type);
|
||||
|
|
|
@ -98,7 +98,7 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
|
|||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> outputs_type;
|
||||
std::vector<KernelObjectType> output_object_type{};
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
size_t output_num = AnfAlgo::GetOutputElementNum(kernel_node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
(void)outputs_format.emplace_back(GetKernelFormat(kernel_node, output_index));
|
||||
if (op_name == kReceiveOpName) {
|
||||
|
|
|
@ -37,7 +37,7 @@ void HostMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
|
|||
std::vector<std::string> inputs_format{};
|
||||
std::vector<TypeId> inputs_type{};
|
||||
std::vector<KernelObjectType> inputs_object_type{};
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
size_t input_num = AnfAlgo::GetInputElementNum(kernel_node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||
inputs_type.push_back(common::AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index));
|
||||
|
@ -46,7 +46,7 @@ void HostMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
|
|||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> outputs_type;
|
||||
std::vector<KernelObjectType> outputs_object_type{};
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
size_t output_num = AnfAlgo::GetOutputElementNum(kernel_node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
outputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
|
||||
|
|
|
@ -25,8 +25,10 @@
|
|||
#include "plugin/device/ascend/kernel/akg/akg_kernel_metadata.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/trace_base.h"
|
||||
#include "kernel/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -38,36 +40,91 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
|
|||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
size_t output_tensor_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
size_t input_tensor_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
size_t unfold_output_tensor_num = AnfAlgo::GetOutputElementNum(kernel_node);
|
||||
size_t unfold_input_tensor_num = AnfAlgo::GetInputElementNum(kernel_node);
|
||||
size_t fold_output_tensor_num = 1;
|
||||
size_t fold_input_tensor_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_list;
|
||||
(void)std::copy_if(
|
||||
kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list),
|
||||
[output_tensor_num, input_tensor_num](const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_info);
|
||||
return kernel_build_info->GetOutputNum() == output_tensor_num &&
|
||||
kernel_build_info->GetInputNum() == input_tensor_num;
|
||||
});
|
||||
std::ostringstream buffer;
|
||||
size_t info_index = 0;
|
||||
for (const auto &kernel_info : *kernel_info_list) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
bool is_fold = kernel::IsFoldKernelBuildInfo(kernel_info);
|
||||
if (is_fold) {
|
||||
bool is_match = true;
|
||||
if (!common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, kernel_node)) {
|
||||
is_match = false;
|
||||
} else {
|
||||
// compare input num
|
||||
std::vector<int64_t> dyn_input_sizes =
|
||||
common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, kAttrDynInputSizes);
|
||||
size_t real_input_num = 0;
|
||||
for (size_t i = 0; i < fold_input_tensor_num; ++i) {
|
||||
if (kernel_info->GetInputKernelObjectType(i) == kernel::KernelObjectType::TUPLE || dyn_input_sizes[i] == -1) {
|
||||
++real_input_num;
|
||||
} else {
|
||||
real_input_num += dyn_input_sizes[i];
|
||||
}
|
||||
}
|
||||
if (kernel_info->GetInputNum() != real_input_num) {
|
||||
is_match = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_match) {
|
||||
// compare output num
|
||||
size_t real_output_num = unfold_output_tensor_num;
|
||||
if (kernel_info->GetOutputKernelObjectType(0) == kernel::KernelObjectType::TUPLE) {
|
||||
real_output_num = 1;
|
||||
}
|
||||
|
||||
if (kernel_info->GetOutputNum() != real_output_num) {
|
||||
is_match = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_match) {
|
||||
(void)filtered_list.emplace_back(kernel_info);
|
||||
} else {
|
||||
buffer << "Kernel [ " << info_index << " ] [Fold]:";
|
||||
if (kernel_info->GetOutputNum() != fold_output_tensor_num) {
|
||||
buffer << "Kernel build info's output size [" << kernel_info->GetOutputNum() << "]"
|
||||
<< " cannot match the node's output size [" << fold_output_tensor_num << "]\n";
|
||||
} else {
|
||||
buffer << "Kernel build info's input size [" << kernel_info->GetInputNum() << "]"
|
||||
<< " cannot match the node's input size [" << fold_input_tensor_num << "]\n";
|
||||
}
|
||||
buffer << "\n kernel info:" << kernel_info->ToString();
|
||||
}
|
||||
} else {
|
||||
if ((kernel_info->GetInputNum() == unfold_input_tensor_num) &&
|
||||
(kernel_info->GetOutputNum() == unfold_output_tensor_num)) {
|
||||
(void)filtered_list.emplace_back(kernel_info);
|
||||
} else {
|
||||
buffer << "Kernel [ " << info_index << " ] [Unfold]:";
|
||||
if (kernel_info->GetOutputNum() != unfold_output_tensor_num) {
|
||||
buffer << "Kernel build info's output size [" << kernel_info->GetOutputNum() << "]"
|
||||
<< " cannot match the node's output size [" << unfold_output_tensor_num << "]\n";
|
||||
} else {
|
||||
buffer << "Kernel build info's input size [" << kernel_info->GetInputNum() << "]"
|
||||
<< " cannot match the node's input size [" << unfold_input_tensor_num << "]\n";
|
||||
}
|
||||
buffer << "\n kernel info:" << kernel_info->ToString();
|
||||
}
|
||||
}
|
||||
info_index++;
|
||||
}
|
||||
|
||||
if (!filtered_list.empty()) {
|
||||
kernel_info_list->clear();
|
||||
(void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list));
|
||||
} else {
|
||||
for (size_t index = 0; index < kernel_info_list->size(); ++index) {
|
||||
std::ostringstream buffer;
|
||||
auto &kernel_info = kernel_info_list->at(index);
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
if (kernel_info->GetOutputNum() != output_tensor_num) {
|
||||
buffer << "Kernel node's output size [" << output_tensor_num << "]"
|
||||
<< " cannot match the kernel's output size [" << kernel_info->GetOutputNum() << "]";
|
||||
} else {
|
||||
buffer << "Kernel node's input size [" << input_tensor_num << "]"
|
||||
<< " cannot match the kernel's input size [" << kernel_info->GetInputNum() << "]";
|
||||
}
|
||||
MS_LOG(INFO) << "Kernel [ " << index << " ] :" << kernel_info->ToString() << buffer.str();
|
||||
}
|
||||
MS_LOG(INFO) << buffer.str();
|
||||
kernel_info_list->clear();
|
||||
MS_LOG(INFO) << "Node: " << kernel_node->DebugString() << "'s output size : [" << output_tensor_num << "]"
|
||||
<< "input size : [" << input_tensor_num << "] can not match any kernelInfo !";
|
||||
MS_LOG(INFO) << "Node: " << kernel_node->DebugString() << "'s fold output size : [" << fold_output_tensor_num << "]"
|
||||
<< ", fold input size : [" << fold_input_tensor_num << "], unfold output size : ["
|
||||
<< unfold_output_tensor_num << "]"
|
||||
<< ", unfold input size : [" << unfold_input_tensor_num << "] can not match any kernelInfo !";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -99,22 +156,37 @@ void KernelQueryAll(const CNodePtr &kernel_node,
|
|||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
||||
TbeMetadataInfo(kernel_node, kernel_info_list);
|
||||
auto select_cnode = kernel_node;
|
||||
auto tuple_unfold_node = opt::ConvertMakeTupleInputToPlantInputs(kernel_node->func_graph(), kernel_node);
|
||||
if (tuple_unfold_node != nullptr) {
|
||||
auto tuple_unfold_cnode = tuple_unfold_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_unfold_cnode);
|
||||
select_cnode = tuple_unfold_cnode;
|
||||
select_cnode->set_fullname_with_scope(kernel_node->fullname_with_scope());
|
||||
MS_LOG(INFO) << "Create tuple unfold node " << tuple_unfold_node->fullname_with_scope() << ", debug string ["
|
||||
<< tuple_unfold_node->DebugString() << "] from " << kernel_node->fullname_with_scope()
|
||||
<< ", debug string [" << kernel_node->DebugString() << "].";
|
||||
}
|
||||
|
||||
TbeMetadataInfo(select_cnode, kernel_info_list);
|
||||
if (kernel_info_list->empty()) {
|
||||
GetRtKelInfo(kernel_node, kernel_info_list);
|
||||
GetRtKelInfo(select_cnode, kernel_info_list);
|
||||
CheckKernelInfoListEmpty(kernel_info_list, "RT_Kernel");
|
||||
}
|
||||
if (kernel_info_list->empty()) {
|
||||
HcclMetadataInfo(kernel_node, kernel_info_list);
|
||||
HcclMetadataInfo(select_cnode, kernel_info_list);
|
||||
CheckKernelInfoListEmpty(kernel_info_list, "HCCL_Kernel");
|
||||
}
|
||||
if (SelectAicpuReshapeInTaskSink(kernel_node)) {
|
||||
if (SelectAicpuReshapeInTaskSink(select_cnode)) {
|
||||
return;
|
||||
}
|
||||
if (kernel_info_list->empty()) {
|
||||
HostMetadataInfo(kernel_node, kernel_info_list);
|
||||
HostMetadataInfo(select_cnode, kernel_info_list);
|
||||
CheckKernelInfoListEmpty(kernel_info_list, "HOST_Kernel");
|
||||
}
|
||||
if (!kernel_info_list->empty()) {
|
||||
common::AnfAlgo::CopyNodeAttrs(select_cnode, kernel_node);
|
||||
}
|
||||
}
|
||||
|
||||
void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list,
|
||||
|
|
|
@ -69,7 +69,7 @@ void GetRtKelInfo(const CNodePtr &kernel_node,
|
|||
if (IsDefaultKernelInfo(node_name)) {
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
// set input infos
|
||||
auto input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
auto input_num = AnfAlgo::GetInputElementNum(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_info_builder);
|
||||
kernel_build_info_builder->SetInputsFormat(std::vector<std::string>(input_num, kOpFormat_DEFAULT));
|
||||
kernel_build_info_builder->SetInputsKernelObjectType(
|
||||
|
|
|
@ -124,7 +124,8 @@ RangePair TbeDynamicShapeUtil::GetOutputDynamicRange(const AnfNodePtr &anf_node,
|
|||
kernel_info->select_kernel_build_info() == nullptr ? def_format : AnfAlgo::GetOutputFormat(anf_node, index);
|
||||
auto data_type =
|
||||
kernel_info->select_kernel_build_info() == nullptr ? type : AnfAlgo::GetOutputDeviceDataType(anf_node, index);
|
||||
std::string reshape_type = AnfAlgo::GetOutputReshapeType(anf_node, index);
|
||||
std::string reshape_type =
|
||||
kernel_info->select_kernel_build_info() == nullptr ? "" : AnfAlgo::GetOutputReshapeType(anf_node, index);
|
||||
trans::ShapeRangeTransfer shapeRangeTransfer;
|
||||
RangePair ret;
|
||||
|
||||
|
|
|
@ -257,7 +257,7 @@ bool SingleTbeJsonCreator::GenOutputsJson(const AnfNodePtr &anf_node, nlohmann::
|
|||
|
||||
size_t sum_outputs_num =
|
||||
std::accumulate(outputs_tensor_num.begin(), outputs_tensor_num.end(), static_cast<size_t>(0));
|
||||
size_t real_output_num = AnfAlgo::GetOutputTensorNum(anf_node);
|
||||
size_t real_output_num = AnfAlgo::GetOutputElementNum(anf_node);
|
||||
std::vector<nlohmann::json> outputs_desc;
|
||||
for (size_t i = 0; i < real_output_num; i++) {
|
||||
nlohmann::json output_desc;
|
||||
|
|
|
@ -54,7 +54,7 @@ bool TbeJsonUtils::GetInputsRealNum(const AnfNodePtr &anf_node, const std::vecto
|
|||
bool TbeJsonUtils::GetOutputsRealNum(const AnfNodePtr &anf_node, const std::vector<OpIOInfoPtr> &outputs_ptr,
|
||||
std::vector<size_t> *outputs_num) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
size_t real_output_num = AnfAlgo::GetOutputTensorNum(anf_node);
|
||||
size_t real_output_num = AnfAlgo::GetOutputElementNum(anf_node);
|
||||
for (const auto &output_ptr : outputs_ptr) {
|
||||
if (output_ptr->param_type() == kJParamDynamic) {
|
||||
if (outputs_ptr.size() > 1) {
|
||||
|
|
|
@ -24,8 +24,8 @@ namespace mindspore::kernel {
|
|||
void TbeKernelAgnosticSelector::GetSupportedFormatDType(SupportFormatDType *support_format_dtype) {
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr_);
|
||||
SupportFormat support_format;
|
||||
auto input_num = common::AnfAlgo::GetInputTensorNum(cnode_ptr_);
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_);
|
||||
auto input_num = AnfAlgo::GetInputElementNum(cnode_ptr_);
|
||||
auto output_num = AnfAlgo::GetOutputElementNum(cnode_ptr_);
|
||||
if (input_num != 1 || output_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "Agnostic only support one input. input_num: " << input_num << ", output num: " << output_num
|
||||
<< ", full_name:" << cnode_ptr_->fullname_with_scope();
|
||||
|
|
|
@ -59,7 +59,7 @@ void TbeKernelBroadcastSelector::GetBroadCastNodeInfo() {
|
|||
(void)input_shapes_.emplace_back(dynamic_input_shape0_);
|
||||
input_num_ = 1;
|
||||
} else {
|
||||
input_num_ = common::AnfAlgo::GetInputTensorNum(cnode_ptr_);
|
||||
input_num_ = AnfAlgo::GetInputElementNum(cnode_ptr_);
|
||||
for (size_t i = 0; i < input_num_; ++i) {
|
||||
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i);
|
||||
PadScalarShape(&input_shape);
|
||||
|
@ -67,7 +67,7 @@ void TbeKernelBroadcastSelector::GetBroadCastNodeInfo() {
|
|||
}
|
||||
}
|
||||
|
||||
output_num_ = AnfAlgo::GetOutputTensorNum(cnode_ptr_);
|
||||
output_num_ = AnfAlgo::GetOutputElementNum(cnode_ptr_);
|
||||
for (size_t i = 0; i < output_num_; ++i) {
|
||||
auto output = common::AnfAlgo::GetOutputInferShape(cnode_ptr_, i);
|
||||
PadScalarShape(&output);
|
||||
|
|
|
@ -51,7 +51,7 @@ void TbeKernelReduceSelector::GetSupportedFormatDType(SupportFormatDType *suppor
|
|||
|
||||
void TbeKernelReduceSelector::GetReduceNodeInfo() {
|
||||
auto input_num = common::AnfAlgo::GetInputTensorNum(cnode_ptr_);
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_);
|
||||
auto output_num = AnfAlgo::GetOutputElementNum(cnode_ptr_);
|
||||
if (input_num != 1 || output_num != 1) {
|
||||
MS_LOG(INFO) << "Reduce operator input/output is not 1, input num: " << input_num << ", output num: " << output_num
|
||||
<< ", node info: " << cnode_ptr_->DebugString();
|
||||
|
|
|
@ -48,6 +48,11 @@ void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<Ke
|
|||
if (op_info_ptr == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (common::AnfAlgo::HasDynamicTupleInput(kernel_node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (IsKernelDynamicImpl(kernel_node)) {
|
||||
common::AnfAlgo::SetNodeAttr(kAttrIsKernelDynamicImpl, MakeValue(true), kernel_node);
|
||||
if (tbe_selector.CheckOpSupported()) {
|
||||
|
@ -486,7 +491,7 @@ bool TbeKernelSelect::GetKernelBuildInfoFromCache() {
|
|||
void TbeKernelSelect::GenerateKernelBuildInfo(const SupportFormatDType &support_format_dtype) {
|
||||
auto dyn_input_sizes = GetNodeDynamicInputs();
|
||||
// get real input/output num
|
||||
size_t real_input_num = common::AnfAlgo::GetInputTensorNum(cnode_ptr_);
|
||||
size_t real_input_num = AnfAlgo::GetInputElementNum(cnode_ptr_);
|
||||
size_t real_output_num = AnfAlgo::GetOutputElementNum(cnode_ptr_);
|
||||
auto op_info_input_num = support_format_dtype.input_dtypes.size();
|
||||
auto op_info_output_num = support_format_dtype.output_dtypes.size();
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "include/common/utils/utils.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "kernel/common_utils.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
namespace {
|
||||
|
@ -56,7 +57,7 @@ bool HostCheck::CheckValidDeviceShape(const AnfNodePtr &node) {
|
|||
}
|
||||
}
|
||||
|
||||
size_t real_output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
size_t real_output_num = AnfAlgo::GetOutputElementNum(node);
|
||||
for (size_t i = 0; i < real_output_num; i++) {
|
||||
auto format = AnfAlgo::GetOutputFormat(node, i);
|
||||
if (!CheckValidInOutDeviceShape(node, i, true, format)) {
|
||||
|
@ -165,8 +166,8 @@ bool IsKernelDynamicImpl(const AnfNodePtr &node) {
|
|||
void GetSupportOriFormat(const CNodePtr &cnode, SupportFormat *support_format) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(support_format);
|
||||
auto input_num = common::AnfAlgo::GetInputTensorNum(cnode);
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||
auto input_num = AnfAlgo::GetInputElementNum(cnode);
|
||||
auto output_num = AnfAlgo::GetOutputElementNum(cnode);
|
||||
auto op_name = common::AnfAlgo::GetCNodeName(cnode);
|
||||
auto op_info = tbe::TbeDynamicShapeUtil::FindOp(op_name, cnode);
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
|
@ -266,30 +267,39 @@ bool CheckHitTargetDtype(const std::map<TypeId, TypeId> &type_map, const TypeId
|
|||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info, const CNodePtr &cnode,
|
||||
const std::map<TypeId, TypeId> &type_map) {
|
||||
// filte kernel info that unsupported raise or reduce datatype
|
||||
bool TagUnfoldRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info, const CNodePtr &cnode,
|
||||
const std::map<TypeId, TypeId> &type_map) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_info);
|
||||
for (size_t input_index = 0; input_index < kernel_build_info->GetInputNum(); ++input_index) {
|
||||
auto in_dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
|
||||
auto device_dtype = kernel_build_info->GetInputDeviceType(input_index);
|
||||
if (device_dtype == kNumberTypeFloat) {
|
||||
device_dtype = kNumberTypeFloat32;
|
||||
}
|
||||
if (!CheckHitTargetDtype(type_map, in_dtype, device_dtype)) {
|
||||
return false;
|
||||
|
||||
// Check input data type
|
||||
size_t kernel_input_index = 0;
|
||||
size_t fold_input_tensor_num = common::AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t input_index = 0; input_index < fold_input_tensor_num; ++input_index) {
|
||||
std::vector<TypeId> inputs_type = common::AnfAlgo::GetRealPrevNodesOutputInferDataType(cnode, input_index);
|
||||
for (size_t i = 0; i < inputs_type.size(); ++i) {
|
||||
if (kernel_input_index >= kernel_build_info->GetInputNum()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto device_dtype = kernel_build_info->GetInputDeviceType(kernel_input_index);
|
||||
if (device_dtype == kNumberTypeFloat) {
|
||||
device_dtype = kNumberTypeFloat32;
|
||||
}
|
||||
if (!CheckHitTargetDtype(type_map, inputs_type[i], device_dtype)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
++kernel_input_index;
|
||||
}
|
||||
}
|
||||
|
||||
// Check output data type
|
||||
for (size_t output_index = 0; output_index < kernel_build_info->GetOutputNum(); ++output_index) {
|
||||
auto in_dtype = common::AnfAlgo::GetOutputInferDataType(cnode, output_index);
|
||||
auto device_dtype = kernel_build_info->GetOutputDeviceType(output_index);
|
||||
if (device_dtype == kNumberTypeFloat) {
|
||||
device_dtype = kNumberTypeFloat32;
|
||||
}
|
||||
|
||||
if (!CheckHitTargetDtype(type_map, in_dtype, device_dtype)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -297,6 +307,71 @@ bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build
|
|||
return true;
|
||||
}
|
||||
|
||||
bool TagFoldRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info, const CNodePtr &cnode,
|
||||
const std::map<TypeId, TypeId> &type_map) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_info);
|
||||
|
||||
// Check input data type
|
||||
size_t kernel_input_index = 0;
|
||||
for (size_t input_index = 0; input_index < common::AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
|
||||
if (kernel_build_info->GetInputKernelObjectType(kernel_input_index) == kernel::KernelObjectType::TUPLE) {
|
||||
auto input_node = cnode->inputs()[input_index + 1];
|
||||
TypeId in_dtype = common::AnfAlgo::GetOutputInferDataType(input_node, 0);
|
||||
auto device_dtype = kernel_build_info->GetInputDeviceType(kernel_input_index);
|
||||
if (device_dtype == kNumberTypeFloat) {
|
||||
device_dtype = kNumberTypeFloat32;
|
||||
}
|
||||
if (!CheckHitTargetDtype(type_map, in_dtype, device_dtype)) {
|
||||
return false;
|
||||
}
|
||||
++kernel_input_index;
|
||||
} else {
|
||||
std::vector<TypeId> inputs_type = common::AnfAlgo::GetRealPrevNodesOutputInferDataType(cnode, input_index);
|
||||
for (size_t i = 0; i < inputs_type.size(); ++i) {
|
||||
if (kernel_input_index >= kernel_build_info->GetInputNum()) {
|
||||
return false;
|
||||
}
|
||||
auto device_dtype = kernel_build_info->GetInputDeviceType(kernel_input_index);
|
||||
if (device_dtype == kNumberTypeFloat) {
|
||||
device_dtype = kNumberTypeFloat32;
|
||||
}
|
||||
if (!CheckHitTargetDtype(type_map, inputs_type[i], device_dtype)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
++kernel_input_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Check output data type
|
||||
for (size_t output_index = 0; output_index < kernel_build_info->GetOutputNum(); ++output_index) {
|
||||
auto in_dtype = common::AnfAlgo::GetOutputInferDataType(cnode, output_index);
|
||||
auto device_dtype = kernel_build_info->GetOutputDeviceType(output_index);
|
||||
if (device_dtype == kNumberTypeFloat) {
|
||||
device_dtype = kNumberTypeFloat32;
|
||||
}
|
||||
if (!CheckHitTargetDtype(type_map, in_dtype, device_dtype)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info, const CNodePtr &cnode,
|
||||
const std::map<TypeId, TypeId> &type_map) {
|
||||
// filte kernel info that unsupported raise or reduce datatype
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_info);
|
||||
|
||||
bool is_fold = kernel::IsFoldKernelBuildInfo(kernel_build_info);
|
||||
if (is_fold) {
|
||||
return TagFoldRaiseReduce(kernel_build_info, cnode, type_map);
|
||||
} else {
|
||||
return TagUnfoldRaiseReduce(kernel_build_info, cnode, type_map);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecisionMatchedKernelInfo(
|
||||
const CNodePtr &cnode, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list,
|
||||
bool *precision_reduce) {
|
||||
|
|
|
@ -184,6 +184,8 @@
|
|||
#include "include/common/debug/draw.h"
|
||||
#include "plugin/device/ascend/optimizer/optimizer_factory.h"
|
||||
#include "plugin/device/ascend/hal/common/ascend_utils.h"
|
||||
#include "backend/common/pass/insert_type_transform_op.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/ascend_convert_tuple_input_to_dynamic_input.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -221,6 +223,7 @@ void AddAscendIRFusionRulesPass(PassManager *ir_fusion_pm) {
|
|||
|
||||
void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
|
||||
MS_EXCEPTION_IF_NULL(ir_fusion_pm);
|
||||
ir_fusion_pm->AddPass(std::make_shared<AscendConvertTupleInputToDynamicInput>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<UnsortedSegmentSumReplace>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<SingleBatchNormFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>());
|
||||
|
@ -289,6 +292,7 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph)
|
|||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||
auto data_layout_pm = std::make_shared<PassManager>("transop_pm");
|
||||
data_layout_pm->AddPass(std::make_shared<opt::InsertTypeTransformOp>());
|
||||
data_layout_pm->AddPass(std::make_shared<ReselectCallInlineFormat>());
|
||||
data_layout_pm->AddPass(std::make_shared<RectifyDoMaskKernelInfo>());
|
||||
data_layout_pm->AddPass(std::make_shared<DynamicRNNGradReformat>());
|
||||
|
@ -739,5 +743,34 @@ void AscendOpAdaptation(const std::shared_ptr<session::KernelGraph> &kernel_grap
|
|||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void AscendUnfoldInputsForSpecialNodes(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
if (context_ptr->CanDump(kIntroductory)) {
|
||||
std::string file_name =
|
||||
"hwopt_d_before_unfold_inputs_for_special_nodes_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
|
||||
DumpIR(file_name, kernel_graph, true, kWholeStack);
|
||||
DumpIRProto(kernel_graph,
|
||||
"before_unfold_inputs_for_special_nodes_hwopt_" + std::to_string(kernel_graph->graph_id()));
|
||||
}
|
||||
#endif
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto unfold_inputs_pm = std::make_shared<opt::PassManager>("unfold_inputs_for_special_nodes_pm");
|
||||
unfold_inputs_pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
|
||||
|
||||
optimizer->AddPassManager(unfold_inputs_pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
if (context_ptr->CanDump(kIntroductory)) {
|
||||
std::string file_name =
|
||||
"hwopt_d_after_unfold_inputs_for_special_nodes_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
|
||||
DumpIR(file_name, kernel_graph, true, kWholeStack);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,6 +30,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
|
|||
void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendUnifyMindIR(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendOpAdaptation(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendUnfoldInputsForSpecialNodes(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -638,8 +638,10 @@ void SelectCallInlineKernelInfo(const CNodePtr &node) {
|
|||
auto sub_ret = sub_graph->output();
|
||||
std::vector<std::string> input_formats;
|
||||
std::vector<TypeId> input_types;
|
||||
std::vector<kernel::KernelObjectType> input_object_types;
|
||||
std::vector<std::string> output_formats;
|
||||
std::vector<TypeId> output_types;
|
||||
std::vector<kernel::KernelObjectType> output_object_types;
|
||||
for (auto ¶m : sub_graph->inputs()) {
|
||||
TypeId type_id = AnfAlgo::GetOutputDeviceDataType(param, 0);
|
||||
if (type_id == kTypeUnknown) {
|
||||
|
@ -650,17 +652,25 @@ void SelectCallInlineKernelInfo(const CNodePtr &node) {
|
|||
}
|
||||
input_types.push_back(type_id);
|
||||
input_formats.push_back(AnfAlgo::GetOutputFormat(param, 0));
|
||||
input_object_types.push_back(kernel::KernelObjectType::TENSOR);
|
||||
}
|
||||
for (size_t i = 0; i < AnfUtils::GetOutputTensorNum(node); ++i) {
|
||||
output_formats.push_back(AnfAlgo::GetOutputFormat(sub_ret, i));
|
||||
output_types.push_back(common::AnfAlgo::GetOutputInferDataType(sub_ret, i));
|
||||
if (AnfAlgo::GetOutputObjectType(node, i) == TypeId::kObjectTypeTuple) {
|
||||
output_object_types.push_back(kernel::KernelObjectType::TUPLE_UNFOLD);
|
||||
} else {
|
||||
output_object_types.push_back(kernel::KernelObjectType::TENSOR);
|
||||
}
|
||||
}
|
||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
MS_EXCEPTION_IF_NULL(builder);
|
||||
builder->SetInputsFormat(input_formats);
|
||||
builder->SetInputsDeviceType(input_types);
|
||||
builder->SetInputsKernelObjectType(input_object_types);
|
||||
builder->SetOutputsFormat(output_formats);
|
||||
builder->SetOutputsDeviceType(output_types);
|
||||
builder->SetOutputsKernelObjectType(output_object_types);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
|
||||
}
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ AnfNodePtr InsertTensorMoveForGetNextOutputs(const FuncGraphPtr &func_graph, con
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
size_t output_num = AnfAlgo::GetOutputElementNum(node);
|
||||
if (output_num == 0) {
|
||||
MS_LOG(DEBUG) << "Output number is zero, no need to insert tensor_move!";
|
||||
return node;
|
||||
|
|
|
@ -101,7 +101,7 @@ AnfNodePtr InsertForOutput(const FuncGraphPtr &func_graph, const CNodePtr &orig_
|
|||
}
|
||||
|
||||
std::vector<AnfNodePtr> tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
auto out_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
auto out_num = AnfAlgo::GetOutputElementNum(node);
|
||||
|
||||
for (size_t output_idx = 0; output_idx < out_num; output_idx++) {
|
||||
auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx);
|
||||
|
|
|
@ -68,7 +68,7 @@ CNodePtr Insert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std
|
|||
}
|
||||
} else if (op_name == kBasicLSTMCellWeightGradOpName) {
|
||||
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
size_t out_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||
size_t out_num = AnfAlgo::GetOutputElementNum(cnode);
|
||||
for (size_t output_idx = 0; output_idx < out_num; output_idx++) {
|
||||
auto tuple_getitem = CreatTupleGetItemNode(func_graph, cnode, output_idx);
|
||||
auto origin_shape = common::AnfAlgo::GetOutputInferShape(cnode, output_idx);
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/ascend_convert_tuple_input_to_dynamic_input.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef AscendConvertTupleInputToDynamicInput::DefinePattern() const {
|
||||
VarPtr V = std::make_shared<Var>();
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({V, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr AscendConvertTupleInputToDynamicInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
if (node == nullptr || !node->isa<CNode>() || !AnfUtils::IsRealKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// this pass should be in front of concat_fission, pack_fission, addn_fission, since the input should be unfold before
|
||||
// this passes.
|
||||
// the auto_monad pass should before this pass
|
||||
bool is_communication_op = common::AnfAlgo::IsCommunicationOp(node);
|
||||
static const PrimitiveSet need_unfold_node = {prim::kPrimAddN, prim::kPrimConcatD, prim::kPrimPack,
|
||||
prim::kPrimStack, prim::kPrimCallInline, prim::kPrimPrint,
|
||||
prim::kPrimSwitchLayer, prim::kPrimCall, prim::kPrimSwitch};
|
||||
PrimitivePtr prim = common::AnfAlgo::GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (!is_communication_op && need_unfold_node.find(prim) == need_unfold_node.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return ConvertMakeTupleInputToPlantInputs(func_graph, node->cast<CNodePtr>());
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_OPTIMIZER_ASCEND_IR_ASCEND_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_
|
||||
#define MINDSPORE_CCSRC_OPTIMIZER_ASCEND_IR_ASCEND_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class AscendConvertTupleInputToDynamicInput : public PatternProcessPass {
|
||||
public:
|
||||
explicit AscendConvertTupleInputToDynamicInput(bool multigraph = true)
|
||||
: PatternProcessPass("ascend_convert_tuple_input_to_dynamic_input", multigraph) {}
|
||||
|
||||
~AscendConvertTupleInputToDynamicInput() override = default;
|
||||
|
||||
const BaseRef DefinePattern() const override;
|
||||
|
||||
const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_OPTIMIZER_ASCEND_IR_ASCEND_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_
|
|
@ -91,6 +91,7 @@ ValueNodePtr CreateValueNode(const AnfNodePtr &node) {
|
|||
kernel::KernelBuildInfo::KernelBuildInfoBuilder op_builder;
|
||||
op_builder.SetOutputsFormat({kOpFormat_NDC1HWC0});
|
||||
op_builder.SetOutputsDeviceType({kNumberTypeFloat16});
|
||||
op_builder.SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(op_builder.Build(), assist_const.get());
|
||||
return assist_const;
|
||||
}
|
||||
|
|
|
@ -70,6 +70,7 @@ ValueNodePtr CreateValueNode(T seed) {
|
|||
} else {
|
||||
builder.SetOutputsDeviceType({kNumberTypeUInt64});
|
||||
}
|
||||
builder.SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), value_node.get());
|
||||
return value_node;
|
||||
}
|
||||
|
|
|
@ -86,6 +86,7 @@ ValueNodePtr CreateValueNode(const AnfNodePtr &node) {
|
|||
kernel::KernelBuildInfo::KernelBuildInfoBuilder op_builder;
|
||||
op_builder.SetOutputsFormat({kOpFormat_NC1HWC0});
|
||||
op_builder.SetOutputsDeviceType({kNumberTypeFloat16});
|
||||
op_builder.SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(op_builder.Build(), assist_const.get());
|
||||
return assist_const;
|
||||
}
|
||||
|
|
|
@ -117,6 +117,7 @@ ValueNodePtr CreateAssistNode(const std::vector<int64_t> &input_shape, int32_t k
|
|||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1;
|
||||
builder1.SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
builder1.SetOutputsDeviceType({common::AnfAlgo::GetOutputInferDataType(assist_const, 0)});
|
||||
builder1.SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), assist_const.get());
|
||||
return assist_const;
|
||||
}
|
||||
|
@ -130,6 +131,8 @@ kernel::KernelBuildInfoPtr CreateKernelBuildInfo() {
|
|||
builder.SetOutputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});
|
||||
builder.SetInputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16});
|
||||
builder.SetOutputsDeviceType({kNumberTypeFloat16, kNumberTypeInt32});
|
||||
builder.SetInputsKernelObjectType({kernel::KernelObjectType::TENSOR});
|
||||
builder.SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
|
|
|
@ -123,7 +123,7 @@ const AnfNodePtr AdaptiveMaxPool2DFusion::Process(const FuncGraphPtr &func_graph
|
|||
std::vector<int64_t> new_output_size{output_h, output_w};
|
||||
common::AnfAlgo::SetNodeAttr(kAttrOutputSize, MakeValue(new_output_size), adaptive_max_pool2d);
|
||||
|
||||
if (AnfAlgo::GetOutputTensorNum(adaptive_max_pool2d) > 1) {
|
||||
if (AnfAlgo::GetOutputElementNum(adaptive_max_pool2d) > 1) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -512,6 +512,7 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph,
|
|||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
|
||||
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{kNumberTypeUInt8});
|
||||
kernel_build_info_builder->SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), mask_input.get());
|
||||
}
|
||||
|
||||
|
|
|
@ -59,6 +59,7 @@ ValueNodePtr CreateValueNode(const ValuePtr &value_ptr, TypeId output_type) {
|
|||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1;
|
||||
builder1.SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
builder1.SetOutputsDeviceType({output_type});
|
||||
builder1.SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), new_node.get());
|
||||
return new_node;
|
||||
}
|
||||
|
|
|
@ -589,18 +589,18 @@ bool GetSelectKernelResult(const CNodePtr &kernel_node,
|
|||
}
|
||||
|
||||
#ifdef ENABLE_TUPLE_UNFOLD
|
||||
bool GetSelectKernelObjectTypeResult(const CNodePtr &kernel_node) {
|
||||
bool GetSelectKernelObjectTypeResult(const CNodePtr &kernel_node, KernelType kernel_type) {
|
||||
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
// Only the kernel nodes that register kernel attr can support the backoff.
|
||||
bool backoff_support_condition =
|
||||
((kernel_type == UNKNOWN_KERNEL_TYPE) && !IsPrimitiveCNode(kernel_node, prim::kPrimCustom) &&
|
||||
!common::AnfAlgo::IsGraphKernel(kernel_node));
|
||||
std::vector<kernel::KernelAttr> kernel_attrs;
|
||||
// Kernel that is not supported can try to backed off on CPU and use the CPU kernel attrs to set object type.
|
||||
if (!kernel::NativeGpuKernelModFactory::GetInstance().IsRegistered(kernel_name)) {
|
||||
if (kernel::NativeGpuKernelModFactory::GetInstance().IsRegistered(kernel_name)) {
|
||||
kernel_attrs = kernel::NativeGpuKernelMod::GetGpuSupportedList(kernel_name);
|
||||
} else if (backoff_support_condition) {
|
||||
// Kernel that is not supported can try to backed off on CPU and use the CPU kernel attrs to set object type.
|
||||
kernel_attrs = kernel::NativeCpuKernelMod::GetCpuSupportedList(kernel_name);
|
||||
// CPU also doesn't support the kernel.
|
||||
if (kernel_attrs.empty()) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
kernel_attrs = kernel::NativeGpuKernelModFactory::GetInstance().GetGpuSupportedList(kernel_name);
|
||||
}
|
||||
|
||||
// Some dynamic kernels may not set the kernel attrs on GPU. Skip check only supports the tuple fold.
|
||||
|
@ -635,7 +635,7 @@ std::pair<std::string, ExceptionType> SetKernelInfoWithMsg(const CNodePtr &kerne
|
|||
auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get());
|
||||
#ifdef ENABLE_TUPLE_UNFOLD
|
||||
bool selected = GetSelectKernelObjectTypeResult(kernel_node);
|
||||
bool selected = GetSelectKernelObjectTypeResult(kernel_node, kernel_type);
|
||||
if (!selected) {
|
||||
std::stringstream ss;
|
||||
ss << "kernel object types are not supported for " << common::AnfAlgo::GetCNodeName(kernel_node)
|
||||
|
|
|
@ -84,7 +84,10 @@ std::vector<KernelAttr> NativeGpuKernelModFactory::GetGpuSupportedList(const std
|
|||
|
||||
for (size_t attr_index = 0; attr_index < (iter->second).size(); ++attr_index) {
|
||||
auto attr = (iter->second)[attr_index].first;
|
||||
kernel_attr_list.push_back(attr);
|
||||
// Skip the invalid attr.
|
||||
if (attr.GetInputSize() > 0 || attr.GetOutputSize() > 0) {
|
||||
kernel_attr_list.push_back(attr);
|
||||
}
|
||||
}
|
||||
|
||||
return kernel_attr_list;
|
||||
|
|
|
@ -574,6 +574,38 @@ KernelWithIndex AnfAlgo::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t in
|
|||
return res;
|
||||
}
|
||||
|
||||
// if the prev_node is MakeTuple, get all the input_nodes recursively, else use the ori GetPrevNodeOutput function
|
||||
std::vector<KernelWithIndex> AnfAlgo::GetRealPrevNodesOutput(const AnfNodePtr &anf_node, size_t input_idx,
|
||||
bool skip_nop_node) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
std::vector<KernelWithIndex> res;
|
||||
auto input_node = AnfAlgo::GetInputNode(cnode, input_idx);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) {
|
||||
auto maketuple_input_num = GetInputTensorNum(input_node);
|
||||
for (size_t i = 0; i < maketuple_input_num; ++i) {
|
||||
auto inputs_i = GetRealPrevNodesOutput(input_node, i, skip_nop_node);
|
||||
res.insert(res.end(), inputs_i.begin(), inputs_i.end());
|
||||
}
|
||||
} else {
|
||||
res.emplace_back(GetPrevNodeOutput(cnode, input_idx, skip_nop_node));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<TypeId> AnfAlgo::GetRealPrevNodesOutputInferDataType(const AnfNodePtr &node, size_t input_idx) {
|
||||
std::vector<KernelWithIndex> kernels_with_index = AnfAlgo::GetRealPrevNodesOutput(node, input_idx);
|
||||
std::vector<TypeId> res;
|
||||
(void)std::transform(kernels_with_index.begin(), kernels_with_index.end(), std::back_inserter(res),
|
||||
[](auto kernel_with_index) {
|
||||
return AnfAlgo::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second);
|
||||
});
|
||||
return res;
|
||||
}
|
||||
|
||||
inline ShapeVector GetShape(const abstract::BaseShapePtr &base_shape) {
|
||||
auto shape_ptr = base_shape->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
|
@ -763,8 +795,8 @@ void AnfAlgo::SetOutputTypeAndDetailShape(const std::vector<TypeId> &types,
|
|||
node_name = GetCNodeName(node_ptr);
|
||||
}
|
||||
if (types.size() != shapes.size()) {
|
||||
MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size() << "."
|
||||
<< trace::DumpSourceLines(node);
|
||||
MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size()
|
||||
<< " for node " << node->fullname_with_scope() << "." << trace::DumpSourceLines(node);
|
||||
}
|
||||
|
||||
auto tuple_node = kNodeTupleOutSet.find(node_name);
|
||||
|
|
|
@ -483,7 +483,8 @@ def test_call_no_self_other_object_method_runtime():
|
|||
assert np.all(result == z)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.skip(reason="Not supported by now")
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -505,7 +506,8 @@ def test_getattr_tensor_with_wrong_attr():
|
|||
assert "object has no attribute" in str(err.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.skip(reason="Not supported by now")
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#define private public
|
||||
#define protected public
|
||||
#include "plugin/device/ascend/optimizer/enhancer/insert_tensor_move_for_hccl_op.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/ascend_convert_tuple_input_to_dynamic_input.h"
|
||||
#undef private
|
||||
#undef protected
|
||||
namespace mindspore {
|
||||
|
@ -168,6 +169,8 @@ TEST_F(TestHWInsertTensorMoveForHccl, test_cond5) {
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
// This pass run before hccl_pass to unfold inputs of hccl node
|
||||
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
|
||||
auto pass = std::make_shared<opt::InsertTensorMoveForHcclOp>();
|
||||
pass->kernel_query_ = std::make_shared<MockInsertTensorMoveForHcclKernelQuery>();
|
||||
pm->AddPass(pass);
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#define private public
|
||||
#define protected public
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/addn_fission.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/ascend_convert_tuple_input_to_dynamic_input.h"
|
||||
#undef private
|
||||
#undef protected
|
||||
|
||||
|
@ -45,6 +46,7 @@ TEST_F(TestHWAddnFission, test_addn_fission_divided_by_2) {
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
|
||||
auto addn_fission = std::make_shared<opt::AddnFission>();
|
||||
addn_fission->inputs_divisor_ = 2;
|
||||
pm->AddPass(addn_fission);
|
||||
|
@ -54,7 +56,13 @@ TEST_F(TestHWAddnFission, test_addn_fission_divided_by_2) {
|
|||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_addn_fission", "after_divided_by_2");
|
||||
EXPECT_NE(g_after, nullptr);
|
||||
auto kg_after = GetKernelGraph(g_after, args_spec_list);
|
||||
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
|
||||
auto optimizer2 = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm2 = std::make_shared<opt::PassManager>();
|
||||
pm2->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
|
||||
optimizer2->AddPassManager(pm2);
|
||||
auto kg_after2 = optimizer2->Optimize(kg_after);
|
||||
|
||||
EXPECT_TRUE(CheckEqualGraph(kg_after2, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWAddnFission, test_addn_fission_divided_by_3) {
|
||||
|
@ -70,6 +78,7 @@ TEST_F(TestHWAddnFission, test_addn_fission_divided_by_3) {
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
|
||||
auto addn_fission = std::make_shared<opt::AddnFission>();
|
||||
addn_fission->inputs_divisor_ = 3;
|
||||
pm->AddPass(addn_fission);
|
||||
|
@ -79,7 +88,13 @@ TEST_F(TestHWAddnFission, test_addn_fission_divided_by_3) {
|
|||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_addn_fission", "after_divided_by_3");
|
||||
EXPECT_NE(g_after, nullptr);
|
||||
auto kg_after = GetKernelGraph(g_after, args_spec_list);
|
||||
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
|
||||
auto optimizer2 = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm2 = std::make_shared<opt::PassManager>();
|
||||
pm2->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
|
||||
optimizer2->AddPassManager(pm2);
|
||||
auto kg_after2 = optimizer2->Optimize(kg_after);
|
||||
|
||||
EXPECT_TRUE(CheckEqualGraph(kg_after2, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWAddnFission, test_addn_fission_divided_by_4) {
|
||||
|
@ -95,6 +110,7 @@ TEST_F(TestHWAddnFission, test_addn_fission_divided_by_4) {
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
|
||||
auto addn_fission = std::make_shared<opt::AddnFission>();
|
||||
addn_fission->inputs_divisor_ = 4;
|
||||
pm->AddPass(addn_fission);
|
||||
|
@ -104,7 +120,13 @@ TEST_F(TestHWAddnFission, test_addn_fission_divided_by_4) {
|
|||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_addn_fission", "after_divided_by_4");
|
||||
EXPECT_NE(g_after, nullptr);
|
||||
auto kg_after = GetKernelGraph(g_after, args_spec_list);
|
||||
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
|
||||
auto optimizer2 = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm2 = std::make_shared<opt::PassManager>();
|
||||
pm2->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
|
||||
optimizer2->AddPassManager(pm2);
|
||||
auto kg_after2 = optimizer2->Optimize(kg_after);
|
||||
|
||||
EXPECT_TRUE(CheckEqualGraph(kg_after2, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWAddnFission, test_addn_fission_divided_by_8) {
|
||||
|
@ -120,6 +142,7 @@ TEST_F(TestHWAddnFission, test_addn_fission_divided_by_8) {
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
|
||||
auto addn_fission = std::make_shared<opt::AddnFission>();
|
||||
addn_fission->inputs_divisor_ = 8;
|
||||
pm->AddPass(addn_fission);
|
||||
|
@ -129,7 +152,13 @@ TEST_F(TestHWAddnFission, test_addn_fission_divided_by_8) {
|
|||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_addn_fission", "after_divided_by_8");
|
||||
EXPECT_NE(g_after, nullptr);
|
||||
auto kg_after = GetKernelGraph(g_after, args_spec_list);
|
||||
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
|
||||
auto optimizer2 = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm2 = std::make_shared<opt::PassManager>();
|
||||
pm2->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
|
||||
optimizer2->AddPassManager(pm2);
|
||||
auto kg_after2 = optimizer2->Optimize(kg_after);
|
||||
|
||||
EXPECT_TRUE(CheckEqualGraph(kg_after2, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWAddnFission, test_addn_fission_divided_by_9) {
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#define private public
|
||||
#define protected public
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/concat_fission.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/ascend_convert_tuple_input_to_dynamic_input.h"
|
||||
#undef private
|
||||
#undef protected
|
||||
|
||||
|
@ -45,6 +46,7 @@ TEST_F(TestHWConcatFission, test_concat_fission_divided_by_2) {
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
|
||||
auto concat_fission = std::make_shared<opt::ConcatFission>();
|
||||
concat_fission->inputs_divisor_ = 2;
|
||||
pm->AddPass(concat_fission);
|
||||
|
@ -54,7 +56,13 @@ TEST_F(TestHWConcatFission, test_concat_fission_divided_by_2) {
|
|||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_2");
|
||||
EXPECT_NE(g_after, nullptr);
|
||||
auto kg_after = GetKernelGraph(g_after, args_spec_list);
|
||||
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
|
||||
auto optimizer2 = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm2 = std::make_shared<opt::PassManager>();
|
||||
pm2->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
|
||||
optimizer2->AddPassManager(pm2);
|
||||
auto kg_after2 = optimizer2->Optimize(kg_after);
|
||||
|
||||
EXPECT_TRUE(CheckEqualGraph(kg_after2, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWConcatFission, test_concat_fission_divided_by_3) {
|
||||
|
@ -70,6 +78,7 @@ TEST_F(TestHWConcatFission, test_concat_fission_divided_by_3) {
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
|
||||
auto concat_fission = std::make_shared<opt::ConcatFission>();
|
||||
concat_fission->inputs_divisor_ = 3;
|
||||
pm->AddPass(concat_fission);
|
||||
|
@ -79,7 +88,13 @@ TEST_F(TestHWConcatFission, test_concat_fission_divided_by_3) {
|
|||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_3");
|
||||
EXPECT_NE(g_after, nullptr);
|
||||
auto kg_after = GetKernelGraph(g_after, args_spec_list);
|
||||
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
|
||||
auto optimizer2 = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm2 = std::make_shared<opt::PassManager>();
|
||||
pm2->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
|
||||
optimizer2->AddPassManager(pm2);
|
||||
auto kg_after2 = optimizer2->Optimize(kg_after);
|
||||
|
||||
EXPECT_TRUE(CheckEqualGraph(kg_after2, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWConcatFission, test_concat_fission_divided_by_4) {
|
||||
|
@ -95,6 +110,7 @@ TEST_F(TestHWConcatFission, test_concat_fission_divided_by_4) {
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
|
||||
auto concat_fission = std::make_shared<opt::ConcatFission>();
|
||||
concat_fission->inputs_divisor_ = 4;
|
||||
pm->AddPass(concat_fission);
|
||||
|
@ -104,7 +120,13 @@ TEST_F(TestHWConcatFission, test_concat_fission_divided_by_4) {
|
|||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_4");
|
||||
EXPECT_NE(g_after, nullptr);
|
||||
auto kg_after = GetKernelGraph(g_after, args_spec_list);
|
||||
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
|
||||
auto optimizer2 = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm2 = std::make_shared<opt::PassManager>();
|
||||
pm2->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
|
||||
optimizer2->AddPassManager(pm2);
|
||||
auto kg_after2 = optimizer2->Optimize(kg_after);
|
||||
|
||||
EXPECT_TRUE(CheckEqualGraph(kg_after2, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWConcatFission, test_concat_fission_divided_by_8) {
|
||||
|
@ -120,6 +142,7 @@ TEST_F(TestHWConcatFission, test_concat_fission_divided_by_8) {
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
|
||||
auto concat_fission = std::make_shared<opt::ConcatFission>();
|
||||
concat_fission->inputs_divisor_ = 8;
|
||||
pm->AddPass(concat_fission);
|
||||
|
@ -129,7 +152,13 @@ TEST_F(TestHWConcatFission, test_concat_fission_divided_by_8) {
|
|||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_8");
|
||||
EXPECT_NE(g_after, nullptr);
|
||||
auto kg_after = GetKernelGraph(g_after, args_spec_list);
|
||||
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
|
||||
auto optimizer2 = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm2 = std::make_shared<opt::PassManager>();
|
||||
pm2->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
|
||||
optimizer2->AddPassManager(pm2);
|
||||
auto kg_after2 = optimizer2->Optimize(kg_after);
|
||||
|
||||
EXPECT_TRUE(CheckEqualGraph(kg_after2, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWConcatFission, test_concat_fission_divided_by_9) {
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#define private public
|
||||
#define protected public
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/pack_fission.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/ascend_convert_tuple_input_to_dynamic_input.h"
|
||||
#undef private
|
||||
#undef protected
|
||||
|
||||
|
@ -45,6 +46,7 @@ TEST_F(TestHWPackFission, test_stack_fission_divided_by_3) {
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
|
||||
auto pack_fission = std::make_shared<opt::PackFission>();
|
||||
pack_fission->inputs_divisor_ = 3;
|
||||
pm->AddPass(pack_fission);
|
||||
|
@ -69,6 +71,7 @@ TEST_F(TestHWPackFission, test_stack_fission_divided_by_4) {
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
|
||||
auto pack_fission = std::make_shared<opt::PackFission>();
|
||||
pack_fission->inputs_divisor_ = 4;
|
||||
pm->AddPass(pack_fission);
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "common/py_func_graph_fetcher.h"
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fusion/confusion_mul_grad_fusion.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/ascend_convert_tuple_input_to_dynamic_input.h"
|
||||
#include "include/common/debug/anf_ir_dump.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -42,6 +43,7 @@ TEST_F(TestHWOptimizeConfusionMulGradFusion, test_fusion) {
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
|
||||
pm->AddPass(std::make_shared<opt::ConfusionMulGradFusion>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "common/backend_common_test.h"
|
||||
#include "common/py_func_graph_fetcher.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fusion/mul_addn_fusion.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/ascend_convert_tuple_input_to_dynamic_input.h"
|
||||
#include "include/common/debug/anf_ir_dump.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -37,6 +38,7 @@ TEST_F(TestHWMulAddNFusion, test_mul_addn_fusion) {
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
|
||||
pm->AddPass(std::make_shared<opt::MulAddNFusion>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
@ -55,6 +57,7 @@ TEST_F(TestHWMulAddNFusion, test_unmatch) {
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AscendConvertTupleInputToDynamicInput>());
|
||||
pm->AddPass(std::make_shared<opt::MulAddNFusion>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
|
|
@ -138,9 +138,7 @@ def test_insert_tensor_move_for_hccl_op_cond5(tag):
|
|||
m1 = tensor_move(b)
|
||||
m2 = tensor_move(c)
|
||||
y = broadcast(m1, m2)
|
||||
y0 = tuple_getitem(y, 0)
|
||||
y1 = tuple_getitem(y, 1)
|
||||
res = depend(x, make_tuple(y0, y1))
|
||||
res = depend(x, y)
|
||||
return make_tuple(res)
|
||||
|
||||
return fns[tag]
|
||||
|
|
|
@ -48,7 +48,7 @@ TEST_F(TestMemUsageAnalyzer, test_mem_usage_analyzer) {
|
|||
auto tensor_infos = analyzer->GetMemUsageTensorInfos();
|
||||
|
||||
ASSERT_EQ(5, kernel_infos.size());
|
||||
ASSERT_EQ(15, tensor_infos.size());
|
||||
ASSERT_EQ(16, tensor_infos.size());
|
||||
for (size_t i = 0; i < kernel_infos.size(); ++i) {
|
||||
ASSERT_NE(nullptr, analyzer->GetMemUsageKernelInfo(i));
|
||||
}
|
||||
|
@ -57,6 +57,6 @@ TEST_F(TestMemUsageAnalyzer, test_mem_usage_analyzer) {
|
|||
ASSERT_NE(nullptr, analyzer->GetMemUsageTensorInfo(i));
|
||||
}
|
||||
|
||||
ASSERT_EQ(132, analyzer->LeastMemNeeded());
|
||||
ASSERT_EQ(100, analyzer->LeastMemNeeded());
|
||||
}
|
||||
} // namespace mindspore::device
|
Loading…
Reference in New Issue