mirror of https://github.com/tracel-ai/burn.git
remove manual option matching (#1948)
This commit is contained in:
parent
41c7a5cf4b
commit
3a9367de73
|
@ -319,12 +319,7 @@ impl MemoryPool {
|
|||
|
||||
let slice_id =
|
||||
self.ring
|
||||
.find_free_slice(effective_size, &mut self.chunks, &mut self.slices);
|
||||
|
||||
let slice_id = match slice_id {
|
||||
Some(val) => val,
|
||||
None => return None,
|
||||
};
|
||||
.find_free_slice(effective_size, &mut self.chunks, &mut self.slices)?;
|
||||
|
||||
let slice = self.slices.get_mut(&slice_id).unwrap();
|
||||
let old_slice_size = slice.effective_size();
|
||||
|
|
|
@ -151,12 +151,7 @@ impl SmallMemoryPool {
|
|||
/// Finds a free slice that can contain the given size
|
||||
/// Returns the chunk's id and size.
|
||||
fn get_free_slice(&mut self, size: usize) -> Option<SliceHandle> {
|
||||
let slice_id = self.find_free_slice();
|
||||
|
||||
let slice_id = match slice_id {
|
||||
Some(val) => val,
|
||||
None => return None,
|
||||
};
|
||||
let slice_id = self.find_free_slice()?;
|
||||
|
||||
let slice = self.slices.get_mut(&slice_id).unwrap();
|
||||
let old_slice_size = slice.effective_size();
|
||||
|
|
|
@ -96,12 +96,7 @@ impl<K: AutotuneKey> TuneCache<K> {
|
|||
}
|
||||
|
||||
pub(crate) fn find_fastest(&self, key: &K) -> Option<usize> {
|
||||
let result = self.in_memory_cache.get(key);
|
||||
|
||||
let val = match result {
|
||||
Some(val) => val,
|
||||
None => return None,
|
||||
};
|
||||
let val = self.in_memory_cache.get(key)?;
|
||||
|
||||
#[cfg(feature = "autotune-persistent-cache")]
|
||||
if val.checksum_checked {
|
||||
|
|
|
@ -17,10 +17,7 @@ impl VariablePool {
|
|||
let map = self.map.borrow();
|
||||
|
||||
// Filter for candidate variables of the same Item
|
||||
let variables = match map.get(&item) {
|
||||
Some(val) => val,
|
||||
None => return None,
|
||||
};
|
||||
let variables = map.get(&item)?;
|
||||
|
||||
// Among the candidates, take a variable if it's only referenced by the map
|
||||
// Arbitrarily takes the first it finds in reverse order.
|
||||
|
|
|
@ -42,10 +42,7 @@ where
|
|||
I: Clone + Send + Sync,
|
||||
{
|
||||
fn get(&self, index: usize) -> Option<I> {
|
||||
let index = match self.indices.get(index) {
|
||||
Some(index) => index,
|
||||
None => return None,
|
||||
};
|
||||
let index = self.indices.get(index)?;
|
||||
self.dataset.get(*index)
|
||||
}
|
||||
|
||||
|
|
|
@ -42,10 +42,7 @@ where
|
|||
where
|
||||
B: Backend,
|
||||
{
|
||||
let grad = match self.tensors.get(id) {
|
||||
Some(grad) => grad,
|
||||
None => return None,
|
||||
};
|
||||
let grad = self.tensors.get(id)?;
|
||||
|
||||
let tensor = grad
|
||||
.downcast_ref::<TensorPrimitive<B, D>>()
|
||||
|
|
|
@ -116,10 +116,7 @@ impl ProgressEstimate {
|
|||
}
|
||||
|
||||
fn secs(&self) -> Option<u64> {
|
||||
let eta = match self.started_after_warmup {
|
||||
Some(started) => started.elapsed(),
|
||||
None => return None,
|
||||
};
|
||||
let eta = self.started_after_warmup?.elapsed();
|
||||
|
||||
let total_estimated = (eta.as_secs() as f64) / self.progress;
|
||||
|
||||
|
|
Loading…
Reference in New Issue