From d044a9e4246fa1768698a4ac9945fe8b885fe624 Mon Sep 17 00:00:00 2001 From: mappu Date: Sat, 8 Apr 2023 15:30:15 +1200 Subject: [PATCH] initial commit --- api.go | 162 ++++++++++++++++++++++++++++++++++++++++++ cflags_linux_amd64.go | 7 ++ cflags_linux_arm64.go | 7 ++ go.mod | 7 ++ go.sum | 2 + main.go | 75 +++++++++++++++++++ webui.go | 123 ++++++++++++++++++++++++++++++++ 7 files changed, 383 insertions(+) create mode 100644 api.go create mode 100644 cflags_linux_amd64.go create mode 100644 cflags_linux_arm64.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 main.go create mode 100644 webui.go diff --git a/api.go b/api.go new file mode 100644 index 0000000..be2faf2 --- /dev/null +++ b/api.go @@ -0,0 +1,162 @@ +package main + +import ( + "encoding/json" + "log" + "net/http" + "runtime" +) + +/* +#include "llama.h" +*/ +import "C" + +func (this *Application) POST_Chat(w http.ResponseWriter, r *http.Request) { + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "The HTTP request does not support live updates", 500) + return + } + + // Parse API POST request + + type requestBody struct { + ConversationID string + APIKey string + Content string + } + var apiParams requestBody + err := json.NewDecoder(r.Body).Decode(&apiParams) + if err != nil { + http.Error(w, err.Error(), 400) + return + } + + // Verify API key + // TODO + + // Wait for a free worker + select { + case this.sem <- struct{}{}: + // OK + case <-r.Context().Done(): + return // Request cancelled while waiting for free worker + } + + defer func() { <-this.sem }() + + // Queue request to worker + + w.Header().Set("Content-Type", "text/plain;charset=UTF-8") + w.WriteHeader(200) + flusher.Flush() // Flush before any responses, so the webui knows things are happening + + // Constant LLaMA parameters + + const ( + ParamContextSize = 512 // RAM requirements: 512 needs 800MB KV (~3216MB overall), 2048 needs 3200MB KV (~??? overall) + ParamTopK = 40 + ParamTopP = 0.95 + ParamTemperature = 0.08 + ParamRepeatPenalty = 1.10 + ParamRepeatPenaltyWindowSize = 64 + ) + + // Start a new LLaMA session + + lparams := C.llama_context_default_params() + lparams.n_ctx = ParamContextSize + + lcontext := C.llama_init_from_file(C.CString(this.cfg.ModelPath), lparams) + defer C.llama_free(lcontext) + + // Feed in the conversation so far + // TODO stash the contexts for reuse with LRU-caching, for faster startup/resumption + + apiParams.Content = " " + apiParams.Content // Add leading space to match LLaMA python behaviour + + llast_n_tokens := make([]C.llama_token, ParamContextSize) + + log.Println("tokenizing supplied prompt...") + + llast_n_tokens_used_size := C.llama_tokenize(lcontext, C.CString(apiParams.Content), &llast_n_tokens[0], ParamContextSize, true) + if llast_n_tokens_used_size <= 0 { + log.Printf("llama_tokenize returned non-positive size (%d)", llast_n_tokens_used_size) + http.Error(w, "Internal error", 500) + return + } + + if (ParamContextSize - int(llast_n_tokens_used_size)) <= 4 { + http.Error(w, "Prompt is too long", 400) + return + } + + // We've consumed all the input. + + for i := int(llast_n_tokens_used_size); i < ParamContextSize; i += 1 { + if err := r.Context().Err(); err != nil { + return + } + + // Get the next token from LLaMA + + log.Println("doing llama_eval...") + + evalErr := C.llama_eval(lcontext, + &llast_n_tokens[0], C.int(i), // tokens + n_tokens is the provided batch of new tokens to process + C.int(i), // n_past is the number of tokens to use from previous eval calls + C.int(runtime.GOMAXPROCS(0))) + if evalErr != 0 { + log.Printf("llama_eval: %d", evalErr) + http.Error(w, "Internal error", 500) + return + } + + if err := r.Context().Err(); err != nil { + return + } + + // + + log.Println("doing llama_sample_top_p_top_k...") + + penalizeStart := 0 + penalizeLen := i + if i > ParamRepeatPenaltyWindowSize { + penalizeStart = i - ParamRepeatPenaltyWindowSize + penalizeLen = ParamRepeatPenaltyWindowSize + } + + newTokenId := C.llama_sample_top_p_top_k(lcontext, + + // Penalize recent tokens + &llast_n_tokens[penalizeStart], C.int(penalizeLen), + + // Other static parameters + ParamTopK, ParamTopP, ParamTemperature, ParamRepeatPenalty) + + if newTokenId == C.llama_token_eos() { + // The model doesn't have anything to say + return + } + + if err := r.Context().Err(); err != nil { + return + } + + log.Println("got token OK") + + // The model did have something to say + tokenStr := C.GoString(C.llama_token_to_str(lcontext, newTokenId)) + + log.Printf("token is %q", tokenStr) + + // Push this new token into the lembedding_ state, or else we'll just get it over and over again + llast_n_tokens[i] = newTokenId + + // time.Sleep(1 * time.Second) + w.Write([]byte(tokenStr)) // fmt.Sprintf(" update %d", i))) + flusher.Flush() + } +} diff --git a/cflags_linux_amd64.go b/cflags_linux_amd64.go new file mode 100644 index 0000000..0129943 --- /dev/null +++ b/cflags_linux_amd64.go @@ -0,0 +1,7 @@ +package main + +/* +#cgo CFLAGS: -O3 -DNDEBUG -std=c11 -march=native -mtune=native -pthread +#cgo CXXFLAGS: -O3 -DNDEBUG -std=c++11 -march=native -mtune=native -pthread +*/ +import "C" diff --git a/cflags_linux_arm64.go b/cflags_linux_arm64.go new file mode 100644 index 0000000..6e585e2 --- /dev/null +++ b/cflags_linux_arm64.go @@ -0,0 +1,7 @@ +package main + +/* +#cgo CFLAGS: -O3 -DNDEBUG -std=c11 -mcpu -pthread +#cgo CXXFLAGS: -O3 -DNDEBUG -std=c++11 -mcpu -pthread +*/ +import "C" diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..8efece7 --- /dev/null +++ b/go.mod @@ -0,0 +1,7 @@ +module code.ivysaur.me/llamacpphtmld + +go 1.19 + +require ( + github.com/google/uuid v1.3.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..43b557d --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= \ No newline at end of file diff --git a/main.go b/main.go new file mode 100644 index 0000000..333c1bd --- /dev/null +++ b/main.go @@ -0,0 +1,75 @@ +package main + +import ( + "fmt" + "log" + "net/http" + "os" + "strconv" +) + +const ( + AppTitle = `llamacpphtmld` + AppVersion = `0.0.0-dev` // should be overridden by go build argument +) + +type Config struct { + NetBind string + ModelPath string + SimultaneousRequests int +} + +func NewConfigFromEnv() (Config, error) { + ret := Config{ + NetBind: os.Getenv(`LCH_NET_BIND`), + ModelPath: os.Getenv(`LCH_MODEL_PATH`), + } + + SimultaneousRequests, err := strconv.Atoi(os.Getenv(`LCH_SIMULTANEOUS_REQUESTS`)) + if err != nil { + return Config{}, fmt.Errorf("LCH_SIMULTANEOUS_REQUESTS: %w", err) + } + ret.SimultaneousRequests = SimultaneousRequests + + if _, err := os.Stat(ret.ModelPath); err != nil { + return Config{}, fmt.Errorf("LCH_MODEL_PATH: %w", err) + } + + return ret, nil +} + +type Application struct { + cfg Config + sem chan (struct{}) +} + +func main() { + log.Printf("%s v%s", AppTitle, AppVersion) + + cfg, err := NewConfigFromEnv() + if err != nil { + log.Fatal(err) + } + + app := Application{ + cfg: cfg, + sem: make(chan struct{}, cfg.SimultaneousRequests), // use a buffered channel as a semaphore + } + + router := http.NewServeMux() + router.HandleFunc(`/`, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(`Server`, AppTitle+`/`+AppVersion) + + if r.Method == `GET` && r.URL.Path == `/` { + app.GET_Root(w, r) + } else if r.Method == `POST` && r.URL.Path == `/api/v1/generate` { + app.POST_Chat(w, r) + } else { + http.Error(w, "Not found", 404) + } + }) + + log.Printf("Listening on %s ...", cfg.NetBind) + + log.Fatal(http.ListenAndServe(cfg.NetBind, router)) +} diff --git a/webui.go b/webui.go new file mode 100644 index 0000000..9ad08f0 --- /dev/null +++ b/webui.go @@ -0,0 +1,123 @@ +package main + +import ( + "html" + "net/http" + + "github.com/google/uuid" +) + +func (this *Application) GET_Root(w http.ResponseWriter, r *http.Request) { + w.Header().Set(`Content-Type`, `text/html;charset=UTF-8`) + w.WriteHeader(200) + w.Write([]byte(` + + + + ` + html.EscapeString(AppTitle) + ` + + +

🦙 ` + html.EscapeString(AppTitle) + `

+ + + + + + +`)) +}