瀏覽代碼

Implemented oauth attach functionality, oauth detach functionality, and required data migration. T713

pull/243/head
Nick Gerakines 4 年之前
父節點
當前提交
c0317b4e93
共有 9 個檔案被更改,包括 173 行新增30 行删除
  1. +34
    -9
      account.go
  2. +45
    -8
      database.go
  3. +2
    -2
      database_test.go
  4. +1
    -0
      migrations/migrations.go
  5. +36
    -0
      migrations/v6.go
  6. +25
    -4
      oauth.go
  7. +7
    -7
      oauth_test.go
  8. +1
    -0
      routes.go
  9. +22
    -0
      templates/user/settings.tmpl

+ 34
- 9
account.go 查看文件

@@ -1038,18 +1038,30 @@ func viewSettings(app *App, u *User, w http.ResponseWriter, r *http.Request) err

flashes, _ := getSessionFlashes(app, w, r, nil)

oauthAccounts, err := app.db.GetOauthAccounts(r.Context(), u.ID)
if err != nil {
log.Error("Unable to get oauth accounts for settings: %s", err)
return impart.HTTPError{http.StatusInternalServerError, "Unable to retrieve user data. The humans have been alerted."}
}

obj := struct {
*UserPage
Email string
HasPass bool
IsLogOut bool
Suspended bool
Email string
HasPass bool
IsLogOut bool
Suspended bool
OauthAccounts []oauthAccountInfo
OauthSlack bool
OauthWriteAs bool
}{
UserPage: NewUserPage(app, r, u, "Account Settings", flashes),
Email: fullUser.EmailClear(app.keys),
HasPass: passIsSet,
IsLogOut: r.FormValue("logout") == "1",
Suspended: fullUser.IsSilenced(),
UserPage: NewUserPage(app, r, u, "Account Settings", flashes),
Email: fullUser.EmailClear(app.keys),
HasPass: passIsSet,
IsLogOut: r.FormValue("logout") == "1",
Suspended: fullUser.IsSilenced(),
OauthAccounts: oauthAccounts,
OauthSlack: app.Config().SlackOauth.ClientID != "",
OauthWriteAs: app.Config().WriteAsOauth.ClientID != "",
}

showUserPage(w, "settings", obj)
@@ -1094,6 +1106,19 @@ func getTempInfo(app *App, key string, r *http.Request, w http.ResponseWriter) s
return s
}

func removeOauth(app *App, u *User, w http.ResponseWriter, r *http.Request) error {
provider := r.FormValue("provider")
clientID := r.FormValue("client_id")
remoteUserID := r.FormValue("remote_user_id")

err := app.db.RemoveOauth(r.Context(), u.ID, provider, clientID, remoteUserID)
if err != nil {
return impart.HTTPError{Status: http.StatusInternalServerError, Message: err.Error()}
}

return impart.HTTPError{Status: http.StatusFound, Message: "/me/settings"}
}

func prepareUserEmail(input string, emailKey []byte) zero.String {
email := zero.NewString("", input != "")
if len(input) > 0 {


+ 45
- 8
database.go 查看文件

@@ -128,8 +128,10 @@ type writestore interface {

GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
RecordRemoteUserID(context.Context, int64, string, string, string, string) error
ValidateOAuthState(context.Context, string) (string, string, error)
GenerateOAuthState(context.Context, string, string) (string, error)
ValidateOAuthState(context.Context, string) (string, string, int64, error)
GenerateOAuthState(context.Context, string, string, int64) (string, error)
GetOauthAccounts(ctx context.Context, userID int64) ([]oauthAccountInfo, error)
RemoveOauth(ctx context.Context, userID int64, provider string, clientID string, remoteUserID string) error

DatabaseInitialized() bool
}
@@ -2462,20 +2464,23 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) {
return &t, nil
}

func (db *datastore) GenerateOAuthState(ctx context.Context, provider, clientID string) (string, error) {
func (db *datastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUser int64) (string, error) {
state := store.Generate62RandomString(24)
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at) VALUES (?, ?, ?, FALSE, NOW())", state, provider, clientID)
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at, attach_user_id) VALUES (?, ?, ?, FALSE, NOW(), ?)", state, provider, clientID, attachUser)
if err != nil {
return "", fmt.Errorf("unable to record oauth client state: %w", err)
}
return state, nil
}

func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) {
func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, error) {
var provider string
var clientID string
var attachUserID int64
err := wf_db.RunTransactionWithOptions(ctx, db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
err := tx.QueryRow("SELECT provider, client_id FROM oauth_client_states WHERE state = ? AND used = FALSE", state).Scan(&provider, &clientID)
err := tx.
QueryRowContext(ctx, "SELECT provider, client_id, attach_user_id FROM oauth_client_states WHERE state = ? AND used = FALSE", state).
Scan(&provider, &clientID, &attachUserID)
if err != nil {
return err
}
@@ -2494,9 +2499,9 @@ func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (stri
return nil
})
if err != nil {
return "", "", nil
return "", "", 0, nil
}
return provider, clientID, nil
return provider, clientID, attachUserID, nil
}

