A clean, Markdown-based publishing platform made for writers. Write together, and build a community. https://writefreely.org
Non puoi selezionare più di 25 argomenti Gli argomenti devono iniziare con una lettera o un numero, possono includere trattini ('-') e possono essere lunghi fino a 35 caratteri.
 
 
 
 
 

473 righe
17 KiB

  1. /*
  2. * Copyright © 2019-2021 Musing Studio LLC.
  3. *
  4. * This file is part of WriteFreely.
  5. *
  6. * WriteFreely is free software: you can redistribute it and/or modify
  7. * it under the terms of the GNU Affero General Public License, included
  8. * in the LICENSE file in this source code package.
  9. */
  10. package writefreely
  11. import (
  12. "context"
  13. "encoding/json"
  14. "fmt"
  15. "io"
  16. "net/http"
  17. "net/url"
  18. "strings"
  19. "time"
  20. "github.com/gorilla/mux"
  21. "github.com/gorilla/sessions"
  22. "github.com/writeas/impart"
  23. "github.com/writeas/web-core/log"
  24. "github.com/writefreely/writefreely/config"
  25. )
  26. // OAuthButtons holds display information for different OAuth providers we support.
  27. type OAuthButtons struct {
  28. SlackEnabled bool
  29. WriteAsEnabled bool
  30. GitLabEnabled bool
  31. GitLabDisplayName string
  32. GiteaEnabled bool
  33. GiteaDisplayName string
  34. GenericEnabled bool
  35. GenericDisplayName string
  36. }
  37. // NewOAuthButtons creates a new OAuthButtons struct based on our app configuration.
  38. func NewOAuthButtons(cfg *config.Config) *OAuthButtons {
  39. return &OAuthButtons{
  40. SlackEnabled: cfg.SlackOauth.ClientID != "",
  41. WriteAsEnabled: cfg.WriteAsOauth.ClientID != "",
  42. GitLabEnabled: cfg.GitlabOauth.ClientID != "",
  43. GitLabDisplayName: config.OrDefaultString(cfg.GitlabOauth.DisplayName, gitlabDisplayName),
  44. GiteaEnabled: cfg.GiteaOauth.ClientID != "",
  45. GiteaDisplayName: config.OrDefaultString(cfg.GiteaOauth.DisplayName, giteaDisplayName),
  46. GenericEnabled: cfg.GenericOauth.ClientID != "",
  47. GenericDisplayName: config.OrDefaultString(cfg.GenericOauth.DisplayName, genericOauthDisplayName),
  48. }
  49. }
  50. // TokenResponse contains data returned when a token is created either
  51. // through a code exchange or using a refresh token.
  52. type TokenResponse struct {
  53. AccessToken string `json:"access_token"`
  54. ExpiresIn int `json:"expires_in"`
  55. RefreshToken string `json:"refresh_token"`
  56. TokenType string `json:"token_type"`
  57. Error string `json:"error"`
  58. }
  59. // InspectResponse contains data returned when an access token is inspected.
  60. type InspectResponse struct {
  61. ClientID string `json:"client_id"`
  62. UserID string `json:"user_id"`
  63. ExpiresAt time.Time `json:"expires_at"`
  64. Username string `json:"username"`
  65. DisplayName string `json:"-"`
  66. Email string `json:"email"`
  67. Error string `json:"error"`
  68. }
  69. // tokenRequestMaxLen is the most bytes that we'll read from the /oauth/token
  70. // endpoint. One megabyte is plenty.
  71. const tokenRequestMaxLen = 1000000
  72. // infoRequestMaxLen is the most bytes that we'll read from the
  73. // /oauth/inspect endpoint.
  74. const infoRequestMaxLen = 1000000
  75. // OAuthDatastoreProvider provides a minimal interface of data store, config,
  76. // and session store for use with the oauth handlers.
  77. type OAuthDatastoreProvider interface {
  78. DB() OAuthDatastore
  79. Config() *config.Config
  80. SessionStore() sessions.Store
  81. }
  82. // OAuthDatastore provides a minimal interface of data store methods used in
  83. // oauth functionality.
  84. type OAuthDatastore interface {
  85. GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
  86. RecordRemoteUserID(context.Context, int64, string, string, string, string) error
  87. ValidateOAuthState(context.Context, string) (string, string, int64, string, error)
  88. GenerateOAuthState(context.Context, string, string, int64, string) (string, error)
  89. CreateUser(*config.Config, *User, string, string) error
  90. GetUserByID(int64) (*User, error)
  91. }
  92. type HttpClient interface {
  93. Do(req *http.Request) (*http.Response, error)
  94. }
  95. type oauthClient interface {
  96. GetProvider() string
  97. GetClientID() string
  98. GetCallbackLocation() string
  99. buildLoginURL(state string) (string, error)
  100. exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error)
  101. inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error)
  102. }
  103. type callbackProxyClient struct {
  104. server string
  105. callbackLocation string
  106. httpClient HttpClient
  107. }
  108. type oauthHandler struct {
  109. Config *config.Config
  110. DB OAuthDatastore
  111. Store sessions.Store
  112. EmailKey []byte
  113. oauthClient oauthClient
  114. callbackProxy *callbackProxyClient
  115. }
  116. func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error {
  117. ctx := r.Context()
  118. var attachUser int64
  119. if attach := r.URL.Query().Get("attach"); attach == "t" {
  120. user, _ := getUserAndSession(app, r)
  121. if user == nil {
  122. return impart.HTTPError{http.StatusInternalServerError, "cannot attach auth to user: user not found in session"}
  123. }
  124. attachUser = user.ID
  125. }
  126. state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID(), attachUser, r.FormValue("invite_code"))
  127. if err != nil {
  128. log.Error("viewOauthInit error: %s", err)
  129. return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
  130. }
  131. if h.callbackProxy != nil {
  132. if err := h.callbackProxy.register(ctx, state); err != nil {
  133. log.Error("viewOauthInit error: %s", err)
  134. return impart.HTTPError{http.StatusInternalServerError, "could not register state server"}
  135. }
  136. }
  137. location, err := h.oauthClient.buildLoginURL(state)
  138. if err != nil {
  139. log.Error("viewOauthInit error: %s", err)
  140. return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
  141. }
  142. return impart.HTTPError{http.StatusTemporaryRedirect, location}
  143. }
  144. func configureSlackOauth(parentHandler *Handler, r *mux.Router, app *App) {
  145. if app.Config().SlackOauth.ClientID != "" {
  146. callbackLocation := app.Config().App.Host + "/oauth/callback/slack"
  147. var stateRegisterClient *callbackProxyClient = nil
  148. if app.Config().SlackOauth.CallbackProxyAPI != "" {
  149. stateRegisterClient = &callbackProxyClient{
  150. server: app.Config().SlackOauth.CallbackProxyAPI,
  151. callbackLocation: app.Config().App.Host + "/oauth/callback/slack",
  152. httpClient: config.DefaultHTTPClient(),
  153. }
  154. callbackLocation = app.Config().SlackOauth.CallbackProxy
  155. }
  156. oauthClient := slackOauthClient{
  157. ClientID: app.Config().SlackOauth.ClientID,
  158. ClientSecret: app.Config().SlackOauth.ClientSecret,
  159. TeamID: app.Config().SlackOauth.TeamID,
  160. HttpClient: config.DefaultHTTPClient(),
  161. CallbackLocation: callbackLocation,
  162. }
  163. configureOauthRoutes(parentHandler, r, app, oauthClient, stateRegisterClient)
  164. }
  165. }
  166. func configureWriteAsOauth(parentHandler *Handler, r *mux.Router, app *App) {
  167. if app.Config().WriteAsOauth.ClientID != "" {
  168. callbackLocation := app.Config().App.Host + "/oauth/callback/write.as"
  169. var callbackProxy *callbackProxyClient = nil
  170. if app.Config().WriteAsOauth.CallbackProxy != "" {
  171. callbackProxy = &callbackProxyClient{
  172. server: app.Config().WriteAsOauth.CallbackProxyAPI,
  173. callbackLocation: app.Config().App.Host + "/oauth/callback/write.as",
  174. httpClient: config.DefaultHTTPClient(),
  175. }
  176. callbackLocation = app.Config().WriteAsOauth.CallbackProxy
  177. }
  178. oauthClient := writeAsOauthClient{
  179. ClientID: app.Config().WriteAsOauth.ClientID,
  180. ClientSecret: app.Config().WriteAsOauth.ClientSecret,
  181. ExchangeLocation: config.OrDefaultString(app.Config().WriteAsOauth.TokenLocation, writeAsExchangeLocation),
  182. InspectLocation: config.OrDefaultString(app.Config().WriteAsOauth.InspectLocation, writeAsIdentityLocation),
  183. AuthLocation: config.OrDefaultString(app.Config().WriteAsOauth.AuthLocation, writeAsAuthLocation),
  184. HttpClient: config.DefaultHTTPClient(),
  185. CallbackLocation: callbackLocation,
  186. }
  187. configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
  188. }
  189. }
  190. func configureGitlabOauth(parentHandler *Handler, r *mux.Router, app *App) {
  191. if app.Config().GitlabOauth.ClientID != "" {
  192. callbackLocation := app.Config().App.Host + "/oauth/callback/gitlab"
  193. var callbackProxy *callbackProxyClient = nil
  194. if app.Config().GitlabOauth.CallbackProxy != "" {
  195. callbackProxy = &callbackProxyClient{
  196. server: app.Config().GitlabOauth.CallbackProxyAPI,
  197. callbackLocation: app.Config().App.Host + "/oauth/callback/gitlab",
  198. httpClient: config.DefaultHTTPClient(),
  199. }
  200. callbackLocation = app.Config().GitlabOauth.CallbackProxy
  201. }
  202. address := config.OrDefaultString(app.Config().GitlabOauth.Host, gitlabHost)
  203. oauthClient := gitlabOauthClient{
  204. ClientID: app.Config().GitlabOauth.ClientID,
  205. ClientSecret: app.Config().GitlabOauth.ClientSecret,
  206. ExchangeLocation: address + "/oauth/token",
  207. InspectLocation: address + "/api/v4/user",
  208. AuthLocation: address + "/oauth/authorize",
  209. HttpClient: config.DefaultHTTPClient(),
  210. CallbackLocation: callbackLocation,
  211. }
  212. configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
  213. }
  214. }
  215. func configureGenericOauth(parentHandler *Handler, r *mux.Router, app *App) {
  216. if app.Config().GenericOauth.ClientID != "" {
  217. callbackLocation := app.Config().App.Host + "/oauth/callback/generic"
  218. var callbackProxy *callbackProxyClient = nil
  219. if app.Config().GenericOauth.CallbackProxy != "" {
  220. callbackProxy = &callbackProxyClient{
  221. server: app.Config().GenericOauth.CallbackProxyAPI,
  222. callbackLocation: app.Config().App.Host + "/oauth/callback/generic",
  223. httpClient: config.DefaultHTTPClient(),
  224. }
  225. callbackLocation = app.Config().GenericOauth.CallbackProxy
  226. }
  227. oauthClient := genericOauthClient{
  228. ClientID: app.Config().GenericOauth.ClientID,
  229. ClientSecret: app.Config().GenericOauth.ClientSecret,
  230. ExchangeLocation: app.Config().GenericOauth.Host + app.Config().GenericOauth.TokenEndpoint,
  231. InspectLocation: app.Config().GenericOauth.Host + app.Config().GenericOauth.InspectEndpoint,
  232. AuthLocation: app.Config().GenericOauth.Host + app.Config().GenericOauth.AuthEndpoint,
  233. HttpClient: config.DefaultHTTPClient(),
  234. CallbackLocation: callbackLocation,
  235. Scope: config.OrDefaultString(app.Config().GenericOauth.Scope, "read_user"),
  236. MapUserID: config.OrDefaultString(app.Config().GenericOauth.MapUserID, "user_id"),
  237. MapUsername: config.OrDefaultString(app.Config().GenericOauth.MapUsername, "username"),
  238. MapDisplayName: config.OrDefaultString(app.Config().GenericOauth.MapDisplayName, "-"),
  239. MapEmail: config.OrDefaultString(app.Config().GenericOauth.MapEmail, "email"),
  240. }
  241. configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
  242. }
  243. }
  244. func configureGiteaOauth(parentHandler *Handler, r *mux.Router, app *App) {
  245. if app.Config().GiteaOauth.ClientID != "" {
  246. callbackLocation := app.Config().App.Host + "/oauth/callback/gitea"
  247. var callbackProxy *callbackProxyClient = nil
  248. if app.Config().GiteaOauth.CallbackProxy != "" {
  249. callbackProxy = &callbackProxyClient{
  250. server: app.Config().GiteaOauth.CallbackProxyAPI,
  251. callbackLocation: app.Config().App.Host + "/oauth/callback/gitea",
  252. httpClient: config.DefaultHTTPClient(),
  253. }
  254. callbackLocation = app.Config().GiteaOauth.CallbackProxy
  255. }
  256. oauthClient := giteaOauthClient{
  257. ClientID: app.Config().GiteaOauth.ClientID,
  258. ClientSecret: app.Config().GiteaOauth.ClientSecret,
  259. ExchangeLocation: app.Config().GiteaOauth.Host + "/login/oauth/access_token",
  260. InspectLocation: app.Config().GiteaOauth.Host + "/login/oauth/userinfo",
  261. AuthLocation: app.Config().GiteaOauth.Host + "/login/oauth/authorize",
  262. HttpClient: config.DefaultHTTPClient(),
  263. CallbackLocation: callbackLocation,
  264. Scope: "openid profile email",
  265. MapUserID: "sub",
  266. MapUsername: "login",
  267. MapDisplayName: "full_name",
  268. MapEmail: "email",
  269. }
  270. configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
  271. }
  272. }
  273. func configureOauthRoutes(parentHandler *Handler, r *mux.Router, app *App, oauthClient oauthClient, callbackProxy *callbackProxyClient) {
  274. handler := &oauthHandler{
  275. Config: app.Config(),
  276. DB: app.DB(),
  277. Store: app.SessionStore(),
  278. oauthClient: oauthClient,
  279. EmailKey: app.keys.EmailKey,
  280. callbackProxy: callbackProxy,
  281. }
  282. r.HandleFunc("/oauth/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthInit)).Methods("GET")
  283. r.HandleFunc("/oauth/callback/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthCallback)).Methods("GET")
  284. r.HandleFunc("/oauth/signup", parentHandler.OAuth(handler.viewOauthSignup)).Methods("POST")
  285. }
  286. func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http.Request) error {
  287. ctx := r.Context()
  288. code := r.FormValue("code")
  289. state := r.FormValue("state")
  290. provider, clientID, attachUserID, inviteCode, err := h.DB.ValidateOAuthState(ctx, state)
  291. if err != nil {
  292. log.Error("Unable to ValidateOAuthState: %s", err)
  293. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  294. }
  295. tokenResponse, err := h.oauthClient.exchangeOauthCode(ctx, code)
  296. if err != nil {
  297. log.Error("Unable to exchangeOauthCode: %s", err)
  298. // TODO: show user friendly message if needed
  299. // TODO: show NO message for cases like user pressing "Cancel" on authorize step
  300. addSessionFlash(app, w, r, err.Error(), nil)
  301. if attachUserID > 0 {
  302. return impart.HTTPError{http.StatusFound, "/me/settings"}
  303. }
  304. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  305. }
  306. // Now that we have the access token, let's use it real quick to make sure
  307. // it really really works.
  308. tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
  309. if err != nil {
  310. log.Error("Unable to inspectOauthAccessToken: %s", err)
  311. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  312. }
  313. localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID, provider, clientID)
  314. if err != nil {
  315. log.Error("Unable to GetIDForRemoteUser: %s", err)
  316. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  317. }
  318. if localUserID != -1 && attachUserID > 0 {
  319. if err = addSessionFlash(app, w, r, "This OAuth account is already attached to another user.", nil); err != nil {
  320. return impart.HTTPError{Status: http.StatusInternalServerError, Message: err.Error()}
  321. }
  322. return impart.HTTPError{http.StatusFound, "/me/settings"}
  323. }
  324. if localUserID != -1 {
  325. // Existing user, so log in now
  326. user, err := h.DB.GetUserByID(localUserID)
  327. if err != nil {
  328. log.Error("Unable to GetUserByID %d: %s", localUserID, err)
  329. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  330. }
  331. if err = loginOrFail(h.Store, w, r, user); err != nil {
  332. log.Error("Unable to loginOrFail %d: %s", localUserID, err)
  333. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  334. }
  335. return nil
  336. }
  337. if attachUserID > 0 {
  338. log.Info("attaching to user %d", attachUserID)
  339. log.Info("OAuth userid: %s", tokenInfo.UserID)
  340. err = h.DB.RecordRemoteUserID(r.Context(), attachUserID, tokenInfo.UserID, provider, clientID, tokenResponse.AccessToken)
  341. if err != nil {
  342. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  343. }
  344. return impart.HTTPError{http.StatusFound, "/me/settings"}
  345. }
  346. // New user registration below.
  347. // First, verify that user is allowed to register
  348. if inviteCode != "" {
  349. // Verify invite code is valid
  350. i, err := app.db.GetUserInvite(inviteCode)
  351. if err != nil {
  352. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  353. }
  354. if !i.Active(app.db) {
  355. return impart.HTTPError{http.StatusNotFound, "Invite link has expired."}
  356. }
  357. } else if !app.cfg.App.OpenRegistration {
  358. addSessionFlash(app, w, r, ErrUserNotFound.Error(), nil)
  359. return impart.HTTPError{http.StatusFound, "/login"}
  360. }
  361. displayName := tokenInfo.DisplayName
  362. if len(displayName) == 0 {
  363. displayName = tokenInfo.Username
  364. }
  365. tp := &oauthSignupPageParams{
  366. AccessToken: tokenResponse.AccessToken,
  367. TokenUsername: tokenInfo.Username,
  368. TokenAlias: tokenInfo.DisplayName,
  369. TokenEmail: tokenInfo.Email,
  370. TokenRemoteUser: tokenInfo.UserID,
  371. Provider: provider,
  372. ClientID: clientID,
  373. InviteCode: inviteCode,
  374. }
  375. tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed)
  376. return h.showOauthSignupPage(app, w, r, tp, nil)
  377. }
  378. func (r *callbackProxyClient) register(ctx context.Context, state string) error {
  379. form := url.Values{}
  380. form.Add("state", state)
  381. form.Add("location", r.callbackLocation)
  382. req, err := http.NewRequestWithContext(ctx, "POST", r.server, strings.NewReader(form.Encode()))
  383. if err != nil {
  384. return err
  385. }
  386. req.Header.Set("User-Agent", ServerUserAgent(""))
  387. req.Header.Set("Accept", "application/json")
  388. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  389. resp, err := r.httpClient.Do(req)
  390. if err != nil {
  391. return err
  392. }
  393. if resp.StatusCode != http.StatusCreated {
  394. return fmt.Errorf("unable register state location: %d", resp.StatusCode)
  395. }
  396. return nil
  397. }
  398. func limitedJsonUnmarshal(body io.ReadCloser, n int, thing interface{}) error {
  399. lr := io.LimitReader(body, int64(n+1))
  400. data, err := io.ReadAll(lr)
  401. if err != nil {
  402. return err
  403. }
  404. if len(data) == n+1 {
  405. return fmt.Errorf("content larger than max read allowance: %d", n)
  406. }
  407. return json.Unmarshal(data, thing)
  408. }
  409. func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error {
  410. // An error may be returned, but a valid session should always be returned.
  411. session, _ := store.Get(r, cookieName)
  412. session.Values[cookieUserVal] = user.Cookie()
  413. if err := session.Save(r, w); err != nil {
  414. fmt.Println("error saving session", err)
  415. return err
  416. }
  417. http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
  418. return nil
  419. }