Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions test/asynchronous/test_async_network_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright 2026-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Async-only unit tests for network_layer.py."""
Comment thread
aclark4life marked this conversation as resolved.

from __future__ import annotations

import asyncio
import sys
from unittest.mock import AsyncMock, MagicMock, patch

sys.path[0:0] = [""]

from pymongo.common import MAX_MESSAGE_SIZE
from pymongo.errors import ProtocolError
from pymongo.network_layer import PyMongoProtocol, _async_socket_receive
from test.asynchronous import AsyncUnitTest, unittest
from test.utils_shared import make_msg_header


def _make_protocol(timeout=None):
protocol = PyMongoProtocol(timeout=timeout)
mock_transport = MagicMock()
mock_transport.is_closing.return_value = False
protocol.transport = mock_transport
return protocol


class TestProcessHeader(AsyncUnitTest):
async def asyncSetUp(self):
self.protocol = _make_protocol()

def test_op_msg_returns_body_len_and_op_code(self):
self.protocol._header = memoryview(
bytearray(make_msg_header(length=32, request_id=1, response_to=99, op_code=2013))
)
body_len, op_code, response_to, expecting_compression = self.protocol.process_header()
self.assertEqual(body_len, 16)
self.assertEqual(op_code, 2013)
self.assertEqual(response_to, 99)
self.assertFalse(expecting_compression)

def test_op_compressed_sets_expecting_compression(self):
# OP_COMPRESSED=2012; process_header strips the 9-byte compression sub-header
# (op code + uncompressed size + compressor id), then the 16-byte standard header.
# length=35 → after compression sub-header: 26 → body: 10
self.protocol._header = memoryview(
bytearray(make_msg_header(length=35, request_id=1, response_to=0, op_code=2012))
)
body_len, op_code, _response_to, expecting_compression = self.protocol.process_header()
self.assertEqual(body_len, 10)
self.assertEqual(op_code, 2012)
self.assertTrue(expecting_compression)

def test_op_compressed_length_too_small_raises(self):
self.protocol._header = memoryview(
bytearray(make_msg_header(length=25, request_id=1, response_to=0, op_code=2012))
)
with self.assertRaisesRegex(ProtocolError, "not longer than standard OP_COMPRESSED"):
self.protocol.process_header()

def test_non_compressed_length_too_small_raises(self):
self.protocol._header = memoryview(
bytearray(make_msg_header(length=16, request_id=1, response_to=0, op_code=2013))
)
with self.assertRaisesRegex(ProtocolError, "not longer than standard message header size"):
self.protocol.process_header()

def test_length_exceeds_max_raises(self):
self.protocol._header = memoryview(
bytearray(
make_msg_header(
length=MAX_MESSAGE_SIZE + 1, request_id=1, response_to=0, op_code=2013
)
)
)
with self.assertRaisesRegex(ProtocolError, "larger than server max"):
self.protocol.process_header()


class TestClose(AsyncUnitTest):
async def asyncSetUp(self):
self.protocol = _make_protocol()

def test_close_aborts_transport(self):
self.protocol.close()
self.assertTrue(self.protocol.transport.abort.called)

async def test_close_propagates_exception_to_pending_read(self):
read_task = asyncio.create_task(
self.protocol.read(request_id=None, max_message_size=MAX_MESSAGE_SIZE)
)
await asyncio.sleep(0)
self.protocol.close(OSError("connection reset"))
with self.assertRaisesRegex(OSError, "connection reset"):
await read_task


class TestBufferUpdated(AsyncUnitTest):
async def asyncSetUp(self):
self.protocol = _make_protocol()

def test_zero_bytes_closes_connection(self):
self.protocol.buffer_updated(0)
self.assertTrue(self.protocol.transport.abort.called)

def test_protocol_error_closes_connection(self):
buf = self.protocol.get_buffer(16)
buf[:16] = make_msg_header(length=16, request_id=1, response_to=0, op_code=2013)
self.protocol.buffer_updated(16)
self.assertTrue(self.protocol.transport.abort.called)

async def test_resolves_pending_read(self):
read_task = asyncio.create_task(
self.protocol.read(request_id=None, max_message_size=MAX_MESSAGE_SIZE)
)
await asyncio.sleep(0)

# Feed a valid 32-byte OP_MSG header (16-byte header + 16-byte body).
header = make_msg_header(length=32, request_id=1, response_to=99, op_code=2013)
buf = self.protocol.get_buffer(16)
buf[:16] = header
self.protocol.buffer_updated(16)

self.assertFalse(self.protocol._expecting_header)
self.assertEqual(self.protocol._message_size, 16)

# Feed the 16-byte body.
buf = self.protocol.get_buffer(16)
buf[:16] = b"x" * 16
self.protocol.buffer_updated(16)

_data, op_code = await read_task
self.assertEqual(op_code, 2013)


class TestAsyncSocketReceive(AsyncUnitTest):
async def test_raises_on_connection_closed(self):
# Covers the explicit `raise OSError("connection closed")` branch when
# sock_recv_into returns 0.
mock_socket = MagicMock()
loop = asyncio.get_running_loop()

with patch.object(loop, "sock_recv_into", new=AsyncMock(return_value=0)):
with self.assertRaisesRegex(OSError, "connection closed"):
await _async_socket_receive(mock_socket, 10, loop)


if __name__ == "__main__":
unittest.main()
100 changes: 100 additions & 0 deletions test/test_network_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2026-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Sync-only unit tests for network_layer.py.

These cover ``receive_message`` and ``receive_data``, which only exist on the
synchronous receive path (the async path uses ``PyMongoProtocol`` instead).
The async-only tests live in ``test/asynchronous/test_async_network_layer.py``.
"""

