forked from mindspore-Ecosystem/mindspore
!48339 set dynamic input size for dynamic param kernel
Merge pull request !48339 from wYann/test_kobj
This commit is contained in:
commit
2ecba9eea2
|
@ -102,9 +102,7 @@ AnfNodePtr CreateNewNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &i
|
||||||
// methods. Only reset input types.
|
// methods. Only reset input types.
|
||||||
auto new_prim = GetValueNode<PrimitivePtr>(new_cnode->input(kIndex0));
|
auto new_prim = GetValueNode<PrimitivePtr>(new_cnode->input(kIndex0));
|
||||||
auto origin_prim = GetValueNode<PrimitivePtr>(origin_node->input(kIndex0));
|
auto origin_prim = GetValueNode<PrimitivePtr>(origin_node->input(kIndex0));
|
||||||
if (kernel::IsDynamicParamKernel(origin_prim->name())) {
|
if (IsPrimitiveEquals(new_prim, origin_prim) && !kernel::IsDynamicParamKernel(origin_prim->name())) {
|
||||||
SetKernelInfoForDynamicParamKernel(new_cnode);
|
|
||||||
} else if (IsPrimitiveEquals(new_prim, origin_prim)) {
|
|
||||||
SetKernelInfoForNewCNode(new_cnode, false);
|
SetKernelInfoForNewCNode(new_cnode, false);
|
||||||
} else {
|
} else {
|
||||||
SetKernelInfoForNewCNode(new_cnode, true);
|
SetKernelInfoForNewCNode(new_cnode, true);
|
||||||
|
@ -249,42 +247,6 @@ void SetKernelInfoForNewCNode(const CNodePtr &cnode, bool set_format_type) {
|
||||||
<< build_info->ToString();
|
<< build_info->ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetKernelInfoForDynamicParamKernel(const CNodePtr &cnode) {
|
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
|
||||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
|
||||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
||||||
cnode->set_kernel_info(kernel_info);
|
|
||||||
auto builder = std::make_shared<KernelBuildInfoBuilder>();
|
|
||||||
MS_EXCEPTION_IF_NULL(builder);
|
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), cnode.get());
|
|
||||||
std::vector<KernelObjectType> input_obj_type =
|
|
||||||
kernel::TypeIdToKernelObjectType(AnfAlgo::GetAllInputObjectType(cnode));
|
|
||||||
std::vector<KernelObjectType> output_obj_type =
|
|
||||||
kernel::TypeIdToKernelObjectType(AnfAlgo::GetAllOutputObjectType(cnode));
|
|
||||||
builder->SetInputsKernelObjectType(input_obj_type);
|
|
||||||
builder->SetOutputsKernelObjectType(output_obj_type);
|
|
||||||
// Set input and output format.
|
|
||||||
std::vector<std::string> inputs_format;
|
|
||||||
std::vector<TypeId> inputs_type;
|
|
||||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
|
|
||||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
|
||||||
auto input_node = common::AnfAlgo::GetInputNode(cnode, input_index);
|
|
||||||
inputs_format.emplace_back(kOpFormat_DEFAULT);
|
|
||||||
inputs_type.push_back(common::AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
|
|
||||||
}
|
|
||||||
std::vector<std::string> outputs_format;
|
|
||||||
std::vector<TypeId> outputs_type;
|
|
||||||
size_t output_num = AnfAlgo::GetOutputElementNum(cnode);
|
|
||||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
|
||||||
outputs_format.emplace_back(kOpFormat_DEFAULT);
|
|
||||||
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(cnode, output_index));
|
|
||||||
}
|
|
||||||
builder->SetInputsFormat(inputs_format);
|
|
||||||
builder->SetInputsDeviceType(inputs_type);
|
|
||||||
builder->SetOutputsFormat(outputs_format);
|
|
||||||
builder->SetOutputsDeviceType(outputs_type);
|
|
||||||
}
|
|
||||||
|
|
||||||
void SetKernelInfoForValueNode(const ValueNodePtr &value_node) {
|
void SetKernelInfoForValueNode(const ValueNodePtr &value_node) {
|
||||||
MS_EXCEPTION_IF_NULL(value_node);
|
MS_EXCEPTION_IF_NULL(value_node);
|
||||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||||
|
|
|
@ -97,9 +97,6 @@ void SetBackOffFlag(const KernelBuildInfoPtr &build_info, const CNodePtr &cnode)
|
||||||
// In some cases, there's no need to set input/output format and type for the node.
|
// In some cases, there's no need to set input/output format and type for the node.
|
||||||
void SetKernelInfoForNewCNode(const CNodePtr &cnode, bool set_format_type = true);
|
void SetKernelInfoForNewCNode(const CNodePtr &cnode, bool set_format_type = true);
|
||||||
|
|
||||||
// Set kernel info for dynamic param kernel.
|
|
||||||
void SetKernelInfoForDynamicParamKernel(const CNodePtr &cnode);
|
|
||||||
|
|
||||||
// Set kernel info for some value nodes manually.
|
// Set kernel info for some value nodes manually.
|
||||||
void SetKernelInfoForValueNode(const ValueNodePtr &value_node);
|
void SetKernelInfoForValueNode(const ValueNodePtr &value_node);
|
||||||
|
|
||||||
|
|
|
@ -87,10 +87,14 @@ TypeId KernelBuildInfo::GetOutputDeviceType(size_t output_index) const {
|
||||||
|
|
||||||
KernelObjectType KernelBuildInfo::GetInputKernelObjectType(size_t input_index) const {
|
KernelObjectType KernelBuildInfo::GetInputKernelObjectType(size_t input_index) const {
|
||||||
if (input_index >= inputs_kernel_object_type_.size()) {
|
if (input_index >= inputs_kernel_object_type_.size()) {
|
||||||
#ifdef ENABLE_TUPLE_UNFOLD
|
bool has_tuple_unfold =
|
||||||
MS_LOG(DEBUG) << "The input index [" << input_index
|
std::any_of(inputs_kernel_object_type_.begin(), inputs_kernel_object_type_.end(),
|
||||||
<< "] is exceed the number of input:" << inputs_kernel_object_type_.size();
|
[](const KernelObjectType &obj_type) { return obj_type == KernelObjectType::TUPLE_UNFOLD; });
|
||||||
#endif
|
// tuple unfold may correspond to many formats or dtypes
|
||||||
|
if (!has_tuple_unfold) {
|
||||||
|
MS_LOG(ERROR) << "The input index [" << input_index
|
||||||
|
<< "] is exceed the number of input:" << inputs_kernel_object_type_.size();
|
||||||
|
}
|
||||||
return KernelObjectType::UNKNOWN_TYPE;
|
return KernelObjectType::UNKNOWN_TYPE;
|
||||||
}
|
}
|
||||||
return inputs_kernel_object_type_[input_index];
|
return inputs_kernel_object_type_[input_index];
|
||||||
|
@ -98,10 +102,14 @@ KernelObjectType KernelBuildInfo::GetInputKernelObjectType(size_t input_index) c
|
||||||
|
|
||||||
KernelObjectType KernelBuildInfo::GetOutputKernelObjectType(size_t output_index) const {
|
KernelObjectType KernelBuildInfo::GetOutputKernelObjectType(size_t output_index) const {
|
||||||
if (output_index >= outputs_kernel_object_type_.size()) {
|
if (output_index >= outputs_kernel_object_type_.size()) {
|
||||||
#ifdef ENABLE_TUPLE_UNFOLD
|
bool has_tuple_unfold =
|
||||||
MS_LOG(DEBUG) << "The output index [" << output_index
|
std::any_of(outputs_kernel_object_type_.begin(), outputs_kernel_object_type_.end(),
|
||||||
<< "] is exceed the number of output:" << outputs_kernel_object_type_.size();
|
[](const KernelObjectType &obj_type) { return obj_type == KernelObjectType::TUPLE_UNFOLD; });
|
||||||
#endif
|
// tuple unfold may correspond to many formats or dtypes
|
||||||
|
if (!has_tuple_unfold) {
|
||||||
|
MS_LOG(ERROR) << "The output index [" << output_index
|
||||||
|
<< "] is exceed the number of output:" << outputs_kernel_object_type_.size();
|
||||||
|
}
|
||||||
return KernelObjectType::UNKNOWN_TYPE;
|
return KernelObjectType::UNKNOWN_TYPE;
|
||||||
}
|
}
|
||||||
return outputs_kernel_object_type_[output_index];
|
return outputs_kernel_object_type_[output_index];
|
||||||
|
|
|
@ -361,6 +361,10 @@ void UpdateDynamicKernelBuildInfo(const CNodePtr &kernel_node) {
|
||||||
auto output_object_types =
|
auto output_object_types =
|
||||||
kernel::TypeIdToKernelObjectTypeForTupleUnfold(AnfAlgo::GetAllOutputObjectType(kernel_node));
|
kernel::TypeIdToKernelObjectTypeForTupleUnfold(AnfAlgo::GetAllOutputObjectType(kernel_node));
|
||||||
kernel::SetKernelObjectTypeBuildInfo(kernel_node, input_object_types, output_object_types);
|
kernel::SetKernelObjectTypeBuildInfo(kernel_node, input_object_types, output_object_types);
|
||||||
|
kernel::UnfoldKernelBuildInfo(kernel_node);
|
||||||
|
if (!common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, kernel_node)) {
|
||||||
|
kernel::SetDynamicInputSizeAttr(kernel_node);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo> &alternative_kernel_info,
|
bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo> &alternative_kernel_info,
|
||||||
|
|
|
@ -179,7 +179,6 @@ void CPUKernelExecutor::OptimizeGraph(const FuncGraphPtr &graph) const {
|
||||||
// Update Graph Dynamic Shape Attr.
|
// Update Graph Dynamic Shape Attr.
|
||||||
opt::AddDynamicShapeAttrPass(kernel_graph);
|
opt::AddDynamicShapeAttrPass(kernel_graph);
|
||||||
|
|
||||||
kernel_graph->SetKernelObjectTypesForUnrealNodes();
|
|
||||||
SetOperatorInfo(kernel_graph);
|
SetOperatorInfo(kernel_graph);
|
||||||
// SetOperatorInfo may generate new node, so need set kernel object type again.
|
// SetOperatorInfo may generate new node, so need set kernel object type again.
|
||||||
kernel_graph->SetKernelObjectTypesForUnrealNodes();
|
kernel_graph->SetKernelObjectTypesForUnrealNodes();
|
||||||
|
|
Loading…
Reference in New Issue