@@ -2,14 +2,17 @@ package writefreely
import (
"context"
"encoding/hex"
"encoding/json"
"fmt"
"github.com/gorilla/mux"
"github.com/gorilla/sessions"
"github.com/guregu/null/zero"
"github.com/writeas/nerds/store"
"github.com/writeas/web-core/auth"
"github.com/writeas/web-core/log"
"github.com/writeas/writefreely/config"
"hash/fnv"
"io"
"io/ioutil"
"net/http"
@@ -55,11 +58,12 @@ type OAuthDatastoreProvider interface {
// OAuthDatastore provides a minimal interface of data store methods used in
// oauth functionality.
type OAuthDatastore interface {
GenerateOAuthState(context.Context) (string, error)
ValidateOAuthState(context.Context, string) error
GetIDForRemoteUser(context.Context, int64) (int64, error)
CreateUser(*config.Config, *User, string) error
RecordRemoteUserID(context.Context, int64, int64) error
ValidateOAuthState(context.Context, string, string, string) error
GenerateOAuthState(context.Context, string, string) (string, error)
CreateUser(*config.Config, *User, string) error
GetUserForAuthByID(int64) (*User, error)
}
@@ -75,8 +79,8 @@ type oauthHandler struct {
}
// buildAuthURL returns a URL used to initiate authentication.
func buildAuthURL(db OAuthDatastore, ctx context.Context, clientID, authLocation, callbackURL string) (string, error) {
state, err := db.GenerateOAuthState(ctx)
func buildAuthURL(db OAuthDatastore, ctx context.Context, provider, clientID, authLocation, callbackURL string) (string, error) {
state, err := db.GenerateOAuthState(ctx, provider, clientID )
if err != nil {
return "", err
}
@@ -95,9 +99,8 @@ func buildAuthURL(db OAuthDatastore, ctx context.Context, clientID, authLocation
return u.String(), nil
}
// app *App, w http.ResponseWriter, r *http.Request
func (h oauthHandler) viewOauthInit(w http.ResponseWriter, r *http.Request) {
location, err := buildAuthURL(h.DB, r.Context(), h.Config.App.OAuthClientID, h.Config.App.OAuthProviderAuthLocation, h.Config.App.OAuthClientCallbackLocation)
func (h oauthHandler) viewOauthInitWriteAs(w http.ResponseWriter, r *http.Request) {
location, err := buildAuthURL(h.DB, r.Context(), "write.as", h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsProviderAuthLocation, h.Config.App.OAuth.WriteAsClientCallbackLocation)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url")
return
@@ -105,94 +108,128 @@ func (h oauthHandler) viewOauthInit(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, location, http.StatusTemporaryRedirect)
}
func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
code := r.FormValue("code")
state := r.FormValue("state")
err := h.DB.ValidateOAuthState(ctx, state)
func (h oauthHandler) viewOauthInitSlack(w http.ResponseWriter, r *http.Request) {
location, err := buildAuthURL(h.DB, r.Context(), "slack", h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsProviderAuthLocation, h.Config.App.OAuth.WriteAsClientCallbackLocation)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error() )
failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url")
return
}
http.Redirect(w, r, location, http.StatusTemporaryRedirect)
}
tokenResponse, err := h.exchangeOauthCode(ctx, code)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
func (h oauthHandler) configureRoutes(r *mux.Router) {
if h.Config.App.OAuth.Enabled {
if h.Config.App.OAuth.WriteAsClientID != "" {
callbackHash := oauthProviderHash("write.as", h.Config.App.OAuth.WriteAsClientID)
log.InfoLog.Println("write.as oauth callback URL", "/oauth/callback/"+callbackHash)
r.HandleFunc("/oauth/write.as", h.viewOauthInitWriteAs).Methods("GET")
r.HandleFunc("/oauth/callback/"+callbackHash, h.viewOauthCallback("write.as", h.Config.App.OAuth.WriteAsClientID)).Methods("GET")
}
if h.Config.App.OAuth.SlackClientID != "" {
callbackHash := oauthProviderHash("slack", h.Config.App.OAuth.SlackClientID)
log.InfoLog.Println("slack oauth callback URL", "/oauth/callback/"+callbackHash)
r.HandleFunc("/oauth/slack", h.viewOauthInitSlack).Methods("GET")
r.HandleFunc("/oauth/callback/"+callbackHash, h.viewOauthCallback("slack", h.Config.App.OAuth.SlackClientID)).Methods("GET")
}
}
// Now that we have the access token, let's use it real quick to make sur
// it really really works.
tokenInfo, err := h.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
}
}
localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
}
func oauthProviderHash(provider, clientID string) string {
hasher := fnv.New32()
return hex.EncodeToString(hasher.Sum([]byte(provider + clientID)))
}
fmt.Println("local user id", localUserID)
func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if localUserID == -1 {
// We don't have, nor do we want, the password from the origin, so we
//create a random string. If the user needs to set a password, they
//can do so through the settings page or through the password reset
//flow.
randPass := store.Generate62RandomString(14)
hashedPass, err := auth.HashPass([]byte(randPass))
code := r.FormValue("code")
state := r.FormValue("state")
err := h.DB.ValidateOAuthState(ctx, state, provider, clientID)
if err != nil {
log.ErrorLog.Println(err)
failOAuthRequest(w, http.StatusInternalServerError, "unable to create password hash")
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
}
newUser := &User{
Username: tokenInfo.Username,
HashedPass: hashedPass,
HasPass: true,
Email: zero.NewString("", tokenInfo.Email != ""),
Created: time.Now().Truncate(time.Second).UTC(),
}
err = h.DB.CreateUser(h.Config, newUser, newUser.Usernam e)
tokenResponse, err := h.exchangeOauthCode(ctx, code)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
}
err = h.DB.RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID)
// Now that we have the access token, let's use it real quick to make sur
// it really really works.
tokenInfo, err := h.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
}
if err := loginOrFail(h.Store, w, r, newUser); err != nil {
localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
}
return
}
user, err := h.DB.GetUserForAuthByID(localUserID)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
}
if err = loginOrFail(h.Store, w, r, user); err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
fmt.Println("local user id", localUserID)
if localUserID == -1 {
// We don't have, nor do we want, the password from the origin, so we
//create a random string. If the user needs to set a password, they
//can do so through the settings page or through the password reset
//flow.
randPass := store.Generate62RandomString(14)
hashedPass, err := auth.HashPass([]byte(randPass))
if err != nil {
log.ErrorLog.Println(err)
failOAuthRequest(w, http.StatusInternalServerError, "unable to create password hash")
return
}
newUser := &User{
Username: tokenInfo.Username,
HashedPass: hashedPass,
HasPass: true,
Email: zero.NewString("", tokenInfo.Email != ""),
Created: time.Now().Truncate(time.Second).UTC(),
}
err = h.DB.CreateUser(h.Config, newUser, newUser.Username)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
}
err = h.DB.RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
}
if err := loginOrFail(h.Store, w, r, newUser); err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
}
return
}
user, err := h.DB.GetUserForAuthByID(localUserID)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
}
if err = loginOrFail(h.Store, w, r, user); err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
}
}
}
func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) {
form := url.Values{}
form.Add("grant_type", "authorization_code")
form.Add("redirect_uri", h.Config.App.OAuthClientCallbackLocation)
form.Add("redirect_uri", h.Config.App.OAuth.WriteAs ClientCallbackLocation)
form.Add("code", code)
req, err := http.NewRequest("POST", h.Config.App.OAuthProviderTokenLocation, strings.NewReader(form.Encode()))
req, err := http.NewRequest("POST", h.Config.App.OAuth.WriteAs ProviderTokenLocation, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
}
@@ -200,7 +237,7 @@ func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*Toke
req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(h.Config.App.OAuthClientID, h.Config.App.OAuthClientSecret)
req.SetBasicAuth(h.Config.App.OAuth.WriteAs ClientID, h.Config.App.OAuth.WriteAs ClientSecret)
resp, err := h.HttpClient.Do(req)
if err != nil {
@@ -224,7 +261,7 @@ func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*Toke
}
func (h oauthHandler) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) {
req, err := http.NewRequest("GET", h.Config.App.OAuthProviderInspectLocation, nil)
req, err := http.NewRequest("GET", h.Config.App.OAuth.WriteAs ProviderInspectLocation, nil)
if err != nil {
return nil, err
}