package main import ( "encoding/json" "log" "net/http" "runtime" "time" ) /* #include "llama.h" */ import "C" // Constant LLaMA parameters const ( ParamContextSize = 1024 // The mem_required is 9800MB + 3216MB/state, regardless of the n_ctx size. However it does affect the KV size for persistence ParamTopK = 40 ParamTopP = 0.95 ParamTemperature = 0.08 ParamRepeatPenalty = 1.10 ParamRepeatPenaltyWindowSize = 64 ) func (this *Application) POST_Chat(w http.ResponseWriter, r *http.Request) { flusher, ok := w.(http.Flusher) if !ok { http.Error(w, "The HTTP request does not support live updates", 500) return } // Parse API POST request type requestBody struct { ConversationID string APIKey string Content string MaxTokens int } var apiParams requestBody err := json.NewDecoder(r.Body).Decode(&apiParams) if err != nil { http.Error(w, err.Error(), 400) return } if apiParams.MaxTokens < 0 { http.Error(w, "MaxTokens should be 0 or positive", 400) return } // Verify API key // TODO // Wait for a free worker // TODO signal the queue length to the user? select { case this.sem <- struct{}{}: // OK case <-r.Context().Done(): return // Request cancelled while waiting for free worker } defer func() { <-this.sem }() // Queue request to worker w.Header().Set("Content-Type", "text/plain;charset=UTF-8") w.WriteHeader(200) flusher.Flush() // Flush before any responses, so the webui knows things are happening // Start a new LLaMA session lparams := C.llama_context_default_params() lparams.n_ctx = ParamContextSize lcontext := C.llama_init_from_file(C.CString(this.cfg.ModelPath), lparams) defer C.llama_free(lcontext) // Feed in the conversation so far // TODO stash the contexts for reuse with LRU-caching, for faster startup/resumption apiParams.Content = " " + apiParams.Content // Add leading space to match LLaMA python behaviour llast_n_tokens := make([]C.llama_token, ParamContextSize) llast_n_tokens_used_size := C.llama_tokenize(lcontext, C.CString(apiParams.Content), &llast_n_tokens[0], ParamContextSize, true) if llast_n_tokens_used_size <= 0 { log.Printf("llama_tokenize: got non-positive size (%d)", llast_n_tokens_used_size) http.Error(w, "Internal error", 500) return } if (ParamContextSize - int(llast_n_tokens_used_size)) <= 4 { http.Error(w, "Prompt is too long", 400) return } // We've consumed all the input. EndPos := ParamContextSize if apiParams.MaxTokens != 0 && (int(llast_n_tokens_used_size)+apiParams.MaxTokens) < ParamContextSize { EndPos = int(llast_n_tokens_used_size) + apiParams.MaxTokens } for i := int(llast_n_tokens_used_size); i < EndPos; i += 1 { if err := r.Context().Err(); err != nil { return } // Perform the LLaMA evaluation step evalTokenStart := i - 1 evalTokenCount := 1 evalTokenPast := i if i == int(llast_n_tokens_used_size) { evalTokenStart = 0 evalTokenCount = i evalTokenPast = 0 } evalStartTime := time.Now() evalErr := C.llama_eval(lcontext, &llast_n_tokens[evalTokenStart], C.int(evalTokenCount), // tokens + n_tokens is the provided batch of new tokens to process C.int(evalTokenPast), // n_past is the number of tokens to use from previous eval calls C.int(runtime.GOMAXPROCS(0))) log.Printf("llama_eval: Evaluated %d token(s) in %s", evalTokenCount, time.Now().Sub(evalStartTime).String()) if evalErr != 0 { log.Printf("llama_eval: %d", evalErr) http.Error(w, "Internal error", 500) return } if err := r.Context().Err(); err != nil { return } // Perform the LLaMA sampling step penalizeStart := 0 penalizeLen := i if i > ParamRepeatPenaltyWindowSize { penalizeStart = i - ParamRepeatPenaltyWindowSize penalizeLen = ParamRepeatPenaltyWindowSize } newTokenId := C.llama_sample_top_p_top_k(lcontext, &llast_n_tokens[penalizeStart], C.int(penalizeLen), // Penalize recent tokens ParamTopK, ParamTopP, ParamTemperature, ParamRepeatPenalty) // Other static parameters if newTokenId == C.llama_token_eos() { // The model doesn't have anything to say return } if err := r.Context().Err(); err != nil { return } log.Println("got token OK") // The model did have something to say tokenStr := C.GoString(C.llama_token_to_str(lcontext, newTokenId)) // Push this new token into the llast_n_tokens state, or else we'll just get it over and over again llast_n_tokens[i] = newTokenId w.Write([]byte(tokenStr)) flusher.Flush() } }