[FEATURE] Google web search integration (#641)

Integrating [Google Custom Search](https://developers.google.com/custom-search) as a `WebSearchEngine` and as a `Tool` for function calling
This commit is contained in:
Carlos Zela Bueno 2024-05-21 13:05:21 +01:00 committed by GitHub
parent 788de9fd91
commit 43274ff465
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 1133 additions and 0 deletions

View File

@ -369,6 +369,13 @@
<version>${project.version}</version>
</dependency>
<!-- web searchers -->
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-web-search-engine-google-custom</artifactId>
<version>${project.version}</version>
</dependency>
</dependencies>
</dependencyManagement>

View File

@ -75,6 +75,9 @@
<!-- embedding store filter parsers -->
<module>embedding-store-filter-parsers/langchain4j-embedding-store-filter-parser-sql</module>
<!-- web searchers -->
<module>web-searchers/langchain4j-web-search-engine-google-custom</module>
</modules>
<build>

View File

@ -0,0 +1,93 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.31.0-SNAPSHOT</version>
<relativePath>../../langchain4j-parent/pom.xml</relativePath>
</parent>
<artifactId>langchain4j-web-search-engine-google-custom</artifactId>
<packaging>jar</packaging>
<name>LangChain4j :: Web Search Engine :: Google Custom Search</name>
<description>Implementation of Google Custom Search API for LangChain4j</description>
<dependencies>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
</dependency>
<!-- Custom Search API Client Library for Java -->
<dependency>
<groupId>com.google.apis</groupId>
<artifactId>google-api-services-customsearch</artifactId>
<version>v1-rev20240417-2.0.0</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>2.0.7</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tinylog</groupId>
<artifactId>tinylog-impl</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tinylog</groupId>
<artifactId>slf4j-tinylog</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<scope>test</scope>
</dependency>
<!-- Visibility for WebSearchEngineIT -->
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<classifier>tests</classifier>
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,188 @@
package dev.langchain4j.web.search.google.customsearch;
import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport;
import com.google.api.client.http.HttpRequest;
import com.google.api.client.http.HttpRequestInitializer;
import com.google.api.client.http.javanet.NetHttpTransport;
import com.google.api.client.json.gson.GsonFactory;
import com.google.api.services.customsearch.v1.CustomSearchAPI;
import com.google.api.services.customsearch.v1.CustomSearchAPIRequest;
import com.google.api.services.customsearch.v1.model.Search;
import lombok.Builder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.time.Duration;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrBlank;
class GoogleCustomSearchApiClient {
private static final Logger LOGGER = LoggerFactory.getLogger(GoogleCustomSearchApiClient.class);
private static final Integer MAXIMUM_VALUE_NUM = 10;
private final CustomSearchAPIRequest<Search> customSearchRequest;
private final Boolean logRequestResponse;
@Builder
GoogleCustomSearchApiClient(String apiKey,
String csi,
Boolean siteRestrict,
Duration timeout,
Integer maxRetries,
Boolean logRequestResponse) {
try {
if (isNullOrBlank(apiKey)) {
throw new IllegalArgumentException("Google Custom Search API Key must be defined. " +
"It can be generated here: https://console.developers.google.com/apis/credentials");
}
if (isNullOrBlank(csi)) {
throw new IllegalArgumentException("Google Custom Search Engine ID must be defined. " +
"It can be created here: https://cse.google.com/cse/create/new");
}
CustomSearchAPI.Builder customSearchAPIBuilder = new CustomSearchAPI.Builder(GoogleNetHttpTransport.newTrustedTransport(), new GsonFactory(), new HttpRequestInitializer() {
@Override
public void initialize(HttpRequest httpRequest) throws IOException {
httpRequest.setConnectTimeout(Math.toIntExact(timeout.toMillis()));
httpRequest.setReadTimeout(Math.toIntExact(timeout.toMillis()));
httpRequest.setWriteTimeout(Math.toIntExact(timeout.toMillis()));
httpRequest.setLoggingEnabled(logRequestResponse);
httpRequest.setNumberOfRetries(maxRetries);
if (logRequestResponse){
httpRequest.setInterceptor(new GoogleSearchApiHttpRequestLoggingInterceptor());
httpRequest.setResponseInterceptor(new GoogleSearchApiHttpResponseLoggingInterceptor());
}
}
}).setApplicationName("LangChain4j");
CustomSearchAPI customSearchAPI = customSearchAPIBuilder.build();
if (siteRestrict) {
customSearchRequest = customSearchAPI.cse().siterestrict().list().setKey(apiKey).setCx(csi);
} else {
customSearchRequest = customSearchAPI.cse().list().setKey(apiKey).setCx(csi);
}
this.logRequestResponse = logRequestResponse;
} catch (IOException e) {
LOGGER.error("Error occurred while creating Google Custom Search API client", e);
throw new RuntimeException(e);
} catch (GeneralSecurityException e) {
LOGGER.error("Error occurred while creating Google Custom Search API client using GoogleNetHttpTransport.newTrustedTransport()", e);
throw new RuntimeException(e);
}
}
Search searchResults(Search.Queries.Request requestQuery) {
try {
Search searchPerformed;
if (customSearchRequest instanceof CustomSearchAPI.Cse.Siterestrict.List) {
searchPerformed = ((CustomSearchAPI.Cse.Siterestrict.List) customSearchRequest)
.setPrettyPrint(true)
.setQ(requestQuery.getSearchTerms())
.setNum(maxResultsAllowed(getDefaultNaturalNumber(requestQuery.getCount())))
.setSort(requestQuery.getSort())
.setSafe(requestQuery.getSafe())
.setDateRestrict(requestQuery.getDateRestrict())
.setGl(requestQuery.getGl())
.setLr(requestQuery.getLanguage())
.setHl(requestQuery.getHl())
.setHq(requestQuery.getHq())
.setSiteSearch(requestQuery.getSiteSearch())
.setSiteSearchFilter(requestQuery.getSiteSearchFilter())
.setExactTerms(requestQuery.getExactTerms())
.setExcludeTerms(requestQuery.getExcludeTerms())
.setLinkSite(requestQuery.getLinkSite())
.setOrTerms(requestQuery.getOrTerms())
.setLowRange(requestQuery.getLowRange())
.setHighRange(requestQuery.getHighRange())
.setSearchType(requestQuery.getSearchType())
.setFileType(requestQuery.getFileType())
.setRights(requestQuery.getRights())
.setImgSize(requestQuery.getImgSize())
.setImgType(requestQuery.getImgType())
.setImgColorType(requestQuery.getImgColorType())
.setImgDominantColor(requestQuery.getImgDominantColor())
.setC2coff(requestQuery.getDisableCnTwTranslation())
.setCr(requestQuery.getCr())
.setGooglehost(requestQuery.getGoogleHost())
.setStart(calculateIndexStartPage(
getDefaultNaturalNumber(requestQuery.getStartPage()),
getDefaultNaturalNumber(requestQuery.getStartIndex())
).longValue())
.setFilter(requestQuery.getFilter())
.execute();
} else if (customSearchRequest instanceof CustomSearchAPI.Cse.List) {
searchPerformed = ((CustomSearchAPI.Cse.List) customSearchRequest)
.setPrettyPrint(true)
.setQ(requestQuery.getSearchTerms())
.setNum(maxResultsAllowed(getDefaultNaturalNumber(requestQuery.getCount())))
.setSort(requestQuery.getSort())
.setSafe(requestQuery.getSafe())
.setDateRestrict(requestQuery.getDateRestrict())
.setGl(requestQuery.getGl())
.setLr(requestQuery.getLanguage())
.setHl(requestQuery.getHl())
.setHq(requestQuery.getHq())
.setSiteSearch(requestQuery.getSiteSearch())
.setSiteSearchFilter(requestQuery.getSiteSearchFilter())
.setExactTerms(requestQuery.getExactTerms())
.setExcludeTerms(requestQuery.getExcludeTerms())
.setLinkSite(requestQuery.getLinkSite())
.setOrTerms(requestQuery.getOrTerms())
.setLowRange(requestQuery.getLowRange())
.setHighRange(requestQuery.getHighRange())
.setSearchType(requestQuery.getSearchType())
.setFileType(requestQuery.getFileType())
.setRights(requestQuery.getRights())
.setImgSize(requestQuery.getImgSize())
.setImgType(requestQuery.getImgType())
.setImgColorType(requestQuery.getImgColorType())
.setImgDominantColor(requestQuery.getImgDominantColor())
.setC2coff(requestQuery.getDisableCnTwTranslation())
.setCr(requestQuery.getCr())
.setGooglehost(requestQuery.getGoogleHost())
.setStart(calculateIndexStartPage(
getDefaultNaturalNumber(requestQuery.getStartPage()),
getDefaultNaturalNumber(requestQuery.getStartIndex())
).longValue())
.setFilter(requestQuery.getFilter())
.execute();
} else {
throw new IllegalStateException("Invalid CustomSearchAPIRequest type");
}
if (logRequestResponse) {
log(searchPerformed);
}
return searchPerformed;
} catch (IOException e) {
LOGGER.error("Error occurred while searching", e);
throw new RuntimeException(e);
}
}
private static void log(Search search){
try {
LOGGER.debug("Response:\n- Response: {}", search.toPrettyString());
} catch (IOException e) {
LOGGER.warn("Error while logging response: {}", e.getMessage());
}
}
private static Integer maxResultsAllowed(Integer maxResults){
return maxResults > MAXIMUM_VALUE_NUM ? MAXIMUM_VALUE_NUM : maxResults;
}
private static Integer getDefaultNaturalNumber(Integer number){
int defaultNumber = getOrDefault(number, 1);
return defaultNumber > 0 ? defaultNumber : 1;
}
private static Integer calculateIndexStartPage(Integer pageNumber, Integer index) {
int indexStartPage = ((pageNumber -1) * MAXIMUM_VALUE_NUM) + 1;
return indexStartPage >= index ? indexStartPage : index;
}
}

View File

@ -0,0 +1,273 @@
package dev.langchain4j.web.search.google.customsearch;
import com.google.api.client.json.GenericJson;
import com.google.api.services.customsearch.v1.model.Result;
import com.google.api.services.customsearch.v1.model.Search;
import dev.langchain4j.web.search.WebSearchEngine;
import dev.langchain4j.web.search.WebSearchInformationResult;
import dev.langchain4j.web.search.WebSearchOrganicResult;
import dev.langchain4j.web.search.WebSearchRequest;
import dev.langchain4j.web.search.WebSearchResults;
import lombok.Builder;
import java.net.URI;
import java.time.Duration;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static com.google.api.services.customsearch.v1.model.Search.*;
import static dev.langchain4j.internal.Utils.*;
import static dev.langchain4j.internal.ValidationUtils.*;
import static java.util.stream.Collectors.toList;
/**
* An implementation of a {@link WebSearchEngine} that uses
* <a href="https://programmablesearchengine.google.com/">Google Custom Search API</a> for performing web searches.
*/
public class GoogleCustomWebSearchEngine implements WebSearchEngine {
private final GoogleCustomSearchApiClient googleCustomSearchApiClient;
private final Boolean includeImages;
/**
* Constructs a new GoogleCustomWebSearchEngine with the specified parameters.
*
* @param apiKey the Google Search API key for accessing the Google Custom Search API
* <p>
* You can just generate an API key <a href="https://developers.google.com/custom-search/docs/paid_element#api_key">here</a>
* @param csi the Custom Search ID parameter for search the entire web
* <p>
* You can create a Custom Search Engine <a href="https://cse.google.com/cse/create/new">here</a>
* @param siteRestrict if your Search Engine is restricted to only searching specific sites, you can set this parameter to true.
* <p>
* Default value is false. View the documentation for more information <a href="https://developers.google.com/custom-search/v1/site_restricted_api">here</a>
* @param includeImages If it is true then include public images relevant to the query. This can add more latency to the search.
* <p>
* Default value is false.
* @param timeout the timeout duration for API requests
* <p>
* Default value is 60 seconds.
* @param logRequestResponse whether to log API request and response
* <p>
* Default value is false.
* @param maxRetries the maximum number of retries for API requests
* <p>
* Default value is 10.
*/
@Builder
public GoogleCustomWebSearchEngine(String apiKey,
String csi,
Boolean siteRestrict,
Boolean includeImages,
Duration timeout,
Boolean logRequestResponse,
Integer maxRetries) {
this.googleCustomSearchApiClient = GoogleCustomSearchApiClient.builder()
.apiKey(apiKey)
.csi(csi)
.siteRestrict(getOrDefault(siteRestrict,false))
.timeout(getOrDefault(timeout,Duration.ofSeconds(60)))
.logRequestResponse(getOrDefault(logRequestResponse,false))
.maxRetries(getOrDefault(maxRetries,10))
.build();
this.includeImages = getOrDefault(includeImages,false);
}
/**
* Creates a new builder for constructing a GoogleCustomWebSearchEngine with the specified API key and Custom Search ID.
*
* @param apiKey the API key for accessing the Google Custom Search API
* @param csi the Custom Search ID parameter for search the entire web
* @return a new builder instance
*/
public static GoogleCustomWebSearchEngine withApiKeyAndCsi(String apiKey, String csi){
return GoogleCustomWebSearchEngine.builder().apiKey(apiKey).csi(csi).build();
}
@Override
public WebSearchResults search(WebSearchRequest webSearchRequest) {
ensureNotNull(webSearchRequest, "webSearchRequest");
Queries.Request requestQuery = new Queries.Request();
requestQuery.setSearchTerms(webSearchRequest.searchTerms());
requestQuery.setCount(getOrDefault(webSearchRequest.maxResults(),5));
requestQuery.setGl(webSearchRequest.geoLocation());
requestQuery.setLanguage(webSearchRequest.language());
requestQuery.setStartPage(webSearchRequest.startPage());
requestQuery.setStartIndex(webSearchRequest.startIndex());
requestQuery.setSafe(webSearchRequest.safeSearch()?"active":"off");
requestQuery.setFilter("1"); // By default, applies filtering to remove duplicate content
requestQuery.setCr(setCountryRestrict(webSearchRequest));
webSearchRequest.additionalParams().forEach(requestQuery::set);
boolean searchTypeImage = isNotNullOrBlank(requestQuery.getSearchType()) && requestQuery.getSearchType().equals("image");
// Web search
Search search = googleCustomSearchApiClient.searchResults(requestQuery);
Map<String, Object> searchMetadata = toSearchMetadata(search, searchTypeImage);
Map<String, Object> searchInformationMetadata = new HashMap<>();
// Images search
if (includeImages && !searchTypeImage) {
requestQuery.setSearchType("image");
Search imagesSearch = googleCustomSearchApiClient.searchResults(requestQuery);
List<ImageSearchResult> images = imagesSearch.getItems().stream()
.map(result -> ImageSearchResult.from(
result.getTitle(),
URI.create(result.getLink()),
URI.create(result.getImage().getContextLink()),
URI.create(result.getImage().getThumbnailLink())))
.collect(toList());
addImagesToSearchInformation(searchInformationMetadata, images);
}
return WebSearchResults.from(
searchMetadata
, WebSearchInformationResult.from(
Long.valueOf(getOrDefault(search.getSearchInformation().getTotalResults(),"0")),
!isNullOrEmpty(search.getQueries().getRequest())
?calculatePageNumberFromQueries(search.getQueries().getRequest().get(0)):1,
searchInformationMetadata.isEmpty()?null:searchInformationMetadata)
, search.getItems().stream()
.map(result -> WebSearchOrganicResult.from(
result.getTitle(),
URI.create(result.getLink()),
result.getSnippet(),
null, // by default google custom search api does not return content
toResultMetadataMap(result, searchTypeImage)
)).collect(toList()));
}
private static void addImagesToSearchInformation(Map<String, Object> searchInformationMetadata, List<ImageSearchResult> images) {
if (!isNullOrEmpty(images)) {
searchInformationMetadata.put("images", images);
}
}
private static Map<String, Object> toSearchMetadata(Search search, Boolean searchTypeImage) {
if (search == null) {
return null;
}
Map<String, Object> searchMetadata = new HashMap<>();
searchMetadata.put("status", "Success");
searchMetadata.put("searchTime", search.getSearchInformation().getSearchTime());
searchMetadata.put("processedAt", LocalDateTime.now().toString());
searchMetadata.put("searchType", searchTypeImage ? "images" : "web");
searchMetadata.putAll(search.getContext());
return searchMetadata;
}
private static Map<String, String> toResultMetadataMap(Result result, boolean searchTypeImage) {
Map<String, String> metadata = new HashMap<>();
// Image search type
if (searchTypeImage) {
metadata.put("imageLink", result.getLink());
metadata.put("contextLink", result.getImage().getContextLink());
metadata.put("thumbnailLink", result.getImage().getThumbnailLink());
metadata.put("mimeType", result.getMime());
return metadata;
}
// Web search type
if (!result.getPagemap().isEmpty()) {
result.getPagemap().forEach((key, value) -> {
if (key.equals("metatags")) {
if (value instanceof List) {
metadata.put(key, ((List<?>) value).stream().map(Object::toString).reduce((a, b) -> a + ", " + b).orElse(""));
} else {
metadata.put(key, value.toString());
}
}
metadata.put("mimeType", isNotNullOrBlank(result.getMime()) ? result.getMime() : "text/html");
});
return metadata;
}
return null;
}
private static Integer calculatePageNumberFromQueries(GenericJson query) {
if (query instanceof Queries.PreviousPage) {
Queries.PreviousPage previousPage = (Queries.PreviousPage) query;
return calculatePageNumber(previousPage.getStartIndex());
}
if (query instanceof Queries.Request) {
Queries.Request currentPage = (Queries.Request) query;
return calculatePageNumber(getOrDefault(currentPage.getStartIndex(),1));
}
if (query instanceof Queries.NextPage) {
Queries.NextPage nextPage = (Queries.NextPage) query;
return calculatePageNumber(nextPage.getStartIndex());
}
return null;
}
private static Integer calculatePageNumber(Integer startIndex) {
if (startIndex == null)
return null;
return ((startIndex -1) / 10) + 1;
}
private static String setCountryRestrict(WebSearchRequest webSearchRequest){
return webSearchRequest.additionalParams().get("cr") != null ? webSearchRequest.additionalParams().get("cr").toString()
: isNotNullOrBlank(webSearchRequest.geoLocation()) ? "country" + webSearchRequest.geoLocation().toUpperCase()
: ""; // default value
}
public static final class ImageSearchResult {
private final String title;
private final URI imageLink;
private final URI contextLink;
private final URI thumbnailLink;
private ImageSearchResult(String title, URI imageLink) {
this.title = ensureNotNull(title,"title");
this.imageLink = ensureNotNull(imageLink,"imageLink");
this.contextLink = null;
this.thumbnailLink = null;
}
private ImageSearchResult(String title, URI imageLink, URI contextLink, URI thumbnailLink) {
this.title = ensureNotNull(title,"title");
this.imageLink = ensureNotNull(imageLink,"imageLink");
this.contextLink = contextLink;
this.thumbnailLink = thumbnailLink;
}
public String title() {
return title;
}
public URI imageLink() {
return imageLink;
}
public URI contextLink() {
return contextLink;
}
public URI thumbnailLink() {
return thumbnailLink;
}
@Override
public String toString() {
return "ImageSearchResult{" +
"title='" + title + '\'' +
", imageLink=" + imageLink +
", contextLink=" + contextLink +
", thumbnailLink=" + thumbnailLink +
'}';
}
public static ImageSearchResult from(String title, URI imageLink) {
return new ImageSearchResult(title, imageLink);
}
public static ImageSearchResult from(String title, URI imageLink, URI contextLink, URI thumbnailLink) {
return new ImageSearchResult(title, imageLink, contextLink, thumbnailLink);
}
}
}

View File

@ -0,0 +1,47 @@
package dev.langchain4j.web.search.google.customsearch;
import com.google.api.client.http.HttpContent;
import com.google.api.client.http.HttpExecuteInterceptor;
import com.google.api.client.http.HttpHeaders;
import com.google.api.client.http.HttpRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.stream.Collectors;
class GoogleSearchApiHttpRequestLoggingInterceptor implements HttpExecuteInterceptor {
private static final Logger LOGGER = LoggerFactory.getLogger(GoogleSearchApiHttpRequestLoggingInterceptor.class);
@Override
public void intercept(HttpRequest httpRequest) throws IOException {
this.log(httpRequest);
}
private void log(HttpRequest httpRequest){
try {
LOGGER.debug("Request:\n- method: {}\n- url: {}\n- headers: {}\n- body: {}",
httpRequest.getRequestMethod(), httpRequest.getUrl(), getHeaders(httpRequest.getHeaders()), getBody(httpRequest.getContent()));
} catch (Exception e) {
LOGGER.warn("Error while logging request: {}", e.getMessage());
}
}
private static String getHeaders(HttpHeaders headers){
return headers.entrySet().stream()
.map(entry -> String.format("[%s: %s]", entry.getKey(), entry.getValue())).collect(Collectors.joining(", "));
}
private static String getBody(HttpContent content){
try {
if (content == null) {
return "";
}
return content.toString();
} catch (Exception e) {
LOGGER.warn("Exception while getting body", e);
return "Exception while getting body: " + e.getMessage();
}
}
}

View File

@ -0,0 +1,36 @@
package dev.langchain4j.web.search.google.customsearch;
import com.google.api.client.http.HttpHeaders;
import com.google.api.client.http.HttpResponse;
import com.google.api.client.http.HttpResponseInterceptor;
import com.google.api.client.json.gson.GsonFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.stream.Collectors;
class GoogleSearchApiHttpResponseLoggingInterceptor implements HttpResponseInterceptor {
private static final Logger LOGGER = LoggerFactory.getLogger(GoogleSearchApiHttpResponseLoggingInterceptor.class);
@Override
public void interceptResponse(HttpResponse httpResponse) throws IOException {
this.log(httpResponse);
}
private void log(HttpResponse httpResponse){
try {
httpResponse.getRequest().setParser(new GsonFactory().createJsonObjectParser());
LOGGER.debug("Response:\n- status code: {}\n- headers: {}",
httpResponse.getStatusCode(), getHeaders(httpResponse.getHeaders())); // response body can't be got twice by google token constraints, it'll be logged in GoogleCustomSearchApiClient
} catch (Exception e) {
LOGGER.warn("Error while logging response: {}", e.getMessage());
}
}
private static String getHeaders(HttpHeaders headers){
return headers.entrySet().stream()
.map(entry -> String.format("[%s: %s]", entry.getKey(), entry.getValue())).collect(Collectors.joining(", "));
}
}

View File

@ -0,0 +1,111 @@
package dev.langchain4j.web.search.google.customsearch;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiChatModelName;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.content.retriever.WebSearchContentRetriever;
import dev.langchain4j.rag.content.retriever.WebSearchContentRetrieverIT;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.web.search.WebSearchEngine;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import java.util.List;
import static org.assertj.core.api.Assertions.assertThat;
@EnabledIfEnvironmentVariable(named = "GOOGLE_API_KEY", matches = ".*")
@EnabledIfEnvironmentVariable(named = "GOOGLE_SEARCH_ENGINE_ID", matches = ".*")
class GoogleCustomWebSearchContentRetrieverIT extends WebSearchContentRetrieverIT {
WebSearchEngine googleSearchEngine = GoogleCustomWebSearchEngine.withApiKeyAndCsi(
System.getenv("GOOGLE_API_KEY"),
System.getenv("GOOGLE_SEARCH_ENGINE_ID"));
ChatLanguageModel chatModel = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName(OpenAiChatModelName.GPT_3_5_TURBO)
.logRequests(true)
.build();
interface Assistant {
@SystemMessage({
"You are a web search support agent.",
"If there is any event that has not happened yet, ",
"you MUST use a web search tool to look up the information on the web.",
"Include the source link and the image urls in your final response if these known, otherwise, do not include them.",
"Do not say that you have not the capability to browse the web in real time"
})
String answer(String userMessage);
}
@Test
void should_retrieve_web_content_with_google_for_current_info() {
// given
googleSearchEngine = GoogleCustomWebSearchEngine.builder()
.apiKey(System.getenv("GOOGLE_API_KEY"))
.csi(System.getenv("GOOGLE_SEARCH_ENGINE_ID"))
.logRequestResponse(true)
.maxRetries(3)
.build();
WebSearchContentRetriever contentRetriever = WebSearchContentRetriever.from(googleSearchEngine);
Query query = Query.from("What is the latest currency exchange rates for the US Dollar and Euro");
// when
List<Content> contents = contentRetriever.retrieve(query);
System.out.println("contents: " + contents);
// then
assertThat(contents)
.as("At least one content should be contains 'us dollar' and 'euro' ignoring case")
.anySatisfy(content -> {
assertThat(content.textSegment().text())
.containsIgnoringCase("us dollar")
.containsIgnoringCase("euro");
assertThat(content.textSegment().metadata().get("url"))
.startsWith("https://");
}
);
}
@Test
void should_retrieve_web_content_with_google_and_use_AiServices_to_summary_response () {
// given
googleSearchEngine = GoogleCustomWebSearchEngine.builder()
.apiKey(System.getenv("GOOGLE_API_KEY"))
.csi(System.getenv("GOOGLE_SEARCH_ENGINE_ID"))
.logRequestResponse(true)
.maxRetries(3)
.build();
WebSearchContentRetriever contentRetriever = WebSearchContentRetriever.from(googleSearchEngine);
String query = "My family is coming to visit me in Madrid next week, list the best tourist activities suitable for the whole family";
// when
Assistant assistant = AiServices.builder(Assistant.class)
.chatLanguageModel(chatModel)
.contentRetriever(contentRetriever)
.chatMemory(MessageWindowChatMemory.withMaxMessages(10))
.build();
String answer = assistant.answer(query);
System.out.println(answer);
// then
assertThat(answer)
.as("At least the string result should be contains 'madrid' and 'tourist' ignoring case")
.containsIgnoringCase("Madrid")
.containsIgnoringCase("Royal Palace of Madrid");
}
@Override
protected WebSearchEngine searchEngine() {
return googleSearchEngine;
}
}

View File

@ -0,0 +1,207 @@
package dev.langchain4j.web.search.google.customsearch;
import dev.langchain4j.web.search.*;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static java.util.Collections.singletonMap;
import static org.assertj.core.api.Assertions.assertThat;
import static dev.langchain4j.web.search.google.customsearch.GoogleCustomWebSearchEngine.ImageSearchResult;
@EnabledIfEnvironmentVariable(named = "GOOGLE_API_KEY", matches = ".*")
@EnabledIfEnvironmentVariable(named = "GOOGLE_SEARCH_ENGINE_ID", matches = ".*")
class GoogleCustomWebSearchEngineIT extends WebSearchEngineIT {
WebSearchEngine googleSearchEngine = GoogleCustomWebSearchEngine.withApiKeyAndCsi(
System.getenv("GOOGLE_API_KEY"),
System.getenv("GOOGLE_SEARCH_ENGINE_ID"));
@Test
void should_return_google_web_results_with_search_information() {
// given
String query = "What is LangChain4j project?";
// when
WebSearchResults results = googleSearchEngine.search(query);
// then
assertThat(results.searchMetadata()).isNotNull();
assertThat(results.searchInformation().totalResults()).isGreaterThan(0);
assertThat(results.results().size()).isGreaterThan(0);
}
@Test
void should_return_google_safe_web_results_in_spanish_language() {
// given
String query = "Who won the FIFA World Cup 2022?";
WebSearchRequest webSearchRequest = WebSearchRequest.builder()
.searchTerms(query)
.language("lang_es")
.safeSearch(true)
.build();
// when
List<WebSearchOrganicResult> results = googleSearchEngine.search(webSearchRequest).results();
// then
assertThat(results)
.as("At least one result should be contains 'argentina' ignoring case")
.anySatisfy(result -> assertThat(result.snippet())
.containsIgnoringCase("argentina"));
}
@Test
void should_return_google_results_of_the_second_page_and_log_http_req_resp() {
// given
WebSearchEngine googleSearchEngine = GoogleCustomWebSearchEngine.builder()
.apiKey(System.getenv("GOOGLE_API_KEY"))
.csi(System.getenv("GOOGLE_SEARCH_ENGINE_ID"))
.logRequestResponse(true)
.build();
String query = "What is the weather in Porto?";
WebSearchRequest webSearchRequest = WebSearchRequest.builder()
.searchTerms(query)
.maxResults(5)
.startPage(2)
.build();
// when
WebSearchResults webSearchResults = googleSearchEngine.search(webSearchRequest);
// then
assertThat(webSearchResults.results())
.as("At least the string result should be contains 'weather' and 'Porto' ignoring case")
.anySatisfy(result -> assertThat(result.snippet())
.containsIgnoringCase("weather")
.containsIgnoringCase("porto"));
}
@Test
void should_return_google_results_using_and_fix_startpage_by_startindex(){
// given
String query = "What is LangChain4j project?";
WebSearchRequest webSearchRequest = WebSearchRequest.builder()
.searchTerms(query)
.language("lang_en")
.startPage(1) //user bad request
.startIndex(15)
.build();
// when
WebSearchResults webSearchResults = googleSearchEngine.search(webSearchRequest);
// then
assertThat(webSearchResults.results())
.as("At least one result should be contains 'Java' and 'AI' ignoring case")
.anySatisfy(result -> assertThat(result.snippet())
.containsIgnoringCase("Java")
.containsIgnoringCase("AI"));
}
@Test
void should_return_google_result_using_additional_params() {
// given
WebSearchEngine googleSearchEngine = GoogleCustomWebSearchEngine.builder()
.apiKey(System.getenv("GOOGLE_API_KEY"))
.csi(System.getenv("GOOGLE_SEARCH_ENGINE_ID"))
.logRequestResponse(true)
.build();
String query = "What is LangChain4j project?";
Map<String, Object> additionalParams = new HashMap<>();
additionalParams.put("dateRestrict", "w[2]");
additionalParams.put("linkSite", "https://github.com/langchain4j/langchain4j");
WebSearchRequest webSearchRequest = WebSearchRequest.builder()
.searchTerms(query)
.additionalParams(additionalParams)
.build();
// when
List<WebSearchOrganicResult> results = googleSearchEngine.search(webSearchRequest).results();
// then
assertThat(results)
.as("At least one result should be contains 'Java' and 'AI' ignoring case")
.anySatisfy(result -> assertThat(result.snippet())
.containsIgnoringCase("Java")
.containsIgnoringCase("github.com/langchain4j/langchain4j"));
}
@Test
void should_return_google_result_with_images_related() {
// given
WebSearchEngine googleSearchEngine = GoogleCustomWebSearchEngine.builder()
.apiKey(System.getenv("GOOGLE_API_KEY"))
.csi(System.getenv("GOOGLE_SEARCH_ENGINE_ID"))
.includeImages(true) // execute an additional search, searchType: image
.logRequestResponse(true)
.build();
String query = "Which top 2024 universities to study computer science?";
WebSearchRequest webSearchRequest = WebSearchRequest.builder()
.searchTerms(query)
.build();
// when
WebSearchResults webSearchResults = googleSearchEngine.search(webSearchRequest);
// then
assertThat(webSearchResults.searchMetadata().get("searchType").toString()).isEqualTo("web"); // searchType: web
assertThat(webSearchResults.searchInformation().metadata().get("images")).isOfAnyClassIn(ArrayList.class, List.class); // should add images related to the query
assertThat((List<ImageSearchResult>)webSearchResults.searchInformation().metadata().get("images")) // Get images from searchInformation.metadata
.as("At least one image result should be contains title, link, contextLink and thumbnailLink")
.anySatisfy(image -> {
assertThat(image.title()).isNotNull();
assertThat(image.imageLink().toString()).startsWith("http");
assertThat(image.contextLink().toString()).startsWith("http");
assertThat(image.thumbnailLink().toString()).startsWith("http");
});
assertThat(webSearchResults.results()) // Get web results
.as("At least the string result should be contains 'university' and 'ranking' ignoring case")
.anySatisfy(result -> assertThat(result.snippet())
.containsIgnoringCase("university")
.containsIgnoringCase("ranking"));
}
@Test
void should_return_google_image_result_with_param_searchType_image() {
// given
String query = "How will be the weather next week in Lisbon and Porto cities?";
WebSearchRequest webSearchRequest = WebSearchRequest.builder()
.searchTerms(query)
.additionalParams(singletonMap("searchType", "image"))
.build();
// when
WebSearchResults webSearchResults = googleSearchEngine.search(webSearchRequest);
// then
assertThat(webSearchResults.searchMetadata().get("searchType").toString()).isEqualTo("images"); // searchType: images
assertThat(webSearchResults.results()) // Get images as search results
.as("At least the snippet should be contains 'weather' and 'Porto' ignoring case")
.anySatisfy(result -> assertThat(result.title())
.containsIgnoringCase("weather")
.containsIgnoringCase("porto"))
.anySatisfy(result -> assertThat(result.url().toString())
.startsWith("http"))
.anySatisfy(result -> assertThat(result.metadata().get("mimeType"))
.startsWith("image"))
.anySatisfy(result -> assertThat(result.metadata().get("imageLink"))
.isEqualTo(result.url().toString()))
.anySatisfy(result -> assertThat(result.metadata().get("contextLink"))
.startsWith("http"))
.anySatisfy(result -> assertThat(result.metadata().get("thumbnailLink"))
.startsWith("http"));
}
@Override
protected WebSearchEngine searchEngine() {
return googleSearchEngine;
}
}

View File

@ -0,0 +1,168 @@
package dev.langchain4j.web.search.google.customsearch;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.agent.tool.ToolSpecifications;
import dev.langchain4j.data.message.*;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiChatModelName;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.web.search.WebSearchEngine;
import dev.langchain4j.web.search.WebSearchTool;
import dev.langchain4j.web.search.WebSearchToolIT;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import java.util.ArrayList;
import java.util.List;
import static org.assertj.core.api.Assertions.assertThat;
@EnabledIfEnvironmentVariable(named = "GOOGLE_API_KEY", matches = ".*")
@EnabledIfEnvironmentVariable(named = "GOOGLE_SEARCH_ENGINE_ID", matches = ".*")
class GoogleCustomWebSearchToolIT extends WebSearchToolIT {
WebSearchEngine googleSearchEngine = GoogleCustomWebSearchEngine.withApiKeyAndCsi(
System.getenv("GOOGLE_API_KEY"),
System.getenv("GOOGLE_SEARCH_ENGINE_ID"));
ChatLanguageModel chatModel = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName(OpenAiChatModelName.GPT_3_5_TURBO)
.logRequests(true)
.build();
interface Assistant {
@dev.langchain4j.service.SystemMessage({
"You are a web search support agent.",
"If there is any event that has not happened yet",
"You MUST create a web search request with with user query and",
"use the web search tool to search the web for organic web results.",
"Include the source link in your final response."
})
String answer(String userMessage);
}
@Test
void should_execute_google_tool_with_chatLanguageModel_to_give_a_final_response(){
// given
googleSearchEngine = GoogleCustomWebSearchEngine.builder()
.apiKey(System.getenv("GOOGLE_API_KEY"))
.csi(System.getenv("GOOGLE_SEARCH_ENGINE_ID"))
.logRequestResponse(true)
.maxRetries(3)
.build();
WebSearchTool webSearchTool = WebSearchTool.from(googleSearchEngine);
List<ToolSpecification> tools = ToolSpecifications.toolSpecificationsFrom(webSearchTool);
String query = "What are the release dates for the movies coming out last week of May 2024?";
List<ChatMessage> messages = new ArrayList<>();
SystemMessage systemMessage = SystemMessage.from("You are a web search support agent. If there is any event that has not happened yet, you MUST use a web search tool to look up the information on the web. Include the source link in your final response. Do not say that you have not the capability to browse the web in real time");
messages.add(systemMessage);
UserMessage userMessage = UserMessage.from(query);
messages.add(userMessage);
// when
AiMessage aiMessage = chatLanguageModel().generate(messages, tools).content();
// then
assertThat(aiMessage.hasToolExecutionRequests()).isTrue();
assertThat(aiMessage.toolExecutionRequests())
.anySatisfy(toolSpec -> {
assertThat(toolSpec.name())
.containsIgnoringCase("searchWeb");
assertThat(toolSpec.arguments())
.isNotBlank();
}
);
messages.add(aiMessage);
// when
String strResult = webSearchTool.searchWeb(query);
ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(aiMessage.toolExecutionRequests().get(0), strResult);
messages.add(toolExecutionResultMessage);
AiMessage finalResponse = chatLanguageModel().generate(messages).content();
System.out.println(finalResponse.text());
// then
assertThat(finalResponse.text())
.as("At least the string result should be contains 'movies' and 'coming soon' ignoring case")
.containsIgnoringCase("movies")
.containsIgnoringCase("May 2024");
}
@Test
void should_execute_google_tool_with_chatLanguageModel_to_summary_response_in_images() {
// given
googleSearchEngine = GoogleCustomWebSearchEngine.builder()
.apiKey(System.getenv("GOOGLE_API_KEY"))
.csi(System.getenv("GOOGLE_SEARCH_ENGINE_ID"))
.logRequestResponse(true)
.build();
WebSearchTool webSearchTool = WebSearchTool.from(googleSearchEngine);
List<ToolSpecification> tools = ToolSpecifications.toolSpecificationsFrom(webSearchTool);
String query = "My family is coming to visit me in Madrid next week, list the best tourist activities suitable for the whole family";
List<ChatMessage> messages = new ArrayList<>();
SystemMessage systemMessage = SystemMessage.from("You are a web search support agent. If there is any event that has not happened yet, you MUST use a web search tool to look up the information on the web. Include the source link in your final response and the image urls. Do not say that you have not the capability to browse the web in real time");
messages.add(systemMessage);
UserMessage userMessage = UserMessage.from(query);
messages.add(userMessage);
// when
AiMessage aiMessage = chatLanguageModel().generate(messages, tools).content();
// then
assertThat(aiMessage.hasToolExecutionRequests()).isTrue();
assertThat(aiMessage.toolExecutionRequests())
.anySatisfy(toolSpec -> {
assertThat(toolSpec.name())
.containsIgnoringCase("searchWeb");
assertThat(toolSpec.arguments())
.isNotBlank();
}
);
messages.add(aiMessage);
// when
String strResult = webSearchTool.searchWeb(query);
ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(aiMessage.toolExecutionRequests().get(0), strResult);
messages.add(toolExecutionResultMessage);
AiMessage finalResponse = chatLanguageModel().generate(messages).content();
System.out.println(finalResponse.text());
// then
assertThat(finalResponse.text())
.as("At least the string result should be contains 'madrid' and 'tourist' ignoring case")
.containsIgnoringCase("Madrid")
.containsIgnoringCase("Royal Palace of Madrid");
}
@Test
void should_execute_google_tool_with_AiServices() {
// given
WebSearchTool webTool = WebSearchTool.from(googleSearchEngine);
Assistant assistant = AiServices.builder(Assistant.class)
.chatLanguageModel(chatModel)
.tools(webTool)
.chatMemory(MessageWindowChatMemory.withMaxMessages(10))
.build();
// when
String answer = assistant.answer("Search in the web who won the FIFA World Cup 2022?");
// then
assertThat(answer).containsIgnoringCase("Argentina");
}
@Override
protected WebSearchEngine searchEngine() {
return googleSearchEngine;
}
@Override
protected ChatLanguageModel chatLanguageModel() {
return chatModel;
}
}