!29112 Add more primtives using tuple/list input.

Merge pull request !29112 from liqiliang/ops_with_tuple_input
This commit is contained in:
i-robot 2022-01-15 06:03:52 +00:00 committed by Gitee
commit a274d27a87
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 83 additions and 15 deletions

View File

@ -58,17 +58,24 @@ mindspore::HashSet<std::string> prims_to_skip_undetermined_infer{
// We consider all tuple/list arguments are used by now.
// Should check 'tuple argument index' and 'element use index' later.
mindspore::HashSet<std::string> prims_use_sequence_elements{prim::kPrimStack->name(),
prim::kPrimBroadcast->name(),
prim::kPrimConcat->name(),
prim::kPrimTupleToArray->name(),
prim::kPrimPack->name(),
prim::kPrimSlice->name(),
prim::kPrimStridedSlice->name(),
prim::kPrimScatterNd->name(),
prim::kPrimReshape->name(),
prim::kPrimTile->name(),
prim::kPrimConv3DBackpropFilter->name(),
prim::kPrimCentralization->name(),
prim::kPrimMerge->name(),
prim::kPrimCustom->name(),
prim::kPrimAssert->name(),
"InvertPermutation",
"Meshgrid",
"TransShape",
"ParallelConcat"};
"ParallelConcat",
"CudnnGRU"};
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) {

View File

@ -122,7 +122,14 @@ AbstractBasePtr DynamicResizeNearestNeighborInfer(const abstract::AnalysisEngine
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(CheckAndConvertUtils::GetRemoveMonadAbsNum(input_args)),
kEqual, input_num, prim_name);
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
auto res = abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
// Set all used flags of tuple as true.
for (size_t i = 0; i < input_args.size(); i++) {
if (input_args[i] != nullptr) {
SetSequenceElementsUseFlags(input_args[i], true);
}
}
return res;
}
REGISTER_PRIMITIVE_EVAL_IMPL(DynamicResizeNearestNeighbor, prim::kPrimDynamicResizeNearestNeighbor,
DynamicResizeNearestNeighborInfer, nullptr, true);

View File

@ -62,8 +62,15 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
AbstractBasePtr AvgPool3DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
auto res = std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
// Set all used flags of tuple as true.
for (size_t i = 0; i < input_args.size(); i++) {
if (input_args[i] != nullptr) {
SetSequenceElementsUseFlags(input_args[i], true);
}
}
return res;
}
REGISTER_PRIMITIVE_EVAL_IMPL(AvgPool3DGrad, prim::kPrimAvgPool3DGrad, AvgPool3DGradInfer, nullptr, true);

View File

@ -239,8 +239,15 @@ AbstractBasePtr Conv2DBackpropFilterInfer(const abstract::AnalysisEnginePtr &, c
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
return std::make_shared<abstract::AbstractTensor>(Conv2DBackpropFilterInferType(primitive, input_args),
Conv2DBackpropFilterInferShape(primitive, input_args));
auto res = std::make_shared<abstract::AbstractTensor>(Conv2DBackpropFilterInferType(primitive, input_args),
Conv2DBackpropFilterInferShape(primitive, input_args));
// Set all used flags of tuple as true.
for (size_t i = 0; i < input_args.size(); i++) {
if (input_args[i] != nullptr) {
SetSequenceElementsUseFlags(input_args[i], true);
}
}
return res;
}
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropFilter, prim::kPrimConv2DBackpropFilter, Conv2DBackpropFilterInfer, nullptr,
true);

View File

