llamacpphtmld/api.go

174 lines
4.5 KiB
Go

package main
import (
"encoding/json"
"log"
"net/http"
"runtime"
"time"
)
/*
#include "llama.h"
*/
import "C"
// Constant LLaMA parameters
const (
ParamContextSize = 1024 // The mem_required is 9800MB + 3216MB/state, regardless of the n_ctx size. However it does affect the KV size for persistence
ParamTopK = 40
ParamTopP = 0.95
ParamTemperature = 0.08
ParamRepeatPenalty = 1.10
ParamRepeatPenaltyWindowSize = 64
)
func (this *Application) POST_Chat(w http.ResponseWriter, r *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "The HTTP request does not support live updates", 500)
return
}
// Parse API POST request
type requestBody struct {
ConversationID string
APIKey string
Content string
MaxTokens int
}
var apiParams requestBody
err := json.NewDecoder(r.Body).Decode(&apiParams)
if err != nil {
http.Error(w, err.Error(), 400)
return
}
if apiParams.MaxTokens < 0 {
http.Error(w, "MaxTokens should be 0 or positive", 400)
return
}
// Verify API key
// TODO
// Wait for a free worker
// TODO signal the queue length to the user?
select {
case this.sem <- struct{}{}:
// OK
case <-r.Context().Done():
return // Request cancelled while waiting for free worker
}
defer func() { <-this.sem }()
// Queue request to worker
w.Header().Set("Content-Type", "text/plain;charset=UTF-8")
w.WriteHeader(200)
flusher.Flush() // Flush before any responses, so the webui knows things are happening
// Start a new LLaMA session
lparams := C.llama_context_default_params()
lparams.n_ctx = ParamContextSize
lcontext := C.llama_init_from_file(C.CString(this.cfg.ModelPath), lparams)
defer C.llama_free(lcontext)
// Feed in the conversation so far
// TODO stash the contexts for reuse with LRU-caching, for faster startup/resumption
apiParams.Content = " " + apiParams.Content // Add leading space to match LLaMA python behaviour
llast_n_tokens := make([]C.llama_token, ParamContextSize)
llast_n_tokens_used_size := C.llama_tokenize(lcontext, C.CString(apiParams.Content), &llast_n_tokens[0], ParamContextSize, true)
if llast_n_tokens_used_size <= 0 {
log.Printf("llama_tokenize: got non-positive size (%d)", llast_n_tokens_used_size)
http.Error(w, "Internal error", 500)
return
}
if (ParamContextSize - int(llast_n_tokens_used_size)) <= 4 {
http.Error(w, "Prompt is too long", 400)
return
}
// We've consumed all the input.
EndPos := ParamContextSize
if apiParams.MaxTokens != 0 && (int(llast_n_tokens_used_size)+apiParams.MaxTokens) < ParamContextSize {
EndPos = int(llast_n_tokens_used_size) + apiParams.MaxTokens
}
for i := int(llast_n_tokens_used_size); i < EndPos; i += 1 {
if err := r.Context().Err(); err != nil {
return
}
// Perform the LLaMA evaluation step
evalTokenStart := i - 1
evalTokenCount := 1
evalTokenPast := i
if i == int(llast_n_tokens_used_size) {
evalTokenStart = 0
evalTokenCount = i
evalTokenPast = 0
}
evalStartTime := time.Now()
evalErr := C.llama_eval(lcontext,
&llast_n_tokens[evalTokenStart], C.int(evalTokenCount), // tokens + n_tokens is the provided batch of new tokens to process
C.int(evalTokenPast), // n_past is the number of tokens to use from previous eval calls
C.int(runtime.GOMAXPROCS(0)))
log.Printf("llama_eval: Evaluated %d token(s) in %s", evalTokenCount, time.Now().Sub(evalStartTime).String())
if evalErr != 0 {
log.Printf("llama_eval: %d", evalErr)
http.Error(w, "Internal error", 500)
return
}
if err := r.Context().Err(); err != nil {
return
}
// Perform the LLaMA sampling step
penalizeStart := 0
penalizeLen := i
if i > ParamRepeatPenaltyWindowSize {
penalizeStart = i - ParamRepeatPenaltyWindowSize
penalizeLen = ParamRepeatPenaltyWindowSize
}
newTokenId := C.llama_sample_top_p_top_k(lcontext,
&llast_n_tokens[penalizeStart], C.int(penalizeLen), // Penalize recent tokens
ParamTopK, ParamTopP, ParamTemperature, ParamRepeatPenalty) // Other static parameters
if newTokenId == C.llama_token_eos() {
// The model doesn't have anything to say
return
}
if err := r.Context().Err(); err != nil {
return
}
log.Println("got token OK")
// The model did have something to say
tokenStr := C.GoString(C.llama_token_to_str(lcontext, newTokenId))
// Push this new token into the llast_n_tokens state, or else we'll just get it over and over again
llast_n_tokens[i] = newTokenId
w.Write([]byte(tokenStr))
flusher.Flush()
}
}