api: simplify/combine the llama_eval branches

This commit is contained in:
mappu 2023-04-08 16:04:16 +12:00
parent 0c96f2bf6b
commit f5ba37a10b

34
api.go
View File

@ -110,37 +110,29 @@ func (this *Application) POST_Chat(w http.ResponseWriter, r *http.Request) {
return return
} }
// Get the next token from LLaMA // Perform the LLaMA evaluation step
evalTokenStart := i - 1
evalTokenCount := 1
evalTokenPast := i
if i == int(llast_n_tokens_used_size) { if i == int(llast_n_tokens_used_size) {
evalTokenStart = 0
evalTokenCount = i
evalTokenPast = 0
}
log.Println("doing llama_eval (for the first time on all supplied input)...") evalStartTime := time.Now()
evalErr := C.llama_eval(lcontext, evalErr := C.llama_eval(lcontext,
&llast_n_tokens[0], C.int(i), // tokens + n_tokens is the provided batch of new tokens to process &llast_n_tokens[evalTokenStart], C.int(evalTokenCount), // tokens + n_tokens is the provided batch of new tokens to process
C.int(0), // n_past is the number of tokens to use from previous eval calls C.int(evalTokenPast), // n_past is the number of tokens to use from previous eval calls
C.int(runtime.GOMAXPROCS(0))) C.int(runtime.GOMAXPROCS(0)))
log.Printf("llama_eval: Evaluated %d token(s) in %s", evalTokenCount, time.Now().Sub(evalStartTime).String())
if evalErr != 0 { if evalErr != 0 {
log.Printf("llama_eval: %d", evalErr) log.Printf("llama_eval: %d", evalErr)
http.Error(w, "Internal error", 500) http.Error(w, "Internal error", 500)
return return
} }
} else {
log.Println("doing llama_eval (incrementally on the newly generated token)...")
evalErr := C.llama_eval(lcontext,
&llast_n_tokens[i-1], 1, // tokens + n_tokens is the provided batch of new tokens to process
C.int(i), // n_past is the number of tokens to use from previous eval calls
C.int(runtime.GOMAXPROCS(0)))
if evalErr != 0 {
log.Printf("llama_eval: %d", evalErr)
http.Error(w, "Internal error", 500)
return
}
}
if err := r.Context().Err(); err != nil { if err := r.Context().Err(); err != nil {
return return
} }