A clean, Markdown-based publishing platform made for writers. Write together, and build a community. https://writefreely.org
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

211 lines
5.9 KiB

  1. package writefreely
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/gorilla/sessions"
  6. "github.com/stretchr/testify/assert"
  7. "github.com/writeas/writefreely/config"
  8. "net/http"
  9. "net/http/httptest"
  10. "net/url"
  11. "strings"
  12. "testing"
  13. )
  14. type MockOAuthDatastoreProvider struct {
  15. DoDB func() OAuthDatastore
  16. DoConfig func() *config.Config
  17. DoSessionStore func() sessions.Store
  18. }
  19. type MockOAuthDatastore struct {
  20. DoGenerateOAuthState func(ctx context.Context) (string, error)
  21. DoValidateOAuthState func(context.Context, string) error
  22. DoGetIDForRemoteUser func(context.Context, int64) (int64, error)
  23. DoCreateUser func(*config.Config, *User, string) error
  24. DoRecordRemoteUserID func(context.Context, int64, int64) error
  25. DoGetUserForAuthByID func(int64) (*User, error)
  26. }
  27. type StringReadCloser struct {
  28. *strings.Reader
  29. }
  30. func (src *StringReadCloser) Close() error {
  31. return nil
  32. }
  33. type MockHTTPClient struct {
  34. DoDo func(req *http.Request) (*http.Response, error)
  35. }
  36. func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) {
  37. if m.DoDo != nil {
  38. return m.DoDo(req)
  39. }
  40. return &http.Response{}, nil
  41. }
  42. func (m *MockOAuthDatastoreProvider) SessionStore() sessions.Store {
  43. if m.DoSessionStore != nil {
  44. return m.DoSessionStore()
  45. }
  46. return sessions.NewCookieStore([]byte("secret-key"))
  47. }
  48. func (m *MockOAuthDatastoreProvider) DB() OAuthDatastore {
  49. if m.DoDB != nil {
  50. return m.DoDB()
  51. }
  52. return &MockOAuthDatastore{}
  53. }
  54. func (m *MockOAuthDatastoreProvider) Config() *config.Config {
  55. if m.DoConfig != nil {
  56. return m.DoConfig()
  57. }
  58. cfg := config.New()
  59. cfg.UseSQLite(true)
  60. cfg.App.EnableOAuth = true
  61. cfg.App.OAuthProviderAuthLocation = "https://write.as/oauth/login"
  62. cfg.App.OAuthProviderTokenLocation = "https://write.as/oauth/token"
  63. cfg.App.OAuthProviderInspectLocation = "https://write.as/oauth/inspect"
  64. cfg.App.OAuthClientCallbackLocation = "http://localhost/oauth/callback"
  65. cfg.App.OAuthClientID = "development"
  66. cfg.App.OAuthClientSecret = "development"
  67. return cfg
  68. }
  69. func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) error {
  70. if m.DoValidateOAuthState != nil {
  71. return m.DoValidateOAuthState(ctx, state)
  72. }
  73. return nil
  74. }
  75. func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error) {
  76. if m.DoGetIDForRemoteUser != nil {
  77. return m.DoGetIDForRemoteUser(ctx, remoteUserID)
  78. }
  79. return -1, nil
  80. }
  81. func (m *MockOAuthDatastore) CreateUser(cfg *config.Config, u *User, username string) error {
  82. if m.DoCreateUser != nil {
  83. return m.DoCreateUser(cfg, u, username)
  84. }
  85. u.ID = 1
  86. return nil
  87. }
  88. func (m *MockOAuthDatastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID int64) error {
  89. if m.DoRecordRemoteUserID != nil {
  90. return m.DoRecordRemoteUserID(ctx, localUserID, remoteUserID)
  91. }
  92. return nil
  93. }
  94. func (m *MockOAuthDatastore) GetUserForAuthByID(userID int64) (*User, error) {
  95. if m.DoGetUserForAuthByID != nil {
  96. return m.DoGetUserForAuthByID(userID)
  97. }
  98. user := &User{
  99. }
  100. return user, nil
  101. }
  102. func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context) (string, error) {
  103. if m.DoGenerateOAuthState != nil {
  104. return m.DoGenerateOAuthState(ctx)
  105. }
  106. return randString(14)
  107. }
  108. func TestViewOauthInit(t *testing.T) {
  109. t.Run("success", func(t *testing.T) {
  110. app := &MockOAuthDatastoreProvider{}
  111. h := oauthHandler{
  112. Config: app.Config(),
  113. DB: app.DB(),
  114. Store: app.SessionStore(),
  115. }
  116. req, err := http.NewRequest("GET", "/oauth/client", nil)
  117. assert.NoError(t, err)
  118. rr := httptest.NewRecorder()
  119. h.viewOauthInit(rr, req)
  120. assert.Equal(t, http.StatusTemporaryRedirect, rr.Code)
  121. locURI, err := url.Parse(rr.Header().Get("Location"))
  122. assert.NoError(t, err)
  123. assert.Equal(t, "/oauth/login", locURI.Path)
  124. assert.Equal(t, "development", locURI.Query().Get("client_id"))
  125. assert.Equal(t, "http://localhost/oauth/callback", locURI.Query().Get("redirect_uri"))
  126. assert.Equal(t, "code", locURI.Query().Get("response_type"))
  127. assert.NotEmpty(t, locURI.Query().Get("state"))
  128. })
  129. t.Run("state failure", func(t *testing.T) {
  130. app := &MockOAuthDatastoreProvider{
  131. DoDB: func() OAuthDatastore {
  132. return &MockOAuthDatastore{
  133. DoGenerateOAuthState: func(ctx context.Context) (string, error) {
  134. return "", fmt.Errorf("pretend unable to write state error")
  135. },
  136. }
  137. },
  138. }
  139. h := oauthHandler{
  140. Config: app.Config(),
  141. DB: app.DB(),
  142. Store: app.SessionStore(),
  143. }
  144. req, err := http.NewRequest("GET", "/oauth/client", nil)
  145. assert.NoError(t, err)
  146. rr := httptest.NewRecorder()
  147. h.viewOauthInit(rr, req)
  148. assert.Equal(t, http.StatusInternalServerError, rr.Code)
  149. expected := `{"error":"could not prepare oauth redirect url"}` + "\n"
  150. assert.Equal(t, expected, rr.Body.String())
  151. })
  152. }
  153. func TestViewOauthCallback(t *testing.T) {
  154. t.Run("success", func(t *testing.T) {
  155. app := &MockOAuthDatastoreProvider{}
  156. h := oauthHandler{
  157. Config: app.Config(),
  158. DB: app.DB(),
  159. Store: app.SessionStore(),
  160. HttpClient: &MockHTTPClient{
  161. DoDo: func(req *http.Request) (*http.Response, error) {
  162. switch req.URL.String() {
  163. case "https://write.as/oauth/token":
  164. return &http.Response{
  165. StatusCode: 200,
  166. Body: &StringReadCloser{strings.NewReader(`{"access_token": "access_token", "expires_in": 1000, "refresh_token": "refresh_token", "token_type": "access"}`)},
  167. }, nil
  168. case "https://write.as/oauth/inspect":
  169. return &http.Response{
  170. StatusCode: 200,
  171. Body: &StringReadCloser{strings.NewReader(`{"client_id": "development", "user_id": 1, "expires_at": "2019-12-19T11:42:01Z", "username": "nick", "email": "nick@testing.write.as"}`)},
  172. }, nil
  173. }
  174. return &http.Response{
  175. StatusCode: http.StatusNotFound,
  176. }, nil
  177. },
  178. },
  179. }
  180. req, err := http.NewRequest("GET", "/oauth/callback", nil)
  181. assert.NoError(t, err)
  182. rr := httptest.NewRecorder()
  183. h.viewOauthCallback(rr, req)
  184. assert.NoError(t, err)
  185. assert.Equal(t, http.StatusTemporaryRedirect, rr.Code)
  186. })
  187. }