api: simplify/combine the llama_eval branches
This commit is contained in:
parent
0c96f2bf6b
commit
f5ba37a10b
46
api.go
46
api.go
@ -110,37 +110,29 @@ func (this *Application) POST_Chat(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the next token from LLaMA
|
// Perform the LLaMA evaluation step
|
||||||
|
|
||||||
|
evalTokenStart := i - 1
|
||||||
|
evalTokenCount := 1
|
||||||
|
evalTokenPast := i
|
||||||
if i == int(llast_n_tokens_used_size) {
|
if i == int(llast_n_tokens_used_size) {
|
||||||
|
evalTokenStart = 0
|
||||||
log.Println("doing llama_eval (for the first time on all supplied input)...")
|
evalTokenCount = i
|
||||||
|
evalTokenPast = 0
|
||||||
evalErr := C.llama_eval(lcontext,
|
|
||||||
&llast_n_tokens[0], C.int(i), // tokens + n_tokens is the provided batch of new tokens to process
|
|
||||||
C.int(0), // n_past is the number of tokens to use from previous eval calls
|
|
||||||
C.int(runtime.GOMAXPROCS(0)))
|
|
||||||
if evalErr != 0 {
|
|
||||||
log.Printf("llama_eval: %d", evalErr)
|
|
||||||
http.Error(w, "Internal error", 500)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
|
|
||||||
log.Println("doing llama_eval (incrementally on the newly generated token)...")
|
|
||||||
|
|
||||||
evalErr := C.llama_eval(lcontext,
|
|
||||||
&llast_n_tokens[i-1], 1, // tokens + n_tokens is the provided batch of new tokens to process
|
|
||||||
C.int(i), // n_past is the number of tokens to use from previous eval calls
|
|
||||||
C.int(runtime.GOMAXPROCS(0)))
|
|
||||||
if evalErr != 0 {
|
|
||||||
log.Printf("llama_eval: %d", evalErr)
|
|
||||||
http.Error(w, "Internal error", 500)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
evalStartTime := time.Now()
|
||||||
|
evalErr := C.llama_eval(lcontext,
|
||||||
|
&llast_n_tokens[evalTokenStart], C.int(evalTokenCount), // tokens + n_tokens is the provided batch of new tokens to process
|
||||||
|
C.int(evalTokenPast), // n_past is the number of tokens to use from previous eval calls
|
||||||
|
C.int(runtime.GOMAXPROCS(0)))
|
||||||
|
|
||||||
|
log.Printf("llama_eval: Evaluated %d token(s) in %s", evalTokenCount, time.Now().Sub(evalStartTime).String())
|
||||||
|
if evalErr != 0 {
|
||||||
|
log.Printf("llama_eval: %d", evalErr)
|
||||||
|
http.Error(w, "Internal error", 500)
|
||||||
|
return
|
||||||
|
}
|
||||||
if err := r.Context().Err(); err != nil {
|
if err := r.Context().Err(); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user