!8632 [MD] Add Multiply op for Lite Dataset

From: @jiangzhiwen8
Reviewed-by: @liucunwei,@heleiwang
Signed-off-by: @liucunwei
This commit is contained in:
mindspore-ci-bot 2020-11-19 09:58:43 +08:00 committed by Gitee
commit c8047a35ab
3 changed files with 158 additions and 3 deletions

View File

@ -495,9 +495,7 @@ bool Divide(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
}
int64_t total_size = src_a.height_ * src_a.width_ * src_a.channel_;
if (src_a.data_type_ == LDataType::BOOL) {
DivideImpl<bool>(src_a, src_b, *dst, total_size);
} else if (src_a.data_type_ == LDataType::INT8) {
if (src_a.data_type_ == LDataType::INT8) {
DivideImpl<int8_t>(src_a, src_b, *dst, total_size);
} else if (src_a.data_type_ == LDataType::UINT8) {
DivideImpl<uint8_t>(src_a, src_b, *dst, total_size);
@ -523,5 +521,102 @@ bool Divide(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
return true;
}
template <typename T>
inline void MultiplyImpl(const T *src0, const T *src1, T *dst, int64_t total_size) {
for (size_t i = 0; i < total_size; i++) {
dst[i] = src0[i] * src1[i];
}
}
template <>
inline void MultiplyImpl(const uint8_t *src0, const uint8_t *src1, uint8_t *dst, int64_t total_size) {
int64_t x = 0;
#ifdef USE_NEON
const int64_t step = 32;
for (; x <= total_size - step; x += step) {
uint8x16_t v_src00 = vld1q_u8(src0 + x);
uint8x16_t v_src01 = vld1q_u8(src0 + x + 16);
uint8x16_t v_src10 = vld1q_u8(src1 + x);
uint8x16_t v_src11 = vld1q_u8(src1 + x + 16);
uint8x16_t v_dst_l, v_dst_h;
v_dst_l = vmull_u8(vget_low_u8(v_src00), vget_low_u8(v_src10));
v_dst_h = vmull_u8(vget_high_u8(v_src00), vget_high_u8(v_src10));
vst1q_u8(dst + x, vcombine_u8(vqmovn_u16(v_dst_l), vqmovn_u16(v_dst_h)));
v_dst_l = vmull_u8(vget_low_u8(v_src01), vget_low_u8(v_src11));
v_dst_h = vmull_u8(vget_high_u8(v_src01), vget_high_u8(v_src11));
vst1q_u8(dst + x + 16, vcombine_u8(vqmovn_u16(v_dst_l), vqmovn_u16(v_dst_h)));
}
#endif
for (; x < total_size; x++) {
int32_t val = src0[x] * src1[x];
dst[x] = std::max<int32_t>(std::numeric_limits<uint8_t>::min(),
std::min<int32_t>(std::numeric_limits<uint8_t>::max(), val));
}
}
template <>
inline void MultiplyImpl(const uint16_t *src0, const uint16_t *src1, uint16_t *dst, int64_t total_size) {
for (size_t i = 0; i < total_size; i++) {
int32_t val = src0[i] * src1[i];
dst[i] = std::max<int32_t>(std::numeric_limits<uint16_t>::min(),
std::min<int32_t>(std::numeric_limits<uint16_t>::max(), val));
}
}
template <>
inline void MultiplyImpl(const uint32_t *src0, const uint32_t *src1, uint32_t *dst, int64_t total_size) {
for (size_t i = 0; i < total_size; i++) {
int64_t val = src0[i] * src1[i];
dst[i] = std::max<int64_t>(std::numeric_limits<uint32_t>::min(),
std::min<int64_t>(std::numeric_limits<uint32_t>::max(), val));
}
}
bool Multiply(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
if (src_a.width_ != src_b.width_ || src_a.height_ != src_b.height_ || src_a.channel_ != src_b.channel_) {
return false;
}
if (src_a.data_type_ != src_b.data_type_) {
return false;
}
if (dst->IsEmpty()) {
dst->Init(src_a.width_, src_a.height_, src_a.channel_, src_a.data_type_);
} else if (src_a.width_ != dst->width_ || src_a.height_ != dst->height_ || src_a.channel_ != dst->channel_) {
return false;
} else if (src_a.data_type_ != dst->data_type_) {
return false;
}
int64_t total_size = src_a.height_ * src_a.width_ * src_a.channel_;
if (src_a.data_type_ == LDataType::INT8) {
MultiplyImpl<int8_t>(src_a, src_b, *dst, total_size);
} else if (src_a.data_type_ == LDataType::UINT8) {
MultiplyImpl<uint8_t>(src_a, src_b, *dst, total_size);
} else if (src_a.data_type_ == LDataType::INT16) {
MultiplyImpl<int16_t>(src_a, src_b, *dst, total_size);
} else if (src_a.data_type_ == LDataType::UINT16) {
MultiplyImpl<uint16_t>(src_a, src_b, *dst, total_size);
} else if (src_a.data_type_ == LDataType::INT32) {
MultiplyImpl<int32_t>(src_a, src_b, *dst, total_size);
} else if (src_a.data_type_ == LDataType::UINT32) {
MultiplyImpl<uint32_t>(src_a, src_b, *dst, total_size);
} else if (src_a.data_type_ == LDataType::INT64) {
MultiplyImpl<int64_t>(src_a, src_b, *dst, total_size);
} else if (src_a.data_type_ == LDataType::UINT64) {
MultiplyImpl<uint64_t>(src_a, src_b, *dst, total_size);
} else if (src_a.data_type_ == LDataType::FLOAT32) {
MultiplyImpl<float>(src_a, src_b, *dst, total_size);
} else if (src_a.data_type_ == LDataType::FLOAT64) {
MultiplyImpl<double>(src_a, src_b, *dst, total_size);
} else {
return false;
}
return true;
}
} // namespace dataset
} // namespace mindspore

