fleet/server/goose/migration.go
2024-10-18 12:38:26 -05:00

219 lines
5.3 KiB
Go

package goose
import (
"database/sql"
"errors"
"fmt"
"log"
"path/filepath"
"regexp"
"strconv"
"strings"
"text/template"
"time"
)
type MigrationRecord struct {
VersionId int64
TStamp time.Time
IsApplied bool // was this a result of up() or down()
}
type Migration struct {
Version int64
Next int64 // next version, or -1 if none
Previous int64 // previous version, -1 if none
Source string // path to .sql script
UpFn func(*sql.Tx) error // Up go migration function
DownFn func(*sql.Tx) error // Down go migration function
}
const (
migrateUp = true
migrateDown = !migrateUp
)
func (m *Migration) String() string {
return fmt.Sprint(m.Source)
}
func (c *Client) runMigration(db *sql.DB, m *Migration, direction bool) error {
switch filepath.Ext(m.Source) {
case ".sql":
if err := c.runSQLMigration(db, m.Source, m.Version, direction); err != nil {
return fmt.Errorf("failed to run migration: %w", err)
}
case ".go":
name, date := parseNameAndDate(m.Source)
log.Printf("[%s] %s\n", date, name)
tx, err := db.Begin()
if err != nil {
log.Fatal("db.Begin: ", err)
}
fn := m.UpFn
if !direction {
fn = m.DownFn
}
if fn != nil {
if err := fn(tx); err != nil {
tx.Rollback() //nolint:errcheck
log.Fatalf("FAIL %s (%v), quitting migration.", filepath.Base(m.Source), err)
return err
}
}
if err = c.FinalizeMigration(tx, direction, m.Version); err != nil {
log.Fatalf("error finalizing migration %s, quitting. (%v)", filepath.Base(m.Source), err)
}
}
return nil
}
var (
upperReplace = regexp.MustCompile("([a-z])([A-Z])") // e.g. UpdateBuiltin -> Update Builtin
allUpperWordsReplace = regexp.MustCompile("([A-Z]+)([A-Z][a-z])") // e.g. IDIn -> ID In
)
func parseNameAndDate(source string) (name string, date string) {
parts := strings.SplitN(strings.TrimSuffix(filepath.Base(source), ".go"), "_", 2)
// Stripping seconds [:8] because Fleet developers add seconds when re-arranging new migrations
// e.g.: 2022/10/10 15:43:46 fail to parse time: parsing time "20201021104586": second out of range
datePart := parts[0][:8]
mt, err := time.Parse("20060102", datePart)
if err != nil {
log.Fatalf("fail to parse time: %s", err)
}
name = upperReplace.ReplaceAllString(parts[1], "$1 $2") // add spaces in the filename
name = allUpperWordsReplace.ReplaceAllString(name, "$1 $2") // add spaces in the filename
date = mt.Format("2006-01-02")
return
}
// look for migration scripts with names in the form:
//
// XXX_descriptivename.ext
//
// where XXX specifies the version number
// and ext specifies the type of migration
func NumericComponent(name string) (int64, error) {
base := filepath.Base(name)
if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" {
return 0, errors.New("not a recognized migration file type")
}
idx := strings.Index(base, "_")
if idx < 0 {
return 0, errors.New("no separator found")
}
n, e := strconv.ParseInt(base[:idx], 10, 64)
if e == nil && n <= 0 {
return 0, errors.New("migration IDs must be greater than zero")
}
return n, e
}
func CreateMigration(name, migrationType, dir string, t time.Time) ([]string, error) {
if migrationType != "go" && migrationType != "sql" {
return nil, errors.New("migration type must be 'go' or 'sql'")
}
timestamp := t.Format("20060102150405")
filename := fmt.Sprintf("%s_%s.%s", timestamp, name, migrationType)
fpath := filepath.Join(dir, filename)
tmpl := sqlMigrationTemplate
if migrationType == "go" {
tmpl = goSqlMigrationTemplate
}
var paths []string
migrationPath, err := writeTemplateToFile(fpath, tmpl, timestamp)
if err != nil {
return nil, err
}
paths = append(paths, migrationPath)
if migrationType == "go" {
fpath := strings.Replace(filepath.Join(dir, filename), ".go", "_test.go", 1)
migrationTestPath, err := writeTemplateToFile(fpath, goSqlMigrationTestTemplate, timestamp)
if err != nil {
return nil, err
}
paths = append(paths, migrationTestPath)
}
return paths, nil
}
// Update the version table for the given migration,
// and finalize the transaction.
func (c *Client) FinalizeMigration(tx *sql.Tx, direction bool, v int64) error {
// XXX: drop goose_db_version table on some minimum version number?
stmt := c.Dialect.insertVersionSql(c.TableName)
if _, err := tx.Exec(stmt, v, direction); err != nil {
tx.Rollback() //nolint:errcheck
return err
}
return tx.Commit()
}
var sqlMigrationTemplate = template.Must(template.New("goose.sql-migration").Parse(`
-- +goose Up
-- SQL in section 'Up' is executed when this migration is applied
-- +goose Down
-- SQL section 'Down' is executed when this migration is rolled back
`))
var goSqlMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(`
package tables
import (
"database/sql"
)
func init() {
MigrationClient.AddMigration(Up_{{.}}, Down_{{.}})
}
func Up_{{.}}(tx *sql.Tx) error {
return nil
}
func Down_{{.}}(tx *sql.Tx) error {
return nil
}
`))
var goSqlMigrationTestTemplate = template.Must(template.New("goose.go-migration").Parse(`
package tables
import "testing"
func TestUp_{{.}}(t *testing.T) {
db := applyUpToPrev(t)
//
// Insert data to test the migration
//
// ...
// Apply current migration.
applyNext(t, db)
//
// Check data, insert new entries, e.g. to verify migration is safe.
//
// ...
}`))