forked from mindspore-Ecosystem/mindspore
commit
09caa368e9
|
@ -487,7 +487,7 @@ ShapeVector DeviceShapeTransfer::TransCore(const ShapeVector &shape, const std::
|
|||
auto temp_shape = shape;
|
||||
if (!IsOneOfNoPaddingFormat(format) && format != kOpFormat_FRACTAL_ZN_LSTM && shape.size() < kDim4 &&
|
||||
!IsOneOf3DFormat(format)) {
|
||||
MS_LOG(WARNING) << "Origin shape size is less than 4, should be Padding shape by Default firstly";
|
||||
MS_LOG(INFO) << "Origin shape size is less than 4, should be Padding shape by Default firstly";
|
||||
temp_shape = PaddingShapeTo4dDefault(shape);
|
||||
}
|
||||
if (shape.size() != kDim5 && IsOneOf3DFormat(format)) {
|
||||
|
|
|
@ -318,13 +318,6 @@ NodePtr Emitter::ReduceSum(const NodePtr &x, const ShapeVector &axis, bool keep_
|
|||
(void)real_axis.emplace_back(i);
|
||||
}
|
||||
}
|
||||
if (std::any_of(real_axis.begin(), real_axis.end(), [](int64_t v) { return v < 0; })) {
|
||||
for (size_t i = 0; i < real_axis.size(); i++) {
|
||||
if (real_axis[i] < 0) {
|
||||
real_axis[i] = real_axis[i] + SizeToLong(real_axis.size());
|
||||
}
|
||||
}
|
||||
}
|
||||
return Emit(prim::kPrimReduceSum->name(), {x, Value<ShapeVector>(real_axis)}, {{"keep_dims", MakeValue(keep_dims)}});
|
||||
#endif
|
||||
return Emit(prim::kPrimReduceSum->name(), {x, Value<ShapeVector>(axis)}, {{"keep_dims", MakeValue(keep_dims)}});
|
||||
|
|
|
@ -27,12 +27,14 @@ broadcast_to_op_info = TBERegOp("DynamicBroadcastTo") \
|
|||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "shape", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.U32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue