A clean, Markdown-based publishing platform made for writers. Write together, and build a community. https://writefreely.org
選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。
 
 
 
 
 

252 行
8.2 KiB

  1. package writefreely
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/gorilla/sessions"
  6. "github.com/stretchr/testify/assert"
  7. "github.com/writeas/impart"
  8. "github.com/writeas/nerds/store"
  9. "github.com/writeas/writefreely/config"
  10. "net/http"
  11. "net/http/httptest"
  12. "net/url"
  13. "strings"
  14. "testing"
  15. )
  16. type MockOAuthDatastoreProvider struct {
  17. DoDB func() OAuthDatastore
  18. DoConfig func() *config.Config
  19. DoSessionStore func() sessions.Store
  20. }
  21. type MockOAuthDatastore struct {
  22. DoGenerateOAuthState func(context.Context, string, string, int64, string) (string, error)
  23. DoValidateOAuthState func(context.Context, string) (string, string, int64, string, error)
  24. DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error)
  25. DoCreateUser func(*config.Config, *User, string) error
  26. DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error
  27. DoGetUserByID func(int64) (*User, error)
  28. }
  29. var _ OAuthDatastore = &MockOAuthDatastore{}
  30. type StringReadCloser struct {
  31. *strings.Reader
  32. }
  33. func (src *StringReadCloser) Close() error {
  34. return nil
  35. }
  36. type MockHTTPClient struct {
  37. DoDo func(req *http.Request) (*http.Response, error)
  38. }
  39. func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) {
  40. if m.DoDo != nil {
  41. return m.DoDo(req)
  42. }
  43. return &http.Response{}, nil
  44. }
  45. func (m *MockOAuthDatastoreProvider) SessionStore() sessions.Store {
  46. if m.DoSessionStore != nil {
  47. return m.DoSessionStore()
  48. }
  49. return sessions.NewCookieStore([]byte("secret-key"))
  50. }
  51. func (m *MockOAuthDatastoreProvider) DB() OAuthDatastore {
  52. if m.DoDB != nil {
  53. return m.DoDB()
  54. }
  55. return &MockOAuthDatastore{}
  56. }
  57. func (m *MockOAuthDatastoreProvider) Config() *config.Config {
  58. if m.DoConfig != nil {
  59. return m.DoConfig()
  60. }
  61. cfg := config.New()
  62. cfg.UseSQLite(true)
  63. cfg.WriteAsOauth = config.WriteAsOauthCfg{
  64. ClientID: "development",
  65. ClientSecret: "development",
  66. AuthLocation: "https://write.as/oauth/login",
  67. TokenLocation: "https://write.as/oauth/token",
  68. InspectLocation: "https://write.as/oauth/inspect",
  69. }
  70. cfg.SlackOauth = config.SlackOauthCfg{
  71. ClientID: "development",
  72. ClientSecret: "development",
  73. TeamID: "development",
  74. }
  75. return cfg
  76. }
  77. func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, string, error) {
  78. if m.DoValidateOAuthState != nil {
  79. return m.DoValidateOAuthState(ctx, state)
  80. }
  81. return "", "", 0, "", nil
  82. }
  83. func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) {
  84. if m.DoGetIDForRemoteUser != nil {
  85. return m.DoGetIDForRemoteUser(ctx, remoteUserID, provider, clientID)
  86. }
  87. return -1, nil
  88. }
  89. func (m *MockOAuthDatastore) CreateUser(cfg *config.Config, u *User, username string) error {
  90. if m.DoCreateUser != nil {
  91. return m.DoCreateUser(cfg, u, username)
  92. }
  93. u.ID = 1
  94. return nil
  95. }
  96. func (m *MockOAuthDatastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error {
  97. if m.DoRecordRemoteUserID != nil {
  98. return m.DoRecordRemoteUserID(ctx, localUserID, remoteUserID, provider, clientID, accessToken)
  99. }
  100. return nil
  101. }
  102. func (m *MockOAuthDatastore) GetUserByID(userID int64) (*User, error) {
  103. if m.DoGetUserByID != nil {
  104. return m.DoGetUserByID(userID)
  105. }
  106. user := &User{}
  107. return user, nil
  108. }
  109. func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUserID int64, inviteCode string) (string, error) {
  110. if m.DoGenerateOAuthState != nil {
  111. return m.DoGenerateOAuthState(ctx, provider, clientID, attachUserID, inviteCode)
  112. }
  113. return store.Generate62RandomString(14), nil
  114. }
  115. func TestViewOauthInit(t *testing.T) {
  116. t.Run("success", func(t *testing.T) {
  117. app := &MockOAuthDatastoreProvider{}
  118. h := oauthHandler{
  119. Config: app.Config(),
  120. DB: app.DB(),
  121. Store: app.SessionStore(),
  122. EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd},
  123. oauthClient: writeAsOauthClient{
  124. ClientID: app.Config().WriteAsOauth.ClientID,
  125. ClientSecret: app.Config().WriteAsOauth.ClientSecret,
  126. ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
  127. InspectLocation: app.Config().WriteAsOauth.InspectLocation,
  128. AuthLocation: app.Config().WriteAsOauth.AuthLocation,
  129. CallbackLocation: "http://localhost/oauth/callback",
  130. HttpClient: nil,
  131. },
  132. }
  133. req, err := http.NewRequest("GET", "/oauth/client", nil)
  134. assert.NoError(t, err)
  135. rr := httptest.NewRecorder()
  136. err = h.viewOauthInit(nil, rr, req)
  137. assert.NotNil(t, err)
  138. httpErr, ok := err.(impart.HTTPError)
  139. assert.True(t, ok)
  140. assert.Equal(t, http.StatusTemporaryRedirect, httpErr.Status)
  141. assert.NotEmpty(t, httpErr.Message)
  142. locURI, err := url.Parse(httpErr.Message)
  143. assert.NoError(t, err)
  144. assert.Equal(t, "/oauth/login", locURI.Path)
  145. assert.Equal(t, "development", locURI.Query().Get("client_id"))
  146. assert.Equal(t, "http://localhost/oauth/callback", locURI.Query().Get("redirect_uri"))
  147. assert.Equal(t, "code", locURI.Query().Get("response_type"))
  148. assert.NotEmpty(t, locURI.Query().Get("state"))
  149. })
  150. t.Run("state failure", func(t *testing.T) {
  151. app := &MockOAuthDatastoreProvider{
  152. DoDB: func() OAuthDatastore {
  153. return &MockOAuthDatastore{
  154. DoGenerateOAuthState: func(ctx context.Context, provider, clientID string, attachUserID int64, inviteCode string) (string, error) {
  155. return "", fmt.Errorf("pretend unable to write state error")
  156. },
  157. }
  158. },
  159. }
  160. h := oauthHandler{
  161. Config: app.Config(),
  162. DB: app.DB(),
  163. Store: app.SessionStore(),
  164. EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd},
  165. oauthClient: writeAsOauthClient{
  166. ClientID: app.Config().WriteAsOauth.ClientID,
  167. ClientSecret: app.Config().WriteAsOauth.ClientSecret,
  168. ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
  169. InspectLocation: app.Config().WriteAsOauth.InspectLocation,
  170. AuthLocation: app.Config().WriteAsOauth.AuthLocation,
  171. CallbackLocation: "http://localhost/oauth/callback",
  172. HttpClient: nil,
  173. },
  174. }
  175. req, err := http.NewRequest("GET", "/oauth/client", nil)
  176. assert.NoError(t, err)
  177. rr := httptest.NewRecorder()
  178. err = h.viewOauthInit(nil, rr, req)
  179. httpErr, ok := err.(impart.HTTPError)
  180. assert.True(t, ok)
  181. assert.NotEmpty(t, httpErr.Message)
  182. assert.Equal(t, http.StatusInternalServerError, httpErr.Status)
  183. assert.Equal(t, "could not prepare oauth redirect url", httpErr.Message)
  184. })
  185. }
  186. func TestViewOauthCallback(t *testing.T) {
  187. t.Run("success", func(t *testing.T) {
  188. app := &MockOAuthDatastoreProvider{}
  189. h := oauthHandler{
  190. Config: app.Config(),
  191. DB: app.DB(),
  192. Store: app.SessionStore(),
  193. EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd},
  194. oauthClient: writeAsOauthClient{
  195. ClientID: app.Config().WriteAsOauth.ClientID,
  196. ClientSecret: app.Config().WriteAsOauth.ClientSecret,
  197. ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
  198. InspectLocation: app.Config().WriteAsOauth.InspectLocation,
  199. AuthLocation: app.Config().WriteAsOauth.AuthLocation,
  200. CallbackLocation: "http://localhost/oauth/callback",
  201. HttpClient: &MockHTTPClient{
  202. DoDo: func(req *http.Request) (*http.Response, error) {
  203. switch req.URL.String() {
  204. case "https://write.as/oauth/token":
  205. return &http.Response{
  206. StatusCode: 200,
  207. Body: &StringReadCloser{strings.NewReader(`{"access_token": "access_token", "expires_in": 1000, "refresh_token": "refresh_token", "token_type": "access"}`)},
  208. }, nil
  209. case "https://write.as/oauth/inspect":
  210. return &http.Response{
  211. StatusCode: 200,
  212. Body: &StringReadCloser{strings.NewReader(`{"client_id": "development", "user_id": "1", "expires_at": "2019-12-19T11:42:01Z", "username": "nick", "email": "nick@testing.write.as"}`)},
  213. }, nil
  214. }
  215. return &http.Response{
  216. StatusCode: http.StatusNotFound,
  217. }, nil
  218. },
  219. },
  220. },
  221. }
  222. req, err := http.NewRequest("GET", "/oauth/callback", nil)
  223. assert.NoError(t, err)
  224. rr := httptest.NewRecorder()
  225. err = h.viewOauthCallback(&App{cfg: app.Config(), sessionStore: app.SessionStore()}, rr, req)
  226. assert.NoError(t, err)
  227. assert.Equal(t, http.StatusTemporaryRedirect, rr.Code)
  228. })
  229. }