!29170 Add flag for more tuple inputs

Merge pull request !29170 from huanghui/tuple-input
This commit is contained in:
i-robot 2022-01-17 11:16:27 +00:00 committed by Gitee
commit ef5e5c2aae
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 71 additions and 26 deletions

View File

@ -1386,7 +1386,7 @@ void KernelGraph::SetInputNodes() {
void KernelGraph::UpdateGraphAquireGilAttr() {
for (const auto &cnode : execution_order_) {
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPyFunc)) {
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPyFunc)) {
MS_LOG(INFO) << "The Graph require GIL. Graph id: " << graph_id_;
is_need_gil_ = true;
return;

View File

@ -51,6 +51,8 @@ AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const Abst
CheckArgsSize(op_name, args_spec_list, 2);
auto input_x = CheckArg<T>(op_name, args_spec_list, 0);
auto input_y = CheckArg<T>(op_name, args_spec_list, 1);
SetSequenceElementsUseFlags(input_x, true);
SetSequenceElementsUseFlags(input_y, true);
ValuePtr x_value = input_x->BuildValue();
ValuePtr y_value = input_y->BuildValue();
@ -296,7 +298,8 @@ AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const
return std::make_shared<AbstractTuple>(elem_list);
}
SetSequenceElementsUseFlags(arg_x, true);
SetSequenceElementsUseFlags(arg_y, true);
return BroadcastGradientArgsDiff(x_shape, y_shape);
}
@ -351,6 +354,7 @@ AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const Primi
MS_EXCEPTION_IF_NULL(result2);
MS_EXCEPTION_IF_NULL(result1->abstract());
MS_EXCEPTION_IF_NULL(result2->abstract());
SetSequenceElementsUseFlags(lst, true);
return result1->abstract()->Join(result2->abstract());
}
@ -366,6 +370,7 @@ AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const Primitiv
AbstractBasePtrList elem_list;
(void)std::transform(tuple_elements.rbegin(), tuple_elements.rend(), std::back_inserter(elem_list),
[](const AbstractBasePtr &elem) { return elem->Clone(); });
SetSequenceElementsUseFlags(input, true);
return std::make_shared<AbstractTuple>(elem_list);
}
@ -407,7 +412,7 @@ AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitiveP
}
auto axis_value_ptr = axis_value->cast<ValueSequencePtr>();
MS_EXCEPTION_IF_NULL(axis_value_ptr);
SetSequenceElementsUseFlags(shape_x, true);
return DoInferReduceShape(shape_x, x_shp_value, axis_value_ptr, primitive);
}
@ -472,7 +477,8 @@ AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr
auto result_v = MakeValue(result);
values.push_back(std::make_shared<AbstractScalar>(result_v, result_v->type()));
}
SetSequenceElementsUseFlags(shape_x, true);
SetSequenceElementsUseFlags(div_shp, true);
return std::make_shared<AbstractTuple>(values);
}

View File

@ -57,25 +57,53 @@ mindspore::HashSet<std::string> prims_to_skip_undetermined_infer{
// The Python primitives who use tuple/list elements.
// We consider all tuple/list arguments are used by now.
// Should check 'tuple argument index' and 'element use index' later.
mindspore::HashSet<std::string> prims_use_sequence_elements{prim::kPrimStack->name(),
prim::kPrimConcat->name(),
prim::kPrimTupleToArray->name(),
prim::kPrimPack->name(),
prim::kPrimSlice->name(),
prim::kPrimStridedSlice->name(),
prim::kPrimScatterNd->name(),
prim::kPrimReshape->name(),
prim::kPrimTile->name(),
prim::kPrimConv3DBackpropFilter->name(),
prim::kPrimCentralization->name(),
prim::kPrimMerge->name(),
prim::kPrimCustom->name(),
prim::kPrimAssert->name(),
"InvertPermutation",
"Meshgrid",
"TransShape",
"ParallelConcat",
"CudnnGRU"};
mindspore::HashSet<std::string> prims_use_sequence_elements{
prim::kPrimStack->name(),
prim::kPrimConcat->name(),
prim::kPrimTupleToArray->name(),
prim::kPrimPack->name(),
prim::kPrimSlice->name(),
prim::kPrimStridedSlice->name(),
prim::kPrimScatterNd->name(),
prim::kPrimReshape->name(),
prim::kPrimTile->name(),
prim::kPrimConv3DBackpropFilter->name(),
prim::kPrimCentralization->name(),
prim::kPrimMerge->name(),
prim::kPrimCustom->name(),
prim::kPrimAssert->name(),
prim::kPrimReduceMean->name(),
prim::kPrimReduceSum->name(),
prim::kPrimReduceAll->name(),
prim::kPrimReduceAny->name(),
prim::kPrimReduceMax->name(),
prim::kPrimReduceMin->name(),
prim::kPrimLstm->name(),
prim::kPrimConv3DBackpropInput->name(),
prim::kPrimCheckBprop->name(),
prim::kPrimPush->name(),
prim::kPrimIdentity->name(),
prim::kPrimPyFunc->name(),
prim::kPrimStandardNormal->name(),
prim::kPrimUniformReal->name(),
prim::kPrimSparseToDense->name(),
prim::kPrimSparseTensorDenseMatmul->name(),
"InvertPermutation",
"Meshgrid",
"TransShape",
"ParallelConcat",
"CudnnGRU",
"GetModel",
"UpdateModel",
"ReduceProd",
"StandardLaplace",
"Gamma",
"Poisson",
"UniformInt",
"BufferSample",
"BufferAppend",
"BufferGetItem",
};
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) {

View File

@ -303,6 +303,7 @@ AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const Abst
// Inputs: a tuple or list or dict.
CheckArgsSize(op_name, args_spec_list, 1);
auto arg = CheckArg<T>(op_name, args_spec_list, 0);
SetSequenceElementsUseFlags(arg, true);
return std::make_shared<AbstractScalar>(SizeToLong(arg->size()));
}
} // namespace abstract

