add mul dtype register to support bool && uint_8

This commit is contained in:
z00512249 2022-02-21 20:12:11 +08:00
parent 95f5075371
commit d1e2dfcbd1
5 changed files with 27 additions and 4 deletions

View File

@ -148,7 +148,11 @@ void ArithmeticCpuKernelMod<T>::Mul(const T *input1, const T *input2, T *out) {
auto iter = base_iter;
iter.SetPos(start);
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(input1[iter.GetInputPosA()] * input2[iter.GetInputPosB()]);
if constexpr (std::is_same_v<T, bool>) {
out[i] = static_cast<T>(input1[iter.GetInputPosA()] && input2[iter.GetInputPosB()]);
} else {
out[i] = static_cast<T>(input1[iter.GetInputPosA()] * input2[iter.GetInputPosB()]);
}
iter.GenNextPos();
}
};
@ -368,7 +372,11 @@ void ArithmeticCpuKernelMod<T>::SquaredDifference(const T *input1, const T *inpu
iter.SetPos(start);
for (size_t i = start; i < end; i++) {
T diff = input1[iter.GetInputPosA()] - input2[iter.GetInputPosB()];
out[i] = static_cast<T>(diff * diff);
if constexpr (std::is_same_v<T, bool>) {
out[i] = static_cast<T>(diff);
} else {
out[i] = static_cast<T>(diff * diff);
}
iter.GenNextPos();
}
};

View File

@ -88,6 +88,12 @@ MS_REG_CPU_KERNEL_T(
MS_REG_CPU_KERNEL_T(
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticCpuKernelMod, double);
MS_REG_CPU_KERNEL_T(
Mul, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
ArithmeticCpuKernelMod, uint8_t);
MS_REG_CPU_KERNEL_T(
Mul, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ArithmeticCpuKernelMod, bool);
MS_REG_CPU_KERNEL_T(
Div, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticCpuKernelMod, int32_t);

View File

@ -472,5 +472,9 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE(
LogicalOr, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, bool)
// uint64
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernelMod, bool)
} // namespace kernel
} // namespace mindspore

View File

@ -18,7 +18,6 @@
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"

View File

@ -1,4 +1,4 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -1677,6 +1677,7 @@ class Neg(Primitive):
"""Initialize Neg"""
self.init_prim_io_names(inputs=['x'], outputs=['y'])
class InplaceAdd(PrimitiveWithInfer):
"""
Adds `v` into specified rows of `x`. Computes `y` = `x`; y[i,] += `v`.
@ -1915,6 +1916,10 @@ class Mul(_MathBinaryOp):
[ 4. 10. 18.]
"""
def infer_dtype(self, x_dtype, y_dtype):
mul_valid_type = mstype.number_type + (mstype.bool_,)
return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, mul_valid_type, self.name)
def infer_value(self, x, y):
if x is not None and y is not None:
x = x.asnumpy()
@ -3496,6 +3501,7 @@ class Asinh(Primitive):
"""Initialize Asinh"""
self.init_prim_io_names(inputs=['x'], outputs=['y'])
class Sinh(Primitive):
r"""
Computes hyperbolic sine of the input element-wise.