Added in-process embedding models (#41)

- all-minilm-l6-v2
- all-minilm-l6-v2-q
- e5-small-v2
- e5-small-v2-q

The idea is to give users an option to embed documents/texts in the same
Java process without any external dependencies.
ONNX Runtime is used to run models inside JVM.
Each model resides in it's own maven module (inside the jar).
This commit is contained in:
LangChain4j 2023-07-23 19:05:13 +02:00 committed by GitHub
parent 3cc75c771e
commit 529ef6b647
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
43 changed files with 122894 additions and 47 deletions

1
.gitattributes vendored Normal file
View File

@ -0,0 +1 @@
*.onnx filter=lfs diff=lfs merge=lfs -text

View File

@ -23,7 +23,7 @@ Please provide a relevant code snippets to reproduce this bug.
A clear and concise description of what you expected to happen.
**Please complete the following information:**
- LangChain4j version: e.g. 0.16.0
- LangChain4j version: e.g. 0.17.0
- Java version: e.g. 11
- Spring Boot version (if applicable): e.g. 2.7.13

View File

@ -7,14 +7,20 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Pull LFS files
run: git lfs pull
- name: Set up JDK 8
uses: actions/setup-java@v3
with:
java-version: '8'
distribution: 'temurin'
- name: Test
run: mvn --batch-mode test
# For checking some compliance things (require a recent JDK due to plugins so in a separate step)
compliance:
runs-on: ubuntu-latest

View File

@ -7,7 +7,7 @@
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.16.0</version>
<version>0.17.0</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>

View File

@ -1,15 +1,21 @@
package dev.langchain4j.model.embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import java.util.List;
import static java.util.Collections.singletonList;
public interface EmbeddingModel {
Embedding embed(String text);
default Embedding embed(String text) {
return embed(TextSegment.from(text));
}
Embedding embed(TextSegment textSegment);
default Embedding embed(TextSegment textSegment) {
return embedAll(singletonList(textSegment)).get(0);
}
List<Embedding> embedAll(List<TextSegment> textSegments);
}

View File

@ -0,0 +1,63 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.17.0</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
<packaging>jar</packaging>
<name>langchain4j-embeddings-all-minilm-l6-v2-q</name>
<description>In-process all-minilm-l6-v2 (quantized) embedding model</description>
<dependencies>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>0.17.0</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings</artifactId>
<version>0.17.0</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<licenses>
<license>
<name>Apache-2.0</name>
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url>
<distribution>repo</distribution>
<comments>A business-friendly OSS license</comments>
</license>
</licenses>
</project>

View File

@ -0,0 +1,14 @@
package dev.langchain4j.model.embedding;
public class ALL_MINILM_L6_V2_Q_EmbeddingModel extends AbstractInProcessEmbeddingModel {
private static final OnnxEmbeddingModel model = new OnnxEmbeddingModel(
"/all-minilm-l6-v2-q.onnx",
"/vocab.txt"
);
@Override
protected OnnxEmbeddingModel model() {
return model;
}
}

View File

@ -0,0 +1,5 @@
Model card:
https://huggingface.co/Xenova/all-MiniLM-L6-v2
Model file:
https://huggingface.co/Xenova/all-MiniLM-L6-v2/blob/main/onnx/model_quantized.onnx

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:afdb6f1a0e45b715d0bb9b11772f032c399babd23bfc31fed1c170afc848bdb1
size 22972370

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,23 @@
package dev.langchain4j.model.embedding;
import dev.langchain4j.data.embedding.Embedding;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
class ALL_MINILM_L6_V2_Q_EmbeddingModelTest {
@Test
void should_embed() {
EmbeddingModel model = new ALL_MINILM_L6_V2_Q_EmbeddingModel();
Embedding first = model.embed("hi");
assertThat(first.vector()).hasSize(384);
Embedding second = model.embed("hello");
assertThat(second.vector()).hasSize(384);
assertThat(Similarity.cosine(first.vector(), second.vector())).isGreaterThan(0.8);
}
}

View File

@ -0,0 +1,63 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.17.0</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
<packaging>jar</packaging>
<name>langchain4j-embeddings-all-minilm-l6-v2</name>
<description>In-process all-minilm-l6-v2 embedding model</description>
<dependencies>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>0.17.0</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings</artifactId>
<version>0.17.0</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<licenses>
<license>
<name>Apache-2.0</name>
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url>
<distribution>repo</distribution>
<comments>A business-friendly OSS license</comments>
</license>
</licenses>
</project>

View File

@ -0,0 +1,14 @@
package dev.langchain4j.model.embedding;
public class ALL_MINILM_L6_V2_EmbeddingModel extends AbstractInProcessEmbeddingModel {
private static final OnnxEmbeddingModel model = new OnnxEmbeddingModel(
"/all-minilm-l6-v2.onnx",
"/vocab.txt"
);
@Override
protected OnnxEmbeddingModel model() {
return model;
}
}

View File

@ -0,0 +1,5 @@
Model card:
https://huggingface.co/Xenova/all-MiniLM-L6-v2
Model file:
https://huggingface.co/Xenova/all-MiniLM-L6-v2/blob/main/onnx/model.onnx

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ca46f1a88a9c6e61b918af1ab38be3e7903b986616551f0a6f10a7ecc5730cd4
size 90984153

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,23 @@
package dev.langchain4j.model.embedding;
import dev.langchain4j.data.embedding.Embedding;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
class ALL_MINILM_L6_V2_EmbeddingModelTest {
@Test
void should_embed() {
EmbeddingModel model = new ALL_MINILM_L6_V2_EmbeddingModel();
Embedding first = model.embed("hi");
assertThat(first.vector()).hasSize(384);
Embedding second = model.embed("hello");
assertThat(second.vector()).hasSize(384);
assertThat(Similarity.cosine(first.vector(), second.vector())).isGreaterThan(0.8);
}
}

View File

@ -0,0 +1,63 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.17.0</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>
<artifactId>langchain4j-embeddings-e5-small-v2-q</artifactId>
<packaging>jar</packaging>
<name>langchain4j-embeddings-e5-small-v2-q</name>
<description>In-process e5-small-v2 (quantized) embedding model</description>
<dependencies>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>0.17.0</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings</artifactId>
<version>0.17.0</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<licenses>
<license>
<name>Apache-2.0</name>
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url>
<distribution>repo</distribution>
<comments>A business-friendly OSS license</comments>
</license>
</licenses>
</project>

View File

@ -0,0 +1,14 @@
package dev.langchain4j.model.embedding;
public class E5_SMALL_V2_Q_EmbeddingModel extends AbstractInProcessEmbeddingModel {
private static final OnnxEmbeddingModel model = new OnnxEmbeddingModel(
"/e5-small-v2-q.onnx",
"/vocab.txt"
);
@Override
protected OnnxEmbeddingModel model() {
return model;
}
}

View File

@ -0,0 +1,5 @@
Model card:
https://huggingface.co/Xenova/e5-small-v2
Model file:
https://huggingface.co/Xenova/e5-small-v2/resolve/main/onnx/model_quantized.onnx

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7d9092cb25f2bd1c023b7e8d2aa459044a02030ac880e5a59fdaf27af69f1ded
size 34014367

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,23 @@
package dev.langchain4j.model.embedding;
import dev.langchain4j.data.embedding.Embedding;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
class E5_SMALL_V2_Q_EmbeddingModelTest {
@Test
void should_embed() {
EmbeddingModel model = new E5_SMALL_V2_Q_EmbeddingModel();
Embedding first = model.embed("hi");
assertThat(first.vector()).hasSize(384);
Embedding second = model.embed("hello");
assertThat(second.vector()).hasSize(384);
assertThat(Similarity.cosine(first.vector(), second.vector())).isGreaterThan(0.95);
}
}

View File

@ -0,0 +1,63 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.17.0</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>
<artifactId>langchain4j-embeddings-e5-small-v2</artifactId>
<packaging>jar</packaging>
<name>langchain4j-embeddings-e5-small-v2</name>
<description>In-process e5-small-v2 embedding model</description>
<dependencies>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>0.17.0</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings</artifactId>
<version>0.17.0</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<licenses>
<license>
<name>Apache-2.0</name>
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url>
<distribution>repo</distribution>
<comments>A business-friendly OSS license</comments>
</license>
</licenses>
</project>

View File

@ -0,0 +1,14 @@
package dev.langchain4j.model.embedding;
public class E5_SMALL_V2_EmbeddingModel extends AbstractInProcessEmbeddingModel {
private static final OnnxEmbeddingModel model = new OnnxEmbeddingModel(
"/e5-small-v2.onnx",
"/vocab.txt"
);
@Override
protected OnnxEmbeddingModel model() {
return model;
}
}

View File

@ -0,0 +1,5 @@
Model card:
https://huggingface.co/Xenova/e5-small-v2
Model file:
https://huggingface.co/Xenova/e5-small-v2/resolve/main/onnx/model.onnx

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b2a43b66f7f9b6f29643a21340ad7c03d2d91e6bd5d43429a77799a9ee880eb0
size 133093467

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,23 @@
package dev.langchain4j.model.embedding;
import dev.langchain4j.data.embedding.Embedding;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
class E5_SMALL_V2_EmbeddingModelTest {
@Test
void should_embed() {
EmbeddingModel model = new E5_SMALL_V2_EmbeddingModel();
Embedding first = model.embed("hi");
assertThat(first.vector()).hasSize(384);
Embedding second = model.embed("hello");
assertThat(second.vector()).hasSize(384);
assertThat(Similarity.cosine(first.vector(), second.vector())).isGreaterThan(0.95);
}
}

View File

@ -0,0 +1,51 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.17.0</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>
<artifactId>langchain4j-embeddings</artifactId>
<packaging>jar</packaging>
<name>langchain4j-embeddings</name>
<description>Common functionality for other langchain4j-embeddings-xxx modules</description>
<dependencies>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>0.17.0</version>
</dependency>
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.15.1</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.23.0</version>
</dependency>
</dependencies>
<licenses>
<license>
<name>Apache-2.0</name>
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url>
<distribution>repo</distribution>
<comments>A business-friendly OSS license</comments>
</license>
</licenses>
</project>

View File

@ -0,0 +1,23 @@
package dev.langchain4j.model.embedding;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import java.util.List;
import static java.util.stream.Collectors.toList;
public abstract class AbstractInProcessEmbeddingModel implements EmbeddingModel {
protected abstract OnnxEmbeddingModel model();
@Override
public List<Embedding> embedAll(List<TextSegment> segments) {
return segments.stream()
.map(segment -> {
float[] vector = model().embed(segment.text());
return Embedding.from(vector);
})
.collect(toList());
}
}

View File

@ -0,0 +1,150 @@
package dev.langchain4j.model.embedding;
import ai.djl.modality.nlp.DefaultVocabulary;
import ai.djl.modality.nlp.bert.BertFullTokenizer;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.OrtSession.Result;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static ai.onnxruntime.OnnxTensor.createTensor;
import static java.nio.LongBuffer.wrap;
public class OnnxEmbeddingModel {
private final OrtEnvironment environment;
private final OrtSession session;
private final DefaultVocabulary vocabulary;
private final BertFullTokenizer tokenizer;
public OnnxEmbeddingModel(String modelFilePath, String vocabularyFilePath) {
try {
this.environment = OrtEnvironment.getEnvironment();
this.session = environment.createSession(loadModel(modelFilePath));
this.vocabulary = DefaultVocabulary.builder()
.addFromTextFile(getClass().getResource(vocabularyFilePath))
.build();
this.tokenizer = new BertFullTokenizer(vocabulary, true);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
public float[] embed(String text) {
try (Result result = runModel(text)) {
return toEmbedding(result);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private Result runModel(String text) throws OrtException {
List<String> stringTokens = new ArrayList<>();
stringTokens.add("[CLS]");
stringTokens.addAll(tokenizer.tokenize(text));
stringTokens.add("[SEP]");
// TODO reusable buffers
long[] tokens = stringTokens.stream()
.mapToLong(vocabulary::getIndex)
.toArray();
long[] attentionMasks = new long[stringTokens.size()];
for (int i = 0; i < stringTokens.size(); i++) {
attentionMasks[i] = 1L;
}
long[] tokenTypeIds = new long[stringTokens.size()];
for (int i = 0; i < stringTokens.size(); i++) {
tokenTypeIds[i] = 0L;
}
long[] shape = {1, tokens.length};
try (
OnnxTensor tokensTensor = createTensor(environment, wrap(tokens), shape);
OnnxTensor attentionMasksTensor = createTensor(environment, wrap(attentionMasks), shape);
OnnxTensor tokenTypeIdsTensor = createTensor(environment, wrap(tokenTypeIds), shape)
) {
Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put("input_ids", tokensTensor);
inputs.put("token_type_ids", tokenTypeIdsTensor);
inputs.put("attention_mask", attentionMasksTensor);
return session.run(inputs);
}
}
private byte[] loadModel(String modelFilePath) {
try (
InputStream inputStream = getClass().getResourceAsStream(modelFilePath);
ByteArrayOutputStream buffer = new ByteArrayOutputStream()
) {
int nRead;
byte[] data = new byte[1024];
while ((nRead = inputStream.read(data, 0, data.length)) != -1) {
buffer.write(data, 0, nRead);
}
buffer.flush();
return buffer.toByteArray();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
private static float[] toEmbedding(Result result) throws OrtException {
float[][] vectors = ((float[][][]) result.get(0).getValue())[0];
return normalize(meanPool(vectors));
}
private static float[] meanPool(float[][] vectors) {
int numVectors = vectors.length;
int vectorLength = vectors[0].length;
float[] averagedVector = new float[vectorLength];
// TODO ignore [CLS] and [SEP] ?
for (int i = 0; i < numVectors; i++) {
for (int j = 0; j < vectorLength; j++) {
averagedVector[j] += vectors[i][j];
}
}
// TODO ignore [CLS] and [SEP] ?
for (int j = 0; j < vectorLength; j++) {
averagedVector[j] /= numVectors;
}
return averagedVector;
}
private static float[] normalize(float[] vector) {
float sumSquare = 0;
for (float v : vector) {
sumSquare += v * v;
}
float norm = (float) Math.sqrt(sumSquare);
float[] normalizedVector = new float[vector.length];
for (int i = 0; i < vector.length; i++) {
normalizedVector[i] = vector[i] / norm;
}
return normalizedVector;
}
}

View File

@ -0,0 +1,19 @@
package dev.langchain4j.model.embedding;
public class Similarity {
public static double cosine(float[] a, float[] b) {
double dotProduct = 0.0;
double normA = 0.0;
double normB = 0.0;
for (int i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
normA += Math.pow(a[i], 2);
normB += Math.pow(b[i], 2);
}
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
}

View File

@ -6,7 +6,7 @@
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.16.0</version>
<version>0.17.0</version>
<packaging>pom</packaging>
<name>langchain4j parent POM</name>

View File

@ -7,7 +7,7 @@
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.16.0</version>
<version>0.17.0</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>
@ -23,7 +23,7 @@
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>0.16.0</version>
<version>0.17.0</version>
</dependency>
<dependency>

View File

@ -7,7 +7,7 @@
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.16.0</version>
<version>0.17.0</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>
@ -22,7 +22,7 @@
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<version>0.16.0</version>
<version>0.17.0</version>
</dependency>
<dependency>

View File

@ -7,7 +7,7 @@
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.16.0</version>
<version>0.17.0</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>
@ -24,7 +24,7 @@
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>0.16.0</version>
<version>0.17.0</version>
</dependency>
<dependency>
@ -37,13 +37,6 @@
<artifactId>jtokkit</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-pinecone</artifactId>
<version>0.16.0</version>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>

View File

@ -9,7 +9,6 @@ import java.time.Duration;
import java.util.List;
import static dev.langchain4j.model.huggingface.HuggingFaceModelName.SENTENCE_TRANSFORMERS_ALL_MINI_LM_L6_V2;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;
public class HuggingFaceEmbeddingModel implements EmbeddingModel {
@ -32,17 +31,6 @@ public class HuggingFaceEmbeddingModel implements EmbeddingModel {
this.waitForModel = waitForModel == null ? true : waitForModel;
}
@Override
public Embedding embed(String text) {
List<Embedding> embeddings = embedTexts(singletonList(text));
return embeddings.get(0);
}
@Override
public Embedding embed(TextSegment textSegment) {
return embed(textSegment.text());
}
@Override
public List<Embedding> embedAll(List<TextSegment> textSegments) {

View File

@ -0,0 +1,60 @@
package dev.langchain4j.model.inprocess;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import java.util.List;
import static java.lang.String.format;
/**
* This is an embedding model that runs within your Java application's process.
* This class serves as a thin wrapper over the actual implementations.
* The implementations are located in separate, optional langchain4j-embeddings-xxx modules and are loaded dynamically.
* If you wish to use the model XXX, please add the langchain4j-embeddings-xxx dependency to your project.
* The model execution is carried out using the ONNX Runtime.
*/
public class InProcessEmbeddingModel implements EmbeddingModel {
private final EmbeddingModel implementation;
public InProcessEmbeddingModel(InProcessEmbeddingModelType type) {
try {
implementation = loadDynamically(type);
} catch (ClassNotFoundException e) {
throw new RuntimeException(getMessage(type), e);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private static EmbeddingModel loadDynamically(InProcessEmbeddingModelType type) throws Exception {
Class<?> implementationClass = Class.forName(
format("dev.langchain4j.model.embedding.%s_EmbeddingModel", type.name()));
return (EmbeddingModel) implementationClass.getConstructor().newInstance();
}
private static String getMessage(InProcessEmbeddingModelType type) {
return format("To use %s embedding model, please add the following dependency to your project:\n"
+ "\n"
+ "Maven:\n"
+ "<dependency>\n" +
" <groupId>dev.langchain4j</groupId>\n" +
" <artifactId>langchain4j-embeddings-%s</artifactId>\n" +
" <version>0.17.0</version>\n" +
"</dependency>\n"
+ "\n"
+ "Gradle:\n"
+ "implementation 'dev.langchain4j:langchain4j-embeddings-%s:0.17.0'\n",
type.name(),
type.name().replace("_", "-").toLowerCase(),
type.name().replace("_", "-").toLowerCase()
);
}
@Override
public List<Embedding> embedAll(List<TextSegment> textSegments) {
return implementation.embedAll(textSegments);
}
}

View File

@ -0,0 +1,33 @@
package dev.langchain4j.model.inprocess;
/**
* Lists all the currently supported in-process embedding models.
* New models will be added gradually.
* If you would like a new model to be added, please open a GitHub issue at: https://github.com/langchain4j/langchain4j/issues/new/choose
*/
public enum InProcessEmbeddingModelType {
/**
* SentenceTransformers all-MiniLM-L6-v2 model.
* More details: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
*/
ALL_MINILM_L6_V2,
/**
* SentenceTransformers all-MiniLM-L6-v2 model (quantized).
* More details: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
*/
ALL_MINILM_L6_V2_Q,
/**
* E5-small-v2 model.
* More details: https://huggingface.co/intfloat/e5-small-v2
*/
E5_SMALL_V2,
/**
* E5-small-v2 model (quantized).
* More details: https://huggingface.co/intfloat/e5-small-v2
*/
E5_SMALL_V2_Q
}

View File

@ -16,7 +16,6 @@ import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.model.openai.OpenAiHelper.*;
import static dev.langchain4j.model.openai.OpenAiModelName.TEXT_EMBEDDING_ADA_002;
import static java.time.Duration.ofSeconds;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;
public class OpenAiEmbeddingModel implements EmbeddingModel, TokenCountEstimator {
@ -58,17 +57,6 @@ public class OpenAiEmbeddingModel implements EmbeddingModel, TokenCountEstimator
this.tokenizer = new OpenAiTokenizer(this.modelName);
}
@Override
public Embedding embed(String text) {
List<Embedding> embeddings = embedTexts(singletonList(text));
return embeddings.get(0);
}
@Override
public Embedding embed(TextSegment textSegment) {
return embed(textSegment.text());
}
@Override
public List<Embedding> embedAll(List<TextSegment> textSegments) {

View File

@ -29,10 +29,10 @@ public class PineconeEmbeddingStore implements EmbeddingStore<TextSegment> {
+ "<dependency>\n" +
" <groupId>dev.langchain4j</groupId>\n" +
" <artifactId>langchain4j-pinecone</artifactId>\n" +
" <version>0.16.0</version>\n" +
" <version>0.17.0</version>\n" +
"</dependency>\n\n"
+ "Gradle:\n"
+ "implementation 'dev.langchain4j:langchain4j-pinecone:0.16.0'\n";
+ "implementation 'dev.langchain4j:langchain4j-pinecone:0.17.0'\n";
}
private static EmbeddingStore<TextSegment> loadDynamically(String implementationClassName, String apiKey, String environment, String projectName, String index, String nameSpace) throws ClassNotFoundException, NoSuchMethodException, InstantiationException, IllegalAccessException, InvocationTargetException {

13
pom.xml
View File

@ -6,14 +6,23 @@
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-aggregator</artifactId>
<version>0.16.0</version>
<version>0.17.0</version>
<packaging>pom</packaging>
<modules>
<module>langchain4j</module>
<module>langchain4j-parent</module>
<module>langchain4j-core</module>
<module>langchain4j</module>
<module>langchain4j-pinecone</module>
<module>langchain4j-embeddings</module>
<module>langchain4j-embeddings-all-minilm-l6-v2</module>
<module>langchain4j-embeddings-all-minilm-l6-v2-q</module>
<module>langchain4j-embeddings-e5-small-v2</module>
<module>langchain4j-embeddings-e5-small-v2-q</module>
<module>langchain4j-spring-boot-starter</module>
</modules>