forked from mindspore-Ecosystem/mindspore
!29290 Fix Broadcast primitive tuple input incorrect elimination.
Merge pull request !29290 from huanghui/fix-broadcast-use-flag
This commit is contained in:
commit
d65f8e3953
|
@ -88,6 +88,7 @@ mindspore::HashSet<std::string> prims_use_sequence_elements{
|
|||
prim::kPrimUniformReal->name(),
|
||||
prim::kPrimSparseToDense->name(),
|
||||
prim::kPrimSparseTensorDenseMatmul->name(),
|
||||
prim::kPrimBroadcast->name(),
|
||||
"InvertPermutation",
|
||||
"Meshgrid",
|
||||
"TransShape",
|
||||
|
|
|
@ -115,7 +115,8 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti
|
|||
(void)std::transform(res.begin(), res.end(), std::back_inserter(elems), [](int64_t n) -> AbstractBasePtr {
|
||||
return std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(n), kInt64);
|
||||
});
|
||||
|
||||
SetSequenceElementsUseFlags(xs, true);
|
||||
SetSequenceElementsUseFlags(ys, true);
|
||||
return std::make_shared<AbstractTuple>(elems);
|
||||
}
|
||||
|
||||
|
|
|
@ -94,13 +94,18 @@ AbstractBasePtr TransposeInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
constexpr size_t input_perm_index = 1;
|
||||
(void)CheckAndConvertUtils::CheckInteger("Transpose infer", SizeToLong(input_args.size()), kGreaterEqual, 1,
|
||||
// The second input is optional.
|
||||
constexpr size_t input_size1 = 1;
|
||||
constexpr size_t input_size2 = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("Transpose infer", SizeToLong(input_args.size()), kGreaterEqual, input_size1,
|
||||
primitive->name());
|
||||
auto type = InferType(primitive, input_args);
|
||||
auto shape = InferShape(primitive, input_args);
|
||||
auto res = abstract::MakeAbstract(shape, type);
|
||||
// Set all used flags of tuple as true.
|
||||
SetSequenceElementsUseFlags(input_args[input_perm_index], true);
|
||||
if (input_args.size() == input_size2) {
|
||||
SetSequenceElementsUseFlags(input_args[input_perm_index], true);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Transpose, prim::kPrimTranspose, TransposeInfer, nullptr, true);
|
||||
|
|
|
@ -526,7 +526,7 @@ class Broadcast(PrimitiveWithInfer):
|
|||
group (str): The communication group to work on. Default: "GlobalComm.WORLD_COMM_GROUP".
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||
- **input_x** (tuple[Tensor]) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape of the input, i.e., :math:`(x_1, x_2, ..., x_R)`.
|
||||
|
|
Loading…
Reference in New Issue