A clean, Markdown-based publishing platform made for writers. Write together, and build a community. https://writefreely.org
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.
 
 
 
 
 

142 lignes
3.1 KiB

  1. /*
  2. * Copyright © 2018-2019 A Bunch Tell LLC.
  3. *
  4. * This file is part of WriteFreely.
  5. *
  6. * WriteFreely is free software: you can redistribute it and/or modify
  7. * it under the terms of the GNU Affero General Public License, included
  8. * in the LICENSE file in this source code package.
  9. */
  10. package writefreely
  11. import (
  12. "encoding/gob"
  13. "github.com/gorilla/sessions"
  14. "github.com/writeas/web-core/log"
  15. "net/http"
  16. "strings"
  17. )
  18. const (
  19. day = 86400
  20. sessionLength = 180 * day
  21. cookieName = "wfu"
  22. cookieUserVal = "u"
  23. blogPassCookieName = "ub"
  24. )
  25. // InitSession creates the cookie store. It depends on the keychain already
  26. // being loaded.
  27. func (app *App) InitSession() {
  28. // Register complex data types we'll be storing in cookies
  29. gob.Register(&User{})
  30. // Create the cookie store
  31. store := sessions.NewCookieStore(app.keys.CookieAuthKey, app.keys.CookieKey)
  32. store.Options = &sessions.Options{
  33. Path: "/",
  34. MaxAge: sessionLength,
  35. HttpOnly: true,
  36. Secure: strings.HasPrefix(app.cfg.App.Host, "https://"),
  37. }
  38. if store.Options.Secure {
  39. store.Options.SameSite = http.SameSiteNoneMode
  40. }
  41. app.sessionStore = store
  42. }
  43. func getSessionFlashes(app *App, w http.ResponseWriter, r *http.Request, session *sessions.Session) ([]string, error) {
  44. var err error
  45. if session == nil {
  46. session, err = app.sessionStore.Get(r, cookieName)
  47. if err != nil {
  48. return nil, err
  49. }
  50. }
  51. f := []string{}
  52. if flashes := session.Flashes(); len(flashes) > 0 {
  53. for _, flash := range flashes {
  54. if str, ok := flash.(string); ok {
  55. f = append(f, str)
  56. }
  57. }
  58. }
  59. saveUserSession(app, r, w)
  60. return f, nil
  61. }
  62. func addSessionFlash(app *App, w http.ResponseWriter, r *http.Request, m string, session *sessions.Session) error {
  63. var err error
  64. if session == nil {
  65. session, err = app.sessionStore.Get(r, cookieName)
  66. }
  67. if err != nil {
  68. log.Error("Unable to add flash '%s': %v", m, err)
  69. return err
  70. }
  71. session.AddFlash(m)
  72. saveUserSession(app, r, w)
  73. return nil
  74. }
  75. func getUserAndSession(app *App, r *http.Request) (*User, *sessions.Session) {
  76. session, err := app.sessionStore.Get(r, cookieName)
  77. if err == nil {
  78. // Got the currently logged-in user
  79. val := session.Values[cookieUserVal]
  80. var u = &User{}
  81. var ok bool
  82. if u, ok = val.(*User); ok {
  83. return u, session
  84. }
  85. }
  86. return nil, nil
  87. }
  88. func getUserSession(app *App, r *http.Request) *User {
  89. u, _ := getUserAndSession(app, r)
  90. return u
  91. }
  92. func saveUserSession(app *App, r *http.Request, w http.ResponseWriter) error {
  93. session, err := app.sessionStore.Get(r, cookieName)
  94. if err != nil {
  95. return ErrInternalCookieSession
  96. }
  97. // Extend the session
  98. session.Options.MaxAge = int(sessionLength)
  99. // Remove any information that accidentally got added
  100. // FIXME: find where Plan information is getting saved to cookie.
  101. val := session.Values[cookieUserVal]
  102. var u = &User{}
  103. var ok bool
  104. if u, ok = val.(*User); ok {
  105. session.Values[cookieUserVal] = u.Cookie()
  106. }
  107. err = session.Save(r, w)
  108. if err != nil {
  109. log.Error("Couldn't saveUserSession: %v", err)
  110. }
  111. return err
  112. }
  113. func getFullUserSession(app *App, r *http.Request) *User {
  114. u := getUserSession(app, r)
  115. if u == nil {
  116. return nil
  117. }
  118. u, _ = app.db.GetUserByID(u.ID)
  119. return u
  120. }