Made TchTensor fields public (#677)

This commit is contained in:
Will Brickner 2023-08-22 20:37:58 -05:00 committed by GitHub
parent 2fefc82099
commit d1a708e317
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 14 additions and 24 deletions

View File

@ -126,7 +126,6 @@ impl LRDecay {
let new_lr = lr / (1. + (state.time as f64 - 1.) * self.lr_decay);
let grad = grad
.clone()
.div(state.sum.clone().sqrt().add_scalar(self.epsilon))
.mul_scalar(new_lr);

View File

@ -167,9 +167,8 @@ impl AdaptiveMomentumW {
.div_scalar(1f32 - self.beta_2.powi(time));
// Compute update delta. This still needs to be scaled by the learning rate.
let update_delta = moment_1_corrected
.clone()
.div(moment_2_corrected.clone().sqrt().add_scalar(self.epsilon));
let update_delta =
moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon));
(
update_delta,

View File

@ -156,7 +156,6 @@ impl<B: Backend, const D: usize> SquareAvgState<B, D> {
Some(state) => {
let square_avg = state
.square_avg
.clone()
.mul_scalar(alpha)
.add(grad.clone().powf(2.).mul_scalar(1. - alpha));
(grad, Self { square_avg })
@ -205,7 +204,7 @@ impl<B: Backend, const D: usize> CenteredState<B, D> {
Some(state) => state
.grad_avg
.map_or(grad_avg_constant.clone(), move |grad_avg| {
grad_avg.clone().mul_scalar(alpha).add(grad_avg_constant)
grad_avg.mul_scalar(alpha).add(grad_avg_constant)
}),
_ => grad_avg_constant,
};
@ -269,18 +268,12 @@ impl RMSPropMomentum {
CenteredState<B, D>,
Option<RMSPropMomentumState<B, D>>,
) {
let grad = grad
.clone()
.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon));
let grad = grad.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon));
if self.momentum > 0. {
let buf = match momentum_state {
Some(state) => state
.buf
.clone()
.mul_scalar(self.momentum)
.add(grad.clone()),
_ => grad.clone(),
Some(state) => state.buf.mul_scalar(self.momentum).add(grad),
_ => grad,
};
(
buf.clone(),
@ -288,7 +281,7 @@ impl RMSPropMomentum {
Some(RMSPropMomentumState { buf }),
)
} else {
(grad.clone(), centered_state, None)
(grad, centered_state, None)
}
}
}

View File

@ -134,7 +134,7 @@ fn linear_update_outputs(node: &mut Node) {
if let ArgType::Tensor(tensor) = node_input.clone().ty {
// Update the output tensor
node.outputs[0].ty = ArgType::Tensor(tensor.clone());
node.outputs[0].ty = ArgType::Tensor(tensor);
} else {
panic!("Only tensor input is valid");
}

View File

@ -532,7 +532,7 @@ fn lift_constants(nodes: &mut Vec<Node>) {
AttributeValue::Tensor(tensor) => State {
// if the value is a tensor, create a new State object with the tensor as its type
name: input.name.clone(),
ty: StateType::Tensor(tensor.clone()),
ty: StateType::Tensor(tensor),
},
_ => todo!("Support non tensor constant type"),
};

View File

@ -9,8 +9,10 @@ pub type StorageRef = Arc<*mut c_void>;
/// A tensor that uses the tch backend.
#[derive(Debug, PartialEq)]
pub struct TchTensor<E: tch::kind::Element, const D: usize> {
pub(crate) tensor: tch::Tensor,
pub(crate) storage: StorageRef,
/// Handle to the tensor. Call methods on this field.
pub tensor: tch::Tensor,
/// The tensor's storage
pub storage: StorageRef,
phantom: PhantomData<E>,
}

View File

@ -122,10 +122,7 @@ pub(crate) fn max_pool1d_with_indices_from_2d<B: Backend>(
);
let [batch_size, channels, _, length] = B::shape(&x.output).dims;
let output = B::reshape(x.output, Shape::from([batch_size, channels, length]));
let indices = B::int_reshape(
x.indices.clone(),
Shape::from([batch_size, channels, length]),
);
let indices = B::int_reshape(x.indices, Shape::from([batch_size, channels, length]));
MaxPool1dWithIndices::new(output, indices)
}