Implement TTS using windows crate (#2371)

* Implement TTS using windows crate

* Use API calls instead of SSML

* Properly stop player in case of TTS error

* Add context to WindowsErrors

* Validate available voices

* Remove TTS text from synthesize error

* Limit maximum buffer size

* Make validation optional and list it in tts filter

* We no longer need the winrt module (dae)

* Use a separate request object so the meaning of the bool is clear (dae)

* Slightly shorten runtime error message (dae)

The default message appears to clip slightly.

* Alternate buffer implementation (dae)

* Use array instead of vec

* Drop the max buffer size to 128k (dae)
This commit is contained in:
RumovZ 2023-02-17 03:26:07 +01:00 committed by GitHub
parent 5a53da23ca
commit cdfb84f19a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 323 additions and 77 deletions

View File

@ -44,4 +44,4 @@ good-names =
ip,
[IMPORTS]
ignored-modules = anki.*_pb2, anki.sync_pb2, win32file,pywintypes,socket,win32pipe,winrt,pyaudio,anki.scheduler_pb2
ignored-modules = anki.*_pb2, anki.sync_pb2, win32file,pywintypes,socket,win32pipe,pyaudio,anki.scheduler_pb2

25
Cargo.lock generated
View File

@ -144,6 +144,7 @@ dependencies = [
"unicode-normalization",
"utime",
"which",
"windows",
"wiremock",
"workspace-hack",
"zip",
@ -4580,6 +4581,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows"
version = "0.44.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e745dab35a0c4c77aa3ce42d595e13d2003d6902d6b08c9ef5fc326d08da12b"
dependencies = [
"windows-targets",
]
[[package]]
name = "windows-sys"
version = "0.42.0"
@ -4595,6 +4605,21 @@ dependencies = [
"windows_x86_64_msvc",
]
[[package]]
name = "windows-targets"
version = "0.42.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e2522491fbfcd58cc84d47aeb2958948c4b8982e9a2d8a2a35bbaed431390e7"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.42.1"

View File

@ -39,7 +39,7 @@ errors-unable-open-collection =
Anki was unable to open your collection file. If problems persist after restarting your computer, please use the Open Backup button in the profile manager.
Debug info:
errors-windows-tts-runtime-error = The TTS service failed. Please ensure Windows updates are installed, try restarting your computer, and try using a different voice.
errors-windows-tts-runtime-error = The TTS service failed. Please ensure Windows updates are installed, try restarting your computer, or try a different voice.
## OBSOLETE; you do not need to translate this

View File

@ -66,6 +66,8 @@ message BackendError {
DELETED = 17;
CARD_TYPE_ERROR = 18;
ANKIDROID_PANIC_ERROR = 19;
// Originated from and usually specific to the OS.
OS_ERROR = 20;
}
// error description, usually localized, suitable for displaying to the user

View File

@ -29,6 +29,8 @@ service CardRenderingService {
rpc CompareAnswer(CompareAnswerRequest) returns (generic.String);
rpc ExtractClozeForTyping(ExtractClozeForTypingRequest)
returns (generic.String);
rpc AllTtsVoices(AllTtsVoicesRequest) returns (AllTtsVoicesResponse);
rpc WriteTtsStream(WriteTtsStreamRequest) returns (generic.Empty);
}
message ExtractAVTagsRequest {
@ -145,3 +147,24 @@ message ExtractClozeForTypingRequest {
string text = 1;
uint32 ordinal = 2;
}
message AllTtsVoicesRequest {
bool validate = 1;
}
message AllTtsVoicesResponse {
message TtsVoice {
string id = 1;
string name = 2;
string language = 3;
optional bool available = 4;
}
repeated TtsVoice voices = 1;
}
message WriteTtsStreamRequest {
string path = 1;
string voice_id = 2;
float speed = 3;
string text = 4;
}

View File

@ -39,6 +39,7 @@ ImportCsvRequest = import_export_pb2.ImportCsvRequest
CsvMetadata = import_export_pb2.CsvMetadata
DupeResolution = CsvMetadata.DupeResolution
Delimiter = import_export_pb2.CsvMetadata.Delimiter
TtsVoice = card_rendering_pb2.AllTtsVoicesResponse.TtsVoice
import copy
import os

View File

@ -7,4 +7,3 @@ send2trash
waitress>=2.0.0
psutil; sys.platform == "win32"
pywin32; sys.platform == "win32"
winrt; sys.platform == "win32"

View File

@ -1,3 +1,2 @@
pywin32
winrt

View File

@ -14,11 +14,3 @@ pywin32==305 \
--hash=sha256:9dd98384da775afa009bc04863426cb30596fd78c6f8e4e2e5bbf4edf8029504 \
--hash=sha256:a55db448124d1c1484df22fa8bbcbc45c64da5e6eae74ab095b9ea62e6d00496
# via -r requirements.win.in
winrt==1.0.21033.1 \
--hash=sha256:224e13eb172435aaabdc7066752898a61dae0fcc3022f6f8cbd1ce953be3358c \
--hash=sha256:9d7b7d2e48c301855afd3280aaf51ea0d3c683450f46de2db813f71ee1cd5937 \
--hash=sha256:ad4afd1c7b041a6b770256d70e07093920fa83eecd80e42cac2704cd03902243 \
--hash=sha256:d035570ce2cf7e8caa785abb43f25a6ede600c2cde0378c931495bdbeaf1a075 \
--hash=sha256:da3ca3626fb992f2efa4528993d4760b298f399a7f459f7e070a2f8681d82106 \
--hash=sha256:f5ab502117da4777ab49b846ad1919fbf448bd5e49b4aca00cc59667bae2c362
# via -r requirements.win.in

View File

@ -28,11 +28,9 @@ expose the name of the engine, which would mean the user could write
from __future__ import annotations
import asyncio
import os
import re
import subprocess
import threading
from concurrent.futures import Future
from dataclasses import dataclass
from operator import attrgetter
@ -40,7 +38,9 @@ from typing import Any, cast
import anki
import anki.template
import aqt
from anki import hooks
from anki.collection import TtsVoice as BackendVoice
from anki.sound import AVTag, TTSTag
from anki.utils import checksum, is_win, tmpdir
from aqt import gui_hooks
@ -52,6 +52,16 @@ from aqt.utils import tooltip, tr
class TTSVoice:
name: str
lang: str
available: bool | None
def __str__(self) -> str:
out = f"{{{{tts {self.lang} voices={self.name}}}}}"
if self.unavailable():
out += " (unavailable)"
return out
def unavailable(self) -> bool:
return self.available is False
@dataclass
@ -124,9 +134,8 @@ def all_tts_voices() -> list[TTSVoice]:
all_voices: list[TTSVoice] = []
for p in av_player.players:
getter = getattr(p, "voices", None)
if not getter:
continue
getter = getattr(p, "validated_voices", getattr(p, "voices", None))
if getter:
all_voices.extend(getter())
return all_voices
@ -137,14 +146,13 @@ def on_tts_voices(
if filter != "tts-voices":
return text
voices = all_tts_voices()
voices.sort(key=attrgetter("name"))
voices.sort(key=attrgetter("lang"))
voices.sort(key=attrgetter("lang", "name"))
buf = "<div style='font-size: 14px; text-align: left;'>TTS voices available:<br>"
buf += "<br>".join(
f"{{{{tts {v.lang} voices={v.name}}}}}" # pylint: disable=no-member
for v in voices
)
buf += "<br>".join(map(str, voices))
if any(v.unavailable() for v in voices):
buf += "<div>One or more voices are unavailable."
buf += " Installing a Windows language pack may help.</div>"
return f"{buf}</div>"
@ -205,7 +213,9 @@ class MacTTSPlayer(TTSProcessPlayer):
original_name = m.group(1).strip()
tidy_name = f"Apple_{original_name.replace(' ', '_')}"
return MacVoice(name=tidy_name, original_name=original_name, lang=m.group(2))
return MacVoice(
name=tidy_name, original_name=original_name, lang=m.group(2), available=None
)
class MacTTSFilePlayer(MacTTSPlayer):
@ -509,7 +519,10 @@ if is_win:
# some voices may not have a name
name = "unknown"
name = self._tidy_name(name)
return [WindowsVoice(name=name, lang=lang, handle=voice) for lang in langs]
return [
WindowsVoice(name=name, lang=lang, handle=voice, available=None)
for lang in langs
]
def _play(self, tag: AVTag) -> None:
assert isinstance(tag, TTSTag)
@ -546,35 +559,36 @@ if is_win:
@dataclass
class WindowsRTVoice(TTSVoice):
id: Any
id: str
class WindowsRTTTSFilePlayer(TTSProcessPlayer):
voice_list: list[Any] = []
tmppath = os.path.join(tmpdir(), "tts.wav")
def import_voices(self) -> None:
import winrt.windows.media.speechsynthesis as speechsynthesis # type: ignore
try:
self.voice_list = speechsynthesis.SpeechSynthesizer.get_all_voices() # type: ignore
except Exception as e:
print("winrt tts voices unavailable:", e)
self.voice_list = []
def get_available_voices(self) -> list[TTSVoice]:
t = threading.Thread(target=self.import_voices)
t.start()
t.join()
return list(map(self._voice_to_object, self.voice_list))
def _voice_to_object(self, voice: Any) -> TTSVoice:
return WindowsRTVoice(
@classmethod
def from_backend_voice(cls, voice: BackendVoice) -> WindowsRTVoice:
return cls(
id=voice.id,
name=voice.display_name.replace(" ", "_"),
name=voice.name.replace(" ", "_"),
lang=voice.language.replace("-", "_"),
available=voice.available,
)
class WindowsRTTTSFilePlayer(TTSProcessPlayer):
tmppath = os.path.join(tmpdir(), "tts.wav")
def validated_voices(self) -> list[TTSVoice]:
self._available_voices = self._get_available_voices(validate=True)
return self._available_voices
@classmethod
def get_available_voices(cls) -> list[TTSVoice]:
return cls._get_available_voices(validate=False)
@staticmethod
def _get_available_voices(validate: bool) -> list[TTSVoice]:
assert aqt.mw
voices = aqt.mw.backend.all_tts_voices(validate=validate)
return list(map(WindowsRTVoice.from_backend_voice, voices))
def _play(self, tag: AVTag) -> None:
assert aqt.mw
assert isinstance(tag, TTSTag)
match = self.voice_for_tag(tag)
assert match
@ -583,13 +597,18 @@ if is_win:
self._taskman.run_on_main(
lambda: gui_hooks.av_player_did_begin_playing(self, tag)
)
asyncio.run(self.speakText(tag, voice.id))
aqt.mw.backend.write_tts_stream(
path=self.tmppath,
voice_id=voice.id,
speed=tag.speed,
text=tag.field_text,
)
def _on_done(self, ret: Future, cb: OnDoneCallback) -> None:
try:
ret.result()
except RuntimeError:
if exception := ret.exception():
print(str(exception))
tooltip(tr.errors_windows_tts_runtime_error())
cb()
return
# inject file into the top of the audio queue
@ -599,26 +618,3 @@ if is_win:
# then tell player to advance, which will cause the file to be played
cb()
async def speakText(self, tag: TTSTag, voice_id: Any) -> None:
import winrt.windows.media.speechsynthesis as speechsynthesis # type: ignore
import winrt.windows.storage.streams as streams # type: ignore
synthesizer = speechsynthesis.SpeechSynthesizer()
voices = speechsynthesis.SpeechSynthesizer.get_all_voices() # type: ignore
voice_match = next(filter(lambda v: v.id == voice_id, voices))
assert voice_match
synthesizer.voice = voice_match
synthesizer.options.speaking_rate = tag.speed
stream = await synthesizer.synthesize_text_to_stream_async(tag.field_text)
inputStream = stream.get_input_stream_at(0)
dataReader = streams.DataReader(inputStream)
dataReader.load_async(stream.size)
f = open(self.tmppath, "wb")
for x in range(stream.size):
f.write(bytes([dataReader.read_byte()]))
f.close()

View File

@ -109,3 +109,7 @@ utime = "0.3.1"
workspace-hack = { version = "0.1", path = "../tools/workspace-hack" }
zip = { version = "0.6.3", default-features = false, features = ["deflate", "time"] }
zstd = { version = "0.12.2", features = ["zstdmt"] }
[target.'cfg(windows)'.dependencies.windows]
version = "0.44.0"
features = ["Media_SpeechSynthesis", "Foundation_Collections", "Storage_Streams"]

View File

@ -4,6 +4,7 @@
use super::Backend;
use crate::card_rendering::extract_av_tags;
use crate::card_rendering::strip_av_tags;
use crate::card_rendering::tts;
use crate::cloze::extract_cloze_for_typing;
use crate::latex::extract_latex;
use crate::latex::extract_latex_expanding_clozes;
@ -175,6 +176,27 @@ impl CardRenderingService for Backend {
.to_string()
.into())
}
fn all_tts_voices(
&self,
input: pb::card_rendering::AllTtsVoicesRequest,
) -> Result<pb::card_rendering::AllTtsVoicesResponse> {
tts::all_voices(input.validate)
.map(|voices| pb::card_rendering::AllTtsVoicesResponse { voices })
}
fn write_tts_stream(
&self,
request: pb::card_rendering::WriteTtsStreamRequest,
) -> Result<pb::generic::Empty> {
tts::write_stream(
&request.path,
&request.voice_id,
request.speed,
&request.text,
)
.map(Into::into)
}
}
fn rendered_nodes_to_proto(

View File

@ -40,6 +40,8 @@ impl AnkiError {
AnkiError::FileIoError { .. } => Kind::IoError,
AnkiError::MediaCheckRequired => Kind::InvalidInput,
AnkiError::InvalidId => Kind::InvalidInput,
#[cfg(windows)]
AnkiError::WindowsError { .. } => Kind::OsError,
};
pb::backend::BackendError {

View File

@ -7,6 +7,7 @@ use crate::pb;
use crate::prelude::*;
mod parser;
pub mod tts;
mod writer;
pub fn strip_av_tags<S: Into<String> + AsRef<str>>(txt: S) -> String {

View File

@ -0,0 +1,20 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::pb::card_rendering::all_tts_voices_response::TtsVoice;
use crate::prelude::*;
#[cfg(windows)]
#[path = "windows.rs"]
mod inner;
#[cfg(not(windows))]
#[path = "other.rs"]
mod inner;
pub fn all_voices(validate: bool) -> Result<Vec<TtsVoice>> {
inner::all_voices(validate)
}
pub fn write_stream(path: &str, voice_id: &str, speed: f32, text: &str) -> Result<()> {
inner::write_stream(path, voice_id, speed, text)
}

View File

@ -0,0 +1,13 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use crate::pb::card_rendering::all_tts_voices_response::TtsVoice;
use crate::prelude::*;
pub(super) fn all_voices(_validate: bool) -> Result<Vec<TtsVoice>> {
invalid_input!("not implemented for this OS");
}
pub(super) fn write_stream(_path: &str, _voice_id: &str, _speed: f32, _text: &str) -> Result<()> {
invalid_input!("not implemented for this OS");
}

View File

@ -0,0 +1,106 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use std::fs::File;
use std::io::Write;
use futures::executor::block_on;
use windows::core::HSTRING;
use windows::Media::SpeechSynthesis::SpeechSynthesisStream;
use windows::Media::SpeechSynthesis::SpeechSynthesizer;
use windows::Media::SpeechSynthesis::VoiceInformation;
use windows::Storage::Streams::DataReader;
use crate::error::windows::WindowsErrorDetails;
use crate::error::windows::WindowsSnafu;
use crate::pb::card_rendering::all_tts_voices_response::TtsVoice;
use crate::prelude::*;
const MAX_BUFFER_SIZE: usize = 128 * 1024;
pub(super) fn all_voices(validate: bool) -> Result<Vec<TtsVoice>> {
SpeechSynthesizer::AllVoices()?
.into_iter()
.map(|info| TtsVoice::from_voice_information(info, validate))
.collect()
}
pub(super) fn write_stream(path: &str, voice_id: &str, speed: f32, text: &str) -> Result<()> {
let voice = find_voice(voice_id)?;
let stream = synthesize_stream(&voice, speed, text)?;
write_stream_to_path(stream, path)?;
Ok(())
}
fn find_voice(voice_id: &str) -> Result<VoiceInformation> {
SpeechSynthesizer::AllVoices()?
.into_iter()
.find(|info| {
info.Id()
.map(|id| id.to_string_lossy().eq(voice_id))
.unwrap_or_default()
})
.or_invalid("voice id not found")
}
fn to_hstring(text: &str) -> HSTRING {
let utf16: Vec<u16> = text.encode_utf16().collect();
HSTRING::from_wide(&utf16).expect("Strings are valid Unicode")
}
fn synthesize_stream(
voice: &VoiceInformation,
speed: f32,
text: &str,
) -> Result<SpeechSynthesisStream> {
let synthesizer = SpeechSynthesizer::new()?;
synthesizer.SetVoice(voice).with_context(|_| WindowsSnafu {
details: WindowsErrorDetails::SettingVoice(voice.clone()),
})?;
synthesizer
.Options()?
.SetSpeakingRate(speed as f64)
.context(WindowsSnafu {
details: WindowsErrorDetails::SettingRate(speed),
})?;
let async_op = synthesizer.SynthesizeTextToStreamAsync(&to_hstring(text))?;
let stream = block_on(async_op).context(WindowsSnafu {
details: WindowsErrorDetails::Synthesizing,
})?;
Ok(stream)
}
fn write_stream_to_path(stream: SpeechSynthesisStream, path: &str) -> Result<()> {
let input_stream = stream.GetInputStreamAt(0)?;
let date_reader = DataReader::CreateDataReader(&input_stream)?;
let stream_size = stream.Size()?.try_into().or_invalid("stream too large")?;
date_reader.LoadAsync(stream_size)?;
let mut file = File::create(path)?;
write_reader_to_file(date_reader, &mut file, stream_size as usize)
}
fn write_reader_to_file(reader: DataReader, file: &mut File, stream_size: usize) -> Result<()> {
let mut bytes_remaining = stream_size;
let mut buf = [0u8; MAX_BUFFER_SIZE];
while bytes_remaining > 0 {
let chunk_size = bytes_remaining.min(MAX_BUFFER_SIZE);
reader.ReadBytes(&mut buf[..chunk_size])?;
file.write_all(&buf[..chunk_size])?;
bytes_remaining -= chunk_size;
}
Ok(())
}
impl TtsVoice {
fn from_voice_information(info: VoiceInformation, validate: bool) -> Result<Self> {
Ok(Self {
id: info.Id()?.to_string_lossy(),
name: info.DisplayName()?.to_string_lossy(),
language: info.Language()?.to_string_lossy(),
// Windows lists voices that fail when actually trying to use them. This has been
// observed with voices from an uninstalled language pack.
// Validation is optional because it may be slow.
available: validate.then(|| synthesize_stream(&info, 1.0, "").is_ok()),
})
}
}

View File

@ -8,6 +8,8 @@ mod invalid_input;
pub(crate) mod network;
mod not_found;
mod search;
#[cfg(windows)]
pub mod windows;
pub use db::DbError;
pub use db::DbErrorKind;
@ -33,7 +35,7 @@ use crate::links::HelpPage;
pub type Result<T, E = AnkiError> = std::result::Result<T, E>;
#[derive(Debug, PartialEq, Eq, Snafu)]
#[derive(Debug, PartialEq, Snafu)]
pub enum AnkiError {
#[snafu(context(false))]
InvalidInput {
@ -105,6 +107,11 @@ pub enum AnkiError {
source: ImportError,
},
InvalidId,
#[cfg(windows)]
#[snafu(context(false))]
WindowsError {
source: windows::WindowsError,
},
}
// error helpers
@ -154,6 +161,8 @@ impl AnkiError {
AnkiError::FileIoError { source } => source.message(),
AnkiError::InvalidInput { source } => source.message(),
AnkiError::NotFound { source } => source.message(tr),
#[cfg(windows)]
AnkiError::WindowsError { source } => format!("{source:?}"),
}
}

View File

@ -0,0 +1,32 @@
// Copyright: Ankitects Pty Ltd and contributors
// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
use snafu::Snafu;
use super::AnkiError;
#[derive(Debug, PartialEq, Snafu)]
#[snafu(visibility(pub))]
pub struct WindowsError {
details: WindowsErrorDetails,
source: windows::core::Error,
}
#[derive(Debug, PartialEq)]
pub enum WindowsErrorDetails {
SettingVoice(windows::Media::SpeechSynthesis::VoiceInformation),
SettingRate(f32),
Synthesizing,
Other,
}
impl From<windows::core::Error> for AnkiError {
fn from(source: windows::core::Error) -> Self {
AnkiError::WindowsError {
source: WindowsError {
source,
details: WindowsErrorDetails::Other,
},
}
}
}