forked from mindspore-Ecosystem/mindspore
!29112 Add more primtives using tuple/list input.
Merge pull request !29112 from liqiliang/ops_with_tuple_input
This commit is contained in:
commit
a274d27a87
|
@ -58,17 +58,24 @@ mindspore::HashSet<std::string> prims_to_skip_undetermined_infer{
|
||||||
// We consider all tuple/list arguments are used by now.
|
// We consider all tuple/list arguments are used by now.
|
||||||
// Should check 'tuple argument index' and 'element use index' later.
|
// Should check 'tuple argument index' and 'element use index' later.
|
||||||
mindspore::HashSet<std::string> prims_use_sequence_elements{prim::kPrimStack->name(),
|
mindspore::HashSet<std::string> prims_use_sequence_elements{prim::kPrimStack->name(),
|
||||||
prim::kPrimBroadcast->name(),
|
|
||||||
prim::kPrimConcat->name(),
|
prim::kPrimConcat->name(),
|
||||||
prim::kPrimTupleToArray->name(),
|
prim::kPrimTupleToArray->name(),
|
||||||
prim::kPrimPack->name(),
|
prim::kPrimPack->name(),
|
||||||
prim::kPrimSlice->name(),
|
prim::kPrimSlice->name(),
|
||||||
prim::kPrimStridedSlice->name(),
|
prim::kPrimStridedSlice->name(),
|
||||||
prim::kPrimScatterNd->name(),
|
prim::kPrimScatterNd->name(),
|
||||||
|
prim::kPrimReshape->name(),
|
||||||
|
prim::kPrimTile->name(),
|
||||||
|
prim::kPrimConv3DBackpropFilter->name(),
|
||||||
|
prim::kPrimCentralization->name(),
|
||||||
|
prim::kPrimMerge->name(),
|
||||||
|
prim::kPrimCustom->name(),
|
||||||
|
prim::kPrimAssert->name(),
|
||||||
"InvertPermutation",
|
"InvertPermutation",
|
||||||
"Meshgrid",
|
"Meshgrid",
|
||||||
"TransShape",
|
"TransShape",
|
||||||
"ParallelConcat"};
|
"ParallelConcat",
|
||||||
|
"CudnnGRU"};
|
||||||
|
|
||||||
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||||
const AnfNodeConfigPtr &out_conf) {
|
const AnfNodeConfigPtr &out_conf) {
|
||||||
|
|
|
@ -122,7 +122,14 @@ AbstractBasePtr DynamicResizeNearestNeighborInfer(const abstract::AnalysisEngine
|
||||||
const int64_t input_num = 2;
|
const int64_t input_num = 2;
|
||||||
(void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(CheckAndConvertUtils::GetRemoveMonadAbsNum(input_args)),
|
(void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(CheckAndConvertUtils::GetRemoveMonadAbsNum(input_args)),
|
||||||
kEqual, input_num, prim_name);
|
kEqual, input_num, prim_name);
|
||||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
auto res = abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||||
|
// Set all used flags of tuple as true.
|
||||||
|
for (size_t i = 0; i < input_args.size(); i++) {
|
||||||
|
if (input_args[i] != nullptr) {
|
||||||
|
SetSequenceElementsUseFlags(input_args[i], true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_EVAL_IMPL(DynamicResizeNearestNeighbor, prim::kPrimDynamicResizeNearestNeighbor,
|
REGISTER_PRIMITIVE_EVAL_IMPL(DynamicResizeNearestNeighbor, prim::kPrimDynamicResizeNearestNeighbor,
|
||||||
DynamicResizeNearestNeighborInfer, nullptr, true);
|
DynamicResizeNearestNeighborInfer, nullptr, true);
|
||||||
|
|
|
@ -62,8 +62,15 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
|
||||||
|
|
||||||
AbstractBasePtr AvgPool3DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr AvgPool3DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
auto res = std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||||
InferShape(primitive, input_args)->shape());
|
InferShape(primitive, input_args)->shape());
|
||||||
|
// Set all used flags of tuple as true.
|
||||||
|
for (size_t i = 0; i < input_args.size(); i++) {
|
||||||
|
if (input_args[i] != nullptr) {
|
||||||
|
SetSequenceElementsUseFlags(input_args[i], true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_PRIMITIVE_EVAL_IMPL(AvgPool3DGrad, prim::kPrimAvgPool3DGrad, AvgPool3DGradInfer, nullptr, true);
|
REGISTER_PRIMITIVE_EVAL_IMPL(AvgPool3DGrad, prim::kPrimAvgPool3DGrad, AvgPool3DGradInfer, nullptr, true);
|
||||||
|
|
|
@ -239,8 +239,15 @@ AbstractBasePtr Conv2DBackpropFilterInfer(const abstract::AnalysisEnginePtr &, c
|
||||||
for (const auto &item : input_args) {
|
for (const auto &item : input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
}
|
}
|
||||||
return std::make_shared<abstract::AbstractTensor>(Conv2DBackpropFilterInferType(primitive, input_args),
|
auto res = std::make_shared<abstract::AbstractTensor>(Conv2DBackpropFilterInferType(primitive, input_args),
|
||||||
Conv2DBackpropFilterInferShape(primitive, input_args));
|
Conv2DBackpropFilterInferShape(primitive, input_args));
|
||||||
|
// Set all used flags of tuple as true.
|
||||||
|
for (size_t i = 0; i < input_args.size(); i++) {
|
||||||
|
if (input_args[i] != nullptr) {
|
||||||
|
SetSequenceElementsUseFlags(input_args[i], true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropFilter, prim::kPrimConv2DBackpropFilter, Conv2DBackpropFilterInfer, nullptr,
|
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropFilter, prim::kPrimConv2DBackpropFilter, Conv2DBackpropFilterInfer, nullptr,
|
||||||
true);
|
true);
|
||||||
|
|
|
@ -53,7 +53,14 @@ AbstractBasePtr DropoutGradInfer(const abstract::AnalysisEnginePtr &, const Prim
|
||||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||||
auto out_type = CheckAndConvertUtils::CheckTensorTypeValid("x", dy_type, valid_types, op_name);
|
auto out_type = CheckAndConvertUtils::CheckTensorTypeValid("x", dy_type, valid_types, op_name);
|
||||||
auto shape = CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, dy_index);
|
auto shape = CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, dy_index);
|
||||||
return abstract::MakeAbstract(shape, out_type);
|
auto res = abstract::MakeAbstract(shape, out_type);
|
||||||
|
// Set all used flags of tuple as true.
|
||||||
|
for (size_t i = 0; i < input_args.size(); i++) {
|
||||||
|
if (input_args[i] != nullptr) {
|
||||||
|
SetSequenceElementsUseFlags(input_args[i], true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_EVAL_IMPL(DropoutGrad, prim::kPrimDropoutGrad, DropoutGradInfer, nullptr, true);
|
REGISTER_PRIMITIVE_EVAL_IMPL(DropoutGrad, prim::kPrimDropoutGrad, DropoutGradInfer, nullptr, true);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
|
@ -120,8 +120,15 @@ AbstractBasePtr StridedSliceGradInfer(const abstract::AnalysisEnginePtr &, const
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
const int64_t input_num = 5;
|
const int64_t input_num = 5;
|
||||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||||
return abstract::MakeAbstract(StridedSliceGradInferShape(primitive, input_args),
|
auto res = abstract::MakeAbstract(StridedSliceGradInferShape(primitive, input_args),
|
||||||
StridedSliceGradInferType(primitive, input_args));
|
StridedSliceGradInferType(primitive, input_args));
|
||||||
|
// Set all used flags of tuple as true.
|
||||||
|
for (size_t i = 0; i < input_args.size(); i++) {
|
||||||
|
if (input_args[i] != nullptr) {
|
||||||
|
SetSequenceElementsUseFlags(input_args[i], true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
void StridedSliceGrad::set_begin_mask(int64_t begin_mask) {
|
void StridedSliceGrad::set_begin_mask(int64_t begin_mask) {
|
||||||
|
|
|
@ -185,7 +185,13 @@ AbstractBasePtr NeighborExchangeInfer(const abstract::AnalysisEnginePtr &, const
|
||||||
Check(primitive, input_args);
|
Check(primitive, input_args);
|
||||||
auto type = InferType(primitive);
|
auto type = InferType(primitive);
|
||||||
auto shape = InferShape(primitive);
|
auto shape = InferShape(primitive);
|
||||||
return abstract::MakeAbstract(shape, type);
|
auto res = abstract::MakeAbstract(shape, type);
|
||||||
|
for (size_t i = 0; i < input_args.size(); i++) {
|
||||||
|
if (input_args[i] != nullptr) {
|
||||||
|
SetSequenceElementsUseFlags(input_args[i], true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_EVAL_IMPL(NeighborExchange, prim::kPrimNeighborExchange, NeighborExchangeInfer, nullptr, true);
|
REGISTER_PRIMITIVE_EVAL_IMPL(NeighborExchange, prim::kPrimNeighborExchange, NeighborExchangeInfer, nullptr, true);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
|
@ -53,8 +53,14 @@ AbstractBasePtr OnesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
||||||
const std::string op_name = primitive->name();
|
const std::string op_name = primitive->name();
|
||||||
const int64_t input_num = 2;
|
const int64_t input_num = 2;
|
||||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, op_name);
|
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, op_name);
|
||||||
|
auto res = abstract::MakeAbstract(OnesInferShape(primitive, input_args), OnesInferType(primitive, input_args));
|
||||||
return abstract::MakeAbstract(OnesInferShape(primitive, input_args), OnesInferType(primitive, input_args));
|
// Set all used flags of tuple as true.
|
||||||
|
for (size_t i = 0; i < input_args.size(); i++) {
|
||||||
|
if (input_args[i] != nullptr) {
|
||||||
|
SetSequenceElementsUseFlags(input_args[i], true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
ValuePtr OnesInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
ValuePtr OnesInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
|
|
@ -97,7 +97,14 @@ AbstractBasePtr TransposeInfer(const abstract::AnalysisEnginePtr &, const Primit
|
||||||
primitive->name());
|
primitive->name());
|
||||||
auto type = InferType(primitive, input_args);
|
auto type = InferType(primitive, input_args);
|
||||||
auto shape = InferShape(primitive, input_args);
|
auto shape = InferShape(primitive, input_args);
|
||||||
return abstract::MakeAbstract(shape, type);
|
auto res = abstract::MakeAbstract(shape, type);
|
||||||
|
// Set all used flags of tuple as true.
|
||||||
|
for (size_t i = 0; i < input_args.size(); i++) {
|
||||||
|
if (input_args[i] != nullptr) {
|
||||||
|
SetSequenceElementsUseFlags(input_args[i], true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_EVAL_IMPL(Transpose, prim::kPrimTranspose, TransposeInfer, nullptr, true);
|
REGISTER_PRIMITIVE_EVAL_IMPL(Transpose, prim::kPrimTranspose, TransposeInfer, nullptr, true);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
|
@ -52,7 +52,14 @@ TypePtr ZerosInferType(const PrimitivePtr &prim, const std::vector<AbstractBaseP
|
||||||
AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
return abstract::MakeAbstract(ZerosInferShape(primitive, input_args), ZerosInferType(primitive, input_args));
|
auto res = abstract::MakeAbstract(ZerosInferShape(primitive, input_args), ZerosInferType(primitive, input_args));
|
||||||
|
// Set all used flags of tuple as true.
|
||||||
|
for (size_t i = 0; i < input_args.size(); i++) {
|
||||||
|
if (input_args[i] != nullptr) {
|
||||||
|
SetSequenceElementsUseFlags(input_args[i], true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
|
Loading…
Reference in New Issue