A clean, Markdown-based publishing platform made for writers. Write together, and build a community. https://writefreely.org
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

253 lines
7.2 KiB

  1. package writefreely
  2. import (
  3. "context"
  4. "encoding/json"
  5. "github.com/gorilla/sessions"
  6. "github.com/guregu/null/zero"
  7. "github.com/writeas/impart"
  8. "github.com/writeas/web-core/auth"
  9. "github.com/writeas/web-core/log"
  10. "github.com/writeas/writefreely/config"
  11. "io"
  12. "io/ioutil"
  13. "net/http"
  14. "net/url"
  15. "strings"
  16. "time"
  17. )
  18. // TokenResponse contains data returned when a token is created either
  19. // through a code exchange or using a refresh token.
  20. type TokenResponse struct {
  21. AccessToken string `json:"access_token"`
  22. ExpiresIn int `json:"expires_in"`
  23. RefreshToken string `json:"refresh_token"`
  24. TokenType string `json:"token_type"`
  25. }
  26. // InspectResponse contains data returned when an access token is inspected.
  27. type InspectResponse struct {
  28. ClientID string `json:"client_id"`
  29. UserID int64 `json:"user_id"`
  30. ExpiresAt time.Time `json:"expires_at"`
  31. Username string `json:"username"`
  32. Email string `json:"email"`
  33. }
  34. // tokenRequestMaxLen is the most bytes that we'll read from the /oauth/token
  35. // endpoint. One megabyte is plenty.
  36. const tokenRequestMaxLen = 1000000
  37. // infoRequestMaxLen is the most bytes that we'll read from the
  38. // /oauth/inspect endpoint.
  39. const infoRequestMaxLen = 1000000
  40. // OAuthDatastoreProvider provides a minimal interface of data store, config,
  41. // and session store for use with the oauth handlers.
  42. type OAuthDatastoreProvider interface {
  43. DB() OAuthDatastore
  44. Config() *config.Config
  45. SessionStore() sessions.Store
  46. }
  47. // OAuthDatastore provides a minimal interface of data store methods used in
  48. // oauth functionality.
  49. type OAuthDatastore interface {
  50. GenerateOAuthState(context.Context) (string, error)
  51. ValidateOAuthState(context.Context, string) error
  52. GetIDForRemoteUser(context.Context, int64) (int64, error)
  53. CreateUser(*config.Config, *User, string) error
  54. RecordRemoteUserID(context.Context, int64, int64) error
  55. GetUserForAuthByID(int64) (*User, error)
  56. }
  57. type HttpClient interface {
  58. Do(req *http.Request) (*http.Response, error)
  59. }
  60. type oauthHandler struct {
  61. HttpClient HttpClient
  62. }
  63. // buildAuthURL returns a URL used to initiate authentication.
  64. func buildAuthURL(app OAuthDatastoreProvider, ctx context.Context, clientID, authLocation, callbackURL string) (string, error) {
  65. state, err := app.DB().GenerateOAuthState(ctx)
  66. if err != nil {
  67. return "", err
  68. }
  69. u, err := url.Parse(authLocation)
  70. if err != nil {
  71. return "", err
  72. }
  73. q := u.Query()
  74. q.Set("client_id", clientID)
  75. q.Set("redirect_uri", callbackURL)
  76. q.Set("response_type", "code")
  77. q.Set("state", state)
  78. u.RawQuery = q.Encode()
  79. return u.String(), nil
  80. }
  81. func (h oauthHandler) viewOauthInit(app OAuthDatastoreProvider, w http.ResponseWriter, r *http.Request) error {
  82. location, err := buildAuthURL(app, r.Context(), app.Config().App.OAuthClientID, app.Config().App.OAuthProviderAuthLocation, app.Config().App.OAuthClientCallbackLocation)
  83. if err != nil {
  84. log.ErrorLog.Println(err)
  85. return impart.HTTPError{Status: http.StatusInternalServerError, Message: "Could not prepare OAuth redirect URL."}
  86. }
  87. http.Redirect(w, r, location, http.StatusTemporaryRedirect)
  88. return nil
  89. }
  90. func (h oauthHandler) viewOauthCallback(app OAuthDatastoreProvider, w http.ResponseWriter, r *http.Request) error {
  91. ctx := r.Context()
  92. code := r.FormValue("code")
  93. state := r.FormValue("state")
  94. err := app.DB().ValidateOAuthState(ctx, state)
  95. if err != nil {
  96. return err
  97. }
  98. tokenResponse, err := h.exchangeOauthCode(app, ctx, code)
  99. if err != nil {
  100. return err
  101. }
  102. // Now that we have the access token, let's use it real quick to make sur
  103. // it really really works.
  104. tokenInfo, err := h.inspectOauthAccessToken(app, ctx, tokenResponse.AccessToken)
  105. if err != nil {
  106. return err
  107. }
  108. localUserID, err := app.DB().GetIDForRemoteUser(ctx, tokenInfo.UserID)
  109. if err != nil {
  110. return err
  111. }
  112. if localUserID == -1 {
  113. // We don't have, nor do we want, the password from the origin, so we
  114. //create a random string. If the user needs to set a password, they
  115. //can do so through the settings page or through the password reset
  116. //flow.
  117. randPass, err := randString(14)
  118. if err != nil {
  119. return err
  120. }
  121. hashedPass, err := auth.HashPass([]byte(randPass))
  122. if err != nil {
  123. log.ErrorLog.Println(err)
  124. return impart.HTTPError{http.StatusInternalServerError, "Could not create password hash."}
  125. }
  126. newUser := &User{
  127. Username: tokenInfo.Username,
  128. HashedPass: hashedPass,
  129. HasPass: true,
  130. Email: zero.NewString("", tokenInfo.Email != ""),
  131. Created: time.Now().Truncate(time.Second).UTC(),
  132. }
  133. err = app.DB().CreateUser(app.Config(), newUser, newUser.Username)
  134. if err != nil {
  135. return err
  136. }
  137. err = app.DB().RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID)
  138. if err != nil {
  139. return err
  140. }
  141. return loginOrFail(app, w, r, newUser)
  142. }
  143. user, err := app.DB().GetUserForAuthByID(localUserID)
  144. if err != nil {
  145. return err
  146. }
  147. return loginOrFail(app, w, r, user)
  148. }
  149. func (h oauthHandler) exchangeOauthCode(app OAuthDatastoreProvider, ctx context.Context, code string) (*TokenResponse, error) {
  150. form := url.Values{}
  151. form.Add("grant_type", "authorization_code")
  152. form.Add("redirect_uri", app.Config().App.OAuthClientCallbackLocation)
  153. form.Add("code", code)
  154. req, err := http.NewRequest("POST", app.Config().App.OAuthProviderTokenLocation, strings.NewReader(form.Encode()))
  155. if err != nil {
  156. return nil, err
  157. }
  158. req.WithContext(ctx)
  159. req.Header.Set("User-Agent", "writefreely")
  160. req.Header.Set("Accept", "application/json")
  161. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  162. req.SetBasicAuth(app.Config().App.OAuthClientID, app.Config().App.OAuthClientSecret)
  163. resp, err := h.HttpClient.Do(req)
  164. if err != nil {
  165. return nil, err
  166. }
  167. // Nick: I like using limited readers to reduce the risk of an endpoint
  168. // being broken or compromised.
  169. lr := io.LimitReader(resp.Body, tokenRequestMaxLen)
  170. body, err := ioutil.ReadAll(lr)
  171. if err != nil {
  172. return nil, err
  173. }
  174. var tokenResponse TokenResponse
  175. err = json.Unmarshal(body, &tokenResponse)
  176. if err != nil {
  177. return nil, err
  178. }
  179. return &tokenResponse, nil
  180. }
  181. func (h oauthHandler) inspectOauthAccessToken(app OAuthDatastoreProvider, ctx context.Context, accessToken string) (*InspectResponse, error) {
  182. req, err := http.NewRequest("GET", app.Config().App.OAuthProviderInspectLocation, nil)
  183. if err != nil {
  184. return nil, err
  185. }
  186. req.WithContext(ctx)
  187. req.Header.Set("User-Agent", "writefreely")
  188. req.Header.Set("Accept", "application/json")
  189. req.Header.Set("Authorization", "Bearer "+accessToken)
  190. resp, err := h.HttpClient.Do(req)
  191. if err != nil {
  192. return nil, err
  193. }
  194. // Nick: I like using limited readers to reduce the risk of an endpoint
  195. // being broken or compromised.
  196. lr := io.LimitReader(resp.Body, infoRequestMaxLen)
  197. body, err := ioutil.ReadAll(lr)
  198. if err != nil {
  199. return nil, err
  200. }
  201. var inspectResponse InspectResponse
  202. err = json.Unmarshal(body, &inspectResponse)
  203. if err != nil {
  204. return nil, err
  205. }
  206. return &inspectResponse, nil
  207. }
  208. func loginOrFail(app OAuthDatastoreProvider, w http.ResponseWriter, r *http.Request, user *User) error {
  209. session, err := app.SessionStore().Get(r, cookieName)
  210. if err != nil {
  211. return err
  212. }
  213. session.Values[cookieUserVal] = user.Cookie()
  214. if err = session.Save(r, w); err != nil {
  215. return err
  216. }
  217. http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
  218. return nil
  219. }