A clean, Markdown-based publishing platform made for writers. Write together, and build a community. https://writefreely.org
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

254 lines
8.0 KiB

  1. package writefreely
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/gorilla/sessions"
  6. "github.com/stretchr/testify/assert"
  7. "github.com/writeas/impart"
  8. "github.com/writeas/nerds/store"
  9. "github.com/writeas/writefreely/config"
  10. "net/http"
  11. "net/http/httptest"
  12. "net/url"
  13. "strings"
  14. "testing"
  15. )
  16. type MockOAuthDatastoreProvider struct {
  17. DoDB func() OAuthDatastore
  18. DoConfig func() *config.Config
  19. DoSessionStore func() sessions.Store
  20. }
  21. type MockOAuthDatastore struct {
  22. DoGenerateOAuthState func(context.Context, string, string) (string, error)
  23. DoValidateOAuthState func(context.Context, string) (string, string, error)
  24. DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error)
  25. DoCreateUser func(*config.Config, *User, string) error
  26. DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error
  27. DoGetUserByID func(int64) (*User, error)
  28. }
  29. var _ OAuthDatastore = &MockOAuthDatastore{}
  30. type StringReadCloser struct {
  31. *strings.Reader
  32. }
  33. func (src *StringReadCloser) Close() error {
  34. return nil
  35. }
  36. type MockHTTPClient struct {
  37. DoDo func(req *http.Request) (*http.Response, error)
  38. }
  39. func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) {
  40. if m.DoDo != nil {
  41. return m.DoDo(req)
  42. }
  43. return &http.Response{}, nil
  44. }
  45. func (m *MockOAuthDatastoreProvider) SessionStore() sessions.Store {
  46. if m.DoSessionStore != nil {
  47. return m.DoSessionStore()
  48. }
  49. return sessions.NewCookieStore([]byte("secret-key"))
  50. }
  51. func (m *MockOAuthDatastoreProvider) DB() OAuthDatastore {
  52. if m.DoDB != nil {
  53. return m.DoDB()
  54. }
  55. return &MockOAuthDatastore{}
  56. }
  57. func (m *MockOAuthDatastoreProvider) Config() *config.Config {
  58. if m.DoConfig != nil {
  59. return m.DoConfig()
  60. }
  61. cfg := config.New()
  62. cfg.UseSQLite(true)
  63. cfg.WriteAsOauth = config.WriteAsOauthCfg{
  64. ClientID: "development",
  65. ClientSecret: "development",
  66. AuthLocation: "https://write.as/oauth/login",
  67. TokenLocation: "https://write.as/oauth/token",
  68. InspectLocation: "https://write.as/oauth/inspect",
  69. }
  70. cfg.SlackOauth = config.SlackOauthCfg{
  71. ClientID: "development",
  72. ClientSecret: "development",
  73. TeamID: "development",
  74. }
  75. return cfg
  76. }
  77. func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) {
  78. if m.DoValidateOAuthState != nil {
  79. return m.DoValidateOAuthState(ctx, state)
  80. }
  81. return "", "", nil
  82. }
  83. func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) {
  84. if m.DoGetIDForRemoteUser != nil {
  85. return m.DoGetIDForRemoteUser(ctx, remoteUserID, provider, clientID)
  86. }
  87. return -1, nil
  88. }
  89. func (m *MockOAuthDatastore) CreateUser(cfg *config.Config, u *User, username string) error {
  90. if m.DoCreateUser != nil {
  91. return m.DoCreateUser(cfg, u, username)
  92. }
  93. u.ID = 1
  94. return nil
  95. }
  96. func (m *MockOAuthDatastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error {
  97. if m.DoRecordRemoteUserID != nil {
  98. return m.DoRecordRemoteUserID(ctx, localUserID, remoteUserID, provider, clientID, accessToken)
  99. }
  100. return nil
  101. }
  102. func (m *MockOAuthDatastore) GetUserByID(userID int64) (*User, error) {
  103. if m.DoGetUserByID != nil {
  104. return m.DoGetUserByID(userID)
  105. }
  106. user := &User{
  107. }
  108. return user, nil
  109. }
  110. func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string) (string, error) {
  111. if m.DoGenerateOAuthState != nil {
  112. return m.DoGenerateOAuthState(ctx, provider, clientID)
  113. }
  114. return store.Generate62RandomString(14), nil
  115. }
  116. func TestViewOauthInit(t *testing.T) {
  117. t.Run("success", func(t *testing.T) {
  118. app := &MockOAuthDatastoreProvider{}
  119. h := oauthHandler{
  120. Config: app.Config(),
  121. DB: app.DB(),
  122. Store: app.SessionStore(),
  123. EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd},
  124. oauthClient: writeAsOauthClient{
  125. ClientID: app.Config().WriteAsOauth.ClientID,
  126. ClientSecret: app.Config().WriteAsOauth.ClientSecret,
  127. ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
  128. InspectLocation: app.Config().WriteAsOauth.InspectLocation,
  129. AuthLocation: app.Config().WriteAsOauth.AuthLocation,
  130. CallbackLocation: "http://localhost/oauth/callback",
  131. HttpClient: nil,
  132. },
  133. }
  134. req, err := http.NewRequest("GET", "/oauth/client", nil)
  135. assert.NoError(t, err)
  136. rr := httptest.NewRecorder()
  137. err = h.viewOauthInit(nil, rr, req)
  138. assert.NotNil(t, err)
  139. httpErr, ok := err.(impart.HTTPError)
  140. assert.True(t, ok)
  141. assert.Equal(t, http.StatusTemporaryRedirect, httpErr.Status)
  142. assert.NotEmpty(t, httpErr.Message)
  143. locURI, err := url.Parse(httpErr.Message)
  144. assert.NoError(t, err)
  145. assert.Equal(t, "/oauth/login", locURI.Path)
  146. assert.Equal(t, "development", locURI.Query().Get("client_id"))
  147. assert.Equal(t, "http://localhost/oauth/callback", locURI.Query().Get("redirect_uri"))
  148. assert.Equal(t, "code", locURI.Query().Get("response_type"))
  149. assert.NotEmpty(t, locURI.Query().Get("state"))
  150. })
  151. t.Run("state failure", func(t *testing.T) {
  152. app := &MockOAuthDatastoreProvider{
  153. DoDB: func() OAuthDatastore {
  154. return &MockOAuthDatastore{
  155. DoGenerateOAuthState: func(ctx context.Context, provider, clientID string) (string, error) {
  156. return "", fmt.Errorf("pretend unable to write state error")
  157. },
  158. }
  159. },
  160. }
  161. h := oauthHandler{
  162. Config: app.Config(),
  163. DB: app.DB(),
  164. Store: app.SessionStore(),
  165. EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd},
  166. oauthClient: writeAsOauthClient{
  167. ClientID: app.Config().WriteAsOauth.ClientID,
  168. ClientSecret: app.Config().WriteAsOauth.ClientSecret,
  169. ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
  170. InspectLocation: app.Config().WriteAsOauth.InspectLocation,
  171. AuthLocation: app.Config().WriteAsOauth.AuthLocation,
  172. CallbackLocation: "http://localhost/oauth/callback",
  173. HttpClient: nil,
  174. },
  175. }
  176. req, err := http.NewRequest("GET", "/oauth/client", nil)
  177. assert.NoError(t, err)
  178. rr := httptest.NewRecorder()
  179. err = h.viewOauthInit(nil, rr, req)
  180. httpErr, ok := err.(impart.HTTPError)
  181. assert.True(t, ok)
  182. assert.NotEmpty(t, httpErr.Message)
  183. assert.Equal(t, http.StatusInternalServerError, httpErr.Status)
  184. assert.Equal(t, "could not prepare oauth redirect url", httpErr.Message)
  185. })
  186. }
  187. func TestViewOauthCallback(t *testing.T) {
  188. t.Run("success", func(t *testing.T) {
  189. app := &MockOAuthDatastoreProvider{}
  190. h := oauthHandler{
  191. Config: app.Config(),
  192. DB: app.DB(),
  193. Store: app.SessionStore(),
  194. EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd},
  195. oauthClient: writeAsOauthClient{
  196. ClientID: app.Config().WriteAsOauth.ClientID,
  197. ClientSecret: app.Config().WriteAsOauth.ClientSecret,
  198. ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
  199. InspectLocation: app.Config().WriteAsOauth.InspectLocation,
  200. AuthLocation: app.Config().WriteAsOauth.AuthLocation,
  201. CallbackLocation: "http://localhost/oauth/callback",
  202. HttpClient: &MockHTTPClient{
  203. DoDo: func(req *http.Request) (*http.Response, error) {
  204. switch req.URL.String() {
  205. case "https://write.as/oauth/token":
  206. return &http.Response{
  207. StatusCode: 200,
  208. Body: &StringReadCloser{strings.NewReader(`{"access_token": "access_token", "expires_in": 1000, "refresh_token": "refresh_token", "token_type": "access"}`)},
  209. }, nil
  210. case "https://write.as/oauth/inspect":
  211. return &http.Response{
  212. StatusCode: 200,
  213. Body: &StringReadCloser{strings.NewReader(`{"client_id": "development", "user_id": "1", "expires_at": "2019-12-19T11:42:01Z", "username": "nick", "email": "nick@testing.write.as"}`)},
  214. }, nil
  215. }
  216. return &http.Response{
  217. StatusCode: http.StatusNotFound,
  218. }, nil
  219. },
  220. },
  221. },
  222. }
  223. req, err := http.NewRequest("GET", "/oauth/callback", nil)
  224. assert.NoError(t, err)
  225. rr := httptest.NewRecorder()
  226. err = h.viewOauthCallback(nil, rr, req)
  227. assert.NoError(t, err)
  228. assert.Equal(t, http.StatusTemporaryRedirect, rr.Code)
  229. })
  230. }