OpenAI sync
This commit is contained in:
parent
bc58e8a310
commit
c3d02f092d
|
@ -3378,6 +3378,7 @@ dependencies = [
|
|||
"tokenizers",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"ureq",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
|
|
|
@ -91,6 +91,7 @@ liquid = "0.26.4"
|
|||
arroy = "0.2.0"
|
||||
rand = "0.8.5"
|
||||
tracing = "0.1.40"
|
||||
ureq = { version = "2.9.6", features = ["json"] }
|
||||
|
||||
[dev-dependencies]
|
||||
mimalloc = { version = "0.1.39", default-features = false }
|
||||
|
|
|
@ -53,17 +53,17 @@ pub enum EmbedErrorKind {
|
|||
#[error("could not run model: {0}")]
|
||||
ModelForward(candle_core::Error),
|
||||
#[error("could not reach OpenAI: {0}")]
|
||||
OpenAiNetwork(reqwest::Error),
|
||||
OpenAiNetwork(ureq::Transport),
|
||||
#[error("unexpected response from OpenAI: {0}")]
|
||||
OpenAiUnexpected(reqwest::Error),
|
||||
#[error("could not authenticate against OpenAI: {0}")]
|
||||
OpenAiAuth(OpenAiError),
|
||||
#[error("sent too many requests to OpenAI: {0}")]
|
||||
OpenAiTooManyRequests(OpenAiError),
|
||||
OpenAiUnexpected(ureq::Error),
|
||||
#[error("could not authenticate against OpenAI: {0:?}")]
|
||||
OpenAiAuth(Option<OpenAiError>),
|
||||
#[error("sent too many requests to OpenAI: {0:?}")]
|
||||
OpenAiTooManyRequests(Option<OpenAiError>),
|
||||
#[error("received internal error from OpenAI: {0:?}")]
|
||||
OpenAiInternalServerError(Option<OpenAiError>),
|
||||
#[error("sent too many tokens in a request to OpenAI: {0}")]
|
||||
OpenAiTooManyTokens(OpenAiError),
|
||||
#[error("sent too many tokens in a request to OpenAI: {0:?}")]
|
||||
OpenAiTooManyTokens(Option<OpenAiError>),
|
||||
#[error("received unhandled HTTP status code {0} from OpenAI")]
|
||||
OpenAiUnhandledStatusCode(u16),
|
||||
#[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")]
|
||||
|
@ -102,19 +102,19 @@ impl EmbedError {
|
|||
Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn openai_network(inner: reqwest::Error) -> Self {
|
||||
pub fn openai_network(inner: ureq::Transport) -> Self {
|
||||
Self { kind: EmbedErrorKind::OpenAiNetwork(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn openai_unexpected(inner: reqwest::Error) -> EmbedError {
|
||||
pub fn openai_unexpected(inner: ureq::Error) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OpenAiUnexpected(inner), fault: FaultSource::Bug }
|
||||
}
|
||||
|
||||
pub(crate) fn openai_auth_error(inner: OpenAiError) -> EmbedError {
|
||||
pub(crate) fn openai_auth_error(inner: Option<OpenAiError>) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OpenAiAuth(inner), fault: FaultSource::User }
|
||||
}
|
||||
|
||||
pub(crate) fn openai_too_many_requests(inner: OpenAiError) -> EmbedError {
|
||||
pub(crate) fn openai_too_many_requests(inner: Option<OpenAiError>) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OpenAiTooManyRequests(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
|
@ -122,7 +122,7 @@ impl EmbedError {
|
|||
Self { kind: EmbedErrorKind::OpenAiInternalServerError(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub(crate) fn openai_too_many_tokens(inner: OpenAiError) -> EmbedError {
|
||||
pub(crate) fn openai_too_many_tokens(inner: Option<OpenAiError>) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OpenAiTooManyTokens(inner), fault: FaultSource::Bug }
|
||||
}
|
||||
|
||||
|
@ -220,7 +220,7 @@ impl NewEmbedderError {
|
|||
Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn hf_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError {
|
||||
pub fn could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError {
|
||||
Self {
|
||||
kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner),
|
||||
fault: FaultSource::Runtime,
|
||||
|
|
|
@ -131,7 +131,7 @@ impl Embedder {
|
|||
|
||||
let embeddings = this
|
||||
.embed(vec!["test".into()])
|
||||
.map_err(NewEmbedderError::hf_could_not_determine_dimension)?;
|
||||
.map_err(NewEmbedderError::could_not_determine_dimension)?;
|
||||
this.dimensions = embeddings.first().unwrap().dimension();
|
||||
|
||||
Ok(this)
|
||||
|
|
|
@ -98,7 +98,7 @@ pub enum Embedder {
|
|||
/// An embedder based on running local models, fetched from the Hugging Face Hub.
|
||||
HuggingFace(hf::Embedder),
|
||||
/// An embedder based on making embedding queries against the OpenAI API.
|
||||
OpenAi(openai::Embedder),
|
||||
OpenAi(openai::sync::Embedder),
|
||||
/// An embedder based on the user providing the embeddings in the documents and queries.
|
||||
UserProvided(manual::Embedder),
|
||||
Ollama(ollama::Embedder),
|
||||
|
@ -201,7 +201,7 @@ impl Embedder {
|
|||
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
|
||||
Ok(match options {
|
||||
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
|
||||
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?),
|
||||
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::sync::Embedder::new(options)?),
|
||||
EmbedderOptions::Ollama(options) => Self::Ollama(ollama::Embedder::new(options)?),
|
||||
EmbedderOptions::UserProvided(options) => {
|
||||
Self::UserProvided(manual::Embedder::new(options))
|
||||
|
@ -218,10 +218,7 @@ impl Embedder {
|
|||
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.embed(texts),
|
||||
Embedder::OpenAi(embedder) => {
|
||||
let client = embedder.new_client()?;
|
||||
embedder.embed(texts, &client).await
|
||||
}
|
||||
Embedder::OpenAi(embedder) => embedder.embed(texts),
|
||||
Embedder::Ollama(embedder) => {
|
||||
let client = embedder.new_client()?;
|
||||
embedder.embed(texts, &client).await
|
||||
|
|
|
@ -1,18 +1,10 @@
|
|||
use std::fmt::Display;
|
||||
|
||||
use reqwest::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::error::{EmbedError, NewEmbedderError};
|
||||
use super::{DistributionShift, Embedding, Embeddings};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Embedder {
|
||||
headers: reqwest::header::HeaderMap,
|
||||
tokenizer: tiktoken_rs::CoreBPE,
|
||||
options: EmbedderOptions,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||
pub struct EmbedderOptions {
|
||||
pub api_key: Option<String>,
|
||||
|
@ -125,298 +117,6 @@ impl EmbedderOptions {
|
|||
}
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
pub fn new_client(&self) -> Result<reqwest::Client, EmbedError> {
|
||||
reqwest::ClientBuilder::new()
|
||||
.default_headers(self.headers.clone())
|
||||
.build()
|
||||
.map_err(EmbedError::openai_initialize_web_client)
|
||||
}
|
||||
|
||||
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
let mut inferred_api_key = Default::default();
|
||||
let api_key = options.api_key.as_ref().unwrap_or_else(|| {
|
||||
inferred_api_key = infer_api_key();
|
||||
&inferred_api_key
|
||||
});
|
||||
headers.insert(
|
||||
reqwest::header::AUTHORIZATION,
|
||||
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key))
|
||||
.map_err(NewEmbedderError::openai_invalid_api_key_format)?,
|
||||
);
|
||||
headers.insert(
|
||||
reqwest::header::CONTENT_TYPE,
|
||||
reqwest::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
// looking at the code it is very unclear that this can actually fail.
|
||||
let tokenizer = tiktoken_rs::cl100k_base().unwrap();
|
||||
|
||||
Ok(Self { options, headers, tokenizer })
|
||||
}
|
||||
|
||||
pub async fn embed(
|
||||
&self,
|
||||
texts: Vec<String>,
|
||||
client: &reqwest::Client,
|
||||
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
let mut tokenized = false;
|
||||
|
||||
for attempt in 0..7 {
|
||||
let result = if tokenized {
|
||||
self.try_embed_tokenized(&texts, client).await
|
||||
} else {
|
||||
self.try_embed(&texts, client).await
|
||||
};
|
||||
|
||||
let retry_duration = match result {
|
||||
Ok(embeddings) => return Ok(embeddings),
|
||||
Err(retry) => {
|
||||
tracing::warn!("Failed: {}", retry.error);
|
||||
tokenized |= retry.must_tokenize();
|
||||
retry.into_duration(attempt)
|
||||
}
|
||||
}?;
|
||||
|
||||
let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute
|
||||
tracing::warn!(
|
||||
"Attempt #{}, retrying after {}ms.",
|
||||
attempt,
|
||||
retry_duration.as_millis()
|
||||
);
|
||||
tokio::time::sleep(retry_duration).await;
|
||||
}
|
||||
|
||||
let result = if tokenized {
|
||||
self.try_embed_tokenized(&texts, client).await
|
||||
} else {
|
||||
self.try_embed(&texts, client).await
|
||||
};
|
||||
|
||||
result.map_err(Retry::into_error)
|
||||
}
|
||||
|
||||
async fn check_response(response: reqwest::Response) -> Result<reqwest::Response, Retry> {
|
||||
if !response.status().is_success() {
|
||||
match response.status() {
|
||||
StatusCode::UNAUTHORIZED => {
|
||||
let error_response: OpenAiErrorResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(EmbedError::openai_unexpected)
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
return Err(Retry::give_up(EmbedError::openai_auth_error(
|
||||
error_response.error,
|
||||
)));
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
let error_response: OpenAiErrorResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(EmbedError::openai_unexpected)
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
return Err(Retry::rate_limited(EmbedError::openai_too_many_requests(
|
||||
error_response.error,
|
||||
)));
|
||||
}
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
| StatusCode::BAD_GATEWAY
|
||||
| StatusCode::SERVICE_UNAVAILABLE => {
|
||||
let error_response: Result<OpenAiErrorResponse, _> = response.json().await;
|
||||
return Err(Retry::retry_later(EmbedError::openai_internal_server_error(
|
||||
error_response.ok().map(|error_response| error_response.error),
|
||||
)));
|
||||
}
|
||||
StatusCode::BAD_REQUEST => {
|
||||
// Most probably, one text contained too many tokens
|
||||
let error_response: OpenAiErrorResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(EmbedError::openai_unexpected)
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
tracing::warn!("OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your prompt.");
|
||||
|
||||
return Err(Retry::retry_tokenized(EmbedError::openai_too_many_tokens(
|
||||
error_response.error,
|
||||
)));
|
||||
}
|
||||
code => {
|
||||
return Err(Retry::retry_later(EmbedError::openai_unhandled_status_code(
|
||||
code.as_u16(),
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn try_embed<S: AsRef<str> + serde::Serialize>(
|
||||
&self,
|
||||
texts: &[S],
|
||||
client: &reqwest::Client,
|
||||
) -> Result<Vec<Embeddings<f32>>, Retry> {
|
||||
for text in texts {
|
||||
tracing::trace!("Received prompt: {}", text.as_ref())
|
||||
}
|
||||
let request = OpenAiRequest {
|
||||
model: self.options.embedding_model.name(),
|
||||
input: texts,
|
||||
dimensions: self.overriden_dimensions(),
|
||||
};
|
||||
let response = client
|
||||
.post(OPENAI_EMBEDDINGS_URL)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(EmbedError::openai_network)
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
let response = Self::check_response(response).await?;
|
||||
|
||||
let response: OpenAiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(EmbedError::openai_unexpected)
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
tracing::trace!("response: {:?}", response.data);
|
||||
|
||||
Ok(response
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|data| Embeddings::from_single_embedding(data.embedding))
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn try_embed_tokenized(
|
||||
&self,
|
||||
text: &[String],
|
||||
client: &reqwest::Client,
|
||||
) -> Result<Vec<Embeddings<f32>>, Retry> {
|
||||
pub const OVERLAP_SIZE: usize = 200;
|
||||
let mut all_embeddings = Vec::with_capacity(text.len());
|
||||
for text in text {
|
||||
let max_token_count = self.options.embedding_model.max_token();
|
||||
let encoded = self.tokenizer.encode_ordinary(text.as_str());
|
||||
let len = encoded.len();
|
||||
if len < max_token_count {
|
||||
all_embeddings.append(&mut self.try_embed(&[text], client).await?);
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut tokens = encoded.as_slice();
|
||||
let mut embeddings_for_prompt = Embeddings::new(self.dimensions());
|
||||
while tokens.len() > max_token_count {
|
||||
let window = &tokens[..max_token_count];
|
||||
embeddings_for_prompt.push(self.embed_tokens(window, client).await?).unwrap();
|
||||
|
||||
tokens = &tokens[max_token_count - OVERLAP_SIZE..];
|
||||
}
|
||||
|
||||
// end of text
|
||||
embeddings_for_prompt.push(self.embed_tokens(tokens, client).await?).unwrap();
|
||||
|
||||
all_embeddings.push(embeddings_for_prompt);
|
||||
}
|
||||
Ok(all_embeddings)
|
||||
}
|
||||
|
||||
async fn embed_tokens(
|
||||
&self,
|
||||
tokens: &[usize],
|
||||
client: &reqwest::Client,
|
||||
) -> Result<Embedding, Retry> {
|
||||
for attempt in 0..9 {
|
||||
let duration = match self.try_embed_tokens(tokens, client).await {
|
||||
Ok(embedding) => return Ok(embedding),
|
||||
Err(retry) => retry.into_duration(attempt),
|
||||
}
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
tokio::time::sleep(duration).await;
|
||||
}
|
||||
|
||||
self.try_embed_tokens(tokens, client)
|
||||
.await
|
||||
.map_err(|retry| Retry::give_up(retry.into_error()))
|
||||
}
|
||||
|
||||
async fn try_embed_tokens(
|
||||
&self,
|
||||
tokens: &[usize],
|
||||
client: &reqwest::Client,
|
||||
) -> Result<Embedding, Retry> {
|
||||
let request = OpenAiTokensRequest {
|
||||
model: self.options.embedding_model.name(),
|
||||
input: tokens,
|
||||
dimensions: self.overriden_dimensions(),
|
||||
};
|
||||
let response = client
|
||||
.post(OPENAI_EMBEDDINGS_URL)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(EmbedError::openai_network)
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
let response = Self::check_response(response).await?;
|
||||
|
||||
let mut response: OpenAiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(EmbedError::openai_unexpected)
|
||||
.map_err(Retry::retry_later)?;
|
||||
Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default())
|
||||
}
|
||||
|
||||
pub fn embed_chunks(
|
||||
&self,
|
||||
text_chunks: Vec<Vec<String>>,
|
||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_io()
|
||||
.enable_time()
|
||||
.build()
|
||||
.map_err(EmbedError::openai_runtime_init)?;
|
||||
let client = self.new_client()?;
|
||||
rt.block_on(futures::future::try_join_all(
|
||||
text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)),
|
||||
))
|
||||
}
|
||||
|
||||
pub fn chunk_count_hint(&self) -> usize {
|
||||
10
|
||||
}
|
||||
|
||||
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
||||
10
|
||||
}
|
||||
|
||||
pub fn dimensions(&self) -> usize {
|
||||
if self.options.embedding_model.supports_overriding_dimensions() {
|
||||
self.options.dimensions.unwrap_or(self.options.embedding_model.default_dimensions())
|
||||
} else {
|
||||
self.options.embedding_model.default_dimensions()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||
self.options.embedding_model.distribution()
|
||||
}
|
||||
|
||||
fn overriden_dimensions(&self) -> Option<usize> {
|
||||
if self.options.embedding_model.supports_overriding_dimensions() {
|
||||
self.options.dimensions
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// retrying in case of failure
|
||||
|
||||
pub struct Retry {
|
||||
|
@ -524,3 +224,257 @@ fn infer_api_key() -> String {
|
|||
.or_else(|_| std::env::var("OPENAI_API_KEY"))
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub mod sync {
|
||||
use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
|
||||
|
||||
use super::{
|
||||
EmbedError, Embedding, Embeddings, NewEmbedderError, OpenAiErrorResponse, OpenAiRequest,
|
||||
OpenAiResponse, OpenAiTokensRequest, Retry, OPENAI_EMBEDDINGS_URL,
|
||||
};
|
||||
use crate::vector::DistributionShift;
|
||||
|
||||
const REQUEST_PARALLELISM: usize = 10;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Embedder {
|
||||
tokenizer: tiktoken_rs::CoreBPE,
|
||||
options: super::EmbedderOptions,
|
||||
bearer: String,
|
||||
threads: rayon::ThreadPool,
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
pub fn new(options: super::EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
||||
let mut inferred_api_key = Default::default();
|
||||
let api_key = options.api_key.as_ref().unwrap_or_else(|| {
|
||||
inferred_api_key = super::infer_api_key();
|
||||
&inferred_api_key
|
||||
});
|
||||
let bearer = format!("Bearer {api_key}");
|
||||
|
||||
// looking at the code it is very unclear that this can actually fail.
|
||||
let tokenizer = tiktoken_rs::cl100k_base().unwrap();
|
||||
|
||||
// FIXME: unwrap
|
||||
let threads = rayon::ThreadPoolBuilder::new()
|
||||
.num_threads(REQUEST_PARALLELISM)
|
||||
.thread_name(|index| format!("embedder-chunk-{index}"))
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
Ok(Self { options, bearer, tokenizer, threads })
|
||||
}
|
||||
|
||||
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
let mut tokenized = false;
|
||||
|
||||
let client = ureq::agent();
|
||||
|
||||
for attempt in 0..7 {
|
||||
let result = if tokenized {
|
||||
self.try_embed_tokenized(&texts, &client)
|
||||
} else {
|
||||
self.try_embed(&texts, &client)
|
||||
};
|
||||
|
||||
let retry_duration = match result {
|
||||
Ok(embeddings) => return Ok(embeddings),
|
||||
Err(retry) => {
|
||||
tracing::warn!("Failed: {}", retry.error);
|
||||
tokenized |= retry.must_tokenize();
|
||||
retry.into_duration(attempt)
|
||||
}
|
||||
}?;
|
||||
|
||||
let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute
|
||||
tracing::warn!(
|
||||
"Attempt #{}, retrying after {}ms.",
|
||||
attempt,
|
||||
retry_duration.as_millis()
|
||||
);
|
||||
std::thread::sleep(retry_duration);
|
||||
}
|
||||
|
||||
let result = if tokenized {
|
||||
self.try_embed_tokenized(&texts, &client)
|
||||
} else {
|
||||
self.try_embed(&texts, &client)
|
||||
};
|
||||
|
||||
result.map_err(Retry::into_error)
|
||||
}
|
||||
|
||||
fn check_response(
|
||||
response: Result<ureq::Response, ureq::Error>,
|
||||
) -> Result<ureq::Response, Retry> {
|
||||
match response {
|
||||
Ok(response) => Ok(response),
|
||||
Err(ureq::Error::Status(code, response)) => {
|
||||
let error_response: Option<OpenAiErrorResponse> = response.into_json().ok();
|
||||
let error = error_response.map(|response| response.error);
|
||||
Err(match code {
|
||||
401 => Retry::give_up(EmbedError::openai_auth_error(error)),
|
||||
429 => Retry::rate_limited(EmbedError::openai_too_many_requests(error)),
|
||||
400 => {
|
||||
tracing::warn!("OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template.");
|
||||
|
||||
Retry::retry_tokenized(EmbedError::openai_too_many_tokens(error))
|
||||
}
|
||||
500..=599 => {
|
||||
Retry::retry_later(EmbedError::openai_internal_server_error(error))
|
||||
}
|
||||
x => Retry::retry_later(EmbedError::openai_unhandled_status_code(code)),
|
||||
})
|
||||
}
|
||||
Err(ureq::Error::Transport(transport)) => {
|
||||
Err(Retry::retry_later(EmbedError::openai_network(transport)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn try_embed<S: AsRef<str> + serde::Serialize>(
|
||||
&self,
|
||||
texts: &[S],
|
||||
client: &ureq::Agent,
|
||||
) -> Result<Vec<Embeddings<f32>>, Retry> {
|
||||
for text in texts {
|
||||
tracing::trace!("Received prompt: {}", text.as_ref())
|
||||
}
|
||||
let request = OpenAiRequest {
|
||||
model: self.options.embedding_model.name(),
|
||||
input: texts,
|
||||
dimensions: self.overriden_dimensions(),
|
||||
};
|
||||
let response = client
|
||||
.post(OPENAI_EMBEDDINGS_URL)
|
||||
.set("Authorization", &self.bearer)
|
||||
.send_json(&request);
|
||||
|
||||
let response = Self::check_response(response)?;
|
||||
|
||||
let response: OpenAiResponse = response
|
||||
.into_json()
|
||||
.map_err(EmbedError::openai_unexpected)
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
tracing::trace!("response: {:?}", response.data);
|
||||
|
||||
Ok(response
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|data| Embeddings::from_single_embedding(data.embedding))
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn try_embed_tokenized(
|
||||
&self,
|
||||
text: &[String],
|
||||
client: &ureq::Agent,
|
||||
) -> Result<Vec<Embeddings<f32>>, Retry> {
|
||||
pub const OVERLAP_SIZE: usize = 200;
|
||||
let mut all_embeddings = Vec::with_capacity(text.len());
|
||||
for text in text {
|
||||
let max_token_count = self.options.embedding_model.max_token();
|
||||
let encoded = self.tokenizer.encode_ordinary(text.as_str());
|
||||
let len = encoded.len();
|
||||
if len < max_token_count {
|
||||
all_embeddings.append(&mut self.try_embed(&[text], client)?);
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut tokens = encoded.as_slice();
|
||||
let mut embeddings_for_prompt = Embeddings::new(self.dimensions());
|
||||
while tokens.len() > max_token_count {
|
||||
let window = &tokens[..max_token_count];
|
||||
embeddings_for_prompt.push(self.embed_tokens(window, client)?).unwrap();
|
||||
|
||||
tokens = &tokens[max_token_count - OVERLAP_SIZE..];
|
||||
}
|
||||
|
||||
// end of text
|
||||
embeddings_for_prompt.push(self.embed_tokens(tokens, client)?).unwrap();
|
||||
|
||||
all_embeddings.push(embeddings_for_prompt);
|
||||
}
|
||||
Ok(all_embeddings)
|
||||
}
|
||||
|
||||
fn embed_tokens(&self, tokens: &[usize], client: &ureq::Agent) -> Result<Embedding, Retry> {
|
||||
for attempt in 0..9 {
|
||||
let duration = match self.try_embed_tokens(tokens, client) {
|
||||
Ok(embedding) => return Ok(embedding),
|
||||
Err(retry) => retry.into_duration(attempt),
|
||||
}
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
std::thread::sleep(duration);
|
||||
}
|
||||
|
||||
self.try_embed_tokens(tokens, client)
|
||||
.map_err(|retry| Retry::give_up(retry.into_error()))
|
||||
}
|
||||
|
||||
fn try_embed_tokens(
|
||||
&self,
|
||||
tokens: &[usize],
|
||||
client: &ureq::Agent,
|
||||
) -> Result<Embedding, Retry> {
|
||||
let request = OpenAiTokensRequest {
|
||||
model: self.options.embedding_model.name(),
|
||||
input: tokens,
|
||||
dimensions: self.overriden_dimensions(),
|
||||
};
|
||||
let response = client
|
||||
.post(OPENAI_EMBEDDINGS_URL)
|
||||
.set("Authorization", &self.bearer)
|
||||
.send_json(&request);
|
||||
|
||||
let response = Self::check_response(response)?;
|
||||
|
||||
let mut response: OpenAiResponse = response
|
||||
.into_json()
|
||||
.map_err(EmbedError::openai_unexpected)
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default())
|
||||
}
|
||||
|
||||
pub fn embed_chunks(
|
||||
&self,
|
||||
text_chunks: Vec<Vec<String>>,
|
||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||
self.threads
|
||||
.install(move || text_chunks.into_par_iter().map(|chunk| self.embed(chunk)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn chunk_count_hint(&self) -> usize {
|
||||
10
|
||||
}
|
||||
|
||||
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
||||
10
|
||||
}
|
||||
|
||||
pub fn dimensions(&self) -> usize {
|
||||
if self.options.embedding_model.supports_overriding_dimensions() {
|
||||
self.options.dimensions.unwrap_or(self.options.embedding_model.default_dimensions())
|
||||
} else {
|
||||
self.options.embedding_model.default_dimensions()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||
self.options.embedding_model.distribution()
|
||||
}
|
||||
|
||||
fn overriden_dimensions(&self) -> Option<usize> {
|
||||
if self.options.embedding_model.supports_overriding_dimensions() {
|
||||
self.options.dimensions
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue