forked from mindspore-Ecosystem/mindspore
modified documentation and gpu kernel for smoothL1Loss
fix pylint changed doc and code for SmoothL1Loss to be same a dchip. fixed grad kernel fix ci
This commit is contained in:
parent
56bd92b88f
commit
0d5220d33c
|
@ -18,47 +18,47 @@
|
|||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void SmoothL1LossKernel(const int input_size, const float sigma, const T *prediction, const T *target,
|
||||
__global__ void SmoothL1LossKernel(const int input_size, const float beta, const T *prediction, const T *target,
|
||||
T *loss) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
|
||||
T value = (prediction[i] - target[i]) > 0 ? (prediction[i] - target[i]) : (target[i] - prediction[i]);
|
||||
if (value < sigma) {
|
||||
loss[i] = static_cast<T>(0.5) * value * value;
|
||||
T value = fabsf(prediction[i] - target[i]);
|
||||
if (value < beta) {
|
||||
loss[i] = 0.5 * value * value / beta;
|
||||
} else {
|
||||
loss[i] = value - static_cast<T>(0.5);
|
||||
loss[i] = value - (0.5 * beta);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SmoothL1Loss(const int &input_size, const float &sigma, const T *prediction, const T *target, T *loss,
|
||||
void SmoothL1Loss(const int &input_size, const float &beta, const T *prediction, const T *target, T *loss,
|
||||
cudaStream_t stream) {
|
||||
SmoothL1LossKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, sigma, prediction, target, loss);
|
||||
SmoothL1LossKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, beta, prediction, target, loss);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void SmoothL1LossGradKernel(const int input_size, const float sigma, const T *prediction, const T *target,
|
||||
__global__ void SmoothL1LossGradKernel(const int input_size, const float beta, const T *prediction, const T *target,
|
||||
const T *dloss, T *dx) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
|
||||
T value = prediction[i] - target[i];
|
||||
if (value > static_cast<T>(sigma)) {
|
||||
if (value > beta) {
|
||||
dx[i] = dloss[i];
|
||||
} else if (value < static_cast<T>(-sigma)) {
|
||||
} else if (value < -beta) {
|
||||
dx[i] = -dloss[i];
|
||||
} else {
|
||||
dx[i] = value * dloss[i];
|
||||
dx[i] = (value / beta) * dloss[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SmoothL1LossGrad(const int &input_size, const float &sigma, const T *prediction, const T *target, const T *dloss,
|
||||
void SmoothL1LossGrad(const int &input_size, const float &beta, const T *prediction, const T *target, const T *dloss,
|
||||
T *dx, cudaStream_t stream) {
|
||||
SmoothL1LossGradKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, sigma, prediction, target,
|
||||
SmoothL1LossGradKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, beta, prediction, target,
|
||||
dloss, dx);
|
||||
}
|
||||
|
||||
template void SmoothL1Loss(const int &input_size, const float &sigma, const float *prediction, const float *target,
|
||||
float *loss, cudaStream_t stream);
|
||||
template void SmoothL1LossGrad(const int &input_size, const float &sigma, const float *prediction, const float *target,
|
||||
const float *dloss, float *dx, cudaStream_t stream);
|
||||
template void SmoothL1Loss<float>(const int &input_size, const float &beta, const float *prediction,
|
||||
const float *target, float *loss, cudaStream_t stream);
|
||||
template void SmoothL1LossGrad<float>(const int &input_size, const float &beta, const float *prediction,
|
||||
const float *target, const float *dloss, float *dx, cudaStream_t stream);
|
||||
|
|
|
@ -17,9 +17,9 @@
|
|||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_
|
||||
template <typename T>
|
||||
void SmoothL1Loss(const int &input_size, const float &sigma, const T *prediction, const T *target, T *loss,
|
||||
void SmoothL1Loss(const int &input_size, const float &beta, const T *prediction, const T *target, T *loss,
|
||||
cudaStream_t stream);
|
||||
template <typename T>
|
||||
void SmoothL1LossGrad(const int &input_size, const float &sigma, const T *prediction, const T *target, const T *dloss,
|
||||
void SmoothL1LossGrad(const int &input_size, const float &beta, const T *prediction, const T *target, const T *dloss,
|
||||
T *dx, cudaStream_t stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_
|
||||
|
|
|
@ -26,7 +26,7 @@ namespace kernel {
|
|||
template <typename T>
|
||||
class SmoothL1LossGpuKernel : public GpuKernel {
|
||||
public:
|
||||
SmoothL1LossGpuKernel() : input_size_(1), sigma_(1.0) {}
|
||||
SmoothL1LossGpuKernel() : input_size_(1), beta_(1.0) {}
|
||||
~SmoothL1LossGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
|
@ -39,7 +39,7 @@ class SmoothL1LossGpuKernel : public GpuKernel {
|
|||
T *target = GetDeviceAddress<T>(inputs, 1);
|
||||
T *loss = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
SmoothL1Loss(input_size_, sigma_, prediction, target, loss, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
SmoothL1Loss(input_size_, beta_, prediction, target, loss, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -49,7 +49,7 @@ class SmoothL1LossGpuKernel : public GpuKernel {
|
|||
input_size_ *= input_shape[i];
|
||||
}
|
||||
|
||||
sigma_ = GetAttr<float>(kernel_node, "sigma");
|
||||
beta_ = GetAttr<float>(kernel_node, "beta");
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
@ -63,7 +63,7 @@ class SmoothL1LossGpuKernel : public GpuKernel {
|
|||
|
||||
private:
|
||||
size_t input_size_;
|
||||
float sigma_;
|
||||
float beta_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
|
|
|
@ -26,7 +26,7 @@ namespace kernel {
|
|||
template <typename T>
|
||||
class SmoothL1LossGradGpuKernel : public GpuKernel {
|
||||
public:
|
||||
SmoothL1LossGradGpuKernel() : input_size_(1), sigma_(1.0) {}
|
||||
SmoothL1LossGradGpuKernel() : input_size_(1), beta_(1.0) {}
|
||||
~SmoothL1LossGradGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
|
@ -40,7 +40,7 @@ class SmoothL1LossGradGpuKernel : public GpuKernel {
|
|||
T *dloss = GetDeviceAddress<T>(inputs, 2);
|
||||
T *dx = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
SmoothL1LossGrad(input_size_, sigma_, prediction, target, dloss, dx, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
SmoothL1LossGrad(input_size_, beta_, prediction, target, dloss, dx, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -50,7 +50,7 @@ class SmoothL1LossGradGpuKernel : public GpuKernel {
|
|||
input_size_ *= input_shape[i];
|
||||
}
|
||||
|
||||
sigma_ = GetAttr<float>(kernel_node, "sigma");
|
||||
beta_ = GetAttr<float>(kernel_node, "beta");
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
@ -64,7 +64,7 @@ class SmoothL1LossGradGpuKernel : public GpuKernel {
|
|||
|
||||
private:
|
||||
size_t input_size_;
|
||||
float sigma_;
|
||||
float beta_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
|
|
|
@ -689,7 +689,7 @@ def get_bprop_top_kv2(self):
|
|||
@bprop_getters.register(P.SmoothL1Loss)
|
||||
def get_bprop_smooth_l1_loss(self):
|
||||
"""Grad definition for `SmoothL1Loss` operation."""
|
||||
grad = G.SmoothL1LossGrad(self.sigma)
|
||||
grad = G.SmoothL1LossGrad(self.beta)
|
||||
|
||||
def bprop(prediction, target, out, dout):
|
||||
dx = grad(prediction, target, dout)
|
||||
|
|
|
@ -1258,7 +1258,7 @@ class SmoothL1LossGrad(PrimitiveWithInfer):
|
|||
"""Computes gradient for prediction on SmoothL1Loss."""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, sigma=1.0):
|
||||
def __init__(self, beta=1.0):
|
||||
pass
|
||||
|
||||
def infer_shape(self, prediction, target, dloss):
|
||||
|
|
|
@ -1658,11 +1658,11 @@ class SmoothL1Loss(PrimitiveWithInfer):
|
|||
Sets input prediction as `X`, input target as `Y`, output as `loss`. Then,
|
||||
|
||||
.. math::
|
||||
\text{SmoothL1Loss} = \begin{cases}0.5x^{2}, &if \left |x \right |\leq \text{sigma} \cr
|
||||
\left |x \right|-0.5, &\text{otherwise}\end{cases}
|
||||
\text{SmoothL1Loss} = \begin{cases} \frac{0.5 x^{2}}{\text{beta}, &if \left |x \right | < \text{beta} \cr
|
||||
\left |x \right|-0.5 \text{beta}, &\text{otherwise}\end{cases}
|
||||
|
||||
Args:
|
||||
sigma (float): A parameter used to control the point where the function will change from
|
||||
beta (float): A parameter used to control the point where the function will change from
|
||||
quadratic to linear. Default: 1.0.
|
||||
|
||||
Inputs:
|
||||
|
@ -1681,9 +1681,9 @@ class SmoothL1Loss(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, sigma=1.0):
|
||||
validator.check_value_type('sigma', sigma, [float], self.name)
|
||||
validator.check('sigma', sigma, '', 0, Rel.GT, self.name)
|
||||
def __init__(self, beta=1.0):
|
||||
validator.check_value_type('beta', beta, [float], self.name)
|
||||
validator.check('beta', beta, '', 0, Rel.GT, self.name)
|
||||
self.init_prim_io_names(inputs=['prediction', 'target'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, prediction, target):
|
||||
|
|
|
@ -21,25 +21,39 @@ import mindspore.nn as nn
|
|||
from mindspore import Tensor
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True)
|
||||
def smoothl1loss(beta):
|
||||
np.random.seed(42)
|
||||
prediction = np.random.randn(20).astype(np.float32)
|
||||
target = np.random.randn(20).astype(np.float32)
|
||||
|
||||
net = nn.SmoothL1Loss(beta)
|
||||
return net(Tensor(prediction), Tensor(target))
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_smoothl1loss():
|
||||
np.random.seed(42)
|
||||
prediction = np.random.randn(20).astype(np.float32)
|
||||
target = np.random.randn(20).astype(np.float32)
|
||||
sigma = 1.0
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True)
|
||||
|
||||
net = nn.SmoothL1Loss(sigma)
|
||||
loss = net(Tensor(prediction), Tensor(target))
|
||||
epsilon = 1e-6
|
||||
|
||||
beta = 1.0
|
||||
loss = smoothl1loss(beta)
|
||||
expect = [0.46941718, 0.00382918, 0.16829303, 2.447778, 0.04812113, 0.05953304,
|
||||
2.2302065, 0.07672881, 0.00860204, 0.34798968, 0.00956192, 1.818008,
|
||||
0.03262977, 0.36599946, 2.047463, 0.2168481, 0.7216947, 1.7739174,
|
||||
0.08826803, 1.109165]
|
||||
assert np.allclose(loss.asnumpy(), expect)
|
||||
diff = np.absolute(loss.asnumpy() - np.array(expect))
|
||||
assert(diff < epsilon).all()
|
||||
|
||||
beta = 1 / 9
|
||||
loss = smoothl1loss(beta)
|
||||
expect = [0.9133791, 0.03446258, 0.5246048, 2.8922224, 0.2546738, 0.289504,
|
||||
2.674651, 0.33618113, 0.07560876, 0.7786982, 0.08273339, 2.2624524,
|
||||
0.19990394, 0.8000138, 2.4919074, 0.6030006, 1.1661391, 2.2183619,
|
||||
0.3646064, 1.5536094]
|
||||
diff = np.absolute(loss.asnumpy() - np.array(expect))
|
||||
assert(diff < epsilon).all()
|
||||
|
||||
|
||||
class Grad(nn.Cell):
|
||||
|
@ -53,20 +67,26 @@ class Grad(nn.Cell):
|
|||
return gout
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_smoothl1loss_grad():
|
||||
def smoothl1loss_grad(beta):
|
||||
np.random.seed(42)
|
||||
prediction = np.random.randn(20).astype(np.float32)
|
||||
target = np.random.randn(20).astype(np.float32)
|
||||
sens = np.random.randn(20).astype(np.float32)
|
||||
sigma = 1.0
|
||||
|
||||
net = nn.SmoothL1Loss(sigma)
|
||||
net = nn.SmoothL1Loss(beta)
|
||||
grad = Grad(net)
|
||||
dx = grad(Tensor(prediction), Tensor(target), Tensor(sens))
|
||||
return grad(Tensor(prediction), Tensor(target), Tensor(sens))
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_smoothl1loss_grad():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True)
|
||||
|
||||
epsilon = 1e-6
|
||||
|
||||
beta = 1.0
|
||||
dx = smoothl1loss_grad(beta)
|
||||
dx1_expect = [-0.71552587, 0.01499678, -0.06709455, -0.30110368, -0.45868093,
|
||||
0.24838912, -0.46063876, 0.41411355, 0.04507046, -1.4708229,
|
||||
0.04481723, 0.38508227, -0.17292616, -0.52333146, -1.0309995,
|
||||
|
@ -77,5 +97,23 @@ def test_smoothl1loss_grad():
|
|||
-0.04481723, -0.38508227, 0.17292616, 0.52333146, 1.0309995,
|
||||
-0.61330026, -0.83921754, 0.3092124, -0.1391843, 0.9755451]
|
||||
|
||||
assert np.allclose(dx[0].asnumpy(), dx1_expect)
|
||||
assert np.allclose(dx[1].asnumpy(), dx2_expect)
|
||||
diff1 = np.absolute(dx[0].asnumpy() - np.array(dx1_expect))
|
||||
diff2 = np.absolute(dx[1].asnumpy() - np.array(dx2_expect))
|
||||
assert(diff1 < epsilon).all()
|
||||
assert(diff2 < epsilon).all()
|
||||
|
||||
beta = 1 / 9
|
||||
dx = smoothl1loss_grad(beta)
|
||||
dx1_expect = [-0.73846656, 0.13497104, -0.11564828, -0.30110368, -1.478522,
|
||||
0.7198442, -0.46063876, 1.0571222, 0.3436183, -1.7630402,
|
||||
0.32408398, 0.38508227, -0.676922, -0.6116763, -1.0309995,
|
||||
0.93128014, 0.83921754, -0.3092124, 0.33126342, -0.9755451]
|
||||
dx2_expect = [0.73846656, -0.13497104, 0.11564828, 0.30110368, 1.478522,
|
||||
-0.7198442, 0.46063876, -1.0571222, -0.3436183, 1.7630402,
|
||||
-0.32408398, -0.38508227, 0.676922, 0.6116763, 1.0309995,
|
||||
-0.93128014, -0.83921754, 0.3092124, -0.33126342, 0.9755451]
|
||||
|
||||
diff1 = np.absolute(dx[0].asnumpy() - np.array(dx1_expect))
|
||||
diff2 = np.absolute(dx[1].asnumpy() - np.array(dx2_expect))
|
||||
assert(diff1 < epsilon).all()
|
||||
assert(diff2 < epsilon).all()
|
||||
|
|
Loading…
Reference in New Issue