Merge pull request !49336 from liubuyu/acl_launch
This commit is contained in:
i-robot 2023-02-25 02:06:10 +00:00 committed by Gitee
commit 09caa368e9
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 7 additions and 12 deletions

View File

@ -487,7 +487,7 @@ ShapeVector DeviceShapeTransfer::TransCore(const ShapeVector &shape, const std::
auto temp_shape = shape;
if (!IsOneOfNoPaddingFormat(format) && format != kOpFormat_FRACTAL_ZN_LSTM && shape.size() < kDim4 &&
!IsOneOf3DFormat(format)) {
MS_LOG(WARNING) << "Origin shape size is less than 4, should be Padding shape by Default firstly";
MS_LOG(INFO) << "Origin shape size is less than 4, should be Padding shape by Default firstly";
temp_shape = PaddingShapeTo4dDefault(shape);
}
if (shape.size() != kDim5 && IsOneOf3DFormat(format)) {

View File

@ -318,13 +318,6 @@ NodePtr Emitter::ReduceSum(const NodePtr &x, const ShapeVector &axis, bool keep_
(void)real_axis.emplace_back(i);
}
}
if (std::any_of(real_axis.begin(), real_axis.end(), [](int64_t v) { return v < 0; })) {
for (size_t i = 0; i < real_axis.size(); i++) {
if (real_axis[i] < 0) {
real_axis[i] = real_axis[i] + SizeToLong(real_axis.size());
}
}
}
return Emit(prim::kPrimReduceSum->name(), {x, Value<ShapeVector>(real_axis)}, {{"keep_dims", MakeValue(keep_dims)}});
#endif
return Emit(prim::kPrimReduceSum->name(), {x, Value<ShapeVector>(axis)}, {{"keep_dims", MakeValue(keep_dims)}});

View File

@ -27,12 +27,14 @@ broadcast_to_op_info = TBERegOp("DynamicBroadcastTo") \
.input(0, "x", False, "required", "all") \
.input(1, "shape", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default) \
.dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.U32_Default) \
.get_op_info()