Keep previous FSRS parameters if they get worse when optimizing (#2996)

* Update to fsrs-rs 0.3.0

* Keep previous FSRS parameters if they get worse when optimizing
This commit is contained in:
Abdo 2024-02-11 09:26:04 +03:00 committed by GitHub
parent e136ec65e9
commit 4ef389b580
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 22 additions and 8 deletions

4
Cargo.lock generated
View File

@ -1794,9 +1794,9 @@ dependencies = [
[[package]] [[package]]
name = "fsrs" name = "fsrs"
version = "0.2.0" version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4938928321ed55a54cd8b90f0aa9dad79ee223aed6462cc37e83eb80b8bddb5a" checksum = "eece5d2325704da8667c3a796e64a3c0b046100ff758ba88f91d0dee80d6deb9"
dependencies = [ dependencies = [
"burn", "burn",
"itertools 0.12.0", "itertools 0.12.0",

View File

@ -35,7 +35,7 @@ git = "https://github.com/ankitects/linkcheck.git"
rev = "184b2ca50ed39ca43da13f0b830a463861adb9ca" rev = "184b2ca50ed39ca43da13f0b830a463861adb9ca"
[workspace.dependencies.fsrs] [workspace.dependencies.fsrs]
version = "0.2.0" version = "0.3.0"
# git = "https://github.com/open-spaced-repetition/fsrs-rs.git" # git = "https://github.com/open-spaced-repetition/fsrs-rs.git"
# rev = "58ca25ed2bc4bb1dc376208bbcaed7f5a501b941" # rev = "58ca25ed2bc4bb1dc376208bbcaed7f5a501b941"
# path = "../../../fsrs-rs" # path = "../../../fsrs-rs"

View File

@ -1216,7 +1216,7 @@
}, },
{ {
"name": "fsrs", "name": "fsrs",
"version": "0.2.0", "version": "0.3.0",
"authors": "Open Spaced Repetition", "authors": "Open Spaced Repetition",
"repository": "https://github.com/open-spaced-repetition/fsrs-rs", "repository": "https://github.com/open-spaced-repetition/fsrs-rs",
"license": "BSD-3-Clause", "license": "BSD-3-Clause",

View File

@ -338,6 +338,7 @@ message RepositionDefaultsResponse {
message ComputeFsrsWeightsRequest { message ComputeFsrsWeightsRequest {
/// The search used to gather cards for training /// The search used to gather cards for training
string search = 1; string search = 1;
repeated float current_weights = 2;
} }
message ComputeFsrsWeightsResponse { message ComputeFsrsWeightsResponse {

View File

@ -332,7 +332,12 @@ impl Collection {
} else { } else {
config.inner.weight_search.clone() config.inner.weight_search.clone()
}; };
match self.compute_weights(&search, idx as u32 + 1, config_len) { match self.compute_weights(
&search,
idx as u32 + 1,
config_len,
&config.inner.fsrs_weights,
) {
Ok(weights) => { Ok(weights) => {
if weights.fsrs_items >= 1000 { if weights.fsrs_items >= 1000 {
println!("{}: {:?}", config.name, weights.weights); println!("{}: {:?}", config.name, weights.weights);

View File

@ -35,6 +35,7 @@ impl Collection {
search: &str, search: &str,
current_preset: u32, current_preset: u32,
total_presets: u32, total_presets: u32,
current_weights: &Weights,
) -> Result<ComputeFsrsWeightsResponse> { ) -> Result<ComputeFsrsWeightsResponse> {
let mut anki_progress = self.new_progress_handler::<ComputeWeightsProgress>(); let mut anki_progress = self.new_progress_handler::<ComputeWeightsProgress>();
let timing = self.timing_today()?; let timing = self.timing_today()?;
@ -69,8 +70,14 @@ impl Collection {
} }
} }
}); });
let fsrs = FSRS::new(None)?; let fsrs = FSRS::new(Some(current_weights))?;
let weights = fsrs.compute_weights(items, revlogs.len() < 1000, Some(progress2))?; let mut weights =
fsrs.compute_weights(items.clone(), revlogs.len() < 1000, Some(progress2))?;
let metrics = fsrs.universal_metrics(items, &weights, |_| true)?;
if metrics.0 < metrics.1 {
weights = current_weights.to_vec();
}
Ok(ComputeFsrsWeightsResponse { Ok(ComputeFsrsWeightsResponse {
weights, weights,
fsrs_items, fsrs_items,

View File

@ -254,7 +254,7 @@ impl crate::services::SchedulerService for Collection {
&mut self, &mut self,
input: scheduler::ComputeFsrsWeightsRequest, input: scheduler::ComputeFsrsWeightsRequest,
) -> Result<scheduler::ComputeFsrsWeightsResponse> { ) -> Result<scheduler::ComputeFsrsWeightsResponse> {
self.compute_weights(&input.search, 1, 1) self.compute_weights(&input.search, 1, 1, &input.current_weights)
} }
fn compute_optimal_retention( fn compute_optimal_retention(

View File

@ -104,6 +104,7 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
search: $config.weightSearch search: $config.weightSearch
? $config.weightSearch ? $config.weightSearch
: defaultWeightSearch, : defaultWeightSearch,
currentWeights: $config.fsrsWeights,
}); });
if (computeWeightsProgress) { if (computeWeightsProgress) {
computeWeightsProgress.current = computeWeightsProgress.total; computeWeightsProgress.current = computeWeightsProgress.total;