A clean, Markdown-based publishing platform made for writers. Write together, and build a community. https://writefreely.org
25개 이상의 토픽을 선택하실 수 없습니다. Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

474 lines
17 KiB

  1. /*
  2. * Copyright © 2019-2021 Musing Studio LLC.
  3. *
  4. * This file is part of WriteFreely.
  5. *
  6. * WriteFreely is free software: you can redistribute it and/or modify
  7. * it under the terms of the GNU Affero General Public License, included
  8. * in the LICENSE file in this source code package.
  9. */
  10. package writefreely
  11. import (
  12. "context"
  13. "encoding/json"
  14. "fmt"
  15. "io"
  16. "io/ioutil"
  17. "net/http"
  18. "net/url"
  19. "strings"
  20. "time"
  21. "github.com/gorilla/mux"
  22. "github.com/gorilla/sessions"
  23. "github.com/writeas/impart"
  24. "github.com/writeas/web-core/log"
  25. "github.com/writefreely/writefreely/config"
  26. )
  27. // OAuthButtons holds display information for different OAuth providers we support.
  28. type OAuthButtons struct {
  29. SlackEnabled bool
  30. WriteAsEnabled bool
  31. GitLabEnabled bool
  32. GitLabDisplayName string
  33. GiteaEnabled bool
  34. GiteaDisplayName string
  35. GenericEnabled bool
  36. GenericDisplayName string
  37. }
  38. // NewOAuthButtons creates a new OAuthButtons struct based on our app configuration.
  39. func NewOAuthButtons(cfg *config.Config) *OAuthButtons {
  40. return &OAuthButtons{
  41. SlackEnabled: cfg.SlackOauth.ClientID != "",
  42. WriteAsEnabled: cfg.WriteAsOauth.ClientID != "",
  43. GitLabEnabled: cfg.GitlabOauth.ClientID != "",
  44. GitLabDisplayName: config.OrDefaultString(cfg.GitlabOauth.DisplayName, gitlabDisplayName),
  45. GiteaEnabled: cfg.GiteaOauth.ClientID != "",
  46. GiteaDisplayName: config.OrDefaultString(cfg.GiteaOauth.DisplayName, giteaDisplayName),
  47. GenericEnabled: cfg.GenericOauth.ClientID != "",
  48. GenericDisplayName: config.OrDefaultString(cfg.GenericOauth.DisplayName, genericOauthDisplayName),
  49. }
  50. }
  51. // TokenResponse contains data returned when a token is created either
  52. // through a code exchange or using a refresh token.
  53. type TokenResponse struct {
  54. AccessToken string `json:"access_token"`
  55. ExpiresIn int `json:"expires_in"`
  56. RefreshToken string `json:"refresh_token"`
  57. TokenType string `json:"token_type"`
  58. Error string `json:"error"`
  59. }
  60. // InspectResponse contains data returned when an access token is inspected.
  61. type InspectResponse struct {
  62. ClientID string `json:"client_id"`
  63. UserID string `json:"user_id"`
  64. ExpiresAt time.Time `json:"expires_at"`
  65. Username string `json:"username"`
  66. DisplayName string `json:"-"`
  67. Email string `json:"email"`
  68. Error string `json:"error"`
  69. }
  70. // tokenRequestMaxLen is the most bytes that we'll read from the /oauth/token
  71. // endpoint. One megabyte is plenty.
  72. const tokenRequestMaxLen = 1000000
  73. // infoRequestMaxLen is the most bytes that we'll read from the
  74. // /oauth/inspect endpoint.
  75. const infoRequestMaxLen = 1000000
  76. // OAuthDatastoreProvider provides a minimal interface of data store, config,
  77. // and session store for use with the oauth handlers.
  78. type OAuthDatastoreProvider interface {
  79. DB() OAuthDatastore
  80. Config() *config.Config
  81. SessionStore() sessions.Store
  82. }
  83. // OAuthDatastore provides a minimal interface of data store methods used in
  84. // oauth functionality.
  85. type OAuthDatastore interface {
  86. GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
  87. RecordRemoteUserID(context.Context, int64, string, string, string, string) error
  88. ValidateOAuthState(context.Context, string) (string, string, int64, string, error)
  89. GenerateOAuthState(context.Context, string, string, int64, string) (string, error)
  90. CreateUser(*config.Config, *User, string, string) error
  91. GetUserByID(int64) (*User, error)
  92. }
  93. type HttpClient interface {
  94. Do(req *http.Request) (*http.Response, error)
  95. }
  96. type oauthClient interface {
  97. GetProvider() string
  98. GetClientID() string
  99. GetCallbackLocation() string
  100. buildLoginURL(state string) (string, error)
  101. exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error)
  102. inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error)
  103. }
  104. type callbackProxyClient struct {
  105. server string
  106. callbackLocation string
  107. httpClient HttpClient
  108. }
  109. type oauthHandler struct {
  110. Config *config.Config
  111. DB OAuthDatastore
  112. Store sessions.Store
  113. EmailKey []byte
  114. oauthClient oauthClient
  115. callbackProxy *callbackProxyClient
  116. }
  117. func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error {
  118. ctx := r.Context()
  119. var attachUser int64
  120. if attach := r.URL.Query().Get("attach"); attach == "t" {
  121. user, _ := getUserAndSession(app, r)
  122. if user == nil {
  123. return impart.HTTPError{http.StatusInternalServerError, "cannot attach auth to user: user not found in session"}
  124. }
  125. attachUser = user.ID
  126. }
  127. state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID(), attachUser, r.FormValue("invite_code"))
  128. if err != nil {
  129. log.Error("viewOauthInit error: %s", err)
  130. return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
  131. }
  132. if h.callbackProxy != nil {
  133. if err := h.callbackProxy.register(ctx, state); err != nil {
  134. log.Error("viewOauthInit error: %s", err)
  135. return impart.HTTPError{http.StatusInternalServerError, "could not register state server"}
  136. }
  137. }
  138. location, err := h.oauthClient.buildLoginURL(state)
  139. if err != nil {
  140. log.Error("viewOauthInit error: %s", err)
  141. return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
  142. }
  143. return impart.HTTPError{http.StatusTemporaryRedirect, location}
  144. }
  145. func configureSlackOauth(parentHandler *Handler, r *mux.Router, app *App) {
  146. if app.Config().SlackOauth.ClientID != "" {
  147. callbackLocation := app.Config().App.Host + "/oauth/callback/slack"
  148. var stateRegisterClient *callbackProxyClient = nil
  149. if app.Config().SlackOauth.CallbackProxyAPI != "" {
  150. stateRegisterClient = &callbackProxyClient{
  151. server: app.Config().SlackOauth.CallbackProxyAPI,
  152. callbackLocation: app.Config().App.Host + "/oauth/callback/slack",
  153. httpClient: config.DefaultHTTPClient(),
  154. }
  155. callbackLocation = app.Config().SlackOauth.CallbackProxy
  156. }
  157. oauthClient := slackOauthClient{
  158. ClientID: app.Config().SlackOauth.ClientID,
  159. ClientSecret: app.Config().SlackOauth.ClientSecret,
  160. TeamID: app.Config().SlackOauth.TeamID,
  161. HttpClient: config.DefaultHTTPClient(),
  162. CallbackLocation: callbackLocation,
  163. }
  164. configureOauthRoutes(parentHandler, r, app, oauthClient, stateRegisterClient)
  165. }
  166. }
  167. func configureWriteAsOauth(parentHandler *Handler, r *mux.Router, app *App) {
  168. if app.Config().WriteAsOauth.ClientID != "" {
  169. callbackLocation := app.Config().App.Host + "/oauth/callback/write.as"
  170. var callbackProxy *callbackProxyClient = nil
  171. if app.Config().WriteAsOauth.CallbackProxy != "" {
  172. callbackProxy = &callbackProxyClient{
  173. server: app.Config().WriteAsOauth.CallbackProxyAPI,
  174. callbackLocation: app.Config().App.Host + "/oauth/callback/write.as",
  175. httpClient: config.DefaultHTTPClient(),
  176. }
  177. callbackLocation = app.Config().WriteAsOauth.CallbackProxy
  178. }
  179. oauthClient := writeAsOauthClient{
  180. ClientID: app.Config().WriteAsOauth.ClientID,
  181. ClientSecret: app.Config().WriteAsOauth.ClientSecret,
  182. ExchangeLocation: config.OrDefaultString(app.Config().WriteAsOauth.TokenLocation, writeAsExchangeLocation),
  183. InspectLocation: config.OrDefaultString(app.Config().WriteAsOauth.InspectLocation, writeAsIdentityLocation),
  184. AuthLocation: config.OrDefaultString(app.Config().WriteAsOauth.AuthLocation, writeAsAuthLocation),
  185. HttpClient: config.DefaultHTTPClient(),
  186. CallbackLocation: callbackLocation,
  187. }
  188. configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
  189. }
  190. }
  191. func configureGitlabOauth(parentHandler *Handler, r *mux.Router, app *App) {
  192. if app.Config().GitlabOauth.ClientID != "" {
  193. callbackLocation := app.Config().App.Host + "/oauth/callback/gitlab"
  194. var callbackProxy *callbackProxyClient = nil
  195. if app.Config().GitlabOauth.CallbackProxy != "" {
  196. callbackProxy = &callbackProxyClient{
  197. server: app.Config().GitlabOauth.CallbackProxyAPI,
  198. callbackLocation: app.Config().App.Host + "/oauth/callback/gitlab",
  199. httpClient: config.DefaultHTTPClient(),
  200. }
  201. callbackLocation = app.Config().GitlabOauth.CallbackProxy
  202. }
  203. address := config.OrDefaultString(app.Config().GitlabOauth.Host, gitlabHost)
  204. oauthClient := gitlabOauthClient{
  205. ClientID: app.Config().GitlabOauth.ClientID,
  206. ClientSecret: app.Config().GitlabOauth.ClientSecret,
  207. ExchangeLocation: address + "/oauth/token",
  208. InspectLocation: address + "/api/v4/user",
  209. AuthLocation: address + "/oauth/authorize",
  210. HttpClient: config.DefaultHTTPClient(),
  211. CallbackLocation: callbackLocation,
  212. }
  213. configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
  214. }
  215. }
  216. func configureGenericOauth(parentHandler *Handler, r *mux.Router, app *App) {
  217. if app.Config().GenericOauth.ClientID != "" {
  218. callbackLocation := app.Config().App.Host + "/oauth/callback/generic"
  219. var callbackProxy *callbackProxyClient = nil
  220. if app.Config().GenericOauth.CallbackProxy != "" {
  221. callbackProxy = &callbackProxyClient{
  222. server: app.Config().GenericOauth.CallbackProxyAPI,
  223. callbackLocation: app.Config().App.Host + "/oauth/callback/generic",
  224. httpClient: config.DefaultHTTPClient(),
  225. }
  226. callbackLocation = app.Config().GenericOauth.CallbackProxy
  227. }
  228. oauthClient := genericOauthClient{
  229. ClientID: app.Config().GenericOauth.ClientID,
  230. ClientSecret: app.Config().GenericOauth.ClientSecret,
  231. ExchangeLocation: app.Config().GenericOauth.Host + app.Config().GenericOauth.TokenEndpoint,
  232. InspectLocation: app.Config().GenericOauth.Host + app.Config().GenericOauth.InspectEndpoint,
  233. AuthLocation: app.Config().GenericOauth.Host + app.Config().GenericOauth.AuthEndpoint,
  234. HttpClient: config.DefaultHTTPClient(),
  235. CallbackLocation: callbackLocation,
  236. Scope: config.OrDefaultString(app.Config().GenericOauth.Scope, "read_user"),
  237. MapUserID: config.OrDefaultString(app.Config().GenericOauth.MapUserID, "user_id"),
  238. MapUsername: config.OrDefaultString(app.Config().GenericOauth.MapUsername, "username"),
  239. MapDisplayName: config.OrDefaultString(app.Config().GenericOauth.MapDisplayName, "-"),
  240. MapEmail: config.OrDefaultString(app.Config().GenericOauth.MapEmail, "email"),
  241. }
  242. configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
  243. }
  244. }
  245. func configureGiteaOauth(parentHandler *Handler, r *mux.Router, app *App) {
  246. if app.Config().GiteaOauth.ClientID != "" {
  247. callbackLocation := app.Config().App.Host + "/oauth/callback/gitea"
  248. var callbackProxy *callbackProxyClient = nil
  249. if app.Config().GiteaOauth.CallbackProxy != "" {
  250. callbackProxy = &callbackProxyClient{
  251. server: app.Config().GiteaOauth.CallbackProxyAPI,
  252. callbackLocation: app.Config().App.Host + "/oauth/callback/gitea",
  253. httpClient: config.DefaultHTTPClient(),
  254. }
  255. callbackLocation = app.Config().GiteaOauth.CallbackProxy
  256. }
  257. oauthClient := giteaOauthClient{
  258. ClientID: app.Config().GiteaOauth.ClientID,
  259. ClientSecret: app.Config().GiteaOauth.ClientSecret,
  260. ExchangeLocation: app.Config().GiteaOauth.Host + "/login/oauth/access_token",
  261. InspectLocation: app.Config().GiteaOauth.Host + "/login/oauth/userinfo",
  262. AuthLocation: app.Config().GiteaOauth.Host + "/login/oauth/authorize",
  263. HttpClient: config.DefaultHTTPClient(),
  264. CallbackLocation: callbackLocation,
  265. Scope: "openid profile email",
  266. MapUserID: "sub",
  267. MapUsername: "login",
  268. MapDisplayName: "full_name",
  269. MapEmail: "email",
  270. }
  271. configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
  272. }
  273. }
  274. func configureOauthRoutes(parentHandler *Handler, r *mux.Router, app *App, oauthClient oauthClient, callbackProxy *callbackProxyClient) {
  275. handler := &oauthHandler{
  276. Config: app.Config(),
  277. DB: app.DB(),
  278. Store: app.SessionStore(),
  279. oauthClient: oauthClient,
  280. EmailKey: app.keys.EmailKey,
  281. callbackProxy: callbackProxy,
  282. }
  283. r.HandleFunc("/oauth/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthInit)).Methods("GET")
  284. r.HandleFunc("/oauth/callback/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthCallback)).Methods("GET")
  285. r.HandleFunc("/oauth/signup", parentHandler.OAuth(handler.viewOauthSignup)).Methods("POST")
  286. }
  287. func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http.Request) error {
  288. ctx := r.Context()
  289. code := r.FormValue("code")
  290. state := r.FormValue("state")
  291. provider, clientID, attachUserID, inviteCode, err := h.DB.ValidateOAuthState(ctx, state)
  292. if err != nil {
  293. log.Error("Unable to ValidateOAuthState: %s", err)
  294. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  295. }
  296. tokenResponse, err := h.oauthClient.exchangeOauthCode(ctx, code)
  297. if err != nil {
  298. log.Error("Unable to exchangeOauthCode: %s", err)
  299. // TODO: show user friendly message if needed
  300. // TODO: show NO message for cases like user pressing "Cancel" on authorize step
  301. addSessionFlash(app, w, r, err.Error(), nil)
  302. if attachUserID > 0 {
  303. return impart.HTTPError{http.StatusFound, "/me/settings"}
  304. }
  305. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  306. }
  307. // Now that we have the access token, let's use it real quick to make sure
  308. // it really really works.
  309. tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
  310. if err != nil {
  311. log.Error("Unable to inspectOauthAccessToken: %s", err)
  312. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  313. }
  314. localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID, provider, clientID)
  315. if err != nil {
  316. log.Error("Unable to GetIDForRemoteUser: %s", err)
  317. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  318. }
  319. if localUserID != -1 && attachUserID > 0 {
  320. if err = addSessionFlash(app, w, r, "This OAuth account is already attached to another user.", nil); err != nil {
  321. return impart.HTTPError{Status: http.StatusInternalServerError, Message: err.Error()}
  322. }
  323. return impart.HTTPError{http.StatusFound, "/me/settings"}
  324. }
  325. if localUserID != -1 {
  326. // Existing user, so log in now
  327. user, err := h.DB.GetUserByID(localUserID)
  328. if err != nil {
  329. log.Error("Unable to GetUserByID %d: %s", localUserID, err)
  330. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  331. }
  332. if err = loginOrFail(h.Store, w, r, user); err != nil {
  333. log.Error("Unable to loginOrFail %d: %s", localUserID, err)
  334. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  335. }
  336. return nil
  337. }
  338. if attachUserID > 0 {
  339. log.Info("attaching to user %d", attachUserID)
  340. log.Info("OAuth userid: %s", tokenInfo.UserID)
  341. err = h.DB.RecordRemoteUserID(r.Context(), attachUserID, tokenInfo.UserID, provider, clientID, tokenResponse.AccessToken)
  342. if err != nil {
  343. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  344. }
  345. return impart.HTTPError{http.StatusFound, "/me/settings"}
  346. }
  347. // New user registration below.
  348. // First, verify that user is allowed to register
  349. if inviteCode != "" {
  350. // Verify invite code is valid
  351. i, err := app.db.GetUserInvite(inviteCode)
  352. if err != nil {
  353. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  354. }
  355. if !i.Active(app.db) {
  356. return impart.HTTPError{http.StatusNotFound, "Invite link has expired."}
  357. }
  358. } else if !app.cfg.App.OpenRegistration {
  359. addSessionFlash(app, w, r, ErrUserNotFound.Error(), nil)
  360. return impart.HTTPError{http.StatusFound, "/login"}
  361. }
  362. displayName := tokenInfo.DisplayName
  363. if len(displayName) == 0 {
  364. displayName = tokenInfo.Username
  365. }
  366. tp := &oauthSignupPageParams{
  367. AccessToken: tokenResponse.AccessToken,
  368. TokenUsername: tokenInfo.Username,
  369. TokenAlias: tokenInfo.DisplayName,
  370. TokenEmail: tokenInfo.Email,
  371. TokenRemoteUser: tokenInfo.UserID,
  372. Provider: provider,
  373. ClientID: clientID,
  374. InviteCode: inviteCode,
  375. }
  376. tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed)
  377. return h.showOauthSignupPage(app, w, r, tp, nil)
  378. }
  379. func (r *callbackProxyClient) register(ctx context.Context, state string) error {
  380. form := url.Values{}
  381. form.Add("state", state)
  382. form.Add("location", r.callbackLocation)
  383. req, err := http.NewRequestWithContext(ctx, "POST", r.server, strings.NewReader(form.Encode()))
  384. if err != nil {
  385. return err
  386. }
  387. req.Header.Set("User-Agent", ServerUserAgent(""))
  388. req.Header.Set("Accept", "application/json")
  389. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  390. resp, err := r.httpClient.Do(req)
  391. if err != nil {
  392. return err
  393. }
  394. if resp.StatusCode != http.StatusCreated {
  395. return fmt.Errorf("unable register state location: %d", resp.StatusCode)
  396. }
  397. return nil
  398. }
  399. func limitedJsonUnmarshal(body io.ReadCloser, n int, thing interface{}) error {
  400. lr := io.LimitReader(body, int64(n+1))
  401. data, err := ioutil.ReadAll(lr)
  402. if err != nil {
  403. return err
  404. }
  405. if len(data) == n+1 {
  406. return fmt.Errorf("content larger than max read allowance: %d", n)
  407. }
  408. return json.Unmarshal(data, thing)
  409. }
  410. func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error {
  411. // An error may be returned, but a valid session should always be returned.
  412. session, _ := store.Get(r, cookieName)
  413. session.Values[cookieUserVal] = user.Cookie()
  414. if err := session.Save(r, w); err != nil {
  415. fmt.Println("error saving session", err)
  416. return err
  417. }
  418. http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
  419. return nil
  420. }