View File

@ -260,6 +260,9 @@ bool Subtract(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst);
/// \brief Calculates the division between the two images for each element
bool Divide(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst);
/// \brief Calculates the multiply between the two images for each element
bool Multiply(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst);
} // namespace dataset
} // namespace mindspore
#endif // MINI_MAT_H_

View File

@ -822,3 +822,60 @@ TEST_F(MindDataImageProcess, TestDivideFloat) {
static_cast<FLOAT32_C1 *>(dst_float.data_ptr_)[i].c1);
}
}
TEST_F(MindDataImageProcess, TestMultiplyUint8) {
const size_t cols = 4;
// Test uint8
LiteMat src1_uint8(1, cols);
LiteMat src2_uint8(1, cols);
LiteMat expect_uint8(1, cols);
for (size_t i = 0; i < cols; i++) {
static_cast<UINT8_C1 *>(src1_uint8.data_ptr_)[i] = 8;
static_cast<UINT8_C1 *>(src2_uint8.data_ptr_)[i] = 4;
static_cast<UINT8_C1 *>(expect_uint8.data_ptr_)[i] = 32;
}
LiteMat dst_uint8;
EXPECT_TRUE(Multiply(src1_uint8, src2_uint8, &dst_uint8));
for (size_t i = 0; i < cols; i++) {
EXPECT_EQ(static_cast<UINT8_C1 *>(expect_uint8.data_ptr_)[i].c1,
static_cast<UINT8_C1 *>(dst_uint8.data_ptr_)[i].c1);
}
}
TEST_F(MindDataImageProcess, TestMultiplyUInt16) {
const size_t cols = 4;
// Test int16
LiteMat src1_int16(1, cols, LDataType(LDataType::UINT16));
LiteMat src2_int16(1, cols, LDataType(LDataType::UINT16));
LiteMat expect_int16(1, cols, LDataType(LDataType::UINT16));
for (size_t i = 0; i < cols; i++) {
static_cast<UINT16_C1 *>(src1_int16.data_ptr_)[i] = 60000;
static_cast<UINT16_C1 *>(src2_int16.data_ptr_)[i] = 2;
static_cast<UINT16_C1 *>(expect_int16.data_ptr_)[i] = 65535;
}
LiteMat dst_int16;
EXPECT_TRUE(Multiply(src1_int16, src2_int16, &dst_int16));
for (size_t i = 0; i < cols; i++) {
EXPECT_EQ(static_cast<UINT16_C1 *>(expect_int16.data_ptr_)[i].c1,
static_cast<UINT16_C1 *>(dst_int16.data_ptr_)[i].c1);
}
}
TEST_F(MindDataImageProcess, TestMultiplyFloat) {
const size_t cols = 4;
// Test float
LiteMat src1_float(1, cols, LDataType(LDataType::FLOAT32));
LiteMat src2_float(1, cols, LDataType(LDataType::FLOAT32));
LiteMat expect_float(1, cols, LDataType(LDataType::FLOAT32));
for (size_t i = 0; i < cols; i++) {
static_cast<FLOAT32_C1 *>(src1_float.data_ptr_)[i] = 30.0f;
static_cast<FLOAT32_C1 *>(src2_float.data_ptr_)[i] = -2.0f;
static_cast<FLOAT32_C1 *>(expect_float.data_ptr_)[i] = -60.0f;
}
LiteMat dst_float;
EXPECT_TRUE(Multiply(src1_float, src2_float, &dst_float));
for (size_t i = 0; i < cols; i++) {
EXPECT_FLOAT_EQ(static_cast<FLOAT32_C1 *>(expect_float.data_ptr_)[i].c1,
static_cast<FLOAT32_C1 *>(dst_float.data_ptr_)[i].c1);
}
}