mirror of https://github.com/tracel-ai/burn.git
Refactor burn-tensor: Split conv backward ops to allow conditional gradient computation (#2278)
This commit is contained in:
parent
81ec64a929
commit
7ac5deebe2
|
@ -78,20 +78,28 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
||||||
let grad = grads.consume::<B, 3>(&ops.node);
|
let grad = grads.consume::<B, 3>(&ops.node);
|
||||||
|
|
||||||
let (x_state, weight_state, bias_state, options) = ops.state;
|
let (x_state, weight_state, bias_state, options) = ops.state;
|
||||||
let x = checkpointer.retrieve_node_output(x_state);
|
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(x_state);
|
||||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
let weight =
|
||||||
let bias = Some(checkpointer.retrieve_node_output(bias_state));
|
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(weight_state);
|
||||||
|
let bias =
|
||||||
let backward = B::conv1d_backward(x, weight, bias, grad, options);
|
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<1>>(bias_state);
|
||||||
|
|
||||||
if let Some(node) = node_x {
|
if let Some(node) = node_x {
|
||||||
grads.register::<B, 3>(node.id, backward.x_grad)
|
let grad = B::conv1d_x_backward(
|
||||||
|
x.clone(),
|
||||||
|
weight.clone(),
|
||||||
|
grad.clone(),
|
||||||
|
options.clone(),
|
||||||
|
);
|
||||||
|
grads.register::<B, 3>(node.id, grad)
|
||||||
}
|
}
|
||||||
if let Some(node) = node_weight {
|
if let Some(node) = node_weight {
|
||||||
grads.register::<B, 3>(node.id, backward.weights_grad)
|
let grad = B::conv1d_weight_backward(x.clone(), weight, grad.clone(), options);
|
||||||
|
grads.register::<B, 3>(node.id, grad)
|
||||||
}
|
}
|
||||||
if let Some(node) = node_bias {
|
if let Some(node) = node_bias {
|
||||||
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
|
let grad = B::conv1d_bias_backward(x, bias, grad);
|
||||||
|
grads.register::<B, 1>(node.id, grad)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -109,16 +117,22 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
||||||
let grad = grads.consume::<B, 3>(&ops.node);
|
let grad = grads.consume::<B, 3>(&ops.node);
|
||||||
|
|
||||||
let (x_state, weight_state, options) = ops.state;
|
let (x_state, weight_state, options) = ops.state;
|
||||||
let x = checkpointer.retrieve_node_output(x_state);
|
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(x_state);
|
||||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
let weight =
|
||||||
|
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(weight_state);
|
||||||
let backward = B::conv1d_backward(x, weight, None, grad, options);
|
|
||||||
|
|
||||||
if let Some(node) = node_x {
|
if let Some(node) = node_x {
|
||||||
grads.register::<B, 3>(node.id, backward.x_grad)
|
let grad = B::conv1d_x_backward(
|
||||||
|
x.clone(),
|
||||||
|
weight.clone(),
|
||||||
|
grad.clone(),
|
||||||
|
options.clone(),
|
||||||
|
);
|
||||||
|
grads.register::<B, 3>(node.id, grad)
|
||||||
}
|
}
|
||||||
if let Some(node) = node_weight {
|
if let Some(node) = node_weight {
|
||||||
grads.register::<B, 3>(node.id, backward.weights_grad)
|
let grad = B::conv1d_weight_backward(x, weight, grad, options);
|
||||||
|
grads.register::<B, 3>(node.id, grad)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -188,20 +202,32 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
||||||
let grad = grads.consume::<B, 3>(&ops.node);
|
let grad = grads.consume::<B, 3>(&ops.node);
|
||||||
|
|
||||||
let (x_state, weight_state, bias_state, options) = ops.state;
|
let (x_state, weight_state, bias_state, options) = ops.state;
|
||||||
let x = checkpointer.retrieve_node_output(x_state);
|
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(x_state);
|
||||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
let weight =
|
||||||
let bias = Some(checkpointer.retrieve_node_output(bias_state));
|
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(weight_state);
|
||||||
|
let bias =
|
||||||
let backward = B::conv_transpose1d_backward(x, weight, bias, grad, options);
|
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<1>>(bias_state);
|
||||||
|
|
||||||
if let Some(node) = node_x {
|
if let Some(node) = node_x {
|
||||||
grads.register::<B, 3>(node.id, backward.x_grad)
|
let grad = B::conv_transpose1d_x_backward(
|
||||||
|
weight.clone(),
|
||||||
|
grad.clone(),
|
||||||
|
options.clone(),
|
||||||
|
);
|
||||||
|
grads.register::<B, 3>(node.id, grad)
|
||||||
}
|
}
|
||||||
if let Some(node) = node_weight {
|
if let Some(node) = node_weight {
|
||||||
grads.register::<B, 3>(node.id, backward.weights_grad)
|
let grad = B::conv_transpose1d_weight_backward(
|
||||||
|
x.clone(),
|
||||||
|
weight,
|
||||||
|
grad.clone(),
|
||||||
|
options,
|
||||||
|
);
|
||||||
|
grads.register::<B, 3>(node.id, grad)
|
||||||
}
|
}
|
||||||
if let Some(node) = node_bias {
|
if let Some(node) = node_bias {
|
||||||
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
|
let grad = B::conv_transpose1d_bias_backward(x, bias, grad);
|
||||||
|
grads.register::<B, 1>(node.id, grad)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -219,16 +245,21 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
||||||
let grad = grads.consume::<B, 3>(&ops.node);
|
let grad = grads.consume::<B, 3>(&ops.node);
|
||||||
|
|
||||||
let (x_state, weight_state, options) = ops.state;
|
let (x_state, weight_state, options) = ops.state;
|
||||||
let x = checkpointer.retrieve_node_output(x_state);
|
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(x_state);
|
||||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
let weight =
|
||||||
|
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<3>>(weight_state);
|
||||||
let backward = B::conv_transpose1d_backward(x, weight, None, grad, options);
|
|
||||||
|
|
||||||
if let Some(node) = node_x {
|
if let Some(node) = node_x {
|
||||||
grads.register::<B, 3>(node.id, backward.x_grad)
|
let grad = B::conv_transpose1d_x_backward(
|
||||||
|
weight.clone(),
|
||||||
|
grad.clone(),
|
||||||
|
options.clone(),
|
||||||
|
);
|
||||||
|
grads.register::<B, 3>(node.id, grad)
|
||||||
}
|
}
|
||||||
if let Some(node) = node_weight {
|
if let Some(node) = node_weight {
|
||||||
grads.register::<B, 3>(node.id, backward.weights_grad)
|
let grad = B::conv_transpose1d_weight_backward(x, weight, grad, options);
|
||||||
|
grads.register::<B, 3>(node.id, grad)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -307,20 +338,29 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
||||||
let grad = grads.consume::<B, 4>(&ops.node);
|
let grad = grads.consume::<B, 4>(&ops.node);
|
||||||
|
|
||||||
let (x_state, weight_state, bias_state, options) = ops.state;
|
let (x_state, weight_state, bias_state, options) = ops.state;
|
||||||
let x = checkpointer.retrieve_node_output(x_state);
|
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(x_state);
|
||||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
let weight =
|
||||||
let bias = Some(checkpointer.retrieve_node_output(bias_state));
|
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(weight_state);
|
||||||
|
let bias =
|
||||||
let backward = B::conv2d_backward(x, weight, bias, grad, options);
|
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<1>>(bias_state);
|
||||||
|
|
||||||
if let Some(node) = node_x {
|
if let Some(node) = node_x {
|
||||||
grads.register::<B, 4>(node.id, backward.x_grad)
|
let grad = B::conv2d_x_backward(
|
||||||
|
x.clone(),
|
||||||
|
weight.clone(),
|
||||||
|
grad.clone(),
|
||||||
|
options.clone(),
|
||||||
|
);
|
||||||
|
grads.register::<B, 4>(node.id, grad)
|
||||||
}
|
}
|
||||||
if let Some(node) = node_weight {
|
if let Some(node) = node_weight {
|
||||||
grads.register::<B, 4>(node.id, backward.weights_grad)
|
let grad =
|
||||||
|
B::conv2d_weight_backward(x.clone(), weight.clone(), grad.clone(), options);
|
||||||
|
grads.register::<B, 4>(node.id, grad)
|
||||||
}
|
}
|
||||||
if let Some(node) = node_bias {
|
if let Some(node) = node_bias {
|
||||||
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
|
let grad = B::conv2d_bias_backward(x, weight, bias, grad);
|
||||||
|
grads.register::<B, 1>(node.id, grad)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -338,16 +378,22 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
||||||
let grad = grads.consume::<B, 4>(&ops.node);
|
let grad = grads.consume::<B, 4>(&ops.node);
|
||||||
|
|
||||||
let (x_state, weight_state, options) = ops.state;
|
let (x_state, weight_state, options) = ops.state;
|
||||||
let x = checkpointer.retrieve_node_output(x_state);
|
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(x_state);
|
||||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
let weight =
|
||||||
|
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(weight_state);
|
||||||
let backward = B::conv2d_backward(x, weight, None, grad, options);
|
|
||||||
|
|
||||||
if let Some(node) = node_x {
|
if let Some(node) = node_x {
|
||||||
grads.register::<B, 4>(node.id, backward.x_grad)
|
let grad = B::conv2d_x_backward(
|
||||||
|
x.clone(),
|
||||||
|
weight.clone(),
|
||||||
|
grad.clone(),
|
||||||
|
options.clone(),
|
||||||
|
);
|
||||||
|
grads.register::<B, 4>(node.id, grad)
|
||||||
}
|
}
|
||||||
if let Some(node) = node_weight {
|
if let Some(node) = node_weight {
|
||||||
grads.register::<B, 4>(node.id, backward.weights_grad)
|
let grad = B::conv2d_weight_backward(x, weight, grad, options);
|
||||||
|
grads.register::<B, 4>(node.id, grad)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -419,20 +465,32 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
||||||
let grad = grads.consume::<B, 4>(&ops.node);
|
let grad = grads.consume::<B, 4>(&ops.node);
|
||||||
|
|
||||||
let (x_state, weight_state, bias_state, options) = ops.state;
|
let (x_state, weight_state, bias_state, options) = ops.state;
|
||||||
let x = checkpointer.retrieve_node_output(x_state);
|
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(x_state);
|
||||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
let weight =
|
||||||
let bias = Some(checkpointer.retrieve_node_output(bias_state));
|
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(weight_state);
|
||||||
|
let bias =
|
||||||
let backward = B::conv_transpose2d_backward(x, weight, bias, grad, options);
|
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<1>>(bias_state);
|
||||||
|
|
||||||
if let Some(node) = node_x {
|
if let Some(node) = node_x {
|
||||||
grads.register::<B, 4>(node.id, backward.x_grad)
|
let grad = B::conv_transpose2d_x_backward(
|
||||||
|
weight.clone(),
|
||||||
|
grad.clone(),
|
||||||
|
options.clone(),
|
||||||
|
);
|
||||||
|
grads.register::<B, 4>(node.id, grad)
|
||||||
}
|
}
|
||||||
if let Some(node) = node_weight {
|
if let Some(node) = node_weight {
|
||||||
grads.register::<B, 4>(node.id, backward.weights_grad)
|
let grad = B::conv_transpose2d_weight_backward(
|
||||||
|
x.clone(),
|
||||||
|
weight,
|
||||||
|
grad.clone(),
|
||||||
|
options,
|
||||||
|
);
|
||||||
|
grads.register::<B, 4>(node.id, grad)
|
||||||
}
|
}
|
||||||
if let Some(node) = node_bias {
|
if let Some(node) = node_bias {
|
||||||
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
|
let grad = B::conv_transpose2d_bias_backward(x, bias, grad);
|
||||||
|
grads.register::<B, 1>(node.id, grad)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -450,16 +508,21 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
||||||
let grad = grads.consume::<B, 4>(&ops.node);
|
let grad = grads.consume::<B, 4>(&ops.node);
|
||||||
|
|
||||||
let (x_state, weight_state, options) = ops.state;
|
let (x_state, weight_state, options) = ops.state;
|
||||||
let x = checkpointer.retrieve_node_output(x_state);
|
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(x_state);
|
||||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
let weight =
|
||||||
|
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<4>>(weight_state);
|
||||||
let backward = B::conv_transpose2d_backward(x, weight, None, grad, options);
|
|
||||||
|
|
||||||
if let Some(node) = node_x {
|
if let Some(node) = node_x {
|
||||||
grads.register::<B, 4>(node.id, backward.x_grad)
|
let grad = B::conv_transpose2d_x_backward(
|
||||||
|
weight.clone(),
|
||||||
|
grad.clone(),
|
||||||
|
options.clone(),
|
||||||
|
);
|
||||||
|
grads.register::<B, 4>(node.id, grad)
|
||||||
}
|
}
|
||||||
if let Some(node) = node_weight {
|
if let Some(node) = node_weight {
|
||||||
grads.register::<B, 4>(node.id, backward.weights_grad)
|
let grad = B::conv_transpose2d_weight_backward(x, weight, grad, options);
|
||||||
|
grads.register::<B, 4>(node.id, grad)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -540,20 +603,29 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
||||||
let grad = grads.consume::<B, 5>(&ops.node);
|
let grad = grads.consume::<B, 5>(&ops.node);
|
||||||
|
|
||||||
let (x_state, weight_state, bias_state, options) = ops.state;
|
let (x_state, weight_state, bias_state, options) = ops.state;
|
||||||
let x = checkpointer.retrieve_node_output(x_state);
|
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(x_state);
|
||||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
let weight =
|
||||||
let bias = Some(checkpointer.retrieve_node_output(bias_state));
|
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(weight_state);
|
||||||
|
let bias =
|
||||||
let backward = B::conv3d_backward(x, weight, bias, grad, options);
|
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<1>>(bias_state);
|
||||||
|
|
||||||
if let Some(node) = node_x {
|
if let Some(node) = node_x {
|
||||||
grads.register::<B, 5>(node.id, backward.x_grad)
|
let grad = B::conv3d_x_backward(
|
||||||
|
x.clone(),
|
||||||
|
weight.clone(),
|
||||||
|
grad.clone(),
|
||||||
|
options.clone(),
|
||||||
|
);
|
||||||
|
grads.register::<B, 5>(node.id, grad)
|
||||||
}
|
}
|
||||||
if let Some(node) = node_weight {
|
if let Some(node) = node_weight {
|
||||||
grads.register::<B, 5>(node.id, backward.weights_grad)
|
let grad =
|
||||||
|
B::conv3d_weight_backward(x.clone(), weight.clone(), grad.clone(), options);
|
||||||
|
grads.register::<B, 5>(node.id, grad)
|
||||||
}
|
}
|
||||||
if let Some(node) = node_bias {
|
if let Some(node) = node_bias {
|
||||||
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
|
let grad = B::conv3d_bias_backward(x, weight, bias, grad);
|
||||||
|
grads.register::<B, 1>(node.id, grad)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -571,16 +643,22 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
||||||
let grad = grads.consume::<B, 5>(&ops.node);
|
let grad = grads.consume::<B, 5>(&ops.node);
|
||||||
|
|
||||||
let (x_state, weight_state, options) = ops.state;
|
let (x_state, weight_state, options) = ops.state;
|
||||||
let x = checkpointer.retrieve_node_output(x_state);
|
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(x_state);
|
||||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
let weight =
|
||||||
|
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(weight_state);
|
||||||
let backward = B::conv3d_backward(x, weight, None, grad, options);
|
|
||||||
|
|
||||||
if let Some(node) = node_x {
|
if let Some(node) = node_x {
|
||||||
grads.register::<B, 5>(node.id, backward.x_grad)
|
let grad = B::conv3d_x_backward(
|
||||||
|
x.clone(),
|
||||||
|
weight.clone(),
|
||||||
|
grad.clone(),
|
||||||
|
options.clone(),
|
||||||
|
);
|
||||||
|
grads.register::<B, 5>(node.id, grad)
|
||||||
}
|
}
|
||||||
if let Some(node) = node_weight {
|
if let Some(node) = node_weight {
|
||||||
grads.register::<B, 5>(node.id, backward.weights_grad)
|
let grad = B::conv3d_weight_backward(x, weight, grad, options);
|
||||||
|
grads.register::<B, 5>(node.id, grad)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -652,20 +730,32 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
||||||
let grad = grads.consume::<B, 5>(&ops.node);
|
let grad = grads.consume::<B, 5>(&ops.node);
|
||||||
|
|
||||||
let (x_state, weight_state, bias_state, options) = ops.state;
|
let (x_state, weight_state, bias_state, options) = ops.state;
|
||||||
let x = checkpointer.retrieve_node_output(x_state);
|
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(x_state);
|
||||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
let weight =
|
||||||
let bias = Some(checkpointer.retrieve_node_output(bias_state));
|
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(weight_state);
|
||||||
|
let bias =
|
||||||
let backward = B::conv_transpose3d_backward(x, weight, bias, grad, options);
|
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<1>>(bias_state);
|
||||||
|
|
||||||
if let Some(node) = node_x {
|
if let Some(node) = node_x {
|
||||||
grads.register::<B, 5>(node.id, backward.x_grad)
|
let grad = B::conv_transpose3d_x_backward(
|
||||||
|
weight.clone(),
|
||||||
|
grad.clone(),
|
||||||
|
options.clone(),
|
||||||
|
);
|
||||||
|
grads.register::<B, 5>(node.id, grad)
|
||||||
}
|
}
|
||||||
if let Some(node) = node_weight {
|
if let Some(node) = node_weight {
|
||||||
grads.register::<B, 5>(node.id, backward.weights_grad)
|
let grad = B::conv_transpose3d_weight_backward(
|
||||||
|
x.clone(),
|
||||||
|
weight,
|
||||||
|
grad.clone(),
|
||||||
|
options,
|
||||||
|
);
|
||||||
|
grads.register::<B, 5>(node.id, grad)
|
||||||
}
|
}
|
||||||
if let Some(node) = node_bias {
|
if let Some(node) = node_bias {
|
||||||
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
|
let grad = B::conv_transpose3d_bias_backward(x, bias, grad);
|
||||||
|
grads.register::<B, 1>(node.id, grad)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -683,16 +773,21 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
||||||
let grad = grads.consume::<B, 5>(&ops.node);
|
let grad = grads.consume::<B, 5>(&ops.node);
|
||||||
|
|
||||||
let (x_state, weight_state, options) = ops.state;
|
let (x_state, weight_state, options) = ops.state;
|
||||||
let x = checkpointer.retrieve_node_output(x_state);
|
let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(x_state);
|
||||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
let weight =
|
||||||
|
checkpointer.retrieve_node_output::<B::FloatTensorPrimitive<5>>(weight_state);
|
||||||
let backward = B::conv_transpose3d_backward(x, weight, None, grad, options);
|
|
||||||
|
|
||||||
if let Some(node) = node_x {
|
if let Some(node) = node_x {
|
||||||
grads.register::<B, 5>(node.id, backward.x_grad)
|
let grad = B::conv_transpose3d_x_backward(
|
||||||
|
weight.clone(),
|
||||||
|
grad.clone(),
|
||||||
|
options.clone(),
|
||||||
|
);
|
||||||
|
grads.register::<B, 5>(node.id, grad)
|
||||||
}
|
}
|
||||||
if let Some(node) = node_weight {
|
if let Some(node) = node_weight {
|
||||||
grads.register::<B, 5>(node.id, backward.weights_grad)
|
let grad = B::conv_transpose3d_weight_backward(x, weight, grad, options);
|
||||||
|
grads.register::<B, 5>(node.id, grad)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,32 +5,6 @@ use crate::{
|
||||||
Shape,
|
Shape,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d).
|
|
||||||
#[derive(new)]
|
|
||||||
pub struct Conv2dBackward<B: Backend> {
|
|
||||||
/// Gradient.
|
|
||||||
pub x_grad: FloatTensor<B, 4>,
|
|
||||||
|
|
||||||
/// Weights gradient.
|
|
||||||
pub weights_grad: FloatTensor<B, 4>,
|
|
||||||
|
|
||||||
/// Bias gradient.
|
|
||||||
pub bias_grad: Option<FloatTensor<B, 1>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Gradient computed during the backward pass for each tensor used by [conv3d](ModuleOps::conv3d).
|
|
||||||
#[derive(new)]
|
|
||||||
pub struct Conv3dBackward<B: Backend> {
|
|
||||||
/// Gradient.
|
|
||||||
pub x_grad: FloatTensor<B, 5>,
|
|
||||||
|
|
||||||
/// Weights gradient.
|
|
||||||
pub weights_grad: FloatTensor<B, 5>,
|
|
||||||
|
|
||||||
/// Bias gradient.
|
|
||||||
pub bias_grad: Option<FloatTensor<B, 1>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d).
|
/// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d).
|
||||||
#[derive(new)]
|
#[derive(new)]
|
||||||
pub struct MaxPool1dBackward<B: Backend> {
|
pub struct MaxPool1dBackward<B: Backend> {
|
||||||
|
@ -65,19 +39,6 @@ pub struct MaxPool2dWithIndices<B: Backend> {
|
||||||
pub indices: IntTensor<B, 4>,
|
pub indices: IntTensor<B, 4>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gradient computed during the backward pass for each tensor used by [conv1d](ModuleOps::conv1d).
|
|
||||||
#[derive(new)]
|
|
||||||
pub struct Conv1dBackward<B: Backend> {
|
|
||||||
/// Gradient.
|
|
||||||
pub x_grad: FloatTensor<B, 3>,
|
|
||||||
|
|
||||||
/// Weights gradient.
|
|
||||||
pub weights_grad: FloatTensor<B, 3>,
|
|
||||||
|
|
||||||
/// Bias gradient.
|
|
||||||
pub bias_grad: Option<FloatTensor<B, 1>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convolution options.
|
/// Convolution options.
|
||||||
#[derive(new, Debug, Clone, Hash, PartialEq, Eq)]
|
#[derive(new, Debug, Clone, Hash, PartialEq, Eq)]
|
||||||
pub struct ConvOptions<const N: usize> {
|
pub struct ConvOptions<const N: usize> {
|
||||||
|
@ -221,15 +182,31 @@ pub trait ModuleOps<B: Backend> {
|
||||||
) -> FloatTensor<B, 3> {
|
) -> FloatTensor<B, 3> {
|
||||||
conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
|
conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
|
||||||
}
|
}
|
||||||
/// Backward pass for the [conv1d](ModuleOps::conv1d) operation.
|
/// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `x`.
|
||||||
fn conv1d_backward(
|
fn conv1d_x_backward(
|
||||||
x: FloatTensor<B, 3>,
|
x: FloatTensor<B, 3>,
|
||||||
weight: FloatTensor<B, 3>,
|
weight: FloatTensor<B, 3>,
|
||||||
bias: Option<FloatTensor<B, 1>>,
|
|
||||||
output_grad: FloatTensor<B, 3>,
|
output_grad: FloatTensor<B, 3>,
|
||||||
options: ConvOptions<1>,
|
options: ConvOptions<1>,
|
||||||
) -> Conv1dBackward<B> {
|
) -> FloatTensor<B, 3> {
|
||||||
conv::conv1d_backward(x, weight, bias, output_grad, options)
|
conv::conv1d_x_backward::<B>(x, weight, output_grad, options)
|
||||||
|
}
|
||||||
|
/// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `weight`.
|
||||||
|
fn conv1d_weight_backward(
|
||||||
|
x: FloatTensor<B, 3>,
|
||||||
|
weight: FloatTensor<B, 3>,
|
||||||
|
output_grad: FloatTensor<B, 3>,
|
||||||
|
options: ConvOptions<1>,
|
||||||
|
) -> FloatTensor<B, 3> {
|
||||||
|
conv::conv1d_weight_backward::<B>(x, weight, output_grad, options)
|
||||||
|
}
|
||||||
|
/// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `bias`.
|
||||||
|
fn conv1d_bias_backward(
|
||||||
|
x: FloatTensor<B, 3>,
|
||||||
|
bias: FloatTensor<B, 1>,
|
||||||
|
output_grad: FloatTensor<B, 3>,
|
||||||
|
) -> FloatTensor<B, 1> {
|
||||||
|
conv::conv1d_bias_backward::<B>(x, bias, output_grad)
|
||||||
}
|
}
|
||||||
/// Two dimensional convolution.
|
/// Two dimensional convolution.
|
||||||
///
|
///
|
||||||
|
@ -244,15 +221,32 @@ pub trait ModuleOps<B: Backend> {
|
||||||
bias: Option<FloatTensor<B, 1>>,
|
bias: Option<FloatTensor<B, 1>>,
|
||||||
options: ConvOptions<2>,
|
options: ConvOptions<2>,
|
||||||
) -> FloatTensor<B, 4>;
|
) -> FloatTensor<B, 4>;
|
||||||
/// Backward pass for the [conv2d](ModuleOps::conv2d) operation.
|
/// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `x`.
|
||||||
fn conv2d_backward(
|
fn conv2d_x_backward(
|
||||||
x: FloatTensor<B, 4>,
|
x: FloatTensor<B, 4>,
|
||||||
weight: FloatTensor<B, 4>,
|
weight: FloatTensor<B, 4>,
|
||||||
bias: Option<FloatTensor<B, 1>>,
|
|
||||||
output_grad: FloatTensor<B, 4>,
|
output_grad: FloatTensor<B, 4>,
|
||||||
options: ConvOptions<2>,
|
options: ConvOptions<2>,
|
||||||
) -> Conv2dBackward<B> {
|
) -> FloatTensor<B, 4> {
|
||||||
conv::conv2d_backward(x, weight, bias, output_grad, options)
|
conv::conv2d_x_backward::<B>(x, weight, output_grad, options)
|
||||||
|
}
|
||||||
|
/// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `weight`.
|
||||||
|
fn conv2d_weight_backward(
|
||||||
|
x: FloatTensor<B, 4>,
|
||||||
|
weight: FloatTensor<B, 4>,
|
||||||
|
output_grad: FloatTensor<B, 4>,
|
||||||
|
options: ConvOptions<2>,
|
||||||
|
) -> FloatTensor<B, 4> {
|
||||||
|
conv::conv2d_weight_backward::<B>(x, weight, output_grad, options)
|
||||||
|
}
|
||||||
|
/// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `bias`.
|
||||||
|
fn conv2d_bias_backward(
|
||||||
|
x: FloatTensor<B, 4>,
|
||||||
|
weight: FloatTensor<B, 4>,
|
||||||
|
bias: FloatTensor<B, 1>,
|
||||||
|
output_grad: FloatTensor<B, 4>,
|
||||||
|
) -> FloatTensor<B, 1> {
|
||||||
|
conv::conv2d_bias_backward::<B>(x, weight, bias, output_grad)
|
||||||
}
|
}
|
||||||
/// Three dimensional convolution.
|
/// Three dimensional convolution.
|
||||||
///
|
///
|
||||||
|
@ -267,15 +261,32 @@ pub trait ModuleOps<B: Backend> {
|
||||||
bias: Option<FloatTensor<B, 1>>,
|
bias: Option<FloatTensor<B, 1>>,
|
||||||
options: ConvOptions<3>,
|
options: ConvOptions<3>,
|
||||||
) -> FloatTensor<B, 5>;
|
) -> FloatTensor<B, 5>;
|
||||||
/// Backward pass for the [conv3d](ModuleOps::conv3d) operation.
|
/// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `x`.
|
||||||
fn conv3d_backward(
|
fn conv3d_x_backward(
|
||||||
x: FloatTensor<B, 5>,
|
x: FloatTensor<B, 5>,
|
||||||
weight: FloatTensor<B, 5>,
|
weight: FloatTensor<B, 5>,
|
||||||
bias: Option<FloatTensor<B, 1>>,
|
|
||||||
output_grad: FloatTensor<B, 5>,
|
output_grad: FloatTensor<B, 5>,
|
||||||
options: ConvOptions<3>,
|
options: ConvOptions<3>,
|
||||||
) -> Conv3dBackward<B> {
|
) -> FloatTensor<B, 5> {
|
||||||
conv::conv3d_backward(x, weight, bias, output_grad, options)
|
conv::conv3d_x_backward::<B>(x, weight, output_grad, options)
|
||||||
|
}
|
||||||
|
/// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `weight`.
|
||||||
|
fn conv3d_weight_backward(
|
||||||
|
x: FloatTensor<B, 5>,
|
||||||
|
weight: FloatTensor<B, 5>,
|
||||||
|
output_grad: FloatTensor<B, 5>,
|
||||||
|
options: ConvOptions<3>,
|
||||||
|
) -> FloatTensor<B, 5> {
|
||||||
|
conv::conv3d_weight_backward::<B>(x, weight, output_grad, options)
|
||||||
|
}
|
||||||
|
/// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `bias`.
|
||||||
|
fn conv3d_bias_backward(
|
||||||
|
x: FloatTensor<B, 5>,
|
||||||
|
weight: FloatTensor<B, 5>,
|
||||||
|
bias: FloatTensor<B, 1>,
|
||||||
|
output_grad: FloatTensor<B, 5>,
|
||||||
|
) -> FloatTensor<B, 1> {
|
||||||
|
conv::conv3d_bias_backward::<B>(x, weight, bias, output_grad)
|
||||||
}
|
}
|
||||||
/// One dimensional transposed convolution.
|
/// One dimensional transposed convolution.
|
||||||
///
|
///
|
||||||
|
@ -292,15 +303,30 @@ pub trait ModuleOps<B: Backend> {
|
||||||
) -> FloatTensor<B, 3> {
|
) -> FloatTensor<B, 3> {
|
||||||
conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
|
conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
|
||||||
}
|
}
|
||||||
/// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation.
|
/// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `x`.
|
||||||
fn conv_transpose1d_backward(
|
fn conv_transpose1d_x_backward(
|
||||||
x: FloatTensor<B, 3>,
|
|
||||||
weight: FloatTensor<B, 3>,
|
weight: FloatTensor<B, 3>,
|
||||||
bias: Option<FloatTensor<B, 1>>,
|
|
||||||
output_grad: FloatTensor<B, 3>,
|
output_grad: FloatTensor<B, 3>,
|
||||||
options: ConvTransposeOptions<1>,
|
options: ConvTransposeOptions<1>,
|
||||||
) -> Conv1dBackward<B> {
|
) -> FloatTensor<B, 3> {
|
||||||
conv::conv_transpose1d_backward(x, weight, bias, output_grad, options)
|
conv::conv_transpose1d_x_backward::<B>(weight, output_grad, options)
|
||||||
|
}
|
||||||
|
/// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `weight`.
|
||||||
|
fn conv_transpose1d_weight_backward(
|
||||||
|
x: FloatTensor<B, 3>,
|
||||||
|
weight: FloatTensor<B, 3>,
|
||||||
|
output_grad: FloatTensor<B, 3>,
|
||||||
|
options: ConvTransposeOptions<1>,
|
||||||
|
) -> FloatTensor<B, 3> {
|
||||||
|
conv::conv_transpose1d_weight_backward::<B>(x, weight, output_grad, options)
|
||||||
|
}
|
||||||
|
/// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `bias`.
|
||||||
|
fn conv_transpose1d_bias_backward(
|
||||||
|
x: FloatTensor<B, 3>,
|
||||||
|
bias: FloatTensor<B, 1>,
|
||||||
|
output_grad: FloatTensor<B, 3>,
|
||||||
|
) -> FloatTensor<B, 1> {
|
||||||
|
conv::conv_transpose1d_bias_backward::<B>(x, bias, output_grad)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Two dimensional transposed convolution.
|
/// Two dimensional transposed convolution.
|
||||||
|
@ -316,15 +342,30 @@ pub trait ModuleOps<B: Backend> {
|
||||||
bias: Option<FloatTensor<B, 1>>,
|
bias: Option<FloatTensor<B, 1>>,
|
||||||
options: ConvTransposeOptions<2>,
|
options: ConvTransposeOptions<2>,
|
||||||
) -> FloatTensor<B, 4>;
|
) -> FloatTensor<B, 4>;
|
||||||
/// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation.
|
/// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `x`.
|
||||||
fn conv_transpose2d_backward(
|
fn conv_transpose2d_x_backward(
|
||||||
x: FloatTensor<B, 4>,
|
|
||||||
weight: FloatTensor<B, 4>,
|
weight: FloatTensor<B, 4>,
|
||||||
bias: Option<FloatTensor<B, 1>>,
|
|
||||||
output_grad: FloatTensor<B, 4>,
|
output_grad: FloatTensor<B, 4>,
|
||||||
options: ConvTransposeOptions<2>,
|
options: ConvTransposeOptions<2>,
|
||||||
) -> Conv2dBackward<B> {
|
) -> FloatTensor<B, 4> {
|
||||||
conv::conv_transpose2d_backward(x, weight, bias, output_grad, options)
|
conv::conv_transpose2d_x_backward::<B>(weight, output_grad, options)
|
||||||
|
}
|
||||||
|
/// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `weight`.
|
||||||
|
fn conv_transpose2d_weight_backward(
|
||||||
|
x: FloatTensor<B, 4>,
|
||||||
|
weight: FloatTensor<B, 4>,
|
||||||
|
output_grad: FloatTensor<B, 4>,
|
||||||
|
options: ConvTransposeOptions<2>,
|
||||||
|
) -> FloatTensor<B, 4> {
|
||||||
|
conv::conv_transpose2d_weight_backward::<B>(x, weight, output_grad, options)
|
||||||
|
}
|
||||||
|
/// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `bias`.
|
||||||
|
fn conv_transpose2d_bias_backward(
|
||||||
|
x: FloatTensor<B, 4>,
|
||||||
|
bias: FloatTensor<B, 1>,
|
||||||
|
output_grad: FloatTensor<B, 4>,
|
||||||
|
) -> FloatTensor<B, 1> {
|
||||||
|
conv::conv_transpose2d_bias_backward::<B>(x, bias, output_grad)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Three dimensional transposed convolution.
|
/// Three dimensional transposed convolution.
|
||||||
|
@ -340,15 +381,30 @@ pub trait ModuleOps<B: Backend> {
|
||||||
bias: Option<FloatTensor<B, 1>>,
|
bias: Option<FloatTensor<B, 1>>,
|
||||||
options: ConvTransposeOptions<3>,
|
options: ConvTransposeOptions<3>,
|
||||||
) -> FloatTensor<B, 5>;
|
) -> FloatTensor<B, 5>;
|
||||||
/// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation.
|
/// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `x`.
|
||||||
fn conv_transpose3d_backward(
|
fn conv_transpose3d_x_backward(
|
||||||
x: FloatTensor<B, 5>,
|
|
||||||
weight: FloatTensor<B, 5>,
|
weight: FloatTensor<B, 5>,
|
||||||
bias: Option<FloatTensor<B, 1>>,
|
|
||||||
output_grad: FloatTensor<B, 5>,
|
output_grad: FloatTensor<B, 5>,
|
||||||
options: ConvTransposeOptions<3>,
|
options: ConvTransposeOptions<3>,
|
||||||
) -> Conv3dBackward<B> {
|
) -> FloatTensor<B, 5> {
|
||||||
conv::conv_transpose3d_backward(x, weight, bias, output_grad, options)
|
conv::conv_transpose3d_x_backward::<B>(weight, output_grad, options)
|
||||||
|
}
|
||||||
|
/// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `weight`.
|
||||||
|
fn conv_transpose3d_weight_backward(
|
||||||
|
x: FloatTensor<B, 5>,
|
||||||
|
weight: FloatTensor<B, 5>,
|
||||||
|
output_grad: FloatTensor<B, 5>,
|
||||||
|
options: ConvTransposeOptions<3>,
|
||||||
|
) -> FloatTensor<B, 5> {
|
||||||
|
conv::conv_transpose3d_weight_backward::<B>(x, weight, output_grad, options)
|
||||||
|
}
|
||||||
|
/// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `bias`.
|
||||||
|
fn conv_transpose3d_bias_backward(
|
||||||
|
x: FloatTensor<B, 5>,
|
||||||
|
bias: FloatTensor<B, 1>,
|
||||||
|
output_grad: FloatTensor<B, 5>,
|
||||||
|
) -> FloatTensor<B, 1> {
|
||||||
|
conv::conv_transpose3d_bias_backward::<B>(x, bias, output_grad)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Four-dimensional unfolding.
|
/// Four-dimensional unfolding.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
#![allow(clippy::single_range_in_vec_init)]
|
#![allow(clippy::single_range_in_vec_init)]
|
||||||
use super::{Conv1dBackward, Conv2dBackward, Conv3dBackward, ConvOptions, ConvTransposeOptions};
|
use super::{ConvOptions, ConvTransposeOptions};
|
||||||
use crate::{backend::Backend, ops::FloatTensor, Shape};
|
use crate::{backend::Backend, ops::FloatTensor, Shape};
|
||||||
|
|
||||||
#[cfg(not(feature = "std"))]
|
#[cfg(not(feature = "std"))]
|
||||||
|
@ -57,19 +57,17 @@ pub fn calculate_pool_output_size(
|
||||||
((size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride) + 1
|
((size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride) + 1
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass using convolutions.
|
/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `x`.
|
||||||
pub(crate) fn conv1d_backward<B: Backend>(
|
pub(crate) fn conv1d_x_backward<B: Backend>(
|
||||||
x: FloatTensor<B, 3>,
|
x: FloatTensor<B, 3>,
|
||||||
weight: FloatTensor<B, 3>,
|
weight: FloatTensor<B, 3>,
|
||||||
bias: Option<FloatTensor<B, 1>>,
|
|
||||||
output_grad: FloatTensor<B, 3>,
|
output_grad: FloatTensor<B, 3>,
|
||||||
options: ConvOptions<1>,
|
options: ConvOptions<1>,
|
||||||
) -> Conv1dBackward<B> {
|
) -> FloatTensor<B, 3> {
|
||||||
let weight_shape = B::float_shape(&weight);
|
let weight_shape = B::float_shape(&weight);
|
||||||
let weight_device = B::float_device(&weight);
|
|
||||||
|
|
||||||
let [batch_size, _, length_in] = B::float_shape(&x).dims;
|
let [_batch_size, _, length_in] = B::float_shape(&x).dims;
|
||||||
let [_batch_size, channels_out, length_out] = B::float_shape(&output_grad).dims;
|
let [_batch_size, _channels_out, length_out] = B::float_shape(&output_grad).dims;
|
||||||
let [_, _, kernel_size] = weight_shape.dims;
|
let [_, _, kernel_size] = weight_shape.dims;
|
||||||
|
|
||||||
let padding_out = calculate_padding_out(
|
let padding_out = calculate_padding_out(
|
||||||
|
@ -81,8 +79,8 @@ pub(crate) fn conv1d_backward<B: Backend>(
|
||||||
length_out,
|
length_out,
|
||||||
);
|
);
|
||||||
|
|
||||||
let x_grad = B::conv_transpose1d(
|
B::conv_transpose1d(
|
||||||
output_grad.clone(),
|
output_grad,
|
||||||
weight,
|
weight,
|
||||||
None,
|
None,
|
||||||
ConvTransposeOptions::new(
|
ConvTransposeOptions::new(
|
||||||
|
@ -92,45 +90,58 @@ pub(crate) fn conv1d_backward<B: Backend>(
|
||||||
options.dilation,
|
options.dilation,
|
||||||
options.groups,
|
options.groups,
|
||||||
),
|
),
|
||||||
);
|
|
||||||
|
|
||||||
let weight_grad = match options.groups == 1 {
|
|
||||||
true => conv1d_weight_grad_no_groups::<B>(x, output_grad.clone(), weight_shape, options),
|
|
||||||
false => conv1d_weight_grad_groups::<B>(
|
|
||||||
x,
|
|
||||||
B::float_zeros(weight_shape, &weight_device),
|
|
||||||
output_grad.clone(),
|
|
||||||
options,
|
|
||||||
),
|
|
||||||
};
|
|
||||||
|
|
||||||
Conv1dBackward::new(
|
|
||||||
x_grad,
|
|
||||||
weight_grad,
|
|
||||||
bias.map(|b| {
|
|
||||||
let grad = B::float_swap_dims(output_grad, 0, 1);
|
|
||||||
let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out]));
|
|
||||||
let grad = B::float_sum_dim(grad, 1);
|
|
||||||
|
|
||||||
B::float_reshape(grad, B::float_shape(&b))
|
|
||||||
}),
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass using convolutions.
|
/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `weight`.
|
||||||
pub(crate) fn conv2d_backward<B: Backend>(
|
pub(crate) fn conv1d_weight_backward<B: Backend>(
|
||||||
x: FloatTensor<B, 4>,
|
x: FloatTensor<B, 3>,
|
||||||
weight: FloatTensor<B, 4>,
|
weight: FloatTensor<B, 3>,
|
||||||
bias: Option<FloatTensor<B, 1>>,
|
output_grad: FloatTensor<B, 3>,
|
||||||
output_grad: FloatTensor<B, 4>,
|
options: ConvOptions<1>,
|
||||||
options: ConvOptions<2>,
|
) -> FloatTensor<B, 3> {
|
||||||
) -> Conv2dBackward<B> {
|
|
||||||
let weight_shape = B::float_shape(&weight);
|
let weight_shape = B::float_shape(&weight);
|
||||||
let weight_device = B::float_device(&weight);
|
let weight_device = B::float_device(&weight);
|
||||||
|
|
||||||
let [batch_size, _channels_in, height_in, width_in] = B::float_shape(&x).dims;
|
match options.groups == 1 {
|
||||||
|
true => conv1d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
|
||||||
|
false => conv1d_weight_grad_groups::<B>(
|
||||||
|
x,
|
||||||
|
B::float_zeros(weight_shape, &weight_device),
|
||||||
|
output_grad,
|
||||||
|
options,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `bias`.
|
||||||
|
pub(crate) fn conv1d_bias_backward<B: Backend>(
|
||||||
|
x: FloatTensor<B, 3>,
|
||||||
|
bias: FloatTensor<B, 1>,
|
||||||
|
output_grad: FloatTensor<B, 3>,
|
||||||
|
) -> FloatTensor<B, 1> {
|
||||||
|
let [batch_size, _, _length_in] = B::float_shape(&x).dims;
|
||||||
|
let [_batch_size, channels_out, length_out] = B::float_shape(&output_grad).dims;
|
||||||
|
|
||||||
|
let grad = B::float_swap_dims(output_grad, 0, 1);
|
||||||
|
let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out]));
|
||||||
|
let grad = B::float_sum_dim(grad, 1);
|
||||||
|
|
||||||
|
B::float_reshape(grad, B::float_shape(&bias))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `x`.
|
||||||
|
pub(crate) fn conv2d_x_backward<B: Backend>(
|
||||||
|
x: FloatTensor<B, 4>,
|
||||||
|
weight: FloatTensor<B, 4>,
|
||||||
|
output_grad: FloatTensor<B, 4>,
|
||||||
|
options: ConvOptions<2>,
|
||||||
|
) -> FloatTensor<B, 4> {
|
||||||
|
let weight_shape = B::float_shape(&weight);
|
||||||
|
|
||||||
|
let [_batch_size, _channels_in, height_in, width_in] = B::float_shape(&x).dims;
|
||||||
let [_, _, height_out, width_out] = B::float_shape(&output_grad).dims;
|
let [_, _, height_out, width_out] = B::float_shape(&output_grad).dims;
|
||||||
let [channels_out, _, kernel_size_1, kernel_size_2] = weight_shape.dims;
|
let [_channels_out, _, kernel_size_1, kernel_size_2] = weight_shape.dims;
|
||||||
|
|
||||||
let padding_1_out = calculate_padding_out(
|
let padding_1_out = calculate_padding_out(
|
||||||
kernel_size_1,
|
kernel_size_1,
|
||||||
|
@ -149,8 +160,8 @@ pub(crate) fn conv2d_backward<B: Backend>(
|
||||||
width_out,
|
width_out,
|
||||||
);
|
);
|
||||||
|
|
||||||
let x_grad = B::conv_transpose2d(
|
B::conv_transpose2d(
|
||||||
output_grad.clone(),
|
output_grad,
|
||||||
weight,
|
weight,
|
||||||
None,
|
None,
|
||||||
ConvTransposeOptions::new(
|
ConvTransposeOptions::new(
|
||||||
|
@ -160,48 +171,65 @@ pub(crate) fn conv2d_backward<B: Backend>(
|
||||||
options.dilation,
|
options.dilation,
|
||||||
options.groups,
|
options.groups,
|
||||||
),
|
),
|
||||||
);
|
|
||||||
|
|
||||||
let weight_grad = match options.groups == 1 {
|
|
||||||
true => conv2d_weight_grad_no_groups::<B>(x, output_grad.clone(), weight_shape, options),
|
|
||||||
false => conv2d_weight_grad_groups::<B>(
|
|
||||||
x,
|
|
||||||
B::float_zeros(weight_shape, &weight_device),
|
|
||||||
output_grad.clone(),
|
|
||||||
options,
|
|
||||||
),
|
|
||||||
};
|
|
||||||
|
|
||||||
Conv2dBackward::new(
|
|
||||||
x_grad,
|
|
||||||
weight_grad,
|
|
||||||
bias.map(|b| {
|
|
||||||
let grad = B::float_swap_dims(output_grad, 0, 1);
|
|
||||||
let grad = B::float_reshape(
|
|
||||||
grad,
|
|
||||||
Shape::new([channels_out, batch_size * height_out * width_out]),
|
|
||||||
);
|
|
||||||
let grad = B::float_sum_dim(grad, 1);
|
|
||||||
|
|
||||||
B::float_reshape(grad, B::float_shape(&b))
|
|
||||||
}),
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass using convolutions.
|
/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `weight`.
|
||||||
pub(crate) fn conv3d_backward<B: Backend>(
|
pub(crate) fn conv2d_weight_backward<B: Backend>(
|
||||||
x: FloatTensor<B, 5>,
|
x: FloatTensor<B, 4>,
|
||||||
weight: FloatTensor<B, 5>,
|
weight: FloatTensor<B, 4>,
|
||||||
bias: Option<FloatTensor<B, 1>>,
|
output_grad: FloatTensor<B, 4>,
|
||||||
output_grad: FloatTensor<B, 5>,
|
options: ConvOptions<2>,
|
||||||
options: ConvOptions<3>,
|
) -> FloatTensor<B, 4> {
|
||||||
) -> Conv3dBackward<B> {
|
|
||||||
let weight_shape = B::float_shape(&weight);
|
let weight_shape = B::float_shape(&weight);
|
||||||
let weight_device = B::float_device(&weight);
|
let weight_device = B::float_device(&weight);
|
||||||
|
|
||||||
let [batch_size, _channels_in, depth_in, height_in, width_in] = B::float_shape(&x).dims;
|
match options.groups == 1 {
|
||||||
|
true => conv2d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
|
||||||
|
false => conv2d_weight_grad_groups::<B>(
|
||||||
|
x,
|
||||||
|
B::float_zeros(weight_shape, &weight_device),
|
||||||
|
output_grad,
|
||||||
|
options,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `bias`.
|
||||||
|
pub(crate) fn conv2d_bias_backward<B: Backend>(
|
||||||
|
x: FloatTensor<B, 4>,
|
||||||
|
weight: FloatTensor<B, 4>,
|
||||||
|
bias: FloatTensor<B, 1>,
|
||||||
|
output_grad: FloatTensor<B, 4>,
|
||||||
|
) -> FloatTensor<B, 1> {
|
||||||
|
let weight_shape = B::float_shape(&weight);
|
||||||
|
|
||||||
|
let [batch_size, _channels_in, _height_in, _width_in] = B::float_shape(&x).dims;
|
||||||
|
let [_, _, height_out, width_out] = B::float_shape(&output_grad).dims;
|
||||||
|
let [channels_out, _, _kernel_size_1, _kernel_size_2] = weight_shape.dims;
|
||||||
|
|
||||||
|
let grad = B::float_swap_dims(output_grad, 0, 1);
|
||||||
|
let grad = B::float_reshape(
|
||||||
|
grad,
|
||||||
|
Shape::new([channels_out, batch_size * height_out * width_out]),
|
||||||
|
);
|
||||||
|
let grad = B::float_sum_dim(grad, 1);
|
||||||
|
|
||||||
|
B::float_reshape(grad, B::float_shape(&bias))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `x`.
|
||||||
|
pub(crate) fn conv3d_x_backward<B: Backend>(
|
||||||
|
x: FloatTensor<B, 5>,
|
||||||
|
weight: FloatTensor<B, 5>,
|
||||||
|
output_grad: FloatTensor<B, 5>,
|
||||||
|
options: ConvOptions<3>,
|
||||||
|
) -> FloatTensor<B, 5> {
|
||||||
|
let weight_shape = B::float_shape(&weight);
|
||||||
|
|
||||||
|
let [_batch_size, _channels_in, depth_in, height_in, width_in] = B::float_shape(&x).dims;
|
||||||
let [_, _, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims;
|
let [_, _, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims;
|
||||||
let [channels_out, _, kernel_size_1, kernel_size_2, kernel_size_3] = weight_shape.dims;
|
let [_channels_out, _, kernel_size_1, kernel_size_2, kernel_size_3] = weight_shape.dims;
|
||||||
|
|
||||||
let padding_1_out = calculate_padding_out(
|
let padding_1_out = calculate_padding_out(
|
||||||
kernel_size_1,
|
kernel_size_1,
|
||||||
|
@ -228,8 +256,8 @@ pub(crate) fn conv3d_backward<B: Backend>(
|
||||||
width_out,
|
width_out,
|
||||||
);
|
);
|
||||||
|
|
||||||
let x_grad = B::conv_transpose3d(
|
B::conv_transpose3d(
|
||||||
output_grad.clone(),
|
output_grad,
|
||||||
weight,
|
weight,
|
||||||
None,
|
None,
|
||||||
ConvTransposeOptions::new(
|
ConvTransposeOptions::new(
|
||||||
|
@ -239,53 +267,64 @@ pub(crate) fn conv3d_backward<B: Backend>(
|
||||||
options.dilation,
|
options.dilation,
|
||||||
options.groups,
|
options.groups,
|
||||||
),
|
),
|
||||||
);
|
)
|
||||||
|
}
|
||||||
|
|
||||||
let weight_grad = match options.groups == 1 {
|
/// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `weight`.
|
||||||
true => conv3d_weight_grad_no_groups::<B>(x, output_grad.clone(), weight_shape, options),
|
pub(crate) fn conv3d_weight_backward<B: Backend>(
|
||||||
|
x: FloatTensor<B, 5>,
|
||||||
|
weight: FloatTensor<B, 5>,
|
||||||
|
output_grad: FloatTensor<B, 5>,
|
||||||
|
options: ConvOptions<3>,
|
||||||
|
) -> FloatTensor<B, 5> {
|
||||||
|
let weight_shape = B::float_shape(&weight);
|
||||||
|
let weight_device = B::float_device(&weight);
|
||||||
|
|
||||||
|
match options.groups == 1 {
|
||||||
|
true => conv3d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
|
||||||
false => conv3d_weight_grad_groups::<B>(
|
false => conv3d_weight_grad_groups::<B>(
|
||||||
x,
|
x,
|
||||||
B::float_zeros(weight_shape, &weight_device),
|
B::float_zeros(weight_shape, &weight_device),
|
||||||
output_grad.clone(),
|
output_grad,
|
||||||
options,
|
options,
|
||||||
),
|
),
|
||||||
};
|
}
|
||||||
|
|
||||||
Conv3dBackward::new(
|
|
||||||
x_grad,
|
|
||||||
weight_grad,
|
|
||||||
bias.map(|b| {
|
|
||||||
let grad = B::float_swap_dims(output_grad, 0, 1);
|
|
||||||
let grad = B::float_reshape(
|
|
||||||
grad,
|
|
||||||
Shape::new([
|
|
||||||
channels_out,
|
|
||||||
batch_size * depth_out * height_out * width_out,
|
|
||||||
]),
|
|
||||||
);
|
|
||||||
let grad = B::float_sum_dim(grad, 1);
|
|
||||||
|
|
||||||
B::float_reshape(grad, B::float_shape(&b))
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass using convolutions.
|
/// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `bias`.
|
||||||
pub(crate) fn conv_transpose1d_backward<B: Backend>(
|
pub(crate) fn conv3d_bias_backward<B: Backend>(
|
||||||
x: FloatTensor<B, 3>,
|
x: FloatTensor<B, 5>,
|
||||||
|
weight: FloatTensor<B, 5>,
|
||||||
|
bias: FloatTensor<B, 1>,
|
||||||
|
output_grad: FloatTensor<B, 5>,
|
||||||
|
) -> FloatTensor<B, 1> {
|
||||||
|
let weight_shape = B::float_shape(&weight);
|
||||||
|
|
||||||
|
let [batch_size, _channels_in, _depth_in, _height_in, _width_in] = B::float_shape(&x).dims;
|
||||||
|
let [_, _, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims;
|
||||||
|
let [channels_out, _, _kernel_size_1, _kernel_size_2, _kernel_size_3] = weight_shape.dims;
|
||||||
|
|
||||||
|
let grad = B::float_swap_dims(output_grad, 0, 1);
|
||||||
|
let grad = B::float_reshape(
|
||||||
|
grad,
|
||||||
|
Shape::new([
|
||||||
|
channels_out,
|
||||||
|
batch_size * depth_out * height_out * width_out,
|
||||||
|
]),
|
||||||
|
);
|
||||||
|
let grad = B::float_sum_dim(grad, 1);
|
||||||
|
|
||||||
|
B::float_reshape(grad, B::float_shape(&bias))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `x`.
|
||||||
|
pub(crate) fn conv_transpose1d_x_backward<B: Backend>(
|
||||||
weight: FloatTensor<B, 3>,
|
weight: FloatTensor<B, 3>,
|
||||||
bias: Option<FloatTensor<B, 1>>,
|
|
||||||
output_grad: FloatTensor<B, 3>,
|
output_grad: FloatTensor<B, 3>,
|
||||||
options: ConvTransposeOptions<1>,
|
options: ConvTransposeOptions<1>,
|
||||||
) -> Conv1dBackward<B> {
|
) -> FloatTensor<B, 3> {
|
||||||
let weight_shape = B::float_shape(&weight);
|
B::conv1d(
|
||||||
let weight_device = B::float_device(&weight);
|
output_grad,
|
||||||
|
|
||||||
let [batch_size, _channels_in, _] = B::float_shape(&x).dims;
|
|
||||||
let [_, channels_out, length_out] = B::float_shape(&output_grad).dims;
|
|
||||||
|
|
||||||
let x_grad = B::conv1d(
|
|
||||||
output_grad.clone(),
|
|
||||||
weight,
|
weight,
|
||||||
None,
|
None,
|
||||||
ConvOptions::new(
|
ConvOptions::new(
|
||||||
|
@ -294,52 +333,54 @@ pub(crate) fn conv_transpose1d_backward<B: Backend>(
|
||||||
options.dilation,
|
options.dilation,
|
||||||
options.groups,
|
options.groups,
|
||||||
),
|
),
|
||||||
);
|
)
|
||||||
|
}
|
||||||
|
|
||||||
let weight_grad = match options.groups == 1 {
|
/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `weight`.
|
||||||
true => conv_transpose1d_weight_grad_no_groups::<B>(
|
pub(crate) fn conv_transpose1d_weight_backward<B: Backend>(
|
||||||
x,
|
x: FloatTensor<B, 3>,
|
||||||
output_grad.clone(),
|
weight: FloatTensor<B, 3>,
|
||||||
weight_shape,
|
output_grad: FloatTensor<B, 3>,
|
||||||
options,
|
options: ConvTransposeOptions<1>,
|
||||||
),
|
) -> FloatTensor<B, 3> {
|
||||||
|
let weight_shape = B::float_shape(&weight);
|
||||||
|
let weight_device = B::float_device(&weight);
|
||||||
|
|
||||||
|
match options.groups == 1 {
|
||||||
|
true => conv_transpose1d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
|
||||||
false => conv_transpose1d_weight_grad_groups::<B>(
|
false => conv_transpose1d_weight_grad_groups::<B>(
|
||||||
x,
|
x,
|
||||||
B::float_zeros(weight_shape, &weight_device),
|
B::float_zeros(weight_shape, &weight_device),
|
||||||
output_grad.clone(),
|
output_grad,
|
||||||
options,
|
options,
|
||||||
),
|
),
|
||||||
};
|
}
|
||||||
|
|
||||||
Conv1dBackward::new(
|
|
||||||
x_grad,
|
|
||||||
weight_grad,
|
|
||||||
bias.map(|b| {
|
|
||||||
let grad = B::float_swap_dims(output_grad, 0, 1);
|
|
||||||
let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out]));
|
|
||||||
let grad = B::float_sum_dim(grad, 1);
|
|
||||||
|
|
||||||
B::float_reshape(grad, B::float_shape(&b))
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass using convolutions.
|
/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `bias`.
|
||||||
pub(crate) fn conv_transpose2d_backward<B: Backend>(
|
pub(crate) fn conv_transpose1d_bias_backward<B: Backend>(
|
||||||
x: FloatTensor<B, 4>,
|
x: FloatTensor<B, 3>,
|
||||||
|
bias: FloatTensor<B, 1>,
|
||||||
|
output_grad: FloatTensor<B, 3>,
|
||||||
|
) -> FloatTensor<B, 1> {
|
||||||
|
let [batch_size, _channels_in, _] = B::float_shape(&x).dims;
|
||||||
|
let [_, channels_out, length_out] = B::float_shape(&output_grad).dims;
|
||||||
|
|
||||||
|
let grad = B::float_swap_dims(output_grad, 0, 1);
|
||||||
|
let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out]));
|
||||||
|
let grad = B::float_sum_dim(grad, 1);
|
||||||
|
|
||||||
|
B::float_reshape(grad, B::float_shape(&bias))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `x`.
|
||||||
|
pub(crate) fn conv_transpose2d_x_backward<B: Backend>(
|
||||||
weight: FloatTensor<B, 4>,
|
weight: FloatTensor<B, 4>,
|
||||||
bias: Option<FloatTensor<B, 1>>,
|
|
||||||
output_grad: FloatTensor<B, 4>,
|
output_grad: FloatTensor<B, 4>,
|
||||||
options: ConvTransposeOptions<2>,
|
options: ConvTransposeOptions<2>,
|
||||||
) -> Conv2dBackward<B> {
|
) -> FloatTensor<B, 4> {
|
||||||
let weight_shape = B::float_shape(&weight);
|
B::conv2d(
|
||||||
let weight_device = B::float_device(&weight);
|
output_grad,
|
||||||
|
|
||||||
let [batch_size, _channels_in, _, _] = B::float_shape(&x).dims;
|
|
||||||
let [_, channels_out, height_out, width_out] = B::float_shape(&output_grad).dims;
|
|
||||||
|
|
||||||
let x_grad = B::conv2d(
|
|
||||||
output_grad.clone(),
|
|
||||||
weight,
|
weight,
|
||||||
None,
|
None,
|
||||||
ConvOptions::new(
|
ConvOptions::new(
|
||||||
|
@ -348,55 +389,57 @@ pub(crate) fn conv_transpose2d_backward<B: Backend>(
|
||||||
options.dilation,
|
options.dilation,
|
||||||
options.groups,
|
options.groups,
|
||||||
),
|
),
|
||||||
);
|
)
|
||||||
|
}
|
||||||
|
|
||||||
let weight_grad = match options.groups == 1 {
|
/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `weight`.
|
||||||
true => conv_transpose2d_weight_grad_no_groups::<B>(
|
pub(crate) fn conv_transpose2d_weight_backward<B: Backend>(
|
||||||
x,
|
x: FloatTensor<B, 4>,
|
||||||
output_grad.clone(),
|
weight: FloatTensor<B, 4>,
|
||||||
weight_shape,
|
output_grad: FloatTensor<B, 4>,
|
||||||
options,
|
options: ConvTransposeOptions<2>,
|
||||||
),
|
) -> FloatTensor<B, 4> {
|
||||||
|
let weight_shape = B::float_shape(&weight);
|
||||||
|
let weight_device = B::float_device(&weight);
|
||||||
|
|
||||||
|
match options.groups == 1 {
|
||||||
|
true => conv_transpose2d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
|
||||||
false => conv_transpose2d_weight_grad_groups::<B>(
|
false => conv_transpose2d_weight_grad_groups::<B>(
|
||||||
x,
|
x,
|
||||||
B::float_zeros(weight_shape, &weight_device),
|
B::float_zeros(weight_shape, &weight_device),
|
||||||
output_grad.clone(),
|
output_grad,
|
||||||
options,
|
options,
|
||||||
),
|
),
|
||||||
};
|
}
|
||||||
|
|
||||||
Conv2dBackward::new(
|
|
||||||
x_grad,
|
|
||||||
weight_grad,
|
|
||||||
bias.map(|b| {
|
|
||||||
let grad = B::float_swap_dims(output_grad, 0, 1);
|
|
||||||
let grad = B::float_reshape(
|
|
||||||
grad,
|
|
||||||
Shape::new([channels_out, batch_size * height_out * width_out]),
|
|
||||||
);
|
|
||||||
let grad = B::float_sum_dim(grad, 1);
|
|
||||||
|
|
||||||
B::float_reshape(grad, B::float_shape(&b))
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass using convolutions.
|
/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `bias`.
|
||||||
pub(crate) fn conv_transpose3d_backward<B: Backend>(
|
pub(crate) fn conv_transpose2d_bias_backward<B: Backend>(
|
||||||
x: FloatTensor<B, 5>,
|
x: FloatTensor<B, 4>,
|
||||||
|
bias: FloatTensor<B, 1>,
|
||||||
|
output_grad: FloatTensor<B, 4>,
|
||||||
|
) -> FloatTensor<B, 1> {
|
||||||
|
let [batch_size, _channels_in, _, _] = B::float_shape(&x).dims;
|
||||||
|
let [_, channels_out, height_out, width_out] = B::float_shape(&output_grad).dims;
|
||||||
|
|
||||||
|
let grad = B::float_swap_dims(output_grad, 0, 1);
|
||||||
|
let grad = B::float_reshape(
|
||||||
|
grad,
|
||||||
|
Shape::new([channels_out, batch_size * height_out * width_out]),
|
||||||
|
);
|
||||||
|
let grad = B::float_sum_dim(grad, 1);
|
||||||
|
|
||||||
|
B::float_reshape(grad, B::float_shape(&bias))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `x`.
|
||||||
|
pub(crate) fn conv_transpose3d_x_backward<B: Backend>(
|
||||||
weight: FloatTensor<B, 5>,
|
weight: FloatTensor<B, 5>,
|
||||||
bias: Option<FloatTensor<B, 1>>,
|
|
||||||
output_grad: FloatTensor<B, 5>,
|
output_grad: FloatTensor<B, 5>,
|
||||||
options: ConvTransposeOptions<3>,
|
options: ConvTransposeOptions<3>,
|
||||||
) -> Conv3dBackward<B> {
|
) -> FloatTensor<B, 5> {
|
||||||
let weight_shape = B::float_shape(&weight);
|
B::conv3d(
|
||||||
let weight_device = B::float_device(&weight);
|
output_grad,
|
||||||
|
|
||||||
let [batch_size, _channels_in, _, _, _] = B::float_shape(&x).dims;
|
|
||||||
let [_, channels_out, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims;
|
|
||||||
|
|
||||||
let x_grad = B::conv3d(
|
|
||||||
output_grad.clone(),
|
|
||||||
weight,
|
weight,
|
||||||
None,
|
None,
|
||||||
ConvOptions::new(
|
ConvOptions::new(
|
||||||
|
@ -405,40 +448,50 @@ pub(crate) fn conv_transpose3d_backward<B: Backend>(
|
||||||
options.dilation,
|
options.dilation,
|
||||||
options.groups,
|
options.groups,
|
||||||
),
|
),
|
||||||
);
|
)
|
||||||
|
}
|
||||||
|
|
||||||
let weight_grad = match options.groups == 1 {
|
/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `weight`.
|
||||||
true => conv_transpose3d_weight_grad_no_groups::<B>(
|
pub(crate) fn conv_transpose3d_weight_backward<B: Backend>(
|
||||||
x,
|
x: FloatTensor<B, 5>,
|
||||||
output_grad.clone(),
|
weight: FloatTensor<B, 5>,
|
||||||
weight_shape,
|
output_grad: FloatTensor<B, 5>,
|
||||||
options,
|
options: ConvTransposeOptions<3>,
|
||||||
),
|
) -> FloatTensor<B, 5> {
|
||||||
|
let weight_shape = B::float_shape(&weight);
|
||||||
|
let weight_device = B::float_device(&weight);
|
||||||
|
|
||||||
|
match options.groups == 1 {
|
||||||
|
true => conv_transpose3d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
|
||||||
false => conv_transpose3d_weight_grad_groups::<B>(
|
false => conv_transpose3d_weight_grad_groups::<B>(
|
||||||
x,
|
x,
|
||||||
B::float_zeros(weight_shape, &weight_device),
|
B::float_zeros(weight_shape, &weight_device),
|
||||||
output_grad.clone(),
|
output_grad,
|
||||||
options,
|
options,
|
||||||
),
|
),
|
||||||
};
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Conv3dBackward::new(
|
/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `bias`.
|
||||||
x_grad,
|
pub(crate) fn conv_transpose3d_bias_backward<B: Backend>(
|
||||||
weight_grad,
|
x: FloatTensor<B, 5>,
|
||||||
bias.map(|b| {
|
bias: FloatTensor<B, 1>,
|
||||||
let grad = B::float_swap_dims(output_grad, 0, 1);
|
output_grad: FloatTensor<B, 5>,
|
||||||
let grad = B::float_reshape(
|
) -> FloatTensor<B, 1> {
|
||||||
grad,
|
let [batch_size, _channels_in, _, _, _] = B::float_shape(&x).dims;
|
||||||
Shape::new([
|
let [_, channels_out, depth_out, height_out, width_out] = B::float_shape(&output_grad).dims;
|
||||||
channels_out,
|
|
||||||
batch_size * depth_out * height_out * width_out,
|
|
||||||
]),
|
|
||||||
);
|
|
||||||
let grad = B::float_sum_dim(grad, 1);
|
|
||||||
|
|
||||||
B::float_reshape(grad, B::float_shape(&b))
|
let grad = B::float_swap_dims(output_grad, 0, 1);
|
||||||
}),
|
let grad = B::float_reshape(
|
||||||
)
|
grad,
|
||||||
|
Shape::new([
|
||||||
|
channels_out,
|
||||||
|
batch_size * depth_out * height_out * width_out,
|
||||||
|
]),
|
||||||
|
);
|
||||||
|
let grad = B::float_sum_dim(grad, 1);
|
||||||
|
|
||||||
|
B::float_reshape(grad, B::float_shape(&bias))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Execute a 1D convolution using a 2D convolution.
|
/// Execute a 1D convolution using a 2D convolution.
|
||||||
|
|
Loading…
Reference in New Issue