forked from mindspore-Ecosystem/mindspore
!48720 DropoutForGE suport dynamicshape
Merge pull request !48720 from TuDouNi/ge2
This commit is contained in:
commit
961651e82d
|
@ -45,6 +45,75 @@ std::vector<int64_t> CalGenMaskOutputShape(const std::vector<int64_t> &shape) {
|
|||
return {ret};
|
||||
}
|
||||
|
||||
abstract::ShapePtr GetDropoutInputShape(const CNodePtr &dropout_node) {
|
||||
MS_EXCEPTION_IF_NULL(dropout_node);
|
||||
auto input = dropout_node->input(kInputIndexOne);
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
auto input_base_shape = input->Shape();
|
||||
MS_EXCEPTION_IF_NULL(input_base_shape);
|
||||
auto input_shape = input_base_shape->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_shape);
|
||||
return input_shape;
|
||||
}
|
||||
|
||||
CNodePtr CreateDynamicShapeCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node_input,
|
||||
const abstract::ShapePtr &input_shape) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(input_shape);
|
||||
std::vector<AnfNodePtr> dynamic_shape_inputs{NewValueNode(std::make_shared<Primitive>("Shape")), node_input};
|
||||
CNodePtr dynamic_shape = func_graph->NewCNode(dynamic_shape_inputs);
|
||||
MS_EXCEPTION_IF_NULL(dynamic_shape);
|
||||
ShapeVector tensor_shp({static_cast<int64_t>(input_shape->shape().size())});
|
||||
auto dynamic_shape_abstract =
|
||||
std::make_shared<abstract::AbstractTensor>(kInt64, std::make_shared<abstract::Shape>(tensor_shp));
|
||||
MS_EXCEPTION_IF_NULL(dynamic_shape_abstract);
|
||||
dynamic_shape->set_abstract(dynamic_shape_abstract);
|
||||
return dynamic_shape;
|
||||
}
|
||||
|
||||
CNodePtr CreateDropoutGenMaskCNode(const FuncGraphPtr &func_graph, const CNodePtr &dropout,
|
||||
const AnfNodePtr &keep_prob_value, const abstract::ShapePtr &input_shape,
|
||||
const ValuePtr &seed_0, const ValuePtr &seed_1) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(dropout);
|
||||
MS_EXCEPTION_IF_NULL(input_shape);
|
||||
std::vector<AnfNodePtr> dropout_gen_mask_inputs =
|
||||
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName))};
|
||||
if (input_shape->IsDynamic() || common::AnfAlgo::HasNodeAttr(kAttrMutableKernel, dropout)) {
|
||||
CNodePtr dynamic_shape = CreateDynamicShapeCNode(func_graph, dropout->input(kIndex1), input_shape);
|
||||
dynamic_shape->set_scope(dropout->scope());
|
||||
dropout_gen_mask_inputs.push_back(dynamic_shape);
|
||||
dropout_gen_mask_inputs.push_back(keep_prob_value);
|
||||
} else {
|
||||
auto shape_value = CreateShapeValueNode(func_graph, input_shape->shape(), true);
|
||||
dropout_gen_mask_inputs.push_back(shape_value);
|
||||
dropout_gen_mask_inputs.push_back(keep_prob_value);
|
||||
}
|
||||
CNodePtr dropout_gen_mask = opt::NewCNode(dropout_gen_mask_inputs, func_graph, {dropout});
|
||||
MS_EXCEPTION_IF_NULL(dropout_gen_mask);
|
||||
if (dropout->HasPrimalAttr(kAttrFusion)) {
|
||||
dropout_gen_mask->AddPrimalAttr(kAttrFusion, dropout->GetPrimalAttr(kAttrFusion));
|
||||
}
|
||||
|
||||
std::shared_ptr<abstract::AbstractTensor> gen_mask_abstract;
|
||||
if (input_shape->IsDynamic() || common::AnfAlgo::HasNodeAttr(kAttrMutableKernel, dropout)) {
|
||||
ShapeVector mask_shp = {abstract::Shape::kShapeDimAny};
|
||||
auto gen_mask_shape = std::make_shared<abstract::Shape>(mask_shp);
|
||||
MS_EXCEPTION_IF_NULL(gen_mask_shape);
|
||||
gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, gen_mask_shape);
|
||||
} else {
|
||||
auto gen_mask_shape = CalGenMaskOutputShape(input_shape->shape());
|
||||
gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, gen_mask_shape);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(gen_mask_abstract);
|
||||
dropout_gen_mask->set_abstract(gen_mask_abstract);
|
||||
dropout_gen_mask->set_scope(dropout->scope());
|
||||
common::AnfAlgo::CopyNodeAttrs(dropout, dropout_gen_mask);
|
||||
common::AnfAlgo::SetNodeAttr(kSeed0AttrName, seed_0, dropout_gen_mask);
|
||||
common::AnfAlgo::SetNodeAttr(kSeed1AttrName, seed_1, dropout_gen_mask);
|
||||
return dropout_gen_mask;
|
||||
}
|
||||
|
||||
const BaseRef DropoutForGE::DefinePattern() const {
|
||||
VarPtr x1 = std::make_shared<Var>();
|
||||
return VectorRef({prim::kPrimDropout, x1});
|
||||
|
@ -61,21 +130,10 @@ const AnfNodePtr DropoutForGE::Process(const FuncGraphPtr &graph, const AnfNodeP
|
|||
auto seed_0 = origin_prim->GetAttr(kSeed0AttrName);
|
||||
auto seed_1 = origin_prim->GetAttr(kSeed1AttrName);
|
||||
|
||||
auto shape = dropout_node->input(kInputIndexOne)->Shape();
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
auto input_shape_ptr = shape->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_shape_ptr);
|
||||
if (input_shape_ptr->IsDynamic()) {
|
||||
MS_LOG(EXCEPTION) << "Dropout does not support dynamic shape in GE backend for now";
|
||||
}
|
||||
auto shape_vector = input_shape_ptr->shape();
|
||||
auto shape_value = MakeValue(shape_vector);
|
||||
auto gen_mask_input_shape = NewValueNode(shape_value);
|
||||
gen_mask_input_shape->set_abstract(shape_value->ToAbstract());
|
||||
auto input_shape_ptr = GetDropoutInputShape(dropout_node);
|
||||
auto keep_prob_node = NewValueNode(keep_prob);
|
||||
MS_EXCEPTION_IF_NULL(keep_prob_node);
|
||||
keep_prob_node->set_abstract(keep_prob->ToAbstract());
|
||||
|
||||
AnfNodePtr gen_mask_input_prob = keep_prob_node;
|
||||
auto dtype_id = common::AnfAlgo::GetPrevNodeOutputInferDataType(dropout_node, 0);
|
||||
if (dtype_id == TypeId::kNumberTypeFloat16) {
|
||||
|
@ -94,15 +152,8 @@ const AnfNodePtr DropoutForGE::Process(const FuncGraphPtr &graph, const AnfNodeP
|
|||
cast_node->set_abstract(cast_abstract);
|
||||
gen_mask_input_prob = cast_node;
|
||||
}
|
||||
|
||||
auto mask_shape = CalGenMaskOutputShape(shape_vector);
|
||||
auto dropout_gen_mask_node = node->func_graph()->NewCNode(
|
||||
{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName)), gen_mask_input_shape, gen_mask_input_prob});
|
||||
MS_EXCEPTION_IF_NULL(dropout_gen_mask_node);
|
||||
common::AnfAlgo::SetNodeAttr(kSeed0AttrName, seed_0, dropout_gen_mask_node);
|
||||
common::AnfAlgo::SetNodeAttr(kSeed1AttrName, seed_1, dropout_gen_mask_node);
|
||||
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, mask_shape);
|
||||
dropout_gen_mask_node->set_abstract(gen_mask_abstract);
|
||||
CNodePtr dropout_gen_mask_node =
|
||||
CreateDropoutGenMaskCNode(graph, dropout_node, gen_mask_input_prob, input_shape_ptr, seed_0, seed_1);
|
||||
|
||||
auto dropout_do_mask_node =
|
||||
node->func_graph()->NewCNode({NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
|
||||
|
@ -112,7 +163,7 @@ const AnfNodePtr DropoutForGE::Process(const FuncGraphPtr &graph, const AnfNodeP
|
|||
|
||||
std::vector<abstract::AbstractBasePtr> make_tuple_input;
|
||||
make_tuple_input.push_back(do_mask_abstract);
|
||||
make_tuple_input.push_back(gen_mask_abstract);
|
||||
make_tuple_input.push_back(dropout_gen_mask_node->abstract());
|
||||
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), dropout_do_mask_node,
|
||||
dropout_gen_mask_node};
|
||||
auto new_make_tuple_node = NewCNode(make_tuple_inputs, graph);
|
||||
|
|
|
@ -348,6 +348,7 @@ constexpr const char kNameMaxPoolV3[] = "MaxPoolV3";
|
|||
constexpr const char kNameAvgPoolV2[] = "AvgPoolV2";
|
||||
constexpr const char kNameShape[] = "Shape";
|
||||
constexpr const char kNameTensorShape[] = "TensorShape";
|
||||
constexpr const char kNameDynamicShape[] = "DynamicShape";
|
||||
constexpr const char kNameGather[] = "Gather";
|
||||
constexpr const char kNameUnsqueeze[] = "Unsqueeze";
|
||||
constexpr const char kNamePadV3[] = "PadV3";
|
||||
|
|
|
@ -55,6 +55,9 @@ REG_ADPT_DESC(Shape, kNameShape, ADPT_DESC(Shape))
|
|||
// TensorShape
|
||||
REG_ADPT_DESC(TensorShape, kNameTensorShape, ADPT_DESC(Shape))
|
||||
|
||||
// DynamicShape
|
||||
REG_ADPT_DESC(DynamicShape, kNameDynamicShape, ADPT_DESC(Shape))
|
||||
|
||||
// GetShape
|
||||
INPUT_MAP(GetShape) = EMPTY_INPUT_MAP;
|
||||
DYN_INPUT_MAP(GetShape) = {{1, DYN_INPUT_DESC(x)}};
|
||||
|
|
Loading…
Reference in New Issue