A clean, Markdown-based publishing platform made for writers. Write together, and build a community. https://writefreely.org
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

449 lines
16 KiB

  1. /*
  2. * Copyright © 2019-2020 A Bunch Tell LLC.
  3. *
  4. * This file is part of WriteFreely.
  5. *
  6. * WriteFreely is free software: you can redistribute it and/or modify
  7. * it under the terms of the GNU Affero General Public License, included
  8. * in the LICENSE file in this source code package.
  9. */
  10. package writefreely
  11. import (
  12. "context"
  13. "encoding/json"
  14. "fmt"
  15. "io"
  16. "io/ioutil"
  17. "net/http"
  18. "net/url"
  19. "strings"
  20. "time"
  21. "github.com/gorilla/mux"
  22. "github.com/gorilla/sessions"
  23. "github.com/writeas/impart"
  24. "github.com/writeas/web-core/log"
  25. "github.com/writeas/writefreely/config"
  26. )
  27. // OAuthButtons holds display information for different OAuth providers we support.
  28. type OAuthButtons struct {
  29. SlackEnabled bool
  30. WriteAsEnabled bool
  31. GitLabEnabled bool
  32. GitLabDisplayName string
  33. }
  34. // NewOAuthButtons creates a new OAuthButtons struct based on our app configuration.
  35. func NewOAuthButtons(cfg *config.Config) *OAuthButtons {
  36. return &OAuthButtons{
  37. SlackEnabled: cfg.SlackOauth.ClientID != "",
  38. WriteAsEnabled: cfg.WriteAsOauth.ClientID != "",
  39. GitLabEnabled: cfg.GitlabOauth.ClientID != "",
  40. GitLabDisplayName: config.OrDefaultString(cfg.GitlabOauth.DisplayName, gitlabDisplayName),
  41. }
  42. }
  43. // TokenResponse contains data returned when a token is created either
  44. // through a code exchange or using a refresh token.
  45. type TokenResponse struct {
  46. AccessToken string `json:"access_token"`
  47. ExpiresIn int `json:"expires_in"`
  48. RefreshToken string `json:"refresh_token"`
  49. TokenType string `json:"token_type"`
  50. Error string `json:"error"`
  51. }
  52. // InspectResponse contains data returned when an access token is inspected.
  53. type InspectResponse struct {
  54. ClientID string `json:"client_id"`
  55. UserID string `json:"user_id"`
  56. ExpiresAt time.Time `json:"expires_at"`
  57. Username string `json:"username"`
  58. DisplayName string `json:"-"`
  59. Email string `json:"email"`
  60. Error string `json:"error"`
  61. }
  62. // tokenRequestMaxLen is the most bytes that we'll read from the /oauth/token
  63. // endpoint. One megabyte is plenty.
  64. const tokenRequestMaxLen = 1000000
  65. // infoRequestMaxLen is the most bytes that we'll read from the
  66. // /oauth/inspect endpoint.
  67. const infoRequestMaxLen = 1000000
  68. // OAuthDatastoreProvider provides a minimal interface of data store, config,
  69. // and session store for use with the oauth handlers.
  70. type OAuthDatastoreProvider interface {
  71. DB() OAuthDatastore
  72. Config() *config.Config
  73. SessionStore() sessions.Store
  74. }
  75. // OAuthDatastore provides a minimal interface of data store methods used in
  76. // oauth functionality.
  77. type OAuthDatastore interface {
  78. GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
  79. RecordRemoteUserID(context.Context, int64, string, string, string, string) error
  80. ValidateOAuthState(context.Context, string) (string, string, int64, string, error)
  81. GenerateOAuthState(context.Context, string, string, int64, string) (string, error)
  82. CreateUser(*config.Config, *User, string) error
  83. GetUserByID(int64) (*User, error)
  84. }
  85. type HttpClient interface {
  86. Do(req *http.Request) (*http.Response, error)
  87. }
  88. type oauthClient interface {
  89. GetProvider() string
  90. GetClientID() string
  91. GetCallbackLocation() string
  92. buildLoginURL(state string) (string, error)
  93. exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error)
  94. inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error)
  95. }
  96. type callbackProxyClient struct {
  97. server string
  98. callbackLocation string
  99. httpClient HttpClient
  100. }
  101. type oauthHandler struct {
  102. Config *config.Config
  103. DB OAuthDatastore
  104. Store sessions.Store
  105. EmailKey []byte
  106. oauthClient oauthClient
  107. callbackProxy *callbackProxyClient
  108. }
  109. func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error {
  110. ctx := r.Context()
  111. var attachUser int64
  112. if attach := r.URL.Query().Get("attach"); attach == "t" {
  113. user, _ := getUserAndSession(app, r)
  114. if user == nil {
  115. return impart.HTTPError{http.StatusInternalServerError, "cannot attach auth to user: user not found in session"}
  116. }
  117. attachUser = user.ID
  118. }
  119. state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID(), attachUser, r.FormValue("invite_code"))
  120. if err != nil {
  121. log.Error("viewOauthInit error: %s", err)
  122. return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
  123. }
  124. if h.callbackProxy != nil {
  125. if err := h.callbackProxy.register(ctx, state); err != nil {
  126. log.Error("viewOauthInit error: %s", err)
  127. return impart.HTTPError{http.StatusInternalServerError, "could not register state server"}
  128. }
  129. }
  130. location, err := h.oauthClient.buildLoginURL(state)
  131. if err != nil {
  132. log.Error("viewOauthInit error: %s", err)
  133. return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
  134. }
  135. return impart.HTTPError{http.StatusTemporaryRedirect, location}
  136. }
  137. func configureSlackOauth(parentHandler *Handler, r *mux.Router, app *App) {
  138. if app.Config().SlackOauth.ClientID != "" {
  139. callbackLocation := app.Config().App.Host + "/oauth/callback/slack"
  140. var stateRegisterClient *callbackProxyClient = nil
  141. if app.Config().SlackOauth.CallbackProxyAPI != "" {
  142. stateRegisterClient = &callbackProxyClient{
  143. server: app.Config().SlackOauth.CallbackProxyAPI,
  144. callbackLocation: app.Config().App.Host + "/oauth/callback/slack",
  145. httpClient: config.DefaultHTTPClient(),
  146. }
  147. callbackLocation = app.Config().SlackOauth.CallbackProxy
  148. }
  149. oauthClient := slackOauthClient{
  150. ClientID: app.Config().SlackOauth.ClientID,
  151. ClientSecret: app.Config().SlackOauth.ClientSecret,
  152. TeamID: app.Config().SlackOauth.TeamID,
  153. HttpClient: config.DefaultHTTPClient(),
  154. CallbackLocation: callbackLocation,
  155. }
  156. configureOauthRoutes(parentHandler, r, app, oauthClient, stateRegisterClient)
  157. }
  158. }
  159. func configureWriteAsOauth(parentHandler *Handler, r *mux.Router, app *App) {
  160. if app.Config().WriteAsOauth.ClientID != "" {
  161. callbackLocation := app.Config().App.Host + "/oauth/callback/write.as"
  162. var callbackProxy *callbackProxyClient = nil
  163. if app.Config().WriteAsOauth.CallbackProxy != "" {
  164. callbackProxy = &callbackProxyClient{
  165. server: app.Config().WriteAsOauth.CallbackProxyAPI,
  166. callbackLocation: app.Config().App.Host + "/oauth/callback/write.as",
  167. httpClient: config.DefaultHTTPClient(),
  168. }
  169. callbackLocation = app.Config().WriteAsOauth.CallbackProxy
  170. }
  171. oauthClient := writeAsOauthClient{
  172. ClientID: app.Config().WriteAsOauth.ClientID,
  173. ClientSecret: app.Config().WriteAsOauth.ClientSecret,
  174. ExchangeLocation: config.OrDefaultString(app.Config().WriteAsOauth.TokenLocation, writeAsExchangeLocation),
  175. InspectLocation: config.OrDefaultString(app.Config().WriteAsOauth.InspectLocation, writeAsIdentityLocation),
  176. AuthLocation: config.OrDefaultString(app.Config().WriteAsOauth.AuthLocation, writeAsAuthLocation),
  177. HttpClient: config.DefaultHTTPClient(),
  178. CallbackLocation: callbackLocation,
  179. }
  180. configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
  181. }
  182. }
  183. func configureGitlabOauth(parentHandler *Handler, r *mux.Router, app *App) {
  184. if app.Config().GitlabOauth.ClientID != "" {
  185. callbackLocation := app.Config().App.Host + "/oauth/callback/gitlab"
  186. var callbackProxy *callbackProxyClient = nil
  187. if app.Config().GitlabOauth.CallbackProxy != "" {
  188. callbackProxy = &callbackProxyClient{
  189. server: app.Config().GitlabOauth.CallbackProxyAPI,
  190. callbackLocation: app.Config().App.Host + "/oauth/callback/gitlab",
  191. httpClient: config.DefaultHTTPClient(),
  192. }
  193. callbackLocation = app.Config().GitlabOauth.CallbackProxy
  194. }
  195. address := config.OrDefaultString(app.Config().GitlabOauth.Host, gitlabHost)
  196. oauthClient := gitlabOauthClient{
  197. ClientID: app.Config().GitlabOauth.ClientID,
  198. ClientSecret: app.Config().GitlabOauth.ClientSecret,
  199. ExchangeLocation: address + "/oauth/token",
  200. InspectLocation: address + "/api/v4/user",
  201. AuthLocation: address + "/oauth/authorize",
  202. HttpClient: config.DefaultHTTPClient(),
  203. CallbackLocation: callbackLocation,
  204. }
  205. configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
  206. }
  207. }
  208. func configureGenericOauth(parentHandler *Handler, r *mux.Router, app *App) {
  209. if app.Config().GenericOauth.ClientID != "" {
  210. callbackLocation := app.Config().App.Host + "/oauth/callback/generic"
  211. var callbackProxy *callbackProxyClient = nil
  212. if app.Config().GenericOauth.CallbackProxy != "" {
  213. callbackProxy = &callbackProxyClient{
  214. server: app.Config().GenericOauth.CallbackProxyAPI,
  215. callbackLocation: app.Config().App.Host + "/oauth/callback/generic",
  216. httpClient: config.DefaultHTTPClient(),
  217. }
  218. callbackLocation = app.Config().GenericOauth.CallbackProxy
  219. }
  220. oauthClient := genericOauthClient{
  221. ClientID: app.Config().GenericOauth.ClientID,
  222. ClientSecret: app.Config().GenericOauth.ClientSecret,
  223. ExchangeLocation: app.Config().GenericOauth.Host + app.Config().GenericOauth.TokenEndpoint,
  224. InspectLocation: app.Config().GenericOauth.Host + app.Config().GenericOauth.InspectEndpoint,
  225. AuthLocation: app.Config().GenericOauth.Host + app.Config().GenericOauth.AuthEndpoint,
  226. HttpClient: config.DefaultHTTPClient(),
  227. CallbackLocation: callbackLocation,
  228. }
  229. configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
  230. }
  231. }
  232. func configureGiteaOauth(parentHandler *Handler, r *mux.Router, app *App) {
  233. if app.Config().GiteaOauth.ClientID != "" {
  234. callbackLocation := app.Config().App.Host + "/oauth/callback/gitea"
  235. var callbackProxy *callbackProxyClient = nil
  236. if app.Config().GiteaOauth.CallbackProxy != "" {
  237. callbackProxy = &callbackProxyClient{
  238. server: app.Config().GiteaOauth.CallbackProxyAPI,
  239. callbackLocation: app.Config().App.Host + "/oauth/callback/gitea",
  240. httpClient: config.DefaultHTTPClient(),
  241. }
  242. callbackLocation = app.Config().GiteaOauth.CallbackProxy
  243. }
  244. oauthClient := giteaOauthClient{
  245. ClientID: app.Config().GiteaOauth.ClientID,
  246. ClientSecret: app.Config().GiteaOauth.ClientSecret,
  247. ExchangeLocation: app.Config().GiteaOauth.Host + "/login/oauth/access_token",
  248. InspectLocation: app.Config().GiteaOauth.Host + "/api/v1/user",
  249. AuthLocation: app.Config().GiteaOauth.Host + "/login/oauth/authorize",
  250. HttpClient: config.DefaultHTTPClient(),
  251. CallbackLocation: callbackLocation,
  252. }
  253. configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
  254. }
  255. }
  256. func configureOauthRoutes(parentHandler *Handler, r *mux.Router, app *App, oauthClient oauthClient, callbackProxy *callbackProxyClient) {
  257. handler := &oauthHandler{
  258. Config: app.Config(),
  259. DB: app.DB(),
  260. Store: app.SessionStore(),
  261. oauthClient: oauthClient,
  262. EmailKey: app.keys.EmailKey,
  263. callbackProxy: callbackProxy,
  264. }
  265. r.HandleFunc("/oauth/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthInit)).Methods("GET")
  266. r.HandleFunc("/oauth/callback/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthCallback)).Methods("GET")
  267. r.HandleFunc("/oauth/signup", parentHandler.OAuth(handler.viewOauthSignup)).Methods("POST")
  268. }
  269. func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http.Request) error {
  270. ctx := r.Context()
  271. code := r.FormValue("code")
  272. state := r.FormValue("state")
  273. provider, clientID, attachUserID, inviteCode, err := h.DB.ValidateOAuthState(ctx, state)
  274. if err != nil {
  275. log.Error("Unable to ValidateOAuthState: %s", err)
  276. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  277. }
  278. tokenResponse, err := h.oauthClient.exchangeOauthCode(ctx, code)
  279. if err != nil {
  280. log.Error("Unable to exchangeOauthCode: %s", err)
  281. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  282. }
  283. // Now that we have the access token, let's use it real quick to make sure
  284. // it really really works.
  285. tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
  286. if err != nil {
  287. log.Error("Unable to inspectOauthAccessToken: %s", err)
  288. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  289. }
  290. localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID, provider, clientID)
  291. if err != nil {
  292. log.Error("Unable to GetIDForRemoteUser: %s", err)
  293. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  294. }
  295. if localUserID != -1 && attachUserID > 0 {
  296. if err = addSessionFlash(app, w, r, "This Slack account is already attached to another user.", nil); err != nil {
  297. return impart.HTTPError{Status: http.StatusInternalServerError, Message: err.Error()}
  298. }
  299. return impart.HTTPError{http.StatusFound, "/me/settings"}
  300. }
  301. if localUserID != -1 {
  302. // Existing user, so log in now
  303. user, err := h.DB.GetUserByID(localUserID)
  304. if err != nil {
  305. log.Error("Unable to GetUserByID %d: %s", localUserID, err)
  306. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  307. }
  308. if err = loginOrFail(h.Store, w, r, user); err != nil {
  309. log.Error("Unable to loginOrFail %d: %s", localUserID, err)
  310. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  311. }
  312. return nil
  313. }
  314. if attachUserID > 0 {
  315. log.Info("attaching to user %d", attachUserID)
  316. err = h.DB.RecordRemoteUserID(r.Context(), attachUserID, tokenInfo.UserID, provider, clientID, tokenResponse.AccessToken)
  317. if err != nil {
  318. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  319. }
  320. return impart.HTTPError{http.StatusFound, "/me/settings"}
  321. }
  322. // New user registration below.
  323. // First, verify that user is allowed to register
  324. if inviteCode != "" {
  325. // Verify invite code is valid
  326. i, err := app.db.GetUserInvite(inviteCode)
  327. if err != nil {
  328. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  329. }
  330. if !i.Active(app.db) {
  331. return impart.HTTPError{http.StatusNotFound, "Invite link has expired."}
  332. }
  333. } else if !app.cfg.App.OpenRegistration {
  334. addSessionFlash(app, w, r, ErrUserNotFound.Error(), nil)
  335. return impart.HTTPError{http.StatusFound, "/login"}
  336. }
  337. displayName := tokenInfo.DisplayName
  338. if len(displayName) == 0 {
  339. displayName = tokenInfo.Username
  340. }
  341. tp := &oauthSignupPageParams{
  342. AccessToken: tokenResponse.AccessToken,
  343. TokenUsername: tokenInfo.Username,
  344. TokenAlias: tokenInfo.DisplayName,
  345. TokenEmail: tokenInfo.Email,
  346. TokenRemoteUser: tokenInfo.UserID,
  347. Provider: provider,
  348. ClientID: clientID,
  349. InviteCode: inviteCode,
  350. }
  351. tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed)
  352. return h.showOauthSignupPage(app, w, r, tp, nil)
  353. }
  354. func (r *callbackProxyClient) register(ctx context.Context, state string) error {
  355. form := url.Values{}
  356. form.Add("state", state)
  357. form.Add("location", r.callbackLocation)
  358. req, err := http.NewRequestWithContext(ctx, "POST", r.server, strings.NewReader(form.Encode()))
  359. if err != nil {
  360. return err
  361. }
  362. req.Header.Set("User-Agent", ServerUserAgent(""))
  363. req.Header.Set("Accept", "application/json")
  364. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  365. resp, err := r.httpClient.Do(req)
  366. if err != nil {
  367. return err
  368. }
  369. if resp.StatusCode != http.StatusCreated {
  370. return fmt.Errorf("unable register state location: %d", resp.StatusCode)
  371. }
  372. return nil
  373. }
  374. func limitedJsonUnmarshal(body io.ReadCloser, n int, thing interface{}) error {
  375. lr := io.LimitReader(body, int64(n+1))
  376. data, err := ioutil.ReadAll(lr)
  377. if err != nil {
  378. return err
  379. }
  380. if len(data) == n+1 {
  381. return fmt.Errorf("content larger than max read allowance: %d", n)
  382. }
  383. return json.Unmarshal(data, thing)
  384. }
  385. func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error {
  386. // An error may be returned, but a valid session should always be returned.
  387. session, _ := store.Get(r, cookieName)
  388. session.Values[cookieUserVal] = user.Cookie()
  389. if err := session.Save(r, w); err != nil {
  390. fmt.Println("error saving session", err)
  391. return err
  392. }
  393. http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
  394. return nil
  395. }