Rework syncing code, and replace local sync server (#2329)

This PR replaces the existing Python-driven sync server with a new one in Rust.
The new server supports both collection and media syncing, and is compatible
with both the new protocol mentioned below, and older clients. A setting has
been added to the preferences screen to point Anki to a local server, and a
similar setting is likely to come to AnkiMobile soon.

Documentation is available here: <https://docs.ankiweb.net/sync-server.html>

In addition to the new server and refactoring, this PR also makes changes to the
sync protocol. The existing sync protocol places payloads and metadata inside a
multipart POST body, which causes a few headaches:

- Legacy clients build the request in a non-deterministic order, meaning the
entire request needs to be scanned to extract the metadata.
- Reqwest's multipart API directly writes the multipart body, without exposing
the resulting stream to us, making it harder to track the progress of the
transfer. We've been relying on a patched version of reqwest for timeouts,
which is a pain to keep up to date.

To address these issues, the metadata is now sent in a HTTP header, with the
data payload sent directly in the body. Instead of the slower gzip, we now
use zstd. The old timeout handling code has been replaced with a new implementation
that wraps the request and response body streams to track progress, allowing us
to drop the git dependencies for reqwest, hyper-timeout and tokio-io-timeout.

The main other change to the protocol is that one-way syncs no longer need to
downgrade the collection to schema 11 prior to sending.
This commit is contained in:
Damien Elmes 2023-01-18 12:43:46 +10:00 committed by GitHub
parent 1be30573e1
commit cf45cbf429
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
135 changed files with 8490 additions and 4540 deletions

View File

@ -1,8 +1,3 @@
hakari-package = "workspace-hack"
dep-format-version = "2"
resolver = "2"
[traversal-excludes]
third-party = [
{ name = "reqwest", git = "https://github.com/ankitects/reqwest.git", rev = "7591444614de02b658ddab125efba7b2bb4e2335" },
]

2
.gitignore vendored
View File

@ -5,7 +5,7 @@ target
/user.bazelrc
.dmypy.json
/.idea/
/.vscode/
/.vscode
/.bazel
/windows.bazelrc
/out

View File

@ -2,5 +2,7 @@
# useful for manual invocation with 'cargo +nightly fmt'
imports_granularity = "Crate"
group_imports = "StdExternalCrate"
# wrap_comments = true
# imports_granularity = "Item"
# imports_layout = "Vertical"

617
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -24,10 +24,6 @@ members = [
exclude = ["qt/bundle"]
resolver = "2"
[patch.crates-io]
# If updating rev, hakari.toml needs updating too.
reqwest = { git = "https://github.com/ankitects/reqwest.git", rev = "7591444614de02b658ddab125efba7b2bb4e2335" }
# Apply mild optimizations to our dependencies in dev mode, which among other things
# improves sha2 performance by about 21x. Opt 1 chosen due to
# https://doc.rust-lang.org/cargo/reference/profiles.html#overrides-and-generics. This

View File

@ -24,5 +24,5 @@ rustls = ["reqwest/rustls-tls", "reqwest/rustls-tls-native-roots"]
native-tls = ["reqwest/native-tls"]
[dependencies.reqwest]
version = "=0.11.3"
version = "0.11.13"
default-features = false

View File

@ -116,9 +116,54 @@
"license_file": null,
"description": "Like percent_encoding, but does not encode non-ASCII characters."
},
{
"name": "assert-json-diff",
"version": "2.0.2",
"authors": "David Pedersen <david.pdrsn@gmail.com>",
"repository": "https://github.com/davidpdrsn/assert-json-diff.git",
"license": "MIT",
"license_file": null,
"description": "Easily compare two JSON values and get great output"
},
{
"name": "async-channel",
"version": "1.8.0",
"authors": "Stjepan Glavina <stjepang@gmail.com>",
"repository": "https://github.com/smol-rs/async-channel",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "Async multi-producer multi-consumer channel"
},
{
"name": "async-compression",
"version": "0.3.15",
"authors": "Wim Looman <wim@nemo157.com>|Allen Bui <fairingrey@gmail.com>",
"repository": "https://github.com/Nemo157/async-compression",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "Adaptors between compression crates and Rust's modern asynchronous IO types."
},
{
"name": "async-stream",
"version": "0.3.3",
"authors": "Carl Lerche <me@carllerche.com>",
"repository": "https://github.com/tokio-rs/async-stream",
"license": "MIT",
"license_file": null,
"description": "Asynchronous streams using async & await notation"
},
{
"name": "async-stream-impl",
"version": "0.3.3",
"authors": "Carl Lerche <me@carllerche.com>",
"repository": "https://github.com/tokio-rs/async-stream",
"license": "MIT",
"license_file": null,
"description": "proc macros for async-stream crate"
},
{
"name": "async-trait",
"version": "0.1.59",
"version": "0.1.60",
"authors": "David Tolnay <dtolnay@gmail.com>",
"repository": "https://github.com/dtolnay/async-trait",
"license": "Apache-2.0 OR MIT",
@ -134,6 +179,42 @@
"license_file": null,
"description": "Automatic cfg for Rust compiler features"
},
{
"name": "axum",
"version": "0.6.1",
"authors": null,
"repository": "https://github.com/tokio-rs/axum",
"license": "MIT",
"license_file": null,
"description": "Web framework that focuses on ergonomics and modularity"
},
{
"name": "axum-client-ip",
"version": "0.3.1",
"authors": null,
"repository": "https://github.com/imbolc/axum-client-ip",
"license": "MIT",
"license_file": null,
"description": "A client IP address extractor for Axum"
},
{
"name": "axum-core",
"version": "0.3.0",
"authors": null,
"repository": "https://github.com/tokio-rs/axum",
"license": "MIT",
"license_file": null,
"description": "Core types and traits for axum"
},
{
"name": "axum-macros",
"version": "0.3.0",
"authors": null,
"repository": "https://github.com/tokio-rs/axum",
"license": "MIT",
"license_file": null,
"description": "Macros for axum"
},
{
"name": "backtrace",
"version": "0.3.66",
@ -152,6 +233,15 @@
"license_file": null,
"description": "encodes and decodes base64 as bytes or utf8"
},
{
"name": "base64",
"version": "0.21.0",
"authors": "Alice Maz <alice@alicemaz.com>|Marshall Pierce <marshall@mpierce.org>",
"repository": "https://github.com/marshallpierce/rust-base64",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "encodes and decodes base64 as bytes or utf8"
},
{
"name": "base64ct",
"version": "1.5.3",
@ -296,6 +386,15 @@
"license_file": null,
"description": "Beautiful diagnostic reporting for text-based programming languages"
},
{
"name": "concurrent-queue",
"version": "2.0.0",
"authors": "Stjepan Glavina <stjepang@gmail.com>|Taiki Endo <te316e89@gmail.com>|John Nunley <jtnunley01@gmail.com>",
"repository": "https://github.com/smol-rs/concurrent-queue",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "Concurrent multi-producer multi-consumer queue"
},
{
"name": "constant_time_eq",
"version": "0.1.5",
@ -440,6 +539,24 @@
"license_file": null,
"description": "Implementation detail of the `cxx` crate."
},
{
"name": "deadpool",
"version": "0.9.5",
"authors": "Michael P. Jung <michael.jung@terreon.de>",
"repository": "https://github.com/bikeshedder/deadpool",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "Dead simple async pool"
},
{
"name": "deadpool-runtime",
"version": "0.1.2",
"authors": "Michael P. Jung <michael.jung@terreon.de>",
"repository": "https://github.com/bikeshedder/deadpool",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "Dead simple async pool utitities for sync managers"
},
{
"name": "digest",
"version": "0.10.6",
@ -521,6 +638,15 @@
"license_file": null,
"description": "Exposes errno functionality to stable Rust on DragonFlyBSD"
},
{
"name": "event-listener",
"version": "2.5.3",
"authors": "Stjepan Glavina <stjepang@gmail.com>",
"repository": "https://github.com/smol-rs/event-listener",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "Notify async tasks or threads"
},
{
"name": "fallible-iterator",
"version": "0.2.0",
@ -638,6 +764,15 @@
"license_file": null,
"description": "Parser and serializer for the application/x-www-form-urlencoded syntax, as used by HTML forms."
},
{
"name": "forwarded-header-value",
"version": "0.1.1",
"authors": "James Brown <jbrown@easypost.com>",
"repository": "https://github.com/EasyPost/rust-forwarded-header-value",
"license": "ISC",
"license_file": null,
"description": "Parser for values from the Forwarded header (RFC 7239)"
},
{
"name": "futf",
"version": "0.1.5",
@ -692,6 +827,15 @@
"license_file": null,
"description": "The `AsyncRead`, `AsyncWrite`, `AsyncSeek`, and `AsyncBufRead` traits for the futures-rs library."
},
{
"name": "futures-lite",
"version": "1.12.0",
"authors": "Stjepan Glavina <stjepang@gmail.com>|Contributors to futures-rs",
"repository": "https://github.com/smol-rs/futures-lite",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "Futures, streams, and async I/O combinators"
},
{
"name": "futures-macro",
"version": "0.3.25",
@ -719,6 +863,15 @@
"license_file": null,
"description": "Tools for working with tasks."
},
{
"name": "futures-timer",
"version": "3.0.2",
"authors": "Alex Crichton <alex@alexcrichton.com>",
"repository": "https://github.com/async-rs/futures-timer",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "Timeouts for futures."
},
{
"name": "futures-util",
"version": "0.3.25",
@ -746,6 +899,15 @@
"license_file": null,
"description": "getopts-like option parsing."
},
{
"name": "getrandom",
"version": "0.1.16",
"authors": "The Rand Project Developers",
"repository": "https://github.com/rust-random/getrandom",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "A small cross-platform library for retrieving random data from system source"
},
{
"name": "getrandom",
"version": "0.2.8",
@ -791,6 +953,24 @@
"license_file": null,
"description": "HashMap-like containers that hold their key-value pairs in a user controllable order"
},
{
"name": "headers",
"version": "0.3.8",
"authors": "Sean McArthur <sean@seanmonstar.com>",
"repository": "https://github.com/hyperium/headers",
"license": "MIT",
"license_file": null,
"description": "typed HTTP headers"
},
{
"name": "headers-core",
"version": "0.2.0",
"authors": "Sean McArthur <sean@seanmonstar.com>",
"repository": "https://github.com/hyperium/headers",
"license": "MIT",
"license_file": null,
"description": "typed HTTP headers core trait"
},
{
"name": "heck",
"version": "0.4.0",
@ -800,15 +980,6 @@
"license_file": null,
"description": "heck is a case conversion library."
},
{
"name": "hermit-abi",
"version": "0.1.19",
"authors": "Stefan Lankes",
"repository": "https://github.com/hermitcore/libhermit-rs",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "hermit-abi is small interface to call functions from the unikernel RustyHermit. It is used to build the target `x86_64-unknown-hermit`."
},
{
"name": "hermit-abi",
"version": "0.2.6",
@ -872,6 +1043,24 @@
"license_file": null,
"description": "Trait representing an asynchronous, streaming, HTTP request or response body."
},
{
"name": "http-range-header",
"version": "0.3.0",
"authors": null,
"repository": "https://github.com/MarcusGrass/parse-range-headers",
"license": "MIT",
"license_file": null,
"description": "No-dep range header parser"
},
{
"name": "http-types",
"version": "2.12.0",
"authors": "Yoshua Wuyts <yoshuawuyts@gmail.com>",
"repository": "https://github.com/http-rs/http-types",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "Common types for HTTP operations."
},
{
"name": "httparse",
"version": "1.8.0",
@ -910,22 +1099,13 @@
},
{
"name": "hyper-rustls",
"version": "0.22.1",
"authors": "Joseph Birr-Pixton <jpixton@gmail.com>",
"version": "0.23.2",
"authors": null,
"repository": "https://github.com/ctz/hyper-rustls",
"license": "Apache-2.0 OR ISC OR MIT",
"license_file": null,
"description": "Rustls+hyper integration for pure rust HTTPS"
},
{
"name": "hyper-timeout",
"version": "0.4.1",
"authors": "Herman J. Radtke III <herman@hermanradtke.com>",
"repository": "https://github.com/hjr3/hyper-timeout",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "A connect, read and write timeout aware connector to be used with hyper Client."
},
{
"name": "hyper-tls",
"version": "0.5.0",
@ -980,6 +1160,15 @@
"license_file": null,
"description": "A hash table with consistent order and fast iteration."
},
{
"name": "infer",
"version": "0.2.3",
"authors": "Bojan <dbojan@gmail.com>",
"repository": "https://github.com/bojand/infer",
"license": "MIT",
"license_file": null,
"description": "Small crate to infer file types based on its magic number signature"
},
{
"name": "inflections",
"version": "1.1.1",
@ -1054,7 +1243,7 @@
},
{
"name": "itoa",
"version": "1.0.4",
"version": "1.0.5",
"authors": "David Tolnay <dtolnay@gmail.com>",
"repository": "https://github.com/dtolnay/itoa",
"license": "Apache-2.0 OR MIT",
@ -1090,7 +1279,7 @@
},
{
"name": "libc",
"version": "0.2.138",
"version": "0.2.139",
"authors": "The Rust Project Developers",
"repository": "https://github.com/rust-lang/libc",
"license": "Apache-2.0 OR MIT",
@ -1187,6 +1376,15 @@
"license_file": null,
"description": "A macro to evaluate, as a boolean, whether an expression matches a pattern."
},
{
"name": "matchit",
"version": "0.7.0",
"authors": "Ibraheem Ahmed <ibraheem@ibraheem.ca>",
"repository": "https://github.com/ibraheemdev/matchit",
"license": "MIT",
"license_file": null,
"description": "A blazing fast URL router."
},
{
"name": "memchr",
"version": "2.5.0",
@ -1250,6 +1448,15 @@
"license_file": null,
"description": "Lightweight non-blocking IO"
},
{
"name": "multer",
"version": "2.0.4",
"authors": "Rousan Ali <hello@rousan.io>",
"repository": "https://github.com/rousan/multer-rs",
"license": "MIT",
"license_file": null,
"description": "An async parser for `multipart/form-data` content-type in Rust."
},
{
"name": "multimap",
"version": "0.8.3",
@ -1286,6 +1493,15 @@
"license_file": null,
"description": "A byte-oriented, zero-copy, parser combinators library"
},
{
"name": "nonempty",
"version": "0.7.0",
"authors": "Alexis Sellier <self@cloudhead.io>",
"repository": "https://github.com/cloudhead/nonempty",
"license": "MIT",
"license_file": null,
"description": "Correct by construction non-empty vector"
},
{
"name": "nu-ansi-term",
"version": "0.46.0",
@ -1324,7 +1540,7 @@
},
{
"name": "num_cpus",
"version": "1.14.0",
"version": "1.15.0",
"authors": "Sean McArthur <sean@seanmonstar.com>",
"repository": "https://github.com/seanmonstar/num_cpus",
"license": "Apache-2.0 OR MIT",
@ -1349,15 +1565,6 @@
"license_file": null,
"description": "Internal implementation details for ::num_enum (Procedural macros to make inter-operation between primitives and enums easier)"
},
{
"name": "num_threads",
"version": "0.1.6",
"authors": "Jacob Pratt <open-source@jhpratt.dev>",
"repository": "https://github.com/jhpratt/num_threads",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "A minimal library that determines the number of running threads for the current process."
},
{
"name": "object",
"version": "0.29.0",
@ -1369,7 +1576,7 @@
},
{
"name": "once_cell",
"version": "1.16.0",
"version": "1.17.0",
"authors": "Aleksey Kladov <aleksey.kladov@gmail.com>",
"repository": "https://github.com/matklad/once_cell",
"license": "Apache-2.0 OR MIT",
@ -1430,6 +1637,15 @@
"license_file": null,
"description": "Provides a macro to simplify operator overloading."
},
{
"name": "parking",
"version": "2.0.0",
"authors": "Stjepan Glavina <stjepang@gmail.com>|The Rust Project Developers",
"repository": "https://github.com/stjepang/parking",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "Thread parking and unparking"
},
{
"name": "parking_lot",
"version": "0.12.1",
@ -1648,7 +1864,7 @@
},
{
"name": "proc-macro2",
"version": "1.0.47",
"version": "1.0.49",
"authors": "David Tolnay <dtolnay@gmail.com>|Alex Crichton <alex@alexcrichton.com>",
"repository": "https://github.com/dtolnay/proc-macro2",
"license": "Apache-2.0 OR MIT",
@ -1702,13 +1918,22 @@
},
{
"name": "quote",
"version": "1.0.21",
"version": "1.0.23",
"authors": "David Tolnay <dtolnay@gmail.com>",
"repository": "https://github.com/dtolnay/quote",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "Quasi-quoting macro quote!(...)"
},
{
"name": "rand",
"version": "0.7.3",
"authors": "The Rand Project Developers|The Rust Project Developers",
"repository": "https://github.com/rust-random/rand",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "Random number generators and other randomness functionality."
},
{
"name": "rand",
"version": "0.8.5",
@ -1718,6 +1943,15 @@
"license_file": null,
"description": "Random number generators and other randomness functionality."
},
{
"name": "rand_chacha",
"version": "0.2.2",
"authors": "The Rand Project Developers|The Rust Project Developers|The CryptoCorrosion Contributors",
"repository": "https://github.com/rust-random/rand",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "ChaCha random number generator"
},
{
"name": "rand_chacha",
"version": "0.3.1",
@ -1727,6 +1961,15 @@
"license_file": null,
"description": "ChaCha random number generator"
},
{
"name": "rand_core",
"version": "0.5.1",
"authors": "The Rand Project Developers|The Rust Project Developers",
"repository": "https://github.com/rust-random/rand",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "Core random number generator traits and tools for implementation."
},
{
"name": "rand_core",
"version": "0.6.4",
@ -1736,6 +1979,24 @@
"license_file": null,
"description": "Core random number generator traits and tools for implementation."
},
{
"name": "rand_hc",
"version": "0.2.0",
"authors": "The Rand Project Developers",
"repository": "https://github.com/rust-random/rand",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "HC128 random number generator"
},
{
"name": "rand_pcg",
"version": "0.2.1",
"authors": "The Rand Project Developers",
"repository": "https://github.com/rust-random/rand",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "Selected PCG random number generators"
},
{
"name": "redox_syscall",
"version": "0.2.16",
@ -1783,13 +2044,22 @@
},
{
"name": "reqwest",
"version": "0.11.3",
"version": "0.11.13",
"authors": "Sean McArthur <sean@seanmonstar.com>",
"repository": "https://github.com/seanmonstar/reqwest",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "higher level HTTP client library"
},
{
"name": "retain_mut",
"version": "0.1.9",
"authors": "Xidorn Quan <me@upsuper.org>",
"repository": "https://github.com/upsuper/retain_mut",
"license": "MIT",
"license_file": null,
"description": "Provide retain_mut method that has the same functionality as retain but gives mutable borrow to the predicate."
},
{
"name": "ring",
"version": "0.16.20",
@ -1837,25 +2107,34 @@
},
{
"name": "rustls",
"version": "0.19.1",
"authors": "Joseph Birr-Pixton <jpixton@gmail.com>",
"repository": "https://github.com/ctz/rustls",
"version": "0.20.8",
"authors": null,
"repository": "https://github.com/rustls/rustls",
"license": "Apache-2.0 OR ISC OR MIT",
"license_file": null,
"description": "Rustls is a modern TLS library written in Rust."
},
{
"name": "rustls-native-certs",
"version": "0.5.0",
"version": "0.6.2",
"authors": "Joseph Birr-Pixton <jpixton@gmail.com>",
"repository": "https://github.com/ctz/rustls-native-certs",
"license": "Apache-2.0 OR ISC OR MIT",
"license_file": null,
"description": "rustls-native-certs allows rustls to use the platform native certificate store"
},
{
"name": "rustls-pemfile",
"version": "1.0.2",
"authors": null,
"repository": "https://github.com/rustls/pemfile",
"license": "Apache-2.0 OR ISC OR MIT",
"license_file": null,
"description": "Basic .pem file parser for keys and certificates"
},
{
"name": "rustversion",
"version": "1.0.9",
"version": "1.0.11",
"authors": "David Tolnay <dtolnay@gmail.com>",
"repository": "https://github.com/dtolnay/rustversion",
"license": "Apache-2.0 OR MIT",
@ -1864,7 +2143,7 @@
},
{
"name": "ryu",
"version": "1.0.11",
"version": "1.0.12",
"authors": "David Tolnay <dtolnay@gmail.com>",
"repository": "https://github.com/dtolnay/ryu",
"license": "Apache-2.0 OR BSL-1.0",
@ -1900,7 +2179,7 @@
},
{
"name": "sct",
"version": "0.6.1",
"version": "0.7.0",
"authors": "Joseph Birr-Pixton <jpixton@gmail.com>",
"repository": "https://github.com/ctz/sct.rs",
"license": "Apache-2.0 OR ISC OR MIT",
@ -1936,7 +2215,7 @@
},
{
"name": "serde",
"version": "1.0.149",
"version": "1.0.152",
"authors": "Erick Tryzelaar <erick.tryzelaar@gmail.com>|David Tolnay <dtolnay@gmail.com>",
"repository": "https://github.com/serde-rs/serde",
"license": "Apache-2.0 OR MIT",
@ -1954,7 +2233,7 @@
},
{
"name": "serde_derive",
"version": "1.0.149",
"version": "1.0.152",
"authors": "Erick Tryzelaar <erick.tryzelaar@gmail.com>|David Tolnay <dtolnay@gmail.com>",
"repository": "https://github.com/serde-rs/serde",
"license": "Apache-2.0 OR MIT",
@ -1963,13 +2242,31 @@
},
{
"name": "serde_json",
"version": "1.0.89",
"version": "1.0.91",
"authors": "Erick Tryzelaar <erick.tryzelaar@gmail.com>|David Tolnay <dtolnay@gmail.com>",
"repository": "https://github.com/serde-rs/json",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "A JSON serialization file format"
},
{
"name": "serde_path_to_error",
"version": "0.1.9",
"authors": "David Tolnay <dtolnay@gmail.com>",
"repository": "https://github.com/dtolnay/path-to-error",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "Path to the element that failed to deserialize"
},
{
"name": "serde_qs",
"version": "0.8.5",
"authors": "Sam Scott <sam@osohq.com>",
"repository": "https://github.com/samscott89/serde_qs",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "Querystrings for Serde"
},
{
"name": "serde_repr",
"version": "0.1.9",
@ -2114,6 +2411,15 @@
"license_file": null,
"description": "Synchronization primitives based on spinning. They may contain data, are usable without `std`, and static initializers are available."
},
{
"name": "spin",
"version": "0.9.4",
"authors": "Mathijs van de Nes <git@mathijs.vd-nes.nl>|John Ericson <git@JohnEricson.me>|Joshua Barretto <joshua.s.barretto@gmail.com>",
"repository": "https://github.com/mvdnes/spin-rs.git",
"license": "MIT",
"license_file": null,
"description": "Spin-based synchronization primitives"
},
{
"name": "string_cache",
"version": "0.8.4",
@ -2161,13 +2467,22 @@
},
{
"name": "syn",
"version": "1.0.105",
"version": "1.0.107",
"authors": "David Tolnay <dtolnay@gmail.com>",
"repository": "https://github.com/dtolnay/syn",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "Parser for Rust source code"
},
{
"name": "sync_wrapper",
"version": "0.1.1",
"authors": "Actyx AG <developer@actyx.io>",
"repository": "https://github.com/Actyx/sync_wrapper",
"license": "Apache-2.0",
"license_file": null,
"description": "A tool for enlisting the compilers help in proving the absence of concurrency"
},
{
"name": "tempfile",
"version": "3.3.0",
@ -2197,7 +2512,7 @@
},
{
"name": "thiserror",
"version": "1.0.37",
"version": "1.0.38",
"authors": "David Tolnay <dtolnay@gmail.com>",
"repository": "https://github.com/dtolnay/thiserror",
"license": "Apache-2.0 OR MIT",
@ -2206,7 +2521,7 @@
},
{
"name": "thiserror-impl",
"version": "1.0.37",
"version": "1.0.38",
"authors": "David Tolnay <dtolnay@gmail.com>",
"repository": "https://github.com/dtolnay/thiserror",
"license": "Apache-2.0 OR MIT",
@ -2278,22 +2593,13 @@
},
{
"name": "tokio",
"version": "1.23.0",
"version": "1.24.1",
"authors": "Tokio Contributors <team@tokio.rs>",
"repository": "https://github.com/tokio-rs/tokio",
"license": "MIT",
"license_file": null,
"description": "An event-driven, non-blocking I/O platform for writing asynchronous I/O backed applications."
},
{
"name": "tokio-io-timeout",
"version": "1.1.1",
"authors": "Steven Fackler <sfackler@gmail.com>",
"repository": "https://github.com/sfackler/tokio-io-timeout",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "Tokio wrappers which apply timeouts to IO operations"
},
{
"name": "tokio-macros",
"version": "1.8.2",
@ -2314,7 +2620,7 @@
},
{
"name": "tokio-rustls",
"version": "0.22.0",
"version": "0.23.4",
"authors": "quininer kel <quininer@live.com>",
"repository": "https://github.com/tokio-rs/tls",
"license": "Apache-2.0 OR MIT",
@ -2348,6 +2654,33 @@
"license_file": null,
"description": "A native Rust encoder and decoder of TOML-formatted files and streams. Provides implementations of the standard Serialize/Deserialize traits for TOML data to facilitate deserializing and serializing Rust structures."
},
{
"name": "tower",
"version": "0.4.13",
"authors": "Tower Maintainers <team@tower-rs.com>",
"repository": "https://github.com/tower-rs/tower",
"license": "MIT",
"license_file": null,
"description": "Tower is a library of modular and reusable components for building robust clients and servers."
},
{
"name": "tower-http",
"version": "0.3.5",
"authors": "Tower Maintainers <team@tower-rs.com>",
"repository": "https://github.com/tower-rs/tower-http",
"license": "MIT",
"license_file": null,
"description": "Tower middleware and utilities for HTTP clients and servers"
},
{
"name": "tower-layer",
"version": "0.3.2",
"authors": "Tower Maintainers <team@tower-rs.com>",
"repository": "https://github.com/tower-rs/tower",
"license": "MIT",
"license_file": null,
"description": "Decorates a `Service` to allow easy composition between `Service`s."
},
{
"name": "tower-service",
"version": "0.3.2",
@ -2539,7 +2872,7 @@
},
{
"name": "unicode-ident",
"version": "1.0.5",
"version": "1.0.6",
"authors": "David Tolnay <dtolnay@gmail.com>",
"repository": "https://github.com/dtolnay/unicode-ident",
"license": "(MIT OR Apache-2.0) AND Unicode-DFS-2016",
@ -2636,6 +2969,15 @@
"license_file": null,
"description": "Tiny crate to check the version of the installed/running rustc."
},
{
"name": "waker-fn",
"version": "1.1.0",
"authors": "Stjepan Glavina <stjepang@gmail.com>",
"repository": "https://github.com/stjepang/waker-fn",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "Convert closures into wakers"
},
{
"name": "want",
"version": "0.3.0",
@ -2645,6 +2987,15 @@
"license_file": null,
"description": "Detect when another Future wants a result."
},
{
"name": "wasi",
"version": "0.9.0+wasi-snapshot-preview1",
"authors": "The Cranelift Project Developers",
"repository": "https://github.com/bytecodealliance/wasi",
"license": "Apache-2.0 OR Apache-2.0 WITH LLVM-exception OR MIT",
"license_file": null,
"description": "Experimental WASI API bindings for Rust"
},
{
"name": "wasi",
"version": "0.11.0+wasi-snapshot-preview1",
@ -2719,7 +3070,7 @@
},
{
"name": "webpki",
"version": "0.21.4",
"version": "0.22.0",
"authors": "Brian Smith <brian@briansmith.org>",
"repository": "https://github.com/briansmith/webpki",
"license": null,
@ -2728,9 +3079,9 @@
},
{
"name": "webpki-roots",
"version": "0.21.1",
"version": "0.22.6",
"authors": "Joseph Birr-Pixton <jpixton@gmail.com>",
"repository": "https://github.com/ctz/webpki-roots",
"repository": "https://github.com/rustls/webpki-roots",
"license": "MPL-2.0",
"license_file": null,
"description": "Mozilla's CA root certificates for use with webpki"
@ -2908,13 +3259,22 @@
},
{
"name": "winreg",
"version": "0.7.0",
"version": "0.10.1",
"authors": "Igor Shaula <gentoo90@gmail.com>",
"repository": "https://github.com/gentoo90/winreg-rs",
"license": "MIT",
"license_file": null,
"description": "Rust bindings to MS Windows Registry API"
},
{
"name": "wiremock",
"version": "0.5.17",
"authors": "Luca Palmieri <rust@lpalmieri.com>",
"repository": "https://github.com/LukeMathWalker/wiremock-rs",
"license": "Apache-2.0 OR MIT",
"license_file": null,
"description": "HTTP mocking to test Rust applications."
},
{
"name": "workspace-hack",
"version": "0.1.0",

View File

@ -1,103 +0,0 @@
# Local sync server
A local sync server is bundled with Anki. If you cannot or do not wish to
use AnkiWeb, you can run the server on a machine on your local network.
Things to be aware of:
- Media syncing is not currently supported. You will either need to disable
syncing of sounds and images in the preferences screen, sync your media via
AnkiWeb, or use some other solution.
- AnkiMobile does not yet provide an option for using a local sync server,
so for now this will only be usable with the computer version of Anki, and
AnkiDroid.
- This code is partly new, and while it has had some testing, it's possible
something has been missed. Please make backups, and report any bugs you run
into.
- The server runs over an unencrypted HTTP connection and does not require
authentication, so it is only suitable for use on a private network.
- This is an advanced feature, targeted at users who are comfortable with
networking and the command line. If you use this, the expectation is you
can resolve any setup/network/firewall issues you run into yourself, and
use of this is entirely at your own risk.
## From source
If you run Anki from git, you can run a sync server with:
```
./tools/runopt --syncserver
```
## From a packaged build
From 2.1.39beta1+, the sync server is included in the packaged binaries.
On Windows in a cmd.exe session:
```
"\program files\anki\anki-console.exe" --syncserver
```
Or MacOS, in Terminal.app:
```
/Applications/Anki.app/Contents/MacOS/AnkiMac --syncserver
```
Or Linux:
```
anki --syncserver
```
## Without Qt dependencies
You can run the server without installing the GUI portion of Anki. Once Anki
2.1.39 is released, the following will work:
```
pip install anki[syncserver]
python -m anki.syncserver
```
## Server setup
The server needs to store a copy of your collection in a folder.
By default it is ~/.syncserver; you can change this by defining
a `FOLDER` environmental variable. This should not be the same location
as your normal Anki data folder.
You can also define `HOST` and `PORT`.
## Client setup
When the server starts, it will print the address it is listening on.
You need to set an environmental variable before starting your Anki
clients to tell them where to connect to. Eg:
```
set SYNC_ENDPOINT="http://10.0.0.5:8080/sync/"
anki
```
Currently any username and password will be accepted. If you wish to
keep using AnkiWeb for media, sync once with AnkiWeb first, then switch
to your local endpoint - collection syncs will be local, and media syncs
will continue to go to AnkiWeb.
## Contributing
Authentication shouldn't be too hard to add - login() and request() in
http_client.rs can be used as a reference. A PR that accepts a password in an
env var, and generates a stable hkey based on it would be welcome.
Once that is done, basic multi-profile support could be implemented by moving
the col object into an array or dict, and fetching the relevant collection based
on the user's authentication.
Because this server is bundled with Anki, simplicity is a design goal - it is
targeted at individual/family use, only makes use of Python libraries the GUI is
already using, and does not require a configuration file. PRs that deviate from
this are less likely to be merged, so please consider reaching out first if you
are thinking of starting work on a larger change.

View File

@ -8,7 +8,7 @@ preferences-interface-language = Interface language:
preferences-interrupt-current-audio-when-answering = Interrupt current audio when answering
preferences-learn-ahead-limit = Learn ahead limit
preferences-mins = mins
preferences-network = Network
preferences-network = Syncing
preferences-next-day-starts-at = Next day starts at
preferences-note-media-is-not-backed-up = Note: Media is not backed up. Please create a periodic backup of your Anki folder to be safe.
preferences-on-next-sync-force-changes-in = On next sync, force changes in one direction
@ -49,3 +49,5 @@ preferences-minutes-between-backups = Minutes between automatic backups:
preferences-reduce-motion = Reduce motion
preferences-reduce-motion-tooltip = Disable various animations and transitions of the user interface
preferences-collapse-toolbar = Hide top bar during review
preferences-custom-sync-url = Self-hosted sync server
preferences-custom-sync-url-disclaimer = For advanced users - please see the manual

View File

@ -19,18 +19,17 @@ service SyncService {
rpc FullUpload(SyncAuth) returns (generic.Empty);
rpc FullDownload(SyncAuth) returns (generic.Empty);
rpc AbortSync(generic.Empty) returns (generic.Empty);
rpc SyncServerMethod(SyncServerMethodRequest) returns (generic.Json);
}
message SyncAuth {
string hkey = 1;
uint32 host_number = 2;
optional string endpoint = 2;
}
message SyncLoginRequest {
string username = 1;
string password = 2;
optional string endpoint = 3;
}
message SyncStatusResponse {
@ -40,6 +39,7 @@ message SyncStatusResponse {
FULL_SYNC = 2;
}
Required required = 1;
optional string new_endpoint = 4;
}
message SyncCollectionResponse {
@ -56,24 +56,5 @@ message SyncCollectionResponse {
uint32 host_number = 1;
string server_message = 2;
ChangesRequired required = 3;
}
message SyncServerMethodRequest {
enum Method {
HOST_KEY = 0;
META = 1;
START = 2;
APPLY_GRAVES = 3;
APPLY_CHANGES = 4;
CHUNK = 5;
APPLY_CHUNK = 6;
SANITY_CHECK = 7;
FINISH = 8;
ABORT = 9;
// caller must reopen after these two are called
FULL_UPLOAD = 10;
FULL_DOWNLOAD = 11;
}
Method method = 1;
bytes data = 2;
optional string new_endpoint = 4;
}

View File

@ -11,7 +11,6 @@ from weakref import ref
from markdown import markdown
import anki.buildinfo
import anki.lang
from anki import _rsbridge, backend_pb2, i18n_pb2
from anki._backend_generated import RustBackendGenerated
from anki._fluent import GeneratedTranslations
@ -72,6 +71,8 @@ class RustBackend(RustBackendGenerated):
server: bool = False,
) -> None:
# pick up global defaults if not provided
import anki.lang
if langs is None:
langs = [anki.lang.current_lang]
@ -81,6 +82,10 @@ class RustBackend(RustBackendGenerated):
)
self._backend = _rsbridge.open_backend(init_msg.SerializeToString())
@staticmethod
def syncserver() -> None:
_rsbridge.syncserver()
def db_query(
self, sql: str, args: Sequence[ValueForDB], first_row_only: bool
) -> list[DBRow]:

View File

@ -1,6 +1,7 @@
def buildhash() -> str: ...
def open_backend(data: bytes) -> Backend: ...
def initialize_logging(log_file: str | None) -> Backend: ...
def syncserver() -> None: ...
class Backend:
@classmethod

View File

@ -16,6 +16,7 @@ from anki import (
stats_pb2,
)
from anki._legacy import DeprecatedNamesMixin, deprecated
from anki.sync_pb2 import SyncLoginRequest
# protobuf we publicly export - listed first to avoid circular imports
HelpPage = links_pb2.HelpPageLinkRequest.HelpPage
@ -1189,8 +1190,12 @@ class Collection(DeprecatedNamesMixin):
def full_download(self, auth: SyncAuth) -> None:
self._backend.full_download(auth)
def sync_login(self, username: str, password: str) -> SyncAuth:
return self._backend.sync_login(username=username, password=password)
def sync_login(
self, username: str, password: str, endpoint: str | None
) -> SyncAuth:
return self._backend.sync_login(
SyncLoginRequest(username=username, password=password, endpoint=endpoint)
)
def sync_collection(self, auth: SyncAuth) -> SyncOutput:
return self._backend.sync_collection(auth)

24
pylib/anki/syncserver.py Normal file
View File

@ -0,0 +1,24 @@
# Copyright: Ankitects Pty Ltd and contributors
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
def run_sync_server() -> None:
import sys
from os import environ as env
from os.path import expanduser
from anki._backend import RustBackend
env["SYNC_BASE"] = env.get("SYNC_BASE", expanduser("~/.syncserver"))
env["RUST_LOG"] = env.get("RUST_LOG", "anki=info")
try:
RustBackend.syncserver()
except Exception as exc:
print("Sync server failed:", exc)
sys.exit(1)
sys.exit(0)
if __name__ == "__main__":
run_sync_server()

View File

@ -1,195 +0,0 @@
# Copyright: Ankitects Pty Ltd and contributors
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
#
# Please see /docs/syncserver.md
#
from __future__ import annotations
import gzip
import os
import socket
import sys
import time
from http import HTTPStatus
from io import BytesIO
from tempfile import NamedTemporaryFile
from typing import Iterable, Optional
try:
import flask
from waitress.server import create_server
except ImportError as error:
print(error, "- to use the server, 'pip install anki[syncserver]'")
sys.exit(1)
from flask import Response
from anki.collection import Collection
from anki.sync_pb2 import SyncServerMethodRequest
Method = SyncServerMethodRequest.Method # pylint: disable=no-member
app = flask.Flask(__name__, root_path="/fake")
col: Collection
trace = os.getenv("TRACE")
def get_request_data() -> bytes:
buf = BytesIO()
flask.request.files["data"].save(buf)
buf.seek(0)
zip = gzip.GzipFile(mode="rb", fileobj=buf)
return zip.read()
def get_request_data_into_file() -> bytes:
"Returns the utf8 path to the resulting file."
# this could be optimized to stream the data into a file
# in the future
data = get_request_data()
tempobj = NamedTemporaryFile(dir=folder(), delete=False)
tempobj.write(data)
tempobj.close()
return tempobj.name.encode("utf8")
def handle_sync_request(method_str: str) -> Response:
method = get_method(method_str)
if method is None:
raise Exception(f"unknown method: {method_str}")
if method == Method.FULL_UPLOAD:
data = get_request_data_into_file()
else:
data = get_request_data()
if trace:
print("-->", data)
full = method in (Method.FULL_UPLOAD, Method.FULL_DOWNLOAD)
if full:
col.close_for_full_sync()
try:
outdata = col._backend.sync_server_method(method=method, data=data)
except Exception as error:
if method == Method.META:
# if parallel syncing requests come in, block them
print("exception in meta", error)
return flask.make_response("Conflict", 409)
else:
raise
finally:
if full:
after_full_sync()
resp = None
if method == Method.FULL_UPLOAD:
# upload call expects a raw string literal returned
outdata = b"OK"
elif method == Method.FULL_DOWNLOAD:
path = outdata.decode("utf8")
def stream_reply() -> Iterable[bytes]:
with open(path, "rb") as file:
while chunk := file.read(16 * 1024):
yield chunk
os.unlink(path)
resp = Response(stream_reply())
else:
if trace:
print("<--", outdata)
if not resp:
resp = flask.make_response(outdata)
resp.headers["Content-Type"] = "application/binary"
return resp
def after_full_sync() -> None:
# the server methods do not reopen the collection after a full sync,
# so we need to
col.reopen(after_full_sync=False)
col.db.rollback()
def get_method(
method_str: str,
) -> SyncServerMethodRequest.Method.V | None: # pylint: disable=no-member
if method_str == "hostKey":
return Method.HOST_KEY
elif method_str == "meta":
return Method.META
elif method_str == "start":
return Method.START
elif method_str == "applyGraves":
return Method.APPLY_GRAVES
elif method_str == "applyChanges":
return Method.APPLY_CHANGES
elif method_str == "chunk":
return Method.CHUNK
elif method_str == "applyChunk":
return Method.APPLY_CHUNK
elif method_str == "sanityCheck2":
return Method.SANITY_CHECK
elif method_str == "finish":
return Method.FINISH
elif method_str == "abort":
return Method.ABORT
elif method_str == "upload":
return Method.FULL_UPLOAD
elif method_str == "download":
return Method.FULL_DOWNLOAD
else:
return None
@app.route("/<path:pathin>", methods=["POST"])
def handle_request(pathin: str) -> Response:
path = pathin
print(int(time.time()), flask.request.remote_addr, path)
if path.startswith("sync/"):
return handle_sync_request(path.split("/", maxsplit=1)[1])
else:
return flask.make_response("not found", HTTPStatus.NOT_FOUND)
def folder() -> str:
folder = os.getenv("FOLDER", os.path.expanduser("~/.syncserver"))
if not os.path.exists(folder):
print("creating", folder)
os.mkdir(folder)
return folder
def col_path() -> str:
return os.path.join(folder(), "collection.server.anki2")
def serve() -> None:
global col # pylint: disable=invalid-name
col = Collection(col_path(), server=True)
# don't hold an outer transaction open
col.db.rollback()
host = os.getenv("HOST", "0.0.0.0")
port = int(os.getenv("PORT", "8080"))
server = create_server(
app,
host=host,
port=port,
clear_untrusted_proxy_headers=True,
)
effective_port = server.effective_port # type: ignore
print(f"Sync server listening on http://{host}:{effective_port}/sync/")
if host == "0.0.0.0":
ip = socket.gethostbyname(socket.gethostname())
print(f"Replace 0.0.0.0 with your machine's IP address (perhaps {ip})")
print(
"For more info, see https://github.com/ankitects/anki/blob/master/docs/syncserver.md"
)
server.run()

View File

@ -1,6 +0,0 @@
# Copyright: Ankitects Pty Ltd and contributors
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
from anki.syncserver import serve
serve()

View File

@ -4,6 +4,7 @@
use anki::{
backend::{init_backend, Backend as RustBackend},
log::set_global_logger,
sync::http_server::SimpleServer,
};
use pyo3::{
create_exception, exceptions::PyException, prelude::*, types::PyBytes, wrap_pyfunction,
@ -26,6 +27,12 @@ fn initialize_logging(path: Option<&str>) -> PyResult<()> {
set_global_logger(path).map_err(|e| PyException::new_err(e.to_string()))
}
#[pyfunction]
fn syncserver() -> PyResult<()> {
set_global_logger(None).unwrap();
SimpleServer::run().map_err(|e| PyException::new_err(format!("{e:?}")))
}
#[pyfunction]
fn open_backend(init_msg: &PyBytes) -> PyResult<Backend> {
match init_backend(init_msg.as_bytes()) {
@ -76,6 +83,7 @@ fn _rsbridge(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(buildhash)).unwrap();
m.add_wrapped(wrap_pyfunction!(open_backend)).unwrap();
m.add_wrapped(wrap_pyfunction!(initialize_logging)).unwrap();
m.add_wrapped(wrap_pyfunction!(syncserver)).unwrap();
Ok(())
}

View File

@ -6,6 +6,7 @@ from __future__ import annotations
import sys
from email.message import EmailMessage
from pathlib import Path
from typing import Sequence
from zipfile import ZIP_DEFLATED, ZipInfo
from wheel.wheelfile import WheelFile
@ -158,11 +159,10 @@ else:
contents: dict[str, str] = {}
merge_sources(contents, src_root, exclude)
merge_sources(contents, generated_root, exclude)
all_requires: Sequence[str | ExtraRequires]
if name == "anki":
all_requires = extract_requirements(Path("python/requirements.anki.in")) + [
ExtraRequires("syncserver", ["flask", "waitress"]),
]
all_requires = extract_requirements(Path("python/requirements.anki.in"))
entrypoints = None
top_level = None
else:

View File

@ -16,18 +16,17 @@ except UnicodeEncodeError as exc:
print("Please Google 'how to change locale on [your Linux distro]'")
sys.exit(1)
# if sync server enabled, bypass the rest of the startup
if "--syncserver" in sys.argv:
from anki.syncserver import run_sync_server
# does not return
run_sync_server()
from .package import packaged_build_setup
packaged_build_setup()
# syncserver needs to be run before Qt loaded
if "--syncserver" in sys.argv:
from anki.syncserver import serve
serve()
sys.exit(0)
import argparse
import builtins
import cProfile

View File

@ -492,6 +492,24 @@
</property>
</spacer>
</item>
<item>
<layout class="QGridLayout" name="gridLayout_2">
<item row="0" column="1">
<widget class="QLineEdit" name="custom_sync_url">
<property name="placeholderText">
<string>preferences_custom_sync_url_disclaimer</string>
</property>
</widget>
</item>
<item row="0" column="0">
<widget class="QLabel" name="label">
<property name="text">
<string>preferences_custom_sync_url</string>
</property>
</widget>
</item>
</layout>
</item>
</layout>
</widget>
<widget class="QWidget" name="tab">
@ -722,6 +740,7 @@
<tabstop>fullSync</tabstop>
<tabstop>syncDeauth</tabstop>
<tabstop>media_log</tabstop>
<tabstop>custom_sync_url</tabstop>
<tabstop>minutes_between_backups</tabstop>
<tabstop>daily_backups</tabstop>
<tabstop>weekly_backups</tabstop>

View File

@ -12,12 +12,12 @@ import aqt
import aqt.forms
import aqt.main
from anki.collection import Progress
from anki.errors import Interrupted, NetworkError
from anki.errors import Interrupted
from anki.types import assert_exhaustive
from anki.utils import int_time
from aqt import gui_hooks
from aqt.qt import QDialog, QDialogButtonBox, QPushButton, QTextCursor, QTimer, qconnect
from aqt.utils import disable_help_button, showWarning, tr
from aqt.utils import disable_help_button, tr
LogEntry = Union[Progress.MediaSync, str]
@ -92,14 +92,13 @@ class MediaSyncer:
if isinstance(exc, Interrupted):
self._log_and_notify(tr.sync_media_aborted())
return
elif isinstance(exc, NetworkError):
# avoid popups for network errors
else:
# Avoid popups for errors; they can cause a deadlock if
# a modal window happens to be active, or a duplicate auth
# failed message if the password is changed.
self._log_and_notify(str(exc))
return
self._log_and_notify(tr.sync_media_failed())
showWarning(str(exc))
def entries(self) -> list[LogEntryWithTime]:
return self._log

View File

@ -174,6 +174,7 @@ class Preferences(QDialog):
self.form.syncUser.setText(self.prof.get("syncUser", ""))
qconnect(self.form.syncDeauth.clicked, self.sync_logout)
self.form.syncDeauth.setText(tr.sync_log_out_button())
self.form.custom_sync_url.setText(self.mw.pm.custom_sync_url())
def on_media_log(self) -> None:
self.mw.media_syncer.show_sync_log()
@ -201,6 +202,7 @@ class Preferences(QDialog):
)
if self.form.fullSync.isChecked():
self.mw.col.mod_schema(check=False)
self.mw.pm.set_custom_sync_url(self.form.custom_sync_url.text())
# Global preferences
######################################################################

View File

@ -606,15 +606,33 @@ create table if not exists profiles
return self.profile["autoSync"]
def sync_auth(self) -> SyncAuth | None:
hkey = self.profile.get("syncKey")
if not hkey:
if not (hkey := self.profile.get("syncKey")):
return None
return SyncAuth(hkey=hkey, host_number=self.profile.get("hostNum", 0))
return SyncAuth(hkey=hkey, endpoint=self.sync_endpoint())
def clear_sync_auth(self) -> None:
self.profile["syncKey"] = None
self.profile["syncUser"] = None
self.profile["hostNum"] = 0
self.set_sync_key(None)
self.set_sync_username(None)
self.set_host_number(None)
self.set_current_sync_url(None)
def sync_endpoint(self) -> str | None:
return self._current_sync_url() or self.custom_sync_url() or None
def _current_sync_url(self) -> str | None:
"""The last endpoint the server redirected us to."""
return self.profile.get("currentSyncUrl")
def set_current_sync_url(self, url: str | None) -> None:
self.profile["currentSyncUrl"] = url
def custom_sync_url(self) -> str | None:
"""A custom server provided by the user."""
return self.profile.get("customSyncUrl")
def set_custom_sync_url(self, url: str | None) -> None:
self.set_current_sync_url(None)
self.profile["customSyncUrl"] = url
def auto_sync_media_minutes(self) -> int:
return self.profile.get("autoSyncMediaMinutes", 15)

View File

@ -12,6 +12,7 @@ import aqt.main
from anki.errors import Interrupted, SyncError, SyncErrorKind
from anki.lang import without_unicode_isolation
from anki.sync import SyncOutput, SyncStatus
from anki.sync_pb2 import SyncAuth
from anki.utils import plat_desc
from aqt import gui_hooks
from aqt.qt import (
@ -43,13 +44,15 @@ def get_sync_status(
callback(SyncStatus(required=SyncStatus.NO_CHANGES)) # pylint:disable=no-member
return
def on_future_done(fut: Future) -> None:
def on_future_done(fut: Future[SyncStatus]) -> None:
try:
out = fut.result()
except Exception as e:
# swallow errors
print("sync status check failed:", str(e))
return
if out.new_endpoint:
mw.pm.set_current_sync_url(out.new_endpoint)
callback(out)
mw.taskman.run_in_background(lambda: mw.col.sync_status(auth), on_future_done)
@ -93,18 +96,20 @@ def sync_collection(mw: aqt.main.AnkiQt, on_done: Callable[[], None]) -> None:
qconnect(timer.timeout, on_timer)
timer.start(150)
def on_future_done(fut: Future) -> None:
def on_future_done(fut: Future[SyncOutput]) -> None:
mw.col.db.begin()
# scheduler version may have changed
mw.col._load_scheduler()
timer.stop()
try:
out: SyncOutput = fut.result()
out = fut.result()
except Exception as err:
handle_sync_error(mw, err)
return on_done()
mw.pm.set_host_number(out.host_number)
if out.new_endpoint:
mw.pm.set_current_sync_url(out.new_endpoint)
if out.server_message:
showText(out.server_message)
if out.required == out.NO_CHANGES:
@ -161,7 +166,7 @@ def confirm_full_download(mw: aqt.main.AnkiQt, on_done: Callable[[], None]) -> N
mw.closeAllWindows(lambda: full_download(mw, on_done))
def on_full_sync_timer(mw: aqt.main.AnkiQt) -> None:
def on_full_sync_timer(mw: aqt.main.AnkiQt, label: str) -> None:
progress = mw.col.latest_progress()
if not progress.HasField("full_sync"):
return
@ -169,8 +174,6 @@ def on_full_sync_timer(mw: aqt.main.AnkiQt) -> None:
if sync_progress.transferred == sync_progress.total:
label = tr.sync_checking()
else:
label = None
mw.progress.update(
value=sync_progress.transferred,
max=sync_progress.total,
@ -183,8 +186,10 @@ def on_full_sync_timer(mw: aqt.main.AnkiQt) -> None:
def full_download(mw: aqt.main.AnkiQt, on_done: Callable[[], None]) -> None:
label = tr.sync_downloading_from_ankiweb()
def on_timer() -> None:
on_full_sync_timer(mw)
on_full_sync_timer(mw, label)
timer = QTimer(mw)
qconnect(timer.timeout, on_timer)
@ -212,7 +217,6 @@ def full_download(mw: aqt.main.AnkiQt, on_done: Callable[[], None]) -> None:
mw.taskman.with_progress(
download,
on_future_done,
label=tr.sync_downloading_from_ankiweb(),
)
@ -220,8 +224,10 @@ def full_upload(mw: aqt.main.AnkiQt, on_done: Callable[[], None]) -> None:
gui_hooks.collection_will_temporarily_close(mw.col)
mw.col.close_for_full_sync()
label = tr.sync_uploading_to_ankiweb()
def on_timer() -> None:
on_full_sync_timer(mw)
on_full_sync_timer(mw, label)
timer = QTimer(mw)
qconnect(timer.timeout, on_timer)
@ -242,7 +248,6 @@ def full_upload(mw: aqt.main.AnkiQt, on_done: Callable[[], None]) -> None:
mw.taskman.with_progress(
lambda: mw.col.full_upload(mw.pm.sync_auth()),
on_future_done,
label=tr.sync_uploading_to_ankiweb(),
)
@ -259,7 +264,7 @@ def sync_login(
if username and password:
break
def on_future_done(fut: Future) -> None:
def on_future_done(fut: Future[SyncAuth]) -> None:
try:
auth = fut.result()
except SyncError as e:
@ -273,14 +278,15 @@ def sync_login(
handle_sync_error(mw, err)
return
mw.pm.set_host_number(auth.host_number)
mw.pm.set_sync_key(auth.hkey)
mw.pm.set_sync_username(username)
on_success()
mw.taskman.with_progress(
lambda: mw.col.sync_login(username=username, password=password),
lambda: mw.col.sync_login(
username=username, password=password, endpoint=mw.pm.sync_endpoint()
),
on_future_done,
)

View File

@ -31,11 +31,12 @@ prost-build = "0.11.4"
which = "4.3.0"
[dev-dependencies]
async-stream = "0.3.3"
env_logger = "0.10.0"
tokio = { version = "1.23", features = ["macros"] }
wiremock = "0.5.17"
[dependencies.reqwest]
version = "=0.11.3"
version = "0.11.13"
default-features = false
features = ["json", "socks", "stream", "multipart"]
@ -51,7 +52,10 @@ unicase = "=2.6.0"
criterion = { version = "0.4.0", optional = true }
ammonia = "3.3.0"
async-compression = { version = "0.3.15", features = ["zstd", "tokio"] }
async-trait = "0.1.59"
axum = { version = "0.6.1", features = ["multipart", "macros", "headers"] }
axum-client-ip = "0.3.1"
blake3 = "1.3.3"
bytes = "1.3.0"
chrono = { version = "0.4.19", default-features = false, features = ["std", "clock"] }
@ -65,6 +69,7 @@ fnv = "1.0.7"
futures = "0.3.25"
hex = "0.4.3"
htmlescape = "0.3.1"
hyper = "0.14.23"
id_tree = "1.8.0"
itertools = "0.10.5"
lazy_static = "1.4.0"
@ -90,8 +95,9 @@ sha1 = "0.10.5"
snafu = { version = "0.7.3", features = ["backtraces"] }
strum = { version = "0.24.1", features = ["derive"] }
tempfile = "3.3.0"
tokio = { version = "1.23", features = ["fs", "rt-multi-thread"] }
tokio = { version = "1.23", features = ["fs", "rt-multi-thread", "macros"] }
tokio-util = { version = "0.7.4", features = ["io"] }
tower-http = { version = "0.3.5", features = ["trace"] }
tracing = { version = "0.1.37", features = ["max_level_trace", "release_max_level_debug"] }
tracing-appender = "0.2.2"
tracing-subscriber = { version = "0.3.16", features = ["fmt", "env-filter"] }

View File

@ -3,7 +3,9 @@
// copied from https://github.com/projectfluent/fluent-rs/pull/241
use std::fmt::{self, Error, Write};
use std::fmt::{
Error, Write, {self},
};
use fluent_syntax::{ast::*, parser::Slice};

View File

@ -28,5 +28,5 @@ rustls = ["reqwest/rustls-tls", "reqwest/rustls-tls-native-roots"]
native-tls = ["reqwest/native-tls"]
[dependencies.reqwest]
version = "=0.11.3"
version = "0.11.13"
default-features = false

View File

@ -12,13 +12,14 @@ use super::{
pub(super) use crate::pb::ankidroid::ankidroid_service::Service as AnkidroidService;
use crate::{
backend::ankidroid::db::{execute_for_row_count, insert_for_id},
pb,
pb::{
self as pb,
ankidroid::{DbResponse, GetActiveSequenceNumbersResponse, GetNextResultPageRequest},
generic::{self, Empty, Int32, Json},
generic,
generic::{Empty, Int32, Json},
},
prelude::*,
scheduler::timing::{self, fixed_offset_from_minutes},
scheduler::{timing, timing::fixed_offset_from_minutes},
};
impl AnkidroidService for Backend {

View File

@ -8,10 +8,7 @@ use tracing::error;
use super::{progress::Progress, Backend};
pub(super) use crate::pb::collection::collection_service::Service as CollectionService;
use crate::{
backend::progress::progress_to_proto,
collection::CollectionBuilder,
pb::{self as pb},
prelude::*,
backend::progress::progress_to_proto, collection::CollectionBuilder, pb, prelude::*,
storage::SchemaVersion,
};

View File

@ -7,7 +7,7 @@ use super::Backend;
pub(super) use crate::pb::decks::decks_service::Service as DecksService;
use crate::{
decks::{DeckSchema11, FilteredSearchOrder},
pb::{self as pb},
pb,
prelude::*,
scheduler::filtered::FilteredDeckForUpdate,
};

View File

@ -7,10 +7,8 @@ use super::{progress::Progress, Backend};
pub(super) use crate::pb::import_export::importexport_service::Service as ImportExportService;
use crate::{
import_export::{package::import_colpkg, ExportProgress, ImportProgress, NoteLog},
pb::{
import_export::{export_limit, ExportLimit},
{self as pb},
},
pb,
pb::import_export::{export_limit, ExportLimit},
prelude::*,
search::SearchNode,
};

View File

@ -3,11 +3,7 @@
use super::{progress::Progress, Backend};
pub(super) use crate::pb::media::media_service::Service as MediaService;
use crate::{
media::{check::MediaChecker, MediaManager},
pb,
prelude::*,
};
use crate::{media::check::MediaChecker, pb, prelude::*};
impl MediaService for Backend {
// media
@ -18,7 +14,7 @@ impl MediaService for Backend {
let progress_fn =
move |progress| handler.update(Progress::MediaCheck(progress as u32), true);
self.with_col(|col| {
let mgr = MediaManager::new(&col.media_folder, &col.media_db)?;
let mgr = col.media()?;
col.transact_no_undo(|ctx| {
let mut checker = MediaChecker::new(ctx, &mgr, progress_fn);
let mut output = checker.check()?;
@ -41,19 +37,17 @@ impl MediaService for Backend {
input: pb::media::TrashMediaFilesRequest,
) -> Result<pb::generic::Empty> {
self.with_col(|col| {
let mgr = MediaManager::new(&col.media_folder, &col.media_db)?;
let mut ctx = mgr.dbctx();
mgr.remove_files(&mut ctx, &input.fnames)
let mgr = col.media()?;
mgr.remove_files(&input.fnames)
})
.map(Into::into)
}
fn add_media_file(&self, input: pb::media::AddMediaFileRequest) -> Result<pb::generic::String> {
self.with_col(|col| {
let mgr = MediaManager::new(&col.media_folder, &col.media_db)?;
let mut ctx = mgr.dbctx();
let mgr = col.media()?;
Ok(mgr
.add_file(&mut ctx, &input.desired_name, &input.data)?
.add_file(&input.desired_name, &input.data)?
.to_string()
.into())
})
@ -65,7 +59,7 @@ impl MediaService for Backend {
move |progress| handler.update(Progress::MediaCheck(progress as u32), true);
self.with_col(|col| {
let mgr = MediaManager::new(&col.media_folder, &col.media_db)?;
let mgr = col.media()?;
let mut checker = MediaChecker::new(col, &mgr, progress_fn);
checker.empty_trash()
})
@ -77,7 +71,7 @@ impl MediaService for Backend {
let progress_fn =
move |progress| handler.update(Progress::MediaCheck(progress as u32), true);
self.with_col(|col| {
let mgr = MediaManager::new(&col.media_folder, &col.media_db)?;
let mgr = col.media()?;
let mut checker = MediaChecker::new(col, &mgr, progress_fn);
checker.restore_trash()
})

View File

@ -38,9 +38,7 @@ use std::{
use once_cell::sync::OnceCell;
use progress::AbortHandleSlot;
use prost::Message;
use tokio::runtime::{
Runtime, {self},
};
use tokio::{runtime, runtime::Runtime};
use self::{
ankidroid::AnkidroidService,

View File

@ -5,11 +5,7 @@ use std::collections::HashSet;
use super::Backend;
pub(super) use crate::pb::notes::notes_service::Service as NotesService;
use crate::{
cloze::add_cloze_numbers_in_string,
pb::{self as pb},
prelude::*,
};
use crate::{cloze::add_cloze_numbers_in_string, pb, prelude::*};
impl NotesService for Backend {
fn new_note(&self, input: pb::notetypes::NotetypeId) -> Result<pb::notes::Note> {

View File

@ -10,9 +10,14 @@ use crate::{
dbcheck::DatabaseCheckProgress,
i18n::I18n,
import_export::{ExportProgress, ImportProgress},
media::sync::MediaSyncProgress,
pb,
sync::{FullSyncProgress, NormalSyncProgress, SyncStage},
sync::{
collection::{
normal::NormalSyncProgress,
progress::{FullSyncProgress, SyncStage},
},
media::progress::MediaSyncProgress,
},
};
pub(super) struct ThrottlingProgressHandler {

View File

@ -1,22 +1,27 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
mod server;
use std::sync::Arc;
use futures::future::{AbortHandle, AbortRegistration, Abortable};
use pb::sync::sync_status_response::Required;
use reqwest::Url;
use tracing::warn;
use super::{progress::AbortHandleSlot, Backend};
pub(super) use crate::pb::sync::sync_service::Service as SyncService;
use crate::{
media::MediaManager,
pb,
pb::sync::SyncStatusResponse,
prelude::*,
sync::{
get_remote_sync_meta, http::SyncRequest, sync_abort, sync_login, FullSyncProgress,
LocalServer, NormalSyncProgress, SyncActionRequired, SyncAuth, SyncMeta, SyncOutput,
collection::{
normal::{ClientSyncState, NormalSyncProgress, SyncActionRequired, SyncOutput},
progress::{sync_abort, FullSyncProgress},
status::online_sync_status_check,
},
http_client::HttpSyncClient,
login::{sync_login, SyncAuth},
},
};
@ -24,17 +29,16 @@ use crate::{
pub(super) struct SyncState {
remote_sync_status: RemoteSyncStatus,
media_sync_abort: Option<AbortHandle>,
http_sync_server: Option<LocalServer>,
}
#[derive(Default, Debug)]
pub(super) struct RemoteSyncStatus {
pub last_check: TimestampSecs,
pub last_response: pb::sync::sync_status_response::Required,
pub last_response: Required,
}
impl RemoteSyncStatus {
pub(super) fn update(&mut self, required: pb::sync::sync_status_response::Required) {
pub(super) fn update(&mut self, required: Required) {
self.last_check = TimestampSecs::now();
self.last_response = required
}
@ -45,6 +49,7 @@ impl From<SyncOutput> for pb::sync::SyncCollectionResponse {
pb::sync::SyncCollectionResponse {
host_number: o.host_number,
server_message: o.server_message,
new_endpoint: o.new_endpoint,
required: match o.required {
SyncActionRequired::NoChanges => {
pb::sync::sync_collection_response::ChangesRequired::NoChanges as i32
@ -69,12 +74,20 @@ impl From<SyncOutput> for pb::sync::SyncCollectionResponse {
}
}
impl From<pb::sync::SyncAuth> for SyncAuth {
fn from(a: pb::sync::SyncAuth) -> Self {
SyncAuth {
hkey: a.hkey,
host_number: a.host_number,
}
impl TryFrom<pb::sync::SyncAuth> for SyncAuth {
type Error = AnkiError;
fn try_from(value: pb::sync::SyncAuth) -> std::result::Result<Self, Self::Error> {
Ok(SyncAuth {
hkey: value.hkey,
endpoint: value
.endpoint
.map(|v| {
Url::try_from(v.as_str())
.or_invalid("Invalid sync server specified. Please check the preferences.")
})
.transpose()?,
})
}
}
@ -123,14 +136,6 @@ impl SyncService for Backend {
self.full_sync_inner(input, false)?;
Ok(().into())
}
fn sync_server_method(
&self,
input: pb::sync::SyncServerMethodRequest,
) -> Result<pb::generic::Json> {
let req = SyncRequest::from_method_and_data(input.method(), input.data)?;
self.sync_server_method_inner(req).map(Into::into)
}
}
impl Backend {
@ -160,7 +165,8 @@ impl Backend {
Ok((guard, abort_reg))
}
pub(super) fn sync_media_inner(&self, input: pb::sync::SyncAuth) -> Result<()> {
pub(super) fn sync_media_inner(&self, auth: pb::sync::SyncAuth) -> Result<()> {
let auth = auth.try_into()?;
// mark media sync as active
let (abort_handle, abort_reg) = AbortHandle::new_pair();
{
@ -173,20 +179,13 @@ impl Backend {
}
}
// get required info from collection
let mut guard = self.col.lock().unwrap();
let col = guard.as_mut().unwrap();
let folder = col.media_folder.clone();
let db = col.media_db.clone();
drop(guard);
// start the sync
let mgr = self.col.lock().unwrap().as_mut().unwrap().media()?;
let mut handler = self.new_progress_handler();
let progress_fn = move |progress| handler.update(progress, true);
let mgr = MediaManager::new(&folder, &db)?;
let rt = self.runtime_handle();
let sync_fut = mgr.sync_media(progress_fn, input.host_number, &input.hkey);
let sync_fut = mgr.sync_media(progress_fn, auth);
let abortable_sync = Abortable::new(sync_fut, abort_reg);
let result = rt.block_on(abortable_sync);
@ -226,7 +225,7 @@ impl Backend {
let (_guard, abort_reg) = self.sync_abort_handle()?;
let rt = self.runtime_handle();
let sync_fut = sync_login(&input.username, &input.password);
let sync_fut = sync_login(input.username, input.password, input.endpoint);
let abortable_sync = Abortable::new(sync_fut, abort_reg);
let ret = match rt.block_on(abortable_sync) {
Ok(sync_result) => sync_result,
@ -234,7 +233,7 @@ impl Backend {
};
ret.map(|a| pb::sync::SyncAuth {
hkey: a.hkey,
host_number: a.host_number,
endpoint: None,
})
}
@ -243,8 +242,8 @@ impl Backend {
input: pb::sync::SyncAuth,
) -> Result<pb::sync::SyncStatusResponse> {
// any local changes mean we can skip the network round-trip
let req = self.with_col(|col| col.get_local_sync_status())?;
if req != pb::sync::sync_status_response::Required::NoChanges {
let req = self.with_col(|col| col.sync_status_offline())?;
if req != Required::NoChanges {
return Ok(req.into());
}
@ -257,11 +256,12 @@ impl Backend {
}
// fetch and cache result
let auth = input.try_into()?;
let rt = self.runtime_handle();
let time_at_check_begin = TimestampSecs::now();
let remote: SyncMeta = rt.block_on(get_remote_sync_meta(input.into()))?;
let response = self.with_col(|col| col.get_sync_status(remote).map(Into::into))?;
let local = self.with_col(|col| col.sync_meta())?;
let mut client = HttpSyncClient::new(auth);
let state = rt.block_on(online_sync_status_check(local, &mut client))?;
{
let mut guard = self.state.lock().unwrap();
// On startup, the sync status check will block on network access, and then automatic syncing begins,
@ -269,21 +269,21 @@ impl Backend {
// so we discard it if stale.
if guard.sync.remote_sync_status.last_check < time_at_check_begin {
guard.sync.remote_sync_status.last_check = time_at_check_begin;
guard.sync.remote_sync_status.last_response = response;
guard.sync.remote_sync_status.last_response = state.required.into();
}
}
Ok(response.into())
Ok(state.into())
}
pub(super) fn sync_collection_inner(
&self,
input: pb::sync::SyncAuth,
) -> Result<pb::sync::SyncCollectionResponse> {
let auth: SyncAuth = input.try_into()?;
let (_guard, abort_reg) = self.sync_abort_handle()?;
let rt = self.runtime_handle();
let input_copy = input.clone();
let ret = self.with_col(|col| {
let mut handler = self.new_progress_handler();
@ -291,7 +291,7 @@ impl Backend {
handler.update(progress, throttle);
};
let sync_fut = col.normal_sync(input.into(), progress_fn);
let sync_fut = col.normal_sync(auth.clone(), progress_fn);
let abortable_sync = Abortable::new(sync_fut, abort_reg);
match rt.block_on(abortable_sync) {
@ -301,7 +301,7 @@ impl Backend {
col.storage.rollback_trx()?;
// and tell AnkiWeb to clean up
let _handle = std::thread::spawn(move || {
let _ = rt.block_on(sync_abort(input_copy.hkey, input_copy.host_number));
let _ = rt.block_on(sync_abort(auth));
});
Err(AnkiError::Interrupted)
@ -320,6 +320,7 @@ impl Backend {
}
pub(super) fn full_sync_inner(&self, input: pb::sync::SyncAuth, upload: bool) -> Result<()> {
let auth = input.try_into()?;
self.abort_media_sync_and_wait();
let rt = self.runtime_handle();
@ -336,16 +337,16 @@ impl Backend {
let builder = col_inner.as_builder();
let mut handler = self.new_progress_handler();
let progress_fn = move |progress: FullSyncProgress, throttle: bool| {
let progress_fn = Box::new(move |progress: FullSyncProgress, throttle: bool| {
handler.update(progress, throttle);
};
});
let result = if upload {
let sync_fut = col_inner.full_upload(input.into(), Box::new(progress_fn));
let sync_fut = col_inner.full_upload(auth, progress_fn);
let abortable_sync = Abortable::new(sync_fut, abort_reg);
rt.block_on(abortable_sync)
} else {
let sync_fut = col_inner.full_download(input.into(), Box::new(progress_fn));
let sync_fut = col_inner.full_download(auth, progress_fn);
let abortable_sync = Abortable::new(sync_fut, abort_reg);
rt.block_on(abortable_sync)
};
@ -361,7 +362,7 @@ impl Backend {
.unwrap()
.sync
.remote_sync_status
.update(pb::sync::sync_status_response::Required::NoChanges);
.update(Required::NoChanges);
}
sync_result
}
@ -369,3 +370,31 @@ impl Backend {
}
}
}
impl From<Required> for SyncStatusResponse {
fn from(r: Required) -> Self {
SyncStatusResponse {
required: r.into(),
new_endpoint: None,
}
}
}
impl From<ClientSyncState> for SyncStatusResponse {
fn from(r: ClientSyncState) -> Self {
SyncStatusResponse {
required: Required::from(r.required).into(),
new_endpoint: r.new_endpoint,
}
}
}
impl From<SyncActionRequired> for Required {
fn from(r: SyncActionRequired) -> Self {
match r {
SyncActionRequired::NoChanges => Required::NoChanges,
SyncActionRequired::FullSyncRequired { .. } => Required::FullSync,
SyncActionRequired::NormalSyncRequired => Required::NormalSync,
}
}
}

View File

@ -1,211 +0,0 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::{path::PathBuf, sync::MutexGuard};
use tokio::runtime::Runtime;
use crate::{
backend::{Backend, BackendState},
error::SyncErrorKind,
prelude::*,
sync::{
http::{
ApplyChangesRequest, ApplyChunkRequest, ApplyGravesRequest, HostKeyRequest,
HostKeyResponse, MetaRequest, SanityCheckRequest, StartRequest, SyncRequest,
},
Chunk, Graves, LocalServer, SanityCheckResponse, SanityCheckStatus, SyncMeta, SyncServer,
UnchunkedChanges, SYNC_VERSION_MAX, SYNC_VERSION_MIN,
},
};
impl Backend {
fn with_sync_server<F, T>(&self, func: F) -> Result<T>
where
F: FnOnce(&mut LocalServer) -> Result<T>,
{
let mut state_guard = self.state.lock().unwrap();
let out = func(
state_guard
.sync
.http_sync_server
.as_mut()
.ok_or_else(|| AnkiError::sync_error("", SyncErrorKind::SyncNotStarted))?,
);
if out.is_err() {
self.abort_and_restore_collection(Some(state_guard))
}
out
}
/// Gives out a dummy hkey - auth should be implemented at a higher layer.
fn host_key(&self, _input: HostKeyRequest) -> Result<HostKeyResponse> {
Ok(HostKeyResponse {
key: "unimplemented".into(),
})
}
fn meta(&self, input: MetaRequest) -> Result<SyncMeta> {
if input.sync_version < SYNC_VERSION_MIN || input.sync_version > SYNC_VERSION_MAX {
return Ok(SyncMeta {
server_message: "Your Anki version is either too old, or too new.".into(),
should_continue: false,
..Default::default()
});
}
let server = self.col_into_server()?;
let rt = Runtime::new().unwrap();
let meta = rt.block_on(server.meta())?;
self.server_into_col(server);
Ok(meta)
}
/// Takes the collection from the backend, places it into a server, and returns it.
fn col_into_server(&self) -> Result<LocalServer> {
self.col
.lock()
.unwrap()
.take()
.map(LocalServer::new)
.ok_or(AnkiError::CollectionNotOpen)
}
fn server_into_col(&self, server: LocalServer) {
let col = server.into_col();
let mut col_guard = self.col.lock().unwrap();
assert!(col_guard.replace(col).is_none());
}
fn take_server(&self, state_guard: Option<MutexGuard<BackendState>>) -> Result<LocalServer> {
let mut state_guard = state_guard.unwrap_or_else(|| self.state.lock().unwrap());
state_guard
.sync
.http_sync_server
.take()
.ok_or_else(|| AnkiError::sync_error("", SyncErrorKind::SyncNotStarted))
}
fn start(&self, input: StartRequest) -> Result<Graves> {
// place col into new server
let server = self.col_into_server()?;
let mut state_guard = self.state.lock().unwrap();
assert!(state_guard.sync.http_sync_server.replace(server).is_none());
drop(state_guard);
self.with_sync_server(|server| {
let rt = Runtime::new().unwrap();
rt.block_on(server.start(
input.client_usn,
input.local_is_newer,
input.deprecated_client_graves,
))
})
}
fn apply_graves(&self, input: ApplyGravesRequest) -> Result<()> {
self.with_sync_server(|server| {
let rt = Runtime::new().unwrap();
rt.block_on(server.apply_graves(input.chunk))
})
}
fn apply_changes(&self, input: ApplyChangesRequest) -> Result<UnchunkedChanges> {
self.with_sync_server(|server| {
let rt = Runtime::new().unwrap();
rt.block_on(server.apply_changes(input.changes))
})
}
fn chunk(&self) -> Result<Chunk> {
self.with_sync_server(|server| {
let rt = Runtime::new().unwrap();
rt.block_on(server.chunk())
})
}
fn apply_chunk(&self, input: ApplyChunkRequest) -> Result<()> {
self.with_sync_server(|server| {
let rt = Runtime::new().unwrap();
rt.block_on(server.apply_chunk(input.chunk))
})
}
fn sanity_check(&self, input: SanityCheckRequest) -> Result<SanityCheckResponse> {
self.with_sync_server(|server| {
let rt = Runtime::new().unwrap();
rt.block_on(server.sanity_check(input.client))
})
.map(|out| {
if out.status != SanityCheckStatus::Ok {
// sanity check failures are an implicit abort
self.abort_and_restore_collection(None);
}
out
})
}
fn finish(&self) -> Result<TimestampMillis> {
let out = self.with_sync_server(|server| {
let rt = Runtime::new().unwrap();
rt.block_on(server.finish())
});
self.server_into_col(self.take_server(None)?);
out
}
fn abort(&self) -> Result<()> {
self.abort_and_restore_collection(None);
Ok(())
}
fn abort_and_restore_collection(&self, state_guard: Option<MutexGuard<BackendState>>) {
if let Ok(mut server) = self.take_server(state_guard) {
let rt = Runtime::new().unwrap();
// attempt to roll back
if let Err(abort_err) = rt.block_on(server.abort()) {
println!("abort failed: {:?}", abort_err);
}
self.server_into_col(server);
}
}
/// Caller must re-open collection after this request. Provided file will be
/// consumed.
fn upload(&self, input: PathBuf) -> Result<()> {
// spool input into a file
let server = Box::new(self.col_into_server()?);
// then process upload
let rt = Runtime::new().unwrap();
rt.block_on(server.full_upload(&input, true))
}
/// Caller must re-open collection after this request, and is responsible
/// for cleaning up the returned file.
fn download(&self) -> Result<Vec<u8>> {
let server = Box::new(self.col_into_server()?);
let rt = Runtime::new().unwrap();
let file = rt.block_on(server.full_download(None))?;
let path = file.into_temp_path().keep()?;
Ok(path.to_str().expect("path was not in utf8").into())
}
pub(crate) fn sync_server_method_inner(&self, req: SyncRequest) -> Result<Vec<u8>> {
use serde_json::to_vec;
match req {
SyncRequest::HostKey(v) => to_vec(&self.host_key(v)?),
SyncRequest::Meta(v) => to_vec(&self.meta(v)?),
SyncRequest::Start(v) => to_vec(&self.start(v)?),
SyncRequest::ApplyGraves(v) => to_vec(&self.apply_graves(v)?),
SyncRequest::ApplyChanges(v) => to_vec(&self.apply_changes(v)?),
SyncRequest::Chunk => to_vec(&self.chunk()?),
SyncRequest::ApplyChunk(v) => to_vec(&self.apply_chunk(v)?),
SyncRequest::SanityCheck(v) => to_vec(&self.sanity_check(v)?),
SyncRequest::Finish => to_vec(&self.finish()?),
SyncRequest::Abort => to_vec(&self.abort()?),
SyncRequest::FullUpload(v) => to_vec(&self.upload(v)?),
SyncRequest::FullDownload => return self.download(),
}
.map_err(Into::into)
}
}

View File

@ -5,9 +5,8 @@ use std::{
ffi::OsStr,
fs::{read_dir, remove_file, DirEntry},
path::{Path, PathBuf},
thread::{
JoinHandle, {self},
},
thread,
thread::JoinHandle,
time::SystemTime,
};

View File

@ -33,6 +33,7 @@ pub struct CollectionBuilder {
media_db: Option<PathBuf>,
server: Option<bool>,
tr: Option<I18n>,
check_integrity: bool,
// temporary option for AnkiDroid
force_schema11: Option<bool>,
}
@ -56,7 +57,13 @@ impl CollectionBuilder {
let media_folder = self.media_folder.clone().unwrap_or_default();
let media_db = self.media_db.clone().unwrap_or_default();
let force_schema11 = self.force_schema11.unwrap_or_default();
let storage = SqliteStorage::open_or_create(&col_path, &tr, server, force_schema11)?;
let storage = SqliteStorage::open_or_create(
&col_path,
&tr,
server,
self.check_integrity,
force_schema11,
)?;
let col = Collection {
storage,
col_path,
@ -95,6 +102,11 @@ impl CollectionBuilder {
self.force_schema11 = Some(force);
self
}
pub fn set_check_integrity(&mut self, check_integrity: bool) -> &mut Self {
self.check_integrity = check_integrity;
self
}
}
#[cfg(test)]
@ -147,7 +159,13 @@ impl Collection {
builder
}
pub(crate) fn close(self, desired_version: Option<SchemaVersion>) -> Result<()> {
// A count of all changed rows since the collection was opened, which can be used to detect
// if the collection was modified or not.
pub fn changes_since_open(&self) -> u64 {
self.storage.db.changes()
}
pub fn close(self, desired_version: Option<SchemaVersion>) -> Result<()> {
self.storage.close(desired_version)
}

View File

@ -12,7 +12,6 @@ pub(crate) mod undo;
use serde::{de::DeserializeOwned, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
use strum::IntoStaticStr;
use tracing::warn;
pub use self::{
bool::BoolKey, deck::DeckConfigKey, notetype::get_aux_notetype_config_key,
@ -108,10 +107,8 @@ impl Collection {
match self.storage.get_config_value(key) {
Ok(Some(val)) => Some(val),
Ok(None) => None,
Err(e) => {
warn!(key, ?e, "error accessing config key");
None
}
// If the key is missing or invalid, we use the default value.
Err(_) => None,
}
}

View File

@ -23,8 +23,9 @@ pub enum DeckSchema11 {
// serde doesn't support integer/bool enum tags, so we manually pick the correct variant
mod dynfix {
use serde::de::{
Deserialize, Deserializer, {self},
use serde::{
de,
de::{Deserialize, Deserializer},
};
use serde_json::{Map, Value};

View File

@ -8,7 +8,7 @@ use snafu::Snafu;
/// Wrapper for [std::io::Error] with additional information on the attempted
/// operation.
#[derive(Debug, Snafu)]
#[snafu(visibility(pub))]
#[snafu(visibility(pub), display("{op:?} {path:?}"))]
pub struct FileIoError {
pub path: PathBuf,
pub op: FileOp,
@ -29,6 +29,7 @@ pub enum FileOp {
Open,
Create,
Write,
Remove,
CopyFrom(PathBuf),
Persist,
Sync,
@ -52,6 +53,7 @@ impl FileIoError {
FileOp::Read => "read".into(),
FileOp::Create => "create file in".into(),
FileOp::Write => "write".into(),
FileOp::Remove => "remove".into(),
FileOp::CopyFrom(p) => format!("copy from '{}' to", p.to_string_lossy()),
FileOp::Persist => "persist".into(),
FileOp::Sync => "sync".into(),

View File

@ -5,7 +5,7 @@ mod db;
mod file_io;
mod filtered;
mod invalid_input;
mod network;
pub(crate) mod network;
mod not_found;
mod search;

View File

@ -1,13 +1,17 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::error::Error;
use anki_i18n::I18n;
use reqwest::StatusCode;
use snafu::Snafu;
use super::AnkiError;
use crate::sync::{collection::sanity::SanityCheckCounts, error::HttpError};
#[derive(Debug, PartialEq, Eq, Snafu)]
#[snafu(visibility(pub(crate)))]
pub struct NetworkError {
pub info: String,
pub kind: NetworkErrorKind,
@ -40,6 +44,10 @@ pub enum SyncErrorKind {
DatabaseCheckRequired,
SyncNotStarted,
UploadTooLarge,
SanityCheckFailed {
client: Option<SanityCheckCounts>,
server: Option<SanityCheckCounts>,
},
}
impl AnkiError {
@ -172,7 +180,9 @@ impl SyncError {
SyncErrorKind::AuthFailed => tr.sync_wrong_pass(),
SyncErrorKind::ResyncRequired => tr.sync_resync_required(),
SyncErrorKind::ClockIncorrect => tr.sync_clock_off(),
SyncErrorKind::DatabaseCheckRequired => tr.sync_sanity_check_failed(),
SyncErrorKind::DatabaseCheckRequired | SyncErrorKind::SanityCheckFailed { .. } => {
tr.sync_sanity_check_failed()
}
SyncErrorKind::SyncNotStarted => "sync not started".into(),
SyncErrorKind::UploadTooLarge => tr.sync_upload_too_large(&self.info),
}
@ -192,3 +202,31 @@ impl NetworkError {
format!("{}\n\n{}", summary, details)
}
}
// This needs rethinking; we should be attaching error context as errors are encountered
// instead of trying to determine the problem later.
impl From<HttpError> for AnkiError {
fn from(err: HttpError) -> Self {
if let Some(source) = &err.source {
if let Some(err) = source.downcast_ref::<reqwest::Error>() {
if let Some(status) = err.status() {
let kind = match status {
StatusCode::CONFLICT => SyncErrorKind::Conflict,
StatusCode::NOT_IMPLEMENTED => SyncErrorKind::ClientTooOld,
StatusCode::FORBIDDEN => SyncErrorKind::AuthFailed,
StatusCode::INTERNAL_SERVER_ERROR => SyncErrorKind::ServerError,
StatusCode::BAD_REQUEST => SyncErrorKind::DatabaseCheckRequired,
_ => SyncErrorKind::Other,
};
let info = format!("{:?}", err);
// in the future we should chain the error instead of discarding it
return AnkiError::sync_error(info, kind);
} else if let Some(source) = err.source() {
let info = format!("{:?}", source);
return AnkiError::sync_error(info, SyncErrorKind::Other);
}
}
}
AnkiError::sync_error(format!("{:?}", err), SyncErrorKind::Other)
}
}

View File

@ -38,7 +38,9 @@ impl Context<'_> {
}
let db_progress_fn = self.progress.media_db_fn(ImportProgress::MediaCheck)?;
let existing_sha1s = self.media_manager.all_checksums(db_progress_fn)?;
let existing_sha1s = self
.media_manager
.all_checksums_after_checking(db_progress_fn)?;
prepare_media(
media_entries,
@ -50,9 +52,8 @@ impl Context<'_> {
pub(super) fn copy_media(&mut self, media_map: &mut MediaUseMap) -> Result<()> {
let mut incrementor = self.progress.incrementor(ImportProgress::Media);
let mut dbctx = self.media_manager.dbctx();
let mut copier = MediaCopier::new(false);
self.media_manager.transact(&mut dbctx, |dbctx| {
self.media_manager.transact(|_db| {
for entry in media_map.used_entries() {
incrementor.increment()?;
entry.copy_and_ensure_sha1_set(
@ -61,7 +62,7 @@ impl Context<'_> {
&mut copier,
)?;
self.media_manager
.add_entry(dbctx, &entry.name, entry.sha1.unwrap())?;
.add_entry(&entry.name, entry.sha1.unwrap())?;
}
Ok(())
})

View File

@ -59,7 +59,7 @@ impl<'a> Context<'a> {
) -> Result<Self> {
let mut progress = IncrementableProgress::new(progress_fn);
progress.call(ImportProgress::Extracting)?;
let media_manager = MediaManager::new(&target_col.media_folder, &target_col.media_db)?;
let media_manager = target_col.media()?;
let meta = Meta::from_archive(&mut archive)?;
let data = ExchangeData::gather_from_archive(
&mut archive,

View File

@ -6,9 +6,8 @@ use std::{
collections::HashMap,
ffi::OsStr,
fs::File,
io::{
Read, Write, {self},
},
io,
io::{Read, Write},
path::{Path, PathBuf},
};

View File

@ -3,17 +3,13 @@
use std::{
fs::File,
io::{
Write, {self},
},
io,
io::Write,
path::{Path, PathBuf},
};
use zip::{read::ZipFile, ZipArchive};
use zstd::{
stream::copy_decode,
{self},
};
use zstd::{self, stream::copy_decode};
use crate::{
collection::CollectionBuilder,

View File

@ -27,11 +27,10 @@ fn collection_with_media(dir: &Path, name: &str) -> Result<Collection> {
let mut note = nt.new_note();
col.add_note(&mut note, DeckId(1))?;
// add sample media
let mgr = MediaManager::new(&col.media_folder, &col.media_db)?;
let mut ctx = mgr.dbctx();
mgr.add_file(&mut ctx, "1", b"1")?;
mgr.add_file(&mut ctx, "2", b"2")?;
mgr.add_file(&mut ctx, "3", b"3")?;
let mgr = col.media()?;
mgr.add_file("1", b"1")?;
mgr.add_file("2", b"2")?;
mgr.add_file("3", b"3")?;
Ok(col)
}

View File

@ -4,9 +4,8 @@
use std::{
borrow::Cow,
collections::HashMap,
fs::{
File, {self},
},
fs,
fs::File,
io,
path::{Path, PathBuf},
};

View File

@ -1,12 +1,7 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::{
fs::File,
io::{
Read, {self},
},
};
use std::{fs::File, io, io::Read};
use prost::Message;
use zip::ZipArchive;

View File

@ -31,6 +31,15 @@ pub(crate) fn write_file(path: impl AsRef<Path>, contents: impl AsRef<[u8]>) ->
})
}
/// See [std::fs::remove_file].
#[allow(dead_code)]
pub(crate) fn remove_file(path: impl AsRef<Path>) -> Result<()> {
std::fs::remove_file(&path).context(FileIoSnafu {
path: path.as_ref(),
op: FileOp::Remove,
})
}
/// See [std::fs::create_dir].
pub(crate) fn create_dir(path: impl AsRef<Path>) -> Result<()> {
std::fs::create_dir(&path).context(FileIoSnafu {

View File

@ -36,7 +36,7 @@ pub mod search;
pub mod serde;
mod stats;
pub mod storage;
mod sync;
pub mod sync;
pub mod tags;
pub mod template;
pub mod template_filters;

View File

@ -14,16 +14,20 @@ const LOG_ROTATE_BYTES: u64 = 50 * 1024 * 1024;
/// Enable logging to the console, and optionally also to a file.
pub fn set_global_logger(path: Option<&str>) -> Result<()> {
let file_writer = if let Some(path) = path {
Some(Layer::new().with_writer(get_appender(path)?))
} else {
None
};
let subscriber = tracing_subscriber::registry()
.with(fmt::layer())
.with(file_writer)
.with(EnvFilter::from_default_env());
set_global_default(subscriber).or_invalid("global subscriber already set")?;
static ONCE: OnceCell<()> = OnceCell::new();
ONCE.get_or_try_init(|| -> Result<()> {
let file_writer = if let Some(path) = path {
Some(Layer::new().with_writer(get_appender(path)?))
} else {
None
};
let subscriber = tracing_subscriber::registry()
.with(fmt::layer().with_target(false))
.with(file_writer)
.with(EnvFilter::from_default_env());
set_global_default(subscriber).or_invalid("global subscriber already set")?;
Ok(())
})?;
Ok(())
}

View File

@ -16,14 +16,11 @@ use crate::{
error::{AnkiError, DbErrorKind, Result},
latex::extract_latex_expanding_clozes,
media::{
database::MediaDatabaseContext,
files::{
data_for_file, filename_if_normalized, normalize_nfc_filename, trash_folder,
MEDIA_SYNC_FILESIZE_LIMIT,
},
files::{data_for_file, filename_if_normalized, normalize_nfc_filename, trash_folder},
MediaManager,
},
notes::Note,
sync::media::MAX_INDIVIDUAL_MEDIA_FILE_SIZE,
text::{extract_media_refs, normalize_to_nfc, MediaRef, REMOTE_FILENAME},
};
@ -74,9 +71,7 @@ where
}
pub fn check(&mut self) -> Result<MediaCheckOutput> {
let mut ctx = self.mgr.dbctx();
let folder_check = self.check_media_folder(&mut ctx)?;
let folder_check = self.check_media_folder()?;
let referenced_files = self.check_media_references(&folder_check.renamed)?;
let (unused, missing) = find_unused_and_missing(folder_check.files, referenced_files);
let (trash_count, trash_bytes) = self.files_in_trash()?;
@ -186,7 +181,7 @@ where
/// - Renames files with invalid names
/// - Notes folders/oversized files
/// - Gathers a list of all files
fn check_media_folder(&mut self, ctx: &mut MediaDatabaseContext) -> Result<MediaFolderCheck> {
fn check_media_folder(&mut self) -> Result<MediaFolderCheck> {
let mut out = MediaFolderCheck::default();
for dentry in self.mgr.media_folder.read_dir()? {
let dentry = dentry?;
@ -211,7 +206,7 @@ where
// ignore large files and zero byte files
let metadata = dentry.metadata()?;
if metadata.len() > MEDIA_SYNC_FILESIZE_LIMIT as u64 {
if metadata.len() > MAX_INDIVIDUAL_MEDIA_FILE_SIZE as u64 {
out.oversize.push(disk_fname.to_string());
continue;
}
@ -224,7 +219,7 @@ where
} else {
match data_for_file(&self.mgr.media_folder, disk_fname)? {
Some(data) => {
let norm_name = self.normalize_file(ctx, disk_fname, data)?;
let norm_name = self.normalize_file(disk_fname, data)?;
out.renamed
.insert(disk_fname.to_string(), norm_name.to_string());
out.files.push(norm_name.into_owned());
@ -242,14 +237,9 @@ where
}
/// Write file data to normalized location, moving old file to trash.
fn normalize_file<'a>(
&mut self,
ctx: &mut MediaDatabaseContext,
disk_fname: &'a str,
data: Vec<u8>,
) -> Result<Cow<'a, str>> {
fn normalize_file<'a>(&mut self, disk_fname: &'a str, data: Vec<u8>) -> Result<Cow<'a, str>> {
// add a copy of the file using the correct name
let fname = self.mgr.add_file(ctx, disk_fname, &data)?;
let fname = self.mgr.add_file(disk_fname, &data)?;
debug!(from = disk_fname, to = &fname.as_ref(), "renamed");
assert_ne!(fname.as_ref(), disk_fname);
@ -336,9 +326,7 @@ where
let fname_os = dentry.file_name();
let fname = fname_os.to_string_lossy();
if let Some(data) = data_for_file(&trash, fname.as_ref())? {
let _new_fname =
self.mgr
.add_file(&mut self.mgr.dbctx(), fname.as_ref(), &data)?;
let _new_fname = self.mgr.add_file(fname.as_ref(), &data)?;
} else {
debug!(?fname, "file disappeared while restoring trash");
}

View File

@ -20,17 +20,9 @@ use crate::{
error::{FileIoError, FileIoSnafu, FileOp},
io::{create_dir, open_file, write_file},
prelude::*,
sync::media::MAX_MEDIA_FILENAME_LENGTH,
};
/// The maximum length we allow a filename to be. When combined
/// with the rest of the path, the full path needs to be under ~240 chars
/// on some platforms, and some filesystems like eCryptFS will increase
/// the length of the filename.
pub(super) static MAX_FILENAME_LENGTH: usize = 120;
/// Media syncing does not support files over 100MiB.
pub(super) static MEDIA_SYNC_FILESIZE_LIMIT: usize = 100 * 1024 * 1024;
lazy_static! {
static ref WINDOWS_DEVICE_NAME: Regex = Regex::new(
r#"(?xi)
@ -56,7 +48,7 @@ lazy_static! {
"#
)
.unwrap();
pub(super) static ref NONSYNCABLE_FILENAME: Regex = Regex::new(
pub(crate) static ref NONSYNCABLE_FILENAME: Regex = Regex::new(
r#"(?xi)
^
(:?
@ -119,7 +111,7 @@ pub(crate) fn normalize_nfc_filename(mut fname: Cow<str>) -> Cow<str> {
fname = format!("{}_", fname.as_ref()).into();
}
if let Cow::Owned(o) = truncate_filename(fname.as_ref(), MAX_FILENAME_LENGTH) {
if let Cow::Owned(o) = truncate_filename(fname.as_ref(), MAX_MEDIA_FILENAME_LENGTH) {
fname = o.into();
}
@ -198,7 +190,7 @@ where
/// Convert foo.jpg into foo-abcde12345679.jpg
pub(crate) fn add_hash_suffix_to_file_stem(fname: &str, hash: &Sha1Hash) -> String {
// when appending a hash to make unique, it will be 40 bytes plus the hyphen.
let max_len = MAX_FILENAME_LENGTH - 40 - 1;
let max_len = MAX_MEDIA_FILENAME_LENGTH - 40 - 1;
let (stem, ext) = split_and_truncate_filename(fname, max_len);
@ -308,14 +300,14 @@ pub(crate) fn sha1_of_data(data: &[u8]) -> Sha1Hash {
hasher.finalize().into()
}
pub(super) fn mtime_as_i64<P: AsRef<Path>>(path: P) -> io::Result<i64> {
pub(crate) fn mtime_as_i64<P: AsRef<Path>>(path: P) -> io::Result<i64> {
Ok(path
.as_ref()
.metadata()?
.modified()?
.duration_since(time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64)
.as_millis() as i64)
}
pub fn remove_files<S>(media_folder: &Path, files: &[S]) -> Result<()>
@ -375,7 +367,7 @@ pub(super) fn trash_folder(media_folder: &Path) -> Result<PathBuf> {
}
}
pub(super) struct AddedFile {
pub struct AddedFile {
pub fname: String,
pub sha1: Sha1Hash,
pub mtime: i64,
@ -386,7 +378,7 @@ pub(super) struct AddedFile {
///
/// Because AnkiWeb did not previously enforce file name limits and invalid
/// characters, we'll need to rename the file if it is not valid.
pub(super) fn add_file_from_ankiweb(
pub(crate) fn add_file_from_ankiweb(
media_folder: &Path,
fname: &str,
data: &[u8],
@ -424,7 +416,7 @@ pub(super) fn add_file_from_ankiweb(
})
}
pub(super) fn data_for_file(media_folder: &Path, fname: &str) -> Result<Option<Vec<u8>>> {
pub(crate) fn data_for_file(media_folder: &Path, fname: &str) -> Result<Option<Vec<u8>>> {
let mut file = match open_file(&media_folder.join(fname)) {
Err(e) if e.is_not_found() => return Ok(None),
res => res?,
@ -440,9 +432,12 @@ mod test {
use tempfile::tempdir;
use crate::media::files::{
add_data_to_folder_uniquely, add_hash_suffix_to_file_stem, normalize_filename,
remove_files, sha1_of_data, truncate_filename, MAX_FILENAME_LENGTH,
use crate::{
media::files::{
add_data_to_folder_uniquely, add_hash_suffix_to_file_stem, normalize_filename,
remove_files, sha1_of_data, truncate_filename,
},
sync::media::MAX_MEDIA_FILENAME_LENGTH,
};
#[test]
@ -456,9 +451,12 @@ mod test {
assert_eq!(normalize_filename("test.").as_ref(), "test._");
assert_eq!(normalize_filename("test ").as_ref(), "test _");
let expected_stem_len = MAX_FILENAME_LENGTH - ".jpg".len() - 1;
let expected_stem_len = MAX_MEDIA_FILENAME_LENGTH - ".jpg".len() - 1;
assert_eq!(
normalize_filename(&format!("{}.jpg", "x".repeat(MAX_FILENAME_LENGTH * 2))),
normalize_filename(&format!(
"{}.jpg",
"x".repeat(MAX_MEDIA_FILENAME_LENGTH * 2)
)),
"x".repeat(expected_stem_len) + ".jpg"
);
}
@ -516,29 +514,32 @@ mod test {
#[test]
fn truncation() {
let one_less = "x".repeat(MAX_FILENAME_LENGTH - 1);
let one_less = "x".repeat(MAX_MEDIA_FILENAME_LENGTH - 1);
assert_eq!(
truncate_filename(&one_less, MAX_FILENAME_LENGTH),
truncate_filename(&one_less, MAX_MEDIA_FILENAME_LENGTH),
Cow::Borrowed(&one_less)
);
let equal = "x".repeat(MAX_FILENAME_LENGTH);
let equal = "x".repeat(MAX_MEDIA_FILENAME_LENGTH);
assert_eq!(
truncate_filename(&equal, MAX_FILENAME_LENGTH),
truncate_filename(&equal, MAX_MEDIA_FILENAME_LENGTH),
Cow::Borrowed(&equal)
);
let equal = format!("{}.jpg", "x".repeat(MAX_FILENAME_LENGTH - 4));
let equal = format!("{}.jpg", "x".repeat(MAX_MEDIA_FILENAME_LENGTH - 4));
assert_eq!(
truncate_filename(&equal, MAX_FILENAME_LENGTH),
truncate_filename(&equal, MAX_MEDIA_FILENAME_LENGTH),
Cow::Borrowed(&equal)
);
let one_more = "x".repeat(MAX_FILENAME_LENGTH + 1);
let one_more = "x".repeat(MAX_MEDIA_FILENAME_LENGTH + 1);
assert_eq!(
truncate_filename(&one_more, MAX_FILENAME_LENGTH),
Cow::<str>::Owned("x".repeat(MAX_FILENAME_LENGTH - 2))
truncate_filename(&one_more, MAX_MEDIA_FILENAME_LENGTH),
Cow::<str>::Owned("x".repeat(MAX_MEDIA_FILENAME_LENGTH - 2))
);
assert_eq!(
truncate_filename(&" ".repeat(MAX_FILENAME_LENGTH + 1), MAX_FILENAME_LENGTH),
Cow::<str>::Owned(format!("{}_", " ".repeat(MAX_FILENAME_LENGTH - 2)))
truncate_filename(
&" ".repeat(MAX_MEDIA_FILENAME_LENGTH + 1),
MAX_MEDIA_FILENAME_LENGTH
),
Cow::<str>::Owned(format!("{}_", " ".repeat(MAX_MEDIA_FILENAME_LENGTH - 2)))
);
}
}

View File

@ -1,35 +1,41 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
pub mod check;
pub mod files;
use std::{
borrow::Cow,
collections::HashMap,
path::{Path, PathBuf},
};
use rusqlite::Connection;
use self::changetracker::ChangeTracker;
use crate::{
media::{
database::{open_or_create, MediaDatabaseContext, MediaEntry},
files::{add_data_to_folder_uniquely, mtime_as_i64, remove_files, sha1_of_data},
sync::{MediaSyncProgress, MediaSyncer},
},
io::create_dir_all,
media::files::{add_data_to_folder_uniquely, mtime_as_i64, remove_files, sha1_of_data},
prelude::*,
sync::{
http_client::HttpSyncClient,
login::SyncAuth,
media::{
database::client::{changetracker::ChangeTracker, MediaDatabase, MediaEntry},
progress::MediaSyncProgress,
syncer::MediaSyncer,
},
},
};
pub mod changetracker;
pub mod check;
pub mod database;
pub mod files;
pub mod sync;
pub type Sha1Hash = [u8; 20];
impl Collection {
pub fn media(&self) -> Result<MediaManager> {
MediaManager::new(&self.media_folder, &self.media_db)
}
}
pub struct MediaManager {
db: Connection,
media_folder: PathBuf,
pub(crate) db: MediaDatabase,
pub(crate) media_folder: PathBuf,
}
impl MediaManager {
@ -38,10 +44,11 @@ impl MediaManager {
P: Into<PathBuf>,
P2: AsRef<Path>,
{
let db = open_or_create(media_db.as_ref())?;
let media_folder = media_folder.into();
create_dir_all(&media_folder)?;
Ok(MediaManager {
db,
media_folder: media_folder.into(),
db: MediaDatabase::new(media_db.as_ref())?,
media_folder,
})
}
@ -51,26 +58,21 @@ impl MediaManager {
/// appended to the name.
///
/// Also notes the file in the media database.
pub fn add_file<'a>(
&self,
ctx: &mut MediaDatabaseContext,
desired_name: &'a str,
data: &[u8],
) -> Result<Cow<'a, str>> {
pub fn add_file<'a>(&self, desired_name: &'a str, data: &[u8]) -> Result<Cow<'a, str>> {
let data_hash = sha1_of_data(data);
self.transact(ctx, |ctx| {
self.transact(|db| {
let chosen_fname =
add_data_to_folder_uniquely(&self.media_folder, desired_name, data, data_hash)?;
let file_mtime = mtime_as_i64(self.media_folder.join(chosen_fname.as_ref()))?;
let existing_entry = ctx.get_entry(&chosen_fname)?;
let existing_entry = db.get_entry(&chosen_fname)?;
let new_sha1 = Some(data_hash);
let entry_update_required = existing_entry.map(|e| e.sha1 != new_sha1).unwrap_or(true);
if entry_update_required {
ctx.set_entry(&MediaEntry {
db.set_entry(&MediaEntry {
fname: chosen_fname.to_string(),
sha1: new_sha1,
mtime: file_mtime,
@ -82,18 +84,18 @@ impl MediaManager {
})
}
pub fn remove_files<S>(&self, ctx: &mut MediaDatabaseContext, filenames: &[S]) -> Result<()>
pub fn remove_files<S>(&self, filenames: &[S]) -> Result<()>
where
S: AsRef<str> + std::fmt::Debug,
{
self.transact(ctx, |ctx| {
self.transact(|db| {
remove_files(&self.media_folder, filenames)?;
for fname in filenames {
if let Some(mut entry) = ctx.get_entry(fname.as_ref())? {
if let Some(mut entry) = db.get_entry(fname.as_ref())? {
entry.sha1 = None;
entry.mtime = 0;
entry.sync_required = true;
ctx.set_entry(&entry)?;
db.set_entry(&entry)?;
}
}
Ok(())
@ -102,21 +104,17 @@ impl MediaManager {
/// Opens a transaction and manages folder mtime, so user should perform not
/// only db ops, but also all file ops inside the closure.
pub(crate) fn transact<T>(
&self,
ctx: &mut MediaDatabaseContext,
func: impl FnOnce(&mut MediaDatabaseContext) -> Result<T>,
) -> Result<T> {
pub(crate) fn transact<T>(&self, func: impl FnOnce(&MediaDatabase) -> Result<T>) -> Result<T> {
let start_folder_mtime = mtime_as_i64(&self.media_folder)?;
ctx.transact(|ctx| {
let out = func(ctx)?;
self.db.transact(|db| {
let out = func(db)?;
let mut meta = ctx.get_meta()?;
let mut meta = db.get_meta()?;
if meta.folder_mtime == start_folder_mtime {
// if media db was in sync with folder prior to this add,
// we can keep it in sync
meta.folder_mtime = mtime_as_i64(&self.media_folder)?;
ctx.set_meta(&meta)?;
db.set_meta(&meta)?;
} else {
// otherwise, leave it alone so that other pending changes
// get picked up later
@ -127,15 +125,10 @@ impl MediaManager {
}
/// Set entry for a newly added file. Caller must ensure transaction.
pub(crate) fn add_entry(
&self,
ctx: &mut MediaDatabaseContext,
fname: impl Into<String>,
sha1: [u8; 20],
) -> Result<()> {
pub(crate) fn add_entry(&self, fname: impl Into<String>, sha1: [u8; 20]) -> Result<()> {
let fname = fname.into();
let mtime = mtime_as_i64(self.media_folder.join(&fname))?;
ctx.set_entry(&MediaEntry {
self.db.set_entry(&MediaEntry {
fname,
mtime,
sha1: Some(sha1),
@ -144,55 +137,38 @@ impl MediaManager {
}
/// Sync media.
pub async fn sync_media<'a, F>(
&'a self,
progress: F,
host_number: u32,
hkey: &'a str,
) -> Result<()>
pub async fn sync_media<F>(self, progress: F, auth: SyncAuth) -> Result<()>
where
F: FnMut(MediaSyncProgress) -> bool,
{
let mut syncer = MediaSyncer::new(self, progress, host_number);
syncer.sync(hkey).await
let client = HttpSyncClient::new(auth);
let mut syncer = MediaSyncer::new(self, progress, client)?;
syncer.sync().await
}
pub fn dbctx(&self) -> MediaDatabaseContext {
MediaDatabaseContext::new(&self.db)
}
pub fn all_checksums(
pub fn all_checksums_after_checking(
&self,
progress: impl FnMut(usize) -> bool,
) -> Result<HashMap<String, Sha1Hash>> {
let mut dbctx = self.dbctx();
ChangeTracker::new(&self.media_folder, progress).register_changes(&mut dbctx)?;
dbctx.all_checksums()
ChangeTracker::new(&self.media_folder, progress).register_changes(&self.db)?;
self.db.all_registered_checksums()
}
pub fn checksum_getter(&self) -> impl FnMut(&str) -> Result<Option<Sha1Hash>> + '_ {
let mut dbctx = self.dbctx();
move |fname: &str| {
dbctx
|fname: &str| {
self.db
.get_entry(fname)
.map(|opt| opt.and_then(|entry| entry.sha1))
}
}
pub fn register_changes(&self, progress: &mut impl FnMut(usize) -> bool) -> Result<()> {
ChangeTracker::new(&self.media_folder, progress).register_changes(&mut self.dbctx())
}
}
#[cfg(test)]
mod test {
use super::*;
impl MediaManager {
/// All checksums without registering changes first.
pub(crate) fn all_checksums_as_is(&self) -> HashMap<String, [u8; 20]> {
let mut dbctx = self.dbctx();
dbctx.all_checksums().unwrap()
}
ChangeTracker::new(&self.media_folder, progress).register_changes(&self.db)
}
/// All checksums without registering changes first.
#[cfg(test)]
pub(crate) fn all_checksums_as_is(&self) -> HashMap<String, [u8; 20]> {
self.db.all_registered_checksums().unwrap()
}
}

View File

@ -1,839 +0,0 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::{
borrow::Cow,
collections::HashMap,
io,
io::{Read, Write},
path::Path,
time,
};
use bytes::Bytes;
use reqwest::{multipart, Client, Response};
use serde_derive::{Deserialize, Serialize};
use serde_tuple::Serialize_tuple;
use time::Duration;
use tracing::debug;
use version::sync_client_version;
use crate::{
error::{AnkiError, Result, SyncErrorKind},
media::{
changetracker::ChangeTracker,
database::{MediaDatabaseContext, MediaDatabaseMetadata, MediaEntry},
files::{
add_file_from_ankiweb, data_for_file, mtime_as_i64, normalize_filename, AddedFile,
},
MediaManager,
},
sync::Timeouts,
version,
};
static SYNC_MAX_FILES: usize = 25;
static SYNC_MAX_BYTES: usize = (2.5 * 1024.0 * 1024.0) as usize;
static SYNC_SINGLE_FILE_MAX_BYTES: usize = 100 * 1024 * 1024;
#[derive(Debug, Default, Clone, Copy)]
pub struct MediaSyncProgress {
pub checked: usize,
pub downloaded_files: usize,
pub downloaded_deletions: usize,
pub uploaded_files: usize,
pub uploaded_deletions: usize,
}
pub struct MediaSyncer<'a, P>
where
P: FnMut(MediaSyncProgress) -> bool,
{
mgr: &'a MediaManager,
ctx: MediaDatabaseContext<'a>,
skey: Option<String>,
client: Client,
progress_cb: P,
progress: MediaSyncProgress,
endpoint: String,
}
#[derive(Debug, Deserialize)]
struct SyncBeginResult {
data: Option<SyncBeginResponse>,
err: String,
}
#[derive(Debug, Deserialize)]
struct SyncBeginResponse {
#[serde(rename = "sk")]
sync_key: String,
usn: i32,
}
#[derive(Debug, Clone, Copy)]
enum LocalState {
NotInDb,
InDbNotPending,
InDbAndPending,
}
#[derive(PartialEq, Eq, Debug)]
enum RequiredChange {
// none also covers the case where we'll later upload
None,
Download,
Delete,
RemovePending,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct RecordBatchRequest {
last_usn: i32,
}
#[derive(Debug, Deserialize)]
struct RecordBatchResult {
data: Option<Vec<ServerMediaRecord>>,
err: String,
}
#[derive(Debug, Deserialize)]
struct ServerMediaRecord {
fname: String,
usn: i32,
sha1: String,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct ZipRequest<'a> {
files: &'a [&'a String],
}
#[derive(Serialize_tuple)]
struct UploadEntry<'a> {
fname: &'a str,
in_zip_name: Option<String>,
}
#[derive(Deserialize, Debug)]
struct UploadResult {
data: Option<UploadReply>,
err: String,
}
#[derive(Deserialize, Debug)]
struct UploadReply {
processed: usize,
current_usn: i32,
}
#[derive(Serialize)]
struct FinalizeRequest {
local: u32,
}
#[derive(Debug, Deserialize)]
struct FinalizeResponse {
data: Option<String>,
err: String,
}
fn media_sync_endpoint(host_number: u32) -> String {
if let Ok(endpoint) = std::env::var("SYNC_ENDPOINT_MEDIA") {
endpoint
} else {
let suffix = if host_number > 0 {
format!("{}", host_number)
} else {
"".to_string()
};
format!("https://sync{}.ankiweb.net/msync/", suffix)
}
}
impl<P> MediaSyncer<'_, P>
where
P: FnMut(MediaSyncProgress) -> bool,
{
pub fn new(mgr: &MediaManager, progress_cb: P, host_number: u32) -> MediaSyncer<'_, P> {
let timeouts = Timeouts::new();
let client = Client::builder()
.connect_timeout(Duration::from_secs(timeouts.connect_secs))
.timeout(Duration::from_secs(timeouts.request_secs))
.io_timeout(Duration::from_secs(timeouts.io_secs))
.build()
.unwrap();
let endpoint = media_sync_endpoint(host_number);
let ctx = mgr.dbctx();
MediaSyncer {
mgr,
ctx,
skey: None,
client,
progress_cb,
progress: Default::default(),
endpoint,
}
}
fn skey(&self) -> &str {
self.skey.as_ref().unwrap()
}
pub async fn sync(&mut self, hkey: &str) -> Result<()> {
self.sync_inner(hkey).await.map_err(|e| {
debug!("sync error: {:?}", e);
e
})
}
#[allow(clippy::useless_let_if_seq)]
async fn sync_inner(&mut self, hkey: &str) -> Result<()> {
self.register_changes()?;
let meta = self.ctx.get_meta()?;
let client_usn = meta.last_sync_usn;
debug!("begin media sync");
let (sync_key, server_usn) = self.sync_begin(hkey).await?;
self.skey = Some(sync_key);
debug!("server usn was {}", server_usn);
let mut actions_performed = false;
// need to fetch changes from server?
if client_usn != server_usn {
debug!("differs from local usn {}, fetching changes", client_usn);
self.fetch_changes(meta).await?;
actions_performed = true;
}
// need to send changes to server?
let changes_pending = !self.ctx.get_pending_uploads(1)?.is_empty();
if changes_pending {
self.send_changes().await?;
actions_performed = true;
}
if actions_performed {
self.finalize_sync().await?;
}
self.fire_progress_cb()?;
debug!("media sync complete");
Ok(())
}
/// Make sure media DB is up to date.
fn register_changes(&mut self) -> Result<()> {
// make borrow checker happy
let progress = &mut self.progress;
let progress_cb = &mut self.progress_cb;
let progress = |checked| {
progress.checked = checked;
(progress_cb)(*progress)
};
ChangeTracker::new(self.mgr.media_folder.as_path(), progress)
.register_changes(&mut self.ctx)
}
async fn sync_begin(&self, hkey: &str) -> Result<(String, i32)> {
let url = format!("{}begin", self.endpoint);
let resp = self
.client
.get(&url)
.query(&[("k", hkey), ("v", sync_client_version())])
.send()
.await?
.error_for_status()?;
let reply: SyncBeginResult = resp.json().await?;
if let Some(data) = reply.data {
Ok((data.sync_key, data.usn))
} else {
Err(AnkiError::server_message(reply.err))
}
}
async fn fetch_changes(&mut self, mut meta: MediaDatabaseMetadata) -> Result<()> {
let mut last_usn = meta.last_sync_usn;
loop {
debug!(start_usn = last_usn, "fetching record batch");
let batch = self.fetch_record_batch(last_usn).await?;
if batch.is_empty() {
debug!("empty batch, done");
break;
}
last_usn = batch.last().unwrap().usn;
self.progress.checked += batch.len();
self.fire_progress_cb()?;
let (to_download, to_delete, to_remove_pending) =
determine_required_changes(&mut self.ctx, &batch)?;
// file removal
self.mgr.remove_files(&mut self.ctx, to_delete.as_slice())?;
self.progress.downloaded_deletions += to_delete.len();
self.fire_progress_cb()?;
// file download
let mut downloaded = vec![];
let mut dl_fnames = to_download.as_slice();
while !dl_fnames.is_empty() {
let batch: Vec<_> = dl_fnames
.iter()
.take(SYNC_MAX_FILES)
.map(ToOwned::to_owned)
.collect();
let zip_data = self.fetch_zip(batch.as_slice()).await?;
let download_batch =
extract_into_media_folder(self.mgr.media_folder.as_path(), zip_data)?
.into_iter();
let len = download_batch.len();
dl_fnames = &dl_fnames[len..];
downloaded.extend(download_batch);
self.progress.downloaded_files += len;
self.fire_progress_cb()?;
}
// then update the DB
let dirmod = mtime_as_i64(&self.mgr.media_folder)?;
self.ctx.transact(|ctx| {
record_clean(ctx, &to_remove_pending)?;
record_removals(ctx, &to_delete)?;
record_additions(ctx, downloaded)?;
// update usn
meta.last_sync_usn = last_usn;
meta.folder_mtime = dirmod;
ctx.set_meta(&meta)?;
Ok(())
})?;
}
Ok(())
}
async fn send_changes(&mut self) -> Result<()> {
loop {
let pending: Vec<MediaEntry> = self.ctx.get_pending_uploads(SYNC_MAX_FILES as u32)?;
if pending.is_empty() {
break;
}
let zip_data = zip_files(&mut self.ctx, &self.mgr.media_folder, &pending)?;
if zip_data.is_none() {
self.progress.checked += pending.len();
self.fire_progress_cb()?;
// discard zip info and retry batch - not particularly efficient,
// but this is a corner case
continue;
}
let reply = self.send_zip_data(zip_data.unwrap()).await?;
let (processed_files, processed_deletions): (Vec<_>, Vec<_>) = pending
.iter()
.take(reply.processed)
.partition(|e| e.sha1.is_some());
self.progress.uploaded_files += processed_files.len();
self.progress.uploaded_deletions += processed_deletions.len();
self.fire_progress_cb()?;
let fnames: Vec<_> = processed_files
.iter()
.chain(processed_deletions.iter())
.map(|e| &e.fname)
.collect();
let fname_cnt = fnames.len() as i32;
self.ctx.transact(|ctx| {
record_clean(ctx, fnames.as_slice())?;
let mut meta = ctx.get_meta()?;
if meta.last_sync_usn + fname_cnt == reply.current_usn {
meta.last_sync_usn = reply.current_usn;
ctx.set_meta(&meta)?;
} else {
debug!(
"server usn {} is not {}, skipping usn update",
reply.current_usn,
meta.last_sync_usn + fname_cnt
);
}
Ok(())
})?;
}
Ok(())
}
async fn finalize_sync(&mut self) -> Result<()> {
let url = format!("{}mediaSanity", self.endpoint);
let local = self.ctx.count()?;
let obj = FinalizeRequest { local };
let resp = ankiweb_json_request(&self.client, &url, &obj, self.skey(), false).await?;
let resp: FinalizeResponse = resp.json().await?;
if let Some(data) = resp.data {
if data == "OK" {
Ok(())
} else {
self.ctx.transact(|ctx| ctx.force_resync())?;
Err(AnkiError::sync_error("", SyncErrorKind::ResyncRequired))
}
} else {
Err(AnkiError::server_message(resp.err))
}
}
fn fire_progress_cb(&mut self) -> Result<()> {
if (self.progress_cb)(self.progress) {
Ok(())
} else {
Err(AnkiError::Interrupted)
}
}
async fn fetch_record_batch(&self, last_usn: i32) -> Result<Vec<ServerMediaRecord>> {
let url = format!("{}mediaChanges", self.endpoint);
let req = RecordBatchRequest { last_usn };
let resp = ankiweb_json_request(&self.client, &url, &req, self.skey(), false).await?;
let res: RecordBatchResult = resp.json().await?;
if let Some(batch) = res.data {
Ok(batch)
} else {
Err(AnkiError::server_message(res.err))
}
}
async fn fetch_zip(&self, files: &[&String]) -> Result<Bytes> {
let url = format!("{}downloadFiles", self.endpoint);
debug!("requesting files: {:?}", files);
let req = ZipRequest { files };
let resp = ankiweb_json_request(&self.client, &url, &req, self.skey(), true).await?;
resp.bytes().await.map_err(Into::into)
}
async fn send_zip_data(&self, data: Vec<u8>) -> Result<UploadReply> {
let url = format!("{}uploadChanges", self.endpoint);
let resp = ankiweb_bytes_request(&self.client, &url, data, self.skey(), true).await?;
let res: UploadResult = resp.json().await?;
if let Some(reply) = res.data {
Ok(reply)
} else {
Err(AnkiError::server_message(res.err))
}
}
}
fn determine_required_change(
local_sha1: &str,
remote_sha1: &str,
local_state: LocalState,
) -> RequiredChange {
use LocalState as L;
use RequiredChange as R;
match (local_sha1, remote_sha1, local_state) {
// both deleted, not in local DB
("", "", L::NotInDb) => R::None,
// both deleted, in local DB
("", "", _) => R::Delete,
// added on server, add even if local deletion pending
("", _, _) => R::Download,
// deleted on server but added locally; upload later
(_, "", L::InDbAndPending) => R::None,
// deleted on server and not pending sync
(_, "", _) => R::Delete,
// if pending but the same as server, don't need to upload
(lsum, rsum, L::InDbAndPending) if lsum == rsum => R::RemovePending,
(lsum, rsum, _) => {
if lsum == rsum {
// not pending and same as server, nothing to do
R::None
} else {
// differs from server, favour server
R::Download
}
}
}
}
/// Get a list of server filenames and the actions required on them.
/// Returns filenames in (to_download, to_delete).
fn determine_required_changes<'a>(
ctx: &mut MediaDatabaseContext,
records: &'a [ServerMediaRecord],
) -> Result<(Vec<&'a String>, Vec<&'a String>, Vec<&'a String>)> {
let mut to_download = vec![];
let mut to_delete = vec![];
let mut to_remove_pending = vec![];
for remote in records {
let (local_sha1, local_state) = match ctx.get_entry(&remote.fname)? {
Some(entry) => (
match entry.sha1 {
Some(arr) => hex::encode(arr),
None => "".to_string(),
},
if entry.sync_required {
LocalState::InDbAndPending
} else {
LocalState::InDbNotPending
},
),
None => ("".to_string(), LocalState::NotInDb),
};
let req_change = determine_required_change(&local_sha1, &remote.sha1, local_state);
debug!(
fname = &remote.fname,
lsha = local_sha1.chars().take(8).collect::<String>(),
rsha = remote.sha1.chars().take(8).collect::<String>(),
state = ?local_state,
action = ?req_change,
"determine action"
);
match req_change {
RequiredChange::Download => to_download.push(&remote.fname),
RequiredChange::Delete => to_delete.push(&remote.fname),
RequiredChange::RemovePending => to_remove_pending.push(&remote.fname),
RequiredChange::None => (),
};
}
Ok((to_download, to_delete, to_remove_pending))
}
async fn ankiweb_json_request<T>(
client: &Client,
url: &str,
json: &T,
skey: &str,
timeout_long: bool,
) -> Result<Response>
where
T: serde::Serialize,
{
let req_json = serde_json::to_string(json)?;
let part = multipart::Part::text(req_json);
ankiweb_request(client, url, part, skey, timeout_long).await
}
async fn ankiweb_bytes_request(
client: &Client,
url: &str,
bytes: Vec<u8>,
skey: &str,
timeout_long: bool,
) -> Result<Response> {
let part = multipart::Part::bytes(bytes);
ankiweb_request(client, url, part, skey, timeout_long).await
}
async fn ankiweb_request(
client: &Client,
url: &str,
data_part: multipart::Part,
skey: &str,
timeout_long: bool,
) -> Result<Response> {
let data_part = data_part.file_name("data");
let form = multipart::Form::new()
.part("data", data_part)
.text("sk", skey.to_string());
let mut req = client.post(url).multipart(form);
if timeout_long {
req = req.timeout(Duration::from_secs(60 * 60));
}
req.send().await?.error_for_status().map_err(Into::into)
}
fn extract_into_media_folder(media_folder: &Path, zip: Bytes) -> Result<Vec<AddedFile>> {
let reader = io::Cursor::new(zip);
let mut zip = zip::ZipArchive::new(reader)?;
let meta_file = zip.by_name("_meta")?;
let fmap: HashMap<String, String> = serde_json::from_reader(meta_file)?;
let mut output = Vec::with_capacity(fmap.len());
for i in 0..zip.len() {
let mut file = zip.by_index(i)?;
let name = file.name();
if name == "_meta" {
continue;
}
let real_name = fmap
.get(name)
.ok_or_else(|| AnkiError::sync_error("malformed zip", SyncErrorKind::Other))?;
let mut data = Vec::with_capacity(file.size() as usize);
file.read_to_end(&mut data)?;
let added = add_file_from_ankiweb(media_folder, real_name, &data)?;
output.push(added);
}
Ok(output)
}
fn record_removals(ctx: &mut MediaDatabaseContext, removals: &[&String]) -> Result<()> {
for &fname in removals {
debug!(fname, "mark removed");
ctx.remove_entry(fname)?;
}
Ok(())
}
fn record_additions(ctx: &mut MediaDatabaseContext, additions: Vec<AddedFile>) -> Result<()> {
for file in additions {
if let Some(renamed) = file.renamed_from {
// the file AnkiWeb sent us wasn't normalized, so we need to record
// the old file name as a deletion
debug!("marking non-normalized file as deleted: {}", renamed);
let mut entry = MediaEntry {
fname: renamed,
sha1: None,
mtime: 0,
sync_required: true,
};
ctx.set_entry(&entry)?;
// and upload the new filename to ankiweb
debug!("marking renamed file as needing upload: {}", file.fname);
entry = MediaEntry {
fname: file.fname.to_string(),
sha1: Some(file.sha1),
mtime: file.mtime,
sync_required: true,
};
ctx.set_entry(&entry)?;
} else {
// a normal addition
let entry = MediaEntry {
fname: file.fname.to_string(),
sha1: Some(file.sha1),
mtime: file.mtime,
sync_required: false,
};
debug!(
fname = &entry.fname,
sha1 = hex::encode(&entry.sha1.as_ref().unwrap()[0..4]),
"mark added"
);
ctx.set_entry(&entry)?;
}
}
Ok(())
}
fn record_clean(ctx: &mut MediaDatabaseContext, clean: &[&String]) -> Result<()> {
for &fname in clean {
if let Some(mut entry) = ctx.get_entry(fname)? {
if entry.sync_required {
entry.sync_required = false;
debug!(fname = &entry.fname, "mark clean");
ctx.set_entry(&entry)?;
}
}
}
Ok(())
}
fn zip_files<'a>(
ctx: &mut MediaDatabaseContext,
media_folder: &Path,
files: &'a [MediaEntry],
) -> Result<Option<Vec<u8>>> {
let buf = vec![];
let mut invalid_entries = vec![];
let w = io::Cursor::new(buf);
let mut zip = zip::ZipWriter::new(w);
let options =
zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored);
let mut accumulated_size = 0;
let mut entries = vec![];
for (idx, file) in files.iter().enumerate() {
if accumulated_size > SYNC_MAX_BYTES {
break;
}
#[cfg(target_vendor = "apple")]
{
use unicode_normalization::is_nfc;
if !is_nfc(&file.fname) {
// older Anki versions stored non-normalized filenames in the DB; clean them up
debug!(fname = file.fname, "clean up non-nfc entry");
invalid_entries.push(&file.fname);
continue;
}
}
let file_data = if file.sha1.is_some() {
match data_for_file(media_folder, &file.fname) {
Ok(data) => data,
Err(e) => {
debug!("error accessing {}: {}", &file.fname, e);
invalid_entries.push(&file.fname);
continue;
}
}
} else {
// uploading deletion
None
};
if let Some(data) = &file_data {
let normalized = normalize_filename(&file.fname);
if let Cow::Owned(o) = normalized {
debug!("media check required: {} should be {}", &file.fname, o);
invalid_entries.push(&file.fname);
continue;
}
if data.is_empty() {
invalid_entries.push(&file.fname);
continue;
}
if data.len() > SYNC_SINGLE_FILE_MAX_BYTES {
invalid_entries.push(&file.fname);
continue;
}
accumulated_size += data.len();
zip.start_file(format!("{}", idx), options)?;
zip.write_all(data)?;
}
debug!(
fname = &file.fname,
kind = if file_data.is_some() {
"addition "
} else {
"removal"
},
"will upload"
);
entries.push(UploadEntry {
fname: &file.fname,
in_zip_name: if file_data.is_some() {
Some(format!("{}", idx))
} else {
None
},
});
}
if !invalid_entries.is_empty() {
// clean up invalid entries; we'll build a new zip
ctx.transact(|ctx| {
for fname in invalid_entries {
ctx.remove_entry(fname)?;
}
Ok(())
})?;
return Ok(None);
}
let meta = serde_json::to_string(&entries)?;
zip.start_file("_meta", options)?;
zip.write_all(meta.as_bytes())?;
let w = zip.finish()?;
Ok(Some(w.into_inner()))
}
#[cfg(test)]
mod test {
use tempfile::tempdir;
use tokio::runtime::Runtime;
use crate::{
error::Result,
io::{create_dir, write_file},
media::{
sync::{determine_required_change, LocalState, MediaSyncProgress, RequiredChange},
MediaManager,
},
};
async fn test_sync(hkey: &str) -> Result<()> {
let dir = tempdir()?;
let media_dir = dir.path().join("media");
create_dir(&media_dir)?;
let media_db = dir.path().join("media.db");
write_file(media_dir.join("test.file").as_path(), "hello")?;
let progress = |progress: MediaSyncProgress| {
println!("got progress: {:?}", progress);
true
};
let mgr = MediaManager::new(&media_dir, &media_db)?;
mgr.sync_media(progress, 0, hkey).await?;
Ok(())
}
#[test]
fn sync() {
let hkey = match std::env::var("TEST_HKEY") {
Ok(s) => s,
Err(_) => {
return;
}
};
let rt = Runtime::new().unwrap();
rt.block_on(test_sync(&hkey)).unwrap()
}
#[test]
fn required_change() {
use determine_required_change as d;
use LocalState as L;
use RequiredChange as R;
assert_eq!(d("", "", L::NotInDb), R::None);
assert_eq!(d("", "", L::InDbNotPending), R::Delete);
assert_eq!(d("", "1", L::InDbAndPending), R::Download);
assert_eq!(d("1", "", L::InDbAndPending), R::None);
assert_eq!(d("1", "", L::InDbNotPending), R::Delete);
assert_eq!(d("1", "1", L::InDbNotPending), R::None);
assert_eq!(d("1", "1", L::InDbAndPending), R::RemovePending);
assert_eq!(d("a", "b", L::InDbAndPending), R::Download);
assert_eq!(d("a", "b", L::InDbNotPending), R::Download);
}
}

View File

@ -120,7 +120,7 @@ impl Note {
pub(crate) fn new(notetype: &Notetype) -> Self {
Note {
id: NoteId(0),
guid: guid(),
guid: base91_u64(),
notetype_id: notetype.id,
mtime: TimestampSecs(0),
usn: Usn(0),
@ -297,20 +297,25 @@ pub(crate) fn field_checksum(text: &str) -> u32 {
u32::from_be_bytes(digest[..4].try_into().unwrap())
}
pub(crate) fn guid() -> String {
pub(crate) fn base91_u64() -> String {
anki_base91(rand::random())
}
fn anki_base91(mut n: u64) -> String {
let table = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ\
0123456789!#$%&()*+,-./:;<=>?@[]^_`{|}~";
fn anki_base91(n: u64) -> String {
to_base_n(
n,
b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ\
0123456789!#$%&()*+,-./:;<=>?@[]^_`{|}~",
)
}
pub fn to_base_n(mut n: u64, table: &[u8]) -> String {
let mut buf = String::new();
while n > 0 {
let (q, r) = n.div_rem(&(table.len() as u64));
buf.push(table[r as usize] as char);
n = q;
}
buf.chars().rev().collect()
}

View File

@ -8,10 +8,8 @@ use crate::{
config::DeckConfigKey,
decks::{FilteredDeck, FilteredSearchOrder, FilteredSearchTerm},
error::{CustomStudyError, FilteredDeckError},
pb::{
scheduler::custom_study_request::{cram::CramKind, Cram, Value as CustomStudyValue},
{self as pb},
},
pb,
pb::scheduler::custom_study_request::{cram::CramKind, Cram, Value as CustomStudyValue},
prelude::*,
search::{JoinSearches, Negated, PropertyKind, RatingKind, SearchNode, StateKind},
};

View File

@ -763,7 +763,7 @@ mod test {
fn add_card() {
let tr = I18n::template_only();
let storage =
SqliteStorage::open_or_create(Path::new(":memory:"), &tr, false, false).unwrap();
SqliteStorage::open_or_create(Path::new(":memory:"), &tr, false, false, false).unwrap();
let mut card = Card::default();
storage.add_card(&mut card).unwrap();
let id1 = card.id;

View File

@ -7,7 +7,7 @@ use num_enum::TryFromPrimitive;
use rusqlite::params;
use super::SqliteStorage;
use crate::{prelude::*, sync::Graves};
use crate::{prelude::*, sync::collection::graves::Graves};
#[derive(TryFromPrimitive)]
#[repr(u8)]

View File

@ -19,10 +19,9 @@ mod upgrades;
use std::fmt::Write;
pub(crate) use sqlite::SqliteStorage;
pub(crate) use sync::open_and_check_sqlite_file;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum SchemaVersion {
pub enum SchemaVersion {
V11,
V18,
}

View File

@ -208,6 +208,7 @@ impl SqliteStorage {
path: &Path,
tr: &I18n,
server: bool,
check_integrity: bool,
force_schema11: bool,
) -> Result<Self> {
let db = open_or_create_collection_db(path)?;
@ -227,6 +228,13 @@ impl SqliteStorage {
return Err(AnkiError::db_error("", kind));
}
if check_integrity {
match db.pragma_query_value(None, "integrity_check", |row| row.get::<_, String>(0)) {
Ok(s) => require!(s == "ok", "corrupt: {s}"),
Err(e) => return Err(e.into()),
};
}
let upgrade = ver != SCHEMA_MAX_VERSION;
if create || upgrade {
db.execute("begin exclusive", [])?;

View File

@ -1,9 +1,7 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::path::Path;
use rusqlite::{params, types::FromSql, Connection, ToSql};
use rusqlite::{params, types::FromSql, ToSql};
use super::*;
use crate::prelude::*;
@ -49,9 +47,9 @@ impl SqliteStorage {
&self,
table: &str,
ids: &[I],
new_usn: Option<Usn>,
server_usn_if_client: Option<Usn>,
) -> Result<()> {
if let Some(new_usn) = new_usn {
if let Some(new_usn) = server_usn_if_client {
let mut stmt = self
.db
.prepare_cached(&format!("update {} set usn=? where id=?", table))?;
@ -63,18 +61,13 @@ impl SqliteStorage {
}
}
/// Return error if file is unreadable, fails the sqlite
/// integrity check, or is not in the 'delete' journal mode.
/// On success, returns the opened DB.
pub(crate) fn open_and_check_sqlite_file(path: &Path) -> Result<Connection> {
let db = Connection::open(path)?;
match db.pragma_query_value(None, "integrity_check", |row| row.get::<_, String>(0)) {
Ok(s) => require!(s == "ok", "corrupt: {s}"),
Err(e) => return Err(e.into()),
};
match db.pragma_query_value(None, "journal_mode", |row| row.get::<_, String>(0)) {
Ok(s) if s == "delete" => Ok(db),
Ok(s) => invalid_input!("corrupt: {s}"),
Err(e) => Err(e.into()),
impl Usn {
/// Used when gathering pending objects during sync.
pub(crate) fn pending_object_clause(self) -> &'static str {
if self.0 == -1 {
"usn = ?"
} else {
"usn >= ?"
}
}
}

View File

@ -5,7 +5,7 @@ use super::*;
use crate::{
error::SyncErrorKind,
prelude::*,
sync::{SanityCheckCounts, SanityCheckDueCounts},
sync::collection::sanity::{SanityCheckCounts, SanityCheckDueCounts},
};
impl SqliteStorage {

View File

@ -0,0 +1,327 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
//! The current sync protocol sends changed notetypes, decks, tags and config
//! all in a single request.
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_tuple::Serialize_tuple;
use tracing::{debug, trace};
use crate::{
deckconfig::DeckConfSchema11,
decks::DeckSchema11,
error::SyncErrorKind,
notetype::NotetypeSchema11,
prelude::*,
sync::{
collection::{
normal::{ClientSyncState, NormalSyncProgress, NormalSyncer},
protocol::SyncProtocol,
start::ServerSyncState,
},
request::IntoSyncRequest,
},
tags::Tag,
};
#[derive(Serialize, Deserialize, Debug)]
pub struct ApplyChangesRequest {
pub changes: UnchunkedChanges,
}
#[derive(Serialize, Deserialize, Debug, Default)]
pub struct UnchunkedChanges {
#[serde(rename = "models")]
notetypes: Vec<NotetypeSchema11>,
#[serde(rename = "decks")]
decks_and_config: DecksAndConfig,
tags: Vec<String>,
// the following are only sent if local is newer
#[serde(skip_serializing_if = "Option::is_none", rename = "conf")]
config: Option<HashMap<String, Value>>,
#[serde(skip_serializing_if = "Option::is_none", rename = "crt")]
creation_stamp: Option<TimestampSecs>,
}
#[derive(Serialize_tuple, Deserialize, Debug, Default)]
pub struct DecksAndConfig {
decks: Vec<DeckSchema11>,
config: Vec<DeckConfSchema11>,
}
impl<F> NormalSyncer<'_, F>
where
F: FnMut(NormalSyncProgress, bool),
{
// This was assumed to a cheap operation when originally written - it didn't anticipate
// the large deck trees and note types some users would create. They should be chunked
// in the future, like other objects. Syncing tags explicitly is also probably of limited
// usefulness.
pub(in crate::sync) async fn process_unchunked_changes(
&mut self,
state: &ClientSyncState,
) -> Result<()> {
debug!("gathering local changes");
let local = self.col.local_unchunked_changes(
state.pending_usn,
Some(state.server_usn),
state.local_is_newer,
)?;
debug!(
notetypes = local.notetypes.len(),
decks = local.decks_and_config.decks.len(),
deck_config = local.decks_and_config.config.len(),
tags = local.tags.len(),
"sending"
);
self.progress.local_update += local.notetypes.len()
+ local.decks_and_config.decks.len()
+ local.decks_and_config.config.len()
+ local.tags.len();
let remote = self
.server
.apply_changes(ApplyChangesRequest { changes: local }.try_into_sync_request()?)
.await?
.json()?;
self.fire_progress_cb(true);
debug!(
notetypes = remote.notetypes.len(),
decks = remote.decks_and_config.decks.len(),
deck_config = remote.decks_and_config.config.len(),
tags = remote.tags.len(),
"received"
);
self.progress.remote_update += remote.notetypes.len()
+ remote.decks_and_config.decks.len()
+ remote.decks_and_config.config.len()
+ remote.tags.len();
self.col.apply_changes(remote, state.server_usn)?;
self.fire_progress_cb(true);
Ok(())
}
}
impl Collection {
// Local->remote unchunked changes
//----------------------------------------------------------------
pub(in crate::sync) fn local_unchunked_changes(
&mut self,
pending_usn: Usn,
server_usn_if_client: Option<Usn>,
local_is_newer: bool,
) -> Result<UnchunkedChanges> {
let mut changes = UnchunkedChanges {
notetypes: self.changed_notetypes(pending_usn, server_usn_if_client)?,
decks_and_config: DecksAndConfig {
decks: self.changed_decks(pending_usn, server_usn_if_client)?,
config: self.changed_deck_config(pending_usn, server_usn_if_client)?,
},
tags: self.changed_tags(pending_usn, server_usn_if_client)?,
..Default::default()
};
if local_is_newer {
changes.config = Some(self.changed_config()?);
changes.creation_stamp = Some(self.storage.creation_stamp()?);
}
Ok(changes)
}
fn changed_notetypes(
&mut self,
pending_usn: Usn,
server_usn_if_client: Option<Usn>,
) -> Result<Vec<NotetypeSchema11>> {
let ids = self
.storage
.objects_pending_sync("notetypes", pending_usn)?;
self.storage
.maybe_update_object_usns("notetypes", &ids, server_usn_if_client)?;
self.state.notetype_cache.clear();
ids.into_iter()
.map(|id| {
self.storage.get_notetype(id).map(|opt| {
let mut nt: NotetypeSchema11 = opt.unwrap().into();
nt.usn = server_usn_if_client.unwrap_or(nt.usn);
nt
})
})
.collect()
}
fn changed_decks(
&mut self,
pending_usn: Usn,
server_usn_if_client: Option<Usn>,
) -> Result<Vec<DeckSchema11>> {
let ids = self.storage.objects_pending_sync("decks", pending_usn)?;
self.storage
.maybe_update_object_usns("decks", &ids, server_usn_if_client)?;
self.state.deck_cache.clear();
ids.into_iter()
.map(|id| {
self.storage.get_deck(id).map(|opt| {
let mut deck = opt.unwrap();
deck.usn = server_usn_if_client.unwrap_or(deck.usn);
deck.into()
})
})
.collect()
}
fn changed_deck_config(
&self,
pending_usn: Usn,
server_usn_if_client: Option<Usn>,
) -> Result<Vec<DeckConfSchema11>> {
let ids = self
.storage
.objects_pending_sync("deck_config", pending_usn)?;
self.storage
.maybe_update_object_usns("deck_config", &ids, server_usn_if_client)?;
ids.into_iter()
.map(|id| {
self.storage.get_deck_config(id).map(|opt| {
let mut conf: DeckConfSchema11 = opt.unwrap().into();
conf.usn = server_usn_if_client.unwrap_or(conf.usn);
conf
})
})
.collect()
}
fn changed_tags(
&self,
pending_usn: Usn,
server_usn_if_client: Option<Usn>,
) -> Result<Vec<String>> {
let changed = self.storage.tags_pending_sync(pending_usn)?;
if let Some(usn) = server_usn_if_client {
self.storage.update_tag_usns(&changed, usn)?;
}
Ok(changed)
}
/// Currently this is all config, as legacy clients overwrite the local items
/// with the provided value.
fn changed_config(&self) -> Result<HashMap<String, Value>> {
let conf = self.storage.get_all_config()?;
self.storage.clear_config_usns()?;
Ok(conf)
}
// Remote->local unchunked changes
//----------------------------------------------------------------
pub(in crate::sync) fn apply_changes(
&mut self,
remote: UnchunkedChanges,
latest_usn: Usn,
) -> Result<()> {
self.merge_notetypes(remote.notetypes, latest_usn)?;
self.merge_decks(remote.decks_and_config.decks, latest_usn)?;
self.merge_deck_config(remote.decks_and_config.config)?;
self.merge_tags(remote.tags, latest_usn)?;
if let Some(crt) = remote.creation_stamp {
self.set_creation_stamp(crt)?;
}
if let Some(config) = remote.config {
self.storage
.set_all_config(config, latest_usn, TimestampSecs::now())?;
}
Ok(())
}
fn merge_notetypes(&mut self, notetypes: Vec<NotetypeSchema11>, latest_usn: Usn) -> Result<()> {
for nt in notetypes {
let mut nt: Notetype = nt.into();
let proceed = if let Some(existing_nt) = self.storage.get_notetype(nt.id)? {
if existing_nt.mtime_secs <= nt.mtime_secs {
if (existing_nt.fields.len() != nt.fields.len())
|| (existing_nt.templates.len() != nt.templates.len())
{
return Err(AnkiError::sync_error(
"notetype schema changed",
SyncErrorKind::ResyncRequired,
));
}
true
} else {
false
}
} else {
true
};
if proceed {
self.ensure_notetype_name_unique(&mut nt, latest_usn)?;
self.storage.add_or_update_notetype_with_existing_id(&nt)?;
self.state.notetype_cache.remove(&nt.id);
}
}
Ok(())
}
fn merge_decks(&mut self, decks: Vec<DeckSchema11>, latest_usn: Usn) -> Result<()> {
for deck in decks {
let proceed = if let Some(existing_deck) = self.storage.get_deck(deck.id())? {
existing_deck.mtime_secs <= deck.common().mtime
} else {
true
};
if proceed {
let mut deck = deck.into();
self.ensure_deck_name_unique(&mut deck, latest_usn)?;
self.storage.add_or_update_deck_with_existing_id(&deck)?;
self.state.deck_cache.remove(&deck.id);
}
}
Ok(())
}
fn merge_deck_config(&self, dconf: Vec<DeckConfSchema11>) -> Result<()> {
for conf in dconf {
let proceed = if let Some(existing_conf) = self.storage.get_deck_config(conf.id)? {
existing_conf.mtime_secs <= conf.mtime
} else {
true
};
if proceed {
let conf = conf.into();
self.storage
.add_or_update_deck_config_with_existing_id(&conf)?;
}
}
Ok(())
}
fn merge_tags(&mut self, tags: Vec<String>, latest_usn: Usn) -> Result<()> {
for tag in tags {
self.register_tag(&mut Tag::new(tag, latest_usn))?;
}
Ok(())
}
}
pub fn server_apply_changes(
req: ApplyChangesRequest,
col: &mut Collection,
state: &mut ServerSyncState,
) -> Result<UnchunkedChanges> {
let server_changes =
col.local_unchunked_changes(state.client_usn, None, !state.client_is_newer)?;
trace!(?req.changes, ?server_changes);
col.apply_changes(req.changes, state.server_usn)?;
Ok(server_changes)
}

View File

@ -0,0 +1,431 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use serde_tuple::Serialize_tuple;
use tracing::debug;
use crate::{
card::{Card, CardQueue, CardType},
notes::Note,
prelude::*,
revlog::RevlogEntry,
serde::deserialize_int_from_number,
storage::card::data::{card_data_string, CardData},
sync::{
collection::{
normal::{ClientSyncState, NormalSyncProgress, NormalSyncer},
protocol::{EmptyInput, SyncProtocol},
start::ServerSyncState,
},
request::IntoSyncRequest,
},
tags::{join_tags, split_tags},
};
pub(in crate::sync) struct ChunkableIds {
revlog: Vec<RevlogId>,
cards: Vec<CardId>,
notes: Vec<NoteId>,
}
#[derive(Serialize, Deserialize, Debug, Default)]
pub struct Chunk {
#[serde(default)]
pub done: bool,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
pub revlog: Vec<RevlogEntry>,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
pub cards: Vec<CardEntry>,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
pub notes: Vec<NoteEntry>,
}
#[derive(Serialize_tuple, Deserialize, Debug)]
pub struct NoteEntry {
pub id: NoteId,
pub guid: String,
#[serde(rename = "mid")]
pub ntid: NotetypeId,
#[serde(rename = "mod")]
pub mtime: TimestampSecs,
pub usn: Usn,
pub tags: String,
pub fields: String,
pub sfld: String, // always empty
pub csum: String, // always empty
pub flags: u32,
pub data: String,
}
#[derive(Serialize_tuple, Deserialize, Debug)]
pub struct CardEntry {
pub id: CardId,
pub nid: NoteId,
pub did: DeckId,
pub ord: u16,
#[serde(deserialize_with = "deserialize_int_from_number")]
pub mtime: TimestampSecs,
pub usn: Usn,
pub ctype: CardType,
pub queue: CardQueue,
#[serde(deserialize_with = "deserialize_int_from_number")]
pub due: i32,
#[serde(deserialize_with = "deserialize_int_from_number")]
pub ivl: u32,
pub factor: u16,
pub reps: u32,
pub lapses: u32,
pub left: u32,
#[serde(deserialize_with = "deserialize_int_from_number")]
pub odue: i32,
pub odid: DeckId,
pub flags: u8,
pub data: String,
}
impl<F> NormalSyncer<'_, F>
where
F: FnMut(NormalSyncProgress, bool),
{
pub(in crate::sync) async fn process_chunks_from_server(
&mut self,
state: &ClientSyncState,
) -> Result<()> {
loop {
let chunk = self.server.chunk(EmptyInput::request()).await?.json()?;
debug!(
done = chunk.done,
cards = chunk.cards.len(),
notes = chunk.notes.len(),
revlog = chunk.revlog.len(),
"received"
);
self.progress.remote_update +=
chunk.cards.len() + chunk.notes.len() + chunk.revlog.len();
let done = chunk.done;
self.col.apply_chunk(chunk, state.pending_usn)?;
self.fire_progress_cb(true);
if done {
return Ok(());
}
}
}
pub(in crate::sync) async fn send_chunks_to_server(
&mut self,
state: &ClientSyncState,
) -> Result<()> {
let mut ids = self.col.get_chunkable_ids(state.pending_usn)?;
loop {
let chunk: Chunk = self.col.get_chunk(&mut ids, Some(state.server_usn))?;
let done = chunk.done;
debug!(
done = chunk.done,
cards = chunk.cards.len(),
notes = chunk.notes.len(),
revlog = chunk.revlog.len(),
"sending"
);
self.progress.local_update +=
chunk.cards.len() + chunk.notes.len() + chunk.revlog.len();
self.server
.apply_chunk(ApplyChunkRequest { chunk }.try_into_sync_request()?)
.await?;
self.fire_progress_cb(true);
if done {
return Ok(());
}
}
}
}
impl Collection {
// Remote->local chunks
//----------------------------------------------------------------
/// pending_usn is used to decide whether the local objects are newer.
/// If the provided objects are not modified locally, the USN inside
/// the individual objects is used.
pub(in crate::sync) fn apply_chunk(&mut self, chunk: Chunk, pending_usn: Usn) -> Result<()> {
self.merge_revlog(chunk.revlog)?;
self.merge_cards(chunk.cards, pending_usn)?;
self.merge_notes(chunk.notes, pending_usn)
}
fn merge_revlog(&self, entries: Vec<RevlogEntry>) -> Result<()> {
for entry in entries {
self.storage.add_revlog_entry(&entry, false)?;
}
Ok(())
}
fn merge_cards(&self, entries: Vec<CardEntry>, pending_usn: Usn) -> Result<()> {
for entry in entries {
self.add_or_update_card_if_newer(entry, pending_usn)?;
}
Ok(())
}
fn add_or_update_card_if_newer(&self, entry: CardEntry, pending_usn: Usn) -> Result<()> {
let proceed = if let Some(existing_card) = self.storage.get_card(entry.id)? {
!existing_card.usn.is_pending_sync(pending_usn) || existing_card.mtime < entry.mtime
} else {
true
};
if proceed {
let card = entry.into();
self.storage.add_or_update_card(&card)?;
}
Ok(())
}
fn merge_notes(&mut self, entries: Vec<NoteEntry>, pending_usn: Usn) -> Result<()> {
for entry in entries {
self.add_or_update_note_if_newer(entry, pending_usn)?;
}
Ok(())
}
fn add_or_update_note_if_newer(&mut self, entry: NoteEntry, pending_usn: Usn) -> Result<()> {
let proceed = if let Some(existing_note) = self.storage.get_note(entry.id)? {
!existing_note.usn.is_pending_sync(pending_usn) || existing_note.mtime < entry.mtime
} else {
true
};
if proceed {
let mut note: Note = entry.into();
let nt = self
.get_notetype(note.notetype_id)?
.or_invalid("note missing notetype")?;
note.prepare_for_update(&nt, false)?;
self.storage.add_or_update_note(&note)?;
}
Ok(())
}
// Local->remote chunks
//----------------------------------------------------------------
pub(in crate::sync) fn get_chunkable_ids(&self, pending_usn: Usn) -> Result<ChunkableIds> {
Ok(ChunkableIds {
revlog: self.storage.objects_pending_sync("revlog", pending_usn)?,
cards: self.storage.objects_pending_sync("cards", pending_usn)?,
notes: self.storage.objects_pending_sync("notes", pending_usn)?,
})
}
/// Fetch a chunk of ids from `ids`, returning the referenced objects.
pub(in crate::sync) fn get_chunk(
&self,
ids: &mut ChunkableIds,
server_usn_if_client: Option<Usn>,
) -> Result<Chunk> {
// get a bunch of IDs
let mut limit = CHUNK_SIZE as i32;
let mut revlog_ids = vec![];
let mut card_ids = vec![];
let mut note_ids = vec![];
let mut chunk = Chunk::default();
while limit > 0 {
let last_limit = limit;
if let Some(id) = ids.revlog.pop() {
revlog_ids.push(id);
limit -= 1;
}
if let Some(id) = ids.notes.pop() {
note_ids.push(id);
limit -= 1;
}
if let Some(id) = ids.cards.pop() {
card_ids.push(id);
limit -= 1;
}
if limit == last_limit {
// all empty
break;
}
}
if limit > 0 {
chunk.done = true;
}
// remove pending status
if !self.server {
self.storage
.maybe_update_object_usns("revlog", &revlog_ids, server_usn_if_client)?;
self.storage
.maybe_update_object_usns("cards", &card_ids, server_usn_if_client)?;
self.storage
.maybe_update_object_usns("notes", &note_ids, server_usn_if_client)?;
}
// the fetch associated objects, and return
chunk.revlog = revlog_ids
.into_iter()
.map(|id| {
self.storage.get_revlog_entry(id).map(|e| {
let mut e = e.unwrap();
e.usn = server_usn_if_client.unwrap_or(e.usn);
e
})
})
.collect::<Result<_>>()?;
chunk.cards = card_ids
.into_iter()
.map(|id| {
self.storage.get_card(id).map(|e| {
let mut e: CardEntry = e.unwrap().into();
e.usn = server_usn_if_client.unwrap_or(e.usn);
e
})
})
.collect::<Result<_>>()?;
chunk.notes = note_ids
.into_iter()
.map(|id| {
self.storage.get_note(id).map(|e| {
let mut e: NoteEntry = e.unwrap().into();
e.usn = server_usn_if_client.unwrap_or(e.usn);
e
})
})
.collect::<Result<_>>()?;
Ok(chunk)
}
}
impl From<CardEntry> for Card {
fn from(e: CardEntry) -> Self {
let CardData {
original_position,
custom_data,
} = CardData::from_str(&e.data);
Card {
id: e.id,
note_id: e.nid,
deck_id: e.did,
template_idx: e.ord,
mtime: e.mtime,
usn: e.usn,
ctype: e.ctype,
queue: e.queue,
due: e.due,
interval: e.ivl,
ease_factor: e.factor,
reps: e.reps,
lapses: e.lapses,
remaining_steps: e.left,
original_due: e.odue,
original_deck_id: e.odid,
flags: e.flags,
original_position,
custom_data,
}
}
}
impl From<Card> for CardEntry {
fn from(e: Card) -> Self {
CardEntry {
id: e.id,
nid: e.note_id,
did: e.deck_id,
ord: e.template_idx,
mtime: e.mtime,
usn: e.usn,
ctype: e.ctype,
queue: e.queue,
due: e.due,
ivl: e.interval,
factor: e.ease_factor,
reps: e.reps,
lapses: e.lapses,
left: e.remaining_steps,
odue: e.original_due,
odid: e.original_deck_id,
flags: e.flags,
data: card_data_string(&e),
}
}
}
impl From<NoteEntry> for Note {
fn from(e: NoteEntry) -> Self {
let fields = e.fields.split('\x1f').map(ToString::to_string).collect();
Note::new_from_storage(
e.id,
e.guid,
e.ntid,
e.mtime,
e.usn,
split_tags(&e.tags).map(ToString::to_string).collect(),
fields,
None,
None,
)
}
}
impl From<Note> for NoteEntry {
fn from(e: Note) -> Self {
NoteEntry {
id: e.id,
fields: e.fields().iter().join("\x1f"),
guid: e.guid,
ntid: e.notetype_id,
mtime: e.mtime,
usn: e.usn,
tags: join_tags(&e.tags),
sfld: String::new(),
csum: String::new(),
flags: 0,
data: String::new(),
}
}
}
pub fn server_chunk(col: &mut Collection, state: &mut ServerSyncState) -> Result<Chunk> {
if state.server_chunk_ids.is_none() {
state.server_chunk_ids = Some(col.get_chunkable_ids(state.client_usn)?);
}
col.get_chunk(state.server_chunk_ids.as_mut().unwrap(), None)
}
pub fn server_apply_chunk(
req: ApplyChunkRequest,
col: &mut Collection,
state: &mut ServerSyncState,
) -> Result<()> {
col.apply_chunk(req.chunk, state.client_usn)
}
impl Usn {
pub(crate) fn is_pending_sync(self, pending_usn: Usn) -> bool {
if pending_usn.0 == -1 {
self.0 == -1
} else {
self.0 >= pending_usn.0
}
}
}
pub const CHUNK_SIZE: usize = 250;
#[derive(Serialize, Deserialize, Debug)]
pub struct ApplyChunkRequest {
pub chunk: Chunk,
}

View File

@ -0,0 +1,64 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::{
collection::CollectionBuilder,
io::{atomic_rename, new_tempfile_in_parent_of, read_file, write_file},
prelude::*,
storage::SchemaVersion,
sync::{
collection::{
progress::FullSyncProgressFn,
protocol::{EmptyInput, SyncProtocol},
},
error::{HttpResult, OrHttpErr},
http_client::HttpSyncClient,
login::SyncAuth,
},
};
impl Collection {
/// Download collection from AnkiWeb. Caller must re-open afterwards.
pub async fn full_download(
self,
auth: SyncAuth,
progress_fn: FullSyncProgressFn,
) -> Result<()> {
let mut server = HttpSyncClient::new(auth);
server.set_full_sync_progress_fn(Some(progress_fn));
self.full_download_with_server(server).await
}
pub(crate) async fn full_download_with_server(self, server: HttpSyncClient) -> Result<()> {
let col_path = self.col_path.clone();
let _col_folder = col_path.parent().or_invalid("couldn't get col_folder")?;
self.close(None)?;
let out_data = server.download(EmptyInput::request()).await?.data;
// check file ok
let temp_file = new_tempfile_in_parent_of(&col_path)?;
write_file(temp_file.path(), out_data)?;
let col = CollectionBuilder::new(temp_file.path())
.set_check_integrity(true)
.build()?;
col.storage.db.execute_batch("update col set ls=mod")?;
col.close(None)?;
atomic_rename(temp_file, &col_path, true)?;
Ok(())
}
}
pub fn server_download(
col: &mut Option<Collection>,
schema_version: SchemaVersion,
) -> HttpResult<Vec<u8>> {
let col_path = {
let mut col = col.take().or_internal_err("take col")?;
let path = col.col_path.clone();
col.transact_no_undo(|col| col.storage.increment_usn())
.or_internal_err("incr usn")?;
col.close(Some(schema_version)).or_internal_err("close")?;
path
};
let data = read_file(col_path).or_internal_err("read col")?;
Ok(data)
}

View File

@ -0,0 +1,43 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::{
prelude::*,
sync::collection::{
normal::{ClientSyncState, NormalSyncProgress, NormalSyncer},
protocol::{EmptyInput, SyncProtocol},
},
};
impl<F> NormalSyncer<'_, F>
where
F: FnMut(NormalSyncProgress, bool),
{
pub(in crate::sync) async fn finalize(&mut self, state: &ClientSyncState) -> Result<()> {
let new_server_mtime = self.server.finish(EmptyInput::request()).await?.json()?;
self.col.finalize_sync(state, new_server_mtime)
}
}
impl Collection {
fn finalize_sync(
&self,
state: &ClientSyncState,
new_server_mtime: TimestampMillis,
) -> Result<()> {
self.storage.set_last_sync(new_server_mtime)?;
let mut usn = state.server_usn;
usn.0 += 1;
self.storage.set_usn(usn)?;
self.storage.set_modified_time(new_server_mtime)
}
}
pub fn server_finish(col: &mut Collection) -> Result<TimestampMillis> {
let now = TimestampMillis::now();
col.storage.set_last_sync(now)?;
col.storage.increment_usn()?;
col.storage.commit_rust_trx()?;
col.storage.set_modified_time(now)?;
Ok(now)
}

View File

@ -0,0 +1,71 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use serde::{Deserialize, Serialize};
use crate::{
prelude::*,
sync::collection::{chunks::CHUNK_SIZE, start::ServerSyncState},
};
#[derive(Serialize, Deserialize, Debug, Default, Clone)]
pub struct ApplyGravesRequest {
pub chunk: Graves,
}
#[derive(Serialize, Deserialize, Debug, Default, Clone)]
pub struct Graves {
pub(crate) cards: Vec<CardId>,
pub(crate) decks: Vec<DeckId>,
pub(crate) notes: Vec<NoteId>,
}
impl Graves {
pub(in crate::sync) fn take_chunk(&mut self) -> Option<Graves> {
let mut limit = CHUNK_SIZE;
let mut out = Graves::default();
while limit > 0 && !self.cards.is_empty() {
out.cards.push(self.cards.pop().unwrap());
limit -= 1;
}
while limit > 0 && !self.notes.is_empty() {
out.notes.push(self.notes.pop().unwrap());
limit -= 1;
}
while limit > 0 && !self.decks.is_empty() {
out.decks.push(self.decks.pop().unwrap());
limit -= 1;
}
if limit == CHUNK_SIZE {
None
} else {
Some(out)
}
}
}
impl Collection {
pub fn apply_graves(&self, graves: Graves, latest_usn: Usn) -> Result<()> {
for nid in graves.notes {
self.storage.remove_note(nid)?;
self.storage.add_note_grave(nid, latest_usn)?;
}
for cid in graves.cards {
self.storage.remove_card(cid)?;
self.storage.add_card_grave(cid, latest_usn)?;
}
for did in graves.decks {
self.storage.remove_deck(did)?;
self.storage.add_deck_grave(did, latest_usn)?;
}
Ok(())
}
}
pub fn server_apply_graves(
req: ApplyGravesRequest,
col: &mut Collection,
state: &mut ServerSyncState,
) -> Result<()> {
col.apply_graves(req.chunk, state.server_usn)
}

View File

@ -0,0 +1,170 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use ammonia::Url;
use axum::http::StatusCode;
use serde::{Deserialize, Serialize};
use tracing::debug;
use crate::{
config::SchedulerVersion,
prelude::*,
sync::{
collection::{
normal::{ClientSyncState, SyncActionRequired},
protocol::SyncProtocol,
},
error::{HttpError, HttpResult, HttpSnafu, OrHttpErr},
http_client::HttpSyncClient,
request::{IntoSyncRequest, SyncRequest},
version::{
SYNC_VERSION_09_V2_SCHEDULER, SYNC_VERSION_10_V2_TIMEZONE, SYNC_VERSION_MAX,
SYNC_VERSION_MIN,
},
},
version::sync_client_version,
};
#[derive(Serialize, Deserialize, Debug, Default)]
pub struct SyncMeta {
#[serde(rename = "mod")]
pub modified: TimestampMillis,
#[serde(rename = "scm")]
pub schema: TimestampMillis,
pub usn: Usn,
#[serde(rename = "ts")]
pub current_time: TimestampSecs,
#[serde(rename = "msg")]
pub server_message: String,
#[serde(rename = "cont")]
pub should_continue: bool,
/// Used by clients prior to sync version 11
#[serde(rename = "hostNum")]
pub host_number: u32,
#[serde(default)]
pub empty: bool,
#[serde(skip)]
pub v2_scheduler_or_later: bool,
#[serde(skip)]
pub v2_timezone: bool,
}
impl SyncMeta {
pub(in crate::sync) fn compared_to_remote(
&self,
remote: SyncMeta,
new_endpoint: Option<String>,
) -> ClientSyncState {
let local = self;
let required = if remote.modified == local.modified {
SyncActionRequired::NoChanges
} else if remote.schema != local.schema {
let upload_ok = !local.empty || remote.empty;
let download_ok = !remote.empty || local.empty;
SyncActionRequired::FullSyncRequired {
upload_ok,
download_ok,
}
} else {
SyncActionRequired::NormalSyncRequired
};
ClientSyncState {
required,
local_is_newer: local.modified > remote.modified,
usn_at_last_sync: local.usn,
server_usn: remote.usn,
pending_usn: Usn(-1),
server_message: remote.server_message,
host_number: remote.host_number,
new_endpoint,
}
}
}
impl HttpSyncClient {
/// Fetch server meta. Returns a new endpoint if one was provided.
pub(in crate::sync) async fn meta_with_redirect(
&mut self,
) -> Result<(SyncMeta, Option<String>)> {
let mut new_endpoint = None;
let response = match self.meta(MetaRequest::request()).await {
Ok(remote) => remote,
Err(HttpError {
code: StatusCode::PERMANENT_REDIRECT,
context,
..
}) => {
debug!(endpoint = context, "redirect to new location");
let url = Url::try_from(context.as_str())
.or_bad_request("couldn't parse new location")?;
new_endpoint = Some(context);
self.endpoint = url;
self.meta(MetaRequest::request()).await?
}
err => err?,
};
let remote = response.json()?;
Ok((remote, new_endpoint))
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct MetaRequest {
#[serde(rename = "v")]
pub sync_version: u8,
#[serde(rename = "cv")]
pub client_version: String,
}
impl Collection {
pub fn sync_meta(&self) -> Result<SyncMeta> {
let stamps = self.storage.get_collection_timestamps()?;
Ok(SyncMeta {
modified: stamps.collection_change,
schema: stamps.schema_change,
// server=true is used for the client case as well, as we
// want the actual usn and not -1
usn: self.storage.usn(true)?,
current_time: TimestampSecs::now(),
server_message: "".into(),
should_continue: true,
host_number: 0,
empty: !self.storage.have_at_least_one_card()?,
v2_scheduler_or_later: self.scheduler_version() == SchedulerVersion::V2,
v2_timezone: self.get_creation_utc_offset().is_some(),
})
}
}
pub fn server_meta(req: MetaRequest, col: &mut Collection) -> HttpResult<SyncMeta> {
if !matches!(req.sync_version, SYNC_VERSION_MIN..=SYNC_VERSION_MAX) {
return HttpSnafu {
// old clients expected this code
code: StatusCode::NOT_IMPLEMENTED,
context: "unsupported version",
source: None,
}
.fail();
}
let mut meta = col.sync_meta().or_internal_err("sync meta")?;
if meta.v2_scheduler_or_later && req.sync_version < SYNC_VERSION_09_V2_SCHEDULER {
meta.server_message = "Your client does not support the v2 scheduler".into();
meta.should_continue = false;
} else if meta.v2_timezone && req.sync_version < SYNC_VERSION_10_V2_TIMEZONE {
meta.server_message = "Your client does not support the new timezone handling.".into();
meta.should_continue = false;
}
Ok(meta)
}
impl MetaRequest {
pub fn request() -> SyncRequest<Self> {
MetaRequest {
sync_version: SYNC_VERSION_MAX,
client_version: sync_client_version().into(),
}
.try_into_sync_request()
.expect("infallible meta request")
}
}

View File

@ -0,0 +1,17 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
pub mod changes;
pub mod chunks;
pub mod download;
pub mod finish;
pub mod graves;
pub mod meta;
pub mod normal;
pub mod progress;
pub mod protocol;
pub mod sanity;
pub mod start;
pub mod status;
pub mod tests;
pub mod upload;

View File

@ -0,0 +1,182 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use tracing::debug;
use crate::{
collection::Collection,
error,
error::{AnkiError, SyncError, SyncErrorKind},
prelude::Usn,
sync::{
collection::{
progress::SyncStage,
protocol::{EmptyInput, SyncProtocol},
status::online_sync_status_check,
},
http_client::HttpSyncClient,
login::SyncAuth,
},
};
pub struct NormalSyncer<'a, F> {
pub(in crate::sync) col: &'a mut Collection,
pub(in crate::sync) server: HttpSyncClient,
pub(in crate::sync) progress: NormalSyncProgress,
pub(in crate::sync) progress_fn: F,
}
#[derive(Default, Debug, Clone, Copy)]
pub struct NormalSyncProgress {
pub stage: SyncStage,
pub local_update: usize,
pub local_remove: usize,
pub remote_update: usize,
pub remote_remove: usize,
}
#[derive(PartialEq, Eq, Debug, Clone, Copy)]
pub enum SyncActionRequired {
NoChanges,
FullSyncRequired { upload_ok: bool, download_ok: bool },
NormalSyncRequired,
}
#[derive(Debug)]
pub struct ClientSyncState {
pub required: SyncActionRequired,
pub server_message: String,
pub host_number: u32,
pub new_endpoint: Option<String>,
pub(in crate::sync) local_is_newer: bool,
pub(in crate::sync) usn_at_last_sync: Usn,
// latest server usn; local -1 entries will be rewritten to this
pub(in crate::sync) server_usn: Usn,
// -1 in client case; used to locate pending entries
pub(in crate::sync) pending_usn: Usn,
}
impl<F> NormalSyncer<'_, F>
where
F: FnMut(NormalSyncProgress, bool),
{
pub fn new(col: &mut Collection, server: HttpSyncClient, progress_fn: F) -> NormalSyncer<'_, F>
where
F: FnMut(NormalSyncProgress, bool),
{
NormalSyncer {
col,
server,
progress: NormalSyncProgress::default(),
progress_fn,
}
}
pub(in crate::sync) fn fire_progress_cb(&mut self, throttle: bool) {
(self.progress_fn)(self.progress, throttle)
}
pub async fn sync(&mut self) -> error::Result<SyncOutput> {
debug!("fetching meta...");
self.fire_progress_cb(false);
let local = self.col.sync_meta()?;
let state = online_sync_status_check(local, &mut self.server).await?;
debug!(?state, "fetched");
match state.required {
SyncActionRequired::NoChanges => Ok(state.into()),
SyncActionRequired::FullSyncRequired { .. } => Ok(state.into()),
SyncActionRequired::NormalSyncRequired => {
self.col.discard_undo_and_study_queues();
let timing = self.col.timing_today()?;
self.col.unbury_if_day_rolled_over(timing)?;
self.col.storage.begin_trx()?;
match self.normal_sync_inner(state).await {
Ok(success) => {
self.col.storage.commit_trx()?;
Ok(success)
}
Err(e) => {
self.col.storage.rollback_trx()?;
let _ = self.server.abort(EmptyInput::request()).await;
if let AnkiError::SyncError {
source:
SyncError {
kind: SyncErrorKind::SanityCheckFailed { client, server },
..
},
} = &e
{
debug!(?client, ?server, "sanity check failed");
self.col.set_schema_modified()?;
}
Err(e)
}
}
}
}
}
/// Sync. Caller must have created a transaction, and should call
/// abort on failure.
async fn normal_sync_inner(&mut self, mut state: ClientSyncState) -> error::Result<SyncOutput> {
self.progress.stage = SyncStage::Syncing;
self.fire_progress_cb(false);
debug!("start");
self.start_and_process_deletions(&state).await?;
debug!("unchunked changes");
self.process_unchunked_changes(&state).await?;
debug!("begin stream from server");
self.process_chunks_from_server(&state).await?;
debug!("begin stream to server");
self.send_chunks_to_server(&state).await?;
self.progress.stage = SyncStage::Finalizing;
self.fire_progress_cb(false);
debug!("sanity check");
self.sanity_check().await?;
debug!("finalize");
self.finalize(&state).await?;
state.required = SyncActionRequired::NoChanges;
Ok(state.into())
}
}
#[derive(Debug)]
pub struct SyncOutput {
pub required: SyncActionRequired,
pub server_message: String,
pub host_number: u32,
pub new_endpoint: Option<String>,
}
impl From<ClientSyncState> for SyncOutput {
fn from(s: ClientSyncState) -> Self {
SyncOutput {
required: s.required,
server_message: s.server_message,
host_number: s.host_number,
new_endpoint: s.new_endpoint,
}
}
}
impl Collection {
pub async fn normal_sync<F>(
&mut self,
auth: SyncAuth,
progress_fn: F,
) -> error::Result<SyncOutput>
where
F: FnMut(NormalSyncProgress, bool),
{
NormalSyncer::new(self, HttpSyncClient::new(auth), progress_fn)
.sync()
.await
}
}

View File

@ -0,0 +1,39 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::{
error,
sync::{
collection::protocol::{EmptyInput, SyncProtocol},
http_client::HttpSyncClient,
login::SyncAuth,
},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SyncStage {
Connecting,
Syncing,
Finalizing,
}
impl Default for SyncStage {
fn default() -> Self {
SyncStage::Connecting
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct FullSyncProgress {
pub transferred_bytes: usize,
pub total_bytes: usize,
}
pub async fn sync_abort(auth: SyncAuth) -> error::Result<()> {
HttpSyncClient::new(auth)
.abort(EmptyInput::request())
.await?
.json()
}
pub type FullSyncProgressFn = Box<dyn FnMut(FullSyncProgress, bool) + Send + Sync + 'static>;

View File

@ -0,0 +1,110 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::marker::PhantomData;
use ammonia::Url;
use async_trait::async_trait;
use serde_derive::{Deserialize, Serialize};
use strum::IntoStaticStr;
use crate::{
prelude::TimestampMillis,
sync::{
collection::{
changes::{ApplyChangesRequest, UnchunkedChanges},
chunks::{ApplyChunkRequest, Chunk},
graves::{ApplyGravesRequest, Graves},
meta::{MetaRequest, SyncMeta},
sanity::{SanityCheckRequest, SanityCheckResponse},
start::StartRequest,
upload::UploadResponse,
},
error::HttpResult,
login::{HostKeyRequest, HostKeyResponse},
request::{IntoSyncRequest, SyncRequest},
response::SyncResponse,
},
};
#[derive(IntoStaticStr, Deserialize, PartialEq, Eq, Debug)]
#[serde(rename_all = "camelCase")]
#[strum(serialize_all = "camelCase")]
pub enum SyncMethod {
HostKey,
Meta,
Start,
ApplyGraves,
ApplyChanges,
Chunk,
ApplyChunk,
SanityCheck2,
Finish,
Abort,
Upload,
Download,
}
pub trait AsSyncEndpoint: Into<&'static str> {
fn as_sync_endpoint(&self, base: &Url) -> Url;
}
impl AsSyncEndpoint for SyncMethod {
fn as_sync_endpoint(&self, base: &Url) -> Url {
base.join("sync/").unwrap().join(self.into()).unwrap()
}
}
#[async_trait]
pub trait SyncProtocol: Send + Sync + 'static {
async fn host_key(
&self,
req: SyncRequest<HostKeyRequest>,
) -> HttpResult<SyncResponse<HostKeyResponse>>;
async fn meta(&self, req: SyncRequest<MetaRequest>) -> HttpResult<SyncResponse<SyncMeta>>;
async fn start(&self, req: SyncRequest<StartRequest>) -> HttpResult<SyncResponse<Graves>>;
async fn apply_graves(
&self,
req: SyncRequest<ApplyGravesRequest>,
) -> HttpResult<SyncResponse<()>>;
async fn apply_changes(
&self,
req: SyncRequest<ApplyChangesRequest>,
) -> HttpResult<SyncResponse<UnchunkedChanges>>;
async fn chunk(&self, req: SyncRequest<EmptyInput>) -> HttpResult<SyncResponse<Chunk>>;
async fn apply_chunk(
&self,
req: SyncRequest<ApplyChunkRequest>,
) -> HttpResult<SyncResponse<()>>;
async fn sanity_check(
&self,
req: SyncRequest<SanityCheckRequest>,
) -> HttpResult<SyncResponse<SanityCheckResponse>>;
async fn finish(
&self,
req: SyncRequest<EmptyInput>,
) -> HttpResult<SyncResponse<TimestampMillis>>;
async fn abort(&self, req: SyncRequest<EmptyInput>) -> HttpResult<SyncResponse<()>>;
async fn upload(&self, req: SyncRequest<Vec<u8>>) -> HttpResult<SyncResponse<UploadResponse>>;
async fn download(&self, req: SyncRequest<EmptyInput>) -> HttpResult<SyncResponse<Vec<u8>>>;
}
/// The sync protocol expects '{}' to be sent in requests without args.
/// Serde serializes/deserializes empty structs as 'null', so we add an empty value
/// to cause it to produce a map instead. This only applies to inputs; empty outputs
/// are returned as ()/null.
#[derive(Serialize, Deserialize, Default)]
#[serde(deny_unknown_fields)]
pub struct EmptyInput {
#[serde(default)]
_pad: PhantomData<()>,
}
impl EmptyInput {
pub(crate) fn request() -> SyncRequest<Self> {
Self::default()
.try_into_sync_request()
// should be infallible
.expect("empty input into request")
}
}

View File

@ -0,0 +1,129 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use serde::{Deserialize, Serialize};
use serde_tuple::Serialize_tuple;
use tracing::{debug, info};
use crate::{
error::SyncErrorKind,
prelude::*,
serde::default_on_invalid,
sync::{
collection::{
normal::{NormalSyncProgress, NormalSyncer},
protocol::SyncProtocol,
},
request::IntoSyncRequest,
},
};
#[derive(Serialize, Deserialize, Debug)]
pub struct SanityCheckResponse {
pub status: SanityCheckStatus,
#[serde(rename = "c", default, deserialize_with = "default_on_invalid")]
pub client: Option<SanityCheckCounts>,
#[serde(rename = "s", default, deserialize_with = "default_on_invalid")]
pub server: Option<SanityCheckCounts>,
}
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum SanityCheckStatus {
Ok,
Bad,
}
#[derive(Serialize_tuple, Deserialize, Debug, PartialEq, Eq)]
pub struct SanityCheckCounts {
pub counts: SanityCheckDueCounts,
pub cards: u32,
pub notes: u32,
pub revlog: u32,
pub graves: u32,
#[serde(rename = "models")]
pub notetypes: u32,
pub decks: u32,
pub deck_config: u32,
}
#[derive(Serialize_tuple, Deserialize, Debug, Default, PartialEq, Eq)]
pub struct SanityCheckDueCounts {
pub new: u32,
pub learn: u32,
pub review: u32,
}
impl<F> NormalSyncer<'_, F>
where
F: FnMut(NormalSyncProgress, bool),
{
/// Caller should force full sync after rolling back.
pub(in crate::sync) async fn sanity_check(&mut self) -> Result<()> {
let local_counts = self.col.storage.sanity_check_info()?;
debug!("gathered local counts; waiting for server reply");
let SanityCheckResponse {
status,
client,
server,
} = self
.server
.sanity_check(
SanityCheckRequest {
client: local_counts,
}
.try_into_sync_request()?,
)
.await?
.json()?;
debug!("got server reply");
if status != SanityCheckStatus::Ok {
Err(AnkiError::sync_error(
"",
SyncErrorKind::SanityCheckFailed { client, server },
))
} else {
Ok(())
}
}
}
pub fn server_sanity_check(
SanityCheckRequest { mut client }: SanityCheckRequest,
col: &mut Collection,
) -> Result<SanityCheckResponse> {
let mut server = match col.storage.sanity_check_info() {
Ok(info) => info,
Err(err) => {
info!(?client, ?err, "sanity check failed");
return Ok(SanityCheckResponse {
status: SanityCheckStatus::Bad,
client: Some(client),
server: None,
});
}
};
client.counts = Default::default();
// clients on schema 17 and below may send duplicate
// deletion markers, so we can't compare graves until
// the minimum syncing version is schema 18.
client.graves = 0;
server.graves = 0;
Ok(SanityCheckResponse {
status: if client == server {
SanityCheckStatus::Ok
} else {
info!(?client, ?server, "sanity check failed");
SanityCheckStatus::Bad
},
client: Some(client),
server: Some(server),
})
}
#[derive(Serialize, Deserialize, Debug)]
pub struct SanityCheckRequest {
pub client: SanityCheckCounts,
}

View File

@ -0,0 +1,186 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use serde::{Deserialize, Deserializer, Serialize};
use tracing::debug;
use crate::{
prelude::*,
sync::{
collection::{
chunks::ChunkableIds,
graves::{ApplyGravesRequest, Graves},
normal::{ClientSyncState, NormalSyncProgress, NormalSyncer},
protocol::SyncProtocol,
},
request::IntoSyncRequest,
},
};
impl<F> NormalSyncer<'_, F>
where
F: FnMut(NormalSyncProgress, bool),
{
pub(in crate::sync) async fn start_and_process_deletions(
&mut self,
state: &ClientSyncState,
) -> Result<()> {
let remote: Graves = self
.server
.start(
StartRequest {
client_usn: state.usn_at_last_sync,
local_is_newer: state.local_is_newer,
deprecated_client_graves: None,
}
.try_into_sync_request()?,
)
.await?
.json()?;
debug!(
cards = remote.cards.len(),
notes = remote.notes.len(),
decks = remote.decks.len(),
"removed on remote"
);
let mut local = self.col.storage.pending_graves(state.pending_usn)?;
self.col
.storage
.update_pending_grave_usns(state.server_usn)?;
debug!(
cards = local.cards.len(),
notes = local.notes.len(),
decks = local.decks.len(),
"locally removed "
);
while let Some(chunk) = local.take_chunk() {
debug!("sending graves chunk");
self.progress.local_remove += chunk.cards.len() + chunk.notes.len() + chunk.decks.len();
self.server
.apply_graves(ApplyGravesRequest { chunk }.try_into_sync_request()?)
.await?;
self.fire_progress_cb(true);
}
self.progress.remote_remove = remote.cards.len() + remote.notes.len() + remote.decks.len();
self.col.apply_graves(remote, state.server_usn)?;
self.fire_progress_cb(true);
debug!("applied server graves");
Ok(())
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StartRequest {
#[serde(rename = "minUsn")]
pub client_usn: Usn,
#[serde(rename = "lnewer")]
pub local_is_newer: bool,
/// Used by old clients, and still used by AnkiDroid.
#[serde(rename = "graves", default, deserialize_with = "legacy_graves")]
pub deprecated_client_graves: Option<Graves>,
}
pub fn server_start(
req: StartRequest,
col: &mut Collection,
state: &mut ServerSyncState,
) -> Result<Graves> {
state.server_usn = col.usn()?;
state.client_usn = req.client_usn;
state.client_is_newer = req.local_is_newer;
col.discard_undo_and_study_queues();
col.storage.begin_rust_trx()?;
// make sure any pending cards have been unburied first if necessary
let timing = col.timing_today()?;
col.unbury_if_day_rolled_over(timing)?;
// fetch local graves
let server_graves = col.storage.pending_graves(state.client_usn)?;
// handle AnkiDroid using old protocol
if let Some(graves) = req.deprecated_client_graves {
col.apply_graves(graves, state.server_usn)?;
}
Ok(server_graves)
}
/// The current sync protocol is stateful, so unfortunately we need to
/// retain a bunch of information across requests. These are set either
/// on start, or on subsequent methods.
pub struct ServerSyncState {
/// The session key. This is sent on every http request, but is ignored for methods
/// where there is not active sync state.
pub skey: String,
pub(in crate::sync) server_usn: Usn,
pub(in crate::sync) client_usn: Usn,
/// Only used to determine whether we should send our
/// config to client.
pub(in crate::sync) client_is_newer: bool,
/// Set on the first call to chunk()
pub(in crate::sync) server_chunk_ids: Option<ChunkableIds>,
}
impl ServerSyncState {
pub fn new(skey: impl Into<String>) -> Self {
Self {
skey: skey.into(),
server_usn: Default::default(),
client_usn: Default::default(),
client_is_newer: false,
server_chunk_ids: None,
}
}
}
pub(crate) fn legacy_graves<'de, D>(deserializer: D) -> Result<Option<Graves>, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum GraveType {
Normal(Graves),
Legacy(StringGraves),
Null,
}
match GraveType::deserialize(deserializer)? {
GraveType::Normal(normal) => Ok(Some(normal)),
GraveType::Legacy(stringly) => Ok(Some(Graves {
cards: string_list_to_ids(stringly.cards)?,
decks: string_list_to_ids(stringly.decks)?,
notes: string_list_to_ids(stringly.notes)?,
})),
GraveType::Null => Ok(None),
}
}
// old AnkiMobile versions
#[derive(Deserialize)]
struct StringGraves {
cards: Vec<String>,
decks: Vec<String>,
notes: Vec<String>,
}
fn string_list_to_ids<T, E>(list: Vec<String>) -> Result<Vec<T>, E>
where
T: From<i64>,
E: serde::de::Error,
{
list.into_iter()
.map(|s| {
s.parse::<i64>()
.map_err(serde::de::Error::custom)
.map(Into::into)
})
.collect::<Result<Vec<T>, E>>()
}

View File

@ -0,0 +1,58 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use tracing::debug;
use crate::{
error::SyncErrorKind,
pb::sync::sync_status_response,
prelude::*,
sync::{
collection::{meta::SyncMeta, normal::ClientSyncState},
http_client::HttpSyncClient,
},
};
impl Collection {
/// Checks local collection only. If local collection is clean but changes are pending
/// on AnkiWeb, NoChanges will be returned.
pub fn sync_status_offline(&mut self) -> Result<sync_status_response::Required> {
let stamps = self.storage.get_collection_timestamps()?;
let required = if stamps.schema_changed_since_sync() {
sync_status_response::Required::FullSync
} else if stamps.collection_changed_since_sync() {
sync_status_response::Required::NormalSync
} else {
sync_status_response::Required::NoChanges
};
Ok(required)
}
}
/// Should be called if a call to sync_status_offline() returns NoChanges, to check
/// if AnkiWeb has pending changes. Caller should persist new endpoint if returned.
///
/// This routine is outside of the collection, as we don't want to block collection access
/// for a potentially slow network request that happens in the background.
pub async fn online_sync_status_check(
local: SyncMeta,
server: &mut HttpSyncClient,
) -> Result<ClientSyncState, AnkiError> {
let (remote, new_endpoint) = server.meta_with_redirect().await?;
debug!(?remote, "meta");
debug!(?local, "meta");
if !remote.should_continue {
debug!(remote.server_message, "server says abort");
return Err(AnkiError::sync_error(
remote.server_message,
SyncErrorKind::ServerMessage,
));
}
let delta = remote.current_time.0 - local.current_time.0;
if delta.abs() > 300 {
debug!(delta, "clock off");
return Err(AnkiError::sync_error("", SyncErrorKind::ClockIncorrect));
}
Ok(local.compared_to_remote(remote, new_endpoint))
}

View File

@ -0,0 +1,752 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
#![cfg(test)]
use std::future::Future;
use axum::http::StatusCode;
use once_cell::sync::Lazy;
use reqwest::Url;
use serde_json::json;
use tempfile::{tempdir, TempDir};
use tokio::sync::{Mutex, MutexGuard};
use tracing::{Instrument, Span};
use wiremock::{
matchers::{method, path},
Mock, MockServer, ResponseTemplate,
};
use crate::{
card::CardQueue,
collection::CollectionBuilder,
deckconfig::DeckConfig,
decks::DeckKind,
error::{SyncError, SyncErrorKind},
log::set_global_logger,
notetype::all_stock_notetypes,
prelude::*,
revlog::RevlogEntry,
search::SortMode,
sync::{
collection::{
graves::ApplyGravesRequest,
meta::MetaRequest,
normal::{NormalSyncProgress, NormalSyncer, SyncActionRequired, SyncOutput},
progress::FullSyncProgress,
protocol::{EmptyInput, SyncProtocol},
start::StartRequest,
upload::{UploadResponse, CORRUPT_MESSAGE},
},
http_client::HttpSyncClient,
http_server::SimpleServer,
login::{HostKeyRequest, SyncAuth},
request::IntoSyncRequest,
},
};
struct TestAuth {
username: String,
password: String,
host_key: String,
}
static AUTH: Lazy<TestAuth> = Lazy::new(|| {
if let Ok(auth) = std::env::var("TEST_AUTH") {
let mut auth = auth.split(':');
TestAuth {
username: auth.next().unwrap().into(),
password: auth.next().unwrap().into(),
host_key: auth.next().unwrap().into(),
}
} else {
TestAuth {
username: "user".to_string(),
password: "pass".to_string(),
host_key: "b2619aa1529dfdc4248e6edbf3c1b2a2b014cf6d".to_string(),
}
}
});
pub(in crate::sync) async fn with_active_server<F, O>(op: F) -> Result<()>
where
F: FnOnce(HttpSyncClient) -> O,
O: Future<Output = Result<()>>,
{
let _ = set_global_logger(None);
// start server
let base_folder = tempdir()?;
std::env::set_var("SYNC_USER1", "user:pass");
let (addr, server_fut) = SimpleServer::make_server(None, base_folder.path()).unwrap();
tokio::spawn(server_fut.instrument(Span::current()));
// when not using ephemeral servers, tests need to be serialized
static LOCK: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
let _lock: MutexGuard<()>;
// setup client to connect to it
let endpoint = if let Ok(endpoint) = std::env::var("TEST_ENDPOINT") {
_lock = LOCK.lock().await;
endpoint
} else {
format!("http://{}/", addr)
};
let endpoint = Url::try_from(endpoint.as_str()).unwrap();
let auth = SyncAuth {
hkey: AUTH.host_key.clone(),
endpoint: Some(endpoint),
};
let client = HttpSyncClient::new(auth);
op(client).await
}
fn unwrap_sync_err_kind(err: AnkiError) -> SyncErrorKind {
let AnkiError::SyncError { source: SyncError { kind, .. } } = err else {
panic!("not sync err: {err:?}");
};
kind
}
fn norm_progress(_: NormalSyncProgress, _: bool) {}
fn full_progress(_: FullSyncProgress, _: bool) {}
#[tokio::test]
async fn host_key() -> Result<()> {
with_active_server(|mut client| async move {
let err = client
.host_key(
HostKeyRequest {
username: "bad".to_string(),
password: "bad".to_string(),
}
.try_into_sync_request()?,
)
.await
.unwrap_err();
assert_eq!(err.code, StatusCode::FORBIDDEN);
assert_eq!(
unwrap_sync_err_kind(AnkiError::from(err)),
SyncErrorKind::AuthFailed
);
// hkey should be automatically set after successful login
client.sync_key = String::new();
let resp = client
.host_key(
HostKeyRequest {
username: AUTH.username.clone(),
password: AUTH.password.clone(),
}
.try_into_sync_request()?,
)
.await?
.json()?;
assert_eq!(resp.key, *AUTH.host_key);
Ok(())
})
.await
}
#[tokio::test]
async fn meta() -> Result<()> {
with_active_server(|client| async move {
// unsupported sync version
assert_eq!(
SyncProtocol::meta(
&client,
MetaRequest {
sync_version: 0,
client_version: "".to_string(),
}
.try_into_sync_request()?,
)
.await
.unwrap_err()
.code,
StatusCode::NOT_IMPLEMENTED
);
Ok(())
})
.await
}
#[tokio::test]
async fn aborting_is_idempotent() -> Result<()> {
with_active_server(|mut client| async move {
// abort is a no-op if no sync in progress
client.abort(EmptyInput::request()).await?;
// start a sync
let _graves = client
.start(
StartRequest {
client_usn: Default::default(),
local_is_newer: false,
deprecated_client_graves: None,
}
.try_into_sync_request()?,
)
.await?;
// an abort request with the wrong key is ignored
let orig_key = client.skey().to_string();
client.set_skey("aabbccdd".into());
client.abort(EmptyInput::request()).await?;
// it should succeed with the correct key
client.set_skey(orig_key);
client.abort(EmptyInput::request()).await?;
Ok(())
})
.await
}
#[tokio::test]
async fn new_syncs_cancel_old_ones() -> Result<()> {
with_active_server(|mut client| async move {
let ctx = SyncTestContext::new(client.partial_clone());
// start a sync
let req = StartRequest {
client_usn: Default::default(),
local_is_newer: false,
deprecated_client_graves: None,
}
.try_into_sync_request()?;
let _ = client.start(req.clone()).await?;
// a new sync aborts the previous one
let orig_key = client.skey().to_string();
client.set_skey("1".into());
let _ = client.start(req.clone()).await?;
// old sync can no longer proceed
client.set_skey(orig_key);
let graves_req = ApplyGravesRequest::default().try_into_sync_request()?;
assert_eq!(
client
.apply_graves(graves_req.clone())
.await
.unwrap_err()
.code,
StatusCode::CONFLICT
);
// with the correct key, it can continue
client.set_skey("1".into());
client.apply_graves(graves_req.clone()).await?;
// but a full upload will break the lock
ctx.full_upload(ctx.col1()).await;
assert_eq!(
client
.apply_graves(graves_req.clone())
.await
.unwrap_err()
.code,
StatusCode::CONFLICT
);
// likewise with download
let _ = client.start(req.clone()).await?;
ctx.full_download(ctx.col1()).await;
assert_eq!(
client
.apply_graves(graves_req.clone())
.await
.unwrap_err()
.code,
StatusCode::CONFLICT
);
Ok(())
})
.await
}
#[tokio::test]
async fn sync_roundtrip() -> Result<()> {
with_active_server(|client| async move {
let ctx = SyncTestContext::new(client);
upload_download(&ctx).await?;
regular_sync(&ctx).await?;
Ok(())
})
.await
}
#[tokio::test]
async fn sanity_check_should_roll_back_and_force_full_sync() -> Result<()> {
with_active_server(|client| async move {
let ctx = SyncTestContext::new(client);
upload_download(&ctx).await?;
let mut col1 = ctx.col1();
// add a deck but don't mark it as requiring a sync, which will trigger the sanity
// check to fail
let mut deck = col1.get_or_create_normal_deck("unsynced deck")?;
col1.add_or_update_deck(&mut deck)?;
col1.storage
.db
.execute("update decks set usn=0 where id=?", [deck.id])?;
// the sync should fail
let err = NormalSyncer::new(&mut col1, ctx.cloned_client(), norm_progress)
.sync()
.await
.unwrap_err();
assert!(matches!(
err,
AnkiError::SyncError {
source: SyncError {
kind: SyncErrorKind::SanityCheckFailed { .. },
..
}
}
));
// the server should have rolled back
let mut col2 = ctx.col2();
let out = ctx.normal_sync(&mut col2).await;
assert_eq!(out.required, SyncActionRequired::NoChanges);
// and the client should have forced a one-way sync
let out = ctx.normal_sync(&mut col1).await;
assert_eq!(
out.required,
SyncActionRequired::FullSyncRequired {
upload_ok: true,
download_ok: true,
}
);
Ok(())
})
.await
}
#[tokio::test]
async fn sync_errors_should_prompt_db_check() -> Result<()> {
with_active_server(|client| async move {
let ctx = SyncTestContext::new(client);
upload_download(&ctx).await?;
let mut col1 = ctx.col1();
// Add a a new notetype, and a note that uses it, but don't mark the notetype as
// requiring a sync, which will cause the sync to fail as the note is added.
let mut nt = all_stock_notetypes(&col1.tr).remove(0);
nt.name = "new".into();
col1.add_notetype(&mut nt, false)?;
let mut note = nt.new_note();
note.set_field(0, "test")?;
col1.add_note(&mut note, DeckId(1))?;
col1.storage.db.execute("update notetypes set usn=0", [])?;
// the sync should fail
let err = NormalSyncer::new(&mut col1, ctx.cloned_client(), norm_progress)
.sync()
.await
.unwrap_err();
let AnkiError::SyncError { source: SyncError { info: _, kind } } = err else { panic!() };
assert_eq!(kind, SyncErrorKind::DatabaseCheckRequired);
// the server should have rolled back
let mut col2 = ctx.col2();
let out = ctx.normal_sync(&mut col2).await;
assert_eq!(out.required, SyncActionRequired::NoChanges);
// and the client should be able to sync again without a forced one-way sync
let err = NormalSyncer::new(&mut col1, ctx.cloned_client(), norm_progress)
.sync()
.await
.unwrap_err();
let AnkiError::SyncError { source: SyncError { info: _, kind } } = err else { panic!() };
assert_eq!(kind, SyncErrorKind::DatabaseCheckRequired);
Ok(())
})
.await
}
/// Old AnkiMobile versions sent grave ids as strings
#[tokio::test]
async fn string_grave_ids_are_handled() -> Result<()> {
with_active_server(|client| async move {
let req = json!({
"minUsn": 0,
"lnewer": false,
"graves": {
"cards": vec!["1"],
"decks": vec!["2", "3"],
"notes": vec!["4"],
}
});
let req = serde_json::to_vec(&req)
.unwrap()
.try_into_sync_request()
.unwrap();
// should not return err 400
client.start(req.into_output_type()).await.unwrap();
client.abort(EmptyInput::request()).await?;
Ok(())
})
.await?;
// a missing value should be handled
with_active_server(|client| async move {
let req = json!({
"minUsn": 0,
"lnewer": false,
});
let req = serde_json::to_vec(&req)
.unwrap()
.try_into_sync_request()
.unwrap();
client.start(req.into_output_type()).await.unwrap();
client.abort(EmptyInput::request()).await?;
Ok(())
})
.await
}
#[tokio::test]
async fn invalid_uploads_should_be_handled() -> Result<()> {
with_active_server(|client| async move {
let mut ctx = SyncTestContext::new(client);
ctx.client
.set_full_sync_progress_fn(Some(Box::new(full_progress)));
let res = ctx
.client
.upload(b"fake data".to_vec().try_into_sync_request()?)
.await?;
assert_eq!(
res.upload_response(),
UploadResponse::Err(CORRUPT_MESSAGE.into())
);
Ok(())
})
.await
}
#[tokio::test]
async fn meta_redirect_is_handled() -> Result<()> {
with_active_server(|client| async move {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/sync/meta"))
.respond_with(
ResponseTemplate::new(308).insert_header("location", client.endpoint.as_str()),
)
.mount(&mock_server)
.await;
// starting from in-sync state
let mut ctx = SyncTestContext::new(client);
upload_download(&ctx).await?;
// add another note to trigger a normal sync
let mut col1 = ctx.col1();
col1_setup(&mut col1);
// switch to bad endpoint
let orig_url = ctx.client.endpoint.to_string();
ctx.client.endpoint = Url::try_from(mock_server.uri().as_str()).unwrap();
// sync should succeed
let out = ctx.normal_sync(&mut col1).await;
// client should have received new endpoint
assert_eq!(out.new_endpoint, Some(orig_url));
// client should not have tried the old endpoint more than once
assert_eq!(mock_server.received_requests().await.unwrap().len(), 1);
Ok(())
})
.await
}
pub(in crate::sync) struct SyncTestContext {
pub folder: TempDir,
pub client: HttpSyncClient,
}
impl SyncTestContext {
pub fn new(client: HttpSyncClient) -> Self {
Self {
folder: tempdir().expect("create temp dir"),
client,
}
}
pub fn col1(&self) -> Collection {
let base = self.folder.path();
CollectionBuilder::new(base.join("col1.anki2"))
.set_media_paths(base.join("col1.media"), base.join("col1.media.db"))
.build()
.unwrap()
}
pub fn col2(&self) -> Collection {
let base = self.folder.path();
CollectionBuilder::new(base.join("col2.anki2"))
.set_media_paths(base.join("col2.media"), base.join("col2.media.db"))
.build()
.unwrap()
}
async fn normal_sync(&self, col: &mut Collection) -> SyncOutput {
NormalSyncer::new(col, self.cloned_client(), norm_progress)
.sync()
.await
.unwrap()
}
async fn full_upload(&self, col: Collection) {
col.full_upload_with_server(self.cloned_client())
.await
.unwrap()
}
async fn full_download(&self, col: Collection) {
col.full_download_with_server(self.cloned_client())
.await
.unwrap()
}
fn cloned_client(&self) -> HttpSyncClient {
let mut client = self.client.partial_clone();
client.set_full_sync_progress_fn(Some(Box::new(full_progress)));
client
}
}
// Setup + full syncs
/////////////////////
fn col1_setup(col: &mut Collection) {
let nt = col.get_notetype_by_name("Basic").unwrap().unwrap();
let mut note = nt.new_note();
note.set_field(0, "1").unwrap();
col.add_note(&mut note, DeckId(1)).unwrap();
}
async fn upload_download(ctx: &SyncTestContext) -> Result<()> {
let mut col1 = ctx.col1();
col1_setup(&mut col1);
let out = ctx.normal_sync(&mut col1).await;
assert!(matches!(
out.required,
SyncActionRequired::FullSyncRequired { .. }
));
ctx.full_upload(col1).await;
// another collection
let mut col2 = ctx.col2();
// won't allow ankiweb clobber
let out = ctx.normal_sync(&mut col2).await;
assert_eq!(
out.required,
SyncActionRequired::FullSyncRequired {
upload_ok: false,
download_ok: true,
}
);
// fetch so we're in sync
ctx.full_download(col2).await;
Ok(())
}
// Regular syncs
/////////////////////
async fn regular_sync(ctx: &SyncTestContext) -> Result<()> {
// add a deck
let mut col1 = ctx.col1();
let mut col2 = ctx.col2();
let mut deck = col1.get_or_create_normal_deck("new deck")?;
// give it a new option group
let mut dconf = DeckConfig {
name: "new dconf".into(),
..Default::default()
};
col1.add_or_update_deck_config(&mut dconf)?;
if let DeckKind::Normal(deck) = &mut deck.kind {
deck.config_id = dconf.id.0;
}
col1.add_or_update_deck(&mut deck)?;
// and a new notetype
let mut nt = all_stock_notetypes(&col1.tr).remove(0);
nt.name = "new".into();
col1.add_notetype(&mut nt, false)?;
// add another note+card+tag
let mut note = nt.new_note();
note.set_field(0, "2")?;
note.tags.push("tag".into());
col1.add_note(&mut note, deck.id)?;
// mock revlog entry
col1.storage.add_revlog_entry(
&RevlogEntry {
id: RevlogId(123),
cid: CardId(456),
usn: Usn(-1),
interval: 10,
..Default::default()
},
true,
)?;
// config + creation
col1.set_config("test", &"test1")?;
// bumping this will affect 'last studied at' on decks at the moment
// col1.storage.set_creation_stamp(TimestampSecs(12345))?;
// and sync our changes
let remote_meta = ctx
.client
.meta(MetaRequest::request())
.await
.unwrap()
.json()
.unwrap();
let out = col1.sync_meta()?.compared_to_remote(remote_meta, None);
assert_eq!(out.required, SyncActionRequired::NormalSyncRequired);
let out = ctx.normal_sync(&mut col1).await;
assert_eq!(out.required, SyncActionRequired::NoChanges);
// sync the other collection
let out = ctx.normal_sync(&mut col2).await;
assert_eq!(out.required, SyncActionRequired::NoChanges);
let ntid = nt.id;
let deckid = deck.id;
let dconfid = dconf.id;
let noteid = note.id;
let cardid = col1.search_cards(note.id, SortMode::NoOrder)?[0];
let revlogid = RevlogId(123);
let compare_sides = |col1: &mut Collection, col2: &mut Collection| -> Result<()> {
assert_eq!(
col1.get_notetype(ntid)?.unwrap(),
col2.get_notetype(ntid)?.unwrap()
);
assert_eq!(
col1.get_deck(deckid)?.unwrap(),
col2.get_deck(deckid)?.unwrap()
);
assert_eq!(
col1.get_deck_config(dconfid, false)?.unwrap(),
col2.get_deck_config(dconfid, false)?.unwrap()
);
assert_eq!(
col1.storage.get_note(noteid)?.unwrap(),
col2.storage.get_note(noteid)?.unwrap()
);
assert_eq!(
col1.storage.get_card(cardid)?.unwrap(),
col2.storage.get_card(cardid)?.unwrap()
);
assert_eq!(
col1.storage.get_revlog_entry(revlogid)?,
col2.storage.get_revlog_entry(revlogid)?,
);
assert_eq!(
col1.storage.get_all_config()?,
col2.storage.get_all_config()?
);
assert_eq!(
col1.storage.creation_stamp()?,
col2.storage.creation_stamp()?
);
// server doesn't send tag usns, so we can only compare tags, not usns,
// as the usns may not match
assert_eq!(
col1.storage
.all_tags()?
.into_iter()
.map(|t| t.name)
.collect::<Vec<_>>(),
col2.storage
.all_tags()?
.into_iter()
.map(|t| t.name)
.collect::<Vec<_>>()
);
std::thread::sleep(std::time::Duration::from_millis(1));
Ok(())
};
// make sure everything has been transferred across
compare_sides(&mut col1, &mut col2)?;
// make some modifications
let mut note = col2.storage.get_note(note.id)?.unwrap();
note.set_field(1, "new")?;
note.tags.push("tag2".into());
col2.update_note(&mut note)?;
col2.get_and_update_card(cardid, |card| {
card.queue = CardQueue::Review;
Ok(())
})?;
let mut deck = col2.storage.get_deck(deck.id)?.unwrap();
deck.name = NativeDeckName::from_native_str("newer");
col2.add_or_update_deck(&mut deck)?;
let mut nt = col2.storage.get_notetype(nt.id)?.unwrap();
nt.name = "newer".into();
col2.update_notetype(&mut nt, false)?;
// sync the changes back
let out = ctx.normal_sync(&mut col2).await;
assert_eq!(out.required, SyncActionRequired::NoChanges);
let out = ctx.normal_sync(&mut col1).await;
assert_eq!(out.required, SyncActionRequired::NoChanges);
// should still match
compare_sides(&mut col1, &mut col2)?;
// deletions should sync too
for table in &["cards", "notes", "decks"] {
assert_eq!(
col1.storage
.db_scalar::<u8>(&format!("select count() from {}", table))?,
2
);
}
// fixme: inconsistent usn arg
std::thread::sleep(std::time::Duration::from_millis(1));
col1.remove_cards_and_orphaned_notes(&[cardid])?;
let usn = col1.usn()?;
col1.remove_note_only_undoable(noteid, usn)?;
col1.remove_decks_and_child_decks(&[deckid])?;
let out = ctx.normal_sync(&mut col1).await;
assert_eq!(out.required, SyncActionRequired::NoChanges);
let out = ctx.normal_sync(&mut col2).await;
assert_eq!(out.required, SyncActionRequired::NoChanges);
for table in &["cards", "notes", "decks"] {
assert_eq!(
col2.storage
.db_scalar::<u8>(&format!("select count() from {}", table))?,
1
);
}
// removing things like a notetype forces a full sync
std::thread::sleep(std::time::Duration::from_millis(1));
col2.remove_notetype(ntid)?;
let out = ctx.normal_sync(&mut col2).await;
assert!(matches!(
out.required,
SyncActionRequired::FullSyncRequired { .. }
));
Ok(())
}

View File

@ -0,0 +1,136 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::{fs, io::Write};
use axum::response::{IntoResponse, Response};
use flate2::{write::GzEncoder, Compression};
use futures::StreamExt;
use tokio_util::io::ReaderStream;
use crate::{
collection::CollectionBuilder,
error::SyncErrorKind,
io::{atomic_rename, new_tempfile_in_parent_of, write_file},
prelude::*,
storage::SchemaVersion,
sync::{
collection::{progress::FullSyncProgressFn, protocol::SyncProtocol},
error::{HttpResult, OrHttpErr},
http_client::HttpSyncClient,
login::SyncAuth,
request::{IntoSyncRequest, MAXIMUM_SYNC_PAYLOAD_BYTES_UNCOMPRESSED},
},
};
/// Old clients didn't display a useful message on HTTP 400, and were expected to show the error message
/// returned by the server.
pub const CORRUPT_MESSAGE: &str =
"Your upload was corrupt. Please use Check Database, or restore from backup.";
impl Collection {
/// Upload collection to AnkiWeb. Caller must re-open afterwards.
pub async fn full_upload(self, auth: SyncAuth, progress_fn: FullSyncProgressFn) -> Result<()> {
let mut server = HttpSyncClient::new(auth);
server.set_full_sync_progress_fn(Some(progress_fn));
self.full_upload_with_server(server).await
}
pub(crate) async fn full_upload_with_server(mut self, server: HttpSyncClient) -> Result<()> {
self.before_upload()?;
let col_path = self.col_path.clone();
self.close(Some(SchemaVersion::V18))?;
let col_data = fs::read(&col_path)?;
let total_bytes = col_data.len();
if server.endpoint.as_str().contains("ankiweb") {
check_upload_limit(
total_bytes,
*MAXIMUM_SYNC_PAYLOAD_BYTES_UNCOMPRESSED as usize,
)?;
}
match server
.upload(col_data.try_into_sync_request()?)
.await?
.upload_response()
{
UploadResponse::Ok => Ok(()),
UploadResponse::Err(msg) => {
Err(AnkiError::sync_error(msg, SyncErrorKind::ServerMessage))
}
}
}
}
/// Collection must already be open, and will be replaced on success.
pub fn handle_received_upload(
col: &mut Option<Collection>,
new_data: Vec<u8>,
) -> HttpResult<UploadResponse> {
let max_bytes = *MAXIMUM_SYNC_PAYLOAD_BYTES_UNCOMPRESSED as usize;
if new_data.len() >= max_bytes {
return Ok(UploadResponse::Err("collection exceeds size limit".into()));
}
let path = col
.as_ref()
.or_internal_err("col was closed")?
.col_path
.clone();
// write to temp file
let temp_file = new_tempfile_in_parent_of(&path).or_internal_err("temp file")?;
write_file(temp_file.path(), &new_data).or_internal_err("temp file")?;
// check the collection is valid
if let Err(err) = CollectionBuilder::new(temp_file.path())
.set_check_integrity(true)
.build()
{
tracing::info!(?err, "uploaded file was corrupt/failed to open");
return Ok(UploadResponse::Err(CORRUPT_MESSAGE.into()));
}
// close collection and rename
if let Some(col) = col.take() {
col.close(None)
.or_internal_err("closing current collection")?;
}
atomic_rename(temp_file, &path, true).or_internal_err("rename upload")?;
Ok(UploadResponse::Ok)
}
impl IntoResponse for UploadResponse {
fn into_response(self) -> Response {
match self {
// the legacy protocol expects this exact string
UploadResponse::Ok => "OK".to_string(),
UploadResponse::Err(e) => e,
}
.into_response()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum UploadResponse {
Ok,
Err(String),
}
pub fn check_upload_limit(size: usize, limit: usize) -> Result<()> {
if size >= limit {
Err(AnkiError::sync_error(
format!("{size} > {limit}"),
SyncErrorKind::UploadTooLarge,
))
} else {
Ok(())
}
}
pub async fn gzipped_data_from_vec(vec: Vec<u8>) -> Result<Vec<u8>> {
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
let mut stream = ReaderStream::new(&vec[..]);
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
encoder.write_all(&chunk)?;
}
encoder.finish().map_err(Into::into)
}

150
rslib/src/sync/error.rs Normal file
View File

@ -0,0 +1,150 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use axum::{
http::StatusCode,
response::{IntoResponse, Redirect, Response},
};
use snafu::{OptionExt, Snafu};
pub type HttpResult<T, E = HttpError> = std::result::Result<T, E>;
#[derive(Debug, Snafu)]
#[snafu(visibility(pub))]
pub struct HttpError {
pub code: StatusCode,
pub context: String,
// snafu's automatic error conversion only supports Option if
// the whatever trait is derived, and deriving whatever means we
// can't have extra fields like `code`. Even without Option, the
// error conversion requires us to manually box the error, so we end
// up having to disable the default behaviour and add the error to the
// snafu ourselves
#[snafu(source(false))]
pub source: Option<Box<dyn std::error::Error + Send + Sync>>,
}
impl HttpError {
pub fn new_without_source(code: StatusCode, context: impl Into<String>) -> Self {
Self {
code,
context: context.into(),
source: None,
}
}
/// Compatibility with ensure!() macro
pub fn fail<T>(self) -> Result<T, Self> {
Err(self)
}
}
impl IntoResponse for HttpError {
fn into_response(self) -> Response {
let HttpError {
code,
context,
source,
} = self;
if code.is_server_error() && code != StatusCode::NOT_IMPLEMENTED {
tracing::error!(context, ?source, httpstatus = code.as_u16(),);
} else {
tracing::info!(context, ?source, httpstatus = code.as_u16(),);
}
if code == StatusCode::PERMANENT_REDIRECT {
Redirect::permanent(&context).into_response()
} else {
(code, code.as_str().to_string()).into_response()
}
}
}
pub trait OrHttpErr {
type Value;
fn or_http_err(
self,
code: StatusCode,
context: impl Into<String>,
) -> Result<Self::Value, HttpError>;
fn or_bad_request(self, context: impl Into<String>) -> Result<Self::Value, HttpError>
where
Self: Sized,
{
self.or_http_err(StatusCode::BAD_REQUEST, context)
}
fn or_internal_err(self, context: impl Into<String>) -> Result<Self::Value, HttpError>
where
Self: Sized,
{
self.or_http_err(StatusCode::INTERNAL_SERVER_ERROR, context)
}
fn or_forbidden(self, context: impl Into<String>) -> Result<Self::Value, HttpError>
where
Self: Sized,
{
self.or_http_err(StatusCode::FORBIDDEN, context)
}
fn or_conflict(self, context: impl Into<String>) -> Result<Self::Value, HttpError>
where
Self: Sized,
{
self.or_http_err(StatusCode::CONFLICT, context)
}
fn or_not_found(self, context: impl Into<String>) -> Result<Self::Value, HttpError>
where
Self: Sized,
{
self.or_http_err(StatusCode::NOT_FOUND, context)
}
fn or_permanent_redirect(self, context: impl Into<String>) -> Result<Self::Value, HttpError>
where
Self: Sized,
{
self.or_http_err(StatusCode::PERMANENT_REDIRECT, context)
}
}
impl<T, E> OrHttpErr for Result<T, E>
where
E: std::error::Error + Send + Sync + 'static,
{
type Value = T;
fn or_http_err(
self,
code: StatusCode,
context: impl Into<String>,
) -> Result<Self::Value, HttpError> {
self.map_err(|err| {
HttpSnafu {
code,
context: context.into(),
source: Some(Box::new(err) as _),
}
.build()
})
}
}
impl<T> OrHttpErr for Option<T> {
type Value = T;
fn or_http_err(
self,
code: StatusCode,
context: impl Into<String>,
) -> Result<Self::Value, HttpError> {
self.context(HttpSnafu {
code,
context,
source: None,
})
}
}

View File

@ -1,124 +0,0 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::path::PathBuf;
use serde::{Deserialize, Serialize};
use super::{Chunk, Graves, SanityCheckCounts, UnchunkedChanges};
use crate::{io::read_file, pb::sync::sync_server_method_request::Method, prelude::*};
#[derive(Serialize, Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
pub enum SyncRequest {
HostKey(HostKeyRequest),
Meta(MetaRequest),
Start(StartRequest),
ApplyGraves(ApplyGravesRequest),
ApplyChanges(ApplyChangesRequest),
Chunk,
ApplyChunk(ApplyChunkRequest),
#[serde(rename = "sanityCheck2")]
SanityCheck(SanityCheckRequest),
Finish,
Abort,
#[serde(rename = "upload")]
FullUpload(PathBuf),
#[serde(rename = "download")]
FullDownload,
}
impl SyncRequest {
/// Return method name and payload bytes.
pub(crate) fn into_method_and_data(self) -> Result<(&'static str, Vec<u8>)> {
use serde_json::to_vec;
Ok(match self {
SyncRequest::HostKey(v) => ("hostKey", to_vec(&v)?),
SyncRequest::Meta(v) => ("meta", to_vec(&v)?),
SyncRequest::Start(v) => ("start", to_vec(&v)?),
SyncRequest::ApplyGraves(v) => ("applyGraves", to_vec(&v)?),
SyncRequest::ApplyChanges(v) => ("applyChanges", to_vec(&v)?),
SyncRequest::Chunk => ("chunk", b"{}".to_vec()),
SyncRequest::ApplyChunk(v) => ("applyChunk", to_vec(&v)?),
SyncRequest::SanityCheck(v) => ("sanityCheck2", to_vec(&v)?),
SyncRequest::Finish => ("finish", b"{}".to_vec()),
SyncRequest::Abort => ("abort", b"{}".to_vec()),
SyncRequest::FullUpload(v) => {
// fixme: stream in the data instead, in a different call
("upload", read_file(&v)?)
}
SyncRequest::FullDownload => ("download", b"{}".to_vec()),
})
}
pub(crate) fn from_method_and_data(method: Method, data: Vec<u8>) -> Result<Self> {
use serde_json::from_slice;
Ok(match method {
Method::HostKey => SyncRequest::HostKey(from_slice(&data)?),
Method::Meta => SyncRequest::Meta(from_slice(&data)?),
Method::Start => SyncRequest::Start(from_slice(&data)?),
Method::ApplyGraves => SyncRequest::ApplyGraves(from_slice(&data)?),
Method::ApplyChanges => SyncRequest::ApplyChanges(from_slice(&data)?),
Method::Chunk => SyncRequest::Chunk,
Method::ApplyChunk => SyncRequest::ApplyChunk(from_slice(&data)?),
Method::SanityCheck => SyncRequest::SanityCheck(from_slice(&data)?),
Method::Finish => SyncRequest::Finish,
Method::Abort => SyncRequest::Abort,
Method::FullUpload => {
let path = PathBuf::from(String::from_utf8(data).expect("path was not in utf8"));
SyncRequest::FullUpload(path)
}
Method::FullDownload => SyncRequest::FullDownload,
})
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct HostKeyRequest {
#[serde(rename = "u")]
pub username: String,
#[serde(rename = "p")]
pub password: String,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct HostKeyResponse {
pub key: String,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct MetaRequest {
#[serde(rename = "v")]
pub sync_version: u8,
#[serde(rename = "cv")]
pub client_version: String,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct StartRequest {
#[serde(rename = "minUsn")]
pub client_usn: Usn,
#[serde(rename = "lnewer")]
pub local_is_newer: bool,
/// Unfortunately AnkiDroid is still using this
#[serde(rename = "graves", default)]
pub deprecated_client_graves: Option<Graves>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ApplyGravesRequest {
pub chunk: Graves,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ApplyChangesRequest {
pub changes: UnchunkedChanges,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ApplyChunkRequest {
pub chunk: Chunk,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct SanityCheckRequest {
pub client: SanityCheckCounts,
}

View File

@ -1,499 +0,0 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::{
env,
io::{prelude::*, Cursor},
mem::MaybeUninit,
path::Path,
time::Duration,
};
use async_trait::async_trait;
use bytes::Bytes;
use flate2::{write::GzEncoder, Compression};
use futures::{Stream, StreamExt};
use lazy_static::lazy_static;
use reqwest::{multipart, Body, Client, Response};
use serde::de::DeserializeOwned;
use tempfile::NamedTempFile;
use tokio_util::io::ReaderStream;
use super::{
http::{
ApplyChangesRequest, ApplyChunkRequest, ApplyGravesRequest, HostKeyRequest,
HostKeyResponse, MetaRequest, SanityCheckRequest, StartRequest, SyncRequest,
},
server::SyncServer,
Chunk, FullSyncProgress, Graves, SanityCheckCounts, SanityCheckResponse, SyncMeta,
UnchunkedChanges, SYNC_VERSION_MAX,
};
use crate::{
error::SyncErrorKind,
io::{new_tempfile, new_tempfile_in},
notes::guid,
prelude::*,
version::sync_client_version,
};
lazy_static! {
// These limits are enforced server-side, but are made adjustable for users
// who are using a custom sync server.
static ref MAXIMUM_UPLOAD_MEGS_UNCOMPRESSED: usize = env::var("MAX_UPLOAD_MEGS_UNCOMP")
.map(|v| v.parse().expect("invalid upload limit"))
.unwrap_or(250);
static ref MAXIMUM_UPLOAD_MEGS_COMPRESSED: usize = env::var("MAX_UPLOAD_MEGS_COMP")
.map(|v| v.parse().expect("invalid upload limit"))
.unwrap_or(100);
}
pub type FullSyncProgressFn = Box<dyn FnMut(FullSyncProgress, bool) + Send + Sync + 'static>;
pub struct HttpSyncClient {
hkey: Option<String>,
skey: String,
client: Client,
endpoint: String,
full_sync_progress_fn: Option<FullSyncProgressFn>,
}
pub struct Timeouts {
pub connect_secs: u64,
pub request_secs: u64,
pub io_secs: u64,
}
impl Timeouts {
pub fn new() -> Self {
let io_secs = if env::var("LONG_IO_TIMEOUT").is_ok() {
3600
} else {
300
};
Timeouts {
connect_secs: 30,
/// This is smaller than the I/O limit because it is just a
/// default - some longer-running requests override it.
request_secs: 60,
io_secs,
}
}
}
#[async_trait(?Send)]
impl SyncServer for HttpSyncClient {
async fn meta(&self) -> Result<SyncMeta> {
let input = SyncRequest::Meta(MetaRequest {
sync_version: SYNC_VERSION_MAX,
client_version: sync_client_version().to_string(),
});
self.json_request(input).await
}
async fn start(
&mut self,
client_usn: Usn,
local_is_newer: bool,
deprecated_client_graves: Option<Graves>,
) -> Result<Graves> {
let input = SyncRequest::Start(StartRequest {
client_usn,
local_is_newer,
deprecated_client_graves,
});
self.json_request(input).await
}
async fn apply_graves(&mut self, chunk: Graves) -> Result<()> {
let input = SyncRequest::ApplyGraves(ApplyGravesRequest { chunk });
self.json_request(input).await
}
async fn apply_changes(&mut self, changes: UnchunkedChanges) -> Result<UnchunkedChanges> {
let input = SyncRequest::ApplyChanges(ApplyChangesRequest { changes });
self.json_request(input).await
}
async fn chunk(&mut self) -> Result<Chunk> {
let input = SyncRequest::Chunk;
self.json_request(input).await
}
async fn apply_chunk(&mut self, chunk: Chunk) -> Result<()> {
let input = SyncRequest::ApplyChunk(ApplyChunkRequest { chunk });
self.json_request(input).await
}
async fn sanity_check(&mut self, client: SanityCheckCounts) -> Result<SanityCheckResponse> {
let input = SyncRequest::SanityCheck(SanityCheckRequest { client });
self.json_request(input).await
}
async fn finish(&mut self) -> Result<TimestampMillis> {
let input = SyncRequest::Finish;
self.json_request(input).await
}
async fn abort(&mut self) -> Result<()> {
let input = SyncRequest::Abort;
self.json_request(input).await
}
async fn full_upload(mut self: Box<Self>, col_path: &Path, _can_consume: bool) -> Result<()> {
let file = tokio::fs::File::open(col_path).await?;
let total_bytes = file.metadata().await?.len() as usize;
check_upload_limit(total_bytes, *MAXIMUM_UPLOAD_MEGS_UNCOMPRESSED)?;
let compressed_data: Vec<u8> = gzipped_data_from_tokio_file(file).await?;
let compressed_size = compressed_data.len();
check_upload_limit(compressed_size, *MAXIMUM_UPLOAD_MEGS_COMPRESSED)?;
let progress_fn = self
.full_sync_progress_fn
.take()
.expect("progress func was not set");
let with_progress = ProgressWrapper {
reader: Cursor::new(compressed_data),
progress_fn,
progress: FullSyncProgress {
transferred_bytes: 0,
total_bytes: compressed_size,
},
};
let body = Body::wrap_stream(with_progress);
self.upload_inner(body).await?;
Ok(())
}
/// Download collection into a temporary file, returning it. Caller should
/// persist the file in the correct path after checking it. Progress func
/// must be set first. The caller should pass the collection's folder in as
/// the temp folder if it wishes to atomically .persist() it.
async fn full_download(
mut self: Box<Self>,
col_folder: Option<&Path>,
) -> Result<NamedTempFile> {
let mut temp_file = if let Some(folder) = col_folder {
new_tempfile_in(folder)
} else {
new_tempfile()
}?;
let (size, mut stream) = self.download_inner().await?;
let mut progress = FullSyncProgress {
transferred_bytes: 0,
total_bytes: size,
};
let mut progress_fn = self
.full_sync_progress_fn
.take()
.expect("progress func was not set");
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
temp_file.write_all(&chunk)?;
progress.transferred_bytes += chunk.len();
progress_fn(progress, true);
}
progress_fn(progress, false);
Ok(temp_file)
}
}
async fn gzipped_data_from_tokio_file(file: tokio::fs::File) -> Result<Vec<u8>> {
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
let mut stream = ReaderStream::new(file);
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
encoder.write_all(&chunk)?;
}
encoder.finish().map_err(Into::into)
}
fn check_upload_limit(size: usize, limit_mb: usize) -> Result<()> {
let size_mb = size / 1024 / 1024;
if size_mb >= limit_mb {
Err(AnkiError::sync_error(
format!("{}MB > {}MB", size_mb, limit_mb),
SyncErrorKind::UploadTooLarge,
))
} else {
Ok(())
}
}
impl HttpSyncClient {
pub fn new(hkey: Option<String>, host_number: u32) -> HttpSyncClient {
let timeouts = Timeouts::new();
let client = Client::builder()
.connect_timeout(Duration::from_secs(timeouts.connect_secs))
.timeout(Duration::from_secs(timeouts.request_secs))
.io_timeout(Duration::from_secs(timeouts.io_secs))
.build()
.unwrap();
let skey = guid();
let endpoint = sync_endpoint(host_number);
HttpSyncClient {
hkey,
skey,
client,
endpoint,
full_sync_progress_fn: None,
}
}
pub fn set_full_sync_progress_fn(&mut self, func: Option<FullSyncProgressFn>) {
self.full_sync_progress_fn = func;
}
async fn json_request<T>(&self, req: SyncRequest) -> Result<T>
where
T: DeserializeOwned,
{
let (method, req_json) = req.into_method_and_data()?;
self.request_bytes(method, &req_json, false)
.await?
.json()
.await
.map_err(Into::into)
}
async fn request_bytes(
&self,
method: &str,
req: &[u8],
timeout_long: bool,
) -> Result<Response> {
let mut gz = GzEncoder::new(Vec::new(), Compression::fast());
gz.write_all(req)?;
let part = multipart::Part::bytes(gz.finish()?);
let resp = self.request(method, part, timeout_long).await?;
resp.error_for_status().map_err(Into::into)
}
async fn request(
&self,
method: &str,
data_part: multipart::Part,
timeout_long: bool,
) -> Result<Response> {
let data_part = data_part.file_name("data");
let mut form = multipart::Form::new()
.part("data", data_part)
.text("c", "1");
if let Some(hkey) = &self.hkey {
form = form.text("k", hkey.clone()).text("s", self.skey.clone());
}
let url = format!("{}{}", self.endpoint, method);
let mut req = self.client.post(&url).multipart(form);
if timeout_long {
req = req.timeout(Duration::from_secs(60 * 60));
}
req.send().await?.error_for_status().map_err(Into::into)
}
pub(crate) async fn login<S: Into<String>>(&mut self, username: S, password: S) -> Result<()> {
let input = SyncRequest::HostKey(HostKeyRequest {
username: username.into(),
password: password.into(),
});
let output: HostKeyResponse = self.json_request(input).await?;
self.hkey = Some(output.key);
Ok(())
}
pub(crate) fn hkey(&self) -> &str {
self.hkey.as_ref().unwrap()
}
async fn download_inner(
&self,
) -> Result<(
usize,
impl Stream<Item = std::result::Result<Bytes, reqwest::Error>>,
)> {
let resp: Response = self.request_bytes("download", b"{}", true).await?;
let len = resp.content_length().unwrap_or_default();
Ok((len as usize, resp.bytes_stream()))
}
async fn upload_inner(&self, body: Body) -> Result<()> {
let data_part = multipart::Part::stream(body);
let resp = self.request("upload", data_part, true).await?;
resp.error_for_status_ref()?;
let text = resp.text().await?;
if text != "OK" {
Err(AnkiError::sync_error(text, SyncErrorKind::Other))
} else {
Ok(())
}
}
}
use std::pin::Pin;
use futures::{
ready,
task::{Context, Poll},
};
use pin_project::pin_project;
use tokio::io::{AsyncRead, ReadBuf};
#[pin_project]
struct ProgressWrapper<S, P> {
#[pin]
reader: S,
progress_fn: P,
progress: FullSyncProgress,
}
impl<S, P> Stream for ProgressWrapper<S, P>
where
S: AsyncRead,
P: FnMut(FullSyncProgress, bool),
{
type Item = std::result::Result<Bytes, std::io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut buf = [MaybeUninit::<u8>::uninit(); 8192];
let mut buf = ReadBuf::uninit(&mut buf);
let this = self.project();
let res = ready!(this.reader.poll_read(cx, &mut buf));
match res {
Ok(()) => {
let filled = buf.filled().to_vec();
Poll::Ready(if filled.is_empty() {
(this.progress_fn)(*this.progress, false);
None
} else {
this.progress.transferred_bytes += filled.len();
(this.progress_fn)(*this.progress, true);
Some(Ok(Bytes::from(filled)))
})
}
Err(e) => Poll::Ready(Some(Err(e))),
}
}
}
fn sync_endpoint(host_number: u32) -> String {
if let Ok(endpoint) = env::var("SYNC_ENDPOINT") {
endpoint
} else {
let suffix = if host_number > 0 {
format!("{}", host_number)
} else {
"".to_string()
};
format!("https://sync{}.ankiweb.net/sync/", suffix)
}
}
#[cfg(test)]
mod test {
use tokio::runtime::Runtime;
use super::*;
use crate::{
error::{SyncError, SyncErrorKind},
sync::SanityCheckDueCounts,
};
async fn http_client_inner(username: String, password: String) -> Result<()> {
let mut syncer = Box::new(HttpSyncClient::new(None, 0));
assert!(matches!(
syncer.login("nosuchuser", "nosuchpass").await,
Err(AnkiError::SyncError {
source: SyncError {
kind: SyncErrorKind::AuthFailed,
..
}
})
));
assert!(syncer.login(&username, &password).await.is_ok());
let _meta = syncer.meta().await?;
// aborting before a start is a conflict
assert!(matches!(
syncer.abort().await,
Err(AnkiError::SyncError {
source: SyncError {
kind: SyncErrorKind::Conflict,
..
}
})
));
let _graves = syncer.start(Usn(1), true, None).await?;
// aborting should now work
syncer.abort().await?;
// start again, and continue
let _graves = syncer.start(Usn(1), true, None).await?;
syncer.apply_graves(Graves::default()).await?;
let _changes = syncer.apply_changes(UnchunkedChanges::default()).await?;
let _chunk = syncer.chunk().await?;
syncer
.apply_chunk(Chunk {
done: true,
..Default::default()
})
.await?;
let _out = syncer
.sanity_check(SanityCheckCounts {
counts: SanityCheckDueCounts {
new: 0,
learn: 0,
review: 0,
},
cards: 0,
notes: 0,
revlog: 0,
graves: 0,
notetypes: 0,
decks: 0,
deck_config: 0,
})
.await?;
// failed sanity check will have cleaned up; can't finish
// syncer.finish().await?;
syncer.set_full_sync_progress_fn(Some(Box::new(|progress, _throttle| {
println!("progress: {:?}", progress);
})));
let out_path = syncer.full_download(None).await?;
let mut syncer = Box::new(HttpSyncClient::new(None, 0));
syncer.set_full_sync_progress_fn(Some(Box::new(|progress, _throttle| {
println!("progress {:?}", progress);
})));
syncer.full_upload(out_path.path(), false).await?;
Ok(())
}
#[test]
fn http_client() -> Result<()> {
let user = match env::var("TEST_SYNC_USER") {
Ok(s) => s,
Err(_) => {
return Ok(());
}
};
let pass = env::var("TEST_SYNC_PASS").unwrap();
env_logger::init();
let rt = Runtime::new().unwrap();
rt.block_on(http_client_inner(user, pass))
}
}

View File

@ -0,0 +1,82 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::{future::Future, time::Duration};
use tokio::{select, time::interval};
use crate::sync::{
collection::{
progress::{FullSyncProgress, FullSyncProgressFn},
protocol::{EmptyInput, SyncMethod},
upload::UploadResponse,
},
error::HttpResult,
http_client::{io_monitor::IoMonitor, HttpSyncClient},
request::SyncRequest,
response::SyncResponse,
};
impl HttpSyncClient {
pub fn set_full_sync_progress_fn(&mut self, func: Option<FullSyncProgressFn>) {
*self.full_sync_progress_fn.lock().unwrap() = func;
}
fn full_sync_progress_monitor(&self, sending: bool) -> (IoMonitor, impl Future<Output = ()>) {
let mut progress = FullSyncProgress {
transferred_bytes: 0,
total_bytes: 0,
};
let mut progress_fn = self
.full_sync_progress_fn
.lock()
.unwrap()
.take()
.expect("progress func was not set");
let io_monitor = IoMonitor::new();
let io_monitor2 = io_monitor.clone();
let update_progress = async move {
let mut interval = interval(Duration::from_millis(100));
loop {
interval.tick().await;
let guard = io_monitor2.0.lock().unwrap();
progress.total_bytes = if sending {
guard.total_bytes_to_send
} else {
guard.total_bytes_to_receive
} as usize;
progress.transferred_bytes = if sending {
guard.bytes_sent
} else {
guard.bytes_received
} as usize;
progress_fn(progress, true)
}
};
(io_monitor, update_progress)
}
pub(super) async fn download_inner(
&self,
req: SyncRequest<EmptyInput>,
) -> HttpResult<SyncResponse<Vec<u8>>> {
let (io_monitor, progress_fut) = self.full_sync_progress_monitor(false);
let output = self.request_ext(SyncMethod::Download, req, io_monitor);
select! {
_ = progress_fut => unreachable!(),
out = output => out
}
}
pub(super) async fn upload_inner(
&self,
req: SyncRequest<Vec<u8>>,
) -> HttpResult<SyncResponse<UploadResponse>> {
let (io_monitor, progress_fut) = self.full_sync_progress_monitor(true);
let output = self.request_ext(SyncMethod::Upload, req, io_monitor);
select! {
_ = progress_fut => unreachable!(),
out = output => out
}
}
}

View File

@ -0,0 +1,292 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::{
io::{Cursor, ErrorKind},
sync::{Arc, Mutex},
time::Duration,
};
use bytes::Bytes;
use futures::{Stream, StreamExt, TryStreamExt};
use reqwest::{
header::{CONTENT_TYPE, LOCATION},
Body, RequestBuilder, Response, StatusCode,
};
use tokio::{
io::AsyncReadExt,
select,
time::{interval, Instant},
};
use tokio_util::io::{ReaderStream, StreamReader};
use crate::{
error::Result,
sync::{
error::{HttpError, HttpResult, HttpSnafu, OrHttpErr},
request::header_and_stream::{decode_zstd_body_stream, encode_zstd_body_stream},
response::ORIGINAL_SIZE,
},
};
/// Serves two purposes:
/// - allows us to monitor data sending/receiving and abort if
/// the transfer stalls
/// - allows us to monitor amount of data moving, to provide progress
/// reporting
#[derive(Clone)]
pub struct IoMonitor(pub Arc<Mutex<IoMonitorInner>>);
impl IoMonitor {
pub fn new() -> Self {
Self(Arc::new(Mutex::new(IoMonitorInner {
last_activity: Instant::now(),
bytes_sent: 0,
total_bytes_to_send: 0,
bytes_received: 0,
total_bytes_to_receive: 0,
})))
}
pub fn wrap_stream<S, E>(
&self,
sending: bool,
total_bytes: u32,
stream: S,
) -> impl Stream<Item = HttpResult<Bytes>> + Send + Sync + 'static
where
S: Stream<Item = Result<Bytes, E>> + Send + Sync + 'static,
E: std::error::Error + Send + Sync + 'static,
{
let inner = self.0.clone();
{
let mut inner = inner.lock().unwrap();
inner.last_activity = Instant::now();
if sending {
inner.total_bytes_to_send += total_bytes
} else {
inner.total_bytes_to_receive += total_bytes
}
}
stream.map(move |res| match res {
Ok(bytes) => {
let mut inner = inner.lock().unwrap();
inner.last_activity = Instant::now();
if sending {
inner.bytes_sent += bytes.len() as u32;
} else {
inner.bytes_received += bytes.len() as u32;
}
Ok(bytes)
}
err => err.or_http_err(StatusCode::SEE_OTHER, "stream failure"),
})
}
/// Returns if no I/O activity observed for `stall_time`.
pub async fn timeout(&self, stall_time: Duration) {
let poll_interval = Duration::from_millis(if cfg!(test) { 10 } else { 1000 });
let mut interval = interval(poll_interval);
loop {
let now = interval.tick().await;
let last_activity = self.0.lock().unwrap().last_activity;
if now.duration_since(last_activity) > stall_time {
return;
}
}
}
/// Takes care of encoding provided request data and setting content type to binary, and returns
/// the decompressed response body.
pub async fn zstd_request_with_timeout(
&self,
request: RequestBuilder,
request_body: Vec<u8>,
stall_duration: Duration,
) -> HttpResult<Vec<u8>> {
let request_total = request_body.len() as u32;
let request_body_stream = encode_zstd_body_stream(self.wrap_stream(
true,
request_total,
ReaderStream::new(Cursor::new(request_body)),
));
let response_body_stream = async move {
let resp = request
.header(CONTENT_TYPE, "application/octet-stream")
.body(Body::wrap_stream(request_body_stream))
.send()
.await?
.error_for_status()?;
map_redirect_to_error(&resp)?;
let response_total = resp
.headers()
.get(&ORIGINAL_SIZE)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u32>().ok())
.or_bad_request("missing original size")?;
let response_stream = self.wrap_stream(
false,
response_total,
decode_zstd_body_stream(resp.bytes_stream()),
);
let mut reader =
StreamReader::new(response_stream.map_err(|e| {
std::io::Error::new(ErrorKind::ConnectionAborted, format!("{e}"))
}));
let mut buf = Vec::with_capacity(response_total as usize);
reader
.read_to_end(&mut buf)
.await
.or_http_err(StatusCode::SEE_OTHER, "reading stream")?;
Ok::<_, HttpError>(buf)
};
select! {
// happy path
data = response_body_stream => Ok(data?),
// timeout
_ = self.timeout(stall_duration) => {
HttpSnafu {
code: StatusCode::REQUEST_TIMEOUT,
context: "timeout monitor",
source: None,
}.fail()
}
}
}
}
/// Reqwest can't retry a redirected request as the body has been consumed, so we need
/// to bubble it up to the sync driver to retry.
fn map_redirect_to_error(resp: &Response) -> HttpResult<()> {
if resp.status() == StatusCode::PERMANENT_REDIRECT {
let location = resp
.headers()
.get(LOCATION)
.or_bad_request("missing location header")?;
let location = String::from_utf8(location.as_bytes().to_vec())
.or_bad_request("location was not in utf8")?;
None.or_permanent_redirect(location)?;
}
Ok(())
}
#[derive(Debug)]
pub struct IoMonitorInner {
last_activity: Instant,
pub bytes_sent: u32,
pub total_bytes_to_send: u32,
pub bytes_received: u32,
pub total_bytes_to_receive: u32,
}
impl IoMonitor {}
#[cfg(test)]
mod test {
use async_stream::stream;
use futures::{pin_mut, StreamExt};
use tokio::{select, time::sleep};
use wiremock::{
matchers::{method, path},
Mock, MockServer, ResponseTemplate,
};
use super::*;
use crate::sync::error::HttpError;
/// Longer delays on Windows
fn millis(millis: u64) -> Duration {
Duration::from_millis(millis * if cfg!(windows) { 10 } else { 1 })
}
#[tokio::test]
async fn can_fail_before_any_bytes() {
let monitor = IoMonitor::new();
let stream = monitor.wrap_stream(
true,
0,
stream! {
sleep(millis(2000)).await;
yield Ok::<_, HttpError>(Bytes::from("1"))
},
);
pin_mut!(stream);
select! {
_ = stream.next() => panic!("expected failure"),
_ = monitor.timeout(millis(100)) => ()
};
}
#[tokio::test]
async fn fails_when_data_stops_moving() {
let monitor = IoMonitor::new();
let stream = monitor.wrap_stream(
true,
0,
stream! {
for _ in 0..10 {
sleep(millis(10)).await;
yield Ok::<_, HttpError>(Bytes::from("1"))
}
sleep(millis(50)).await;
yield Ok::<_, HttpError>(Bytes::from("1"))
},
);
pin_mut!(stream);
for _ in 0..10 {
select! {
_ = stream.next() => (),
_ = monitor.timeout(millis(20)) => panic!("expected success")
};
}
select! {
_ = stream.next() => panic!("expected timeout"),
_ = monitor.timeout(millis(20)) => ()
};
}
#[tokio::test]
async fn connect_timeout_works() {
let monitor = IoMonitor::new();
let req = monitor.zstd_request_with_timeout(
reqwest::Client::new().post("http://0.0.0.1"),
vec![],
millis(50),
);
req.await.unwrap_err();
}
#[tokio::test]
async fn http_success() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/"))
.respond_with(ResponseTemplate::new(200).insert_header(ORIGINAL_SIZE.clone(), "0"))
.mount(&mock_server)
.await;
let monitor = IoMonitor::new();
let req = monitor.zstd_request_with_timeout(
reqwest::Client::new().post(mock_server.uri()),
vec![],
millis(10),
);
req.await.unwrap();
}
#[tokio::test]
async fn delay_before_reply_fails() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/"))
.respond_with(ResponseTemplate::new(200).set_delay(millis(50)))
.mount(&mock_server)
.await;
let monitor = IoMonitor::new();
let req = monitor.zstd_request_with_timeout(
reqwest::Client::new().post(mock_server.uri()),
vec![],
millis(10),
);
req.await.unwrap_err();
}
}

View File

@ -0,0 +1,124 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
pub(crate) mod full_sync;
pub(crate) mod io_monitor;
mod protocol;
use std::{sync::Mutex, time::Duration};
use reqwest::{Client, Error, StatusCode, Url};
use crate::{
notes,
sync::{
collection::{progress::FullSyncProgressFn, protocol::AsSyncEndpoint},
error::{HttpError, HttpResult, HttpSnafu},
http_client::io_monitor::IoMonitor,
login::SyncAuth,
request::{
header_and_stream::{SyncHeader, SYNC_HEADER_NAME},
SyncRequest,
},
response::SyncResponse,
},
};
pub struct HttpSyncClient {
/// Set to the empty string for initial login
pub sync_key: String,
session_key: String,
client: Client,
pub endpoint: Url,
full_sync_progress_fn: Mutex<Option<FullSyncProgressFn>>,
}
impl HttpSyncClient {
pub fn new(auth: SyncAuth) -> HttpSyncClient {
HttpSyncClient {
sync_key: auth.hkey,
session_key: simple_session_id(),
client: Client::new(),
endpoint: auth
.endpoint
.unwrap_or_else(|| Url::try_from("https://sync.ankiweb.net/").unwrap()),
full_sync_progress_fn: Mutex::new(None),
}
}
#[cfg(test)]
pub fn partial_clone(&self) -> Self {
Self {
sync_key: self.sync_key.clone(),
session_key: self.session_key.clone(),
client: self.client.clone(),
endpoint: self.endpoint.clone(),
full_sync_progress_fn: Mutex::new(None),
}
}
async fn request<I, O>(
&self,
method: impl AsSyncEndpoint,
request: SyncRequest<I>,
) -> HttpResult<SyncResponse<O>> {
self.request_ext(method, request, IoMonitor::new()).await
}
async fn request_ext<I, O>(
&self,
method: impl AsSyncEndpoint,
request: SyncRequest<I>,
io_monitor: IoMonitor,
) -> HttpResult<SyncResponse<O>> {
let header = SyncHeader {
sync_version: request.sync_version,
sync_key: self.sync_key.clone(),
client_ver: request.client_version,
session_key: self.session_key.clone(),
};
let data = request.data;
let url = method.as_sync_endpoint(&self.endpoint);
let request = self
.client
.post(url)
.header(&SYNC_HEADER_NAME, serde_json::to_string(&header).unwrap());
io_monitor
.zstd_request_with_timeout(request, data, Duration::from_secs(30))
.await
.map(SyncResponse::from_vec)
}
#[cfg(test)]
pub(crate) fn endpoint(&self) -> &Url {
&self.endpoint
}
#[cfg(test)]
pub(crate) fn set_skey(&mut self, skey: String) {
self.session_key = skey;
}
#[cfg(test)]
pub(crate) fn skey(&self) -> &str {
&self.session_key
}
}
impl From<Error> for HttpError {
fn from(err: Error) -> Self {
HttpSnafu {
// we should perhaps make this Optional instead
code: err.status().unwrap_or(StatusCode::SEE_OTHER),
context: "from reqwest",
source: Some(Box::new(err) as _),
}
.build()
}
}
fn simple_session_id() -> String {
let table = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ\
0123456789";
notes::to_base_n(rand::random::<u32>() as u64, table)
}

View File

@ -0,0 +1,139 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use async_trait::async_trait;
use crate::{
prelude::TimestampMillis,
sync::{
collection::{
changes::{ApplyChangesRequest, UnchunkedChanges},
chunks::{ApplyChunkRequest, Chunk},
graves::{ApplyGravesRequest, Graves},
meta::{MetaRequest, SyncMeta},
protocol::{EmptyInput, SyncMethod, SyncProtocol},
sanity::{SanityCheckRequest, SanityCheckResponse},
start::StartRequest,
upload::UploadResponse,
},
error::HttpResult,
http_client::HttpSyncClient,
login::{HostKeyRequest, HostKeyResponse},
media::{
begin::{SyncBeginRequest, SyncBeginResponse},
changes::{MediaChangesRequest, MediaChangesResponse},
download::DownloadFilesRequest,
protocol::{JsonResult, MediaSyncMethod, MediaSyncProtocol},
sanity, upload,
},
request::SyncRequest,
response::SyncResponse,
},
};
#[async_trait]
impl SyncProtocol for HttpSyncClient {
async fn host_key(
&self,
req: SyncRequest<HostKeyRequest>,
) -> HttpResult<SyncResponse<HostKeyResponse>> {
self.request(SyncMethod::HostKey, req).await
}
async fn meta(&self, req: SyncRequest<MetaRequest>) -> HttpResult<SyncResponse<SyncMeta>> {
self.request(SyncMethod::Meta, req).await
}
async fn start(&self, req: SyncRequest<StartRequest>) -> HttpResult<SyncResponse<Graves>> {
self.request(SyncMethod::Start, req).await
}
async fn apply_graves(
&self,
req: SyncRequest<ApplyGravesRequest>,
) -> HttpResult<SyncResponse<()>> {
self.request(SyncMethod::ApplyGraves, req).await
}
async fn apply_changes(
&self,
req: SyncRequest<ApplyChangesRequest>,
) -> HttpResult<SyncResponse<UnchunkedChanges>> {
self.request(SyncMethod::ApplyChanges, req).await
}
async fn chunk(&self, req: SyncRequest<EmptyInput>) -> HttpResult<SyncResponse<Chunk>> {
self.request(SyncMethod::Chunk, req).await
}
async fn apply_chunk(
&self,
req: SyncRequest<ApplyChunkRequest>,
) -> HttpResult<SyncResponse<()>> {
self.request(SyncMethod::ApplyChunk, req).await
}
async fn sanity_check(
&self,
req: SyncRequest<SanityCheckRequest>,
) -> HttpResult<SyncResponse<SanityCheckResponse>> {
self.request(SyncMethod::SanityCheck2, req).await
}
async fn finish(
&self,
req: SyncRequest<EmptyInput>,
) -> HttpResult<SyncResponse<TimestampMillis>> {
self.request(SyncMethod::Finish, req).await
}
async fn abort(&self, req: SyncRequest<EmptyInput>) -> HttpResult<SyncResponse<()>> {
self.request(SyncMethod::Abort, req).await
}
async fn upload(&self, req: SyncRequest<Vec<u8>>) -> HttpResult<SyncResponse<UploadResponse>> {
self.upload_inner(req).await
}
async fn download(&self, req: SyncRequest<EmptyInput>) -> HttpResult<SyncResponse<Vec<u8>>> {
self.download_inner(req).await
}
}
#[async_trait]
impl MediaSyncProtocol for HttpSyncClient {
async fn begin(
&self,
req: SyncRequest<SyncBeginRequest>,
) -> HttpResult<SyncResponse<JsonResult<SyncBeginResponse>>> {
self.request(MediaSyncMethod::Begin, req).await
}
async fn media_changes(
&self,
req: SyncRequest<MediaChangesRequest>,
) -> HttpResult<SyncResponse<JsonResult<MediaChangesResponse>>> {
self.request(MediaSyncMethod::MediaChanges, req).await
}
async fn upload_changes(
&self,
req: SyncRequest<Vec<u8>>,
) -> HttpResult<SyncResponse<JsonResult<upload::MediaUploadResponse>>> {
self.request(MediaSyncMethod::UploadChanges, req).await
}
async fn download_files(
&self,
req: SyncRequest<DownloadFilesRequest>,
) -> HttpResult<SyncResponse<Vec<u8>>> {
self.request(MediaSyncMethod::DownloadFiles, req).await
}
async fn media_sanity_check(
&self,
req: SyncRequest<sanity::SanityCheckRequest>,
) -> HttpResult<SyncResponse<JsonResult<sanity::MediaSanityCheckResponse>>> {
self.request(MediaSyncMethod::MediaSanity, req).await
}
}

View File

@ -0,0 +1,244 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::sync::Arc;
use async_trait::async_trait;
use media::{sanity::MediaSanityCheckResponse, upload::MediaUploadResponse};
use crate::{
prelude::*,
sync::{
collection::{
changes::{server_apply_changes, ApplyChangesRequest, UnchunkedChanges},
chunks::{server_apply_chunk, server_chunk, ApplyChunkRequest, Chunk},
download::server_download,
finish::server_finish,
graves::{server_apply_graves, ApplyGravesRequest, Graves},
meta::{server_meta, MetaRequest, SyncMeta},
protocol::{EmptyInput, SyncProtocol},
sanity::{
server_sanity_check, SanityCheckRequest, SanityCheckResponse, SanityCheckStatus,
},
start::{server_start, StartRequest},
upload::{handle_received_upload, UploadResponse},
},
error::{HttpResult, OrHttpErr},
http_server::SimpleServer,
login::{HostKeyRequest, HostKeyResponse},
media,
media::{
begin::{SyncBeginRequest, SyncBeginResponse},
changes::{MediaChangesRequest, MediaChangesResponse},
download::DownloadFilesRequest,
protocol::{JsonResult, MediaSyncProtocol},
},
request::SyncRequest,
response::SyncResponse,
},
};
#[async_trait]
impl SyncProtocol for Arc<SimpleServer> {
async fn host_key(
&self,
req: SyncRequest<HostKeyRequest>,
) -> HttpResult<SyncResponse<HostKeyResponse>> {
self.get_host_key(req.json()?)
}
async fn meta(&self, req: SyncRequest<MetaRequest>) -> HttpResult<SyncResponse<SyncMeta>> {
self.with_authenticated_user(req, |user, req| {
let req = req.json()?;
user.with_col(|col| server_meta(req, col))
})
.await
.and_then(SyncResponse::try_from_obj)
}
async fn start(&self, req: SyncRequest<StartRequest>) -> HttpResult<SyncResponse<Graves>> {
self.with_authenticated_user(req, |user, req| {
let skey = req.skey()?;
let req = req.json()?;
user.start_new_sync(skey)?;
user.with_sync_state(skey, |col, state| server_start(req, col, state))
.and_then(SyncResponse::try_from_obj)
})
.await
}
async fn apply_graves(
&self,
req: SyncRequest<ApplyGravesRequest>,
) -> HttpResult<SyncResponse<()>> {
self.with_authenticated_user(req, |user, req| {
let skey = req.skey()?;
let req = req.json()?;
user.with_sync_state(skey, |col, state| server_apply_graves(req, col, state))
.and_then(SyncResponse::try_from_obj)
})
.await
}
async fn apply_changes(
&self,
req: SyncRequest<ApplyChangesRequest>,
) -> HttpResult<SyncResponse<UnchunkedChanges>> {
self.with_authenticated_user(req, |user, req| {
let skey = req.skey()?;
let req = req.json()?;
user.with_sync_state(skey, |col, state| server_apply_changes(req, col, state))
.and_then(SyncResponse::try_from_obj)
})
.await
}
async fn chunk(&self, req: SyncRequest<EmptyInput>) -> HttpResult<SyncResponse<Chunk>> {
self.with_authenticated_user(req, |user, req| {
let skey = req.skey()?;
let _ = req.json()?;
user.with_sync_state(skey, server_chunk)
.and_then(SyncResponse::try_from_obj)
})
.await
}
async fn apply_chunk(
&self,
req: SyncRequest<ApplyChunkRequest>,
) -> HttpResult<SyncResponse<()>> {
self.with_authenticated_user(req, |user, req| {
let skey = req.skey()?;
let req = req.json()?;
user.with_sync_state(skey, |col, state| server_apply_chunk(req, col, state))
.and_then(SyncResponse::try_from_obj)
})
.await
}
async fn sanity_check(
&self,
req: SyncRequest<SanityCheckRequest>,
) -> HttpResult<SyncResponse<SanityCheckResponse>> {
self.with_authenticated_user(req, |user, req| {
let skey = req.skey()?;
let req = req.json()?;
let resp = user.with_sync_state(skey, |col, _state| server_sanity_check(req, col))?;
if resp.status == SanityCheckStatus::Bad {
// don't wait for an abort to roll back
let _ = user.col.take();
}
SyncResponse::try_from_obj(resp)
})
.await
}
async fn finish(
&self,
req: SyncRequest<EmptyInput>,
) -> HttpResult<SyncResponse<TimestampMillis>> {
self.with_authenticated_user(req, |user, req| {
let _ = req.json()?;
let now = user.with_sync_state(req.skey()?, |col, _state| server_finish(col))?;
user.sync_state = None;
SyncResponse::try_from_obj(now)
})
.await
}
async fn abort(&self, req: SyncRequest<EmptyInput>) -> HttpResult<SyncResponse<()>> {
self.with_authenticated_user(req, |user, req| {
let _ = req.json()?;
user.abort_stateful_sync_if_active();
SyncResponse::try_from_obj(())
})
.await
}
async fn upload(&self, req: SyncRequest<Vec<u8>>) -> HttpResult<SyncResponse<UploadResponse>> {
self.with_authenticated_user(req, |user, req| {
user.abort_stateful_sync_if_active();
user.ensure_col_open()?;
handle_received_upload(&mut user.col, req.data).map(SyncResponse::from_upload_response)
})
.await
}
async fn download(&self, req: SyncRequest<EmptyInput>) -> HttpResult<SyncResponse<Vec<u8>>> {
self.with_authenticated_user(req, |user, req| {
let schema_version = req.sync_version.collection_schema();
let _ = req.json()?;
user.abort_stateful_sync_if_active();
user.ensure_col_open()?;
server_download(&mut user.col, schema_version).map(SyncResponse::from_vec)
})
.await
}
}
#[async_trait]
impl MediaSyncProtocol for Arc<SimpleServer> {
async fn begin(
&self,
req: SyncRequest<SyncBeginRequest>,
) -> HttpResult<SyncResponse<JsonResult<SyncBeginResponse>>> {
let hkey = req.sync_key.clone();
self.with_authenticated_user(req, |user, req| {
let req = req.json()?;
if req.client_version.is_empty() {
None.or_bad_request("missing client version")?;
}
SyncResponse::try_from_obj(JsonResult::ok(SyncBeginResponse {
usn: user.media.last_usn()?,
host_key: hkey,
}))
})
.await
}
async fn media_changes(
&self,
req: SyncRequest<MediaChangesRequest>,
) -> HttpResult<SyncResponse<JsonResult<MediaChangesResponse>>> {
self.with_authenticated_user(req, |user, req| {
SyncResponse::try_from_obj(JsonResult::ok(
user.media.media_changes_chunk(req.json()?.last_usn)?,
))
})
.await
}
async fn upload_changes(
&self,
req: SyncRequest<Vec<u8>>,
) -> HttpResult<SyncResponse<JsonResult<MediaUploadResponse>>> {
self.with_authenticated_user(req, |user, req| {
SyncResponse::try_from_obj(JsonResult::ok(
user.media.process_uploaded_changes(req.data)?,
))
})
.await
}
async fn download_files(
&self,
req: SyncRequest<DownloadFilesRequest>,
) -> HttpResult<SyncResponse<Vec<u8>>> {
self.with_authenticated_user(req, |user, req| {
Ok(SyncResponse::from_vec(
user.media.zip_files_for_download(req.json()?.files)?,
))
})
.await
}
async fn media_sanity_check(
&self,
req: SyncRequest<media::sanity::SanityCheckRequest>,
) -> HttpResult<SyncResponse<JsonResult<MediaSanityCheckResponse>>> {
self.with_authenticated_user(req, |user, req| {
SyncResponse::try_from_obj(JsonResult::ok(user.media.sanity_check(req.json()?.local)?))
})
.await
}
}

View File

@ -0,0 +1,33 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::time::Duration;
use axum::{body::Body, http::Request, response::Response, Router};
use tower_http::trace::TraceLayer;
use tracing::{info_span, Span};
pub fn with_logging_layer(router: Router) -> Router {
router.layer(
TraceLayer::new_for_http()
.make_span_with(|request: &Request<Body>| {
info_span!(
"request",
uri = request.uri().path(),
ip = tracing::field::Empty,
uid = tracing::field::Empty,
client = tracing::field::Empty,
session = tracing::field::Empty,
)
})
.on_request(())
.on_response(|response: &Response, latency: Duration, _span: &Span| {
tracing::info!(
elap_ms = latency.as_millis() as u32,
httpstatus = response.status().as_u16(),
"finished"
);
})
.on_failure(()),
)
}

View File

@ -0,0 +1,50 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::{fs, io::ErrorKind};
use snafu::ResultExt;
use crate::{
error::{FileIoSnafu, FileOp},
sync::{
error::{HttpResult, OrHttpErr},
http_server::media_manager::ServerMediaManager,
media::{database::server::entry::MediaEntry, zip::zip_files_for_download},
},
};
impl ServerMediaManager {
pub fn zip_files_for_download(&mut self, files: Vec<String>) -> HttpResult<Vec<u8>> {
let entries = self.db.get_entries_for_download(&files)?;
let filenames_with_data = self.gather_file_data(&entries)?;
zip_files_for_download(filenames_with_data).or_internal_err("zip files")
}
/// Mutable for the missing file case.
fn gather_file_data(&mut self, entries: &[MediaEntry]) -> HttpResult<Vec<(String, Vec<u8>)>> {
let mut out = vec![];
for entry in entries {
let path = self.media_folder.join(&entry.nfc_filename);
match fs::read(&path) {
Ok(data) => out.push((entry.nfc_filename.clone(), data)),
Err(err) if err.kind() == ErrorKind::NotFound => {
self.db
.forget_missing_file(entry)
.or_internal_err("forget missing")?;
None.or_conflict(format!(
"requested a file that doesn't exist: {}",
entry.nfc_filename
))?;
}
Err(err) => Err(err)
.context(FileIoSnafu {
path,
op: FileOp::Read,
})
.or_internal_err("gather file data")?,
}
}
Ok(out)
}
}

View File

@ -0,0 +1,58 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
pub mod download;
pub mod upload;
use std::path::{Path, PathBuf};
use crate::{
io::create_dir_all,
prelude::*,
sync::{
error::{HttpResult, OrHttpErr},
media::{
changes::MediaChange, database::server::ServerMediaDatabase,
sanity::MediaSanityCheckResponse,
},
},
};
pub(crate) struct ServerMediaManager {
pub media_folder: PathBuf,
pub db: ServerMediaDatabase,
}
impl ServerMediaManager {
pub(crate) fn new(user_folder: &Path) -> HttpResult<ServerMediaManager> {
let media_folder = user_folder.join("media");
create_dir_all(&media_folder).or_internal_err("media folder create")?;
Ok(Self {
media_folder,
db: ServerMediaDatabase::new(&user_folder.join("media.db"))
.or_internal_err("open media db")?,
})
}
pub fn last_usn(&self) -> HttpResult<Usn> {
self.db.last_usn().or_internal_err("get last usn")
}
pub fn media_changes_chunk(&self, after_usn: Usn) -> HttpResult<Vec<MediaChange>> {
self.db
.media_changes_chunk(after_usn)
.or_internal_err("changes chunk")
}
pub fn sanity_check(&self, client_file_count: u32) -> HttpResult<MediaSanityCheckResponse> {
let server = self
.db
.nonempty_file_count()
.or_internal_err("get nonempty count")?;
Ok(if server == client_file_count {
MediaSanityCheckResponse::Ok
} else {
MediaSanityCheckResponse::SanityCheckFailed
})
}
}

View File

@ -0,0 +1,96 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::{fs, io::ErrorKind, path::Path};
use snafu::ResultExt;
use tracing::info;
use crate::{
error,
error::{FileIoError, FileIoSnafu, FileOp},
io::write_file,
sync::{
error::{HttpResult, OrHttpErr},
http_server::media_manager::ServerMediaManager,
media::{
database::server::entry::upload::UploadedChangeResult, upload::MediaUploadResponse,
zip::unzip_and_validate_files,
},
},
};
impl ServerMediaManager {
pub fn process_uploaded_changes(
&mut self,
zip_data: Vec<u8>,
) -> HttpResult<MediaUploadResponse> {
let extracted = unzip_and_validate_files(&zip_data).or_bad_request("unzip files")?;
let folder = &self.media_folder;
let mut processed = 0;
let new_usn = self
.db
.with_transaction(|db, meta| {
for change in extracted {
match db.register_uploaded_change(meta, change)? {
UploadedChangeResult::FileAlreadyDeleted { filename } => {
info!(filename, "already deleted");
}
UploadedChangeResult::FileIdentical { filename, sha1 } => {
info!(filename, sha1 = hex::encode(sha1), "already have");
}
UploadedChangeResult::Added {
filename,
data,
sha1,
} => {
info!(filename, sha1 = hex::encode(sha1), "added");
add_or_replace_file(&folder.join(filename), data)?;
}
UploadedChangeResult::Replaced {
filename,
data,
old_sha1,
new_sha1,
} => {
info!(
filename,
old_sha1 = hex::encode(old_sha1),
new_sha1 = hex::encode(new_sha1),
"replaced"
);
add_or_replace_file(&folder.join(filename), data)?;
}
UploadedChangeResult::Removed { filename, sha1 } => {
info!(filename, sha1 = hex::encode(sha1), "removed");
remove_file(&folder.join(filename))?;
}
}
processed += 1;
}
Ok(())
})
.or_internal_err("handle uploaded change")?;
Ok(MediaUploadResponse {
processed,
current_usn: new_usn,
})
}
}
fn add_or_replace_file(path: &Path, data: Vec<u8>) -> error::Result<(), FileIoError> {
write_file(path, data).map_err(Into::into)
}
fn remove_file(path: &Path) -> error::Result<(), FileIoError> {
if let Err(err) = fs::remove_file(path) {
// if transaction was previously aborted, the file may have already been deleted
if err.kind() != ErrorKind::NotFound {
return Err(err).context(FileIoSnafu {
path,
op: FileOp::Remove,
});
}
}
Ok(())
}

View File

@ -0,0 +1,176 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
mod handlers;
mod logging;
mod media_manager;
mod routes;
mod user;
use std::{
collections::HashMap,
env,
future::Future,
net::{SocketAddr, TcpListener},
path::{Path, PathBuf},
pin::Pin,
sync::{Arc, Mutex},
};
use axum::{extract::DefaultBodyLimit, Router};
use snafu::{whatever, OptionExt, ResultExt, Whatever};
use tracing::Span;
use crate::{
error,
io::create_dir_all,
media::files::sha1_of_data,
sync::{
error::{HttpResult, OrHttpErr},
http_server::{
logging::with_logging_layer,
media_manager::ServerMediaManager,
routes::{collection_sync_router, media_sync_router},
user::User,
},
login::{HostKeyRequest, HostKeyResponse},
request::{SyncRequest, MAXIMUM_SYNC_PAYLOAD_BYTES},
response::SyncResponse,
},
};
pub struct SimpleServer {
state: Mutex<SimpleServerInner>,
}
pub struct SimpleServerInner {
/// hkey->user
users: HashMap<String, User>,
}
impl SimpleServerInner {
fn new_from_env(base_folder: &Path) -> error::Result<Self, Whatever> {
let mut idx = 1;
let mut users: HashMap<String, User> = Default::default();
loop {
let envvar = format!("SYNC_USER{idx}");
match std::env::var(&envvar) {
Ok(val) => {
let hkey = derive_hkey(&val);
let (name, _) = val.split_once(':').with_whatever_context(|| {
format!("{envvar} should be in 'username:password' format.")
})?;
let folder = base_folder.join(name);
create_dir_all(&folder).whatever_context("creating SYNC_BASE")?;
let media =
ServerMediaManager::new(&folder).whatever_context("opening media")?;
users.insert(
hkey,
User {
name: name.into(),
col: None,
sync_state: None,
media,
folder,
},
);
idx += 1;
}
Err(_) => break,
}
}
if users.is_empty() {
whatever!("No users defined; SYNC_USER1 env var should be set.");
}
Ok(Self { users })
}
}
// This is not what AnkiWeb does, but should suffice for this use case.
fn derive_hkey(user_and_pass: &str) -> String {
hex::encode(sha1_of_data(user_and_pass.as_bytes()))
}
impl SimpleServer {
pub(in crate::sync) async fn with_authenticated_user<F, I, O>(
&self,
req: SyncRequest<I>,
op: F,
) -> HttpResult<O>
where
F: FnOnce(&mut User, SyncRequest<I>) -> HttpResult<O>,
{
let mut state = self.state.lock().unwrap();
let user = state
.users
.get_mut(&req.sync_key)
.or_forbidden("invalid hkey")?;
Span::current().record("uid", &user.name);
Span::current().record("client", &req.client_version);
Span::current().record("session", &req.session_key);
op(user, req)
}
pub(in crate::sync) fn get_host_key(
&self,
request: HostKeyRequest,
) -> HttpResult<SyncResponse<HostKeyResponse>> {
let state = self.state.lock().unwrap();
let key = derive_hkey(&format!("{}:{}", request.username, request.password));
if state.users.contains_key(&key) {
SyncResponse::try_from_obj(HostKeyResponse { key })
} else {
None.or_forbidden("invalid user/pass in get_host_key")
}
}
pub fn new(base_folder: &Path) -> error::Result<Self, Whatever> {
let inner = SimpleServerInner::new_from_env(base_folder)?;
Ok(SimpleServer {
state: Mutex::new(inner),
})
}
pub fn make_server(
address: Option<&str>,
base_folder: &Path,
) -> error::Result<(SocketAddr, ServerFuture), Whatever> {
let server =
Arc::new(SimpleServer::new(base_folder).whatever_context("unable to create server")?);
let address = address.unwrap_or("127.0.0.1:0");
let listener = TcpListener::bind(address)
.with_whatever_context(|_| format!("couldn't bind to {address}"))?;
let addr = listener.local_addr().unwrap();
let server = with_logging_layer(
Router::new()
.nest("/sync", collection_sync_router())
.nest("/msync", media_sync_router())
.with_state(server)
.layer(DefaultBodyLimit::max(*MAXIMUM_SYNC_PAYLOAD_BYTES)),
);
let future = axum::Server::from_tcp(listener)
.whatever_context("listen failed")?
.serve(server.into_make_service_with_connect_info::<SocketAddr>())
.with_graceful_shutdown(async {
let _ = tokio::signal::ctrl_c().await;
});
tracing::info!(%addr, "listening");
Ok((addr, Box::pin(future)))
}
#[snafu::report]
#[tokio::main]
pub async fn run() -> error::Result<(), Whatever> {
let host = env::var("SYNC_HOST").unwrap_or_else(|_| "0.0.0.0".into());
let port = env::var("SYNC_PORT").unwrap_or_else(|_| "8080".into());
let base_folder =
PathBuf::from(env::var("SYNC_BASE").whatever_context("missing SYNC_BASE")?);
let addr = format!("{host}:{port}");
let (_addr, server_fut) = SimpleServer::make_server(Some(&addr), &base_folder)?;
server_fut.await.whatever_context("await server")?;
Ok(())
}
}
pub type ServerFuture = Pin<Box<dyn Future<Output = error::Result<(), hyper::Error>> + Send>>;

View File

@ -0,0 +1,108 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use axum::{
extract::{Path, Query, State},
response::Response,
routing::{get, post},
Router,
};
use crate::sync::{
collection::protocol::{SyncMethod, SyncProtocol},
error::{HttpResult, OrHttpErr},
media::{
begin::{SyncBeginQuery, SyncBeginRequest},
protocol::{MediaSyncMethod, MediaSyncProtocol},
},
request::{IntoSyncRequest, SyncRequest},
version::SyncVersion,
};
macro_rules! sync_method {
($server:ident, $req:ident, $method:ident) => {{
let sync_version = $req.sync_version;
let obj = $server.$method($req.into_output_type()).await?;
obj.make_response(sync_version)
}};
}
async fn sync_handler<P: SyncProtocol>(
Path(method): Path<SyncMethod>,
State(server): State<P>,
request: SyncRequest<Vec<u8>>,
) -> HttpResult<Response> {
Ok(match method {
SyncMethod::HostKey => sync_method!(server, request, host_key),
SyncMethod::Meta => sync_method!(server, request, meta),
SyncMethod::Start => sync_method!(server, request, start),
SyncMethod::ApplyGraves => sync_method!(server, request, apply_graves),
SyncMethod::ApplyChanges => sync_method!(server, request, apply_changes),
SyncMethod::Chunk => sync_method!(server, request, chunk),
SyncMethod::ApplyChunk => sync_method!(server, request, apply_chunk),
SyncMethod::SanityCheck2 => sync_method!(server, request, sanity_check),
SyncMethod::Finish => sync_method!(server, request, finish),
SyncMethod::Abort => sync_method!(server, request, abort),
SyncMethod::Upload => sync_method!(server, request, upload),
SyncMethod::Download => sync_method!(server, request, download),
})
}
pub fn collection_sync_router<P: SyncProtocol + Clone>() -> Router<P> {
Router::new().route("/:method", post(sync_handler::<P>))
}
/// The Rust code used to send a GET with query params, which was inconsistent with the
/// rest of our code - map the request into our standard structure.
async fn media_begin_get<P: MediaSyncProtocol>(
Query(req): Query<SyncBeginQuery>,
server: State<P>,
) -> HttpResult<Response> {
let host_key = req.host_key;
let mut req = SyncBeginRequest {
client_version: req.client_version,
}
.try_into_sync_request()
.or_bad_request("convert begin")?;
req.sync_key = host_key;
req.sync_version = SyncVersion::multipart();
media_begin_post(server, req).await
}
/// Older clients would send client info in the multipart instead of the inner JSON;
/// Inject it into the json if provided.
async fn media_begin_post<P: MediaSyncProtocol>(
server: State<P>,
mut req: SyncRequest<SyncBeginRequest>,
) -> HttpResult<Response> {
if let Some(ver) = &req.media_client_version {
req.data = serde_json::to_vec(&SyncBeginRequest {
client_version: ver.clone(),
})
.or_internal_err("serialize begin request")?;
}
media_sync_handler(Path(MediaSyncMethod::Begin), server, req.into_output_type()).await
}
async fn media_sync_handler<P: MediaSyncProtocol>(
Path(method): Path<MediaSyncMethod>,
State(server): State<P>,
request: SyncRequest<Vec<u8>>,
) -> HttpResult<Response> {
Ok(match method {
MediaSyncMethod::Begin => sync_method!(server, request, begin),
MediaSyncMethod::MediaChanges => sync_method!(server, request, media_changes),
MediaSyncMethod::UploadChanges => sync_method!(server, request, upload_changes),
MediaSyncMethod::DownloadFiles => sync_method!(server, request, download_files),
MediaSyncMethod::MediaSanity => sync_method!(server, request, media_sanity_check),
})
}
pub fn media_sync_router<P: MediaSyncProtocol + Clone>() -> Router<P> {
Router::new()
.route(
"/begin",
get(media_begin_get::<P>).post(media_begin_post::<P>),
)
.route("/:method", post(media_sync_handler::<P>))
}

View File

@ -0,0 +1,95 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::path::PathBuf;
use tracing::info;
use crate::{
collection::{Collection, CollectionBuilder},
error,
sync::{
collection::start::ServerSyncState,
error::{HttpResult, OrHttpErr},
http_server::media_manager::ServerMediaManager,
},
};
pub(in crate::sync) struct User {
pub name: String,
pub col: Option<Collection>,
pub sync_state: Option<ServerSyncState>,
pub media: ServerMediaManager,
pub folder: PathBuf,
}
impl User {
/// Run op with access to the collection. If a sync is active, it's aborted.
pub(crate) fn with_col<F, T>(&mut self, op: F) -> HttpResult<T>
where
F: FnOnce(&mut Collection) -> HttpResult<T>,
{
self.abort_stateful_sync_if_active();
self.ensure_col_open()?;
op(self.col.as_mut().unwrap())
}
/// Run op with the existing sync state created by start_new_sync(). If there is no
/// existing state, or the current state's key does not match, abort the request with
/// a conflict.
pub(crate) fn with_sync_state<F, T>(&mut self, skey: &str, op: F) -> HttpResult<T>
where
F: FnOnce(&mut Collection, &mut ServerSyncState) -> error::Result<T>,
{
match &self.sync_state {
None => None.or_conflict("no active sync")?,
Some(state) => {
if state.skey != skey {
None.or_conflict("active sync with different key")?;
}
}
};
self.ensure_col_open()?;
let state = self.sync_state.as_mut().unwrap();
let col = self.col.as_mut().or_internal_err("open col")?;
// Failures in a sync op are usually caused by referential integrity issues (eg they've sent
// a note without sending its associated notetype). Returning HTTP 400 will inform the client that
// a DB check+full sync is required to fix the issue.
op(col, state)
.map_err(|e| {
self.col = None;
self.sync_state = None;
e
})
.or_bad_request("op failed in sync_state")
}
pub(crate) fn abort_stateful_sync_if_active(&mut self) {
if self.sync_state.is_some() {
info!("aborting active sync");
self.sync_state = None;
self.col = None;
}
}
pub(crate) fn start_new_sync(&mut self, skey: &str) -> HttpResult<()> {
self.abort_stateful_sync_if_active();
self.sync_state = Some(ServerSyncState::new(skey));
Ok(())
}
pub(crate) fn ensure_col_open(&mut self) -> HttpResult<()> {
if self.col.is_none() {
self.col = Some(self.open_collection()?);
}
Ok(())
}
fn open_collection(&mut self) -> HttpResult<Collection> {
CollectionBuilder::new(self.folder.join("collection.anki2"))
.set_server(true)
.build()
.or_internal_err("open collection")
}
}

58
rslib/src/sync/login.rs Normal file
View File

@ -0,0 +1,58 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use reqwest::Url;
use serde_derive::{Deserialize, Serialize};
use crate::{
prelude::*,
sync::{
collection::protocol::SyncProtocol, http_client::HttpSyncClient, request::IntoSyncRequest,
},
};
#[derive(Clone, Default)]
pub struct SyncAuth {
pub hkey: String,
pub endpoint: Option<Url>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct HostKeyRequest {
#[serde(rename = "u")]
pub username: String,
#[serde(rename = "p")]
pub password: String,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct HostKeyResponse {
pub key: String,
}
pub async fn sync_login<S: Into<String>>(
username: S,
password: S,
endpoint: Option<String>,
) -> Result<SyncAuth> {
let auth = crate::pb::sync::SyncAuth {
endpoint,
..Default::default()
}
.try_into()?;
let client = HttpSyncClient::new(auth);
let resp = client
.host_key(
HostKeyRequest {
username: username.into(),
password: password.into(),
}
.try_into_sync_request()?,
)
.await?
.json()?;
Ok(SyncAuth {
hkey: resp.key,
endpoint: None,
})
}

View File

@ -0,0 +1,34 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use serde_derive::{Deserialize, Serialize};
use crate::prelude::*;
// The old Rust code sent the host key in a query string
#[derive(Debug, Serialize, Deserialize)]
pub struct SyncBeginQuery {
#[serde(rename = "k")]
pub host_key: String,
#[serde(rename = "v")]
pub client_version: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SyncBeginRequest {
/// Older clients provide this in the multipart wrapper; our router will
/// inject the value in if necessary. The route handler should check that
/// a value has actually been provided.
#[serde(rename = "v", default)]
pub client_version: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SyncBeginResponse {
pub usn: Usn,
/// The server used to send back a session key used for following requests,
/// but this is no longer required. To avoid breaking older clients, the host
/// key is returned in its place.
#[serde(rename = "sk")]
pub host_key: String,
}

View File

@ -0,0 +1,135 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use serde_derive::{Deserialize, Serialize};
use serde_tuple::Serialize_tuple;
use tracing::debug;
use crate::{error, prelude::Usn, sync::media::database::client::MediaDatabase};
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct MediaChangesRequest {
pub last_usn: Usn,
}
pub type MediaChangesResponse = Vec<MediaChange>;
#[derive(Debug, Serialize_tuple, Deserialize)]
pub struct MediaChange {
pub fname: String,
pub usn: Usn,
pub sha1: String,
}
#[derive(Debug, Clone, Copy)]
pub enum LocalState {
NotInDb,
InDbNotPending,
InDbAndPending,
}
#[derive(PartialEq, Eq, Debug)]
pub enum RequiredChange {
// none also covers the case where we'll later upload
None,
Download,
Delete,
RemovePending,
}
pub fn determine_required_change(
local_sha1: &str,
remote_sha1: &str,
local_state: LocalState,
) -> RequiredChange {
match (local_sha1, remote_sha1, local_state) {
// both deleted, not in local DB
("", "", LocalState::NotInDb) => RequiredChange::None,
// both deleted, in local DB
("", "", _) => RequiredChange::Delete,
// added on server, add even if local deletion pending
("", _, _) => RequiredChange::Download,
// deleted on server but added locally; upload later
(_, "", LocalState::InDbAndPending) => RequiredChange::None,
// deleted on server and not pending sync
(_, "", _) => RequiredChange::Delete,
// if pending but the same as server, don't need to upload
(lsum, rsum, LocalState::InDbAndPending) if lsum == rsum => RequiredChange::RemovePending,
(lsum, rsum, _) => {
if lsum == rsum {
// not pending and same as server, nothing to do
RequiredChange::None
} else {
// differs from server, favour server
RequiredChange::Download
}
}
}
}
/// Get a list of server filenames and the actions required on them.
/// Returns filenames in (to_download, to_delete).
pub fn determine_required_changes(
ctx: &MediaDatabase,
records: Vec<MediaChange>,
) -> error::Result<(Vec<String>, Vec<String>, Vec<String>)> {
let mut to_download = vec![];
let mut to_delete = vec![];
let mut to_remove_pending = vec![];
for remote in records {
let (local_sha1, local_state) = match ctx.get_entry(&remote.fname)? {
Some(entry) => (
match entry.sha1 {
Some(arr) => hex::encode(arr),
None => "".to_string(),
},
if entry.sync_required {
LocalState::InDbAndPending
} else {
LocalState::InDbNotPending
},
),
None => ("".to_string(), LocalState::NotInDb),
};
let req_change = determine_required_change(&local_sha1, &remote.sha1, local_state);
debug!(
fname = &remote.fname,
lsha = local_sha1.chars().take(8).collect::<String>(),
rsha = remote.sha1.chars().take(8).collect::<String>(),
state = ?local_state,
action = ?req_change,
"determine action"
);
match req_change {
RequiredChange::Download => to_download.push(remote.fname),
RequiredChange::Delete => to_delete.push(remote.fname),
RequiredChange::RemovePending => to_remove_pending.push(remote.fname),
RequiredChange::None => (),
};
}
Ok((to_download, to_delete, to_remove_pending))
}
#[cfg(test)]
mod test {
#[test]
fn required_change() {
use crate::sync::media::changes::{
determine_required_change as d, LocalState as L, RequiredChange as R,
};
assert_eq!(d("", "", L::NotInDb), R::None);
assert_eq!(d("", "", L::InDbNotPending), R::Delete);
assert_eq!(d("", "1", L::InDbAndPending), R::Download);
assert_eq!(d("1", "", L::InDbAndPending), R::None);
assert_eq!(d("1", "", L::InDbNotPending), R::Delete);
assert_eq!(d("1", "1", L::InDbNotPending), R::None);
assert_eq!(d("1", "1", L::InDbAndPending), R::RemovePending);
assert_eq!(d("a", "b", L::InDbAndPending), R::Download);
assert_eq!(d("a", "b", L::InDbNotPending), R::Download);
}
}

View File

@ -7,14 +7,12 @@ use tracing::debug;
use crate::{
io::read_dir_files,
media::{
database::{MediaDatabaseContext, MediaEntry},
files::{
filename_if_normalized, mtime_as_i64, sha1_of_file, MEDIA_SYNC_FILESIZE_LIMIT,
NONSYNCABLE_FILENAME,
},
},
media::files::{filename_if_normalized, mtime_as_i64, sha1_of_file, NONSYNCABLE_FILENAME},
prelude::*,
sync::media::{
database::client::{MediaDatabase, MediaEntry},
MAX_INDIVIDUAL_MEDIA_FILE_SIZE,
},
};
struct FilesystemEntry {
@ -24,7 +22,7 @@ struct FilesystemEntry {
is_new: bool,
}
pub(super) struct ChangeTracker<'a, F>
pub(crate) struct ChangeTracker<'a, F>
where
F: FnMut(usize) -> bool,
{
@ -37,7 +35,7 @@ impl<F> ChangeTracker<'_, F>
where
F: FnMut(usize) -> bool,
{
pub(super) fn new(media_folder: &Path, progress: F) -> ChangeTracker<'_, F> {
pub(crate) fn new(media_folder: &Path, progress: F) -> ChangeTracker<'_, F> {
ChangeTracker {
media_folder,
progress_cb: progress,
@ -53,7 +51,7 @@ where
}
}
pub(super) fn register_changes(&mut self, ctx: &mut MediaDatabaseContext) -> Result<()> {
pub(crate) fn register_changes(&mut self, ctx: &MediaDatabase) -> Result<()> {
ctx.transact(|ctx| {
// folder mtime unchanged?
let dirmod = mtime_as_i64(self.media_folder)?;
@ -125,7 +123,7 @@ where
// ignore large files and zero byte files
let metadata = dentry.metadata()?;
if metadata.len() > MEDIA_SYNC_FILESIZE_LIMIT as u64 {
if metadata.len() > MAX_INDIVIDUAL_MEDIA_FILE_SIZE as u64 {
continue;
}
if metadata.len() == 0 {
@ -184,7 +182,7 @@ where
/// Skip files where the mod time differed, but checksums are the same.
fn add_updated_entries(
&mut self,
ctx: &mut MediaDatabaseContext,
ctx: &MediaDatabase,
entries: Vec<FilesystemEntry>,
) -> Result<()> {
for fentry in entries {
@ -217,11 +215,7 @@ where
}
/// Remove deleted files from the media DB.
fn remove_deleted_files(
&mut self,
ctx: &mut MediaDatabaseContext,
removed: Vec<String>,
) -> Result<()> {
fn remove_deleted_files(&mut self, ctx: &MediaDatabase, removed: Vec<String>) -> Result<()> {
for fname in removed {
ctx.set_entry(&MediaEntry {
fname,
@ -246,12 +240,12 @@ mod test {
use tempfile::tempdir;
use super::*;
use crate::{
error::Result,
io::{create_dir, write_file},
media::{
changetracker::ChangeTracker, database::MediaEntry, files::sha1_of_data, MediaManager,
},
media::{files::sha1_of_data, MediaManager},
sync::media::database::client::MediaEntry,
};
// helper
@ -273,9 +267,7 @@ mod test {
let media_db = dir.path().join("media.db");
let mgr = MediaManager::new(&media_dir, media_db)?;
let mut ctx = mgr.dbctx();
assert_eq!(ctx.count()?, 0);
assert_eq!(mgr.db.count()?, 0);
// add a file and check it's picked up
let f1 = media_dir.join("file.jpg");
@ -283,11 +275,11 @@ mod test {
change_mtime(&media_dir);
let progress_cb = |_n| true;
let mut progress_cb = |_n| true;
ChangeTracker::new(&mgr.media_folder, progress_cb).register_changes(&mut ctx)?;
mgr.register_changes(&mut progress_cb)?;
let mut entry = ctx.transact(|ctx| {
let mut entry = mgr.db.transact(|ctx| {
assert_eq!(ctx.count()?, 1);
assert!(!ctx.get_pending_uploads(1)?.is_empty());
let mut entry = ctx.get_entry("file.jpg")?.unwrap();
@ -320,9 +312,9 @@ mod test {
Ok(entry)
})?;
ChangeTracker::new(&mgr.media_folder, progress_cb).register_changes(&mut ctx)?;
ChangeTracker::new(&mgr.media_folder, progress_cb).register_changes(&mgr.db)?;
ctx.transact(|ctx| {
mgr.db.transact(|ctx| {
assert_eq!(ctx.count()?, 1);
assert!(!ctx.get_pending_uploads(1)?.is_empty());
assert_eq!(
@ -353,12 +345,12 @@ mod test {
change_mtime(&media_dir);
ChangeTracker::new(&mgr.media_folder, progress_cb).register_changes(&mut ctx)?;
ChangeTracker::new(&mgr.media_folder, progress_cb).register_changes(&mgr.db)?;
assert_eq!(ctx.count()?, 0);
assert!(!ctx.get_pending_uploads(1)?.is_empty());
assert_eq!(mgr.db.count()?, 0);
assert!(!mgr.db.get_pending_uploads(1)?.is_empty());
assert_eq!(
ctx.get_entry("file.jpg")?.unwrap(),
mgr.db.get_entry("file.jpg")?.unwrap(),
MediaEntry {
fname: "file.jpg".into(),
sha1: None,

Some files were not shown because too many files have changed in this diff Show More