!48512 fix dropout mask dtype error in ms_function

Merge pull request !48512 from chujinjin/fix_dropout_output_dtype_error_in_ms_function
This commit is contained in:
i-robot 2023-02-09 02:39:16 +00:00 committed by Gitee
commit e0915062cf
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 29 additions and 0 deletions

View File

@ -499,6 +499,7 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph,
// DropoutGrad may not in the same graph with Dropout in heterogeneous scene, and mask input which is a parameter
// in that scene, need to be updated.
auto mask_input = dropout_grad_cnode->input(kIndex2);
MS_EXCEPTION_IF_NULL(mask_input);
if (mask_input->isa<Parameter>()) {
// update abstract
auto mask_abstract = mask_input->abstract();
@ -514,6 +515,34 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph,
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{kNumberTypeUInt8});
kernel_build_info_builder->SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), mask_input.get());
} else if (IsPrimitiveCNode(mask_input, prim::kPrimTupleGetItem)) {
auto mask_input_cnode = mask_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(mask_input_cnode);
auto tuple_input = mask_input_cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(tuple_input);
if (IsValueNode<ValueTuple>(tuple_input)) {
auto tuple_abstract = tuple_input->abstract();
MS_EXCEPTION_IF_NULL(tuple_abstract);
abstract::AbstractSequencePtr sequence_abstract_ptr = tuple_abstract->cast<abstract::AbstractSequencePtr>();
MS_EXCEPTION_IF_NULL(sequence_abstract_ptr);
// Dropout's outputs only have two elements.
if (sequence_abstract_ptr->size() != kIndex2) {
MS_LOG(EXCEPTION) << "Dropout's outputs have more than two elements, " << sequence_abstract_ptr->size();
}
abstract::AbstractBasePtrList abs{};
abs.push_back(sequence_abstract_ptr->elements()[0]);
// modify mask abstract
auto mask_abstract = mask_input->abstract();
MS_EXCEPTION_IF_NULL(mask_abstract);
auto grad_shape_vec = grad_input_shape->shape();
auto mask_shape =
use_v3 ? CalGenMaskV3OutputShape(grad_shape_vec, kNumberTypeUInt8) : CalGenMaskOutputShape(grad_shape_vec);
mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, mask_shape);
mask_input->set_abstract(mask_abstract);
abs.push_back(mask_abstract);
auto new_abstract = std::make_shared<abstract::AbstractTuple>(abs);
tuple_input->set_abstract(new_abstract);
}
}
// CreateDropoutDoMask