improve swift API and test app

This commit is contained in:
Alex Rozanski 2023-03-14 11:17:24 +01:00
parent e3df7bb91f
commit 60458cc580
4 changed files with 53 additions and 8 deletions

View File

@ -22,6 +22,14 @@ public class LlamaRunner {
}
}
public enum RunState {
case notStarted
case initializing
case generatingOutput
case completed
case failed(error: Error)
}
public let modelURL: URL
private lazy var bridge = _LlamaRunnerBridge(modelPath: modelURL.path)
@ -33,19 +41,40 @@ public class LlamaRunner {
public func run(
with prompt: String,
config: Config = .default,
completion: @escaping () -> Void
tokenHandler: @escaping (String) -> Void,
stateChangeHandler: ((RunState) -> Void)? = nil
) {
let _config = _LlamaRunnerBridgeConfig()
_config.numberOfThreads = config.numThreads
_config.numberOfTokens = config.numTokens
_config.reversePrompt = config.reversePrompt
stateChangeHandler?(.notStarted)
bridge.run(
withPrompt: prompt,
config: _config,
eventHandler: { event in
event.match(
startedLoadingModel: {
stateChangeHandler?(.initializing)
},
finishedLoadingModel: {},
startedGeneratingOutput: {
stateChangeHandler?(.generatingOutput)
},
outputToken: { token in
tokenHandler(token)
},
completed: {
stateChangeHandler?(.completed)
},
failed: { error in
stateChangeHandler?(.failed(error: error))
}
)
},
completion: completion
eventHandlerQueue: DispatchQueue.main
)
}
}

View File

@ -23,7 +23,7 @@ typedef void (^_LlamaRunnerBridgeEventHandler)(_LlamaEvent *event);
- (void)runWithPrompt:(nonnull NSString*)prompt
config:(nonnull _LlamaRunnerBridgeConfig *)config
eventHandler:(nonnull _LlamaRunnerBridgeEventHandler)eventHandler
completion:(void (^)())completion;
eventHandlerQueue:(nonnull dispatch_queue_t)eventHandlerQueue;
@end
NS_ASSUME_NONNULL_END

View File

@ -29,7 +29,7 @@
- (void)runWithPrompt:(nonnull NSString*)prompt
config:(nonnull _LlamaRunnerBridgeConfig *)config
eventHandler:(nonnull _LlamaRunnerBridgeEventHandler)eventHandler
completion:(void (^)())completion
eventHandlerQueue:(nonnull dispatch_queue_t)eventHandlerQueue
{
gpt_params params;
params.model = [_modelPath cStringUsingEncoding:NSUTF8StringEncoding];
@ -44,8 +44,7 @@
LlamaPredictOperation *operation = [[LlamaPredictOperation alloc] initWithParams:params
eventHandler:eventHandler
eventHandlerQueue:dispatch_get_main_queue()];
[operation setCompletionBlock:completion];
eventHandlerQueue:eventHandlerQueue];
[_operationQueue addOperation:operation];
}

View File

@ -25,8 +25,25 @@ let semaphore = DispatchSemaphore(value: 0)
let llama = LlamaRunner(modelURL: url)
llama.run(
with: "Building a website can be done in 10 simple steps:",
completion: {
semaphore.signal()
tokenHandler: { token in
print(token, terminator: "")
},
stateChangeHandler: { state in
switch state {
case .notStarted:
break
case .initializing:
print("Initializing model... ", terminator: "")
break
case .generatingOutput:
print("Done.")
break
case .completed:
semaphore.signal()
case .failed(error: let error):
print("")
print("Failed to generate output: ", error.localizedDescription)
}
})
while semaphore.wait(timeout: .now()) == .timedOut {