|
- package writefreely
-
- import (
- "context"
- "encoding/json"
- "fmt"
- "github.com/gorilla/sessions"
- "github.com/guregu/null/zero"
- "github.com/writeas/web-core/auth"
- "github.com/writeas/web-core/log"
- "github.com/writeas/writefreely/config"
- "io"
- "io/ioutil"
- "net/http"
- "net/url"
- "strings"
- "time"
- )
-
- // TokenResponse contains data returned when a token is created either
- // through a code exchange or using a refresh token.
- type TokenResponse struct {
- AccessToken string `json:"access_token"`
- ExpiresIn int `json:"expires_in"`
- RefreshToken string `json:"refresh_token"`
- TokenType string `json:"token_type"`
- }
-
- // InspectResponse contains data returned when an access token is inspected.
- type InspectResponse struct {
- ClientID string `json:"client_id"`
- UserID int64 `json:"user_id"`
- ExpiresAt time.Time `json:"expires_at"`
- Username string `json:"username"`
- Email string `json:"email"`
- }
-
- // tokenRequestMaxLen is the most bytes that we'll read from the /oauth/token
- // endpoint. One megabyte is plenty.
- const tokenRequestMaxLen = 1000000
-
- // infoRequestMaxLen is the most bytes that we'll read from the
- // /oauth/inspect endpoint.
- const infoRequestMaxLen = 1000000
-
- // OAuthDatastoreProvider provides a minimal interface of data store, config,
- // and session store for use with the oauth handlers.
- type OAuthDatastoreProvider interface {
- DB() OAuthDatastore
- Config() *config.Config
- SessionStore() sessions.Store
- }
-
- // OAuthDatastore provides a minimal interface of data store methods used in
- // oauth functionality.
- type OAuthDatastore interface {
- GenerateOAuthState(context.Context) (string, error)
- ValidateOAuthState(context.Context, string) error
- GetIDForRemoteUser(context.Context, int64) (int64, error)
- CreateUser(*config.Config, *User, string) error
- RecordRemoteUserID(context.Context, int64, int64) error
- GetUserForAuthByID(int64) (*User, error)
- }
-
- type HttpClient interface {
- Do(req *http.Request) (*http.Response, error)
- }
-
- type oauthHandler struct {
- Config *config.Config
- DB OAuthDatastore
- Store sessions.Store
- HttpClient HttpClient
- }
-
- // buildAuthURL returns a URL used to initiate authentication.
- func buildAuthURL(db OAuthDatastore, ctx context.Context, clientID, authLocation, callbackURL string) (string, error) {
- state, err := db.GenerateOAuthState(ctx)
- if err != nil {
- return "", err
- }
-
- u, err := url.Parse(authLocation)
- if err != nil {
- return "", err
- }
- q := u.Query()
- q.Set("client_id", clientID)
- q.Set("redirect_uri", callbackURL)
- q.Set("response_type", "code")
- q.Set("state", state)
- u.RawQuery = q.Encode()
-
- return u.String(), nil
- }
-
- // app *App, w http.ResponseWriter, r *http.Request
- func (h oauthHandler) viewOauthInit(w http.ResponseWriter, r *http.Request) {
- location, err := buildAuthURL(h.DB, r.Context(), h.Config.App.OAuthClientID, h.Config.App.OAuthProviderAuthLocation, h.Config.App.OAuthClientCallbackLocation)
- if err != nil {
- failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url")
- return
- }
- http.Redirect(w, r, location, http.StatusTemporaryRedirect)
- }
-
- func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) {
- ctx := r.Context()
-
- code := r.FormValue("code")
- state := r.FormValue("state")
-
- err := h.DB.ValidateOAuthState(ctx, state)
- if err != nil {
- failOAuthRequest(w, http.StatusInternalServerError, err.Error())
- return
- }
-
- tokenResponse, err := h.exchangeOauthCode(ctx, code)
- if err != nil {
- failOAuthRequest(w, http.StatusInternalServerError, err.Error())
- return
- }
-
- // Now that we have the access token, let's use it real quick to make sur
- // it really really works.
- tokenInfo, err := h.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
- if err != nil {
- failOAuthRequest(w, http.StatusInternalServerError, err.Error())
- return
- }
-
- localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID)
- if err != nil {
- failOAuthRequest(w, http.StatusInternalServerError, err.Error())
- return
- }
-
- fmt.Println("local user id", localUserID)
-
- if localUserID == -1 {
- // We don't have, nor do we want, the password from the origin, so we
- //create a random string. If the user needs to set a password, they
- //can do so through the settings page or through the password reset
- //flow.
- randPass, err := randString(14)
- if err != nil {
- failOAuthRequest(w, http.StatusInternalServerError, err.Error())
- return
- }
- hashedPass, err := auth.HashPass([]byte(randPass))
- if err != nil {
- log.ErrorLog.Println(err)
- failOAuthRequest(w, http.StatusInternalServerError, "unable to create password hash")
- return
- }
- newUser := &User{
- Username: tokenInfo.Username,
- HashedPass: hashedPass,
- HasPass: true,
- Email: zero.NewString("", tokenInfo.Email != ""),
- Created: time.Now().Truncate(time.Second).UTC(),
- }
-
- err = h.DB.CreateUser(h.Config, newUser, newUser.Username)
- if err != nil {
- failOAuthRequest(w, http.StatusInternalServerError, err.Error())
- return
- }
-
- err = h.DB.RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID)
- if err != nil {
- failOAuthRequest(w, http.StatusInternalServerError, err.Error())
- return
- }
-
- if err := loginOrFail(h.Store, w, r, newUser); err != nil {
- failOAuthRequest(w, http.StatusInternalServerError, err.Error())
- }
- return
- }
-
- user, err := h.DB.GetUserForAuthByID(localUserID)
- if err != nil {
- failOAuthRequest(w, http.StatusInternalServerError, err.Error())
- return
- }
- if err = loginOrFail(h.Store, w, r, user); err != nil {
- failOAuthRequest(w, http.StatusInternalServerError, err.Error())
- }
- }
-
- func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) {
- form := url.Values{}
- form.Add("grant_type", "authorization_code")
- form.Add("redirect_uri", h.Config.App.OAuthClientCallbackLocation)
- form.Add("code", code)
- req, err := http.NewRequest("POST", h.Config.App.OAuthProviderTokenLocation, strings.NewReader(form.Encode()))
- if err != nil {
- return nil, err
- }
- req.WithContext(ctx)
- req.Header.Set("User-Agent", "writefreely")
- req.Header.Set("Accept", "application/json")
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
- req.SetBasicAuth(h.Config.App.OAuthClientID, h.Config.App.OAuthClientSecret)
-
- resp, err := h.HttpClient.Do(req)
- if err != nil {
- return nil, err
- }
-
- // Nick: I like using limited readers to reduce the risk of an endpoint
- // being broken or compromised.
- lr := io.LimitReader(resp.Body, tokenRequestMaxLen)
- body, err := ioutil.ReadAll(lr)
- if err != nil {
- return nil, err
- }
-
- var tokenResponse TokenResponse
- err = json.Unmarshal(body, &tokenResponse)
- if err != nil {
- return nil, err
- }
- return &tokenResponse, nil
- }
-
- func (h oauthHandler) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) {
- req, err := http.NewRequest("GET", h.Config.App.OAuthProviderInspectLocation, nil)
- if err != nil {
- return nil, err
- }
- req.WithContext(ctx)
- req.Header.Set("User-Agent", "writefreely")
- req.Header.Set("Accept", "application/json")
- req.Header.Set("Authorization", "Bearer "+accessToken)
-
- resp, err := h.HttpClient.Do(req)
- if err != nil {
- return nil, err
- }
-
- // Nick: I like using limited readers to reduce the risk of an endpoint
- // being broken or compromised.
- lr := io.LimitReader(resp.Body, infoRequestMaxLen)
- body, err := ioutil.ReadAll(lr)
- if err != nil {
- return nil, err
- }
-
- var inspectResponse InspectResponse
- err = json.Unmarshal(body, &inspectResponse)
- if err != nil {
- return nil, err
- }
- return &inspectResponse, nil
- }
-
- func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error {
- // An error may be returned, but a valid session should always be returned.
- session, _ := store.Get(r, cookieName)
- session.Values[cookieUserVal] = user.Cookie()
- if err := session.Save(r, w); err != nil {
- fmt.Println("error saving session", err)
- return err
- }
- http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
- return nil
- }
-
- // failOAuthRequest is an HTTP handler helper that formats returned error
- // messages.
- func failOAuthRequest(w http.ResponseWriter, statusCode int, message string) {
- w.Header().Set("Content-Type", "application/json")
- w.WriteHeader(statusCode)
- err := json.NewEncoder(w).Encode(map[string]interface{}{
- "error": message,
- })
- if err != nil {
- panic(err)
- }
- }
|