initial commit

This commit is contained in:
mappu 2023-04-08 15:30:15 +12:00
parent 7c6a0cdaa2
commit d044a9e424
7 changed files with 383 additions and 0 deletions

162
api.go Normal file
View File

@ -0,0 +1,162 @@
package main
import (
"encoding/json"
"log"
"net/http"
"runtime"
)
/*
#include "llama.h"
*/
import "C"
func (this *Application) POST_Chat(w http.ResponseWriter, r *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "The HTTP request does not support live updates", 500)
return
}
// Parse API POST request
type requestBody struct {
ConversationID string
APIKey string
Content string
}
var apiParams requestBody
err := json.NewDecoder(r.Body).Decode(&apiParams)
if err != nil {
http.Error(w, err.Error(), 400)
return
}
// Verify API key
// TODO
// Wait for a free worker
select {
case this.sem <- struct{}{}:
// OK
case <-r.Context().Done():
return // Request cancelled while waiting for free worker
}
defer func() { <-this.sem }()
// Queue request to worker
w.Header().Set("Content-Type", "text/plain;charset=UTF-8")
w.WriteHeader(200)
flusher.Flush() // Flush before any responses, so the webui knows things are happening
// Constant LLaMA parameters
const (
ParamContextSize = 512 // RAM requirements: 512 needs 800MB KV (~3216MB overall), 2048 needs 3200MB KV (~??? overall)
ParamTopK = 40
ParamTopP = 0.95
ParamTemperature = 0.08
ParamRepeatPenalty = 1.10
ParamRepeatPenaltyWindowSize = 64
)
// Start a new LLaMA session
lparams := C.llama_context_default_params()
lparams.n_ctx = ParamContextSize
lcontext := C.llama_init_from_file(C.CString(this.cfg.ModelPath), lparams)
defer C.llama_free(lcontext)
// Feed in the conversation so far
// TODO stash the contexts for reuse with LRU-caching, for faster startup/resumption
apiParams.Content = " " + apiParams.Content // Add leading space to match LLaMA python behaviour
llast_n_tokens := make([]C.llama_token, ParamContextSize)
log.Println("tokenizing supplied prompt...")
llast_n_tokens_used_size := C.llama_tokenize(lcontext, C.CString(apiParams.Content), &llast_n_tokens[0], ParamContextSize, true)
if llast_n_tokens_used_size <= 0 {
log.Printf("llama_tokenize returned non-positive size (%d)", llast_n_tokens_used_size)
http.Error(w, "Internal error", 500)
return
}
if (ParamContextSize - int(llast_n_tokens_used_size)) <= 4 {
http.Error(w, "Prompt is too long", 400)
return
}
// We've consumed all the input.
for i := int(llast_n_tokens_used_size); i < ParamContextSize; i += 1 {
if err := r.Context().Err(); err != nil {
return
}
// Get the next token from LLaMA
log.Println("doing llama_eval...")
evalErr := C.llama_eval(lcontext,
&llast_n_tokens[0], C.int(i), // tokens + n_tokens is the provided batch of new tokens to process
C.int(i), // n_past is the number of tokens to use from previous eval calls
C.int(runtime.GOMAXPROCS(0)))
if evalErr != 0 {
log.Printf("llama_eval: %d", evalErr)
http.Error(w, "Internal error", 500)
return
}
if err := r.Context().Err(); err != nil {
return
}
//
log.Println("doing llama_sample_top_p_top_k...")
penalizeStart := 0
penalizeLen := i
if i > ParamRepeatPenaltyWindowSize {
penalizeStart = i - ParamRepeatPenaltyWindowSize
penalizeLen = ParamRepeatPenaltyWindowSize
}
newTokenId := C.llama_sample_top_p_top_k(lcontext,
// Penalize recent tokens
&llast_n_tokens[penalizeStart], C.int(penalizeLen),
// Other static parameters
ParamTopK, ParamTopP, ParamTemperature, ParamRepeatPenalty)
if newTokenId == C.llama_token_eos() {
// The model doesn't have anything to say
return
}
if err := r.Context().Err(); err != nil {
return
}
log.Println("got token OK")
// The model did have something to say
tokenStr := C.GoString(C.llama_token_to_str(lcontext, newTokenId))
log.Printf("token is %q", tokenStr)
// Push this new token into the lembedding_ state, or else we'll just get it over and over again
llast_n_tokens[i] = newTokenId
// time.Sleep(1 * time.Second)
w.Write([]byte(tokenStr)) // fmt.Sprintf(" update %d", i)))
flusher.Flush()
}
}

7
cflags_linux_amd64.go Normal file
View File

@ -0,0 +1,7 @@
package main
/*
#cgo CFLAGS: -O3 -DNDEBUG -std=c11 -march=native -mtune=native -pthread
#cgo CXXFLAGS: -O3 -DNDEBUG -std=c++11 -march=native -mtune=native -pthread
*/
import "C"

7
cflags_linux_arm64.go Normal file
View File

@ -0,0 +1,7 @@
package main
/*
#cgo CFLAGS: -O3 -DNDEBUG -std=c11 -mcpu -pthread
#cgo CXXFLAGS: -O3 -DNDEBUG -std=c++11 -mcpu -pthread
*/
import "C"

