add float64 support to select

This commit is contained in:
Peilin Wang 2021-03-12 16:01:25 -05:00
parent 483bb9de60
commit 2637133242
2 changed files with 10 additions and 1 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,6 +18,13 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(Select,
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
SelectGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(Select,
KernelAttr()
.AddInputAttr(kNumberTypeBool)

View File

@ -34,6 +34,8 @@ void CalSelect(const size_t size, const bool* cond, const T* input_x, const T* i
return;
}
template void CalSelect<double>(const size_t size, const bool* cond, const double* input_X, const double* input_y,
double* output, cudaStream_t cuda_stream);
template void CalSelect<float>(const size_t size, const bool* cond, const float* input_X, const float* input_y,
float* output, cudaStream_t cuda_stream);
template void CalSelect<int>(const size_t size, const bool* cond, const int* input_X, const int* input_y, int* output,