!49834 fix dropout dynamic shape

Merge pull request !49834 from 王禹程/dyn_dropout
This commit is contained in:
i-robot 2023-03-08 08:44:08 +00:00 committed by Gitee
commit 8b137ce3b7
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 19 additions and 19 deletions

View File

@ -124,6 +124,22 @@ std::vector<int64_t> CalGenMaskV3OutputShape(const std::vector<int64_t> &shape,
return shape;
}
std::shared_ptr<abstract::AbstractTensor> GetDropoutMaskShapeAbstract(const abstract::ShapePtr &input_shape,
const CNodePtr &dropout, bool use_v3) {
std::shared_ptr<abstract::AbstractTensor> gen_mask_abstract;
if (input_shape->IsDynamic() || (dropout != nullptr && common::AnfAlgo::HasNodeAttr(kAttrMutableKernel, dropout))) {
ShapeVector mask_shp = {abstract::Shape::kShapeDimAny};
auto gen_mask_shape = std::make_shared<abstract::Shape>(mask_shp);
MS_EXCEPTION_IF_NULL(gen_mask_shape);
gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, gen_mask_shape);
} else {
auto gen_mask_shape = use_v3 ? CalGenMaskV3OutputShape(input_shape->shape(), kNumberTypeUInt8)
: CalGenMaskOutputShape(input_shape->shape());
gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, gen_mask_shape);
}
return gen_mask_abstract;
}
bool NeedUpdate(const CNodePtr &getitem_cnode) {
MS_EXCEPTION_IF_NULL(getitem_cnode);
MS_EXCEPTION_IF_NULL(getitem_cnode->input(kIndex2));
@ -200,17 +216,7 @@ CNodePtr CreateDropoutGenMaskCNode(const FuncGraphPtr &func_graph, const CNodePt
dropout_gen_mask->AddPrimalAttr(kAttrFusion, dropout->GetPrimalAttr(kAttrFusion));
}
std::shared_ptr<abstract::AbstractTensor> gen_mask_abstract;
if (input_shape->IsDynamic() || common::AnfAlgo::HasNodeAttr(kAttrMutableKernel, dropout)) {
ShapeVector mask_shp = {abstract::Shape::kShapeDimAny};
auto gen_mask_shape = std::make_shared<abstract::Shape>(mask_shp);
MS_EXCEPTION_IF_NULL(gen_mask_shape);
gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, gen_mask_shape);
} else {
auto gen_mask_shape = use_v3 ? CalGenMaskV3OutputShape(input_shape->shape(), kNumberTypeUInt8)
: CalGenMaskOutputShape(input_shape->shape());
gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, gen_mask_shape);
}
auto gen_mask_abstract = GetDropoutMaskShapeAbstract(input_shape, dropout, use_v3);
MS_EXCEPTION_IF_NULL(gen_mask_abstract);
dropout_gen_mask->set_abstract(gen_mask_abstract);
dropout_gen_mask->set_scope(dropout->scope());
@ -345,10 +351,7 @@ AnfNodePtr BuildDropoutDoMask(const PatternMap &m, const AnfNodePtr &) {
// update abstract
auto mask_abstract = mask_input->abstract();
MS_EXCEPTION_IF_NULL(mask_abstract);
auto grad_shape_vec = grad_input_shape->shape();
auto mask_shape =
use_v3 ? CalGenMaskV3OutputShape(grad_shape_vec, kNumberTypeUInt8) : CalGenMaskOutputShape(grad_shape_vec);
mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, mask_shape);
mask_abstract = GetDropoutMaskShapeAbstract(grad_input_shape, nullptr, use_v3);
mask_input->set_abstract(mask_abstract);
// update kernel info
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
@ -375,10 +378,7 @@ AnfNodePtr BuildDropoutDoMask(const PatternMap &m, const AnfNodePtr &) {
// modify mask abstract
auto mask_abstract = mask_input->abstract();
MS_EXCEPTION_IF_NULL(mask_abstract);
auto grad_shape_vec = grad_input_shape->shape();
auto mask_shape =
use_v3 ? CalGenMaskV3OutputShape(grad_shape_vec, kNumberTypeUInt8) : CalGenMaskOutputShape(grad_shape_vec);
mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, mask_shape);
mask_abstract = GetDropoutMaskShapeAbstract(grad_input_shape, nullptr, use_v3);
mask_input->set_abstract(mask_abstract);
abs.push_back(mask_abstract);
auto new_abstract = std::make_shared<abstract::AbstractTuple>(abs);