diff --git a/api.go b/api.go index 3ca633d..45e1980 100644 --- a/api.go +++ b/api.go @@ -101,16 +101,33 @@ func (this *Application) POST_Chat(w http.ResponseWriter, r *http.Request) { // Get the next token from LLaMA - log.Println("doing llama_eval...") + if i == int(llast_n_tokens_used_size) { - evalErr := C.llama_eval(lcontext, - &llast_n_tokens[0], C.int(i), // tokens + n_tokens is the provided batch of new tokens to process - C.int(i), // n_past is the number of tokens to use from previous eval calls - C.int(runtime.GOMAXPROCS(0))) - if evalErr != 0 { - log.Printf("llama_eval: %d", evalErr) - http.Error(w, "Internal error", 500) - return + log.Println("doing llama_eval (for the first time on all supplied input)...") + + evalErr := C.llama_eval(lcontext, + &llast_n_tokens[0], C.int(i), // tokens + n_tokens is the provided batch of new tokens to process + C.int(0), // n_past is the number of tokens to use from previous eval calls + C.int(runtime.GOMAXPROCS(0))) + if evalErr != 0 { + log.Printf("llama_eval: %d", evalErr) + http.Error(w, "Internal error", 500) + return + } + + } else { + + log.Println("doing llama_eval (incrementally on the newly generated token)...") + + evalErr := C.llama_eval(lcontext, + &llast_n_tokens[i-1], 1, // tokens + n_tokens is the provided batch of new tokens to process + C.int(i), // n_past is the number of tokens to use from previous eval calls + C.int(runtime.GOMAXPROCS(0))) + if evalErr != 0 { + log.Printf("llama_eval: %d", evalErr) + http.Error(w, "Internal error", 500) + return + } } if err := r.Context().Err(); err != nil {