@@ -1,7 +1,7 @@ | |||
language: go | |||
go: | |||
- "1.11.x" | |||
- "1.13.x" | |||
env: | |||
- GO111MODULE=on | |||
@@ -25,28 +25,40 @@ build-no-sqlite: assets-no-sqlite deps-no-sqlite | |||
build-linux: deps | |||
@hash xgo > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ | |||
$(GOGET) -u github.com/karalabe/xgo; \ | |||
$(GOGET) -u src.techknowlogick.com/xgo; \ | |||
fi | |||
xgo --targets=linux/amd64, -dest build/ $(LDFLAGS) -tags='sqlite' -out writefreely ./cmd/writefreely | |||
build-windows: deps | |||
@hash xgo > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ | |||
$(GOGET) -u github.com/karalabe/xgo; \ | |||
$(GOGET) -u src.techknowlogick.com/xgo; \ | |||
fi | |||
xgo --targets=windows/amd64, -dest build/ $(LDFLAGS) -tags='sqlite' -out writefreely ./cmd/writefreely | |||
build-darwin: deps | |||
@hash xgo > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ | |||
$(GOGET) -u github.com/karalabe/xgo; \ | |||
$(GOGET) -u src.techknowlogick.com/xgo; \ | |||
fi | |||
xgo --targets=darwin/amd64, -dest build/ $(LDFLAGS) -tags='sqlite' -out writefreely ./cmd/writefreely | |||
build-arm6: deps | |||
@hash xgo > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ | |||
$(GOGET) -u src.techknowlogick.com/xgo; \ | |||
fi | |||
xgo --targets=linux/arm-6, -dest build/ $(LDFLAGS) -tags='sqlite' -out writefreely ./cmd/writefreely | |||
build-arm7: deps | |||
@hash xgo > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ | |||
$(GOGET) -u github.com/karalabe/xgo; \ | |||
$(GOGET) -u src.techknowlogick.com/xgo; \ | |||
fi | |||
xgo --targets=linux/arm-7, -dest build/ $(LDFLAGS) -tags='sqlite' -out writefreely ./cmd/writefreely | |||
build-arm64: deps | |||
@hash xgo > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ | |||
$(GOGET) -u src.techknowlogick.com/xgo; \ | |||
fi | |||
xgo --targets=linux/arm64, -dest build/ $(LDFLAGS) -tags='sqlite' -out writefreely ./cmd/writefreely | |||
build-docker : | |||
$(DOCKERCMD) build -t $(IMAGE_NAME):latest -t $(IMAGE_NAME):$(GITREV) . | |||
@@ -79,10 +91,18 @@ release : clean ui assets | |||
mv build/$(BINARY_NAME)-linux-amd64 $(BUILDPATH)/$(BINARY_NAME) | |||
tar -cvzf $(BINARY_NAME)_$(GITREV)_linux_amd64.tar.gz -C build $(BINARY_NAME) | |||
rm $(BUILDPATH)/$(BINARY_NAME) | |||
$(MAKE) build-arm6 | |||
mv build/$(BINARY_NAME)-linux-arm-6 $(BUILDPATH)/$(BINARY_NAME) | |||
tar -cvzf $(BINARY_NAME)_$(GITREV)_linux_arm6.tar.gz -C build $(BINARY_NAME) | |||
rm $(BUILDPATH)/$(BINARY_NAME) | |||
$(MAKE) build-arm7 | |||
mv build/$(BINARY_NAME)-linux-arm-7 $(BUILDPATH)/$(BINARY_NAME) | |||
tar -cvzf $(BINARY_NAME)_$(GITREV)_linux_arm7.tar.gz -C build $(BINARY_NAME) | |||
rm $(BUILDPATH)/$(BINARY_NAME) | |||
$(MAKE) build-arm64 | |||
mv build/$(BINARY_NAME)-linux-arm64 $(BUILDPATH)/$(BINARY_NAME) | |||
tar -cvzf $(BINARY_NAME)_$(GITREV)_linux_arm64.tar.gz -C build $(BINARY_NAME) | |||
rm $(BUILDPATH)/$(BINARY_NAME) | |||
$(MAKE) build-darwin | |||
mv build/$(BINARY_NAME)-darwin-10.6-amd64 $(BUILDPATH)/$(BINARY_NAME) | |||
tar -cvzf $(BINARY_NAME)_$(GITREV)_macos_amd64.tar.gz -C build $(BINARY_NAME) | |||
@@ -135,7 +155,7 @@ $(TMPBIN)/go-bindata: deps $(TMPBIN) | |||
$(GOBUILD) -o $(TMPBIN)/go-bindata github.com/jteeuwen/go-bindata/go-bindata | |||
$(TMPBIN)/xgo: deps $(TMPBIN) | |||
$(GOBUILD) -o $(TMPBIN)/xgo github.com/karalabe/xgo | |||
$(GOBUILD) -o $(TMPBIN)/xgo src.techknowlogick.com/xgo | |||
ci-assets : $(TMPBIN)/go-bindata | |||
$(TMPBIN)/go-bindata -pkg writefreely -ignore=\\.gitignore -tags="!wflib" schema.sql sqlite.sql | |||
@@ -156,17 +156,9 @@ func signupWithRegistration(app *App, signup userRegistration, w http.ResponseWr | |||
Username: signup.Alias, | |||
HashedPass: hashedPass, | |||
HasPass: createdWithPass, | |||
Email: zero.NewString("", signup.Email != ""), | |||
Email: prepareUserEmail(signup.Email, app.keys.EmailKey), | |||
Created: time.Now().Truncate(time.Second).UTC(), | |||
} | |||
if signup.Email != "" { | |||
encEmail, err := data.Encrypt(app.keys.EmailKey, signup.Email) | |||
if err != nil { | |||
log.Error("Unable to encrypt email: %s\n", err) | |||
} else { | |||
u.Email.String = string(encEmail) | |||
} | |||
} | |||
// Create actual user | |||
if err := app.db.CreateUser(app.cfg, u, desiredUsername); err != nil { | |||
@@ -314,12 +306,16 @@ func viewLogin(app *App, w http.ResponseWriter, r *http.Request) error { | |||
Message template.HTML | |||
Flashes []template.HTML | |||
LoginUsername string | |||
OauthSlack bool | |||
OauthWriteAs bool | |||
}{ | |||
pageForReq(app, r), | |||
r.FormValue("to"), | |||
template.HTML(""), | |||
[]template.HTML{}, | |||
getTempInfo(app, "login-user", r, w), | |||
app.Config().SlackOauth.ClientID != "", | |||
app.Config().WriteAsOauth.ClientID != "", | |||
} | |||
if earlyError != "" { | |||
@@ -1097,3 +1093,16 @@ func getTempInfo(app *App, key string, r *http.Request, w http.ResponseWriter) s | |||
// Return value | |||
return s | |||
} | |||
func prepareUserEmail(input string, emailKey []byte) zero.String { | |||
email := zero.NewString("", input != "") | |||
if len(input) > 0 { | |||
encEmail, err := data.Encrypt(emailKey, input) | |||
if err != nil { | |||
log.Error("Unable to encrypt email: %s\n", err) | |||
} else { | |||
email.String = string(encEmail) | |||
} | |||
} | |||
return email | |||
} |
@@ -0,0 +1,195 @@ | |||
package writefreely | |||
import ( | |||
"encoding/json" | |||
"fmt" | |||
"html/template" | |||
"io" | |||
"io/ioutil" | |||
"net/http" | |||
"os" | |||
"path/filepath" | |||
"strings" | |||
"time" | |||
"github.com/hashicorp/go-multierror" | |||
"github.com/writeas/impart" | |||
wfimport "github.com/writeas/import" | |||
"github.com/writeas/web-core/log" | |||
) | |||
func viewImport(app *App, u *User, w http.ResponseWriter, r *http.Request) error { | |||
// Fetch extra user data | |||
p := NewUserPage(app, r, u, "Import Posts", nil) | |||
c, err := app.db.GetCollections(u, app.Config().App.Host) | |||
if err != nil { | |||
return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("unable to fetch collections: %v", err)} | |||
} | |||
d := struct { | |||
*UserPage | |||
Collections *[]Collection | |||
Flashes []template.HTML | |||
Message string | |||
InfoMsg bool | |||
}{ | |||
UserPage: p, | |||
Collections: c, | |||
Flashes: []template.HTML{}, | |||
} | |||
flashes, _ := getSessionFlashes(app, w, r, nil) | |||
for _, flash := range flashes { | |||
if strings.HasPrefix(flash, "SUCCESS: ") { | |||
d.Message = strings.TrimPrefix(flash, "SUCCESS: ") | |||
} else if strings.HasPrefix(flash, "INFO: ") { | |||
d.Message = strings.TrimPrefix(flash, "INFO: ") | |||
d.InfoMsg = true | |||
} else { | |||
d.Flashes = append(d.Flashes, template.HTML(flash)) | |||
} | |||
} | |||
showUserPage(w, "import", d) | |||
return nil | |||
} | |||
func handleImport(app *App, u *User, w http.ResponseWriter, r *http.Request) error { | |||
// limit 10MB per submission | |||
r.ParseMultipartForm(10 << 20) | |||
collAlias := r.PostFormValue("collection") | |||
coll := &Collection{ | |||
ID: 0, | |||
} | |||
var err error | |||
if collAlias != "" { | |||
coll, err = app.db.GetCollection(collAlias) | |||
if err != nil { | |||
log.Error("Unable to get collection for import: %s", err) | |||
return err | |||
} | |||
// Only allow uploading to collection if current user is owner | |||
if coll.OwnerID != u.ID { | |||
err := ErrUnauthorizedGeneral | |||
_ = addSessionFlash(app, w, r, err.Message, nil) | |||
return err | |||
} | |||
coll.hostName = app.cfg.App.Host | |||
} | |||
fileDates := make(map[string]int64) | |||
err = json.Unmarshal([]byte(r.FormValue("fileDates")), &fileDates) | |||
if err != nil { | |||
log.Error("invalid form data for file dates: %v", err) | |||
return impart.HTTPError{http.StatusBadRequest, "form data for file dates was invalid"} | |||
} | |||
files := r.MultipartForm.File["files"] | |||
var fileErrs []error | |||
filesSubmitted := len(files) | |||
var filesImported int | |||
for _, formFile := range files { | |||
fname := "" | |||
ok := func() bool { | |||
file, err := formFile.Open() | |||
if err != nil { | |||
fileErrs = append(fileErrs, fmt.Errorf("Unable to read file %s", formFile.Filename)) | |||
log.Error("import file: open from form: %v", err) | |||
return false | |||
} | |||
defer file.Close() | |||
tempFile, err := ioutil.TempFile("", "post-upload-*.txt") | |||
if err != nil { | |||
fileErrs = append(fileErrs, fmt.Errorf("Internal error for %s", formFile.Filename)) | |||
log.Error("import file: create temp file %s: %v", formFile.Filename, err) | |||
return false | |||
} | |||
defer tempFile.Close() | |||
_, err = io.Copy(tempFile, file) | |||
if err != nil { | |||
fileErrs = append(fileErrs, fmt.Errorf("Internal error for %s", formFile.Filename)) | |||
log.Error("import file: copy to temp location %s: %v", formFile.Filename, err) | |||
return false | |||
} | |||
info, err := tempFile.Stat() | |||
if err != nil { | |||
fileErrs = append(fileErrs, fmt.Errorf("Internal error for %s", formFile.Filename)) | |||
log.Error("import file: stat temp file %s: %v", formFile.Filename, err) | |||
return false | |||
} | |||
fname = info.Name() | |||
return true | |||
}() | |||
if !ok { | |||
continue | |||
} | |||
post, err := wfimport.FromFile(filepath.Join(os.TempDir(), fname)) | |||
if err == wfimport.ErrEmptyFile { | |||
// not a real error so don't log | |||
_ = addSessionFlash(app, w, r, fmt.Sprintf("%s was empty, import skipped", formFile.Filename), nil) | |||
continue | |||
} else if err == wfimport.ErrInvalidContentType { | |||
// same as above | |||
_ = addSessionFlash(app, w, r, fmt.Sprintf("%s is not a supported post file", formFile.Filename), nil) | |||
continue | |||
} else if err != nil { | |||
fileErrs = append(fileErrs, fmt.Errorf("failed to read copy of %s", formFile.Filename)) | |||
log.Error("import textfile: file to post: %v", err) | |||
continue | |||
} | |||
if collAlias != "" { | |||
post.Collection = collAlias | |||
} | |||
dateTime := time.Unix(fileDates[formFile.Filename], 0) | |||
post.Created = &dateTime | |||
created := post.Created.Format("2006-01-02T15:04:05Z") | |||
submittedPost := SubmittedPost{ | |||
Title: &post.Title, | |||
Content: &post.Content, | |||
Font: "norm", | |||
Created: &created, | |||
} | |||
rp, err := app.db.CreatePost(u.ID, coll.ID, &submittedPost) | |||
if err != nil { | |||
fileErrs = append(fileErrs, fmt.Errorf("failed to create post from %s", formFile.Filename)) | |||
log.Error("import textfile: create db post: %v", err) | |||
continue | |||
} | |||
// Federate post, if necessary | |||
if app.cfg.App.Federation && coll.ID > 0 { | |||
go federatePost( | |||
app, | |||
&PublicPost{ | |||
Post: rp, | |||
Collection: &CollectionObj{ | |||
Collection: *coll, | |||
}, | |||
}, | |||
coll.ID, | |||
false, | |||
) | |||
} | |||
filesImported++ | |||
} | |||
if len(fileErrs) != 0 { | |||
_ = addSessionFlash(app, w, r, multierror.ListFormatFunc(fileErrs), nil) | |||
} | |||
if filesImported == filesSubmitted { | |||
verb := "posts" | |||
if filesSubmitted == 1 { | |||
verb = "post" | |||
} | |||
_ = addSessionFlash(app, w, r, fmt.Sprintf("SUCCESS: Import complete, %d %s imported.", filesImported, verb), nil) | |||
} else if filesImported > 0 { | |||
_ = addSessionFlash(app, w, r, fmt.Sprintf("INFO: %d of %d posts imported, see details below.", filesImported, filesSubmitted), nil) | |||
} | |||
return impart.HTTPError{http.StatusFound, "/me/import"} | |||
} |
@@ -1,5 +1,5 @@ | |||
/* | |||
* Copyright © 2018-2019 A Bunch Tell LLC. | |||
* Copyright © 2018-2020 A Bunch Tell LLC. | |||
* | |||
* This file is part of WriteFreely. | |||
* | |||
@@ -37,6 +37,8 @@ import ( | |||
const ( | |||
// TODO: delete. don't use this! | |||
apCustomHandleDefault = "blog" | |||
apCacheTime = time.Minute | |||
) | |||
type RemoteUser struct { | |||
@@ -44,6 +46,7 @@ type RemoteUser struct { | |||
ActorID string | |||
Inbox string | |||
SharedInbox string | |||
Handle string | |||
} | |||
func (ru *RemoteUser) AsPerson() *activitystreams.Person { | |||
@@ -92,6 +95,7 @@ func handleFetchCollectionActivities(app *App, w http.ResponseWriter, r *http.Re | |||
p := c.PersonObject() | |||
setCacheControl(w, apCacheTime) | |||
return impart.RenderActivityJSON(w, p, http.StatusOK) | |||
} | |||
@@ -148,11 +152,12 @@ func handleFetchCollectionOutbox(app *App, w http.ResponseWriter, r *http.Reques | |||
posts, err := app.db.GetPosts(app.cfg, c, p, false, true, false) | |||
for _, pp := range *posts { | |||
pp.Collection = res | |||
o := pp.ActivityObject(app.cfg) | |||
o := pp.ActivityObject(app) | |||
a := activitystreams.NewCreateActivity(o) | |||
ocp.OrderedItems = append(ocp.OrderedItems, *a) | |||
} | |||
setCacheControl(w, apCacheTime) | |||
return impart.RenderActivityJSON(w, ocp, http.StatusOK) | |||
} | |||
@@ -207,6 +212,7 @@ func handleFetchCollectionFollowers(app *App, w http.ResponseWriter, r *http.Req | |||
ocp.OrderedItems = append(ocp.OrderedItems, f.ActorID) | |||
} | |||
*/ | |||
setCacheControl(w, apCacheTime) | |||
return impart.RenderActivityJSON(w, ocp, http.StatusOK) | |||
} | |||
@@ -251,6 +257,7 @@ func handleFetchCollectionFollowing(app *App, w http.ResponseWriter, r *http.Req | |||
// Return outbox page | |||
ocp := activitystreams.NewOrderedCollectionPage(accountRoot, "following", 0, p) | |||
ocp.OrderedItems = []interface{}{} | |||
setCacheControl(w, apCacheTime) | |||
return impart.RenderActivityJSON(w, ocp, http.StatusOK) | |||
} | |||
@@ -564,7 +571,7 @@ func deleteFederatedPost(app *App, p *PublicPost, collID int64) error { | |||
} | |||
p.Collection.hostName = app.cfg.App.Host | |||
actor := p.Collection.PersonObject(collID) | |||
na := p.ActivityObject(app.cfg) | |||
na := p.ActivityObject(app) | |||
// Add followers | |||
p.Collection.ID = collID | |||
@@ -610,7 +617,7 @@ func federatePost(app *App, p *PublicPost, collID int64, isUpdate bool) error { | |||
} | |||
} | |||
actor := p.Collection.PersonObject(collID) | |||
na := p.ActivityObject(app.cfg) | |||
na := p.ActivityObject(app) | |||
// Add followers | |||
p.Collection.ID = collID | |||
@@ -628,18 +635,25 @@ func federatePost(app *App, p *PublicPost, collID int64, isUpdate bool) error { | |||
inbox = f.Inbox | |||
} | |||
if _, ok := inboxes[inbox]; ok { | |||
// check if we're already sending to this shared inbox | |||
inboxes[inbox] = append(inboxes[inbox], f.ActorID) | |||
} else { | |||
// add the new shared inbox to the list | |||
inboxes[inbox] = []string{f.ActorID} | |||
} | |||
} | |||
var activity *activitystreams.Activity | |||
// for each one of the shared inboxes | |||
for si, instFolls := range inboxes { | |||
// add all followers from that instance | |||
// to the CC field | |||
na.CC = []string{} | |||
for _, f := range instFolls { | |||
na.CC = append(na.CC, f) | |||
} | |||
var activity *activitystreams.Activity | |||
// create a new "Create" activity | |||
// with our article as object | |||
if isUpdate { | |||
activity = activitystreams.NewUpdateActivity(na) | |||
} else { | |||
@@ -647,17 +661,42 @@ func federatePost(app *App, p *PublicPost, collID int64, isUpdate bool) error { | |||
activity.To = na.To | |||
activity.CC = na.CC | |||
} | |||
// and post it to that sharedInbox | |||
err = makeActivityPost(app.cfg.App.Host, actor, si, activity) | |||
if err != nil { | |||
log.Error("Couldn't post! %v", err) | |||
} | |||
} | |||
// re-create the object so that the CC list gets reset and has | |||
// the mentioned users. This might seem wasteful but the code is | |||
// cleaner than adding the mentioned users to CC here instead of | |||
// in p.ActivityObject() | |||
na = p.ActivityObject(app) | |||
for _, tag := range na.Tag { | |||
if tag.Type == "Mention" { | |||
activity = activitystreams.NewCreateActivity(na) | |||
activity.To = na.To | |||
activity.CC = na.CC | |||
// This here might be redundant in some cases as we might have already | |||
// sent this to the sharedInbox of this instance above, but we need too | |||
// much logic to catch this at the expense of the odd extra request. | |||
// I don't believe we'd ever have too many mentions in a single post that this | |||
// could become a burden. | |||
remoteUser, err := getRemoteUser(app, tag.HRef) | |||
err = makeActivityPost(app.cfg.App.Host, actor, remoteUser.Inbox, activity) | |||
if err != nil { | |||
log.Error("Couldn't post! %v", err) | |||
} | |||
} | |||
} | |||
return nil | |||
} | |||
func getRemoteUser(app *App, actorID string) (*RemoteUser, error) { | |||
u := RemoteUser{ActorID: actorID} | |||
err := app.db.QueryRow("SELECT id, inbox, shared_inbox FROM remoteusers WHERE actor_id = ?", actorID).Scan(&u.ID, &u.Inbox, &u.SharedInbox) | |||
err := app.db.QueryRow("SELECT id, inbox, shared_inbox, handle FROM remoteusers WHERE actor_id = ?", actorID).Scan(&u.ID, &u.Inbox, &u.SharedInbox, &u.Handle) | |||
switch { | |||
case err == sql.ErrNoRows: | |||
return nil, impart.HTTPError{http.StatusNotFound, "No remote user with that ID."} | |||
@@ -669,6 +708,21 @@ func getRemoteUser(app *App, actorID string) (*RemoteUser, error) { | |||
return &u, nil | |||
} | |||
// getRemoteUserFromHandle retrieves the profile page of a remote user | |||
// from the @user@server.tld handle | |||
func getRemoteUserFromHandle(app *App, handle string) (*RemoteUser, error) { | |||
u := RemoteUser{Handle: handle} | |||
err := app.db.QueryRow("SELECT id, actor_id, inbox, shared_inbox FROM remoteusers WHERE handle = ?", handle).Scan(&u.ID, &u.ActorID, &u.Inbox, &u.SharedInbox) | |||
switch { | |||
case err == sql.ErrNoRows: | |||
return nil, ErrRemoteUserNotFound | |||
case err != nil: | |||
log.Error("Couldn't get remote user %s: %v", handle, err) | |||
return nil, err | |||
} | |||
return &u, nil | |||
} | |||
func getActor(app *App, actorIRI string) (*activitystreams.Person, *RemoteUser, error) { | |||
log.Info("Fetching actor %s locally", actorIRI) | |||
actor := &activitystreams.Person{} | |||
@@ -743,3 +797,7 @@ func unmarshalActor(actorResp []byte, actor *activitystreams.Person) error { | |||
return nil | |||
} | |||
func setCacheControl(w http.ResponseWriter, ttl time.Duration) { | |||
w.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%.0f", ttl.Seconds())) | |||
} |
@@ -187,7 +187,11 @@ func handleViewAdminUser(app *App, u *User, w http.ResponseWriter, r *http.Reque | |||
var err error | |||
p.User, err = app.db.GetUserForAuth(username) | |||
if err != nil { | |||
return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not get user: %v", err)} | |||
if err == ErrUserNotFound { | |||
return err | |||
} | |||
log.Error("Could not get user: %v", err) | |||
return impart.HTTPError{http.StatusInternalServerError, err.Error()} | |||
} | |||
flashes, _ := getSessionFlashes(app, w, r, nil) | |||
@@ -260,7 +264,7 @@ func handleAdminToggleUserStatus(app *App, u *User, w http.ResponseWriter, r *ht | |||
} | |||
if err != nil { | |||
log.Error("toggle user silenced: %v", err) | |||
return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not toggle user status: %v")} | |||
return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not toggle user status: %v", err)} | |||
} | |||
return impart.HTTPError{http.StatusFound, fmt.Sprintf("/admin/user/%s#status", username)} | |||
} | |||
@@ -30,7 +30,7 @@ import ( | |||
"github.com/gorilla/schema" | |||
"github.com/gorilla/sessions" | |||
"github.com/manifoldco/promptui" | |||
"github.com/writeas/go-strip-markdown" | |||
stripmd "github.com/writeas/go-strip-markdown" | |||
"github.com/writeas/impart" | |||
"github.com/writeas/web-core/auth" | |||
"github.com/writeas/web-core/converter" | |||
@@ -56,7 +56,7 @@ var ( | |||
debugging bool | |||
// Software version can be set from git env using -ldflags | |||
softwareVer = "0.11.0" | |||
softwareVer = "0.11.2" | |||
// DEPRECATED VARS | |||
isSingleUser bool | |||
@@ -70,7 +70,7 @@ type App struct { | |||
cfg *config.Config | |||
cfgFile string | |||
keys *key.Keychain | |||
sessionStore *sessions.CookieStore | |||
sessionStore sessions.Store | |||
formDecoder *schema.Decoder | |||
timeline *localTimeline | |||
@@ -101,6 +101,14 @@ func (app *App) SetKeys(k *key.Keychain) { | |||
app.keys = k | |||
} | |||
func (app *App) SessionStore() sessions.Store { | |||
return app.sessionStore | |||
} | |||
func (app *App) SetSessionStore(s sessions.Store) { | |||
app.sessionStore = s | |||
} | |||
// Apper is the interface for getting data into and out of a WriteFreely | |||
// instance (or "App"). | |||
// | |||
@@ -681,6 +689,52 @@ func ResetPassword(apper Apper, username string) error { | |||
return nil | |||
} | |||
// DoDeleteAccount runs the confirmation and account delete process. | |||
func DoDeleteAccount(apper Apper, username string) error { | |||
// Connect to the database | |||
apper.LoadConfig() | |||
connectToDatabase(apper.App()) | |||
defer shutdown(apper.App()) | |||
// check user exists | |||
u, err := apper.App().db.GetUserForAuth(username) | |||
if err != nil { | |||
log.Error("%s", err) | |||
os.Exit(1) | |||
} | |||
userID := u.ID | |||
// do not delete the admin account | |||
// TODO: check for other admins and skip? | |||
if u.IsAdmin() { | |||
log.Error("Can not delete admin account") | |||
os.Exit(1) | |||
} | |||
// confirm deletion, w/ w/out posts | |||
prompt := promptui.Prompt{ | |||
Templates: &promptui.PromptTemplates{ | |||
Success: "{{ . | bold | faint }}: ", | |||
}, | |||
Label: fmt.Sprintf("Really delete user : %s", username), | |||
IsConfirm: true, | |||
} | |||
_, err = prompt.Run() | |||
if err != nil { | |||
log.Info("Aborted...") | |||
os.Exit(0) | |||
} | |||
log.Info("Deleting...") | |||
err = apper.App().db.DeleteAccount(userID) | |||
if err != nil { | |||
log.Error("%s", err) | |||
os.Exit(1) | |||
} | |||
log.Info("Success.") | |||
return nil | |||
} | |||
func connectToDatabase(app *App) { | |||
log.Info("Connecting to %s database...", app.cfg.Database.Type) | |||
@@ -1,5 +1,5 @@ | |||
/* | |||
* Copyright © 2018 A Bunch Tell LLC. | |||
* Copyright © 2018-2020 A Bunch Tell LLC. | |||
* | |||
* This file is part of WriteFreely. | |||
* | |||
@@ -65,6 +65,7 @@ var reservedUsernames = map[string]bool{ | |||
"metadata": true, | |||
"new": true, | |||
"news": true, | |||
"oauth": true, | |||
"post": true, | |||
"posts": true, | |||
"privacy": true, | |||
@@ -13,11 +13,12 @@ package main | |||
import ( | |||
"flag" | |||
"fmt" | |||
"os" | |||
"strings" | |||
"github.com/gorilla/mux" | |||
"github.com/writeas/web-core/log" | |||
"github.com/writeas/writefreely" | |||
"os" | |||
"strings" | |||
) | |||
func main() { | |||
@@ -38,6 +39,7 @@ func main() { | |||
// Admin actions | |||
createAdmin := flag.String("create-admin", "", "Create an admin with the given username:password") | |||
createUser := flag.String("create-user", "", "Create a regular user with the given username:password") | |||
deleteUsername := flag.String("delete-user", "", "Delete a user with the given username") | |||
resetPassUser := flag.String("reset-pass", "", "Reset the given user's password") | |||
outputVersion := flag.Bool("v", false, "Output the current version") | |||
flag.Parse() | |||
@@ -102,6 +104,13 @@ func main() { | |||
os.Exit(1) | |||
} | |||
os.Exit(0) | |||
} else if *deleteUsername != "" { | |||
err := writefreely.DoDeleteAccount(app, *deleteUsername) | |||
if err != nil { | |||
log.Error(err.Error()) | |||
os.Exit(1) | |||
} | |||
os.Exit(0) | |||
} else if *migrate { | |||
err := writefreely.Migrate(app) | |||
if err != nil { | |||
@@ -1,5 +1,5 @@ | |||
/* | |||
* Copyright © 2018 A Bunch Tell LLC. | |||
* Copyright © 2018-2020 A Bunch Tell LLC. | |||
* | |||
* This file is part of WriteFreely. | |||
* | |||
@@ -63,6 +63,7 @@ type ( | |||
TotalPosts int `json:"total_posts"` | |||
Owner *User `json:"owner,omitempty"` | |||
Posts *[]PublicPost `json:"posts,omitempty"` | |||
Format *CollectionFormat | |||
} | |||
DisplayCollection struct { | |||
*CollectionObj | |||
@@ -556,6 +557,13 @@ type CollectionPage struct { | |||
CanInvite bool | |||
} | |||
func NewCollectionObj(c *Collection) *CollectionObj { | |||
return &CollectionObj{ | |||
Collection: *c, | |||
Format: c.NewFormat(), | |||
} | |||
} | |||
func (c *CollectionObj) ScriptDisplay() template.JS { | |||
return template.JS(c.Script) | |||
} | |||
@@ -648,6 +656,16 @@ func processCollectionPermissions(app *App, cr *collectionReq, u *User, w http.R | |||
uname = u.Username | |||
} | |||
// TODO: move this to all permission checks? | |||
suspended, err := app.db.IsUserSuspended(c.OwnerID) | |||
if err != nil { | |||
log.Error("process protected collection permissions: %v", err) | |||
return nil, err | |||
} | |||
if suspended { | |||
return nil, ErrCollectionNotFound | |||
} | |||
// See if we've authorized this collection | |||
authd := isAuthorizedForCollection(app, c.Alias, r) | |||
@@ -695,11 +713,10 @@ func checkUserForCollection(app *App, cr *collectionReq, r *http.Request, isPost | |||
func newDisplayCollection(c *Collection, cr *collectionReq, page int) *DisplayCollection { | |||
coll := &DisplayCollection{ | |||
CollectionObj: &CollectionObj{Collection: *c}, | |||
CollectionObj: NewCollectionObj(c), | |||
CurrentPage: page, | |||
Prefix: cr.prefix, | |||
IsTopLevel: isSingleUser, | |||
Format: c.NewFormat(), | |||
} | |||
c.db.GetPostsCount(coll.CollectionObj, cr.isCollOwner) | |||
return coll | |||
@@ -748,6 +765,7 @@ func handleViewCollection(app *App, w http.ResponseWriter, r *http.Request) erro | |||
if strings.Contains(r.Header.Get("Accept"), "application/activity+json") { | |||
ac := c.PersonObject() | |||
ac.Context = []interface{}{activitystreams.Namespace} | |||
setCacheControl(w, apCacheTime) | |||
return impart.RenderActivityJSON(w, ac, http.StatusOK) | |||
} | |||
@@ -840,6 +858,19 @@ func handleViewCollection(app *App, w http.ResponseWriter, r *http.Request) erro | |||
return err | |||
} | |||
func handleViewMention(app *App, w http.ResponseWriter, r *http.Request) error { | |||
vars := mux.Vars(r) | |||
handle := vars["handle"] | |||
remoteUser, err := app.db.GetProfilePageFromHandle(app, handle) | |||
if err != nil || remoteUser == "" { | |||
log.Error("Couldn't find user %s: %v", handle, err) | |||
return ErrRemoteUserNotFound | |||
} | |||
return impart.HTTPError{Status: http.StatusFound, Message: remoteUser} | |||
} | |||
func handleViewCollectionTag(app *App, w http.ResponseWriter, r *http.Request) error { | |||
vars := mux.Vars(r) | |||
tag := vars["tag"] | |||
@@ -905,11 +936,11 @@ func handleViewCollectionTag(app *App, w http.ResponseWriter, r *http.Request) e | |||
// Log the error and just continue | |||
log.Error("Error getting user for collection: %v", err) | |||
} | |||
if owner.IsSilenced() { | |||
return ErrCollectionNotFound | |||
} | |||
} | |||
if !isOwner && u.IsSilenced() { | |||
return ErrCollectionNotFound | |||
} | |||
displayPage.Silenced = u.IsSilenced() | |||
displayPage.Silenced = owner != nil && owner.IsSilenced() | |||
displayPage.Owner = owner | |||
coll.Owner = displayPage.Owner | |||
// Add more data | |||
@@ -42,6 +42,8 @@ type ( | |||
PagesParentDir string `ini:"pages_parent_dir"` | |||
KeysParentDir string `ini:"keys_parent_dir"` | |||
HashSeed string `ini:"hash_seed"` | |||
Dev bool `ini:"-"` | |||
} | |||
@@ -56,6 +58,24 @@ type ( | |||
Port int `ini:"port"` | |||
} | |||
WriteAsOauthCfg struct { | |||
ClientID string `ini:"client_id"` | |||
ClientSecret string `ini:"client_secret"` | |||
AuthLocation string `ini:"auth_location"` | |||
TokenLocation string `ini:"token_location"` | |||
InspectLocation string `ini:"inspect_location"` | |||
CallbackProxy string `ini:"callback_proxy"` | |||
CallbackProxyAPI string `ini:"callback_proxy_api"` | |||
} | |||
SlackOauthCfg struct { | |||
ClientID string `ini:"client_id"` | |||
ClientSecret string `ini:"client_secret"` | |||
TeamID string `ini:"team_id"` | |||
CallbackProxy string `ini:"callback_proxy"` | |||
CallbackProxyAPI string `ini:"callback_proxy_api"` | |||
} | |||
// AppCfg holds values that affect how the application functions | |||
AppCfg struct { | |||
SiteName string `ini:"site_name"` | |||
@@ -98,9 +118,11 @@ type ( | |||
// Config holds the complete configuration for running a writefreely instance | |||
Config struct { | |||
Server ServerCfg `ini:"server"` | |||
Database DatabaseCfg `ini:"database"` | |||
App AppCfg `ini:"app"` | |||
Server ServerCfg `ini:"server"` | |||
Database DatabaseCfg `ini:"database"` | |||
App AppCfg `ini:"app"` | |||
SlackOauth SlackOauthCfg `ini:"oauth.slack"` | |||
WriteAsOauth WriteAsOauthCfg `ini:"oauth.writeas"` | |||
} | |||
) | |||
@@ -11,7 +11,9 @@ | |||
package config | |||
import ( | |||
"net/http" | |||
"strings" | |||
"time" | |||
) | |||
// FriendlyHost returns the app's Host sans any schema | |||
@@ -25,3 +27,16 @@ func (ac AppCfg) CanCreateBlogs(currentlyUsed uint64) bool { | |||
} | |||
return int(currentlyUsed) < ac.MaxBlogs | |||
} | |||
// OrDefaultString returns input or a default value if input is empty. | |||
func OrDefaultString(input, defaultValue string) string { | |||
if len(input) == 0 { | |||
return defaultValue | |||
} | |||
return input | |||
} | |||
// DefaultHTTPClient returns a sane default HTTP client. | |||
func DefaultHTTPClient() *http.Client { | |||
return &http.Client{Timeout: 10 * time.Second} | |||
} |
@@ -11,8 +11,10 @@ | |||
package writefreely | |||
import ( | |||
"context" | |||
"database/sql" | |||
"fmt" | |||
wf_db "github.com/writeas/writefreely/db" | |||
"net/http" | |||
"strings" | |||
"time" | |||
@@ -20,6 +22,7 @@ import ( | |||
"github.com/guregu/null" | |||
"github.com/guregu/null/zero" | |||
uuid "github.com/nu7hatch/gouuid" | |||
"github.com/writeas/activityserve" | |||
"github.com/writeas/impart" | |||
"github.com/writeas/nerds/store" | |||
"github.com/writeas/web-core/activitypub" | |||
@@ -61,7 +64,7 @@ type writestore interface { | |||
GetAccessToken(userID int64) (string, error) | |||
GetTemporaryAccessToken(userID int64, validSecs int) (string, error) | |||
GetTemporaryOneTimeAccessToken(userID int64, validSecs int, oneTime bool) (string, error) | |||
DeleteAccount(userID int64) (l *string, err error) | |||
DeleteAccount(userID int64) error | |||
ChangeSettings(app *App, u *User, s *userSettings) error | |||
ChangePassphrase(userID int64, sudo bool, curPass string, hashedPass []byte) error | |||
@@ -124,6 +127,11 @@ type writestore interface { | |||
GetUserLastPostTime(id int64) (*time.Time, error) | |||
GetCollectionLastPostTime(id int64) (*time.Time, error) | |||
GetIDForRemoteUser(context.Context, string, string, string) (int64, error) | |||
RecordRemoteUserID(context.Context, int64, string, string, string, string) error | |||
ValidateOAuthState(context.Context, string) (string, string, error) | |||
GenerateOAuthState(context.Context, string, string) (string, error) | |||
DatabaseInitialized() bool | |||
} | |||
@@ -132,6 +140,8 @@ type datastore struct { | |||
driverName string | |||
} | |||
var _ writestore = &datastore{} | |||
func (db *datastore) now() string { | |||
if db.driverName == driverSQLite { | |||
return "strftime('%Y-%m-%d %H:%M:%S','now')" | |||
@@ -2104,22 +2114,13 @@ func (db *datastore) CollectionHasAttribute(id int64, attr string) bool { | |||
return true | |||
} | |||
func (db *datastore) DeleteAccount(userID int64) (l *string, err error) { | |||
debug := "" | |||
l = &debug | |||
t, err := db.Begin() | |||
if err != nil { | |||
stringLogln(l, "Unable to begin: %v", err) | |||
return | |||
} | |||
// DeleteAccount will delete the entire account for userID | |||
func (db *datastore) DeleteAccount(userID int64) error { | |||
// Get all collections | |||
rows, err := db.Query("SELECT id, alias FROM collections WHERE owner_id = ?", userID) | |||
if err != nil { | |||
t.Rollback() | |||
stringLogln(l, "Unable to get collections: %v", err) | |||
return | |||
log.Error("Unable to get collections: %v", err) | |||
return err | |||
} | |||
defer rows.Close() | |||
colls := []Collection{} | |||
@@ -2127,103 +2128,158 @@ func (db *datastore) DeleteAccount(userID int64) (l *string, err error) { | |||
for rows.Next() { | |||
err = rows.Scan(&c.ID, &c.Alias) | |||
if err != nil { | |||
t.Rollback() | |||
stringLogln(l, "Unable to scan collection cols: %v", err) | |||
return | |||
log.Error("Unable to scan collection cols: %v", err) | |||
return err | |||
} | |||
colls = append(colls, c) | |||
} | |||
// Start transaction | |||
t, err := db.Begin() | |||
if err != nil { | |||
log.Error("Unable to begin: %v", err) | |||
return err | |||
} | |||
// Clean up all collection related information | |||
var res sql.Result | |||
for _, c := range colls { | |||
// TODO: user deleteCollection() func | |||
// Delete tokens | |||
res, err = t.Exec("DELETE FROM collectionattributes WHERE collection_id = ?", c.ID) | |||
if err != nil { | |||
t.Rollback() | |||
stringLogln(l, "Unable to delete attributes on %s: %v", c.Alias, err) | |||
return | |||
log.Error("Unable to delete attributes on %s: %v", c.Alias, err) | |||
return err | |||
} | |||
rs, _ := res.RowsAffected() | |||
stringLogln(l, "Deleted %d for %s from collectionattributes", rs, c.Alias) | |||
log.Info("Deleted %d for %s from collectionattributes", rs, c.Alias) | |||
// Remove any optional collection password | |||
res, err = t.Exec("DELETE FROM collectionpasswords WHERE collection_id = ?", c.ID) | |||
if err != nil { | |||
t.Rollback() | |||
stringLogln(l, "Unable to delete passwords on %s: %v", c.Alias, err) | |||
return | |||
log.Error("Unable to delete passwords on %s: %v", c.Alias, err) | |||
return err | |||
} | |||
rs, _ = res.RowsAffected() | |||
stringLogln(l, "Deleted %d for %s from collectionpasswords", rs, c.Alias) | |||
log.Info("Deleted %d for %s from collectionpasswords", rs, c.Alias) | |||
// Remove redirects to this collection | |||
res, err = t.Exec("DELETE FROM collectionredirects WHERE new_alias = ?", c.Alias) | |||
if err != nil { | |||
t.Rollback() | |||
stringLogln(l, "Unable to delete redirects on %s: %v", c.Alias, err) | |||
return | |||
log.Error("Unable to delete redirects on %s: %v", c.Alias, err) | |||
return err | |||
} | |||
rs, _ = res.RowsAffected() | |||
log.Info("Deleted %d for %s from collectionredirects", rs, c.Alias) | |||
// Remove any collection keys | |||
res, err = t.Exec("DELETE FROM collectionkeys WHERE collection_id = ?", c.ID) | |||
if err != nil { | |||
t.Rollback() | |||
log.Error("Unable to delete keys on %s: %v", c.Alias, err) | |||
return err | |||
} | |||
rs, _ = res.RowsAffected() | |||
log.Info("Deleted %d for %s from collectionkeys", rs, c.Alias) | |||
// TODO: federate delete collection | |||
// Remove remote follows | |||
res, err = t.Exec("DELETE FROM remotefollows WHERE collection_id = ?", c.ID) | |||
if err != nil { | |||
t.Rollback() | |||
log.Error("Unable to delete remote follows on %s: %v", c.Alias, err) | |||
return err | |||
} | |||
rs, _ = res.RowsAffected() | |||
stringLogln(l, "Deleted %d for %s from collectionredirects", rs, c.Alias) | |||
log.Info("Deleted %d for %s from remotefollows", rs, c.Alias) | |||
} | |||
// Delete collections | |||
res, err = t.Exec("DELETE FROM collections WHERE owner_id = ?", userID) | |||
if err != nil { | |||
t.Rollback() | |||
stringLogln(l, "Unable to delete collections: %v", err) | |||
return | |||
log.Error("Unable to delete collections: %v", err) | |||
return err | |||
} | |||
rs, _ := res.RowsAffected() | |||
stringLogln(l, "Deleted %d from collections", rs) | |||
log.Info("Deleted %d from collections", rs) | |||
// Delete tokens | |||
res, err = t.Exec("DELETE FROM accesstokens WHERE user_id = ?", userID) | |||
if err != nil { | |||
t.Rollback() | |||
stringLogln(l, "Unable to delete access tokens: %v", err) | |||
return | |||
log.Error("Unable to delete access tokens: %v", err) | |||
return err | |||
} | |||
rs, _ = res.RowsAffected() | |||
stringLogln(l, "Deleted %d from accesstokens", rs) | |||
log.Info("Deleted %d from accesstokens", rs) | |||
// Delete user attributes | |||
res, err = t.Exec("DELETE FROM oauth_users WHERE user_id = ?", userID) | |||
if err != nil { | |||
t.Rollback() | |||
log.Error("Unable to delete oauth_users: %v", err) | |||
return err | |||
} | |||
rs, _ = res.RowsAffected() | |||
log.Info("Deleted %d from oauth_users", rs) | |||
// Delete posts | |||
// TODO: should maybe get each row so we can federate a delete | |||
// if so needs to be outside of transaction like collections | |||
res, err = t.Exec("DELETE FROM posts WHERE owner_id = ?", userID) | |||
if err != nil { | |||
t.Rollback() | |||
stringLogln(l, "Unable to delete posts: %v", err) | |||
return | |||
log.Error("Unable to delete posts: %v", err) | |||
return err | |||
} | |||
rs, _ = res.RowsAffected() | |||
stringLogln(l, "Deleted %d from posts", rs) | |||
log.Info("Deleted %d from posts", rs) | |||
// Delete user attributes | |||
res, err = t.Exec("DELETE FROM userattributes WHERE user_id = ?", userID) | |||
if err != nil { | |||
t.Rollback() | |||
stringLogln(l, "Unable to delete attributes: %v", err) | |||
return | |||
log.Error("Unable to delete attributes: %v", err) | |||
return err | |||
} | |||
rs, _ = res.RowsAffected() | |||
log.Info("Deleted %d from userattributes", rs) | |||
// Delete user invites | |||
res, err = t.Exec("DELETE FROM userinvites WHERE owner_id = ?", userID) | |||
if err != nil { | |||
t.Rollback() | |||
log.Error("Unable to delete invites: %v", err) | |||
return err | |||
} | |||
rs, _ = res.RowsAffected() | |||
stringLogln(l, "Deleted %d from userattributes", rs) | |||
log.Info("Deleted %d from userinvites", rs) | |||
// Delete the user | |||
res, err = t.Exec("DELETE FROM users WHERE id = ?", userID) | |||
if err != nil { | |||
t.Rollback() | |||
stringLogln(l, "Unable to delete user: %v", err) | |||
return | |||
log.Error("Unable to delete user: %v", err) | |||
return err | |||
} | |||
rs, _ = res.RowsAffected() | |||
stringLogln(l, "Deleted %d from users", rs) | |||
log.Info("Deleted %d from users", rs) | |||
// Commit all changes to the database | |||
err = t.Commit() | |||
if err != nil { | |||
t.Rollback() | |||
stringLogln(l, "Unable to commit: %v", err) | |||
return | |||
log.Error("Unable to commit: %v", err) | |||
return err | |||
} | |||
return | |||
// TODO: federate delete actor | |||
return nil | |||
} | |||
func (db *datastore) GetAPActorKeys(collectionID int64) ([]byte, []byte) { | |||
@@ -2453,6 +2509,69 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) { | |||
return &t, nil | |||
} | |||
func (db *datastore) GenerateOAuthState(ctx context.Context, provider, clientID string) (string, error) { | |||
state := store.Generate62RandomString(24) | |||
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at) VALUES (?, ?, ?, FALSE, NOW())", state, provider, clientID) | |||
if err != nil { | |||
return "", fmt.Errorf("unable to record oauth client state: %w", err) | |||
} | |||
return state, nil | |||
} | |||
func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) { | |||
var provider string | |||
var clientID string | |||
err := wf_db.RunTransactionWithOptions(ctx, db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { | |||
err := tx.QueryRow("SELECT provider, client_id FROM oauth_client_states WHERE state = ? AND used = FALSE", state).Scan(&provider, &clientID) | |||
if err != nil { | |||
return err | |||
} | |||
res, err := tx.ExecContext(ctx, "UPDATE oauth_client_states SET used = TRUE WHERE state = ?", state) | |||
if err != nil { | |||
return err | |||
} | |||
rowsAffected, err := res.RowsAffected() | |||
if err != nil { | |||
return err | |||
} | |||
if rowsAffected != 1 { | |||
return fmt.Errorf("state not found") | |||
} | |||
return nil | |||
}) | |||
if err != nil { | |||
return "", "", nil | |||
} | |||
return provider, clientID, nil | |||
} | |||
func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error { | |||
var err error | |||
if db.driverName == driverSQLite { | |||
_, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO oauth_users (user_id, remote_user_id, provider, client_id, access_token) VALUES (?, ?, ?, ?, ?)", localUserID, remoteUserID, provider, clientID, accessToken) | |||
} else { | |||
_, err = db.ExecContext(ctx, "INSERT INTO oauth_users (user_id, remote_user_id, provider, client_id, access_token) VALUES (?, ?, ?, ?, ?) "+db.upsert("user")+" access_token = ?", localUserID, remoteUserID, provider, clientID, accessToken, accessToken) | |||
} | |||
if err != nil { | |||
log.Error("Unable to INSERT oauth_users for '%d': %v", localUserID, err) | |||
} | |||
return err | |||
} | |||
// GetIDForRemoteUser returns a user ID associated with a remote user ID. | |||
func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) { | |||
var userID int64 = -1 | |||
err := db. | |||
QueryRowContext(ctx, "SELECT user_id FROM oauth_users WHERE remote_user_id = ? AND provider = ? AND client_id = ?", remoteUserID, provider, clientID). | |||
Scan(&userID) | |||
// Not finding a record is OK. | |||
if err != nil && err != sql.ErrNoRows { | |||
return -1, err | |||
} | |||
return userID, nil | |||
} | |||
// DatabaseInitialized returns whether or not the current datastore has been | |||
// initialized with the correct schema. | |||
// Currently, it checks to see if the `users` table exists. | |||
@@ -2483,3 +2602,40 @@ func handleFailedPostInsert(err error) error { | |||
log.Error("Couldn't insert into posts: %v", err) | |||
return err | |||
} | |||
func (db *datastore) GetProfilePageFromHandle(app *App, handle string) (string, error) { | |||
actorIRI := "" | |||
remoteUser, err := getRemoteUserFromHandle(app, handle) | |||
if err != nil { | |||
// can't find using handle in the table but the table may already have this user without | |||
// handle from a previous version | |||
// TODO: Make this determination. We should know whether a user exists without a handle, or doesn't exist at all | |||
actorIRI = RemoteLookup(handle) | |||
_, errRemoteUser := getRemoteUser(app, actorIRI) | |||
// if it exists then we need to update the handle | |||
if errRemoteUser == nil { | |||
_, err := app.db.Exec("UPDATE remoteusers SET handle = ? WHERE actor_id = ?", handle, actorIRI) | |||
if err != nil { | |||
log.Error("Can't update handle (" + handle + ") in database for user " + actorIRI) | |||
} | |||
} else { | |||
// this probably means we don't have the user in the table so let's try to insert it | |||
// here we need to ask the server for the inboxes | |||
remoteActor, err := activityserve.NewRemoteActor(actorIRI) | |||
if err != nil { | |||
log.Error("Couldn't fetch remote actor", err) | |||
} | |||
if debugging { | |||
log.Info("%s %s %s %s", actorIRI, remoteActor.GetInbox(), remoteActor.GetSharedInbox(), handle) | |||
} | |||
_, err = app.db.Exec("INSERT INTO remoteusers (actor_id, inbox, shared_inbox, handle) VALUES(?, ?, ?, ?)", actorIRI, remoteActor.GetInbox(), remoteActor.GetSharedInbox(), handle) | |||
if err != nil { | |||
log.Error("Can't insert remote user in database", err) | |||
return "", err | |||
} | |||
} | |||
} else { | |||
actorIRI = remoteUser.ActorID | |||
} | |||
return actorIRI, nil | |||
} |
@@ -0,0 +1,50 @@ | |||
package writefreely | |||
import ( | |||
"context" | |||
"database/sql" | |||
"github.com/stretchr/testify/assert" | |||
"testing" | |||
) | |||
func TestOAuthDatastore(t *testing.T) { | |||
if !runMySQLTests() { | |||
t.Skip("skipping mysql tests") | |||
} | |||
withTestDB(t, func(db *sql.DB) { | |||
ctx := context.Background() | |||
ds := &datastore{ | |||
DB: db, | |||
driverName: "", | |||
} | |||
state, err := ds.GenerateOAuthState(ctx, "test", "development") | |||
assert.NoError(t, err) | |||
assert.Len(t, state, 24) | |||
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = false", state) | |||
_, _, err = ds.ValidateOAuthState(ctx, state) | |||
assert.NoError(t, err) | |||
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = true", state) | |||
var localUserID int64 = 99 | |||
var remoteUserID = "100" | |||
err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID, "test", "test", "access_token_a") | |||
assert.NoError(t, err) | |||
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_users` WHERE `user_id` = ? AND `remote_user_id` = ? AND access_token = 'access_token_a'", localUserID, remoteUserID) | |||
err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID, "test", "test", "access_token_b") | |||
assert.NoError(t, err) | |||
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_users` WHERE `user_id` = ? AND `remote_user_id` = ? AND access_token = 'access_token_b'", localUserID, remoteUserID) | |||
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_users`") | |||
foundUserID, err := ds.GetIDForRemoteUser(ctx, remoteUserID, "test", "test") | |||
assert.NoError(t, err) | |||
assert.Equal(t, localUserID, foundUserID) | |||
}) | |||
} |
@@ -0,0 +1,52 @@ | |||
package db | |||
import ( | |||
"fmt" | |||
"strings" | |||
) | |||
type AlterTableSqlBuilder struct { | |||
Dialect DialectType | |||
Name string | |||
Changes []string | |||
} | |||
func (b *AlterTableSqlBuilder) AddColumn(col *Column) *AlterTableSqlBuilder { | |||
if colVal, err := col.String(); err == nil { | |||
b.Changes = append(b.Changes, fmt.Sprintf("ADD COLUMN %s", colVal)) | |||
} | |||
return b | |||
} | |||
func (b *AlterTableSqlBuilder) ChangeColumn(name string, col *Column) *AlterTableSqlBuilder { | |||
if colVal, err := col.String(); err == nil { | |||
b.Changes = append(b.Changes, fmt.Sprintf("CHANGE COLUMN %s %s", name, colVal)) | |||
} | |||
return b | |||
} | |||
func (b *AlterTableSqlBuilder) AddUniqueConstraint(name string, columns ...string) *AlterTableSqlBuilder { | |||
b.Changes = append(b.Changes, fmt.Sprintf("ADD CONSTRAINT %s UNIQUE (%s)", name, strings.Join(columns, ", "))) | |||
return b | |||
} | |||
func (b *AlterTableSqlBuilder) ToSQL() (string, error) { | |||
var str strings.Builder | |||
str.WriteString("ALTER TABLE ") | |||
str.WriteString(b.Name) | |||
str.WriteString(" ") | |||
if len(b.Changes) == 0 { | |||
return "", fmt.Errorf("no changes provide for table: %s", b.Name) | |||
} | |||
changeCount := len(b.Changes) | |||
for i, thing := range b.Changes { | |||
str.WriteString(thing) | |||
if i < changeCount-1 { | |||
str.WriteString(", ") | |||
} | |||
} | |||
return str.String(), nil | |||
} |
@@ -0,0 +1,56 @@ | |||
package db | |||
import "testing" | |||
func TestAlterTableSqlBuilder_ToSQL(t *testing.T) { | |||
type fields struct { | |||
Dialect DialectType | |||
Name string | |||
Changes []string | |||
} | |||
tests := []struct { | |||
name string | |||
builder *AlterTableSqlBuilder | |||
want string | |||
wantErr bool | |||
}{ | |||
{ | |||
name: "MySQL add int", | |||
builder: DialectMySQL. | |||
AlterTable("the_table"). | |||
AddColumn(DialectMySQL.Column("the_col", ColumnTypeInteger, UnsetSize)), | |||
want: "ALTER TABLE the_table ADD COLUMN the_col INT NOT NULL", | |||
wantErr: false, | |||
}, | |||
{ | |||
name: "MySQL add string", | |||
builder: DialectMySQL. | |||
AlterTable("the_table"). | |||
AddColumn(DialectMySQL.Column("the_col", ColumnTypeVarChar, OptionalInt{true, 128})), | |||
want: "ALTER TABLE the_table ADD COLUMN the_col VARCHAR(128) NOT NULL", | |||
wantErr: false, | |||
}, | |||
{ | |||
name: "MySQL add int and string", | |||
builder: DialectMySQL. | |||
AlterTable("the_table"). | |||
AddColumn(DialectMySQL.Column("first_col", ColumnTypeInteger, UnsetSize)). | |||
AddColumn(DialectMySQL.Column("second_col", ColumnTypeVarChar, OptionalInt{true, 128})), | |||
want: "ALTER TABLE the_table ADD COLUMN first_col INT NOT NULL, ADD COLUMN second_col VARCHAR(128) NOT NULL", | |||
wantErr: false, | |||
}, | |||
} | |||
for _, tt := range tests { | |||
t.Run(tt.name, func(t *testing.T) { | |||
got, err := tt.builder.ToSQL() | |||
if (err != nil) != tt.wantErr { | |||
t.Errorf("ToSQL() error = %v, wantErr %v", err, tt.wantErr) | |||
return | |||
} | |||
if got != tt.want { | |||
t.Errorf("ToSQL() got = %v, want %v", got, tt.want) | |||
} | |||
}) | |||
} | |||
} |
@@ -0,0 +1,244 @@ | |||
package db | |||
import ( | |||
"fmt" | |||
"strings" | |||
) | |||
type ColumnType int | |||
type OptionalInt struct { | |||
Set bool | |||
Value int | |||
} | |||
type OptionalString struct { | |||
Set bool | |||
Value string | |||
} | |||
type SQLBuilder interface { | |||
ToSQL() (string, error) | |||
} | |||
type Column struct { | |||
Dialect DialectType | |||
Name string | |||
Nullable bool | |||
Default OptionalString | |||
Type ColumnType | |||
Size OptionalInt | |||
PrimaryKey bool | |||
} | |||
type CreateTableSqlBuilder struct { | |||
Dialect DialectType | |||
Name string | |||
IfNotExists bool | |||
ColumnOrder []string | |||
Columns map[string]*Column | |||
Constraints []string | |||
} | |||
const ( | |||
ColumnTypeBool ColumnType = iota | |||
ColumnTypeSmallInt ColumnType = iota | |||
ColumnTypeInteger ColumnType = iota | |||
ColumnTypeChar ColumnType = iota | |||
ColumnTypeVarChar ColumnType = iota | |||
ColumnTypeText ColumnType = iota | |||
ColumnTypeDateTime ColumnType = iota | |||
) | |||
var _ SQLBuilder = &CreateTableSqlBuilder{} | |||
var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0} | |||
var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""} | |||
func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) { | |||
if dialect != DialectMySQL && dialect != DialectSQLite { | |||
return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) | |||
} | |||
switch d { | |||
case ColumnTypeSmallInt: | |||
{ | |||
if dialect == DialectSQLite { | |||
return "INTEGER", nil | |||
} | |||
mod := "" | |||
if size.Set { | |||
mod = fmt.Sprintf("(%d)", size.Value) | |||
} | |||
return "SMALLINT" + mod, nil | |||
} | |||
case ColumnTypeInteger: | |||
{ | |||
if dialect == DialectSQLite { | |||
return "INTEGER", nil | |||
} | |||
mod := "" | |||
if size.Set { | |||
mod = fmt.Sprintf("(%d)", size.Value) | |||
} | |||
return "INT" + mod, nil | |||
} | |||
case ColumnTypeChar: | |||
{ | |||
if dialect == DialectSQLite { | |||
return "TEXT", nil | |||
} | |||
mod := "" | |||
if size.Set { | |||
mod = fmt.Sprintf("(%d)", size.Value) | |||
} | |||
return "CHAR" + mod, nil | |||
} | |||
case ColumnTypeVarChar: | |||
{ | |||
if dialect == DialectSQLite { | |||
return "TEXT", nil | |||
} | |||
mod := "" | |||
if size.Set { | |||
mod = fmt.Sprintf("(%d)", size.Value) | |||
} | |||
return "VARCHAR" + mod, nil | |||
} | |||
case ColumnTypeBool: | |||
{ | |||
if dialect == DialectSQLite { | |||
return "INTEGER", nil | |||
} | |||
return "TINYINT(1)", nil | |||
} | |||
case ColumnTypeDateTime: | |||
return "DATETIME", nil | |||
case ColumnTypeText: | |||
return "TEXT", nil | |||
} | |||
return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) | |||
} | |||
func (c *Column) SetName(name string) *Column { | |||
c.Name = name | |||
return c | |||
} | |||
func (c *Column) SetNullable(nullable bool) *Column { | |||
c.Nullable = nullable | |||
return c | |||
} | |||
func (c *Column) SetPrimaryKey(pk bool) *Column { | |||
c.PrimaryKey = pk | |||
return c | |||
} | |||
func (c *Column) SetDefault(value string) *Column { | |||
c.Default = OptionalString{Set: true, Value: value} | |||
return c | |||
} | |||
func (c *Column) SetType(t ColumnType) *Column { | |||
c.Type = t | |||
return c | |||
} | |||
func (c *Column) SetSize(size int) *Column { | |||
c.Size = OptionalInt{Set: true, Value: size} | |||
return c | |||
} | |||
func (c *Column) String() (string, error) { | |||
var str strings.Builder | |||
str.WriteString(c.Name) | |||
str.WriteString(" ") | |||
typeStr, err := c.Type.Format(c.Dialect, c.Size) | |||
if err != nil { | |||
return "", err | |||
} | |||
str.WriteString(typeStr) | |||
if !c.Nullable { | |||
str.WriteString(" NOT NULL") | |||
} | |||
if c.Default.Set { | |||
str.WriteString(" DEFAULT ") | |||
str.WriteString(c.Default.Value) | |||
} | |||
if c.PrimaryKey { | |||
str.WriteString(" PRIMARY KEY") | |||
} | |||
return str.String(), nil | |||
} | |||
func (b *CreateTableSqlBuilder) Column(column *Column) *CreateTableSqlBuilder { | |||
if b.Columns == nil { | |||
b.Columns = make(map[string]*Column) | |||
} | |||
b.Columns[column.Name] = column | |||
b.ColumnOrder = append(b.ColumnOrder, column.Name) | |||
return b | |||
} | |||
func (b *CreateTableSqlBuilder) UniqueConstraint(columns ...string) *CreateTableSqlBuilder { | |||
for _, column := range columns { | |||
if _, ok := b.Columns[column]; !ok { | |||
// This fails silently. | |||
return b | |||
} | |||
} | |||
b.Constraints = append(b.Constraints, fmt.Sprintf("UNIQUE(%s)", strings.Join(columns, ","))) | |||
return b | |||
} | |||
func (b *CreateTableSqlBuilder) SetIfNotExists(ine bool) *CreateTableSqlBuilder { | |||
b.IfNotExists = ine | |||
return b | |||
} | |||
func (b *CreateTableSqlBuilder) ToSQL() (string, error) { | |||
var str strings.Builder | |||
str.WriteString("CREATE TABLE ") | |||
if b.IfNotExists { | |||
str.WriteString("IF NOT EXISTS ") | |||
} | |||
str.WriteString(b.Name) | |||
var things []string | |||
for _, columnName := range b.ColumnOrder { | |||
column, ok := b.Columns[columnName] | |||
if !ok { | |||
return "", fmt.Errorf("column not found: %s", columnName) | |||
} | |||
columnStr, err := column.String() | |||
if err != nil { | |||
return "", err | |||
} | |||
things = append(things, columnStr) | |||
} | |||
for _, constraint := range b.Constraints { | |||
things = append(things, constraint) | |||
} | |||
if thingLen := len(things); thingLen > 0 { | |||
str.WriteString(" ( ") | |||
for i, thing := range things { | |||
str.WriteString(thing) | |||
if i < thingLen-1 { | |||
str.WriteString(", ") | |||
} | |||
} | |||
str.WriteString(" )") | |||
} | |||
return str.String(), nil | |||
} | |||
@@ -0,0 +1,146 @@ | |||
package db | |||
import ( | |||
"github.com/stretchr/testify/assert" | |||
"testing" | |||
) | |||
func TestDialect_Column(t *testing.T) { | |||
c1 := DialectSQLite.Column("foo", ColumnTypeBool, UnsetSize) | |||
assert.Equal(t, DialectSQLite, c1.Dialect) | |||
c2 := DialectMySQL.Column("foo", ColumnTypeBool, UnsetSize) | |||
assert.Equal(t, DialectMySQL, c2.Dialect) | |||
} | |||
func TestColumnType_Format(t *testing.T) { | |||
type args struct { | |||
dialect DialectType | |||
size OptionalInt | |||
} | |||
tests := []struct { | |||
name string | |||
d ColumnType | |||
args args | |||
want string | |||
wantErr bool | |||
}{ | |||
{"Sqlite bool", ColumnTypeBool, args{dialect: DialectSQLite}, "INTEGER", false}, | |||
{"Sqlite small int", ColumnTypeSmallInt, args{dialect: DialectSQLite}, "INTEGER", false}, | |||
{"Sqlite int", ColumnTypeInteger, args{dialect: DialectSQLite}, "INTEGER", false}, | |||
{"Sqlite char", ColumnTypeChar, args{dialect: DialectSQLite}, "TEXT", false}, | |||
{"Sqlite varchar", ColumnTypeVarChar, args{dialect: DialectSQLite}, "TEXT", false}, | |||
{"Sqlite text", ColumnTypeText, args{dialect: DialectSQLite}, "TEXT", false}, | |||
{"Sqlite datetime", ColumnTypeDateTime, args{dialect: DialectSQLite}, "DATETIME", false}, | |||
{"MySQL bool", ColumnTypeBool, args{dialect: DialectMySQL}, "TINYINT(1)", false}, | |||
{"MySQL small int", ColumnTypeSmallInt, args{dialect: DialectMySQL}, "SMALLINT", false}, | |||
{"MySQL small int with param", ColumnTypeSmallInt, args{dialect: DialectMySQL, size: OptionalInt{true, 3}}, "SMALLINT(3)", false}, | |||
{"MySQL int", ColumnTypeInteger, args{dialect: DialectMySQL}, "INT", false}, | |||
{"MySQL int with param", ColumnTypeInteger, args{dialect: DialectMySQL, size: OptionalInt{true, 11}}, "INT(11)", false}, | |||
{"MySQL char", ColumnTypeChar, args{dialect: DialectMySQL}, "CHAR", false}, | |||
{"MySQL char with param", ColumnTypeChar, args{dialect: DialectMySQL, size: OptionalInt{true, 4}}, "CHAR(4)", false}, | |||
{"MySQL varchar", ColumnTypeVarChar, args{dialect: DialectMySQL}, "VARCHAR", false}, | |||
{"MySQL varchar with param", ColumnTypeVarChar, args{dialect: DialectMySQL, size: OptionalInt{true, 25}}, "VARCHAR(25)", false}, | |||
{"MySQL text", ColumnTypeText, args{dialect: DialectMySQL}, "TEXT", false}, | |||
{"MySQL datetime", ColumnTypeDateTime, args{dialect: DialectMySQL}, "DATETIME", false}, | |||
{"invalid column type", 10000, args{dialect: DialectMySQL}, "", true}, | |||
{"invalid dialect", ColumnTypeBool, args{dialect: 10000}, "", true}, | |||
} | |||
for _, tt := range tests { | |||
t.Run(tt.name, func(t *testing.T) { | |||
got, err := tt.d.Format(tt.args.dialect, tt.args.size) | |||
if (err != nil) != tt.wantErr { | |||
t.Errorf("Format() error = %v, wantErr %v", err, tt.wantErr) | |||
return | |||
} | |||
if got != tt.want { | |||
t.Errorf("Format() got = %v, want %v", got, tt.want) | |||
} | |||
}) | |||
} | |||
} | |||
func TestColumn_Build(t *testing.T) { | |||
type fields struct { | |||
Dialect DialectType | |||
Name string | |||
Nullable bool | |||
Default OptionalString | |||
Type ColumnType | |||
Size OptionalInt | |||
PrimaryKey bool | |||
} | |||
tests := []struct { | |||
name string | |||
fields fields | |||
want string | |||
wantErr bool | |||
}{ | |||
{"Sqlite bool", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo INTEGER NOT NULL", false}, | |||
{"Sqlite bool nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo INTEGER", false}, | |||
{"Sqlite small int", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeSmallInt, UnsetSize, true}, "foo INTEGER NOT NULL PRIMARY KEY", false}, | |||
{"Sqlite small int nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeSmallInt, UnsetSize, false}, "foo INTEGER", false}, | |||
{"Sqlite int", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INTEGER NOT NULL", false}, | |||
{"Sqlite int nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INTEGER", false}, | |||
{"Sqlite char", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo TEXT NOT NULL", false}, | |||
{"Sqlite char nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo TEXT", false}, | |||
{"Sqlite varchar", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo TEXT NOT NULL", false}, | |||
{"Sqlite varchar nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo TEXT", false}, | |||
{"Sqlite text", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT NOT NULL", false}, | |||
{"Sqlite text nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT", false}, | |||
{"Sqlite datetime", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME NOT NULL", false}, | |||
{"Sqlite datetime nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME", false}, | |||
{"MySQL bool", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo TINYINT(1) NOT NULL", false}, | |||
{"MySQL bool nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo TINYINT(1)", false}, | |||
{"MySQL small int", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeSmallInt, UnsetSize, true}, "foo SMALLINT NOT NULL PRIMARY KEY", false}, | |||
{"MySQL small int nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeSmallInt, UnsetSize, false}, "foo SMALLINT", false}, | |||
{"MySQL int", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INT NOT NULL", false}, | |||
{"MySQL int nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INT", false}, | |||
{"MySQL char", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo CHAR NOT NULL", false}, | |||
{"MySQL char nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo CHAR", false}, | |||
{"MySQL varchar", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo VARCHAR NOT NULL", false}, | |||
{"MySQL varchar nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo VARCHAR", false}, | |||
{"MySQL text", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT NOT NULL", false}, | |||
{"MySQL text nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT", false}, | |||
{"MySQL datetime", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME NOT NULL", false}, | |||
{"MySQL datetime nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME", false}, | |||
} | |||
for _, tt := range tests { | |||
t.Run(tt.name, func(t *testing.T) { | |||
c := &Column{ | |||
Dialect: tt.fields.Dialect, | |||
Name: tt.fields.Name, | |||
Nullable: tt.fields.Nullable, | |||
Default: tt.fields.Default, | |||
Type: tt.fields.Type, | |||
Size: tt.fields.Size, | |||
PrimaryKey: tt.fields.PrimaryKey, | |||
} | |||
if got, err := c.String(); got != tt.want { | |||
if (err != nil) != tt.wantErr { | |||
t.Errorf("String() error = %v, wantErr %v", err, tt.wantErr) | |||
return | |||
} | |||
if got != tt.want { | |||
t.Errorf("String() got = %v, want %v", got, tt.want) | |||
} | |||
} | |||
}) | |||
} | |||
} | |||
func TestCreateTableSqlBuilder_ToSQL(t *testing.T) { | |||
sql, err := DialectMySQL. | |||
Table("foo"). | |||
SetIfNotExists(true). | |||
Column(DialectMySQL.Column("bar", ColumnTypeInteger, UnsetSize).SetPrimaryKey(true)). | |||
Column(DialectMySQL.Column("baz", ColumnTypeText, UnsetSize)). | |||
Column(DialectMySQL.Column("qux", ColumnTypeDateTime, UnsetSize).SetDefault("NOW()")). | |||
UniqueConstraint("bar"). | |||
UniqueConstraint("bar", "baz"). | |||
ToSQL() | |||
assert.NoError(t, err) | |||
assert.Equal(t, "CREATE TABLE IF NOT EXISTS foo ( bar INT NOT NULL PRIMARY KEY, baz TEXT NOT NULL, qux DATETIME NOT NULL DEFAULT NOW(), UNIQUE(bar), UNIQUE(bar,baz) )", sql) | |||
} |
@@ -0,0 +1,76 @@ | |||
package db | |||
import "fmt" | |||
type DialectType int | |||
const ( | |||
DialectSQLite DialectType = iota | |||
DialectMySQL DialectType = iota | |||
) | |||
func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column { | |||
switch d { | |||
case DialectSQLite: | |||
return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size} | |||
case DialectMySQL: | |||
return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size} | |||
default: | |||
panic(fmt.Sprintf("unexpected dialect: %d", d)) | |||
} | |||
} | |||
func (d DialectType) Table(name string) *CreateTableSqlBuilder { | |||
switch d { | |||
case DialectSQLite: | |||
return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name} | |||
case DialectMySQL: | |||
return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name} | |||
default: | |||
panic(fmt.Sprintf("unexpected dialect: %d", d)) | |||
} | |||
} | |||
func (d DialectType) AlterTable(name string) *AlterTableSqlBuilder { | |||
switch d { | |||
case DialectSQLite: | |||
return &AlterTableSqlBuilder{Dialect: DialectSQLite, Name: name} | |||
case DialectMySQL: | |||
return &AlterTableSqlBuilder{Dialect: DialectMySQL, Name: name} | |||
default: | |||
panic(fmt.Sprintf("unexpected dialect: %d", d)) | |||
} | |||
} | |||
func (d DialectType) CreateUniqueIndex(name, table string, columns ...string) *CreateIndexSqlBuilder { | |||
switch d { | |||
case DialectSQLite: | |||
return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: true, Columns: columns} | |||
case DialectMySQL: | |||
return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: true, Columns: columns} | |||
default: | |||
panic(fmt.Sprintf("unexpected dialect: %d", d)) | |||
} | |||
} | |||
func (d DialectType) CreateIndex(name, table string, columns ...string) *CreateIndexSqlBuilder { | |||
switch d { | |||
case DialectSQLite: | |||
return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: false, Columns: columns} | |||
case DialectMySQL: | |||
return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: false, Columns: columns} | |||
default: | |||
panic(fmt.Sprintf("unexpected dialect: %d", d)) | |||
} | |||
} | |||
func (d DialectType) DropIndex(name, table string) *DropIndexSqlBuilder { | |||
switch d { | |||
case DialectSQLite: | |||
return &DropIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table} | |||
case DialectMySQL: | |||
return &DropIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table} | |||
default: | |||
panic(fmt.Sprintf("unexpected dialect: %d", d)) | |||
} | |||
} |
@@ -0,0 +1,53 @@ | |||
package db | |||
import ( | |||
"fmt" | |||
"strings" | |||
) | |||
type CreateIndexSqlBuilder struct { | |||
Dialect DialectType | |||
Name string | |||
Table string | |||
Unique bool | |||
Columns []string | |||
} | |||
type DropIndexSqlBuilder struct { | |||
Dialect DialectType | |||
Name string | |||
Table string | |||
} | |||
func (b *CreateIndexSqlBuilder) ToSQL() (string, error) { | |||
var str strings.Builder | |||
str.WriteString("CREATE ") | |||
if b.Unique { | |||
str.WriteString("UNIQUE ") | |||
} | |||
str.WriteString("INDEX ") | |||
str.WriteString(b.Name) | |||
str.WriteString(" on ") | |||
str.WriteString(b.Table) | |||
if len(b.Columns) == 0 { | |||
return "", fmt.Errorf("columns provided for this index: %s", b.Name) | |||
} | |||
str.WriteString(" (") | |||
columnCount := len(b.Columns) | |||
for i, thing := range b.Columns { | |||
str.WriteString(thing) | |||
if i < columnCount-1 { | |||
str.WriteString(", ") | |||
} | |||
} | |||
str.WriteString(")") | |||
return str.String(), nil | |||
} | |||
func (b *DropIndexSqlBuilder) ToSQL() (string, error) { | |||
return fmt.Sprintf("DROP INDEX %s on %s", b.Name, b.Table), nil | |||
} |
@@ -0,0 +1,9 @@ | |||
package db | |||
type RawSqlBuilder struct { | |||
Query string | |||
} | |||
func (b *RawSqlBuilder) ToSQL() (string, error) { | |||
return b.Query, nil | |||
} |
@@ -0,0 +1,26 @@ | |||
package db | |||
import ( | |||
"context" | |||
"database/sql" | |||
) | |||
// TransactionScopedWork describes code executed within a database transaction. | |||
type TransactionScopedWork func(ctx context.Context, db *sql.Tx) error | |||
// RunTransactionWithOptions executes a block of code within a database transaction. | |||
func RunTransactionWithOptions(ctx context.Context, db *sql.DB, txOpts *sql.TxOptions, txWork TransactionScopedWork) error { | |||
tx, err := db.BeginTx(ctx, txOpts) | |||
if err != nil { | |||
return err | |||
} | |||
if err = txWork(ctx, tx); err != nil { | |||
if txErr := tx.Rollback(); txErr != nil { | |||
return txErr | |||
} | |||
return err | |||
} | |||
return tx.Commit() | |||
} | |||
@@ -1,5 +1,5 @@ | |||
/* | |||
* Copyright © 2018 A Bunch Tell LLC. | |||
* Copyright © 2018-2020 A Bunch Tell LLC. | |||
* | |||
* This file is part of WriteFreely. | |||
* | |||
@@ -45,8 +45,9 @@ var ( | |||
ErrPostUnpublished = impart.HTTPError{Status: http.StatusGone, Message: "Post unpublished by author."} | |||
ErrPostFetchError = impart.HTTPError{Status: http.StatusInternalServerError, Message: "We encountered an error getting the post. The humans have been alerted."} | |||
ErrUserNotFound = impart.HTTPError{http.StatusNotFound, "User doesn't exist."} | |||
ErrUserNotFoundEmail = impart.HTTPError{http.StatusNotFound, "Please enter your username instead of your email address."} | |||
ErrUserNotFound = impart.HTTPError{http.StatusNotFound, "User doesn't exist."} | |||
ErrRemoteUserNotFound = impart.HTTPError{http.StatusNotFound, "Remote user not found."} | |||
ErrUserNotFoundEmail = impart.HTTPError{http.StatusNotFound, "Please enter your username instead of your email address."} | |||
ErrUserSilenced = impart.HTTPError{http.StatusForbidden, "Account is silenced."} | |||
) | |||
@@ -6,17 +6,21 @@ require ( | |||
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf // indirect | |||
github.com/captncraig/cors v0.0.0-20180620154129-376d45073b49 // indirect | |||
github.com/clbanning/mxj v1.8.4 // indirect | |||
github.com/dchest/uniuri v0.0.0-20160212164326-8902c56451e9 // indirect | |||
github.com/dustin/go-humanize v1.0.0 | |||
github.com/fatih/color v1.7.0 | |||
github.com/go-fed/httpsig v0.1.1-0.20190924171022-f4c36041199d // indirect | |||
github.com/go-sql-driver/mysql v1.4.1 | |||
github.com/go-test/deep v1.0.1 // indirect | |||
github.com/golang/lint v0.0.0-20181217174547-8f45f776aaf1 // indirect | |||
github.com/gologme/log v0.0.0-20181207131047-4e5d8ccb38e8 // indirect | |||
github.com/gopherjs/gopherjs v0.0.0-20181103185306-d547d1d9531e // indirect | |||
github.com/gorilla/feeds v1.1.0 | |||
github.com/gorilla/mux v1.7.0 | |||
github.com/gorilla/schema v1.0.2 | |||
github.com/gorilla/sessions v1.1.3 | |||
github.com/gorilla/sessions v1.2.0 | |||
github.com/guregu/null v3.4.0+incompatible | |||
github.com/hashicorp/go-multierror v1.0.0 | |||
github.com/ikeikeikeike/go-sitemap-generator/v2 v2.0.2 | |||
github.com/jtolds/gls v4.2.1+incompatible // indirect | |||
github.com/kylemcc/twitter-text-go v0.0.0-20180726194232-7f582f6736ec | |||
@@ -31,18 +35,18 @@ require ( | |||
github.com/pelletier/go-toml v1.2.0 // indirect | |||
github.com/pkg/errors v0.8.1 // indirect | |||
github.com/rainycape/unidecode v0.0.0-20150907023854-cb7f23ec59be // indirect | |||
github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect | |||
github.com/smartystreets/assertions v0.0.0-20190116191733-b6c0e53d7304 // indirect | |||
github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c // indirect | |||
github.com/stretchr/testify v1.3.0 // indirect | |||
github.com/stretchr/testify v1.3.0 | |||
github.com/writeas/activity v0.1.2 | |||
github.com/writeas/activityserve v0.0.0-20191115095800-dd6d19cc8b89 | |||
github.com/writeas/go-strip-markdown v2.0.1+incompatible | |||
github.com/writeas/go-webfinger v0.0.0-20190106002315-85cf805c86d2 | |||
github.com/writeas/httpsig v1.0.0 | |||
github.com/writeas/impart v1.1.0 | |||
github.com/writeas/impart v1.1.1-0.20191230230525-d3c45ced010d | |||
github.com/writeas/import v0.2.0 | |||
github.com/writeas/monday v0.0.0-20181024183321-54a7dd579219 | |||
github.com/writeas/nerds v1.0.0 | |||
github.com/writeas/openssl-go v1.0.0 // indirect | |||
github.com/writeas/saturday v1.7.1 | |||
github.com/writeas/slug v1.2.0 | |||
github.com/writeas/web-core v1.2.0 | |||
@@ -55,6 +59,8 @@ require ( | |||
google.golang.org/appengine v1.4.0 // indirect | |||
gopkg.in/alecthomas/kingpin.v3-unstable v3.0.0-20180810215634-df19058c872c // indirect | |||
gopkg.in/ini.v1 v1.41.0 | |||
gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0 // indirect | |||
gopkg.in/yaml.v2 v2.2.2 // indirect | |||
src.techknowlogick.com/xgo v0.0.0-20200129005940-d0fae26e014b // indirect | |||
) | |||
go 1.13 |
@@ -1,3 +1,5 @@ | |||
code.as/core/socks v1.0.0 h1:SPQXNp4SbEwjOAP9VzUahLHak8SDqy5n+9cm9tpjZOs= | |||
code.as/core/socks v1.0.0/go.mod h1:BAXBy5O9s2gmw6UxLqNJcVbWY7C/UPs+801CcSsfWOY= | |||
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= | |||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= | |||
github.com/alecthomas/gometalinter v2.0.11+incompatible/go.mod h1:qfIpQGGz3d+NmgyPBqv+LSh50emm1pt72EtcX2vKYQk= | |||
@@ -23,13 +25,18 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk | |||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | |||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= | |||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | |||
github.com/dchest/uniuri v0.0.0-20160212164326-8902c56451e9 h1:74lLNRzvsdIlkTgfDSMuaPjBr4cf6k7pwQQANm/yLKU= | |||
github.com/dchest/uniuri v0.0.0-20160212164326-8902c56451e9/go.mod h1:GgB8SF9nRG+GqaDtLcwJZsQFhcogVCJ79j4EdT0c2V4= | |||
github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= | |||
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= | |||
github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys= | |||
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= | |||
github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo= | |||
github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= | |||
github.com/go-fed/httpsig v0.1.0 h1:6F2OxRVnNTN4OPN+Mc2jxs2WEay9/qiHT/jphlvAwIY= | |||
github.com/go-fed/httpsig v0.1.0/go.mod h1:T56HUNYZUQ1AGUzhAYPugZfp36sKApVnGBgKlIY+aIE= | |||
github.com/go-fed/httpsig v0.1.1-0.20190924171022-f4c36041199d h1:+uoOvOnNDgsYbWtAij4xP6Rgir3eJGjocFPxBJETU/U= | |||
github.com/go-fed/httpsig v0.1.1-0.20190924171022-f4c36041199d/go.mod h1:T56HUNYZUQ1AGUzhAYPugZfp36sKApVnGBgKlIY+aIE= | |||
github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= | |||
github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= | |||
github.com/go-test/deep v1.0.1 h1:UQhStjbkDClarlmv0am7OXXO4/GaPdCGiUiMTvi28sg= | |||
@@ -38,14 +45,14 @@ github.com/golang/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:tluoj9z5200j | |||
github.com/golang/lint v0.0.0-20181217174547-8f45f776aaf1 h1:6DVPu65tee05kY0/rciBQ47ue+AnuY8KTayV6VHikIo= | |||
github.com/golang/lint v0.0.0-20181217174547-8f45f776aaf1/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= | |||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= | |||
github.com/gologme/log v0.0.0-20181207131047-4e5d8ccb38e8 h1:WD8iJ37bRNwvETMfVTusVSAi0WdXTpfNVGY2aHycNKY= | |||
github.com/gologme/log v0.0.0-20181207131047-4e5d8ccb38e8/go.mod h1:gq31gQ8wEHkR+WekdWsqDuf8pXTUZA9BnnzTuPz1Y9U= | |||
github.com/google/shlex v0.0.0-20181106134648-c34317bd91bf h1:7+FW5aGwISbqUtkfmIpZJGRgNFg2ioYPvFaUxdqpDsg= | |||
github.com/google/shlex v0.0.0-20181106134648-c34317bd91bf/go.mod h1:RpwtwJQFrIEPstU94h88MWPXP2ektJZ8cZ0YntAmXiE= | |||
github.com/gopherjs/gopherjs v0.0.0-20181103185306-d547d1d9531e h1:JKmoR8x90Iww1ks85zJ1lfDGgIiMDuIptTOhJq+zKyg= | |||
github.com/gopherjs/gopherjs v0.0.0-20181103185306-d547d1d9531e/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= | |||
github.com/gordonklaus/ineffassign v0.0.0-20180909121442-1003c8bd00dc h1:cJlkeAx1QYgO5N80aF5xRGstVsRQwgLR7uA2FnP1ZjY= | |||
github.com/gordonklaus/ineffassign v0.0.0-20180909121442-1003c8bd00dc/go.mod h1:cuNKsD1zp2v6XfE/orVX2QE1LC+i254ceGcVeDT3pTU= | |||
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= | |||
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= | |||
github.com/gorilla/feeds v1.1.0 h1:pcgLJhbdYgaUESnj3AmXPcB7cS3vy63+jC/TI14AGXk= | |||
github.com/gorilla/feeds v1.1.0/go.mod h1:Nk0jZrvPFZX1OBe5NPiddPw7CfwF6Q9eqzaBbaightA= | |||
github.com/gorilla/mux v1.7.0 h1:tOSd0UKHQd6urX6ApfOn4XdBMY6Sh1MfxV3kmaazO+U= | |||
@@ -54,10 +61,14 @@ github.com/gorilla/schema v1.0.2 h1:sAgNfOcNYvdDSrzGHVy9nzCQahG+qmsg+nE8dK85QRA= | |||
github.com/gorilla/schema v1.0.2/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU= | |||
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= | |||
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= | |||
github.com/gorilla/sessions v1.1.3 h1:uXoZdcdA5XdXF3QzuSlheVRUvjl+1rKY7zBXL68L9RU= | |||
github.com/gorilla/sessions v1.1.3/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w= | |||
github.com/gorilla/sessions v1.2.0 h1:S7P+1Hm5V/AT9cjEcUD5uDaQSX0OE577aCXgoaKpYbQ= | |||
github.com/gorilla/sessions v1.2.0/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= | |||
github.com/guregu/null v3.4.0+incompatible h1:a4mw37gBO7ypcBlTJeZGuMpSxxFTV9qFfFKgWxQSGaM= | |||
github.com/guregu/null v3.4.0+incompatible/go.mod h1:ePGpQaN9cw0tj45IR5E5ehMvsFlLlQZAkkOXZurJ3NM= | |||
github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= | |||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= | |||
github.com/hashicorp/go-multierror v1.0.0 h1:iVjPR7a6H0tWELX5NxNe7bYopibicUzc7uPribsnS6o= | |||
github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= | |||
github.com/ikeikeikeike/go-sitemap-generator/v2 v2.0.2 h1:wIdDEle9HEy7vBPjC6oKz6ejs3Ut+jmsYvuOoAW2pSM= | |||
github.com/ikeikeikeike/go-sitemap-generator/v2 v2.0.2/go.mod h1:WtaVKD9TeruTED9ydiaOJU08qGoEPP/LyzTKiD3jEsw= | |||
github.com/jtolds/gls v4.2.1+incompatible h1:fSuqC+Gmlu6l/ZYAoZzx2pyucC8Xza35fpRVWLVmUEE= | |||
@@ -115,32 +126,46 @@ github.com/tsenart/deadcode v0.0.0-20160724212837-210d2dc333e9 h1:vY5WqiEon0ZSTG | |||
github.com/tsenart/deadcode v0.0.0-20160724212837-210d2dc333e9/go.mod h1:q+QjxYvZ+fpjMXqs+XEriussHjSYqeXVnAdSV1tkMYk= | |||
github.com/writeas/activity v0.1.2 h1:Y12B5lIrabfqKE7e7HFCWiXrlfXljr9tlkFm2mp7DgY= | |||
github.com/writeas/activity v0.1.2/go.mod h1:mYYgiewmEM+8tlifirK/vl6tmB2EbjYaxwb+ndUw5T0= | |||
github.com/writeas/activityserve v0.0.0-20191008122325-5fc3b48e70c5 h1:nG84xWpxBM8YU/FJchezJqg7yZH8ImSRow6NoYtbSII= | |||
github.com/writeas/activityserve v0.0.0-20191008122325-5fc3b48e70c5/go.mod h1:Kz62mzYsCnrFTSTSFLXFj3fGYBQOntmBWTDDq57b46A= | |||
github.com/writeas/activityserve v0.0.0-20191011072627-3a81f7784d5b h1:rd2wX/bTqD55hxtBjAhwLcUgaQE36c70KX3NzpDAwVI= | |||
github.com/writeas/activityserve v0.0.0-20191011072627-3a81f7784d5b/go.mod h1:Kz62mzYsCnrFTSTSFLXFj3fGYBQOntmBWTDDq57b46A= | |||
github.com/writeas/activityserve v0.0.0-20191115095800-dd6d19cc8b89 h1:NJhzq9aTccL3SSSZMrcnYhkD6sObdY9otNZ1X6/ZKNE= | |||
github.com/writeas/activityserve v0.0.0-20191115095800-dd6d19cc8b89/go.mod h1:Kz62mzYsCnrFTSTSFLXFj3fGYBQOntmBWTDDq57b46A= | |||
github.com/writeas/go-strip-markdown v2.0.1+incompatible h1:IIqxTM5Jr7RzhigcL6FkrCNfXkvbR+Nbu1ls48pXYcw= | |||
github.com/writeas/go-strip-markdown v2.0.1+incompatible/go.mod h1:Rsyu10ZhbEK9pXdk8V6MVnZmTzRG0alMNLMwa0J01fE= | |||
github.com/writeas/go-webfinger v0.0.0-20190106002315-85cf805c86d2 h1:DUsp4OhdfI+e6iUqcPQlwx8QYXuUDsToTz/x82D3Zuo= | |||
github.com/writeas/go-webfinger v0.0.0-20190106002315-85cf805c86d2/go.mod h1:w2VxyRO/J5vfNjJHYVubsjUGHd3RLDoVciz0DE3ApOc= | |||
github.com/writeas/go-writeas v1.1.0 h1:WHGm6wriBkxYAOGbvriXH8DlMUGOi6jhSZLUZKQ+4mQ= | |||
github.com/writeas/go-writeas v1.1.0/go.mod h1:oh9U1rWaiE0p3kzdKwwvOpNXgp0P0IELI7OLOwV4fkA= | |||
github.com/writeas/go-writeas/v2 v2.0.2 h1:akvdMg89U5oBJiCkBwOXljVLTqP354uN6qnG2oOMrbk= | |||
github.com/writeas/go-writeas/v2 v2.0.2/go.mod h1:9sjczQJKmru925fLzg0usrU1R1tE4vBmQtGnItUMR0M= | |||
github.com/writeas/httpsig v1.0.0 h1:peIAoIA3DmlP8IG8tMNZqI4YD1uEnWBmkcC9OFPjt3A= | |||
github.com/writeas/httpsig v1.0.0/go.mod h1:7ClMGSrSVXJbmiLa17bZ1LrG1oibGZmUMlh3402flPY= | |||
github.com/writeas/impart v1.1.0 h1:nPnoO211VscNkp/gnzir5UwCDEvdHThL5uELU60NFSE= | |||
github.com/writeas/impart v1.1.0/go.mod h1:g0MpxdnTOHHrl+Ca/2oMXUHJ0PcRAEWtkCzYCJUXC9Y= | |||
github.com/writeas/impart v1.1.1-0.20191230230525-d3c45ced010d h1:PK7DOj3JE6MGf647esPrKzXEHFjGWX2hl22uX79ixaE= | |||
github.com/writeas/impart v1.1.1-0.20191230230525-d3c45ced010d/go.mod h1:g0MpxdnTOHHrl+Ca/2oMXUHJ0PcRAEWtkCzYCJUXC9Y= | |||
github.com/writeas/import v0.2.0 h1:Ov23JW9Rnjxk06rki1Spar45bNX647HhwhAZj3flJiY= | |||
github.com/writeas/import v0.2.0/go.mod h1:gFe0Pl7ZWYiXbI0TJxeMMyylPGZmhVvCfQxhMEc8CxM= | |||
github.com/writeas/monday v0.0.0-20181024183321-54a7dd579219 h1:baEp0631C8sT2r/hqwypIw2snCFZa6h7U6TojoLHu/c= | |||
github.com/writeas/monday v0.0.0-20181024183321-54a7dd579219/go.mod h1:NyM35ayknT7lzO6O/1JpfgGyv+0W9Z9q7aE0J8bXxfQ= | |||
github.com/writeas/nerds v1.0.0 h1:ZzRcCN+Sr3MWID7o/x1cr1ZbLvdpej9Y1/Ho+JKlqxo= | |||
github.com/writeas/nerds v1.0.0/go.mod h1:Gn2bHy1EwRcpXeB7ZhVmuUwiweK0e+JllNf66gvNLdU= | |||
github.com/writeas/openssl-go v1.0.0 h1:YXM1tDXeYOlTyJjoMlYLQH1xOloUimSR1WMF8kjFc5o= | |||
github.com/writeas/openssl-go v1.0.0/go.mod h1:WsKeK5jYl0B5y8ggOmtVjbmb+3rEGqSD25TppjJnETA= | |||
github.com/writeas/saturday v1.6.0/go.mod h1:ETE1EK6ogxptJpAgUbcJD0prAtX48bSloie80+tvnzQ= | |||
github.com/writeas/saturday v1.7.1 h1:lYo1EH6CYyrFObQoA9RNWHVlpZA5iYL5Opxo7PYAnZE= | |||
github.com/writeas/saturday v1.7.1/go.mod h1:ETE1EK6ogxptJpAgUbcJD0prAtX48bSloie80+tvnzQ= | |||
github.com/writeas/slug v1.2.0 h1:EMQ+cwLiOcA6EtFwUgyw3Ge18x9uflUnOnR6bp/J+/g= | |||
github.com/writeas/slug v1.2.0/go.mod h1:RE8shOqQP3YhsfsQe0L3RnuejfQ4Mk+JjY5YJQFubfQ= | |||
github.com/writeas/web-core v1.0.0 h1:5VKkCakQgdKZcbfVKJXtRpc5VHrkflusCl/KRCPzpQ0= | |||
github.com/writeas/web-core v1.0.0/go.mod h1:Si3chV7VWgY8CsV+3gRolMXSO2Vx1ZFAQ/mkrpvmyEE= | |||
github.com/writeas/web-core v1.2.0 h1:CYqvBd+byi1cK4mCr1NZ6CjILuMOFmiFecv+OACcmG0= | |||
github.com/writeas/web-core v1.2.0/go.mod h1:vTYajviuNBAxjctPp2NUYdgjofywVkxUGpeaERF3SfI= | |||
github.com/writefreely/go-nodeinfo v1.2.0 h1:La+YbTCvmpTwFhBSlebWDDL81N88Qf/SCAvRLR7F8ss= | |||
github.com/writefreely/go-nodeinfo v1.2.0/go.mod h1:UTvE78KpcjYOlRHupZIiSEFcXHioTXuacCbHU+CAcPg= | |||
golang.org/x/crypto v0.0.0-20180527072434-ab813273cd59 h1:hk3yo72LXLapY9EXVttc3Z1rLOxT9IuAPPX3GpY2+jo= | |||
golang.org/x/crypto v0.0.0-20180527072434-ab813273cd59/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= | |||
golang.org/x/crypto v0.0.0-20190131182504-b8fe1690c613/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= | |||
golang.org/x/crypto v0.0.0-20190208162236-193df9c0f06f h1:ETU2VEl7TnT5bl7IvuKEzTDpplg5wzGYsOCAPhdoEIg= | |||
golang.org/x/crypto v0.0.0-20190208162236-193df9c0f06f/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= | |||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= | |||
@@ -172,3 +197,5 @@ gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0 h1:POO/ycCATvegFmVuPpQzZFJ+p | |||
gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0/go.mod h1:WDnlLJ4WF5VGsH/HVa3CI79GS0ol3YnhVnKP89i0kNg= | |||
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= | |||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= | |||
src.techknowlogick.com/xgo v0.0.0-20200129005940-d0fae26e014b h1:rPAdjgXks4ToezTjygsnKZroxKVnA1L35DSpsJXPtfc= | |||
src.techknowlogick.com/xgo v0.0.0-20200129005940-d0fae26e014b/go.mod h1:31CE1YKtDOrKTk9PSnjTpe6YbO6W/0LTYZ1VskL09oU= |
@@ -73,7 +73,7 @@ type ( | |||
type Handler struct { | |||
errors *ErrorPages | |||
sessionStore *sessions.CookieStore | |||
sessionStore sessions.Store | |||
app Apper | |||
} | |||
@@ -96,7 +96,7 @@ func NewHandler(apper Apper) *Handler { | |||
InternalServerError: template.Must(template.New("").Parse("{{define \"base\"}}<html><head><title>500</title></head><body><p>Internal server error.</p></body></html>{{end}}")), | |||
Blank: template.Must(template.New("").Parse("{{define \"base\"}}<html><head><title>{{.Title}}</title></head><body><p>{{.Content}}</p></body></html>{{end}}")), | |||
}, | |||
sessionStore: apper.App().sessionStore, | |||
sessionStore: apper.App().SessionStore(), | |||
app: apper, | |||
} | |||
@@ -549,6 +549,37 @@ func (h *Handler) All(f handlerFunc) http.HandlerFunc { | |||
} | |||
} | |||
func (h *Handler) OAuth(f handlerFunc) http.HandlerFunc { | |||
return func(w http.ResponseWriter, r *http.Request) { | |||
h.handleOAuthError(w, r, func() error { | |||
// TODO: return correct "success" status | |||
status := 200 | |||
start := time.Now() | |||
defer func() { | |||
if e := recover(); e != nil { | |||
log.Error("%s:\n%s", e, debug.Stack()) | |||
impart.WriteError(w, impart.HTTPError{http.StatusInternalServerError, "Something didn't work quite right."}) | |||
status = 500 | |||
} | |||
log.Info(h.app.ReqLog(r, status, time.Since(start))) | |||
}() | |||
err := f(h.app.App(), w, r) | |||
if err != nil { | |||
if err, ok := err.(impart.HTTPError); ok { | |||
status = err.Status | |||
} else { | |||
status = 500 | |||
} | |||
} | |||
return err | |||
}()) | |||
} | |||
} | |||
func (h *Handler) AllReader(f handlerFunc) http.HandlerFunc { | |||
return func(w http.ResponseWriter, r *http.Request) { | |||
h.handleError(w, r, func() error { | |||
@@ -779,6 +810,25 @@ func (h *Handler) handleError(w http.ResponseWriter, r *http.Request, err error) | |||
h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r)) | |||
} | |||
func (h *Handler) handleOAuthError(w http.ResponseWriter, r *http.Request, err error) { | |||
if err == nil { | |||
return | |||
} | |||
if err, ok := err.(impart.HTTPError); ok { | |||
if err.Status >= 300 && err.Status < 400 { | |||
sendRedirect(w, err.Status, err.Message) | |||
return | |||
} | |||
impart.WriteOAuthError(w, err) | |||
return | |||
} | |||
impart.WriteOAuthError(w, impart.HTTPError{http.StatusInternalServerError, "This is an unhelpful error message for a miscellaneous internal error."}) | |||
return | |||
} | |||
func correctPageFromLoginAttempt(r *http.Request) string { | |||
to := r.FormValue("to") | |||
if to == "" { | |||
@@ -684,18 +684,19 @@ select.inputform, textarea.inputform { | |||
border: 1px solid #999; | |||
} | |||
input, button, select.inputform, textarea.inputform { | |||
input, button, select.inputform, textarea.inputform, a.btn { | |||
padding: 0.5em; | |||
font-family: @serifFont; | |||
font-size: 100%; | |||
.rounded(.25em); | |||
&[type=submit], &.submit { | |||
&[type=submit], &.submit, &.cta { | |||
border: 1px solid @primary; | |||
background: @primary; | |||
color: white; | |||
.transition(0.2s); | |||
&:hover { | |||
background-color: lighten(@primary, 3%); | |||
text-decoration: none; | |||
} | |||
&:disabled { | |||
cursor: default; | |||
@@ -1317,6 +1318,24 @@ form { | |||
font-size: 0.86em; | |||
line-height: 2; | |||
} | |||
&.prominent { | |||
margin: 1em 0; | |||
label { | |||
font-weight: bold; | |||
} | |||
input, select { | |||
width: 100%; | |||
} | |||
select { | |||
font-size: 1em; | |||
padding: 0.5rem; | |||
display: block; | |||
border-radius: 0.25rem; | |||
margin: 0.5rem 0; | |||
} | |||
} | |||
} | |||
div.row { | |||
display: flex; | |||
@@ -17,6 +17,16 @@ body { | |||
font-size: 1.6em; | |||
} | |||
} | |||
article { | |||
h2#title.dated { | |||
margin-bottom: 0.5em; | |||
} | |||
time.dt-published { | |||
display: block; | |||
color: #666; | |||
margin-bottom: 1em; | |||
} | |||
} | |||
} | |||
} | |||
@@ -0,0 +1,153 @@ | |||
package writefreely | |||
import ( | |||
"context" | |||
"database/sql" | |||
"encoding/gob" | |||
"errors" | |||
"fmt" | |||
uuid "github.com/nu7hatch/gouuid" | |||
"github.com/stretchr/testify/assert" | |||
"math/rand" | |||
"os" | |||
"strings" | |||
"testing" | |||
"time" | |||
) | |||
var testDB *sql.DB | |||
type ScopedTestBody func(*sql.DB) | |||
// TestMain provides testing infrastructure within this package. | |||
func TestMain(m *testing.M) { | |||
rand.Seed(time.Now().UTC().UnixNano()) | |||
gob.Register(&User{}) | |||
if runMySQLTests() { | |||
var err error | |||
testDB, err = initMySQL(os.Getenv("WF_USER"), os.Getenv("WF_PASSWORD"), os.Getenv("WF_DB"), os.Getenv("WF_HOST")) | |||
if err != nil { | |||
fmt.Println(err) | |||
return | |||
} | |||
} | |||
code := m.Run() | |||
if runMySQLTests() { | |||
if closeErr := testDB.Close(); closeErr != nil { | |||
fmt.Println(closeErr) | |||
} | |||
} | |||
os.Exit(code) | |||
} | |||
func runMySQLTests() bool { | |||
return len(os.Getenv("TEST_MYSQL")) > 0 | |||
} | |||
func initMySQL(dbUser, dbPassword, dbName, dbHost string) (*sql.DB, error) { | |||
if dbUser == "" || dbPassword == "" { | |||
return nil, errors.New("database user or password not set") | |||
} | |||
if dbHost == "" { | |||
dbHost = "localhost" | |||
} | |||
if dbName == "" { | |||
dbName = "writefreely" | |||
} | |||
dsn := fmt.Sprintf("%s:%s@tcp(%s:3306)/%s?charset=utf8mb4&parseTime=true", dbUser, dbPassword, dbHost, dbName) | |||
db, err := sql.Open("mysql", dsn) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if err := ensureMySQL(db); err != nil { | |||
return nil, err | |||
} | |||
return db, nil | |||
} | |||
func ensureMySQL(db *sql.DB) error { | |||
if err := db.Ping(); err != nil { | |||
return err | |||
} | |||
db.SetMaxOpenConns(250) | |||
return nil | |||
} | |||
// withTestDB provides a scoped database connection. | |||
func withTestDB(t *testing.T, testBody ScopedTestBody) { | |||
db, cleanup, err := newTestDatabase(testDB, | |||
os.Getenv("WF_USER"), | |||
os.Getenv("WF_PASSWORD"), | |||
os.Getenv("WF_DB"), | |||
os.Getenv("WF_HOST"), | |||
) | |||
assert.NoError(t, err) | |||
defer func() { | |||
assert.NoError(t, cleanup()) | |||
}() | |||
testBody(db) | |||
} | |||
// newTestDatabase creates a new temporary test database. When a test | |||
// database connection is returned, it will have created a new database and | |||
// initialized it with tables from a reference database. | |||
func newTestDatabase(base *sql.DB, dbUser, dbPassword, dbName, dbHost string) (*sql.DB, func() error, error) { | |||
var err error | |||
var baseName = dbName | |||
if baseName == "" { | |||
row := base.QueryRow("SELECT DATABASE()") | |||
err := row.Scan(&baseName) | |||
if err != nil { | |||
return nil, nil, err | |||
} | |||
} | |||
tUUID, _ := uuid.NewV4() | |||
suffix := strings.Replace(tUUID.String(), "-", "_", -1) | |||
newDBName := baseName + suffix | |||
_, err = base.Exec("CREATE DATABASE " + newDBName) | |||
if err != nil { | |||
return nil, nil, err | |||
} | |||
newDB, err := initMySQL(dbUser, dbPassword, newDBName, dbHost) | |||
if err != nil { | |||
return nil, nil, err | |||
} | |||
rows, err := base.Query("SHOW TABLES IN " + baseName) | |||
if err != nil { | |||
return nil, nil, err | |||
} | |||
for rows.Next() { | |||
var tableName string | |||
if err := rows.Scan(&tableName); err != nil { | |||
return nil, nil, err | |||
} | |||
query := fmt.Sprintf("CREATE TABLE %s LIKE %s.%s", tableName, baseName, tableName) | |||
if _, err := newDB.Exec(query); err != nil { | |||
return nil, nil, err | |||
} | |||
} | |||
cleanup := func() error { | |||
if closeErr := newDB.Close(); closeErr != nil { | |||
fmt.Println(closeErr) | |||
} | |||
_, err = base.Exec("DROP DATABASE " + newDBName) | |||
return err | |||
} | |||
return newDB, cleanup, nil | |||
} | |||
func countRows(t *testing.T, ctx context.Context, db *sql.DB, count int, query string, args ...interface{}) { | |||
var returned int | |||
err := db.QueryRowContext(ctx, query, args...).Scan(&returned) | |||
assert.NoError(t, err, "error executing query %s and args %s", query, args) | |||
assert.Equal(t, count, returned, "unexpected return count %d, expected %d from %s and args %s", returned, count, query, args) | |||
} |
@@ -56,9 +56,12 @@ func (m *migration) Migrate(db *datastore) error { | |||
} | |||
var migrations = []Migration{ | |||
New("support user invites", supportUserInvites), // -> V1 (v0.8.0) | |||
New("support dynamic instance pages", supportInstancePages), // V1 -> V2 (v0.9.0) | |||
New("support users suspension", supportUserStatus), // V2 -> V3 (v0.11.0) | |||
New("support user invites", supportUserInvites), // -> V1 (v0.8.0) | |||
New("support dynamic instance pages", supportInstancePages), // V1 -> V2 (v0.9.0) | |||
New("support users suspension", supportUserStatus), // V2 -> V3 (v0.11.0) | |||
New("support oauth", oauth), // V3 -> V4 | |||
New("support slack oauth", oauthSlack), // V4 -> v5 | |||
New("support ActivityPub mentions", supportActivityPubMentions), // V5 -> V6 (v0.12.0) | |||
} | |||
// CurrentVer returns the current migration version the application is on | |||
@@ -0,0 +1,46 @@ | |||
package migrations | |||
import ( | |||
"context" | |||
"database/sql" | |||
wf_db "github.com/writeas/writefreely/db" | |||
) | |||
func oauth(db *datastore) error { | |||
dialect := wf_db.DialectMySQL | |||
if db.driverName == driverSQLite { | |||
dialect = wf_db.DialectSQLite | |||
} | |||
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { | |||
createTableUsersOauth, err := dialect. | |||
Table("oauth_users"). | |||
SetIfNotExists(true). | |||
Column(dialect.Column("user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)). | |||
Column(dialect.Column("remote_user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)). | |||
UniqueConstraint("user_id"). | |||
UniqueConstraint("remote_user_id"). | |||
ToSQL() | |||
if err != nil { | |||
return err | |||
} | |||
createTableOauthClientState, err := dialect. | |||
Table("oauth_client_states"). | |||
SetIfNotExists(true). | |||
Column(dialect.Column("state", wf_db.ColumnTypeVarChar, wf_db.OptionalInt{Set: true, Value: 255})). | |||
Column(dialect.Column("used", wf_db.ColumnTypeBool, wf_db.UnsetSize)). | |||
Column(dialect.Column("created_at", wf_db.ColumnTypeDateTime, wf_db.UnsetSize).SetDefault("NOW()")). | |||
UniqueConstraint("state"). | |||
ToSQL() | |||
if err != nil { | |||
return err | |||
} | |||
for _, table := range []string{createTableUsersOauth, createTableOauthClientState} { | |||
if _, err := tx.ExecContext(ctx, table); err != nil { | |||
return err | |||
} | |||
} | |||
return nil | |||
}) | |||
} |
@@ -0,0 +1,67 @@ | |||
package migrations | |||
import ( | |||
"context" | |||
"database/sql" | |||
wf_db "github.com/writeas/writefreely/db" | |||
) | |||
func oauthSlack(db *datastore) error { | |||
dialect := wf_db.DialectMySQL | |||
if db.driverName == driverSQLite { | |||
dialect = wf_db.DialectSQLite | |||
} | |||
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { | |||
builders := []wf_db.SQLBuilder{ | |||
dialect. | |||
AlterTable("oauth_client_states"). | |||
AddColumn(dialect. | |||
Column( | |||
"provider", | |||
wf_db.ColumnTypeVarChar, | |||
wf_db.OptionalInt{Set: true, Value: 24,})). | |||
AddColumn(dialect. | |||
Column( | |||
"client_id", | |||
wf_db.ColumnTypeVarChar, | |||
wf_db.OptionalInt{Set: true, Value: 128,})), | |||
dialect. | |||
AlterTable("oauth_users"). | |||
ChangeColumn("remote_user_id", | |||
dialect. | |||
Column( | |||
"remote_user_id", | |||
wf_db.ColumnTypeVarChar, | |||
wf_db.OptionalInt{Set: true, Value: 128,})). | |||
AddColumn(dialect. | |||
Column( | |||
"provider", | |||
wf_db.ColumnTypeVarChar, | |||
wf_db.OptionalInt{Set: true, Value: 24,})). | |||
AddColumn(dialect. | |||
Column( | |||
"client_id", | |||
wf_db.ColumnTypeVarChar, | |||
wf_db.OptionalInt{Set: true, Value: 128,})). | |||
AddColumn(dialect. | |||
Column( | |||
"access_token", | |||
wf_db.ColumnTypeVarChar, | |||
wf_db.OptionalInt{Set: true, Value: 512,})), | |||
dialect.DropIndex("remote_user_id", "oauth_users"), | |||
dialect.DropIndex("user_id", "oauth_users"), | |||
dialect.CreateUniqueIndex("oauth_users", "oauth_users", "user_id", "provider", "client_id"), | |||
} | |||
for _, builder := range builders { | |||
query, err := builder.ToSQL() | |||
if err != nil { | |||
return err | |||
} | |||
if _, err := tx.ExecContext(ctx, query); err != nil { | |||
return err | |||
} | |||
} | |||
return nil | |||
}) | |||
} |
@@ -0,0 +1,29 @@ | |||
/* | |||
* Copyright © 2019 A Bunch Tell LLC. | |||
* | |||
* This file is part of WriteFreely. | |||
* | |||
* WriteFreely is free software: you can redistribute it and/or modify | |||
* it under the terms of the GNU Affero General Public License, included | |||
* in the LICENSE file in this source code package. | |||
*/ | |||
package migrations | |||
func supportActivityPubMentions(db *datastore) error { | |||
t, err := db.Begin() | |||
_, err = t.Exec(`ALTER TABLE remoteusers ADD COLUMN handle ` + db.typeVarChar(255) + ` DEFAULT '' NOT NULL`) | |||
if err != nil { | |||
t.Rollback() | |||
return err | |||
} | |||
err = t.Commit() | |||
if err != nil { | |||
t.Rollback() | |||
return err | |||
} | |||
return nil | |||
} |
@@ -0,0 +1,291 @@ | |||
package writefreely | |||
import ( | |||
"context" | |||
"encoding/json" | |||
"fmt" | |||
"github.com/gorilla/mux" | |||
"github.com/gorilla/sessions" | |||
"github.com/writeas/impart" | |||
"github.com/writeas/web-core/log" | |||
"github.com/writeas/writefreely/config" | |||
"io" | |||
"io/ioutil" | |||
"net/http" | |||
"net/url" | |||
"strings" | |||
"time" | |||
) | |||
// TokenResponse contains data returned when a token is created either | |||
// through a code exchange or using a refresh token. | |||
type TokenResponse struct { | |||
AccessToken string `json:"access_token"` | |||
ExpiresIn int `json:"expires_in"` | |||
RefreshToken string `json:"refresh_token"` | |||
TokenType string `json:"token_type"` | |||
Error string `json:"error"` | |||
} | |||
// InspectResponse contains data returned when an access token is inspected. | |||
type InspectResponse struct { | |||
ClientID string `json:"client_id"` | |||
UserID string `json:"user_id"` | |||
ExpiresAt time.Time `json:"expires_at"` | |||
Username string `json:"username"` | |||
DisplayName string `json:"-"` | |||
Email string `json:"email"` | |||
Error string `json:"error"` | |||
} | |||
// tokenRequestMaxLen is the most bytes that we'll read from the /oauth/token | |||
// endpoint. One megabyte is plenty. | |||
const tokenRequestMaxLen = 1000000 | |||
// infoRequestMaxLen is the most bytes that we'll read from the | |||
// /oauth/inspect endpoint. | |||
const infoRequestMaxLen = 1000000 | |||
// OAuthDatastoreProvider provides a minimal interface of data store, config, | |||
// and session store for use with the oauth handlers. | |||
type OAuthDatastoreProvider interface { | |||
DB() OAuthDatastore | |||
Config() *config.Config | |||
SessionStore() sessions.Store | |||
} | |||
// OAuthDatastore provides a minimal interface of data store methods used in | |||
// oauth functionality. | |||
type OAuthDatastore interface { | |||
GetIDForRemoteUser(context.Context, string, string, string) (int64, error) | |||
RecordRemoteUserID(context.Context, int64, string, string, string, string) error | |||
ValidateOAuthState(context.Context, string) (string, string, error) | |||
GenerateOAuthState(context.Context, string, string) (string, error) | |||
CreateUser(*config.Config, *User, string) error | |||
GetUserByID(int64) (*User, error) | |||
} | |||
type HttpClient interface { | |||
Do(req *http.Request) (*http.Response, error) | |||
} | |||
type oauthClient interface { | |||
GetProvider() string | |||
GetClientID() string | |||
GetCallbackLocation() string | |||
buildLoginURL(state string) (string, error) | |||
exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) | |||
inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) | |||
} | |||
type callbackProxyClient struct { | |||
server string | |||
callbackLocation string | |||
httpClient HttpClient | |||
} | |||
type oauthHandler struct { | |||
Config *config.Config | |||
DB OAuthDatastore | |||
Store sessions.Store | |||
EmailKey []byte | |||
oauthClient oauthClient | |||
callbackProxy *callbackProxyClient | |||
} | |||
func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error { | |||
ctx := r.Context() | |||
state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID()) | |||
if err != nil { | |||
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} | |||
} | |||
if h.callbackProxy != nil { | |||
if err := h.callbackProxy.register(ctx, state); err != nil { | |||
return impart.HTTPError{http.StatusInternalServerError, "could not register state server"} | |||
} | |||
} | |||
location, err := h.oauthClient.buildLoginURL(state) | |||
if err != nil { | |||
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} | |||
} | |||
return impart.HTTPError{http.StatusTemporaryRedirect, location} | |||
} | |||
func configureSlackOauth(parentHandler *Handler, r *mux.Router, app *App) { | |||
if app.Config().SlackOauth.ClientID != "" { | |||
callbackLocation := app.Config().App.Host + "/oauth/callback/slack" | |||
var stateRegisterClient *callbackProxyClient = nil | |||
if app.Config().SlackOauth.CallbackProxyAPI != "" { | |||
stateRegisterClient = &callbackProxyClient{ | |||
server: app.Config().SlackOauth.CallbackProxyAPI, | |||
callbackLocation: app.Config().App.Host + "/oauth/callback/slack", | |||
httpClient: config.DefaultHTTPClient(), | |||
} | |||
callbackLocation = app.Config().SlackOauth.CallbackProxy | |||
} | |||
oauthClient := slackOauthClient{ | |||
ClientID: app.Config().SlackOauth.ClientID, | |||
ClientSecret: app.Config().SlackOauth.ClientSecret, | |||
TeamID: app.Config().SlackOauth.TeamID, | |||
HttpClient: config.DefaultHTTPClient(), | |||
CallbackLocation: callbackLocation, | |||
} | |||
configureOauthRoutes(parentHandler, r, app, oauthClient, stateRegisterClient) | |||
} | |||
} | |||
func configureWriteAsOauth(parentHandler *Handler, r *mux.Router, app *App) { | |||
if app.Config().WriteAsOauth.ClientID != "" { | |||
callbackLocation := app.Config().App.Host + "/oauth/callback/write.as" | |||
var callbackProxy *callbackProxyClient = nil | |||
if app.Config().WriteAsOauth.CallbackProxy != "" { | |||
callbackProxy = &callbackProxyClient{ | |||
server: app.Config().WriteAsOauth.CallbackProxyAPI, | |||
callbackLocation: app.Config().App.Host + "/oauth/callback/write.as", | |||
httpClient: config.DefaultHTTPClient(), | |||
} | |||
callbackLocation = app.Config().SlackOauth.CallbackProxy | |||
} | |||
oauthClient := writeAsOauthClient{ | |||
ClientID: app.Config().WriteAsOauth.ClientID, | |||
ClientSecret: app.Config().WriteAsOauth.ClientSecret, | |||
ExchangeLocation: config.OrDefaultString(app.Config().WriteAsOauth.TokenLocation, writeAsExchangeLocation), | |||
InspectLocation: config.OrDefaultString(app.Config().WriteAsOauth.InspectLocation, writeAsIdentityLocation), | |||
AuthLocation: config.OrDefaultString(app.Config().WriteAsOauth.AuthLocation, writeAsAuthLocation), | |||
HttpClient: config.DefaultHTTPClient(), | |||
CallbackLocation: callbackLocation, | |||
} | |||
configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy) | |||
} | |||
} | |||
func configureOauthRoutes(parentHandler *Handler, r *mux.Router, app *App, oauthClient oauthClient, callbackProxy *callbackProxyClient) { | |||
handler := &oauthHandler{ | |||
Config: app.Config(), | |||
DB: app.DB(), | |||
Store: app.SessionStore(), | |||
oauthClient: oauthClient, | |||
EmailKey: app.keys.EmailKey, | |||
callbackProxy: callbackProxy, | |||
} | |||
r.HandleFunc("/oauth/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthInit)).Methods("GET") | |||
r.HandleFunc("/oauth/callback/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthCallback)).Methods("GET") | |||
r.HandleFunc("/oauth/signup", parentHandler.OAuth(handler.viewOauthSignup)).Methods("POST") | |||
} | |||
func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http.Request) error { | |||
ctx := r.Context() | |||
code := r.FormValue("code") | |||
state := r.FormValue("state") | |||
provider, clientID, err := h.DB.ValidateOAuthState(ctx, state) | |||
if err != nil { | |||
log.Error("Unable to ValidateOAuthState: %s", err) | |||
return impart.HTTPError{http.StatusInternalServerError, err.Error()} | |||
} | |||
tokenResponse, err := h.oauthClient.exchangeOauthCode(ctx, code) | |||
if err != nil { | |||
log.Error("Unable to exchangeOauthCode: %s", err) | |||
return impart.HTTPError{http.StatusInternalServerError, err.Error()} | |||
} | |||
// Now that we have the access token, let's use it real quick to make sur | |||
// it really really works. | |||
tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken) | |||
if err != nil { | |||
log.Error("Unable to inspectOauthAccessToken: %s", err) | |||
return impart.HTTPError{http.StatusInternalServerError, err.Error()} | |||
} | |||
localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID, provider, clientID) | |||
if err != nil { | |||
log.Error("Unable to GetIDForRemoteUser: %s", err) | |||
return impart.HTTPError{http.StatusInternalServerError, err.Error()} | |||
} | |||
if localUserID != -1 { | |||
user, err := h.DB.GetUserByID(localUserID) | |||
if err != nil { | |||
log.Error("Unable to GetUserByID %d: %s", localUserID, err) | |||
return impart.HTTPError{http.StatusInternalServerError, err.Error()} | |||
} | |||
if err = loginOrFail(h.Store, w, r, user); err != nil { | |||
log.Error("Unable to loginOrFail %d: %s", localUserID, err) | |||
return impart.HTTPError{http.StatusInternalServerError, err.Error()} | |||
} | |||
return nil | |||
} | |||
displayName := tokenInfo.DisplayName | |||
if len(displayName) == 0 { | |||
displayName = tokenInfo.Username | |||
} | |||
tp := &oauthSignupPageParams{ | |||
AccessToken: tokenResponse.AccessToken, | |||
TokenUsername: tokenInfo.Username, | |||
TokenAlias: tokenInfo.DisplayName, | |||
TokenEmail: tokenInfo.Email, | |||
TokenRemoteUser: tokenInfo.UserID, | |||
Provider: provider, | |||
ClientID: clientID, | |||
} | |||
tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed) | |||
return h.showOauthSignupPage(app, w, r, tp, nil) | |||
} | |||
func (r *callbackProxyClient) register(ctx context.Context, state string) error { | |||
form := url.Values{} | |||
form.Add("state", state) | |||
form.Add("location", r.callbackLocation) | |||
req, err := http.NewRequestWithContext(ctx, "POST", r.server, strings.NewReader(form.Encode())) | |||
if err != nil { | |||
return err | |||
} | |||
req.Header.Set("User-Agent", "writefreely") | |||
req.Header.Set("Accept", "application/json") | |||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") | |||
resp, err := r.httpClient.Do(req) | |||
if err != nil { | |||
return err | |||
} | |||
if resp.StatusCode != http.StatusCreated { | |||
return fmt.Errorf("unable register state location: %d", resp.StatusCode) | |||
} | |||
return nil | |||
} | |||
func limitedJsonUnmarshal(body io.ReadCloser, n int, thing interface{}) error { | |||
lr := io.LimitReader(body, int64(n+1)) | |||
data, err := ioutil.ReadAll(lr) | |||
if err != nil { | |||
return err | |||
} | |||
if len(data) == n+1 { | |||
return fmt.Errorf("content larger than max read allowance: %d", n) | |||
} | |||
return json.Unmarshal(data, thing) | |||
} | |||
func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error { | |||
// An error may be returned, but a valid session should always be returned. | |||
session, _ := store.Get(r, cookieName) | |||
session.Values[cookieUserVal] = user.Cookie() | |||
if err := session.Save(r, w); err != nil { | |||
fmt.Println("error saving session", err) | |||
return err | |||
} | |||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect) | |||
return nil | |||
} |
@@ -0,0 +1,10 @@ | |||
package oauth | |||
import "context" | |||
// ClientStateStore provides state management used by the OAuth client. | |||
type ClientStateStore interface { | |||
Generate(ctx context.Context) (string, error) | |||
Validate(ctx context.Context, state string) error | |||
} | |||
@@ -0,0 +1,218 @@ | |||
/* | |||
* Copyright © 2020 A Bunch Tell LLC. | |||
* | |||
* This file is part of WriteFreely. | |||
* | |||
* WriteFreely is free software: you can redistribute it and/or modify | |||
* it under the terms of the GNU Affero General Public License, included | |||
* in the LICENSE file in this source code package. | |||
*/ | |||
package writefreely | |||
import ( | |||
"crypto/sha256" | |||
"encoding/hex" | |||
"fmt" | |||
"github.com/writeas/impart" | |||
"github.com/writeas/web-core/auth" | |||
"github.com/writeas/web-core/log" | |||
"github.com/writeas/writefreely/page" | |||
"html/template" | |||
"net/http" | |||
"strings" | |||
"time" | |||
) | |||
type viewOauthSignupVars struct { | |||
page.StaticPage | |||
To string | |||
Message template.HTML | |||
Flashes []template.HTML | |||
AccessToken string | |||
TokenUsername string | |||
TokenAlias string // TODO: rename this to match the data it represents: the collection title | |||
TokenEmail string | |||
TokenRemoteUser string | |||
Provider string | |||
ClientID string | |||
TokenHash string | |||
LoginUsername string | |||
Alias string // TODO: rename this to match the data it represents: the collection title | |||
Email string | |||
} | |||
const ( | |||
oauthParamAccessToken = "access_token" | |||
oauthParamTokenUsername = "token_username" | |||
oauthParamTokenAlias = "token_alias" | |||
oauthParamTokenEmail = "token_email" | |||
oauthParamTokenRemoteUserID = "token_remote_user" | |||
oauthParamClientID = "client_id" | |||
oauthParamProvider = "provider" | |||
oauthParamHash = "signature" | |||
oauthParamUsername = "username" | |||
oauthParamAlias = "alias" | |||
oauthParamEmail = "email" | |||
oauthParamPassword = "password" | |||
) | |||
type oauthSignupPageParams struct { | |||
AccessToken string | |||
TokenUsername string | |||
TokenAlias string // TODO: rename this to match the data it represents: the collection title | |||
TokenEmail string | |||
TokenRemoteUser string | |||
ClientID string | |||
Provider string | |||
TokenHash string | |||
} | |||
func (p oauthSignupPageParams) HashTokenParams(key string) string { | |||
hasher := sha256.New() | |||
hasher.Write([]byte(key)) | |||
hasher.Write([]byte(p.AccessToken)) | |||
hasher.Write([]byte(p.TokenUsername)) | |||
hasher.Write([]byte(p.TokenAlias)) | |||
hasher.Write([]byte(p.TokenEmail)) | |||
hasher.Write([]byte(p.TokenRemoteUser)) | |||
hasher.Write([]byte(p.ClientID)) | |||
hasher.Write([]byte(p.Provider)) | |||
return hex.EncodeToString(hasher.Sum(nil)) | |||
} | |||
func (h oauthHandler) viewOauthSignup(app *App, w http.ResponseWriter, r *http.Request) error { | |||
tp := &oauthSignupPageParams{ | |||
AccessToken: r.FormValue(oauthParamAccessToken), | |||
TokenUsername: r.FormValue(oauthParamTokenUsername), | |||
TokenAlias: r.FormValue(oauthParamTokenAlias), | |||
TokenEmail: r.FormValue(oauthParamTokenEmail), | |||
TokenRemoteUser: r.FormValue(oauthParamTokenRemoteUserID), | |||
ClientID: r.FormValue(oauthParamClientID), | |||
Provider: r.FormValue(oauthParamProvider), | |||
} | |||
if tp.HashTokenParams(h.Config.Server.HashSeed) != r.FormValue(oauthParamHash) { | |||
return impart.HTTPError{Status: http.StatusBadRequest, Message: "Request has been tampered with."} | |||
} | |||
tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed) | |||
if err := h.validateOauthSignup(r); err != nil { | |||
return h.showOauthSignupPage(app, w, r, tp, err) | |||
} | |||
var err error | |||
hashedPass := []byte{} | |||
clearPass := r.FormValue(oauthParamPassword) | |||
hasPass := clearPass != "" | |||
if hasPass { | |||
hashedPass, err = auth.HashPass([]byte(clearPass)) | |||
if err != nil { | |||
return h.showOauthSignupPage(app, w, r, tp, fmt.Errorf("unable to hash password")) | |||
} | |||
} | |||
newUser := &User{ | |||
Username: r.FormValue(oauthParamUsername), | |||
HashedPass: hashedPass, | |||
HasPass: hasPass, | |||
Email: prepareUserEmail(r.FormValue(oauthParamEmail), h.EmailKey), | |||
Created: time.Now().Truncate(time.Second).UTC(), | |||
} | |||
displayName := r.FormValue(oauthParamAlias) | |||
if len(displayName) == 0 { | |||
displayName = r.FormValue(oauthParamUsername) | |||
} | |||
err = h.DB.CreateUser(h.Config, newUser, displayName) | |||
if err != nil { | |||
return h.showOauthSignupPage(app, w, r, tp, err) | |||
} | |||
err = h.DB.RecordRemoteUserID(r.Context(), newUser.ID, r.FormValue(oauthParamTokenRemoteUserID), r.FormValue(oauthParamProvider), r.FormValue(oauthParamClientID), r.FormValue(oauthParamAccessToken)) | |||
if err != nil { | |||
return h.showOauthSignupPage(app, w, r, tp, err) | |||
} | |||
if err := loginOrFail(h.Store, w, r, newUser); err != nil { | |||
return h.showOauthSignupPage(app, w, r, tp, err) | |||
} | |||
return nil | |||
} | |||
func (h oauthHandler) validateOauthSignup(r *http.Request) error { | |||
username := r.FormValue(oauthParamUsername) | |||
if len(username) < h.Config.App.MinUsernameLen { | |||
return impart.HTTPError{Status: http.StatusBadRequest, Message: "Username is too short."} | |||
} | |||
if len(username) > 100 { | |||
return impart.HTTPError{Status: http.StatusBadRequest, Message: "Username is too long."} | |||
} | |||
collTitle := r.FormValue(oauthParamAlias) | |||
if len(collTitle) == 0 { | |||
collTitle = username | |||
} | |||
email := r.FormValue(oauthParamEmail) | |||
if len(email) > 0 { | |||
parts := strings.Split(email, "@") | |||
if len(parts) != 2 || (len(parts[0]) < 1 || len(parts[1]) < 1) { | |||
return impart.HTTPError{Status: http.StatusBadRequest, Message: "Invalid email address"} | |||
} | |||
} | |||
return nil | |||
} | |||
func (h oauthHandler) showOauthSignupPage(app *App, w http.ResponseWriter, r *http.Request, tp *oauthSignupPageParams, errMsg error) error { | |||
username := tp.TokenUsername | |||
collTitle := tp.TokenAlias | |||
email := tp.TokenEmail | |||
session, err := app.sessionStore.Get(r, cookieName) | |||
if err != nil { | |||
// Ignore this | |||
log.Error("Unable to get session; ignoring: %v", err) | |||
} | |||
if tmpValue := r.FormValue(oauthParamUsername); len(tmpValue) > 0 { | |||
username = tmpValue | |||
} | |||
if tmpValue := r.FormValue(oauthParamAlias); len(tmpValue) > 0 { | |||
collTitle = tmpValue | |||
} | |||
if tmpValue := r.FormValue(oauthParamEmail); len(tmpValue) > 0 { | |||
email = tmpValue | |||
} | |||
p := &viewOauthSignupVars{ | |||
StaticPage: pageForReq(app, r), | |||
To: r.FormValue("to"), | |||
Flashes: []template.HTML{}, | |||
AccessToken: tp.AccessToken, | |||
TokenUsername: tp.TokenUsername, | |||
TokenAlias: tp.TokenAlias, | |||
TokenEmail: tp.TokenEmail, | |||
TokenRemoteUser: tp.TokenRemoteUser, | |||
Provider: tp.Provider, | |||
ClientID: tp.ClientID, | |||
TokenHash: tp.TokenHash, | |||
LoginUsername: username, | |||
Alias: collTitle, | |||
Email: email, | |||
} | |||
// Display any error messages | |||
flashes, _ := getSessionFlashes(app, w, r, session) | |||
for _, flash := range flashes { | |||
p.Flashes = append(p.Flashes, template.HTML(flash)) | |||
} | |||
if errMsg != nil { | |||
p.Flashes = append(p.Flashes, template.HTML(errMsg.Error())) | |||
} | |||
err = pages["signup-oauth.tmpl"].ExecuteTemplate(w, "base", p) | |||
if err != nil { | |||
log.Error("Unable to render signup-oauth: %v", err) | |||
return err | |||
} | |||
return nil | |||
} |
@@ -0,0 +1,180 @@ | |||
/* | |||
* Copyright © 2019-2020 A Bunch Tell LLC. | |||
* | |||
* This file is part of WriteFreely. | |||
* | |||
* WriteFreely is free software: you can redistribute it and/or modify | |||
* it under the terms of the GNU Affero General Public License, included | |||
* in the LICENSE file in this source code package. | |||
*/ | |||
package writefreely | |||
import ( | |||
"context" | |||
"errors" | |||
"fmt" | |||
"github.com/writeas/nerds/store" | |||
"github.com/writeas/slug" | |||
"net/http" | |||
"net/url" | |||
"strings" | |||
) | |||
type slackOauthClient struct { | |||
ClientID string | |||
ClientSecret string | |||
TeamID string | |||
CallbackLocation string | |||
HttpClient HttpClient | |||
} | |||
type slackExchangeResponse struct { | |||
OK bool `json:"ok"` | |||
AccessToken string `json:"access_token"` | |||
Scope string `json:"scope"` | |||
TeamName string `json:"team_name"` | |||
TeamID string `json:"team_id"` | |||
Error string `json:"error"` | |||
} | |||
type slackIdentity struct { | |||
Name string `json:"name"` | |||
ID string `json:"id"` | |||
Email string `json:"email"` | |||
} | |||
type slackTeam struct { | |||
Name string `json:"name"` | |||
ID string `json:"id"` | |||
} | |||
type slackUserIdentityResponse struct { | |||
OK bool `json:"ok"` | |||
User slackIdentity `json:"user"` | |||
Team slackTeam `json:"team"` | |||
Error string `json:"error"` | |||
} | |||
const ( | |||
slackAuthLocation = "https://slack.com/oauth/authorize" | |||
slackExchangeLocation = "https://slack.com/api/oauth.access" | |||
slackIdentityLocation = "https://slack.com/api/users.identity" | |||
) | |||
var _ oauthClient = slackOauthClient{} | |||
func (c slackOauthClient) GetProvider() string { | |||
return "slack" | |||
} | |||
func (c slackOauthClient) GetClientID() string { | |||
return c.ClientID | |||
} | |||
func (c slackOauthClient) GetCallbackLocation() string { | |||
return c.CallbackLocation | |||
} | |||
func (c slackOauthClient) buildLoginURL(state string) (string, error) { | |||
u, err := url.Parse(slackAuthLocation) | |||
if err != nil { | |||
return "", err | |||
} | |||
q := u.Query() | |||
q.Set("client_id", c.ClientID) | |||
q.Set("scope", "identity.basic identity.email identity.team") | |||
q.Set("redirect_uri", c.CallbackLocation) | |||
q.Set("state", state) | |||
// If this param is not set, the user can select which team they | |||
// authenticate through and then we'd have to match the configured team | |||
// against the profile get. That is extra work in the post-auth phase | |||
// that we don't want to do. | |||
q.Set("team", c.TeamID) | |||
// The Slack OAuth docs don't explicitly list this one, but it is part of | |||
// the spec, so we include it anyway. | |||
q.Set("response_type", "code") | |||
u.RawQuery = q.Encode() | |||
return u.String(), nil | |||
} | |||
func (c slackOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { | |||
form := url.Values{} | |||
// The oauth.access documentation doesn't explicitly mention this | |||
// parameter, but it is part of the spec, so we include it anyway. | |||
// https://api.slack.com/methods/oauth.access | |||
form.Add("grant_type", "authorization_code") | |||
form.Add("redirect_uri", c.CallbackLocation) | |||
form.Add("code", code) | |||
req, err := http.NewRequest("POST", slackExchangeLocation, strings.NewReader(form.Encode())) | |||
if err != nil { | |||
return nil, err | |||
} | |||
req.WithContext(ctx) | |||
req.Header.Set("User-Agent", "writefreely") | |||
req.Header.Set("Accept", "application/json") | |||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") | |||
req.SetBasicAuth(c.ClientID, c.ClientSecret) | |||
resp, err := c.HttpClient.Do(req) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if resp.StatusCode != http.StatusOK { | |||
return nil, errors.New("unable to exchange code for access token") | |||
} | |||
var tokenResponse slackExchangeResponse | |||
if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil { | |||
return nil, err | |||
} | |||
if !tokenResponse.OK { | |||
return nil, errors.New(tokenResponse.Error) | |||
} | |||
return tokenResponse.TokenResponse(), nil | |||
} | |||
func (c slackOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) { | |||
req, err := http.NewRequest("GET", slackIdentityLocation, nil) | |||
if err != nil { | |||
return nil, err | |||
} | |||
req.WithContext(ctx) | |||
req.Header.Set("User-Agent", "writefreely") | |||
req.Header.Set("Accept", "application/json") | |||
req.Header.Set("Authorization", "Bearer "+accessToken) | |||
resp, err := c.HttpClient.Do(req) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if resp.StatusCode != http.StatusOK { | |||
return nil, errors.New("unable to inspect access token") | |||
} | |||
var inspectResponse slackUserIdentityResponse | |||
if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil { | |||
return nil, err | |||
} | |||
if !inspectResponse.OK { | |||
return nil, errors.New(inspectResponse.Error) | |||
} | |||
return inspectResponse.InspectResponse(), nil | |||
} | |||
func (resp slackUserIdentityResponse) InspectResponse() *InspectResponse { | |||
return &InspectResponse{ | |||
UserID: resp.User.ID, | |||
Username: fmt.Sprintf("%s-%s", slug.Make(resp.User.Name), store.GenerateRandomString("0123456789bcdfghjklmnpqrstvwxyz", 5)), | |||
DisplayName: resp.User.Name, | |||
Email: resp.User.Email, | |||
} | |||
} | |||
func (resp slackExchangeResponse) TokenResponse() *TokenResponse { | |||
return &TokenResponse{ | |||
AccessToken: resp.AccessToken, | |||
} | |||
} |
@@ -0,0 +1,253 @@ | |||
package writefreely | |||
import ( | |||
"context" | |||
"fmt" | |||
"github.com/gorilla/sessions" | |||
"github.com/stretchr/testify/assert" | |||
"github.com/writeas/impart" | |||
"github.com/writeas/nerds/store" | |||
"github.com/writeas/writefreely/config" | |||
"net/http" | |||
"net/http/httptest" | |||
"net/url" | |||
"strings" | |||
"testing" | |||
) | |||
type MockOAuthDatastoreProvider struct { | |||
DoDB func() OAuthDatastore | |||
DoConfig func() *config.Config | |||
DoSessionStore func() sessions.Store | |||
} | |||
type MockOAuthDatastore struct { | |||
DoGenerateOAuthState func(context.Context, string, string) (string, error) | |||
DoValidateOAuthState func(context.Context, string) (string, string, error) | |||
DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error) | |||
DoCreateUser func(*config.Config, *User, string) error | |||
DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error | |||
DoGetUserByID func(int64) (*User, error) | |||
} | |||
var _ OAuthDatastore = &MockOAuthDatastore{} | |||
type StringReadCloser struct { | |||
*strings.Reader | |||
} | |||
func (src *StringReadCloser) Close() error { | |||
return nil | |||
} | |||
type MockHTTPClient struct { | |||
DoDo func(req *http.Request) (*http.Response, error) | |||
} | |||
func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) { | |||
if m.DoDo != nil { | |||
return m.DoDo(req) | |||
} | |||
return &http.Response{}, nil | |||
} | |||
func (m *MockOAuthDatastoreProvider) SessionStore() sessions.Store { | |||
if m.DoSessionStore != nil { | |||
return m.DoSessionStore() | |||
} | |||
return sessions.NewCookieStore([]byte("secret-key")) | |||
} | |||
func (m *MockOAuthDatastoreProvider) DB() OAuthDatastore { | |||
if m.DoDB != nil { | |||
return m.DoDB() | |||
} | |||
return &MockOAuthDatastore{} | |||
} | |||
func (m *MockOAuthDatastoreProvider) Config() *config.Config { | |||
if m.DoConfig != nil { | |||
return m.DoConfig() | |||
} | |||
cfg := config.New() | |||
cfg.UseSQLite(true) | |||
cfg.WriteAsOauth = config.WriteAsOauthCfg{ | |||
ClientID: "development", | |||
ClientSecret: "development", | |||
AuthLocation: "https://write.as/oauth/login", | |||
TokenLocation: "https://write.as/oauth/token", | |||
InspectLocation: "https://write.as/oauth/inspect", | |||
} | |||
cfg.SlackOauth = config.SlackOauthCfg{ | |||
ClientID: "development", | |||
ClientSecret: "development", | |||
TeamID: "development", | |||
} | |||
return cfg | |||
} | |||
func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) { | |||
if m.DoValidateOAuthState != nil { | |||
return m.DoValidateOAuthState(ctx, state) | |||
} | |||
return "", "", nil | |||
} | |||
func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) { | |||
if m.DoGetIDForRemoteUser != nil { | |||
return m.DoGetIDForRemoteUser(ctx, remoteUserID, provider, clientID) | |||
} | |||
return -1, nil | |||
} | |||
func (m *MockOAuthDatastore) CreateUser(cfg *config.Config, u *User, username string) error { | |||
if m.DoCreateUser != nil { | |||
return m.DoCreateUser(cfg, u, username) | |||
} | |||
u.ID = 1 | |||
return nil | |||
} | |||
func (m *MockOAuthDatastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error { | |||
if m.DoRecordRemoteUserID != nil { | |||
return m.DoRecordRemoteUserID(ctx, localUserID, remoteUserID, provider, clientID, accessToken) | |||
} | |||
return nil | |||
} | |||
func (m *MockOAuthDatastore) GetUserByID(userID int64) (*User, error) { | |||
if m.DoGetUserByID != nil { | |||
return m.DoGetUserByID(userID) | |||
} | |||
user := &User{ | |||
} | |||
return user, nil | |||
} | |||
func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string) (string, error) { | |||
if m.DoGenerateOAuthState != nil { | |||
return m.DoGenerateOAuthState(ctx, provider, clientID) | |||
} | |||
return store.Generate62RandomString(14), nil | |||
} | |||
func TestViewOauthInit(t *testing.T) { | |||
t.Run("success", func(t *testing.T) { | |||
app := &MockOAuthDatastoreProvider{} | |||
h := oauthHandler{ | |||
Config: app.Config(), | |||
DB: app.DB(), | |||
Store: app.SessionStore(), | |||
EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd}, | |||
oauthClient: writeAsOauthClient{ | |||
ClientID: app.Config().WriteAsOauth.ClientID, | |||
ClientSecret: app.Config().WriteAsOauth.ClientSecret, | |||
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation, | |||
InspectLocation: app.Config().WriteAsOauth.InspectLocation, | |||
AuthLocation: app.Config().WriteAsOauth.AuthLocation, | |||
CallbackLocation: "http://localhost/oauth/callback", | |||
HttpClient: nil, | |||
}, | |||
} | |||
req, err := http.NewRequest("GET", "/oauth/client", nil) | |||
assert.NoError(t, err) | |||
rr := httptest.NewRecorder() | |||
err = h.viewOauthInit(nil, rr, req) | |||
assert.NotNil(t, err) | |||
httpErr, ok := err.(impart.HTTPError) | |||
assert.True(t, ok) | |||
assert.Equal(t, http.StatusTemporaryRedirect, httpErr.Status) | |||
assert.NotEmpty(t, httpErr.Message) | |||
locURI, err := url.Parse(httpErr.Message) | |||
assert.NoError(t, err) | |||
assert.Equal(t, "/oauth/login", locURI.Path) | |||
assert.Equal(t, "development", locURI.Query().Get("client_id")) | |||
assert.Equal(t, "http://localhost/oauth/callback", locURI.Query().Get("redirect_uri")) | |||
assert.Equal(t, "code", locURI.Query().Get("response_type")) | |||
assert.NotEmpty(t, locURI.Query().Get("state")) | |||
}) | |||
t.Run("state failure", func(t *testing.T) { | |||
app := &MockOAuthDatastoreProvider{ | |||
DoDB: func() OAuthDatastore { | |||
return &MockOAuthDatastore{ | |||
DoGenerateOAuthState: func(ctx context.Context, provider, clientID string) (string, error) { | |||
return "", fmt.Errorf("pretend unable to write state error") | |||
}, | |||
} | |||
}, | |||
} | |||
h := oauthHandler{ | |||
Config: app.Config(), | |||
DB: app.DB(), | |||
Store: app.SessionStore(), | |||
EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd}, | |||
oauthClient: writeAsOauthClient{ | |||
ClientID: app.Config().WriteAsOauth.ClientID, | |||
ClientSecret: app.Config().WriteAsOauth.ClientSecret, | |||
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation, | |||
InspectLocation: app.Config().WriteAsOauth.InspectLocation, | |||
AuthLocation: app.Config().WriteAsOauth.AuthLocation, | |||
CallbackLocation: "http://localhost/oauth/callback", | |||
HttpClient: nil, | |||
}, | |||
} | |||
req, err := http.NewRequest("GET", "/oauth/client", nil) | |||
assert.NoError(t, err) | |||
rr := httptest.NewRecorder() | |||
err = h.viewOauthInit(nil, rr, req) | |||
httpErr, ok := err.(impart.HTTPError) | |||
assert.True(t, ok) | |||
assert.NotEmpty(t, httpErr.Message) | |||
assert.Equal(t, http.StatusInternalServerError, httpErr.Status) | |||
assert.Equal(t, "could not prepare oauth redirect url", httpErr.Message) | |||
}) | |||
} | |||
func TestViewOauthCallback(t *testing.T) { | |||
t.Run("success", func(t *testing.T) { | |||
app := &MockOAuthDatastoreProvider{} | |||
h := oauthHandler{ | |||
Config: app.Config(), | |||
DB: app.DB(), | |||
Store: app.SessionStore(), | |||
EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd}, | |||
oauthClient: writeAsOauthClient{ | |||
ClientID: app.Config().WriteAsOauth.ClientID, | |||
ClientSecret: app.Config().WriteAsOauth.ClientSecret, | |||
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation, | |||
InspectLocation: app.Config().WriteAsOauth.InspectLocation, | |||
AuthLocation: app.Config().WriteAsOauth.AuthLocation, | |||
CallbackLocation: "http://localhost/oauth/callback", | |||
HttpClient: &MockHTTPClient{ | |||
DoDo: func(req *http.Request) (*http.Response, error) { | |||
switch req.URL.String() { | |||
case "https://write.as/oauth/token": | |||
return &http.Response{ | |||
StatusCode: 200, | |||
Body: &StringReadCloser{strings.NewReader(`{"access_token": "access_token", "expires_in": 1000, "refresh_token": "refresh_token", "token_type": "access"}`)}, | |||
}, nil | |||
case "https://write.as/oauth/inspect": | |||
return &http.Response{ | |||
StatusCode: 200, | |||
Body: &StringReadCloser{strings.NewReader(`{"client_id": "development", "user_id": "1", "expires_at": "2019-12-19T11:42:01Z", "username": "nick", "email": "nick@testing.write.as"}`)}, | |||
}, nil | |||
} | |||
return &http.Response{ | |||
StatusCode: http.StatusNotFound, | |||
}, nil | |||
}, | |||
}, | |||
}, | |||
} | |||
req, err := http.NewRequest("GET", "/oauth/callback", nil) | |||
assert.NoError(t, err) | |||
rr := httptest.NewRecorder() | |||
err = h.viewOauthCallback(nil, rr, req) | |||
assert.NoError(t, err) | |||
assert.Equal(t, http.StatusTemporaryRedirect, rr.Code) | |||
}) | |||
} |
@@ -0,0 +1,114 @@ | |||
package writefreely | |||
import ( | |||
"context" | |||
"errors" | |||
"net/http" | |||
"net/url" | |||
"strings" | |||
) | |||
type writeAsOauthClient struct { | |||
ClientID string | |||
ClientSecret string | |||
AuthLocation string | |||
ExchangeLocation string | |||
InspectLocation string | |||
CallbackLocation string | |||
HttpClient HttpClient | |||
} | |||
var _ oauthClient = writeAsOauthClient{} | |||
const ( | |||
writeAsAuthLocation = "https://write.as/oauth/login" | |||
writeAsExchangeLocation = "https://write.as/oauth/token" | |||
writeAsIdentityLocation = "https://write.as/oauth/inspect" | |||
) | |||
func (c writeAsOauthClient) GetProvider() string { | |||
return "write.as" | |||
} | |||
func (c writeAsOauthClient) GetClientID() string { | |||
return c.ClientID | |||
} | |||
func (c writeAsOauthClient) GetCallbackLocation() string { | |||
return c.CallbackLocation | |||
} | |||
func (c writeAsOauthClient) buildLoginURL(state string) (string, error) { | |||
u, err := url.Parse(c.AuthLocation) | |||
if err != nil { | |||
return "", err | |||
} | |||
q := u.Query() | |||
q.Set("client_id", c.ClientID) | |||
q.Set("redirect_uri", c.CallbackLocation) | |||
q.Set("response_type", "code") | |||
q.Set("state", state) | |||
u.RawQuery = q.Encode() | |||
return u.String(), nil | |||
} | |||
func (c writeAsOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { | |||
form := url.Values{} | |||
form.Add("grant_type", "authorization_code") | |||
form.Add("redirect_uri", c.CallbackLocation) | |||
form.Add("code", code) | |||
req, err := http.NewRequest("POST", c.ExchangeLocation, strings.NewReader(form.Encode())) | |||
if err != nil { | |||
return nil, err | |||
} | |||
req.WithContext(ctx) | |||
req.Header.Set("User-Agent", "writefreely") | |||
req.Header.Set("Accept", "application/json") | |||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") | |||
req.SetBasicAuth(c.ClientID, c.ClientSecret) | |||
resp, err := c.HttpClient.Do(req) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if resp.StatusCode != http.StatusOK { | |||
return nil, errors.New("unable to exchange code for access token") | |||
} | |||
var tokenResponse TokenResponse | |||
if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil { | |||
return nil, err | |||
} | |||
if tokenResponse.Error != "" { | |||
return nil, errors.New(tokenResponse.Error) | |||
} | |||
return &tokenResponse, nil | |||
} | |||
func (c writeAsOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) { | |||
req, err := http.NewRequest("GET", c.InspectLocation, nil) | |||
if err != nil { | |||
return nil, err | |||
} | |||
req.WithContext(ctx) | |||
req.Header.Set("User-Agent", "writefreely") | |||
req.Header.Set("Accept", "application/json") | |||
req.Header.Set("Authorization", "Bearer "+accessToken) | |||
resp, err := c.HttpClient.Do(req) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if resp.StatusCode != http.StatusOK { | |||
return nil, errors.New("unable to inspect access token") | |||
} | |||
var inspectResponse InspectResponse | |||
if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil { | |||
return nil, err | |||
} | |||
if inspectResponse.Error != "" { | |||
return nil, errors.New(inspectResponse.Error) | |||
} | |||
return &inspectResponse, nil | |||
} |
@@ -92,6 +92,7 @@ func handleViewPad(app *App, w http.ResponseWriter, r *http.Request) error { | |||
if err != nil { | |||
return err | |||
} | |||
appData.EditCollection.hostName = app.cfg.App.Host | |||
} else { | |||
// Editing a floating article | |||
appData.Post = getRawPost(app, action) | |||
@@ -161,6 +162,7 @@ func handleViewMeta(app *App, w http.ResponseWriter, r *http.Request) error { | |||
if err != nil { | |||
return err | |||
} | |||
appData.EditCollection.hostName = app.cfg.App.Host | |||
} else { | |||
// Editing a floating article | |||
appData.Post = getRawPost(app, action) | |||
@@ -1,7 +1,38 @@ | |||
{{define "head"}}<title>Log in — {{.SiteName}}</title> | |||
<meta name="description" content="Log in to {{.SiteName}}."> | |||
<meta itemprop="description" content="Log in to {{.SiteName}}."> | |||
<style>input{margin-bottom:0.5em;}</style> | |||
<style> | |||
input{margin-bottom:0.5em;} | |||
.or { | |||
text-align: center; | |||
margin-bottom: 3.5em; | |||
} | |||
.or p { | |||
display: inline-block; | |||
background-color: white; | |||
padding: 0 1em; | |||
} | |||
.or hr { | |||
margin-top: -1.6em; | |||
margin-bottom: 0; | |||
} | |||
hr.short { | |||
max-width: 30rem; | |||
} | |||
.row.signinbtns { | |||
justify-content: space-evenly; | |||
font-size: 1em; | |||
margin-top: 3em; | |||
margin-bottom: 2em; | |||
} | |||
.loginbtn { | |||
height: 40px; | |||
} | |||
#writeas-login { | |||
box-sizing: border-box; | |||
font-size: 17px; | |||
} | |||
</style> | |||
{{end}} | |||
{{define "content"}} | |||
<div class="tight content-container"> | |||
@@ -11,6 +42,22 @@ | |||
{{range .Flashes}}<li class="urgent">{{.}}</li>{{end}} | |||
</ul>{{end}} | |||
{{ if or .OauthSlack .OauthWriteAs }} | |||
<div class="row content-container signinbtns"> | |||
{{ if .OauthSlack }} | |||
<a class="loginbtn" href="/oauth/slack"><img alt="Sign in with Slack" height="40" width="172" src="/img/sign_in_with_slack.png" srcset="/img/sign_in_with_slack.png 1x, /img/sign_in_with_slack@2x.png 2x" /></a> | |||
{{ end }} | |||
{{ if .OauthWriteAs }} | |||
<a class="btn cta loginbtn" id="writeas-login" href="/oauth/write.as">Sign in with <strong>Write.as</strong></a> | |||
{{ end }} | |||
</div> | |||
<div class="or"> | |||
<p>or</p> | |||
<hr class="short" /> | |||
</div> | |||
{{ end }} | |||
<form action="/auth/login" method="post" style="text-align: center;margin-top:1em;" onsubmit="disableSubmit()"> | |||
<input type="text" name="alias" placeholder="Username" value="{{.LoginUsername}}" {{if not .LoginUsername}}autofocus{{end}} /><br /> | |||
<input type="password" name="pass" placeholder="Password" {{if .LoginUsername}}autofocus{{end}} /><br /> | |||
@@ -0,0 +1,174 @@ | |||
{{define "head"}}<title>Log in — {{.SiteName}}</title> | |||
<meta name="description" content="Log in to {{.SiteName}}."> | |||
<meta itemprop="description" content="Log in to {{.SiteName}}."> | |||
<style>input{margin-bottom:0.5em;}</style> | |||
<style type="text/css"> | |||
h2 { | |||
font-weight: normal; | |||
} | |||
#pricing.content-container div.form-container #payment-form { | |||
display: block !important; | |||
} | |||
#pricing #signup-form table { | |||
max-width: inherit !important; | |||
width: 100%; | |||
} | |||
#pricing #payment-form table { | |||
margin-top: 0 !important; | |||
max-width: inherit !important; | |||
width: 100%; | |||
} | |||
tr.subscription { | |||
border-spacing: 0; | |||
} | |||
#pricing.content-container tr.subscription button { | |||
margin-top: 0 !important; | |||
margin-bottom: 0 !important; | |||
width: 100%; | |||
} | |||
#pricing tr.subscription td { | |||
padding: 0 0.5em; | |||
} | |||
#pricing table.billing > tbody > tr > td:first-child { | |||
vertical-align: middle !important; | |||
} | |||
.billing-section { | |||
display: none; | |||
} | |||
.billing-section.bill-me { | |||
display: table-row; | |||
} | |||
#btn-create { | |||
color: white !important; | |||
} | |||
#total-price { | |||
padding-left: 0.5em; | |||
} | |||
#alias-site.demo { | |||
color: #999; | |||
} | |||
#alias-site { | |||
text-align: left; | |||
margin: 0.5em 0; | |||
} | |||
form dd { | |||
margin: 0; | |||
} | |||
</style> | |||
{{end}} | |||
{{define "content"}} | |||
<div id="pricing" class="tight content-container"> | |||
<h1>Log in to {{.SiteName}}</h1> | |||
{{if .Flashes}}<ul class="errors"> | |||
{{range .Flashes}}<li class="urgent">{{.}}</li>{{end}} | |||
</ul>{{end}} | |||
<div id="billing"> | |||
<form action="/oauth/signup" method="post" style="text-align: center;margin-top:1em;" onsubmit="return disableSubmit()"> | |||
<input type="hidden" name="access_token" value="{{ .AccessToken }}" /> | |||
<input type="hidden" name="token_username" value="{{ .TokenUsername }}" /> | |||
<input type="hidden" name="token_alias" value="{{ .TokenAlias }}" /> | |||
<input type="hidden" name="token_email" value="{{ .TokenEmail }}" /> | |||
<input type="hidden" name="token_remote_user" value="{{ .TokenRemoteUser }}" /> | |||
<input type="hidden" name="provider" value="{{ .Provider }}" /> | |||
<input type="hidden" name="client_id" value="{{ .ClientID }}" /> | |||
<input type="hidden" name="signature" value="{{ .TokenHash }}" /> | |||
<dl class="billing"> | |||
<label> | |||
<dt>Display Name</dt> | |||
<dd> | |||
<input type="text" style="width: 100%; box-sizing: border-box;" name="alias" placeholder="Name"{{ if .Alias }} value="{{.Alias}}"{{ end }} /> | |||
</dd> | |||
</label> | |||
<label> | |||
<dt>Username</dt> | |||
<dd> | |||
<input type="text" id="username" name="username" style="width: 100%; box-sizing: border-box;" placeholder="Username" value="{{.LoginUsername}}" /><br /> | |||
{{if .Federation}}<p id="alias-site" class="demo">@<strong>your-username</strong>@{{.FriendlyHost}}</p>{{else}}<p id="alias-site" class="demo">{{.FriendlyHost}}/<strong>your-username</strong></p>{{end}} | |||
</dd> | |||
</label> | |||
<label> | |||
<dt>Email</dt> | |||
<dd> | |||
<input type="text" name="email" style="width: 100%; box-sizing: border-box;" placeholder="Email"{{ if .Email }} value="{{.Email}}"{{ end }} /> | |||
</dd> | |||
</label> | |||
<dt> | |||
<input type="submit" id="btn-login" value="Login" /> | |||
</dt> | |||
</dl> | |||
</form> | |||
</div> | |||
<script type="text/javascript" src="/js/h.js"></script> | |||
<script type="text/javascript"> | |||
// Copied from signup.tmpl | |||
// NOTE: this element is named "alias" on signup.tmpl and "username" here | |||
var $alias = H.getEl('username'); | |||
function disableSubmit() { | |||
// Validate input | |||
if (!aliasOK) { | |||
var $a = $alias; | |||
$a.el.className = 'error'; | |||
$a.el.focus(); | |||
$a.el.scrollIntoView(); | |||
return false; | |||
} | |||
var $btn = document.getElementById("btn-login"); | |||
$btn.value = "Logging in..."; | |||
$btn.disabled = true; | |||
return true; | |||
} | |||
// Copied from signup.tmpl | |||
var $aliasSite = document.getElementById('alias-site'); | |||
var aliasOK = true; | |||
var typingTimer; | |||
var doneTypingInterval = 750; | |||
var doneTyping = function() { | |||
// Check on username | |||
var alias = $alias.el.value; | |||
if (alias != "") { | |||
var params = { | |||
username: alias | |||
}; | |||
var http = new XMLHttpRequest(); | |||
http.open("POST", '/api/alias', true); | |||
// Send the proper header information along with the request | |||
http.setRequestHeader("Content-type", "application/json"); | |||
http.onreadystatechange = function() { | |||
if (http.readyState == 4) { | |||
data = JSON.parse(http.responseText); | |||
if (http.status == 200) { | |||
aliasOK = true; | |||
$alias.removeClass('error'); | |||
$aliasSite.className = $aliasSite.className.replace(/(?:^|\s)demo(?!\S)/g, ''); | |||
$aliasSite.className = $aliasSite.className.replace(/(?:^|\s)error(?!\S)/g, ''); | |||
$aliasSite.innerHTML = '{{ if .Federation }}@<strong>' + data.data + '</strong>@{{.FriendlyHost}}{{ else }}{{.FriendlyHost}}/<strong>' + data.data + '</strong>/{{ end }}'; | |||
} else { | |||
aliasOK = false; | |||
$alias.setClass('error'); | |||
$aliasSite.className = 'error'; | |||
$aliasSite.textContent = data.error_msg; | |||
} | |||
} | |||
} | |||
http.send(JSON.stringify(params)); | |||
} else { | |||
$aliasSite.className += ' demo'; | |||
$aliasSite.innerHTML = '{{ if .Federation }}@<strong>your-username</strong>@{{.FriendlyHost}}{{ else }}{{.FriendlyHost}}/<strong>your-username</strong>/{{ end }}'; | |||
} | |||
}; | |||
$alias.on('keyup input', function() { | |||
clearTimeout(typingTimer); | |||
typingTimer = setTimeout(doneTyping, doneTypingInterval); | |||
}); | |||
doneTyping(); | |||
</script> | |||
{{end}} |
@@ -1,5 +1,5 @@ | |||
/* | |||
* Copyright © 2018 A Bunch Tell LLC. | |||
* Copyright © 2018-2020 A Bunch Tell LLC. | |||
* | |||
* This file is part of WriteFreely. | |||
* | |||
@@ -11,9 +11,11 @@ | |||
package writefreely | |||
import ( | |||
"encoding/json" | |||
"fmt" | |||
"html" | |||
"html/template" | |||
"net/http" | |||
"regexp" | |||
"strings" | |||
"unicode" | |||
@@ -21,7 +23,9 @@ import ( | |||
"github.com/microcosm-cc/bluemonday" | |||
stripmd "github.com/writeas/go-strip-markdown" | |||
"github.com/writeas/impart" | |||
blackfriday "github.com/writeas/saturday" | |||
"github.com/writeas/web-core/log" | |||
"github.com/writeas/web-core/stringmanip" | |||
"github.com/writeas/writefreely/config" | |||
"github.com/writeas/writefreely/parse" | |||
@@ -34,6 +38,7 @@ var ( | |||
titleElementReg = regexp.MustCompile("</?h[1-6]>") | |||
hashtagReg = regexp.MustCompile(`{{\[\[\|\|([^|]+)\|\|\]\]}}`) | |||
markeddownReg = regexp.MustCompile("<p>(.+)</p>") | |||
mentionReg = regexp.MustCompile(`@([A-Za-z0-9._%+-]+)(@[A-Za-z0-9.-]+\.[A-Za-z]+)\b`) | |||
) | |||
func (p *Post) formatContent(cfg *config.Config, c *Collection, isOwner bool) { | |||
@@ -82,6 +87,8 @@ func applyMarkdownSpecial(data []byte, skipNoFollow bool, baseURL string, cfg *c | |||
tagPrefix = "/read/t/" | |||
} | |||
md = []byte(hashtagReg.ReplaceAll(md, []byte("<a href=\""+tagPrefix+"$1\" class=\"hashtag\"><span>#</span><span class=\"p-category\">$1</span></a>"))) | |||
handlePrefix := cfg.App.Host + "/@/" | |||
md = []byte(mentionReg.ReplaceAll(md, []byte("<a href=\""+handlePrefix+"$1$2\" class=\"u-url mention\">@<span>$1$2</span></a>"))) | |||
} | |||
// Strip out bad HTML | |||
policy := getSanitizationPolicy() | |||
@@ -234,3 +241,29 @@ func shortPostDescription(content string) string { | |||
} | |||
return strings.TrimSpace(fmt.Sprintf(fmtStr, strings.Replace(stringmanip.Substring(content, 0, maxLen-truncation), "\n", " ", -1))) | |||
} | |||
func handleRenderMarkdown(app *App, w http.ResponseWriter, r *http.Request) error { | |||
if !IsJSON(r) { | |||
return impart.HTTPError{Status: http.StatusUnsupportedMediaType, Message: "Markdown API only supports JSON requests"} | |||
} | |||
in := struct { | |||
CollectionURL string `json:"collection_url"` | |||
RawBody string `json:"raw_body"` | |||
}{} | |||
decoder := json.NewDecoder(r.Body) | |||
err := decoder.Decode(&in) | |||
if err != nil { | |||
log.Error("Couldn't parse markdown JSON request: %v", err) | |||
return ErrBadJSON | |||
} | |||
out := struct { | |||
Body string `json:"body"` | |||
}{ | |||
Body: applyMarkdown([]byte(in.RawBody), in.CollectionURL, app.cfg), | |||
} | |||
return impart.WriteSuccess(w, out, http.StatusOK) | |||
} |
@@ -1,5 +1,5 @@ | |||
/* | |||
* Copyright © 2018-2019 A Bunch Tell LLC. | |||
* Copyright © 2018-2020 A Bunch Tell LLC. | |||
* | |||
* This file is part of WriteFreely. | |||
* | |||
@@ -35,7 +35,6 @@ import ( | |||
"github.com/writeas/web-core/i18n" | |||
"github.com/writeas/web-core/log" | |||
"github.com/writeas/web-core/tags" | |||
"github.com/writeas/writefreely/config" | |||
"github.com/writeas/writefreely/page" | |||
"github.com/writeas/writefreely/parse" | |||
) | |||
@@ -229,6 +228,10 @@ func (p Post) Summary() string { | |||
return shortPostDescription(p.Content) | |||
} | |||
func (p Post) SummaryHTML() template.HTML { | |||
return template.HTML(p.Summary()) | |||
} | |||
// Excerpt shows any text that comes before a (more) tag. | |||
// TODO: use HTMLExcerpt in templates instead of this method | |||
func (p *Post) Excerpt() template.HTML { | |||
@@ -381,10 +384,12 @@ func handleViewPost(app *App, w http.ResponseWriter, r *http.Request) error { | |||
} | |||
} | |||
silenced, err := app.db.IsUserSilenced(ownerID.Int64) | |||
if err != nil { | |||
log.Error("view post: %v", err) | |||
return ErrInternalGeneral | |||
var silenced bool | |||
if found { | |||
silenced, err = app.db.IsUserSuspended(ownerID.Int64) | |||
if err != nil { | |||
log.Error("view post: %v", err) | |||
} | |||
} | |||
// Check if post has been unpublished | |||
@@ -511,7 +516,6 @@ func newPost(app *App, w http.ResponseWriter, r *http.Request) error { | |||
silenced, err := app.db.IsUserSilenced(userID) | |||
if err != nil { | |||
log.Error("new post: %v", err) | |||
return ErrInternalGeneral | |||
} | |||
if silenced { | |||
return ErrUserSilenced | |||
@@ -685,7 +689,6 @@ func existingPost(app *App, w http.ResponseWriter, r *http.Request) error { | |||
silenced, err := app.db.IsUserSilenced(userID) | |||
if err != nil { | |||
log.Error("existing post: %v", err) | |||
return ErrInternalGeneral | |||
} | |||
if silenced { | |||
return ErrUserSilenced | |||
@@ -888,7 +891,6 @@ func addPost(app *App, w http.ResponseWriter, r *http.Request) error { | |||
silenced, err := app.db.IsUserSilenced(ownerID) | |||
if err != nil { | |||
log.Error("add post: %v", err) | |||
return ErrInternalGeneral | |||
} | |||
if silenced { | |||
return ErrUserSilenced | |||
@@ -991,7 +993,6 @@ func pinPost(app *App, w http.ResponseWriter, r *http.Request) error { | |||
silenced, err := app.db.IsUserSilenced(userID) | |||
if err != nil { | |||
log.Error("pin post: %v", err) | |||
return ErrInternalGeneral | |||
} | |||
if silenced { | |||
return ErrUserSilenced | |||
@@ -1039,7 +1040,6 @@ func pinPost(app *App, w http.ResponseWriter, r *http.Request) error { | |||
func fetchPost(app *App, w http.ResponseWriter, r *http.Request) error { | |||
var collID int64 | |||
var ownerID int64 | |||
var coll *Collection | |||
var err error | |||
vars := mux.Vars(r) | |||
@@ -1049,25 +1049,32 @@ func fetchPost(app *App, w http.ResponseWriter, r *http.Request) error { | |||
if err != nil { | |||
return err | |||
} | |||
coll.hostName = app.cfg.App.Host | |||
_, err = apiCheckCollectionPermissions(app, r, coll) | |||
if err != nil { | |||
return err | |||
} | |||
collID = coll.ID | |||
ownerID = coll.OwnerID | |||
} | |||
p, err := app.db.GetPost(vars["post"], collID) | |||
if err != nil { | |||
return err | |||
} | |||
silenced, err := app.db.IsUserSilenced(ownerID) | |||
if coll == nil && p.CollectionID.Valid { | |||
// Collection post is getting fetched by post ID, not coll alias + post slug, so get coll info now. | |||
coll, err = app.db.GetCollectionByID(p.CollectionID.Int64) | |||
if err != nil { | |||
return err | |||
} | |||
} | |||
if coll != nil { | |||
coll.hostName = app.cfg.App.Host | |||
_, err = apiCheckCollectionPermissions(app, r, coll) | |||
if err != nil { | |||
return err | |||
} | |||
} | |||
silenced, err := app.db.IsUserSilenced(p.OwnerID.Int64) | |||
if err != nil { | |||
log.Error("fetch post: %v", err) | |||
return ErrInternalGeneral | |||
} | |||
if silenced { | |||
return ErrPostNotFound | |||
} | |||
@@ -1076,13 +1083,6 @@ func fetchPost(app *App, w http.ResponseWriter, r *http.Request) error { | |||
accept := r.Header.Get("Accept") | |||
if strings.Contains(accept, "application/activity+json") { | |||
// Fetch information about the collection this belongs to | |||
if coll == nil && p.CollectionID.Valid { | |||
coll, err = app.db.GetCollectionByID(p.CollectionID.Int64) | |||
if err != nil { | |||
return err | |||
} | |||
} | |||
if coll == nil { | |||
// This is a draft post; 404 for now | |||
// TODO: return ActivityObject | |||
@@ -1090,8 +1090,9 @@ func fetchPost(app *App, w http.ResponseWriter, r *http.Request) error { | |||
} | |||
p.Collection = &CollectionObj{Collection: *coll} | |||
po := p.ActivityObject(app.cfg) | |||
po := p.ActivityObject(app) | |||
po.Context = []interface{}{activitystreams.Namespace} | |||
setCacheControl(w, apCacheTime) | |||
return impart.RenderActivityJSON(w, po, http.StatusOK) | |||
} | |||
@@ -1125,7 +1126,8 @@ func (p *PublicPost) CanonicalURL(hostName string) string { | |||
return p.Collection.CanonicalURL() + p.Slug.String | |||
} | |||
func (p *PublicPost) ActivityObject(cfg *config.Config) *activitystreams.Object { | |||
func (p *PublicPost) ActivityObject(app *App) *activitystreams.Object { | |||
cfg := app.cfg | |||
o := activitystreams.NewArticleObject() | |||
o.ID = p.Collection.FederatedAPIBase() + "api/posts/" + p.ID | |||
o.Published = p.Created | |||
@@ -1165,6 +1167,27 @@ func (p *PublicPost) ActivityObject(cfg *config.Config) *activitystreams.Object | |||
}) | |||
} | |||
} | |||
// Find mentioned users | |||
mentionedUsers := make(map[string]string) | |||
stripper := bluemonday.StrictPolicy() | |||
content := stripper.Sanitize(p.Content) | |||
mentionRegex := regexp.MustCompile(`@[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]+\b`) | |||
mentions := mentionRegex.FindAllString(content, -1) | |||
for _, handle := range mentions { | |||
actorIRI, err := app.db.GetProfilePageFromHandle(app, handle) | |||
if err != nil { | |||
log.Info("Can't find this user either in the database nor in the remote instance") | |||
return nil | |||
} | |||
mentionedUsers[handle] = actorIRI | |||
} | |||
for handle, iri := range mentionedUsers { | |||
o.CC = append(o.CC, iri) | |||
o.Tag = append(o.Tag, activitystreams.Tag{Type: "Mention", HRef: iri, Name: handle}) | |||
} | |||
return o | |||
} | |||
@@ -1335,15 +1358,18 @@ func viewCollectionPost(app *App, w http.ResponseWriter, r *http.Request) error | |||
silenced, err := app.db.IsUserSilenced(c.OwnerID) | |||
if err != nil { | |||
log.Error("view collection post: %v", err) | |||
return ErrInternalGeneral | |||
} | |||
// Check collection permissions | |||
if c.IsPrivate() && (u == nil || u.ID != c.OwnerID) { | |||
return ErrPostNotFound | |||
} | |||
if c.IsProtected() && ((u == nil || u.ID != c.OwnerID) && !isAuthorizedForCollection(app, c.Alias, r)) { | |||
return impart.HTTPError{http.StatusFound, c.CanonicalURL() + "/?g=" + slug} | |||
if c.IsProtected() && (u == nil || u.ID != c.OwnerID) { | |||
if suspended { | |||
return ErrPostNotFound | |||
} else if !isAuthorizedForCollection(app, c.Alias, r) { | |||
return impart.HTTPError{http.StatusFound, c.CanonicalURL() + "/?g=" + slug} | |||
} | |||
} | |||
cr.isCollOwner = u != nil && c.OwnerID == u.ID | |||
@@ -1354,7 +1380,7 @@ func viewCollectionPost(app *App, w http.ResponseWriter, r *http.Request) error | |||
// Fetch extra data about the Collection | |||
// TODO: refactor out this logic, shared in collection.go:fetchCollection() | |||
coll := &CollectionObj{Collection: *c} | |||
coll := NewCollectionObj(c) | |||
owner, err := app.db.GetUserByID(coll.OwnerID) | |||
if err != nil { | |||
// Log the error and just continue | |||
@@ -1390,7 +1416,7 @@ Are you sure it was ever here?`, | |||
return err | |||
} | |||
} | |||
p.IsOwner = owner != nil && p.OwnerID.Valid && u.ID == p.OwnerID.Int64 | |||
p.IsOwner = owner != nil && p.OwnerID.Valid && owner.ID == p.OwnerID.Int64 | |||
p.Collection = coll | |||
p.IsTopLevel = app.cfg.App.SingleUser | |||
@@ -1428,8 +1454,9 @@ Are you sure it was ever here?`, | |||
return ErrCollectionPageNotFound | |||
} | |||
p.extractData() | |||
ap := p.ActivityObject(app.cfg) | |||
ap := p.ActivityObject(app) | |||
ap.Context = []interface{}{activitystreams.Namespace} | |||
setCacheControl(w, apCacheTime) | |||
return impart.RenderActivityJSON(w, ap, http.StatusOK) | |||
} else { | |||
p.extractData() | |||
@@ -70,6 +70,12 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router { | |||
write.HandleFunc(nodeinfo.NodeInfoPath, handler.LogHandlerFunc(http.HandlerFunc(ni.NodeInfoDiscover))) | |||
write.HandleFunc(niCfg.InfoURL, handler.LogHandlerFunc(http.HandlerFunc(ni.NodeInfo))) | |||
// handle mentions | |||
write.HandleFunc("/@/{handle}", handler.Web(handleViewMention, UserLevelReader)) | |||
configureSlackOauth(handler, write, apper.App()) | |||
configureWriteAsOauth(handler, write, apper.App()) | |||
// Set up dyamic page handlers | |||
// Handle auth | |||
auth := write.PathPrefix("/api/auth/").Subrouter() | |||
@@ -94,6 +100,7 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router { | |||
me.HandleFunc("/posts/export.json", handler.Download(viewExportPosts, UserLevelUser)).Methods("GET") | |||
me.HandleFunc("/export", handler.User(viewExportOptions)).Methods("GET") | |||
me.HandleFunc("/export.json", handler.Download(viewExportFull, UserLevelUser)).Methods("GET") | |||
me.HandleFunc("/import", handler.User(viewImport)).Methods("GET") | |||
me.HandleFunc("/settings", handler.User(viewSettings)).Methods("GET") | |||
me.HandleFunc("/invites", handler.User(handleViewUserInvites)).Methods("GET") | |||
me.HandleFunc("/logout", handler.Web(viewLogout, UserLevelNone)).Methods("GET") | |||
@@ -106,10 +113,13 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router { | |||
apiMe.HandleFunc("/password", handler.All(updatePassphrase)).Methods("POST") | |||
apiMe.HandleFunc("/self", handler.All(updateSettings)).Methods("POST") | |||
apiMe.HandleFunc("/invites", handler.User(handleCreateUserInvite)).Methods("POST") | |||
apiMe.HandleFunc("/import", handler.User(handleImport)).Methods("POST") | |||
// Sign up validation | |||
write.HandleFunc("/api/alias", handler.All(handleUsernameCheck)).Methods("POST") | |||
write.HandleFunc("/api/markdown", handler.All(handleRenderMarkdown)).Methods("POST") | |||
// Handle collections | |||
write.HandleFunc("/api/collections", handler.All(newCollection)).Methods("POST") | |||
apiColls := write.PathPrefix("/api/collections/").Subrouter() | |||
@@ -162,9 +172,9 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router { | |||
draftEditPrefix := "" | |||
if apper.App().cfg.App.SingleUser { | |||
draftEditPrefix = "/d" | |||
write.HandleFunc("/me/new", handler.Web(handleViewPad, UserLevelOptional)).Methods("GET") | |||
write.HandleFunc("/me/new", handler.Web(handleViewPad, UserLevelUser)).Methods("GET") | |||
} else { | |||
write.HandleFunc("/new", handler.Web(handleViewPad, UserLevelOptional)).Methods("GET") | |||
write.HandleFunc("/new", handler.Web(handleViewPad, UserLevelUser)).Methods("GET") | |||
} | |||
// All the existing stuff | |||
@@ -181,6 +191,7 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router { | |||
} | |||
write.HandleFunc(draftEditPrefix+"/{post}", handler.Web(handleViewPost, UserLevelOptional)) | |||
write.HandleFunc("/", handler.Web(handleViewHome, UserLevelOptional)) | |||
return r | |||
} | |||
@@ -11,7 +11,7 @@ | |||
## have not installed the binary `writefreely` in another location. ## | |||
############################################################################### | |||
# | |||
# Copyright © 2019 A Bunch Tell LLC. | |||
# Copyright © 2019-2020 A Bunch Tell LLC. | |||
# | |||
# This file is part of WriteFreely. | |||
# | |||
@@ -31,7 +31,7 @@ fi | |||
# go ahead and check for the latest release on linux | |||
echo "Checking for updates..." | |||
url=`curl -s https://api.github.com/repos/writeas/writefreely/releases/latest | grep 'browser_' | grep linux | cut -d\" -f4` | |||
url=`curl -s https://api.github.com/repos/writeas/writefreely/releases/latest | grep 'browser_' | grep 'linux' | grep 'amd64' | cut -d\" -f4` | |||
# check current version | |||
@@ -82,13 +82,25 @@ filename=${parts[-1]} | |||
echo "Extracting files..." | |||
tar -zxf $tempdir/$filename -C $tempdir | |||
# stop service | |||
echo "Stopping writefreely systemd service..." | |||
if `systemctl start writefreely`; then | |||
echo "Success, service stopped." | |||
else | |||
echo "Upgrade failed to stop the systemd service, exiting early." | |||
exit 1 | |||
fi | |||
# copy files | |||
echo "Copying files..." | |||
cp -r $tempdir/{pages,static,templates,writefreely} . | |||
cp -r $tempdir/writefreely/{pages,static,templates,writefreely} . | |||
# migrate db | |||
./writefreely -migrate | |||
# restart service | |||
echo "Restarting writefreely systemd service..." | |||
if `systemctl restart writefreely`; then | |||
echo "Starting writefreely systemd service..." | |||
if `systemctl start writefreely`; then | |||
echo "Success, version has been upgraded to $latest." | |||
else | |||
echo "Upgrade complete, but failed to restart service." | |||
@@ -0,0 +1,16 @@ | |||
function toLocalDate(dateEl, displayEl) { | |||
var d = new Date(dateEl.getAttribute("datetime")); | |||
displayEl.textContent = d.toLocaleDateString(navigator.language || "en-US", { year: 'numeric', month: 'long', day: 'numeric' }); | |||
} | |||
// Adjust dates on individual post pages, and on posts in a list *with* an explicit title | |||
var $dates = document.querySelectorAll("article > time"); | |||
for (var i=0; i < $dates.length; i++) { | |||
toLocalDate($dates[i], $dates[i]); | |||
} | |||
// Adjust dates on posts in a list without an explicit title, where they act as the header | |||
$dates = document.querySelectorAll("h2.post-title > time"); | |||
for (i=0; i < $dates.length; i++) { | |||
toLocalDate($dates[i], $dates[i].querySelector('a')); | |||
} |
@@ -22,7 +22,7 @@ | |||
{{ end }} | |||
{{if not .SingleUser}} | |||
<nav id="user-nav"> | |||
{{if and .Chorus .Username}} | |||
{{if .Username}} | |||
<nav class="dropdown-nav"> | |||
<ul><li><a>{{.Username}}</a> <img class="ic-18dp" src="/img/ic_down_arrow_dark@2x.png" /><ul> | |||
{{if .IsAdmin}}<li><a href="/admin">Admin dashboard</a></li>{{end}} | |||
@@ -39,10 +39,10 @@ | |||
{{ if and .SimpleNav (not .SingleUser) }} | |||
{{if and (and .LocalTimeline .CanViewReader) .Chorus}}<a href="/"{{if eq .Path "/"}} class="selected"{{end}}>Home</a>{{end}} | |||
{{ end }} | |||
<a href="/about"{{if eq .Path "/about"}} class="selected"{{end}}>About</a> | |||
{{if or .Chorus (not .Username)}}<a href="/about"{{if eq .Path "/about"}} class="selected"{{end}}>About</a>{{end}} | |||
{{ if not .SingleUser }} | |||
{{ if .Username }} | |||
{{if gt .MaxBlogs 1}}<a href="/me/c/"{{if eq .Path "/me/c/"}} class="selected"{{end}}>Blogs</a>{{end}} | |||
{{if or (not .Chorus) (gt .MaxBlogs 1)}}<a href="/me/c/"{{if eq .Path "/me/c/"}} class="selected"{{end}}>Blogs</a>{{end}} | |||
{{if and (and .Chorus (eq .MaxBlogs 1)) .Username}}<a href="/{{.Username}}/"{{if eq .Path (printf "/%s/" .Username)}} class="selected"{{end}}>My Posts</a>{{end}} | |||
{{if not .DisableDrafts}}<a href="/me/posts/"{{if eq .Path "/me/posts/"}} class="selected"{{end}}>Drafts</a>{{end}} | |||
{{ end }} | |||
@@ -38,16 +38,6 @@ body footer { | |||
body#post header { | |||
padding: 1em 1rem; | |||
} | |||
article time.dt-published { | |||
display: block; | |||
color: #666; | |||
} | |||
body#post article h2#title{ | |||
margin-bottom: 0.5em; | |||
} | |||
article time.dt-published { | |||
margin-bottom: 1em; | |||
} | |||
</style> | |||
{{if .Collection.RenderMathJax}} | |||
@@ -68,7 +58,7 @@ article time.dt-published { | |||
{{if .Silenced}} | |||
{{template "user-silenced"}} | |||
{{end}} | |||
<article id="post-body" class="{{.Font}} h-entry">{{if .IsScheduled}}<p class="badge">Scheduled</p>{{end}}{{if .Title.String}}<h2 id="title" class="p-name">{{.FormattedDisplayTitle}}</h2>{{end}}{{/* TODO: check format: if .Collection.Format.ShowDates*/}}<time class="dt-published" datetime="{{.Created}}" pubdate itemprop="datePublished" content="{{.Created}}">{{.DisplayDate}}</time><div class="e-content">{{.HTMLContent}}</div></article> | |||
<article id="post-body" class="{{.Font}} h-entry">{{if .IsScheduled}}<p class="badge">Scheduled</p>{{end}}{{if .Title.String}}<h2 id="title" class="p-name{{if $.Collection.Format.ShowDates}} dated{{end}}">{{.FormattedDisplayTitle}}</h2>{{end}}{{if $.Collection.Format.ShowDates}}<time class="dt-published" datetime="{{.Created8601}}" pubdate itemprop="datePublished" content="{{.Created}}">{{.DisplayDate}}</time>{{end}}<div class="e-content">{{.HTMLContent}}</div></article> | |||
{{ if .Collection.ShowFooterBranding }} | |||
<footer dir="ltr"> | |||
@@ -80,7 +70,7 @@ article time.dt-published { | |||
</p> | |||
<nav> | |||
{{if .PinnedPosts}} | |||
{{range .PinnedPosts}}<a class="pinned{{if eq .Slug.String $.Slug.String}} selected{{end}}" href="{{if not $.SingleUser}}/{{$.Collection.Alias}}/{{.Slug.String}}{{else}}{{.CanonicalURL .Host}}{{end}}">{{.PlainDisplayTitle}}</a>{{end}} | |||
{{range .PinnedPosts}}<a class="pinned{{if eq .Slug.String $.Slug.String}} selected{{end}}" href="{{if not $.SingleUser}}/{{$.Collection.Alias}}/{{.Slug.String}}{{else}}{{.CanonicalURL $.Host}}{{end}}">{{.PlainDisplayTitle}}</a>{{end}} | |||
{{end}} | |||
</nav> | |||
<hr> | |||
@@ -93,6 +83,7 @@ article time.dt-published { | |||
{{range .Collection.ExternalScripts}}<script type="text/javascript" src="{{.}}" async></script>{{end}} | |||
{{if .Collection.Script}}<script type="text/javascript">{{.Collection.ScriptDisplay}}</script>{{end}} | |||
{{end}} | |||
<script src="/js/localdate.js"></script> | |||
<script type="text/javascript"> | |||
var pinning = false; | |||
@@ -71,7 +71,7 @@ body#collection header nav.tabs a:first-child { | |||
<!--p class="meta-note"><span>Private collection</span>. Only you can see this page.</p--> | |||
{{/*end*/}} | |||
{{if .PinnedPosts}}<nav class="pinned-posts"> | |||
{{range .PinnedPosts}}<a class="pinned" href="{{if not $.SingleUser}}/{{$.Alias}}/{{.Slug.String}}{{else}}{{.CanonicalURL .Host}}{{end}}">{{.PlainDisplayTitle}}</a>{{end}}</nav> | |||
{{range .PinnedPosts}}<a class="pinned" href="{{if not $.SingleUser}}/{{$.Alias}}/{{.Slug.String}}{{else}}{{.CanonicalURL $.Host}}{{end}}">{{.PlainDisplayTitle}}</a>{{end}}</nav> | |||
{{end}} | |||
</header> | |||
@@ -115,6 +115,7 @@ body#collection header nav.tabs a:first-child { | |||
{{if .Script}}<script type="text/javascript">{{.ScriptDisplay}}</script>{{end}} | |||
{{end}} | |||
<script src="/js/h.js"></script> | |||
<script src="/js/localdate.js"></script> | |||
<script src="/js/postactions.js"></script> | |||
<script type="text/javascript"> | |||
var deleting = false; | |||
@@ -50,7 +50,7 @@ | |||
<h1 dir="{{.Direction}}" id="blog-title"><a rel="author" href="{{if .IsTopLevel}}/{{else}}/{{.Collection.Alias}}/{{end}}" class="h-card p-author">{{.Collection.DisplayTitle}}</a></h1> | |||
<nav> | |||
{{if .PinnedPosts}} | |||
{{range .PinnedPosts}}<a class="pinned{{if eq .Slug.String $.Slug.String}} selected{{end}}" href="{{if not $.SingleUser}}/{{$.Collection.Alias}}/{{.Slug.String}}{{else}}{{.CanonicalURL .Host}}{{end}}">{{.PlainDisplayTitle}}</a>{{end}} | |||
{{range .PinnedPosts}}<a class="pinned{{if eq .Slug.String $.Slug.String}} selected{{end}}" href="{{if not $.SingleUser}}/{{$.Collection.Alias}}/{{.Slug.String}}{{else}}{{.CanonicalURL $.Host}}{{end}}">{{.PlainDisplayTitle}}</a>{{end}} | |||
{{end}} | |||
{{ if and .IsOwner .IsFound }}<span class="views" dir="ltr"><strong>{{largeNumFmt .Views}}</strong> {{pluralize "view" "views" .Views}}</span> | |||
<a class="xtra-feature" href="/{{if not .SingleUser}}{{.Collection.Alias}}/{{end}}{{.Slug.String}}/edit" dir="{{.Direction}}">Edit</a> | |||
@@ -62,7 +62,7 @@ | |||
{{if .Silenced}} | |||
{{template "user-silenced"}} | |||
{{end}} | |||
<article id="post-body" class="{{.Font}} h-entry {{if not .IsFound}}error-page{{end}}">{{if .IsScheduled}}<p class="badge">Scheduled</p>{{end}}{{if .Title.String}}<h2 id="title" class="p-name">{{.FormattedDisplayTitle}}</h2>{{end}}<div class="e-content">{{.HTMLContent}}</div></article> | |||
<article id="post-body" class="{{.Font}} h-entry {{if not .IsFound}}error-page{{end}}">{{if .IsScheduled}}<p class="badge">Scheduled</p>{{end}}{{if .Title.String}}<h2 id="title" class="p-name{{if $.Collection.Format.ShowDates}} dated{{end}}">{{.FormattedDisplayTitle}}</h2>{{end}}{{if $.Collection.Format.ShowDates}}<time class="dt-published" datetime="{{.Created8601}}" pubdate itemprop="datePublished" content="{{.Created}}">{{.DisplayDate}}</time>{{end}}<div class="e-content">{{.HTMLContent}}</div></article> | |||
{{ if .Collection.ShowFooterBranding }} | |||
<footer dir="ltr"><hr><nav><p style="font-size: 0.9em">{{localhtml "published with write.as" .Language.String}}</p></nav></footer> | |||
@@ -73,6 +73,7 @@ | |||
{{range .Collection.ExternalScripts}}<script type="text/javascript" src="{{.}}" async></script>{{end}} | |||
{{if .Collection.Script}}<script type="text/javascript">{{.Collection.ScriptDisplay}}</script>{{end}} | |||
{{end}} | |||
<script src="/js/localdate.js"></script> | |||
<script type="text/javascript"> | |||
var pinning = false; | |||
@@ -48,7 +48,7 @@ | |||
<h1 dir="{{.Direction}}" id="blog-title"><a href="{{if .IsTopLevel}}/{{else}}/{{.Collection.Alias}}/{{end}}" class="h-card p-author">{{.Collection.DisplayTitle}}</a></h1> | |||
<nav> | |||
{{if .PinnedPosts}} | |||
{{range .PinnedPosts}}<a class="pinned" href="{{if not $.SingleUser}}/{{$.Collection.Alias}}/{{.Slug.String}}{{else}}{{.CanonicalURL .Host}}{{end}}">{{.DisplayTitle}}</a>{{end}} | |||
{{range .PinnedPosts}}<a class="pinned" href="{{if not $.SingleUser}}/{{$.Collection.Alias}}/{{.Slug.String}}{{else}}{{.CanonicalURL $.Host}}{{end}}">{{.DisplayTitle}}</a>{{end}} | |||
{{end}} | |||
</nav> | |||
</header> | |||
@@ -75,6 +75,7 @@ | |||
{{range .ExternalScripts}}<script type="text/javascript" src="{{.}}" async></script>{{end}} | |||
{{if .Collection.Script}}<script type="text/javascript">{{.ScriptDisplay}}</script>{{end}} | |||
{{end}} | |||
<script src="/js/localdate.js"></script> | |||
{{if .IsOwner}} | |||
<script src="/js/h.js"></script> | |||
<script src="/js/postactions.js"></script> | |||
@@ -71,7 +71,7 @@ | |||
<!--p class="meta-note"><span>Private collection</span>. Only you can see this page.</p--> | |||
{{/*end*/}} | |||
{{if .PinnedPosts}}<nav> | |||
{{range .PinnedPosts}}<a class="pinned" href="{{if not $.SingleUser}}/{{$.Alias}}/{{.Slug.String}}{{else}}{{.CanonicalURL .Host}}{{end}}">{{.PlainDisplayTitle}}</a>{{end}}</nav> | |||
{{range .PinnedPosts}}<a class="pinned" href="{{if not $.SingleUser}}/{{$.Alias}}/{{.Slug.String}}{{else}}{{.CanonicalURL $.Host}}{{end}}">{{.PlainDisplayTitle}}</a>{{end}}</nav> | |||
{{end}} | |||
</header> | |||
@@ -116,6 +116,7 @@ | |||
{{end}} | |||
<script src="/js/h.js"></script> | |||
<script src="/js/postactions.js"></script> | |||
<script src="/js/localdate.js"></script> | |||
<script type="text/javascript"> | |||
var deleting = false; | |||
function delPost(e, id, owned) { | |||
@@ -270,7 +270,7 @@ | |||
<script> | |||
function updateMeta() { | |||
if ({{.Silenced}}) { | |||
alert('Your account is currently silenced, editing posts is disabled.'); | |||
alert("Your account is silenced, so you can't edit posts."); | |||
return | |||
} | |||
document.getElementById('create-error').style.display = 'none'; | |||
@@ -21,10 +21,10 @@ | |||
{{end}} | |||
{{end}} | |||
</h2> | |||
{{if $.Format.ShowDates}}<time class="dt-published" datetime="{{.Created}}" pubdate itemprop="datePublished" content="{{.Created}}">{{if not .Title.String}}<a href="{{$.CanonicalURL}}{{.Slug.String}}" itemprop="url">{{end}}{{.DisplayDate}}{{if not .Title.String}}</a>{{end}}</time>{{end}} | |||
{{if $.Format.ShowDates}}<time class="dt-published" datetime="{{.Created8601}}" pubdate itemprop="datePublished" content="{{.Created}}">{{if not .Title.String}}<a href="{{$.CanonicalURL}}{{.Slug.String}}" itemprop="url">{{end}}{{.DisplayDate}}{{if not .Title.String}}</a>{{end}}</time>{{end}} | |||
{{else}} | |||
<h2 class="post-title" itemprop="name"> | |||
{{if $.Format.ShowDates}}<time class="dt-published" datetime="{{.Created}}" pubdate itemprop="datePublished" content="{{.Created}}"><a href="{{if not $.SingleUser}}/{{$.Alias}}/{{.Slug.String}}{{else}}{{$.CanonicalURL}}{{.Slug.String}}{{end}}" itemprop="url" class="u-url">{{.DisplayDate}}</a></time>{{end}} | |||
{{if $.Format.ShowDates}}<time class="dt-published" datetime="{{.Created8601}}" pubdate itemprop="datePublished" content="{{.Created}}"><a href="{{if not $.SingleUser}}/{{$.Alias}}/{{.Slug.String}}{{else}}{{$.CanonicalURL}}{{.Slug.String}}{{end}}" itemprop="url" class="u-url">{{.DisplayDate}}</a></time>{{end}} | |||
{{if $.IsOwner}} | |||
{{if not $.Format.ShowDates}}<a class="user hidden action" href="{{if not $.SingleUser}}/{{$.Alias}}/{{.Slug.String}}{{else}}{{$.CanonicalURL}}{{.Slug.String}}{{end}}">view</a>{{end}} | |||
<a class="user hidden action" href="/{{if not $.SingleUser}}{{$.Alias}}/{{end}}{{.Slug.String}}/edit">edit</a> | |||
@@ -25,9 +25,6 @@ | |||
</head> | |||
<body id="collection" itemscope itemtype="http://schema.org/WebPage"> | |||
{{if .Silenced}} | |||
{{template "user-silenced"}} | |||
{{end}} | |||
<header> | |||
<h1 dir="{{.Direction}}" id="blog-title"><a href="/{{.Alias}}/" class="h-card p-author u-url" rel="me author">{{.DisplayTitle}}</a></h1> | |||
</header> | |||
@@ -88,9 +88,9 @@ | |||
<section itemscope itemtype="http://schema.org/Blog"> | |||
{{range .Posts}}<article class="{{.Font}} h-entry" itemscope itemtype="http://schema.org/BlogPosting"> | |||
{{if .Title.String}}<h2 class="post-title" itemprop="name" class="p-name"><a href="{{if .Slug.String}}{{.Collection.CanonicalURL}}{{.Slug.String}}{{else}}{{.CanonicalURL .Host}}.md{{end}}" itemprop="url" class="u-url">{{.PlainDisplayTitle}}</a></h2> | |||
<time class="dt-published" datetime="{{.Created}}" pubdate itemprop="datePublished" content="{{.Created}}">{{if not .Title.String}}<a href="{{.Collection.CanonicalURL}}{{.Slug.String}}" itemprop="url">{{end}}{{.DisplayDate}}{{if not .Title.String}}</a>{{end}}</time> | |||
<time class="dt-published" datetime="{{.Created8601}}" pubdate itemprop="datePublished" content="{{.Created}}">{{if not .Title.String}}<a href="{{.Collection.CanonicalURL}}{{.Slug.String}}" itemprop="url">{{end}}{{.DisplayDate}}{{if not .Title.String}}</a>{{end}}</time> | |||
{{else}} | |||
<h2 class="post-title" itemprop="name"><time class="dt-published" datetime="{{.Created}}" pubdate itemprop="datePublished" content="{{.Created}}"><a href="{{if .Collection}}{{.Collection.CanonicalURL}}{{.Slug.String}}{{else}}{{.CanonicalURL .Host}}.md{{end}}" itemprop="url" class="u-url">{{.DisplayDate}}</a></time></h2> | |||
<h2 class="post-title" itemprop="name"><time class="dt-published" datetime="{{.Created8601}}" pubdate itemprop="datePublished" content="{{.Created}}"><a href="{{if .Collection}}{{.Collection.CanonicalURL}}{{.Slug.String}}{{else}}{{.CanonicalURL .Host}}.md{{end}}" itemprop="url" class="u-url">{{.DisplayDate}}</a></time></h2> | |||
{{end}} | |||
<p class="source">{{if .Collection}}from <a href="{{.Collection.CanonicalURL}}">{{.Collection.DisplayTitle}}</a>{{else}}<em>Anonymous</em>{{end}}</p> | |||
{{if .Excerpt}}<div class="p-summary" {{if .Language}}lang="{{.Language.String}}"{{end}} dir="{{.Direction}}">{{.Excerpt}}</div> | |||
@@ -112,7 +112,7 @@ | |||
</nav>{{end}} | |||
</div> | |||
<script src="/js/localdate.js"> | |||
<script type="text/javascript"> | |||
(function() { | |||
var $articles = document.querySelectorAll('article'); | |||
@@ -12,7 +12,10 @@ | |||
<h2 id="posts-header">drafts</h2> | |||
{{ if .AnonymousPosts }}<div class="atoms posts"> | |||
{{ if .AnonymousPosts }} | |||
<p>These are your draft posts. You can share them individually (without a blog) or move them to your blog when you're ready.</p> | |||
<div class="atoms posts"> | |||
{{ range $el := .AnonymousPosts }}<div id="post-{{.ID}}" class="post"> | |||
<h3><a href="/{{if $.SingleUser}}d/{{end}}{{.ID}}" itemprop="url">{{.DisplayTitle}}</a></h3> | |||
<h4> | |||
@@ -34,10 +37,11 @@ | |||
{{end}} | |||
{{ end }} | |||
</h4> | |||
{{if .Summary}}<p>{{.Summary}}</p>{{end}} | |||
{{if .Summary}}<p>{{.SummaryHTML}}</p>{{end}} | |||
</div>{{end}} | |||
</div>{{ else }}<div id="no-posts-published"><p>You haven't saved any drafts yet.</p> | |||
<p>They'll show up here once you do. {{if not .SingleUser}}Find your blog posts from the <a href="/me/c/">Blogs</a> page.{{end}}</p> | |||
</div>{{ else }}<div id="no-posts-published"> | |||
<p>Your anonymous and draft posts will show up here once you've published some. You'll be able to share them individually (without a blog) or move them to a blog when you're ready.</p> | |||
{{if not .SingleUser}}<p>Alternatively, see your blogs and their posts on your <a href="/me/c/">Blogs</a> page.</p>{{end}} | |||
<p class="text-cta"><a href="{{if .SingleUser}}/me/new{{else}}/{{end}}">Start writing</a></p></div>{{ end }} | |||
<div id="moving"></div> | |||
@@ -0,0 +1,61 @@ | |||
{{define "import"}} | |||
{{template "header" .}} | |||
<style> | |||
input[type=file] { | |||
padding: 0; | |||
font-size: 0.86em; | |||
display: block; | |||
margin: 0.5rem 0; | |||
} | |||
label { | |||
display: block; | |||
margin: 1em 0; | |||
} | |||
</style> | |||
<div class="snug content-container"> | |||
<h1 id="import-header">Import posts</h1> | |||
{{if .Message}} | |||
<div class="alert {{if .InfoMsg}}info{{else}}success{{end}}"> | |||
<p>{{.Message}}</p> | |||
</div> | |||
{{end}} | |||
{{if .Flashes}} | |||
<ul class="errors"> | |||
{{range .Flashes}}<li class="urgent">{{.}}</li>{{end}} | |||
</ul> | |||
{{end}} | |||
<p>Publish plain text or Markdown files to your account by uploading them below.</p> | |||
<div class="formContainer"> | |||
<form id="importPosts" class="prominent" enctype="multipart/form-data" action="/api/me/import" method="POST"> | |||
<label>Select some files to import: | |||
<input id="fileInput" class="fileInput" name="files" type="file" multiple accept="text/markdown, text/plain"/> | |||
</label> | |||
<input id="fileDates" name="fileDates" hidden/> | |||
<label> | |||
Import these posts to: | |||
<select name="collection"> | |||
{{range $i, $el := .Collections}} | |||
<option value="{{.Alias}}" {{if eq $i 0}}selected{{end}}>{{.DisplayTitle}}</option> | |||
{{end}} | |||
<option value="">Drafts</option> | |||
</select> | |||
</label> | |||
<script> | |||
const fileInput = document.getElementById('fileInput'); | |||
const fileDates = document.getElementById('fileDates'); | |||
fileInput.addEventListener('change', (e) => { | |||
const files = e.target.files; | |||
let dateMap = {}; | |||
for (let file of files) { | |||
dateMap[file.name] = file.lastModified / 1000; | |||
} | |||
fileDates.value = JSON.stringify(dateMap); | |||
}) | |||
</script> | |||
<input type="submit" value="Import" /> | |||
</form> | |||
</div> | |||
</div> | |||
{{template "footer" .}} | |||
{{end}} |
@@ -10,6 +10,7 @@ | |||
<li class="separator"><hr /></li> | |||
{{if .IsAdmin}}<li><a href="/admin">Admin</a></li>{{end}} | |||
<li><a href="/me/settings">Settings</a></li> | |||
<li><a href="/me/import">Import posts</a></li> | |||
<li><a href="/me/export">Export</a></li> | |||
<li class="separator"><hr /></li> | |||
<li><a href="/me/logout">Log out</a></li> | |||
@@ -22,19 +23,17 @@ | |||
</nav> | |||
</nav> | |||
{{else}} | |||
{{ if .Chorus }}<nav id="full-nav"> | |||
<nav id="full-nav"> | |||
<div class="left-side"> | |||
<h1><a href="/" title="Return to editor">{{.SiteName}}</a></h1> | |||
</div> | |||
{{ else }} | |||
<h1><a href="/" title="Return to editor">{{.SiteName}}</a></h1> | |||
{{ end }} | |||
<nav id="user-nav"> | |||
{{if .Username}} | |||
<nav class="dropdown-nav"> | |||
<ul><li><a>{{.Username}}</a> <img class="ic-18dp" src="/img/ic_down_arrow_dark@2x.png" /><ul> | |||
{{if .IsAdmin}}<li><a href="/admin">Admin dashboard</a></li>{{end}} | |||
<li><a href="/me/settings">Account settings</a></li> | |||
<li><a href="/me/import">Import posts</a></li> | |||
<li><a href="/me/export">Export</a></li> | |||
{{if .CanInvite}}<li><a href="/me/invites">Invite people</a></li>{{end}} | |||
<li class="separator"><hr /></li> | |||
@@ -62,6 +61,7 @@ | |||
{{else}} | |||
<a href="/me/c/"{{if eq .Path "/me/c/"}} class="selected"{{end}}>Blogs</a> | |||
{{if not .DisableDrafts}}<a href="/me/posts/"{{if eq .Path "/me/posts/"}} class="selected"{{end}}>Drafts</a>{{end}} | |||
{{if and (and .LocalTimeline .CanViewReader) (not .Chorus)}}<a href="/read">Reader</a>{{end}} | |||
{{end}} | |||
</nav> | |||
</nav> | |||
@@ -8,18 +8,7 @@ | |||
margin-left: 0.5em; | |||
margin-right: 0; | |||
} | |||
label { | |||
font-weight: bold; | |||
} | |||
select { | |||
font-size: 1em; | |||
width: 100%; | |||
padding: 0.5rem; | |||
display: block; | |||
border-radius: 0.25rem; | |||
margin: 0.5rem 0; | |||
} | |||
input, table.classy { | |||
table.classy { | |||
width: 100%; | |||
} | |||
table.classy.export a { | |||
@@ -34,7 +23,7 @@ table td { | |||
<h1>Invite people</h1> | |||
<p>Invite others to join <em>{{.SiteName}}</em> by generating and sharing invite links below.</p> | |||
<form style="margin: 2em 0" action="/api/me/invites" method="post"> | |||
<form style="margin: 2em 0" class="prominent" action="/api/me/invites" method="post"> | |||
<div class="row"> | |||
<div class="half"> | |||
<label for="uses">Maximum number of uses:</label> | |||
@@ -1,5 +1,5 @@ | |||
/* | |||
* Copyright © 2018 A Bunch Tell LLC. | |||
* Copyright © 2018-2020 A Bunch Tell LLC. | |||
* | |||
* This file is part of WriteFreely. | |||
* | |||
@@ -11,7 +11,10 @@ | |||
package writefreely | |||
import ( | |||
"encoding/json" | |||
"io/ioutil" | |||
"net/http" | |||
"strings" | |||
"github.com/writeas/go-webfinger" | |||
"github.com/writeas/impart" | |||
@@ -89,3 +92,49 @@ func (wfr wfResolver) DummyUser(username string, hostname string, r []webfinger. | |||
func (wfr wfResolver) IsNotFoundError(err error) bool { | |||
return err == wfUserNotFoundErr | |||
} | |||
// RemoteLookup looks up a user by handle at a remote server | |||
// and returns the actor URL | |||
func RemoteLookup(handle string) string { | |||
handle = strings.TrimLeft(handle, "@") | |||
// let's take the server part of the handle | |||
parts := strings.Split(handle, "@") | |||
resp, err := http.Get("https://" + parts[1] + "/.well-known/webfinger?resource=acct:" + handle) | |||
if err != nil { | |||
log.Error("Error performing webfinger request", err) | |||
return "" | |||
} | |||
body, err := ioutil.ReadAll(resp.Body) | |||
if err != nil { | |||
log.Error("Error reading webfinger response", err) | |||
return "" | |||
} | |||
var result webfinger.Resource | |||
err = json.Unmarshal(body, &result) | |||
if err != nil { | |||
log.Error("Unsupported webfinger response received: %v", err) | |||
return "" | |||
} | |||
var href string | |||
// iterate over webfinger links and find the one with | |||
// a self "rel" | |||
for _, link := range result.Links { | |||
if link.Rel == "self" { | |||
href = link.HRef | |||
} | |||
} | |||
// if we didn't find it with the above then | |||
// try using aliases | |||
if href == "" { | |||
// take the last alias because mastodon has the | |||
// https://instance.tld/@user first which | |||
// doesn't work as an href | |||
href = result.Aliases[len(result.Aliases)-1] | |||
} | |||
return href | |||
} |