improve swift API and test app
This commit is contained in:
parent
e3df7bb91f
commit
60458cc580
|
@ -22,6 +22,14 @@ public class LlamaRunner {
|
|||
}
|
||||
}
|
||||
|
||||
public enum RunState {
|
||||
case notStarted
|
||||
case initializing
|
||||
case generatingOutput
|
||||
case completed
|
||||
case failed(error: Error)
|
||||
}
|
||||
|
||||
public let modelURL: URL
|
||||
|
||||
private lazy var bridge = _LlamaRunnerBridge(modelPath: modelURL.path)
|
||||
|
@ -33,19 +41,40 @@ public class LlamaRunner {
|
|||
public func run(
|
||||
with prompt: String,
|
||||
config: Config = .default,
|
||||
completion: @escaping () -> Void
|
||||
tokenHandler: @escaping (String) -> Void,
|
||||
stateChangeHandler: ((RunState) -> Void)? = nil
|
||||
) {
|
||||
let _config = _LlamaRunnerBridgeConfig()
|
||||
_config.numberOfThreads = config.numThreads
|
||||
_config.numberOfTokens = config.numTokens
|
||||
_config.reversePrompt = config.reversePrompt
|
||||
|
||||
stateChangeHandler?(.notStarted)
|
||||
|
||||
bridge.run(
|
||||
withPrompt: prompt,
|
||||
config: _config,
|
||||
eventHandler: { event in
|
||||
event.match(
|
||||
startedLoadingModel: {
|
||||
stateChangeHandler?(.initializing)
|
||||
},
|
||||
completion: completion
|
||||
finishedLoadingModel: {},
|
||||
startedGeneratingOutput: {
|
||||
stateChangeHandler?(.generatingOutput)
|
||||
},
|
||||
outputToken: { token in
|
||||
tokenHandler(token)
|
||||
},
|
||||
completed: {
|
||||
stateChangeHandler?(.completed)
|
||||
},
|
||||
failed: { error in
|
||||
stateChangeHandler?(.failed(error: error))
|
||||
}
|
||||
)
|
||||
},
|
||||
eventHandlerQueue: DispatchQueue.main
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ typedef void (^_LlamaRunnerBridgeEventHandler)(_LlamaEvent *event);
|
|||
- (void)runWithPrompt:(nonnull NSString*)prompt
|
||||
config:(nonnull _LlamaRunnerBridgeConfig *)config
|
||||
eventHandler:(nonnull _LlamaRunnerBridgeEventHandler)eventHandler
|
||||
completion:(void (^)())completion;
|
||||
eventHandlerQueue:(nonnull dispatch_queue_t)eventHandlerQueue;
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
- (void)runWithPrompt:(nonnull NSString*)prompt
|
||||
config:(nonnull _LlamaRunnerBridgeConfig *)config
|
||||
eventHandler:(nonnull _LlamaRunnerBridgeEventHandler)eventHandler
|
||||
completion:(void (^)())completion
|
||||
eventHandlerQueue:(nonnull dispatch_queue_t)eventHandlerQueue
|
||||
{
|
||||
gpt_params params;
|
||||
params.model = [_modelPath cStringUsingEncoding:NSUTF8StringEncoding];
|
||||
|
@ -44,8 +44,7 @@
|
|||
|
||||
LlamaPredictOperation *operation = [[LlamaPredictOperation alloc] initWithParams:params
|
||||
eventHandler:eventHandler
|
||||
eventHandlerQueue:dispatch_get_main_queue()];
|
||||
[operation setCompletionBlock:completion];
|
||||
eventHandlerQueue:eventHandlerQueue];
|
||||
[_operationQueue addOperation:operation];
|
||||
}
|
||||
|
||||
|
|
|
@ -25,8 +25,25 @@ let semaphore = DispatchSemaphore(value: 0)
|
|||
let llama = LlamaRunner(modelURL: url)
|
||||
llama.run(
|
||||
with: "Building a website can be done in 10 simple steps:",
|
||||
completion: {
|
||||
tokenHandler: { token in
|
||||
print(token, terminator: "")
|
||||
},
|
||||
stateChangeHandler: { state in
|
||||
switch state {
|
||||
case .notStarted:
|
||||
break
|
||||
case .initializing:
|
||||
print("Initializing model... ", terminator: "")
|
||||
break
|
||||
case .generatingOutput:
|
||||
print("Done.")
|
||||
break
|
||||
case .completed:
|
||||
semaphore.signal()
|
||||
case .failed(error: let error):
|
||||
print("")
|
||||
print("Failed to generate output: ", error.localizedDescription)
|
||||
}
|
||||
})
|
||||
|
||||
while semaphore.wait(timeout: .now()) == .timedOut {
|
||||
|
|
Loading…
Reference in New Issue