!48720 DropoutForGE suport dynamicshape

Merge pull request !48720 from TuDouNi/ge2
This commit is contained in:
i-robot 2023-02-11 08:23:50 +00:00 committed by Gitee
commit 961651e82d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 77 additions and 22 deletions

View File

@ -45,6 +45,75 @@ std::vector<int64_t> CalGenMaskOutputShape(const std::vector<int64_t> &shape) {
return {ret};
}
abstract::ShapePtr GetDropoutInputShape(const CNodePtr &dropout_node) {
MS_EXCEPTION_IF_NULL(dropout_node);
auto input = dropout_node->input(kInputIndexOne);
MS_EXCEPTION_IF_NULL(input);
auto input_base_shape = input->Shape();
MS_EXCEPTION_IF_NULL(input_base_shape);
auto input_shape = input_base_shape->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(input_shape);
return input_shape;
}
CNodePtr CreateDynamicShapeCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node_input,
const abstract::ShapePtr &input_shape) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(input_shape);
std::vector<AnfNodePtr> dynamic_shape_inputs{NewValueNode(std::make_shared<Primitive>("Shape")), node_input};
CNodePtr dynamic_shape = func_graph->NewCNode(dynamic_shape_inputs);
MS_EXCEPTION_IF_NULL(dynamic_shape);
ShapeVector tensor_shp({static_cast<int64_t>(input_shape->shape().size())});
auto dynamic_shape_abstract =
std::make_shared<abstract::AbstractTensor>(kInt64, std::make_shared<abstract::Shape>(tensor_shp));
MS_EXCEPTION_IF_NULL(dynamic_shape_abstract);
dynamic_shape->set_abstract(dynamic_shape_abstract);
return dynamic_shape;
}
CNodePtr CreateDropoutGenMaskCNode(const FuncGraphPtr &func_graph, const CNodePtr &dropout,
const AnfNodePtr &keep_prob_value, const abstract::ShapePtr &input_shape,
const ValuePtr &seed_0, const ValuePtr &seed_1) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(dropout);
MS_EXCEPTION_IF_NULL(input_shape);
std::vector<AnfNodePtr> dropout_gen_mask_inputs =
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName))};
if (input_shape->IsDynamic() || common::AnfAlgo::HasNodeAttr(kAttrMutableKernel, dropout)) {
CNodePtr dynamic_shape = CreateDynamicShapeCNode(func_graph, dropout->input(kIndex1), input_shape);
dynamic_shape->set_scope(dropout->scope());
dropout_gen_mask_inputs.push_back(dynamic_shape);
dropout_gen_mask_inputs.push_back(keep_prob_value);
} else {
auto shape_value = CreateShapeValueNode(func_graph, input_shape->shape(), true);
dropout_gen_mask_inputs.push_back(shape_value);
dropout_gen_mask_inputs.push_back(keep_prob_value);
}
CNodePtr dropout_gen_mask = opt::NewCNode(dropout_gen_mask_inputs, func_graph, {dropout});
MS_EXCEPTION_IF_NULL(dropout_gen_mask);
if (dropout->HasPrimalAttr(kAttrFusion)) {
dropout_gen_mask->AddPrimalAttr(kAttrFusion, dropout->GetPrimalAttr(kAttrFusion));
}
std::shared_ptr<abstract::AbstractTensor> gen_mask_abstract;
if (input_shape->IsDynamic() || common::AnfAlgo::HasNodeAttr(kAttrMutableKernel, dropout)) {
ShapeVector mask_shp = {abstract::Shape::kShapeDimAny};
auto gen_mask_shape = std::make_shared<abstract::Shape>(mask_shp);
MS_EXCEPTION_IF_NULL(gen_mask_shape);
gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, gen_mask_shape);
} else {
auto gen_mask_shape = CalGenMaskOutputShape(input_shape->shape());
gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, gen_mask_shape);
}
MS_EXCEPTION_IF_NULL(gen_mask_abstract);
dropout_gen_mask->set_abstract(gen_mask_abstract);
dropout_gen_mask->set_scope(dropout->scope());
common::AnfAlgo::CopyNodeAttrs(dropout, dropout_gen_mask);
common::AnfAlgo::SetNodeAttr(kSeed0AttrName, seed_0, dropout_gen_mask);
common::AnfAlgo::SetNodeAttr(kSeed1AttrName, seed_1, dropout_gen_mask);
return dropout_gen_mask;
}
const BaseRef DropoutForGE::DefinePattern() const {
VarPtr x1 = std::make_shared<Var>();
return VectorRef({prim::kPrimDropout, x1});
@ -61,21 +130,10 @@ const AnfNodePtr DropoutForGE::Process(const FuncGraphPtr &graph, const AnfNodeP
auto seed_0 = origin_prim->GetAttr(kSeed0AttrName);
auto seed_1 = origin_prim->GetAttr(kSeed1AttrName);
auto shape = dropout_node->input(kInputIndexOne)->Shape();
MS_EXCEPTION_IF_NULL(shape);
auto input_shape_ptr = shape->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(input_shape_ptr);
if (input_shape_ptr->IsDynamic()) {
MS_LOG(EXCEPTION) << "Dropout does not support dynamic shape in GE backend for now";
}
auto shape_vector = input_shape_ptr->shape();
auto shape_value = MakeValue(shape_vector);
auto gen_mask_input_shape = NewValueNode(shape_value);
gen_mask_input_shape->set_abstract(shape_value->ToAbstract());
auto input_shape_ptr = GetDropoutInputShape(dropout_node);
auto keep_prob_node = NewValueNode(keep_prob);
MS_EXCEPTION_IF_NULL(keep_prob_node);
keep_prob_node->set_abstract(keep_prob->ToAbstract());
AnfNodePtr gen_mask_input_prob = keep_prob_node;
auto dtype_id = common::AnfAlgo::GetPrevNodeOutputInferDataType(dropout_node, 0);
if (dtype_id == TypeId::kNumberTypeFloat16) {
@ -94,15 +152,8 @@ const AnfNodePtr DropoutForGE::Process(const FuncGraphPtr &graph, const AnfNodeP
cast_node->set_abstract(cast_abstract);
gen_mask_input_prob = cast_node;
}
auto mask_shape = CalGenMaskOutputShape(shape_vector);
auto dropout_gen_mask_node = node->func_graph()->NewCNode(
{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName)), gen_mask_input_shape, gen_mask_input_prob});
MS_EXCEPTION_IF_NULL(dropout_gen_mask_node);
common::AnfAlgo::SetNodeAttr(kSeed0AttrName, seed_0, dropout_gen_mask_node);
common::AnfAlgo::SetNodeAttr(kSeed1AttrName, seed_1, dropout_gen_mask_node);
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, mask_shape);
dropout_gen_mask_node->set_abstract(gen_mask_abstract);
CNodePtr dropout_gen_mask_node =
CreateDropoutGenMaskCNode(graph, dropout_node, gen_mask_input_prob, input_shape_ptr, seed_0, seed_1);
auto dropout_do_mask_node =
node->func_graph()->NewCNode({NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
@ -112,7 +163,7 @@ const AnfNodePtr DropoutForGE::Process(const FuncGraphPtr &graph, const AnfNodeP
std::vector<abstract::AbstractBasePtr> make_tuple_input;
make_tuple_input.push_back(do_mask_abstract);
make_tuple_input.push_back(gen_mask_abstract);
make_tuple_input.push_back(dropout_gen_mask_node->abstract());
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), dropout_do_mask_node,
dropout_gen_mask_node};
auto new_make_tuple_node = NewCNode(make_tuple_inputs, graph);

View File

@ -348,6 +348,7 @@ constexpr const char kNameMaxPoolV3[] = "MaxPoolV3";
constexpr const char kNameAvgPoolV2[] = "AvgPoolV2";
constexpr const char kNameShape[] = "Shape";
constexpr const char kNameTensorShape[] = "TensorShape";
constexpr const char kNameDynamicShape[] = "DynamicShape";
constexpr const char kNameGather[] = "Gather";
constexpr const char kNameUnsqueeze[] = "Unsqueeze";
constexpr const char kNamePadV3[] = "PadV3";

View File

@ -55,6 +55,9 @@ REG_ADPT_DESC(Shape, kNameShape, ADPT_DESC(Shape))
// TensorShape
REG_ADPT_DESC(TensorShape, kNameTensorShape, ADPT_DESC(Shape))
// DynamicShape
REG_ADPT_DESC(DynamicShape, kNameDynamicShape, ADPT_DESC(Shape))
// GetShape
INPUT_MAP(GetShape) = EMPTY_INPUT_MAP;
DYN_INPUT_MAP(GetShape) = {{1, DYN_INPUT_DESC(x)}};