Add streaming API for Bedrock Anthropics (#679)
Currently, langchain4j doesn't have support for streaming of Bedrock Anthropics model, this PR tries to address it. Adding support for Anthropics v2 and v2:1 streaming. New tests disabled due to need for AWS credentials but pass when enabled.
This commit is contained in:
parent
ded9ecce38
commit
6684ea0b33
|
@ -0,0 +1,33 @@
|
|||
package dev.langchain4j.model.bedrock;
|
||||
|
||||
import dev.langchain4j.model.bedrock.internal.AbstractBedrockStreamingChatModel;
|
||||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
import lombok.experimental.SuperBuilder;
|
||||
|
||||
@Getter
|
||||
@SuperBuilder
|
||||
public class BedrockAnthropicStreamingChatModel extends AbstractBedrockStreamingChatModel {
|
||||
@Builder.Default
|
||||
private final BedrockAnthropicChatModel.Types model = BedrockAnthropicChatModel.Types.AnthropicClaudeV2;
|
||||
|
||||
@Override
|
||||
protected String getModelId() {
|
||||
return model.getValue();
|
||||
}
|
||||
|
||||
@Getter
|
||||
/**
|
||||
* Bedrock Anthropic model ids
|
||||
*/
|
||||
public enum Types {
|
||||
AnthropicClaudeV2("anthropic.claude-v2"),
|
||||
AnthropicClaudeV2_1("anthropic.claude-v2:1");
|
||||
|
||||
private final String value;
|
||||
|
||||
Types(String modelID) {
|
||||
this.value = modelID;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -13,6 +13,7 @@ import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
|
|||
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
|
||||
import software.amazon.awssdk.core.SdkBytes;
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
|
||||
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
|
||||
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
|
||||
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;
|
||||
|
@ -30,47 +31,14 @@ import static java.util.stream.Collectors.joining;
|
|||
*/
|
||||
@Getter
|
||||
@SuperBuilder
|
||||
public abstract class AbstractBedrockChatModel<T extends BedrockChatModelResponse> implements ChatLanguageModel {
|
||||
private static final String HUMAN_PROMPT = "Human:";
|
||||
private static final String ASSISTANT_PROMPT = "Assistant:";
|
||||
|
||||
@Builder.Default
|
||||
private final String humanPrompt = HUMAN_PROMPT;
|
||||
@Builder.Default
|
||||
private final String assistantPrompt = ASSISTANT_PROMPT;
|
||||
@Builder.Default
|
||||
private final Integer maxRetries = 5;
|
||||
@Builder.Default
|
||||
private final Region region = Region.US_EAST_1;
|
||||
@Builder.Default
|
||||
private final AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.builder().build();
|
||||
@Builder.Default
|
||||
private final int maxTokens = 300;
|
||||
@Builder.Default
|
||||
private final float temperature = 1;
|
||||
@Builder.Default
|
||||
private final float topP = 0.999f;
|
||||
@Builder.Default
|
||||
private final String[] stopSequences = new String[]{};
|
||||
public abstract class AbstractBedrockChatModel<T extends BedrockChatModelResponse> extends AbstractSharedBedrockChatModel implements ChatLanguageModel {
|
||||
@Getter(lazy = true)
|
||||
private final BedrockRuntimeClient client = initClient();
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages) {
|
||||
|
||||
final String context = messages.stream()
|
||||
.filter(message -> message.type() == ChatMessageType.SYSTEM)
|
||||
.map(ChatMessage::text)
|
||||
.collect(joining("\n"));
|
||||
|
||||
final String userMessages = messages.stream()
|
||||
.filter(message -> message.type() != ChatMessageType.SYSTEM)
|
||||
.map(this::chatMessageToString)
|
||||
.collect(joining("\n"));
|
||||
|
||||
final String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT);
|
||||
final Map<String, Object> requestParameters = getRequestParameters(prompt);
|
||||
final String body = Json.toJson(requestParameters);
|
||||
final String body = convertMessagesToAwsBody(messages);
|
||||
|
||||
InvokeModelResponse invokeModelResponse = withRetry(() -> invoke(body), maxRetries);
|
||||
final String response = invokeModelResponse.body().asUtf8String();
|
||||
|
@ -81,26 +49,6 @@ public abstract class AbstractBedrockChatModel<T extends BedrockChatModelRespons
|
|||
result.getFinishReason());
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert chat message to string
|
||||
*
|
||||
* @param message chat message
|
||||
* @return string
|
||||
*/
|
||||
protected String chatMessageToString(ChatMessage message) {
|
||||
switch (message.type()) {
|
||||
case SYSTEM:
|
||||
return message.text();
|
||||
case USER:
|
||||
return humanPrompt + " " + message.text();
|
||||
case AI:
|
||||
return assistantPrompt + " " + message.text();
|
||||
case TOOL_EXECUTION_RESULT:
|
||||
throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models");
|
||||
}
|
||||
|
||||
throw new IllegalArgumentException("Unknown message type: " + message.type());
|
||||
}
|
||||
|
||||
/**
|
||||
* Get request parameters
|
||||
|
@ -110,13 +58,6 @@ public abstract class AbstractBedrockChatModel<T extends BedrockChatModelRespons
|
|||
*/
|
||||
protected abstract Map<String, Object> getRequestParameters(final String prompt);
|
||||
|
||||
/**
|
||||
* Get model id
|
||||
*
|
||||
* @return model id
|
||||
*/
|
||||
protected abstract String getModelId();
|
||||
|
||||
|
||||
/**
|
||||
* Get response class type
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
package dev.langchain4j.model.bedrock.internal;
|
||||
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import dev.langchain4j.internal.Json;
|
||||
import dev.langchain4j.model.StreamingResponseHandler;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.Getter;
|
||||
import lombok.experimental.SuperBuilder;
|
||||
import software.amazon.awssdk.core.SdkBytes;
|
||||
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
|
||||
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamRequest;
|
||||
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamResponseHandler;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
/**
|
||||
* Bedrock Streaming chat model
|
||||
*/
|
||||
@Getter
|
||||
@SuperBuilder
|
||||
public abstract class AbstractBedrockStreamingChatModel extends AbstractSharedBedrockChatModel implements StreamingChatLanguageModel {
|
||||
@Getter
|
||||
private final BedrockRuntimeAsyncClient asyncClient = initAsyncClient();
|
||||
|
||||
class StreamingResponse {
|
||||
public String completion;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void generate(String userMessage, StreamingResponseHandler<AiMessage> handler) {
|
||||
List<ChatMessage> messages = new ArrayList<>();
|
||||
messages.add(new UserMessage(userMessage));
|
||||
generate(messages, handler);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
|
||||
InvokeModelWithResponseStreamRequest request = InvokeModelWithResponseStreamRequest.builder()
|
||||
.body(SdkBytes.fromUtf8String(convertMessagesToAwsBody(messages)))
|
||||
.modelId(getModelId())
|
||||
.contentType("application/json")
|
||||
.accept("application/json")
|
||||
.build();
|
||||
|
||||
AtomicReference<String> finalCompletion = new AtomicReference<>("");
|
||||
|
||||
InvokeModelWithResponseStreamResponseHandler.Visitor visitor = InvokeModelWithResponseStreamResponseHandler.Visitor.builder()
|
||||
.onChunk(chunk -> {
|
||||
StreamingResponse sr = Json.fromJson(chunk.bytes().asUtf8String(), StreamingResponse.class);
|
||||
finalCompletion.set(finalCompletion.get() + sr.completion);
|
||||
handler.onNext(sr.completion);
|
||||
})
|
||||
.build();
|
||||
|
||||
InvokeModelWithResponseStreamResponseHandler h = InvokeModelWithResponseStreamResponseHandler.builder()
|
||||
.onEventStream(stream -> stream.subscribe(event -> event.accept(visitor)))
|
||||
.onComplete(() -> {
|
||||
handler.onComplete(Response.from(new AiMessage(finalCompletion.get())));
|
||||
})
|
||||
.onError(handler::onError)
|
||||
.build();
|
||||
asyncClient.invokeModelWithResponseStream(request, h).join();
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize async bedrock client
|
||||
*
|
||||
* @return async bedrock client
|
||||
*/
|
||||
private BedrockRuntimeAsyncClient initAsyncClient() {
|
||||
BedrockRuntimeAsyncClient client = BedrockRuntimeAsyncClient.builder()
|
||||
.region(region)
|
||||
.credentialsProvider(credentialsProvider)
|
||||
.build();
|
||||
return client;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,110 @@
|
|||
package dev.langchain4j.model.bedrock.internal;
|
||||
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.data.message.ChatMessageType;
|
||||
import dev.langchain4j.internal.Json;
|
||||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
import lombok.experimental.SuperBuilder;
|
||||
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
|
||||
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static java.util.stream.Collectors.joining;
|
||||
|
||||
@Getter
|
||||
@SuperBuilder
|
||||
public abstract class AbstractSharedBedrockChatModel {
|
||||
protected static final String HUMAN_PROMPT = "Human:";
|
||||
protected static final String ASSISTANT_PROMPT = "Assistant:";
|
||||
protected static final String DEFAULT_ANTHROPIC_VERSION = "bedrock-2023-05-31";
|
||||
|
||||
@Builder.Default
|
||||
protected final String humanPrompt = HUMAN_PROMPT;
|
||||
@Builder.Default
|
||||
protected final String assistantPrompt = ASSISTANT_PROMPT;
|
||||
@Builder.Default
|
||||
protected final Integer maxRetries = 5;
|
||||
@Builder.Default
|
||||
protected final Region region = Region.US_EAST_1;
|
||||
@Builder.Default
|
||||
protected final AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.builder().build();
|
||||
@Builder.Default
|
||||
protected final int maxTokens = 300;
|
||||
@Builder.Default
|
||||
protected final float temperature = 1;
|
||||
@Builder.Default
|
||||
protected final float topP = 0.999f;
|
||||
@Builder.Default
|
||||
protected final String[] stopSequences = new String[]{};
|
||||
@Builder.Default
|
||||
protected final int topK = 250;
|
||||
@Builder.Default
|
||||
protected final String anthropicVersion = DEFAULT_ANTHROPIC_VERSION;
|
||||
|
||||
|
||||
/**
|
||||
* Convert chat message to string
|
||||
*
|
||||
* @param message chat message
|
||||
* @return string
|
||||
*/
|
||||
protected String chatMessageToString(ChatMessage message) {
|
||||
switch (message.type()) {
|
||||
case SYSTEM:
|
||||
return message.text();
|
||||
case USER:
|
||||
return humanPrompt + " " + message.text();
|
||||
case AI:
|
||||
return assistantPrompt + " " + message.text();
|
||||
case TOOL_EXECUTION_RESULT:
|
||||
throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models");
|
||||
}
|
||||
|
||||
throw new IllegalArgumentException("Unknown message type: " + message.type());
|
||||
}
|
||||
|
||||
protected String convertMessagesToAwsBody(List<ChatMessage> messages) {
|
||||
final String context = messages.stream()
|
||||
.filter(message -> message.type() == ChatMessageType.SYSTEM)
|
||||
.map(ChatMessage::text)
|
||||
.collect(joining("\n"));
|
||||
|
||||
final String userMessages = messages.stream()
|
||||
.filter(message -> message.type() != ChatMessageType.SYSTEM)
|
||||
.map(this::chatMessageToString)
|
||||
.collect(joining("\n"));
|
||||
|
||||
final String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT);
|
||||
final Map<String, Object> requestParameters = getRequestParameters(prompt);
|
||||
final String body = Json.toJson(requestParameters);
|
||||
return body;
|
||||
}
|
||||
|
||||
protected Map<String, Object> getRequestParameters(String prompt) {
|
||||
final Map<String, Object> parameters = new HashMap<>(7);
|
||||
|
||||
parameters.put("prompt", prompt);
|
||||
parameters.put("max_tokens_to_sample", getMaxTokens());
|
||||
parameters.put("temperature", getTemperature());
|
||||
parameters.put("top_k", topK);
|
||||
parameters.put("top_p", getTopP());
|
||||
parameters.put("stop_sequences", getStopSequences());
|
||||
parameters.put("anthropic_version", anthropicVersion);
|
||||
|
||||
return parameters;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get model id
|
||||
*
|
||||
* @return model id
|
||||
*/
|
||||
protected abstract String getModelId();
|
||||
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
package dev.langchain4j.model.bedrock;
|
||||
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.StreamingResponseHandler;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
|
||||
import static java.util.concurrent.TimeUnit.SECONDS;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
public class BedrockStreamingChatModelIT {
|
||||
@Test
|
||||
@Disabled("To run this test, you must have provide your own access key, secret, region")
|
||||
void testBedrockAnthropicStreamingChatModel() throws ExecutionException, InterruptedException, TimeoutException {
|
||||
BedrockAnthropicStreamingChatModel bedrockChatModel = BedrockAnthropicStreamingChatModel
|
||||
.builder()
|
||||
.temperature(0.50f)
|
||||
.maxTokens(300)
|
||||
.region(Region.US_EAST_1)
|
||||
.maxRetries(1)
|
||||
.build();
|
||||
|
||||
CompletableFuture<String> futureAnswer = new CompletableFuture<>();
|
||||
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
|
||||
bedrockChatModel.generate("What's the capital of Poland?", new StreamingResponseHandler<AiMessage>() {
|
||||
private final StringBuilder answerBuilder = new StringBuilder();
|
||||
@Override
|
||||
public void onNext(String token) {
|
||||
answerBuilder.append(token);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete(Response<AiMessage> response) {
|
||||
futureAnswer.complete(answerBuilder.toString());
|
||||
futureResponse.complete(response);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable throwable) {
|
||||
System.out.println(throwable);
|
||||
}
|
||||
|
||||
});
|
||||
String answer = futureAnswer.get(30, SECONDS);
|
||||
Response<AiMessage> response = futureResponse.get(30, SECONDS);
|
||||
|
||||
assertThat(answer).contains("Warsaw");
|
||||
assertThat(response.content().text()).contains("Warsaw");
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue