diff --git a/external/eyrie b/external/eyrie index 9da8423a..138b60dd 160000 --- a/external/eyrie +++ b/external/eyrie @@ -1 +1 @@ -Subproject commit 9da8423ad70f82947b3b400acd447082d33cb336 +Subproject commit 138b60dd6af0afd2841deda4ba73d79146a52f84 diff --git a/internal/api/server.go b/internal/api/server.go index cd2a5c50..604c6ebf 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -5,11 +5,13 @@ import ( "context" "crypto/subtle" "encoding/json" + "fmt" "io" "net" "net/http" "strings" "sync" + "time" ) const maxRequestBodyBytes = 1 << 20 @@ -74,9 +76,19 @@ func NewWithAPIKey(addr, apiKey string) *Server { // registerRoutes sets up the HTTP endpoints. func (s *Server) registerRoutes() { - s.mux.HandleFunc("GET /health", s.handleHealth) - s.mux.HandleFunc("GET /version", s.handleVersion) - s.mux.HandleFunc("POST /chat", s.auth(s.handleChat)) + s.mux.HandleFunc("GET /health", securityHeaders(s.handleHealth)) + s.mux.HandleFunc("GET /version", securityHeaders(s.handleVersion)) + s.mux.HandleFunc("POST /chat", securityHeaders(s.auth(s.handleChat))) +} + +// securityHeaders sets standard HTTP security headers on every response. +func securityHeaders(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("X-Frame-Options", "DENY") + w.Header().Set("Cache-Control", "no-store") + next(w, r) + } } func (s *Server) auth(next http.HandlerFunc) http.HandlerFunc { @@ -110,6 +122,35 @@ func constantTimeEqual(a, b string) bool { return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1 } +// validateAuthConfig refuses to start the server with no API key on a +// non-loopback bind. The auth middleware silently allows every request when +// the API key is empty, so a misconfigured server would be wide open. The +// only safe no-key mode is loopback bind. +func (s *Server) validateAuthConfig() error { + if s.apiKey != "" { + return nil + } + host, _, err := net.SplitHostPort(s.addr) + if err != nil { + return fmt.Errorf("api: invalid bind address %q: %w", s.addr, err) + } + if !isLoopbackHost(host) { + return fmt.Errorf("api: apiKey is empty and bind address %q is not loopback; refusing to start. Set apiKey or bind to 127.0.0.1", s.addr) + } + return nil +} + +// isLoopbackHost reports whether host is a loopback address. +func isLoopbackHost(host string) bool { + if host == "" || host == "localhost" { + return host == "localhost" // "" is unsafe; "localhost" is loopback + } + if ip := net.ParseIP(host); ip != nil { + return ip.IsLoopback() + } + return false +} + func decodeJSONBody(w http.ResponseWriter, r *http.Request, dst any) bool { r.Body = http.MaxBytesReader(w, r.Body, maxRequestBodyBytes) dec := json.NewDecoder(r.Body) @@ -127,10 +168,17 @@ func decodeJSONBody(w http.ResponseWriter, r *http.Request, dst any) bool { // Start starts the HTTP server. It blocks until the context is cancelled or an error occurs. func (s *Server) Start(ctx context.Context) error { + if err := s.validateAuthConfig(); err != nil { + return err + } s.mu.Lock() s.server = &http.Server{ - Addr: s.addr, - Handler: s.mux, + Addr: s.addr, + Handler: s.mux, + ReadHeaderTimeout: 10 * time.Second, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, } s.mu.Unlock() @@ -151,7 +199,7 @@ func (s *Server) Start(ctx context.Context) error { return err } -// Stop gracefully shuts down the HTTP server. +// Stop gracefully shuts down the HTTP server with a 15-second timeout. func (s *Server) Stop(ctx context.Context) error { s.mu.Lock() srv := s.server @@ -160,7 +208,12 @@ func (s *Server) Stop(ctx context.Context) error { if srv == nil { return nil } - return srv.Shutdown(ctx) + // Use a bounded timeout so Stop cannot hang indefinitely if a + // client keeps a connection open. The caller's ctx is respected + // if it has a shorter deadline. + shutdownCtx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + return srv.Shutdown(shutdownCtx) } // Handler returns the underlying http.Handler for testing purposes. diff --git a/internal/daemon/discord.go b/internal/daemon/discord.go index 54e1c18a..323a8748 100644 --- a/internal/daemon/discord.go +++ b/internal/daemon/discord.go @@ -201,7 +201,7 @@ func (g *DiscordGateway) fetchMessagesREST(ctx context.Context, channelID, after } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - data, _ := io.ReadAll(resp.Body) + data, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) return nil, fmt.Errorf("discord messages: HTTP %d: %s", resp.StatusCode, string(data)) } var msgs []discordMessage diff --git a/internal/daemon/gateway.go b/internal/daemon/gateway.go index 9a2f0974..1773d7ca 100644 --- a/internal/daemon/gateway.go +++ b/internal/daemon/gateway.go @@ -29,7 +29,7 @@ func forwardToHawk(ctx context.Context, client *http.Client, daemonAddr, apiKey, return "", err } defer func() { _ = resp.Body.Close() }() - body, _ := io.ReadAll(resp.Body) + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) var chatResp struct { Response string `json:"response"` } diff --git a/internal/daemon/telegram.go b/internal/daemon/telegram.go index b049bb55..4f0d70ef 100644 --- a/internal/daemon/telegram.go +++ b/internal/daemon/telegram.go @@ -152,17 +152,23 @@ func (tg *TelegramGateway) handleMessage(ctx context.Context, msg *TelegramMessa response = fmt.Sprintf("Error: %v", err) } - // Format for Telegram (truncate if too long) - if len(response) > 4000 { - response = response[:4000] + "\n\n... (truncated)" + // Format for Telegram (truncate if too long, at rune boundary) + if len([]rune(response)) > 4000 { + response = string([]rune(response)[:4000]) + "\n\n... (truncated)" } _ = tg.sendMessage(ctx, msg.Chat.ID, response) } func (tg *TelegramGateway) forwardToHawk(ctx context.Context, prompt string) (string, error) { - payload := fmt.Sprintf(`{"prompt":%q}`, prompt) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, tg.DaemonAddr+"/v1/chat", strings.NewReader(payload)) + // Use json.Marshal for safe JSON encoding instead of fmt.Sprintf + // with %q, which does not handle all JSON edge cases (e.g., control + // characters, surrogate pairs). + payload, err := json.Marshal(map[string]string{"prompt": prompt}) + if err != nil { + return "", fmt.Errorf("encode prompt: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tg.DaemonAddr+"/v1/chat", strings.NewReader(string(payload))) if err != nil { return "", err } @@ -177,7 +183,8 @@ func (tg *TelegramGateway) forwardToHawk(ctx context.Context, prompt string) (st } defer func() { _ = resp.Body.Close() }() - body, _ := io.ReadAll(resp.Body) + // Limit response body to 1 MiB to prevent memory exhaustion. + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) var chatResp struct { Response string `json:"response"` } diff --git a/internal/engine/session.go b/internal/engine/session.go index c57417a4..b9d6a043 100644 --- a/internal/engine/session.go +++ b/internal/engine/session.go @@ -35,6 +35,7 @@ type MemoryRecaller interface { // SnapshotTracker abstracts the snapshot system so engine doesn't import snapshot directly. type SnapshotTracker interface { Track(message string) (string, error) + TrackCtx(ctx context.Context, message string) (string, error) } // Session manages a conversation with an LLM via eyrie. diff --git a/internal/engine/stream.go b/internal/engine/stream.go index b62cbe3f..ed32462e 100644 --- a/internal/engine/stream.go +++ b/internal/engine/stream.go @@ -460,9 +460,11 @@ func (s *Session) agentLoop(ctx context.Context, ch chan<- StreamEvent) { "reason": retryReason, "error": streamErr.Error(), }) + retryTimer := time.NewTimer(time.Duration(streamAttempt+1) * time.Second) select { - case <-time.After(time.Duration(streamAttempt+1) * time.Second): + case <-retryTimer.C: case <-ctx.Done(): + retryTimer.Stop() ch <- StreamEvent{Type: "error", Content: "stream retry cancelled: " + ctx.Err().Error()} result.Close() return @@ -730,7 +732,13 @@ func (s *Session) agentLoop(ctx context.Context, ch chan<- StreamEvent) { } } if len(writeNames) > 0 { - go func() { _, _ = s.Snapshots.Track(strings.Join(writeNames, ", ")) }() + go func() { + // Bound the snapshot so a slow filesystem doesn't + // leak a goroutine after the session ends. + snapCtx, snapCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer snapCancel() + _, _ = s.Snapshots.TrackCtx(snapCtx, strings.Join(writeNames, ", ")) + }() } } diff --git a/internal/mcp/mcp.go b/internal/mcp/mcp.go index 8d388daa..b9e25c1e 100644 --- a/internal/mcp/mcp.go +++ b/internal/mcp/mcp.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "os/exec" "strings" "sync" @@ -24,17 +25,18 @@ func SetClientVersion(v string) { clientVersion = v } // Server represents a connected MCP server. type Server struct { - Name string - Command string - Args []string - cmd *exec.Cmd - stdin io.WriteCloser - stdout io.ReadCloser - mu sync.Mutex - nextID int - reader *bufio.Scanner - pending map[int]chan json.RawMessage // response channels keyed by request ID - pendMu sync.Mutex + Name string + Command string + Args []string + cmd *exec.Cmd + stdin io.WriteCloser + stdout io.ReadCloser + mu sync.Mutex + nextID int + reader *bufio.Scanner + pending map[int]chan json.RawMessage // response channels keyed by request ID + pendErrors map[int]string // error details keyed by request ID + pendMu sync.Mutex } // Tool is a tool exposed by an MCP server. @@ -91,14 +93,15 @@ func Connect(ctx context.Context, name, command string, args ...string) (*Server scanner.Buffer(make([]byte, 0, 1024*1024), 1024*1024) // 1MB buffer s := &Server{ - Name: name, - Command: command, - Args: args, - cmd: cmd, - stdin: stdin, - stdout: stdout, - reader: scanner, - pending: make(map[int]chan json.RawMessage), + Name: name, + Command: command, + Args: args, + cmd: cmd, + stdin: stdin, + stdout: stdout, + reader: scanner, + pending: make(map[int]chan json.RawMessage), + pendErrors: make(map[int]string), } // Start background reader to dispatch responses and notifications @@ -142,6 +145,11 @@ func (s *Server) readLoop() { s.pendMu.Unlock() if ok { if msg.Error != nil { + // Store error details so the caller can include them + // in the returned error instead of a generic message. + s.pendMu.Lock() + s.pendErrors[msg.ID] = fmt.Sprintf("code %d: %s", msg.Error.Code, msg.Error.Message) + s.pendMu.Unlock() ch <- nil // signal error via nil } else { ch <- msg.Result @@ -152,11 +160,19 @@ func (s *Server) readLoop() { } // Otherwise it's a notification — ignore for now } - // Scanner done — close all pending channels + // Scanner done — log the cause if it was an error (e.g., oversized + // response exceeding the 1MB buffer), then close all pending channels. + if err := s.reader.Err(); err != nil { + slog.Warn("mcp: stdout reader stopped", "server", s.Name, "error", err) + } s.pendMu.Lock() for id, ch := range s.pending { close(ch) delete(s.pending, id) + // Clean up pendErrors only for requests that will never be + // answered. Entries for already-signaled requests (no longer in + // s.pending) are left for the caller to reap. + delete(s.pendErrors, id) } s.pendMu.Unlock() } @@ -309,6 +325,7 @@ func (s *Server) callWithTimeout(ctx context.Context, method string, params inte if err != nil { s.pendMu.Lock() delete(s.pending, id) + delete(s.pendErrors, id) s.pendMu.Unlock() return nil, fmt.Errorf("write: %w", err) } @@ -319,23 +336,40 @@ func (s *Server) callWithTimeout(ctx context.Context, method string, params inte timeout = time.Until(deadline) } + // Use time.NewTimer + Stop instead of time.After to avoid leaking + // the timer in the runtime when the response arrives or ctx is + // cancelled before the timeout fires. + timer := time.NewTimer(timeout) select { case result, ok := <-ch: + timer.Stop() if !ok { return nil, fmt.Errorf("mcp: connection closed") } if result == nil { + // Include the server's error code and message if available, + // instead of a generic "server returned error" with no detail. + s.pendMu.Lock() + errMsg := s.pendErrors[id] + delete(s.pendErrors, id) + s.pendMu.Unlock() + if errMsg != "" { + return nil, fmt.Errorf("mcp: server error: %s", errMsg) + } return nil, fmt.Errorf("mcp: server returned error") } return result, nil - case <-time.After(timeout): + case <-timer.C: s.pendMu.Lock() delete(s.pending, id) + delete(s.pendErrors, id) s.pendMu.Unlock() return nil, fmt.Errorf("mcp: call %s timed out after %s", method, timeout) case <-ctx.Done(): + timer.Stop() s.pendMu.Lock() delete(s.pending, id) + delete(s.pendErrors, id) s.pendMu.Unlock() return nil, ctx.Err() } diff --git a/internal/permissions/guardian.go b/internal/permissions/guardian.go index 3618fa16..94a51219 100644 --- a/internal/permissions/guardian.go +++ b/internal/permissions/guardian.go @@ -256,9 +256,12 @@ func isBase64Injection(s string) bool { if len(s) < minBase64Len { return false } - // Check if the line is mostly base64 characters (letters, digits, +, /, =) + // Count base64-legal bytes (all ASCII). Using byte iteration instead + // of rune iteration keeps the count consistent with len(s) (which is + // a byte count), so the ratio is correct for multi-byte UTF-8 input. b64Chars := 0 - for _, c := range s { + for i := 0; i < len(s); i++ { + c := s[i] if (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '+' || c == '/' || c == '=' { b64Chars++ } diff --git a/internal/resilience/ratelimit/ratelimit.go b/internal/resilience/ratelimit/ratelimit.go index 7cb99328..20384276 100644 --- a/internal/resilience/ratelimit/ratelimit.go +++ b/internal/resilience/ratelimit/ratelimit.go @@ -70,10 +70,12 @@ func (l *Limiter) Wait(ctx context.Context) error { } l.mu.Unlock() + timer := time.NewTimer(waitTime) select { case <-ctx.Done(): + timer.Stop() return ctx.Err() - case <-time.After(waitTime): + case <-timer.C: // Try again } } diff --git a/internal/resilience/retry/retry.go b/internal/resilience/retry/retry.go index b10ec49e..78876de7 100644 --- a/internal/resilience/retry/retry.go +++ b/internal/resilience/retry/retry.go @@ -76,10 +76,12 @@ func Do(ctx context.Context, cfg Config, fn func() error) error { return err } delay := backoff(i, cfg.BaseDelay, cfg.MaxDelay, cfg.Multiplier) + timer := time.NewTimer(delay) select { case <-ctx.Done(): + timer.Stop() return ctx.Err() - case <-time.After(delay): + case <-timer.C: } } return err @@ -102,10 +104,12 @@ func DoWithResult[T any](ctx context.Context, cfg Config, fn func() (T, error)) return result, err } delay := backoff(i, cfg.BaseDelay, cfg.MaxDelay, cfg.Multiplier) + timer := time.NewTimer(delay) select { case <-ctx.Done(): + timer.Stop() return result, ctx.Err() - case <-time.After(delay): + case <-timer.C: } } return result, err diff --git a/internal/sandbox/netproxy.go b/internal/sandbox/netproxy.go index d61a60c0..fcf4810f 100644 --- a/internal/sandbox/netproxy.go +++ b/internal/sandbox/netproxy.go @@ -108,7 +108,11 @@ func (np *NetworkProxy) Start(ctx context.Context) (string, error) { }) np.server = &http.Server{ - Handler: mux, + Handler: mux, + ReadHeaderTimeout: 10 * time.Second, + ReadTimeout: 30 * time.Second, + WriteTimeout: 5 * time.Minute, // long for CONNECT tunnels + IdleTimeout: 120 * time.Second, } go func() { diff --git a/internal/session/persist.go b/internal/session/persist.go index fe3be2a8..a004fbb4 100644 --- a/internal/session/persist.go +++ b/internal/session/persist.go @@ -21,13 +21,26 @@ func SaveMessages(path string, messages []Message) error { return fmt.Errorf("marshal messages: %w", err) } - // Atomic write: temp file + rename to avoid partial writes. - tmp := path + ".tmp" - if err := os.WriteFile(tmp, data, 0o644); err != nil { + // Atomic write: temp file → sync → rename to avoid partial writes. + tmp, err := os.CreateTemp(dir, ".hawk-session-*.tmp") + if err != nil { + return fmt.Errorf("create session temp file: %w", err) + } + tmpPath := tmp.Name() + defer func() { _ = os.Remove(tmpPath) }() // cleanup if rename fails + + if _, err := tmp.Write(data); err != nil { + _ = tmp.Close() return fmt.Errorf("write session temp file: %w", err) } - if err := os.Rename(tmp, path); err != nil { - _ = os.Remove(tmp) + if err := tmp.Sync(); err != nil { + _ = tmp.Close() + return fmt.Errorf("sync session temp file: %w", err) + } + if err := tmp.Close(); err != nil { + return fmt.Errorf("close session temp file: %w", err) + } + if err := os.Rename(tmpPath, path); err != nil { return fmt.Errorf("rename session file: %w", err) } return nil diff --git a/internal/session/sqlite_store.go b/internal/session/sqlite_store.go index 9d675ecf..6b450060 100644 --- a/internal/session/sqlite_store.go +++ b/internal/session/sqlite_store.go @@ -128,12 +128,22 @@ func NewSQLiteStore(dbPath string) (*SQLiteStore, error) { return nil, fmt.Errorf("open sqlite: %w", err) } + // SQLite serializes writes; a single connection avoids "database is + // locked" errors under concurrent access. + db.SetMaxOpenConns(1) + // Enable WAL mode for better concurrent read performance. if _, err := db.ExecContext(context.Background(), "PRAGMA journal_mode=WAL"); err != nil { _ = db.Close() return nil, fmt.Errorf("set WAL mode: %w", err) } + // Set a busy timeout so concurrent writers wait instead of failing. + if _, err := db.ExecContext(context.Background(), "PRAGMA busy_timeout=5000"); err != nil { + _ = db.Close() + return nil, fmt.Errorf("set busy timeout: %w", err) + } + // Enable foreign keys. if _, err := db.ExecContext(context.Background(), "PRAGMA foreign_keys=ON"); err != nil { _ = db.Close() diff --git a/internal/snapshot/snapshot.go b/internal/snapshot/snapshot.go index faebb22d..30e92979 100644 --- a/internal/snapshot/snapshot.go +++ b/internal/snapshot/snapshot.go @@ -83,6 +83,12 @@ func (t *Tracker) Init() error { // Track takes a snapshot of the current project state. Returns the commit hash. func (t *Tracker) Track(message string) (string, error) { + return t.TrackCtx(context.Background(), message) +} + +// TrackCtx is like Track but respects ctx, allowing callers to bound slow +// git operations (e.g., a hung filesystem) and avoid goroutine leaks. +func (t *Tracker) TrackCtx(ctx context.Context, message string) (string, error) { t.mu.Lock() defer t.mu.Unlock() @@ -91,23 +97,23 @@ func (t *Tracker) Track(message string) (string, error) { } // Add all files from project dir - if err := t.gitWork("add", "--all", t.projectDir); err != nil { + if err := t.gitWorkCtx(ctx, "add", "--all", t.projectDir); err != nil { return "", fmt.Errorf("add: %w", err) } // Check if there are changes to commit - if err := t.gitWork("diff", "--cached", "--quiet"); err == nil { + if err := t.gitWorkCtx(ctx, "diff", "--cached", "--quiet"); err == nil { // No changes — return current HEAD - out, _ := t.gitWorkOutput("rev-parse", "--short", "HEAD") + out, _ := t.gitWorkOutputCtx(ctx, "rev-parse", "--short", "HEAD") return strings.TrimSpace(out), nil } // Commit - if err := t.gitWork("commit", "-m", message, "--allow-empty"); err != nil { + if err := t.gitWorkCtx(ctx, "commit", "-m", message, "--allow-empty"); err != nil { return "", fmt.Errorf("commit: %w", err) } - out, err := t.gitWorkOutput("rev-parse", "--short", "HEAD") + out, err := t.gitWorkOutputCtx(ctx, "rev-parse", "--short", "HEAD") if err != nil { return "", err } @@ -215,7 +221,11 @@ func (t *Tracker) Cleanup(maxAge time.Duration) error { } func (t *Tracker) gitWork(args ...string) error { - cmd := exec.CommandContext(context.Background(), "git", args...) + return t.gitWorkCtx(context.Background(), args...) +} + +func (t *Tracker) gitWorkCtx(ctx context.Context, args ...string) error { + cmd := exec.CommandContext(ctx, "git", args...) cmd.Dir = t.shadowDir cmd.Env = append( os.Environ(), @@ -230,7 +240,11 @@ func (t *Tracker) gitWork(args ...string) error { } func (t *Tracker) gitWorkOutput(args ...string) (string, error) { - cmd := exec.CommandContext(context.Background(), "git", args...) + return t.gitWorkOutputCtx(context.Background(), args...) +} + +func (t *Tracker) gitWorkOutputCtx(ctx context.Context, args ...string) (string, error) { + cmd := exec.CommandContext(ctx, "git", args...) cmd.Dir = t.shadowDir cmd.Env = append( os.Environ(), diff --git a/internal/tool/bash.go b/internal/tool/bash.go index 34e8403c..af97c898 100644 --- a/internal/tool/bash.go +++ b/internal/tool/bash.go @@ -1,6 +1,7 @@ package tool import ( + "bytes" "context" "encoding/json" "fmt" @@ -130,6 +131,29 @@ func ContainerExecutorFromContext(ctx context.Context) ContainerExecutor { return nil } +// limitedWriter is an io.Writer that caps the total bytes stored in its +// internal buffer. Once the limit is reached, subsequent writes are silently +// discarded (but still counted). This prevents unbounded memory growth from +// commands that produce huge output (e.g., cat /dev/urandom, yes) — the +// command continues to run and is eventually killed by the context timeout, +// but the process memory stays bounded. +type limitedWriter struct { + buf bytes.Buffer + maxBytes int +} + +func (w *limitedWriter) Write(p []byte) (int, error) { + if w.buf.Len() >= w.maxBytes { + return len(p), nil // silently discard, keep the command unblocked + } + remaining := w.maxBytes - w.buf.Len() + if len(p) > remaining { + w.buf.Write(p[:remaining]) + return len(p), nil + } + return w.buf.Write(p) +} + type BashTool struct{} func (BashTool) Name() string { return "Bash" } @@ -585,10 +609,24 @@ func (BashTool) Execute(ctx context.Context, input json.RawMessage) (string, err } cmd := exec.CommandContext(ctx, execName, execArgs...) - out, err := cmd.CombinedOutput() - result := string(out) - - // Apply safety output truncation (50KB). + // Use a limitedWriter to cap output at maxOutputBytes instead of + // CombinedOutput, which buffers the entire output in memory. A command + // like `yes` or `cat /dev/urandom` can produce GBs before the timeout + // kills it; the limitedWriter keeps memory bounded while the command + // continues to run (writes are silently discarded after the cap). + var lw limitedWriter + // Cap one byte above maxOutputBytes so that TruncateOutput's > branch + // fires when the cap is reached. At exactly maxOutputBytes (no discard) + // TruncateOutput returns unchanged, which is correct. + lw.maxBytes = maxOutputBytes + 1 + cmd.Stdout = &lw + cmd.Stderr = &lw + err := cmd.Run() + result := lw.buf.String() + + // Apply safety output truncation (50KB) — the limitedWriter may have + // captured up to maxOutputBytes (500KB), so we still truncate for the + // final result returned to the model. result = TruncateOutput(result) result = strings.TrimRight(result, "\n") diff --git a/internal/tool/download.go b/internal/tool/download.go index 05338ac6..b10e581a 100644 --- a/internal/tool/download.go +++ b/internal/tool/download.go @@ -47,7 +47,7 @@ func (DownloadTool) Execute(ctx context.Context, input json.RawMessage) (string, if err := validatePathAllowed(ctx, p.Destination); err != nil { return "", err } - pinnedURL, err := validateURLPublic(ctx, p.URL) + pinnedURL, origHost, err := validateURLPublic(ctx, p.URL) if err != nil { return "", err } @@ -57,6 +57,10 @@ func (DownloadTool) Execute(ctx context.Context, input json.RawMessage) (string, if err != nil { return "", fmt.Errorf("create request: %w", err) } + // Preserve the original Host header for virtual-host routing. + if origHost != "" { + req.Host = origHost + } resp, err := client.Do(req) if err != nil { return "", fmt.Errorf("download failed: %w", err) diff --git a/internal/tool/file_write.go b/internal/tool/file_write.go index b09ea442..87973fe2 100644 --- a/internal/tool/file_write.go +++ b/internal/tool/file_write.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "log/slog" "os" "path/filepath" ) @@ -59,15 +60,41 @@ func (FileWriteTool) Execute(ctx context.Context, input json.RawMessage) (string // Backup existing file before overwriting if _, statErr := os.Stat(path); statErr == nil { if _, backupErr := BackupFile(path); backupErr != nil { - // Best-effort backup — log but don't block the write - _ = backupErr + // Log the backup failure so the user knows the original may + // be lost on a bad write. Previously this was silently dropped. + slog.Warn("file write: backup failed, proceeding with overwrite", "path", path, "error", backupErr) } } if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil { return "", fmt.Errorf("mkdir: %w", err) } - if err := os.WriteFile(path, []byte(p.Content), 0o644); err != nil { - return "", fmt.Errorf("write: %w", err) + // Write atomically: temp file in the same directory → sync → rename. + // This prevents file corruption if the process crashes mid-write, + // which os.WriteFile (truncate-then-write) cannot guarantee. + dir := filepath.Dir(path) + tmp, err := os.CreateTemp(dir, ".hawk-write-*") + if err != nil { + return "", fmt.Errorf("create temp file: %w", err) + } + tmpPath := tmp.Name() + defer func() { _ = os.Remove(tmpPath) }() // cleanup if rename fails + + if _, err := tmp.Write([]byte(p.Content)); err != nil { + _ = tmp.Close() + return "", fmt.Errorf("write temp: %w", err) + } + if err := tmp.Sync(); err != nil { + _ = tmp.Close() + return "", fmt.Errorf("sync temp: %w", err) + } + if err := tmp.Close(); err != nil { + return "", fmt.Errorf("close temp: %w", err) + } + if err := os.Chmod(tmpPath, 0o644); err != nil { + return "", fmt.Errorf("chmod temp: %w", err) + } + if err := os.Rename(tmpPath, path); err != nil { + return "", fmt.Errorf("rename: %w", err) } if autoCommitEnabled(ctx) { _ = AutoCommit(ctx, path, "Write", "wrote file") diff --git a/internal/tool/retry.go b/internal/tool/retry.go index b8a33ed8..5854107d 100644 --- a/internal/tool/retry.go +++ b/internal/tool/retry.go @@ -94,9 +94,11 @@ func RetryExecutor(ctx context.Context, t Tool, input []byte, policy RetryPolicy break } // Wait, respecting ctx cancellation. + timer := time.NewTimer(delay) select { - case <-time.After(delay): + case <-timer.C: case <-ctx.Done(): + timer.Stop() return out, ctx.Err() } delay *= 2 diff --git a/internal/tool/safety.go b/internal/tool/safety.go index bb28fdc8..e1c28745 100644 --- a/internal/tool/safety.go +++ b/internal/tool/safety.go @@ -385,10 +385,13 @@ var privateIPBlocks []*net.IPNet func init() { for _, cidr := range []string{ "127.0.0.0/8", // loopback + "0.0.0.0/8", // "this network" (RFC 1122) "10.0.0.0/8", // private "172.16.0.0/12", // private "192.168.0.0/16", // private "169.254.0.0/16", // link-local / cloud metadata + "100.64.0.0/10", // CGN (RFC 6598) + "198.18.0.0/15", // benchmark testing (RFC 2544) "::1/128", // IPv6 loopback "fc00::/7", // IPv6 unique local "fe80::/10", // IPv6 link-local @@ -411,37 +414,41 @@ func WithSSRFSkip(ctx context.Context) context.Context { // validateURLPublic rejects URLs that resolve to private/link-local IP ranges // to prevent SSRF attacks (e.g., fetching AWS metadata at 169.254.169.254). -// Returns the validated URL with the resolved IP pinned as the host, preventing -// DNS rebinding attacks where the second resolution returns a private IP. -func validateURLPublic(ctx context.Context, rawURL string) (string, error) { +// Returns the validated URL with the resolved IP pinned as the host (preventing +// DNS rebinding) and the original hostname (so callers can preserve the Host +// header for virtual-host routing). +func validateURLPublic(ctx context.Context, rawURL string) (pinnedURL, originalHost string, err error) { if ctx.Value(ssrfSkipKey{}) != nil { - return rawURL, nil + return rawURL, "", nil } u, err := url.Parse(rawURL) if err != nil { - return "", fmt.Errorf("invalid URL: %w", err) + return "", "", fmt.Errorf("invalid URL: %w", err) } if u.Scheme != "http" && u.Scheme != "https" { - return "", fmt.Errorf("blocked: only http/https URLs are allowed") + return "", "", fmt.Errorf("blocked: only http/https URLs are allowed") } host := u.Hostname() if host == "" { - return "", fmt.Errorf("blocked: URL has no host") + return "", "", fmt.Errorf("blocked: URL has no host") } // Resolve the hostname to check against private ranges. addrs, err := net.DefaultResolver.LookupIPAddr(ctx, host) if err != nil { // DNS failure — block the request rather than allowing potential SSRF bypass. - return "", fmt.Errorf("blocked: DNS resolution failed for %q: %w", host, err) + return "", "", fmt.Errorf("blocked: DNS resolution failed for %q: %w", host, err) } var safeIP string for _, addr := range addrs { ip := addr.IP + // net.IPNet.Contains calls ip.To4() internally, so IPv4-mapped IPv6 + // addresses (::ffff:a.b.c.d) are correctly checked against IPv4 CIDR + // blocks — no separate handling needed. for _, block := range privateIPBlocks { if block.Contains(ip) { - return "", fmt.Errorf("blocked: URL %q resolves to private IP %s", rawURL, ip) + return "", "", fmt.Errorf("blocked: URL %q resolves to private IP %s", rawURL, ip) } } if safeIP == "" { @@ -449,17 +456,18 @@ func validateURLPublic(ctx context.Context, rawURL string) (string, error) { } } if safeIP == "" { - return "", fmt.Errorf("blocked: URL %q resolved to no addresses", rawURL) + return "", "", fmt.Errorf("blocked: URL %q resolved to no addresses", rawURL) } // Pin the IP to prevent DNS rebinding: replace host with the validated IP. - // Preserve the original Host header via a separate mechanism if needed. + // The caller should set req.Host to originalHost to preserve virtual-host + // routing (most web servers route by Host header, not by IP). if u.Port() != "" { u.Host = net.JoinHostPort(safeIP, u.Port()) } else { u.Host = safeIP } - return u.String(), nil + return u.String(), host, nil } // ssrfSafeClient returns an http.Client that validates redirect targets @@ -476,7 +484,7 @@ func ssrfSafeClient(ctx context.Context, timeout time.Duration) *http.Client { if ctx.Value(ssrfSkipKey{}) != nil { return nil } - pinned, err := validateURLPublic(ctx, req.URL.String()) + pinned, origHost, err := validateURLPublic(ctx, req.URL.String()) if err != nil { return err } @@ -487,6 +495,11 @@ func ssrfSafeClient(ctx context.Context, timeout time.Duration) *http.Client { return parseErr } req.URL = parsed + // Preserve the original Host header so virtual-host routing + // works correctly on the redirect target. + if origHost != "" { + req.Host = origHost + } return nil }, } diff --git a/internal/tool/safety_test.go b/internal/tool/safety_test.go index dba49048..4164e7f3 100644 --- a/internal/tool/safety_test.go +++ b/internal/tool/safety_test.go @@ -654,7 +654,7 @@ func TestResolvePath_Symlink(t *testing.T) { func TestValidateURLPublic_SkipContext(t *testing.T) { ctx := WithSSRFSkip(t.Context()) - got, err := validateURLPublic(ctx, "http://127.0.0.1/metadata") + got, _, err := validateURLPublic(ctx, "http://127.0.0.1/metadata") if err != nil { t.Fatalf("expected no error with SSRF skip, got: %v", err) } @@ -665,7 +665,7 @@ func TestValidateURLPublic_SkipContext(t *testing.T) { func TestValidateURLPublic_InvalidURL(t *testing.T) { ctx := t.Context() - _, err := validateURLPublic(ctx, "://invalid") + _, _, err := validateURLPublic(ctx, "://invalid") if err == nil { t.Error("expected error for invalid URL") } @@ -679,7 +679,7 @@ func TestValidateURLPublic_BlockedScheme(t *testing.T) { "javascript:alert(1)", } for _, u := range cases { - _, err := validateURLPublic(ctx, u) + _, _, err := validateURLPublic(ctx, u) if err == nil { t.Errorf("expected error for URL scheme %q", u) } @@ -688,7 +688,7 @@ func TestValidateURLPublic_BlockedScheme(t *testing.T) { func TestValidateURLPublic_NoHost(t *testing.T) { ctx := t.Context() - _, err := validateURLPublic(ctx, "http:///path") + _, _, err := validateURLPublic(ctx, "http:///path") if err == nil { t.Error("expected error for URL with no host") } diff --git a/internal/tool/web_fetch.go b/internal/tool/web_fetch.go index e6539e2d..bf5c8e06 100644 --- a/internal/tool/web_fetch.go +++ b/internal/tool/web_fetch.go @@ -44,7 +44,7 @@ func (WebFetchTool) Execute(ctx context.Context, input json.RawMessage) (string, if p.URL == "" { return "", fmt.Errorf("url is required") } - pinnedURL, err := validateURLPublic(ctx, p.URL) + pinnedURL, origHost, err := validateURLPublic(ctx, p.URL) if err != nil { return "", err } @@ -57,6 +57,13 @@ func (WebFetchTool) Execute(ctx context.Context, input json.RawMessage) (string, return "", err } req.Header.Set("User-Agent", "hawk/0.1.0") + // Preserve the original Host header so virtual-host routing works + // correctly. validateURLPublic pins the connection to the validated + // IP (preventing DNS rebinding), but replaces the URL host with the IP. + // Setting req.Host restores the original hostname for the Host header. + if origHost != "" { + req.Host = origHost + } client := ssrfSafeClient(ctx, 30*time.Second) resp, err := client.Do(req)