A clean, Markdown-based publishing platform made for writers. Write together, and build a community. https://writefreely.org
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

199 lines
5.7 KiB

  1. package writefreely
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/gorilla/sessions"
  6. "github.com/stretchr/testify/assert"
  7. "github.com/writeas/impart"
  8. "github.com/writeas/writefreely/config"
  9. "net/http"
  10. "net/http/httptest"
  11. "net/url"
  12. "strings"
  13. "testing"
  14. )
  15. type MockOAuthDatastoreProvider struct {
  16. DoDB func() OAuthDatastore
  17. DoConfig func() *config.Config
  18. DoSessionStore func() sessions.Store
  19. }
  20. type MockOAuthDatastore struct {
  21. DoGenerateOAuthState func(ctx context.Context) (string, error)
  22. DoValidateOAuthState func(context.Context, string) error
  23. DoGetIDForRemoteUser func(context.Context, int64) (int64, error)
  24. DoCreateUser func(*config.Config, *User, string) error
  25. DoRecordRemoteUserID func(context.Context, int64, int64) error
  26. DoGetUserForAuthByID func(int64) (*User, error)
  27. }
  28. type StringReadCloser struct {
  29. *strings.Reader
  30. }
  31. func (src *StringReadCloser) Close() error {
  32. return nil
  33. }
  34. type MockHTTPClient struct {
  35. DoDo func(req *http.Request) (*http.Response, error)
  36. }
  37. func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) {
  38. if m.DoDo != nil {
  39. return m.DoDo(req)
  40. }
  41. return &http.Response{}, nil
  42. }
  43. func (m *MockOAuthDatastoreProvider) SessionStore() sessions.Store {
  44. if m.DoSessionStore != nil {
  45. return m.DoSessionStore()
  46. }
  47. return sessions.NewCookieStore([]byte("secret-key"))
  48. }
  49. func (m *MockOAuthDatastoreProvider) DB() OAuthDatastore {
  50. if m.DoDB != nil {
  51. return m.DoDB()
  52. }
  53. return &MockOAuthDatastore{}
  54. }
  55. func (m *MockOAuthDatastoreProvider) Config() *config.Config {
  56. if m.DoConfig != nil {
  57. return m.DoConfig()
  58. }
  59. cfg := config.New()
  60. cfg.UseSQLite(true)
  61. cfg.App.EnableOAuth = true
  62. cfg.App.OAuthProviderAuthLocation = "https://write.as/oauth/login"
  63. cfg.App.OAuthProviderTokenLocation = "https://write.as/oauth/token"
  64. cfg.App.OAuthProviderInspectLocation = "https://write.as/oauth/inspect"
  65. cfg.App.OAuthClientCallbackLocation = "http://localhost/oauth/callback"
  66. cfg.App.OAuthClientID = "development"
  67. cfg.App.OAuthClientSecret = "development"
  68. return cfg
  69. }
  70. func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) error {
  71. if m.DoValidateOAuthState != nil {
  72. return m.DoValidateOAuthState(ctx, state)
  73. }
  74. return nil
  75. }
  76. func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error) {
  77. if m.DoGetIDForRemoteUser != nil {
  78. return m.DoGetIDForRemoteUser(ctx, remoteUserID)
  79. }
  80. return -1, nil
  81. }
  82. func (m *MockOAuthDatastore) CreateUser(cfg *config.Config, u *User, username string) error {
  83. if m.DoCreateUser != nil {
  84. return m.DoCreateUser(cfg, u, username)
  85. }
  86. u.ID = 1
  87. return nil
  88. }
  89. func (m *MockOAuthDatastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID int64) error {
  90. if m.DoRecordRemoteUserID != nil {
  91. return m.DoRecordRemoteUserID(ctx, localUserID, remoteUserID)
  92. }
  93. return nil
  94. }
  95. func (m *MockOAuthDatastore) GetUserForAuthByID(userID int64) (*User, error) {
  96. if m.DoGetUserForAuthByID != nil {
  97. return m.DoGetUserForAuthByID(userID)
  98. }
  99. user := &User{
  100. }
  101. return user, nil
  102. }
  103. func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context) (string, error) {
  104. if m.DoGenerateOAuthState != nil {
  105. return m.DoGenerateOAuthState(ctx)
  106. }
  107. return randString(14)
  108. }
  109. func TestViewOauthInit(t *testing.T) {
  110. h := oauthHandler{}
  111. t.Run("success", func(t *testing.T) {
  112. app := &MockOAuthDatastoreProvider{}
  113. req, err := http.NewRequest("GET", "/oauth/client", nil)
  114. assert.NoError(t, err)
  115. rr := httptest.NewRecorder()
  116. err = h.viewOauthInit(app, rr, req)
  117. assert.NoError(t, err)
  118. assert.Equal(t, http.StatusTemporaryRedirect, rr.Code)
  119. locURI, err := url.Parse(rr.Header().Get("Location"))
  120. assert.NoError(t, err)
  121. assert.Equal(t, "/oauth/login", locURI.Path)
  122. assert.Equal(t, "development", locURI.Query().Get("client_id"))
  123. assert.Equal(t, "http://localhost/oauth/callback", locURI.Query().Get("redirect_uri"))
  124. assert.Equal(t, "code", locURI.Query().Get("response_type"))
  125. assert.NotEmpty(t, locURI.Query().Get("state"))
  126. })
  127. t.Run("state failure", func(t *testing.T) {
  128. app := &MockOAuthDatastoreProvider{
  129. DoDB: func() OAuthDatastore {
  130. return &MockOAuthDatastore{
  131. DoGenerateOAuthState: func(ctx context.Context) (string, error) {
  132. return "", fmt.Errorf("pretend unable to write state error")
  133. },
  134. }
  135. },
  136. }
  137. req, err := http.NewRequest("GET", "/oauth/client", nil)
  138. assert.NoError(t, err)
  139. rr := httptest.NewRecorder()
  140. err = h.viewOauthInit(app, rr, req)
  141. assert.Error(t, err)
  142. assert.Equal(t, impart.HTTPError{Status: http.StatusInternalServerError, Message: "Could not prepare OAuth redirect URL."}, err)
  143. })
  144. }
  145. func TestViewOauthCallback(t *testing.T) {
  146. t.Run("success", func(t *testing.T) {
  147. app := &MockOAuthDatastoreProvider{}
  148. h := oauthHandler{
  149. HttpClient: &MockHTTPClient{
  150. DoDo: func(req *http.Request) (*http.Response, error) {
  151. switch req.URL.String() {
  152. case "https://write.as/oauth/token":
  153. return &http.Response{
  154. StatusCode: 200,
  155. Body: &StringReadCloser{strings.NewReader(`{"access_token": "access_token", "expires_in": 1000, "refresh_token": "refresh_token", "token_type": "access"}`)},
  156. }, nil
  157. case "https://write.as/oauth/inspect":
  158. return &http.Response{
  159. StatusCode: 200,
  160. Body: &StringReadCloser{strings.NewReader(`{"client_id": "development", "user_id": 1, "expires_at": "2019-12-19T11:42:01Z", "username": "nick", "email": "nick@testing.write.as"}`)},
  161. }, nil
  162. }
  163. return &http.Response{
  164. StatusCode: http.StatusNotFound,
  165. }, nil
  166. },
  167. },
  168. }
  169. req, err := http.NewRequest("GET", "/oauth/callback", nil)
  170. assert.NoError(t, err)
  171. rr := httptest.NewRecorder()
  172. err = h.viewOauthCallback(app, rr, req)
  173. assert.NoError(t, err)
  174. assert.Equal(t, http.StatusTemporaryRedirect, rr.Code)
  175. })
  176. }