A clean, Markdown-based publishing platform made for writers. Write together, and build a community. https://writefreely.org
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

320 lines
11 KiB

  1. package writefreely
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/gorilla/mux"
  7. "github.com/gorilla/sessions"
  8. "github.com/writeas/impart"
  9. "github.com/writeas/web-core/log"
  10. "github.com/writeas/writefreely/config"
  11. "io"
  12. "io/ioutil"
  13. "net/http"
  14. "net/url"
  15. "strings"
  16. "time"
  17. )
  18. // TokenResponse contains data returned when a token is created either
  19. // through a code exchange or using a refresh token.
  20. type TokenResponse struct {
  21. AccessToken string `json:"access_token"`
  22. ExpiresIn int `json:"expires_in"`
  23. RefreshToken string `json:"refresh_token"`
  24. TokenType string `json:"token_type"`
  25. Error string `json:"error"`
  26. }
  27. // InspectResponse contains data returned when an access token is inspected.
  28. type InspectResponse struct {
  29. ClientID string `json:"client_id"`
  30. UserID string `json:"user_id"`
  31. ExpiresAt time.Time `json:"expires_at"`
  32. Username string `json:"username"`
  33. DisplayName string `json:"-"`
  34. Email string `json:"email"`
  35. Error string `json:"error"`
  36. }
  37. // tokenRequestMaxLen is the most bytes that we'll read from the /oauth/token
  38. // endpoint. One megabyte is plenty.
  39. const tokenRequestMaxLen = 1000000
  40. // infoRequestMaxLen is the most bytes that we'll read from the
  41. // /oauth/inspect endpoint.
  42. const infoRequestMaxLen = 1000000
  43. // OAuthDatastoreProvider provides a minimal interface of data store, config,
  44. // and session store for use with the oauth handlers.
  45. type OAuthDatastoreProvider interface {
  46. DB() OAuthDatastore
  47. Config() *config.Config
  48. SessionStore() sessions.Store
  49. }
  50. // OAuthDatastore provides a minimal interface of data store methods used in
  51. // oauth functionality.
  52. type OAuthDatastore interface {
  53. GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
  54. RecordRemoteUserID(context.Context, int64, string, string, string, string) error
  55. ValidateOAuthState(context.Context, string) (string, string, error)
  56. GenerateOAuthState(context.Context, string, string) (string, error)
  57. CreateUser(*config.Config, *User, string) error
  58. GetUserByID(int64) (*User, error)
  59. }
  60. type HttpClient interface {
  61. Do(req *http.Request) (*http.Response, error)
  62. }
  63. type oauthClient interface {
  64. GetProvider() string
  65. GetClientID() string
  66. GetCallbackLocation() string
  67. buildLoginURL(state string) (string, error)
  68. exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error)
  69. inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error)
  70. }
  71. type callbackProxyClient struct {
  72. server string
  73. callbackLocation string
  74. httpClient HttpClient
  75. }
  76. type oauthHandler struct {
  77. Config *config.Config
  78. DB OAuthDatastore
  79. Store sessions.Store
  80. EmailKey []byte
  81. oauthClient oauthClient
  82. callbackProxy *callbackProxyClient
  83. }
  84. func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error {
  85. ctx := r.Context()
  86. state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID())
  87. if err != nil {
  88. return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
  89. }
  90. if h.callbackProxy != nil {
  91. if err := h.callbackProxy.register(ctx, state); err != nil {
  92. return impart.HTTPError{http.StatusInternalServerError, "could not register state server"}
  93. }
  94. }
  95. location, err := h.oauthClient.buildLoginURL(state)
  96. if err != nil {
  97. return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
  98. }
  99. return impart.HTTPError{http.StatusTemporaryRedirect, location}
  100. }
  101. func configureSlackOauth(parentHandler *Handler, r *mux.Router, app *App) {
  102. if app.Config().SlackOauth.ClientID != "" {
  103. callbackLocation := app.Config().App.Host + "/oauth/callback/slack"
  104. var stateRegisterClient *callbackProxyClient = nil
  105. if app.Config().SlackOauth.CallbackProxyAPI != "" {
  106. stateRegisterClient = &callbackProxyClient{
  107. server: app.Config().SlackOauth.CallbackProxyAPI,
  108. callbackLocation: app.Config().App.Host + "/oauth/callback/slack",
  109. httpClient: config.DefaultHTTPClient(),
  110. }
  111. callbackLocation = app.Config().SlackOauth.CallbackProxy
  112. }
  113. oauthClient := slackOauthClient{
  114. ClientID: app.Config().SlackOauth.ClientID,
  115. ClientSecret: app.Config().SlackOauth.ClientSecret,
  116. TeamID: app.Config().SlackOauth.TeamID,
  117. HttpClient: config.DefaultHTTPClient(),
  118. CallbackLocation: callbackLocation,
  119. }
  120. configureOauthRoutes(parentHandler, r, app, oauthClient, stateRegisterClient)
  121. }
  122. }
  123. func configureWriteAsOauth(parentHandler *Handler, r *mux.Router, app *App) {
  124. if app.Config().WriteAsOauth.ClientID != "" {
  125. callbackLocation := app.Config().App.Host + "/oauth/callback/write.as"
  126. var callbackProxy *callbackProxyClient = nil
  127. if app.Config().WriteAsOauth.CallbackProxy != "" {
  128. callbackProxy = &callbackProxyClient{
  129. server: app.Config().WriteAsOauth.CallbackProxyAPI,
  130. callbackLocation: app.Config().App.Host + "/oauth/callback/write.as",
  131. httpClient: config.DefaultHTTPClient(),
  132. }
  133. callbackLocation = app.Config().WriteAsOauth.CallbackProxy
  134. }
  135. oauthClient := writeAsOauthClient{
  136. ClientID: app.Config().WriteAsOauth.ClientID,
  137. ClientSecret: app.Config().WriteAsOauth.ClientSecret,
  138. ExchangeLocation: config.OrDefaultString(app.Config().WriteAsOauth.TokenLocation, writeAsExchangeLocation),
  139. InspectLocation: config.OrDefaultString(app.Config().WriteAsOauth.InspectLocation, writeAsIdentityLocation),
  140. AuthLocation: config.OrDefaultString(app.Config().WriteAsOauth.AuthLocation, writeAsAuthLocation),
  141. HttpClient: config.DefaultHTTPClient(),
  142. CallbackLocation: callbackLocation,
  143. }
  144. configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
  145. }
  146. }
  147. func configureGitlabOauth(parentHandler *Handler, r *mux.Router, app *App) {
  148. if app.Config().GitlabOauth.ClientID != "" {
  149. callbackLocation := app.Config().App.Host + "/oauth/callback/gitlab"
  150. var callbackProxy *callbackProxyClient = nil
  151. if app.Config().GitlabOauth.CallbackProxy != "" {
  152. callbackProxy = &callbackProxyClient{
  153. server: app.Config().GitlabOauth.CallbackProxyAPI,
  154. callbackLocation: app.Config().App.Host + "/oauth/callback/gitlab",
  155. httpClient: config.DefaultHTTPClient(),
  156. }
  157. callbackLocation = app.Config().GitlabOauth.CallbackProxy
  158. }
  159. address := config.OrDefaultString(app.Config().GitlabOauth.Host, gitlabHost)
  160. oauthClient := gitlabOauthClient{
  161. ClientID: app.Config().GitlabOauth.ClientID,
  162. ClientSecret: app.Config().GitlabOauth.ClientSecret,
  163. ExchangeLocation: address + "/oauth/token",
  164. InspectLocation: address + "/api/v4/user",
  165. AuthLocation: address + "/oauth/authorize",
  166. HttpClient: config.DefaultHTTPClient(),
  167. CallbackLocation: callbackLocation,
  168. }
  169. configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
  170. }
  171. }
  172. func configureOauthRoutes(parentHandler *Handler, r *mux.Router, app *App, oauthClient oauthClient, callbackProxy *callbackProxyClient) {
  173. handler := &oauthHandler{
  174. Config: app.Config(),
  175. DB: app.DB(),
  176. Store: app.SessionStore(),
  177. oauthClient: oauthClient,
  178. EmailKey: app.keys.EmailKey,
  179. callbackProxy: callbackProxy,
  180. }
  181. r.HandleFunc("/oauth/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthInit)).Methods("GET")
  182. r.HandleFunc("/oauth/callback/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthCallback)).Methods("GET")
  183. r.HandleFunc("/oauth/signup", parentHandler.OAuth(handler.viewOauthSignup)).Methods("POST")
  184. }
  185. func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http.Request) error {
  186. ctx := r.Context()
  187. code := r.FormValue("code")
  188. state := r.FormValue("state")
  189. provider, clientID, err := h.DB.ValidateOAuthState(ctx, state)
  190. if err != nil {
  191. log.Error("Unable to ValidateOAuthState: %s", err)
  192. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  193. }
  194. tokenResponse, err := h.oauthClient.exchangeOauthCode(ctx, code)
  195. if err != nil {
  196. log.Error("Unable to exchangeOauthCode: %s", err)
  197. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  198. }
  199. // Now that we have the access token, let's use it real quick to make sur
  200. // it really really works.
  201. tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
  202. if err != nil {
  203. log.Error("Unable to inspectOauthAccessToken: %s", err)
  204. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  205. }
  206. localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID, provider, clientID)
  207. if err != nil {
  208. log.Error("Unable to GetIDForRemoteUser: %s", err)
  209. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  210. }
  211. if localUserID != -1 {
  212. user, err := h.DB.GetUserByID(localUserID)
  213. if err != nil {
  214. log.Error("Unable to GetUserByID %d: %s", localUserID, err)
  215. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  216. }
  217. if err = loginOrFail(h.Store, w, r, user); err != nil {
  218. log.Error("Unable to loginOrFail %d: %s", localUserID, err)
  219. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  220. }
  221. return nil
  222. }
  223. displayName := tokenInfo.DisplayName
  224. if len(displayName) == 0 {
  225. displayName = tokenInfo.Username
  226. }
  227. tp := &oauthSignupPageParams{
  228. AccessToken: tokenResponse.AccessToken,
  229. TokenUsername: tokenInfo.Username,
  230. TokenAlias: tokenInfo.DisplayName,
  231. TokenEmail: tokenInfo.Email,
  232. TokenRemoteUser: tokenInfo.UserID,
  233. Provider: provider,
  234. ClientID: clientID,
  235. }
  236. tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed)
  237. return h.showOauthSignupPage(app, w, r, tp, nil)
  238. }
  239. func (r *callbackProxyClient) register(ctx context.Context, state string) error {
  240. form := url.Values{}
  241. form.Add("state", state)
  242. form.Add("location", r.callbackLocation)
  243. req, err := http.NewRequestWithContext(ctx, "POST", r.server, strings.NewReader(form.Encode()))
  244. if err != nil {
  245. return err
  246. }
  247. req.Header.Set("User-Agent", "writefreely")
  248. req.Header.Set("Accept", "application/json")
  249. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  250. resp, err := r.httpClient.Do(req)
  251. if err != nil {
  252. return err
  253. }
  254. if resp.StatusCode != http.StatusCreated {
  255. return fmt.Errorf("unable register state location: %d", resp.StatusCode)
  256. }
  257. return nil
  258. }
  259. func limitedJsonUnmarshal(body io.ReadCloser, n int, thing interface{}) error {
  260. lr := io.LimitReader(body, int64(n+1))
  261. data, err := ioutil.ReadAll(lr)
  262. if err != nil {
  263. return err
  264. }
  265. if len(data) == n+1 {
  266. return fmt.Errorf("content larger than max read allowance: %d", n)
  267. }
  268. return json.Unmarshal(data, thing)
  269. }
  270. func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error {
  271. // An error may be returned, but a valid session should always be returned.
  272. session, _ := store.Get(r, cookieName)
  273. session.Values[cookieUserVal] = user.Cookie()
  274. if err := session.Save(r, w); err != nil {
  275. fmt.Println("error saving session", err)
  276. return err
  277. }
  278. http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
  279. return nil
  280. }