A clean, Markdown-based publishing platform made for writers. Write together, and build a community. https://writefreely.org
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

292 lines
9.8 KiB

  1. package writefreely
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/gorilla/mux"
  7. "github.com/gorilla/sessions"
  8. "github.com/writeas/impart"
  9. "github.com/writeas/web-core/log"
  10. "github.com/writeas/writefreely/config"
  11. "io"
  12. "io/ioutil"
  13. "net/http"
  14. "net/url"
  15. "strings"
  16. "time"
  17. )
  18. // TokenResponse contains data returned when a token is created either
  19. // through a code exchange or using a refresh token.
  20. type TokenResponse struct {
  21. AccessToken string `json:"access_token"`
  22. ExpiresIn int `json:"expires_in"`
  23. RefreshToken string `json:"refresh_token"`
  24. TokenType string `json:"token_type"`
  25. Error string `json:"error"`
  26. }
  27. // InspectResponse contains data returned when an access token is inspected.
  28. type InspectResponse struct {
  29. ClientID string `json:"client_id"`
  30. UserID string `json:"user_id"`
  31. ExpiresAt time.Time `json:"expires_at"`
  32. Username string `json:"username"`
  33. DisplayName string `json:"-"`
  34. Email string `json:"email"`
  35. Error string `json:"error"`
  36. }
  37. // tokenRequestMaxLen is the most bytes that we'll read from the /oauth/token
  38. // endpoint. One megabyte is plenty.
  39. const tokenRequestMaxLen = 1000000
  40. // infoRequestMaxLen is the most bytes that we'll read from the
  41. // /oauth/inspect endpoint.
  42. const infoRequestMaxLen = 1000000
  43. // OAuthDatastoreProvider provides a minimal interface of data store, config,
  44. // and session store for use with the oauth handlers.
  45. type OAuthDatastoreProvider interface {
  46. DB() OAuthDatastore
  47. Config() *config.Config
  48. SessionStore() sessions.Store
  49. }
  50. // OAuthDatastore provides a minimal interface of data store methods used in
  51. // oauth functionality.
  52. type OAuthDatastore interface {
  53. GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
  54. RecordRemoteUserID(context.Context, int64, string, string, string, string) error
  55. ValidateOAuthState(context.Context, string) (string, string, error)
  56. GenerateOAuthState(context.Context, string, string) (string, error)
  57. CreateUser(*config.Config, *User, string) error
  58. GetUserByID(int64) (*User, error)
  59. }
  60. type HttpClient interface {
  61. Do(req *http.Request) (*http.Response, error)
  62. }
  63. type oauthClient interface {
  64. GetProvider() string
  65. GetClientID() string
  66. GetCallbackLocation() string
  67. buildLoginURL(state string) (string, error)
  68. exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error)
  69. inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error)
  70. }
  71. type callbackProxyClient struct {
  72. server string
  73. callbackLocation string
  74. httpClient HttpClient
  75. }
  76. type oauthHandler struct {
  77. Config *config.Config
  78. DB OAuthDatastore
  79. Store sessions.Store
  80. EmailKey []byte
  81. oauthClient oauthClient
  82. callbackProxy *callbackProxyClient
  83. }
  84. func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error {
  85. ctx := r.Context()
  86. state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID())
  87. if err != nil {
  88. return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
  89. }
  90. if h.callbackProxy != nil {
  91. if err := h.callbackProxy.register(ctx, state); err != nil {
  92. return impart.HTTPError{http.StatusInternalServerError, "could not register state server"}
  93. }
  94. }
  95. location, err := h.oauthClient.buildLoginURL(state)
  96. if err != nil {
  97. return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
  98. }
  99. return impart.HTTPError{http.StatusTemporaryRedirect, location}
  100. }
  101. func configureSlackOauth(parentHandler *Handler, r *mux.Router, app *App) {
  102. if app.Config().SlackOauth.ClientID != "" {
  103. callbackLocation := app.Config().App.Host + "/oauth/callback/slack"
  104. var stateRegisterClient *callbackProxyClient = nil
  105. if app.Config().SlackOauth.CallbackProxyAPI != "" {
  106. stateRegisterClient = &callbackProxyClient{
  107. server: app.Config().SlackOauth.CallbackProxyAPI,
  108. callbackLocation: app.Config().App.Host + "/oauth/callback/slack",
  109. httpClient: config.DefaultHTTPClient(),
  110. }
  111. callbackLocation = app.Config().SlackOauth.CallbackProxy
  112. }
  113. oauthClient := slackOauthClient{
  114. ClientID: app.Config().SlackOauth.ClientID,
  115. ClientSecret: app.Config().SlackOauth.ClientSecret,
  116. TeamID: app.Config().SlackOauth.TeamID,
  117. HttpClient: config.DefaultHTTPClient(),
  118. CallbackLocation: callbackLocation,
  119. }
  120. configureOauthRoutes(parentHandler, r, app, oauthClient, stateRegisterClient)
  121. }
  122. }
  123. func configureWriteAsOauth(parentHandler *Handler, r *mux.Router, app *App) {
  124. if app.Config().WriteAsOauth.ClientID != "" {
  125. callbackLocation := app.Config().App.Host + "/oauth/callback/write.as"
  126. var callbackProxy *callbackProxyClient = nil
  127. if app.Config().WriteAsOauth.CallbackProxy != "" {
  128. callbackProxy = &callbackProxyClient{
  129. server: app.Config().WriteAsOauth.CallbackProxyAPI,
  130. callbackLocation: app.Config().App.Host + "/oauth/callback/write.as",
  131. httpClient: config.DefaultHTTPClient(),
  132. }
  133. callbackLocation = app.Config().SlackOauth.CallbackProxy
  134. }
  135. oauthClient := writeAsOauthClient{
  136. ClientID: app.Config().WriteAsOauth.ClientID,
  137. ClientSecret: app.Config().WriteAsOauth.ClientSecret,
  138. ExchangeLocation: config.OrDefaultString(app.Config().WriteAsOauth.TokenLocation, writeAsExchangeLocation),
  139. InspectLocation: config.OrDefaultString(app.Config().WriteAsOauth.InspectLocation, writeAsIdentityLocation),
  140. AuthLocation: config.OrDefaultString(app.Config().WriteAsOauth.AuthLocation, writeAsAuthLocation),
  141. HttpClient: config.DefaultHTTPClient(),
  142. CallbackLocation: callbackLocation,
  143. }
  144. configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
  145. }
  146. }
  147. func configureOauthRoutes(parentHandler *Handler, r *mux.Router, app *App, oauthClient oauthClient, callbackProxy *callbackProxyClient) {
  148. handler := &oauthHandler{
  149. Config: app.Config(),
  150. DB: app.DB(),
  151. Store: app.SessionStore(),
  152. oauthClient: oauthClient,
  153. EmailKey: app.keys.EmailKey,
  154. callbackProxy: callbackProxy,
  155. }
  156. r.HandleFunc("/oauth/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthInit)).Methods("GET")
  157. r.HandleFunc("/oauth/callback/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthCallback)).Methods("GET")
  158. r.HandleFunc("/oauth/signup", parentHandler.OAuth(handler.viewOauthSignup)).Methods("POST")
  159. }
  160. func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http.Request) error {
  161. ctx := r.Context()
  162. code := r.FormValue("code")
  163. state := r.FormValue("state")
  164. provider, clientID, err := h.DB.ValidateOAuthState(ctx, state)
  165. if err != nil {
  166. log.Error("Unable to ValidateOAuthState: %s", err)
  167. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  168. }
  169. tokenResponse, err := h.oauthClient.exchangeOauthCode(ctx, code)
  170. if err != nil {
  171. log.Error("Unable to exchangeOauthCode: %s", err)
  172. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  173. }
  174. // Now that we have the access token, let's use it real quick to make sur
  175. // it really really works.
  176. tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
  177. if err != nil {
  178. log.Error("Unable to inspectOauthAccessToken: %s", err)
  179. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  180. }
  181. localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID, provider, clientID)
  182. if err != nil {
  183. log.Error("Unable to GetIDForRemoteUser: %s", err)
  184. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  185. }
  186. if localUserID != -1 {
  187. user, err := h.DB.GetUserByID(localUserID)
  188. if err != nil {
  189. log.Error("Unable to GetUserByID %d: %s", localUserID, err)
  190. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  191. }
  192. if err = loginOrFail(h.Store, w, r, user); err != nil {
  193. log.Error("Unable to loginOrFail %d: %s", localUserID, err)
  194. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  195. }
  196. return nil
  197. }
  198. displayName := tokenInfo.DisplayName
  199. if len(displayName) == 0 {
  200. displayName = tokenInfo.Username
  201. }
  202. tp := &oauthSignupPageParams{
  203. AccessToken: tokenResponse.AccessToken,
  204. TokenUsername: tokenInfo.Username,
  205. TokenAlias: tokenInfo.DisplayName,
  206. TokenEmail: tokenInfo.Email,
  207. TokenRemoteUser: tokenInfo.UserID,
  208. Provider: provider,
  209. ClientID: clientID,
  210. }
  211. tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed)
  212. return h.showOauthSignupPage(app, w, r, tp, nil)
  213. }
  214. func (r *callbackProxyClient) register(ctx context.Context, state string) error {
  215. form := url.Values{}
  216. form.Add("state", state)
  217. form.Add("location", r.callbackLocation)
  218. req, err := http.NewRequestWithContext(ctx, "POST", r.server, strings.NewReader(form.Encode()))
  219. if err != nil {
  220. return err
  221. }
  222. req.Header.Set("User-Agent", "writefreely")
  223. req.Header.Set("Accept", "application/json")
  224. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  225. resp, err := r.httpClient.Do(req)
  226. if err != nil {
  227. return err
  228. }
  229. if resp.StatusCode != http.StatusCreated {
  230. return fmt.Errorf("unable register state location: %d", resp.StatusCode)
  231. }
  232. return nil
  233. }
  234. func limitedJsonUnmarshal(body io.ReadCloser, n int, thing interface{}) error {
  235. lr := io.LimitReader(body, int64(n+1))
  236. data, err := ioutil.ReadAll(lr)
  237. if err != nil {
  238. return err
  239. }
  240. if len(data) == n+1 {
  241. return fmt.Errorf("content larger than max read allowance: %d", n)
  242. }
  243. return json.Unmarshal(data, thing)
  244. }
  245. func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error {
  246. // An error may be returned, but a valid session should always be returned.
  247. session, _ := store.Get(r, cookieName)
  248. session.Values[cookieUserVal] = user.Cookie()
  249. if err := session.Save(r, w); err != nil {
  250. fmt.Println("error saving session", err)
  251. return err
  252. }
  253. http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
  254. return nil
  255. }