@@ -1038,18 +1038,30 @@ func viewSettings(app *App, u *User, w http.ResponseWriter, r *http.Request) err | |||
flashes, _ := getSessionFlashes(app, w, r, nil) | |||
oauthAccounts, err := app.db.GetOauthAccounts(r.Context(), u.ID) | |||
if err != nil { | |||
log.Error("Unable to get oauth accounts for settings: %s", err) | |||
return impart.HTTPError{http.StatusInternalServerError, "Unable to retrieve user data. The humans have been alerted."} | |||
} | |||
obj := struct { | |||
*UserPage | |||
Email string | |||
HasPass bool | |||
IsLogOut bool | |||
Suspended bool | |||
Email string | |||
HasPass bool | |||
IsLogOut bool | |||
Suspended bool | |||
OauthAccounts []oauthAccountInfo | |||
OauthSlack bool | |||
OauthWriteAs bool | |||
}{ | |||
UserPage: NewUserPage(app, r, u, "Account Settings", flashes), | |||
Email: fullUser.EmailClear(app.keys), | |||
HasPass: passIsSet, | |||
IsLogOut: r.FormValue("logout") == "1", | |||
Suspended: fullUser.IsSilenced(), | |||
UserPage: NewUserPage(app, r, u, "Account Settings", flashes), | |||
Email: fullUser.EmailClear(app.keys), | |||
HasPass: passIsSet, | |||
IsLogOut: r.FormValue("logout") == "1", | |||
Suspended: fullUser.IsSilenced(), | |||
OauthAccounts: oauthAccounts, | |||
OauthSlack: app.Config().SlackOauth.ClientID != "", | |||
OauthWriteAs: app.Config().WriteAsOauth.ClientID != "", | |||
} | |||
showUserPage(w, "settings", obj) | |||
@@ -1094,6 +1106,19 @@ func getTempInfo(app *App, key string, r *http.Request, w http.ResponseWriter) s | |||
return s | |||
} | |||
func removeOauth(app *App, u *User, w http.ResponseWriter, r *http.Request) error { | |||
provider := r.FormValue("provider") | |||
clientID := r.FormValue("client_id") | |||
remoteUserID := r.FormValue("remote_user_id") | |||
err := app.db.RemoveOauth(r.Context(), u.ID, provider, clientID, remoteUserID) | |||
if err != nil { | |||
return impart.HTTPError{Status: http.StatusInternalServerError, Message: err.Error()} | |||
} | |||
return impart.HTTPError{Status: http.StatusFound, Message: "/me/settings"} | |||
} | |||
func prepareUserEmail(input string, emailKey []byte) zero.String { | |||
email := zero.NewString("", input != "") | |||
if len(input) > 0 { | |||
@@ -128,8 +128,10 @@ type writestore interface { | |||
GetIDForRemoteUser(context.Context, string, string, string) (int64, error) | |||
RecordRemoteUserID(context.Context, int64, string, string, string, string) error | |||
ValidateOAuthState(context.Context, string) (string, string, error) | |||
GenerateOAuthState(context.Context, string, string) (string, error) | |||
ValidateOAuthState(context.Context, string) (string, string, int64, error) | |||
GenerateOAuthState(context.Context, string, string, int64) (string, error) | |||
GetOauthAccounts(ctx context.Context, userID int64) ([]oauthAccountInfo, error) | |||
RemoveOauth(ctx context.Context, userID int64, provider string, clientID string, remoteUserID string) error | |||
DatabaseInitialized() bool | |||
} | |||
@@ -2462,20 +2464,23 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) { | |||
return &t, nil | |||
} | |||
func (db *datastore) GenerateOAuthState(ctx context.Context, provider, clientID string) (string, error) { | |||
func (db *datastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUser int64) (string, error) { | |||
state := store.Generate62RandomString(24) | |||
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at) VALUES (?, ?, ?, FALSE, NOW())", state, provider, clientID) | |||
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at, attach_user_id) VALUES (?, ?, ?, FALSE, NOW(), ?)", state, provider, clientID, attachUser) | |||
if err != nil { | |||
return "", fmt.Errorf("unable to record oauth client state: %w", err) | |||
} | |||
return state, nil | |||
} | |||
func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) { | |||
func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, error) { | |||
var provider string | |||
var clientID string | |||
var attachUserID int64 | |||
err := wf_db.RunTransactionWithOptions(ctx, db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { | |||
err := tx.QueryRow("SELECT provider, client_id FROM oauth_client_states WHERE state = ? AND used = FALSE", state).Scan(&provider, &clientID) | |||
err := tx. | |||
QueryRowContext(ctx, "SELECT provider, client_id, attach_user_id FROM oauth_client_states WHERE state = ? AND used = FALSE", state). | |||
Scan(&provider, &clientID, &attachUserID) | |||
if err != nil { | |||
return err | |||
} | |||
@@ -2494,9 +2499,9 @@ func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (stri | |||
return nil | |||
}) | |||
if err != nil { | |||
return "", "", nil | |||
return "", "", 0, nil | |||
} | |||
return provider, clientID, nil | |||
return provider, clientID, attachUserID, nil | |||
} | |||
func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error { | |||
@@ -2525,6 +2530,33 @@ func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provi | |||
return userID, nil | |||
} | |||
type oauthAccountInfo struct { | |||
Provider string | |||
ClientID string | |||
RemoteUserID string | |||
} | |||
func (db *datastore) GetOauthAccounts(ctx context.Context, userID int64) ([]oauthAccountInfo, error) { | |||
rows, err := db.QueryContext(ctx, "SELECT provider, client_id, remote_user_id FROM oauth_users WHERE user_id = ? ", userID) | |||
if err != nil { | |||
log.Error("Failed selecting from oauth_users: %v", err) | |||
return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user oauth accounts."} | |||
} | |||
defer rows.Close() | |||
var records []oauthAccountInfo | |||
for rows.Next() { | |||
info := oauthAccountInfo{} | |||
err = rows.Scan(&info.Provider, &info.ClientID, &info.RemoteUserID) | |||
if err != nil { | |||
log.Error("Failed scanning GetAllUsers() row: %v", err) | |||
break | |||
} | |||
records = append(records, info) | |||
} | |||
return records, nil | |||
} | |||
// DatabaseInitialized returns whether or not the current datastore has been | |||
// initialized with the correct schema. | |||
// Currently, it checks to see if the `users` table exists. | |||
@@ -2547,6 +2579,11 @@ func (db *datastore) DatabaseInitialized() bool { | |||
return true | |||
} | |||
func (db *datastore) RemoveOauth(ctx context.Context, userID int64, provider string, clientID string, remoteUserID string) error { | |||
_, err := db.ExecContext(ctx, `DELETE FROM oauth_users WHERE user_id = ? AND provider = ? AND client_id = ? AND remote_user_id = ?`, userID, provider, clientID, remoteUserID) | |||
return err | |||
} | |||
func stringLogln(log *string, s string, v ...interface{}) { | |||
*log += fmt.Sprintf(s+"\n", v...) | |||
} | |||
@@ -18,13 +18,13 @@ func TestOAuthDatastore(t *testing.T) { | |||
driverName: "", | |||
} | |||
state, err := ds.GenerateOAuthState(ctx, "test", "development") | |||
state, err := ds.GenerateOAuthState(ctx, "test", "development", 0) | |||
assert.NoError(t, err) | |||
assert.Len(t, state, 24) | |||
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = false", state) | |||
_, _, err = ds.ValidateOAuthState(ctx, state) | |||
_, _, _, err = ds.ValidateOAuthState(ctx, state) | |||
assert.NoError(t, err) | |||
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = true", state) | |||
@@ -61,6 +61,7 @@ var migrations = []Migration{ | |||
New("support users suspension", supportUserStatus), // V2 -> V3 (v0.11.0) | |||
New("support oauth", oauth), // V3 -> V4 | |||
New("support slack oauth", oauthSlack), // V4 -> v5 | |||
New("support oauth attach", oauthAttach), // V5 -> V6 | |||
} | |||
// CurrentVer returns the current migration version the application is on | |||
@@ -0,0 +1,36 @@ | |||
package migrations | |||
import ( | |||
"context" | |||
"database/sql" | |||
wf_db "github.com/writeas/writefreely/db" | |||
) | |||
func oauthAttach(db *datastore) error { | |||
dialect := wf_db.DialectMySQL | |||
if db.driverName == driverSQLite { | |||
dialect = wf_db.DialectSQLite | |||
} | |||
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { | |||
builders := []wf_db.SQLBuilder{ | |||
dialect. | |||
AlterTable("oauth_client_states"). | |||
AddColumn(dialect. | |||
Column( | |||
"attach_user_id", | |||
wf_db.ColumnTypeInteger, | |||
wf_db.OptionalInt{Set: true, Value: 24,}).SetNullable(false).SetDefault("0")), | |||
} | |||
for _, builder := range builders { | |||
query, err := builder.ToSQL() | |||
if err != nil { | |||
return err | |||
} | |||
if _, err := tx.ExecContext(ctx, query); err != nil { | |||
return err | |||
} | |||
} | |||
return nil | |||
}) | |||
} |
@@ -59,8 +59,8 @@ type OAuthDatastoreProvider interface { | |||
type OAuthDatastore interface { | |||
GetIDForRemoteUser(context.Context, string, string, string) (int64, error) | |||
RecordRemoteUserID(context.Context, int64, string, string, string, string) error | |||
ValidateOAuthState(context.Context, string) (string, string, error) | |||
GenerateOAuthState(context.Context, string, string) (string, error) | |||
ValidateOAuthState(context.Context, string) (string, string, int64, error) | |||
GenerateOAuthState(context.Context, string, string, int64) (string, error) | |||
CreateUser(*config.Config, *User, string) error | |||
GetUserByID(int64) (*User, error) | |||
@@ -96,19 +96,32 @@ type oauthHandler struct { | |||
func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error { | |||
ctx := r.Context() | |||
state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID()) | |||
var attachUser int64 | |||
if attach := r.URL.Query().Get("attach"); attach == "t" { | |||
user, _ := getUserAndSession(app, r) | |||
if user == nil { | |||
return impart.HTTPError{http.StatusInternalServerError, "cannot attach auth to user: user not found in session"} | |||
} | |||
attachUser = user.ID | |||
} | |||
state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID(), attachUser) | |||
if err != nil { | |||
log.Error("viewOauthInit error: %s", err) | |||
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} | |||
} | |||
if h.callbackProxy != nil { | |||
if err := h.callbackProxy.register(ctx, state); err != nil { | |||
log.Error("viewOauthInit error: %s", err) | |||
return impart.HTTPError{http.StatusInternalServerError, "could not register state server"} | |||
} | |||
} | |||
location, err := h.oauthClient.buildLoginURL(state) | |||
if err != nil { | |||
log.Error("viewOauthInit error: %s", err) | |||
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} | |||
} | |||
return impart.HTTPError{http.StatusTemporaryRedirect, location} | |||
@@ -185,7 +198,7 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http | |||
code := r.FormValue("code") | |||
state := r.FormValue("state") | |||
provider, clientID, err := h.DB.ValidateOAuthState(ctx, state) | |||
provider, clientID, attachUserID, err := h.DB.ValidateOAuthState(ctx, state) | |||
if err != nil { | |||
log.Error("Unable to ValidateOAuthState: %s", err) | |||
return impart.HTTPError{http.StatusInternalServerError, err.Error()} | |||
@@ -223,6 +236,14 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http | |||
} | |||
return nil | |||
} | |||
if attachUserID > 0 { | |||
log.Info("attaching to user %d", attachUserID) | |||
err = h.DB.RecordRemoteUserID(r.Context(), attachUserID, tokenInfo.UserID, provider, clientID, tokenResponse.AccessToken) | |||
if err != nil { | |||
return impart.HTTPError{http.StatusInternalServerError, err.Error()} | |||
} | |||
return impart.HTTPError{http.StatusFound, "/me/settings"} | |||
} | |||
displayName := tokenInfo.DisplayName | |||
if len(displayName) == 0 { | |||
@@ -22,8 +22,8 @@ type MockOAuthDatastoreProvider struct { | |||
} | |||
type MockOAuthDatastore struct { | |||
DoGenerateOAuthState func(context.Context, string, string) (string, error) | |||
DoValidateOAuthState func(context.Context, string) (string, string, error) | |||
DoGenerateOAuthState func(context.Context, string, string, int64) (string, error) | |||
DoValidateOAuthState func(context.Context, string) (string, string, int64, error) | |||
DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error) | |||
DoCreateUser func(*config.Config, *User, string) error | |||
DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error | |||
@@ -86,11 +86,11 @@ func (m *MockOAuthDatastoreProvider) Config() *config.Config { | |||
return cfg | |||
} | |||
func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) { | |||
func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, error) { | |||
if m.DoValidateOAuthState != nil { | |||
return m.DoValidateOAuthState(ctx, state) | |||
} | |||
return "", "", nil | |||
return "", "", 0, nil | |||
} | |||
func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) { | |||
@@ -125,9 +125,9 @@ func (m *MockOAuthDatastore) GetUserByID(userID int64) (*User, error) { | |||
return user, nil | |||
} | |||
func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string) (string, error) { | |||
func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUserID int64) (string, error) { | |||
if m.DoGenerateOAuthState != nil { | |||
return m.DoGenerateOAuthState(ctx, provider, clientID) | |||
return m.DoGenerateOAuthState(ctx, provider, clientID, attachUserID) | |||
} | |||
return store.Generate62RandomString(14), nil | |||
} | |||
@@ -173,7 +173,7 @@ func TestViewOauthInit(t *testing.T) { | |||
app := &MockOAuthDatastoreProvider{ | |||
DoDB: func() OAuthDatastore { | |||
return &MockOAuthDatastore{ | |||
DoGenerateOAuthState: func(ctx context.Context, provider, clientID string) (string, error) { | |||
DoGenerateOAuthState: func(ctx context.Context, provider, clientID string, attachUserID int64) (string, error) { | |||
return "", fmt.Errorf("pretend unable to write state error") | |||
}, | |||
} | |||
@@ -101,6 +101,7 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router { | |||
me.HandleFunc("/settings", handler.User(viewSettings)).Methods("GET") | |||
me.HandleFunc("/invites", handler.User(handleViewUserInvites)).Methods("GET") | |||
me.HandleFunc("/logout", handler.Web(viewLogout, UserLevelNone)).Methods("GET") | |||
me.HandleFunc("/oauth/remove", handler.User(removeOauth)).Methods("POST") | |||
write.HandleFunc("/api/me", handler.All(viewMeAPI)).Methods("GET") | |||
apiMe := write.PathPrefix("/api/me/").Subrouter() | |||
@@ -66,6 +66,28 @@ h3 { font-weight: normal; } | |||
<input type="submit" value="Save changes" tabindex="4" /> | |||
</div> | |||
</form> | |||
{{ if .OauthAccounts }} | |||
{{ range $oauth_account := .OauthAccounts }} | |||
<form method="post" action="/me/oauth/remove" autocomplete="false"> | |||
<input type="hidden" name="provider" value="{{ $oauth_account.Provider }}" /> | |||
<input type="hidden" name="client_id" value="{{ $oauth_account.ClientID }}" /> | |||
<input type="hidden" name="remote_user_id" value="{{ $oauth_account.RemoteUserID }}" /> | |||
<div class="option"> | |||
<h3>{{ $oauth_account.Provider }} </h3> | |||
<div class="section"> | |||
<input type="submit" value="Remove" style="margin-left: 1em;" /> | |||
</div> | |||
</div> | |||
</form> | |||
{{ end }} | |||
{{ end }} | |||
{{ if .OauthSlack }} | |||
<a class="loginbtn" href="/oauth/slack?attach=t"><img alt="Sign in with Slack" height="40" width="172" src="/img/sign_in_with_slack.png" srcset="/img/sign_in_with_slack.png 1x, /img/sign_in_with_slack@2x.png 2x" /></a> | |||
{{ end }} | |||
{{ if .OauthWriteAs }} | |||
<a class="btn cta loginbtn" id="writeas-login" href="/oauth/write.as?attach=t">Link your <strong>Write.as</strong> account.</a> | |||
{{ end }} | |||
</div> | |||
<script> | |||