mirror of https://github.com/tracel-ai/burn.git
Refactor burn-tensor: Split conv backward ops to allow conditional gradient computation (#2278)
This commit is contained in:
parent
81ec64a929
commit
7ac5deebe2
|
@ -78,20 +78,28 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
let grad = grads.consume::<B, 3>(&ops.node);
|
||||
|
||||
let (x_state, weight_state, bias_state, options) = ops.state;
|
||||
let x = checkpointer.retrieve_node_output(x_state);
|
||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
||||
let bias = Some(checkpointer.retrieve_node_output(bias_state));
|
||||
|
||||
let backward = B::conv1d_backward(x, weight, bias, grad, options);
|
||||
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(x_state);
|
||||
let weight =
|
||||
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(weight_state);
|
||||
let bias =
|
||||
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<1>>(bias_state);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 3>(node.id, backward.x_grad)
|
||||
let grad = B::conv1d_x_backward(
|
||||
x.clone(),
|
||||
weight.clone(),
|
||||
grad.clone(),
|
||||
options.clone(),
|
||||
);
|
||||
grads.register::<B, 3>(node.id, grad)
|
||||
}
|
||||
if let Some(node) = node_weight {
|
||||
grads.register::<B, 3>(node.id, backward.weights_grad)
|
||||
let grad = B::conv1d_weight_backward(x.clone(), weight, grad.clone(), options);
|
||||
grads.register::<B, 3>(node.id, grad)
|
||||
}
|
||||
if let Some(node) = node_bias {
|
||||
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
|
||||
let grad = B::conv1d_bias_backward(x, bias, grad);
|
||||
grads.register::<B, 1>(node.id, grad)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -109,16 +117,22 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
let grad = grads.consume::<B, 3>(&ops.node);
|
||||
|
||||
let (x_state, weight_state, options) = ops.state;
|
||||
let x = checkpointer.retrieve_node_output(x_state);
|
||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
||||
|
||||
let backward = B::conv1d_backward(x, weight, None, grad, options);
|
||||
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(x_state);
|
||||
let weight =
|
||||
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(weight_state);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 3>(node.id, backward.x_grad)
|
||||
let grad = B::conv1d_x_backward(
|
||||
x.clone(),
|
||||
weight.clone(),
|
||||
grad.clone(),
|
||||
options.clone(),
|
||||
);
|
||||
grads.register::<B, 3>(node.id, grad)
|
||||
}
|
||||
if let Some(node) = node_weight {
|
||||
grads.register::<B, 3>(node.id, backward.weights_grad)
|
||||
let grad = B::conv1d_weight_backward(x, weight, grad, options);
|
||||
grads.register::<B, 3>(node.id, grad)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -188,20 +202,32 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
let grad = grads.consume::<B, 3>(&ops.node);
|
||||
|
||||
let (x_state, weight_state, bias_state, options) = ops.state;
|
||||
let x = checkpointer.retrieve_node_output(x_state);
|
||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
||||
let bias = Some(checkpointer.retrieve_node_output(bias_state));
|
||||
|
||||
let backward = B::conv_transpose1d_backward(x, weight, bias, grad, options);
|
||||
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(x_state);
|
||||
let weight =
|
||||
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(weight_state);
|
||||
let bias =
|
||||
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<1>>(bias_state);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 3>(node.id, backward.x_grad)
|
||||
let grad = B::conv_transpose1d_x_backward(
|
||||
weight.clone(),
|
||||
grad.clone(),
|
||||
options.clone(),
|
||||
);
|
||||
grads.register::<B, 3>(node.id, grad)
|
||||
}
|
||||
if let Some(node) = node_weight {
|
||||
grads.register::<B, 3>(node.id, backward.weights_grad)
|
||||
let grad = B::conv_transpose1d_weight_backward(
|
||||
x.clone(),
|
||||
weight,
|
||||
grad.clone(),
|
||||
options,
|
||||
);
|
||||
grads.register::<B, 3>(node.id, grad)
|
||||
}
|
||||
if let Some(node) = node_bias {
|
||||
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
|
||||
let grad = B::conv_transpose1d_bias_backward(x, bias, grad);
|
||||
grads.register::<B, 1>(node.id, grad)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -219,16 +245,21 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
let grad = grads.consume::<B, 3>(&ops.node);
|
||||
|
||||
let (x_state, weight_state, options) = ops.state;
|
||||
let x = checkpointer.retrieve_node_output(x_state);
|
||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
||||
|
||||
let backward = B::conv_transpose1d_backward(x, weight, None, grad, options);
|
||||
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(x_state);
|
||||
let weight =
|
||||
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(weight_state);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 3>(node.id, backward.x_grad)
|
||||
let grad = B::conv_transpose1d_x_backward(
|
||||
weight.clone(),
|
||||
grad.clone(),
|
||||
options.clone(),
|
||||
);
|
||||
grads.register::<B, 3>(node.id, grad)
|
||||
}
|
||||
if let Some(node) = node_weight {
|
||||
grads.register::<B, 3>(node.id, backward.weights_grad)
|
||||
let grad = B::conv_transpose1d_weight_backward(x, weight, grad, options);
|
||||
grads.register::<B, 3>(node.id, grad)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -307,20 +338,29 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
let grad = grads.consume::<B, 4>(&ops.node);
|
||||
|
||||
let (x_state, weight_state, bias_state, options) = ops.state;
|
||||
let x = checkpointer.retrieve_node_output(x_state);
|
||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
||||
let bias = Some(checkpointer.retrieve_node_output(bias_state));
|
||||
|
||||
let backward = B::conv2d_backward(x, weight, bias, grad, options);
|
||||
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(x_state);
|
||||
let weight =
|
||||
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(weight_state);
|
||||
let bias =
|
||||
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<1>>(bias_state);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 4>(node.id, backward.x_grad)
|
||||
let grad = B::conv2d_x_backward(
|
||||
x.clone(),
|
||||
weight.clone(),
|
||||
grad.clone(),
|
||||
options.clone(),
|
||||
);
|
||||
grads.register::<B, 4>(node.id, grad)
|
||||
}
|
||||
if let Some(node) = node_weight {
|
||||
grads.register::<B, 4>(node.id, backward.weights_grad)
|
||||
let grad =
|
||||
B::conv2d_weight_backward(x.clone(), weight.clone(), grad.clone(), options);
|
||||
grads.register::<B, 4>(node.id, grad)
|
||||
}
|
||||
if let Some(node) = node_bias {
|
||||
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
|
||||
let grad = B::conv2d_bias_backward(x, weight, bias, grad);
|
||||
grads.register::<B, 1>(node.id, grad)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -338,16 +378,22 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
let grad = grads.consume::<B, 4>(&ops.node);
|
||||
|
||||
let (x_state, weight_state, options) = ops.state;
|
||||
let x = checkpointer.retrieve_node_output(x_state);
|
||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
||||
|
||||
let backward = B::conv2d_backward(x, weight, None, grad, options);
|
||||
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(x_state);
|
||||
let weight =
|
||||
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(weight_state);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 4>(node.id, backward.x_grad)
|
||||
let grad = B::conv2d_x_backward(
|
||||
x.clone(),
|
||||
weight.clone(),
|
||||
grad.clone(),
|
||||
options.clone(),
|
||||
);
|
||||
grads.register::<B, 4>(node.id, grad)
|
||||
}
|
||||
if let Some(node) = node_weight {
|
||||
grads.register::<B, 4>(node.id, backward.weights_grad)
|
||||
let grad = B::conv2d_weight_backward(x, weight, grad, options);
|
||||
grads.register::<B, 4>(node.id, grad)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -419,20 +465,32 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
let grad = grads.consume::<B, 4>(&ops.node);
|
||||
|
||||
let (x_state, weight_state, bias_state, options) = ops.state;
|
||||
let x = checkpointer.retrieve_node_output(x_state);
|
||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
||||
let bias = Some(checkpointer.retrieve_node_output(bias_state));
|
||||
|
||||
let backward = B::conv_transpose2d_backward(x, weight, bias, grad, options);
|
||||
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(x_state);
|
||||
let weight =
|
||||
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(weight_state);
|
||||
let bias =
|
||||
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<1>>(bias_state);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 4>(node.id, backward.x_grad)
|
||||
let grad = B::conv_transpose2d_x_backward(
|
||||
weight.clone(),
|
||||
grad.clone(),
|
||||
options.clone(),
|
||||
);
|
||||
grads.register::<B, 4>(node.id, grad)
|
||||
}
|
||||
if let Some(node) = node_weight {
|
||||
grads.register::<B, 4>(node.id, backward.weights_grad)
|
||||
let grad = B::conv_transpose2d_weight_backward(
|
||||
x.clone(),
|
||||
weight,
|
||||
grad.clone(),
|
||||
options,
|
||||
);
|
||||
grads.register::<B, 4>(node.id, grad)
|
||||
}
|
||||
if let Some(node) = node_bias {
|
||||
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
|
||||
let grad = B::conv_transpose2d_bias_backward(x, bias, grad);
|
||||
grads.register::<B, 1>(node.id, grad)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -450,16 +508,21 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
let grad = grads.consume::<B, 4>(&ops.node);
|
||||
|
||||
let (x_state, weight_state, options) = ops.state;
|
||||
let x = checkpointer.retrieve_node_output(x_state);
|
||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
||||
|
||||
let backward = B::conv_transpose2d_backward(x, weight, None, grad, options);
|
||||
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(x_state);
|
||||
let weight =
|
||||
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(weight_state);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 4>(node.id, backward.x_grad)
|
||||
let grad = B::conv_transpose2d_x_backward(
|
||||
weight.clone(),
|
||||
grad.clone(),
|
||||
options.clone(),
|
||||
);
|
||||
grads.register::<B, 4>(node.id, grad)
|
||||
}
|
||||
if let Some(node) = node_weight {
|
||||
grads.register::<B, 4>(node.id, backward.weights_grad)
|
||||
let grad = B::conv_transpose2d_weight_backward(x, weight, grad, options);
|
||||
grads.register::<B, 4>(node.id, grad)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -540,20 +603,29 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
let grad = grads.consume::<B, 5>(&ops.node);
|
||||
|
||||
let (x_state, weight_state, bias_state, options) = ops.state;
|
||||
let x = checkpointer.retrieve_node_output(x_state);
|
||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
||||
let bias = Some(checkpointer.retrieve_node_output(bias_state));
|
||||
|
||||
let backward = B::conv3d_backward(x, weight, bias, grad, options);
|
||||
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(x_state);
|
||||
let weight =
|
||||
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(weight_state);
|
||||
let bias =
|
||||
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<1>>(bias_state);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 5>(node.id, backward.x_grad)
|
||||
let grad = B::conv3d_x_backward(
|
||||
x.clone(),
|
||||
weight.clone(),
|
||||
grad.clone(),
|
||||
options.clone(),
|
||||
);
|
||||
grads.register::<B, 5>(node.id, grad)
|
||||
}
|
||||
if let Some(node) = node_weight {
|
||||
grads.register::<B, 5>(node.id, backward.weights_grad)
|
||||
let grad =
|
||||
B::conv3d_weight_backward(x.clone(), weight.clone(), grad.clone(), options);
|
||||
grads.register::<B, 5>(node.id, grad)
|
||||
}
|
||||
if let Some(node) = node_bias {
|
||||
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
|
||||
let grad = B::conv3d_bias_backward(x, weight, bias, grad);
|
||||
grads.register::<B, 1>(node.id, grad)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -571,16 +643,22 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
let grad = grads.consume::<B, 5>(&ops.node);
|
||||
|
||||
let (x_state, weight_state, options) = ops.state;
|
||||
let x = checkpointer.retrieve_node_output(x_state);
|
||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
||||
|
||||
let backward = B::conv3d_backward(x, weight, None, grad, options);
|
||||
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(x_state);
|
||||
let weight =
|
||||
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(weight_state);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 5>(node.id, backward.x_grad)
|
||||
let grad = B::conv3d_x_backward(
|
||||
x.clone(),
|
||||
weight.clone(),
|
||||
grad.clone(),
|
||||
options.clone(),
|
||||
);
|
||||
grads.register::<B, 5>(node.id, grad)
|
||||
}
|
||||
if let Some(node) = node_weight {
|
||||
grads.register::<B, 5>(node.id, backward.weights_grad)
|
||||
let grad = B::conv3d_weight_backward(x, weight, grad, options);
|
||||
grads.register::<B, 5>(node.id, grad)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -652,20 +730,32 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
let grad = grads.consume::<B, 5>(&ops.node);
|
||||
|
||||
let (x_state, weight_state, bias_state, options) = ops.state;
|
||||
let x = checkpointer.retrieve_node_output(x_state);
|
||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
||||
let bias = Some(checkpointer.retrieve_node_output(bias_state));
|
||||
|
||||
let backward = B::conv_transpose3d_backward(x, weight, bias, grad, options);
|
||||
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(x_state);
|
||||
let weight =
|
||||
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(weight_state);
|
||||
let bias =
|
||||
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<1>>(bias_state);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 5>(node.id, backward.x_grad)
|
||||
let grad = B::conv_transpose3d_x_backward(
|
||||
weight.clone(),
|
||||
grad.clone(),
|
||||
options.clone(),
|
||||
);
|
||||
grads.register::<B, 5>(node.id, grad)
|
||||
}
|
||||
if let Some(node) = node_weight {
|
||||
grads.register::<B, 5>(node.id, backward.weights_grad)
|
||||
let grad = B::conv_transpose3d_weight_backward(
|
||||
x.clone(),
|
||||
weight,
|
||||
grad.clone(),
|
||||
options,
|
||||
);
|
||||
grads.register::<B, 5>(node.id, grad)
|
||||
}
|
||||
if let Some(node) = node_bias {
|
||||
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
|
||||
let grad = B::conv_transpose3d_bias_backward(x, bias, grad);
|
||||
grads.register::<B, 1>(node.id, grad)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -683,16 +773,21 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
let grad = grads.consume::<B, 5>(&ops.node);
|
||||
|
||||
let (x_state, weight_state, options) = ops.state;
|
||||
let x = checkpointer.retrieve_node_output(x_state);
|
||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
||||
|
||||
let backward = B::conv_transpose3d_backward(x, weight, None, grad, options);
|
||||
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(x_state);
|
||||
let weight =
|
||||
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(weight_state);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 5>(node.id, backward.x_grad)
|
||||
let grad = B::conv_transpose3d_x_backward(
|
||||
weight.clone(),
|
||||
grad.clone(),
|
||||
options.clone(),
|
||||
);
|
||||
grads.register::<B, 5>(node.id, grad)
|
||||
}
|
||||
if let Some(node) = node_weight {
|
||||
grads.register::<B, 5>(node.id, backward.weights_grad)
|
||||
let grad = B::conv_transpose3d_weight_backward(x, weight, grad, options);
|
||||
grads.register::<B, 5>(node.id, grad)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,32 +5,6 @@ use crate::{
|
|||
Shape,
|
||||
};
|
||||
|
||||
/// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d).
|
||||
#[derive(new)]
|
||||
pub struct Conv2dBackward<B: Backend> {
|
||||
/// Gradient.
|
||||
pub x_grad: FloatTensor<B, 4>,
|
||||
|
||||
/// Weights gradient.
|
||||
pub weights_grad: FloatTensor<B, 4>,
|
||||
|
||||
/// Bias gradient.
|
||||
pub bias_grad: Option<FloatTensor<B, 1>>,
|
||||
}
|
||||
|
||||
/// Gradient computed during the backward pass for each tensor used by [conv3d](ModuleOps::conv3d).
|
||||
#[derive(new)]
|
||||
pub struct Conv3dBackward<B: Backend> {
|
||||
/// Gradient.
|
||||
pub x_grad: FloatTensor<B, 5>,
|
||||
|
||||
/// Weights gradient.
|
||||
pub weights_grad: FloatTensor<B, 5>,
|
||||
|
||||
/// Bias gradient.
|
||||
pub bias_grad: Option<FloatTensor<B, 1>>,
|
||||
}
|
||||
|
||||
/// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d).
|
||||
#[derive(new)]
|
||||
pub struct MaxPool1dBackward<B: Backend> {
|
||||
|
@ -65,19 +39,6 @@ pub struct MaxPool2dWithIndices<B: Backend> {
|
|||
pub indices: IntTensor<B, 4>,
|
||||
}
|
||||
|
||||
/// Gradient computed during the backward pass for each tensor used by [conv1d](ModuleOps::conv1d).
|
||||
#[derive(new)]
|
||||
pub struct Conv1dBackward<B: Backend> {
|
||||
/// Gradient.
|
||||
pub x_grad: FloatTensor<B, 3>,
|
||||
|
||||
/// Weights gradient.
|
||||
pub weights_grad: FloatTensor<B, 3>,
|
||||
|
||||
/// Bias gradient.
|
||||
pub bias_grad: Option<FloatTensor<B, 1>>,
|
||||
}
|
||||
|
||||
/// Convolution options.
|
||||
#[derive(new, Debug, Clone, Hash, PartialEq, Eq)]
|
||||
pub struct ConvOptions<const N: usize> {
|
||||
|
@ -221,15 +182,31 @@ pub trait ModuleOps<B: Backend> {
|
|||
) -> FloatTensor<B, 3> {
|
||||
conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
|
||||
}
|
||||
/// Backward pass for the [conv1d](ModuleOps::conv1d) operation.
|
||||
fn conv1d_backward(
|
||||
/// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `x`.
|
||||
fn conv1d_x_backward(
|
||||
x: FloatTensor<B, 3>,
|
||||
weight: FloatTensor<B, 3>,
|
||||
bias: Option<FloatTensor<B, 1>>,
|
||||
output_grad: FloatTensor<B, 3>,
|
||||
options: ConvOptions<1>,
|
||||
) -> Conv1dBackward<B> {
|
||||
conv::conv1d_backward(x, weight, bias, output_grad, options)
|
||||
) -> FloatTensor<B, 3> {
|
||||
conv::conv1d_x_backward::<B>(x, weight, output_grad, options)
|
||||
}
|
||||
/// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `weight`.
|
||||
fn conv1d_weight_backward(
|
||||
x: FloatTensor<B, 3>,
|
||||
weight: FloatTensor<B, 3>,
|
||||
output_grad: FloatTensor<B, 3>,
|
||||
options: ConvOptions<1>,
|
||||
) -> FloatTensor<B, 3> {
|
||||
conv::conv1d_weight_backward::<B>(x, weight, output_grad, options)
|
||||
}
|
||||
/// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `bias`.
|
||||
fn conv1d_bias_backward(
|
||||
x: FloatTensor<B, 3>,
|
||||
bias: FloatTensor<B, 1>,
|
||||
output_grad: FloatTensor<B, 3>,
|
||||
) -> FloatTensor<B, 1> {
|
||||
conv::conv1d_bias_backward::<B>(x, bias, output_grad)
|
||||
}
|
||||
/// Two dimensional convolution.
|
||||
///
|
||||
|
@ -244,15 +221,32 @@ pub trait ModuleOps<B: Backend> {
|
|||
bias: Option<FloatTensor<B, 1>>,
|
||||
options: ConvOptions<2>,
|
||||
) -> FloatTensor<B, 4>;
|
||||
/// Backward pass for the [conv2d](ModuleOps::conv2d) operation.
|
||||
fn conv2d_backward(
|
||||
/// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `x`.
|
||||
fn conv2d_x_backward(
|
||||
x: FloatTensor<B, 4>,
|
||||
weight: FloatTensor<B, 4>,
|
||||
bias: Option<FloatTensor<B, 1>>,
|
||||
output_grad: FloatTensor<B, 4>,
|
||||
options: ConvOptions<2>,
|
||||
) -> Conv2dBackward<B> {
|
||||
conv::conv2d_backward(x, weight, bias, output_grad, options)
|
||||
) -> FloatTensor<B, 4> {
|
||||
conv::conv2d_x_backward::<B>(x, weight, output_grad, options)
|
||||
}
|
||||
/// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `weight`.
|
||||
fn conv2d_weight_backward(
|
||||
x: FloatTensor<B, 4>,
|
||||
weight: FloatTensor<B, 4>,
|
||||
output_grad: FloatTensor<B, 4>,
|
||||
options: ConvOptions<2>,
|
||||
) -> FloatTensor<B, 4> {
|
||||
conv::conv2d_weight_backward::<B>(x, weight, output_grad, options)
|
||||
}
|
||||
/// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `bias`.
|
||||
fn conv2d_bias_backward(
|
||||
x: FloatTensor<B, 4>,
|
||||
weight: FloatTensor<B, 4>,
|
||||
bias: FloatTensor<B, 1>,
|
||||
output_grad: FloatTensor<B, 4>,
|
||||
) -> FloatTensor<B, 1> {
|
||||
conv::conv2d_bias_backward::<B>(x, weight, bias, output_grad)
|
||||
}
|
||||
/// Three dimensional convolution.
|
||||
///
|
||||
|
@ -267,15 +261,32 @@ pub trait ModuleOps<B: Backend> {
|
|||
bias: Option<FloatTensor<B, 1>>,
|
||||
options: ConvOptions<3>,
|
||||
) -> FloatTensor<B, 5>;
|
||||
/// Backward pass for the [conv3d](ModuleOps::conv3d) operation.
|
||||
fn conv3d_backward(
|
||||
/// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `x`.
|
||||
fn conv3d_x_backward(
|
||||
x: FloatTensor<B, 5>,
|
||||
weight: FloatTensor<B, 5>,
|
||||
bias: Option<FloatTensor<B, 1>>,
|
||||
output_grad: FloatTensor<B, 5>,
|
||||
options: ConvOptions<3>,
|
||||
) -> Conv3dBackward<B> {
|
||||
conv::conv3d_backward(x, weight, bias, output_grad, options)
|
||||
) -> FloatTensor<B, 5> {
|
||||
conv::conv3d_x_backward::<B>(x, weight, output_grad, options)
|
||||
}
|
||||
/// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `weight`.
|
||||
fn conv3d_weight_backward(
|
||||
x: FloatTensor<B, 5>,
|
||||
weight: FloatTensor<B, 5>,
|
||||
output_grad: FloatTensor<B, 5>,
|
||||
options: ConvOptions<3>,
|
||||
) -> FloatTensor<B, 5> {
|
||||
conv::conv3d_weight_backward::<B>(x, weight, output_grad, options)
|
||||
}
|
||||
/// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `bias`.
|
||||
fn conv3d_bias_backward(
|
||||
x: FloatTensor<B, 5>,
|
||||
weight: FloatTensor<B, 5>,
|
||||
bias: FloatTensor<B, 1>,
|
||||
output_grad: FloatTensor<B, 5>,
|
||||
) -> FloatTensor<B, 1> {
|
||||
conv::conv3d_bias_backward::<B>(x, weight, bias, output_grad)
|
||||
}
|
||||
/// One dimensional transposed convolution.
|
||||
///
|
||||
|
@ -292,15 +303,30 @@ pub trait ModuleOps<B: Backend> {
|
|||
) -> FloatTensor<B, 3> {
|
||||
conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
|
||||
}
|
||||
/// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation.
|
||||
fn conv_transpose1d_backward(
|
||||
x: FloatTensor<B, 3>,
|
||||
/// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `x`.
|
||||
fn conv_transpose1d_x_backward(
|
||||
weight: FloatTensor<B, 3>,
|
||||
bias: Option<FloatTensor<B, 1>>,
|
||||
output_grad: FloatTensor<B, 3>,
|
||||
options: ConvTransposeOptions<1>,
|
||||
) -> Conv1dBackward<B> {
|
||||
conv::conv_transpose1d_backward(x, weight, bias, output_grad, options)
|
||||
) -> FloatTensor<B, 3> {
|
||||
conv::conv_transpose1d_x_backward::<B>(weight, output_grad, options)
|
||||
}
|
||||
/// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `weight`.
|
||||
fn conv_transpose1d_weight_backward(
|
||||
x: FloatTensor<B, 3>,
|
||||
weight: FloatTensor<B, 3>,
|
||||
output_grad: FloatTensor<B, 3>,
|
||||
options: ConvTransposeOptions<1>,
|
||||
) -> FloatTensor<B, 3> {
|
||||
conv::conv_transpose1d_weight_backward::<B>(x, weight, output_grad, options)
|
||||
}
|
||||
/// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `bias`.
|
||||
fn conv_transpose1d_bias_backward(
|
||||
x: FloatTensor<B, 3>,
|
||||
bias: FloatTensor<B, 1>,
|
||||
output_grad: FloatTensor<B, 3>,
|
||||
) -> FloatTensor<B, 1> {
|
||||
conv::conv_transpose1d_bias_backward::<B>(x, bias, output_grad)
|
||||
}
|
||||
|
||||
/// Two dimensional transposed convolution.
|
||||
|
@ -316,15 +342,30 @@ pub trait ModuleOps<B: Backend> {
|
|||
bias: Option<FloatTensor<B, 1>>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> FloatTensor<B, 4>;
|
||||
/// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation.
|
||||
fn conv_transpose2d_backward(
|
||||
x: FloatTensor<B, 4>,
|
||||
/// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `x`.
|
||||
fn conv_transpose2d_x_backward(
|
||||
weight: FloatTensor<B, 4>,
|
||||
bias: Option<FloatTensor<B, 1>>,
|
||||
output_grad: FloatTensor<B, 4>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> Conv2dBackward<B> {
|
||||
conv::conv_transpose2d_backward(x, weight, bias, output_grad, options)
|
||||
) -> FloatTensor<B, 4> {
|
||||
conv::conv_transpose2d_x_backward::<B>(weight, output_grad, options)
|
||||
}
|
||||
/// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `weight`.
|
||||
fn conv_transpose2d_weight_backward(
|
||||
x: FloatTensor<B, 4>,
|
||||
weight: FloatTensor<B, 4>,
|
||||
output_grad: FloatTensor<B, 4>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> FloatTensor<B, 4> {
|
||||
conv::conv_transpose2d_weight_backward::<B>(x, weight, output_grad, options)
|
||||
}
|
||||
/// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `bias`.
|
||||
fn conv_transpose2d_bias_backward(
|
||||
x: FloatTensor<B, 4>,
|
||||
bias: FloatTensor<B, 1>,
|
||||
output_grad: FloatTensor<B, 4>,
|
||||
) -> FloatTensor<B, 1> {
|
||||
conv::conv_transpose2d_bias_backward::<B>(x, bias, output_grad)
|
||||
}
|
||||
|
||||
/// Three dimensional transposed convolution.
|
||||
|
@ -340,15 +381,30 @@ pub trait ModuleOps<B: Backend> {
|
|||
bias: Option<FloatTensor<B, 1>>,
|
||||
options: ConvTransposeOptions<3>,
|
||||
) -> FloatTensor<B, 5>;
|
||||
/// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation.
|
||||
fn conv_transpose3d_backward(
|
||||
x: FloatTensor<B, 5>,
|
||||
/// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `x`.
|
||||
fn conv_transpose3d_x_backward(
|
||||
weight: FloatTensor<B, 5>,
|
||||
bias: Option<FloatTensor<B, 1>>,
|
||||
output_grad: FloatTensor<B, 5>,
|
||||
options: ConvTransposeOptions<3>,
|
||||
) -> Conv3dBackward<B> {
|
||||
conv::conv_transpose3d_backward(x, weight, bias, output_grad, options)
|
||||
) -> FloatTensor<B, 5> {
|
||||
conv::conv_transpose3d_x_backward::<B>(weight, output_grad, options)
|
||||
}
|
||||
/// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `weight`.
|
||||
fn conv_transpose3d_weight_backward(
|
||||
x: FloatTensor<B, 5>,
|
||||
weight: FloatTensor<B, 5>,
|
||||
output_grad: FloatTensor<B, 5>,
|
||||
options: ConvTransposeOptions<3>,
|
||||
) -> FloatTensor<B, 5> {
|
||||
conv::conv_transpose3d_weight_backward::<B>(x, weight, output_grad, options)
|
||||
}
|
||||
/// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `bias`.
|
||||
fn conv_transpose3d_bias_backward(
|
||||
x: FloatTensor<B, 5>,
|
||||
bias: FloatTensor<B, 1>,
|
||||
output_grad: FloatTensor<B, 5>,
|
||||
) -> FloatTensor<B, 1> {
|
||||
conv::conv_transpose3d_bias_backward::<B>(x, bias, output_grad)
|
||||
}
|
||||
|
||||
/// Four-dimensional unfolding.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#![allow(clippy::single_range_in_vec_init)]
|
||||
use super::{Conv1dBackward, Conv2dBackward, Conv3dBackward, ConvOptions, ConvTransposeOptions};
|
||||
use super::{ConvOptions, ConvTransposeOptions};
|
||||
use crate::{backend::Backend, ops::FloatTensor, Shape};
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
|
@ -57,19 +57,17 @@ pub fn calculate_pool_output_size(
|
|||
((size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride) + 1
|
||||
}
|
||||
|
||||
/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass using convolutions.
|
||||
pub(crate) fn conv1d_backward<B: Backend>(
|
||||
/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `x`.
|
||||
pub(crate) fn conv1d_x_backward<B: Backend>(
|
||||
x: FloatTensor<B, 3>,
|
||||
weight: FloatTensor<B, 3>,
|
||||
bias: Option<FloatTensor<B, 1>>,
|
||||
output_grad: FloatTensor<B, 3>,
|
||||
options: ConvOptions<1>,
|
||||
) -> Conv1dBackward<B> {
|
||||
) -> FloatTensor<B, 3> {
|
||||
let weight_shape = B::float_shape(&weight);
|
||||
let weight_device = B::float_device(&weight);
|
||||
|
||||
let [batch_size, _, length_in] = B::float_shape(&x).dims;
|
||||
let [_batch_size, channels_out, length_out] = B::float_shape(&output_grad).dims;
|
||||
let [_batch_size, _, length_in] = B::float_shape(&x).dims;
|
||||
let [_batch_size, _channels_out, length_out] = B::float_shape(&output_grad).dims;
|
||||
let [_, _, kernel_size] = weight_shape.dims;
|
||||
|
||||
let padding_out = calculate_padding_out(
|
||||
|
@ -81,8 +79,8 @@ pub(crate) fn conv1d_backward<B: Backend>(
|
|||
length_out,
|
||||
);
|
||||
|
||||
let x_grad = B::conv_transpose1d(
|
||||
output_grad.clone(),
|
||||
B::conv_transpose1d(
|
||||
output_grad,
|
||||
weight,
|
||||
None,
|
||||
ConvTransposeOptions::new(
|
||||
|
@ -92,45 +90,58 @@ pub(crate) fn conv1d_backward<B: Backend>(
|
|||
options.dilation,
|
||||
options.groups,
|
||||
),
|
||||
);
|
||||
)
|
||||
}
|
||||
|
||||
let weight_grad = match options.groups == 1 {
|
||||
true => conv1d_weight_grad_no_groups::<B>(x, output_grad.clone(), weight_shape, options),
|
||||
/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `weight`.
|
||||
pub(crate) fn conv1d_weight_backward<B: Backend>(
|
||||
x: FloatTensor<B, 3>,
|
||||
weight: FloatTensor<B, 3>,
|
||||
output_grad: FloatTensor<B, 3>,
|
||||
options: ConvOptions<1>,
|
||||
) -> FloatTensor<B, 3> {
|
||||
let weight_shape = B::float_shape(&weight);
|
||||
let weight_device = B::float_device(&weight);
|
||||
|
||||
match options.groups == 1 {
|
||||
true => conv1d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
|
||||
false => conv1d_weight_grad_groups::<B>(
|
||||
x,
|
||||
B::float_zeros(weight_shape, &weight_device),
|
||||
output_grad.clone(),
|
||||
output_grad,
|
||||
options,
|
||||
),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `bias`.
|
||||
pub(crate) fn conv1d_bias_backward<B: Backend>(
|
||||
x: FloatTensor<B, 3>,
|
||||
bias: FloatTensor<B, 1>,
|
||||
output_grad: FloatTensor<B, 3>,
|
||||
) -> FloatTensor<B, 1> {
|
||||
let [batch_size, _, _length_in] = B::float_shape(&x).dims;
|
||||
let [_batch_size, channels_out, length_out] = B::float_shape(&output_grad).dims;
|
||||
|
||||
Conv1dBackward::new(
|
||||
x_grad,
|
||||
weight_grad,
|
||||
bias.map(|b| {
|
||||
let grad = B::float_swap_dims(output_grad, 0, 1);
|
||||
let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out]));
|
||||
let grad = B::float_sum_dim(grad, 1);
|
||||
|
||||
B::float_reshape(grad, B::float_shape(&b))
|
||||
}),
|
||||
)
|
||||
B::float_reshape(grad, B::float_shape(&bias))
|
||||
}
|
||||
|
||||
/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass using convolutions.
|
||||
pub(crate) fn conv2d_backward<B: Backend>(
|
||||
/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `x`.
|
||||
pub(crate) fn conv2d_x_backward<B: Backend>(
|
||||
x: FloatTensor<B, 4>,
|
||||
weight: FloatTensor<B, 4>,
|
||||
bias: Option<FloatTensor<B, 1>>,
|
||||
output_grad: FloatTensor<B, 4>,
|
||||
options: ConvOptions<2>,
|
||||
) -> Conv2dBackward<B> {
|
||||
) -> FloatTensor<B, 4> {
|
||||
let weight_shape = B::float_shape(&weight);
|
||||
let weight_device = B::float_device(&weight);
|
||||
|
||||
let [batch_size, _channels_in, height_in, width_in] = B::float_shape(&x).dims;
|
||||
let [_batch_size, _channels_in, height_in, width_in] = B::float_shape(&x).dims;
|
||||
let [_, _, height_out, width_out] = B::float_shape(&output_grad).dims;
|
||||
let [channels_out, _, kernel_size_1, kernel_size_2] = weight_shape.dims;
|
||||
let [_channels_out, _, kernel_size_1, kernel_size_2] = weight_shape.dims;
|
||||
|
||||
let padding_1_out = calculate_padding_out(
|
||||
kernel_size_1,
|
||||
|
@ -149,8 +160,8 @@ pub(crate) fn conv2d_backward<B: Backend>(
|
|||
width_out,
|
||||
);
|
||||
|
||||
let x_grad = B::conv_transpose2d(
|
||||
output_grad.clone(),
|
||||
B::conv_transpose2d(
|
||||
output_grad,
|
||||
weight,
|
||||
None,
|
||||
ConvTransposeOptions::new(
|
||||
|
@ -160,22 +171,43 @@ pub(crate) fn conv2d_backward<B: Backend>(
|
|||
options.dilation,
|
||||
options.groups,
|
||||
),
|
||||
);
|
||||
)
|
||||
}
|
||||
|
||||
let weight_grad = match options.groups == 1 {
|
||||
true => conv2d_weight_grad_no_groups::<B>(x, output_grad.clone(), weight_shape, options),
|
||||
/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `weight`.
|
||||
pub(crate) fn conv2d_weight_backward<B: Backend>(
|
||||
x: FloatTensor<B, 4>,
|
||||
weight: FloatTensor<B, 4>,
|
||||
output_grad: FloatTensor<B, 4>,
|
||||
options: ConvOptions<2>,
|
||||
) -> FloatTensor<B, 4> {
|
||||
let weight_shape = B::float_shape(&weight);
|
||||
let weight_device = B::float_device(&weight);
|
||||
|
||||
match options.groups == 1 {
|
||||
true => conv2d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
|
||||
false => conv2d_weight_grad_groups::<B>(
|
||||
x,
|
||||
B::float_zeros(weight_shape, &weight_device),
|
||||
output_grad.clone(),
|
||||
output_grad,
|
||||
options,
|
||||
),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `bias`.
|
||||
pub(crate) fn conv2d_bias_backward<B: Backend>(
|
||||
x: FloatTensor<B, 4>,
|
||||
weight: FloatTensor<B, 4>,
|
||||
bias: FloatTensor<B, 1>,
|
||||
output_grad: FloatTensor<B, 4>,
|
||||
) -> FloatTensor<B, 1> {
|
||||
let weight_shape = B::float_shape(&weight);
|
||||
|
||||
let [batch_size, _channels_in, _height_in, _width_in] = B::float_shape(&x).dims;
|
||||
let [_, _, height_out, width_out] = B::float_shape(&output_grad).dims;
|
||||
let [channels_out, _, _kernel_size_1, _kernel_size_2] = weight_shape.dims;
|
||||
|
||||
Conv2dBackward::new(
|
||||
x_grad,
|
||||
weight_grad,
|
||||
bias.map(|b| {
|
||||
let grad = B::float_swap_dims(output_grad, 0, 1);
|
||||
let grad = B::float_reshape(
|
||||
grad,
|
||||
|
@ -183,25 +215,21 @@ pub(crate) fn conv2d_backward<B: Backend>(
|
|||
);
|
||||
let grad = B::float_sum_dim(grad, 1);
|
||||
|
||||
B::float_reshape(grad, B::float_shape(&b))
|
||||
}),
|
||||
)
|
||||
B::float_reshape(grad, B::float_shape(&bias))
|
||||
}
|
||||
|
||||
/// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass using convolutions.
|
||||
pub(crate) fn conv3d_backward<B: Backend>(
|
||||
/// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `x`.
|
||||
pub(crate) fn conv3d_x_backward<B: Backend>(
|
||||
x: FloatTensor<B, 5>,
|
||||
weight: FloatTensor<B, 5>,
|
||||
bias: Option<FloatTensor<B, 1>>,
|
||||
output_grad: FloatTensor<B, 5>,
|
||||
options: ConvOptions<3>,
|
||||
) -> Conv3dBackward<B> {
|
||||
) -> FloatTensor<B, 5> {
|
||||
let weight_shape = B::float_shape(&weight);
|
||||
let weight_device = B::float_device(&weight);
|
||||
|
||||
let [batch_size, _channels_in, depth_in, height_in, width_in] = B::float_shape(&x).dims;
|
||||
let [_batch_size, _channels_in, depth_in, height_in, width_in] = B::float_shape(&x).dims;
|
||||
let [_, _, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims;
|
||||
let [channels_out, _, kernel_size_1, kernel_size_2, kernel_size_3] = weight_shape.dims;
|
||||
let [_channels_out, _, kernel_size_1, kernel_size_2, kernel_size_3] = weight_shape.dims;
|
||||
|
||||
let padding_1_out = calculate_padding_out(
|
||||
kernel_size_1,
|
||||
|
@ -228,8 +256,8 @@ pub(crate) fn conv3d_backward<B: Backend>(
|
|||
width_out,
|
||||
);
|
||||
|
||||
let x_grad = B::conv_transpose3d(
|
||||
output_grad.clone(),
|
||||
B::conv_transpose3d(
|
||||
output_grad,
|
||||
weight,
|
||||
None,
|
||||
ConvTransposeOptions::new(
|
||||
|
@ -239,22 +267,43 @@ pub(crate) fn conv3d_backward<B: Backend>(
|
|||
options.dilation,
|
||||
options.groups,
|
||||
),
|
||||
);
|
||||
)
|
||||
}
|
||||
|
||||
let weight_grad = match options.groups == 1 {
|
||||
true => conv3d_weight_grad_no_groups::<B>(x, output_grad.clone(), weight_shape, options),
|
||||
/// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `weight`.
|
||||
pub(crate) fn conv3d_weight_backward<B: Backend>(
|
||||
x: FloatTensor<B, 5>,
|
||||
weight: FloatTensor<B, 5>,
|
||||
output_grad: FloatTensor<B, 5>,
|
||||
options: ConvOptions<3>,
|
||||
) -> FloatTensor<B, 5> {
|
||||
let weight_shape = B::float_shape(&weight);
|
||||
let weight_device = B::float_device(&weight);
|
||||
|
||||
match options.groups == 1 {
|
||||
true => conv3d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
|
||||
false => conv3d_weight_grad_groups::<B>(
|
||||
x,
|
||||
B::float_zeros(weight_shape, &weight_device),
|
||||
output_grad.clone(),
|
||||
output_grad,
|
||||
options,
|
||||
),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `bias`.
|
||||
pub(crate) fn conv3d_bias_backward<B: Backend>(
|
||||
x: FloatTensor<B, 5>,
|
||||
weight: FloatTensor<B, 5>,
|
||||
bias: FloatTensor<B, 1>,
|
||||
output_grad: FloatTensor<B, 5>,
|
||||
) -> FloatTensor<B, 1> {
|
||||
let weight_shape = B::float_shape(&weight);
|
||||
|
||||
let [batch_size, _channels_in, _depth_in, _height_in, _width_in] = B::float_shape(&x).dims;
|
||||
let [_, _, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims;
|
||||
let [channels_out, _, _kernel_size_1, _kernel_size_2, _kernel_size_3] = weight_shape.dims;
|
||||
|
||||
Conv3dBackward::new(
|
||||
x_grad,
|
||||
weight_grad,
|
||||
bias.map(|b| {
|
||||
let grad = B::float_swap_dims(output_grad, 0, 1);
|
||||
let grad = B::float_reshape(
|
||||
grad,
|
||||
|
@ -265,27 +314,17 @@ pub(crate) fn conv3d_backward<B: Backend>(
|
|||
);
|
||||
let grad = B::float_sum_dim(grad, 1);
|
||||
|
||||
B::float_reshape(grad, B::float_shape(&b))
|
||||
}),
|
||||
)
|
||||
B::float_reshape(grad, B::float_shape(&bias))
|
||||
}
|
||||
|
||||
/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass using convolutions.
|
||||
pub(crate) fn conv_transpose1d_backward<B: Backend>(
|
||||
x: FloatTensor<B, 3>,
|
||||
/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `x`.
|
||||
pub(crate) fn conv_transpose1d_x_backward<B: Backend>(
|
||||
weight: FloatTensor<B, 3>,
|
||||
bias: Option<FloatTensor<B, 1>>,
|
||||
output_grad: FloatTensor<B, 3>,
|
||||
options: ConvTransposeOptions<1>,
|
||||
) -> Conv1dBackward<B> {
|
||||
let weight_shape = B::float_shape(&weight);
|
||||
let weight_device = B::float_device(&weight);
|
||||
|
||||
let [batch_size, _channels_in, _] = B::float_shape(&x).dims;
|
||||
let [_, channels_out, length_out] = B::float_shape(&output_grad).dims;
|
||||
|
||||
let x_grad = B::conv1d(
|
||||
output_grad.clone(),
|
||||
) -> FloatTensor<B, 3> {
|
||||
B::conv1d(
|
||||
output_grad,
|
||||
weight,
|
||||
None,
|
||||
ConvOptions::new(
|
||||
|
@ -294,52 +333,54 @@ pub(crate) fn conv_transpose1d_backward<B: Backend>(
|
|||
options.dilation,
|
||||
options.groups,
|
||||
),
|
||||
);
|
||||
)
|
||||
}
|
||||
|
||||
let weight_grad = match options.groups == 1 {
|
||||
true => conv_transpose1d_weight_grad_no_groups::<B>(
|
||||
x,
|
||||
output_grad.clone(),
|
||||
weight_shape,
|
||||
options,
|
||||
),
|
||||
/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `weight`.
|
||||
pub(crate) fn conv_transpose1d_weight_backward<B: Backend>(
|
||||
x: FloatTensor<B, 3>,
|
||||
weight: FloatTensor<B, 3>,
|
||||
output_grad: FloatTensor<B, 3>,
|
||||
options: ConvTransposeOptions<1>,
|
||||
) -> FloatTensor<B, 3> {
|
||||
let weight_shape = B::float_shape(&weight);
|
||||
let weight_device = B::float_device(&weight);
|
||||
|
||||
match options.groups == 1 {
|
||||
true => conv_transpose1d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
|
||||
false => conv_transpose1d_weight_grad_groups::<B>(
|
||||
x,
|
||||
B::float_zeros(weight_shape, &weight_device),
|
||||
output_grad.clone(),
|
||||
output_grad,
|
||||
options,
|
||||
),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `bias`.
|
||||
pub(crate) fn conv_transpose1d_bias_backward<B: Backend>(
|
||||
x: FloatTensor<B, 3>,
|
||||
bias: FloatTensor<B, 1>,
|
||||
output_grad: FloatTensor<B, 3>,
|
||||
) -> FloatTensor<B, 1> {
|
||||
let [batch_size, _channels_in, _] = B::float_shape(&x).dims;
|
||||
let [_, channels_out, length_out] = B::float_shape(&output_grad).dims;
|
||||
|
||||
Conv1dBackward::new(
|
||||
x_grad,
|
||||
weight_grad,
|
||||
bias.map(|b| {
|
||||
let grad = B::float_swap_dims(output_grad, 0, 1);
|
||||
let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out]));
|
||||
let grad = B::float_sum_dim(grad, 1);
|
||||
|
||||
B::float_reshape(grad, B::float_shape(&b))
|
||||
}),
|
||||
)
|
||||
B::float_reshape(grad, B::float_shape(&bias))
|
||||
}
|
||||
|
||||
/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass using convolutions.
|
||||
pub(crate) fn conv_transpose2d_backward<B: Backend>(
|
||||
x: FloatTensor<B, 4>,
|
||||
/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `x`.
|
||||
pub(crate) fn conv_transpose2d_x_backward<B: Backend>(
|
||||
weight: FloatTensor<B, 4>,
|
||||
bias: Option<FloatTensor<B, 1>>,
|
||||
output_grad: FloatTensor<B, 4>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> Conv2dBackward<B> {
|
||||
let weight_shape = B::float_shape(&weight);
|
||||
let weight_device = B::float_device(&weight);
|
||||
|
||||
let [batch_size, _channels_in, _, _] = B::float_shape(&x).dims;
|
||||
let [_, channels_out, height_out, width_out] = B::float_shape(&output_grad).dims;
|
||||
|
||||
let x_grad = B::conv2d(
|
||||
output_grad.clone(),
|
||||
) -> FloatTensor<B, 4> {
|
||||
B::conv2d(
|
||||
output_grad,
|
||||
weight,
|
||||
None,
|
||||
ConvOptions::new(
|
||||
|
@ -348,27 +389,39 @@ pub(crate) fn conv_transpose2d_backward<B: Backend>(
|
|||
options.dilation,
|
||||
options.groups,
|
||||
),
|
||||
);
|
||||
)
|
||||
}
|
||||
|
||||
let weight_grad = match options.groups == 1 {
|
||||
true => conv_transpose2d_weight_grad_no_groups::<B>(
|
||||
x,
|
||||
output_grad.clone(),
|
||||
weight_shape,
|
||||
options,
|
||||
),
|
||||
/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `weight`.
|
||||
pub(crate) fn conv_transpose2d_weight_backward<B: Backend>(
|
||||
x: FloatTensor<B, 4>,
|
||||
weight: FloatTensor<B, 4>,
|
||||
output_grad: FloatTensor<B, 4>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> FloatTensor<B, 4> {
|
||||
let weight_shape = B::float_shape(&weight);
|
||||
let weight_device = B::float_device(&weight);
|
||||
|
||||
match options.groups == 1 {
|
||||
true => conv_transpose2d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
|
||||
false => conv_transpose2d_weight_grad_groups::<B>(
|
||||
x,
|
||||
B::float_zeros(weight_shape, &weight_device),
|
||||
output_grad.clone(),
|
||||
output_grad,
|
||||
options,
|
||||
),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `bias`.
|
||||
pub(crate) fn conv_transpose2d_bias_backward<B: Backend>(
|
||||
x: FloatTensor<B, 4>,
|
||||
bias: FloatTensor<B, 1>,
|
||||
output_grad: FloatTensor<B, 4>,
|
||||
) -> FloatTensor<B, 1> {
|
||||
let [batch_size, _channels_in, _, _] = B::float_shape(&x).dims;
|
||||
let [_, channels_out, height_out, width_out] = B::float_shape(&output_grad).dims;
|
||||
|
||||
Conv2dBackward::new(
|
||||
x_grad,
|
||||
weight_grad,
|
||||
bias.map(|b| {
|
||||
let grad = B::float_swap_dims(output_grad, 0, 1);
|
||||
let grad = B::float_reshape(
|
||||
grad,
|
||||
|
@ -376,27 +429,17 @@ pub(crate) fn conv_transpose2d_backward<B: Backend>(
|
|||
);
|
||||
let grad = B::float_sum_dim(grad, 1);
|
||||
|
||||
B::float_reshape(grad, B::float_shape(&b))
|
||||
}),
|
||||
)
|
||||
B::float_reshape(grad, B::float_shape(&bias))
|
||||
}
|
||||
|
||||
/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass using convolutions.
|
||||
pub(crate) fn conv_transpose3d_backward<B: Backend>(
|
||||
x: FloatTensor<B, 5>,
|
||||
/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `x`.
|
||||
pub(crate) fn conv_transpose3d_x_backward<B: Backend>(
|
||||
weight: FloatTensor<B, 5>,
|
||||
bias: Option<FloatTensor<B, 1>>,
|
||||
output_grad: FloatTensor<B, 5>,
|
||||
options: ConvTransposeOptions<3>,
|
||||
) -> Conv3dBackward<B> {
|
||||
let weight_shape = B::float_shape(&weight);
|
||||
let weight_device = B::float_device(&weight);
|
||||
|
||||
let [batch_size, _channels_in, _, _, _] = B::float_shape(&x).dims;
|
||||
let [_, channels_out, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims;
|
||||
|
||||
let x_grad = B::conv3d(
|
||||
output_grad.clone(),
|
||||
) -> FloatTensor<B, 5> {
|
||||
B::conv3d(
|
||||
output_grad,
|
||||
weight,
|
||||
None,
|
||||
ConvOptions::new(
|
||||
|
@ -405,27 +448,39 @@ pub(crate) fn conv_transpose3d_backward<B: Backend>(
|
|||
options.dilation,
|
||||
options.groups,
|
||||
),
|
||||
);
|
||||
)
|
||||
}
|
||||
|
||||
let weight_grad = match options.groups == 1 {
|
||||
true => conv_transpose3d_weight_grad_no_groups::<B>(
|
||||
x,
|
||||
output_grad.clone(),
|
||||
weight_shape,
|
||||
options,
|
||||
),
|
||||
/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `weight`.
|
||||
pub(crate) fn conv_transpose3d_weight_backward<B: Backend>(
|
||||
x: FloatTensor<B, 5>,
|
||||
weight: FloatTensor<B, 5>,
|
||||
output_grad: FloatTensor<B, 5>,
|
||||
options: ConvTransposeOptions<3>,
|
||||
) -> FloatTensor<B, 5> {
|
||||
let weight_shape = B::float_shape(&weight);
|
||||
let weight_device = B::float_device(&weight);
|
||||
|
||||
match options.groups == 1 {
|
||||
true => conv_transpose3d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
|
||||
false => conv_transpose3d_weight_grad_groups::<B>(
|
||||
x,
|
||||
B::float_zeros(weight_shape, &weight_device),
|
||||
output_grad.clone(),
|
||||
output_grad,
|
||||
options,
|
||||
),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `bias`.
|
||||
pub(crate) fn conv_transpose3d_bias_backward<B: Backend>(
|
||||
x: FloatTensor<B, 5>,
|
||||
bias: FloatTensor<B, 1>,
|
||||
output_grad: FloatTensor<B, 5>,
|
||||
) -> FloatTensor<B, 1> {
|
||||
let [batch_size, _channels_in, _, _, _] = B::float_shape(&x).dims;
|
||||
let [_, channels_out, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims;
|
||||
|
||||
Conv3dBackward::new(
|
||||
x_grad,
|
||||
weight_grad,
|
||||
bias.map(|b| {
|
||||
let grad = B::float_swap_dims(output_grad, 0, 1);
|
||||
let grad = B::float_reshape(
|
||||
grad,
|
||||
|
@ -436,9 +491,7 @@ pub(crate) fn conv_transpose3d_backward<B: Backend>(
|
|||
);
|
||||
let grad = B::float_sum_dim(grad, 1);
|
||||
|
||||
B::float_reshape(grad, B::float_shape(&b))
|
||||
}),
|
||||
)
|
||||
B::float_reshape(grad, B::float_shape(&bias))
|
||||
}
|
||||
|
||||
/// Execute a 1D convolution using a 2D convolution.
|
||||
|
|
Loading…
Reference in New Issue