|
5 | 5 | package main
|
6 | 6 |
|
7 | 7 | import (
|
| 8 | + "bufio" |
8 | 9 | "flag"
|
9 | 10 | "fmt"
|
10 | 11 | "os"
|
| 12 | + "path/filepath" |
| 13 | + "strings" |
| 14 | + "syscall" |
11 | 15 |
|
12 | 16 | "github.com/aitjcize/Overlord/overlord"
|
| 17 | + "golang.org/x/term" |
13 | 18 | )
|
14 | 19 |
|
15 |
| -var bindAddr = flag.String("bind", "0.0.0.0", "specify alternate bind address") |
16 |
| -var port = flag.Int("port", 0, |
17 |
| - "alternate port listen instead of standard ports (http:80, https:443)") |
18 |
| -var lanDiscInterface = flag.String("lan-disc-iface", "", |
19 |
| - "the network interface used for broadcasting LAN discovery packets") |
20 |
| -var noLanDisc = flag.Bool("no-lan-disc", false, |
21 |
| - "disable LAN discovery broadcasting") |
22 |
| -var tlsCerts = flag.String("tls", "", |
23 |
| - "TLS certificates in the form of 'cert.pem,key.pem'. Empty to disable.") |
24 |
| -var noLinkTLS = flag.Bool("no-link-tls", false, |
25 |
| - "disable TLS between ghost and overlord. Only valid when TLS is enabled.") |
26 |
| -var htpasswdPath = flag.String("htpasswd-path", "overlord.htpasswd", |
27 |
| - "the path to the .htpasswd file. Required for authentication.") |
28 |
| -var jwtSecretPath = flag.String("jwt-secret-path", "jwt-secret", |
29 |
| - "Path to the file containing the JWT secret. Required for authentication.") |
| 20 | +var ( |
| 21 | + bindAddr = flag.String("bind", |
| 22 | + "0.0.0.0", "specify alternate bind address") |
| 23 | + port = flag.Int("port", |
| 24 | + 0, "alternate port listen instead of standard ports (http:80, https:443)") |
| 25 | + lanDiscInterface = flag.String("lan-disc-iface", |
| 26 | + "", "the network interface used for broadcasting LAN discovery packets") |
| 27 | + noLanDisc = flag.Bool("no-lan-disc", |
| 28 | + false, "disable LAN discovery broadcasting") |
| 29 | + tlsCerts = flag.String("tls", |
| 30 | + "", "TLS certificates in the form of 'cert.pem,key.pem'. Empty to disable.") |
| 31 | + noLinkTLS = flag.Bool("no-link-tls", |
| 32 | + false, "disable TLS between ghost and overlord. Only valid when TLS is enabled.") |
| 33 | + dbPath = flag.String("db-path", |
| 34 | + "overlord.db", "the path to the SQLite database file for user, group, and authentication data") |
| 35 | + initializeDB = flag.Bool("init", |
| 36 | + false, "Initialize the database with a custom admin user and password and generate a JWT secret") |
| 37 | + adminUser = flag.String("admin-user", |
| 38 | + "", "Admin username for database initialization (only used with -init)") |
| 39 | + adminPass = flag.String("admin-pass", |
| 40 | + "", "Admin password for database initialization (only used with -init)") |
| 41 | +) |
30 | 42 |
|
31 | 43 | func usage() {
|
32 |
| - fmt.Fprintf(os.Stderr, "Usage: overlordd [OPTIONS]\n") |
| 44 | + prog := filepath.Base(os.Args[0]) |
| 45 | + fmt.Fprintf(os.Stderr, "Usage: %s [options]\n", prog) |
| 46 | + fmt.Fprintln(os.Stderr, "Options:") |
33 | 47 | flag.PrintDefaults()
|
34 |
| - os.Exit(2) |
| 48 | + os.Exit(1) |
| 49 | +} |
| 50 | + |
| 51 | +func promptForInput(prompt string) string { |
| 52 | + reader := bufio.NewReader(os.Stdin) |
| 53 | + fmt.Print(prompt) |
| 54 | + input, _ := reader.ReadString('\n') |
| 55 | + return strings.TrimSpace(input) |
| 56 | +} |
| 57 | + |
| 58 | +func promptForPassword(prompt string) string { |
| 59 | + fmt.Print(prompt) |
| 60 | + password, err := term.ReadPassword(int(syscall.Stdin)) |
| 61 | + fmt.Println() // Add a newline after the password input |
| 62 | + |
| 63 | + if err != nil { |
| 64 | + panic(err) |
| 65 | + } |
| 66 | + return string(password) |
| 67 | +} |
| 68 | + |
| 69 | +func initializeDatabase(dbPath string) error { |
| 70 | + adminUsername := *adminUser |
| 71 | + adminPassword := *adminPass |
| 72 | + |
| 73 | + // If admin username is not provided via command line, prompt for it |
| 74 | + if adminUsername == "" { |
| 75 | + adminUsername = promptForInput("Enter admin username [admin]: ") |
| 76 | + if adminUsername == "" { |
| 77 | + adminUsername = "admin" |
| 78 | + } |
| 79 | + } |
| 80 | + |
| 81 | + // If admin password is not provided via command line, prompt for it |
| 82 | + if adminPassword == "" { |
| 83 | + for { |
| 84 | + adminPassword = promptForPassword("Enter admin password: ") |
| 85 | + if adminPassword == "" { |
| 86 | + fmt.Println("Password cannot be empty, please try again.") |
| 87 | + continue |
| 88 | + } |
| 89 | + break |
| 90 | + } |
| 91 | + } else if adminPassword == "" { |
| 92 | + return fmt.Errorf("password cannot be empty") |
| 93 | + } |
| 94 | + |
| 95 | + dbManager := overlord.NewDatabaseManager(dbPath) |
| 96 | + |
| 97 | + err := dbManager.Initialize(adminUsername, adminPassword) |
| 98 | + if err != nil { |
| 99 | + return fmt.Errorf("failed to initialize database: %v", err) |
| 100 | + } |
| 101 | + |
| 102 | + fmt.Println("Database initialization complete.") |
| 103 | + fmt.Println("You can now start the server without the -init flag.") |
| 104 | + return nil |
| 105 | +} |
| 106 | + |
| 107 | +func checkDatabaseInitialized(dbPath string) bool { |
| 108 | + // Simple check - if the file exists and has data, consider it initialized |
| 109 | + info, err := os.Stat(dbPath) |
| 110 | + if err != nil || info.Size() == 0 { |
| 111 | + return false |
| 112 | + } |
| 113 | + return true |
35 | 114 | }
|
36 | 115 |
|
37 | 116 | func main() {
|
38 |
| - flag.Usage = usage |
39 | 117 | flag.Parse()
|
40 | 118 |
|
41 |
| - // Validate required flags |
42 |
| - if *htpasswdPath == "" { |
43 |
| - fmt.Fprintf(os.Stderr, "Error: -htpasswd-path is required\n") |
| 119 | + if len(flag.Args()) > 0 { |
| 120 | + fmt.Fprintf(os.Stderr, "Error: unknown argument: %s\n", flag.Args()[0]) |
44 | 121 | usage()
|
45 | 122 | }
|
46 |
| - if *jwtSecretPath == "" { |
47 |
| - fmt.Fprintf(os.Stderr, "Error: -jwt-secret-path is required\n") |
| 123 | + |
| 124 | + // Validate required flags |
| 125 | + if *dbPath == "" { |
| 126 | + fmt.Fprintf(os.Stderr, "Error: -db-path is required\n") |
48 | 127 | usage()
|
49 | 128 | }
|
50 | 129 |
|
| 130 | + // Initialize the database if requested |
| 131 | + if *initializeDB { |
| 132 | + if checkDatabaseInitialized(*dbPath) { |
| 133 | + fmt.Fprintf(os.Stderr, "Error: Database already initialized\n") |
| 134 | + os.Exit(1) |
| 135 | + } |
| 136 | + if err := initializeDatabase(*dbPath); err != nil { |
| 137 | + fmt.Fprintf(os.Stderr, "Error initializing database: %v\n", err) |
| 138 | + os.Exit(1) |
| 139 | + } |
| 140 | + os.Exit(0) |
| 141 | + } |
| 142 | + |
| 143 | + // Check if the database is initialized |
| 144 | + if !checkDatabaseInitialized(*dbPath) { |
| 145 | + fmt.Fprintf(os.Stderr, "Error: Database not initialized. Run with -init to initialize\n") |
| 146 | + os.Exit(1) |
| 147 | + } |
| 148 | + |
51 | 149 | overlord.StartOverlord(*bindAddr, *port, *lanDiscInterface, !*noLanDisc,
|
52 |
| - *tlsCerts, !*noLinkTLS, *htpasswdPath, *jwtSecretPath) |
| 150 | + *tlsCerts, !*noLinkTLS, *dbPath) |
53 | 151 | }
|
0 commit comments