diff --git a/catalog/xiaomi/http.go b/catalog/xiaomi/http.go index f03503a..01b95ef 100644 --- a/catalog/xiaomi/http.go +++ b/catalog/xiaomi/http.go @@ -4,9 +4,10 @@ import ( "context" "encoding/json" "fmt" - "io" "net/http" "strings" + + "github.com/GrayCodeAI/eyrie/internal/probehttp" ) // ProbeOpenAIModels GETs {baseURL}/models using api-key auth, then Bearer on 401. @@ -19,68 +20,49 @@ func ProbeOpenAIModels(ctx context.Context, baseURL, apiKey string) error { if apiKey == "" { return fmt.Errorf("xiaomi probe: missing API key") } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/models", nil) - if err != nil { - return err + url := baseURL + "/models" + commonHeaders := map[string]string{ + "Accept": "application/json", + "User-Agent": probehttp.UserAgent(), } - req.Header.Set("Accept", "application/json") - req.Header.Set("User-Agent", "eyrie-model-catalog/1.0") - setAPIKeyAuth(req, apiKey) - status, err := doProbe(req) + status, _, err := probehttp.DoGet(ctx, url, func() map[string]string { + h := map[string]string{} + for k, v := range commonHeaders { + h[k] = v + } + setAPIKeyAuthHeader(h, apiKey) + return h + }()) if err != nil { - return err + return fmt.Errorf("xiaomi probe: network error: %w", err) } if status == http.StatusUnauthorized { - req2, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/models", nil) - if err != nil { - return err - } - req2.Header.Set("Accept", "application/json") - req2.Header.Set("User-Agent", "eyrie-model-catalog/1.0") - req2.Header.Set("Authorization", "Bearer "+apiKey) - status, err = doProbe(req2) + status, _, err = probehttp.DoGet(ctx, url, func() map[string]string { + h := map[string]string{} + for k, v := range commonHeaders { + h[k] = v + } + h["Authorization"] = "Bearer " + apiKey + return h + }()) if err != nil { - return err + return fmt.Errorf("xiaomi probe: network error: %w", err) } } - return probeStatusErr(status) -} - -func setAPIKeyAuth(req *http.Request, apiKey string) { - req.Header.Set("api-key", apiKey) -} - -func doProbe(req *http.Request) (int, error) { - resp, err := http.DefaultClient.Do(req) - if err != nil { - return 0, fmt.Errorf("xiaomi probe: network error: %w", err) - } - defer func() { _ = resp.Body.Close() }() - _, _ = io.Copy(io.Discard, resp.Body) - return resp.StatusCode, nil -} - -func probeStatusErr(status int) error { if status >= 200 && status < 300 { return nil } - switch status { - case http.StatusUnauthorized, http.StatusForbidden: - return fmt.Errorf("credential probe failed: invalid API key (HTTP %d)", status) - case http.StatusTooManyRequests: - return fmt.Errorf("credential probe failed: rate limited — try again shortly") - default: - if status >= 500 { - return fmt.Errorf("credential probe failed: provider unavailable (HTTP %d)", status) - } - return fmt.Errorf("credential probe failed: HTTP %d", status) - } + return probehttp.ProbeError(status) +} + +func setAPIKeyAuthHeader(h map[string]string, apiKey string) { + h["api-key"] = apiKey } // SetMimoRequestAuth applies MiMo-preferred auth (api-key header). func SetMimoRequestAuth(req *http.Request, apiKey string) { - setAPIKeyAuth(req, apiKey) + req.Header.Set("api-key", apiKey) } // FetchOpenAIModelsJSON GETs /models and returns raw model objects from the OpenAI list response. @@ -90,18 +72,28 @@ func FetchOpenAIModelsJSON(ctx context.Context, baseURL, apiKey string) ([]json. if baseURL == "" || apiKey == "" { return nil, fmt.Errorf("xiaomi: base URL and API key required") } - body, status, err := getModelsBody(ctx, baseURL, apiKey) + url := baseURL + "/models" + + headers := map[string]string{ + "Accept": "application/json", + "User-Agent": probehttp.UserAgent(), + } + setAPIKeyAuthHeader(headers, apiKey) + + status, body, err := probehttp.DoGet(ctx, url, headers) if err != nil { return nil, err } if status == http.StatusUnauthorized { - body, status, err = getModelsBodyBearer(ctx, baseURL, apiKey) + delete(headers, "api-key") + headers["Authorization"] = "Bearer " + apiKey + status, body, err = probehttp.DoGet(ctx, url, headers) if err != nil { return nil, err } } if status < 200 || status >= 300 { - return nil, probeStatusErr(status) + return nil, probehttp.ProbeError(status) } var payload struct { Data []json.RawMessage `json:"data"` @@ -112,41 +104,6 @@ func FetchOpenAIModelsJSON(ctx context.Context, baseURL, apiKey string) ([]json. return payload.Data, nil } -func getModelsBody(ctx context.Context, baseURL, apiKey string) ([]byte, int, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/models", nil) - if err != nil { - return nil, 0, err - } - req.Header.Set("Accept", "application/json") - req.Header.Set("User-Agent", "eyrie-model-catalog/1.0") - SetMimoRequestAuth(req, apiKey) - return doModelsRequest(req) -} - -func getModelsBodyBearer(ctx context.Context, baseURL, apiKey string) ([]byte, int, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/models", nil) - if err != nil { - return nil, 0, err - } - req.Header.Set("Accept", "application/json") - req.Header.Set("User-Agent", "eyrie-model-catalog/1.0") - req.Header.Set("Authorization", "Bearer "+apiKey) - return doModelsRequest(req) -} - -func doModelsRequest(req *http.Request) ([]byte, int, error) { - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, 0, err - } - defer func() { _ = resp.Body.Close() }() - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, resp.StatusCode, err - } - return body, resp.StatusCode, nil -} - // IsRetryableHTTPStatus reports whether chat may retry via Anthropic compatibility. func IsRetryableHTTPStatus(status int) bool { switch status { diff --git a/config/credential/probe.go b/config/credential/probe.go index 11de4e2..80b85c2 100644 --- a/config/credential/probe.go +++ b/config/credential/probe.go @@ -3,8 +3,6 @@ package credential import ( "context" "fmt" - "io" - "net/http" "os" "strings" "time" @@ -12,6 +10,7 @@ import ( "github.com/GrayCodeAI/eyrie/catalog" "github.com/GrayCodeAI/eyrie/catalog/registry" "github.com/GrayCodeAI/eyrie/catalog/xiaomi" + "github.com/GrayCodeAI/eyrie/internal/probehttp" ) const credentialProbeTimeout = 8 * time.Second @@ -131,57 +130,41 @@ func probeOpenAIModels(ctx context.Context, baseURL, secret string) error { if baseURL == "" { return fmt.Errorf("credential probe: missing base URL") } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/models", nil) + status, _, err := probehttp.DoGet(ctx, baseURL+"/models", map[string]string{ + "Authorization": "Bearer " + secret, + }) if err != nil { - return err + return fmt.Errorf("credential probe: network error: %w", err) + } + if status >= 200 && status < 300 { + return nil } - req.Header.Set("Authorization", "Bearer "+secret) - return doProbeRequest(req) + return probehttp.ProbeError(status) } func probeAnthropic(ctx context.Context, secret string) error { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.anthropic.com/v1/models", nil) + status, _, err := probehttp.DoGet(ctx, "https://api.anthropic.com/v1/models", map[string]string{ + "x-api-key": secret, + "anthropic-version": "2023-06-01", + }) if err != nil { - return err + return fmt.Errorf("credential probe: network error: %w", err) } - req.Header.Set("x-api-key", secret) - req.Header.Set("anthropic-version", "2023-06-01") - return doProbeRequest(req) -} - -func probeGemini(ctx context.Context, secret string) error { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://generativelanguage.googleapis.com/v1beta/models", nil) - if err != nil { - return err + if status >= 200 && status < 300 { + return nil } - req.Header.Set("x-goog-api-key", secret) - return doProbeRequest(req) + return probehttp.ProbeError(status) } -func doProbeRequest(req *http.Request) error { - resp, err := http.DefaultClient.Do(req) +func probeGemini(ctx context.Context, secret string) error { + status, _, err := probehttp.DoGet(ctx, "https://generativelanguage.googleapis.com/v1beta/models", map[string]string{ + "x-goog-api-key": secret, + }) if err != nil { return fmt.Errorf("credential probe: network error: %w", err) } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode >= 200 && resp.StatusCode < 300 { - _, _ = io.Copy(io.Discard, resp.Body) + if status >= 200 && status < 300 { return nil } - _, _ = io.ReadAll(io.LimitReader(resp.Body, 512)) - return probeHTTPError(resp.StatusCode) -} - -func probeHTTPError(status int) error { - switch status { - case http.StatusUnauthorized, http.StatusForbidden: - return fmt.Errorf("credential probe failed: invalid API key (HTTP %d)", status) - case http.StatusTooManyRequests: - return fmt.Errorf("credential probe failed: rate limited — try again shortly") - default: - if status >= 500 { - return fmt.Errorf("credential probe failed: provider unavailable (HTTP %d)", status) - } - return fmt.Errorf("credential probe failed: HTTP %d", status) - } + return probehttp.ProbeError(status) } diff --git a/internal/probehttp/probehttp.go b/internal/probehttp/probehttp.go new file mode 100644 index 0000000..996e643 --- /dev/null +++ b/internal/probehttp/probehttp.go @@ -0,0 +1,95 @@ +// Package probehttp contains shared helpers for the eyrie credential-probe +// and catalog-probe call sites. It centralises the HTTP-client configuration +// and the HTTP-status-to-error mapping that probe code reaches for on every +// request. Keeping it in one place means timeout policy and error wording +// stay aligned across credential probes and live catalog probes. +package probehttp + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// DefaultRequestTimeout caps the time a single probe HTTP request can take. +// The probe context already carries a deadline, but the http.Client.Timeout +// is a second line of defence: it bounds the time spent in TLS, redirects, +// and the like even if the caller's context deadline is missing. +const DefaultRequestTimeout = 15 * time.Second + +// DefaultClient is the shared *http.Client used by probe code in the eyrie +// repo. Callers should reuse it instead of http.DefaultClient so the +// per-request timeout policy stays consistent. +var DefaultClient = &http.Client{Timeout: DefaultRequestTimeout} + +// ProbeError builds a credential-probe error message for a non-2xx response. +// The wording is part of the public surface that hawk surfaces to users when +// /config probe fails, so the strings here are stable. +// +// status is the HTTP status code returned by the provider. The function +// collapses 401/403 into a single "invalid key" message, distinguishes +// 429 (rate limited) from a hard 5xx (provider unavailable), and falls +// back to a generic HTTP-status message for everything else. +func ProbeError(status int) error { + switch { + case status == http.StatusUnauthorized || status == http.StatusForbidden: + return fmt.Errorf("credential probe failed: invalid API key (HTTP %d)", status) + case status == http.StatusTooManyRequests: + return fmt.Errorf("credential probe failed: rate limited — try again shortly") + case status >= 500: + return fmt.Errorf("credential probe failed: provider unavailable (HTTP %d)", status) + default: + return fmt.Errorf("credential probe failed: HTTP %d", status) + } +} + +// DoGet issues a GET against url with the given headers, returns the status +// code and body. The body is bounded to 1 MiB so a malicious or buggy +// provider cannot exhaust memory. The body is read and closed on the caller's +// behalf; callers only need to inspect (status, body, err). +// +// The request inherits the supplied context and the package-level +// DefaultClient, so a missing context deadline is still capped by the +// client Timeout. +func DoGet(ctx context.Context, url string, headers map[string]string) (int, []byte, error) { + if ctx == nil { + ctx = context.Background() + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return 0, nil, err + } + for k, v := range headers { + req.Header.Set(k, v) + } + resp, err := DefaultClient.Do(req) + if err != nil { + return 0, nil, err + } + defer func() { _ = resp.Body.Close() }() + + const maxBody = 1 << 20 // 1 MiB + body, err := io.ReadAll(io.LimitReader(resp.Body, maxBody)) + if err != nil { + return resp.StatusCode, nil, err + } + return resp.StatusCode, body, nil +} + +// UserAgent returns the standard eyrie User-Agent string for probe traffic. +func UserAgent() string { return "eyrie-probe/1.0" } + +// JoinURL trims a trailing slash from base and joins it with the supplied +// path. It's a tiny helper kept here so the various probe call sites stop +// re-implementing the trim/concat dance. +func JoinURL(base, path string) string { + base = strings.TrimRight(base, "/") + path = strings.TrimLeft(path, "/") + if path == "" { + return base + } + return base + "/" + path +} diff --git a/internal/probehttp/probehttp_test.go b/internal/probehttp/probehttp_test.go new file mode 100644 index 0000000..5d2dea8 --- /dev/null +++ b/internal/probehttp/probehttp_test.go @@ -0,0 +1,107 @@ +package probehttp + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestProbeError(t *testing.T) { + tests := []struct { + name string + status int + wantSubstr string + }{ + {"unauthorized", http.StatusUnauthorized, "invalid API key"}, + {"forbidden", http.StatusForbidden, "invalid API key"}, + {"rate limited", http.StatusTooManyRequests, "rate limited"}, + {"server 500", http.StatusInternalServerError, "provider unavailable"}, + {"bad gateway 502", http.StatusBadGateway, "provider unavailable"}, + {"client error 400", http.StatusBadRequest, "HTTP 400"}, + {"client error 404", http.StatusNotFound, "HTTP 404"}, + {"ok 200 still errs as HTTP 200", http.StatusOK, "HTTP 200"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ProbeError(tt.status) + if err == nil { + t.Fatalf("ProbeError(%d) returned nil; expected an error", tt.status) + } + if !strings.Contains(err.Error(), tt.wantSubstr) { + t.Errorf("ProbeError(%d) = %q; want substring %q", tt.status, err.Error(), tt.wantSubstr) + } + }) + } +} + +func TestDoGet_RespondsAndBoundsBody(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Test") != "ok" { + t.Errorf("missing X-Test header; got %q", r.Header.Get("X-Test")) + } + // Write 2 MiB; we expect DoGet to truncate to 1 MiB. + _, _ = w.Write(make([]byte, 2<<20)) + })) + defer srv.Close() + + status, body, err := DoGet(context.Background(), srv.URL+"/foo", map[string]string{"X-Test": "ok"}) + if err != nil { + t.Fatalf("DoGet: %v", err) + } + if status != http.StatusOK { + t.Errorf("status = %d; want 200", status) + } + if len(body) > (1<<20)+1024 { + t.Errorf("body len = %d; expected <= 1 MiB + a few bytes for safety", len(body)) + } +} + +func TestDoGet_RespectsContextDeadline(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(200 * time.Millisecond) + _, _ = w.Write([]byte("late")) + })) + defer srv.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + + _, _, err := DoGet(ctx, srv.URL, nil) + if err == nil { + t.Fatalf("DoGet: expected error from context deadline, got nil") + } +} + +func TestJoinURL(t *testing.T) { + tests := []struct { + base, path, want string + }{ + {"http://a", "b", "http://a/b"}, + {"http://a/", "b", "http://a/b"}, + {"http://a", "/b", "http://a/b"}, + {"http://a/", "/b", "http://a/b"}, + {"http://a/", "", "http://a"}, + {"http://a", "", "http://a"}, + } + for _, tt := range tests { + got := JoinURL(tt.base, tt.path) + if got != tt.want { + t.Errorf("JoinURL(%q, %q) = %q; want %q", tt.base, tt.path, got, tt.want) + } + } +} + +func TestUserAgent(t *testing.T) { + if got := UserAgent(); !strings.HasPrefix(got, "eyrie-") { + t.Errorf("UserAgent() = %q; want it to start with \"eyrie-\"", got) + } +} + +func TestDefaultClient_HasTimeout(t *testing.T) { + if DefaultClient.Timeout <= 0 { + t.Errorf("DefaultClient.Timeout = %v; expected > 0 so the client bounds requests", DefaultClient.Timeout) + } +}