forked from mindspore-Ecosystem/mindspore
!29290 Fix Broadcast primitive tuple input incorrect elimination.
Merge pull request !29290 from huanghui/fix-broadcast-use-flag
This commit is contained in:
commit
d65f8e3953
|
@ -88,6 +88,7 @@ mindspore::HashSet<std::string> prims_use_sequence_elements{
|
||||||
prim::kPrimUniformReal->name(),
|
prim::kPrimUniformReal->name(),
|
||||||
prim::kPrimSparseToDense->name(),
|
prim::kPrimSparseToDense->name(),
|
||||||
prim::kPrimSparseTensorDenseMatmul->name(),
|
prim::kPrimSparseTensorDenseMatmul->name(),
|
||||||
|
prim::kPrimBroadcast->name(),
|
||||||
"InvertPermutation",
|
"InvertPermutation",
|
||||||
"Meshgrid",
|
"Meshgrid",
|
||||||
"TransShape",
|
"TransShape",
|
||||||
|
|
|
@ -115,7 +115,8 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti
|
||||||
(void)std::transform(res.begin(), res.end(), std::back_inserter(elems), [](int64_t n) -> AbstractBasePtr {
|
(void)std::transform(res.begin(), res.end(), std::back_inserter(elems), [](int64_t n) -> AbstractBasePtr {
|
||||||
return std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(n), kInt64);
|
return std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(n), kInt64);
|
||||||
});
|
});
|
||||||
|
SetSequenceElementsUseFlags(xs, true);
|
||||||
|
SetSequenceElementsUseFlags(ys, true);
|
||||||
return std::make_shared<AbstractTuple>(elems);
|
return std::make_shared<AbstractTuple>(elems);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -94,13 +94,18 @@ AbstractBasePtr TransposeInfer(const abstract::AnalysisEnginePtr &, const Primit
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
constexpr size_t input_perm_index = 1;
|
constexpr size_t input_perm_index = 1;
|
||||||
(void)CheckAndConvertUtils::CheckInteger("Transpose infer", SizeToLong(input_args.size()), kGreaterEqual, 1,
|
// The second input is optional.
|
||||||
|
constexpr size_t input_size1 = 1;
|
||||||
|
constexpr size_t input_size2 = 2;
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("Transpose infer", SizeToLong(input_args.size()), kGreaterEqual, input_size1,
|
||||||
primitive->name());
|
primitive->name());
|
||||||
auto type = InferType(primitive, input_args);
|
auto type = InferType(primitive, input_args);
|
||||||
auto shape = InferShape(primitive, input_args);
|
auto shape = InferShape(primitive, input_args);
|
||||||
auto res = abstract::MakeAbstract(shape, type);
|
auto res = abstract::MakeAbstract(shape, type);
|
||||||
// Set all used flags of tuple as true.
|
if (input_args.size() == input_size2) {
|
||||||
SetSequenceElementsUseFlags(input_args[input_perm_index], true);
|
SetSequenceElementsUseFlags(input_args[input_perm_index], true);
|
||||||
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_EVAL_IMPL(Transpose, prim::kPrimTranspose, TransposeInfer, nullptr, true);
|
REGISTER_PRIMITIVE_EVAL_IMPL(Transpose, prim::kPrimTranspose, TransposeInfer, nullptr, true);
|
||||||
|
|
|
@ -526,7 +526,7 @@ class Broadcast(PrimitiveWithInfer):
|
||||||
group (str): The communication group to work on. Default: "GlobalComm.WORLD_COMM_GROUP".
|
group (str): The communication group to work on. Default: "GlobalComm.WORLD_COMM_GROUP".
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
- **input_x** (tuple[Tensor]) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
Tensor, has the same shape of the input, i.e., :math:`(x_1, x_2, ..., x_R)`.
|
Tensor, has the same shape of the input, i.e., :math:`(x_1, x_2, ..., x_R)`.
|
||||||
|
|
Loading…
Reference in New Issue