Made TchTensor fields public (#677)

This commit is contained in:
Will Brickner 2023-08-22 20:37:58 -05:00 committed by GitHub
parent 2fefc82099
commit d1a708e317
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 14 additions and 24 deletions

View File

@ -126,7 +126,6 @@ impl LRDecay {
let new_lr = lr / (1. + (state.time as f64 - 1.) * self.lr_decay); let new_lr = lr / (1. + (state.time as f64 - 1.) * self.lr_decay);
let grad = grad let grad = grad
.clone()
.div(state.sum.clone().sqrt().add_scalar(self.epsilon)) .div(state.sum.clone().sqrt().add_scalar(self.epsilon))
.mul_scalar(new_lr); .mul_scalar(new_lr);

View File

@ -167,9 +167,8 @@ impl AdaptiveMomentumW {
.div_scalar(1f32 - self.beta_2.powi(time)); .div_scalar(1f32 - self.beta_2.powi(time));
// Compute update delta. This still needs to be scaled by the learning rate. // Compute update delta. This still needs to be scaled by the learning rate.
let update_delta = moment_1_corrected let update_delta =
.clone() moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon));
.div(moment_2_corrected.clone().sqrt().add_scalar(self.epsilon));
( (
update_delta, update_delta,

View File

@ -156,7 +156,6 @@ impl<B: Backend, const D: usize> SquareAvgState<B, D> {
Some(state) => { Some(state) => {
let square_avg = state let square_avg = state
.square_avg .square_avg
.clone()
.mul_scalar(alpha) .mul_scalar(alpha)
.add(grad.clone().powf(2.).mul_scalar(1. - alpha)); .add(grad.clone().powf(2.).mul_scalar(1. - alpha));
(grad, Self { square_avg }) (grad, Self { square_avg })
@ -205,7 +204,7 @@ impl<B: Backend, const D: usize> CenteredState<B, D> {
Some(state) => state Some(state) => state
.grad_avg .grad_avg
.map_or(grad_avg_constant.clone(), move |grad_avg| { .map_or(grad_avg_constant.clone(), move |grad_avg| {
grad_avg.clone().mul_scalar(alpha).add(grad_avg_constant) grad_avg.mul_scalar(alpha).add(grad_avg_constant)
}), }),
_ => grad_avg_constant, _ => grad_avg_constant,
}; };
@ -269,18 +268,12 @@ impl RMSPropMomentum {
CenteredState<B, D>, CenteredState<B, D>,
Option<RMSPropMomentumState<B, D>>, Option<RMSPropMomentumState<B, D>>,
) { ) {
let grad = grad let grad = grad.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon));
.clone()
.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon));
if self.momentum > 0. { if self.momentum > 0. {
let buf = match momentum_state { let buf = match momentum_state {
Some(state) => state Some(state) => state.buf.mul_scalar(self.momentum).add(grad),
.buf _ => grad,
.clone()
.mul_scalar(self.momentum)
.add(grad.clone()),
_ => grad.clone(),
}; };
( (
buf.clone(), buf.clone(),
@ -288,7 +281,7 @@ impl RMSPropMomentum {
Some(RMSPropMomentumState { buf }), Some(RMSPropMomentumState { buf }),
) )
} else { } else {
(grad.clone(), centered_state, None) (grad, centered_state, None)
} }
} }
} }

View File

@ -134,7 +134,7 @@ fn linear_update_outputs(node: &mut Node) {
if let ArgType::Tensor(tensor) = node_input.clone().ty { if let ArgType::Tensor(tensor) = node_input.clone().ty {
// Update the output tensor // Update the output tensor
node.outputs[0].ty = ArgType::Tensor(tensor.clone()); node.outputs[0].ty = ArgType::Tensor(tensor);
} else { } else {
panic!("Only tensor input is valid"); panic!("Only tensor input is valid");
} }

View File

@ -532,7 +532,7 @@ fn lift_constants(nodes: &mut Vec<Node>) {
AttributeValue::Tensor(tensor) => State { AttributeValue::Tensor(tensor) => State {
// if the value is a tensor, create a new State object with the tensor as its type // if the value is a tensor, create a new State object with the tensor as its type
name: input.name.clone(), name: input.name.clone(),
ty: StateType::Tensor(tensor.clone()), ty: StateType::Tensor(tensor),
}, },
_ => todo!("Support non tensor constant type"), _ => todo!("Support non tensor constant type"),
}; };

View File

@ -9,8 +9,10 @@ pub type StorageRef = Arc<*mut c_void>;
/// A tensor that uses the tch backend. /// A tensor that uses the tch backend.
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub struct TchTensor<E: tch::kind::Element, const D: usize> { pub struct TchTensor<E: tch::kind::Element, const D: usize> {
pub(crate) tensor: tch::Tensor, /// Handle to the tensor. Call methods on this field.
pub(crate) storage: StorageRef, pub tensor: tch::Tensor,
/// The tensor's storage
pub storage: StorageRef,
phantom: PhantomData<E>, phantom: PhantomData<E>,
} }

View File

@ -122,10 +122,7 @@ pub(crate) fn max_pool1d_with_indices_from_2d<B: Backend>(
); );
let [batch_size, channels, _, length] = B::shape(&x.output).dims; let [batch_size, channels, _, length] = B::shape(&x.output).dims;
let output = B::reshape(x.output, Shape::from([batch_size, channels, length])); let output = B::reshape(x.output, Shape::from([batch_size, channels, length]));
let indices = B::int_reshape( let indices = B::int_reshape(x.indices, Shape::from([batch_size, channels, length]));
x.indices.clone(),
Shape::from([batch_size, channels, length]),
);
MaxPool1dWithIndices::new(output, indices) MaxPool1dWithIndices::new(output, indices)
} }