modified documentation and gpu kernel for smoothL1Loss

fix pylint

changed doc and code for SmoothL1Loss to be same a dchip. fixed grad kernel

fix ci
This commit is contained in:
Peilin Wang 2020-08-18 17:53:11 -04:00
parent 56bd92b88f
commit 0d5220d33c
8 changed files with 90 additions and 52 deletions

View File

@ -18,47 +18,47 @@
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
__global__ void SmoothL1LossKernel(const int input_size, const float sigma, const T *prediction, const T *target,
__global__ void SmoothL1LossKernel(const int input_size, const float beta, const T *prediction, const T *target,
T *loss) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
T value = (prediction[i] - target[i]) > 0 ? (prediction[i] - target[i]) : (target[i] - prediction[i]);
if (value < sigma) {
loss[i] = static_cast<T>(0.5) * value * value;
T value = fabsf(prediction[i] - target[i]);
if (value < beta) {
loss[i] = 0.5 * value * value / beta;
} else {
loss[i] = value - static_cast<T>(0.5);
loss[i] = value - (0.5 * beta);
}
}
}
template <typename T>
void SmoothL1Loss(const int &input_size, const float &sigma, const T *prediction, const T *target, T *loss,
void SmoothL1Loss(const int &input_size, const float &beta, const T *prediction, const T *target, T *loss,
cudaStream_t stream) {
SmoothL1LossKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, sigma, prediction, target, loss);
SmoothL1LossKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, beta, prediction, target, loss);
}
template <typename T>
__global__ void SmoothL1LossGradKernel(const int input_size, const float sigma, const T *prediction, const T *target,
__global__ void SmoothL1LossGradKernel(const int input_size, const float beta, const T *prediction, const T *target,
const T *dloss, T *dx) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
T value = prediction[i] - target[i];
if (value > static_cast<T>(sigma)) {
if (value > beta) {
dx[i] = dloss[i];
} else if (value < static_cast<T>(-sigma)) {
} else if (value < -beta) {
dx[i] = -dloss[i];
} else {
dx[i] = value * dloss[i];
dx[i] = (value / beta) * dloss[i];
}
}
}
template <typename T>
void SmoothL1LossGrad(const int &input_size, const float &sigma, const T *prediction, const T *target, const T *dloss,
void SmoothL1LossGrad(const int &input_size, const float &beta, const T *prediction, const T *target, const T *dloss,
T *dx, cudaStream_t stream) {
SmoothL1LossGradKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, sigma, prediction, target,
SmoothL1LossGradKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, beta, prediction, target,
dloss, dx);
}
template void SmoothL1Loss(const int &input_size, const float &sigma, const float *prediction, const float *target,
float *loss, cudaStream_t stream);
template void SmoothL1LossGrad(const int &input_size, const float &sigma, const float *prediction, const float *target,
const float *dloss, float *dx, cudaStream_t stream);
template void SmoothL1Loss<float>(const int &input_size, const float &beta, const float *prediction,
const float *target, float *loss, cudaStream_t stream);
template void SmoothL1LossGrad<float>(const int &input_size, const float &beta, const float *prediction,
const float *target, const float *dloss, float *dx, cudaStream_t stream);

View File