from __future__ import annotations

import sys
from unittest.mock import MagicMock, patch

sys.path[0:0] = [""]

from pymongo import network_layer
from pymongo.common import MAX_MESSAGE_SIZE
from pymongo.errors import ProtocolError
from test import UnitTest, unittest
from test.utils_shared import make_msg_header


def _make_conn():
conn = MagicMock()
conn.conn.gettimeout.return_value = None
# On PyPy/Windows, receive_data() calls wait_for_read() before recv_into().
# wait_for_read() checks fileno() == -1 as an early-exit; without this mock,
# sock.fileno() returns a MagicMock and sock.pending() > 0 raises TypeError.
conn.conn.sock.fileno.return_value = -1
return conn


class TestReceiveMessage(UnitTest):
def test_request_id_mismatch_raises(self):
with patch.object(
network_layer,
"receive_data",
return_value=make_msg_header(length=32, request_id=0, response_to=99, op_code=2013),
):
with self.assertRaisesRegex(ProtocolError, "Got response id"):
network_layer.receive_message(_make_conn(), request_id=1)

def test_length_too_small_raises(self):
with patch.object(
network_layer,
"receive_data",
return_value=make_msg_header(length=16, request_id=0, response_to=0, op_code=2013),
):
with self.assertRaisesRegex(ProtocolError, "not longer than standard message header"):
network_layer.receive_message(_make_conn(), request_id=None)

def test_length_exceeds_max_raises(self):
with patch.object(
network_layer,
"receive_data",
return_value=make_msg_header(
length=MAX_MESSAGE_SIZE + 1, request_id=0, response_to=0, op_code=2013
),
):
with self.assertRaisesRegex(ProtocolError, "larger than server max"):
network_layer.receive_message(_make_conn(), request_id=None)

def test_unknown_opcode_raises(self):
with patch.object(
network_layer,
"receive_data",
side_effect=[
make_msg_header(length=20, request_id=0, response_to=0, op_code=9999),
b"data",
],
):
with self.assertRaisesRegex(ProtocolError, "Got opcode"):
network_layer.receive_message(_make_conn(), request_id=None)


class TestReceiveData(UnitTest):
def test_raises_on_connection_closed(self):
# Covers the explicit `raise OSError("connection closed")` branch when
# recv_into returns 0.
conn = _make_conn()
conn.conn.recv_into.return_value = 0
with self.assertRaisesRegex(OSError, "connection closed"):
network_layer.receive_data(conn, 10, deadline=None)


if __name__ == "__main__":
unittest.main()
6 changes: 6 additions & 0 deletions test/utils_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import random
import re
import shutil
import struct
import sys
import threading
import unittest
Expand Down Expand Up @@ -743,3 +744,8 @@ async def async_barrier_wait(barrier, timeout: float | None = None):

def barrier_wait(barrier, timeout: float | None = None):
barrier.wait(timeout=timeout)


def make_msg_header(length: int, request_id: int, response_to: int, op_code: int) -> bytes:
"""Pack a MongoDB wire-protocol message header."""
return struct.pack("<iiii", length, request_id, response_to, op_code)
1 change: 1 addition & 0 deletions tools/synchro.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def async_only_test(f: Path) -> bool:
"test_async_loop_safety.py",
"test_async_contextvars_reset.py",
"test_async_loop_unblocked.py",
"test_async_network_layer.py",
]


Expand Down
Loading