A clean, Markdown-based publishing platform made for writers. Write together, and build a community. https://writefreely.org
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

143 lines
4.1 KiB

  1. /*
  2. * Copyright © 2020-2021 A Bunch Tell LLC and respective authors.
  3. *
  4. * This file is part of WriteFreely.
  5. *
  6. * WriteFreely is free software: you can redistribute it and/or modify
  7. * it under the terms of the GNU Affero General Public License, included
  8. * in the LICENSE file in this source code package.
  9. */
  10. package writefreely
  11. import (
  12. "context"
  13. "errors"
  14. "fmt"
  15. "github.com/writeas/web-core/log"
  16. "net/http"
  17. "net/url"
  18. "strings"
  19. )
  20. type genericOauthClient struct {
  21. ClientID string
  22. ClientSecret string
  23. AuthLocation string
  24. ExchangeLocation string
  25. InspectLocation string
  26. CallbackLocation string
  27. Scope string
  28. MapUserID string
  29. MapUsername string
  30. MapDisplayName string
  31. MapEmail string
  32. HttpClient HttpClient
  33. }
  34. var _ oauthClient = genericOauthClient{}
  35. const (
  36. genericOauthDisplayName = "OAuth"
  37. )
  38. func (c genericOauthClient) GetProvider() string {
  39. return "generic"
  40. }
  41. func (c genericOauthClient) GetClientID() string {
  42. return c.ClientID
  43. }
  44. func (c genericOauthClient) GetCallbackLocation() string {
  45. return c.CallbackLocation
  46. }
  47. func (c genericOauthClient) buildLoginURL(state string) (string, error) {
  48. u, err := url.Parse(c.AuthLocation)
  49. if err != nil {
  50. return "", err
  51. }
  52. q := u.Query()
  53. q.Set("client_id", c.ClientID)
  54. q.Set("redirect_uri", c.CallbackLocation)
  55. q.Set("response_type", "code")
  56. q.Set("state", state)
  57. q.Set("scope", c.Scope)
  58. u.RawQuery = q.Encode()
  59. return u.String(), nil
  60. }
  61. func (c genericOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) {
  62. form := url.Values{}
  63. form.Add("grant_type", "authorization_code")
  64. form.Add("redirect_uri", c.CallbackLocation)
  65. form.Add("scope", c.Scope)
  66. form.Add("code", code)
  67. req, err := http.NewRequest("POST", c.ExchangeLocation, strings.NewReader(form.Encode()))
  68. if err != nil {
  69. return nil, err
  70. }
  71. req.WithContext(ctx)
  72. req.Header.Set("User-Agent", ServerUserAgent(""))
  73. req.Header.Set("Accept", "application/json")
  74. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  75. req.SetBasicAuth(c.ClientID, c.ClientSecret)
  76. resp, err := c.HttpClient.Do(req)
  77. if err != nil {
  78. return nil, err
  79. }
  80. if resp.StatusCode != http.StatusOK {
  81. return nil, errors.New("unable to exchange code for access token")
  82. }
  83. var tokenResponse TokenResponse
  84. if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil {
  85. return nil, err
  86. }
  87. if tokenResponse.Error != "" {
  88. return nil, errors.New(tokenResponse.Error)
  89. }
  90. return &tokenResponse, nil
  91. }
  92. func (c genericOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) {
  93. req, err := http.NewRequest("GET", c.InspectLocation, nil)
  94. if err != nil {
  95. return nil, err
  96. }
  97. req.WithContext(ctx)
  98. req.Header.Set("User-Agent", ServerUserAgent(""))
  99. req.Header.Set("Accept", "application/json")
  100. req.Header.Set("Authorization", "Bearer "+accessToken)
  101. resp, err := c.HttpClient.Do(req)
  102. if err != nil {
  103. return nil, err
  104. }
  105. if resp.StatusCode != http.StatusOK {
  106. return nil, errors.New("unable to inspect access token")
  107. }
  108. // since we don't know what the JSON from the server will look like, we create a
  109. // generic interface and then map manually to values set in the config
  110. var genericInterface map[string]interface{}
  111. if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &genericInterface); err != nil {
  112. return nil, err
  113. }
  114. // map each relevant field in inspectResponse to the mapped field from the config
  115. var inspectResponse InspectResponse
  116. inspectResponse.UserID, _ = genericInterface[c.MapUserID].(string)
  117. if inspectResponse.UserID == "" {
  118. log.Error("[CONFIGURATION ERROR] Generic OAuth provider returned empty UserID value (`%s`).\n Do you need to configure a different `map_user_id` value for this provider?", c.MapUserID)
  119. return nil, fmt.Errorf("no UserID (`%s`) value returned", c.MapUserID)
  120. }
  121. inspectResponse.Username, _ = genericInterface[c.MapUsername].(string)
  122. inspectResponse.DisplayName, _ = genericInterface[c.MapDisplayName].(string)
  123. inspectResponse.Email, _ = genericInterface[c.MapEmail].(string)
  124. return &inspectResponse, nil
  125. }