From f5ba37a10bc10d004087658f6f8407a2de36ee68 Mon Sep 17 00:00:00 2001 From: mappu Date: Sat, 8 Apr 2023 16:04:16 +1200 Subject: [PATCH] api: simplify/combine the llama_eval branches --- api.go | 46 +++++++++++++++++++--------------------------- 1 file changed, 19 insertions(+), 27 deletions(-) diff --git a/api.go b/api.go index 7265ac8..0d6204c 100644 --- a/api.go +++ b/api.go @@ -110,37 +110,29 @@ func (this *Application) POST_Chat(w http.ResponseWriter, r *http.Request) { return } - // Get the next token from LLaMA + // Perform the LLaMA evaluation step + evalTokenStart := i - 1 + evalTokenCount := 1 + evalTokenPast := i if i == int(llast_n_tokens_used_size) { - - log.Println("doing llama_eval (for the first time on all supplied input)...") - - evalErr := C.llama_eval(lcontext, - &llast_n_tokens[0], C.int(i), // tokens + n_tokens is the provided batch of new tokens to process - C.int(0), // n_past is the number of tokens to use from previous eval calls - C.int(runtime.GOMAXPROCS(0))) - if evalErr != 0 { - log.Printf("llama_eval: %d", evalErr) - http.Error(w, "Internal error", 500) - return - } - - } else { - - log.Println("doing llama_eval (incrementally on the newly generated token)...") - - evalErr := C.llama_eval(lcontext, - &llast_n_tokens[i-1], 1, // tokens + n_tokens is the provided batch of new tokens to process - C.int(i), // n_past is the number of tokens to use from previous eval calls - C.int(runtime.GOMAXPROCS(0))) - if evalErr != 0 { - log.Printf("llama_eval: %d", evalErr) - http.Error(w, "Internal error", 500) - return - } + evalTokenStart = 0 + evalTokenCount = i + evalTokenPast = 0 } + evalStartTime := time.Now() + evalErr := C.llama_eval(lcontext, + &llast_n_tokens[evalTokenStart], C.int(evalTokenCount), // tokens + n_tokens is the provided batch of new tokens to process + C.int(evalTokenPast), // n_past is the number of tokens to use from previous eval calls + C.int(runtime.GOMAXPROCS(0))) + + log.Printf("llama_eval: Evaluated %d token(s) in %s", evalTokenCount, time.Now().Sub(evalStartTime).String()) + if evalErr != 0 { + log.Printf("llama_eval: %d", evalErr) + http.Error(w, "Internal error", 500) + return + } if err := r.Context().Err(); err != nil { return }