View File

@ -279,6 +279,7 @@ AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const Primitiv
ret->set_indices(indices);
ret->set_values(values);
ret->set_dense_shape(dense_shape);
SetSequenceElementsUseFlags(dense_shape, true);
return ret;
}
@ -382,6 +383,7 @@ AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const Primi
ret->set_indices(indices);
ret->set_values(values);
ret->set_dense_shape(dense_shape);
SetSequenceElementsUseFlags(dense_shape, true);
return ret;
}
@ -595,6 +597,7 @@ AbstractBasePtr InferImplMakeCSRTensor(const AnalysisEnginePtr &, const Primitiv
ret->set_indices(indices);
ret->set_values(values);
ret->set_dense_shape(shape);
SetSequenceElementsUseFlags(shape, true);
return ret;
}

View File

@ -60,6 +60,8 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr
auto key_string = GetValue<std::string>(keyPtr);
(void)key_value.emplace_back(key_string, value_list[index]);
}
SetSequenceElementsUseFlags(keys, true);
SetSequenceElementsUseFlags(values, true);
return std::make_shared<AbstractDictionary>(key_value);
}
@ -163,7 +165,8 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra
size_t index_unsigned_value = LongToSize(index_positive_value);
constexpr int target_value_index = 2;
elements[index_unsigned_value] = args_spec_list[target_value_index];
return std::make_shared<T>(elements);
SetSequenceElementsUseFlags(queue, index_unsigned_value, true);
return std::make_shared<T>(elements, queue->sequence_nodes());
}
AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
@ -295,7 +298,7 @@ AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePt
MS_EXCEPTION_IF_NULL(item);
auto new_list = AbstractBasePtrList(list->elements());
new_list.emplace_back(item);
return std::make_shared<AbstractList>(new_list);
return std::make_shared<AbstractList>(new_list, list->sequence_nodes());
}
AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -715,7 +715,7 @@ inline const PrimitivePtr kPrimWhile = std::make_shared<Primitive>("While");
inline const PrimitivePtr kPrimPull = std::make_shared<Primitive>("Pull");
inline const PrimitivePtr kPrimPush = std::make_shared<Primitive>("Push");
inline const PrimitivePtr kPrimNPUAllocFloatStatus = std::make_shared<Primitive>("NPUAllocFloatStatus");
inline const PrimitivePtr kPyFunc = std::make_shared<Primitive>("PyFunc");
inline const PrimitivePtr kPrimPyFunc = std::make_shared<Primitive>("PyFunc");
// Structures
inline const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list");

View File

@ -96,6 +96,7 @@ AbstractBasePtr AccumulateNV2Infer(const abstract::AnalysisEnginePtr &, const Pr
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
SetSequenceElementsUseFlags(input_args[0], true);
auto infer_type = AccumulateNV2InferType(primitive, input_args);
auto infer_shape = AccumulateNV2InferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);

View File

@ -39,6 +39,7 @@ abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
SetSequenceElementsUseFlags(item, true);
}
auto log_probs_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto targets_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());

View File

@ -37,6 +37,7 @@ abstract::ShapePtr CTCLossV2GradInferShape(const PrimitivePtr &primitive,
prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
SetSequenceElementsUseFlags(item, true);
}
auto log_probs_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
auto log_probs_shape = log_probs_shape_map[kShape];

View File

@ -105,6 +105,7 @@ AbstractBasePtr LpNormInfer(const abstract::AnalysisEnginePtr &, const Primitive
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
auto infer_type = InferType(primitive, input_args);
auto infer_shape = InferShape(primitive, input_args);
SetSequenceElementsUseFlags(input_args[0], true);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(LpNorm, prim::kPrimLpNorm, LpNormInfer, nullptr, true);