forked from mindspore-Ecosystem/mindspore
!14767 update tensor add: support input type of int32, uint32 on CPU kernel
From: @zyx5256 Reviewed-by: @liangchenghui,@wuxuejian Signed-off-by: @wuxuejian
This commit is contained in:
commit
4d38302c68
|
@ -19,8 +19,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
void TensorAddCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
template <typename T>
|
||||
void TensorAddCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
// Init shape ans strides
|
||||
input_shape_a_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
|
@ -28,13 +28,14 @@ void TensorAddCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
}
|
||||
|
||||
bool TensorAddCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto input_addr_a = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
auto input_addr_b = reinterpret_cast<float *>(inputs[1]->addr);
|
||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
auto output_size = outputs[0]->size / sizeof(float);
|
||||
template <typename T>
|
||||
bool TensorAddCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
T *input_addr_a = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
T *input_addr_b = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
T *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
size_t output_size = outputs[0]->size / sizeof(T);
|
||||
if (input_shape_a_ == input_shape_b_) {
|
||||
auto task = [output_addr, input_addr_a, input_addr_b](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class TensorAddCPUKernel : public CPUKernel {
|
||||
public:
|
||||
TensorAddCPUKernel() = default;
|
||||
|
@ -39,9 +40,15 @@ class TensorAddCPUKernel : public CPUKernel {
|
|||
std::vector<size_t> output_shape_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Add, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
TensorAddCPUKernel);
|
||||
TensorAddCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Add, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
TensorAddCPUKernel, int);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Add, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
TensorAddCPUKernel, uint32_t);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
Loading…
Reference in New Issue