forked from mindspore-Ecosystem/mindspore
!29170 Add flag for more tuple inputs
Merge pull request !29170 from huanghui/tuple-input
This commit is contained in:
commit
ef5e5c2aae
|
@ -1386,7 +1386,7 @@ void KernelGraph::SetInputNodes() {
|
|||
|
||||
void KernelGraph::UpdateGraphAquireGilAttr() {
|
||||
for (const auto &cnode : execution_order_) {
|
||||
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPyFunc)) {
|
||||
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPyFunc)) {
|
||||
MS_LOG(INFO) << "The Graph require GIL. Graph id: " << graph_id_;
|
||||
is_need_gil_ = true;
|
||||
return;
|
||||
|
|
|
@ -51,6 +51,8 @@ AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const Abst
|
|||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
auto input_x = CheckArg<T>(op_name, args_spec_list, 0);
|
||||
auto input_y = CheckArg<T>(op_name, args_spec_list, 1);
|
||||
SetSequenceElementsUseFlags(input_x, true);
|
||||
SetSequenceElementsUseFlags(input_y, true);
|
||||
|
||||
ValuePtr x_value = input_x->BuildValue();
|
||||
ValuePtr y_value = input_y->BuildValue();
|
||||
|
@ -296,7 +298,8 @@ AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const
|
|||
|
||||
return std::make_shared<AbstractTuple>(elem_list);
|
||||
}
|
||||
|
||||
SetSequenceElementsUseFlags(arg_x, true);
|
||||
SetSequenceElementsUseFlags(arg_y, true);
|
||||
return BroadcastGradientArgsDiff(x_shape, y_shape);
|
||||
}
|
||||
|
||||
|
@ -351,6 +354,7 @@ AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const Primi
|
|||
MS_EXCEPTION_IF_NULL(result2);
|
||||
MS_EXCEPTION_IF_NULL(result1->abstract());
|
||||
MS_EXCEPTION_IF_NULL(result2->abstract());
|
||||
SetSequenceElementsUseFlags(lst, true);
|
||||
return result1->abstract()->Join(result2->abstract());
|
||||
}
|
||||
|
||||
|
@ -366,6 +370,7 @@ AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const Primitiv
|
|||
AbstractBasePtrList elem_list;
|
||||
(void)std::transform(tuple_elements.rbegin(), tuple_elements.rend(), std::back_inserter(elem_list),
|
||||
[](const AbstractBasePtr &elem) { return elem->Clone(); });
|
||||
SetSequenceElementsUseFlags(input, true);
|
||||
return std::make_shared<AbstractTuple>(elem_list);
|
||||
}
|
||||
|
||||
|
@ -407,7 +412,7 @@ AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitiveP
|
|||
}
|
||||
auto axis_value_ptr = axis_value->cast<ValueSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(axis_value_ptr);
|
||||
|
||||
SetSequenceElementsUseFlags(shape_x, true);
|
||||
return DoInferReduceShape(shape_x, x_shp_value, axis_value_ptr, primitive);
|
||||
}
|
||||
|
||||
|
@ -472,7 +477,8 @@ AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
auto result_v = MakeValue(result);
|
||||
values.push_back(std::make_shared<AbstractScalar>(result_v, result_v->type()));
|
||||
}
|
||||
|
||||
SetSequenceElementsUseFlags(shape_x, true);
|
||||
SetSequenceElementsUseFlags(div_shp, true);
|
||||
return std::make_shared<AbstractTuple>(values);
|
||||
}
|
||||
|
||||
|
|
|
@ -57,25 +57,53 @@ mindspore::HashSet<std::string> prims_to_skip_undetermined_infer{
|
|||
// The Python primitives who use tuple/list elements.
|
||||
// We consider all tuple/list arguments are used by now.
|
||||
// Should check 'tuple argument index' and 'element use index' later.
|
||||
mindspore::HashSet<std::string> prims_use_sequence_elements{prim::kPrimStack->name(),
|
||||
prim::kPrimConcat->name(),
|
||||
prim::kPrimTupleToArray->name(),
|
||||
prim::kPrimPack->name(),
|
||||
prim::kPrimSlice->name(),
|
||||
prim::kPrimStridedSlice->name(),
|
||||
prim::kPrimScatterNd->name(),
|
||||
prim::kPrimReshape->name(),
|
||||
prim::kPrimTile->name(),
|
||||
prim::kPrimConv3DBackpropFilter->name(),
|
||||
prim::kPrimCentralization->name(),
|
||||
prim::kPrimMerge->name(),
|
||||
prim::kPrimCustom->name(),
|
||||
prim::kPrimAssert->name(),
|
||||
"InvertPermutation",
|
||||
"Meshgrid",
|
||||
"TransShape",
|
||||
"ParallelConcat",
|
||||
"CudnnGRU"};
|
||||
mindspore::HashSet<std::string> prims_use_sequence_elements{
|
||||
prim::kPrimStack->name(),
|
||||
prim::kPrimConcat->name(),
|
||||
prim::kPrimTupleToArray->name(),
|
||||
prim::kPrimPack->name(),
|
||||
prim::kPrimSlice->name(),
|
||||
prim::kPrimStridedSlice->name(),
|
||||
prim::kPrimScatterNd->name(),
|
||||
prim::kPrimReshape->name(),
|
||||
prim::kPrimTile->name(),
|
||||
prim::kPrimConv3DBackpropFilter->name(),
|
||||
prim::kPrimCentralization->name(),
|
||||
prim::kPrimMerge->name(),
|
||||
prim::kPrimCustom->name(),
|
||||
prim::kPrimAssert->name(),
|
||||
prim::kPrimReduceMean->name(),
|
||||
prim::kPrimReduceSum->name(),
|
||||
prim::kPrimReduceAll->name(),
|
||||
prim::kPrimReduceAny->name(),
|
||||
prim::kPrimReduceMax->name(),
|
||||
prim::kPrimReduceMin->name(),
|
||||
prim::kPrimLstm->name(),
|
||||
prim::kPrimConv3DBackpropInput->name(),
|
||||
prim::kPrimCheckBprop->name(),
|
||||
prim::kPrimPush->name(),
|
||||
prim::kPrimIdentity->name(),
|
||||
prim::kPrimPyFunc->name(),
|
||||
prim::kPrimStandardNormal->name(),
|
||||
prim::kPrimUniformReal->name(),
|
||||
prim::kPrimSparseToDense->name(),
|
||||
prim::kPrimSparseTensorDenseMatmul->name(),
|
||||
"InvertPermutation",
|
||||
"Meshgrid",
|
||||
"TransShape",
|
||||
"ParallelConcat",
|
||||
"CudnnGRU",
|
||||
"GetModel",
|
||||
"UpdateModel",
|
||||
"ReduceProd",
|
||||
"StandardLaplace",
|
||||
"Gamma",
|
||||
"Poisson",
|
||||
"UniformInt",
|
||||
"BufferSample",
|
||||
"BufferAppend",
|
||||
"BufferGetItem",
|
||||
};
|
||||
|
||||
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
|
|
|
@ -303,6 +303,7 @@ AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const Abst
|
|||
// Inputs: a tuple or list or dict.
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
auto arg = CheckArg<T>(op_name, args_spec_list, 0);
|
||||
SetSequenceElementsUseFlags(arg, true);
|
||||
return std::make_shared<AbstractScalar>(SizeToLong(arg->size()));
|
||||
}
|
||||
} // namespace abstract
|
||||
|
|
|
@ -279,6 +279,7 @@ AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const Primitiv
|
|||
ret->set_indices(indices);
|
||||
ret->set_values(values);
|
||||
ret->set_dense_shape(dense_shape);
|
||||
SetSequenceElementsUseFlags(dense_shape, true);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -382,6 +383,7 @@ AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const Primi
|
|||
ret->set_indices(indices);
|
||||
ret->set_values(values);
|
||||
ret->set_dense_shape(dense_shape);
|
||||
SetSequenceElementsUseFlags(dense_shape, true);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -595,6 +597,7 @@ AbstractBasePtr InferImplMakeCSRTensor(const AnalysisEnginePtr &, const Primitiv
|
|||
ret->set_indices(indices);
|
||||
ret->set_values(values);
|
||||
ret->set_dense_shape(shape);
|
||||
SetSequenceElementsUseFlags(shape, true);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
|
@ -60,6 +60,8 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
auto key_string = GetValue<std::string>(keyPtr);
|
||||
(void)key_value.emplace_back(key_string, value_list[index]);
|
||||
}
|
||||
SetSequenceElementsUseFlags(keys, true);
|
||||
SetSequenceElementsUseFlags(values, true);
|
||||
return std::make_shared<AbstractDictionary>(key_value);
|
||||
}
|
||||
|
||||
|
@ -163,7 +165,8 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra
|
|||
size_t index_unsigned_value = LongToSize(index_positive_value);
|
||||
constexpr int target_value_index = 2;
|
||||
elements[index_unsigned_value] = args_spec_list[target_value_index];
|
||||
return std::make_shared<T>(elements);
|
||||
SetSequenceElementsUseFlags(queue, index_unsigned_value, true);
|
||||
return std::make_shared<T>(elements, queue->sequence_nodes());
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
@ -295,7 +298,7 @@ AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePt
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
auto new_list = AbstractBasePtrList(list->elements());
|
||||
new_list.emplace_back(item);
|
||||
return std::make_shared<AbstractList>(new_list);
|
||||
return std::make_shared<AbstractList>(new_list, list->sequence_nodes());
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -715,7 +715,7 @@ inline const PrimitivePtr kPrimWhile = std::make_shared<Primitive>("While");
|
|||
inline const PrimitivePtr kPrimPull = std::make_shared<Primitive>("Pull");
|
||||
inline const PrimitivePtr kPrimPush = std::make_shared<Primitive>("Push");
|
||||
inline const PrimitivePtr kPrimNPUAllocFloatStatus = std::make_shared<Primitive>("NPUAllocFloatStatus");
|
||||
inline const PrimitivePtr kPyFunc = std::make_shared<Primitive>("PyFunc");
|
||||
inline const PrimitivePtr kPrimPyFunc = std::make_shared<Primitive>("PyFunc");
|
||||
|
||||
// Structures
|
||||
inline const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list");
|
||||
|
|
|
@ -96,6 +96,7 @@ AbstractBasePtr AccumulateNV2Infer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
SetSequenceElementsUseFlags(input_args[0], true);
|
||||
auto infer_type = AccumulateNV2InferType(primitive, input_args);
|
||||
auto infer_shape = AccumulateNV2InferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
|
|
|
@ -39,6 +39,7 @@ abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
|
|||
prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
SetSequenceElementsUseFlags(item, true);
|
||||
}
|
||||
auto log_probs_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto targets_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
|
||||
|
|
|
@ -37,6 +37,7 @@ abstract::ShapePtr CTCLossV2GradInferShape(const PrimitivePtr &primitive,
|
|||
prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
SetSequenceElementsUseFlags(item, true);
|
||||
}
|
||||
auto log_probs_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
|
||||
auto log_probs_shape = log_probs_shape_map[kShape];
|
||||
|
|
|
@ -105,6 +105,7 @@ AbstractBasePtr LpNormInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
|
||||
auto infer_type = InferType(primitive, input_args);
|
||||
auto infer_shape = InferShape(primitive, input_args);
|
||||
SetSequenceElementsUseFlags(input_args[0], true);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(LpNorm, prim::kPrimLpNorm, LpNormInfer, nullptr, true);
|
||||
|
|
Loading…
Reference in New Issue