!49834 fix dropout dynamic shape
Merge pull request !49834 from 王禹程/dyn_dropout
This commit is contained in:
commit
8b137ce3b7
|
@ -124,6 +124,22 @@ std::vector<int64_t> CalGenMaskV3OutputShape(const std::vector<int64_t> &shape,
|
|||
return shape;
|
||||
}
|
||||
|
||||
std::shared_ptr<abstract::AbstractTensor> GetDropoutMaskShapeAbstract(const abstract::ShapePtr &input_shape,
|
||||
const CNodePtr &dropout, bool use_v3) {
|
||||
std::shared_ptr<abstract::AbstractTensor> gen_mask_abstract;
|
||||
if (input_shape->IsDynamic() || (dropout != nullptr && common::AnfAlgo::HasNodeAttr(kAttrMutableKernel, dropout))) {
|
||||
ShapeVector mask_shp = {abstract::Shape::kShapeDimAny};
|
||||
auto gen_mask_shape = std::make_shared<abstract::Shape>(mask_shp);
|
||||
MS_EXCEPTION_IF_NULL(gen_mask_shape);
|
||||
gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, gen_mask_shape);
|
||||
} else {
|
||||
auto gen_mask_shape = use_v3 ? CalGenMaskV3OutputShape(input_shape->shape(), kNumberTypeUInt8)
|
||||
: CalGenMaskOutputShape(input_shape->shape());
|
||||
gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, gen_mask_shape);
|
||||
}
|
||||
return gen_mask_abstract;
|
||||
}
|
||||
|
||||
bool NeedUpdate(const CNodePtr &getitem_cnode) {
|
||||
MS_EXCEPTION_IF_NULL(getitem_cnode);
|
||||
MS_EXCEPTION_IF_NULL(getitem_cnode->input(kIndex2));
|
||||
|
@ -200,17 +216,7 @@ CNodePtr CreateDropoutGenMaskCNode(const FuncGraphPtr &func_graph, const CNodePt
|
|||
dropout_gen_mask->AddPrimalAttr(kAttrFusion, dropout->GetPrimalAttr(kAttrFusion));
|
||||
}
|
||||
|
||||
std::shared_ptr<abstract::AbstractTensor> gen_mask_abstract;
|
||||
if (input_shape->IsDynamic() || common::AnfAlgo::HasNodeAttr(kAttrMutableKernel, dropout)) {
|
||||
ShapeVector mask_shp = {abstract::Shape::kShapeDimAny};
|
||||
auto gen_mask_shape = std::make_shared<abstract::Shape>(mask_shp);
|
||||
MS_EXCEPTION_IF_NULL(gen_mask_shape);
|
||||
gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, gen_mask_shape);
|
||||
} else {
|
||||
auto gen_mask_shape = use_v3 ? CalGenMaskV3OutputShape(input_shape->shape(), kNumberTypeUInt8)
|
||||
: CalGenMaskOutputShape(input_shape->shape());
|
||||
gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, gen_mask_shape);
|
||||
}
|
||||
auto gen_mask_abstract = GetDropoutMaskShapeAbstract(input_shape, dropout, use_v3);
|
||||
MS_EXCEPTION_IF_NULL(gen_mask_abstract);
|
||||
dropout_gen_mask->set_abstract(gen_mask_abstract);
|
||||
dropout_gen_mask->set_scope(dropout->scope());
|
||||
|
@ -345,10 +351,7 @@ AnfNodePtr BuildDropoutDoMask(const PatternMap &m, const AnfNodePtr &) {
|
|||
// update abstract
|
||||
auto mask_abstract = mask_input->abstract();
|
||||
MS_EXCEPTION_IF_NULL(mask_abstract);
|
||||
auto grad_shape_vec = grad_input_shape->shape();
|
||||
auto mask_shape =
|
||||
use_v3 ? CalGenMaskV3OutputShape(grad_shape_vec, kNumberTypeUInt8) : CalGenMaskOutputShape(grad_shape_vec);
|
||||
mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, mask_shape);
|
||||
mask_abstract = GetDropoutMaskShapeAbstract(grad_input_shape, nullptr, use_v3);
|
||||
mask_input->set_abstract(mask_abstract);
|
||||
// update kernel info
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
|
@ -375,10 +378,7 @@ AnfNodePtr BuildDropoutDoMask(const PatternMap &m, const AnfNodePtr &) {
|
|||
// modify mask abstract
|
||||
auto mask_abstract = mask_input->abstract();
|
||||
MS_EXCEPTION_IF_NULL(mask_abstract);
|
||||
auto grad_shape_vec = grad_input_shape->shape();
|
||||
auto mask_shape =
|
||||
use_v3 ? CalGenMaskV3OutputShape(grad_shape_vec, kNumberTypeUInt8) : CalGenMaskOutputShape(grad_shape_vec);
|
||||
mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, mask_shape);
|
||||
mask_abstract = GetDropoutMaskShapeAbstract(grad_input_shape, nullptr, use_v3);
|
||||
mask_input->set_abstract(mask_abstract);
|
||||
abs.push_back(mask_abstract);
|
||||
auto new_abstract = std::make_shared<abstract::AbstractTuple>(abs);
|
||||
|
|
Loading…
Reference in New Issue