package store import ( "context" "database/sql" "fmt" "io/fs" "sort" "strings" _ "modernc.org/sqlite" "embed" ) //go:embed migrations/*.sql var migrationsFS embed.FS // Open opens (or creates) the SQLite database at path and runs migrations. func Open(path string) (*sql.DB, error) { dsn := fmt.Sprintf( "file:%s?_pragma=journal_mode(WAL)&_pragma=busy_timeout(5000)&_pragma=foreign_keys(ON)&_pragma=synchronous(NORMAL)", path, ) db, err := sql.Open("sqlite", dsn) if err != nil { return nil, fmt.Errorf("open sqlite: %w", err) } db.SetMaxOpenConns(1) // SQLite WAL: single writer if err := migrate(db); err != nil { db.Close() return nil, fmt.Errorf("migrate: %w", err) } return db, nil } // migrate applies all SQL migration files in migrations/ in filename order. // user_version tracks the last applied migration index (1-based). func migrate(db *sql.DB) error { var version int if err := db.QueryRow("PRAGMA user_version").Scan(&version); err != nil { return err } entries, err := fs.ReadDir(migrationsFS, "migrations") if err != nil { return fmt.Errorf("read migrations dir: %w", err) } // Sort by name so 001_*, 002_* apply in order. sort.Slice(entries, func(i, j int) bool { return entries[i].Name() < entries[j].Name() }) ctx := context.Background() for i, entry := range entries { migrationNum := i + 1 // 1-based if version >= migrationNum { continue } data, err := migrationsFS.ReadFile("migrations/" + entry.Name()) if err != nil { return fmt.Errorf("read %s: %w", entry.Name(), err) } if err := applySchema(db, ctx, string(data)); err != nil { return fmt.Errorf("migration %d (%s): %w", migrationNum, entry.Name(), err) } if _, err := db.ExecContext(ctx, fmt.Sprintf("PRAGMA user_version = %d", migrationNum)); err != nil { return fmt.Errorf("set user_version = %d: %w", migrationNum, err) } } return nil } func applySchema(db *sql.DB, ctx context.Context, sql string) error { for _, stmt := range splitStatements(sql) { stmt = strings.TrimSpace(stmt) if stmt == "" { continue } if _, err := db.ExecContext(ctx, stmt); err != nil { return fmt.Errorf("exec %q: %w", stmt[:min(len(stmt), 60)], err) } } return nil } func splitStatements(sql string) []string { // Only process statements up to the "-- +migrate Down" marker. var lines []string for _, line := range strings.Split(sql, "\n") { trimmed := strings.TrimSpace(line) if trimmed == "-- +migrate Down" { break } if strings.HasPrefix(trimmed, "-- +migrate") { continue } lines = append(lines, line) } joined := strings.Join(lines, "\n") // Split on semicolons parts := strings.Split(joined, ";") return parts } func min(a, b int) int { if a < b { return a } return b }