|
- package writefreely
-
- import (
- "context"
- "encoding/json"
- "fmt"
- "github.com/gorilla/mux"
- "github.com/gorilla/sessions"
- "github.com/writeas/impart"
- "github.com/writeas/web-core/log"
- "github.com/writeas/writefreely/config"
- "io"
- "io/ioutil"
- "net/http"
- "net/url"
- "strings"
- "time"
- )
-
- // TokenResponse contains data returned when a token is created either
- // through a code exchange or using a refresh token.
- type TokenResponse struct {
- AccessToken string `json:"access_token"`
- ExpiresIn int `json:"expires_in"`
- RefreshToken string `json:"refresh_token"`
- TokenType string `json:"token_type"`
- Error string `json:"error"`
- }
-
- // InspectResponse contains data returned when an access token is inspected.
- type InspectResponse struct {
- ClientID string `json:"client_id"`
- UserID string `json:"user_id"`
- ExpiresAt time.Time `json:"expires_at"`
- Username string `json:"username"`
- DisplayName string `json:"-"`
- Email string `json:"email"`
- Error string `json:"error"`
- }
-
- // tokenRequestMaxLen is the most bytes that we'll read from the /oauth/token
- // endpoint. One megabyte is plenty.
- const tokenRequestMaxLen = 1000000
-
- // infoRequestMaxLen is the most bytes that we'll read from the
- // /oauth/inspect endpoint.
- const infoRequestMaxLen = 1000000
-
- // OAuthDatastoreProvider provides a minimal interface of data store, config,
- // and session store for use with the oauth handlers.
- type OAuthDatastoreProvider interface {
- DB() OAuthDatastore
- Config() *config.Config
- SessionStore() sessions.Store
- }
-
- // OAuthDatastore provides a minimal interface of data store methods used in
- // oauth functionality.
- type OAuthDatastore interface {
- GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
- RecordRemoteUserID(context.Context, int64, string, string, string, string) error
- ValidateOAuthState(context.Context, string) (string, string, error)
- GenerateOAuthState(context.Context, string, string) (string, error)
-
- CreateUser(*config.Config, *User, string) error
- GetUserByID(int64) (*User, error)
- }
-
- type HttpClient interface {
- Do(req *http.Request) (*http.Response, error)
- }
-
- type oauthClient interface {
- GetProvider() string
- GetClientID() string
- GetCallbackLocation() string
- buildLoginURL(state string) (string, error)
- exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error)
- inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error)
- }
-
- type callbackProxyClient struct {
- server string
- callbackLocation string
- httpClient HttpClient
- }
-
- type oauthHandler struct {
- Config *config.Config
- DB OAuthDatastore
- Store sessions.Store
- EmailKey []byte
- oauthClient oauthClient
- callbackProxy *callbackProxyClient
- }
-
- func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error {
- ctx := r.Context()
- state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID())
- if err != nil {
- return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
- }
-
- if h.callbackProxy != nil {
- if err := h.callbackProxy.register(ctx, state); err != nil {
- return impart.HTTPError{http.StatusInternalServerError, "could not register state server"}
- }
- }
-
- location, err := h.oauthClient.buildLoginURL(state)
- if err != nil {
- return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
- }
- return impart.HTTPError{http.StatusTemporaryRedirect, location}
- }
-
- func configureSlackOauth(parentHandler *Handler, r *mux.Router, app *App) {
- if app.Config().SlackOauth.ClientID != "" {
- callbackLocation := app.Config().App.Host + "/oauth/callback/slack"
-
- var stateRegisterClient *callbackProxyClient = nil
- if app.Config().SlackOauth.CallbackProxyAPI != "" {
- stateRegisterClient = &callbackProxyClient{
- server: app.Config().SlackOauth.CallbackProxyAPI,
- callbackLocation: app.Config().App.Host + "/oauth/callback/slack",
- httpClient: config.DefaultHTTPClient(),
- }
- callbackLocation = app.Config().SlackOauth.CallbackProxy
- }
- oauthClient := slackOauthClient{
- ClientID: app.Config().SlackOauth.ClientID,
- ClientSecret: app.Config().SlackOauth.ClientSecret,
- TeamID: app.Config().SlackOauth.TeamID,
- HttpClient: config.DefaultHTTPClient(),
- CallbackLocation: callbackLocation,
- }
- configureOauthRoutes(parentHandler, r, app, oauthClient, stateRegisterClient)
- }
- }
-
- func configureWriteAsOauth(parentHandler *Handler, r *mux.Router, app *App) {
- if app.Config().WriteAsOauth.ClientID != "" {
- callbackLocation := app.Config().App.Host + "/oauth/callback/write.as"
-
- var callbackProxy *callbackProxyClient = nil
- if app.Config().WriteAsOauth.CallbackProxy != "" {
- callbackProxy = &callbackProxyClient{
- server: app.Config().WriteAsOauth.CallbackProxyAPI,
- callbackLocation: app.Config().App.Host + "/oauth/callback/write.as",
- httpClient: config.DefaultHTTPClient(),
- }
- callbackLocation = app.Config().SlackOauth.CallbackProxy
- }
-
- oauthClient := writeAsOauthClient{
- ClientID: app.Config().WriteAsOauth.ClientID,
- ClientSecret: app.Config().WriteAsOauth.ClientSecret,
- ExchangeLocation: config.OrDefaultString(app.Config().WriteAsOauth.TokenLocation, writeAsExchangeLocation),
- InspectLocation: config.OrDefaultString(app.Config().WriteAsOauth.InspectLocation, writeAsIdentityLocation),
- AuthLocation: config.OrDefaultString(app.Config().WriteAsOauth.AuthLocation, writeAsAuthLocation),
- HttpClient: config.DefaultHTTPClient(),
- CallbackLocation: callbackLocation,
- }
- configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
- }
- }
-
- func configureOauthRoutes(parentHandler *Handler, r *mux.Router, app *App, oauthClient oauthClient, callbackProxy *callbackProxyClient) {
- handler := &oauthHandler{
- Config: app.Config(),
- DB: app.DB(),
- Store: app.SessionStore(),
- oauthClient: oauthClient,
- EmailKey: app.keys.EmailKey,
- callbackProxy: callbackProxy,
- }
- r.HandleFunc("/oauth/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthInit)).Methods("GET")
- r.HandleFunc("/oauth/callback/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthCallback)).Methods("GET")
- r.HandleFunc("/oauth/signup", parentHandler.OAuth(handler.viewOauthSignup)).Methods("POST")
- }
-
- func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http.Request) error {
- ctx := r.Context()
-
- code := r.FormValue("code")
- state := r.FormValue("state")
-
- provider, clientID, err := h.DB.ValidateOAuthState(ctx, state)
- if err != nil {
- log.Error("Unable to ValidateOAuthState: %s", err)
- return impart.HTTPError{http.StatusInternalServerError, err.Error()}
- }
-
- tokenResponse, err := h.oauthClient.exchangeOauthCode(ctx, code)
- if err != nil {
- log.Error("Unable to exchangeOauthCode: %s", err)
- return impart.HTTPError{http.StatusInternalServerError, err.Error()}
- }
-
- // Now that we have the access token, let's use it real quick to make sur
- // it really really works.
- tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
- if err != nil {
- log.Error("Unable to inspectOauthAccessToken: %s", err)
- return impart.HTTPError{http.StatusInternalServerError, err.Error()}
- }
-
- localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID, provider, clientID)
- if err != nil {
- log.Error("Unable to GetIDForRemoteUser: %s", err)
- return impart.HTTPError{http.StatusInternalServerError, err.Error()}
- }
-
- if localUserID != -1 {
- user, err := h.DB.GetUserByID(localUserID)
- if err != nil {
- log.Error("Unable to GetUserByID %d: %s", localUserID, err)
- return impart.HTTPError{http.StatusInternalServerError, err.Error()}
- }
- if err = loginOrFail(h.Store, w, r, user); err != nil {
- log.Error("Unable to loginOrFail %d: %s", localUserID, err)
- return impart.HTTPError{http.StatusInternalServerError, err.Error()}
- }
- return nil
- }
-
- displayName := tokenInfo.DisplayName
- if len(displayName) == 0 {
- displayName = tokenInfo.Username
- }
-
- tp := &oauthSignupPageParams{
- AccessToken: tokenResponse.AccessToken,
- TokenUsername: tokenInfo.Username,
- TokenAlias: tokenInfo.DisplayName,
- TokenEmail: tokenInfo.Email,
- TokenRemoteUser: tokenInfo.UserID,
- Provider: provider,
- ClientID: clientID,
- }
- tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed)
-
- return h.showOauthSignupPage(app, w, r, tp, nil)
- }
-
- func (r *callbackProxyClient) register(ctx context.Context, state string) error {
- form := url.Values{}
- form.Add("state", state)
- form.Add("location", r.callbackLocation)
- req, err := http.NewRequestWithContext(ctx, "POST", r.server, strings.NewReader(form.Encode()))
- if err != nil {
- return err
- }
- req.Header.Set("User-Agent", "writefreely")
- req.Header.Set("Accept", "application/json")
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
-
- resp, err := r.httpClient.Do(req)
- if err != nil {
- return err
- }
- if resp.StatusCode != http.StatusCreated {
- return fmt.Errorf("unable register state location: %d", resp.StatusCode)
- }
-
- return nil
- }
-
- func limitedJsonUnmarshal(body io.ReadCloser, n int, thing interface{}) error {
- lr := io.LimitReader(body, int64(n+1))
- data, err := ioutil.ReadAll(lr)
- if err != nil {
- return err
- }
- if len(data) == n+1 {
- return fmt.Errorf("content larger than max read allowance: %d", n)
- }
- return json.Unmarshal(data, thing)
- }
-
- func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error {
- // An error may be returned, but a valid session should always be returned.
- session, _ := store.Get(r, cookieName)
- session.Values[cookieUserVal] = user.Cookie()
- if err := session.Save(r, w); err != nil {
- fmt.Println("error saving session", err)
- return err
- }
- http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
- return nil
- }
|