Update tuple/list use flags for Einsum primitive.

This commit is contained in:
Zhang Qinghua 2022-01-19 20:35:23 +08:00
parent a33b509a48
commit 705e71f20d
3 changed files with 9 additions and 3 deletions

View File

@ -89,6 +89,7 @@ mindspore::HashSet<std::string> prims_use_sequence_elements{
prim::kPrimSparseToDense->name(),
prim::kPrimSparseTensorDenseMatmul->name(),
prim::kPrimBroadcast->name(),
prim::kPrimEinsumGrad->name(),
"InvertPermutation",
"Meshgrid",
"TransShape",

View File

@ -238,8 +238,11 @@ AbstractBasePtr EinsumInfer(const abstract::AnalysisEnginePtr &, const Primitive
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args));
auto res =
std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), InferShape(primitive, input_args));
// For Einsum(tuple/list), set all used flags of tuple as true.
SetSequenceElementsUseFlags(input_args[0], true);
return res;
}
REGISTER_PRIMITIVE_EVAL_IMPL(Einsum, prim::kPrimEinsum, EinsumInfer, nullptr, true);
} // namespace ops

View File

@ -188,7 +188,9 @@ AbstractBasePtr NeighborExchangeInfer(const abstract::AnalysisEnginePtr &, const
auto shape = InferShape(primitive);
auto res = abstract::MakeAbstract(shape, type);
// Set all used flags of tuple as true.
SetSequenceElementsUseFlags(input_args[input_tuple_index], true);
if (!input_args.empty()) {
SetSequenceElementsUseFlags(input_args[input_tuple_index], true);
}
return res;
}
REGISTER_PRIMITIVE_EVAL_IMPL(NeighborExchange, prim::kPrimNeighborExchange, NeighborExchangeInfer, nullptr, true);