soft shrink bug
This commit is contained in:
parent
758dff008d
commit
b04d92fb0b
|
@ -20,28 +20,24 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
#define SOFT_SHRINK_GRAD_CPU_REGISTER(DT, T) \
|
||||
KernelAttr().AddInputAttr(DT).AddOutputAttr(DT), &SoftShrinkGradCpuKernelMod::LaunchKernel<T>
|
||||
KernelAttr().AddInputAttr(DT).AddInputAttr(DT).AddOutputAttr(DT), &SoftShrinkGradCpuKernelMod::LaunchKernel<T>
|
||||
|
||||
template <typename T>
|
||||
bool SoftShrinkGradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
T *input_addr = reinterpret_cast<T *>(inputs.at(kIndex0)->addr);
|
||||
T *output_addr = reinterpret_cast<T *>(outputs.at(kIndex0)->addr);
|
||||
T *dy_addr = reinterpret_cast<T *>(inputs.at(kIndex0)->addr);
|
||||
T *x_addr = reinterpret_cast<T *>(inputs.at(kIndex1)->addr);
|
||||
T *dx_addr = reinterpret_cast<T *>(outputs.at(kIndex0)->addr);
|
||||
|
||||
float lambd_value = lambd_;
|
||||
auto task = [input_addr, output_addr, lambd_value](size_t start, size_t end) {
|
||||
auto task = [dy_addr, x_addr, dx_addr, lambd_value](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
if (input_addr[i] > 0) {
|
||||
output_addr[i] = input_addr[i] + lambd_value;
|
||||
} else if (input_addr[i] < 0) {
|
||||
output_addr[i] = input_addr[i] - lambd_value;
|
||||
} else {
|
||||
output_addr[i] = 0;
|
||||
}
|
||||
dx_addr[i] = (x_addr[i] >= -lambd_value && x_addr[i] <= lambd_value) ? 0 : dy_addr[i];
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, size_, this, ¶llel_search_info_);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,12 +30,12 @@ __global__ void SoftShrinkComp(size_t size, const T *input, const float lambd, T
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void SoftShrinkGradComp(size_t size, const T *input, const float lambd, T *output) {
|
||||
__global__ void SoftShrinkGradComp(size_t size, const T *dy_addr, const T *x_addr, const float lambd, T *dx_addr) {
|
||||
const T positive_lambd = static_cast<T>(lambd);
|
||||
const T negative_lambd = static_cast<T>(-1 * lambd);
|
||||
const T zero = static_cast<T>(0);
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
output[pos] = (input[pos] > zero) ? (input[pos] + positive_lambd)
|
||||
: ((input[pos] < zero) ? (input[pos] - positive_lambd) : (zero));
|
||||
dx_addr[pos] = (x_addr[pos] >= negative_lambd && x_addr[pos] <= positive_lambd) ? zero : dy_addr[pos];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -46,10 +46,10 @@ void SoftShrink(const size_t &size, const T *input, const float lambd, T *output
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void SoftShrinkGrad(const size_t &size, const T *input, const float lambd, T *output, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream) {
|
||||
SoftShrinkGradComp<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(size, input, lambd,
|
||||
output);
|
||||
void SoftShrinkGrad(const size_t &size, const T *dy_addr, const T *x_addr, const float lambd, T *dx_addr,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
SoftShrinkGradComp<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(size, dy_addr, x_addr,
|
||||
lambd, dx_addr);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void SoftShrink(const size_t &size, const half *input, const float lambd, half *output,
|
||||
|
@ -60,11 +60,16 @@ template CUDA_LIB_EXPORT void SoftShrink(const size_t &size, const int *input, c
|
|||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void SoftShrink(const size_t &size, const int64_t *input, const float lambd, int64_t *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void SoftShrinkGrad(const size_t &size, const half *input, const float lambd, half *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void SoftShrinkGrad(const size_t &size, const float *input, const float lambd, float *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void SoftShrinkGrad(const size_t &size, const int *input, const float lambd, int *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void SoftShrinkGrad(const size_t &size, const int64_t *input, const float lambd,
|
||||
int64_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void SoftShrinkGrad(const size_t &size, const half *dy_addr, const half *x_addr,
|
||||
const float lambd, half *dx_addr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void SoftShrinkGrad(const size_t &size, const float *dy_addr, const float *x_addr,
|
||||
const float lambd, float *dx_addr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void SoftShrinkGrad(const size_t &size, const int *dy_addr, const int *x_addr,
|
||||
const float lambd, int *dx_addr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void SoftShrinkGrad(const size_t &size, const int64_t *dy_addr, const int64_t *x_addr,
|
||||
const float lambd, int64_t *dx_addr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
|
@ -24,7 +24,7 @@ CUDA_LIB_EXPORT void SoftShrink(const size_t &size, const T *input, const float
|
|||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void SoftShrinkGrad(const size_t &size, const T *input, const float lambd, T *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
CUDA_LIB_EXPORT void SoftShrinkGrad(const size_t &size, const T *dy_addr, const T *x_addr, const float lambd,
|
||||
T *dx_addr, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HSHRINK_IMPL_CUH_
|
||||
|
|
|
@ -21,15 +21,16 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
#define SOFT_SHRINK_GRAD_GPU_REGISTER(DT, T) \
|
||||
KernelAttr().AddInputAttr(DT).AddOutputAttr(DT), &SoftShrinkGradGpuKernelMod::LaunchKernel<T>
|
||||
KernelAttr().AddInputAttr(DT).AddInputAttr(DT).AddOutputAttr(DT), &SoftShrinkGradGpuKernelMod::LaunchKernel<T>
|
||||
|
||||
template <typename T>
|
||||
bool SoftShrinkGradGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, kIndex0);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, kIndex0);
|
||||
SoftShrinkGrad(size_, input_addr, lambd_, output_addr, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
T *dy_addr = GetDeviceAddress<T>(inputs, kIndex0);
|
||||
T *x_addr = GetDeviceAddress<T>(inputs, kIndex1);
|
||||
T *dx_addr = GetDeviceAddress<T>(outputs, kIndex0);
|
||||
SoftShrinkGrad(size_, dy_addr, x_addr, lambd_, dx_addr, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ class SoftShrinkNet(nn.Cell):
|
|||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype', [np.float32])
|
||||
@pytest.mark.parametrize("data_shape", [(3, 4), (4, 5, 6, 7)])
|
||||
@pytest.mark.parametrize("lambd", [0.5])
|
||||
@pytest.mark.parametrize("lambd", [0.5, 0.75])
|
||||
def test_soft_shrink(dtype, data_shape, lambd):
|
||||
"""
|
||||
Feature: SoftShrink cpu kernel
|
||||
|
|
Loading…
Reference in New Issue