A clean, Markdown-based publishing platform made for writers. Write together, and build a community. https://writefreely.org
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

291 regels
8.6 KiB

  1. package writefreely
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/gorilla/sessions"
  7. "github.com/guregu/null/zero"
  8. "github.com/writeas/nerds/store"
  9. "github.com/writeas/web-core/auth"
  10. "github.com/writeas/web-core/log"
  11. "github.com/writeas/writefreely/config"
  12. "io"
  13. "io/ioutil"
  14. "net/http"
  15. "net/url"
  16. "strings"
  17. "time"
  18. )
  19. // TokenResponse contains data returned when a token is created either
  20. // through a code exchange or using a refresh token.
  21. type TokenResponse struct {
  22. AccessToken string `json:"access_token"`
  23. ExpiresIn int `json:"expires_in"`
  24. RefreshToken string `json:"refresh_token"`
  25. TokenType string `json:"token_type"`
  26. Error string `json:"error"`
  27. }
  28. // InspectResponse contains data returned when an access token is inspected.
  29. type InspectResponse struct {
  30. ClientID string `json:"client_id"`
  31. UserID int64 `json:"user_id"`
  32. ExpiresAt time.Time `json:"expires_at"`
  33. Username string `json:"username"`
  34. Email string `json:"email"`
  35. }
  36. // tokenRequestMaxLen is the most bytes that we'll read from the /oauth/token
  37. // endpoint. One megabyte is plenty.
  38. const tokenRequestMaxLen = 1000000
  39. // infoRequestMaxLen is the most bytes that we'll read from the
  40. // /oauth/inspect endpoint.
  41. const infoRequestMaxLen = 1000000
  42. // OAuthDatastoreProvider provides a minimal interface of data store, config,
  43. // and session store for use with the oauth handlers.
  44. type OAuthDatastoreProvider interface {
  45. DB() OAuthDatastore
  46. Config() *config.Config
  47. SessionStore() sessions.Store
  48. }
  49. // OAuthDatastore provides a minimal interface of data store methods used in
  50. // oauth functionality.
  51. type OAuthDatastore interface {
  52. GenerateOAuthState(context.Context) (string, error)
  53. ValidateOAuthState(context.Context, string) error
  54. GetIDForRemoteUser(context.Context, int64) (int64, error)
  55. CreateUser(*config.Config, *User, string) error
  56. RecordRemoteUserID(context.Context, int64, int64) error
  57. GetUserForAuthByID(int64) (*User, error)
  58. }
  59. type HttpClient interface {
  60. Do(req *http.Request) (*http.Response, error)
  61. }
  62. type oauthHandler struct {
  63. Config *config.Config
  64. DB OAuthDatastore
  65. Store sessions.Store
  66. HttpClient HttpClient
  67. }
  68. // buildAuthURL returns a URL used to initiate authentication.
  69. func buildAuthURL(db OAuthDatastore, ctx context.Context, clientID, authLocation, callbackURL string) (string, error) {
  70. state, err := db.GenerateOAuthState(ctx)
  71. if err != nil {
  72. return "", err
  73. }
  74. u, err := url.Parse(authLocation)
  75. if err != nil {
  76. return "", err
  77. }
  78. q := u.Query()
  79. q.Set("client_id", clientID)
  80. q.Set("redirect_uri", callbackURL)
  81. q.Set("response_type", "code")
  82. q.Set("state", state)
  83. u.RawQuery = q.Encode()
  84. return u.String(), nil
  85. }
  86. // app *App, w http.ResponseWriter, r *http.Request
  87. func (h oauthHandler) viewOauthInit(w http.ResponseWriter, r *http.Request) {
  88. location, err := buildAuthURL(h.DB, r.Context(), h.Config.App.OAuthClientID, h.Config.App.OAuthProviderAuthLocation, h.Config.App.OAuthClientCallbackLocation)
  89. if err != nil {
  90. failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url")
  91. return
  92. }
  93. http.Redirect(w, r, location, http.StatusTemporaryRedirect)
  94. }
  95. func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) {
  96. ctx := r.Context()
  97. code := r.FormValue("code")
  98. state := r.FormValue("state")
  99. err := h.DB.ValidateOAuthState(ctx, state)
  100. if err != nil {
  101. log.Error("Unable to ValidateOAuthState: %s", err)
  102. failOAuthRequest(w, http.StatusInternalServerError, err.Error())
  103. return
  104. }
  105. tokenResponse, err := h.exchangeOauthCode(ctx, code)
  106. if err != nil {
  107. log.Error("Unable to exchangeOauthCode: %s", err)
  108. failOAuthRequest(w, http.StatusInternalServerError, err.Error())
  109. return
  110. }
  111. // Now that we have the access token, let's use it real quick to make sur
  112. // it really really works.
  113. tokenInfo, err := h.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
  114. if err != nil {
  115. log.Error("Unable to inspectOauthAccessToken: %s", err)
  116. failOAuthRequest(w, http.StatusInternalServerError, err.Error())
  117. return
  118. }
  119. localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID)
  120. if err != nil {
  121. log.Error("Unable to GetIDForRemoteUser: %s", err)
  122. failOAuthRequest(w, http.StatusInternalServerError, err.Error())
  123. return
  124. }
  125. fmt.Println("local user id", localUserID)
  126. if localUserID == -1 {
  127. // We don't have, nor do we want, the password from the origin, so we
  128. //create a random string. If the user needs to set a password, they
  129. //can do so through the settings page or through the password reset
  130. //flow.
  131. randPass := store.Generate62RandomString(14)
  132. hashedPass, err := auth.HashPass([]byte(randPass))
  133. if err != nil {
  134. log.ErrorLog.Println(err)
  135. failOAuthRequest(w, http.StatusInternalServerError, "unable to create password hash")
  136. return
  137. }
  138. newUser := &User{
  139. Username: tokenInfo.Username,
  140. HashedPass: hashedPass,
  141. HasPass: true,
  142. Email: zero.NewString("", tokenInfo.Email != ""),
  143. Created: time.Now().Truncate(time.Second).UTC(),
  144. }
  145. err = h.DB.CreateUser(h.Config, newUser, newUser.Username)
  146. if err != nil {
  147. failOAuthRequest(w, http.StatusInternalServerError, err.Error())
  148. return
  149. }
  150. err = h.DB.RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID)
  151. if err != nil {
  152. failOAuthRequest(w, http.StatusInternalServerError, err.Error())
  153. return
  154. }
  155. if err := loginOrFail(h.Store, w, r, newUser); err != nil {
  156. failOAuthRequest(w, http.StatusInternalServerError, err.Error())
  157. }
  158. return
  159. }
  160. user, err := h.DB.GetUserForAuthByID(localUserID)
  161. if err != nil {
  162. failOAuthRequest(w, http.StatusInternalServerError, err.Error())
  163. return
  164. }
  165. if err = loginOrFail(h.Store, w, r, user); err != nil {
  166. failOAuthRequest(w, http.StatusInternalServerError, err.Error())
  167. }
  168. }
  169. func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) {
  170. form := url.Values{}
  171. form.Add("grant_type", "authorization_code")
  172. form.Add("redirect_uri", h.Config.App.OAuthClientCallbackLocation)
  173. form.Add("code", code)
  174. req, err := http.NewRequest("POST", h.Config.App.OAuthProviderTokenLocation, strings.NewReader(form.Encode()))
  175. if err != nil {
  176. return nil, err
  177. }
  178. req.WithContext(ctx)
  179. req.Header.Set("User-Agent", "writefreely")
  180. req.Header.Set("Accept", "application/json")
  181. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  182. req.SetBasicAuth(h.Config.App.OAuthClientID, h.Config.App.OAuthClientSecret)
  183. resp, err := h.HttpClient.Do(req)
  184. if err != nil {
  185. return nil, err
  186. }
  187. // Nick: I like using limited readers to reduce the risk of an endpoint
  188. // being broken or compromised.
  189. lr := io.LimitReader(resp.Body, tokenRequestMaxLen)
  190. body, err := ioutil.ReadAll(lr)
  191. if err != nil {
  192. return nil, err
  193. }
  194. var tokenResponse TokenResponse
  195. err = json.Unmarshal(body, &tokenResponse)
  196. if err != nil {
  197. return nil, err
  198. }
  199. // Check the response for an error message, and return it if there is one.
  200. if tokenResponse.Error != "" {
  201. return nil, fmt.Errorf(tokenResponse.Error)
  202. }
  203. return &tokenResponse, nil
  204. }
  205. func (h oauthHandler) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) {
  206. req, err := http.NewRequest("GET", h.Config.App.OAuthProviderInspectLocation, nil)
  207. if err != nil {
  208. return nil, err
  209. }
  210. req.WithContext(ctx)
  211. req.Header.Set("User-Agent", "writefreely")
  212. req.Header.Set("Accept", "application/json")
  213. req.Header.Set("Authorization", "Bearer "+accessToken)
  214. resp, err := h.HttpClient.Do(req)
  215. if err != nil {
  216. return nil, err
  217. }
  218. // Nick: I like using limited readers to reduce the risk of an endpoint
  219. // being broken or compromised.
  220. lr := io.LimitReader(resp.Body, infoRequestMaxLen)
  221. body, err := ioutil.ReadAll(lr)
  222. if err != nil {
  223. return nil, err
  224. }
  225. var inspectResponse InspectResponse
  226. err = json.Unmarshal(body, &inspectResponse)
  227. if err != nil {
  228. return nil, err
  229. }
  230. return &inspectResponse, nil
  231. }
  232. func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error {
  233. // An error may be returned, but a valid session should always be returned.
  234. session, _ := store.Get(r, cookieName)
  235. session.Values[cookieUserVal] = user.Cookie()
  236. if err := session.Save(r, w); err != nil {
  237. fmt.Println("error saving session", err)
  238. return err
  239. }
  240. http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
  241. return nil
  242. }
  243. // failOAuthRequest is an HTTP handler helper that formats returned error
  244. // messages.
  245. func failOAuthRequest(w http.ResponseWriter, statusCode int, message string) {
  246. w.Header().Set("Content-Type", "application/json")
  247. w.WriteHeader(statusCode)
  248. err := json.NewEncoder(w).Encode(map[string]interface{}{
  249. "error": message,
  250. })
  251. if err != nil {
  252. panic(err)
  253. }
  254. }