mirror of https://github.com/tracel-ai/burn.git
Made TchTensor fields public (#677)
This commit is contained in:
parent
2fefc82099
commit
d1a708e317
|
@ -126,7 +126,6 @@ impl LRDecay {
|
|||
let new_lr = lr / (1. + (state.time as f64 - 1.) * self.lr_decay);
|
||||
|
||||
let grad = grad
|
||||
.clone()
|
||||
.div(state.sum.clone().sqrt().add_scalar(self.epsilon))
|
||||
.mul_scalar(new_lr);
|
||||
|
||||
|
|
|
@ -167,9 +167,8 @@ impl AdaptiveMomentumW {
|
|||
.div_scalar(1f32 - self.beta_2.powi(time));
|
||||
|
||||
// Compute update delta. This still needs to be scaled by the learning rate.
|
||||
let update_delta = moment_1_corrected
|
||||
.clone()
|
||||
.div(moment_2_corrected.clone().sqrt().add_scalar(self.epsilon));
|
||||
let update_delta =
|
||||
moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon));
|
||||
|
||||
(
|
||||
update_delta,
|
||||
|
|
|
@ -156,7 +156,6 @@ impl<B: Backend, const D: usize> SquareAvgState<B, D> {
|
|||
Some(state) => {
|
||||
let square_avg = state
|
||||
.square_avg
|
||||
.clone()
|
||||
.mul_scalar(alpha)
|
||||
.add(grad.clone().powf(2.).mul_scalar(1. - alpha));
|
||||
(grad, Self { square_avg })
|
||||
|
@ -205,7 +204,7 @@ impl<B: Backend, const D: usize> CenteredState<B, D> {
|
|||
Some(state) => state
|
||||
.grad_avg
|
||||
.map_or(grad_avg_constant.clone(), move |grad_avg| {
|
||||
grad_avg.clone().mul_scalar(alpha).add(grad_avg_constant)
|
||||
grad_avg.mul_scalar(alpha).add(grad_avg_constant)
|
||||
}),
|
||||
_ => grad_avg_constant,
|
||||
};
|
||||
|
@ -269,18 +268,12 @@ impl RMSPropMomentum {
|
|||
CenteredState<B, D>,
|
||||
Option<RMSPropMomentumState<B, D>>,
|
||||
) {
|
||||
let grad = grad
|
||||
.clone()
|
||||
.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon));
|
||||
let grad = grad.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon));
|
||||
|
||||
if self.momentum > 0. {
|
||||
let buf = match momentum_state {
|
||||
Some(state) => state
|
||||
.buf
|
||||
.clone()
|
||||
.mul_scalar(self.momentum)
|
||||
.add(grad.clone()),
|
||||
_ => grad.clone(),
|
||||
Some(state) => state.buf.mul_scalar(self.momentum).add(grad),
|
||||
_ => grad,
|
||||
};
|
||||
(
|
||||
buf.clone(),
|
||||
|
@ -288,7 +281,7 @@ impl RMSPropMomentum {
|
|||
Some(RMSPropMomentumState { buf }),
|
||||
)
|
||||
} else {
|
||||
(grad.clone(), centered_state, None)
|
||||
(grad, centered_state, None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -134,7 +134,7 @@ fn linear_update_outputs(node: &mut Node) {
|
|||
|
||||
if let ArgType::Tensor(tensor) = node_input.clone().ty {
|
||||
// Update the output tensor
|
||||
node.outputs[0].ty = ArgType::Tensor(tensor.clone());
|
||||
node.outputs[0].ty = ArgType::Tensor(tensor);
|
||||
} else {
|
||||
panic!("Only tensor input is valid");
|
||||
}
|
||||
|
|
|
@ -532,7 +532,7 @@ fn lift_constants(nodes: &mut Vec<Node>) {
|
|||
AttributeValue::Tensor(tensor) => State {
|
||||
// if the value is a tensor, create a new State object with the tensor as its type
|
||||
name: input.name.clone(),
|
||||
ty: StateType::Tensor(tensor.clone()),
|
||||
ty: StateType::Tensor(tensor),
|
||||
},
|
||||
_ => todo!("Support non tensor constant type"),
|
||||
};
|
||||
|
|
|
@ -9,8 +9,10 @@ pub type StorageRef = Arc<*mut c_void>;
|
|||
/// A tensor that uses the tch backend.
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct TchTensor<E: tch::kind::Element, const D: usize> {
|
||||
pub(crate) tensor: tch::Tensor,
|
||||
pub(crate) storage: StorageRef,
|
||||
/// Handle to the tensor. Call methods on this field.
|
||||
pub tensor: tch::Tensor,
|
||||
/// The tensor's storage
|
||||
pub storage: StorageRef,
|
||||
phantom: PhantomData<E>,
|
||||
}
|
||||
|
||||
|
|
|
@ -122,10 +122,7 @@ pub(crate) fn max_pool1d_with_indices_from_2d<B: Backend>(
|
|||
);
|
||||
let [batch_size, channels, _, length] = B::shape(&x.output).dims;
|
||||
let output = B::reshape(x.output, Shape::from([batch_size, channels, length]));
|
||||
let indices = B::int_reshape(
|
||||
x.indices.clone(),
|
||||
Shape::from([batch_size, channels, length]),
|
||||
);
|
||||
let indices = B::int_reshape(x.indices, Shape::from([batch_size, channels, length]));
|
||||
MaxPool1dWithIndices::new(output, indices)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue