diff --git a/docs/static/img/web-search-engine.png b/docs/static/img/web-search-engine.png new file mode 100644 index 000000000..c4974e057 Binary files /dev/null and b/docs/static/img/web-search-engine.png differ diff --git a/langchain4j-core/pom.xml b/langchain4j-core/pom.xml index 52566b823..c586a09e5 100644 --- a/langchain4j-core/pom.xml +++ b/langchain4j-core/pom.xml @@ -201,4 +201,4 @@ - \ No newline at end of file + diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/content/retriever/ContentRetriever.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/content/retriever/ContentRetriever.java index eaadae942..8f29bfb47 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/rag/content/retriever/ContentRetriever.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/content/retriever/ContentRetriever.java @@ -13,15 +13,16 @@ import java.util.List; * The underlying data source can be virtually anything: *
  * - Embedding (vector) store (see {@link EmbeddingStoreContentRetriever})
- * - Full-text search engine (e.g., Apache Lucene, Elasticsearch, Vespa)
- * - Hybrid of keyword and vector search
- * - The Web (e.g., Google, Bing)
- * - Knowledge graph
+ * - Full-text search engine (see {@code AzureAiSearchContentRetriever} in {@code langchain4j-azure-ai-search} module)
+ * - Hybrid of vector and full-text search (see {@code AzureAiSearchContentRetriever} in {@code langchain4j-azure-ai-search} module)
+ * - Web Search Engine (see {@link WebSearchContentRetriever})
+ * - Knowledge graph (see {@code Neo4jContentRetriever} in {@code langchain4j-neo4j} module)
  * - Relational database
  * - etc.
  * 
* * @see EmbeddingStoreContentRetriever + * @see WebSearchContentRetriever */ public interface ContentRetriever { diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetriever.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetriever.java new file mode 100644 index 000000000..78ca0a4d1 --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetriever.java @@ -0,0 +1,50 @@ +package dev.langchain4j.rag.content.retriever; + +import dev.langchain4j.rag.content.Content; +import dev.langchain4j.rag.query.Query; +import dev.langchain4j.web.search.WebSearchEngine; +import dev.langchain4j.web.search.WebSearchResults; + +import java.util.List; + +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; +import static java.util.stream.Collectors.toList; + +/** + * A {@link ContentRetriever} that retrieves relevant {@link Content} from the web using a {@link WebSearchEngine}. + *
+ * It returns one {@link Content} for each result that a {@link WebSearchEngine} has returned for a given {@link Query}. + *
+ * Depending on the {@link WebSearchEngine} implementation, the {@link Content#textSegment()} + * can contain either a snippet of a web page or a complete content of a web page. + */ +public class WebSearchContentRetriever implements ContentRetriever { + + private final WebSearchEngine webSearchEngine; + + /** + * Constructs a new WebSearchContentRetriever with the specified web search engine. + * + * @param webSearchEngine The web search engine to use for retrieving search results. + */ + public WebSearchContentRetriever(WebSearchEngine webSearchEngine) { + this.webSearchEngine = ensureNotNull(webSearchEngine, "webSearchEngine"); + } + + @Override + public List retrieve(Query query) { + WebSearchResults webSearchResults = webSearchEngine.search(query.text()); + return webSearchResults.toTextSegments().stream() + .map(Content::from) + .collect(toList()); + } + + /** + * Creates a new instance of {@code WebSearchContentRetriever} with the specified {@link WebSearchEngine}. + * + * @return A new instance of WebSearchContentRetriever. + */ + public static WebSearchContentRetriever from(WebSearchEngine webSearchEngine) { + return new WebSearchContentRetriever(webSearchEngine); + } +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchEngine.java b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchEngine.java new file mode 100644 index 000000000..4fb927e93 --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchEngine.java @@ -0,0 +1,25 @@ +package dev.langchain4j.web.search; + +/** + * Represents a web search engine that can be used to perform searches on the Web in response to a user query. + */ +public interface WebSearchEngine { + + /** + * Performs a search query on the web search engine and returns the search results. + * + * @param query the search query + * @return the search results + */ + default WebSearchResults search(String query) { + return search(WebSearchRequest.from(query)); + } + + /** + * Performs a search request on the web search engine and returns the search results. + * + * @param webSearchRequest the search request + * @return the web search results + */ + WebSearchResults search(WebSearchRequest webSearchRequest); +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchInformationResult.java b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchInformationResult.java new file mode 100644 index 000000000..c29919b01 --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchInformationResult.java @@ -0,0 +1,117 @@ +package dev.langchain4j.web.search; + +import java.util.Map; +import java.util.Objects; + +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; + +/** + * Represents general information about the web search performed. + * This includes the total number of results, the page number, and metadata. + *

+ * The total number of results is the total number of web pages that are found by the search engine in response to a search query. + * The page number is the current page number of the search results. + * The metadata is a map of key-value pairs that provide additional information about the search. + * For example, it could include the search query, the search engine used, the time it took to perform the search, etc. + */ +public class WebSearchInformationResult { + + private final Long totalResults; + private final Integer pageNumber; + private final Map metadata; + + /** + * Constructs a new WebSearchInformationResult with the specified total results. + * + * @param totalResults The total number of results. + */ + public WebSearchInformationResult(Long totalResults) { + this(totalResults, null, null); + } + + /** + * Constructs a new WebSearchInformationResult with the specified total results, page number, and metadata. + * + * @param totalResults The total number of results. + * @param pageNumber The page number. + * @param metadata The metadata. + */ + public WebSearchInformationResult(Long totalResults, Integer pageNumber, Map metadata) { + this.totalResults = ensureNotNull(totalResults, "totalResults"); + this.pageNumber = pageNumber; + this.metadata = metadata; + } + + /** + * Gets the total number of results. + * + * @return The total number of results. + */ + public Long totalResults() { + return totalResults; + } + + /** + * Gets the page number. + * + * @return The page number. + */ + public Integer pageNumber() { + return pageNumber; + } + + /** + * Gets the metadata. + * + * @return The metadata. + */ + public Map metadata() { + return metadata; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + WebSearchInformationResult that = (WebSearchInformationResult) o; + return Objects.equals(totalResults, that.totalResults) + && Objects.equals(pageNumber, that.pageNumber) + && Objects.equals(metadata, that.metadata); + } + + @Override + public int hashCode() { + return Objects.hash(totalResults, pageNumber, metadata); + } + + @Override + public String toString() { + return "WebSearchInformationResult{" + + "totalResults=" + totalResults + + ", pageNumber=" + pageNumber + + ", metadata=" + metadata + + '}'; + } + + /** + * Creates a new WebSearchInformationResult with the specified total results. + * + * @param totalResults The total number of results. + * @return The new WebSearchInformationResult. + */ + public static WebSearchInformationResult from(Long totalResults) { + return new WebSearchInformationResult(totalResults); + } + + /** + * Creates a new WebSearchInformationResult with the specified total results, page number, and metadata. + * + * @param totalResults The total number of results. + * @param pageNumber The page number. + * @param metadata The metadata. + * @return The new WebSearchInformationResult. + */ + public static WebSearchInformationResult from(Long totalResults, Integer pageNumber, Map metadata) { + return new WebSearchInformationResult(totalResults, pageNumber, metadata); + } +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchOrganicResult.java b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchOrganicResult.java new file mode 100644 index 000000000..3689c406b --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchOrganicResult.java @@ -0,0 +1,230 @@ +package dev.langchain4j.web.search; + +import dev.langchain4j.data.document.Document; +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.segment.TextSegment; + +import java.net.URI; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.Utils.isNotNullOrBlank; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; + +/** + * Represents an organic search results are the web pages that are returned by the search engine in response to a search query. + * This includes the title, URL, snippet and/or content, and metadata of the web page. + *

+ * These results are typically ranked by relevance to the search query. + *

+ */ +public class WebSearchOrganicResult { + private final String title; + private final URI url; + private final String snippet; + private final String content; + private final Map metadata; + + + /** + * Constructs a WebSearchOrganicResult object with the given title and URL. + * + * @param title The title of the search result. + * @param url The URL associated with the search result. + */ + public WebSearchOrganicResult(String title, URI url) { + this.title = ensureNotBlank(title, "title"); + this.url = ensureNotNull(url, "url"); + this.snippet = null; + this.content = null; + this.metadata = null; + } + + /** + * Constructs a WebSearchOrganicResult object with the given title, URL, snippet and/or content. + * + * @param title The title of the search result. + * @param url The URL associated with the search result. + * @param snippet The snippet of the search result, in plain text. + * @param content The most query related content from the scraped url. + */ + public WebSearchOrganicResult(String title, URI url, String snippet, String content) { + this.title = ensureNotBlank(title, "title"); + this.url = ensureNotNull(url, "url"); + this.snippet = snippet; + this.content = content; + this.metadata = null; + } + + /** + * Constructs a WebSearchOrganicResult object with the given title, URL, snippet and/or content, and metadata. + * + * @param title The title of the search result. + * @param url The URL associated with the search result. + * @param snippet The snippet of the search result, in plain text. + * @param content The most query related content from the scraped url. + * @param metadata The metadata associated with the search result. + */ + public WebSearchOrganicResult(String title, URI url, String snippet, String content, Map metadata) { + this.title = ensureNotBlank(title, "title"); + this.url = ensureNotNull(url,"url"); + this.snippet = snippet; + this.content = content; + this.metadata = getOrDefault(metadata, new HashMap<>()); + } + + /** + * Returns the title of the web page. + * + * @return The title of the web page. + */ + public String title() { + return title; + } + + /** + * Returns the URL associated with the web page. + * + * @return The URL associated with the web page. + */ + public URI url() { + return url; + } + + /** + * Returns the snippet associated with the web page. + * + * @return The snippet associated with the web page. + */ + public String snippet() { + return snippet; + } + + /** + * Returns the content scraped from the web page. + * + * @return The content scraped from the web page. + */ + public String content() { + return content; + } + + /** + * Returns the result metadata associated with the search result. + * + * @return The result metadata associated with the search result. + */ + public Map metadata() { + return metadata; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + WebSearchOrganicResult that = (WebSearchOrganicResult) o; + return Objects.equals(title, that.title) + && Objects.equals(url, that.url) + && Objects.equals(snippet, that.snippet) + && Objects.equals(content, that.content) + && Objects.equals(metadata, that.metadata); + } + + @Override + public int hashCode() { + return Objects.hash(title, url, snippet, content, metadata); + } + + @Override + public String toString() { + return "WebSearchOrganicResult{" + + "title='" + title + '\'' + + ", url=" + url + + ", snippet='" + snippet + '\'' + + ", content='" + content + '\'' + + ", metadata=" + metadata + + '}'; + } + + /** + * Converts this WebSearchOrganicResult to a TextSegment. + * + * @return The TextSegment representation of this WebSearchOrganicResult. + */ + public TextSegment toTextSegment() { + return TextSegment.from(copyToText(), copyToMetadata()); + } + + /** + * Converts this WebSearchOrganicResult to a Document. + * + * @return The Document representation of this WebSearchOrganicResult. + */ + public Document toDocument() { + return Document.from(copyToText(), copyToMetadata()); + } + + private String copyToText() { + StringBuilder text = new StringBuilder(); + text.append(title); + text.append("\n"); + if (isNotNullOrBlank(content)) { + text.append(content); + } else if (isNotNullOrBlank(snippet)) { + text.append(snippet); + } + return text.toString(); + } + + private Metadata copyToMetadata() { + Metadata docMetadata = new Metadata(); + docMetadata.add("url", url); + if (metadata != null) { + for (Map.Entry entry : metadata.entrySet()) { + docMetadata.add(entry.getKey(), entry.getValue()); + } + } + return docMetadata; + } + + /** + * Creates a WebSearchOrganicResult object from the given title and URL. + * + * @param title The title of the search result. + * @param url The URL associated with the search result. + * @return The created WebSearchOrganicResult object. + */ + public static WebSearchOrganicResult from(String title, URI url) { + return new WebSearchOrganicResult(title, url); + } + + /** + * Creates a WebSearchOrganicResult object from the given title, URL, snippet and/or content. + * + * @param title The title of the search result. + * @param url The URL associated with the search result. + * @param snippet The snippet of the search result, in plain text. + * @param content The most query related content from the scraped url. + * @return The created WebSearchOrganicResult object. + */ + public static WebSearchOrganicResult from(String title, URI url, String snippet, String content) { + return new WebSearchOrganicResult(title, url, snippet, content); + } + + /** + * Creates a WebSearchOrganicResult object from the given title, URL, snippet and/or content, and result metadata. + * + * @param title The title of the search result. + * @param url The URL associated with the search result. + * @param snippet The snippet of the search result, in plain text. + * @param content The most query related content from the scraped url. + * @param metadata The metadata associated with the search result. + * @return The created WebSearchOrganicResult object. + */ + public static WebSearchOrganicResult from(String title, URI url, String snippet, String content, Map metadata) { + return new WebSearchOrganicResult(title, url, snippet, content, metadata); + } +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchRequest.java b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchRequest.java new file mode 100644 index 000000000..c79eb68fd --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchRequest.java @@ -0,0 +1,312 @@ +package dev.langchain4j.web.search; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; + + +/** + * Represents a search request that can be made by the user to perform searches in any implementation of {@link WebSearchEngine}. + *

+ * {@link WebSearchRequest} follow opensearch foundation standard implemented by most web search engine libs like Google, Bing, Yahoo, etc. + * OpenSearch#parameters + *

+ *

+ * The {@link #searchTerms} are the keywords that the search client desires to search for. This param is mandatory to perform a search. + *

+ *
+ * Configurable parameters (optional): + *

+ */ +public class WebSearchRequest { + + private final String searchTerms; + private final Integer maxResults; + private final String language; + private final String geoLocation; + private final Integer startPage; + private final Integer startIndex; + private final Boolean safeSearch; + private final Map additionalParams; + + private WebSearchRequest(Builder builder){ + this.searchTerms = ensureNotBlank(builder.searchTerms,"searchTerms"); + this.maxResults = builder.maxResults; + this.language = builder.language; + this.geoLocation = builder.geoLocation; + this.startPage = getOrDefault(builder.startPage,1); + this.startIndex = builder.startIndex; + this.safeSearch = getOrDefault(builder.safeSearch,true); + this.additionalParams = getOrDefault(builder.additionalParams, () -> new HashMap<>()); + } + + /** + * Get the search terms. + * + * @return The search terms. + */ + public String searchTerms() { + return searchTerms; + } + + /** + * Get the maximum number of results. + * + * @return The maximum number of results. + */ + public Integer maxResults() { + return maxResults; + } + + /** + * Get the desired language for search results. + * + * @return The desired language for search results. + */ + public String language() { + return language; + } + + /** + * Get the desired geolocation for search results. + * + * @return The desired geolocation for search results. + */ + public String geoLocation() { + return geoLocation; + } + + /** + * Get the start page number for search results. + * + * @return The start page number for search results. + */ + public Integer startPage() { + return startPage; + } + + /** + * Get the start index for search results. + * + * @return The start index for search results. + */ + public Integer startIndex() { + return startIndex; + } + + /** + * Get the safe search flag. + * + * @return The safe search flag. + */ + public Boolean safeSearch() { + return safeSearch; + } + + /** + * Get the additional parameters for the search request. + * + * @return The additional parameters for the search request. + */ + public Map additionalParams() { + return additionalParams; + } + + @Override + public boolean equals(Object another) { + if (this == another) return true; + return another instanceof WebSearchRequest + && equalTo((WebSearchRequest) another); + } + + private boolean equalTo(WebSearchRequest another){ + return Objects.equals(searchTerms, another.searchTerms) + && Objects.equals(maxResults, another.maxResults) + && Objects.equals(language, another.language) + && Objects.equals(geoLocation, another.geoLocation) + && Objects.equals(startPage, another.startPage) + && Objects.equals(startIndex, another.startIndex) + && Objects.equals(safeSearch, another.safeSearch) + && Objects.equals(additionalParams, another.additionalParams); + } + + @Override + public int hashCode() { + int h = 5381; + h += (h << 5) + Objects.hashCode(searchTerms); + h += (h << 5) + Objects.hashCode(maxResults); + h += (h << 5) + Objects.hashCode(language); + h += (h << 5) + Objects.hashCode(geoLocation); + h += (h << 5) + Objects.hashCode(startPage); + h += (h << 5) + Objects.hashCode(startIndex); + h += (h << 5) + Objects.hashCode(safeSearch); + h += (h << 5) + Objects.hashCode(additionalParams); + return h; + } + + @Override + public String toString() { + return "WebSearchRequest{" + + "searchTerms='" + searchTerms + '\'' + + ", maxResults=" + maxResults + + ", language='" + language + '\'' + + ", geoLocation='" + geoLocation + '\'' + + ", startPage=" + startPage + + ", startIndex=" + startIndex + + ", siteRestrict=" + safeSearch + + ", additionalParams=" + additionalParams + + '}'; + } + + /** + * Create a new builder instance. + * + * @return A new builder instance. + */ + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + private String searchTerms; + private Integer maxResults; + private String language; + private String geoLocation; + private Integer startPage; + private Integer startIndex; + private Boolean safeSearch; + private Map additionalParams; + + private Builder() { + } + + /** + * Set the search terms. + * + * @param searchTerms The keyword or keywords desired by the search user. + * @return The builder instance. + */ + public Builder searchTerms(String searchTerms) { + this.searchTerms = searchTerms; + return this; + } + + /** + * Set the maximum number of results. + * + * @param maxResults The maximum number of results. + * @return The builder instance. + */ + public Builder maxResults(Integer maxResults) { + this.maxResults = maxResults; + return this; + } + + /** + * Set the desired language for search results. + * + * @param language The desired language for search results. + * @return The builder instance. + */ + public Builder language(String language) { + this.language = language; + return this; + } + + /** + * Set the desired geolocation for search results. + * + * @param geoLocation The desired geolocation for search results. + * @return The builder instance. + */ + public Builder geoLocation(String geoLocation) { + this.geoLocation = geoLocation; + return this; + } + + /** + * Set the start page number for search results. + * + * @param startPage The start page number for search results. + * @return The builder instance. + */ + public Builder startPage(Integer startPage) { + this.startPage = startPage; + return this; + } + + /** + * Set the start index for search results. + * + * @param startIndex The start index for search results. + * @return The builder instance. + */ + public Builder startIndex(Integer startIndex) { + this.startIndex = startIndex; + return this; + } + + /** + * Set the safe search flag. + * + * @param safeSearch The safe search flag. + * @return The builder instance. + */ + public Builder safeSearch(Boolean safeSearch) { + this.safeSearch = safeSearch; + return this; + } + + /** + * Set the additional parameters for the search request. + * + * @param additionalParams The additional parameters for the search request. + * @return The builder instance. + */ + public Builder additionalParams(Map additionalParams) { + this.additionalParams = additionalParams; + return this; + } + + /** + * Build the web search request. + * + * @return The web search request. + */ + public WebSearchRequest build() { + return new WebSearchRequest(this); + } + } + + /** + * Create a web search request with the given search terms. + * + * @param searchTerms The search terms. + * @return The web search request. + */ + public static WebSearchRequest from(String searchTerms) { + return WebSearchRequest.builder().searchTerms(searchTerms).build(); + } + + /** + * Create a web search request with the given search terms and maximum number of results. + * + * @param searchTerms The search terms. + * @param maxResults The maximum number of results. + * @return The web search request. + */ + public static WebSearchRequest from(String searchTerms, Integer maxResults) { + return WebSearchRequest.builder().searchTerms(searchTerms).maxResults(maxResults).build(); + } +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchResults.java b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchResults.java new file mode 100644 index 000000000..d3667db84 --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchResults.java @@ -0,0 +1,149 @@ +package dev.langchain4j.web.search; + +import dev.langchain4j.data.document.Document; +import dev.langchain4j.data.segment.TextSegment; + +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; +import static java.util.stream.Collectors.toList; + +/** + * Represents the response of a web search performed. + * This includes the list of organic search results, information about the search, and pagination information. + *

+ * {@link WebSearchResults} follow opensearch foundation standard implemented by most web search engine libs like Google, Bing, Yahoo, etc. + * OpenSearch#response + *

+ *

+ * The organic search results are the web pages that are returned by the search engine in response to a search query. + * These results are typically ranked by relevance to the search query. + */ +public class WebSearchResults { + + private final Map searchMetadata; + private final WebSearchInformationResult searchInformation; + private final List results; + + /** + * Constructs a new instance of WebSearchResults. + * + * @param searchInformation The information about the web search. + * @param results The list of organic search results. + */ + public WebSearchResults(WebSearchInformationResult searchInformation, List results) { + this(null, searchInformation, results); + } + + /** + * Constructs a new instance of WebSearchResults. + * + * @param searchMetadata The metadata associated with the web search. + * @param searchInformation The information about the web search. + * @param results The list of organic search results. + */ + public WebSearchResults(Map searchMetadata, WebSearchInformationResult searchInformation, List results) { + this.searchMetadata = searchMetadata; + this.searchInformation = ensureNotNull(searchInformation, "searchInformation"); + this.results = ensureNotEmpty(results, "results"); + } + + /** + * Gets the metadata associated with the web search. + * + * @return The metadata associated with the web search. + */ + public Map searchMetadata() { + return searchMetadata; + } + + /** + * Gets the information about the web search. + * + * @return The information about the web search. + */ + public WebSearchInformationResult searchInformation() { + return searchInformation; + } + + /** + * Gets the list of organic search results. + * + * @return The list of organic search results. + */ + public List results() { + return results; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + WebSearchResults that = (WebSearchResults) o; + return Objects.equals(searchMetadata, that.searchMetadata) + && Objects.equals(searchInformation, that.searchInformation) + && Objects.equals(results, that.results); + } + + @Override + public int hashCode() { + return Objects.hash(searchMetadata, searchInformation, results); + } + + @Override + public String toString() { + return "WebSearchResults{" + + "searchMetadata=" + searchMetadata + + ", searchInformation=" + searchInformation + + ", results=" + results + + '}'; + } + + /** + * Converts the organic search results to a list of text segments. + * + * @return The list of text segments. + */ + public List toTextSegments() { + return results.stream() + .map(WebSearchOrganicResult::toTextSegment) + .collect(toList()); + } + + /** + * Converts the organic search results to a list of documents. + * + * @return The list of documents. + */ + public List toDocuments() { + return results.stream() + .map(WebSearchOrganicResult::toDocument) + .collect(toList()); + } + + /** + * Creates a new instance of WebSearchResults from the specified parameters. + * + * @param results The list of organic search results. + * @param searchInformation The information about the web search. + * @return The new instance of WebSearchResults. + */ + public static WebSearchResults from(WebSearchInformationResult searchInformation, List results) { + return new WebSearchResults(searchInformation, results); + } + + /** + * Creates a new instance of WebSearchResults from the specified parameters. + * + * @param searchMetadata The metadata associated with the search results. + * @param searchInformation The information about the web search. + * @param results The list of organic search results. + * @return The new instance of WebSearchResults. + */ + public static WebSearchResults from(Map searchMetadata, WebSearchInformationResult searchInformation, List results) { + return new WebSearchResults(searchMetadata, searchInformation, results); + } +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchTool.java b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchTool.java new file mode 100644 index 000000000..4b550b4d8 --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/web/search/WebSearchTool.java @@ -0,0 +1,48 @@ +package dev.langchain4j.web.search; + +import dev.langchain4j.agent.tool.P; +import dev.langchain4j.agent.tool.Tool; + +import java.util.stream.Collectors; + +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; + +public class WebSearchTool { + + private final WebSearchEngine searchEngine; + + public WebSearchTool(WebSearchEngine searchEngine) { + this.searchEngine = ensureNotNull(searchEngine, "searchEngine"); + } + + /** + * Runs a search query on the web search engine and returns a pretty-string representation of the search results. + * + * @param query the search user query + * @return a pretty-string representation of the search results + */ + @Tool("This tool can be used to perform web searches using search engines such as Google, particularly when seeking information about recent events.") + public String searchWeb(@P("Web search query") String query) { + WebSearchResults results = searchEngine.search(query); + return format(results); + } + + private String format(WebSearchResults results) { + return results.results() + .stream() + .map(organicResult -> "Title: " + organicResult.title() + "\n" + + "Source: " + organicResult.url().toString() + "\n" + + (organicResult.content() != null ? "Content:" + "\n" + organicResult.content() : "Snippet:" + "\n" + organicResult.snippet())) + .collect(Collectors.joining("\n\n")); + } + + /** + * Creates a new WebSearchTool with the specified web search engine. + * + * @param searchEngine the web search engine to use for searching the web + * @return a new WebSearchTool + */ + public static WebSearchTool from(WebSearchEngine searchEngine) { + return new WebSearchTool(searchEngine); + } +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetrieverIT.java b/langchain4j-core/src/test/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetrieverIT.java new file mode 100644 index 000000000..f7f581645 --- /dev/null +++ b/langchain4j-core/src/test/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetrieverIT.java @@ -0,0 +1,37 @@ +package dev.langchain4j.rag.content.retriever; + +import dev.langchain4j.rag.content.Content; +import dev.langchain4j.rag.query.Query; +import dev.langchain4j.web.search.WebSearchEngineIT; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +public abstract class WebSearchContentRetrieverIT extends WebSearchEngineIT { + + @Test + void should_retrieve_web_page_as_content() { + // given + WebSearchContentRetriever contentRetriever = WebSearchContentRetriever.from(searchEngine()); + Query query = Query.from("What is the current weather in New York?"); + + // when + List contents = contentRetriever.retrieve(query); + + // then + assertThat(contents) + .as("At least one content should be contains 'weather' and 'New York' ignoring case") + .anySatisfy(content -> { + assertThat(content.textSegment().text()) + .containsIgnoringCase("weather") + .containsIgnoringCase("New York"); + assertThat(content.textSegment().metadata().get("url")) + .startsWith("https://"); + assertThat(content.textSegment().metadata().get("title")) + .isNotBlank(); + } + ); + } +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetrieverTest.java b/langchain4j-core/src/test/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetrieverTest.java new file mode 100644 index 000000000..86154abca --- /dev/null +++ b/langchain4j-core/src/test/java/dev/langchain4j/rag/content/retriever/WebSearchContentRetrieverTest.java @@ -0,0 +1,68 @@ +package dev.langchain4j.rag.content.retriever; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.rag.content.Content; +import dev.langchain4j.rag.query.Query; +import dev.langchain4j.web.search.WebSearchEngine; +import dev.langchain4j.web.search.WebSearchInformationResult; +import dev.langchain4j.web.search.WebSearchOrganicResult; +import dev.langchain4j.web.search.WebSearchResults; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.net.URI; +import java.util.HashMap; +import java.util.List; + +import static java.util.Arrays.asList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.*; + +class WebSearchContentRetrieverTest { + + WebSearchEngine webSearchEngine; + + @BeforeEach + void mockWebSearchEngine() { + webSearchEngine = mock(WebSearchEngine.class); + when(webSearchEngine.search(anyString())).thenReturn( + new WebSearchResults( + WebSearchInformationResult.from(3L, 1, new HashMap<>()), + asList( + WebSearchOrganicResult.from("title 1", URI.create("https://google.com"), "snippet 1", null), + WebSearchOrganicResult.from("title 2", URI.create("https://docs.langchain4j.dev"), null, "content 2"), + WebSearchOrganicResult.from("title 3", URI.create("https://github.com/dewitt/opensearch/blob/master/README.md"), "snippet 3", "content 3") + ) + ) + ); + } + + @AfterEach + void resetWebSearchEngine() { + reset(webSearchEngine); + } + + @Test + void should_retrieve_web_pages_back() { + // given + ContentRetriever contentRetriever = WebSearchContentRetriever.from(webSearchEngine); + + Query query = Query.from("query"); + + // when + List contents = contentRetriever.retrieve(query); + + // then + assertThat(contents).containsExactly( + Content.from(TextSegment.from("title 1\nsnippet 1", Metadata.from("url", "https://google.com"))), + Content.from(TextSegment.from("title 2\ncontent 2", Metadata.from("url", "https://docs.langchain4j.dev"))), + Content.from(TextSegment.from("title 3\ncontent 3", Metadata.from("url", "https://github.com/dewitt/opensearch/blob/master/README.md"))) + ); + + verify(webSearchEngine).search(query.text()); + verifyNoMoreInteractions(webSearchEngine); + } +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchEngineIT.java b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchEngineIT.java new file mode 100644 index 000000000..f378abeb7 --- /dev/null +++ b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchEngineIT.java @@ -0,0 +1,71 @@ +package dev.langchain4j.web.search; + +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * A minimum set of tests that each implementation of {@link WebSearchEngine} must pass. + */ +public abstract class WebSearchEngineIT { + + protected abstract WebSearchEngine searchEngine(); + + @Test + void should_return_web_results_with_default_constructor() { + // given + String searchTerm = "What is the current weather in New York?"; + + // when + WebSearchResults results = searchEngine().search(searchTerm); + + // then + assertThat(results).isNotNull(); + assertThat(results.searchInformation()).isNotNull(); + assertThat(results.results()).isNotNull(); + + assertThat(results.searchInformation().totalResults()).isGreaterThan(0); + assertThat(results.results().size()).isGreaterThan(0); + } + + @Test + void should_return_web_results_with_max_results() { + // given + String searchTerm = "What is the current weather in New York?"; + WebSearchRequest webSearchRequest = WebSearchRequest.from(searchTerm, 5); + + // when + WebSearchResults results = searchEngine().search(webSearchRequest); + + // then + assertThat(results.searchInformation().totalResults()).isGreaterThanOrEqualTo (5); + assertThat(results.results()).hasSize(5); + assertThat(results.results()) + .as("At least one result should be contains 'weather' and 'New York' ignoring case") + .anySatisfy(result -> assertThat(result.snippet()) + .containsIgnoringCase("weather") + .containsIgnoringCase("New York")); + } + + @Test + void should_return_web_results_with_geolocation() { + // given + String searchTerm = "Who is the current president?"; + WebSearchRequest webSearchRequest = WebSearchRequest.builder() + .searchTerms(searchTerm) + .geoLocation("fr") + .build(); + + // when + List webSearchOrganicResults = searchEngine().search(webSearchRequest).results(); + + // then + assertThat(webSearchOrganicResults).isNotNull(); + assertThat(webSearchOrganicResults) + .as("At least one result should be contains 'Emmanuel Macro' ignoring case") + .anySatisfy(result -> assertThat(result.title()) + .containsIgnoringCase("Emmanuel Macro")); + } +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchInformationResultTest.java b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchInformationResultTest.java new file mode 100644 index 000000000..8664c4513 --- /dev/null +++ b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchInformationResultTest.java @@ -0,0 +1,53 @@ +package dev.langchain4j.web.search; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class WebSearchInformationResultTest { + + @Test + void should_return_webSearchInformationResult_with_default_values(){ + WebSearchInformationResult webSearchInformationResult = new WebSearchInformationResult(1L); + + assertThat(webSearchInformationResult.totalResults()).isEqualTo(1L); + assertThat(webSearchInformationResult.pageNumber()).isNull(); + assertThat(webSearchInformationResult.metadata()).isNull(); + + assertThat(webSearchInformationResult).hasToString("WebSearchInformationResult{totalResults=1, pageNumber=null, metadata=null}"); + } + + @Test + void should_return_webSearchInformationResult_with_informationResult(){ + WebSearchInformationResult webSearchInformationResult = WebSearchInformationResult.from(1L); + + assertThat(webSearchInformationResult.totalResults()).isEqualTo(1L); + assertThat(webSearchInformationResult.pageNumber()).isNull(); + assertThat(webSearchInformationResult.metadata()).isNull(); + + assertThat(webSearchInformationResult).hasToString("WebSearchInformationResult{totalResults=1, pageNumber=null, metadata=null}"); + } + + @Test + void test_equals_and_hash(){ + WebSearchInformationResult wsi1 = WebSearchInformationResult.from(1L); + WebSearchInformationResult wsi2 = WebSearchInformationResult.from(1L); + + assertThat(wsi1) + .isEqualTo(wsi1) + .isNotEqualTo(null) + .isNotEqualTo(new Object()) + .isEqualTo(wsi2) + .hasSameHashCodeAs(wsi2); + + assertThat(WebSearchInformationResult.from(2L)) + .isNotEqualTo(wsi1); + } + + @Test + void should_throw_illegalArgumentException(){ + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> WebSearchInformationResult.from(null)); + assertThat(exception.getMessage()).isEqualTo("totalResults cannot be null"); + } +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchOrganicResultTest.java b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchOrganicResultTest.java new file mode 100644 index 000000000..f96675fd9 --- /dev/null +++ b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchOrganicResultTest.java @@ -0,0 +1,134 @@ +package dev.langchain4j.web.search; + +import dev.langchain4j.data.document.Metadata; +import org.junit.jupiter.api.Test; + +import java.net.URI; +import java.util.AbstractMap; +import java.util.Map; +import java.util.stream.Stream; + +import static java.util.stream.Collectors.toMap; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class WebSearchOrganicResultTest { + + @Test + void should_build_webSearchOrganicResult_with_default_values(){ + WebSearchOrganicResult webSearchOrganicResult = WebSearchOrganicResult.from("title", URI.create("https://google.com")); + + assertThat(webSearchOrganicResult.title()).isEqualTo("title"); + assertThat(webSearchOrganicResult.url().toString()).isEqualTo("https://google.com"); + assertThat(webSearchOrganicResult.snippet()).isNull(); + assertThat(webSearchOrganicResult.content()).isNull(); + assertThat(webSearchOrganicResult.metadata()).isNull(); + } + + @Test + void should_build_webSearchOrganicResult_with_custom_snippet(){ + WebSearchOrganicResult webSearchOrganicResult = WebSearchOrganicResult.from("title", URI.create("https://google.com"), "snippet", null); + + assertThat(webSearchOrganicResult.title()).isEqualTo("title"); + assertThat(webSearchOrganicResult.url().toString()).isEqualTo("https://google.com"); + assertThat(webSearchOrganicResult.snippet()).isEqualTo("snippet"); + assertThat(webSearchOrganicResult.content()).isNull(); + assertThat(webSearchOrganicResult.metadata()).isNull(); + + assertThat(webSearchOrganicResult).hasToString("WebSearchOrganicResult{title='title', url=https://google.com, snippet='snippet', content='null', metadata=null}"); + } + + @Test + void should_build_webSearchOrganicResult_with_custom_content(){ + WebSearchOrganicResult webSearchOrganicResult = WebSearchOrganicResult.from("title", URI.create("https://google.com"), null, "content"); + + assertThat(webSearchOrganicResult.title()).isEqualTo("title"); + assertThat(webSearchOrganicResult.url().toString()).isEqualTo("https://google.com"); + assertThat(webSearchOrganicResult.snippet()).isNull(); + assertThat(webSearchOrganicResult.content()).isEqualTo("content"); + assertThat(webSearchOrganicResult.metadata()).isNull(); + + assertThat(webSearchOrganicResult).hasToString("WebSearchOrganicResult{title='title', url=https://google.com, snippet='null', content='content', metadata=null}"); + } + + @Test + void should_build_webSearchOrganicResult_with_custom_title_link_and_metadata(){ + WebSearchOrganicResult webSearchOrganicResult = WebSearchOrganicResult.from("title", URI.create("https://google.com"), "snippet", null, + Stream.of(new AbstractMap.SimpleEntry<>("key", "value")).collect(toMap(Map.Entry::getKey, Map.Entry::getValue))); + + assertThat(webSearchOrganicResult.title()).isEqualTo("title"); + assertThat(webSearchOrganicResult.url().toString()).isEqualTo("https://google.com"); + assertThat(webSearchOrganicResult.snippet()).isEqualTo("snippet"); + assertThat(webSearchOrganicResult.metadata()).containsExactly(new AbstractMap.SimpleEntry<>("key", "value")); + + assertThat(webSearchOrganicResult).hasToString("WebSearchOrganicResult{title='title', url=https://google.com, snippet='snippet', content='null', metadata={key=value}}"); + } + + @Test + void test_equals_and_hash(){ + WebSearchOrganicResult wsor1 = WebSearchOrganicResult.from("title", URI.create("https://google.com"), "snippet", null, + Stream.of(new AbstractMap.SimpleEntry<>("key", "value")).collect(toMap(Map.Entry::getKey, Map.Entry::getValue))); + + WebSearchOrganicResult wsor2 = WebSearchOrganicResult.from("title", URI.create("https://google.com"), "snippet", null, + Stream.of(new AbstractMap.SimpleEntry<>("key", "value")).collect(toMap(Map.Entry::getKey, Map.Entry::getValue))); + + assertThat(wsor1) + .isEqualTo(wsor1) + .isNotEqualTo(null) + .isNotEqualTo(new Object()) + .isEqualTo(wsor2) + .hasSameHashCodeAs(wsor2); + + assertThat(WebSearchOrganicResult.from("other title", URI.create("https://google.com"), "snippet", null, + Stream.of(new AbstractMap.SimpleEntry<>("key", "value")).collect(toMap(Map.Entry::getKey, Map.Entry::getValue)))) + .isNotEqualTo(wsor1); + + assertThat(WebSearchOrganicResult.from("title", URI.create("https://docs.langchain4j.dev"), "snippet", null, + Stream.of(new AbstractMap.SimpleEntry<>("key", "value")).collect(toMap(Map.Entry::getKey, Map.Entry::getValue)))) + .isNotEqualTo(wsor1); + + assertThat(WebSearchOrganicResult.from("title", URI.create("https://google.com"), "other snippet", null, + Stream.of(new AbstractMap.SimpleEntry<>("key", "value")).collect(toMap(Map.Entry::getKey, Map.Entry::getValue)))) + .isNotEqualTo(wsor1); + + assertThat(WebSearchOrganicResult.from("title", URI.create("https://google.com"), "snippet", null, + Stream.of(new AbstractMap.SimpleEntry<>("other key", "value")).collect(toMap(Map.Entry::getKey, Map.Entry::getValue)))) + .isNotEqualTo(wsor1); + } + + @Test + void should_return_textSegment(){ + WebSearchOrganicResult webSearchOrganicResult = WebSearchOrganicResult.from("title", URI.create("https://google.com"), "snippet", null, + Stream.of(new AbstractMap.SimpleEntry<>("key", "value")).collect(toMap(Map.Entry::getKey, Map.Entry::getValue))); + + assertThat(webSearchOrganicResult.toTextSegment().text()).isEqualTo("title\nsnippet"); + assertThat(webSearchOrganicResult.toTextSegment().metadata()).isEqualTo( + Metadata.from(Stream.of( + new AbstractMap.SimpleEntry<>("url", "https://google.com"), + new AbstractMap.SimpleEntry<>("key", "value")) + .collect(toMap(Map.Entry::getKey, Map.Entry::getValue)) + ) + ); + } + + @Test + void should_return_document(){ + WebSearchOrganicResult webSearchOrganicResult = WebSearchOrganicResult.from("title", URI.create("https://google.com"), "snippet", null, + Stream.of(new AbstractMap.SimpleEntry<>("key", "value")).collect(toMap(Map.Entry::getKey, Map.Entry::getValue))); + + assertThat(webSearchOrganicResult.toDocument().text()).isEqualTo("title\nsnippet"); + assertThat(webSearchOrganicResult.toDocument().metadata()).isEqualTo( + Metadata.from(Stream.of( + new AbstractMap.SimpleEntry<>("url", "https://google.com"), + new AbstractMap.SimpleEntry<>("key", "value")) + .collect(toMap(Map.Entry::getKey, Map.Entry::getValue)) + ) + ); + } + + @Test + void should_throw_illegalArgumentException_without_title(){ + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> WebSearchOrganicResult.from(null, URI.create("https://google.com"), "snippet", "content")); + assertThat(exception).hasMessage("title cannot be null or blank"); + } +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchRequestTest.java b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchRequestTest.java new file mode 100644 index 000000000..375e48586 --- /dev/null +++ b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchRequestTest.java @@ -0,0 +1,99 @@ +package dev.langchain4j.web.search; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class WebSearchRequestTest { + + @Test + void should_build_webSearchRequest_with_default_values(){ + WebSearchRequest webSearchRequest = WebSearchRequest.from("query"); + + assertThat(webSearchRequest.searchTerms()).isEqualTo("query"); + assertThat(webSearchRequest.startPage()).isEqualTo(1); + assertThat(webSearchRequest.maxResults()).isNull(); + assertThat(webSearchRequest.language()).isNull(); + assertThat(webSearchRequest.geoLocation()).isNull(); + assertThat(webSearchRequest.startIndex()).isNull(); + assertThat(webSearchRequest.safeSearch()).isTrue(); + assertThat(webSearchRequest.additionalParams()).isEmpty(); + + assertThat(webSearchRequest).hasToString("WebSearchRequest{searchTerms='query', maxResults=null, language='null', geoLocation='null', startPage=1, startIndex=null, siteRestrict=true, additionalParams={}}"); + } + + @Test + void should_build_webSearchRequest_with_default_values_builder(){ + WebSearchRequest webSearchRequest = WebSearchRequest.builder().searchTerms("query").build(); + + assertThat(webSearchRequest.searchTerms()).isEqualTo("query"); + assertThat(webSearchRequest.startPage()).isEqualTo(1); + assertThat(webSearchRequest.maxResults()).isNull(); + assertThat(webSearchRequest.language()).isNull(); + assertThat(webSearchRequest.geoLocation()).isNull(); + assertThat(webSearchRequest.startIndex()).isNull(); + assertThat(webSearchRequest.safeSearch()).isTrue(); + assertThat(webSearchRequest.additionalParams()).isEmpty(); + + assertThat(webSearchRequest).hasToString("WebSearchRequest{searchTerms='query', maxResults=null, language='null', geoLocation='null', startPage=1, startIndex=null, siteRestrict=true, additionalParams={}}"); + } + + @Test + void should_build_webSearchRequest_with_custom_maxResults(){ + WebSearchRequest webSearchRequest = WebSearchRequest.from("query", 10); + + assertThat(webSearchRequest.searchTerms()).isEqualTo("query"); + assertThat(webSearchRequest.startPage()).isEqualTo(1); + assertThat(webSearchRequest.maxResults()).isEqualTo(10); + assertThat(webSearchRequest.language()).isNull(); + assertThat(webSearchRequest.geoLocation()).isNull(); + assertThat(webSearchRequest.startIndex()).isNull(); + assertThat(webSearchRequest.safeSearch()).isTrue(); + assertThat(webSearchRequest.additionalParams()).isEmpty(); + + assertThat(webSearchRequest).hasToString("WebSearchRequest{searchTerms='query', maxResults=10, language='null', geoLocation='null', startPage=1, startIndex=null, siteRestrict=true, additionalParams={}}"); + } + + @Test + void should_build_webSearchRequest_with_custom_maxResults_builder(){ + WebSearchRequest webSearchRequest = WebSearchRequest.builder().searchTerms("query").maxResults(10).build(); + + assertThat(webSearchRequest.searchTerms()).isEqualTo("query"); + assertThat(webSearchRequest.startPage()).isEqualTo(1); + assertThat(webSearchRequest.maxResults()).isEqualTo(10); + assertThat(webSearchRequest.language()).isNull(); + assertThat(webSearchRequest.geoLocation()).isNull(); + assertThat(webSearchRequest.startIndex()).isNull(); + assertThat(webSearchRequest.safeSearch()).isTrue(); + assertThat(webSearchRequest.additionalParams()).isEmpty(); + + assertThat(webSearchRequest).hasToString("WebSearchRequest{searchTerms='query', maxResults=10, language='null', geoLocation='null', startPage=1, startIndex=null, siteRestrict=true, additionalParams={}}"); + } + + @Test + void test_equals_and_hash(){ + WebSearchRequest wsr1 = WebSearchRequest.from("query", 10); + WebSearchRequest wsr2 = WebSearchRequest.from("query", 10); + + assertThat(wsr1) + .isEqualTo(wsr1) + .isNotEqualTo(null) + .isNotEqualTo(new Object()) + .isEqualTo(wsr2) + .hasSameHashCodeAs(wsr2); + + assertThat(WebSearchRequest.from("other query", 10)) + .isNotEqualTo(wsr1); + + assertThat(WebSearchRequest.from("query", 20)) + .isNotEqualTo(wsr1); + } + + @Test + void should_throw_illegalArgumentException_without_searchTerms(){ + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> + WebSearchRequest.builder().build()); + assertThat(exception).hasMessage("searchTerms cannot be null or blank"); + } +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchResultsTest.java b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchResultsTest.java new file mode 100644 index 000000000..d73285b51 --- /dev/null +++ b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchResultsTest.java @@ -0,0 +1,109 @@ +package dev.langchain4j.web.search; + +import dev.langchain4j.data.document.Metadata; +import org.junit.jupiter.api.Test; + +import java.net.URI; +import java.util.AbstractMap; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Stream; + +import static java.util.Collections.emptyList; +import static java.util.Collections.singletonList; +import static java.util.stream.Collectors.toMap; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.anyList; + +class WebSearchResultsTest { + + @Test + void should_build_webSearchResults(){ + WebSearchResults webSearchResults = WebSearchResults.from( + WebSearchInformationResult.from(1L), + singletonList(WebSearchOrganicResult.from("title", URI.create("https://google.com")))); + + assertThat(webSearchResults.results()).hasSize(1); + assertThat(webSearchResults.results().get(0).url().toString()).isEqualTo("https://google.com"); + assertThat(webSearchResults.searchInformation().totalResults()).isEqualTo(1L); + + assertThat(webSearchResults).hasToString("WebSearchResults{searchMetadata=null, searchInformation=WebSearchInformationResult{totalResults=1, pageNumber=null, metadata=null}, results=[WebSearchOrganicResult{title='title', url=https://google.com, snippet='null', content='null', metadata=null}]}"); + } + + @Test + void test_equals_and_hash(){ + WebSearchResults wsr1 = WebSearchResults.from( + WebSearchInformationResult.from(1L), + singletonList(WebSearchOrganicResult.from("title", URI.create("https://google.com")))); + + WebSearchResults wsr2 = WebSearchResults.from( + WebSearchInformationResult.from(1L), + singletonList(WebSearchOrganicResult.from("title", URI.create("https://google.com")))); + + assertThat(wsr1) + .isEqualTo(wsr1) + .isNotEqualTo(null) + .isNotEqualTo(new Object()) + .isEqualTo(wsr2) + .hasSameHashCodeAs(wsr2); + + assertThat(WebSearchResults.from( + WebSearchInformationResult.from(1L), + singletonList(WebSearchOrganicResult.from("title", URI.create("https://docs.langchain4j.dev"))))) + .isNotEqualTo(wsr1); + + assertThat(WebSearchResults.from( + WebSearchInformationResult.from(2L), + singletonList(WebSearchOrganicResult.from("title", URI.create("https://google.com"))))) + .isNotEqualTo(wsr1); + } + + @Test + void should_return_array_of_textSegments_with_snippet(){ + WebSearchResults webSearchResults = WebSearchResults.from( + WebSearchInformationResult.from(1L), + singletonList(WebSearchOrganicResult.from("title", URI.create("https://google.com"),"snippet", null))); + + assertThat(webSearchResults.toTextSegments()).hasSize(1); + assertThat(webSearchResults.toTextSegments().get(0).text()).isEqualTo("title\nsnippet"); + assertThat(webSearchResults.toTextSegments().get(0).metadata()).isEqualTo(Metadata.from("url", "https://google.com")); + } + + @Test + void should_return_array_of_documents_with_content(){ + WebSearchResults webSearchResults = WebSearchResults.from( + WebSearchInformationResult.from(1L), + singletonList(WebSearchOrganicResult.from("title", URI.create("https://google.com"),null, "content"))); + + assertThat(webSearchResults.toDocuments()).hasSize(1); + assertThat(webSearchResults.toDocuments().get(0).text()).isEqualTo("title\ncontent"); + assertThat(webSearchResults.toDocuments().get(0).metadata()).isEqualTo(Metadata.from("url", "https://google.com")); + } + + @Test + void should_throw_illegalArgumentException_without_searchInformation(){ + // given + Map searchMetadata = new HashMap<>(); + searchMetadata.put("key", "value"); + + // then + assertThrows(IllegalArgumentException.class, () -> new WebSearchResults( + searchMetadata, + null, + singletonList(WebSearchOrganicResult.from("title", URI.create("https://google.com"),"snippet",null)))); + } + + @Test + void should_throw_illegalArgumentException_without_results(){ + // given + Map searchMetadata = new HashMap<>(); + searchMetadata.put("key", "value"); + + // then + assertThrows(IllegalArgumentException.class, () -> new WebSearchResults( + searchMetadata, + WebSearchInformationResult.from(1L), + emptyList())); + } +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchToolIT.java b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchToolIT.java new file mode 100644 index 000000000..534874979 --- /dev/null +++ b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchToolIT.java @@ -0,0 +1,59 @@ +package dev.langchain4j.web.search; + +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.agent.tool.ToolSpecifications; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; + +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static java.util.Collections.singletonList; +import static org.assertj.core.api.Assertions.assertThat; + +public abstract class WebSearchToolIT extends WebSearchEngineIT { + + protected abstract ChatLanguageModel chatLanguageModel(); + + @Test + void should_be_usable_tool_with_chatLanguageModel(){ + // given + WebSearchTool webSearchTool = WebSearchTool.from(searchEngine()); + List tools = ToolSpecifications.toolSpecificationsFrom(webSearchTool); + + UserMessage userMessage = UserMessage.from("What is LangChain4j project?"); + + // when + AiMessage aiMessage = chatLanguageModel().generate(singletonList(userMessage), tools).content(); + + // then + assertThat(aiMessage.hasToolExecutionRequests()).isTrue(); + assertThat(aiMessage.toolExecutionRequests()) + .anySatisfy(toolSpec -> { + assertThat(toolSpec.name()) + .containsIgnoringCase("searchWeb"); + assertThat(toolSpec.arguments()) + .isNotBlank(); + } + ); + } + + @Test + void should_return_pretty_result_as_a_tool(){ + // given + WebSearchTool webSearchTool = WebSearchTool.from(searchEngine()); + String searchTerm = "What is LangChain4j project?"; + + // when + String strResult = webSearchTool.searchWeb(searchTerm); + + // then + assertThat(strResult).isNotBlank(); + assertThat(strResult) + .as("At least the string result should be contains 'java' and 'AI' ignoring case") + .containsIgnoringCase("Java") + .containsIgnoringCase("AI"); + } +} diff --git a/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchToolTest.java b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchToolTest.java new file mode 100644 index 000000000..8ad8a04e9 --- /dev/null +++ b/langchain4j-core/src/test/java/dev/langchain4j/web/search/WebSearchToolTest.java @@ -0,0 +1,58 @@ +package dev.langchain4j.web.search; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.net.URI; +import java.util.HashMap; + +import static java.util.Arrays.asList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +class WebSearchToolTest { + + WebSearchEngine webSearchEngine; + + @BeforeEach + void mockWebSearchEngine(){ + webSearchEngine = mock(WebSearchEngine.class); + when(webSearchEngine.search(anyString())).thenReturn( + new WebSearchResults( + WebSearchInformationResult.from(3L,1, new HashMap<>()), + asList( + WebSearchOrganicResult.from("title 1", URI.create("https://google.com"), "snippet 1", "content 1"), + WebSearchOrganicResult.from("title 2", URI.create("https://docs.langchain4j.dev"), "snippet 2", "content 2"), + WebSearchOrganicResult.from("title 3", URI.create("https://github.com/dewitt/opensearch/blob/master/README.md"), "snippet 3","content 3") + ) + ) + ); + } + + @AfterEach + void resetWebSearchEngine(){ + reset(webSearchEngine); + } + + @Test + void should_build_webSearchTool(){ + // given + String searchTerm = "Any text to search"; + WebSearchTool webSearchTool = WebSearchTool.from(webSearchEngine); + + // when + String strResult = webSearchTool.searchWeb(searchTerm); + + // then + assertThat(strResult).isNotBlank(); + assertThat(strResult) + .as("At least one result should be contains 'title 1' and 'https://google.com' and 'content 1'") + .contains("Title: title 1\nSource: https://google.com\nContent:\ncontent 1"); + + verify(webSearchEngine).search(searchTerm); + verifyNoMoreInteractions(webSearchEngine); + } +}