package dbx import ( "context" "fmt" "strings" "time" "gorm.io/driver/postgres" "gorm.io/gorm" ) func GetPostgresDB(cfg *DBConfig, dbName string) *gorm.DB { db, err := getPostgresInstance(cfg, dbName) if err != nil { panic("failed to connect to Postgres database: " + err.Error()) } return db } func getPostgresInstance(cfg *DBConfig, dbName string) (*gorm.DB, error) { targetDSN := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s", cfg.Host, cfg.Port, cfg.User, cfg.Password, dbName, cfg.Sslmode) // 先尝试直接连接目标库 if db, err := gorm.Open(postgres.Open(targetDSN), &gorm.Config{}); err == nil { return db, nil } // 连接 admin DB(postgres) adminDSN := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=postgres sslmode=%s", cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.Sslmode) adminDB, err := gorm.Open(postgres.Open(adminDSN), &gorm.Config{}) if err != nil { return nil, fmt.Errorf("connect admin postgres failed: %w", err) } // 确保关闭底层连接 if sqlDB, e := adminDB.DB(); e == nil { defer sqlDB.Close() } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // 检查是否存在 var count int64 if err := adminDB.WithContext(ctx).Raw("SELECT count(*) FROM pg_database WHERE datname = ?", dbName).Scan(&count).Error; err != nil { return nil, fmt.Errorf("check database existence failed: %w", err) } if count == 0 { ident := escapePostgresIdentifier(dbName) createSQL := fmt.Sprintf("CREATE DATABASE %s", ident) if err := adminDB.WithContext(ctx).Exec(createSQL).Error; err != nil { return nil, fmt.Errorf("create database %s failed: %w", dbName, err) } } // 尝试连接目标数据库 db2, err := gorm.Open(postgres.Open(targetDSN), &gorm.Config{}) if err != nil { return nil, fmt.Errorf("connect target after create failed: %w", err) } return db2, nil } // 辅助:安全转义 Postgres 标识符(用双引号并把 " 替换为 "") func escapePostgresIdentifier(s string) string { s = strings.ReplaceAll(s, `"`, `""`) return `"` + s + `"` }