!48339 set dynamic input size for dynamic param kernel

Merge pull request !48339 from wYann/test_kobj
This commit is contained in:
i-robot 2023-02-08 08:08:58 +00:00 committed by Gitee
commit 2ecba9eea2
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 21 additions and 51 deletions

View File

@ -102,9 +102,7 @@ AnfNodePtr CreateNewNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &i
// methods. Only reset input types. // methods. Only reset input types.
auto new_prim = GetValueNode<PrimitivePtr>(new_cnode->input(kIndex0)); auto new_prim = GetValueNode<PrimitivePtr>(new_cnode->input(kIndex0));
auto origin_prim = GetValueNode<PrimitivePtr>(origin_node->input(kIndex0)); auto origin_prim = GetValueNode<PrimitivePtr>(origin_node->input(kIndex0));
if (kernel::IsDynamicParamKernel(origin_prim->name())) { if (IsPrimitiveEquals(new_prim, origin_prim) && !kernel::IsDynamicParamKernel(origin_prim->name())) {
SetKernelInfoForDynamicParamKernel(new_cnode);
} else if (IsPrimitiveEquals(new_prim, origin_prim)) {
SetKernelInfoForNewCNode(new_cnode, false); SetKernelInfoForNewCNode(new_cnode, false);
} else { } else {
SetKernelInfoForNewCNode(new_cnode, true); SetKernelInfoForNewCNode(new_cnode, true);
@ -249,42 +247,6 @@ void SetKernelInfoForNewCNode(const CNodePtr &cnode, bool set_format_type) {
<< build_info->ToString(); << build_info->ToString();
} }
void SetKernelInfoForDynamicParamKernel(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(kernel_info);
cnode->set_kernel_info(kernel_info);
auto builder = std::make_shared<KernelBuildInfoBuilder>();
MS_EXCEPTION_IF_NULL(builder);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), cnode.get());
std::vector<KernelObjectType> input_obj_type =
kernel::TypeIdToKernelObjectType(AnfAlgo::GetAllInputObjectType(cnode));
std::vector<KernelObjectType> output_obj_type =
kernel::TypeIdToKernelObjectType(AnfAlgo::GetAllOutputObjectType(cnode));
builder->SetInputsKernelObjectType(input_obj_type);
builder->SetOutputsKernelObjectType(output_obj_type);
// Set input and output format.
std::vector<std::string> inputs_format;
std::vector<TypeId> inputs_type;
size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
auto input_node = common::AnfAlgo::GetInputNode(cnode, input_index);
inputs_format.emplace_back(kOpFormat_DEFAULT);
inputs_type.push_back(common::AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
}
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
size_t output_num = AnfAlgo::GetOutputElementNum(cnode);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
outputs_format.emplace_back(kOpFormat_DEFAULT);
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(cnode, output_index));
}
builder->SetInputsFormat(inputs_format);
builder->SetInputsDeviceType(inputs_type);
builder->SetOutputsFormat(outputs_format);
builder->SetOutputsDeviceType(outputs_type);
}
void SetKernelInfoForValueNode(const ValueNodePtr &value_node) { void SetKernelInfoForValueNode(const ValueNodePtr &value_node) {
MS_EXCEPTION_IF_NULL(value_node); MS_EXCEPTION_IF_NULL(value_node);
auto kernel_info = std::make_shared<device::KernelInfo>(); auto kernel_info = std::make_shared<device::KernelInfo>();

View File

@ -97,9 +97,6 @@ void SetBackOffFlag(const KernelBuildInfoPtr &build_info, const CNodePtr &cnode)
// In some cases, there's no need to set input/output format and type for the node. // In some cases, there's no need to set input/output format and type for the node.
void SetKernelInfoForNewCNode(const CNodePtr &cnode, bool set_format_type = true); void SetKernelInfoForNewCNode(const CNodePtr &cnode, bool set_format_type = true);
// Set kernel info for dynamic param kernel.
void SetKernelInfoForDynamicParamKernel(const CNodePtr &cnode);
// Set kernel info for some value nodes manually. // Set kernel info for some value nodes manually.
void SetKernelInfoForValueNode(const ValueNodePtr &value_node); void SetKernelInfoForValueNode(const ValueNodePtr &value_node);

View File

@ -87,10 +87,14 @@ TypeId KernelBuildInfo::GetOutputDeviceType(size_t output_index) const {
KernelObjectType KernelBuildInfo::GetInputKernelObjectType(size_t input_index) const { KernelObjectType KernelBuildInfo::GetInputKernelObjectType(size_t input_index) const {
if (input_index >= inputs_kernel_object_type_.size()) { if (input_index >= inputs_kernel_object_type_.size()) {
#ifdef ENABLE_TUPLE_UNFOLD bool has_tuple_unfold =
MS_LOG(DEBUG) << "The input index [" << input_index std::any_of(inputs_kernel_object_type_.begin(), inputs_kernel_object_type_.end(),
<< "] is exceed the number of input:" << inputs_kernel_object_type_.size(); [](const KernelObjectType &obj_type) { return obj_type == KernelObjectType::TUPLE_UNFOLD; });
#endif // tuple unfold may correspond to many formats or dtypes
if (!has_tuple_unfold) {
MS_LOG(ERROR) << "The input index [" << input_index
<< "] is exceed the number of input:" << inputs_kernel_object_type_.size();
}
return KernelObjectType::UNKNOWN_TYPE; return KernelObjectType::UNKNOWN_TYPE;
} }
return inputs_kernel_object_type_[input_index]; return inputs_kernel_object_type_[input_index];
@ -98,10 +102,14 @@ KernelObjectType KernelBuildInfo::GetInputKernelObjectType(size_t input_index) c
KernelObjectType KernelBuildInfo::GetOutputKernelObjectType(size_t output_index) const { KernelObjectType KernelBuildInfo::GetOutputKernelObjectType(size_t output_index) const {
if (output_index >= outputs_kernel_object_type_.size()) { if (output_index >= outputs_kernel_object_type_.size()) {
#ifdef ENABLE_TUPLE_UNFOLD bool has_tuple_unfold =
MS_LOG(DEBUG) << "The output index [" << output_index std::any_of(outputs_kernel_object_type_.begin(), outputs_kernel_object_type_.end(),
<< "] is exceed the number of output:" << outputs_kernel_object_type_.size(); [](const KernelObjectType &obj_type) { return obj_type == KernelObjectType::TUPLE_UNFOLD; });
#endif // tuple unfold may correspond to many formats or dtypes
if (!has_tuple_unfold) {
MS_LOG(ERROR) << "The output index [" << output_index
<< "] is exceed the number of output:" << outputs_kernel_object_type_.size();
}
return KernelObjectType::UNKNOWN_TYPE; return KernelObjectType::UNKNOWN_TYPE;
} }
return outputs_kernel_object_type_[output_index]; return outputs_kernel_object_type_[output_index];

View File

@ -361,6 +361,10 @@ void UpdateDynamicKernelBuildInfo(const CNodePtr &kernel_node) {
auto output_object_types = auto output_object_types =
kernel::TypeIdToKernelObjectTypeForTupleUnfold(AnfAlgo::GetAllOutputObjectType(kernel_node)); kernel::TypeIdToKernelObjectTypeForTupleUnfold(AnfAlgo::GetAllOutputObjectType(kernel_node));
kernel::SetKernelObjectTypeBuildInfo(kernel_node, input_object_types, output_object_types); kernel::SetKernelObjectTypeBuildInfo(kernel_node, input_object_types, output_object_types);
kernel::UnfoldKernelBuildInfo(kernel_node);
if (!common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, kernel_node)) {
kernel::SetDynamicInputSizeAttr(kernel_node);
}
} }
bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo> &alternative_kernel_info, bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo> &alternative_kernel_info,

View File

@ -179,7 +179,6 @@ void CPUKernelExecutor::OptimizeGraph(const FuncGraphPtr &graph) const {
// Update Graph Dynamic Shape Attr. // Update Graph Dynamic Shape Attr.
opt::AddDynamicShapeAttrPass(kernel_graph); opt::AddDynamicShapeAttrPass(kernel_graph);
kernel_graph->SetKernelObjectTypesForUnrealNodes();
SetOperatorInfo(kernel_graph); SetOperatorInfo(kernel_graph);
// SetOperatorInfo may generate new node, so need set kernel object type again. // SetOperatorInfo may generate new node, so need set kernel object type again.
kernel_graph->SetKernelObjectTypesForUnrealNodes(); kernel_graph->SetKernelObjectTypesForUnrealNodes();