@ -53,7 +53,14 @@ AbstractBasePtr DropoutGradInfer(const abstract::AnalysisEnginePtr &, const Prim
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
auto out_type = CheckAndConvertUtils::CheckTensorTypeValid("x", dy_type, valid_types, op_name);
auto shape = CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, dy_index);
return abstract::MakeAbstract(shape, out_type);
auto res = abstract::MakeAbstract(shape, out_type);
// Set all used flags of tuple as true.
for (size_t i = 0; i < input_args.size(); i++) {
if (input_args[i] != nullptr) {
SetSequenceElementsUseFlags(input_args[i], true);
}
}
return res;
}
REGISTER_PRIMITIVE_EVAL_IMPL(DropoutGrad, prim::kPrimDropoutGrad, DropoutGradInfer, nullptr, true);
} // namespace ops

View File

@ -120,8 +120,15 @@ AbstractBasePtr StridedSliceGradInfer(const abstract::AnalysisEnginePtr &, const
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 5;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
return abstract::MakeAbstract(StridedSliceGradInferShape(primitive, input_args),
StridedSliceGradInferType(primitive, input_args));
auto res = abstract::MakeAbstract(StridedSliceGradInferShape(primitive, input_args),
StridedSliceGradInferType(primitive, input_args));
// Set all used flags of tuple as true.
for (size_t i = 0; i < input_args.size(); i++) {
if (input_args[i] != nullptr) {
SetSequenceElementsUseFlags(input_args[i], true);
}
}
return res;
}
void StridedSliceGrad::set_begin_mask(int64_t begin_mask) {

View File

@ -185,7 +185,13 @@ AbstractBasePtr NeighborExchangeInfer(const abstract::AnalysisEnginePtr &, const
Check(primitive, input_args);
auto type = InferType(primitive);
auto shape = InferShape(primitive);
return abstract::MakeAbstract(shape, type);
auto res = abstract::MakeAbstract(shape, type);
for (size_t i = 0; i < input_args.size(); i++) {
if (input_args[i] != nullptr) {
SetSequenceElementsUseFlags(input_args[i], true);
}
}
return res;
}
REGISTER_PRIMITIVE_EVAL_IMPL(NeighborExchange, prim::kPrimNeighborExchange, NeighborExchangeInfer, nullptr, true);
} // namespace ops

View File

@ -53,8 +53,14 @@ AbstractBasePtr OnesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
const std::string op_name = primitive->name();
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, op_name);
return abstract::MakeAbstract(OnesInferShape(primitive, input_args), OnesInferType(primitive, input_args));
auto res = abstract::MakeAbstract(OnesInferShape(primitive, input_args), OnesInferType(primitive, input_args));
// Set all used flags of tuple as true.
for (size_t i = 0; i < input_args.size(); i++) {
if (input_args[i] != nullptr) {
SetSequenceElementsUseFlags(input_args[i], true);
}
}
return res;
}
ValuePtr OnesInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {

View File

@ -97,7 +97,14 @@ AbstractBasePtr TransposeInfer(const abstract::AnalysisEnginePtr &, const Primit
primitive->name());
auto type = InferType(primitive, input_args);
auto shape = InferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
auto res = abstract::MakeAbstract(shape, type);
// Set all used flags of tuple as true.
for (size_t i = 0; i < input_args.size(); i++) {
if (input_args[i] != nullptr) {
SetSequenceElementsUseFlags(input_args[i], true);
}
}
return res;
}
REGISTER_PRIMITIVE_EVAL_IMPL(Transpose, prim::kPrimTranspose, TransposeInfer, nullptr, true);
} // namespace ops

View File

@ -52,7 +52,14 @@ TypePtr ZerosInferType(const PrimitivePtr &prim, const std::vector<AbstractBaseP
AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
return abstract::MakeAbstract(ZerosInferShape(primitive, input_args), ZerosInferType(primitive, input_args));
auto res = abstract::MakeAbstract(ZerosInferShape(primitive, input_args), ZerosInferType(primitive, input_args));
// Set all used flags of tuple as true.
for (size_t i = 0; i < input_args.size(); i++) {
if (input_args[i] != nullptr) {
SetSequenceElementsUseFlags(input_args[i], true);
}
}
return res;
}
ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {