[FEATURE] Google web search integration (#641)
Integrating [Google Custom Search](https://developers.google.com/custom-search) as a `WebSearchEngine` and as a `Tool` for function calling
This commit is contained in:
parent
788de9fd91
commit
43274ff465
|
@ -369,6 +369,13 @@
|
|||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<!-- web searchers -->
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-web-search-engine-google-custom</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
</dependencyManagement>
|
||||
|
||||
|
|
3
pom.xml
3
pom.xml
|
@ -75,6 +75,9 @@
|
|||
<!-- embedding store filter parsers -->
|
||||
<module>embedding-store-filter-parsers/langchain4j-embedding-store-filter-parser-sql</module>
|
||||
|
||||
<!-- web searchers -->
|
||||
<module>web-searchers/langchain4j-web-search-engine-google-custom</module>
|
||||
|
||||
</modules>
|
||||
|
||||
<build>
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-parent</artifactId>
|
||||
<version>0.31.0-SNAPSHOT</version>
|
||||
<relativePath>../../langchain4j-parent/pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
<artifactId>langchain4j-web-search-engine-google-custom</artifactId>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<name>LangChain4j :: Web Search Engine :: Google Custom Search</name>
|
||||
<description>Implementation of Google Custom Search API for LangChain4j</description>
|
||||
|
||||
<dependencies>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-core</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- Custom Search API Client Library for Java -->
|
||||
<dependency>
|
||||
<groupId>com.google.apis</groupId>
|
||||
<artifactId>google-api-services-customsearch</artifactId>
|
||||
<version>v1-rev20240417-2.0.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
<version>2.0.7</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-engine</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.assertj</groupId>
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.tinylog</groupId>
|
||||
<artifactId>tinylog-impl</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.tinylog</groupId>
|
||||
<artifactId>slf4j-tinylog</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<!-- Visibility for WebSearchEngineIT -->
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-core</artifactId>
|
||||
<classifier>tests</classifier>
|
||||
<type>test-jar</type>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-open-ai</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
</project>
|
|
@ -0,0 +1,188 @@
|
|||
package dev.langchain4j.web.search.google.customsearch;
|
||||
|
||||
import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport;
|
||||
import com.google.api.client.http.HttpRequest;
|
||||
import com.google.api.client.http.HttpRequestInitializer;
|
||||
import com.google.api.client.http.javanet.NetHttpTransport;
|
||||
import com.google.api.client.json.gson.GsonFactory;
|
||||
import com.google.api.services.customsearch.v1.CustomSearchAPI;
|
||||
import com.google.api.services.customsearch.v1.CustomSearchAPIRequest;
|
||||
import com.google.api.services.customsearch.v1.model.Search;
|
||||
import lombok.Builder;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.time.Duration;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.Utils.isNullOrBlank;
|
||||
|
||||
class GoogleCustomSearchApiClient {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(GoogleCustomSearchApiClient.class);
|
||||
private static final Integer MAXIMUM_VALUE_NUM = 10;
|
||||
private final CustomSearchAPIRequest<Search> customSearchRequest;
|
||||
private final Boolean logRequestResponse;
|
||||
|
||||
@Builder
|
||||
GoogleCustomSearchApiClient(String apiKey,
|
||||
String csi,
|
||||
Boolean siteRestrict,
|
||||
Duration timeout,
|
||||
Integer maxRetries,
|
||||
Boolean logRequestResponse) {
|
||||
|
||||
try {
|
||||
if (isNullOrBlank(apiKey)) {
|
||||
throw new IllegalArgumentException("Google Custom Search API Key must be defined. " +
|
||||
"It can be generated here: https://console.developers.google.com/apis/credentials");
|
||||
}
|
||||
if (isNullOrBlank(csi)) {
|
||||
throw new IllegalArgumentException("Google Custom Search Engine ID must be defined. " +
|
||||
"It can be created here: https://cse.google.com/cse/create/new");
|
||||
}
|
||||
|
||||
CustomSearchAPI.Builder customSearchAPIBuilder = new CustomSearchAPI.Builder(GoogleNetHttpTransport.newTrustedTransport(), new GsonFactory(), new HttpRequestInitializer() {
|
||||
@Override
|
||||
public void initialize(HttpRequest httpRequest) throws IOException {
|
||||
httpRequest.setConnectTimeout(Math.toIntExact(timeout.toMillis()));
|
||||
httpRequest.setReadTimeout(Math.toIntExact(timeout.toMillis()));
|
||||
httpRequest.setWriteTimeout(Math.toIntExact(timeout.toMillis()));
|
||||
httpRequest.setLoggingEnabled(logRequestResponse);
|
||||
httpRequest.setNumberOfRetries(maxRetries);
|
||||
if (logRequestResponse){
|
||||
httpRequest.setInterceptor(new GoogleSearchApiHttpRequestLoggingInterceptor());
|
||||
httpRequest.setResponseInterceptor(new GoogleSearchApiHttpResponseLoggingInterceptor());
|
||||
}
|
||||
}
|
||||
}).setApplicationName("LangChain4j");
|
||||
|
||||
CustomSearchAPI customSearchAPI = customSearchAPIBuilder.build();
|
||||
|
||||
if (siteRestrict) {
|
||||
customSearchRequest = customSearchAPI.cse().siterestrict().list().setKey(apiKey).setCx(csi);
|
||||
} else {
|
||||
customSearchRequest = customSearchAPI.cse().list().setKey(apiKey).setCx(csi);
|
||||
}
|
||||
this.logRequestResponse = logRequestResponse;
|
||||
} catch (IOException e) {
|
||||
LOGGER.error("Error occurred while creating Google Custom Search API client", e);
|
||||
throw new RuntimeException(e);
|
||||
} catch (GeneralSecurityException e) {
|
||||
LOGGER.error("Error occurred while creating Google Custom Search API client using GoogleNetHttpTransport.newTrustedTransport()", e);
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
Search searchResults(Search.Queries.Request requestQuery) {
|
||||
try {
|
||||
Search searchPerformed;
|
||||
if (customSearchRequest instanceof CustomSearchAPI.Cse.Siterestrict.List) {
|
||||
searchPerformed = ((CustomSearchAPI.Cse.Siterestrict.List) customSearchRequest)
|
||||
.setPrettyPrint(true)
|
||||
.setQ(requestQuery.getSearchTerms())
|
||||
.setNum(maxResultsAllowed(getDefaultNaturalNumber(requestQuery.getCount())))
|
||||
.setSort(requestQuery.getSort())
|
||||
.setSafe(requestQuery.getSafe())
|
||||
.setDateRestrict(requestQuery.getDateRestrict())
|
||||
.setGl(requestQuery.getGl())
|
||||
.setLr(requestQuery.getLanguage())
|
||||
.setHl(requestQuery.getHl())
|
||||
.setHq(requestQuery.getHq())
|
||||
.setSiteSearch(requestQuery.getSiteSearch())
|
||||
.setSiteSearchFilter(requestQuery.getSiteSearchFilter())
|
||||
.setExactTerms(requestQuery.getExactTerms())
|
||||
.setExcludeTerms(requestQuery.getExcludeTerms())
|
||||
.setLinkSite(requestQuery.getLinkSite())
|
||||
.setOrTerms(requestQuery.getOrTerms())
|
||||
.setLowRange(requestQuery.getLowRange())
|
||||
.setHighRange(requestQuery.getHighRange())
|
||||
.setSearchType(requestQuery.getSearchType())
|
||||
.setFileType(requestQuery.getFileType())
|
||||
.setRights(requestQuery.getRights())
|
||||
.setImgSize(requestQuery.getImgSize())
|
||||
.setImgType(requestQuery.getImgType())
|
||||
.setImgColorType(requestQuery.getImgColorType())
|
||||
.setImgDominantColor(requestQuery.getImgDominantColor())
|
||||
.setC2coff(requestQuery.getDisableCnTwTranslation())
|
||||
.setCr(requestQuery.getCr())
|
||||
.setGooglehost(requestQuery.getGoogleHost())
|
||||
.setStart(calculateIndexStartPage(
|
||||
getDefaultNaturalNumber(requestQuery.getStartPage()),
|
||||
getDefaultNaturalNumber(requestQuery.getStartIndex())
|
||||
).longValue())
|
||||
.setFilter(requestQuery.getFilter())
|
||||
.execute();
|
||||
} else if (customSearchRequest instanceof CustomSearchAPI.Cse.List) {
|
||||
searchPerformed = ((CustomSearchAPI.Cse.List) customSearchRequest)
|
||||
.setPrettyPrint(true)
|
||||
.setQ(requestQuery.getSearchTerms())
|
||||
.setNum(maxResultsAllowed(getDefaultNaturalNumber(requestQuery.getCount())))
|
||||
.setSort(requestQuery.getSort())
|
||||
.setSafe(requestQuery.getSafe())
|
||||
.setDateRestrict(requestQuery.getDateRestrict())
|
||||
.setGl(requestQuery.getGl())
|
||||
.setLr(requestQuery.getLanguage())
|
||||
.setHl(requestQuery.getHl())
|
||||
.setHq(requestQuery.getHq())
|
||||
.setSiteSearch(requestQuery.getSiteSearch())
|
||||
.setSiteSearchFilter(requestQuery.getSiteSearchFilter())
|
||||
.setExactTerms(requestQuery.getExactTerms())
|
||||
.setExcludeTerms(requestQuery.getExcludeTerms())
|
||||
.setLinkSite(requestQuery.getLinkSite())
|
||||
.setOrTerms(requestQuery.getOrTerms())
|
||||
.setLowRange(requestQuery.getLowRange())
|
||||
.setHighRange(requestQuery.getHighRange())
|
||||
.setSearchType(requestQuery.getSearchType())
|
||||
.setFileType(requestQuery.getFileType())
|
||||
.setRights(requestQuery.getRights())
|
||||
.setImgSize(requestQuery.getImgSize())
|
||||
.setImgType(requestQuery.getImgType())
|
||||
.setImgColorType(requestQuery.getImgColorType())
|
||||
.setImgDominantColor(requestQuery.getImgDominantColor())
|
||||
.setC2coff(requestQuery.getDisableCnTwTranslation())
|
||||
.setCr(requestQuery.getCr())
|
||||
.setGooglehost(requestQuery.getGoogleHost())
|
||||
.setStart(calculateIndexStartPage(
|
||||
getDefaultNaturalNumber(requestQuery.getStartPage()),
|
||||
getDefaultNaturalNumber(requestQuery.getStartIndex())
|
||||
).longValue())
|
||||
.setFilter(requestQuery.getFilter())
|
||||
.execute();
|
||||
} else {
|
||||
throw new IllegalStateException("Invalid CustomSearchAPIRequest type");
|
||||
}
|
||||
if (logRequestResponse) {
|
||||
log(searchPerformed);
|
||||
}
|
||||
return searchPerformed;
|
||||
} catch (IOException e) {
|
||||
LOGGER.error("Error occurred while searching", e);
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private static void log(Search search){
|
||||
try {
|
||||
LOGGER.debug("Response:\n- Response: {}", search.toPrettyString());
|
||||
} catch (IOException e) {
|
||||
LOGGER.warn("Error while logging response: {}", e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private static Integer maxResultsAllowed(Integer maxResults){
|
||||
return maxResults > MAXIMUM_VALUE_NUM ? MAXIMUM_VALUE_NUM : maxResults;
|
||||
}
|
||||
|
||||
private static Integer getDefaultNaturalNumber(Integer number){
|
||||
int defaultNumber = getOrDefault(number, 1);
|
||||
return defaultNumber > 0 ? defaultNumber : 1;
|
||||
}
|
||||
|
||||
private static Integer calculateIndexStartPage(Integer pageNumber, Integer index) {
|
||||
int indexStartPage = ((pageNumber -1) * MAXIMUM_VALUE_NUM) + 1;
|
||||
return indexStartPage >= index ? indexStartPage : index;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,273 @@
|
|||
package dev.langchain4j.web.search.google.customsearch;
|
||||
|
||||
import com.google.api.client.json.GenericJson;
|
||||
import com.google.api.services.customsearch.v1.model.Result;
|
||||
import com.google.api.services.customsearch.v1.model.Search;
|
||||
import dev.langchain4j.web.search.WebSearchEngine;
|
||||
import dev.langchain4j.web.search.WebSearchInformationResult;
|
||||
import dev.langchain4j.web.search.WebSearchOrganicResult;
|
||||
import dev.langchain4j.web.search.WebSearchRequest;
|
||||
import dev.langchain4j.web.search.WebSearchResults;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.net.URI;
|
||||
import java.time.Duration;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static com.google.api.services.customsearch.v1.model.Search.*;
|
||||
import static dev.langchain4j.internal.Utils.*;
|
||||
import static dev.langchain4j.internal.ValidationUtils.*;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
/**
|
||||
* An implementation of a {@link WebSearchEngine} that uses
|
||||
* <a href="https://programmablesearchengine.google.com/">Google Custom Search API</a> for performing web searches.
|
||||
*/
|
||||
public class GoogleCustomWebSearchEngine implements WebSearchEngine {
|
||||
|
||||
private final GoogleCustomSearchApiClient googleCustomSearchApiClient;
|
||||
private final Boolean includeImages;
|
||||
|
||||
/**
|
||||
* Constructs a new GoogleCustomWebSearchEngine with the specified parameters.
|
||||
*
|
||||
* @param apiKey the Google Search API key for accessing the Google Custom Search API
|
||||
* <p>
|
||||
* You can just generate an API key <a href="https://developers.google.com/custom-search/docs/paid_element#api_key">here</a>
|
||||
* @param csi the Custom Search ID parameter for search the entire web
|
||||
* <p>
|
||||
* You can create a Custom Search Engine <a href="https://cse.google.com/cse/create/new">here</a>
|
||||
* @param siteRestrict if your Search Engine is restricted to only searching specific sites, you can set this parameter to true.
|
||||
* <p>
|
||||
* Default value is false. View the documentation for more information <a href="https://developers.google.com/custom-search/v1/site_restricted_api">here</a>
|
||||
* @param includeImages If it is true then include public images relevant to the query. This can add more latency to the search.
|
||||
* <p>
|
||||
* Default value is false.
|
||||
* @param timeout the timeout duration for API requests
|
||||
* <p>
|
||||
* Default value is 60 seconds.
|
||||
* @param logRequestResponse whether to log API request and response
|
||||
* <p>
|
||||
* Default value is false.
|
||||
* @param maxRetries the maximum number of retries for API requests
|
||||
* <p>
|
||||
* Default value is 10.
|
||||
*/
|
||||
@Builder
|
||||
public GoogleCustomWebSearchEngine(String apiKey,
|
||||
String csi,
|
||||
Boolean siteRestrict,
|
||||
Boolean includeImages,
|
||||
Duration timeout,
|
||||
Boolean logRequestResponse,
|
||||
Integer maxRetries) {
|
||||
|
||||
this.googleCustomSearchApiClient = GoogleCustomSearchApiClient.builder()
|
||||
.apiKey(apiKey)
|
||||
.csi(csi)
|
||||
.siteRestrict(getOrDefault(siteRestrict,false))
|
||||
.timeout(getOrDefault(timeout,Duration.ofSeconds(60)))
|
||||
.logRequestResponse(getOrDefault(logRequestResponse,false))
|
||||
.maxRetries(getOrDefault(maxRetries,10))
|
||||
.build();
|
||||
this.includeImages = getOrDefault(includeImages,false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new builder for constructing a GoogleCustomWebSearchEngine with the specified API key and Custom Search ID.
|
||||
*
|
||||
* @param apiKey the API key for accessing the Google Custom Search API
|
||||
* @param csi the Custom Search ID parameter for search the entire web
|
||||
* @return a new builder instance
|
||||
*/
|
||||
public static GoogleCustomWebSearchEngine withApiKeyAndCsi(String apiKey, String csi){
|
||||
return GoogleCustomWebSearchEngine.builder().apiKey(apiKey).csi(csi).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public WebSearchResults search(WebSearchRequest webSearchRequest) {
|
||||
ensureNotNull(webSearchRequest, "webSearchRequest");
|
||||
|
||||
Queries.Request requestQuery = new Queries.Request();
|
||||
requestQuery.setSearchTerms(webSearchRequest.searchTerms());
|
||||
requestQuery.setCount(getOrDefault(webSearchRequest.maxResults(),5));
|
||||
requestQuery.setGl(webSearchRequest.geoLocation());
|
||||
requestQuery.setLanguage(webSearchRequest.language());
|
||||
requestQuery.setStartPage(webSearchRequest.startPage());
|
||||
requestQuery.setStartIndex(webSearchRequest.startIndex());
|
||||
requestQuery.setSafe(webSearchRequest.safeSearch()?"active":"off");
|
||||
requestQuery.setFilter("1"); // By default, applies filtering to remove duplicate content
|
||||
requestQuery.setCr(setCountryRestrict(webSearchRequest));
|
||||
webSearchRequest.additionalParams().forEach(requestQuery::set);
|
||||
|
||||
boolean searchTypeImage = isNotNullOrBlank(requestQuery.getSearchType()) && requestQuery.getSearchType().equals("image");
|
||||
|
||||
// Web search
|
||||
Search search = googleCustomSearchApiClient.searchResults(requestQuery);
|
||||
Map<String, Object> searchMetadata = toSearchMetadata(search, searchTypeImage);
|
||||
Map<String, Object> searchInformationMetadata = new HashMap<>();
|
||||
|
||||
// Images search
|
||||
if (includeImages && !searchTypeImage) {
|
||||
requestQuery.setSearchType("image");
|
||||
Search imagesSearch = googleCustomSearchApiClient.searchResults(requestQuery);
|
||||
List<ImageSearchResult> images = imagesSearch.getItems().stream()
|
||||
.map(result -> ImageSearchResult.from(
|
||||
result.getTitle(),
|
||||
URI.create(result.getLink()),
|
||||
URI.create(result.getImage().getContextLink()),
|
||||
URI.create(result.getImage().getThumbnailLink())))
|
||||
.collect(toList());
|
||||
addImagesToSearchInformation(searchInformationMetadata, images);
|
||||
}
|
||||
|
||||
|
||||
return WebSearchResults.from(
|
||||
searchMetadata
|
||||
, WebSearchInformationResult.from(
|
||||
Long.valueOf(getOrDefault(search.getSearchInformation().getTotalResults(),"0")),
|
||||
!isNullOrEmpty(search.getQueries().getRequest())
|
||||
?calculatePageNumberFromQueries(search.getQueries().getRequest().get(0)):1,
|
||||
searchInformationMetadata.isEmpty()?null:searchInformationMetadata)
|
||||
, search.getItems().stream()
|
||||
.map(result -> WebSearchOrganicResult.from(
|
||||
result.getTitle(),
|
||||
URI.create(result.getLink()),
|
||||
result.getSnippet(),
|
||||
null, // by default google custom search api does not return content
|
||||
toResultMetadataMap(result, searchTypeImage)
|
||||
)).collect(toList()));
|
||||
}
|
||||
|
||||
private static void addImagesToSearchInformation(Map<String, Object> searchInformationMetadata, List<ImageSearchResult> images) {
|
||||
if (!isNullOrEmpty(images)) {
|
||||
searchInformationMetadata.put("images", images);
|
||||
}
|
||||
}
|
||||
|
||||
private static Map<String, Object> toSearchMetadata(Search search, Boolean searchTypeImage) {
|
||||
if (search == null) {
|
||||
return null;
|
||||
}
|
||||
Map<String, Object> searchMetadata = new HashMap<>();
|
||||
searchMetadata.put("status", "Success");
|
||||
searchMetadata.put("searchTime", search.getSearchInformation().getSearchTime());
|
||||
searchMetadata.put("processedAt", LocalDateTime.now().toString());
|
||||
searchMetadata.put("searchType", searchTypeImage ? "images" : "web");
|
||||
searchMetadata.putAll(search.getContext());
|
||||
return searchMetadata;
|
||||
}
|
||||
|
||||
private static Map<String, String> toResultMetadataMap(Result result, boolean searchTypeImage) {
|
||||
Map<String, String> metadata = new HashMap<>();
|
||||
// Image search type
|
||||
if (searchTypeImage) {
|
||||
metadata.put("imageLink", result.getLink());
|
||||
metadata.put("contextLink", result.getImage().getContextLink());
|
||||
metadata.put("thumbnailLink", result.getImage().getThumbnailLink());
|
||||
metadata.put("mimeType", result.getMime());
|
||||
return metadata;
|
||||
}
|
||||
// Web search type
|
||||
if (!result.getPagemap().isEmpty()) {
|
||||
result.getPagemap().forEach((key, value) -> {
|
||||
if (key.equals("metatags")) {
|
||||
if (value instanceof List) {
|
||||
metadata.put(key, ((List<?>) value).stream().map(Object::toString).reduce((a, b) -> a + ", " + b).orElse(""));
|
||||
} else {
|
||||
metadata.put(key, value.toString());
|
||||
}
|
||||
}
|
||||
metadata.put("mimeType", isNotNullOrBlank(result.getMime()) ? result.getMime() : "text/html");
|
||||
});
|
||||
return metadata;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static Integer calculatePageNumberFromQueries(GenericJson query) {
|
||||
if (query instanceof Queries.PreviousPage) {
|
||||
Queries.PreviousPage previousPage = (Queries.PreviousPage) query;
|
||||
return calculatePageNumber(previousPage.getStartIndex());
|
||||
}
|
||||
if (query instanceof Queries.Request) {
|
||||
Queries.Request currentPage = (Queries.Request) query;
|
||||
return calculatePageNumber(getOrDefault(currentPage.getStartIndex(),1));
|
||||
}
|
||||
if (query instanceof Queries.NextPage) {
|
||||
Queries.NextPage nextPage = (Queries.NextPage) query;
|
||||
return calculatePageNumber(nextPage.getStartIndex());
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static Integer calculatePageNumber(Integer startIndex) {
|
||||
if (startIndex == null)
|
||||
return null;
|
||||
return ((startIndex -1) / 10) + 1;
|
||||
}
|
||||
|
||||
private static String setCountryRestrict(WebSearchRequest webSearchRequest){
|
||||
return webSearchRequest.additionalParams().get("cr") != null ? webSearchRequest.additionalParams().get("cr").toString()
|
||||
: isNotNullOrBlank(webSearchRequest.geoLocation()) ? "country" + webSearchRequest.geoLocation().toUpperCase()
|
||||
: ""; // default value
|
||||
}
|
||||
|
||||
public static final class ImageSearchResult {
|
||||
private final String title;
|
||||
private final URI imageLink;
|
||||
private final URI contextLink;
|
||||
private final URI thumbnailLink;
|
||||
|
||||
private ImageSearchResult(String title, URI imageLink) {
|
||||
this.title = ensureNotNull(title,"title");
|
||||
this.imageLink = ensureNotNull(imageLink,"imageLink");
|
||||
this.contextLink = null;
|
||||
this.thumbnailLink = null;
|
||||
}
|
||||
|
||||
private ImageSearchResult(String title, URI imageLink, URI contextLink, URI thumbnailLink) {
|
||||
this.title = ensureNotNull(title,"title");
|
||||
this.imageLink = ensureNotNull(imageLink,"imageLink");
|
||||
this.contextLink = contextLink;
|
||||
this.thumbnailLink = thumbnailLink;
|
||||
}
|
||||
|
||||
public String title() {
|
||||
return title;
|
||||
}
|
||||
|
||||
public URI imageLink() {
|
||||
return imageLink;
|
||||
}
|
||||
|
||||
public URI contextLink() {
|
||||
return contextLink;
|
||||
}
|
||||
|
||||
public URI thumbnailLink() {
|
||||
return thumbnailLink;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "ImageSearchResult{" +
|
||||
"title='" + title + '\'' +
|
||||
", imageLink=" + imageLink +
|
||||
", contextLink=" + contextLink +
|
||||
", thumbnailLink=" + thumbnailLink +
|
||||
'}';
|
||||
}
|
||||
|
||||
public static ImageSearchResult from(String title, URI imageLink) {
|
||||
return new ImageSearchResult(title, imageLink);
|
||||
}
|
||||
|
||||
public static ImageSearchResult from(String title, URI imageLink, URI contextLink, URI thumbnailLink) {
|
||||
return new ImageSearchResult(title, imageLink, contextLink, thumbnailLink);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
package dev.langchain4j.web.search.google.customsearch;
|
||||
|
||||
import com.google.api.client.http.HttpContent;
|
||||
import com.google.api.client.http.HttpExecuteInterceptor;
|
||||
import com.google.api.client.http.HttpHeaders;
|
||||
import com.google.api.client.http.HttpRequest;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
class GoogleSearchApiHttpRequestLoggingInterceptor implements HttpExecuteInterceptor {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(GoogleSearchApiHttpRequestLoggingInterceptor.class);
|
||||
|
||||
@Override
|
||||
public void intercept(HttpRequest httpRequest) throws IOException {
|
||||
this.log(httpRequest);
|
||||
}
|
||||
|
||||
private void log(HttpRequest httpRequest){
|
||||
try {
|
||||
LOGGER.debug("Request:\n- method: {}\n- url: {}\n- headers: {}\n- body: {}",
|
||||
httpRequest.getRequestMethod(), httpRequest.getUrl(), getHeaders(httpRequest.getHeaders()), getBody(httpRequest.getContent()));
|
||||
} catch (Exception e) {
|
||||
LOGGER.warn("Error while logging request: {}", e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private static String getHeaders(HttpHeaders headers){
|
||||
return headers.entrySet().stream()
|
||||
.map(entry -> String.format("[%s: %s]", entry.getKey(), entry.getValue())).collect(Collectors.joining(", "));
|
||||
}
|
||||
|
||||
private static String getBody(HttpContent content){
|
||||
try {
|
||||
if (content == null) {
|
||||
return "";
|
||||
}
|
||||
return content.toString();
|
||||
} catch (Exception e) {
|
||||
LOGGER.warn("Exception while getting body", e);
|
||||
return "Exception while getting body: " + e.getMessage();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
package dev.langchain4j.web.search.google.customsearch;
|
||||
|
||||
import com.google.api.client.http.HttpHeaders;
|
||||
import com.google.api.client.http.HttpResponse;
|
||||
import com.google.api.client.http.HttpResponseInterceptor;
|
||||
import com.google.api.client.json.gson.GsonFactory;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
class GoogleSearchApiHttpResponseLoggingInterceptor implements HttpResponseInterceptor {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(GoogleSearchApiHttpResponseLoggingInterceptor.class);
|
||||
|
||||
@Override
|
||||
public void interceptResponse(HttpResponse httpResponse) throws IOException {
|
||||
this.log(httpResponse);
|
||||
}
|
||||
|
||||
private void log(HttpResponse httpResponse){
|
||||
try {
|
||||
httpResponse.getRequest().setParser(new GsonFactory().createJsonObjectParser());
|
||||
LOGGER.debug("Response:\n- status code: {}\n- headers: {}",
|
||||
httpResponse.getStatusCode(), getHeaders(httpResponse.getHeaders())); // response body can't be got twice by google token constraints, it'll be logged in GoogleCustomSearchApiClient
|
||||
} catch (Exception e) {
|
||||
LOGGER.warn("Error while logging response: {}", e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private static String getHeaders(HttpHeaders headers){
|
||||
return headers.entrySet().stream()
|
||||
.map(entry -> String.format("[%s: %s]", entry.getKey(), entry.getValue())).collect(Collectors.joining(", "));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,111 @@
|
|||
package dev.langchain4j.web.search.google.customsearch;
|
||||
|
||||
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModelName;
|
||||
import dev.langchain4j.rag.content.Content;
|
||||
import dev.langchain4j.rag.content.retriever.WebSearchContentRetriever;
|
||||
import dev.langchain4j.rag.content.retriever.WebSearchContentRetrieverIT;
|
||||
import dev.langchain4j.rag.query.Query;
|
||||
import dev.langchain4j.service.AiServices;
|
||||
import dev.langchain4j.service.SystemMessage;
|
||||
import dev.langchain4j.web.search.WebSearchEngine;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@EnabledIfEnvironmentVariable(named = "GOOGLE_API_KEY", matches = ".*")
|
||||
@EnabledIfEnvironmentVariable(named = "GOOGLE_SEARCH_ENGINE_ID", matches = ".*")
|
||||
class GoogleCustomWebSearchContentRetrieverIT extends WebSearchContentRetrieverIT {
|
||||
|
||||
WebSearchEngine googleSearchEngine = GoogleCustomWebSearchEngine.withApiKeyAndCsi(
|
||||
System.getenv("GOOGLE_API_KEY"),
|
||||
System.getenv("GOOGLE_SEARCH_ENGINE_ID"));
|
||||
|
||||
ChatLanguageModel chatModel = OpenAiChatModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.modelName(OpenAiChatModelName.GPT_3_5_TURBO)
|
||||
.logRequests(true)
|
||||
.build();
|
||||
|
||||
interface Assistant {
|
||||
@SystemMessage({
|
||||
"You are a web search support agent.",
|
||||
"If there is any event that has not happened yet, ",
|
||||
"you MUST use a web search tool to look up the information on the web.",
|
||||
"Include the source link and the image urls in your final response if these known, otherwise, do not include them.",
|
||||
"Do not say that you have not the capability to browse the web in real time"
|
||||
})
|
||||
String answer(String userMessage);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_retrieve_web_content_with_google_for_current_info() {
|
||||
// given
|
||||
googleSearchEngine = GoogleCustomWebSearchEngine.builder()
|
||||
.apiKey(System.getenv("GOOGLE_API_KEY"))
|
||||
.csi(System.getenv("GOOGLE_SEARCH_ENGINE_ID"))
|
||||
.logRequestResponse(true)
|
||||
.maxRetries(3)
|
||||
.build();
|
||||
|
||||
WebSearchContentRetriever contentRetriever = WebSearchContentRetriever.from(googleSearchEngine);
|
||||
Query query = Query.from("What is the latest currency exchange rates for the US Dollar and Euro");
|
||||
|
||||
// when
|
||||
List<Content> contents = contentRetriever.retrieve(query);
|
||||
System.out.println("contents: " + contents);
|
||||
|
||||
// then
|
||||
assertThat(contents)
|
||||
.as("At least one content should be contains 'us dollar' and 'euro' ignoring case")
|
||||
.anySatisfy(content -> {
|
||||
assertThat(content.textSegment().text())
|
||||
.containsIgnoringCase("us dollar")
|
||||
.containsIgnoringCase("euro");
|
||||
assertThat(content.textSegment().metadata().get("url"))
|
||||
.startsWith("https://");
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_retrieve_web_content_with_google_and_use_AiServices_to_summary_response () {
|
||||
// given
|
||||
googleSearchEngine = GoogleCustomWebSearchEngine.builder()
|
||||
.apiKey(System.getenv("GOOGLE_API_KEY"))
|
||||
.csi(System.getenv("GOOGLE_SEARCH_ENGINE_ID"))
|
||||
.logRequestResponse(true)
|
||||
.maxRetries(3)
|
||||
.build();
|
||||
|
||||
WebSearchContentRetriever contentRetriever = WebSearchContentRetriever.from(googleSearchEngine);
|
||||
|
||||
String query = "My family is coming to visit me in Madrid next week, list the best tourist activities suitable for the whole family";
|
||||
|
||||
// when
|
||||
Assistant assistant = AiServices.builder(Assistant.class)
|
||||
.chatLanguageModel(chatModel)
|
||||
.contentRetriever(contentRetriever)
|
||||
.chatMemory(MessageWindowChatMemory.withMaxMessages(10))
|
||||
.build();
|
||||
|
||||
String answer = assistant.answer(query);
|
||||
System.out.println(answer);
|
||||
|
||||
// then
|
||||
assertThat(answer)
|
||||
.as("At least the string result should be contains 'madrid' and 'tourist' ignoring case")
|
||||
.containsIgnoringCase("Madrid")
|
||||
.containsIgnoringCase("Royal Palace of Madrid");
|
||||
}
|
||||
|
||||
@Override
|
||||
protected WebSearchEngine searchEngine() {
|
||||
return googleSearchEngine;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,207 @@
|
|||
package dev.langchain4j.web.search.google.customsearch;
|
||||
|
||||
import dev.langchain4j.web.search.*;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static java.util.Collections.singletonMap;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static dev.langchain4j.web.search.google.customsearch.GoogleCustomWebSearchEngine.ImageSearchResult;
|
||||
|
||||
@EnabledIfEnvironmentVariable(named = "GOOGLE_API_KEY", matches = ".*")
|
||||
@EnabledIfEnvironmentVariable(named = "GOOGLE_SEARCH_ENGINE_ID", matches = ".*")
|
||||
class GoogleCustomWebSearchEngineIT extends WebSearchEngineIT {
|
||||
|
||||
WebSearchEngine googleSearchEngine = GoogleCustomWebSearchEngine.withApiKeyAndCsi(
|
||||
System.getenv("GOOGLE_API_KEY"),
|
||||
System.getenv("GOOGLE_SEARCH_ENGINE_ID"));
|
||||
|
||||
@Test
|
||||
void should_return_google_web_results_with_search_information() {
|
||||
// given
|
||||
String query = "What is LangChain4j project?";
|
||||
|
||||
// when
|
||||
WebSearchResults results = googleSearchEngine.search(query);
|
||||
|
||||
// then
|
||||
assertThat(results.searchMetadata()).isNotNull();
|
||||
assertThat(results.searchInformation().totalResults()).isGreaterThan(0);
|
||||
assertThat(results.results().size()).isGreaterThan(0);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_return_google_safe_web_results_in_spanish_language() {
|
||||
// given
|
||||
String query = "Who won the FIFA World Cup 2022?";
|
||||
WebSearchRequest webSearchRequest = WebSearchRequest.builder()
|
||||
.searchTerms(query)
|
||||
.language("lang_es")
|
||||
.safeSearch(true)
|
||||
.build();
|
||||
|
||||
// when
|
||||
List<WebSearchOrganicResult> results = googleSearchEngine.search(webSearchRequest).results();
|
||||
|
||||
// then
|
||||
assertThat(results)
|
||||
.as("At least one result should be contains 'argentina' ignoring case")
|
||||
.anySatisfy(result -> assertThat(result.snippet())
|
||||
.containsIgnoringCase("argentina"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_return_google_results_of_the_second_page_and_log_http_req_resp() {
|
||||
// given
|
||||
WebSearchEngine googleSearchEngine = GoogleCustomWebSearchEngine.builder()
|
||||
.apiKey(System.getenv("GOOGLE_API_KEY"))
|
||||
.csi(System.getenv("GOOGLE_SEARCH_ENGINE_ID"))
|
||||
.logRequestResponse(true)
|
||||
.build();
|
||||
|
||||
String query = "What is the weather in Porto?";
|
||||
WebSearchRequest webSearchRequest = WebSearchRequest.builder()
|
||||
.searchTerms(query)
|
||||
.maxResults(5)
|
||||
.startPage(2)
|
||||
.build();
|
||||
|
||||
// when
|
||||
WebSearchResults webSearchResults = googleSearchEngine.search(webSearchRequest);
|
||||
|
||||
// then
|
||||
assertThat(webSearchResults.results())
|
||||
.as("At least the string result should be contains 'weather' and 'Porto' ignoring case")
|
||||
.anySatisfy(result -> assertThat(result.snippet())
|
||||
.containsIgnoringCase("weather")
|
||||
.containsIgnoringCase("porto"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_return_google_results_using_and_fix_startpage_by_startindex(){
|
||||
// given
|
||||
String query = "What is LangChain4j project?";
|
||||
WebSearchRequest webSearchRequest = WebSearchRequest.builder()
|
||||
.searchTerms(query)
|
||||
.language("lang_en")
|
||||
.startPage(1) //user bad request
|
||||
.startIndex(15)
|
||||
.build();
|
||||
|
||||
// when
|
||||
WebSearchResults webSearchResults = googleSearchEngine.search(webSearchRequest);
|
||||
|
||||
// then
|
||||
assertThat(webSearchResults.results())
|
||||
.as("At least one result should be contains 'Java' and 'AI' ignoring case")
|
||||
.anySatisfy(result -> assertThat(result.snippet())
|
||||
.containsIgnoringCase("Java")
|
||||
.containsIgnoringCase("AI"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_return_google_result_using_additional_params() {
|
||||
// given
|
||||
WebSearchEngine googleSearchEngine = GoogleCustomWebSearchEngine.builder()
|
||||
.apiKey(System.getenv("GOOGLE_API_KEY"))
|
||||
.csi(System.getenv("GOOGLE_SEARCH_ENGINE_ID"))
|
||||
.logRequestResponse(true)
|
||||
.build();
|
||||
|
||||
String query = "What is LangChain4j project?";
|
||||
Map<String, Object> additionalParams = new HashMap<>();
|
||||
additionalParams.put("dateRestrict", "w[2]");
|
||||
additionalParams.put("linkSite", "https://github.com/langchain4j/langchain4j");
|
||||
WebSearchRequest webSearchRequest = WebSearchRequest.builder()
|
||||
.searchTerms(query)
|
||||
.additionalParams(additionalParams)
|
||||
.build();
|
||||
|
||||
// when
|
||||
List<WebSearchOrganicResult> results = googleSearchEngine.search(webSearchRequest).results();
|
||||
|
||||
// then
|
||||
assertThat(results)
|
||||
.as("At least one result should be contains 'Java' and 'AI' ignoring case")
|
||||
.anySatisfy(result -> assertThat(result.snippet())
|
||||
.containsIgnoringCase("Java")
|
||||
.containsIgnoringCase("github.com/langchain4j/langchain4j"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_return_google_result_with_images_related() {
|
||||
// given
|
||||
WebSearchEngine googleSearchEngine = GoogleCustomWebSearchEngine.builder()
|
||||
.apiKey(System.getenv("GOOGLE_API_KEY"))
|
||||
.csi(System.getenv("GOOGLE_SEARCH_ENGINE_ID"))
|
||||
.includeImages(true) // execute an additional search, searchType: image
|
||||
.logRequestResponse(true)
|
||||
.build();
|
||||
|
||||
String query = "Which top 2024 universities to study computer science?";
|
||||
WebSearchRequest webSearchRequest = WebSearchRequest.builder()
|
||||
.searchTerms(query)
|
||||
.build();
|
||||
|
||||
// when
|
||||
WebSearchResults webSearchResults = googleSearchEngine.search(webSearchRequest);
|
||||
|
||||
// then
|
||||
assertThat(webSearchResults.searchMetadata().get("searchType").toString()).isEqualTo("web"); // searchType: web
|
||||
assertThat(webSearchResults.searchInformation().metadata().get("images")).isOfAnyClassIn(ArrayList.class, List.class); // should add images related to the query
|
||||
assertThat((List<ImageSearchResult>)webSearchResults.searchInformation().metadata().get("images")) // Get images from searchInformation.metadata
|
||||
.as("At least one image result should be contains title, link, contextLink and thumbnailLink")
|
||||
.anySatisfy(image -> {
|
||||
assertThat(image.title()).isNotNull();
|
||||
assertThat(image.imageLink().toString()).startsWith("http");
|
||||
assertThat(image.contextLink().toString()).startsWith("http");
|
||||
assertThat(image.thumbnailLink().toString()).startsWith("http");
|
||||
});
|
||||
assertThat(webSearchResults.results()) // Get web results
|
||||
.as("At least the string result should be contains 'university' and 'ranking' ignoring case")
|
||||
.anySatisfy(result -> assertThat(result.snippet())
|
||||
.containsIgnoringCase("university")
|
||||
.containsIgnoringCase("ranking"));
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_return_google_image_result_with_param_searchType_image() {
|
||||
// given
|
||||
String query = "How will be the weather next week in Lisbon and Porto cities?";
|
||||
WebSearchRequest webSearchRequest = WebSearchRequest.builder()
|
||||
.searchTerms(query)
|
||||
.additionalParams(singletonMap("searchType", "image"))
|
||||
.build();
|
||||
|
||||
// when
|
||||
WebSearchResults webSearchResults = googleSearchEngine.search(webSearchRequest);
|
||||
|
||||
// then
|
||||
assertThat(webSearchResults.searchMetadata().get("searchType").toString()).isEqualTo("images"); // searchType: images
|
||||
assertThat(webSearchResults.results()) // Get images as search results
|
||||
.as("At least the snippet should be contains 'weather' and 'Porto' ignoring case")
|
||||
.anySatisfy(result -> assertThat(result.title())
|
||||
.containsIgnoringCase("weather")
|
||||
.containsIgnoringCase("porto"))
|
||||
.anySatisfy(result -> assertThat(result.url().toString())
|
||||
.startsWith("http"))
|
||||
.anySatisfy(result -> assertThat(result.metadata().get("mimeType"))
|
||||
.startsWith("image"))
|
||||
.anySatisfy(result -> assertThat(result.metadata().get("imageLink"))
|
||||
.isEqualTo(result.url().toString()))
|
||||
.anySatisfy(result -> assertThat(result.metadata().get("contextLink"))
|
||||
.startsWith("http"))
|
||||
.anySatisfy(result -> assertThat(result.metadata().get("thumbnailLink"))
|
||||
.startsWith("http"));
|
||||
}
|
||||
@Override
|
||||
protected WebSearchEngine searchEngine() {
|
||||
return googleSearchEngine;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,168 @@
|
|||
package dev.langchain4j.web.search.google.customsearch;
|
||||
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.agent.tool.ToolSpecifications;
|
||||
import dev.langchain4j.data.message.*;
|
||||
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModelName;
|
||||
import dev.langchain4j.service.AiServices;
|
||||
import dev.langchain4j.web.search.WebSearchEngine;
|
||||
import dev.langchain4j.web.search.WebSearchTool;
|
||||
import dev.langchain4j.web.search.WebSearchToolIT;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@EnabledIfEnvironmentVariable(named = "GOOGLE_API_KEY", matches = ".*")
|
||||
@EnabledIfEnvironmentVariable(named = "GOOGLE_SEARCH_ENGINE_ID", matches = ".*")
|
||||
class GoogleCustomWebSearchToolIT extends WebSearchToolIT {
|
||||
|
||||
WebSearchEngine googleSearchEngine = GoogleCustomWebSearchEngine.withApiKeyAndCsi(
|
||||
System.getenv("GOOGLE_API_KEY"),
|
||||
System.getenv("GOOGLE_SEARCH_ENGINE_ID"));
|
||||
|
||||
ChatLanguageModel chatModel = OpenAiChatModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.modelName(OpenAiChatModelName.GPT_3_5_TURBO)
|
||||
.logRequests(true)
|
||||
.build();
|
||||
|
||||
interface Assistant {
|
||||
@dev.langchain4j.service.SystemMessage({
|
||||
"You are a web search support agent.",
|
||||
"If there is any event that has not happened yet",
|
||||
"You MUST create a web search request with with user query and",
|
||||
"use the web search tool to search the web for organic web results.",
|
||||
"Include the source link in your final response."
|
||||
})
|
||||
String answer(String userMessage);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_execute_google_tool_with_chatLanguageModel_to_give_a_final_response(){
|
||||
// given
|
||||
googleSearchEngine = GoogleCustomWebSearchEngine.builder()
|
||||
.apiKey(System.getenv("GOOGLE_API_KEY"))
|
||||
.csi(System.getenv("GOOGLE_SEARCH_ENGINE_ID"))
|
||||
.logRequestResponse(true)
|
||||
.maxRetries(3)
|
||||
.build();
|
||||
|
||||
WebSearchTool webSearchTool = WebSearchTool.from(googleSearchEngine);
|
||||
List<ToolSpecification> tools = ToolSpecifications.toolSpecificationsFrom(webSearchTool);
|
||||
String query = "What are the release dates for the movies coming out last week of May 2024?";
|
||||
List<ChatMessage> messages = new ArrayList<>();
|
||||
SystemMessage systemMessage = SystemMessage.from("You are a web search support agent. If there is any event that has not happened yet, you MUST use a web search tool to look up the information on the web. Include the source link in your final response. Do not say that you have not the capability to browse the web in real time");
|
||||
messages.add(systemMessage);
|
||||
UserMessage userMessage = UserMessage.from(query);
|
||||
messages.add(userMessage);
|
||||
// when
|
||||
AiMessage aiMessage = chatLanguageModel().generate(messages, tools).content();
|
||||
|
||||
// then
|
||||
assertThat(aiMessage.hasToolExecutionRequests()).isTrue();
|
||||
assertThat(aiMessage.toolExecutionRequests())
|
||||
.anySatisfy(toolSpec -> {
|
||||
assertThat(toolSpec.name())
|
||||
.containsIgnoringCase("searchWeb");
|
||||
assertThat(toolSpec.arguments())
|
||||
.isNotBlank();
|
||||
}
|
||||
);
|
||||
messages.add(aiMessage);
|
||||
|
||||
// when
|
||||
String strResult = webSearchTool.searchWeb(query);
|
||||
ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(aiMessage.toolExecutionRequests().get(0), strResult);
|
||||
messages.add(toolExecutionResultMessage);
|
||||
|
||||
AiMessage finalResponse = chatLanguageModel().generate(messages).content();
|
||||
System.out.println(finalResponse.text());
|
||||
|
||||
// then
|
||||
assertThat(finalResponse.text())
|
||||
.as("At least the string result should be contains 'movies' and 'coming soon' ignoring case")
|
||||
.containsIgnoringCase("movies")
|
||||
.containsIgnoringCase("May 2024");
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_execute_google_tool_with_chatLanguageModel_to_summary_response_in_images() {
|
||||
// given
|
||||
googleSearchEngine = GoogleCustomWebSearchEngine.builder()
|
||||
.apiKey(System.getenv("GOOGLE_API_KEY"))
|
||||
.csi(System.getenv("GOOGLE_SEARCH_ENGINE_ID"))
|
||||
.logRequestResponse(true)
|
||||
.build();
|
||||
|
||||
WebSearchTool webSearchTool = WebSearchTool.from(googleSearchEngine);
|
||||
List<ToolSpecification> tools = ToolSpecifications.toolSpecificationsFrom(webSearchTool);
|
||||
String query = "My family is coming to visit me in Madrid next week, list the best tourist activities suitable for the whole family";
|
||||
List<ChatMessage> messages = new ArrayList<>();
|
||||
SystemMessage systemMessage = SystemMessage.from("You are a web search support agent. If there is any event that has not happened yet, you MUST use a web search tool to look up the information on the web. Include the source link in your final response and the image urls. Do not say that you have not the capability to browse the web in real time");
|
||||
messages.add(systemMessage);
|
||||
UserMessage userMessage = UserMessage.from(query);
|
||||
messages.add(userMessage);
|
||||
// when
|
||||
AiMessage aiMessage = chatLanguageModel().generate(messages, tools).content();
|
||||
|
||||
// then
|
||||
assertThat(aiMessage.hasToolExecutionRequests()).isTrue();
|
||||
assertThat(aiMessage.toolExecutionRequests())
|
||||
.anySatisfy(toolSpec -> {
|
||||
assertThat(toolSpec.name())
|
||||
.containsIgnoringCase("searchWeb");
|
||||
assertThat(toolSpec.arguments())
|
||||
.isNotBlank();
|
||||
}
|
||||
);
|
||||
messages.add(aiMessage);
|
||||
|
||||
// when
|
||||
String strResult = webSearchTool.searchWeb(query);
|
||||
ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(aiMessage.toolExecutionRequests().get(0), strResult);
|
||||
messages.add(toolExecutionResultMessage);
|
||||
|
||||
AiMessage finalResponse = chatLanguageModel().generate(messages).content();
|
||||
System.out.println(finalResponse.text());
|
||||
|
||||
// then
|
||||
assertThat(finalResponse.text())
|
||||
.as("At least the string result should be contains 'madrid' and 'tourist' ignoring case")
|
||||
.containsIgnoringCase("Madrid")
|
||||
.containsIgnoringCase("Royal Palace of Madrid");
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_execute_google_tool_with_AiServices() {
|
||||
// given
|
||||
WebSearchTool webTool = WebSearchTool.from(googleSearchEngine);
|
||||
|
||||
Assistant assistant = AiServices.builder(Assistant.class)
|
||||
.chatLanguageModel(chatModel)
|
||||
.tools(webTool)
|
||||
.chatMemory(MessageWindowChatMemory.withMaxMessages(10))
|
||||
.build();
|
||||
// when
|
||||
String answer = assistant.answer("Search in the web who won the FIFA World Cup 2022?");
|
||||
|
||||
// then
|
||||
assertThat(answer).containsIgnoringCase("Argentina");
|
||||
}
|
||||
|
||||
@Override
|
||||
protected WebSearchEngine searchEngine() {
|
||||
return googleSearchEngine;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ChatLanguageModel chatLanguageModel() {
|
||||
return chatModel;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue