A clean, Markdown-based publishing platform made for writers. Write together, and build a community. https://writefreely.org
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

262 lines
8.5 KiB

  1. /*
  2. * Copyright © 2019-2021 Musing Studio LLC.
  3. *
  4. * This file is part of WriteFreely.
  5. *
  6. * WriteFreely is free software: you can redistribute it and/or modify
  7. * it under the terms of the GNU Affero General Public License, included
  8. * in the LICENSE file in this source code package.
  9. */
  10. package writefreely
  11. import (
  12. "context"
  13. "fmt"
  14. "github.com/gorilla/sessions"
  15. "github.com/stretchr/testify/assert"
  16. "github.com/writeas/impart"
  17. "github.com/writeas/web-core/id"
  18. "github.com/writefreely/writefreely/config"
  19. "net/http"
  20. "net/http/httptest"
  21. "net/url"
  22. "strings"
  23. "testing"
  24. )
  25. type MockOAuthDatastoreProvider struct {
  26. DoDB func() OAuthDatastore
  27. DoConfig func() *config.Config
  28. DoSessionStore func() sessions.Store
  29. }
  30. type MockOAuthDatastore struct {
  31. DoGenerateOAuthState func(context.Context, string, string, int64, string) (string, error)
  32. DoValidateOAuthState func(context.Context, string) (string, string, int64, string, error)
  33. DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error)
  34. DoCreateUser func(*config.Config, *User, string) error
  35. DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error
  36. DoGetUserByID func(int64) (*User, error)
  37. }
  38. var _ OAuthDatastore = &MockOAuthDatastore{}
  39. type StringReadCloser struct {
  40. *strings.Reader
  41. }
  42. func (src *StringReadCloser) Close() error {
  43. return nil
  44. }
  45. type MockHTTPClient struct {
  46. DoDo func(req *http.Request) (*http.Response, error)
  47. }
  48. func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) {
  49. if m.DoDo != nil {
  50. return m.DoDo(req)
  51. }
  52. return &http.Response{}, nil
  53. }
  54. func (m *MockOAuthDatastoreProvider) SessionStore() sessions.Store {
  55. if m.DoSessionStore != nil {
  56. return m.DoSessionStore()
  57. }
  58. return sessions.NewCookieStore([]byte("secret-key"))
  59. }
  60. func (m *MockOAuthDatastoreProvider) DB() OAuthDatastore {
  61. if m.DoDB != nil {
  62. return m.DoDB()
  63. }
  64. return &MockOAuthDatastore{}
  65. }
  66. func (m *MockOAuthDatastoreProvider) Config() *config.Config {
  67. if m.DoConfig != nil {
  68. return m.DoConfig()
  69. }
  70. cfg := config.New()
  71. cfg.UseSQLite(true)
  72. cfg.WriteAsOauth = config.WriteAsOauthCfg{
  73. ClientID: "development",
  74. ClientSecret: "development",
  75. AuthLocation: "https://write.as/oauth/login",
  76. TokenLocation: "https://write.as/oauth/token",
  77. InspectLocation: "https://write.as/oauth/inspect",
  78. }
  79. cfg.SlackOauth = config.SlackOauthCfg{
  80. ClientID: "development",
  81. ClientSecret: "development",
  82. TeamID: "development",
  83. }
  84. return cfg
  85. }
  86. func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, string, error) {
  87. if m.DoValidateOAuthState != nil {
  88. return m.DoValidateOAuthState(ctx, state)
  89. }
  90. return "", "", 0, "", nil
  91. }
  92. func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) {
  93. if m.DoGetIDForRemoteUser != nil {
  94. return m.DoGetIDForRemoteUser(ctx, remoteUserID, provider, clientID)
  95. }
  96. return -1, nil
  97. }
  98. func (m *MockOAuthDatastore) CreateUser(cfg *config.Config, u *User, username, description string) error {
  99. if m.DoCreateUser != nil {
  100. return m.DoCreateUser(cfg, u, username)
  101. }
  102. u.ID = 1
  103. return nil
  104. }
  105. func (m *MockOAuthDatastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error {
  106. if m.DoRecordRemoteUserID != nil {
  107. return m.DoRecordRemoteUserID(ctx, localUserID, remoteUserID, provider, clientID, accessToken)
  108. }
  109. return nil
  110. }
  111. func (m *MockOAuthDatastore) GetUserByID(userID int64) (*User, error) {
  112. if m.DoGetUserByID != nil {
  113. return m.DoGetUserByID(userID)
  114. }
  115. user := &User{}
  116. return user, nil
  117. }
  118. func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUserID int64, inviteCode string) (string, error) {
  119. if m.DoGenerateOAuthState != nil {
  120. return m.DoGenerateOAuthState(ctx, provider, clientID, attachUserID, inviteCode)
  121. }
  122. return id.Generate62RandomString(14), nil
  123. }
  124. func TestViewOauthInit(t *testing.T) {
  125. t.Run("success", func(t *testing.T) {
  126. app := &MockOAuthDatastoreProvider{}
  127. h := oauthHandler{
  128. Config: app.Config(),
  129. DB: app.DB(),
  130. Store: app.SessionStore(),
  131. EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd},
  132. oauthClient: writeAsOauthClient{
  133. ClientID: app.Config().WriteAsOauth.ClientID,
  134. ClientSecret: app.Config().WriteAsOauth.ClientSecret,
  135. ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
  136. InspectLocation: app.Config().WriteAsOauth.InspectLocation,
  137. AuthLocation: app.Config().WriteAsOauth.AuthLocation,
  138. CallbackLocation: "http://localhost/oauth/callback",
  139. HttpClient: nil,
  140. },
  141. }
  142. req, err := http.NewRequest("GET", "/oauth/client", nil)
  143. assert.NoError(t, err)
  144. rr := httptest.NewRecorder()
  145. err = h.viewOauthInit(nil, rr, req)
  146. assert.NotNil(t, err)
  147. httpErr, ok := err.(impart.HTTPError)
  148. assert.True(t, ok)
  149. assert.Equal(t, http.StatusTemporaryRedirect, httpErr.Status)
  150. assert.NotEmpty(t, httpErr.Message)
  151. locURI, err := url.Parse(httpErr.Message)
  152. assert.NoError(t, err)
  153. assert.Equal(t, "/oauth/login", locURI.Path)
  154. assert.Equal(t, "development", locURI.Query().Get("client_id"))
  155. assert.Equal(t, "http://localhost/oauth/callback", locURI.Query().Get("redirect_uri"))
  156. assert.Equal(t, "code", locURI.Query().Get("response_type"))
  157. assert.NotEmpty(t, locURI.Query().Get("state"))
  158. })
  159. t.Run("state failure", func(t *testing.T) {
  160. app := &MockOAuthDatastoreProvider{
  161. DoDB: func() OAuthDatastore {
  162. return &MockOAuthDatastore{
  163. DoGenerateOAuthState: func(ctx context.Context, provider, clientID string, attachUserID int64, inviteCode string) (string, error) {
  164. return "", fmt.Errorf("pretend unable to write state error")
  165. },
  166. }
  167. },
  168. }
  169. h := oauthHandler{
  170. Config: app.Config(),
  171. DB: app.DB(),
  172. Store: app.SessionStore(),
  173. EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd},
  174. oauthClient: writeAsOauthClient{
  175. ClientID: app.Config().WriteAsOauth.ClientID,
  176. ClientSecret: app.Config().WriteAsOauth.ClientSecret,
  177. ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
  178. InspectLocation: app.Config().WriteAsOauth.InspectLocation,
  179. AuthLocation: app.Config().WriteAsOauth.AuthLocation,
  180. CallbackLocation: "http://localhost/oauth/callback",
  181. HttpClient: nil,
  182. },
  183. }
  184. req, err := http.NewRequest("GET", "/oauth/client", nil)
  185. assert.NoError(t, err)
  186. rr := httptest.NewRecorder()
  187. err = h.viewOauthInit(nil, rr, req)
  188. httpErr, ok := err.(impart.HTTPError)
  189. assert.True(t, ok)
  190. assert.NotEmpty(t, httpErr.Message)
  191. assert.Equal(t, http.StatusInternalServerError, httpErr.Status)
  192. assert.Equal(t, "could not prepare oauth redirect url", httpErr.Message)
  193. })
  194. }
  195. func TestViewOauthCallback(t *testing.T) {
  196. t.Run("success", func(t *testing.T) {
  197. app := &MockOAuthDatastoreProvider{}
  198. h := oauthHandler{
  199. Config: app.Config(),
  200. DB: app.DB(),
  201. Store: app.SessionStore(),
  202. EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd},
  203. oauthClient: writeAsOauthClient{
  204. ClientID: app.Config().WriteAsOauth.ClientID,
  205. ClientSecret: app.Config().WriteAsOauth.ClientSecret,
  206. ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
  207. InspectLocation: app.Config().WriteAsOauth.InspectLocation,
  208. AuthLocation: app.Config().WriteAsOauth.AuthLocation,
  209. CallbackLocation: "http://localhost/oauth/callback",
  210. HttpClient: &MockHTTPClient{
  211. DoDo: func(req *http.Request) (*http.Response, error) {
  212. switch req.URL.String() {
  213. case "https://write.as/oauth/token":
  214. return &http.Response{
  215. StatusCode: 200,
  216. Body: &StringReadCloser{strings.NewReader(`{"access_token": "access_token", "expires_in": 1000, "refresh_token": "refresh_token", "token_type": "access"}`)},
  217. }, nil
  218. case "https://write.as/oauth/inspect":
  219. return &http.Response{
  220. StatusCode: 200,
  221. Body: &StringReadCloser{strings.NewReader(`{"client_id": "development", "user_id": "1", "expires_at": "2019-12-19T11:42:01Z", "username": "nick", "email": "nick@testing.write.as"}`)},
  222. }, nil
  223. }
  224. return &http.Response{
  225. StatusCode: http.StatusNotFound,
  226. }, nil
  227. },
  228. },
  229. },
  230. }
  231. req, err := http.NewRequest("GET", "/oauth/callback", nil)
  232. assert.NoError(t, err)
  233. rr := httptest.NewRecorder()
  234. err = h.viewOauthCallback(&App{cfg: app.Config(), sessionStore: app.SessionStore()}, rr, req)
  235. assert.NoError(t, err)
  236. assert.Equal(t, http.StatusTemporaryRedirect, rr.Code)
  237. })
  238. }