initial commit
This commit is contained in:
parent
7c6a0cdaa2
commit
d044a9e424
162
api.go
Normal file
162
api.go
Normal file
@ -0,0 +1,162 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
/*
|
||||
#include "llama.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
func (this *Application) POST_Chat(w http.ResponseWriter, r *http.Request) {
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "The HTTP request does not support live updates", 500)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse API POST request
|
||||
|
||||
type requestBody struct {
|
||||
ConversationID string
|
||||
APIKey string
|
||||
Content string
|
||||
}
|
||||
var apiParams requestBody
|
||||
err := json.NewDecoder(r.Body).Decode(&apiParams)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 400)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify API key
|
||||
// TODO
|
||||
|
||||
// Wait for a free worker
|
||||
select {
|
||||
case this.sem <- struct{}{}:
|
||||
// OK
|
||||
case <-r.Context().Done():
|
||||
return // Request cancelled while waiting for free worker
|
||||
}
|
||||
|
||||
defer func() { <-this.sem }()
|
||||
|
||||
// Queue request to worker
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain;charset=UTF-8")
|
||||
w.WriteHeader(200)
|
||||
flusher.Flush() // Flush before any responses, so the webui knows things are happening
|
||||
|
||||
// Constant LLaMA parameters
|
||||
|
||||
const (
|
||||
ParamContextSize = 512 // RAM requirements: 512 needs 800MB KV (~3216MB overall), 2048 needs 3200MB KV (~??? overall)
|
||||
ParamTopK = 40
|
||||
ParamTopP = 0.95
|
||||
ParamTemperature = 0.08
|
||||
ParamRepeatPenalty = 1.10
|
||||
ParamRepeatPenaltyWindowSize = 64
|
||||
)
|
||||
|
||||
// Start a new LLaMA session
|
||||
|
||||
lparams := C.llama_context_default_params()
|
||||
lparams.n_ctx = ParamContextSize
|
||||
|
||||
lcontext := C.llama_init_from_file(C.CString(this.cfg.ModelPath), lparams)
|
||||
defer C.llama_free(lcontext)
|
||||
|
||||
// Feed in the conversation so far
|
||||
// TODO stash the contexts for reuse with LRU-caching, for faster startup/resumption
|
||||
|
||||
apiParams.Content = " " + apiParams.Content // Add leading space to match LLaMA python behaviour
|
||||
|
||||
llast_n_tokens := make([]C.llama_token, ParamContextSize)
|
||||
|
||||
log.Println("tokenizing supplied prompt...")
|
||||
|
||||
llast_n_tokens_used_size := C.llama_tokenize(lcontext, C.CString(apiParams.Content), &llast_n_tokens[0], ParamContextSize, true)
|
||||
if llast_n_tokens_used_size <= 0 {
|
||||
log.Printf("llama_tokenize returned non-positive size (%d)", llast_n_tokens_used_size)
|
||||
http.Error(w, "Internal error", 500)
|
||||
return
|
||||
}
|
||||
|
||||
if (ParamContextSize - int(llast_n_tokens_used_size)) <= 4 {
|
||||
http.Error(w, "Prompt is too long", 400)
|
||||
return
|
||||
}
|
||||
|
||||
// We've consumed all the input.
|
||||
|
||||
for i := int(llast_n_tokens_used_size); i < ParamContextSize; i += 1 {
|
||||
if err := r.Context().Err(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Get the next token from LLaMA
|
||||
|
||||
log.Println("doing llama_eval...")
|
||||
|
||||
evalErr := C.llama_eval(lcontext,
|
||||
&llast_n_tokens[0], C.int(i), // tokens + n_tokens is the provided batch of new tokens to process
|
||||
C.int(i), // n_past is the number of tokens to use from previous eval calls
|
||||
C.int(runtime.GOMAXPROCS(0)))
|
||||
if evalErr != 0 {
|
||||
log.Printf("llama_eval: %d", evalErr)
|
||||
http.Error(w, "Internal error", 500)
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.Context().Err(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
log.Println("doing llama_sample_top_p_top_k...")
|
||||
|
||||
penalizeStart := 0
|
||||
penalizeLen := i
|
||||
if i > ParamRepeatPenaltyWindowSize {
|
||||
penalizeStart = i - ParamRepeatPenaltyWindowSize
|
||||
penalizeLen = ParamRepeatPenaltyWindowSize
|
||||
}
|
||||
|
||||
newTokenId := C.llama_sample_top_p_top_k(lcontext,
|
||||
|
||||
// Penalize recent tokens
|
||||
&llast_n_tokens[penalizeStart], C.int(penalizeLen),
|
||||
|
||||
// Other static parameters
|
||||
ParamTopK, ParamTopP, ParamTemperature, ParamRepeatPenalty)
|
||||
|
||||
if newTokenId == C.llama_token_eos() {
|
||||
// The model doesn't have anything to say
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.Context().Err(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("got token OK")
|
||||
|
||||
// The model did have something to say
|
||||
tokenStr := C.GoString(C.llama_token_to_str(lcontext, newTokenId))
|
||||
|
||||
log.Printf("token is %q", tokenStr)
|
||||
|
||||
// Push this new token into the lembedding_ state, or else we'll just get it over and over again
|
||||
llast_n_tokens[i] = newTokenId
|
||||
|
||||
// time.Sleep(1 * time.Second)
|
||||
w.Write([]byte(tokenStr)) // fmt.Sprintf(" update %d", i)))
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
7
cflags_linux_amd64.go
Normal file
7
cflags_linux_amd64.go
Normal file
@ -0,0 +1,7 @@
|
||||
package main
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -O3 -DNDEBUG -std=c11 -march=native -mtune=native -pthread
|
||||
#cgo CXXFLAGS: -O3 -DNDEBUG -std=c++11 -march=native -mtune=native -pthread
|
||||
*/
|
||||
import "C"
|
7
cflags_linux_arm64.go
Normal file
7
cflags_linux_arm64.go
Normal file
@ -0,0 +1,7 @@
|
||||
package main
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -O3 -DNDEBUG -std=c11 -mcpu -pthread
|
||||
#cgo CXXFLAGS: -O3 -DNDEBUG -std=c++11 -mcpu -pthread
|
||||
*/
|
||||
import "C"
|
7
go.mod
Normal file
7
go.mod
Normal file
@ -0,0 +1,7 @@
|
||||
module code.ivysaur.me/llamacpphtmld
|
||||
|
||||
go 1.19
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
)
|
2
go.sum
Normal file
2
go.sum
Normal file
@ -0,0 +1,2 @@
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
75
main.go
Normal file
75
main.go
Normal file
@ -0,0 +1,75 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
AppTitle = `llamacpphtmld`
|
||||
AppVersion = `0.0.0-dev` // should be overridden by go build argument
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
NetBind string
|
||||
ModelPath string
|
||||
SimultaneousRequests int
|
||||
}
|
||||
|
||||
func NewConfigFromEnv() (Config, error) {
|
||||
ret := Config{
|
||||
NetBind: os.Getenv(`LCH_NET_BIND`),
|
||||
ModelPath: os.Getenv(`LCH_MODEL_PATH`),
|
||||
}
|
||||
|
||||
SimultaneousRequests, err := strconv.Atoi(os.Getenv(`LCH_SIMULTANEOUS_REQUESTS`))
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("LCH_SIMULTANEOUS_REQUESTS: %w", err)
|
||||
}
|
||||
ret.SimultaneousRequests = SimultaneousRequests
|
||||
|
||||
if _, err := os.Stat(ret.ModelPath); err != nil {
|
||||
return Config{}, fmt.Errorf("LCH_MODEL_PATH: %w", err)
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
type Application struct {
|
||||
cfg Config
|
||||
sem chan (struct{})
|
||||
}
|
||||
|
||||
func main() {
|
||||
log.Printf("%s v%s", AppTitle, AppVersion)
|
||||
|
||||
cfg, err := NewConfigFromEnv()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
app := Application{
|
||||
cfg: cfg,
|
||||
sem: make(chan struct{}, cfg.SimultaneousRequests), // use a buffered channel as a semaphore
|
||||
}
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.HandleFunc(`/`, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set(`Server`, AppTitle+`/`+AppVersion)
|
||||
|
||||
if r.Method == `GET` && r.URL.Path == `/` {
|
||||
app.GET_Root(w, r)
|
||||
} else if r.Method == `POST` && r.URL.Path == `/api/v1/generate` {
|
||||
app.POST_Chat(w, r)
|
||||
} else {
|
||||
http.Error(w, "Not found", 404)
|
||||
}
|
||||
})
|
||||
|
||||
log.Printf("Listening on %s ...", cfg.NetBind)
|
||||
|
||||
log.Fatal(http.ListenAndServe(cfg.NetBind, router))
|
||||
}
|
123
webui.go
Normal file
123
webui.go
Normal file
@ -0,0 +1,123 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"html"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func (this *Application) GET_Root(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set(`Content-Type`, `text/html;charset=UTF-8`)
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`<!DOCTYPE html>
|
||||
<html>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>` + html.EscapeString(AppTitle) + `</title>
|
||||
<style type="text/css">
|
||||
html {
|
||||
font-family: sans-serif;
|
||||
}
|
||||
textarea {
|
||||
border-radius: 4px;
|
||||
display: block;
|
||||
width: 100%;
|
||||
min-height: 100px;
|
||||
|
||||
background: #fff;
|
||||
transition: background-color 0.5s ease-out;
|
||||
}
|
||||
textarea.alert {
|
||||
background: lightyellow;
|
||||
transition: initial;
|
||||
}
|
||||
button {
|
||||
margin-top: 8px;
|
||||
padding: 4px 6px;
|
||||
}
|
||||
</style>
|
||||
<body>
|
||||
<h2>🦙 ` + html.EscapeString(AppTitle) + `</h2>
|
||||
<textarea id="main" autofocus></textarea>
|
||||
<button id="generate">▶️ Generate</button>
|
||||
<button id="interrupt" disabled>⏸️ Interrupt</button>
|
||||
<script type="text/javascript">
|
||||
function main() {
|
||||
let conversationID = "` + uuid.New().String() + `";
|
||||
const apiKey = "public-web-interface";
|
||||
|
||||
const $generate = document.getElementById("generate");
|
||||
const $interrupt = document.getElementById("interrupt");
|
||||
const $main = document.getElementById("main");
|
||||
|
||||
$generate.addEventListener('click', async function() {
|
||||
const content = $main.value;
|
||||
if (content.split(" ").length >= 2047) {
|
||||
if (! confirm("Warning: high prompt length, the model will forget part of the content. Are you sure you want to continue?")) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
$main.readOnly = true;
|
||||
$generate.disabled = true;
|
||||
|
||||
try {
|
||||
const controller = new AbortController();
|
||||
|
||||
const response = await fetch("/api/v1/generate", {
|
||||
method: "POST",
|
||||
signal: controller.signal,
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
body: JSON.stringify({
|
||||
"ConversationID": conversationID,
|
||||
"APIKey": apiKey,
|
||||
"Content": content
|
||||
})
|
||||
});
|
||||
|
||||
$interrupt.disabled = false;
|
||||
const doInterrupt = () => {
|
||||
controller.abort();
|
||||
$interrupt.removeEventListener('click', doInterrupt);
|
||||
};
|
||||
$interrupt.addEventListener('click', doInterrupt);
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
for(;;) {
|
||||
const singleReadResult = await reader.read();
|
||||
if (singleReadResult.done) {
|
||||
break;
|
||||
}
|
||||
|
||||
$main.value += decoder.decode(singleReadResult.value);
|
||||
$main.className = 'alert';
|
||||
setTimeout(() => { $main.className = ''; }, 1);
|
||||
|
||||
}
|
||||
|
||||
} catch (ex) {
|
||||
alert(
|
||||
"Error processing the request: " +
|
||||
(ex instanceof Error ? ex.message : JSON.stringify(ex))
|
||||
);
|
||||
return;
|
||||
|
||||
} finally {
|
||||
$main.readOnly = false;
|
||||
$generate.disabled = false;
|
||||
$interrupt.disabled = true;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
main();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
`))
|
||||
}
|
Loading…
Reference in New Issue
Block a user