forked from mindspore-Ecosystem/mindspore
Update tuple/list use flags for Einsum primitive.
This commit is contained in:
parent
a33b509a48
commit
705e71f20d
|
@ -89,6 +89,7 @@ mindspore::HashSet<std::string> prims_use_sequence_elements{
|
|||
prim::kPrimSparseToDense->name(),
|
||||
prim::kPrimSparseTensorDenseMatmul->name(),
|
||||
prim::kPrimBroadcast->name(),
|
||||
prim::kPrimEinsumGrad->name(),
|
||||
"InvertPermutation",
|
||||
"Meshgrid",
|
||||
"TransShape",
|
||||
|
|
|
@ -238,8 +238,11 @@ AbstractBasePtr EinsumInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args));
|
||||
auto res =
|
||||
std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), InferShape(primitive, input_args));
|
||||
// For Einsum(tuple/list), set all used flags of tuple as true.
|
||||
SetSequenceElementsUseFlags(input_args[0], true);
|
||||
return res;
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Einsum, prim::kPrimEinsum, EinsumInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -188,7 +188,9 @@ AbstractBasePtr NeighborExchangeInfer(const abstract::AnalysisEnginePtr &, const
|
|||
auto shape = InferShape(primitive);
|
||||
auto res = abstract::MakeAbstract(shape, type);
|
||||
// Set all used flags of tuple as true.
|
||||
SetSequenceElementsUseFlags(input_args[input_tuple_index], true);
|
||||
if (!input_args.empty()) {
|
||||
SetSequenceElementsUseFlags(input_args[input_tuple_index], true);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(NeighborExchange, prim::kPrimNeighborExchange, NeighborExchangeInfer, nullptr, true);
|
||||
|
|
Loading…
Reference in New Issue