mirror of https://github.com/tracel-ai/burn.git
Made TchTensor fields public (#677)
This commit is contained in:
parent
2fefc82099
commit
d1a708e317
|
@ -126,7 +126,6 @@ impl LRDecay {
|
||||||
let new_lr = lr / (1. + (state.time as f64 - 1.) * self.lr_decay);
|
let new_lr = lr / (1. + (state.time as f64 - 1.) * self.lr_decay);
|
||||||
|
|
||||||
let grad = grad
|
let grad = grad
|
||||||
.clone()
|
|
||||||
.div(state.sum.clone().sqrt().add_scalar(self.epsilon))
|
.div(state.sum.clone().sqrt().add_scalar(self.epsilon))
|
||||||
.mul_scalar(new_lr);
|
.mul_scalar(new_lr);
|
||||||
|
|
||||||
|
|
|
@ -167,9 +167,8 @@ impl AdaptiveMomentumW {
|
||||||
.div_scalar(1f32 - self.beta_2.powi(time));
|
.div_scalar(1f32 - self.beta_2.powi(time));
|
||||||
|
|
||||||
// Compute update delta. This still needs to be scaled by the learning rate.
|
// Compute update delta. This still needs to be scaled by the learning rate.
|
||||||
let update_delta = moment_1_corrected
|
let update_delta =
|
||||||
.clone()
|
moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon));
|
||||||
.div(moment_2_corrected.clone().sqrt().add_scalar(self.epsilon));
|
|
||||||
|
|
||||||
(
|
(
|
||||||
update_delta,
|
update_delta,
|
||||||
|
|
|
@ -156,7 +156,6 @@ impl<B: Backend, const D: usize> SquareAvgState<B, D> {
|
||||||
Some(state) => {
|
Some(state) => {
|
||||||
let square_avg = state
|
let square_avg = state
|
||||||
.square_avg
|
.square_avg
|
||||||
.clone()
|
|
||||||
.mul_scalar(alpha)
|
.mul_scalar(alpha)
|
||||||
.add(grad.clone().powf(2.).mul_scalar(1. - alpha));
|
.add(grad.clone().powf(2.).mul_scalar(1. - alpha));
|
||||||
(grad, Self { square_avg })
|
(grad, Self { square_avg })
|
||||||
|
@ -205,7 +204,7 @@ impl<B: Backend, const D: usize> CenteredState<B, D> {
|
||||||
Some(state) => state
|
Some(state) => state
|
||||||
.grad_avg
|
.grad_avg
|
||||||
.map_or(grad_avg_constant.clone(), move |grad_avg| {
|
.map_or(grad_avg_constant.clone(), move |grad_avg| {
|
||||||
grad_avg.clone().mul_scalar(alpha).add(grad_avg_constant)
|
grad_avg.mul_scalar(alpha).add(grad_avg_constant)
|
||||||
}),
|
}),
|
||||||
_ => grad_avg_constant,
|
_ => grad_avg_constant,
|
||||||
};
|
};
|
||||||
|
@ -269,18 +268,12 @@ impl RMSPropMomentum {
|
||||||
CenteredState<B, D>,
|
CenteredState<B, D>,
|
||||||
Option<RMSPropMomentumState<B, D>>,
|
Option<RMSPropMomentumState<B, D>>,
|
||||||
) {
|
) {
|
||||||
let grad = grad
|
let grad = grad.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon));
|
||||||
.clone()
|
|
||||||
.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon));
|
|
||||||
|
|
||||||
if self.momentum > 0. {
|
if self.momentum > 0. {
|
||||||
let buf = match momentum_state {
|
let buf = match momentum_state {
|
||||||
Some(state) => state
|
Some(state) => state.buf.mul_scalar(self.momentum).add(grad),
|
||||||
.buf
|
_ => grad,
|
||||||
.clone()
|
|
||||||
.mul_scalar(self.momentum)
|
|
||||||
.add(grad.clone()),
|
|
||||||
_ => grad.clone(),
|
|
||||||
};
|
};
|
||||||
(
|
(
|
||||||
buf.clone(),
|
buf.clone(),
|
||||||
|
@ -288,7 +281,7 @@ impl RMSPropMomentum {
|
||||||
Some(RMSPropMomentumState { buf }),
|
Some(RMSPropMomentumState { buf }),
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
(grad.clone(), centered_state, None)
|
(grad, centered_state, None)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -134,7 +134,7 @@ fn linear_update_outputs(node: &mut Node) {
|
||||||
|
|
||||||
if let ArgType::Tensor(tensor) = node_input.clone().ty {
|
if let ArgType::Tensor(tensor) = node_input.clone().ty {
|
||||||
// Update the output tensor
|
// Update the output tensor
|
||||||
node.outputs[0].ty = ArgType::Tensor(tensor.clone());
|
node.outputs[0].ty = ArgType::Tensor(tensor);
|
||||||
} else {
|
} else {
|
||||||
panic!("Only tensor input is valid");
|
panic!("Only tensor input is valid");
|
||||||
}
|
}
|
||||||
|
|
|
@ -532,7 +532,7 @@ fn lift_constants(nodes: &mut Vec<Node>) {
|
||||||
AttributeValue::Tensor(tensor) => State {
|
AttributeValue::Tensor(tensor) => State {
|
||||||
// if the value is a tensor, create a new State object with the tensor as its type
|
// if the value is a tensor, create a new State object with the tensor as its type
|
||||||
name: input.name.clone(),
|
name: input.name.clone(),
|
||||||
ty: StateType::Tensor(tensor.clone()),
|
ty: StateType::Tensor(tensor),
|
||||||
},
|
},
|
||||||
_ => todo!("Support non tensor constant type"),
|
_ => todo!("Support non tensor constant type"),
|
||||||
};
|
};
|
||||||
|
|
|
@ -9,8 +9,10 @@ pub type StorageRef = Arc<*mut c_void>;
|
||||||
/// A tensor that uses the tch backend.
|
/// A tensor that uses the tch backend.
|
||||||
#[derive(Debug, PartialEq)]
|
#[derive(Debug, PartialEq)]
|
||||||
pub struct TchTensor<E: tch::kind::Element, const D: usize> {
|
pub struct TchTensor<E: tch::kind::Element, const D: usize> {
|
||||||
pub(crate) tensor: tch::Tensor,
|
/// Handle to the tensor. Call methods on this field.
|
||||||
pub(crate) storage: StorageRef,
|
pub tensor: tch::Tensor,
|
||||||
|
/// The tensor's storage
|
||||||
|
pub storage: StorageRef,
|
||||||
phantom: PhantomData<E>,
|
phantom: PhantomData<E>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -122,10 +122,7 @@ pub(crate) fn max_pool1d_with_indices_from_2d<B: Backend>(
|
||||||
);
|
);
|
||||||
let [batch_size, channels, _, length] = B::shape(&x.output).dims;
|
let [batch_size, channels, _, length] = B::shape(&x.output).dims;
|
||||||
let output = B::reshape(x.output, Shape::from([batch_size, channels, length]));
|
let output = B::reshape(x.output, Shape::from([batch_size, channels, length]));
|
||||||
let indices = B::int_reshape(
|
let indices = B::int_reshape(x.indices, Shape::from([batch_size, channels, length]));
|
||||||
x.indices.clone(),
|
|
||||||
Shape::from([batch_size, channels, length]),
|
|
||||||
);
|
|
||||||
MaxPool1dWithIndices::new(output, indices)
|
MaxPool1dWithIndices::new(output, indices)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue