diff --git a/keepcurrent.go b/keepcurrent.go index 9d816ae..f5b1381 100644 --- a/keepcurrent.go +++ b/keepcurrent.go @@ -7,7 +7,6 @@ import ( "bytes" "errors" "io" - "io/ioutil" "time" ) @@ -103,8 +102,10 @@ func (runner *Runner) syncOnce(from Source, chStop chan struct{}) { return } if err == nil { - // Read ahead to surface any error reading from the source - data, err = ioutil.ReadAll(rc) + // Read ahead to surface any error reading from the source. readAll + // pre-sizes its buffer when the source reports a length, avoiding the + // reallocation churn io.ReadAll incurs on large payloads. + data, err = readAll(rc) rc.Close() } if err == nil { diff --git a/read.go b/read.go new file mode 100644 index 0000000..6e4fc78 --- /dev/null +++ b/read.go @@ -0,0 +1,62 @@ +package keepcurrent + +import ( + "bytes" + "io" +) + +// maxPreAlloc caps how much readAll will allocate up front from a reader's +// self-reported size. It guards against a bogus or hostile size (e.g. a wildly +// inflated HTTP Content-Length) turning into a giant make() that panics or +// OOMs; anything larger falls back to io.ReadAll, which grows against the bytes +// actually delivered. The bound sits comfortably above the payloads keepcurrent +// syncs in practice (a MaxMind database is well under 100MB). +const maxPreAlloc = 256 << 20 // 256 MiB + +// readAll reads r to EOF into a single buffer. It differs from io.ReadAll only +// in that, when r can report how many bytes remain, it allocates that buffer up +// front. io.ReadAll grows its buffer by repeatedly appending and reallocating, +// so reading an N-byte payload churns through a sequence of ever-larger backing +// arrays (N/2, 3N/4, N, ...) that all become garbage. For the multi-megabyte +// payloads keepcurrent is built to sync (e.g. a ~75MB MaxMind database) that +// transient churn dominates memory on small hosts. Pre-sizing turns the read +// into a single allocation. +func readAll(r io.Reader) ([]byte, error) { + if n, ok := knownSize(r); ok && n >= 0 && n <= maxPreAlloc { + // +bytes.MinRead leaves room for the final zero-byte read that signals + // EOF, so an exactly-sized payload never forces ReadFrom to reallocate. + buf := bytes.NewBuffer(make([]byte, 0, int(n)+bytes.MinRead)) + _, err := buf.ReadFrom(r) + return buf.Bytes(), err + } + return io.ReadAll(r) +} + +// knownSize reports the number of bytes remaining in r when r can tell us. It +// recognises the sized readers keepcurrent constructs internally (sizedReadCloser) +// as well as the standard in-memory readers (*bytes.Reader, *bytes.Buffer, +// *strings.Reader) whose Len() reports the unread remainder. +func knownSize(r io.Reader) (int64, bool) { + switch v := r.(type) { + case interface{ size() int64 }: + return v.size(), true + case interface{ Len() int }: + return int64(v.Len()), true + } + return 0, false +} + +// sizedReadCloser couples a ReadCloser with the total number of bytes it will +// yield, so readAll can pre-size its buffer. keepcurrent wraps HTTP bodies +// (Content-Length) and extracted archive entries in this. +type sizedReadCloser struct { + io.ReadCloser + n int64 +} + +func (s sizedReadCloser) size() int64 { return s.n } + +// bytesReadCloser wraps an in-memory payload in a size-aware ReadCloser. +func bytesReadCloser(b []byte) io.ReadCloser { + return sizedReadCloser{ReadCloser: io.NopCloser(bytes.NewReader(b)), n: int64(len(b))} +} diff --git a/read_test.go b/read_test.go new file mode 100644 index 0000000..7bdc2f1 --- /dev/null +++ b/read_test.go @@ -0,0 +1,73 @@ +package keepcurrent + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// unsizedReader hides any size its wrapped reader might otherwise expose (it +// implements only Read), forcing readAll down the io.ReadAll fallback path. +type unsizedReader struct{ r io.Reader } + +func (u *unsizedReader) Read(p []byte) (int, error) { return u.r.Read(p) } + +func TestReadAllPreSizesKnownReaders(t *testing.T) { + payload := bytes.Repeat([]byte("x"), 1<<20) // 1 MiB + + // *bytes.Reader reports Len(), so readAll should allocate exactly once and + // not overshoot the way io.ReadAll's doubling does. + got, err := readAll(bytes.NewReader(payload)) + require.NoError(t, err) + assert.Equal(t, payload, got) + assert.Equalf(t, len(payload)+bytes.MinRead, cap(got), + "buffer for a sized reader should be pre-allocated to the payload size, not grown") + + // sizedReadCloser (what the web/tar.gz sources return) is also recognised. + got, err = readAll(bytesReadCloser(payload)) + require.NoError(t, err) + assert.Equal(t, payload, got) + assert.Equal(t, len(payload)+bytes.MinRead, cap(got)) +} + +func TestReadAllFallsBackForUnsizedReaders(t *testing.T) { + payload := bytes.Repeat([]byte("y"), 4096) + // An opaque reader exposes no size; readAll must still return the full data. + got, err := readAll(&unsizedReader{bytes.NewReader(payload)}) + require.NoError(t, err) + assert.Equal(t, payload, got) +} + +func TestReadAllFallsBackWhenSizeExceedsCap(t *testing.T) { + // A reader that reports a huge size (e.g. a bogus/hostile Content-Length) + // but only delivers a small payload. readAll must not attempt the giant + // pre-allocation; it should fall back to io.ReadAll and still return the + // real bytes. + payload := bytes.Repeat([]byte("z"), 1024) + r := sizedReadCloser{ReadCloser: io.NopCloser(bytes.NewReader(payload)), n: maxPreAlloc + 1} + + n, ok := knownSize(r) + require.True(t, ok) + require.Greater(t, n, int64(maxPreAlloc)) + + got, err := readAll(r) + require.NoError(t, err) + assert.Equal(t, payload, got) + assert.LessOrEqual(t, cap(got), maxPreAlloc, "must not pre-allocate the reported (bogus) size") +} + +func TestKnownSize(t *testing.T) { + n, ok := knownSize(bytes.NewReader(make([]byte, 42))) + assert.True(t, ok) + assert.EqualValues(t, 42, n) + + n, ok = knownSize(bytesReadCloser(make([]byte, 7))) + assert.True(t, ok) + assert.EqualValues(t, 7, n) + + _, ok = knownSize(&unsizedReader{}) + assert.False(t, ok) +} diff --git a/sink.go b/sink.go index 7dbdb59..38c899d 100644 --- a/sink.go +++ b/sink.go @@ -89,7 +89,12 @@ func (s *byteChannel) UpdateFrom(r io.Reader) (err error) { panic(rec) } }() - b, err := ioutil.ReadAll(r) + // readAll pre-sizes from the reader's length (the Runner hands us a + // *bytes.Reader), so this copy is a single allocation. We deliberately copy + // rather than forward the Runner's buffer: ToChannel's contract is that each + // delivered slice is independently owned, so consumers may retain or mutate + // it freely. + b, err := readAll(r) if err != nil { return err } diff --git a/source.go b/source.go index ea0bd5a..a0251cf 100644 --- a/source.go +++ b/source.go @@ -1,7 +1,6 @@ package keepcurrent import ( - "bytes" "context" "errors" "fmt" @@ -22,6 +21,14 @@ type webSource struct { client *http.Client } +// drainClose discards any remaining body and closes it. net/http only returns a +// connection to the keep-alive pool once its response body has been read to EOF, +// so error/not-modified responses we don't hand to the caller must be drained. +func drainClose(rc io.ReadCloser) { + _, _ = io.Copy(io.Discard, rc) + _ = rc.Close() +} + // FromWeb constructs a source from the given URL. func FromWeb(url string) Source { return FromWebWithClient(url, http.DefaultClient) @@ -49,15 +56,25 @@ func (s *webSource) Fetch(ifNewerThan time.Time) (io.ReadCloser, error) { return nil, err } if resp.StatusCode == http.StatusNotModified { + // Drain to EOF then close so net/http can return the connection to the + // pool for keep-alive reuse (it won't reuse one whose body wasn't fully + // read). 304 carries no body, so this is effectively just a close here. + // We hand the body to the caller only on the success path below. + drainClose(resp.Body) return nil, ErrUnmodified } if resp.StatusCode != http.StatusOK { + drainClose(resp.Body) return nil, fmt.Errorf("unexpected HTTP status %v", resp.StatusCode) } etag := resp.Header.Get("ETag") if etag != "" { s.setETag(etag) } + if resp.ContentLength >= 0 { + // Surface the Content-Length so the Runner can pre-size its read buffer. + return sizedReadCloser{ReadCloser: resp.Body, n: resp.ContentLength}, nil + } return resp.Body, nil } @@ -112,7 +129,11 @@ func (s *tarGzSource) Fetch(ifNewerThan time.Time) (io.ReadCloser, error) { return err } defer f.Close() - buf, err = io.ReadAll(f) + // Wrap the entry in a size-aware reader so readAll pre-sizes the + // buffer from the archive entry's uncompressed size — extracting a + // large file (e.g. a ~75MB mmdb) becomes a single allocation rather + // than the reallocation churn of io.ReadAll. + buf, err = readAll(sizedReadCloser{ReadCloser: f, n: info.Size()}) if err != nil { return err } @@ -122,7 +143,8 @@ func (s *tarGzSource) Fetch(ifNewerThan time.Time) (io.ReadCloser, error) { }) if errors.Is(err, errFound) { - return io.NopCloser(bytes.NewReader(buf)), nil + // Return a size-aware reader so the Runner's read can also be pre-sized. + return bytesReadCloser(buf), nil } if err != nil { return nil, err