!29290 Fix Broadcast primitive tuple input incorrect elimination.

Merge pull request !29290 from huanghui/fix-broadcast-use-flag
This commit is contained in:
i-robot 2022-01-19 09:28:23 +00:00 committed by Gitee
commit d65f8e3953
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 12 additions and 5 deletions

View File

@ -88,6 +88,7 @@ mindspore::HashSet<std::string> prims_use_sequence_elements{
prim::kPrimUniformReal->name(),
prim::kPrimSparseToDense->name(),
prim::kPrimSparseTensorDenseMatmul->name(),
prim::kPrimBroadcast->name(),
"InvertPermutation",
"Meshgrid",
"TransShape",

View File

@ -115,7 +115,8 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti
(void)std::transform(res.begin(), res.end(), std::back_inserter(elems), [](int64_t n) -> AbstractBasePtr {
return std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(n), kInt64);
});
SetSequenceElementsUseFlags(xs, true);
SetSequenceElementsUseFlags(ys, true);
return std::make_shared<AbstractTuple>(elems);
}

View File

@ -94,13 +94,18 @@ AbstractBasePtr TransposeInfer(const abstract::AnalysisEnginePtr &, const Primit
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
constexpr size_t input_perm_index = 1;
(void)CheckAndConvertUtils::CheckInteger("Transpose infer", SizeToLong(input_args.size()), kGreaterEqual, 1,
// The second input is optional.
constexpr size_t input_size1 = 1;
constexpr size_t input_size2 = 2;
(void)CheckAndConvertUtils::CheckInteger("Transpose infer", SizeToLong(input_args.size()), kGreaterEqual, input_size1,
primitive->name());
auto type = InferType(primitive, input_args);
auto shape = InferShape(primitive, input_args);
auto res = abstract::MakeAbstract(shape, type);
// Set all used flags of tuple as true.
if (input_args.size() == input_size2) {
SetSequenceElementsUseFlags(input_args[input_perm_index], true);
}
return res;
}
REGISTER_PRIMITIVE_EVAL_IMPL(Transpose, prim::kPrimTranspose, TransposeInfer, nullptr, true);

View File

@ -526,7 +526,7 @@ class Broadcast(PrimitiveWithInfer):
group (str): The communication group to work on. Default: "GlobalComm.WORLD_COMM_GROUP".
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
- **input_x** (tuple[Tensor]) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Outputs:
Tensor, has the same shape of the input, i.e., :math:`(x_1, x_2, ..., x_R)`.