From d1a708e317363368fdb3e0cb55b4c11e4a45702c Mon Sep 17 00:00:00 2001 From: Will Brickner Date: Tue, 22 Aug 2023 20:37:58 -0500 Subject: [PATCH] Made TchTensor fields public (#677) --- burn-core/src/optim/adagrad.rs | 1 - burn-core/src/optim/adamw.rs | 5 ++--- burn-core/src/optim/rmsprop.rs | 17 +++++------------ burn-import/src/onnx/dim_inference.rs | 2 +- burn-import/src/onnx/from_onnx.rs | 2 +- burn-tch/src/tensor.rs | 6 ++++-- burn-tensor/src/tensor/ops/modules/pool.rs | 5 +---- 7 files changed, 14 insertions(+), 24 deletions(-) diff --git a/burn-core/src/optim/adagrad.rs b/burn-core/src/optim/adagrad.rs index d44b423cb..3c2826aac 100644 --- a/burn-core/src/optim/adagrad.rs +++ b/burn-core/src/optim/adagrad.rs @@ -126,7 +126,6 @@ impl LRDecay { let new_lr = lr / (1. + (state.time as f64 - 1.) * self.lr_decay); let grad = grad - .clone() .div(state.sum.clone().sqrt().add_scalar(self.epsilon)) .mul_scalar(new_lr); diff --git a/burn-core/src/optim/adamw.rs b/burn-core/src/optim/adamw.rs index 40b39c65a..e4fd23451 100644 --- a/burn-core/src/optim/adamw.rs +++ b/burn-core/src/optim/adamw.rs @@ -167,9 +167,8 @@ impl AdaptiveMomentumW { .div_scalar(1f32 - self.beta_2.powi(time)); // Compute update delta. This still needs to be scaled by the learning rate. - let update_delta = moment_1_corrected - .clone() - .div(moment_2_corrected.clone().sqrt().add_scalar(self.epsilon)); + let update_delta = + moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon)); ( update_delta, diff --git a/burn-core/src/optim/rmsprop.rs b/burn-core/src/optim/rmsprop.rs index de754755d..90655bbc8 100644 --- a/burn-core/src/optim/rmsprop.rs +++ b/burn-core/src/optim/rmsprop.rs @@ -156,7 +156,6 @@ impl SquareAvgState { Some(state) => { let square_avg = state .square_avg - .clone() .mul_scalar(alpha) .add(grad.clone().powf(2.).mul_scalar(1. - alpha)); (grad, Self { square_avg }) @@ -205,7 +204,7 @@ impl CenteredState { Some(state) => state .grad_avg .map_or(grad_avg_constant.clone(), move |grad_avg| { - grad_avg.clone().mul_scalar(alpha).add(grad_avg_constant) + grad_avg.mul_scalar(alpha).add(grad_avg_constant) }), _ => grad_avg_constant, }; @@ -269,18 +268,12 @@ impl RMSPropMomentum { CenteredState, Option>, ) { - let grad = grad - .clone() - .div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon)); + let grad = grad.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon)); if self.momentum > 0. { let buf = match momentum_state { - Some(state) => state - .buf - .clone() - .mul_scalar(self.momentum) - .add(grad.clone()), - _ => grad.clone(), + Some(state) => state.buf.mul_scalar(self.momentum).add(grad), + _ => grad, }; ( buf.clone(), @@ -288,7 +281,7 @@ impl RMSPropMomentum { Some(RMSPropMomentumState { buf }), ) } else { - (grad.clone(), centered_state, None) + (grad, centered_state, None) } } } diff --git a/burn-import/src/onnx/dim_inference.rs b/burn-import/src/onnx/dim_inference.rs index 7526c8186..4243f8eac 100644 --- a/burn-import/src/onnx/dim_inference.rs +++ b/burn-import/src/onnx/dim_inference.rs @@ -134,7 +134,7 @@ fn linear_update_outputs(node: &mut Node) { if let ArgType::Tensor(tensor) = node_input.clone().ty { // Update the output tensor - node.outputs[0].ty = ArgType::Tensor(tensor.clone()); + node.outputs[0].ty = ArgType::Tensor(tensor); } else { panic!("Only tensor input is valid"); } diff --git a/burn-import/src/onnx/from_onnx.rs b/burn-import/src/onnx/from_onnx.rs index 8c96fcfae..8b9beab95 100644 --- a/burn-import/src/onnx/from_onnx.rs +++ b/burn-import/src/onnx/from_onnx.rs @@ -532,7 +532,7 @@ fn lift_constants(nodes: &mut Vec) { AttributeValue::Tensor(tensor) => State { // if the value is a tensor, create a new State object with the tensor as its type name: input.name.clone(), - ty: StateType::Tensor(tensor.clone()), + ty: StateType::Tensor(tensor), }, _ => todo!("Support non tensor constant type"), }; diff --git a/burn-tch/src/tensor.rs b/burn-tch/src/tensor.rs index b0005aca2..a2fdf332e 100644 --- a/burn-tch/src/tensor.rs +++ b/burn-tch/src/tensor.rs @@ -9,8 +9,10 @@ pub type StorageRef = Arc<*mut c_void>; /// A tensor that uses the tch backend. #[derive(Debug, PartialEq)] pub struct TchTensor { - pub(crate) tensor: tch::Tensor, - pub(crate) storage: StorageRef, + /// Handle to the tensor. Call methods on this field. + pub tensor: tch::Tensor, + /// The tensor's storage + pub storage: StorageRef, phantom: PhantomData, } diff --git a/burn-tensor/src/tensor/ops/modules/pool.rs b/burn-tensor/src/tensor/ops/modules/pool.rs index 22bf1ffd6..9c531169e 100644 --- a/burn-tensor/src/tensor/ops/modules/pool.rs +++ b/burn-tensor/src/tensor/ops/modules/pool.rs @@ -122,10 +122,7 @@ pub(crate) fn max_pool1d_with_indices_from_2d( ); let [batch_size, channels, _, length] = B::shape(&x.output).dims; let output = B::reshape(x.output, Shape::from([batch_size, channels, length])); - let indices = B::int_reshape( - x.indices.clone(), - Shape::from([batch_size, channels, length]), - ); + let indices = B::int_reshape(x.indices, Shape::from([batch_size, channels, length])); MaxPool1dWithIndices::new(output, indices) }