|
- package writefreely
-
- import (
- "context"
- "database/sql"
- "encoding/gob"
- "errors"
- "fmt"
- uuid "github.com/nu7hatch/gouuid"
- "github.com/stretchr/testify/assert"
- "math/rand"
- "os"
- "strings"
- "testing"
- "time"
- )
-
- var testDB *sql.DB
-
- type ScopedTestBody func(*sql.DB)
-
- // TestMain provides testing infrastructure within this package.
- func TestMain(m *testing.M) {
- rand.Seed(time.Now().UTC().UnixNano())
- gob.Register(&User{})
-
- if runMySQLTests() {
- var err error
-
- testDB, err = initMySQL(os.Getenv("WF_USER"), os.Getenv("WF_PASSWORD"), os.Getenv("WF_DB"), os.Getenv("WF_HOST"))
- if err != nil {
- fmt.Println(err)
- return
- }
- }
-
- code := m.Run()
- if runMySQLTests() {
- if closeErr := testDB.Close(); closeErr != nil {
- fmt.Println(closeErr)
- }
- }
- os.Exit(code)
- }
-
- func runMySQLTests() bool {
- return len(os.Getenv("TEST_MYSQL")) > 0
- }
-
- func initMySQL(dbUser, dbPassword, dbName, dbHost string) (*sql.DB, error) {
- if dbUser == "" || dbPassword == "" {
- return nil, errors.New("database user or password not set")
- }
- if dbHost == "" {
- dbHost = "localhost"
- }
- if dbName == "" {
- dbName = "writefreely"
- }
-
- dsn := fmt.Sprintf("%s:%s@tcp(%s:3306)/%s?charset=utf8mb4&parseTime=true", dbUser, dbPassword, dbHost, dbName)
- db, err := sql.Open("mysql", dsn)
- if err != nil {
- return nil, err
- }
- if err := ensureMySQL(db); err != nil {
- return nil, err
- }
- return db, nil
- }
-
- func ensureMySQL(db *sql.DB) error {
- if err := db.Ping(); err != nil {
- return err
- }
- db.SetMaxOpenConns(250)
- return nil
- }
-
- // withTestDB provides a scoped database connection.
- func withTestDB(t *testing.T, testBody ScopedTestBody) {
- db, cleanup, err := newTestDatabase(testDB,
- os.Getenv("WF_USER"),
- os.Getenv("WF_PASSWORD"),
- os.Getenv("WF_DB"),
- os.Getenv("WF_HOST"),
- )
- assert.NoError(t, err)
- defer func() {
- assert.NoError(t, cleanup())
- }()
-
- testBody(db)
- }
-
- // newTestDatabase creates a new temporary test database. When a test
- // database connection is returned, it will have created a new database and
- // initialized it with tables from a reference database.
- func newTestDatabase(base *sql.DB, dbUser, dbPassword, dbName, dbHost string) (*sql.DB, func() error, error) {
- var err error
- var baseName = dbName
-
- if baseName == "" {
- row := base.QueryRow("SELECT DATABASE()")
- err := row.Scan(&baseName)
- if err != nil {
- return nil, nil, err
- }
- }
- tUUID, _ := uuid.NewV4()
- suffix := strings.Replace(tUUID.String(), "-", "_", -1)
- newDBName := baseName + suffix
- _, err = base.Exec("CREATE DATABASE " + newDBName)
- if err != nil {
- return nil, nil, err
- }
- newDB, err := initMySQL(dbUser, dbPassword, newDBName, dbHost)
- if err != nil {
- return nil, nil, err
- }
-
- rows, err := base.Query("SHOW TABLES IN " + baseName)
- if err != nil {
- return nil, nil, err
- }
- for rows.Next() {
- var tableName string
- if err := rows.Scan(&tableName); err != nil {
- return nil, nil, err
- }
- query := fmt.Sprintf("CREATE TABLE %s LIKE %s.%s", tableName, baseName, tableName)
- if _, err := newDB.Exec(query); err != nil {
- return nil, nil, err
- }
- }
-
- cleanup := func() error {
- if closeErr := newDB.Close(); closeErr != nil {
- fmt.Println(closeErr)
- }
-
- _, err = base.Exec("DROP DATABASE " + newDBName)
- return err
- }
- return newDB, cleanup, nil
- }
-
- func countRows(t *testing.T, ctx context.Context, db *sql.DB, count int, query string, args ...interface{}) {
- var returned int
- err := db.QueryRowContext(ctx, query, args...).Scan(&returned)
- assert.NoError(t, err, "error executing query %s and args %s", query, args)
- assert.Equal(t, count, returned, "unexpected return count %d, expected %d from %s and args %s", returned, count, query, args)
- }
|