84 lines
2.1 KiB
Go
84 lines
2.1 KiB
Go
package test
|
|
|
|
import (
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"gitea.qpismont.fr/qpismont/trepa/internal/core"
|
|
"github.com/jmoiron/sqlx"
|
|
)
|
|
|
|
func SetupTestDB(test *testing.T, rootPath string) *sqlx.DB {
|
|
core.LoadEnvVars(rootPath + "/.env")
|
|
|
|
dbHost := core.MustGetEnvVar("TEST_DB_HOST")
|
|
dbPort := core.MustGetEnvVar("TEST_DB_PORT")
|
|
dbUser := core.MustGetEnvVar("TEST_DB_USER")
|
|
dbPassword := core.MustGetEnvVar("TEST_DB_PASSWORD")
|
|
dbName := core.MustGetEnvVar("TEST_DB_NAME")
|
|
|
|
dbExecute := initTestDB(test, dbHost, dbPort, dbUser, dbPassword, "postgres")
|
|
resetTestDB(dbExecute, test, dbName)
|
|
dbExecute.Close()
|
|
|
|
dbTest := initTestDB(test, dbHost, dbPort, dbUser, dbPassword, dbName)
|
|
executeMigrations(dbTest, test, rootPath)
|
|
executeFixtures(dbTest, test, rootPath)
|
|
|
|
return dbTest
|
|
}
|
|
|
|
func initTestDB(t *testing.T, dbHost, dbPort, dbUser, dbPassword, dbName string) *sqlx.DB {
|
|
dbURL := core.ComputeDBURL(dbHost, dbPort, dbUser, dbPassword, dbName)
|
|
db, err := core.SetupDB(dbURL)
|
|
if err != nil {
|
|
t.Fatalf("Failed to connect to test database: %v", err)
|
|
}
|
|
|
|
return db
|
|
}
|
|
|
|
func resetTestDB(db *sqlx.DB, t *testing.T, dbName string) {
|
|
_, err := db.Exec("DROP DATABASE IF EXISTS " + dbName + " WITH (FORCE);")
|
|
if err != nil {
|
|
t.Fatalf("Failed to drop test database: %v", err)
|
|
}
|
|
|
|
_, err = db.Exec("CREATE DATABASE " + dbName + ";")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test database: %v", err)
|
|
}
|
|
}
|
|
func executeMigrations(db *sqlx.DB, t *testing.T, rootPath string) {
|
|
rootPath = filepath.Join(rootPath, "migrations")
|
|
|
|
executeSqlFolder(db, t, rootPath)
|
|
}
|
|
|
|
func executeFixtures(db *sqlx.DB, t *testing.T, rootPath string) {
|
|
rootPath = filepath.Join(rootPath, "test", "fixtures")
|
|
|
|
executeSqlFolder(db, t, rootPath)
|
|
}
|
|
|
|
func executeSqlFolder(db *sqlx.DB, t *testing.T, folder string) {
|
|
files, err := filepath.Glob(folder + "/*.sql")
|
|
if err != nil {
|
|
t.Fatalf("Failed to read sql folder: %v", err)
|
|
}
|
|
|
|
for _, file := range files {
|
|
t.Log("Executing " + file)
|
|
|
|
sql, err := os.ReadFile(file)
|
|
if err != nil {
|
|
t.Fatalf("Failed to read sql file: %v", err)
|
|
}
|
|
|
|
_, err = db.Exec(string(sql))
|
|
if err != nil {
|
|
t.Fatalf("Failed to execute sql file: %v", err)
|
|
}
|
|
}
|
|
}
|