fleet/tools/snapshot/snapshot.go
2025-12-05 11:08:40 -05:00

288 lines
7.3 KiB
Go

package main
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"slices"
"time"
"github.com/manifoldco/promptui"
// Force promptui to use our newer x/sys package,
// which doesn't have security vulnerabilities.
_ "golang.org/x/sys/unix"
)
// Represents a snapshot.
// Snapshots are stored in folders named after the snapshot name.
// Each snapshot folder contains a db.sql.gz file.
type Snapshot struct {
Name string
Date time.Time
Path string // The directory containing the snapshot.
}
func (s Snapshot) DateStr() string {
return s.Date.Format("Jan 02, 2006 03:04:05 PM")
}
// Which command to run.
type Command int
const (
CMD_SNAPSHOT Command = iota
CMD_RESTORE
)
func main() {
// Ensure there's a command specified.
// TODO - as we add more commands, we should probably use a library like spf13/cobra.
if len(os.Args) < 2 {
fmt.Println("Please specify whether to (b)ackup or (r)estore.")
os.Exit(1)
}
// Determine the command.
var command Command
switch os.Args[1] {
case "s", "snap", "snapshot":
command = CMD_SNAPSHOT
case "r", "restore":
command = CMD_RESTORE
default:
fmt.Println("Please specify whether to (s)snapshot or (r)estore.")
}
// Determine the path to the top-level directory (where the Makefile resides).
repoRoot, err := getRepoRoot()
if err != nil {
fmt.Printf("Error determining repo root: %v\n", err)
os.Exit(1)
}
// Change the working directory to the repo root.
if err := os.Chdir(repoRoot); err != nil {
fmt.Printf("Error changing directory to repo root: %v\n", err)
os.Exit(1)
}
// Get the home directory so we can get the snapshots dir.
homedir, err := os.UserHomeDir()
if err != nil {
fmt.Printf("Could not determine home directory: %v\n", err)
return
}
// Run the command.
switch command {
case CMD_SNAPSHOT:
snapshot(homedir)
case CMD_RESTORE:
restore(homedir)
}
}
// Restore a snapshot.
func restore(homedir string) error {
snapshotsDir := filepath.Join(homedir, ".fleet", "snapshots")
_, err := os.Lstat(snapshotsDir)
if err != nil {
if os.IsNotExist(err) {
fmt.Printf("You don't currently have any snapshots.\n")
} else {
// Handle other PathError-specific cases
fmt.Printf("Error reading snapshots directory (%s): %v\n", snapshotsDir, err)
}
return err
}
// Walk the ~/.fleet/snapshots directory if it exists.
dirEntries, err := os.ReadDir(snapshotsDir)
var snapshots []Snapshot
// var lastSnapshotName []byte
for _, entry := range dirEntries {
if entry.IsDir() {
// Ensure there's a db backup file.
subdirEntries, err := os.ReadDir(filepath.Join(snapshotsDir, entry.Name()))
if err != nil {
continue
}
for _, subentry := range subdirEntries {
dbBackupFile := filepath.Join(snapshotsDir, entry.Name(), subentry.Name())
snapshotName := subentry.Name()
if snapshotName == "db.sql.gz" {
snapshotName = entry.Name()
}
dbBackupFileInfo, err := os.Lstat(dbBackupFile)
if err != nil {
continue
}
snapshot := Snapshot{
Name: snapshotName,
Date: dbBackupFileInfo.ModTime(),
Path: dbBackupFile,
}
snapshots = append(snapshots, snapshot)
}
}
}
slices.SortFunc(snapshots, func(a, b Snapshot) int {
return b.Date.Compare(a.Date)
})
// Set up and run the "Select snapshot" UI.
templates := &promptui.SelectTemplates{
Label: " {{ .Name }}",
Active: "• {{ .Name }} ({{ .DateStr }})",
Inactive: " {{ .Name }} ({{ .DateStr }})",
Selected: " {{ .Name }} ({{ .DateStr }})",
}
prompt := promptui.Select{
Label: "Select snapshot to restore",
Items: snapshots,
Templates: templates,
Size: 10,
}
index, _, err := prompt.Run()
if err != nil {
fmt.Printf("Prompt failed %v\n", err)
return err
}
// Prepare the restore script with the selected snapshot.
cmd := exec.Command("./tools/backup_db/restore.sh", snapshots[index].Path)
// Use the same stdin, stdout, and stderr as the parent process.
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
// Run the command.
err = cmd.Run()
output, _ := cmd.CombinedOutput()
fmt.Println(string(output))
if err != nil {
fmt.Printf("Error: %v\n", err)
return err
}
// Write the selected snapshot name to the "last_snapshot" file.
err = os.WriteFile(filepath.Join(snapshotsDir, "last_snapshot"), []byte(snapshots[index].Name), 0o644)
if err != nil {
fmt.Printf("Error writing last snapshot file: %v\n", err)
}
return nil
}
// Create a snapshot.
func snapshot(homedir string) error {
snapshotsDir := filepath.Join(homedir, ".fleet", "snapshots")
// Ensure the snapshots directory exists.
_, err := os.Lstat(snapshotsDir)
if err != nil {
// If the directory doesn't exist, create it.
if os.IsNotExist(err) {
err = os.Mkdir(snapshotsDir, 0o755)
if err != nil {
fmt.Printf("Error creating snapshots directory (%s): %v\n", snapshotsDir, err)
}
} else {
fmt.Printf("Error reading snapshots directory (%s): %v\n", snapshotsDir, err)
}
return err
}
// Prompt the user for a name for the snapshot.
prompt := promptui.Prompt{
Label: "Enter a name for the snapshot",
}
result, err := prompt.Run()
if err != nil {
fmt.Printf("Prompt failed %v\n", err)
return err
}
snapshotPath := filepath.Join(snapshotsDir, result)
// Check if the snapshot already exists.
_, err = os.Lstat(snapshotPath)
// If the file exists, prompt the user to overwrite it.
if err == nil {
prompt := promptui.Prompt{
Label: "This snapshot already exists. Overwrite? (Y/n)",
}
result, err := prompt.Run()
if err != nil {
fmt.Printf("Prompt failed %v\n", err)
return err
}
switch result {
case "Y", "y", "":
err = os.Remove(filepath.Join(snapshotPath, "db.sql.gz"))
if err != nil {
fmt.Printf("Error removing existing snapshot (%s): %v\n", result, err)
return err
}
default:
return nil
}
} else if !os.IsNotExist(err) {
fmt.Printf("Error checking for existing snapshot (%s): %v\n", result, err)
return err
}
// Create the snapshot directory
err = os.Mkdir(snapshotPath, 0o755)
if err != nil && !os.IsExist(err) {
fmt.Printf("Error creating snapshot directory (%s): %v\n", snapshotPath, err)
}
// Prepare the backup script with the snapshot path.
cmd := exec.Command("./tools/backup_db/backup.sh", filepath.Join(snapshotPath, "db.sql.gz"))
// Use the same stdin, stdout, and stderr as the parent process.
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
// Run the command.
err = cmd.Run()
output, _ := cmd.CombinedOutput()
fmt.Println(string(output))
if err != nil {
fmt.Printf("Error: %v\n", err)
return err
}
// Write the selected snapshot name to the "last_snapshot" file.
err = os.WriteFile(filepath.Join(snapshotsDir, "last_snapshot"), []byte(result), 0o644)
if err != nil {
fmt.Printf("Error writing last snapshot file: %v\n", err)
}
return nil
}
// getRepoRoot determines the repo root (top-level directory) relative to this binary.
func getRepoRoot() (string, error) {
// Get the path of the currently executing binary
executable, err := os.Executable()
if err != nil {
return "", err
}
// Get the path of the binary, following symlinks.
execDir, err := filepath.EvalSymlinks(executable)
if err != nil {
return "", err
}
// Get the directory.
execDir = filepath.Dir(execDir)
// Compute the repo root relative to the binary's location.
repoRoot := filepath.Join(execDir, "../")
return filepath.Abs(repoRoot)
}