2023-04-08 03:30:15 +00:00
|
|
|
package main
|
|
|
|
|
|
|
|
import (
|
|
|
|
"encoding/json"
|
|
|
|
"log"
|
|
|
|
"net/http"
|
|
|
|
"runtime"
|
|
|
|
)
|
|
|
|
|
|
|
|
/*
|
|
|
|
#include "llama.h"
|
|
|
|
*/
|
|
|
|
import "C"
|
|
|
|
|
2023-04-08 03:48:16 +00:00
|
|
|
// Constant LLaMA parameters
|
|
|
|
|
|
|
|
const (
|
|
|
|
ParamContextSize = 512 // RAM requirements: 512 needs 800MB KV (~3216MB overall), 2048 needs 3200MB KV (~??? overall)
|
|
|
|
ParamTopK = 40
|
|
|
|
ParamTopP = 0.95
|
|
|
|
ParamTemperature = 0.08
|
|
|
|
ParamRepeatPenalty = 1.10
|
|
|
|
ParamRepeatPenaltyWindowSize = 64
|
|
|
|
)
|
|
|
|
|
2023-04-08 03:30:15 +00:00
|
|
|
func (this *Application) POST_Chat(w http.ResponseWriter, r *http.Request) {
|
|
|
|
flusher, ok := w.(http.Flusher)
|
|
|
|
if !ok {
|
|
|
|
http.Error(w, "The HTTP request does not support live updates", 500)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// Parse API POST request
|
|
|
|
|
|
|
|
type requestBody struct {
|
|
|
|
ConversationID string
|
|
|
|
APIKey string
|
|
|
|
Content string
|
2023-04-08 04:03:59 +00:00
|
|
|
MaxTokens int
|
2023-04-08 03:30:15 +00:00
|
|
|
}
|
|
|
|
var apiParams requestBody
|
|
|
|
err := json.NewDecoder(r.Body).Decode(&apiParams)
|
|
|
|
if err != nil {
|
|
|
|
http.Error(w, err.Error(), 400)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2023-04-08 04:03:59 +00:00
|
|
|
if apiParams.MaxTokens < 0 {
|
|
|
|
http.Error(w, "MaxTokens should be 0 or positive", 400)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2023-04-08 03:30:15 +00:00
|
|
|
// Verify API key
|
|
|
|
// TODO
|
|
|
|
|
|
|
|
// Wait for a free worker
|
|
|
|
select {
|
|
|
|
case this.sem <- struct{}{}:
|
|
|
|
// OK
|
|
|
|
case <-r.Context().Done():
|
|
|
|
return // Request cancelled while waiting for free worker
|
|
|
|
}
|
|
|
|
|
|
|
|
defer func() { <-this.sem }()
|
|
|
|
|
|
|
|
// Queue request to worker
|
|
|
|
|
|
|
|
w.Header().Set("Content-Type", "text/plain;charset=UTF-8")
|
|
|
|
w.WriteHeader(200)
|
|
|
|
flusher.Flush() // Flush before any responses, so the webui knows things are happening
|
|
|
|
|
|
|
|
// Start a new LLaMA session
|
|
|
|
|
|
|
|
lparams := C.llama_context_default_params()
|
|
|
|
lparams.n_ctx = ParamContextSize
|
|
|
|
|
|
|
|
lcontext := C.llama_init_from_file(C.CString(this.cfg.ModelPath), lparams)
|
|
|
|
defer C.llama_free(lcontext)
|
|
|
|
|
|
|
|
// Feed in the conversation so far
|
|
|
|
// TODO stash the contexts for reuse with LRU-caching, for faster startup/resumption
|
|
|
|
|
|
|
|
apiParams.Content = " " + apiParams.Content // Add leading space to match LLaMA python behaviour
|
|
|
|
|
|
|
|
llast_n_tokens := make([]C.llama_token, ParamContextSize)
|
|
|
|
|
|
|
|
log.Println("tokenizing supplied prompt...")
|
|
|
|
|
|
|
|
llast_n_tokens_used_size := C.llama_tokenize(lcontext, C.CString(apiParams.Content), &llast_n_tokens[0], ParamContextSize, true)
|
|
|
|
if llast_n_tokens_used_size <= 0 {
|
|
|
|
log.Printf("llama_tokenize returned non-positive size (%d)", llast_n_tokens_used_size)
|
|
|
|
http.Error(w, "Internal error", 500)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
if (ParamContextSize - int(llast_n_tokens_used_size)) <= 4 {
|
|
|
|
http.Error(w, "Prompt is too long", 400)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// We've consumed all the input.
|
|
|
|
|
2023-04-08 04:03:59 +00:00
|
|
|
EndPos := ParamContextSize
|
|
|
|
if apiParams.MaxTokens != 0 && (int(llast_n_tokens_used_size)+apiParams.MaxTokens) < ParamContextSize {
|
|
|
|
EndPos = int(llast_n_tokens_used_size) + apiParams.MaxTokens
|
|
|
|
}
|
|
|
|
|
|
|
|
for i := int(llast_n_tokens_used_size); i < EndPos; i += 1 {
|
2023-04-08 03:30:15 +00:00
|
|
|
if err := r.Context().Err(); err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// Get the next token from LLaMA
|
|
|
|
|
2023-04-08 03:48:38 +00:00
|
|
|
if i == int(llast_n_tokens_used_size) {
|
|
|
|
|
|
|
|
log.Println("doing llama_eval (for the first time on all supplied input)...")
|
|
|
|
|
|
|
|
evalErr := C.llama_eval(lcontext,
|
|
|
|
&llast_n_tokens[0], C.int(i), // tokens + n_tokens is the provided batch of new tokens to process
|
|
|
|
C.int(0), // n_past is the number of tokens to use from previous eval calls
|
|
|
|
C.int(runtime.GOMAXPROCS(0)))
|
|
|
|
if evalErr != 0 {
|
|
|
|
log.Printf("llama_eval: %d", evalErr)
|
|
|
|
http.Error(w, "Internal error", 500)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
log.Println("doing llama_eval (incrementally on the newly generated token)...")
|
|
|
|
|
|
|
|
evalErr := C.llama_eval(lcontext,
|
|
|
|
&llast_n_tokens[i-1], 1, // tokens + n_tokens is the provided batch of new tokens to process
|
|
|
|
C.int(i), // n_past is the number of tokens to use from previous eval calls
|
|
|
|
C.int(runtime.GOMAXPROCS(0)))
|
|
|
|
if evalErr != 0 {
|
|
|
|
log.Printf("llama_eval: %d", evalErr)
|
|
|
|
http.Error(w, "Internal error", 500)
|
|
|
|
return
|
|
|
|
}
|
2023-04-08 03:30:15 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
if err := r.Context().Err(); err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
log.Println("doing llama_sample_top_p_top_k...")
|
|
|
|
|
|
|
|
penalizeStart := 0
|
|
|
|
penalizeLen := i
|
|
|
|
if i > ParamRepeatPenaltyWindowSize {
|
|
|
|
penalizeStart = i - ParamRepeatPenaltyWindowSize
|
|
|
|
penalizeLen = ParamRepeatPenaltyWindowSize
|
|
|
|
}
|
|
|
|
|
|
|
|
newTokenId := C.llama_sample_top_p_top_k(lcontext,
|
|
|
|
|
|
|
|
// Penalize recent tokens
|
|
|
|
&llast_n_tokens[penalizeStart], C.int(penalizeLen),
|
|
|
|
|
|
|
|
// Other static parameters
|
|
|
|
ParamTopK, ParamTopP, ParamTemperature, ParamRepeatPenalty)
|
|
|
|
|
|
|
|
if newTokenId == C.llama_token_eos() {
|
|
|
|
// The model doesn't have anything to say
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
if err := r.Context().Err(); err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
log.Println("got token OK")
|
|
|
|
|
|
|
|
// The model did have something to say
|
|
|
|
tokenStr := C.GoString(C.llama_token_to_str(lcontext, newTokenId))
|
|
|
|
|
|
|
|
// Push this new token into the lembedding_ state, or else we'll just get it over and over again
|
|
|
|
llast_n_tokens[i] = newTokenId
|
|
|
|
|
|
|
|
// time.Sleep(1 * time.Second)
|
|
|
|
w.Write([]byte(tokenStr)) // fmt.Sprintf(" update %d", i)))
|
|
|
|
flusher.Flush()
|
|
|
|
}
|
|
|
|
}
|