initial commit
This commit is contained in:
parent
7c6a0cdaa2
commit
d044a9e424
162
api.go
Normal file
162
api.go
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"runtime"
|
||||||
|
)
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include "llama.h"
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
func (this *Application) POST_Chat(w http.ResponseWriter, r *http.Request) {
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "The HTTP request does not support live updates", 500)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse API POST request
|
||||||
|
|
||||||
|
type requestBody struct {
|
||||||
|
ConversationID string
|
||||||
|
APIKey string
|
||||||
|
Content string
|
||||||
|
}
|
||||||
|
var apiParams requestBody
|
||||||
|
err := json.NewDecoder(r.Body).Decode(&apiParams)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify API key
|
||||||
|
// TODO
|
||||||
|
|
||||||
|
// Wait for a free worker
|
||||||
|
select {
|
||||||
|
case this.sem <- struct{}{}:
|
||||||
|
// OK
|
||||||
|
case <-r.Context().Done():
|
||||||
|
return // Request cancelled while waiting for free worker
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() { <-this.sem }()
|
||||||
|
|
||||||
|
// Queue request to worker
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/plain;charset=UTF-8")
|
||||||
|
w.WriteHeader(200)
|
||||||
|
flusher.Flush() // Flush before any responses, so the webui knows things are happening
|
||||||
|
|
||||||
|
// Constant LLaMA parameters
|
||||||
|
|
||||||
|
const (
|
||||||
|
ParamContextSize = 512 // RAM requirements: 512 needs 800MB KV (~3216MB overall), 2048 needs 3200MB KV (~??? overall)
|
||||||
|
ParamTopK = 40
|
||||||
|
ParamTopP = 0.95
|
||||||
|
ParamTemperature = 0.08
|
||||||
|
ParamRepeatPenalty = 1.10
|
||||||
|
ParamRepeatPenaltyWindowSize = 64
|
||||||
|
)
|
||||||
|
|
||||||
|
// Start a new LLaMA session
|
||||||
|
|
||||||
|
lparams := C.llama_context_default_params()
|
||||||
|
lparams.n_ctx = ParamContextSize
|
||||||
|
|
||||||
|
lcontext := C.llama_init_from_file(C.CString(this.cfg.ModelPath), lparams)
|
||||||
|
defer C.llama_free(lcontext)
|
||||||
|
|
||||||
|
// Feed in the conversation so far
|
||||||
|
// TODO stash the contexts for reuse with LRU-caching, for faster startup/resumption
|
||||||
|
|
||||||
|
apiParams.Content = " " + apiParams.Content // Add leading space to match LLaMA python behaviour
|
||||||
|
|
||||||
|
llast_n_tokens := make([]C.llama_token, ParamContextSize)
|
||||||
|
|
||||||
|
log.Println("tokenizing supplied prompt...")
|
||||||
|
|
||||||
|
llast_n_tokens_used_size := C.llama_tokenize(lcontext, C.CString(apiParams.Content), &llast_n_tokens[0], ParamContextSize, true)
|
||||||
|
if llast_n_tokens_used_size <= 0 {
|
||||||
|
log.Printf("llama_tokenize returned non-positive size (%d)", llast_n_tokens_used_size)
|
||||||
|
http.Error(w, "Internal error", 500)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ParamContextSize - int(llast_n_tokens_used_size)) <= 4 {
|
||||||
|
http.Error(w, "Prompt is too long", 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// We've consumed all the input.
|
||||||
|
|
||||||
|
for i := int(llast_n_tokens_used_size); i < ParamContextSize; i += 1 {
|
||||||
|
if err := r.Context().Err(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the next token from LLaMA
|
||||||
|
|
||||||
|
log.Println("doing llama_eval...")
|
||||||
|
|
||||||
|
evalErr := C.llama_eval(lcontext,
|
||||||
|
&llast_n_tokens[0], C.int(i), // tokens + n_tokens is the provided batch of new tokens to process
|
||||||
|
C.int(i), // n_past is the number of tokens to use from previous eval calls
|
||||||
|
C.int(runtime.GOMAXPROCS(0)))
|
||||||
|
if evalErr != 0 {
|
||||||
|
log.Printf("llama_eval: %d", evalErr)
|
||||||
|
http.Error(w, "Internal error", 500)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.Context().Err(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
|
||||||
|
log.Println("doing llama_sample_top_p_top_k...")
|
||||||
|
|
||||||
|
penalizeStart := 0
|
||||||
|
penalizeLen := i
|
||||||
|
if i > ParamRepeatPenaltyWindowSize {
|
||||||
|
penalizeStart = i - ParamRepeatPenaltyWindowSize
|
||||||
|
penalizeLen = ParamRepeatPenaltyWindowSize
|
||||||
|
}
|
||||||
|
|
||||||
|
newTokenId := C.llama_sample_top_p_top_k(lcontext,
|
||||||
|
|
||||||
|
// Penalize recent tokens
|
||||||
|
&llast_n_tokens[penalizeStart], C.int(penalizeLen),
|
||||||
|
|
||||||
|
// Other static parameters
|
||||||
|
ParamTopK, ParamTopP, ParamTemperature, ParamRepeatPenalty)
|
||||||
|
|
||||||
|
if newTokenId == C.llama_token_eos() {
|
||||||
|
// The model doesn't have anything to say
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.Context().Err(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("got token OK")
|
||||||
|
|
||||||
|
// The model did have something to say
|
||||||
|
tokenStr := C.GoString(C.llama_token_to_str(lcontext, newTokenId))
|
||||||
|
|
||||||
|
log.Printf("token is %q", tokenStr)
|
||||||
|
|
||||||
|
// Push this new token into the lembedding_ state, or else we'll just get it over and over again
|
||||||
|
llast_n_tokens[i] = newTokenId
|
||||||
|
|
||||||
|
// time.Sleep(1 * time.Second)
|
||||||
|
w.Write([]byte(tokenStr)) // fmt.Sprintf(" update %d", i)))
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
7
cflags_linux_amd64.go
Normal file
7
cflags_linux_amd64.go
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo CFLAGS: -O3 -DNDEBUG -std=c11 -march=native -mtune=native -pthread
|
||||||
|
#cgo CXXFLAGS: -O3 -DNDEBUG -std=c++11 -march=native -mtune=native -pthread
|
||||||
|
*/
|
||||||
|
import "C"
|
7
cflags_linux_arm64.go
Normal file
7
cflags_linux_arm64.go
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo CFLAGS: -O3 -DNDEBUG -std=c11 -mcpu -pthread
|
||||||
|
#cgo CXXFLAGS: -O3 -DNDEBUG -std=c++11 -mcpu -pthread
|
||||||
|
*/
|
||||||
|
import "C"
|
7
go.mod
Normal file
7
go.mod
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
module code.ivysaur.me/llamacpphtmld
|
||||||
|
|
||||||
|
go 1.19
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/google/uuid v1.3.0 // indirect
|
||||||
|
)
|
2
go.sum
Normal file
2
go.sum
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||||
|
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
75
main.go
Normal file
75
main.go
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
AppTitle = `llamacpphtmld`
|
||||||
|
AppVersion = `0.0.0-dev` // should be overridden by go build argument
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
NetBind string
|
||||||
|
ModelPath string
|
||||||
|
SimultaneousRequests int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConfigFromEnv() (Config, error) {
|
||||||
|
ret := Config{
|
||||||
|
NetBind: os.Getenv(`LCH_NET_BIND`),
|
||||||
|
ModelPath: os.Getenv(`LCH_MODEL_PATH`),
|
||||||
|
}
|
||||||
|
|
||||||
|
SimultaneousRequests, err := strconv.Atoi(os.Getenv(`LCH_SIMULTANEOUS_REQUESTS`))
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("LCH_SIMULTANEOUS_REQUESTS: %w", err)
|
||||||
|
}
|
||||||
|
ret.SimultaneousRequests = SimultaneousRequests
|
||||||
|
|
||||||
|
if _, err := os.Stat(ret.ModelPath); err != nil {
|
||||||
|
return Config{}, fmt.Errorf("LCH_MODEL_PATH: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Application struct {
|
||||||
|
cfg Config
|
||||||
|
sem chan (struct{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
log.Printf("%s v%s", AppTitle, AppVersion)
|
||||||
|
|
||||||
|
cfg, err := NewConfigFromEnv()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
app := Application{
|
||||||
|
cfg: cfg,
|
||||||
|
sem: make(chan struct{}, cfg.SimultaneousRequests), // use a buffered channel as a semaphore
|
||||||
|
}
|
||||||
|
|
||||||
|
router := http.NewServeMux()
|
||||||
|
router.HandleFunc(`/`, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set(`Server`, AppTitle+`/`+AppVersion)
|
||||||
|
|
||||||
|
if r.Method == `GET` && r.URL.Path == `/` {
|
||||||
|
app.GET_Root(w, r)
|
||||||
|
} else if r.Method == `POST` && r.URL.Path == `/api/v1/generate` {
|
||||||
|
app.POST_Chat(w, r)
|
||||||
|
} else {
|
||||||
|
http.Error(w, "Not found", 404)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
log.Printf("Listening on %s ...", cfg.NetBind)
|
||||||
|
|
||||||
|
log.Fatal(http.ListenAndServe(cfg.NetBind, router))
|
||||||
|
}
|
123
webui.go
Normal file
123
webui.go
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"html"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (this *Application) GET_Root(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set(`Content-Type`, `text/html;charset=UTF-8`)
|
||||||
|
w.WriteHeader(200)
|
||||||
|
w.Write([]byte(`<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||||
|
<title>` + html.EscapeString(AppTitle) + `</title>
|
||||||
|
<style type="text/css">
|
||||||
|
html {
|
||||||
|
font-family: sans-serif;
|
||||||
|
}
|
||||||
|
textarea {
|
||||||
|
border-radius: 4px;
|
||||||
|
display: block;
|
||||||
|
width: 100%;
|
||||||
|
min-height: 100px;
|
||||||
|
|
||||||
|
background: #fff;
|
||||||
|
transition: background-color 0.5s ease-out;
|
||||||
|
}
|
||||||
|
textarea.alert {
|
||||||
|
background: lightyellow;
|
||||||
|
transition: initial;
|
||||||
|
}
|
||||||
|
button {
|
||||||
|
margin-top: 8px;
|
||||||
|
padding: 4px 6px;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
<body>
|
||||||
|
<h2>🦙 ` + html.EscapeString(AppTitle) + `</h2>
|
||||||
|
<textarea id="main" autofocus></textarea>
|
||||||
|
<button id="generate">▶️ Generate</button>
|
||||||
|
<button id="interrupt" disabled>⏸️ Interrupt</button>
|
||||||
|
<script type="text/javascript">
|
||||||
|
function main() {
|
||||||
|
let conversationID = "` + uuid.New().String() + `";
|
||||||
|
const apiKey = "public-web-interface";
|
||||||
|
|
||||||
|
const $generate = document.getElementById("generate");
|
||||||
|
const $interrupt = document.getElementById("interrupt");
|
||||||
|
const $main = document.getElementById("main");
|
||||||
|
|
||||||
|
$generate.addEventListener('click', async function() {
|
||||||
|
const content = $main.value;
|
||||||
|
if (content.split(" ").length >= 2047) {
|
||||||
|
if (! confirm("Warning: high prompt length, the model will forget part of the content. Are you sure you want to continue?")) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
$main.readOnly = true;
|
||||||
|
$generate.disabled = true;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const controller = new AbortController();
|
||||||
|
|
||||||
|
const response = await fetch("/api/v1/generate", {
|
||||||
|
method: "POST",
|
||||||
|
signal: controller.signal,
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
"ConversationID": conversationID,
|
||||||
|
"APIKey": apiKey,
|
||||||
|
"Content": content
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
$interrupt.disabled = false;
|
||||||
|
const doInterrupt = () => {
|
||||||
|
controller.abort();
|
||||||
|
$interrupt.removeEventListener('click', doInterrupt);
|
||||||
|
};
|
||||||
|
$interrupt.addEventListener('click', doInterrupt);
|
||||||
|
|
||||||
|
const reader = response.body.getReader();
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
|
||||||
|
for(;;) {
|
||||||
|
const singleReadResult = await reader.read();
|
||||||
|
if (singleReadResult.done) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
$main.value += decoder.decode(singleReadResult.value);
|
||||||
|
$main.className = 'alert';
|
||||||
|
setTimeout(() => { $main.className = ''; }, 1);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (ex) {
|
||||||
|
alert(
|
||||||
|
"Error processing the request: " +
|
||||||
|
(ex instanceof Error ? ex.message : JSON.stringify(ex))
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
|
||||||
|
} finally {
|
||||||
|
$main.readOnly = false;
|
||||||
|
$generate.disabled = false;
|
||||||
|
$interrupt.disabled = true;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
main();
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
`))
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user