-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.go
More file actions
419 lines (377 loc) · 14.3 KB
/
main.go
File metadata and controls
419 lines (377 loc) · 14.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
package main
import (
"context"
"embed"
"errors"
"flag"
"fmt"
"log/slog"
"net"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
"windshift/internal/auth"
"windshift/internal/database"
"windshift/internal/logger"
"windshift/internal/middleware"
"windshift/internal/server"
"windshift/internal/tui"
"github.com/charmbracelet/ssh"
"github.com/charmbracelet/wish"
wishbubbletea "github.com/charmbracelet/wish/bubbletea"
"github.com/charmbracelet/wish/logging"
)
//go:embed all:frontend/dist
var frontendFiles embed.FS
//go:embed banner.txt
var bannerArt string
// ANSI color for startup banner
const colorTeal = "\033[38;5;37m"
const colorReset = "\033[0m"
// printBanner prints the windshift logo at startup
func printBanner() {
fmt.Print(colorTeal)
fmt.Print(bannerArt)
fmt.Print(colorReset)
fmt.Println()
fmt.Println(colorTeal + " W I N D S H I F T" + colorReset)
fmt.Println(" Work Management Platform")
fmt.Println()
}
func main() {
// Command line flags
var port string
var dbPath string
var postgresConn string
var attachmentPath string
var disableCSRF bool
var allowedHosts string
var allowedPort string
var useProxy bool
var baseURL string
var additionalProxies string
var enableSSH bool
var sshPort string
var sshHost string
var sshKeyPath string
var maxReadConns int
var maxWriteConns int
var logLevel string
var logFormat string
var tlsCertPath string
var tlsKeyPath string
var disablePlugins bool
var pluginDir string
var enableAdminFallback bool
var llmProvidersFile string
var aiPromptsDir string
flag.StringVar(&port, "port", "8080", "Port to run the HTTP server on")
flag.StringVar(&port, "p", "8080", "Port to run the HTTP server on (shorthand)")
flag.StringVar(&dbPath, "db", "windshift.db", "Database file path (SQLite)")
flag.StringVar(&postgresConn, "postgres-connection-string", "", "PostgreSQL connection string (e.g., postgresql://user:password@localhost:5432/windshift)")
flag.StringVar(&postgresConn, "pg-conn", "", "PostgreSQL connection string (shorthand)")
flag.StringVar(&attachmentPath, "attachment-path", "", "Path to store attachments (enables attachment feature if specified)")
flag.BoolVar(&disableCSRF, "no-csrf", false, "Disable CSRF protection (for development only)")
flag.StringVar(&allowedHosts, "allowed-hosts", "", "Comma-separated list of allowed hostnames for CSRF (e.g., 192.168.1.30,myserver.local)")
flag.StringVar(&allowedPort, "allowed-port", "", "Port for CORS/WebAuthn trusted origins (defaults to server port, useful for reverse proxy setups)")
flag.BoolVar(&useProxy, "use-proxy", false, "Enable proxy mode: trust X-Forwarded-Proto from private IPs. WARNING: Only enable when behind a reverse proxy that terminates TLS. Server must NOT be directly accessible from the internet.")
flag.StringVar(&baseURL, "base-url", "", "Public URL for the server (used for email links, SSO redirects, calendar feeds)")
flag.StringVar(&additionalProxies, "additional-proxies", "", "Additional proxy IPs to trust beyond private ranges (requires --use-proxy)")
flag.BoolVar(&enableSSH, "ssh", false, "Enable SSH TUI server")
flag.StringVar(&sshPort, "ssh-port", "23234", "Port to run the SSH server on")
flag.StringVar(&sshHost, "ssh-host", "localhost", "Host for SSH server")
flag.StringVar(&sshKeyPath, "ssh-key", ".ssh/windshift_host_key", "Path to SSH host key file")
flag.IntVar(&maxReadConns, "max-read-conns", 120, "Maximum number of read connections (PocketBase default: 120)")
flag.IntVar(&maxWriteConns, "max-write-conns", 1, "Maximum number of write connections (PocketBase default: 1)")
flag.StringVar(&logLevel, "log-level", "info", "Log level (debug, info, warn, error)")
flag.StringVar(&logFormat, "log-format", "text", "Log format (text, json, logfmt)")
flag.StringVar(&tlsCertPath, "tls-cert", "", "Path to TLS certificate file (enables HTTPS)")
flag.StringVar(&tlsKeyPath, "tls-key", "", "Path to TLS key file (enables HTTPS)")
flag.BoolVar(&disablePlugins, "disable-plugins", false, "Disable the plugin system (prevents loading and uploading plugins)")
flag.BoolVar(&enableAdminFallback, "enable-fallback", false, "Enable admin password fallback for restrictive auth policies (disabled by default for security)")
flag.StringVar(&llmProvidersFile, "llm-providers", "", "Path to custom LLM providers JSON file (overrides built-in provider list)")
flag.StringVar(&aiPromptsDir, "ai-prompts-dir", "", "Directory containing custom AI prompt override files")
flag.Parse()
// Initialize logger early, before any other operations
logger.Init(logLevel, logFormat)
// Print startup banner
printBanner()
// Check for environment variables (common in deployment environments)
if envPort := os.Getenv("PORT"); envPort != "" {
port = envPort
}
if envPostgres := os.Getenv("POSTGRES_CONNECTION_STRING"); envPostgres != "" {
postgresConn = envPostgres
}
if envMaxReadConns := os.Getenv("MAX_READ_CONNS"); envMaxReadConns != "" {
if parsed, err := strconv.Atoi(envMaxReadConns); err == nil {
maxReadConns = parsed
}
}
if envMaxWriteConns := os.Getenv("MAX_WRITE_CONNS"); envMaxWriteConns != "" {
if parsed, err := strconv.Atoi(envMaxWriteConns); err == nil {
maxWriteConns = parsed
}
}
// Additional Docker environment variables (for scratch/distroless images without shell)
if envDBPath := os.Getenv("DB_PATH"); envDBPath != "" {
dbPath = envDBPath
}
if envAttachmentPath := os.Getenv("ATTACHMENT_PATH"); envAttachmentPath != "" {
attachmentPath = envAttachmentPath
}
if envLogLevel := os.Getenv("LOG_LEVEL"); envLogLevel != "" {
logLevel = envLogLevel
logger.Init(logLevel, logFormat) // Re-init with new level
}
if envLogFormat := os.Getenv("LOG_FORMAT"); envLogFormat != "" {
logFormat = envLogFormat
logger.Init(logLevel, logFormat) // Re-init with new format
}
// Build PostgreSQL connection from individual env vars if not already set
if postgresConn == "" && os.Getenv("DB_TYPE") == "postgres" {
pgHost := os.Getenv("POSTGRES_HOST")
if pgHost == "" {
pgHost = "postgres"
}
pgPort := os.Getenv("POSTGRES_PORT")
if pgPort == "" {
pgPort = "5432"
}
pgUser := os.Getenv("POSTGRES_USER")
if pgUser == "" {
pgUser = "windshift"
}
pgPassword := os.Getenv("POSTGRES_PASSWORD")
pgDB := os.Getenv("POSTGRES_DB")
if pgDB == "" {
pgDB = "windshift"
}
postgresConn = fmt.Sprintf("postgresql://%s:%s@%s:%s/%s?sslmode=disable", pgUser, pgPassword, pgHost, pgPort, pgDB)
}
if envAllowedHosts := os.Getenv("ALLOWED_HOSTS"); envAllowedHosts != "" && allowedHosts == "" {
allowedHosts = envAllowedHosts
}
// Read BASE_URL from env if not set via flag
if baseURL == "" {
baseURL = os.Getenv("BASE_URL")
}
if baseURL == "" {
baseURL = os.Getenv("PUBLIC_URL")
}
// SSH environment variables
if os.Getenv("SSH_ENABLED") == "true" {
enableSSH = true
}
if envSSHPort := os.Getenv("SSH_PORT"); envSSHPort != "" {
sshPort = envSSHPort
}
if envSSHHost := os.Getenv("SSH_HOST"); envSSHHost != "" {
sshHost = envSSHHost
}
// Proxy environment variables
// Track whether proxy was explicitly set (flag or env) for auto-detection logic
useProxyExplicit := false
flag.Visit(func(f *flag.Flag) {
if f.Name == "use-proxy" {
useProxyExplicit = true
}
})
if os.Getenv("USE_PROXY") == "true" {
useProxy = true
useProxyExplicit = true
}
if envAdditionalProxies := os.Getenv("ADDITIONAL_PROXIES"); envAdditionalProxies != "" {
additionalProxies = envAdditionalProxies
}
// Plugin system environment variables
if os.Getenv("DISABLE_PLUGINS") == "true" {
disablePlugins = true
}
if envPluginDir := os.Getenv("PLUGIN_DIR"); envPluginDir != "" {
pluginDir = envPluginDir
}
// Admin fallback environment variable
if os.Getenv("ENABLE_ADMIN_FALLBACK") == "true" {
enableAdminFallback = true
}
// LLM endpoint for AI features
llmEndpoint := os.Getenv("LLM_ENDPOINT")
// Logbook sidecar endpoint
logbookEndpoint := os.Getenv("LOGBOOK_ENDPOINT")
// LLM providers file override
if envLLMProviders := os.Getenv("LLM_PROVIDERS_FILE"); envLLMProviders != "" && llmProvidersFile == "" {
llmProvidersFile = envLLMProviders
}
// AI prompts directory override
if envAIPromptsDir := os.Getenv("AI_PROMPTS_DIR"); envAIPromptsDir != "" && aiPromptsDir == "" {
aiPromptsDir = envAIPromptsDir
}
// Notification tuning env vars
var notificationFlushInterval time.Duration
var notificationBatchSize int
var notificationSyncInterval time.Duration
if envVal := os.Getenv("NOTIFICATION_FLUSH_INTERVAL"); envVal != "" {
if parsed, err := time.ParseDuration(envVal); err == nil {
notificationFlushInterval = parsed
} else {
slog.Warn("invalid NOTIFICATION_FLUSH_INTERVAL, using default", "value", envVal, "error", err)
}
}
if envVal := os.Getenv("NOTIFICATION_BATCH_SIZE"); envVal != "" {
if parsed, err := strconv.Atoi(envVal); err == nil {
notificationBatchSize = parsed
} else {
slog.Warn("invalid NOTIFICATION_BATCH_SIZE, using default", "value", envVal, "error", err)
}
}
if envVal := os.Getenv("NOTIFICATION_SYNC_INTERVAL"); envVal != "" {
if parsed, err := time.ParseDuration(envVal); err == nil {
notificationSyncInterval = parsed
} else {
slog.Warn("invalid NOTIFICATION_SYNC_INTERVAL, using default", "value", envVal, "error", err)
}
}
// Setup signal handling for graceful shutdown
shutdownChan := make(chan os.Signal, 1)
signal.Notify(shutdownChan, os.Interrupt, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
// Build server configuration
cfg := server.Config{
Port: port,
DBPath: dbPath,
PostgresConn: postgresConn,
DisableCSRF: disableCSRF,
AttachmentPath: attachmentPath,
AllowedHosts: allowedHosts,
AllowedPort: allowedPort,
UseProxy: useProxy,
UseProxyExplicit: useProxyExplicit,
AdditionalProxies: additionalProxies,
MaxReadConns: maxReadConns,
MaxWriteConns: maxWriteConns,
TLSCertPath: tlsCertPath,
TLSKeyPath: tlsKeyPath,
DisablePlugins: disablePlugins,
PluginDir: pluginDir,
EnableAdminFallback: enableAdminFallback,
BaseURL: baseURL,
LLMEndpoint: llmEndpoint,
LLMProvidersFile: llmProvidersFile,
AIPromptsDir: aiPromptsDir,
LogbookEndpoint: logbookEndpoint,
SSHEnabled: enableSSH,
FrontendFiles: frontendFiles,
ShutdownChan: shutdownChan,
NotificationFlushInterval: notificationFlushInterval,
NotificationBatchSize: notificationBatchSize,
NotificationSyncInterval: notificationSyncInterval,
}
// Resolve security configuration: auto-detect proxy, derive CORS hosts/ports, validate
resolved, err := server.ResolveSecurityConfig(cfg)
if err != nil {
slog.Error("security configuration error", "error", err)
os.Exit(1)
}
resolved.LogDiagnostics()
// Apply resolved values back to config
cfg.UseProxy = resolved.UseProxy
cfg.AllowedHosts = resolved.AllowedHosts
cfg.AllowedPort = resolved.AllowedPort
// Create and start the server
srv, err := server.New(cfg)
if err != nil {
slog.Error("failed to create server", "error", err)
os.Exit(1)
}
if err = srv.Start(); err != nil {
slog.Error("failed to start server", "error", err)
os.Exit(1)
}
// Setup SSH server if enabled
var sshServer *ssh.Server
var sshDB database.Database // Declared at function scope to allow explicit cleanup
if enableSSH {
apiURL := fmt.Sprintf("http://localhost:%d", srv.Port())
// We need to create a separate database connection for SSH
// since the server's DB is internal
var additionalProxyList []string
if additionalProxies != "" {
additionalProxyList = strings.Split(additionalProxies, ",")
}
enableHTTPS := tlsCertPath != "" && tlsKeyPath != ""
// Create a separate DB connection for SSH auth
if postgresConn != "" {
sshDB, err = database.NewDatabase("postgres", postgresConn, maxReadConns, maxWriteConns)
} else {
sshDB, err = database.NewDatabase("sqlite3", dbPath, maxReadConns, maxWriteConns)
}
if err != nil {
slog.Error("failed to create SSH database connection", "error", err)
} else {
sessionManager := auth.NewSessionManager(sshDB, enableHTTPS, useProxy, additionalProxyList, os.Getenv("SESSION_SECRET"))
serverOptions := make([]ssh.Option, 0, 4)
serverOptions = append(serverOptions,
wish.WithAddress(net.JoinHostPort(sshHost, sshPort)),
wish.WithHostKeyPath(sshKeyPath),
)
slog.Info("SSH server starting with public key authentication enabled")
sshAuthMiddleware := middleware.NewSSHAuthMiddleware(sshDB)
serverOptions = append(serverOptions,
wish.WithPublicKeyAuth(sshAuthMiddleware.PublicKeyHandler()),
wish.WithMiddleware(
wishbubbletea.Middleware(tui.NewTUIHandler(apiURL, sessionManager)),
logging.Middleware(),
),
)
s, err := wish.NewServer(serverOptions...)
if err != nil {
slog.Error("failed to create SSH server", "error", err)
} else {
sshServer = s
slog.Info("SSH TUI server starting", "host", sshHost, "port", sshPort)
go func() {
if err := sshServer.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
slog.Error("SSH server error", "error", err)
}
}()
}
}
}
// Log startup info
if enableSSH {
slog.Info("SSH TUI available", "command", "ssh "+sshHost+" -p "+sshPort)
}
// Wait for shutdown signal
<-shutdownChan
slog.Info("shutdown signal received, starting graceful shutdown")
// Shutdown SSH server first
if sshServer != nil {
slog.Info("shutting down SSH server")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
if err := sshServer.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
slog.Error("SSH server shutdown error", "error", err)
} else {
slog.Info("SSH server shutdown complete")
}
cancel()
}
// Shutdown the main server
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
if err := srv.Shutdown(ctx); err != nil {
slog.Error("server shutdown error", "error", err)
cancel()
if sshDB != nil {
_ = sshDB.Close()
}
os.Exit(1)
}
cancel()
// Clean up SSH database connection
if sshDB != nil {
_ = sshDB.Close()
}
slog.Info("all servers stopped successfully")
}