7
go.mod Normal file
View File

@ -0,0 +1,7 @@
module code.ivysaur.me/llamacpphtmld
go 1.19
require (
github.com/google/uuid v1.3.0 // indirect
)

2
go.sum Normal file
View File

@ -0,0 +1,2 @@
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=

75
main.go Normal file
View File

@ -0,0 +1,75 @@
package main
import (
"fmt"
"log"
"net/http"
"os"
"strconv"
)
const (
AppTitle = `llamacpphtmld`
AppVersion = `0.0.0-dev` // should be overridden by go build argument
)
type Config struct {
NetBind string
ModelPath string
SimultaneousRequests int
}
func NewConfigFromEnv() (Config, error) {
ret := Config{
NetBind: os.Getenv(`LCH_NET_BIND`),
ModelPath: os.Getenv(`LCH_MODEL_PATH`),
}
SimultaneousRequests, err := strconv.Atoi(os.Getenv(`LCH_SIMULTANEOUS_REQUESTS`))
if err != nil {
return Config{}, fmt.Errorf("LCH_SIMULTANEOUS_REQUESTS: %w", err)
}
ret.SimultaneousRequests = SimultaneousRequests
if _, err := os.Stat(ret.ModelPath); err != nil {
return Config{}, fmt.Errorf("LCH_MODEL_PATH: %w", err)
}
return ret, nil
}
type Application struct {
cfg Config
sem chan (struct{})
}
func main() {
log.Printf("%s v%s", AppTitle, AppVersion)
cfg, err := NewConfigFromEnv()
if err != nil {
log.Fatal(err)
}
app := Application{
cfg: cfg,
sem: make(chan struct{}, cfg.SimultaneousRequests), // use a buffered channel as a semaphore
}
router := http.NewServeMux()
router.HandleFunc(`/`, func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(`Server`, AppTitle+`/`+AppVersion)
if r.Method == `GET` && r.URL.Path == `/` {
app.GET_Root(w, r)
} else if r.Method == `POST` && r.URL.Path == `/api/v1/generate` {
app.POST_Chat(w, r)
} else {
http.Error(w, "Not found", 404)
}
})
log.Printf("Listening on %s ...", cfg.NetBind)
log.Fatal(http.ListenAndServe(cfg.NetBind, router))
}

123
webui.go Normal file
View File

@ -0,0 +1,123 @@
package main
import (
"html"
"net/http"
"github.com/google/uuid"
)
func (this *Application) GET_Root(w http.ResponseWriter, r *http.Request) {
w.Header().Set(`Content-Type`, `text/html;charset=UTF-8`)
w.WriteHeader(200)
w.Write([]byte(`<!DOCTYPE html>
<html>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>` + html.EscapeString(AppTitle) + `</title>
<style type="text/css">
html {
font-family: sans-serif;
}
textarea {
border-radius: 4px;
display: block;
width: 100%;
min-height: 100px;
background: #fff;
transition: background-color 0.5s ease-out;
}
textarea.alert {
background: lightyellow;
transition: initial;
}
button {
margin-top: 8px;
padding: 4px 6px;
}
</style>
<body>
<h2>🦙 ` + html.EscapeString(AppTitle) + `</h2>
<textarea id="main" autofocus></textarea>
<button id="generate"> Generate</button>
<button id="interrupt" disabled> Interrupt</button>
<script type="text/javascript">
function main() {
let conversationID = "` + uuid.New().String() + `";
const apiKey = "public-web-interface";
const $generate = document.getElementById("generate");
const $interrupt = document.getElementById("interrupt");
const $main = document.getElementById("main");
$generate.addEventListener('click', async function() {
const content = $main.value;
if (content.split(" ").length >= 2047) {
if (! confirm("Warning: high prompt length, the model will forget part of the content. Are you sure you want to continue?")) {
return;
}
}
$main.readOnly = true;
$generate.disabled = true;
try {
const controller = new AbortController();
const response = await fetch("/api/v1/generate", {
method: "POST",
signal: controller.signal,
headers: {
"Content-Type": "application/json"
},
body: JSON.stringify({
"ConversationID": conversationID,
"APIKey": apiKey,
"Content": content
})
});
$interrupt.disabled = false;
const doInterrupt = () => {
controller.abort();
$interrupt.removeEventListener('click', doInterrupt);
};
$interrupt.addEventListener('click', doInterrupt);
const reader = response.body.getReader();
const decoder = new TextDecoder();
for(;;) {
const singleReadResult = await reader.read();
if (singleReadResult.done) {
break;
}
$main.value += decoder.decode(singleReadResult.value);
$main.className = 'alert';
setTimeout(() => { $main.className = ''; }, 1);
}
} catch (ex) {
alert(
"Error processing the request: " +
(ex instanceof Error ? ex.message : JSON.stringify(ex))
);
return;
} finally {
$main.readOnly = false;
$generate.disabled = false;
$interrupt.disabled = true;
}
});
}
main();
</script>
</body>
</html>
`))
}