A clean, Markdown-based publishing platform made for writers. Write together, and build a community. https://writefreely.org
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

284 lines
8.2 KiB

  1. package writefreely
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/gorilla/sessions"
  7. "github.com/guregu/null/zero"
  8. "github.com/writeas/web-core/auth"
  9. "github.com/writeas/web-core/log"
  10. "github.com/writeas/writefreely/config"
  11. "io"
  12. "io/ioutil"
  13. "net/http"
  14. "net/url"
  15. "strings"
  16. "time"
  17. )
  18. // TokenResponse contains data returned when a token is created either
  19. // through a code exchange or using a refresh token.
  20. type TokenResponse struct {
  21. AccessToken string `json:"access_token"`
  22. ExpiresIn int `json:"expires_in"`
  23. RefreshToken string `json:"refresh_token"`
  24. TokenType string `json:"token_type"`
  25. }
  26. // InspectResponse contains data returned when an access token is inspected.
  27. type InspectResponse struct {
  28. ClientID string `json:"client_id"`
  29. UserID int64 `json:"user_id"`
  30. ExpiresAt time.Time `json:"expires_at"`
  31. Username string `json:"username"`
  32. Email string `json:"email"`
  33. }
  34. // tokenRequestMaxLen is the most bytes that we'll read from the /oauth/token
  35. // endpoint. One megabyte is plenty.
  36. const tokenRequestMaxLen = 1000000
  37. // infoRequestMaxLen is the most bytes that we'll read from the
  38. // /oauth/inspect endpoint.
  39. const infoRequestMaxLen = 1000000
  40. // OAuthDatastoreProvider provides a minimal interface of data store, config,
  41. // and session store for use with the oauth handlers.
  42. type OAuthDatastoreProvider interface {
  43. DB() OAuthDatastore
  44. Config() *config.Config
  45. SessionStore() sessions.Store
  46. }
  47. // OAuthDatastore provides a minimal interface of data store methods used in
  48. // oauth functionality.
  49. type OAuthDatastore interface {
  50. GenerateOAuthState(context.Context) (string, error)
  51. ValidateOAuthState(context.Context, string) error
  52. GetIDForRemoteUser(context.Context, int64) (int64, error)
  53. CreateUser(*config.Config, *User, string) error
  54. RecordRemoteUserID(context.Context, int64, int64) error
  55. GetUserForAuthByID(int64) (*User, error)
  56. }
  57. type HttpClient interface {
  58. Do(req *http.Request) (*http.Response, error)
  59. }
  60. type oauthHandler struct {
  61. Config *config.Config
  62. DB OAuthDatastore
  63. Store sessions.Store
  64. HttpClient HttpClient
  65. }
  66. // buildAuthURL returns a URL used to initiate authentication.
  67. func buildAuthURL(db OAuthDatastore, ctx context.Context, clientID, authLocation, callbackURL string) (string, error) {
  68. state, err := db.GenerateOAuthState(ctx)
  69. if err != nil {
  70. return "", err
  71. }
  72. u, err := url.Parse(authLocation)
  73. if err != nil {
  74. return "", err
  75. }
  76. q := u.Query()
  77. q.Set("client_id", clientID)
  78. q.Set("redirect_uri", callbackURL)
  79. q.Set("response_type", "code")
  80. q.Set("state", state)
  81. u.RawQuery = q.Encode()
  82. return u.String(), nil
  83. }
  84. // app *App, w http.ResponseWriter, r *http.Request
  85. func (h oauthHandler) viewOauthInit(w http.ResponseWriter, r *http.Request) {
  86. location, err := buildAuthURL(h.DB, r.Context(), h.Config.App.OAuthClientID, h.Config.App.OAuthProviderAuthLocation, h.Config.App.OAuthClientCallbackLocation)
  87. if err != nil {
  88. failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url")
  89. return
  90. }
  91. http.Redirect(w, r, location, http.StatusTemporaryRedirect)
  92. }
  93. func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) {
  94. ctx := r.Context()
  95. code := r.FormValue("code")
  96. state := r.FormValue("state")
  97. err := h.DB.ValidateOAuthState(ctx, state)
  98. if err != nil {
  99. failOAuthRequest(w, http.StatusInternalServerError, err.Error())
  100. return
  101. }
  102. tokenResponse, err := h.exchangeOauthCode(ctx, code)
  103. if err != nil {
  104. failOAuthRequest(w, http.StatusInternalServerError, err.Error())
  105. return
  106. }
  107. // Now that we have the access token, let's use it real quick to make sur
  108. // it really really works.
  109. tokenInfo, err := h.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
  110. if err != nil {
  111. failOAuthRequest(w, http.StatusInternalServerError, err.Error())
  112. return
  113. }
  114. localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID)
  115. if err != nil {
  116. failOAuthRequest(w, http.StatusInternalServerError, err.Error())
  117. return
  118. }
  119. fmt.Println("local user id", localUserID)
  120. if localUserID == -1 {
  121. // We don't have, nor do we want, the password from the origin, so we
  122. //create a random string. If the user needs to set a password, they
  123. //can do so through the settings page or through the password reset
  124. //flow.
  125. randPass, err := randString(14)
  126. if err != nil {
  127. failOAuthRequest(w, http.StatusInternalServerError, err.Error())
  128. return
  129. }
  130. hashedPass, err := auth.HashPass([]byte(randPass))
  131. if err != nil {
  132. log.ErrorLog.Println(err)
  133. failOAuthRequest(w, http.StatusInternalServerError, "unable to create password hash")
  134. return
  135. }
  136. newUser := &User{
  137. Username: tokenInfo.Username,
  138. HashedPass: hashedPass,
  139. HasPass: true,
  140. Email: zero.NewString("", tokenInfo.Email != ""),
  141. Created: time.Now().Truncate(time.Second).UTC(),
  142. }
  143. err = h.DB.CreateUser(h.Config, newUser, newUser.Username)
  144. if err != nil {
  145. failOAuthRequest(w, http.StatusInternalServerError, err.Error())
  146. return
  147. }
  148. err = h.DB.RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID)
  149. if err != nil {
  150. failOAuthRequest(w, http.StatusInternalServerError, err.Error())
  151. return
  152. }
  153. if err := loginOrFail(h.Store, w, r, newUser); err != nil {
  154. failOAuthRequest(w, http.StatusInternalServerError, err.Error())
  155. }
  156. return
  157. }
  158. user, err := h.DB.GetUserForAuthByID(localUserID)
  159. if err != nil {
  160. failOAuthRequest(w, http.StatusInternalServerError, err.Error())
  161. return
  162. }
  163. if err = loginOrFail(h.Store, w, r, user); err != nil {
  164. failOAuthRequest(w, http.StatusInternalServerError, err.Error())
  165. }
  166. }
  167. func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) {
  168. form := url.Values{}
  169. form.Add("grant_type", "authorization_code")
  170. form.Add("redirect_uri", h.Config.App.OAuthClientCallbackLocation)
  171. form.Add("code", code)
  172. req, err := http.NewRequest("POST", h.Config.App.OAuthProviderTokenLocation, strings.NewReader(form.Encode()))
  173. if err != nil {
  174. return nil, err
  175. }
  176. req.WithContext(ctx)
  177. req.Header.Set("User-Agent", "writefreely")
  178. req.Header.Set("Accept", "application/json")
  179. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  180. req.SetBasicAuth(h.Config.App.OAuthClientID, h.Config.App.OAuthClientSecret)
  181. resp, err := h.HttpClient.Do(req)
  182. if err != nil {
  183. return nil, err
  184. }
  185. // Nick: I like using limited readers to reduce the risk of an endpoint
  186. // being broken or compromised.
  187. lr := io.LimitReader(resp.Body, tokenRequestMaxLen)
  188. body, err := ioutil.ReadAll(lr)
  189. if err != nil {
  190. return nil, err
  191. }
  192. var tokenResponse TokenResponse
  193. err = json.Unmarshal(body, &tokenResponse)
  194. if err != nil {
  195. return nil, err
  196. }
  197. return &tokenResponse, nil
  198. }
  199. func (h oauthHandler) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) {
  200. req, err := http.NewRequest("GET", h.Config.App.OAuthProviderInspectLocation, nil)
  201. if err != nil {
  202. return nil, err
  203. }
  204. req.WithContext(ctx)
  205. req.Header.Set("User-Agent", "writefreely")
  206. req.Header.Set("Accept", "application/json")
  207. req.Header.Set("Authorization", "Bearer "+accessToken)
  208. resp, err := h.HttpClient.Do(req)
  209. if err != nil {
  210. return nil, err
  211. }
  212. // Nick: I like using limited readers to reduce the risk of an endpoint
  213. // being broken or compromised.
  214. lr := io.LimitReader(resp.Body, infoRequestMaxLen)
  215. body, err := ioutil.ReadAll(lr)
  216. if err != nil {
  217. return nil, err
  218. }
  219. var inspectResponse InspectResponse
  220. err = json.Unmarshal(body, &inspectResponse)
  221. if err != nil {
  222. return nil, err
  223. }
  224. return &inspectResponse, nil
  225. }
  226. func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error {
  227. // An error may be returned, but a valid session should always be returned.
  228. session, _ := store.Get(r, cookieName)
  229. session.Values[cookieUserVal] = user.Cookie()
  230. if err := session.Save(r, w); err != nil {
  231. fmt.Println("error saving session", err)
  232. return err
  233. }
  234. http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
  235. return nil
  236. }
  237. // failOAuthRequest is an HTTP handler helper that formats returned error
  238. // messages.
  239. func failOAuthRequest(w http.ResponseWriter, statusCode int, message string) {
  240. w.Header().Set("Content-Type", "application/json")
  241. w.WriteHeader(statusCode)
  242. err := json.NewEncoder(w).Encode(map[string]interface{}{
  243. "error": message,
  244. })
  245. if err != nil {
  246. panic(err)
  247. }
  248. }