forked from mindspore-Ecosystem/mindspore
add float64 support to select
This commit is contained in:
parent
483bb9de60
commit
2637133242
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -18,6 +18,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(Select,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
SelectGpuKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(Select,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
|
|
|
@ -34,6 +34,8 @@ void CalSelect(const size_t size, const bool* cond, const T* input_x, const T* i
|
|||
return;
|
||||
}
|
||||
|
||||
template void CalSelect<double>(const size_t size, const bool* cond, const double* input_X, const double* input_y,
|
||||
double* output, cudaStream_t cuda_stream);
|
||||
template void CalSelect<float>(const size_t size, const bool* cond, const float* input_X, const float* input_y,
|
||||
float* output, cudaStream_t cuda_stream);
|
||||
template void CalSelect<int>(const size_t size, const bool* cond, const int* input_X, const int* input_y, int* output,
|
||||
|
|
Loading…
Reference in New Issue