Added in-process embedding models (#41)
- all-minilm-l6-v2 - all-minilm-l6-v2-q - e5-small-v2 - e5-small-v2-q The idea is to give users an option to embed documents/texts in the same Java process without any external dependencies. ONNX Runtime is used to run models inside JVM. Each model resides in it's own maven module (inside the jar).
This commit is contained in:
parent
3cc75c771e
commit
529ef6b647
|
@ -0,0 +1 @@
|
|||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
|
@ -23,7 +23,7 @@ Please provide a relevant code snippets to reproduce this bug.
|
|||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Please complete the following information:**
|
||||
- LangChain4j version: e.g. 0.16.0
|
||||
- LangChain4j version: e.g. 0.17.0
|
||||
- Java version: e.g. 11
|
||||
- Spring Boot version (if applicable): e.g. 2.7.13
|
||||
|
||||
|
|
|
@ -7,14 +7,20 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Pull LFS files
|
||||
run: git lfs pull
|
||||
|
||||
- name: Set up JDK 8
|
||||
uses: actions/setup-java@v3
|
||||
with:
|
||||
java-version: '8'
|
||||
distribution: 'temurin'
|
||||
|
||||
- name: Test
|
||||
run: mvn --batch-mode test
|
||||
|
||||
|
||||
# For checking some compliance things (require a recent JDK due to plugins so in a separate step)
|
||||
compliance:
|
||||
runs-on: ubuntu-latest
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
<parent>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-parent</artifactId>
|
||||
<version>0.16.0</version>
|
||||
<version>0.17.0</version>
|
||||
<relativePath>../langchain4j-parent/pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
|
|
@ -1,15 +1,21 @@
|
|||
package dev.langchain4j.model.embedding;
|
||||
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static java.util.Collections.singletonList;
|
||||
|
||||
public interface EmbeddingModel {
|
||||
|
||||
Embedding embed(String text);
|
||||
default Embedding embed(String text) {
|
||||
return embed(TextSegment.from(text));
|
||||
}
|
||||
|
||||
Embedding embed(TextSegment textSegment);
|
||||
default Embedding embed(TextSegment textSegment) {
|
||||
return embedAll(singletonList(textSegment)).get(0);
|
||||
}
|
||||
|
||||
List<Embedding> embedAll(List<TextSegment> textSegments);
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-parent</artifactId>
|
||||
<version>0.17.0</version>
|
||||
<relativePath>../langchain4j-parent/pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<name>langchain4j-embeddings-all-minilm-l6-v2-q</name>
|
||||
<description>In-process all-minilm-l6-v2 (quantized) embedding model</description>
|
||||
|
||||
<dependencies>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-core</artifactId>
|
||||
<version>0.17.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-embeddings</artifactId>
|
||||
<version>0.17.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-engine</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-params</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.assertj</groupId>
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
<licenses>
|
||||
<license>
|
||||
<name>Apache-2.0</name>
|
||||
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url>
|
||||
<distribution>repo</distribution>
|
||||
<comments>A business-friendly OSS license</comments>
|
||||
</license>
|
||||
</licenses>
|
||||
|
||||
</project>
|
|
@ -0,0 +1,14 @@
|
|||
package dev.langchain4j.model.embedding;
|
||||
|
||||
public class ALL_MINILM_L6_V2_Q_EmbeddingModel extends AbstractInProcessEmbeddingModel {
|
||||
|
||||
private static final OnnxEmbeddingModel model = new OnnxEmbeddingModel(
|
||||
"/all-minilm-l6-v2-q.onnx",
|
||||
"/vocab.txt"
|
||||
);
|
||||
|
||||
@Override
|
||||
protected OnnxEmbeddingModel model() {
|
||||
return model;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
Model card:
|
||||
https://huggingface.co/Xenova/all-MiniLM-L6-v2
|
||||
|
||||
Model file:
|
||||
https://huggingface.co/Xenova/all-MiniLM-L6-v2/blob/main/onnx/model_quantized.onnx
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:afdb6f1a0e45b715d0bb9b11772f032c399babd23bfc31fed1c170afc848bdb1
|
||||
size 22972370
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,23 @@
|
|||
package dev.langchain4j.model.embedding;
|
||||
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class ALL_MINILM_L6_V2_Q_EmbeddingModelTest {
|
||||
|
||||
@Test
|
||||
void should_embed() {
|
||||
|
||||
EmbeddingModel model = new ALL_MINILM_L6_V2_Q_EmbeddingModel();
|
||||
|
||||
Embedding first = model.embed("hi");
|
||||
assertThat(first.vector()).hasSize(384);
|
||||
|
||||
Embedding second = model.embed("hello");
|
||||
assertThat(second.vector()).hasSize(384);
|
||||
|
||||
assertThat(Similarity.cosine(first.vector(), second.vector())).isGreaterThan(0.8);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-parent</artifactId>
|
||||
<version>0.17.0</version>
|
||||
<relativePath>../langchain4j-parent/pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<name>langchain4j-embeddings-all-minilm-l6-v2</name>
|
||||
<description>In-process all-minilm-l6-v2 embedding model</description>
|
||||
|
||||
<dependencies>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-core</artifactId>
|
||||
<version>0.17.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-embeddings</artifactId>
|
||||
<version>0.17.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-engine</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-params</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.assertj</groupId>
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
<licenses>
|
||||
<license>
|
||||
<name>Apache-2.0</name>
|
||||
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url>
|
||||
<distribution>repo</distribution>
|
||||
<comments>A business-friendly OSS license</comments>
|
||||
</license>
|
||||
</licenses>
|
||||
|
||||
</project>
|
|
@ -0,0 +1,14 @@
|
|||
package dev.langchain4j.model.embedding;
|
||||
|
||||
public class ALL_MINILM_L6_V2_EmbeddingModel extends AbstractInProcessEmbeddingModel {
|
||||
|
||||
private static final OnnxEmbeddingModel model = new OnnxEmbeddingModel(
|
||||
"/all-minilm-l6-v2.onnx",
|
||||
"/vocab.txt"
|
||||
);
|
||||
|
||||
@Override
|
||||
protected OnnxEmbeddingModel model() {
|
||||
return model;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
Model card:
|
||||
https://huggingface.co/Xenova/all-MiniLM-L6-v2
|
||||
|
||||
Model file:
|
||||
https://huggingface.co/Xenova/all-MiniLM-L6-v2/blob/main/onnx/model.onnx
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ca46f1a88a9c6e61b918af1ab38be3e7903b986616551f0a6f10a7ecc5730cd4
|
||||
size 90984153
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,23 @@
|
|||
package dev.langchain4j.model.embedding;
|
||||
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class ALL_MINILM_L6_V2_EmbeddingModelTest {
|
||||
|
||||
@Test
|
||||
void should_embed() {
|
||||
|
||||
EmbeddingModel model = new ALL_MINILM_L6_V2_EmbeddingModel();
|
||||
|
||||
Embedding first = model.embed("hi");
|
||||
assertThat(first.vector()).hasSize(384);
|
||||
|
||||
Embedding second = model.embed("hello");
|
||||
assertThat(second.vector()).hasSize(384);
|
||||
|
||||
assertThat(Similarity.cosine(first.vector(), second.vector())).isGreaterThan(0.8);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-parent</artifactId>
|
||||
<version>0.17.0</version>
|
||||
<relativePath>../langchain4j-parent/pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
<artifactId>langchain4j-embeddings-e5-small-v2-q</artifactId>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<name>langchain4j-embeddings-e5-small-v2-q</name>
|
||||
<description>In-process e5-small-v2 (quantized) embedding model</description>
|
||||
|
||||
<dependencies>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-core</artifactId>
|
||||
<version>0.17.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-embeddings</artifactId>
|
||||
<version>0.17.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-engine</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-params</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.assertj</groupId>
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
<licenses>
|
||||
<license>
|
||||
<name>Apache-2.0</name>
|
||||
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url>
|
||||
<distribution>repo</distribution>
|
||||
<comments>A business-friendly OSS license</comments>
|
||||
</license>
|
||||
</licenses>
|
||||
|
||||
</project>
|
|
@ -0,0 +1,14 @@
|
|||
package dev.langchain4j.model.embedding;
|
||||
|
||||
public class E5_SMALL_V2_Q_EmbeddingModel extends AbstractInProcessEmbeddingModel {
|
||||
|
||||
private static final OnnxEmbeddingModel model = new OnnxEmbeddingModel(
|
||||
"/e5-small-v2-q.onnx",
|
||||
"/vocab.txt"
|
||||
);
|
||||
|
||||
@Override
|
||||
protected OnnxEmbeddingModel model() {
|
||||
return model;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
Model card:
|
||||
https://huggingface.co/Xenova/e5-small-v2
|
||||
|
||||
Model file:
|
||||
https://huggingface.co/Xenova/e5-small-v2/resolve/main/onnx/model_quantized.onnx
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7d9092cb25f2bd1c023b7e8d2aa459044a02030ac880e5a59fdaf27af69f1ded
|
||||
size 34014367
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,23 @@
|
|||
package dev.langchain4j.model.embedding;
|
||||
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class E5_SMALL_V2_Q_EmbeddingModelTest {
|
||||
|
||||
@Test
|
||||
void should_embed() {
|
||||
|
||||
EmbeddingModel model = new E5_SMALL_V2_Q_EmbeddingModel();
|
||||
|
||||
Embedding first = model.embed("hi");
|
||||
assertThat(first.vector()).hasSize(384);
|
||||
|
||||
Embedding second = model.embed("hello");
|
||||
assertThat(second.vector()).hasSize(384);
|
||||
|
||||
assertThat(Similarity.cosine(first.vector(), second.vector())).isGreaterThan(0.95);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-parent</artifactId>
|
||||
<version>0.17.0</version>
|
||||
<relativePath>../langchain4j-parent/pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
<artifactId>langchain4j-embeddings-e5-small-v2</artifactId>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<name>langchain4j-embeddings-e5-small-v2</name>
|
||||
<description>In-process e5-small-v2 embedding model</description>
|
||||
|
||||
<dependencies>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-core</artifactId>
|
||||
<version>0.17.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-embeddings</artifactId>
|
||||
<version>0.17.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-engine</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-params</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.assertj</groupId>
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
<licenses>
|
||||
<license>
|
||||
<name>Apache-2.0</name>
|
||||
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url>
|
||||
<distribution>repo</distribution>
|
||||
<comments>A business-friendly OSS license</comments>
|
||||
</license>
|
||||
</licenses>
|
||||
|
||||
</project>
|
|
@ -0,0 +1,14 @@
|
|||
package dev.langchain4j.model.embedding;
|
||||
|
||||
public class E5_SMALL_V2_EmbeddingModel extends AbstractInProcessEmbeddingModel {
|
||||
|
||||
private static final OnnxEmbeddingModel model = new OnnxEmbeddingModel(
|
||||
"/e5-small-v2.onnx",
|
||||
"/vocab.txt"
|
||||
);
|
||||
|
||||
@Override
|
||||
protected OnnxEmbeddingModel model() {
|
||||
return model;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
Model card:
|
||||
https://huggingface.co/Xenova/e5-small-v2
|
||||
|
||||
Model file:
|
||||
https://huggingface.co/Xenova/e5-small-v2/resolve/main/onnx/model.onnx
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b2a43b66f7f9b6f29643a21340ad7c03d2d91e6bd5d43429a77799a9ee880eb0
|
||||
size 133093467
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,23 @@
|
|||
package dev.langchain4j.model.embedding;
|
||||
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class E5_SMALL_V2_EmbeddingModelTest {
|
||||
|
||||
@Test
|
||||
void should_embed() {
|
||||
|
||||
EmbeddingModel model = new E5_SMALL_V2_EmbeddingModel();
|
||||
|
||||
Embedding first = model.embed("hi");
|
||||
assertThat(first.vector()).hasSize(384);
|
||||
|
||||
Embedding second = model.embed("hello");
|
||||
assertThat(second.vector()).hasSize(384);
|
||||
|
||||
assertThat(Similarity.cosine(first.vector(), second.vector())).isGreaterThan(0.95);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-parent</artifactId>
|
||||
<version>0.17.0</version>
|
||||
<relativePath>../langchain4j-parent/pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
<artifactId>langchain4j-embeddings</artifactId>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<name>langchain4j-embeddings</name>
|
||||
<description>Common functionality for other langchain4j-embeddings-xxx modules</description>
|
||||
|
||||
<dependencies>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-core</artifactId>
|
||||
<version>0.17.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.microsoft.onnxruntime</groupId>
|
||||
<artifactId>onnxruntime</artifactId>
|
||||
<version>1.15.1</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>ai.djl</groupId>
|
||||
<artifactId>api</artifactId>
|
||||
<version>0.23.0</version>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
<licenses>
|
||||
<license>
|
||||
<name>Apache-2.0</name>
|
||||
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url>
|
||||
<distribution>repo</distribution>
|
||||
<comments>A business-friendly OSS license</comments>
|
||||
</license>
|
||||
</licenses>
|
||||
|
||||
</project>
|
|
@ -0,0 +1,23 @@
|
|||
package dev.langchain4j.model.embedding;
|
||||
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
public abstract class AbstractInProcessEmbeddingModel implements EmbeddingModel {
|
||||
|
||||
protected abstract OnnxEmbeddingModel model();
|
||||
|
||||
@Override
|
||||
public List<Embedding> embedAll(List<TextSegment> segments) {
|
||||
return segments.stream()
|
||||
.map(segment -> {
|
||||
float[] vector = model().embed(segment.text());
|
||||
return Embedding.from(vector);
|
||||
})
|
||||
.collect(toList());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,150 @@
|
|||
package dev.langchain4j.model.embedding;
|
||||
|
||||
import ai.djl.modality.nlp.DefaultVocabulary;
|
||||
import ai.djl.modality.nlp.bert.BertFullTokenizer;
|
||||
import ai.onnxruntime.OnnxTensor;
|
||||
import ai.onnxruntime.OrtEnvironment;
|
||||
import ai.onnxruntime.OrtException;
|
||||
import ai.onnxruntime.OrtSession;
|
||||
import ai.onnxruntime.OrtSession.Result;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static ai.onnxruntime.OnnxTensor.createTensor;
|
||||
import static java.nio.LongBuffer.wrap;
|
||||
|
||||
public class OnnxEmbeddingModel {
|
||||
|
||||
private final OrtEnvironment environment;
|
||||
private final OrtSession session;
|
||||
private final DefaultVocabulary vocabulary;
|
||||
private final BertFullTokenizer tokenizer;
|
||||
|
||||
public OnnxEmbeddingModel(String modelFilePath, String vocabularyFilePath) {
|
||||
try {
|
||||
this.environment = OrtEnvironment.getEnvironment();
|
||||
this.session = environment.createSession(loadModel(modelFilePath));
|
||||
this.vocabulary = DefaultVocabulary.builder()
|
||||
.addFromTextFile(getClass().getResource(vocabularyFilePath))
|
||||
.build();
|
||||
this.tokenizer = new BertFullTokenizer(vocabulary, true);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public float[] embed(String text) {
|
||||
try (Result result = runModel(text)) {
|
||||
return toEmbedding(result);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
private Result runModel(String text) throws OrtException {
|
||||
|
||||
List<String> stringTokens = new ArrayList<>();
|
||||
stringTokens.add("[CLS]");
|
||||
stringTokens.addAll(tokenizer.tokenize(text));
|
||||
stringTokens.add("[SEP]");
|
||||
|
||||
// TODO reusable buffers
|
||||
long[] tokens = stringTokens.stream()
|
||||
.mapToLong(vocabulary::getIndex)
|
||||
.toArray();
|
||||
|
||||
long[] attentionMasks = new long[stringTokens.size()];
|
||||
for (int i = 0; i < stringTokens.size(); i++) {
|
||||
attentionMasks[i] = 1L;
|
||||
}
|
||||
|
||||
long[] tokenTypeIds = new long[stringTokens.size()];
|
||||
for (int i = 0; i < stringTokens.size(); i++) {
|
||||
tokenTypeIds[i] = 0L;
|
||||
}
|
||||
|
||||
long[] shape = {1, tokens.length};
|
||||
|
||||
try (
|
||||
OnnxTensor tokensTensor = createTensor(environment, wrap(tokens), shape);
|
||||
OnnxTensor attentionMasksTensor = createTensor(environment, wrap(attentionMasks), shape);
|
||||
OnnxTensor tokenTypeIdsTensor = createTensor(environment, wrap(tokenTypeIds), shape)
|
||||
) {
|
||||
Map<String, OnnxTensor> inputs = new HashMap<>();
|
||||
inputs.put("input_ids", tokensTensor);
|
||||
inputs.put("token_type_ids", tokenTypeIdsTensor);
|
||||
inputs.put("attention_mask", attentionMasksTensor);
|
||||
|
||||
return session.run(inputs);
|
||||
}
|
||||
}
|
||||
|
||||
private byte[] loadModel(String modelFilePath) {
|
||||
try (
|
||||
InputStream inputStream = getClass().getResourceAsStream(modelFilePath);
|
||||
ByteArrayOutputStream buffer = new ByteArrayOutputStream()
|
||||
) {
|
||||
int nRead;
|
||||
byte[] data = new byte[1024];
|
||||
|
||||
while ((nRead = inputStream.read(data, 0, data.length)) != -1) {
|
||||
buffer.write(data, 0, nRead);
|
||||
}
|
||||
|
||||
buffer.flush();
|
||||
return buffer.toByteArray();
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private static float[] toEmbedding(Result result) throws OrtException {
|
||||
float[][] vectors = ((float[][][]) result.get(0).getValue())[0];
|
||||
return normalize(meanPool(vectors));
|
||||
}
|
||||
|
||||
private static float[] meanPool(float[][] vectors) {
|
||||
|
||||
int numVectors = vectors.length;
|
||||
int vectorLength = vectors[0].length;
|
||||
|
||||
float[] averagedVector = new float[vectorLength];
|
||||
|
||||
// TODO ignore [CLS] and [SEP] ?
|
||||
for (int i = 0; i < numVectors; i++) {
|
||||
for (int j = 0; j < vectorLength; j++) {
|
||||
averagedVector[j] += vectors[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
// TODO ignore [CLS] and [SEP] ?
|
||||
for (int j = 0; j < vectorLength; j++) {
|
||||
averagedVector[j] /= numVectors;
|
||||
}
|
||||
|
||||
return averagedVector;
|
||||
}
|
||||
|
||||
private static float[] normalize(float[] vector) {
|
||||
|
||||
float sumSquare = 0;
|
||||
for (float v : vector) {
|
||||
sumSquare += v * v;
|
||||
}
|
||||
float norm = (float) Math.sqrt(sumSquare);
|
||||
|
||||
float[] normalizedVector = new float[vector.length];
|
||||
for (int i = 0; i < vector.length; i++) {
|
||||
normalizedVector[i] = vector[i] / norm;
|
||||
}
|
||||
|
||||
return normalizedVector;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
package dev.langchain4j.model.embedding;
|
||||
|
||||
public class Similarity {
|
||||
|
||||
public static double cosine(float[] a, float[] b) {
|
||||
|
||||
double dotProduct = 0.0;
|
||||
double normA = 0.0;
|
||||
double normB = 0.0;
|
||||
|
||||
for (int i = 0; i < a.length; i++) {
|
||||
dotProduct += a[i] * b[i];
|
||||
normA += Math.pow(a[i], 2);
|
||||
normB += Math.pow(b[i], 2);
|
||||
}
|
||||
|
||||
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
|
||||
}
|
||||
}
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-parent</artifactId>
|
||||
<version>0.16.0</version>
|
||||
<version>0.17.0</version>
|
||||
<packaging>pom</packaging>
|
||||
|
||||
<name>langchain4j parent POM</name>
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
<parent>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-parent</artifactId>
|
||||
<version>0.16.0</version>
|
||||
<version>0.17.0</version>
|
||||
<relativePath>../langchain4j-parent/pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
@ -23,7 +23,7 @@
|
|||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-core</artifactId>
|
||||
<version>0.16.0</version>
|
||||
<version>0.17.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
<parent>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-parent</artifactId>
|
||||
<version>0.16.0</version>
|
||||
<version>0.17.0</version>
|
||||
<relativePath>../langchain4j-parent/pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
@ -22,7 +22,7 @@
|
|||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j</artifactId>
|
||||
<version>0.16.0</version>
|
||||
<version>0.17.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
<parent>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-parent</artifactId>
|
||||
<version>0.16.0</version>
|
||||
<version>0.17.0</version>
|
||||
<relativePath>../langchain4j-parent/pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
@ -24,7 +24,7 @@
|
|||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-core</artifactId>
|
||||
<version>0.16.0</version>
|
||||
<version>0.17.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
|
@ -37,13 +37,6 @@
|
|||
<artifactId>jtokkit</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-pinecone</artifactId>
|
||||
<version>0.16.0</version>
|
||||
<optional>true</optional>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
|
|
|
@ -9,7 +9,6 @@ import java.time.Duration;
|
|||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.model.huggingface.HuggingFaceModelName.SENTENCE_TRANSFORMERS_ALL_MINI_LM_L6_V2;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
public class HuggingFaceEmbeddingModel implements EmbeddingModel {
|
||||
|
@ -32,17 +31,6 @@ public class HuggingFaceEmbeddingModel implements EmbeddingModel {
|
|||
this.waitForModel = waitForModel == null ? true : waitForModel;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Embedding embed(String text) {
|
||||
List<Embedding> embeddings = embedTexts(singletonList(text));
|
||||
return embeddings.get(0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Embedding embed(TextSegment textSegment) {
|
||||
return embed(textSegment.text());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Embedding> embedAll(List<TextSegment> textSegments) {
|
||||
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
package dev.langchain4j.model.inprocess;
|
||||
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static java.lang.String.format;
|
||||
|
||||
/**
|
||||
* This is an embedding model that runs within your Java application's process.
|
||||
* This class serves as a thin wrapper over the actual implementations.
|
||||
* The implementations are located in separate, optional langchain4j-embeddings-xxx modules and are loaded dynamically.
|
||||
* If you wish to use the model XXX, please add the langchain4j-embeddings-xxx dependency to your project.
|
||||
* The model execution is carried out using the ONNX Runtime.
|
||||
*/
|
||||
public class InProcessEmbeddingModel implements EmbeddingModel {
|
||||
|
||||
private final EmbeddingModel implementation;
|
||||
|
||||
public InProcessEmbeddingModel(InProcessEmbeddingModelType type) {
|
||||
try {
|
||||
implementation = loadDynamically(type);
|
||||
} catch (ClassNotFoundException e) {
|
||||
throw new RuntimeException(getMessage(type), e);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private static EmbeddingModel loadDynamically(InProcessEmbeddingModelType type) throws Exception {
|
||||
Class<?> implementationClass = Class.forName(
|
||||
format("dev.langchain4j.model.embedding.%s_EmbeddingModel", type.name()));
|
||||
return (EmbeddingModel) implementationClass.getConstructor().newInstance();
|
||||
}
|
||||
|
||||
private static String getMessage(InProcessEmbeddingModelType type) {
|
||||
return format("To use %s embedding model, please add the following dependency to your project:\n"
|
||||
+ "\n"
|
||||
+ "Maven:\n"
|
||||
+ "<dependency>\n" +
|
||||
" <groupId>dev.langchain4j</groupId>\n" +
|
||||
" <artifactId>langchain4j-embeddings-%s</artifactId>\n" +
|
||||
" <version>0.17.0</version>\n" +
|
||||
"</dependency>\n"
|
||||
+ "\n"
|
||||
+ "Gradle:\n"
|
||||
+ "implementation 'dev.langchain4j:langchain4j-embeddings-%s:0.17.0'\n",
|
||||
type.name(),
|
||||
type.name().replace("_", "-").toLowerCase(),
|
||||
type.name().replace("_", "-").toLowerCase()
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Embedding> embedAll(List<TextSegment> textSegments) {
|
||||
return implementation.embedAll(textSegments);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
package dev.langchain4j.model.inprocess;
|
||||
|
||||
/**
|
||||
* Lists all the currently supported in-process embedding models.
|
||||
* New models will be added gradually.
|
||||
* If you would like a new model to be added, please open a GitHub issue at: https://github.com/langchain4j/langchain4j/issues/new/choose
|
||||
*/
|
||||
public enum InProcessEmbeddingModelType {
|
||||
|
||||
/**
|
||||
* SentenceTransformers all-MiniLM-L6-v2 model.
|
||||
* More details: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
|
||||
*/
|
||||
ALL_MINILM_L6_V2,
|
||||
|
||||
/**
|
||||
* SentenceTransformers all-MiniLM-L6-v2 model (quantized).
|
||||
* More details: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
|
||||
*/
|
||||
ALL_MINILM_L6_V2_Q,
|
||||
|
||||
/**
|
||||
* E5-small-v2 model.
|
||||
* More details: https://huggingface.co/intfloat/e5-small-v2
|
||||
*/
|
||||
E5_SMALL_V2,
|
||||
|
||||
/**
|
||||
* E5-small-v2 model (quantized).
|
||||
* More details: https://huggingface.co/intfloat/e5-small-v2
|
||||
*/
|
||||
E5_SMALL_V2_Q
|
||||
}
|
|
@ -16,7 +16,6 @@ import static dev.langchain4j.internal.RetryUtils.withRetry;
|
|||
import static dev.langchain4j.model.openai.OpenAiHelper.*;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.TEXT_EMBEDDING_ADA_002;
|
||||
import static java.time.Duration.ofSeconds;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
public class OpenAiEmbeddingModel implements EmbeddingModel, TokenCountEstimator {
|
||||
|
@ -58,17 +57,6 @@ public class OpenAiEmbeddingModel implements EmbeddingModel, TokenCountEstimator
|
|||
this.tokenizer = new OpenAiTokenizer(this.modelName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Embedding embed(String text) {
|
||||
List<Embedding> embeddings = embedTexts(singletonList(text));
|
||||
return embeddings.get(0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Embedding embed(TextSegment textSegment) {
|
||||
return embed(textSegment.text());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Embedding> embedAll(List<TextSegment> textSegments) {
|
||||
|
||||
|
|
|
@ -29,10 +29,10 @@ public class PineconeEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|||
+ "<dependency>\n" +
|
||||
" <groupId>dev.langchain4j</groupId>\n" +
|
||||
" <artifactId>langchain4j-pinecone</artifactId>\n" +
|
||||
" <version>0.16.0</version>\n" +
|
||||
" <version>0.17.0</version>\n" +
|
||||
"</dependency>\n\n"
|
||||
+ "Gradle:\n"
|
||||
+ "implementation 'dev.langchain4j:langchain4j-pinecone:0.16.0'\n";
|
||||
+ "implementation 'dev.langchain4j:langchain4j-pinecone:0.17.0'\n";
|
||||
}
|
||||
|
||||
private static EmbeddingStore<TextSegment> loadDynamically(String implementationClassName, String apiKey, String environment, String projectName, String index, String nameSpace) throws ClassNotFoundException, NoSuchMethodException, InstantiationException, IllegalAccessException, InvocationTargetException {
|
||||
|
|
13
pom.xml
13
pom.xml
|
@ -6,14 +6,23 @@
|
|||
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-aggregator</artifactId>
|
||||
<version>0.16.0</version>
|
||||
<version>0.17.0</version>
|
||||
<packaging>pom</packaging>
|
||||
|
||||
<modules>
|
||||
<module>langchain4j</module>
|
||||
<module>langchain4j-parent</module>
|
||||
|
||||
<module>langchain4j-core</module>
|
||||
<module>langchain4j</module>
|
||||
|
||||
<module>langchain4j-pinecone</module>
|
||||
|
||||
<module>langchain4j-embeddings</module>
|
||||
<module>langchain4j-embeddings-all-minilm-l6-v2</module>
|
||||
<module>langchain4j-embeddings-all-minilm-l6-v2-q</module>
|
||||
<module>langchain4j-embeddings-e5-small-v2</module>
|
||||
<module>langchain4j-embeddings-e5-small-v2-q</module>
|
||||
|
||||
<module>langchain4j-spring-boot-starter</module>
|
||||
</modules>
|
||||
|
||||
|
|
Loading…
Reference in New Issue