package main import ( "encoding/json" "log" "net/http" "runtime" ) /* #include "llama.h" */ import "C" // Constant LLaMA parameters const ( ParamContextSize = 512 // RAM requirements: 512 needs 800MB KV (~3216MB overall), 2048 needs 3200MB KV (~??? overall) ParamTopK = 40 ParamTopP = 0.95 ParamTemperature = 0.08 ParamRepeatPenalty = 1.10 ParamRepeatPenaltyWindowSize = 64 ) func (this *Application) POST_Chat(w http.ResponseWriter, r *http.Request) { flusher, ok := w.(http.Flusher) if !ok { http.Error(w, "The HTTP request does not support live updates", 500) return } // Parse API POST request type requestBody struct { ConversationID string APIKey string Content string } var apiParams requestBody err := json.NewDecoder(r.Body).Decode(&apiParams) if err != nil { http.Error(w, err.Error(), 400) return } // Verify API key // TODO // Wait for a free worker select { case this.sem <- struct{}{}: // OK case <-r.Context().Done(): return // Request cancelled while waiting for free worker } defer func() { <-this.sem }() // Queue request to worker w.Header().Set("Content-Type", "text/plain;charset=UTF-8") w.WriteHeader(200) flusher.Flush() // Flush before any responses, so the webui knows things are happening // Start a new LLaMA session lparams := C.llama_context_default_params() lparams.n_ctx = ParamContextSize lcontext := C.llama_init_from_file(C.CString(this.cfg.ModelPath), lparams) defer C.llama_free(lcontext) // Feed in the conversation so far // TODO stash the contexts for reuse with LRU-caching, for faster startup/resumption apiParams.Content = " " + apiParams.Content // Add leading space to match LLaMA python behaviour llast_n_tokens := make([]C.llama_token, ParamContextSize) log.Println("tokenizing supplied prompt...") llast_n_tokens_used_size := C.llama_tokenize(lcontext, C.CString(apiParams.Content), &llast_n_tokens[0], ParamContextSize, true) if llast_n_tokens_used_size <= 0 { log.Printf("llama_tokenize returned non-positive size (%d)", llast_n_tokens_used_size) http.Error(w, "Internal error", 500) return } if (ParamContextSize - int(llast_n_tokens_used_size)) <= 4 { http.Error(w, "Prompt is too long", 400) return } // We've consumed all the input. for i := int(llast_n_tokens_used_size); i < ParamContextSize; i += 1 { if err := r.Context().Err(); err != nil { return } // Get the next token from LLaMA log.Println("doing llama_eval...") evalErr := C.llama_eval(lcontext, &llast_n_tokens[0], C.int(i), // tokens + n_tokens is the provided batch of new tokens to process C.int(i), // n_past is the number of tokens to use from previous eval calls C.int(runtime.GOMAXPROCS(0))) if evalErr != 0 { log.Printf("llama_eval: %d", evalErr) http.Error(w, "Internal error", 500) return } if err := r.Context().Err(); err != nil { return } // log.Println("doing llama_sample_top_p_top_k...") penalizeStart := 0 penalizeLen := i if i > ParamRepeatPenaltyWindowSize { penalizeStart = i - ParamRepeatPenaltyWindowSize penalizeLen = ParamRepeatPenaltyWindowSize } newTokenId := C.llama_sample_top_p_top_k(lcontext, // Penalize recent tokens &llast_n_tokens[penalizeStart], C.int(penalizeLen), // Other static parameters ParamTopK, ParamTopP, ParamTemperature, ParamRepeatPenalty) if newTokenId == C.llama_token_eos() { // The model doesn't have anything to say return } if err := r.Context().Err(); err != nil { return } log.Println("got token OK") // The model did have something to say tokenStr := C.GoString(C.llama_token_to_str(lcontext, newTokenId)) // Push this new token into the lembedding_ state, or else we'll just get it over and over again llast_n_tokens[i] = newTokenId // time.Sleep(1 * time.Second) w.Write([]byte(tokenStr)) // fmt.Sprintf(" update %d", i))) flusher.Flush() } }