Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 60 additions & 7 deletions internal/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ import (
"context"
"crypto/subtle"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"time"
)

const maxRequestBodyBytes = 1 << 20
Expand Down Expand Up @@ -74,9 +76,19 @@ func NewWithAPIKey(addr, apiKey string) *Server {

// registerRoutes sets up the HTTP endpoints.
func (s *Server) registerRoutes() {
s.mux.HandleFunc("GET /health", s.handleHealth)
s.mux.HandleFunc("GET /version", s.handleVersion)
s.mux.HandleFunc("POST /chat", s.auth(s.handleChat))
s.mux.HandleFunc("GET /health", securityHeaders(s.handleHealth))
s.mux.HandleFunc("GET /version", securityHeaders(s.handleVersion))
s.mux.HandleFunc("POST /chat", securityHeaders(s.auth(s.handleChat)))
}

// securityHeaders sets standard HTTP security headers on every response.
func securityHeaders(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("Cache-Control", "no-store")
next(w, r)
}
}

func (s *Server) auth(next http.HandlerFunc) http.HandlerFunc {
Expand Down Expand Up @@ -110,6 +122,35 @@ func constantTimeEqual(a, b string) bool {
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
}

// validateAuthConfig refuses to start the server with no API key on a
// non-loopback bind. The auth middleware silently allows every request when
// the API key is empty, so a misconfigured server would be wide open. The
// only safe no-key mode is loopback bind.
func (s *Server) validateAuthConfig() error {
if s.apiKey != "" {
return nil
}
host, _, err := net.SplitHostPort(s.addr)
if err != nil {
return fmt.Errorf("api: invalid bind address %q: %w", s.addr, err)
}
if !isLoopbackHost(host) {
return fmt.Errorf("api: apiKey is empty and bind address %q is not loopback; refusing to start. Set apiKey or bind to 127.0.0.1", s.addr)
}
return nil
}

// isLoopbackHost reports whether host is a loopback address.
func isLoopbackHost(host string) bool {
if host == "" || host == "localhost" {
return host == "localhost" // "" is unsafe; "localhost" is loopback
}
if ip := net.ParseIP(host); ip != nil {
return ip.IsLoopback()
}
return false
}

func decodeJSONBody(w http.ResponseWriter, r *http.Request, dst any) bool {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestBodyBytes)
dec := json.NewDecoder(r.Body)
Expand All @@ -127,10 +168,17 @@ func decodeJSONBody(w http.ResponseWriter, r *http.Request, dst any) bool {

// Start starts the HTTP server. It blocks until the context is cancelled or an error occurs.
func (s *Server) Start(ctx context.Context) error {
if err := s.validateAuthConfig(); err != nil {
return err
}
s.mu.Lock()
s.server = &http.Server{
Addr: s.addr,
Handler: s.mux,
Addr: s.addr,
Handler: s.mux,
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
}
s.mu.Unlock()

Expand All @@ -151,7 +199,7 @@ func (s *Server) Start(ctx context.Context) error {
return err
}

// Stop gracefully shuts down the HTTP server.
// Stop gracefully shuts down the HTTP server with a 15-second timeout.
func (s *Server) Stop(ctx context.Context) error {
s.mu.Lock()
srv := s.server
Expand All @@ -160,7 +208,12 @@ func (s *Server) Stop(ctx context.Context) error {
if srv == nil {
return nil
}
return srv.Shutdown(ctx)
// Use a bounded timeout so Stop cannot hang indefinitely if a
// client keeps a connection open. The caller's ctx is respected
// if it has a shorter deadline.
shutdownCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()
return srv.Shutdown(shutdownCtx)
}

// Handler returns the underlying http.Handler for testing purposes.
Expand Down
2 changes: 1 addition & 1 deletion internal/daemon/discord.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ func (g *DiscordGateway) fetchMessagesREST(ctx context.Context, channelID, after
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
data, _ := io.ReadAll(resp.Body)
data, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return nil, fmt.Errorf("discord messages: HTTP %d: %s", resp.StatusCode, string(data))
}
var msgs []discordMessage
Expand Down
2 changes: 1 addition & 1 deletion internal/daemon/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func forwardToHawk(ctx context.Context, client *http.Client, daemonAddr, apiKey,
return "", err
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
var chatResp struct {
Response string `json:"response"`
}
Expand Down
19 changes: 13 additions & 6 deletions internal/daemon/telegram.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,17 +152,23 @@ func (tg *TelegramGateway) handleMessage(ctx context.Context, msg *TelegramMessa
response = fmt.Sprintf("Error: %v", err)
}

// Format for Telegram (truncate if too long)
if len(response) > 4000 {
response = response[:4000] + "\n\n... (truncated)"
// Format for Telegram (truncate if too long, at rune boundary)
if len([]rune(response)) > 4000 {
response = string([]rune(response)[:4000]) + "\n\n... (truncated)"
}

_ = tg.sendMessage(ctx, msg.Chat.ID, response)
}

func (tg *TelegramGateway) forwardToHawk(ctx context.Context, prompt string) (string, error) {
payload := fmt.Sprintf(`{"prompt":%q}`, prompt)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tg.DaemonAddr+"/v1/chat", strings.NewReader(payload))
// Use json.Marshal for safe JSON encoding instead of fmt.Sprintf
// with %q, which does not handle all JSON edge cases (e.g., control
// characters, surrogate pairs).
payload, err := json.Marshal(map[string]string{"prompt": prompt})
if err != nil {
return "", fmt.Errorf("encode prompt: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tg.DaemonAddr+"/v1/chat", strings.NewReader(string(payload)))
if err != nil {
return "", err
}
Expand All @@ -177,7 +183,8 @@ func (tg *TelegramGateway) forwardToHawk(ctx context.Context, prompt string) (st
}
defer func() { _ = resp.Body.Close() }()

body, _ := io.ReadAll(resp.Body)
// Limit response body to 1 MiB to prevent memory exhaustion.
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
var chatResp struct {
Response string `json:"response"`
}
Expand Down
1 change: 1 addition & 0 deletions internal/engine/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type MemoryRecaller interface {
// SnapshotTracker abstracts the snapshot system so engine doesn't import snapshot directly.
type SnapshotTracker interface {
Track(message string) (string, error)
TrackCtx(ctx context.Context, message string) (string, error)
}

// Session manages a conversation with an LLM via eyrie.
Expand Down
12 changes: 10 additions & 2 deletions internal/engine/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,11 @@ func (s *Session) agentLoop(ctx context.Context, ch chan<- StreamEvent) {
"reason": retryReason,
"error": streamErr.Error(),
})
retryTimer := time.NewTimer(time.Duration(streamAttempt+1) * time.Second)
select {
case <-time.After(time.Duration(streamAttempt+1) * time.Second):
case <-retryTimer.C:
case <-ctx.Done():
retryTimer.Stop()
ch <- StreamEvent{Type: "error", Content: "stream retry cancelled: " + ctx.Err().Error()}
result.Close()
return
Expand Down Expand Up @@ -730,7 +732,13 @@ func (s *Session) agentLoop(ctx context.Context, ch chan<- StreamEvent) {
}
}
if len(writeNames) > 0 {
go func() { _, _ = s.Snapshots.Track(strings.Join(writeNames, ", ")) }()
go func() {
// Bound the snapshot so a slow filesystem doesn't
// leak a goroutine after the session ends.
snapCtx, snapCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer snapCancel()
_, _ = s.Snapshots.TrackCtx(snapCtx, strings.Join(writeNames, ", "))
}()
}
}

Expand Down
76 changes: 55 additions & 21 deletions internal/mcp/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"io"
"log/slog"
"os/exec"
"strings"
"sync"
Expand All @@ -24,17 +25,18 @@ func SetClientVersion(v string) { clientVersion = v }

// Server represents a connected MCP server.
type Server struct {
Name string
Command string
Args []string
cmd *exec.Cmd
stdin io.WriteCloser
stdout io.ReadCloser
mu sync.Mutex
nextID int
reader *bufio.Scanner
pending map[int]chan json.RawMessage // response channels keyed by request ID
pendMu sync.Mutex
Name string
Command string
Args []string
cmd *exec.Cmd
stdin io.WriteCloser
stdout io.ReadCloser
mu sync.Mutex
nextID int
reader *bufio.Scanner
pending map[int]chan json.RawMessage // response channels keyed by request ID
pendErrors map[int]string // error details keyed by request ID
pendMu sync.Mutex
}

// Tool is a tool exposed by an MCP server.
Expand Down Expand Up @@ -91,14 +93,15 @@ func Connect(ctx context.Context, name, command string, args ...string) (*Server
scanner.Buffer(make([]byte, 0, 1024*1024), 1024*1024) // 1MB buffer

s := &Server{
Name: name,
Command: command,
Args: args,
cmd: cmd,
stdin: stdin,
stdout: stdout,
reader: scanner,
pending: make(map[int]chan json.RawMessage),
Name: name,
Command: command,
Args: args,
cmd: cmd,
stdin: stdin,
stdout: stdout,
reader: scanner,
pending: make(map[int]chan json.RawMessage),
pendErrors: make(map[int]string),
}

// Start background reader to dispatch responses and notifications
Expand Down Expand Up @@ -142,6 +145,11 @@ func (s *Server) readLoop() {
s.pendMu.Unlock()
if ok {
if msg.Error != nil {
// Store error details so the caller can include them
// in the returned error instead of a generic message.
s.pendMu.Lock()
s.pendErrors[msg.ID] = fmt.Sprintf("code %d: %s", msg.Error.Code, msg.Error.Message)
s.pendMu.Unlock()
ch <- nil // signal error via nil
} else {
ch <- msg.Result
Expand All @@ -152,11 +160,19 @@ func (s *Server) readLoop() {
}
// Otherwise it's a notification — ignore for now
}
// Scanner done — close all pending channels
// Scanner done — log the cause if it was an error (e.g., oversized
// response exceeding the 1MB buffer), then close all pending channels.
if err := s.reader.Err(); err != nil {
slog.Warn("mcp: stdout reader stopped", "server", s.Name, "error", err)
}
s.pendMu.Lock()
for id, ch := range s.pending {
close(ch)
delete(s.pending, id)
// Clean up pendErrors only for requests that will never be
// answered. Entries for already-signaled requests (no longer in
// s.pending) are left for the caller to reap.
delete(s.pendErrors, id)
}
s.pendMu.Unlock()
}
Expand Down Expand Up @@ -309,6 +325,7 @@ func (s *Server) callWithTimeout(ctx context.Context, method string, params inte
if err != nil {
s.pendMu.Lock()
delete(s.pending, id)
delete(s.pendErrors, id)
s.pendMu.Unlock()
return nil, fmt.Errorf("write: %w", err)
}
Expand All @@ -319,23 +336,40 @@ func (s *Server) callWithTimeout(ctx context.Context, method string, params inte
timeout = time.Until(deadline)
}

// Use time.NewTimer + Stop instead of time.After to avoid leaking
// the timer in the runtime when the response arrives or ctx is
// cancelled before the timeout fires.
timer := time.NewTimer(timeout)
select {
case result, ok := <-ch:
timer.Stop()
if !ok {
return nil, fmt.Errorf("mcp: connection closed")
}
if result == nil {
// Include the server's error code and message if available,
// instead of a generic "server returned error" with no detail.
s.pendMu.Lock()
errMsg := s.pendErrors[id]
delete(s.pendErrors, id)
s.pendMu.Unlock()
if errMsg != "" {
return nil, fmt.Errorf("mcp: server error: %s", errMsg)
}
return nil, fmt.Errorf("mcp: server returned error")
}
return result, nil
case <-time.After(timeout):
case <-timer.C:
s.pendMu.Lock()
delete(s.pending, id)
delete(s.pendErrors, id)
s.pendMu.Unlock()
return nil, fmt.Errorf("mcp: call %s timed out after %s", method, timeout)
case <-ctx.Done():
timer.Stop()
s.pendMu.Lock()
delete(s.pending, id)
delete(s.pendErrors, id)
s.pendMu.Unlock()
return nil, ctx.Err()
}
Expand Down
7 changes: 5 additions & 2 deletions internal/permissions/guardian.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,12 @@ func isBase64Injection(s string) bool {
if len(s) < minBase64Len {
return false
}
// Check if the line is mostly base64 characters (letters, digits, +, /, =)
// Count base64-legal bytes (all ASCII). Using byte iteration instead
// of rune iteration keeps the count consistent with len(s) (which is
// a byte count), so the ratio is correct for multi-byte UTF-8 input.
b64Chars := 0
for _, c := range s {
for i := 0; i < len(s); i++ {
c := s[i]
if (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '+' || c == '/' || c == '=' {
b64Chars++
}
Expand Down
4 changes: 3 additions & 1 deletion internal/resilience/ratelimit/ratelimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,12 @@ func (l *Limiter) Wait(ctx context.Context) error {
}
l.mu.Unlock()

timer := time.NewTimer(waitTime)
select {
case <-ctx.Done():
timer.Stop()
return ctx.Err()
case <-time.After(waitTime):
case <-timer.C:
// Try again
}
}
Expand Down
Loading
Loading