A clean, Markdown-based publishing platform made for writers. Write together, and build a community. https://writefreely.org
選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。
 
 
 
 
 

319 行
11 KiB

  1. package writefreely
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/gorilla/mux"
  7. "github.com/gorilla/sessions"
  8. "github.com/writeas/impart"
  9. "github.com/writeas/web-core/log"
  10. "github.com/writeas/writefreely/config"
  11. "io"
  12. "io/ioutil"
  13. "net/http"
  14. "net/url"
  15. "strings"
  16. "time"
  17. )
  18. // TokenResponse contains data returned when a token is created either
  19. // through a code exchange or using a refresh token.
  20. type TokenResponse struct {
  21. AccessToken string `json:"access_token"`
  22. ExpiresIn int `json:"expires_in"`
  23. RefreshToken string `json:"refresh_token"`
  24. TokenType string `json:"token_type"`
  25. Error string `json:"error"`
  26. }
  27. // InspectResponse contains data returned when an access token is inspected.
  28. type InspectResponse struct {
  29. ClientID string `json:"client_id"`
  30. UserID string `json:"user_id"`
  31. ExpiresAt time.Time `json:"expires_at"`
  32. Username string `json:"username"`
  33. DisplayName string `json:"-"`
  34. Email string `json:"email"`
  35. Error string `json:"error"`
  36. }
  37. // tokenRequestMaxLen is the most bytes that we'll read from the /oauth/token
  38. // endpoint. One megabyte is plenty.
  39. const tokenRequestMaxLen = 1000000
  40. // infoRequestMaxLen is the most bytes that we'll read from the
  41. // /oauth/inspect endpoint.
  42. const infoRequestMaxLen = 1000000
  43. // OAuthDatastoreProvider provides a minimal interface of data store, config,
  44. // and session store for use with the oauth handlers.
  45. type OAuthDatastoreProvider interface {
  46. DB() OAuthDatastore
  47. Config() *config.Config
  48. SessionStore() sessions.Store
  49. }
  50. // OAuthDatastore provides a minimal interface of data store methods used in
  51. // oauth functionality.
  52. type OAuthDatastore interface {
  53. GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
  54. RecordRemoteUserID(context.Context, int64, string, string, string, string) error
  55. ValidateOAuthState(context.Context, string) (string, string, error)
  56. GenerateOAuthState(context.Context, string, string) (string, error)
  57. CreateUser(*config.Config, *User, string) error
  58. GetUserByID(int64) (*User, error)
  59. }
  60. type HttpClient interface {
  61. Do(req *http.Request) (*http.Response, error)
  62. }
  63. type oauthClient interface {
  64. GetProvider() string
  65. GetClientID() string
  66. GetCallbackLocation() string
  67. buildLoginURL(state string) (string, error)
  68. exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error)
  69. inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error)
  70. }
  71. type callbackProxyClient struct {
  72. server string
  73. callbackLocation string
  74. httpClient HttpClient
  75. }
  76. type oauthHandler struct {
  77. Config *config.Config
  78. DB OAuthDatastore
  79. Store sessions.Store
  80. EmailKey []byte
  81. oauthClient oauthClient
  82. callbackProxy *callbackProxyClient
  83. }
  84. func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error {
  85. ctx := r.Context()
  86. state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID())
  87. if err != nil {
  88. return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
  89. }
  90. if h.callbackProxy != nil {
  91. if err := h.callbackProxy.register(ctx, state); err != nil {
  92. return impart.HTTPError{http.StatusInternalServerError, "could not register state server"}
  93. }
  94. }
  95. location, err := h.oauthClient.buildLoginURL(state)
  96. if err != nil {
  97. return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
  98. }
  99. return impart.HTTPError{http.StatusTemporaryRedirect, location}
  100. }
  101. func configureSlackOauth(parentHandler *Handler, r *mux.Router, app *App) {
  102. if app.Config().SlackOauth.ClientID != "" {
  103. callbackLocation := app.Config().App.Host + "/oauth/callback/slack"
  104. var stateRegisterClient *callbackProxyClient = nil
  105. if app.Config().SlackOauth.CallbackProxyAPI != "" {
  106. stateRegisterClient = &callbackProxyClient{
  107. server: app.Config().SlackOauth.CallbackProxyAPI,
  108. callbackLocation: app.Config().App.Host + "/oauth/callback/slack",
  109. httpClient: config.DefaultHTTPClient(),
  110. }
  111. callbackLocation = app.Config().SlackOauth.CallbackProxy
  112. }
  113. oauthClient := slackOauthClient{
  114. ClientID: app.Config().SlackOauth.ClientID,
  115. ClientSecret: app.Config().SlackOauth.ClientSecret,
  116. TeamID: app.Config().SlackOauth.TeamID,
  117. HttpClient: config.DefaultHTTPClient(),
  118. CallbackLocation: callbackLocation,
  119. }
  120. configureOauthRoutes(parentHandler, r, app, oauthClient, stateRegisterClient)
  121. }
  122. }
  123. func configureWriteAsOauth(parentHandler *Handler, r *mux.Router, app *App) {
  124. if app.Config().WriteAsOauth.ClientID != "" {
  125. callbackLocation := app.Config().App.Host + "/oauth/callback/write.as"
  126. var callbackProxy *callbackProxyClient = nil
  127. if app.Config().WriteAsOauth.CallbackProxy != "" {
  128. callbackProxy = &callbackProxyClient{
  129. server: app.Config().WriteAsOauth.CallbackProxyAPI,
  130. callbackLocation: app.Config().App.Host + "/oauth/callback/write.as",
  131. httpClient: config.DefaultHTTPClient(),
  132. }
  133. callbackLocation = app.Config().WriteAsOauth.CallbackProxy
  134. }
  135. oauthClient := writeAsOauthClient{
  136. ClientID: app.Config().WriteAsOauth.ClientID,
  137. ClientSecret: app.Config().WriteAsOauth.ClientSecret,
  138. ExchangeLocation: config.OrDefaultString(app.Config().WriteAsOauth.TokenLocation, writeAsExchangeLocation),
  139. InspectLocation: config.OrDefaultString(app.Config().WriteAsOauth.InspectLocation, writeAsIdentityLocation),
  140. AuthLocation: config.OrDefaultString(app.Config().WriteAsOauth.AuthLocation, writeAsAuthLocation),
  141. HttpClient: config.DefaultHTTPClient(),
  142. CallbackLocation: callbackLocation,
  143. }
  144. configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
  145. }
  146. }
  147. func configureGitlabOauth(parentHandler *Handler, r *mux.Router, app *App) {
  148. if app.Config().GitlabOauth.ClientID != "" {
  149. callbackLocation := app.Config().App.Host + "/oauth/callback/gitlab"
  150. var callbackProxy *callbackProxyClient = nil
  151. if app.Config().GitlabOauth.CallbackProxy != "" {
  152. callbackProxy = &callbackProxyClient{
  153. server: app.Config().GitlabOauth.CallbackProxyAPI,
  154. callbackLocation: app.Config().App.Host + "/oauth/callback/gitlab",
  155. httpClient: config.DefaultHTTPClient(),
  156. }
  157. callbackLocation = app.Config().GitlabOauth.CallbackProxy
  158. }
  159. oauthClient := gitlabOauthClient{
  160. ClientID: app.Config().GitlabOauth.ClientID,
  161. ClientSecret: app.Config().GitlabOauth.ClientSecret,
  162. ExchangeLocation: config.OrDefaultString(app.Config().GitlabOauth.TokenLocation, writeAsExchangeLocation),
  163. InspectLocation: config.OrDefaultString(app.Config().GitlabOauth.InspectLocation, writeAsIdentityLocation),
  164. AuthLocation: config.OrDefaultString(app.Config().GitlabOauth.AuthLocation, writeAsAuthLocation),
  165. HttpClient: config.DefaultHTTPClient(),
  166. CallbackLocation: callbackLocation,
  167. }
  168. configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
  169. }
  170. }
  171. func configureOauthRoutes(parentHandler *Handler, r *mux.Router, app *App, oauthClient oauthClient, callbackProxy *callbackProxyClient) {
  172. handler := &oauthHandler{
  173. Config: app.Config(),
  174. DB: app.DB(),
  175. Store: app.SessionStore(),
  176. oauthClient: oauthClient,
  177. EmailKey: app.keys.EmailKey,
  178. callbackProxy: callbackProxy,
  179. }
  180. r.HandleFunc("/oauth/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthInit)).Methods("GET")
  181. r.HandleFunc("/oauth/callback/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthCallback)).Methods("GET")
  182. r.HandleFunc("/oauth/signup", parentHandler.OAuth(handler.viewOauthSignup)).Methods("POST")
  183. }
  184. func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http.Request) error {
  185. ctx := r.Context()
  186. code := r.FormValue("code")
  187. state := r.FormValue("state")
  188. provider, clientID, err := h.DB.ValidateOAuthState(ctx, state)
  189. if err != nil {
  190. log.Error("Unable to ValidateOAuthState: %s", err)
  191. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  192. }
  193. tokenResponse, err := h.oauthClient.exchangeOauthCode(ctx, code)
  194. if err != nil {
  195. log.Error("Unable to exchangeOauthCode: %s", err)
  196. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  197. }
  198. // Now that we have the access token, let's use it real quick to make sur
  199. // it really really works.
  200. tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
  201. if err != nil {
  202. log.Error("Unable to inspectOauthAccessToken: %s", err)
  203. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  204. }
  205. localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID, provider, clientID)
  206. if err != nil {
  207. log.Error("Unable to GetIDForRemoteUser: %s", err)
  208. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  209. }
  210. if localUserID != -1 {
  211. user, err := h.DB.GetUserByID(localUserID)
  212. if err != nil {
  213. log.Error("Unable to GetUserByID %d: %s", localUserID, err)
  214. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  215. }
  216. if err = loginOrFail(h.Store, w, r, user); err != nil {
  217. log.Error("Unable to loginOrFail %d: %s", localUserID, err)
  218. return impart.HTTPError{http.StatusInternalServerError, err.Error()}
  219. }
  220. return nil
  221. }
  222. displayName := tokenInfo.DisplayName
  223. if len(displayName) == 0 {
  224. displayName = tokenInfo.Username
  225. }
  226. tp := &oauthSignupPageParams{
  227. AccessToken: tokenResponse.AccessToken,
  228. TokenUsername: tokenInfo.Username,
  229. TokenAlias: tokenInfo.DisplayName,
  230. TokenEmail: tokenInfo.Email,
  231. TokenRemoteUser: tokenInfo.UserID,
  232. Provider: provider,
  233. ClientID: clientID,
  234. }
  235. tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed)
  236. return h.showOauthSignupPage(app, w, r, tp, nil)
  237. }
  238. func (r *callbackProxyClient) register(ctx context.Context, state string) error {
  239. form := url.Values{}
  240. form.Add("state", state)
  241. form.Add("location", r.callbackLocation)
  242. req, err := http.NewRequestWithContext(ctx, "POST", r.server, strings.NewReader(form.Encode()))
  243. if err != nil {
  244. return err
  245. }
  246. req.Header.Set("User-Agent", "writefreely")
  247. req.Header.Set("Accept", "application/json")
  248. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  249. resp, err := r.httpClient.Do(req)
  250. if err != nil {
  251. return err
  252. }
  253. if resp.StatusCode != http.StatusCreated {
  254. return fmt.Errorf("unable register state location: %d", resp.StatusCode)
  255. }
  256. return nil
  257. }
  258. func limitedJsonUnmarshal(body io.ReadCloser, n int, thing interface{}) error {
  259. lr := io.LimitReader(body, int64(n+1))
  260. data, err := ioutil.ReadAll(lr)
  261. if err != nil {
  262. return err
  263. }
  264. if len(data) == n+1 {
  265. return fmt.Errorf("content larger than max read allowance: %d", n)
  266. }
  267. return json.Unmarshal(data, thing)
  268. }
  269. func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error {
  270. // An error may be returned, but a valid session should always be returned.
  271. session, _ := store.Get(r, cookieName)
  272. session.Values[cookieUserVal] = user.Cookie()
  273. if err := session.Save(r, w); err != nil {
  274. fmt.Println("error saving session", err)
  275. return err
  276. }
  277. http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
  278. return nil
  279. }