forked from mindspore-Ecosystem/mindspore
add mul dtype register to support bool && uint_8
This commit is contained in:
parent
95f5075371
commit
d1e2dfcbd1
|
@ -148,7 +148,11 @@ void ArithmeticCpuKernelMod<T>::Mul(const T *input1, const T *input2, T *out) {
|
|||
auto iter = base_iter;
|
||||
iter.SetPos(start);
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(input1[iter.GetInputPosA()] * input2[iter.GetInputPosB()]);
|
||||
if constexpr (std::is_same_v<T, bool>) {
|
||||
out[i] = static_cast<T>(input1[iter.GetInputPosA()] && input2[iter.GetInputPosB()]);
|
||||
} else {
|
||||
out[i] = static_cast<T>(input1[iter.GetInputPosA()] * input2[iter.GetInputPosB()]);
|
||||
}
|
||||
iter.GenNextPos();
|
||||
}
|
||||
};
|
||||
|
@ -368,7 +372,11 @@ void ArithmeticCpuKernelMod<T>::SquaredDifference(const T *input1, const T *inpu
|
|||
iter.SetPos(start);
|
||||
for (size_t i = start; i < end; i++) {
|
||||
T diff = input1[iter.GetInputPosA()] - input2[iter.GetInputPosB()];
|
||||
out[i] = static_cast<T>(diff * diff);
|
||||
if constexpr (std::is_same_v<T, bool>) {
|
||||
out[i] = static_cast<T>(diff);
|
||||
} else {
|
||||
out[i] = static_cast<T>(diff * diff);
|
||||
}
|
||||
iter.GenNextPos();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -88,6 +88,12 @@ MS_REG_CPU_KERNEL_T(
|
|||
MS_REG_CPU_KERNEL_T(
|
||||
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticCpuKernelMod, double);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Mul, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
ArithmeticCpuKernelMod, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Mul, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticCpuKernelMod, bool);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Div, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ArithmeticCpuKernelMod, int32_t);
|
||||
|
|
|
@ -472,5 +472,9 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
MS_REG_GPU_KERNEL_ONE(
|
||||
LogicalOr, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernelMod, bool)
|
||||
// uint64
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Mul, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernelMod, bool)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1677,6 +1677,7 @@ class Neg(Primitive):
|
|||
"""Initialize Neg"""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
|
||||
|
||||
class InplaceAdd(PrimitiveWithInfer):
|
||||
"""
|
||||
Adds `v` into specified rows of `x`. Computes `y` = `x`; y[i,] += `v`.
|
||||
|
@ -1915,6 +1916,10 @@ class Mul(_MathBinaryOp):
|
|||
[ 4. 10. 18.]
|
||||
"""
|
||||
|
||||
def infer_dtype(self, x_dtype, y_dtype):
|
||||
mul_valid_type = mstype.number_type + (mstype.bool_,)
|
||||
return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, mul_valid_type, self.name)
|
||||
|
||||
def infer_value(self, x, y):
|
||||
if x is not None and y is not None:
|
||||
x = x.asnumpy()
|
||||
|
@ -3496,6 +3501,7 @@ class Asinh(Primitive):
|
|||
"""Initialize Asinh"""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
|
||||
|
||||
class Sinh(Primitive):
|
||||
r"""
|
||||
Computes hyperbolic sine of the input element-wise.
|
||||
|
|
Loading…
Reference in New Issue