diff --git a/AuthorHash.go b/AuthorHash.go new file mode 100644 index 0000000..b894aba --- /dev/null +++ b/AuthorHash.go @@ -0,0 +1,15 @@ +package yatwiki3 + +import ( + "crypto/md5" + "encoding/hex" + "net/http" + "strings" +) + +func Author(r *http.Request) string { + userAgentHash := md5.Sum([]byte(r.UserAgent())) + ipAddr := strings.TrimRight(strings.TrimRight(r.RemoteAddr, `0123456789`), `:`) // trim trailing port; IPv4 and IPv6-safe + + return ipAddr + "-" + hex.EncodeToString(userAgentHash[:])[:6] +} diff --git a/DB.go b/DB.go index 05223c3..d493ebf 100644 --- a/DB.go +++ b/DB.go @@ -1,30 +1,29 @@ package yatwiki3 import ( - "bytes" - "compress/flate" - "crypto/md5" "database/sql" - "encoding/hex" "fmt" - "io" - "net/http" "time" _ "github.com/mattn/go-sqlite3" ) type WikiDB struct { - db *sql.DB + db *sql.DB + compressionLevel int } -func NewWikiDB(dbFilePath string) (*WikiDB, error) { +func NewWikiDB(dbFilePath string, compressionLevel int) (*WikiDB, error) { db, err := sql.Open("sqlite3", dbFilePath) if err != nil { return nil, err } - wdb := WikiDB{db: db} + wdb := WikiDB{ + db: db, + compressionLevel: compressionLevel, + } + err = wdb.assertSchema() if err != nil { return nil, fmt.Errorf("assertSchema: %s", err.Error()) @@ -73,23 +72,14 @@ func (this *WikiDB) assertSchema() error { } type Article struct { - ID int - TitleID int + ID int64 + TitleID int64 Modified int64 Body []byte Author string Title string } -func (this *Article) FillModifiedTimestamp() { - this.Modified = time.Now().Unix() -} - -func (this *Article) FillAuthor(r *http.Request) { - userAgentHash := md5.Sum([]byte(r.UserAgent())) - this.Author = r.RemoteAddr + "-" + hex.EncodeToString(userAgentHash[:])[:6] -} - func (this *WikiDB) GetArticleById(articleId int) (*Article, error) { row := this.db.QueryRow(`SELECT articles.* FROM articles WHERE id = ?`, articleId) return this.parseArticle(row) @@ -105,6 +95,57 @@ func (this *WikiDB) GetLatestVersion(title string) (*Article, error) { return this.parseArticle(row) } +type ArticleAlteredError struct { + got, expected int64 +} + +func (aae ArticleAlteredError) Error() string { + return fmt.Sprintf("Warning: Your changes were not based on the most recent version of the page (r%d ≠ r%d). No changes were saved.", aae.got, aae.expected) +} + +func (this *WikiDB) SaveArticle(title, author, body string, expectBaseRev int64) error { + isNewArticle := false + a, err := this.GetLatestVersion(title) + if err != nil { + if err == sql.ErrNoRows { + isNewArticle = true + } else { + return fmt.Errorf("Couldn't check for existing article title: %s", err.Error()) + } + } + + if !isNewArticle && a.ID != expectBaseRev { + return ArticleAlteredError{got: expectBaseRev, expected: a.ID} + } + + zBody, err := gzdeflate([]byte(body), this.compressionLevel) + if err != nil { + return err + } + + var titleId int64 + if isNewArticle { + titleInsert, err := this.db.Exec(`INSERT INTO titles (title) VALUES (?)`, title) + if err != nil { + return err + } + + titleId, err = titleInsert.LastInsertId() + if err != nil { + return err + } + } else { + titleId = a.TitleID + } + + _, err = this.db.Exec(`INSERT INTO articles (article, modified, body, author) VALUES (?, ?, ?, ?)`, titleId, time.Now().Unix(), zBody, author) + if err != nil { + return err + } + + return nil +} + func (this *WikiDB) GetRevisionHistory(title string) ([]Article, error) { rows, err := this.db.Query(`SELECT articles.id, articles.modified, articles.author FROM articles WHERE article = (SELECT id FROM titles WHERE title = ?) ORDER BY modified DESC`, title) if err != nil { @@ -177,21 +218,6 @@ func (this *WikiDB) ListTitles() ([]string, error) { return ret, nil } -func (this *WikiDB) gzinflate(gzBody []byte) ([]byte, error) { - gzBodyReader := bytes.NewReader(gzBody) - - gzReader := flate.NewReader(gzBodyReader) - defer gzReader.Close() - - buffer := bytes.Buffer{} - _, err := io.Copy(&buffer, gzReader) - if err != nil { - return nil, err - } - - return buffer.Bytes(), nil -} - func (this *WikiDB) parseArticle(row *sql.Row) (*Article, error) { a := Article{} var gzBody []byte @@ -200,7 +226,7 @@ func (this *WikiDB) parseArticle(row *sql.Row) (*Article, error) { return nil, err } - decompressed, err := this.gzinflate(gzBody) + decompressed, err := gzinflate(gzBody) if err != nil { return nil, err } @@ -218,7 +244,7 @@ func (this *WikiDB) parseArticleWithTitle(row *sql.Row) (*Article, error) { return nil, err } - decompressed, err := this.gzinflate(gzBody) + decompressed, err := gzinflate(gzBody) if err != nil { return nil, err } diff --git a/WikiServer.go b/WikiServer.go index 8a7cdf9..f4a6a37 100644 --- a/WikiServer.go +++ b/WikiServer.go @@ -19,7 +19,7 @@ type WikiServer struct { } func NewWikiServer(opts *ServerOptions) (*WikiServer, error) { - wdb, err := NewWikiDB(opts.DBFilePath) + wdb, err := NewWikiDB(opts.DBFilePath, opts.GzipCompressionLevel) if err != nil { return nil, err } @@ -141,6 +141,33 @@ func (this *WikiServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } + } else if r.Method == "POST" { + + if r.URL.Path == this.opts.ExpectBaseURL+"save" { + err := r.ParseForm() + if err != nil { + this.serveErrorMessage(w, err) + return + } + + title := r.Form.Get("pname") + body := r.Form.Get("content") + expectRev, err := strconv.Atoi(r.Form.Get("baserev")) + if err != nil { + this.serveErrorMessage(w, err) + return + } + + err = this.db.SaveArticle(title, Author(r), body, int64(expectRev)) + if err != nil { + this.serveErrorMessage(w, err) + return + } + + this.serveRedirect(w, this.opts.ExpectBaseURL+`view/`+url.QueryEscape(title)) + return + } + } // No match? Add 'Page not found' to next session response, and redirect to homepage diff --git a/gzflate.go b/gzflate.go new file mode 100644 index 0000000..decd5c2 --- /dev/null +++ b/gzflate.go @@ -0,0 +1,44 @@ +package yatwiki3 + +import ( + "bytes" + "compress/flate" + "io" +) + +func gzinflate(gzBody []byte) ([]byte, error) { + gzBodyReader := bytes.NewReader(gzBody) + + gzReader := flate.NewReader(gzBodyReader) + defer gzReader.Close() + + buffer := bytes.Buffer{} + _, err := io.Copy(&buffer, gzReader) + if err != nil { + return nil, err + } + + return buffer.Bytes(), nil +} + +func gzdeflate(plaintext []byte, level int) ([]byte, error) { + compressedContent := bytes.Buffer{} + + zipper, err := flate.NewWriter(&compressedContent, level) + if err != nil { + return nil, err // e.g. bad level + } + defer zipper.Close() + + _, err = zipper.Write(plaintext) + if err != nil { + return nil, err + } + + err = zipper.Close() // flush data + if err != nil { + return nil, err + } + + return compressedContent.Bytes(), nil +} diff --git a/rModify.go b/rModify.go index 309e25c..f031c45 100644 --- a/rModify.go +++ b/rModify.go @@ -26,7 +26,7 @@ func (this *WikiServer) routeModify(w http.ResponseWriter, r *http.Request, arti } var pageTitleHTML string - var baseRev int + var baseRev int64 var existingBody string if isNewArticle { pageTitleHTML = `Creating new article`