func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error {
@@ -2525,6 +2530,33 @@ func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provi
return userID, nil
}

type oauthAccountInfo struct {
Provider string
ClientID string
RemoteUserID string
}

func (db *datastore) GetOauthAccounts(ctx context.Context, userID int64) ([]oauthAccountInfo, error) {
rows, err := db.QueryContext(ctx, "SELECT provider, client_id, remote_user_id FROM oauth_users WHERE user_id = ? ", userID)
if err != nil {
log.Error("Failed selecting from oauth_users: %v", err)
return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user oauth accounts."}
}
defer rows.Close()

var records []oauthAccountInfo
for rows.Next() {
info := oauthAccountInfo{}
err = rows.Scan(&info.Provider, &info.ClientID, &info.RemoteUserID)
if err != nil {
log.Error("Failed scanning GetAllUsers() row: %v", err)
break
}
records = append(records, info)
}
return records, nil
}

// DatabaseInitialized returns whether or not the current datastore has been
// initialized with the correct schema.
// Currently, it checks to see if the `users` table exists.
@@ -2547,6 +2579,11 @@ func (db *datastore) DatabaseInitialized() bool {
return true
}

func (db *datastore) RemoveOauth(ctx context.Context, userID int64, provider string, clientID string, remoteUserID string) error {
_, err := db.ExecContext(ctx, `DELETE FROM oauth_users WHERE user_id = ? AND provider = ? AND client_id = ? AND remote_user_id = ?`, userID, provider, clientID, remoteUserID)
return err
}

func stringLogln(log *string, s string, v ...interface{}) {
*log += fmt.Sprintf(s+"\n", v...)
}


+ 2
- 2
database_test.go 查看文件

@@ -18,13 +18,13 @@ func TestOAuthDatastore(t *testing.T) {
driverName: "",
}

state, err := ds.GenerateOAuthState(ctx, "test", "development")
state, err := ds.GenerateOAuthState(ctx, "test", "development", 0)
assert.NoError(t, err)
assert.Len(t, state, 24)

countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = false", state)

_, _, err = ds.ValidateOAuthState(ctx, state)
_, _, _, err = ds.ValidateOAuthState(ctx, state)
assert.NoError(t, err)

countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = true", state)


+ 1
- 0
migrations/migrations.go 查看文件

@@ -61,6 +61,7 @@ var migrations = []Migration{
New("support users suspension", supportUserStatus), // V2 -> V3 (v0.11.0)
New("support oauth", oauth), // V3 -> V4
New("support slack oauth", oauthSlack), // V4 -> v5
New("support oauth attach", oauthAttach), // V5 -> V6
}

// CurrentVer returns the current migration version the application is on


+ 36
- 0
migrations/v6.go 查看文件

@@ -0,0 +1,36 @@
package migrations

import (
"context"
"database/sql"

wf_db "github.com/writeas/writefreely/db"
)

func oauthAttach(db *datastore) error {
dialect := wf_db.DialectMySQL
if db.driverName == driverSQLite {
dialect = wf_db.DialectSQLite
}
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
builders := []wf_db.SQLBuilder{
dialect.
AlterTable("oauth_client_states").
AddColumn(dialect.
Column(
"attach_user_id",
wf_db.ColumnTypeInteger,
wf_db.OptionalInt{Set: true, Value: 24,}).SetNullable(false).SetDefault("0")),
}
for _, builder := range builders {
query, err := builder.ToSQL()
if err != nil {
return err
}
if _, err := tx.ExecContext(ctx, query); err != nil {
return err
}
}
return nil
})
}

+ 25
- 4
oauth.go 查看文件

@@ -59,8 +59,8 @@ type OAuthDatastoreProvider interface {
type OAuthDatastore interface {
GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
RecordRemoteUserID(context.Context, int64, string, string, string, string) error
ValidateOAuthState(context.Context, string) (string, string, error)
GenerateOAuthState(context.Context, string, string) (string, error)
ValidateOAuthState(context.Context, string) (string, string, int64, error)
GenerateOAuthState(context.Context, string, string, int64) (string, error)

CreateUser(*config.Config, *User, string) error
GetUserByID(int64) (*User, error)
@@ -96,19 +96,32 @@ type oauthHandler struct {

func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID())

var attachUser int64
if attach := r.URL.Query().Get("attach"); attach == "t" {
user, _ := getUserAndSession(app, r)
if user == nil {
return impart.HTTPError{http.StatusInternalServerError, "cannot attach auth to user: user not found in session"}
}
attachUser = user.ID
}

state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID(), attachUser)
if err != nil {
log.Error("viewOauthInit error: %s", err)
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
}

if h.callbackProxy != nil {
if err := h.callbackProxy.register(ctx, state); err != nil {
log.Error("viewOauthInit error: %s", err)
return impart.HTTPError{http.StatusInternalServerError, "could not register state server"}
}
}

location, err := h.oauthClient.buildLoginURL(state)
if err != nil {
log.Error("viewOauthInit error: %s", err)
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
}
return impart.HTTPError{http.StatusTemporaryRedirect, location}
@@ -185,7 +198,7 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
code := r.FormValue("code")
state := r.FormValue("state")

provider, clientID, err := h.DB.ValidateOAuthState(ctx, state)
provider, clientID, attachUserID, err := h.DB.ValidateOAuthState(ctx, state)
if err != nil {
log.Error("Unable to ValidateOAuthState: %s", err)
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
@@ -223,6 +236,14 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
}
return nil
}
if attachUserID > 0 {
log.Info("attaching to user %d", attachUserID)
err = h.DB.RecordRemoteUserID(r.Context(), attachUserID, tokenInfo.UserID, provider, clientID, tokenResponse.AccessToken)
if err != nil {
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
}
return impart.HTTPError{http.StatusFound, "/me/settings"}
}