@ -17,9 +17,9 @@
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_
template <typename T>
void SmoothL1Loss(const int &input_size, const float &sigma, const T *prediction, const T *target, T *loss,
void SmoothL1Loss(const int &input_size, const float &beta, const T *prediction, const T *target, T *loss,
cudaStream_t stream);
template <typename T>
void SmoothL1LossGrad(const int &input_size, const float &sigma, const T *prediction, const T *target, const T *dloss,
void SmoothL1LossGrad(const int &input_size, const float &beta, const T *prediction, const T *target, const T *dloss,
T *dx, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_

View File

@ -26,7 +26,7 @@ namespace kernel {
template <typename T>
class SmoothL1LossGpuKernel : public GpuKernel {
public:
SmoothL1LossGpuKernel() : input_size_(1), sigma_(1.0) {}
SmoothL1LossGpuKernel() : input_size_(1), beta_(1.0) {}
~SmoothL1LossGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -39,7 +39,7 @@ class SmoothL1LossGpuKernel : public GpuKernel {
T *target = GetDeviceAddress<T>(inputs, 1);
T *loss = GetDeviceAddress<T>(outputs, 0);
SmoothL1Loss(input_size_, sigma_, prediction, target, loss, reinterpret_cast<cudaStream_t>(stream_ptr));
SmoothL1Loss(input_size_, beta_, prediction, target, loss, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
@ -49,7 +49,7 @@ class SmoothL1LossGpuKernel : public GpuKernel {
input_size_ *= input_shape[i];
}
sigma_ = GetAttr<float>(kernel_node, "sigma");
beta_ = GetAttr<float>(kernel_node, "beta");
InitSizeLists();
return true;
}
@ -63,7 +63,7 @@ class SmoothL1LossGpuKernel : public GpuKernel {
private:
size_t input_size_;
float sigma_;
float beta_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;

View File

@ -26,7 +26,7 @@ namespace kernel {
template <typename T>
class SmoothL1LossGradGpuKernel : public GpuKernel {
public:
SmoothL1LossGradGpuKernel() : input_size_(1), sigma_(1.0) {}
SmoothL1LossGradGpuKernel() : input_size_(1), beta_(1.0) {}
~SmoothL1LossGradGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -40,7 +40,7 @@ class SmoothL1LossGradGpuKernel : public GpuKernel {
T *dloss = GetDeviceAddress<T>(inputs, 2);
T *dx = GetDeviceAddress<T>(outputs, 0);
SmoothL1LossGrad(input_size_, sigma_, prediction, target, dloss, dx, reinterpret_cast<cudaStream_t>(stream_ptr));
SmoothL1LossGrad(input_size_, beta_, prediction, target, dloss, dx, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
@ -50,7 +50,7 @@ class SmoothL1LossGradGpuKernel : public GpuKernel {
input_size_ *= input_shape[i];
}
sigma_ = GetAttr<float>(kernel_node, "sigma");
beta_ = GetAttr<float>(kernel_node, "beta");
InitSizeLists();
return true;
}
@ -64,7 +64,7 @@ class SmoothL1LossGradGpuKernel : public GpuKernel {
private:
size_t input_size_;
float sigma_;
float beta_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;

View File

@ -689,7 +689,7 @@ def get_bprop_top_kv2(self):
@bprop_getters.register(P.SmoothL1Loss)
def get_bprop_smooth_l1_loss(self):
"""Grad definition for `SmoothL1Loss` operation."""
grad = G.SmoothL1LossGrad(self.sigma)
grad = G.SmoothL1LossGrad(self.beta)
def bprop(prediction, target, out, dout):
dx = grad(prediction, target, dout)

View File

@ -1258,7 +1258,7 @@ class SmoothL1LossGrad(PrimitiveWithInfer):
"""Computes gradient for prediction on SmoothL1Loss."""
@prim_attr_register
def __init__(self, sigma=1.0):
def __init__(self, beta=1.0):
pass
def infer_shape(self, prediction, target, dloss):

View File

@ -1658,11 +1658,11 @@ class SmoothL1Loss(PrimitiveWithInfer):
Sets input prediction as `X`, input target as `Y`, output as `loss`. Then,
.. math::
\text{SmoothL1Loss} = \begin{cases}0.5x^{2}, &if \left |x \right |\leq \text{sigma} \cr
\left |x \right|-0.5, &\text{otherwise}\end{cases}
\text{SmoothL1Loss} = \begin{cases} \frac{0.5 x^{2}}{\text{beta}, &if \left |x \right | < \text{beta} \cr
\left |x \right|-0.5 \text{beta}, &\text{otherwise}\end{cases}
Args:
sigma (float): A parameter used to control the point where the function will change from
beta (float): A parameter used to control the point where the function will change from
quadratic to linear. Default: 1.0.
Inputs:
@ -1681,9 +1681,9 @@ class SmoothL1Loss(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self, sigma=1.0):
validator.check_value_type('sigma', sigma, [float], self.name)
validator.check('sigma', sigma, '', 0, Rel.GT, self.name)
def __init__(self, beta=1.0):
validator.check_value_type('beta', beta, [float], self.name)
validator.check('beta', beta, '', 0, Rel.GT, self.name)
self.init_prim_io_names(inputs=['prediction', 'target'], outputs=['output'])
def infer_shape(self, prediction, target):

View File

@ -21,25 +21,39 @@ import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import composite as C
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True)
def smoothl1loss(beta):
np.random.seed(42)
prediction = np.random.randn(20).astype(np.float32)
target = np.random.randn(20).astype(np.float32)
net = nn.SmoothL1Loss(beta)
return net(Tensor(prediction), Tensor(target))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_smoothl1loss():
np.random.seed(42)
prediction = np.random.randn(20).astype(np.float32)
target = np.random.randn(20).astype(np.float32)
sigma = 1.0
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True)
net = nn.SmoothL1Loss(sigma)
loss = net(Tensor(prediction), Tensor(target))
epsilon = 1e-6
beta = 1.0
loss = smoothl1loss(beta)
expect = [0.46941718, 0.00382918, 0.16829303, 2.447778, 0.04812113, 0.05953304,
2.2302065, 0.07672881, 0.00860204, 0.34798968, 0.00956192, 1.818008,
0.03262977, 0.36599946, 2.047463, 0.2168481, 0.7216947, 1.7739174,
0.08826803, 1.109165]
assert np.allclose(loss.asnumpy(), expect)
diff = np.absolute(loss.asnumpy() - np.array(expect))
assert(diff < epsilon).all()
beta = 1 / 9
loss = smoothl1loss(beta)
expect = [0.9133791, 0.03446258, 0.5246048, 2.8922224, 0.2546738, 0.289504,
2.674651, 0.33618113, 0.07560876, 0.7786982, 0.08273339, 2.2624524,
0.19990394, 0.8000138, 2.4919074, 0.6030006, 1.1661391, 2.2183619,
0.3646064, 1.5536094]
diff = np.absolute(loss.asnumpy() - np.array(expect))
assert(diff < epsilon).all()
class Grad(nn.Cell):
@ -53,20 +67,26 @@ class Grad(nn.Cell):
return gout
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_smoothl1loss_grad():
def smoothl1loss_grad(beta):
np.random.seed(42)
prediction = np.random.randn(20).astype(np.float32)
target = np.random.randn(20).astype(np.float32)
sens = np.random.randn(20).astype(np.float32)
sigma = 1.0
net = nn.SmoothL1Loss(sigma)
net = nn.SmoothL1Loss(beta)
grad = Grad(net)
dx = grad(Tensor(prediction), Tensor(target), Tensor(sens))
return grad(Tensor(prediction), Tensor(target), Tensor(sens))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_smoothl1loss_grad():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True)
epsilon = 1e-6
beta = 1.0
dx = smoothl1loss_grad(beta)
dx1_expect = [-0.71552587, 0.01499678, -0.06709455, -0.30110368, -0.45868093,
0.24838912, -0.46063876, 0.41411355, 0.04507046, -1.4708229,
0.04481723, 0.38508227, -0.17292616, -0.52333146, -1.0309995,
@ -77,5 +97,23 @@ def test_smoothl1loss_grad():
-0.04481723, -0.38508227, 0.17292616, 0.52333146, 1.0309995,
-0.61330026, -0.83921754, 0.3092124, -0.1391843, 0.9755451]
assert np.allclose(dx[0].asnumpy(), dx1_expect)
assert np.allclose(dx[1].asnumpy(), dx2_expect)
diff1 = np.absolute(dx[0].asnumpy() - np.array(dx1_expect))
diff2 = np.absolute(dx[1].asnumpy() - np.array(dx2_expect))
assert(diff1 < epsilon).all()
assert(diff2 < epsilon).all()
beta = 1 / 9
dx = smoothl1loss_grad(beta)
dx1_expect = [-0.73846656, 0.13497104, -0.11564828, -0.30110368, -1.478522,
0.7198442, -0.46063876, 1.0571222, 0.3436183, -1.7630402,
0.32408398, 0.38508227, -0.676922, -0.6116763, -1.0309995,
0.93128014, 0.83921754, -0.3092124, 0.33126342, -0.9755451]
dx2_expect = [0.73846656, -0.13497104, 0.11564828, 0.30110368, 1.478522,
-0.7198442, 0.46063876, -1.0571222, -0.3436183, 1.7630402,
-0.32408398, -0.38508227, 0.676922, 0.6116763, 1.0309995,
-0.93128014, -0.83921754, 0.3092124, -0.33126342, 0.9755451]
diff1 = np.absolute(dx[0].asnumpy() - np.array(dx1_expect))
diff2 = np.absolute(dx[1].asnumpy() - np.array(dx2_expect))
assert(diff1 < epsilon).all()
assert(diff2 < epsilon).all()