diff --git a/Makefile b/Makefile index 3499e3d..2e798aa 100644 --- a/Makefile +++ b/Makefile @@ -100,13 +100,16 @@ ui : force_look cd less/; $(MAKE) $(MFLAGS) assets : generate - go-bindata -pkg writefreely -ignore=\\.gitignore schema.sql sqlite.sql + go-bindata -pkg writefreely -ignore=\\.gitignore -tags="!wflib" schema.sql sqlite.sql assets-no-sqlite: generate - go-bindata -pkg writefreely -ignore=\\.gitignore schema.sql + go-bindata -pkg writefreely -ignore=\\.gitignore -tags="!wflib" schema.sql dev-assets : generate - go-bindata -pkg writefreely -ignore=\\.gitignore -debug schema.sql sqlite.sql + go-bindata -pkg writefreely -ignore=\\.gitignore -debug -tags="!wflib" schema.sql sqlite.sql + +lib-assets : generate + go-bindata -pkg writefreely -ignore=\\.gitignore -o bindata-lib.go -tags="wflib" schema.sql generate : @hash go-bindata > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ @@ -123,7 +126,7 @@ $(TMPBIN)/xgo: deps $(TMPBIN) $(GOBUILD) -o $(TMPBIN)/xgo github.com/karalabe/xgo ci-assets : $(TMPBIN)/go-bindata - $(TMPBIN)/go-bindata -pkg writefreely -ignore=\\.gitignore schema.sql sqlite.sql + $(TMPBIN)/go-bindata -pkg writefreely -ignore=\\.gitignore -tags="!wflib" schema.sql sqlite.sql clean : -rm -rf build diff --git a/account.go b/account.go index 06151c9..d8ea0df 100644 --- a/account.go +++ b/account.go @@ -1,5 +1,5 @@ /* - * Copyright © 2018 A Bunch Tell LLC. + * Copyright © 2018-2019 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -49,7 +49,7 @@ type ( } ) -func NewUserPage(app *app, r *http.Request, u *User, title string, flashes []string) *UserPage { +func NewUserPage(app *App, r *http.Request, u *User, title string, flashes []string) *UserPage { up := &UserPage{ StaticPage: pageForReq(app, r), PageTitle: title, @@ -73,12 +73,12 @@ const ( var actuallyUsernameReg = regexp.MustCompile("username is actually ([a-z0-9\\-]+)\\. Please try that, instead") -func apiSignup(app *app, w http.ResponseWriter, r *http.Request) error { +func apiSignup(app *App, w http.ResponseWriter, r *http.Request) error { _, err := signup(app, w, r) return err } -func signup(app *app, w http.ResponseWriter, r *http.Request) (*AuthUser, error) { +func signup(app *App, w http.ResponseWriter, r *http.Request) (*AuthUser, error) { reqJSON := IsJSON(r.Header.Get("Content-Type")) // Get params @@ -113,7 +113,7 @@ func signup(app *app, w http.ResponseWriter, r *http.Request) (*AuthUser, error) return signupWithRegistration(app, ur, w, r) } -func signupWithRegistration(app *app, signup userRegistration, w http.ResponseWriter, r *http.Request) (*AuthUser, error) { +func signupWithRegistration(app *App, signup userRegistration, w http.ResponseWriter, r *http.Request) (*AuthUser, error) { reqJSON := IsJSON(r.Header.Get("Content-Type")) // Validate required params (alias) @@ -154,7 +154,7 @@ func signupWithRegistration(app *app, signup userRegistration, w http.ResponseWr Created: time.Now().Truncate(time.Second).UTC(), } if signup.Email != "" { - encEmail, err := data.Encrypt(app.keys.emailKey, signup.Email) + encEmail, err := data.Encrypt(app.keys.EmailKey, signup.Email) if err != nil { log.Error("Unable to encrypt email: %s\n", err) } else { @@ -229,7 +229,7 @@ func signupWithRegistration(app *app, signup userRegistration, w http.ResponseWr return resUser, nil } -func viewLogout(app *app, w http.ResponseWriter, r *http.Request) error { +func viewLogout(app *App, w http.ResponseWriter, r *http.Request) error { session, err := app.sessionStore.Get(r, cookieName) if err != nil { return ErrInternalCookieSession @@ -268,7 +268,7 @@ func viewLogout(app *app, w http.ResponseWriter, r *http.Request) error { return impart.HTTPError{http.StatusFound, "/"} } -func handleAPILogout(app *app, w http.ResponseWriter, r *http.Request) error { +func handleAPILogout(app *App, w http.ResponseWriter, r *http.Request) error { accessToken := r.Header.Get("Authorization") if accessToken == "" { return ErrNoAccessToken @@ -284,7 +284,7 @@ func handleAPILogout(app *app, w http.ResponseWriter, r *http.Request) error { return impart.HTTPError{Status: http.StatusNoContent} } -func viewLogin(app *app, w http.ResponseWriter, r *http.Request) error { +func viewLogin(app *App, w http.ResponseWriter, r *http.Request) error { var earlyError string oneTimeToken := r.FormValue("with") if oneTimeToken != "" { @@ -333,7 +333,7 @@ func viewLogin(app *app, w http.ResponseWriter, r *http.Request) error { return nil } -func webLogin(app *app, w http.ResponseWriter, r *http.Request) error { +func webLogin(app *App, w http.ResponseWriter, r *http.Request) error { err := login(app, w, r) if err != nil { username := r.FormValue("alias") @@ -370,7 +370,7 @@ func webLogin(app *app, w http.ResponseWriter, r *http.Request) error { var loginAttemptUsers = sync.Map{} -func login(app *app, w http.ResponseWriter, r *http.Request) error { +func login(app *App, w http.ResponseWriter, r *http.Request) error { reqJSON := IsJSON(r.Header.Get("Content-Type")) oneTimeToken := r.FormValue("with") verbose := r.FormValue("all") == "true" || r.FormValue("verbose") == "1" || r.FormValue("verbose") == "true" || (reqJSON && oneTimeToken != "") @@ -534,7 +534,7 @@ func login(app *app, w http.ResponseWriter, r *http.Request) error { return nil } -func getVerboseAuthUser(app *app, token string, u *User, verbose bool) *AuthUser { +func getVerboseAuthUser(app *App, token string, u *User, verbose bool) *AuthUser { resUser := &AuthUser{ AccessToken: token, User: u, @@ -563,7 +563,7 @@ func getVerboseAuthUser(app *app, token string, u *User, verbose bool) *AuthUser return resUser } -func viewExportOptions(app *app, u *User, w http.ResponseWriter, r *http.Request) error { +func viewExportOptions(app *App, u *User, w http.ResponseWriter, r *http.Request) error { // Fetch extra user data p := NewUserPage(app, r, u, "Export", nil) @@ -571,7 +571,7 @@ func viewExportOptions(app *app, u *User, w http.ResponseWriter, r *http.Request return nil } -func viewExportPosts(app *app, w http.ResponseWriter, r *http.Request) ([]byte, string, error) { +func viewExportPosts(app *App, w http.ResponseWriter, r *http.Request) ([]byte, string, error) { var filename string var u = &User{} reqJSON := IsJSON(r.Header.Get("Content-Type")) @@ -635,7 +635,7 @@ func viewExportPosts(app *app, w http.ResponseWriter, r *http.Request) ([]byte, return data, filename, err } -func viewExportFull(app *app, w http.ResponseWriter, r *http.Request) ([]byte, string, error) { +func viewExportFull(app *App, w http.ResponseWriter, r *http.Request) ([]byte, string, error) { var err error filename := "" u := getUserSession(app, r) @@ -655,7 +655,7 @@ func viewExportFull(app *app, w http.ResponseWriter, r *http.Request) ([]byte, s return data, filename, err } -func viewMeAPI(app *app, w http.ResponseWriter, r *http.Request) error { +func viewMeAPI(app *App, w http.ResponseWriter, r *http.Request) error { reqJSON := IsJSON(r.Header.Get("Content-Type")) uObj := struct { ID int64 `json:"id,omitempty"` @@ -679,7 +679,7 @@ func viewMeAPI(app *app, w http.ResponseWriter, r *http.Request) error { return impart.WriteSuccess(w, uObj, http.StatusOK) } -func viewMyPostsAPI(app *app, u *User, w http.ResponseWriter, r *http.Request) error { +func viewMyPostsAPI(app *App, u *User, w http.ResponseWriter, r *http.Request) error { reqJSON := IsJSON(r.Header.Get("Content-Type")) if !reqJSON { return ErrBadRequestedType @@ -710,7 +710,7 @@ func viewMyPostsAPI(app *app, u *User, w http.ResponseWriter, r *http.Request) e return impart.WriteSuccess(w, p, http.StatusOK) } -func viewMyCollectionsAPI(app *app, u *User, w http.ResponseWriter, r *http.Request) error { +func viewMyCollectionsAPI(app *App, u *User, w http.ResponseWriter, r *http.Request) error { reqJSON := IsJSON(r.Header.Get("Content-Type")) if !reqJSON { return ErrBadRequestedType @@ -724,7 +724,7 @@ func viewMyCollectionsAPI(app *app, u *User, w http.ResponseWriter, r *http.Requ return impart.WriteSuccess(w, p, http.StatusOK) } -func viewArticles(app *app, u *User, w http.ResponseWriter, r *http.Request) error { +func viewArticles(app *App, u *User, w http.ResponseWriter, r *http.Request) error { p, err := app.db.GetAnonymousPosts(u) if err != nil { log.Error("unable to fetch anon posts: %v", err) @@ -761,7 +761,7 @@ func viewArticles(app *app, u *User, w http.ResponseWriter, r *http.Request) err return nil } -func viewCollections(app *app, u *User, w http.ResponseWriter, r *http.Request) error { +func viewCollections(app *App, u *User, w http.ResponseWriter, r *http.Request) error { c, err := app.db.GetCollections(u) if err != nil { log.Error("unable to fetch collections: %v", err) @@ -792,7 +792,7 @@ func viewCollections(app *app, u *User, w http.ResponseWriter, r *http.Request) return nil } -func viewEditCollection(app *app, u *User, w http.ResponseWriter, r *http.Request) error { +func viewEditCollection(app *App, u *User, w http.ResponseWriter, r *http.Request) error { vars := mux.Vars(r) c, err := app.db.GetCollection(vars["collection"]) if err != nil { @@ -815,7 +815,7 @@ func viewEditCollection(app *app, u *User, w http.ResponseWriter, r *http.Reques return nil } -func updateSettings(app *app, w http.ResponseWriter, r *http.Request) error { +func updateSettings(app *App, w http.ResponseWriter, r *http.Request) error { reqJSON := IsJSON(r.Header.Get("Content-Type")) var s userSettings @@ -904,7 +904,7 @@ func updateSettings(app *app, w http.ResponseWriter, r *http.Request) error { return nil } -func updatePassphrase(app *app, w http.ResponseWriter, r *http.Request) error { +func updatePassphrase(app *App, w http.ResponseWriter, r *http.Request) error { accessToken := r.Header.Get("Authorization") if accessToken == "" { return ErrNoAccessToken @@ -943,7 +943,7 @@ func updatePassphrase(app *app, w http.ResponseWriter, r *http.Request) error { return impart.WriteSuccess(w, struct{}{}, http.StatusOK) } -func viewStats(app *app, u *User, w http.ResponseWriter, r *http.Request) error { +func viewStats(app *App, u *User, w http.ResponseWriter, r *http.Request) error { var c *Collection var err error vars := mux.Vars(r) @@ -994,7 +994,7 @@ func viewStats(app *app, u *User, w http.ResponseWriter, r *http.Request) error return nil } -func viewSettings(app *app, u *User, w http.ResponseWriter, r *http.Request) error { +func viewSettings(app *App, u *User, w http.ResponseWriter, r *http.Request) error { fullUser, err := app.db.GetUserByID(u.ID) if err != nil { log.Error("Unable to get user for settings: %s", err) @@ -1025,7 +1025,7 @@ func viewSettings(app *app, u *User, w http.ResponseWriter, r *http.Request) err return nil } -func saveTempInfo(app *app, key, val string, r *http.Request, w http.ResponseWriter) error { +func saveTempInfo(app *App, key, val string, r *http.Request, w http.ResponseWriter) error { session, err := app.sessionStore.Get(r, "t") if err != nil { return ErrInternalCookieSession @@ -1039,7 +1039,7 @@ func saveTempInfo(app *app, key, val string, r *http.Request, w http.ResponseWri return err } -func getTempInfo(app *app, key string, r *http.Request, w http.ResponseWriter) string { +func getTempInfo(app *App, key string, r *http.Request, w http.ResponseWriter) string { session, err := app.sessionStore.Get(r, "t") if err != nil { return "" diff --git a/activitypub.go b/activitypub.go index 4d67a20..0ac4d0c 100644 --- a/activitypub.go +++ b/activitypub.go @@ -1,5 +1,5 @@ /* - * Copyright © 2018 A Bunch Tell LLC. + * Copyright © 2018-2019 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -24,7 +24,6 @@ import ( "strconv" "time" - "github.com/go-sql-driver/mysql" "github.com/gorilla/mux" "github.com/writeas/activity/streams" "github.com/writeas/httpsig" @@ -63,7 +62,7 @@ func (ru *RemoteUser) AsPerson() *activitystreams.Person { } } -func handleFetchCollectionActivities(app *app, w http.ResponseWriter, r *http.Request) error { +func handleFetchCollectionActivities(app *App, w http.ResponseWriter, r *http.Request) error { w.Header().Set("Server", serverSoftware) vars := mux.Vars(r) @@ -81,13 +80,14 @@ func handleFetchCollectionActivities(app *app, w http.ResponseWriter, r *http.Re if err != nil { return err } + c.hostName = app.cfg.App.Host p := c.PersonObject() return impart.RenderActivityJSON(w, p, http.StatusOK) } -func handleFetchCollectionOutbox(app *app, w http.ResponseWriter, r *http.Request) error { +func handleFetchCollectionOutbox(app *App, w http.ResponseWriter, r *http.Request) error { w.Header().Set("Server", serverSoftware) vars := mux.Vars(r) @@ -105,6 +105,7 @@ func handleFetchCollectionOutbox(app *app, w http.ResponseWriter, r *http.Reques if err != nil { return err } + c.hostName = app.cfg.App.Host if app.cfg.App.SingleUser { if alias != c.Alias { @@ -139,7 +140,7 @@ func handleFetchCollectionOutbox(app *app, w http.ResponseWriter, r *http.Reques return impart.RenderActivityJSON(w, ocp, http.StatusOK) } -func handleFetchCollectionFollowers(app *app, w http.ResponseWriter, r *http.Request) error { +func handleFetchCollectionFollowers(app *App, w http.ResponseWriter, r *http.Request) error { w.Header().Set("Server", serverSoftware) vars := mux.Vars(r) @@ -157,6 +158,7 @@ func handleFetchCollectionFollowers(app *app, w http.ResponseWriter, r *http.Req if err != nil { return err } + c.hostName = app.cfg.App.Host accountRoot := c.FederatedAccount() @@ -184,7 +186,7 @@ func handleFetchCollectionFollowers(app *app, w http.ResponseWriter, r *http.Req return impart.RenderActivityJSON(w, ocp, http.StatusOK) } -func handleFetchCollectionFollowing(app *app, w http.ResponseWriter, r *http.Request) error { +func handleFetchCollectionFollowing(app *App, w http.ResponseWriter, r *http.Request) error { w.Header().Set("Server", serverSoftware) vars := mux.Vars(r) @@ -202,6 +204,7 @@ func handleFetchCollectionFollowing(app *app, w http.ResponseWriter, r *http.Req if err != nil { return err } + c.hostName = app.cfg.App.Host accountRoot := c.FederatedAccount() @@ -219,7 +222,7 @@ func handleFetchCollectionFollowing(app *app, w http.ResponseWriter, r *http.Req return impart.RenderActivityJSON(w, ocp, http.StatusOK) } -func handleFetchCollectionInbox(app *app, w http.ResponseWriter, r *http.Request) error { +func handleFetchCollectionInbox(app *App, w http.ResponseWriter, r *http.Request) error { w.Header().Set("Server", serverSoftware) vars := mux.Vars(r) @@ -235,6 +238,7 @@ func handleFetchCollectionInbox(app *app, w http.ResponseWriter, r *http.Request // TODO: return Reject? return err } + c.hostName = app.cfg.App.Host if debugging { dump, err := httputil.DumpRequest(r, true) @@ -350,7 +354,7 @@ func handleFetchCollectionInbox(app *app, w http.ResponseWriter, r *http.Request log.Error("No to! %v", err) return } - err = makeActivityPost(p, fullActor.Inbox, am) + err = makeActivityPost(app.cfg.App.Host, p, fullActor.Inbox, am) if err != nil { log.Error("Unable to make activity POST: %v", err) return @@ -371,13 +375,7 @@ func handleFetchCollectionInbox(app *app, w http.ResponseWriter, r *http.Request // Add follower locally, since it wasn't found before res, err := t.Exec("INSERT INTO remoteusers (actor_id, inbox, shared_inbox) VALUES (?, ?, ?)", fullActor.ID, fullActor.Inbox, fullActor.Endpoints.SharedInbox) if err != nil { - if mysqlErr, ok := err.(*mysql.MySQLError); ok { - if mysqlErr.Number != mySQLErrDuplicateKey { - t.Rollback() - log.Error("Couldn't add new remoteuser in DB: %v\n", err) - return - } - } else { + if !app.db.isDuplicateKeyErr(err) { t.Rollback() log.Error("Couldn't add new remoteuser in DB: %v\n", err) return @@ -394,13 +392,7 @@ func handleFetchCollectionInbox(app *app, w http.ResponseWriter, r *http.Request // Add in key _, err = t.Exec("INSERT INTO remoteuserkeys (id, remote_user_id, public_key) VALUES (?, ?, ?)", fullActor.PublicKey.ID, followerID, fullActor.PublicKey.PublicKeyPEM) if err != nil { - if mysqlErr, ok := err.(*mysql.MySQLError); ok { - if mysqlErr.Number != mySQLErrDuplicateKey { - t.Rollback() - log.Error("Couldn't add follower keys in DB: %v\n", err) - return - } - } else { + if !app.db.isDuplicateKeyErr(err) { t.Rollback() log.Error("Couldn't add follower keys in DB: %v\n", err) return @@ -411,13 +403,7 @@ func handleFetchCollectionInbox(app *app, w http.ResponseWriter, r *http.Request // Add follow _, err = t.Exec("INSERT INTO remotefollows (collection_id, remote_user_id, created) VALUES (?, ?, "+app.db.now()+")", c.ID, followerID) if err != nil { - if mysqlErr, ok := err.(*mysql.MySQLError); ok { - if mysqlErr.Number != mySQLErrDuplicateKey { - t.Rollback() - log.Error("Couldn't add follower in DB: %v\n", err) - return - } - } else { + if !app.db.isDuplicateKeyErr(err) { t.Rollback() log.Error("Couldn't add follower in DB: %v\n", err) return @@ -442,7 +428,7 @@ func handleFetchCollectionInbox(app *app, w http.ResponseWriter, r *http.Request return nil } -func makeActivityPost(p *activitystreams.Person, url string, m interface{}) error { +func makeActivityPost(hostName string, p *activitystreams.Person, url string, m interface{}) error { log.Info("POST %s", url) b, err := json.Marshal(m) if err != nil { @@ -496,7 +482,7 @@ func makeActivityPost(p *activitystreams.Person, url string, m interface{}) erro return nil } -func resolveIRI(url string) ([]byte, error) { +func resolveIRI(hostName, url string) ([]byte, error) { log.Info("GET %s", url) r, _ := http.NewRequest("GET", url, nil) @@ -532,7 +518,7 @@ func resolveIRI(url string) ([]byte, error) { return body, nil } -func deleteFederatedPost(app *app, p *PublicPost, collID int64) error { +func deleteFederatedPost(app *App, p *PublicPost, collID int64) error { if debugging { log.Info("Deleting federated post!") } @@ -566,7 +552,7 @@ func deleteFederatedPost(app *app, p *PublicPost, collID int64) error { na.CC = append(na.CC, f) } - err = makeActivityPost(actor, si, activitystreams.NewDeleteActivity(na)) + err = makeActivityPost(app.cfg.App.Host, actor, si, activitystreams.NewDeleteActivity(na)) if err != nil { log.Error("Couldn't delete post! %v", err) } @@ -574,7 +560,7 @@ func deleteFederatedPost(app *app, p *PublicPost, collID int64) error { return nil } -func federatePost(app *app, p *PublicPost, collID int64, isUpdate bool) error { +func federatePost(app *App, p *PublicPost, collID int64, isUpdate bool) error { if debugging { if isUpdate { log.Info("Federating updated post!") @@ -620,7 +606,7 @@ func federatePost(app *app, p *PublicPost, collID int64, isUpdate bool) error { activity.To = na.To activity.CC = na.CC } - err = makeActivityPost(actor, si, activity) + err = makeActivityPost(app.cfg.App.Host, actor, si, activity) if err != nil { log.Error("Couldn't post! %v", err) } @@ -628,7 +614,7 @@ func federatePost(app *app, p *PublicPost, collID int64, isUpdate bool) error { return nil } -func getRemoteUser(app *app, actorID string) (*RemoteUser, error) { +func getRemoteUser(app *App, actorID string) (*RemoteUser, error) { u := RemoteUser{ActorID: actorID} err := app.db.QueryRow("SELECT id, inbox, shared_inbox FROM remoteusers WHERE actor_id = ?", actorID).Scan(&u.ID, &u.Inbox, &u.SharedInbox) switch { @@ -642,7 +628,7 @@ func getRemoteUser(app *app, actorID string) (*RemoteUser, error) { return &u, nil } -func getActor(app *app, actorIRI string) (*activitystreams.Person, *RemoteUser, error) { +func getActor(app *App, actorIRI string) (*activitystreams.Person, *RemoteUser, error) { log.Info("Fetching actor %s locally", actorIRI) actor := &activitystreams.Person{} remoteUser, err := getRemoteUser(app, actorIRI) @@ -651,7 +637,7 @@ func getActor(app *app, actorIRI string) (*activitystreams.Person, *RemoteUser, if iErr.Status == http.StatusNotFound { // Fetch remote actor log.Info("Not found; fetching actor %s remotely", actorIRI) - actorResp, err := resolveIRI(actorIRI) + actorResp, err := resolveIRI(app.cfg.App.Host, actorIRI) if err != nil { log.Error("Unable to get actor! %v", err) return nil, nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't fetch actor."} diff --git a/admin.go b/admin.go index 29f3501..7d21abd 100644 --- a/admin.go +++ b/admin.go @@ -1,5 +1,5 @@ /* - * Copyright © 2018 A Bunch Tell LLC. + * Copyright © 2018-2019 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -96,7 +96,7 @@ func (c instanceContent) UpdatedFriendly() string { return c.Updated.Format("January 2, 2006, 3:04 PM") } -func handleViewAdminDash(app *app, u *User, w http.ResponseWriter, r *http.Request) error { +func handleViewAdminDash(app *App, u *User, w http.ResponseWriter, r *http.Request) error { updateAppStats() p := struct { *UserPage @@ -117,7 +117,7 @@ func handleViewAdminDash(app *app, u *User, w http.ResponseWriter, r *http.Reque return nil } -func handleViewAdminUsers(app *app, u *User, w http.ResponseWriter, r *http.Request) error { +func handleViewAdminUsers(app *App, u *User, w http.ResponseWriter, r *http.Request) error { p := struct { *UserPage Config config.AppCfg @@ -157,7 +157,7 @@ func handleViewAdminUsers(app *app, u *User, w http.ResponseWriter, r *http.Requ return nil } -func handleViewAdminUser(app *app, u *User, w http.ResponseWriter, r *http.Request) error { +func handleViewAdminUser(app *App, u *User, w http.ResponseWriter, r *http.Request) error { vars := mux.Vars(r) username := vars["username"] if username == "" { @@ -229,7 +229,7 @@ func handleViewAdminUser(app *app, u *User, w http.ResponseWriter, r *http.Reque return nil } -func handleViewAdminPages(app *app, u *User, w http.ResponseWriter, r *http.Request) error { +func handleViewAdminPages(app *App, u *User, w http.ResponseWriter, r *http.Request) error { p := struct { *UserPage Config config.AppCfg @@ -287,7 +287,7 @@ func handleViewAdminPages(app *app, u *User, w http.ResponseWriter, r *http.Requ return nil } -func handleViewAdminPage(app *app, u *User, w http.ResponseWriter, r *http.Request) error { +func handleViewAdminPage(app *App, u *User, w http.ResponseWriter, r *http.Request) error { vars := mux.Vars(r) slug := vars["slug"] if slug == "" { @@ -329,7 +329,7 @@ func handleViewAdminPage(app *app, u *User, w http.ResponseWriter, r *http.Reque return nil } -func handleAdminUpdateSite(app *app, u *User, w http.ResponseWriter, r *http.Request) error { +func handleAdminUpdateSite(app *App, u *User, w http.ResponseWriter, r *http.Request) error { vars := mux.Vars(r) id := vars["page"] @@ -347,33 +347,34 @@ func handleAdminUpdateSite(app *app, u *User, w http.ResponseWriter, r *http.Req return impart.HTTPError{http.StatusFound, "/admin/page/" + id + m} } -func handleAdminUpdateConfig(app *app, u *User, w http.ResponseWriter, r *http.Request) error { - app.cfg.App.SiteName = r.FormValue("site_name") - app.cfg.App.SiteDesc = r.FormValue("site_desc") - app.cfg.App.OpenRegistration = r.FormValue("open_registration") == "on" +func handleAdminUpdateConfig(apper Apper, u *User, w http.ResponseWriter, r *http.Request) error { + apper.App().cfg.App.SiteName = r.FormValue("site_name") + apper.App().cfg.App.SiteDesc = r.FormValue("site_desc") + apper.App().cfg.App.Landing = r.FormValue("landing") + apper.App().cfg.App.OpenRegistration = r.FormValue("open_registration") == "on" mul, err := strconv.Atoi(r.FormValue("min_username_len")) if err == nil { - app.cfg.App.MinUsernameLen = mul + apper.App().cfg.App.MinUsernameLen = mul } mb, err := strconv.Atoi(r.FormValue("max_blogs")) if err == nil { - app.cfg.App.MaxBlogs = mb + apper.App().cfg.App.MaxBlogs = mb } - app.cfg.App.Federation = r.FormValue("federation") == "on" - app.cfg.App.PublicStats = r.FormValue("public_stats") == "on" - app.cfg.App.Private = r.FormValue("private") == "on" - app.cfg.App.LocalTimeline = r.FormValue("local_timeline") == "on" - if app.cfg.App.LocalTimeline && app.timeline == nil { + apper.App().cfg.App.Federation = r.FormValue("federation") == "on" + apper.App().cfg.App.PublicStats = r.FormValue("public_stats") == "on" + apper.App().cfg.App.Private = r.FormValue("private") == "on" + apper.App().cfg.App.LocalTimeline = r.FormValue("local_timeline") == "on" + if apper.App().cfg.App.LocalTimeline && apper.App().timeline == nil { log.Info("Initializing local timeline...") - initLocalTimeline(app) + initLocalTimeline(apper.App()) } - app.cfg.App.UserInvites = r.FormValue("user_invites") - if app.cfg.App.UserInvites == "none" { - app.cfg.App.UserInvites = "" + apper.App().cfg.App.UserInvites = r.FormValue("user_invites") + if apper.App().cfg.App.UserInvites == "none" { + apper.App().cfg.App.UserInvites = "" } m := "?cm=Configuration+saved." - err = config.Save(app.cfg, app.cfgFile) + err = apper.SaveConfig(apper.App().cfg) if err != nil { m = "?cm=" + err.Error() } @@ -418,7 +419,7 @@ func updateAppStats() { sysStatus.NumGC = m.NumGC } -func adminResetPassword(app *app, u *User, newPass string) error { +func adminResetPassword(app *App, u *User, newPass string) error { hashedPass, err := auth.HashPass([]byte(newPass)) if err != nil { return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not create password hash: %v", err)} diff --git a/app.go b/app.go index 4cb7dab..3120368 100644 --- a/app.go +++ b/app.go @@ -1,5 +1,5 @@ /* - * Copyright © 2018 A Bunch Tell LLC. + * Copyright © 2018-2019 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -12,9 +12,9 @@ package writefreely import ( "database/sql" - "flag" "fmt" "html/template" + "io/ioutil" "net/http" "net/url" "os" @@ -25,18 +25,18 @@ import ( "syscall" "time" - _ "github.com/go-sql-driver/mysql" - "github.com/gorilla/mux" "github.com/gorilla/schema" "github.com/gorilla/sessions" "github.com/manifoldco/promptui" "github.com/writeas/go-strip-markdown" + "github.com/writeas/impart" "github.com/writeas/web-core/auth" "github.com/writeas/web-core/converter" "github.com/writeas/web-core/log" "github.com/writeas/writefreely/author" "github.com/writeas/writefreely/config" + "github.com/writeas/writefreely/key" "github.com/writeas/writefreely/migrations" "github.com/writeas/writefreely/page" ) @@ -57,27 +57,129 @@ var ( softwareVer = "0.9.0" // DEPRECATED VARS - // TODO: pass app.cfg into GetCollection* calls so we can get these values - // from Collection methods and we no longer need these. - hostName string isSingleUser bool ) -type app struct { +// App holds data and configuration for an individual WriteFreely instance. +type App struct { router *mux.Router + shttp *http.ServeMux db *datastore cfg *config.Config cfgFile string - keys *keychain + keys *key.Keychain sessionStore *sessions.CookieStore formDecoder *schema.Decoder timeline *localTimeline } +// DB returns the App's datastore +func (app *App) DB() *datastore { + return app.db +} + +// Router returns the App's router +func (app *App) Router() *mux.Router { + return app.router +} + +// Config returns the App's current configuration. +func (app *App) Config() *config.Config { + return app.cfg +} + +// SetConfig updates the App's Config to the given value. +func (app *App) SetConfig(cfg *config.Config) { + app.cfg = cfg +} + +// SetKeys updates the App's Keychain to the given value. +func (app *App) SetKeys(k *key.Keychain) { + app.keys = k +} + +// Apper is the interface for getting data into and out of a WriteFreely +// instance (or "App"). +// +// App returns the App for the current instance. +// +// LoadConfig reads an app configuration into the App, returning any error +// encountered. +// +// SaveConfig persists the current App configuration. +// +// LoadKeys reads the App's encryption keys and loads them into its +// key.Keychain. +type Apper interface { + App() *App + + LoadConfig() error + SaveConfig(*config.Config) error + + LoadKeys() error +} + +// App returns the App +func (app *App) App() *App { + return app +} + +// LoadConfig loads and parses a config file. +func (app *App) LoadConfig() error { + log.Info("Loading %s configuration...", app.cfgFile) + cfg, err := config.Load(app.cfgFile) + if err != nil { + log.Error("Unable to load configuration: %v", err) + os.Exit(1) + return err + } + app.cfg = cfg + return nil +} + +// SaveConfig saves the given Config to disk -- namely, to the App's cfgFile. +func (app *App) SaveConfig(c *config.Config) error { + return config.Save(c, app.cfgFile) +} + +// LoadKeys reads all needed keys from disk into the App. In order to use the +// configured `Server.KeysParentDir`, you must call initKeyPaths(App) before +// this. +func (app *App) LoadKeys() error { + var err error + app.keys = &key.Keychain{} + + if debugging { + log.Info(" %s", emailKeyPath) + } + app.keys.EmailKey, err = ioutil.ReadFile(emailKeyPath) + if err != nil { + return err + } + + if debugging { + log.Info(" %s", cookieAuthKeyPath) + } + app.keys.CookieAuthKey, err = ioutil.ReadFile(cookieAuthKeyPath) + if err != nil { + return err + } + + if debugging { + log.Info(" %s", cookieKeyPath) + } + app.keys.CookieKey, err = ioutil.ReadFile(cookieKeyPath) + if err != nil { + return err + } + + return nil +} + // handleViewHome shows page at root path. Will be the Pad if logged in and the // catch-all landing page otherwise. -func handleViewHome(app *app, w http.ResponseWriter, r *http.Request) error { +func handleViewHome(app *App, w http.ResponseWriter, r *http.Request) error { if app.cfg.App.SingleUser { // Render blog index return handleViewCollection(app, w, r) @@ -90,6 +192,10 @@ func handleViewHome(app *app, w http.ResponseWriter, r *http.Request) error { return handleViewPad(app, w, r) } + if land := app.cfg.App.LandingPath(); land != "/" { + return impart.HTTPError{http.StatusFound, land} + } + p := struct { page.StaticPage Flashes []template.HTML @@ -112,7 +218,7 @@ func handleViewHome(app *app, w http.ResponseWriter, r *http.Request) error { return renderPage(w, "landing.tmpl", p) } -func handleTemplatedPage(app *app, w http.ResponseWriter, r *http.Request, t *template.Template) error { +func handleTemplatedPage(app *App, w http.ResponseWriter, r *http.Request, t *template.Template) error { p := struct { page.StaticPage ContentTitle string @@ -158,7 +264,7 @@ func handleTemplatedPage(app *app, w http.ResponseWriter, r *http.Request, t *te return nil } -func pageForReq(app *app, r *http.Request) page.StaticPage { +func pageForReq(app *App, r *http.Request) page.StaticPage { p := page.StaticPage{ AppCfg: app.cfg.App, Path: r.URL.Path, @@ -187,226 +293,96 @@ func pageForReq(app *app, r *http.Request) page.StaticPage { return p } -var shttp = http.NewServeMux() var fileRegex = regexp.MustCompile("/([^/]*\\.[^/]*)$") -func Serve() { - // General options usable with other commands - debugPtr := flag.Bool("debug", false, "Enables debug logging.") - configFile := flag.String("c", "config.ini", "The configuration file to use") - - // Setup actions - createConfig := flag.Bool("create-config", false, "Creates a basic configuration and exits") - doConfig := flag.Bool("config", false, "Run the configuration process") - configSections := flag.String("sections", "server db app", "Which sections of the configuration to go through (requires --config), " + - "valid values are any combination of 'server', 'db' and 'app' " + - "example: writefreely --config --sections \"db app\"") - genKeys := flag.Bool("gen-keys", false, "Generate encryption and authentication keys") - createSchema := flag.Bool("init-db", false, "Initialize app database") - migrate := flag.Bool("migrate", false, "Migrate the database") - - // Admin actions - createAdmin := flag.String("create-admin", "", "Create an admin with the given username:password") - createUser := flag.String("create-user", "", "Create a regular user with the given username:password") - resetPassUser := flag.String("reset-pass", "", "Reset the given user's password") - outputVersion := flag.Bool("v", false, "Output the current version") - flag.Parse() +// Initialize loads the app configuration and initializes templates, keys, +// session, route handlers, and the database connection. +func Initialize(apper Apper, debug bool) (*App, error) { + debugging = debug - debugging = *debugPtr + apper.LoadConfig() - app := &app{ - cfgFile: *configFile, + // Load templates + err := InitTemplates(apper.App().Config()) + if err != nil { + return nil, fmt.Errorf("load templates: %s", err) } - if *outputVersion { - fmt.Println(serverSoftware + " " + softwareVer) - os.Exit(0) - } else if *createConfig { - log.Info("Creating configuration...") - c := config.New() - log.Info("Saving configuration %s...", app.cfgFile) - err := config.Save(c, app.cfgFile) - if err != nil { - log.Error("Unable to save configuration: %v", err) - os.Exit(1) - } - os.Exit(0) - } else if *doConfig { - if *configSections == "" { - *configSections = "server db app" - } - // let's check there aren't any garbage in the list - configSectionsArray := strings.Split(*configSections, " ") - for _, element := range configSectionsArray { - if element != "server" && element != "db" && element != "app" { - log.Error("Invalid argument to --sections. Valid arguments are only \"server\", \"db\" and \"app\"") - os.Exit(1) - } - } - d, err := config.Configure(app.cfgFile, *configSections) - if err != nil { - log.Error("Unable to configure: %v", err) - os.Exit(1) - } - if d.User != nil { - app.cfg = d.Config - connectToDatabase(app) - defer shutdown(app) - - if !app.db.DatabaseInitialized() { - err = adminInitDatabase(app) - if err != nil { - log.Error(err.Error()) - os.Exit(1) - } - } - - u := &User{ - Username: d.User.Username, - HashedPass: d.User.HashedPass, - Created: time.Now().Truncate(time.Second).UTC(), - } - - // Create blog - log.Info("Creating user %s...\n", u.Username) - err = app.db.CreateUser(u, app.cfg.App.SiteName) - if err != nil { - log.Error("Unable to create user: %s", err) - os.Exit(1) - } - log.Info("Done!") - } - os.Exit(0) - } else if *genKeys { - errStatus := 0 + // Load keys and set up session + initKeyPaths(apper.App()) // TODO: find a better way to do this, since it's unneeded in all Apper implementations + err = InitKeys(apper) + if err != nil { + return nil, fmt.Errorf("init keys: %s", err) + } + apper.App().InitSession() - // Read keys path from config - loadConfig(app) + apper.App().InitDecoder() - // Create keys dir if it doesn't exist yet - fullKeysDir := filepath.Join(app.cfg.Server.KeysParentDir, keysDir) - if _, err := os.Stat(fullKeysDir); os.IsNotExist(err) { - err = os.Mkdir(fullKeysDir, 0700) - if err != nil { - log.Error("%s", err) - os.Exit(1) - } - } + err = ConnectToDatabase(apper.App()) + if err != nil { + return nil, fmt.Errorf("connect to DB: %s", err) + } - // Generate keys - initKeyPaths(app) - err := generateKey(emailKeyPath) - if err != nil { - errStatus = 1 - } - err = generateKey(cookieAuthKeyPath) - if err != nil { - errStatus = 1 - } - err = generateKey(cookieKeyPath) - if err != nil { - errStatus = 1 - } + // Handle local timeline, if enabled + if apper.App().cfg.App.LocalTimeline { + log.Info("Initializing local timeline...") + initLocalTimeline(apper.App()) + } - os.Exit(errStatus) - } else if *createSchema { - loadConfig(app) - connectToDatabase(app) - defer shutdown(app) - err := adminInitDatabase(app) - if err != nil { - log.Error(err.Error()) - os.Exit(1) - } - os.Exit(0) - } else if *createAdmin != "" { - err := adminCreateUser(app, *createAdmin, true) - if err != nil { - log.Error(err.Error()) - os.Exit(1) - } - os.Exit(0) - } else if *createUser != "" { - err := adminCreateUser(app, *createUser, false) - if err != nil { - log.Error(err.Error()) - os.Exit(1) - } - os.Exit(0) - } else if *resetPassUser != "" { - // Connect to the database - loadConfig(app) - connectToDatabase(app) - defer shutdown(app) + return apper.App(), nil +} - // Fetch user - u, err := app.db.GetUserForAuth(*resetPassUser) - if err != nil { - log.Error("Get user: %s", err) - os.Exit(1) - } +func Serve(app *App, r *mux.Router) { + log.Info("Going to serve...") - // Prompt for new password - prompt := promptui.Prompt{ - Templates: &promptui.PromptTemplates{ - Success: "{{ . | bold | faint }}: ", - }, - Label: "New password", - Mask: '*', - } - newPass, err := prompt.Run() - if err != nil { - log.Error("%s", err) - os.Exit(1) - } + isSingleUser = app.cfg.App.SingleUser + app.cfg.Server.Dev = debugging - // Do the update - log.Info("Updating...") - err = adminResetPassword(app, u, newPass) - if err != nil { - log.Error("%s", err) - os.Exit(1) - } - log.Info("Success.") + // Handle shutdown + c := make(chan os.Signal, 2) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + go func() { + <-c + log.Info("Shutting down...") + shutdown(app) + log.Info("Done.") os.Exit(0) - } else if *migrate { - loadConfig(app) - connectToDatabase(app) - defer shutdown(app) - - err := migrations.Migrate(migrations.NewDatastore(app.db.DB, app.db.driverName)) - if err != nil { - log.Error("migrate: %s", err) - os.Exit(1) - } + }() - os.Exit(0) + // Start web application server + var bindAddress = app.cfg.Server.Bind + if bindAddress == "" { + bindAddress = "localhost" } + var err error + if app.cfg.IsSecureStandalone() { + log.Info("Serving redirects on http://%s:80", bindAddress) + go func() { + err = http.ListenAndServe( + fmt.Sprintf("%s:80", bindAddress), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, app.cfg.App.Host, http.StatusMovedPermanently) + })) + log.Error("Unable to start redirect server: %v", err) + }() - log.Info("Initializing...") - - loadConfig(app) - - hostName = app.cfg.App.Host - isSingleUser = app.cfg.App.SingleUser - app.cfg.Server.Dev = *debugPtr - - err := initTemplates(app.cfg) - if err != nil { - log.Error("load templates: %s", err) - os.Exit(1) + log.Info("Serving on https://%s:443", bindAddress) + log.Info("---") + err = http.ListenAndServeTLS( + fmt.Sprintf("%s:443", bindAddress), app.cfg.Server.TLSCertPath, app.cfg.Server.TLSKeyPath, r) + } else { + log.Info("Serving on http://%s:%d\n", bindAddress, app.cfg.Server.Port) + log.Info("---") + err = http.ListenAndServe(fmt.Sprintf("%s:%d", bindAddress, app.cfg.Server.Port), r) } - - // Load keys - log.Info("Loading encryption keys...") - initKeyPaths(app) - err = initKeys(app) if err != nil { - log.Error("\n%s\n", err) + log.Error("Unable to start: %v", err) + os.Exit(1) } +} +func (app *App) InitDecoder() { + // TODO: do this at the package level, instead of the App level // Initialize modules - app.sessionStore = initSession(app) app.formDecoder = schema.NewDecoder() app.formDecoder.RegisterConverter(converter.NullJSONString{}, converter.ConvertJSONNullString) app.formDecoder.RegisterConverter(converter.NullJSONBool{}, converter.ConvertJSONNullBool) @@ -414,11 +390,14 @@ func Serve() { app.formDecoder.RegisterConverter(sql.NullBool{}, converter.ConvertSQLNullBool) app.formDecoder.RegisterConverter(sql.NullInt64{}, converter.ConvertSQLNullInt64) app.formDecoder.RegisterConverter(sql.NullFloat64{}, converter.ConvertSQLNullFloat64) +} +// ConnectToDatabase validates and connects to the configured database, then +// tests the connection. +func ConnectToDatabase(app *App) error { // Check database configuration if app.cfg.Database.Type == driverMySQL && (app.cfg.Database.User == "" || app.cfg.Database.Password == "") { - log.Error("Database user or password not set.") - os.Exit(1) + return fmt.Errorf("Database user or password not set.") } if app.cfg.Database.Host == "" { app.cfg.Database.Host = "localhost" @@ -427,92 +406,190 @@ func Serve() { app.cfg.Database.Database = "writefreely" } + // TODO: check err connectToDatabase(app) - defer shutdown(app) // Test database connection - err = app.db.Ping() + err := app.db.Ping() if err != nil { - log.Error("Database ping failed: %s", err) + return fmt.Errorf("Database ping failed: %s", err) } - r := mux.NewRouter() - handler := NewHandler(app) - handler.SetErrorPages(&ErrorPages{ - NotFound: pages["404-general.tmpl"], - Gone: pages["410.tmpl"], - InternalServerError: pages["500.tmpl"], - Blank: pages["blank.tmpl"], - }) + return nil +} - // Handle app routes - initRoutes(handler, r, app.cfg, app.db) +// OutputVersion prints out the version of the application. +func OutputVersion() { + fmt.Println(serverSoftware + " " + softwareVer) +} - // Handle local timeline, if enabled - if app.cfg.App.LocalTimeline { - log.Info("Initializing local timeline...") - initLocalTimeline(app) +// NewApp creates a new app instance. +func NewApp(cfgFile string) *App { + return &App{ + cfgFile: cfgFile, + } +} + +// CreateConfig creates a default configuration and saves it to the app's cfgFile. +func CreateConfig(app *App) error { + log.Info("Creating configuration...") + c := config.New() + log.Info("Saving configuration %s...", app.cfgFile) + err := config.Save(c, app.cfgFile) + if err != nil { + return fmt.Errorf("Unable to save configuration: %v", err) } + return nil +} - // Handle static files - fs := http.FileServer(http.Dir(filepath.Join(app.cfg.Server.StaticParentDir, staticDir))) - shttp.Handle("/", fs) - r.PathPrefix("/").Handler(fs) +// DoConfig runs the interactive configuration process. +func DoConfig(app *App, configSections string) { + if configSections == "" { + configSections = "server db app" + } + // let's check there aren't any garbage in the list + configSectionsArray := strings.Split(configSections, " ") + for _, element := range configSectionsArray { + if element != "server" && element != "db" && element != "app" { + log.Error("Invalid argument to --sections. Valid arguments are only \"server\", \"db\" and \"app\"") + os.Exit(1) + } + } + d, err := config.Configure(app.cfgFile, configSections) + if err != nil { + log.Error("Unable to configure: %v", err) + os.Exit(1) + } + if d.User != nil { + app.cfg = d.Config + connectToDatabase(app) + defer shutdown(app) - // Handle shutdown - c := make(chan os.Signal, 2) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - go func() { - <-c - log.Info("Shutting down...") - shutdown(app) - log.Info("Done.") - os.Exit(0) - }() + if !app.db.DatabaseInitialized() { + err = adminInitDatabase(app) + if err != nil { + log.Error(err.Error()) + os.Exit(1) + } + } - http.Handle("/", r) + u := &User{ + Username: d.User.Username, + HashedPass: d.User.HashedPass, + Created: time.Now().Truncate(time.Second).UTC(), + } - // Start web application server - var bindAddress = app.cfg.Server.Bind - if bindAddress == "" { - bindAddress = "localhost" + // Create blog + log.Info("Creating user %s...\n", u.Username) + err = app.db.CreateUser(u, app.cfg.App.SiteName) + if err != nil { + log.Error("Unable to create user: %s", err) + os.Exit(1) + } + log.Info("Done!") } - if app.cfg.IsSecureStandalone() { - log.Info("Serving redirects on http://%s:80", bindAddress) - go func() { - err = http.ListenAndServe( - fmt.Sprintf("%s:80", bindAddress), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, app.cfg.App.Host, http.StatusMovedPermanently) - })) - log.Error("Unable to start redirect server: %v", err) - }() + os.Exit(0) +} - log.Info("Serving on https://%s:443", bindAddress) - log.Info("---") - err = http.ListenAndServeTLS( - fmt.Sprintf("%s:443", bindAddress), app.cfg.Server.TLSCertPath, app.cfg.Server.TLSKeyPath, nil) - } else { - log.Info("Serving on http://%s:%d\n", bindAddress, app.cfg.Server.Port) - log.Info("---") - err = http.ListenAndServe(fmt.Sprintf("%s:%d", bindAddress, app.cfg.Server.Port), nil) +// GenerateKeyFiles creates app encryption keys and saves them into the configured KeysParentDir. +func GenerateKeyFiles(app *App) error { + // Read keys path from config + app.LoadConfig() + + // Create keys dir if it doesn't exist yet + fullKeysDir := filepath.Join(app.cfg.Server.KeysParentDir, keysDir) + if _, err := os.Stat(fullKeysDir); os.IsNotExist(err) { + err = os.Mkdir(fullKeysDir, 0700) + if err != nil { + return err + } } + + // Generate keys + initKeyPaths(app) + // TODO: use something like https://github.com/hashicorp/go-multierror to return errors + var keyErrs error + err := generateKey(emailKeyPath) if err != nil { - log.Error("Unable to start: %v", err) - os.Exit(1) + keyErrs = err + } + err = generateKey(cookieAuthKeyPath) + if err != nil { + keyErrs = err } + err = generateKey(cookieKeyPath) + if err != nil { + keyErrs = err + } + + return keyErrs } -func loadConfig(app *app) { - log.Info("Loading %s configuration...", app.cfgFile) - cfg, err := config.Load(app.cfgFile) +// CreateSchema creates all database tables needed for the application. +func CreateSchema(apper Apper) error { + apper.LoadConfig() + connectToDatabase(apper.App()) + defer shutdown(apper.App()) + err := adminInitDatabase(apper.App()) if err != nil { - log.Error("Unable to load configuration: %v", err) + return err + } + return nil +} + +// Migrate runs all necessary database migrations. +func Migrate(app *App) error { + app.LoadConfig() + connectToDatabase(app) + defer shutdown(app) + + err := migrations.Migrate(migrations.NewDatastore(app.db.DB, app.db.driverName)) + if err != nil { + return fmt.Errorf("migrate: %s", err) + } + return nil +} + +// ResetPassword runs the interactive password reset process. +func ResetPassword(app *App, username string) error { + // Connect to the database + app.LoadConfig() + connectToDatabase(app) + defer shutdown(app) + + // Fetch user + u, err := app.db.GetUserForAuth(username) + if err != nil { + log.Error("Get user: %s", err) os.Exit(1) } - app.cfg = cfg + + // Prompt for new password + prompt := promptui.Prompt{ + Templates: &promptui.PromptTemplates{ + Success: "{{ . | bold | faint }}: ", + }, + Label: "New password", + Mask: '*', + } + newPass, err := prompt.Run() + if err != nil { + log.Error("%s", err) + os.Exit(1) + } + + // Do the update + log.Info("Updating...") + err = adminResetPassword(app, u, newPass) + if err != nil { + log.Error("%s", err) + os.Exit(1) + } + log.Info("Success.") + return nil } -func connectToDatabase(app *app) { +func connectToDatabase(app *App) { log.Info("Connecting to %s database...", app.cfg.Database.Type) var db *sql.DB @@ -542,28 +619,20 @@ func connectToDatabase(app *app) { app.db = &datastore{db, app.cfg.Database.Type} } -func shutdown(app *app) { +func shutdown(app *App) { log.Info("Closing database connection...") app.db.Close() } -func adminCreateUser(app *app, credStr string, isAdmin bool) error { +// CreateUser creates a new admin or normal user from the given credentials. +func CreateUser(apper Apper, username, password string, isAdmin bool) error { // Create an admin user with --create-admin - creds := strings.Split(credStr, ":") - if len(creds) != 2 { - c := "user" - if isAdmin { - c = "admin" - } - return fmt.Errorf("usage: writefreely --create-%s username:password", c) - } - - loadConfig(app) - connectToDatabase(app) - defer shutdown(app) + apper.LoadConfig() + connectToDatabase(apper.App()) + defer shutdown(apper.App()) // Ensure an admin / first user doesn't already exist - firstUser, _ := app.db.GetUserByID(1) + firstUser, _ := apper.App().db.GetUserByID(1) if isAdmin { // Abort if trying to create admin user, but one already exists if firstUser != nil { @@ -577,9 +646,6 @@ func adminCreateUser(app *app, credStr string, isAdmin bool) error { } // Create the user - username := creds[0] - password := creds[1] - // Normalize and validate username desiredUsername := username username = getSlug(username, "") @@ -589,8 +655,8 @@ func adminCreateUser(app *app, credStr string, isAdmin bool) error { usernameDesc += " (originally: " + desiredUsername + ")" } - if !author.IsValidUsername(app.cfg, username) { - return fmt.Errorf("Username %s is invalid, reserved, or shorter than configured minimum length (%d characters).", usernameDesc, app.cfg.App.MinUsernameLen) + if !author.IsValidUsername(apper.App().cfg, username) { + return fmt.Errorf("Username %s is invalid, reserved, or shorter than configured minimum length (%d characters).", usernameDesc, apper.App().cfg.App.MinUsernameLen) } // Hash the password @@ -610,7 +676,7 @@ func adminCreateUser(app *app, credStr string, isAdmin bool) error { userType = "admin" } log.Info("Creating %s %s...", userType, usernameDesc) - err = app.db.CreateUser(u, desiredUsername) + err = apper.App().db.CreateUser(u, desiredUsername) if err != nil { return fmt.Errorf("Unable to create user: %s", err) } @@ -618,7 +684,7 @@ func adminCreateUser(app *app, credStr string, isAdmin bool) error { return nil } -func adminInitDatabase(app *app) error { +func adminInitDatabase(app *App) error { schemaFileName := "schema.sql" if app.cfg.Database.Type == driverSQLite { schemaFileName = "sqlite.sql" diff --git a/bindata-lib.go b/bindata-lib.go new file mode 100644 index 0000000..74429dc --- /dev/null +++ b/bindata-lib.go @@ -0,0 +1,105 @@ +// +build wflib + +package writefreely + +import ( + "bytes" + "compress/gzip" + "fmt" + "io" + "strings" +) + +func bindata_read(data []byte, name string) ([]byte, error) { + gz, err := gzip.NewReader(bytes.NewBuffer(data)) + if err != nil { + return nil, fmt.Errorf("Read %q: %v", name, err) + } + + var buf bytes.Buffer + _, err = io.Copy(&buf, gz) + gz.Close() + + if err != nil { + return nil, fmt.Errorf("Read %q: %v", name, err) + } + + return buf.Bytes(), nil +} + +var _schema_sql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xd4\x59\x5f\x6f\xa3\x38\x10\x7f\xef\xa7\xf0\xdb\xa6\x52\x23\x6d\x7a\xdd\xaa\xba\xd3\x3e\x64\x53\x76\x2f\xba\x94\xee\x25\x44\xba\x7d\x02\x03\x93\xd4\xaa\xb1\x91\x6d\x92\xe6\xdb\x9f\x8c\x49\x08\x86\x24\xd0\xdb\x3b\x71\x7d\x2a\xcc\x6f\x8c\xfd\x9b\x3f\x9e\x99\x0c\x87\x57\xc3\x21\x7a\xc4\x0a\x87\x58\xc2\xaf\x28\xd8\x0a\xa2\x60\x25\x00\xe8\x2e\xb8\x1a\x0e\xaf\xb4\x78\xf8\xce\x3f\xad\xac\xf5\x3d\x1c\x52\x40\x52\x89\x2c\x52\x99\x00\xb4\xe2\x02\xa9\xfc\x5d\x80\xa3\x08\xa4\x54\xfc\x15\x98\x34\xdf\x9b\xcc\x9d\xb1\xe7\x20\x6f\xfc\x65\xe6\xa0\xe9\x57\xe4\x3e\x7b\xc8\xf9\x6b\xba\xf0\x16\x16\x1a\x0d\xae\x10\x0a\xf2\x87\x00\x85\x84\x61\xb1\x1b\x8c\xee\xaf\x73\x05\x77\x39\x9b\xdd\x68\x71\x26\x41\xf8\x24\x0e\x10\x61\x6a\x60\x0b\x65\x16\xf3\x00\x29\xc2\x76\x5a\x3a\x2a\xa5\xe8\xd1\xf9\x3a\x5e\xce\x3c\xf4\xe1\xe3\x87\x1c\xc9\x19\xf8\x8a\x24\xd0\x0e\x1d\x09\xc0\x0a\xe2\x00\xc5\x58\x81\x56\xab\x43\x27\xcb\xf9\xdc\x71\x3d\xdf\x9b\x3e\x39\x0b\x6f\xfc\xf4\x3d\x57\x84\xb7\x94\x08\x90\x47\x8a\x7b\x7c\xf5\x40\x78\x0d\x4c\x05\x68\x83\x45\xf4\x82\xc5\xe0\xf6\xd3\xa7\xeb\x1a\xf2\xfb\x7c\xfa\x34\x9e\xff\x40\x7f\x38\x3f\xd0\xa0\xa0\xe9\xfa\xea\x1a\x39\xee\xb7\xa9\xeb\x7c\x9e\x32\xc6\x1f\xbf\x94\xfb\xf9\x7d\x3c\x5f\x38\xde\x67\x8a\x15\x61\xa3\xdf\xfe\x75\xb3\xa7\x69\xc4\x99\xd2\xa7\xb8\x6c\xf4\x12\x6b\x4c\xae\xcd\xb9\x3f\xfa\x2f\xb6\x4d\x0f\xd0\x04\x62\x92\x25\x0a\xde\x54\x7e\xb8\xf1\xc4\x73\xe6\x68\xe1\x78\x28\x53\xab\x07\x34\x79\x9e\xcd\xf4\x17\xf5\x83\x1f\x12\x66\x79\x4d\x1a\xbf\xcb\x80\x55\xce\x49\xdc\x2b\xc2\x13\xb2\x16\x58\x11\xde\x18\x68\x16\xc0\x10\xbd\x01\x21\x09\x67\x26\x78\x46\x23\x8b\x69\x03\x6f\x64\x29\x97\x0b\x90\x19\x55\x01\xca\x4d\xb0\x97\xf4\x85\x8f\x88\x53\x0a\x91\x3e\x2c\x56\x4a\x90\x30\x53\xd0\x22\xff\x34\x6a\x19\xae\x4a\xd1\xc9\x74\x73\xd0\x29\xdd\x77\x74\xfb\x60\x81\x36\x98\x66\x60\x85\x76\xdd\x7f\x93\xf0\xae\xe2\xc2\x49\x78\x57\xf3\xe2\xaa\x33\x56\xf7\x77\x73\xb4\x99\xde\xf8\x68\xb9\xc5\x57\xd8\x75\xb2\x46\x8e\x6f\x6d\x87\x34\x0b\x29\x89\xfc\x57\xd8\x05\x28\xa4\x3c\xb4\xa4\x82\x6c\xb0\x82\x13\xe2\x73\xa4\xf6\x90\xc8\x14\x4b\xb9\xe5\x22\xee\xc4\x66\xa9\xd4\x9e\xd2\x42\x25\x40\xb9\xd7\xde\x7f\xbc\xfe\x3f\xb3\x26\x20\x26\x02\x22\xd5\x89\xb5\x52\xc9\xb0\x96\x0a\xd8\xf8\x98\x12\x2c\x8f\xc2\xfd\xa3\x45\x4c\xc0\x60\x7b\x11\x54\x65\xef\x68\xdd\x1e\x52\xd7\x89\x32\x79\x74\xa1\x5b\x5e\x85\xc6\x4b\xef\xd9\x9f\xba\x93\xb9\xf3\xe4\xb8\x9e\xc9\x9f\x0d\x3c\xb5\x4f\x8d\xb5\x4a\x4a\x11\x45\x7f\x4e\xa6\x0d\x62\x90\x91\x20\xa9\xca\x2f\xcb\xc3\xfe\xee\x3b\xed\xaf\x5a\x99\xaa\x1d\x05\x5f\xbe\x00\x14\x17\xa8\x79\x9b\x7f\xa4\xb8\x51\x5b\xaf\x9c\xab\xae\xb8\x48\xf0\x51\xc9\xf8\x50\x2f\x18\x4d\xe6\x8b\x76\x8d\x35\xae\xa9\x82\xb7\xec\x4c\x35\xbd\x21\xb0\xf5\x23\x9e\xe9\xe2\xab\x41\x5e\xaf\x8d\xf4\xdb\xa5\x3b\xfd\x73\xe9\xe4\x2f\xf7\xf6\x1d\x04\x3d\xf3\xee\x94\xcb\x36\xa9\xc0\xc0\x4a\x8f\x2e\x9c\xc0\xee\x39\x68\xb6\xb6\x7c\xb8\x66\x88\x84\xc7\x64\xb5\xf3\x8b\xd6\xc6\xd4\xb9\xb7\x0d\x38\xed\x07\x3e\x4e\x53\xc0\x02\xb3\x08\x0a\xe8\x5d\x53\x67\xc2\xb8\x48\x4c\x73\x42\x31\x5b\x67\x78\xbd\x47\x37\xad\x2b\x14\xad\x38\xc1\x4f\xf0\x94\xda\x12\xcd\x97\x4a\xfd\x4b\x84\x31\x88\xfd\x94\x4b\x62\xa2\xeb\xe8\x8b\x4b\x77\x31\xfd\xe6\x3a\x8f\x0d\x8b\xef\x1b\x30\x5d\x95\x4a\x85\x93\xb4\x6d\x07\x76\xa8\xfc\x3b\x6b\x5e\x70\x7f\x3b\xdd\xfc\x93\xec\x70\xe8\x71\xba\x25\x82\x8e\xe1\x48\x62\xdf\x38\x6b\xbd\x78\xcc\xdf\xd7\x14\x4a\xa3\x0f\xca\xff\x6f\x0e\x6b\xe7\x98\xc2\x73\x0a\xd4\xde\x8f\x6e\x7a\xd5\x2b\x09\x48\xb8\x82\x15\xa7\x94\x6f\x5b\xc4\x7d\x15\x7e\xb2\x64\xaa\xf5\x4f\x46\xcf\xaf\x4c\x28\x6a\xa0\xd3\xa3\x84\xcb\x25\xbe\xf5\x81\x9e\xf1\xab\xb7\xd5\xae\xce\xb7\xf0\xf5\x21\x40\x7e\x75\x77\xe7\xf6\x6c\x1f\x70\x39\x3e\x8c\xc5\x0f\x1e\xdf\x7f\xb6\x3b\x51\x6d\xd7\x66\xc7\xec\x35\x16\x67\x91\xe2\x86\x8a\xd3\x56\x21\x2c\xe4\x6f\xe7\x00\xf2\x05\x0b\x88\xfd\x4b\xb8\xcb\xb6\xb1\xe2\x6f\x50\x6e\xaf\x37\x76\xd1\x24\x77\x99\x3d\x58\x78\x63\x9d\xb3\xe3\xcd\x86\x79\xc3\xfd\xdd\x7f\x34\x6e\xd8\x6f\xac\x97\x83\x06\xbd\x39\xc2\x36\xa4\x99\xf7\x8a\xd8\x2a\xe7\x6c\x8a\xab\x75\x4e\x7d\x44\x86\xdf\x74\x42\x90\x01\x92\x09\xa6\xf4\x64\x2d\x74\x36\xc9\xb7\x99\x0a\x13\x86\x23\x45\x36\xcd\xf3\xe9\x3e\xd1\xde\xd2\xd1\x3b\x76\x86\x5a\x85\xe1\x04\xde\xdd\x1c\x5e\x1a\x66\x54\x57\x32\x7c\x1d\x16\x32\x8f\xf5\x75\x20\xc1\x84\xe6\x5b\x2a\x7e\x9d\x68\x9c\xd3\xbf\xfb\xd7\x82\xcb\x59\xb0\xa4\x65\x50\xfe\xdf\xab\x28\x94\x26\xce\xe2\x53\x61\x78\x90\x17\xee\x90\x3f\xf9\x27\xc3\xf1\xe4\x7d\xdf\xfa\xcc\x7f\x07\x00\x00\xff\xff\xbe\x79\x68\xa8\x10\x1b\x00\x00") + +func schema_sql() ([]byte, error) { + return bindata_read( + _schema_sql, + "schema.sql", + ) +} + +// Asset loads and returns the asset for the given name. +// It returns an error if the asset could not be found or +// could not be loaded. +func Asset(name string) ([]byte, error) { + cannonicalName := strings.Replace(name, "\\", "/", -1) + if f, ok := _bindata[cannonicalName]; ok { + return f() + } + return nil, fmt.Errorf("Asset %s not found", name) +} + +// AssetNames returns the names of the assets. +func AssetNames() []string { + names := make([]string, 0, len(_bindata)) + for name := range _bindata { + names = append(names, name) + } + return names +} + +// _bindata is a table, holding each asset generator, mapped to its name. +var _bindata = map[string]func() ([]byte, error){ + "schema.sql": schema_sql, +} +// AssetDir returns the file names below a certain +// directory embedded in the file by go-bindata. +// For example if you run go-bindata on data/... and data contains the +// following hierarchy: +// data/ +// foo.txt +// img/ +// a.png +// b.png +// then AssetDir("data") would return []string{"foo.txt", "img"} +// AssetDir("data/img") would return []string{"a.png", "b.png"} +// AssetDir("foo.txt") and AssetDir("notexist") would return an error +// AssetDir("") will return []string{"data"}. +func AssetDir(name string) ([]string, error) { + node := _bintree + if len(name) != 0 { + cannonicalName := strings.Replace(name, "\\", "/", -1) + pathList := strings.Split(cannonicalName, "/") + for _, p := range pathList { + node = node.Children[p] + if node == nil { + return nil, fmt.Errorf("Asset %s not found", name) + } + } + } + if node.Func != nil { + return nil, fmt.Errorf("Asset %s not found", name) + } + rv := make([]string, 0, len(node.Children)) + for name := range node.Children { + rv = append(rv, name) + } + return rv, nil +} + +type _bintree_t struct { + Func func() ([]byte, error) + Children map[string]*_bintree_t +} +var _bintree = &_bintree_t{nil, map[string]*_bintree_t{ + "schema.sql": &_bintree_t{schema_sql, map[string]*_bintree_t{ + }}, +}} diff --git a/cmd/writefreely/main.go b/cmd/writefreely/main.go index e0a293c..98a984e 100644 --- a/cmd/writefreely/main.go +++ b/cmd/writefreely/main.go @@ -1,5 +1,5 @@ /* - * Copyright © 2018 A Bunch Tell LLC. + * Copyright © 2018-2019 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -11,9 +11,135 @@ package main import ( + "flag" + "fmt" + "github.com/gorilla/mux" + "github.com/writeas/web-core/log" "github.com/writeas/writefreely" + "os" + "strings" ) func main() { - writefreely.Serve() + // General options usable with other commands + debugPtr := flag.Bool("debug", false, "Enables debug logging.") + configFile := flag.String("c", "config.ini", "The configuration file to use") + + // Setup actions + createConfig := flag.Bool("create-config", false, "Creates a basic configuration and exits") + doConfig := flag.Bool("config", false, "Run the configuration process") + configSections := flag.String("sections", "server db app", "Which sections of the configuration to go through (requires --config), " + + "valid values are any combination of 'server', 'db' and 'app' " + + "example: writefreely --config --sections \"db app\"") + genKeys := flag.Bool("gen-keys", false, "Generate encryption and authentication keys") + createSchema := flag.Bool("init-db", false, "Initialize app database") + migrate := flag.Bool("migrate", false, "Migrate the database") + + // Admin actions + createAdmin := flag.String("create-admin", "", "Create an admin with the given username:password") + createUser := flag.String("create-user", "", "Create a regular user with the given username:password") + resetPassUser := flag.String("reset-pass", "", "Reset the given user's password") + outputVersion := flag.Bool("v", false, "Output the current version") + flag.Parse() + + app := writefreely.NewApp(*configFile) + + if *outputVersion { + writefreely.OutputVersion() + os.Exit(0) + } else if *createConfig { + err := writefreely.CreateConfig(app) + if err != nil { + log.Error(err.Error()) + os.Exit(1) + } + os.Exit(0) + } else if *doConfig { + writefreely.DoConfig(app, *configSections) + os.Exit(0) + } else if *genKeys { + err := writefreely.GenerateKeyFiles(app) + if err != nil { + log.Error(err.Error()) + os.Exit(1) + } + os.Exit(0) + } else if *createSchema { + err := writefreely.CreateSchema(app) + if err != nil { + log.Error(err.Error()) + os.Exit(1) + } + os.Exit(0) + } else if *createAdmin != "" { + username, password, err := userPass(*createAdmin, true) + if err != nil { + log.Error(err.Error()) + os.Exit(1) + } + err = writefreely.CreateUser(app, username, password, true) + if err != nil { + log.Error(err.Error()) + os.Exit(1) + } + os.Exit(0) + } else if *createUser != "" { + username, password, err := userPass(*createUser, false) + if err != nil { + log.Error(err.Error()) + os.Exit(1) + } + err = writefreely.CreateUser(app, username, password, false) + if err != nil { + log.Error(err.Error()) + os.Exit(1) + } + os.Exit(0) + } else if *resetPassUser != "" { + err := writefreely.ResetPassword(app, *resetPassUser) + if err != nil { + log.Error(err.Error()) + os.Exit(1) + } + os.Exit(0) + } else if *migrate { + err := writefreely.Migrate(app) + if err != nil { + log.Error(err.Error()) + os.Exit(1) + } + os.Exit(0) + } + + // Initialize the application + var err error + app, err = writefreely.Initialize(app, *debugPtr) + if err != nil { + log.Error("%s", err) + os.Exit(1) + } + + // Set app routes + r := mux.NewRouter() + writefreely.InitRoutes(app, r) + app.InitStaticRoutes(r) + + // Serve the application + writefreely.Serve(app, r) +} + +func userPass(credStr string, isAdmin bool) (user string, pass string, err error) { + creds := strings.Split(credStr, ":") + if len(creds) != 2 { + c := "user" + if isAdmin { + c = "admin" + } + err = fmt.Errorf("usage: writefreely --create-%s username:password", c) + return + } + + user = creds[0] + pass = creds[1] + return } diff --git a/collections.go b/collections.go index 391d153..1a8ceca 100644 --- a/collections.go +++ b/collections.go @@ -54,7 +54,8 @@ type ( PublicOwner bool `datastore:"public_owner" json:"-"` URL string `json:"url,omitempty"` - db *datastore + db *datastore + hostName string } CollectionObj struct { Collection @@ -211,10 +212,10 @@ func (c *Collection) DisplayCanonicalURL() string { func (c *Collection) RedirectingCanonicalURL(isRedir bool) string { if isSingleUser { - return hostName + "/" + return c.hostName + "/" } - return fmt.Sprintf("%s/%s/", hostName, c.Alias) + return fmt.Sprintf("%s/%s/", c.hostName, c.Alias) } // PrevPageURL provides a full URL for the previous page of collection posts, @@ -300,11 +301,11 @@ func (c *Collection) AvatarURL() string { if !isAvatarChar(fl) { return "" } - return hostName + "/img/avatars/" + fl + ".png" + return c.hostName + "/img/avatars/" + fl + ".png" } func (c *Collection) FederatedAPIBase() string { - return hostName + "/" + return c.hostName + "/" } func (c *Collection) FederatedAccount() string { @@ -316,7 +317,7 @@ func (c *Collection) RenderMathJax() bool { return c.db.CollectionHasAttribute(c.ID, "render_mathjax") } -func newCollection(app *app, w http.ResponseWriter, r *http.Request) error { +func newCollection(app *App, w http.ResponseWriter, r *http.Request) error { reqJSON := IsJSON(r.Header.Get("Content-Type")) alias := r.FormValue("alias") title := r.FormValue("title") @@ -399,7 +400,7 @@ func newCollection(app *app, w http.ResponseWriter, r *http.Request) error { return impart.HTTPError{http.StatusFound, redirectTo} } -func apiCheckCollectionPermissions(app *app, r *http.Request, c *Collection) (int64, error) { +func apiCheckCollectionPermissions(app *App, r *http.Request, c *Collection) (int64, error) { accessToken := r.Header.Get("Authorization") var userID int64 = -1 if accessToken != "" { @@ -419,7 +420,7 @@ func apiCheckCollectionPermissions(app *app, r *http.Request, c *Collection) (in } // fetchCollection handles the API endpoint for retrieving collection data. -func fetchCollection(app *app, w http.ResponseWriter, r *http.Request) error { +func fetchCollection(app *App, w http.ResponseWriter, r *http.Request) error { accept := r.Header.Get("Accept") if strings.Contains(accept, "application/activity+json") { return handleFetchCollectionActivities(app, w, r) @@ -434,6 +435,8 @@ func fetchCollection(app *app, w http.ResponseWriter, r *http.Request) error { if err != nil { return err } + c.hostName = app.cfg.App.Host + // Redirect users who aren't requesting JSON reqJSON := IsJSON(r.Header.Get("Content-Type")) if !reqJSON { @@ -467,7 +470,7 @@ func fetchCollection(app *app, w http.ResponseWriter, r *http.Request) error { // fetchCollectionPosts handles an API endpoint for retrieving a collection's // posts. -func fetchCollectionPosts(app *app, w http.ResponseWriter, r *http.Request) error { +func fetchCollectionPosts(app *App, w http.ResponseWriter, r *http.Request) error { vars := mux.Vars(r) alias := vars["alias"] @@ -475,6 +478,7 @@ func fetchCollectionPosts(app *app, w http.ResponseWriter, r *http.Request) erro if err != nil { return err } + c.hostName = app.cfg.App.Host // Check permissions userID, err := apiCheckCollectionPermissions(app, r, c) @@ -563,7 +567,7 @@ func processCollectionRequest(cr *collectionReq, vars map[string]string, w http. // domain that doesn't yet have a collection associated, or if a collection // requires a password. In either case, this will return nil, nil -- thus both // values should ALWAYS be checked to determine whether or not to continue. -func processCollectionPermissions(app *app, cr *collectionReq, u *User, w http.ResponseWriter, r *http.Request) (*Collection, error) { +func processCollectionPermissions(app *App, cr *collectionReq, u *User, w http.ResponseWriter, r *http.Request) (*Collection, error) { // Display collection if this is a collection var c *Collection var err error @@ -600,6 +604,7 @@ func processCollectionPermissions(app *app, cr *collectionReq, u *User, w http.R } return nil, err } + c.hostName = app.cfg.App.Host // Update CollectionRequest to reflect owner status cr.isCollOwner = u != nil && u.ID == c.OwnerID @@ -654,7 +659,7 @@ func processCollectionPermissions(app *app, cr *collectionReq, u *User, w http.R return c, nil } -func checkUserForCollection(app *app, cr *collectionReq, r *http.Request, isPostReq bool) (*User, error) { +func checkUserForCollection(app *App, cr *collectionReq, r *http.Request, isPostReq bool) (*User, error) { u := getUserSession(app, r) return u, nil } @@ -682,7 +687,7 @@ func getCollectionPage(vars map[string]string) int { } // handleViewCollection displays the requested Collection -func handleViewCollection(app *app, w http.ResponseWriter, r *http.Request) error { +func handleViewCollection(app *App, w http.ResponseWriter, r *http.Request) error { vars := mux.Vars(r) cr := &collectionReq{} @@ -788,7 +793,7 @@ func handleViewCollection(app *app, w http.ResponseWriter, r *http.Request) erro return err } -func handleViewCollectionTag(app *app, w http.ResponseWriter, r *http.Request) error { +func handleViewCollectionTag(app *App, w http.ResponseWriter, r *http.Request) error { vars := mux.Vars(r) tag := vars["tag"] @@ -867,7 +872,7 @@ func handleViewCollectionTag(app *app, w http.ResponseWriter, r *http.Request) e return nil } -func handleCollectionPostRedirect(app *app, w http.ResponseWriter, r *http.Request) error { +func handleCollectionPostRedirect(app *App, w http.ResponseWriter, r *http.Request) error { vars := mux.Vars(r) slug := vars["slug"] @@ -885,7 +890,7 @@ func handleCollectionPostRedirect(app *app, w http.ResponseWriter, r *http.Reque return impart.HTTPError{http.StatusFound, loc} } -func existingCollection(app *app, w http.ResponseWriter, r *http.Request) error { +func existingCollection(app *App, w http.ResponseWriter, r *http.Request) error { reqJSON := IsJSON(r.Header.Get("Content-Type")) vars := mux.Vars(r) collAlias := vars["alias"] @@ -980,7 +985,7 @@ func collectionAliasFromReq(r *http.Request) string { return alias } -func handleWebCollectionUnlock(app *app, w http.ResponseWriter, r *http.Request) error { +func handleWebCollectionUnlock(app *App, w http.ResponseWriter, r *http.Request) error { var readReq struct { Alias string `schema:"alias" json:"alias"` Pass string `schema:"password" json:"password"` @@ -1047,7 +1052,7 @@ func handleWebCollectionUnlock(app *app, w http.ResponseWriter, r *http.Request) return impart.HTTPError{http.StatusFound, next} } -func isAuthorizedForCollection(app *app, alias string, r *http.Request) bool { +func isAuthorizedForCollection(app *App, alias string, r *http.Request) bool { authd := false session, err := app.sessionStore.Get(r, blogPassCookieName) if err == nil { diff --git a/config/config.go b/config/config.go index 2b07bed..add5447 100644 --- a/config/config.go +++ b/config/config.go @@ -13,6 +13,7 @@ package config import ( "gopkg.in/ini.v1" + "strings" ) const ( @@ -64,6 +65,7 @@ type ( Theme string `ini:"theme"` JSDisabled bool `ini:"disable_js"` WebFonts bool `ini:"webfonts"` + Landing string `ini:"landing"` // Users SingleUser bool `ini:"single_user"` @@ -134,6 +136,13 @@ func (cfg *Config) IsSecureStandalone() bool { return cfg.Server.Port == 443 && cfg.Server.TLSCertPath != "" && cfg.Server.TLSKeyPath != "" } +func (ac *AppCfg) LandingPath() string { + if !strings.HasPrefix(ac.Landing, "/") { + return "/" + ac.Landing + } + return ac.Landing +} + // Load reads the given configuration file, then parses and returns it as a Config. func Load(fname string) (*Config, error) { if fname == "" { diff --git a/database-lib.go b/database-lib.go new file mode 100644 index 0000000..58beb05 --- /dev/null +++ b/database-lib.go @@ -0,0 +1,20 @@ +// +build wflib + +/* + * Copyright © 2019 A Bunch Tell LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + +// This file contains dummy database funcs for when writefreely is used as a +// library. + +package writefreely + +func (db *datastore) isDuplicateKeyErr(err error) bool { + return false +} diff --git a/database-no-sqlite.go b/database-no-sqlite.go index 10db7d5..a3d50fc 100644 --- a/database-no-sqlite.go +++ b/database-no-sqlite.go @@ -1,4 +1,4 @@ -// +build !sqlite +// +build !sqlite,!wflib /* * Copyright © 2019 A Bunch Tell LLC. diff --git a/database-sqlite.go b/database-sqlite.go index 5fa3f6c..3741169 100644 --- a/database-sqlite.go +++ b/database-sqlite.go @@ -1,4 +1,4 @@ -// +build sqlite +// +build sqlite,!wflib /* * Copyright © 2019 A Bunch Tell LLC. diff --git a/database.go b/database.go index 27769b0..3af659d 100644 --- a/database.go +++ b/database.go @@ -29,6 +29,7 @@ import ( "github.com/writeas/web-core/log" "github.com/writeas/web-core/query" "github.com/writeas/writefreely/author" + "github.com/writeas/writefreely/key" ) const ( @@ -44,7 +45,7 @@ var ( type writestore interface { CreateUser(*User, string) error - UpdateUserEmail(keys *keychain, userID int64, email string) error + UpdateUserEmail(keys *key.Keychain, userID int64, email string) error UpdateEncryptedUserEmail(int64, []byte) error GetUserByID(int64) (*User, error) GetUserForAuth(string) (*User, error) @@ -60,7 +61,7 @@ type writestore interface { GetTemporaryAccessToken(userID int64, validSecs int) (string, error) GetTemporaryOneTimeAccessToken(userID int64, validSecs int, oneTime bool) (string, error) DeleteAccount(userID int64) (l *string, err error) - ChangeSettings(app *app, u *User, s *userSettings) error + ChangeSettings(app *App, u *User, s *userSettings) error ChangePassphrase(userID int64, sudo bool, curPass string, hashedPass []byte) error GetCollections(u *User) (*[]Collection, error) @@ -219,8 +220,8 @@ func (db *datastore) CreateUser(u *User, collectionTitle string) error { // FIXME: We're returning errors inconsistently in this file. Do we use Errorf // for returned value, or impart? -func (db *datastore) UpdateUserEmail(keys *keychain, userID int64, email string) error { - encEmail, err := data.Encrypt(keys.emailKey, email) +func (db *datastore) UpdateUserEmail(keys *key.Keychain, userID int64, email string) error { + encEmail, err := data.Encrypt(keys.EmailKey, email) if err != nil { return fmt.Errorf("Couldn't encrypt email %s: %s\n", email, err) } @@ -1779,13 +1780,13 @@ func (db *datastore) GetUserPostsCount(userID int64) int64 { // ChangeSettings takes a User and applies the changes in the given // userSettings, MODIFYING THE USER with successful changes. -func (db *datastore) ChangeSettings(app *app, u *User, s *userSettings) error { +func (db *datastore) ChangeSettings(app *App, u *User, s *userSettings) error { var errPass error q := query.NewUpdate() // Update email if given if s.Email != "" { - encEmail, err := data.Encrypt(app.keys.emailKey, s.Email) + encEmail, err := data.Encrypt(app.keys.EmailKey, s.Email) if err != nil { log.Error("Couldn't encrypt email %s: %s\n", s.Email, err) return impart.HTTPError{http.StatusInternalServerError, "Unable to encrypt email address."} diff --git a/export.go b/export.go index c04d629..47a2603 100644 --- a/export.go +++ b/export.go @@ -1,5 +1,5 @@ /* - * Copyright © 2018 A Bunch Tell LLC. + * Copyright © 2018-2019 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -99,7 +99,7 @@ func exportPostsZip(u *User, posts *[]PublicPost) []byte { return b.Bytes() } -func compileFullExport(app *app, u *User) *ExportUser { +func compileFullExport(app *App, u *User) *ExportUser { exportUser := &ExportUser{ User: u, } diff --git a/feed.go b/feed.go index fc6478d..dd82c33 100644 --- a/feed.go +++ b/feed.go @@ -1,5 +1,5 @@ /* - * Copyright © 2018 A Bunch Tell LLC. + * Copyright © 2018-2019 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -20,7 +20,7 @@ import ( "time" ) -func ViewFeed(app *app, w http.ResponseWriter, req *http.Request) error { +func ViewFeed(app *App, w http.ResponseWriter, req *http.Request) error { alias := collectionAliasFromReq(req) // Display collection if this is a collection @@ -34,6 +34,7 @@ func ViewFeed(app *app, w http.ResponseWriter, req *http.Request) error { if err != nil { return nil } + c.hostName = app.cfg.App.Host if c.IsPrivate() || c.IsProtected() { return ErrCollectionNotFound diff --git a/handle.go b/handle.go index 706a2fa..81a4823 100644 --- a/handle.go +++ b/handle.go @@ -1,5 +1,5 @@ /* - * Copyright © 2018 A Bunch Tell LLC. + * Copyright © 2018-2019 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -36,16 +36,17 @@ const ( ) type ( - handlerFunc func(app *app, w http.ResponseWriter, r *http.Request) error - userHandlerFunc func(app *app, u *User, w http.ResponseWriter, r *http.Request) error - dataHandlerFunc func(app *app, w http.ResponseWriter, r *http.Request) ([]byte, string, error) - authFunc func(app *app, r *http.Request) (*User, error) + handlerFunc func(app *App, w http.ResponseWriter, r *http.Request) error + userHandlerFunc func(app *App, u *User, w http.ResponseWriter, r *http.Request) error + userApperHandlerFunc func(apper Apper, u *User, w http.ResponseWriter, r *http.Request) error + dataHandlerFunc func(app *App, w http.ResponseWriter, r *http.Request) ([]byte, string, error) + authFunc func(app *App, r *http.Request) (*User, error) ) type Handler struct { errors *ErrorPages sessionStore *sessions.CookieStore - app *app + app Apper } // ErrorPages hold template HTML error pages for displaying errors to the user. @@ -59,7 +60,7 @@ type ErrorPages struct { // NewHandler returns a new Handler instance, using the given StaticPage data, // and saving alias to the application's CookieStore. -func NewHandler(app *app) *Handler { +func NewHandler(apper Apper) *Handler { h := &Handler{ errors: &ErrorPages{ NotFound: template.Must(template.New("").Parse("{{define \"base\"}}404

Not found.

{{end}}")), @@ -67,13 +68,26 @@ func NewHandler(app *app) *Handler { InternalServerError: template.Must(template.New("").Parse("{{define \"base\"}}500

Internal server error.

{{end}}")), Blank: template.Must(template.New("").Parse("{{define \"base\"}}{{.Title}}

{{.Content}}

{{end}}")), }, - sessionStore: app.sessionStore, - app: app, + sessionStore: apper.App().sessionStore, + app: apper, } return h } +// NewWFHandler returns a new Handler instance, using WriteFreely template files. +// You MUST call writefreely.InitTemplates() before this. +func NewWFHandler(apper Apper) *Handler { + h := NewHandler(apper) + h.SetErrorPages(&ErrorPages{ + NotFound: pages["404-general.tmpl"], + Gone: pages["410.tmpl"], + InternalServerError: pages["500.tmpl"], + Blank: pages["blank.tmpl"], + }) + return h +} + // SetErrorPages sets the given set of ErrorPages as templates for any errors // that come up. func (h *Handler) SetErrorPages(e *ErrorPages) { @@ -91,21 +105,21 @@ func (h *Handler) User(f userHandlerFunc) http.HandlerFunc { defer func() { if e := recover(); e != nil { log.Error("%s: %s", e, debug.Stack()) - h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app, r)) + h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r)) status = http.StatusInternalServerError } log.Info("\"%s %s\" %d %s \"%s\"", r.Method, r.RequestURI, status, time.Since(start), r.UserAgent()) }() - u := getUserSession(h.app, r) + u := getUserSession(h.app.App(), r) if u == nil { err := ErrNotLoggedIn status = err.Status return err } - err := f(h.app, u, w, r) + err := f(h.app.App(), u, w, r) if err == nil { status = http.StatusOK } else if err, ok := err.(impart.HTTPError); ok { @@ -129,14 +143,52 @@ func (h *Handler) Admin(f userHandlerFunc) http.HandlerFunc { defer func() { if e := recover(); e != nil { log.Error("%s: %s", e, debug.Stack()) - h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app, r)) + h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r)) + status = http.StatusInternalServerError + } + + log.Info(fmt.Sprintf("\"%s %s\" %d %s \"%s\"", r.Method, r.RequestURI, status, time.Since(start), r.UserAgent())) + }() + + u := getUserSession(h.app.App(), r) + if u == nil || !u.IsAdmin() { + err := impart.HTTPError{http.StatusNotFound, ""} + status = err.Status + return err + } + + err := f(h.app.App(), u, w, r) + if err == nil { + status = http.StatusOK + } else if err, ok := err.(impart.HTTPError); ok { + status = err.Status + } else { + status = http.StatusInternalServerError + } + + return err + }()) + } +} + +// AdminApper handles requests on /admin routes that require an Apper. +func (h *Handler) AdminApper(f userApperHandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + h.handleHTTPError(w, r, func() error { + var status int + start := time.Now() + + defer func() { + if e := recover(); e != nil { + log.Error("%s: %s", e, debug.Stack()) + h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r)) status = http.StatusInternalServerError } log.Info(fmt.Sprintf("\"%s %s\" %d %s \"%s\"", r.Method, r.RequestURI, status, time.Since(start), r.UserAgent())) }() - u := getUserSession(h.app, r) + u := getUserSession(h.app.App(), r) if u == nil || !u.IsAdmin() { err := impart.HTTPError{http.StatusNotFound, ""} status = err.Status @@ -160,7 +212,7 @@ func (h *Handler) Admin(f userHandlerFunc) http.HandlerFunc { // UserAPI handles requests made in the API by the authenticated user. // This provides user-friendly HTML pages and actions that work in the browser. func (h *Handler) UserAPI(f userHandlerFunc) http.HandlerFunc { - return h.UserAll(false, f, func(app *app, r *http.Request) (*User, error) { + return h.UserAll(false, f, func(app *App, r *http.Request) (*User, error) { // Authorize user from Authorization header t := r.Header.Get("Authorization") if t == "" { @@ -191,7 +243,7 @@ func (h *Handler) UserAll(web bool, f userHandlerFunc, a authFunc) http.HandlerF log.Info("\"%s %s\" %d %s \"%s\"", r.Method, r.RequestURI, status, time.Since(start), r.UserAgent()) }() - u, err := a(h.app, r) + u, err := a(h.app.App(), r) if err != nil { if err, ok := err.(impart.HTTPError); ok { status = err.Status @@ -201,7 +253,7 @@ func (h *Handler) UserAll(web bool, f userHandlerFunc, a authFunc) http.HandlerF return err } - err = f(h.app, u, w, r) + err = f(h.app.App(), u, w, r) if err == nil { status = 200 } else if err, ok := err.(impart.HTTPError); ok { @@ -222,7 +274,7 @@ func (h *Handler) UserAll(web bool, f userHandlerFunc, a authFunc) http.HandlerF } func (h *Handler) RedirectOnErr(f handlerFunc, loc string) handlerFunc { - return func(app *app, w http.ResponseWriter, r *http.Request) error { + return func(app *App, w http.ResponseWriter, r *http.Request) error { err := f(app, w, r) if err != nil { if ie, ok := err.(impart.HTTPError); ok { @@ -239,7 +291,7 @@ func (h *Handler) RedirectOnErr(f handlerFunc, loc string) handlerFunc { } func (h *Handler) Page(n string) http.HandlerFunc { - return h.Web(func(app *app, w http.ResponseWriter, r *http.Request) error { + return h.Web(func(app *App, w http.ResponseWriter, r *http.Request) error { t, ok := pages[n] if !ok { return impart.HTTPError{http.StatusNotFound, "Page not found."} @@ -264,13 +316,13 @@ func (h *Handler) WebErrors(f handlerFunc, ul UserLevel) http.HandlerFunc { defer func() { if e := recover(); e != nil { - u := getUserSession(h.app, r) + u := getUserSession(h.app.App(), r) username := "None" if u != nil { username = u.Username } log.Error("User: %s\n\n%s: %s", username, e, debug.Stack()) - h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app, r)) + h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r)) status = 500 } @@ -302,13 +354,13 @@ func (h *Handler) WebErrors(f handlerFunc, ul UserLevel) http.HandlerFunc { } // TODO: pass User object to function - err = f(h.app, w, r) + err = f(h.app.App(), w, r) if err == nil { status = 200 } else if httpErr, ok := err.(impart.HTTPError); ok { status = httpErr.Status if status < 300 || status > 399 { - addSessionFlash(h.app, w, r, httpErr.Message, session) + addSessionFlash(h.app.App(), w, r, httpErr.Message, session) return impart.HTTPError{http.StatusFound, r.Referer()} } } else { @@ -319,7 +371,7 @@ func (h *Handler) WebErrors(f handlerFunc, ul UserLevel) http.HandlerFunc { log.Error(e) } log.Info("Web handler internal error render") - h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app, r)) + h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r)) status = 500 } @@ -338,14 +390,14 @@ func (h *Handler) Web(f handlerFunc, ul UserLevel) http.HandlerFunc { defer func() { if e := recover(); e != nil { - u := getUserSession(h.app, r) + u := getUserSession(h.app.App(), r) username := "None" if u != nil { username = u.Username } log.Error("User: %s\n\n%s: %s", username, e, debug.Stack()) log.Info("Web deferred internal error render") - h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app, r)) + h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r)) status = 500 } @@ -375,7 +427,7 @@ func (h *Handler) Web(f handlerFunc, ul UserLevel) http.HandlerFunc { } // TODO: pass User object to function - err := f(h.app, w, r) + err := f(h.app.App(), w, r) if err == nil { status = 200 } else if httpErr, ok := err.(impart.HTTPError); ok { @@ -384,7 +436,7 @@ func (h *Handler) Web(f handlerFunc, ul UserLevel) http.HandlerFunc { e := fmt.Sprintf("[Web handler] 500: %v", err) log.Error(e) log.Info("Web internal error render") - h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app, r)) + h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r)) status = 500 } @@ -412,7 +464,7 @@ func (h *Handler) All(f handlerFunc) http.HandlerFunc { // TODO: do any needed authentication - err := f(h.app, w, r) + err := f(h.app.App(), w, r) if err != nil { if err, ok := err.(impart.HTTPError); ok { status = err.Status @@ -434,14 +486,14 @@ func (h *Handler) Download(f dataHandlerFunc, ul UserLevel) http.HandlerFunc { defer func() { if e := recover(); e != nil { log.Error("%s: %s", e, debug.Stack()) - h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app, r)) + h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r)) status = 500 } log.Info("\"%s %s\" %d %s \"%s\"", r.Method, r.RequestURI, status, time.Since(start), r.UserAgent()) }() - data, filename, err := f(h.app, w, r) + data, filename, err := f(h.app.App(), w, r) if err != nil { if err, ok := err.(impart.HTTPError); ok { status = err.Status @@ -530,7 +582,7 @@ func (h *Handler) handleHTTPError(w http.ResponseWriter, r *http.Request, err er page.StaticPage Content *template.HTML }{ - StaticPage: pageForReq(h.app, r), + StaticPage: pageForReq(h.app.App(), r), } if err.Message != "" { co := template.HTML(err.Message) @@ -540,12 +592,12 @@ func (h *Handler) handleHTTPError(w http.ResponseWriter, r *http.Request, err er return } else if err.Status == http.StatusNotFound { w.WriteHeader(err.Status) - h.errors.NotFound.ExecuteTemplate(w, "base", pageForReq(h.app, r)) + h.errors.NotFound.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r)) return } else if err.Status == http.StatusInternalServerError { w.WriteHeader(err.Status) log.Info("handleHTTPErorr internal error render") - h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app, r)) + h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r)) return } else if err.Status == http.StatusAccepted { impart.WriteSuccess(w, "", err.Status) @@ -556,7 +608,7 @@ func (h *Handler) handleHTTPError(w http.ResponseWriter, r *http.Request, err er Title string Content template.HTML }{ - pageForReq(h.app, r), + pageForReq(h.app.App(), r), fmt.Sprintf("Uh oh (%d)", err.Status), template.HTML(fmt.Sprintf("

%s

", err.Message)), } @@ -591,7 +643,7 @@ func (h *Handler) handleError(w http.ResponseWriter, r *http.Request, err error) impart.WriteError(w, impart.HTTPError{http.StatusInternalServerError, "This is an unhelpful error message for a miscellaneous internal error."}) return } - h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app, r)) + h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r)) } func correctPageFromLoginAttempt(r *http.Request) string { @@ -613,7 +665,7 @@ func (h *Handler) LogHandlerFunc(f http.HandlerFunc) http.HandlerFunc { defer func() { if e := recover(); e != nil { log.Error("Handler.LogHandlerFunc\n\n%s: %s", e, debug.Stack()) - h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app, r)) + h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r)) status = 500 } diff --git a/hostmeta.go b/hostmeta.go index 70a8856..4f452c3 100644 --- a/hostmeta.go +++ b/hostmeta.go @@ -1,5 +1,5 @@ /* - * Copyright © 2018 A Bunch Tell LLC. + * Copyright © 2018-2019 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -15,7 +15,7 @@ import ( "net/http" ) -func handleViewHostMeta(app *app, w http.ResponseWriter, r *http.Request) error { +func handleViewHostMeta(app *App, w http.ResponseWriter, r *http.Request) error { w.Header().Set("Server", serverSoftware) w.Header().Set("Content-Type", "application/xrd+xml; charset=utf-8") diff --git a/invites.go b/invites.go index 54d6619..561255f 100644 --- a/invites.go +++ b/invites.go @@ -45,7 +45,7 @@ func (i Invite) ExpiresFriendly() string { return i.Expires.Format("January 2, 2006, 3:04 PM") } -func handleViewUserInvites(app *app, u *User, w http.ResponseWriter, r *http.Request) error { +func handleViewUserInvites(app *App, u *User, w http.ResponseWriter, r *http.Request) error { // Don't show page if instance doesn't allow it if !(app.cfg.App.UserInvites != "" && (u.IsAdmin() || app.cfg.App.UserInvites != "admin")) { return impart.HTTPError{http.StatusNotFound, ""} @@ -73,7 +73,7 @@ func handleViewUserInvites(app *app, u *User, w http.ResponseWriter, r *http.Req return nil } -func handleCreateUserInvite(app *app, u *User, w http.ResponseWriter, r *http.Request) error { +func handleCreateUserInvite(app *App, u *User, w http.ResponseWriter, r *http.Request) error { muVal := r.FormValue("uses") expVal := r.FormValue("expires") @@ -106,7 +106,7 @@ func handleCreateUserInvite(app *app, u *User, w http.ResponseWriter, r *http.Re return impart.HTTPError{http.StatusFound, "/me/invites"} } -func handleViewInvite(app *app, w http.ResponseWriter, r *http.Request) error { +func handleViewInvite(app *App, w http.ResponseWriter, r *http.Request) error { inviteCode := mux.Vars(r)["code"] i, err := app.db.GetUserInvite(inviteCode) diff --git a/key/key.go b/key/key.go new file mode 100644 index 0000000..1cb3cf8 --- /dev/null +++ b/key/key.go @@ -0,0 +1,63 @@ +/* + * Copyright © 2019 A Bunch Tell LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + +// Package key holds application keys and utilities around generating them. +package key + +import ( + "crypto/rand" +) + +const ( + EncKeysBytes = 32 +) + +type Keychain struct { + EmailKey, CookieAuthKey, CookieKey []byte +} + +// GenerateKeys generates necessary keys for the app on the given Keychain, +// skipping any that already exist. +func (keys *Keychain) GenerateKeys() error { + // Generate keys only if they don't already exist + // TODO: use something like https://github.com/hashicorp/go-multierror to return errors + var err, keyErrs error + if len(keys.EmailKey) == 0 { + keys.EmailKey, err = GenerateBytes(EncKeysBytes) + if err != nil { + keyErrs = err + } + } + if len(keys.CookieAuthKey) == 0 { + keys.CookieAuthKey, err = GenerateBytes(EncKeysBytes) + if err != nil { + keyErrs = err + } + } + if len(keys.CookieKey) == 0 { + keys.CookieKey, err = GenerateBytes(EncKeysBytes) + if err != nil { + keyErrs = err + } + } + + return keyErrs +} + +// GenerateBytes returns securely generated random bytes. +func GenerateBytes(n int) ([]byte, error) { + b := make([]byte, n) + _, err := rand.Read(b) + if err != nil { + return nil, err + } + + return b, nil +} diff --git a/keys.go b/keys.go index 3b9c360..5cc63a3 100644 --- a/keys.go +++ b/keys.go @@ -1,5 +1,5 @@ /* - * Copyright © 2018 A Bunch Tell LLC. + * Copyright © 2018-2019 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -11,8 +11,8 @@ package writefreely import ( - "crypto/rand" "github.com/writeas/web-core/log" + "github.com/writeas/writefreely/key" "io/ioutil" "os" "path/filepath" @@ -20,8 +20,6 @@ import ( const ( keysDir = "keys" - - encKeysBytes = 32 ) var ( @@ -30,47 +28,22 @@ var ( cookieKeyPath = filepath.Join(keysDir, "cookies_enc.aes256") ) -type keychain struct { - emailKey, cookieAuthKey, cookieKey []byte +// InitKeys loads encryption keys into memory via the given Apper interface +func InitKeys(apper Apper) error { + log.Info("Loading encryption keys...") + err := apper.LoadKeys() + if err != nil { + return err + } + return nil } -func initKeyPaths(app *app) { +func initKeyPaths(app *App) { emailKeyPath = filepath.Join(app.cfg.Server.KeysParentDir, emailKeyPath) cookieAuthKeyPath = filepath.Join(app.cfg.Server.KeysParentDir, cookieAuthKeyPath) cookieKeyPath = filepath.Join(app.cfg.Server.KeysParentDir, cookieKeyPath) } -func initKeys(app *app) error { - var err error - app.keys = &keychain{} - - if debugging { - log.Info(" %s", emailKeyPath) - } - app.keys.emailKey, err = ioutil.ReadFile(emailKeyPath) - if err != nil { - return err - } - - if debugging { - log.Info(" %s", cookieAuthKeyPath) - } - app.keys.cookieAuthKey, err = ioutil.ReadFile(cookieAuthKeyPath) - if err != nil { - return err - } - - if debugging { - log.Info(" %s", cookieKeyPath) - } - app.keys.cookieKey, err = ioutil.ReadFile(cookieKeyPath) - if err != nil { - return err - } - - return nil -} - // generateKey generates a key at the given path used for the encryption of // certain user data. Because user data becomes unrecoverable without these // keys, this won't overwrite any existing key, and instead outputs a message. @@ -85,7 +58,7 @@ func generateKey(path string) error { } log.Info("Generating %s.", path) - b, err := generateBytes(encKeysBytes) + b, err := key.GenerateBytes(key.EncKeysBytes) if err != nil { log.Error("FAILED. %s. Run writefreely --gen-keys again.", err) return err @@ -98,14 +71,3 @@ func generateKey(path string) error { log.Info("Success.") return nil } - -// generateBytes returns securely generated random bytes. -func generateBytes(n int) ([]byte, error) { - b := make([]byte, n) - _, err := rand.Read(b) - if err != nil { - return nil, err - } - - return b, nil -} diff --git a/pad.go b/pad.go index 1a602dd..0057299 100644 --- a/pad.go +++ b/pad.go @@ -1,5 +1,5 @@ /* - * Copyright © 2018 A Bunch Tell LLC. + * Copyright © 2018-2019 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -19,7 +19,7 @@ import ( "strings" ) -func handleViewPad(app *app, w http.ResponseWriter, r *http.Request) error { +func handleViewPad(app *App, w http.ResponseWriter, r *http.Request) error { vars := mux.Vars(r) action := vars["action"] slug := vars["slug"] @@ -102,7 +102,7 @@ func handleViewPad(app *app, w http.ResponseWriter, r *http.Request) error { return nil } -func handleViewMeta(app *app, w http.ResponseWriter, r *http.Request) error { +func handleViewMeta(app *App, w http.ResponseWriter, r *http.Request) error { vars := mux.Vars(r) action := vars["action"] slug := vars["slug"] diff --git a/pages.go b/pages.go index ddbd132..29ba07a 100644 --- a/pages.go +++ b/pages.go @@ -1,5 +1,5 @@ /* - * Copyright © 2018 A Bunch Tell LLC. + * Copyright © 2018-2019 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -18,7 +18,7 @@ import ( var defaultPageUpdatedTime = time.Date(2018, 11, 8, 12, 0, 0, 0, time.Local) -func getAboutPage(app *app) (*instanceContent, error) { +func getAboutPage(app *App) (*instanceContent, error) { c, err := app.db.GetDynamicContent("about") if err != nil { return nil, err @@ -40,7 +40,7 @@ func defaultAboutTitle(cfg *config.Config) sql.NullString { return sql.NullString{String: "About " + cfg.App.SiteName, Valid: true} } -func getPrivacyPage(app *app) (*instanceContent, error) { +func getPrivacyPage(app *App) (*instanceContent, error) { c, err := app.db.GetDynamicContent("privacy") if err != nil { return nil, err diff --git a/posts.go b/posts.go index edd1c91..0efa5ec 100644 --- a/posts.go +++ b/posts.go @@ -1,5 +1,5 @@ /* - * Copyright © 2018 A Bunch Tell LLC. + * Copyright © 2018-2019 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -262,7 +262,7 @@ func (p *Post) HasTitleLink() bool { return hasLink } -func handleViewPost(app *app, w http.ResponseWriter, r *http.Request) error { +func handleViewPost(app *App, w http.ResponseWriter, r *http.Request) error { vars := mux.Vars(r) friendlyID := vars["post"] @@ -277,7 +277,7 @@ func handleViewPost(app *app, w http.ResponseWriter, r *http.Request) error { return handleTemplatedPage(app, w, r, t) } else if (strings.Contains(r.URL.Path, ".") && !isRaw && !isMarkdown) || r.URL.Path == "/robots.txt" || r.URL.Path == "/manifest.json" { // Serve static file - shttp.ServeHTTP(w, r) + app.shttp.ServeHTTP(w, r) return nil } @@ -468,7 +468,7 @@ func handleViewPost(app *app, w http.ResponseWriter, r *http.Request) error { // /posts // /posts?collection={alias} // ? /collections/{alias}/posts -func newPost(app *app, w http.ResponseWriter, r *http.Request) error { +func newPost(app *App, w http.ResponseWriter, r *http.Request) error { reqJSON := IsJSON(r.Header.Get("Content-Type")) vars := mux.Vars(r) collAlias := vars["alias"] @@ -593,7 +593,7 @@ func newPost(app *app, w http.ResponseWriter, r *http.Request) error { return response } -func existingPost(app *app, w http.ResponseWriter, r *http.Request) error { +func existingPost(app *App, w http.ResponseWriter, r *http.Request) error { reqJSON := IsJSON(r.Header.Get("Content-Type")) vars := mux.Vars(r) postID := vars["post"] @@ -717,7 +717,7 @@ func existingPost(app *app, w http.ResponseWriter, r *http.Request) error { return nil } -func deletePost(app *app, w http.ResponseWriter, r *http.Request) error { +func deletePost(app *App, w http.ResponseWriter, r *http.Request) error { vars := mux.Vars(r) friendlyID := vars["post"] editToken := r.FormValue("token") @@ -834,7 +834,7 @@ func deletePost(app *app, w http.ResponseWriter, r *http.Request) error { } // addPost associates a post with the authenticated user. -func addPost(app *app, w http.ResponseWriter, r *http.Request) error { +func addPost(app *App, w http.ResponseWriter, r *http.Request) error { var ownerID int64 // Authenticate user @@ -883,7 +883,7 @@ func addPost(app *app, w http.ResponseWriter, r *http.Request) error { return impart.WriteSuccess(w, res, http.StatusOK) } -func dispersePost(app *app, w http.ResponseWriter, r *http.Request) error { +func dispersePost(app *App, w http.ResponseWriter, r *http.Request) error { var ownerID int64 // Authenticate user @@ -927,7 +927,7 @@ type ( ) // pinPost pins a post to a blog -func pinPost(app *app, w http.ResponseWriter, r *http.Request) error { +func pinPost(app *App, w http.ResponseWriter, r *http.Request) error { var userID int64 // Authenticate user @@ -985,7 +985,7 @@ func pinPost(app *app, w http.ResponseWriter, r *http.Request) error { return impart.WriteSuccess(w, res, http.StatusOK) } -func fetchPost(app *app, w http.ResponseWriter, r *http.Request) error { +func fetchPost(app *App, w http.ResponseWriter, r *http.Request) error { var collID int64 var coll *Collection var err error @@ -996,6 +996,7 @@ func fetchPost(app *app, w http.ResponseWriter, r *http.Request) error { if err != nil { return err } + coll.hostName = app.cfg.App.Host _, err = apiCheckCollectionPermissions(app, r, coll) if err != nil { return err @@ -1034,7 +1035,7 @@ func fetchPost(app *app, w http.ResponseWriter, r *http.Request) error { return impart.WriteSuccess(w, p, http.StatusOK) } -func fetchPostProperty(app *app, w http.ResponseWriter, r *http.Request) error { +func fetchPostProperty(app *App, w http.ResponseWriter, r *http.Request) error { vars := mux.Vars(r) p, err := app.db.GetPostProperty(vars["post"], 0, vars["property"]) if err != nil { @@ -1056,7 +1057,7 @@ func (p *Post) processPost() PublicPost { func (p *PublicPost) CanonicalURL() string { if p.Collection == nil || p.Collection.Alias == "" { - return hostName + "/" + p.ID + return p.Collection.hostName + "/" + p.ID } return p.Collection.CanonicalURL() + p.Slug.String } @@ -1087,7 +1088,7 @@ func (p *PublicPost) ActivityObject() *activitystreams.Object { if isSingleUser { tagBaseURL = p.Collection.CanonicalURL() + "tag:" } else { - tagBaseURL = fmt.Sprintf("%s/%s/tag:", hostName, p.Collection.Alias) + tagBaseURL = fmt.Sprintf("%s/%s/tag:", p.Collection.hostName, p.Collection.Alias) } for _, t := range p.Tags { o.Tag = append(o.Tag, activitystreams.Tag{ @@ -1132,7 +1133,7 @@ func (p *SubmittedPost) isFontValid() bool { return valid } -func getRawPost(app *app, friendlyID string) *RawPost { +func getRawPost(app *App, friendlyID string) *RawPost { var content, font, title string var isRTL sql.NullBool var lang sql.NullString @@ -1152,7 +1153,7 @@ func getRawPost(app *app, friendlyID string) *RawPost { } // TODO; return a Post! -func getRawCollectionPost(app *app, slug, collAlias string) *RawPost { +func getRawCollectionPost(app *App, slug, collAlias string) *RawPost { var id, title, content, font string var isRTL sql.NullBool var lang sql.NullString @@ -1189,7 +1190,7 @@ func getRawCollectionPost(app *app, slug, collAlias string) *RawPost { } } -func viewCollectionPost(app *app, w http.ResponseWriter, r *http.Request) error { +func viewCollectionPost(app *App, w http.ResponseWriter, r *http.Request) error { vars := mux.Vars(r) slug := vars["slug"] @@ -1200,7 +1201,7 @@ func viewCollectionPost(app *app, w http.ResponseWriter, r *http.Request) error if strings.Contains(r.URL.Path, ".") && !isRaw { // Serve static file - shttp.ServeHTTP(w, r) + app.shttp.ServeHTTP(w, r) return nil } @@ -1244,6 +1245,7 @@ func viewCollectionPost(app *app, w http.ResponseWriter, r *http.Request) error } return err } + c.hostName = app.cfg.App.Host // Check collection permissions if c.IsPrivate() && (u == nil || u.ID != c.OwnerID) { diff --git a/read.go b/read.go index 5098935..3bc91c7 100644 --- a/read.go +++ b/read.go @@ -1,5 +1,5 @@ /* - * Copyright © 2018 A Bunch Tell LLC. + * Copyright © 2018-2019 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -49,20 +49,20 @@ type readPublication struct { TotalPages int } -func initLocalTimeline(app *app) { +func initLocalTimeline(app *App) { app.timeline = &localTimeline{ postsPerPage: tlPostsPerPage, - m: memo.New(app.db.FetchPublicPosts, 10*time.Minute), + m: memo.New(app.FetchPublicPosts, 10*time.Minute), } } // satisfies memo.Func -func (db *datastore) FetchPublicPosts() (interface{}, error) { +func (app *App) FetchPublicPosts() (interface{}, error) { // Finds all public posts and posts in a public collection published during the owner's active subscription period and within the last 3 months - rows, err := db.Query(`SELECT p.id, alias, c.title, p.slug, p.title, p.content, p.text_appearance, p.language, p.rtl, p.created, p.updated + rows, err := app.db.Query(`SELECT p.id, alias, c.title, p.slug, p.title, p.content, p.text_appearance, p.language, p.rtl, p.created, p.updated FROM collections c LEFT JOIN posts p ON p.collection_id = c.id - WHERE c.privacy = 1 AND (p.created >= ` + db.dateSub(3, "month") + ` AND p.created <= ` + db.now() + ` AND pinned_position IS NULL) + WHERE c.privacy = 1 AND (p.created >= ` + app.db.dateSub(3, "month") + ` AND p.created <= ` + app.db.now() + ` AND pinned_position IS NULL) ORDER BY p.created DESC`) if err != nil { log.Error("Failed selecting from posts: %v", err) @@ -82,6 +82,8 @@ func (db *datastore) FetchPublicPosts() (interface{}, error) { log.Error("[READ] Unable to scan row, skipping: %v", err) continue } + c.hostName = app.cfg.App.Host + isCollectionPost := alias.Valid if isCollectionPost { c.Alias = alias.String @@ -108,7 +110,7 @@ func (db *datastore) FetchPublicPosts() (interface{}, error) { return posts, nil } -func viewLocalTimelineAPI(app *app, w http.ResponseWriter, r *http.Request) error { +func viewLocalTimelineAPI(app *App, w http.ResponseWriter, r *http.Request) error { updateTimelineCache(app.timeline) skip, _ := strconv.Atoi(r.FormValue("skip")) @@ -121,7 +123,7 @@ func viewLocalTimelineAPI(app *app, w http.ResponseWriter, r *http.Request) erro return impart.WriteSuccess(w, posts, http.StatusOK) } -func viewLocalTimeline(app *app, w http.ResponseWriter, r *http.Request) error { +func viewLocalTimeline(app *App, w http.ResponseWriter, r *http.Request) error { if !app.cfg.App.LocalTimeline { return impart.HTTPError{http.StatusNotFound, "Page doesn't exist."} } @@ -153,7 +155,7 @@ func updateTimelineCache(tl *localTimeline) { } } -func showLocalTimeline(app *app, w http.ResponseWriter, r *http.Request, page int, author, tag string) error { +func showLocalTimeline(app *App, w http.ResponseWriter, r *http.Request, page int, author, tag string) error { updateTimelineCache(app.timeline) pl := len(*(app.timeline.posts)) @@ -226,7 +228,7 @@ func (c *readPublication) PrevPageURL(n int) string { // handlePostIDRedirect handles a route where a post ID is given and redirects // the user to the canonical post URL. -func handlePostIDRedirect(app *app, w http.ResponseWriter, r *http.Request) error { +func handlePostIDRedirect(app *App, w http.ResponseWriter, r *http.Request) error { vars := mux.Vars(r) postID := vars["post"] p, err := app.db.GetPost(postID, 0) @@ -244,12 +246,13 @@ func handlePostIDRedirect(app *app, w http.ResponseWriter, r *http.Request) erro if err != nil { return err } + c.hostName = app.cfg.App.Host // Retrieve collection information and send user to canonical URL return impart.HTTPError{http.StatusFound, c.CanonicalURL() + p.Slug.String} } -func viewLocalTimelineFeed(app *app, w http.ResponseWriter, req *http.Request) error { +func viewLocalTimelineFeed(app *App, w http.ResponseWriter, req *http.Request) error { if !app.cfg.App.LocalTimeline { return impart.HTTPError{http.StatusNotFound, "Page doesn't exist."} } diff --git a/routes.go b/routes.go index bd0136f..13dd3a5 100644 --- a/routes.go +++ b/routes.go @@ -1,5 +1,5 @@ /* - * Copyright © 2018 A Bunch Tell LLC. + * Copyright © 2018-2019 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -14,15 +14,30 @@ import ( "github.com/gorilla/mux" "github.com/writeas/go-webfinger" "github.com/writeas/web-core/log" - "github.com/writeas/writefreely/config" "github.com/writefreely/go-nodeinfo" "net/http" + "path/filepath" "strings" ) -func initRoutes(handler *Handler, r *mux.Router, cfg *config.Config, db *datastore) { - hostSubroute := cfg.App.Host[strings.Index(cfg.App.Host, "://")+3:] - if cfg.App.SingleUser { +// InitStaticRoutes adds routes for serving static files. +// TODO: this should just be a func, not method +func (app *App) InitStaticRoutes(r *mux.Router) { + // Handle static files + fs := http.FileServer(http.Dir(filepath.Join(app.cfg.Server.StaticParentDir, staticDir))) + app.shttp = http.NewServeMux() + app.shttp.Handle("/", fs) + r.PathPrefix("/").Handler(fs) +} + +// InitRoutes adds dynamic routes for the given mux.Router. +func InitRoutes(apper Apper, r *mux.Router) *mux.Router { + // Create handler + handler := NewWFHandler(apper) + + // Set up routes + hostSubroute := apper.App().cfg.App.Host[strings.Index(apper.App().cfg.App.Host, "://")+3:] + if apper.App().cfg.App.SingleUser { hostSubroute = "{domain}" } else { if strings.HasPrefix(hostSubroute, "localhost") { @@ -30,7 +45,7 @@ func initRoutes(handler *Handler, r *mux.Router, cfg *config.Config, db *datasto } } - if cfg.App.SingleUser { + if apper.App().cfg.App.SingleUser { log.Info("Adding %s routes (single user)...", hostSubroute) } else { log.Info("Adding %s routes (multi-user)...", hostSubroute) @@ -40,7 +55,7 @@ func initRoutes(handler *Handler, r *mux.Router, cfg *config.Config, db *datasto write := r.PathPrefix("/").Subrouter() // Federation endpoint configurations - wf := webfinger.Default(wfResolver{db, cfg}) + wf := webfinger.Default(wfResolver{apper.App().db, apper.App().cfg}) wf.NoTLSHandler = nil // Federation endpoints @@ -49,15 +64,15 @@ func initRoutes(handler *Handler, r *mux.Router, cfg *config.Config, db *datasto // webfinger write.HandleFunc(webfinger.WebFingerPath, handler.LogHandlerFunc(http.HandlerFunc(wf.Webfinger))) // nodeinfo - niCfg := nodeInfoConfig(db, cfg) - ni := nodeinfo.NewService(*niCfg, nodeInfoResolver{cfg, db}) + niCfg := nodeInfoConfig(apper.App().db, apper.App().cfg) + ni := nodeinfo.NewService(*niCfg, nodeInfoResolver{apper.App().cfg, apper.App().db}) write.HandleFunc(nodeinfo.NodeInfoPath, handler.LogHandlerFunc(http.HandlerFunc(ni.NodeInfoDiscover))) write.HandleFunc(niCfg.InfoURL, handler.LogHandlerFunc(http.HandlerFunc(ni.NodeInfo))) // Set up dyamic page handlers // Handle auth auth := write.PathPrefix("/api/auth/").Subrouter() - if cfg.App.OpenRegistration { + if apper.App().cfg.App.OpenRegistration { auth.HandleFunc("/signup", handler.All(apiSignup)).Methods("POST") } auth.HandleFunc("/login", handler.All(login)).Methods("POST") @@ -130,7 +145,7 @@ func initRoutes(handler *Handler, r *mux.Router, cfg *config.Config, db *datasto write.HandleFunc("/admin/user/{username}", handler.Admin(handleViewAdminUser)).Methods("GET") write.HandleFunc("/admin/pages", handler.Admin(handleViewAdminPages)).Methods("GET") write.HandleFunc("/admin/page/{slug}", handler.Admin(handleViewAdminPage)).Methods("GET") - write.HandleFunc("/admin/update/config", handler.Admin(handleAdminUpdateConfig)).Methods("POST") + write.HandleFunc("/admin/update/config", handler.AdminApper(handleAdminUpdateConfig)).Methods("POST") write.HandleFunc("/admin/update/{page}", handler.Admin(handleAdminUpdateSite)).Methods("POST") // Handle special pages first @@ -144,7 +159,7 @@ func initRoutes(handler *Handler, r *mux.Router, cfg *config.Config, db *datasto RouteRead(handler, readPerm, write.PathPrefix("/read").Subrouter()) draftEditPrefix := "" - if cfg.App.SingleUser { + if apper.App().cfg.App.SingleUser { draftEditPrefix = "/d" write.HandleFunc("/me/new", handler.Web(handleViewPad, UserLevelOptional)).Methods("GET") } else { @@ -155,7 +170,7 @@ func initRoutes(handler *Handler, r *mux.Router, cfg *config.Config, db *datasto write.HandleFunc(draftEditPrefix+"/{action}/edit", handler.Web(handleViewPad, UserLevelOptional)).Methods("GET") write.HandleFunc(draftEditPrefix+"/{action}/meta", handler.Web(handleViewMeta, UserLevelOptional)).Methods("GET") // Collections - if cfg.App.SingleUser { + if apper.App().cfg.App.SingleUser { RouteCollections(handler, write.PathPrefix("/").Subrouter()) } else { write.HandleFunc("/{prefix:[@~$!\\-+]}{collection}", handler.Web(handleViewCollection, UserLevelOptional)) @@ -165,6 +180,7 @@ func initRoutes(handler *Handler, r *mux.Router, cfg *config.Config, db *datasto } write.HandleFunc(draftEditPrefix+"/{post}", handler.Web(handleViewPost, UserLevelOptional)) write.HandleFunc("/", handler.Web(handleViewHome, UserLevelOptional)) + return r } func RouteCollections(handler *Handler, r *mux.Router) { diff --git a/scripts/upgrade-server.sh b/scripts/upgrade-server.sh new file mode 100755 index 0000000..c8e004a --- /dev/null +++ b/scripts/upgrade-server.sh @@ -0,0 +1,96 @@ +#! /bin/bash +############################################################################### +## writefreely update script ## +## ## +## WARNING: running this script will overwrite any modifed assets or ## +## template files. If you have any custom changes to these files you ## +## should back them up FIRST. ## +## ## +## This must be run from the web application root directory ## +## i.e. /var/www/writefreely, and operates under the assumption that you## +## have not installed the binary `writefreely` in another location. ## +############################################################################### +# +# Copyright © 2019 A Bunch Tell LLC. +# +# This file is part of WriteFreely. +# +# WriteFreely is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License, included +# in the LICENSE file in this source code package. +# + + +# only execute as root, or use sudo + +if [[ `id -u` -ne 0 ]]; then + echo "You must login as root, or execute this script with sudo" + exit 10 +fi + +# go ahead and check for the latest release on linux +echo "Checking for updates..." + +url=`curl -s https://api.github.com/repos/writeas/writefreely/releases/latest | grep 'browser_' | grep linux | cut -d\" -f4` + +# check current version + +bin_output=`./writefreely -v` +if [ -z "$bin_output" ]; then + exit 1 +fi + +current=${bin_output:12:5} +echo "Current version is v$current" + +# grab latest version number +IFS='/' +read -ra parts <<< "$url" + +latest=${parts[-2]} +echo "Latest release is $latest" + + +IFS='.' +read -ra cv <<< "$current" +read -ra lv <<< "${latest#v}" + +IFS=' ' +tempdir=$(mktemp -d) + + +if [[ ${lv[0]} -gt ${cv[0]} ]]; then + echo "New major version available." + echo "Downloading..." + `wget -P $tempdir -q --show-progress $url` +elif [[ ${lv[0]} -eq ${cv[0]} ]] && [[ ${lv[1]} -gt ${cv[1]} ]]; then + echo "New minor version available." + echo "Downloading..." + `wget -P $tempdir -q --show-progress $url` +elif [[ ${lv[2]} -gt ${cv[2]} ]]; then + echo "New patch version available." + echo "Downloading..." + `wget -P $tempdir -q --show-progress $url` +else + echo "Up to date." + exit 0 +fi + +filename=${parts[-1]} + +# extract +echo "Extracting files..." +tar -zxf $tempdir/$filename -C $tempdir + +# copy files +echo "Copying files..." +cp -r $tempdir/{pages,static,templates,writefreely} . + +# restart service +echo "Restarting writefreely systemd service..." +if `systemctl restart writefreely`; then + echo "Success, version has been upgraded to $latest." +else + echo "Upgrade complete, but failed to restart service." + exit 1 +fi diff --git a/session.go b/session.go index 6140b05..e379496 100644 --- a/session.go +++ b/session.go @@ -1,5 +1,5 @@ /* - * Copyright © 2018 A Bunch Tell LLC. + * Copyright © 2018-2019 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -27,24 +27,24 @@ const ( blogPassCookieName = "ub" ) -// initSession creates the cookie store. It depends on the keychain already +// InitSession creates the cookie store. It depends on the keychain already // being loaded. -func initSession(app *app) *sessions.CookieStore { +func (app *App) InitSession() { // Register complex data types we'll be storing in cookies gob.Register(&User{}) // Create the cookie store - store := sessions.NewCookieStore(app.keys.cookieAuthKey, app.keys.cookieKey) + store := sessions.NewCookieStore(app.keys.CookieAuthKey, app.keys.CookieKey) store.Options = &sessions.Options{ Path: "/", MaxAge: sessionLength, HttpOnly: true, Secure: strings.HasPrefix(app.cfg.App.Host, "https://"), } - return store + app.sessionStore = store } -func getSessionFlashes(app *app, w http.ResponseWriter, r *http.Request, session *sessions.Session) ([]string, error) { +func getSessionFlashes(app *App, w http.ResponseWriter, r *http.Request, session *sessions.Session) ([]string, error) { var err error if session == nil { session, err = app.sessionStore.Get(r, cookieName) @@ -66,7 +66,7 @@ func getSessionFlashes(app *app, w http.ResponseWriter, r *http.Request, session return f, nil } -func addSessionFlash(app *app, w http.ResponseWriter, r *http.Request, m string, session *sessions.Session) error { +func addSessionFlash(app *App, w http.ResponseWriter, r *http.Request, m string, session *sessions.Session) error { var err error if session == nil { session, err = app.sessionStore.Get(r, cookieName) @@ -82,7 +82,7 @@ func addSessionFlash(app *app, w http.ResponseWriter, r *http.Request, m string, return nil } -func getUserAndSession(app *app, r *http.Request) (*User, *sessions.Session) { +func getUserAndSession(app *App, r *http.Request) (*User, *sessions.Session) { session, err := app.sessionStore.Get(r, cookieName) if err == nil { // Got the currently logged-in user @@ -97,12 +97,12 @@ func getUserAndSession(app *app, r *http.Request) (*User, *sessions.Session) { return nil, nil } -func getUserSession(app *app, r *http.Request) *User { +func getUserSession(app *App, r *http.Request) *User { u, _ := getUserAndSession(app, r) return u } -func saveUserSession(app *app, r *http.Request, w http.ResponseWriter) error { +func saveUserSession(app *App, r *http.Request, w http.ResponseWriter) error { session, err := app.sessionStore.Get(r, cookieName) if err != nil { return ErrInternalCookieSession @@ -127,7 +127,7 @@ func saveUserSession(app *app, r *http.Request, w http.ResponseWriter) error { return err } -func getFullUserSession(app *app, r *http.Request) *User { +func getFullUserSession(app *App, r *http.Request) *User { u := getUserSession(app, r) if u == nil { return nil diff --git a/sitemap.go b/sitemap.go index 517a379..5c37366 100644 --- a/sitemap.go +++ b/sitemap.go @@ -1,5 +1,5 @@ /* - * Copyright © 2018 A Bunch Tell LLC. + * Copyright © 2018-2019 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -34,7 +34,7 @@ func buildSitemap(host, alias string) *stm.Sitemap { return sm } -func handleViewSitemap(app *app, w http.ResponseWriter, r *http.Request) error { +func handleViewSitemap(app *App, w http.ResponseWriter, r *http.Request) error { vars := mux.Vars(r) // Determine canonical blog URL @@ -57,6 +57,7 @@ func handleViewSitemap(app *app, w http.ResponseWriter, r *http.Request) error { if err != nil { return err } + c.hostName = app.cfg.App.Host if !isSubdomain { pre += alias + "/" diff --git a/templates.go b/templates.go index 0f93cb9..7a45c45 100644 --- a/templates.go +++ b/templates.go @@ -98,7 +98,8 @@ func initUserPage(parentDir, path, key string) { )) } -func initTemplates(cfg *config.Config) error { +// InitTemplates loads all template files from the configured parent dir. +func InitTemplates(cfg *config.Config) error { log.Info("Loading templates...") tmplFiles, err := ioutil.ReadDir(filepath.Join(cfg.Server.TemplatesParentDir, templatesDir)) if err != nil { diff --git a/templates/user/admin.tmpl b/templates/user/admin.tmpl index 5ce11a2..90b5d70 100644 --- a/templates/user/admin.tmpl +++ b/templates/user/admin.tmpl @@ -71,6 +71,8 @@ p.docs {
{{.Config.Host}}
User Mode
{{if .Config.SingleUser}}Single user{{else}}Multiple users{{end}}
+