A clean, Markdown-based publishing platform made for writers. Write together, and build a community. https://writefreely.org
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

154 lines
3.6 KiB

  1. package writefreely
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/gob"
  6. "errors"
  7. "fmt"
  8. uuid "github.com/nu7hatch/gouuid"
  9. "github.com/stretchr/testify/assert"
  10. "math/rand"
  11. "os"
  12. "strings"
  13. "testing"
  14. "time"
  15. )
  16. var testDB *sql.DB
  17. type ScopedTestBody func(*sql.DB)
  18. // TestMain provides testing infrastructure within this package.
  19. func TestMain(m *testing.M) {
  20. rand.Seed(time.Now().UTC().UnixNano())
  21. gob.Register(&User{})
  22. if runMySQLTests() {
  23. var err error
  24. testDB, err = initMySQL(os.Getenv("WF_USER"), os.Getenv("WF_PASSWORD"), os.Getenv("WF_DB"), os.Getenv("WF_HOST"))
  25. if err != nil {
  26. fmt.Println(err)
  27. return
  28. }
  29. }
  30. code := m.Run()
  31. if runMySQLTests() {
  32. if closeErr := testDB.Close(); closeErr != nil {
  33. fmt.Println(closeErr)
  34. }
  35. }
  36. os.Exit(code)
  37. }
  38. func runMySQLTests() bool {
  39. return len(os.Getenv("TEST_MYSQL")) > 0
  40. }
  41. func initMySQL(dbUser, dbPassword, dbName, dbHost string) (*sql.DB, error) {
  42. if dbUser == "" || dbPassword == "" {
  43. return nil, errors.New("database user or password not set")
  44. }
  45. if dbHost == "" {
  46. dbHost = "localhost"
  47. }
  48. if dbName == "" {
  49. dbName = "writefreely"
  50. }
  51. dsn := fmt.Sprintf("%s:%s@tcp(%s:3306)/%s?charset=utf8mb4&parseTime=true", dbUser, dbPassword, dbHost, dbName)
  52. db, err := sql.Open("mysql", dsn)
  53. if err != nil {
  54. return nil, err
  55. }
  56. if err := ensureMySQL(db); err != nil {
  57. return nil, err
  58. }
  59. return db, nil
  60. }
  61. func ensureMySQL(db *sql.DB) error {
  62. if err := db.Ping(); err != nil {
  63. return err
  64. }
  65. db.SetMaxOpenConns(250)
  66. return nil
  67. }
  68. // withTestDB provides a scoped database connection.
  69. func withTestDB(t *testing.T, testBody ScopedTestBody) {
  70. db, cleanup, err := newTestDatabase(testDB,
  71. os.Getenv("WF_USER"),
  72. os.Getenv("WF_PASSWORD"),
  73. os.Getenv("WF_DB"),
  74. os.Getenv("WF_HOST"),
  75. )
  76. assert.NoError(t, err)
  77. defer func() {
  78. assert.NoError(t, cleanup())
  79. }()
  80. testBody(db)
  81. }
  82. // newTestDatabase creates a new temporary test database. When a test
  83. // database connection is returned, it will have created a new database and
  84. // initialized it with tables from a reference database.
  85. func newTestDatabase(base *sql.DB, dbUser, dbPassword, dbName, dbHost string) (*sql.DB, func() error, error) {
  86. var err error
  87. var baseName = dbName
  88. if baseName == "" {
  89. row := base.QueryRow("SELECT DATABASE()")
  90. err := row.Scan(&baseName)
  91. if err != nil {
  92. return nil, nil, err
  93. }
  94. }
  95. tUUID, _ := uuid.NewV4()
  96. suffix := strings.Replace(tUUID.String(), "-", "_", -1)
  97. newDBName := baseName + suffix
  98. _, err = base.Exec("CREATE DATABASE " + newDBName)
  99. if err != nil {
  100. return nil, nil, err
  101. }
  102. newDB, err := initMySQL(dbUser, dbPassword, newDBName, dbHost)
  103. if err != nil {
  104. return nil, nil, err
  105. }
  106. rows, err := base.Query("SHOW TABLES IN " + baseName)
  107. if err != nil {
  108. return nil, nil, err
  109. }
  110. for rows.Next() {
  111. var tableName string
  112. if err := rows.Scan(&tableName); err != nil {
  113. return nil, nil, err
  114. }
  115. query := fmt.Sprintf("CREATE TABLE %s LIKE %s.%s", tableName, baseName, tableName)
  116. if _, err := newDB.Exec(query); err != nil {
  117. return nil, nil, err
  118. }
  119. }
  120. cleanup := func() error {
  121. if closeErr := newDB.Close(); closeErr != nil {
  122. fmt.Println(closeErr)
  123. }
  124. _, err = base.Exec("DROP DATABASE " + newDBName)
  125. return err
  126. }
  127. return newDB, cleanup, nil
  128. }
  129. func countRows(t *testing.T, ctx context.Context, db *sql.DB, count int, query string, args ...interface{}) {
  130. var returned int
  131. err := db.QueryRowContext(ctx, query, args...).Scan(&returned)
  132. assert.NoError(t, err, "error executing query %s and args %s", query, args)
  133. assert.Equal(t, count, returned, "unexpected return count %d, expected %d from %s and args %s", returned, count, query, args)
  134. }