forked from mindspore-Ecosystem/mindspore
!48512 fix dropout mask dtype error in ms_function
Merge pull request !48512 from chujinjin/fix_dropout_output_dtype_error_in_ms_function
This commit is contained in:
commit
e0915062cf
|
@ -499,6 +499,7 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph,
|
|||
// DropoutGrad may not in the same graph with Dropout in heterogeneous scene, and mask input which is a parameter
|
||||
// in that scene, need to be updated.
|
||||
auto mask_input = dropout_grad_cnode->input(kIndex2);
|
||||
MS_EXCEPTION_IF_NULL(mask_input);
|
||||
if (mask_input->isa<Parameter>()) {
|
||||
// update abstract
|
||||
auto mask_abstract = mask_input->abstract();
|
||||
|
@ -514,6 +515,34 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph,
|
|||
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{kNumberTypeUInt8});
|
||||
kernel_build_info_builder->SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), mask_input.get());
|
||||
} else if (IsPrimitiveCNode(mask_input, prim::kPrimTupleGetItem)) {
|
||||
auto mask_input_cnode = mask_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(mask_input_cnode);
|
||||
auto tuple_input = mask_input_cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(tuple_input);
|
||||
if (IsValueNode<ValueTuple>(tuple_input)) {
|
||||
auto tuple_abstract = tuple_input->abstract();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
abstract::AbstractSequencePtr sequence_abstract_ptr = tuple_abstract->cast<abstract::AbstractSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(sequence_abstract_ptr);
|
||||
// Dropout's outputs only have two elements.
|
||||
if (sequence_abstract_ptr->size() != kIndex2) {
|
||||
MS_LOG(EXCEPTION) << "Dropout's outputs have more than two elements, " << sequence_abstract_ptr->size();
|
||||
}
|
||||
abstract::AbstractBasePtrList abs{};
|
||||
abs.push_back(sequence_abstract_ptr->elements()[0]);
|
||||
// modify mask abstract
|
||||
auto mask_abstract = mask_input->abstract();
|
||||
MS_EXCEPTION_IF_NULL(mask_abstract);
|
||||
auto grad_shape_vec = grad_input_shape->shape();
|
||||
auto mask_shape =
|
||||
use_v3 ? CalGenMaskV3OutputShape(grad_shape_vec, kNumberTypeUInt8) : CalGenMaskOutputShape(grad_shape_vec);
|
||||
mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, mask_shape);
|
||||
mask_input->set_abstract(mask_abstract);
|
||||
abs.push_back(mask_abstract);
|
||||
auto new_abstract = std::make_shared<abstract::AbstractTuple>(abs);
|
||||
tuple_input->set_abstract(new_abstract);
|
||||
}
|
||||
}
|
||||
|
||||
// CreateDropoutDoMask
|
||||
|
|
Loading…
Reference in New Issue