forked from mindspore-Ecosystem/mindspore
!29112 Add more primtives using tuple/list input.
Merge pull request !29112 from liqiliang/ops_with_tuple_input
This commit is contained in:
commit
a274d27a87
|
@ -58,17 +58,24 @@ mindspore::HashSet<std::string> prims_to_skip_undetermined_infer{
|
|||
// We consider all tuple/list arguments are used by now.
|
||||
// Should check 'tuple argument index' and 'element use index' later.
|
||||
mindspore::HashSet<std::string> prims_use_sequence_elements{prim::kPrimStack->name(),
|
||||
prim::kPrimBroadcast->name(),
|
||||
prim::kPrimConcat->name(),
|
||||
prim::kPrimTupleToArray->name(),
|
||||
prim::kPrimPack->name(),
|
||||
prim::kPrimSlice->name(),
|
||||
prim::kPrimStridedSlice->name(),
|
||||
prim::kPrimScatterNd->name(),
|
||||
prim::kPrimReshape->name(),
|
||||
prim::kPrimTile->name(),
|
||||
prim::kPrimConv3DBackpropFilter->name(),
|
||||
prim::kPrimCentralization->name(),
|
||||
prim::kPrimMerge->name(),
|
||||
prim::kPrimCustom->name(),
|
||||
prim::kPrimAssert->name(),
|
||||
"InvertPermutation",
|
||||
"Meshgrid",
|
||||
"TransShape",
|
||||
"ParallelConcat"};
|
||||
"ParallelConcat",
|
||||
"CudnnGRU"};
|
||||
|
||||
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
|
|
|
@ -122,7 +122,14 @@ AbstractBasePtr DynamicResizeNearestNeighborInfer(const abstract::AnalysisEngine
|
|||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(CheckAndConvertUtils::GetRemoveMonadAbsNum(input_args)),
|
||||
kEqual, input_num, prim_name);
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
auto res = abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
// Set all used flags of tuple as true.
|
||||
for (size_t i = 0; i < input_args.size(); i++) {
|
||||
if (input_args[i] != nullptr) {
|
||||
SetSequenceElementsUseFlags(input_args[i], true);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(DynamicResizeNearestNeighbor, prim::kPrimDynamicResizeNearestNeighbor,
|
||||
DynamicResizeNearestNeighborInfer, nullptr, true);
|
||||
|
|
|
@ -62,8 +62,15 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
|
|||
|
||||
AbstractBasePtr AvgPool3DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
auto res = std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
// Set all used flags of tuple as true.
|
||||
for (size_t i = 0; i < input_args.size(); i++) {
|
||||
if (input_args[i] != nullptr) {
|
||||
SetSequenceElementsUseFlags(input_args[i], true);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(AvgPool3DGrad, prim::kPrimAvgPool3DGrad, AvgPool3DGradInfer, nullptr, true);
|
||||
|
|
|
@ -239,8 +239,15 @@ AbstractBasePtr Conv2DBackpropFilterInfer(const abstract::AnalysisEnginePtr &, c
|
|||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
return std::make_shared<abstract::AbstractTensor>(Conv2DBackpropFilterInferType(primitive, input_args),
|
||||
Conv2DBackpropFilterInferShape(primitive, input_args));
|
||||
auto res = std::make_shared<abstract::AbstractTensor>(Conv2DBackpropFilterInferType(primitive, input_args),
|
||||
Conv2DBackpropFilterInferShape(primitive, input_args));
|
||||
// Set all used flags of tuple as true.
|
||||
for (size_t i = 0; i < input_args.size(); i++) {
|
||||
if (input_args[i] != nullptr) {
|
||||
SetSequenceElementsUseFlags(input_args[i], true);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropFilter, prim::kPrimConv2DBackpropFilter, Conv2DBackpropFilterInfer, nullptr,
|
||||
true);
|
||||
|
|
|
@ -53,7 +53,14 @@ AbstractBasePtr DropoutGradInfer(const abstract::AnalysisEnginePtr &, const Prim
|
|||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
auto out_type = CheckAndConvertUtils::CheckTensorTypeValid("x", dy_type, valid_types, op_name);
|
||||
auto shape = CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, dy_index);
|
||||
return abstract::MakeAbstract(shape, out_type);
|
||||
auto res = abstract::MakeAbstract(shape, out_type);
|
||||
// Set all used flags of tuple as true.
|
||||
for (size_t i = 0; i < input_args.size(); i++) {
|
||||
if (input_args[i] != nullptr) {
|
||||
SetSequenceElementsUseFlags(input_args[i], true);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(DropoutGrad, prim::kPrimDropoutGrad, DropoutGradInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -120,8 +120,15 @@ AbstractBasePtr StridedSliceGradInfer(const abstract::AnalysisEnginePtr &, const
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 5;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
return abstract::MakeAbstract(StridedSliceGradInferShape(primitive, input_args),
|
||||
StridedSliceGradInferType(primitive, input_args));
|
||||
auto res = abstract::MakeAbstract(StridedSliceGradInferShape(primitive, input_args),
|
||||
StridedSliceGradInferType(primitive, input_args));
|
||||
// Set all used flags of tuple as true.
|
||||
for (size_t i = 0; i < input_args.size(); i++) {
|
||||
if (input_args[i] != nullptr) {
|
||||
SetSequenceElementsUseFlags(input_args[i], true);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
void StridedSliceGrad::set_begin_mask(int64_t begin_mask) {
|
||||
|
|
|
@ -185,7 +185,13 @@ AbstractBasePtr NeighborExchangeInfer(const abstract::AnalysisEnginePtr &, const
|
|||
Check(primitive, input_args);
|
||||
auto type = InferType(primitive);
|
||||
auto shape = InferShape(primitive);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
auto res = abstract::MakeAbstract(shape, type);
|
||||
for (size_t i = 0; i < input_args.size(); i++) {
|
||||
if (input_args[i] != nullptr) {
|
||||
SetSequenceElementsUseFlags(input_args[i], true);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(NeighborExchange, prim::kPrimNeighborExchange, NeighborExchangeInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -53,8 +53,14 @@ AbstractBasePtr OnesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
|||
const std::string op_name = primitive->name();
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, op_name);
|
||||
|
||||
return abstract::MakeAbstract(OnesInferShape(primitive, input_args), OnesInferType(primitive, input_args));
|
||||
auto res = abstract::MakeAbstract(OnesInferShape(primitive, input_args), OnesInferType(primitive, input_args));
|
||||
// Set all used flags of tuple as true.
|
||||
for (size_t i = 0; i < input_args.size(); i++) {
|
||||
if (input_args[i] != nullptr) {
|
||||
SetSequenceElementsUseFlags(input_args[i], true);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
ValuePtr OnesInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
|
|
@ -97,7 +97,14 @@ AbstractBasePtr TransposeInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
primitive->name());
|
||||
auto type = InferType(primitive, input_args);
|
||||
auto shape = InferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
auto res = abstract::MakeAbstract(shape, type);
|
||||
// Set all used flags of tuple as true.
|
||||
for (size_t i = 0; i < input_args.size(); i++) {
|
||||
if (input_args[i] != nullptr) {
|
||||
SetSequenceElementsUseFlags(input_args[i], true);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Transpose, prim::kPrimTranspose, TransposeInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -52,7 +52,14 @@ TypePtr ZerosInferType(const PrimitivePtr &prim, const std::vector<AbstractBaseP
|
|||
AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
return abstract::MakeAbstract(ZerosInferShape(primitive, input_args), ZerosInferType(primitive, input_args));
|
||||
auto res = abstract::MakeAbstract(ZerosInferShape(primitive, input_args), ZerosInferType(primitive, input_args));
|
||||
// Set all used flags of tuple as true.
|
||||
for (size_t i = 0; i < input_args.size(); i++) {
|
||||
if (input_args[i] != nullptr) {
|
||||
SetSequenceElementsUseFlags(input_args[i], true);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
|
Loading…
Reference in New Issue