displayName := tokenInfo.DisplayName
if len(displayName) == 0 {


+ 7
- 7
oauth_test.go 查看文件

@@ -22,8 +22,8 @@ type MockOAuthDatastoreProvider struct {
}

type MockOAuthDatastore struct {
DoGenerateOAuthState func(context.Context, string, string) (string, error)
DoValidateOAuthState func(context.Context, string) (string, string, error)
DoGenerateOAuthState func(context.Context, string, string, int64) (string, error)
DoValidateOAuthState func(context.Context, string) (string, string, int64, error)
DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error)
DoCreateUser func(*config.Config, *User, string) error
DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error
@@ -86,11 +86,11 @@ func (m *MockOAuthDatastoreProvider) Config() *config.Config {
return cfg
}

func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) {
func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, error) {
if m.DoValidateOAuthState != nil {
return m.DoValidateOAuthState(ctx, state)
}
return "", "", nil
return "", "", 0, nil
}

func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) {
@@ -125,9 +125,9 @@ func (m *MockOAuthDatastore) GetUserByID(userID int64) (*User, error) {
return user, nil
}

func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string) (string, error) {
func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUserID int64) (string, error) {
if m.DoGenerateOAuthState != nil {
return m.DoGenerateOAuthState(ctx, provider, clientID)
return m.DoGenerateOAuthState(ctx, provider, clientID, attachUserID)
}
return store.Generate62RandomString(14), nil
}
@@ -173,7 +173,7 @@ func TestViewOauthInit(t *testing.T) {
app := &MockOAuthDatastoreProvider{
DoDB: func() OAuthDatastore {
return &MockOAuthDatastore{
DoGenerateOAuthState: func(ctx context.Context, provider, clientID string) (string, error) {
DoGenerateOAuthState: func(ctx context.Context, provider, clientID string, attachUserID int64) (string, error) {
return "", fmt.Errorf("pretend unable to write state error")
},
}


+ 1
- 0
routes.go 查看文件

@@ -101,6 +101,7 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router {
me.HandleFunc("/settings", handler.User(viewSettings)).Methods("GET")
me.HandleFunc("/invites", handler.User(handleViewUserInvites)).Methods("GET")
me.HandleFunc("/logout", handler.Web(viewLogout, UserLevelNone)).Methods("GET")
me.HandleFunc("/oauth/remove", handler.User(removeOauth)).Methods("POST")

write.HandleFunc("/api/me", handler.All(viewMeAPI)).Methods("GET")
apiMe := write.PathPrefix("/api/me/").Subrouter()


+ 22
- 0
templates/user/settings.tmpl 查看文件

@@ -66,6 +66,28 @@ h3 { font-weight: normal; }
<input type="submit" value="Save changes" tabindex="4" />
</div>
</form>

{{ if .OauthAccounts }}
{{ range $oauth_account := .OauthAccounts }}
<form method="post" action="/me/oauth/remove" autocomplete="false">
<input type="hidden" name="provider" value="{{ $oauth_account.Provider }}" />
<input type="hidden" name="client_id" value="{{ $oauth_account.ClientID }}" />
<input type="hidden" name="remote_user_id" value="{{ $oauth_account.RemoteUserID }}" />
<div class="option">
<h3>{{ $oauth_account.Provider }} </h3>
<div class="section">
<input type="submit" value="Remove" style="margin-left: 1em;" />
</div>
</div>
</form>
{{ end }}
{{ end }}
{{ if .OauthSlack }}
<a class="loginbtn" href="/oauth/slack?attach=t"><img alt="Sign in with Slack" height="40" width="172" src="/img/sign_in_with_slack.png" srcset="/img/sign_in_with_slack.png 1x, /img/sign_in_with_slack@2x.png 2x" /></a>
{{ end }}
{{ if .OauthWriteAs }}
<a class="btn cta loginbtn" id="writeas-login" href="/oauth/write.as?attach=t">Link your <strong>Write.as</strong> account.</a>
{{ end }}
</div>

<script>


Loading…
取消
儲存