!16629 Binary_cross_entropy op supports optional weight input

From: @zuochuanyong
Reviewed-by: @zhoufeng54,@liangchenghui
Signed-off-by: @liangchenghui
This commit is contained in:
mindspore-ci-bot 2021-05-25 09:26:00 +08:00 committed by Gitee
commit e1803a0938
11 changed files with 193 additions and 37 deletions

View File

@ -17,6 +17,8 @@
namespace mindspore {
namespace kernel {
constexpr size_t kBceInputNumWithWeight = 3;
template <typename T>
void BinaryCrossEntropyCpuKernel::LaunchToScalar(const int &input_size, const int &reduction, T *loss, T *tmp_loss) {
if (input_size % 2 == 1) {
@ -44,24 +46,37 @@ void BinaryCrossEntropyCpuKernel::Launchkernel(const std::vector<AddressPtr> &in
const std::vector<AddressPtr> &outputs) {
T *input_x = reinterpret_cast<T *>(inputs[0]->addr);
T *input_y = reinterpret_cast<T *>(inputs[1]->addr);
T *weight = reinterpret_cast<T *>(inputs[2]->addr);
T *weight = nullptr;
if (weight_defined_) {
weight = reinterpret_cast<T *>(inputs[2]->addr);
}
T *loss = reinterpret_cast<T *>(outputs[0]->addr);
std::vector<T> tmp_loss(input_size_);
T epsilon = static_cast<T>(1e-12);
T one = static_cast<T>(1);
if (reduction_ == 0) {
if (reduction_ == 0 && weight_defined_) {
for (size_t i = 0; i < input_size_; i++) {
T value =
-weight[i] * (input_y[i] * log(input_x[i] + epsilon) + (one - input_y[i]) * log(one - input_x[i] + epsilon));
loss[i] = value;
}
} else {
} else if (reduction_ == 0 && (!weight_defined_)) {
for (size_t i = 0; i < input_size_; i++) {
T value = -(input_y[i] * log(input_x[i] + epsilon) + (one - input_y[i]) * log(one - input_x[i] + epsilon));
loss[i] = value;
}
} else if ((reduction_ != 0) && weight_defined_) {
for (size_t i = 0; i < input_size_; i++) {
T value =
-weight[i] * (input_y[i] * log(input_x[i] + epsilon) + (one - input_y[i]) * log(one - input_x[i] + epsilon));
tmp_loss[i] = value;
}
} else {
for (size_t i = 0; i < input_size_; i++) {
T value = -(input_y[i] * log(input_x[i] + epsilon) + (one - input_y[i]) * log(one - input_x[i] + epsilon));
tmp_loss[i] = value;
}
}
if (reduction_ != 0) {
@ -93,7 +108,8 @@ void BinaryCrossEntropyCpuKernel::InitKernel(const CNodePtr &kernel_node) {
} else if (reduction == "sum") {
reduction_ = 2;
}
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
weight_defined_ = (input_num == kBceInputNumWithWeight);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
}
} // namespace kernel

View File

@ -25,7 +25,7 @@ namespace mindspore {
namespace kernel {
class BinaryCrossEntropyCpuKernel : public CPUKernel {
public:
BinaryCrossEntropyCpuKernel() : input_size_(1), reduction_(1) {}
BinaryCrossEntropyCpuKernel() : input_size_(1), reduction_(1), weight_defined_(false) {}
~BinaryCrossEntropyCpuKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
@ -42,6 +42,7 @@ class BinaryCrossEntropyCpuKernel : public CPUKernel {
TypeId dtype_{kTypeUnknown};
size_t input_size_;
int reduction_;
bool weight_defined_; // true: there are 3 inputs, false: there are 2 inputs(no [weight])
};
MS_REG_CPU_KERNEL(BinaryCrossEntropy,
KernelAttr()
@ -57,6 +58,14 @@ MS_REG_CPU_KERNEL(BinaryCrossEntropy,
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
BinaryCrossEntropyCpuKernel);
MS_REG_CPU_KERNEL(
BinaryCrossEntropy,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BinaryCrossEntropyCpuKernel);
MS_REG_CPU_KERNEL(
BinaryCrossEntropy,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BinaryCrossEntropyCpuKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_BINARY_CROSS_ENTROPY_KERNEL_H

View File

@ -17,32 +17,54 @@
namespace mindspore {
namespace kernel {
constexpr size_t kBceGradInputNumWithWeight = 4;
template <typename T>
void BinaryCrossEntropyGradCpuKernel::Launchkernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
T *input_x = reinterpret_cast<T *>(inputs[0]->addr);
T *input_y = reinterpret_cast<T *>(inputs[1]->addr);
T *dloss = reinterpret_cast<T *>(inputs[2]->addr);
T *weight = reinterpret_cast<T *>(inputs[3]->addr);
T *weight = nullptr;
if (weight_defined_) {
weight = reinterpret_cast<T *>(inputs[3]->addr);
}
T *dx = reinterpret_cast<T *>(outputs[0]->addr);
T epsilon = static_cast<T>(1e-12);
T one = static_cast<T>(1);
if (reduction_ == 0) {
for (size_t i = 0; i < input_size_; i++) {
T denominator = ((input_x[i] * (one - input_x[i])) > epsilon) ? (input_x[i] * (one - input_x[i])) : epsilon;
T value = weight[i] * (input_x[i] - input_y[i]) / denominator;
dx[i] = value * dloss[i];
if (weight_defined_) {
for (size_t i = 0; i < input_size_; i++) {
T denominator = ((input_x[i] * (one - input_x[i])) > epsilon) ? (input_x[i] * (one - input_x[i])) : epsilon;
T value = weight[i] * (input_x[i] - input_y[i]) / denominator;
dx[i] = value * dloss[i];
}
} else {
for (size_t i = 0; i < input_size_; i++) {
T denominator = ((input_x[i] * (one - input_x[i])) > epsilon) ? (input_x[i] * (one - input_x[i])) : epsilon;
T value = (input_x[i] - input_y[i]) / denominator;
dx[i] = value * dloss[i];
}
}
} else {
T dloss1 = dloss[0];
if (reduction_ == 1) {
dloss1 = dloss[0] / static_cast<T>(input_size_);
}
for (size_t i = 0; i < input_size_; i++) {
T denominator = ((input_x[i] * (one - input_x[i])) > epsilon) ? (input_x[i] * (one - input_x[i])) : epsilon;
T value = weight[i] * (input_x[i] - input_y[i]) / denominator;
dx[i] = value * dloss1;
if (weight_defined_) {
for (size_t i = 0; i < input_size_; i++) {
T denominator = ((input_x[i] * (one - input_x[i])) > epsilon) ? (input_x[i] * (one - input_x[i])) : epsilon;
T value = weight[i] * (input_x[i] - input_y[i]) / denominator;
dx[i] = value * dloss1;
}
} else {
for (size_t i = 0; i < input_size_; i++) {
T denominator = ((input_x[i] * (one - input_x[i])) > epsilon) ? (input_x[i] * (one - input_x[i])) : epsilon;
T value = (input_x[i] - input_y[i]) / denominator;
dx[i] = value * dloss1;
}
}
}
}
@ -72,6 +94,8 @@ void BinaryCrossEntropyGradCpuKernel::InitKernel(const CNodePtr &kernel_node) {
reduction_ = 2;
}
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
weight_defined_ = (input_num == kBceGradInputNumWithWeight);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
}
} // namespace kernel

View File

@ -25,7 +25,7 @@ namespace mindspore {
namespace kernel {
class BinaryCrossEntropyGradCpuKernel : public CPUKernel {
public:
BinaryCrossEntropyGradCpuKernel() : input_size_(1), reduction_(1) {}
BinaryCrossEntropyGradCpuKernel() : input_size_(1), reduction_(1), weight_defined_(false) {}
~BinaryCrossEntropyGradCpuKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
@ -39,6 +39,7 @@ class BinaryCrossEntropyGradCpuKernel : public CPUKernel {
TypeId dtype_{kTypeUnknown};
size_t input_size_;
int reduction_;
bool weight_defined_; // true: there are 4 inputs, false: there are 3 inputs(no [weight])
};
MS_REG_CPU_KERNEL(BinaryCrossEntropyGrad,
KernelAttr()
@ -56,6 +57,20 @@ MS_REG_CPU_KERNEL(BinaryCrossEntropyGrad,
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
BinaryCrossEntropyGradCpuKernel);
MS_REG_CPU_KERNEL(BinaryCrossEntropyGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
BinaryCrossEntropyGradCpuKernel);
MS_REG_CPU_KERNEL(BinaryCrossEntropyGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
BinaryCrossEntropyGradCpuKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_BINARY_CROSS_ENTROPY_GRAD_KERNEL_H

View File

@ -124,18 +124,28 @@ __global__ void BinaryCrossEntropyLossKernel(const int input_size, const int red
const T *input_y, const T *weight, T *loss, T *tmp_loss) {
T epsilon = 1e-12;
T one = static_cast<T>(1);
if (reduction == 0) {
if (reduction == 0 && weight != nullptr) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
T value =
-weight[i] * (input_y[i] * logT(input_x[i] + epsilon) + (one - input_y[i]) * logT(one - input_x[i] + epsilon));
loss[i] = value;
}
} else {
} else if (reduction == 0 && weight == nullptr) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
T value = -(input_y[i] * logT(input_x[i] + epsilon) + (one - input_y[i]) * logT(one - input_x[i] + epsilon));
loss[i] = value;
}
} else if (reduction != 0 && weight != nullptr) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
T value =
-weight[i] * (input_y[i] * logT(input_x[i] + epsilon) + (one - input_y[i]) * logT(one - input_x[i] + epsilon));
tmp_loss[i] = value;
}
} else {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
T value = -(input_y[i] * logT(input_x[i] + epsilon) + (one - input_y[i]) * logT(one - input_x[i] + epsilon));
tmp_loss[i] = value;
}
}
}
@ -165,20 +175,36 @@ __global__ void BinaryCrossEntropyLossGradKernel(const int input_size, const int
T epsilon = 1e-12;
T one = static_cast<T>(1);
if (reduction == 0) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
T denominator = maxT(input_x[i] * (one - input_x[i]), epsilon);
T value = weight[i] * (input_x[i] - input_y[i]) / denominator;
dx[i] = value * dloss[i];
if (weight != nullptr) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
T denominator = maxT(input_x[i] * (one - input_x[i]), epsilon);
T value = weight[i] * (input_x[i] - input_y[i]) / denominator;
dx[i] = value * dloss[i];
}
} else {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
T denominator = maxT(input_x[i] * (one - input_x[i]), epsilon);
T value = (input_x[i] - input_y[i]) / denominator;
dx[i] = value * dloss[i];
}
}
} else {
T dloss1 = dloss[0];
if (reduction == 1) {
dloss1 = dloss[0] / castT(dloss[0], input_size);
}
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
T denominator = maxT(input_x[i] * (one - input_x[i]), epsilon);
T value = weight[i] * (input_x[i] - input_y[i]) / denominator;
dx[i] = value * dloss1;
if (weight != nullptr) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
T denominator = maxT(input_x[i] * (one - input_x[i]), epsilon);
T value = weight[i] * (input_x[i] - input_y[i]) / denominator;
dx[i] = value * dloss1;
}
} else {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
T denominator = maxT(input_x[i] * (one - input_x[i]), epsilon);
T value = (input_x[i] - input_y[i]) / denominator;
dx[i] = value * dloss1;
}
}
}
}

View File

@ -31,5 +31,13 @@ MS_REG_GPU_KERNEL_ONE(BinaryCrossEntropy,
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
BinaryCrossEntropyGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
BinaryCrossEntropy,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BinaryCrossEntropyGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
BinaryCrossEntropy,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BinaryCrossEntropyGpuKernel, half)
} // namespace kernel
} // namespace mindspore

View File

@ -28,7 +28,7 @@ namespace kernel {
template <typename T>
class BinaryCrossEntropyGpuKernel : public GpuKernel {
public:
BinaryCrossEntropyGpuKernel() : input_size_(1), reduction_(1) {}
BinaryCrossEntropyGpuKernel() : weight_defined_(false), input_size_(1), reduction_(1) {}
~BinaryCrossEntropyGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
@ -37,7 +37,10 @@ class BinaryCrossEntropyGpuKernel : public GpuKernel {
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input_x = GetDeviceAddress<T>(inputs, 0);
T *input_y = GetDeviceAddress<T>(inputs, 1);
T *weight = GetDeviceAddress<T>(inputs, 2);
T *weight = nullptr;
if (weight_defined_) {
weight = GetDeviceAddress<T>(inputs, 2);
}
T *loss = GetDeviceAddress<T>(outputs, 0);
T *tmp_loss = GetDeviceAddress<T>(workspace, 0);
if (input_size_ > 0) {
@ -49,6 +52,8 @@ class BinaryCrossEntropyGpuKernel : public GpuKernel {
bool Init(const CNodePtr &kernel_node) override {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
weight_defined_ = (input_num == 3);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
@ -70,7 +75,9 @@ class BinaryCrossEntropyGpuKernel : public GpuKernel {
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T));
input_size_list_.push_back(input_size_ * sizeof(T));
input_size_list_.push_back(input_size_ * sizeof(T));
if (weight_defined_) {
input_size_list_.push_back(input_size_ * sizeof(T));
}
if (reduction_ == 0) {
output_size_list_.push_back(input_size_ * sizeof(T));
} else {
@ -80,6 +87,7 @@ class BinaryCrossEntropyGpuKernel : public GpuKernel {
}
private:
bool weight_defined_; // true: there are 3 inputs, false: there are 2 inputs(no [weight])
size_t input_size_;
int reduction_;
size_t workspace_size_;

View File

@ -34,5 +34,19 @@ MS_REG_GPU_KERNEL_ONE(BinaryCrossEntropyGrad,
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
BinaryCrossEntropyGradGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(BinaryCrossEntropyGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
BinaryCrossEntropyGradGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(BinaryCrossEntropyGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
BinaryCrossEntropyGradGpuKernel, half)
} // namespace kernel
} // namespace mindspore

View File

@ -28,7 +28,7 @@ namespace kernel {
template <typename T>
class BinaryCrossEntropyGradGpuKernel : public GpuKernel {
public:
BinaryCrossEntropyGradGpuKernel() : input_size_(1), reduction_(1) {}
BinaryCrossEntropyGradGpuKernel() : input_size_(1), reduction_(1), weight_defined_(false) {}
~BinaryCrossEntropyGradGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -40,7 +40,10 @@ class BinaryCrossEntropyGradGpuKernel : public GpuKernel {
T *input_x = GetDeviceAddress<T>(inputs, 0);
T *input_y = GetDeviceAddress<T>(inputs, 1);
T *dloss = GetDeviceAddress<T>(inputs, 2);
T *weight = GetDeviceAddress<T>(inputs, 3);
T *weight = nullptr;
if (weight_defined_) {
weight = GetDeviceAddress<T>(inputs, 3);
}
T *dx = GetDeviceAddress<T>(outputs, 0);
if (input_size_ > 0) {
BinaryCrossEntropyLossGrad(input_size_, reduction_, input_x, input_y, weight, dloss, dx,
@ -51,6 +54,8 @@ class BinaryCrossEntropyGradGpuKernel : public GpuKernel {
bool Init(const CNodePtr &kernel_node) override {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
weight_defined_ = (input_num == 4);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
@ -73,14 +78,16 @@ class BinaryCrossEntropyGradGpuKernel : public GpuKernel {
} else {
input_size_list_.push_back(sizeof(T));
}
input_size_list_.push_back(input_size_ * sizeof(T));
if (weight_defined_) {
input_size_list_.push_back(input_size_ * sizeof(T));
}
output_size_list_.push_back(input_size_ * sizeof(T));
}
private:
size_t input_size_;
int reduction_;
bool weight_defined_; // true: there are 4 inputs, false: there are 3 inputs(no [weight])
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;

View File

@ -24,12 +24,13 @@ from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class Net(nn.Cell):
def __init__(self, reduction="none"):
super(Net, self).__init__()
self.BinaryCrossEntropy = P.BinaryCrossEntropy(reduction)
def construct(self, x, y, weight):
def construct(self, x, y, weight=None):
return self.BinaryCrossEntropy(x, y, weight)
@ -50,6 +51,7 @@ def test_binary_cross_entropy_loss():
0.03405444, 0.23934692]
assert np.allclose(loss.asnumpy(), expect)
def test_binary_cross_entropy_loss_mean():
np.random.seed(42)
prediction = np.random.rand(20).astype(np.float32)
@ -61,6 +63,7 @@ def test_binary_cross_entropy_loss_mean():
expect = [0.7447324991226196]
assert loss.asnumpy() == expect
def test_binary_cross_entropy_loss_sum():
np.random.seed(42)
prediction = np.random.rand(20).astype(np.float32)
@ -72,6 +75,18 @@ def test_binary_cross_entropy_loss_sum():
expect = [14.894649505615234]
assert loss.asnumpy() == expect
def test_binary_cross_entropy_loss_sum_without_weight():
np.random.seed(42)
prediction = np.random.rand(20).astype(np.float32)
target = np.random.rand(20).astype(np.float32)
reduction = "sum"
net = Net(reduction)
loss = net(Tensor(prediction), Tensor(target))
expect = [25.48195216753522]
assert np.allclose(loss.asnumpy(), expect)
def test_binary_cross_entropy_loss_16():
np.random.seed(42)
prediction = np.random.rand(20).astype(np.float16)
@ -86,6 +101,7 @@ def test_binary_cross_entropy_loss_16():
0.0340576, 0.239258]
assert np.allclose(loss.asnumpy(), expect)
def test_binary_cross_entropy_loss_mean_16():
np.random.seed(42)
prediction = np.random.rand(20).astype(np.float16)
@ -97,6 +113,7 @@ def test_binary_cross_entropy_loss_mean_16():
expect = [0.74462890625]
assert loss.asnumpy() == expect
def test_binary_cross_entropy_loss_sum_16():
np.random.seed(42)
prediction = np.random.rand(20).astype(np.float16)
@ -108,13 +125,14 @@ def test_binary_cross_entropy_loss_sum_16():
expect = [14.890625]
assert loss.asnumpy() == expect
class Grad(nn.Cell):
def __init__(self, network):
super(Grad, self).__init__()
self.grad = C.GradOperation(get_all=True, sens_param=True)
self.network = network
def construct(self, x1, x2, sens, weight):
def construct(self, x1, x2, sens, weight=None):
gout = self.grad(self.network)(x1, x2, sens, weight)
return gout

View File

@ -28,9 +28,9 @@ context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class Net(nn.Cell):
def __init__(self, reduction="none"):
super(Net, self).__init__()
self.BinaryCrossEntropy = P.BinaryCrossEntropy("none")
self.BinaryCrossEntropy = P.BinaryCrossEntropy(reduction)
def construct(self, x, y, weight):
def construct(self, x, y, weight=None):
return self.BinaryCrossEntropy(x, y, weight)
@ -51,13 +51,24 @@ def test_binary_cross_entropy_loss():
assert np.allclose(loss.asnumpy(), expect)
def test_binary_cross_entropy_loss_sum_without_weight():
np.random.seed(42)
prediction = np.random.rand(20).astype(np.float32)
target = np.random.rand(20).astype(np.float32)
reduction = "sum"
net = Net(reduction)
loss = net(Tensor(prediction), Tensor(target))
expect = [25.48195216753522]
assert np.allclose(loss.asnumpy(), expect)
class Grad(nn.Cell):
def __init__(self, network):
super(Grad, self).__init__()
self.grad = C.GradOperation(get_all=True, sens_param=True)
self.network = network
def construct(self, x1, x2, sens, weight):
def construct(self, x1, x2, sens, weight=None):
gout = self.grad(self.network)(x1, x2, sens, weight)
return gout