diff --git a/pymongo/_telemetry.py b/pymongo/_telemetry.py index 471a013eb9..168d2a07fb 100644 --- a/pymongo/_telemetry.py +++ b/pymongo/_telemetry.py @@ -86,7 +86,7 @@ def _emit_log(self, message: _CommandStatusMessage, **extra: Any) -> None: commandName=self._name, databaseName=self._dbname, requestId=self._request_id, - operationId=self._request_id, + operationId=self._op_id if self._op_id is not None else self._request_id, driverConnectionId=self._conn.id, serverConnectionId=self._conn.server_connection_id, serverHost=self._conn.address[0], diff --git a/pymongo/asynchronous/command_runner.py b/pymongo/asynchronous/command_runner.py index 1f32cc744e..ac47aed798 100644 --- a/pymongo/asynchronous/command_runner.py +++ b/pymongo/asynchronous/command_runner.py @@ -284,6 +284,7 @@ async def run_cursor_command( more_to_come: bool = False, unpack_res: Optional[Callable[..., Any]] = None, cursor_id: Optional[int] = None, + op_id: Optional[int] = None, ) -> tuple[list[dict[str, Any]], Optional[_OpMsg], datetime.timedelta]: """Run a cursor ``find``/``getMore`` operation over ``conn``. @@ -304,6 +305,7 @@ async def run_cursor_command( :param unpack_res: A callable decoding the wire response; when ``None`` the reply's own ``unpack_response`` is used. :param cursor_id: The cursor id passed to ``unpack_res``. + :param op_id: The APM operation id; defaults to ``request_id``. """ topology_id = client._topology_id if client is not None else None return await _run_command( @@ -325,6 +327,7 @@ async def run_cursor_command( more_to_come=more_to_come, unpack_res=unpack_res, cursor_id=cursor_id, + op_id=op_id, ) @@ -348,6 +351,7 @@ async def run_command( user_fields: Optional[Mapping[str, Any]] = None, exhaust_allowed: bool = False, write_concern: Optional[WriteConcern] = None, + op_id: Optional[int] = None, ) -> _DocumentType: """Encode and execute a command over ``conn``, or raise socket.error. @@ -376,6 +380,7 @@ async def run_command( passed to ``bson._decode_all_selective``. :param exhaust_allowed: True if we should enable OP_MSG exhaustAllowed. :param write_concern: The write concern for this command. Applied via CSOT. + :param op_id: The APM operation id; defaults to ``request_id``. """ name = next(iter(spec)) @@ -428,6 +433,7 @@ async def run_command( codec_options=codec_options, user_fields=user_fields, orig=orig, + op_id=op_id, check=check, allowable_errors=allowable_errors, parse_write_concern_error=parse_write_concern_error, diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index a558f96356..bcbbc87e3d 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -95,7 +95,7 @@ _log_client_error, _log_or_warn, ) -from pymongo.message import _CursorAddress, _GetMore, _Query +from pymongo.message import _CursorAddress, _GetMore, _Query, _randint from pymongo.monitoring import ConnectionClosedReason, _EventListeners from pymongo.operations import ( DeleteMany, @@ -1837,6 +1837,8 @@ async def _select_server( be pinned to a mongos server address. - `address` (optional): Address when sending a message to a specific server, used for getMore. + - `operation_id` (optional): Stable operation id shared across retries, + used for command monitoring. """ try: topology = await self._get_topology() @@ -1932,6 +1934,7 @@ async def _run_operation( async with operation.conn_mgr._lock: async with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type] err_handler.contribute_socket(operation.conn_mgr.conn) + operation.conn_mgr.conn.op_id = _randint() return await server.run_operation( operation.conn_mgr.conn, operation, @@ -2023,6 +2026,7 @@ async def _retry_internal( :param retryable: If the operation should be retried once, defaults to None :param is_run_command: If this is a runCommand operation, defaults to False :param is_aggregate_write: If this is a aggregate operation with a write, defaults to False. + :param operation_id: Stable operation id shared across retries, defaults to None :return: Output of the calling func() """ @@ -2069,6 +2073,7 @@ async def _retryable_read( (may not always be supported even if supplied), defaults to False :param is_run_command: If this is a runCommand operation, defaults to False. :param is_aggregate_write: If this is a aggregate operation with a write, defaults to False. + :param operation_id: Stable operation id shared across retries, defaults to None """ # Ensure that the client supports retrying on reads and there is no session in @@ -2112,6 +2117,7 @@ async def _retryable_write( :param session: Client session we will use to execute write operation :param operation: The name of the operation that the server is being selected for :param bulk: bulk abstraction to execute operations in bulk, defaults to None + :param operation_id: Stable operation id shared across retries, defaults to None """ async with self._tmp_session(session) as s: return await self._retry_with_session(retryable, func, s, bulk, operation, operation_id) @@ -2795,7 +2801,7 @@ def __init__( self._server: Server = None # type: ignore self._deprioritized_servers: list[Server] = [] self._operation = operation - self._operation_id = operation_id + self._operation_id = operation_id if operation_id is not None else _randint() self._attempt_number = 0 self._is_run_command = is_run_command self._is_aggregate_write = is_aggregate_write @@ -3001,6 +3007,7 @@ async def _write(self) -> T: is_mongos = False self._server = await self._get_server() async with self._client._checkout(self._server, self._session) as conn: + conn.op_id = self._operation_id max_wire_version = conn.max_wire_version sessions_supported = ( self._session @@ -3040,6 +3047,7 @@ async def _read(self) -> T: conn, read_pref, ): + conn.op_id = self._operation_id if self._retrying and not self._retryable and not self._always_retryable: self._check_last_error() if self._retrying: diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 4ed3b85dbf..a008246f40 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -177,6 +177,8 @@ def __init__( self.creation_time = time.monotonic() # For gossiping $clusterTime from the connection handshake to the client. self._cluster_time = None + # Stable operation id for the operation currently using this connection. + self.op_id: Optional[int] = None def set_conn_timeout(self, timeout: Optional[float]) -> None: """Cache last timeout to avoid duplicate calls to conn.settimeout.""" @@ -416,6 +418,7 @@ async def command( user_fields=user_fields, exhaust_allowed=exhaust_allowed, write_concern=write_concern, + op_id=self.op_id, ) except (OperationFailure, NotPrimaryError): raise @@ -1319,6 +1322,7 @@ async def checkin(self, conn: AsyncConnection) -> None: txn = conn.pinned_txn cursor = conn.pinned_cursor conn.active = False + conn.op_id = None conn.pinned_txn = False conn.pinned_cursor = False self.__pinned_sockets.discard(conn) diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index 57158dfc44..bceb080e17 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -185,6 +185,7 @@ async def run_operation( more_to_come=more_to_come, unpack_res=unpack_res, cursor_id=operation.cursor_id, + op_id=conn.op_id, ) assert reply is not None diff --git a/pymongo/synchronous/command_runner.py b/pymongo/synchronous/command_runner.py index 077e0f9409..88574426f6 100644 --- a/pymongo/synchronous/command_runner.py +++ b/pymongo/synchronous/command_runner.py @@ -284,6 +284,7 @@ def run_cursor_command( more_to_come: bool = False, unpack_res: Optional[Callable[..., Any]] = None, cursor_id: Optional[int] = None, + op_id: Optional[int] = None, ) -> tuple[list[dict[str, Any]], Optional[_OpMsg], datetime.timedelta]: """Run a cursor ``find``/``getMore`` operation over ``conn``. @@ -304,6 +305,7 @@ def run_cursor_command( :param unpack_res: A callable decoding the wire response; when ``None`` the reply's own ``unpack_response`` is used. :param cursor_id: The cursor id passed to ``unpack_res``. + :param op_id: The APM operation id; defaults to ``request_id``. """ topology_id = client._topology_id if client is not None else None return _run_command( @@ -325,6 +327,7 @@ def run_cursor_command( more_to_come=more_to_come, unpack_res=unpack_res, cursor_id=cursor_id, + op_id=op_id, ) @@ -348,6 +351,7 @@ def run_command( user_fields: Optional[Mapping[str, Any]] = None, exhaust_allowed: bool = False, write_concern: Optional[WriteConcern] = None, + op_id: Optional[int] = None, ) -> _DocumentType: """Encode and execute a command over ``conn``, or raise socket.error. @@ -376,6 +380,7 @@ def run_command( passed to ``bson._decode_all_selective``. :param exhaust_allowed: True if we should enable OP_MSG exhaustAllowed. :param write_concern: The write concern for this command. Applied via CSOT. + :param op_id: The APM operation id; defaults to ``request_id``. """ name = next(iter(spec)) @@ -428,6 +433,7 @@ def run_command( codec_options=codec_options, user_fields=user_fields, orig=orig, + op_id=op_id, check=check, allowable_errors=allowable_errors, parse_write_concern_error=parse_write_concern_error, diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 5f321afe5c..42d912475a 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -85,7 +85,7 @@ _log_client_error, _log_or_warn, ) -from pymongo.message import _CursorAddress, _GetMore, _Query +from pymongo.message import _CursorAddress, _GetMore, _Query, _randint from pymongo.monitoring import ConnectionClosedReason, _EventListeners from pymongo.operations import ( DeleteMany, @@ -1834,6 +1834,8 @@ def _select_server( be pinned to a mongos server address. - `address` (optional): Address when sending a message to a specific server, used for getMore. + - `operation_id` (optional): Stable operation id shared across retries, + used for command monitoring. """ try: topology = self._get_topology() @@ -1929,6 +1931,7 @@ def _run_operation( with operation.conn_mgr._lock: with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type] err_handler.contribute_socket(operation.conn_mgr.conn) + operation.conn_mgr.conn.op_id = _randint() return server.run_operation( operation.conn_mgr.conn, operation, @@ -2020,6 +2023,7 @@ def _retry_internal( :param retryable: If the operation should be retried once, defaults to None :param is_run_command: If this is a runCommand operation, defaults to False :param is_aggregate_write: If this is a aggregate operation with a write, defaults to False. + :param operation_id: Stable operation id shared across retries, defaults to None :return: Output of the calling func() """ @@ -2066,6 +2070,7 @@ def _retryable_read( (may not always be supported even if supplied), defaults to False :param is_run_command: If this is a runCommand operation, defaults to False. :param is_aggregate_write: If this is a aggregate operation with a write, defaults to False. + :param operation_id: Stable operation id shared across retries, defaults to None """ # Ensure that the client supports retrying on reads and there is no session in @@ -2109,6 +2114,7 @@ def _retryable_write( :param session: Client session we will use to execute write operation :param operation: The name of the operation that the server is being selected for :param bulk: bulk abstraction to execute operations in bulk, defaults to None + :param operation_id: Stable operation id shared across retries, defaults to None """ with self._tmp_session(session) as s: return self._retry_with_session(retryable, func, s, bulk, operation, operation_id) @@ -2786,7 +2792,7 @@ def __init__( self._server: Server = None # type: ignore self._deprioritized_servers: list[Server] = [] self._operation = operation - self._operation_id = operation_id + self._operation_id = operation_id if operation_id is not None else _randint() self._attempt_number = 0 self._is_run_command = is_run_command self._is_aggregate_write = is_aggregate_write @@ -2992,6 +2998,7 @@ def _write(self) -> T: is_mongos = False self._server = self._get_server() with self._client._checkout(self._server, self._session) as conn: + conn.op_id = self._operation_id max_wire_version = conn.max_wire_version sessions_supported = ( self._session @@ -3031,6 +3038,7 @@ def _read(self) -> T: conn, read_pref, ): + conn.op_id = self._operation_id if self._retrying and not self._retryable and not self._always_retryable: self._check_last_error() if self._retrying: diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 1006735444..422b07726b 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -177,6 +177,8 @@ def __init__( self.creation_time = time.monotonic() # For gossiping $clusterTime from the connection handshake to the client. self._cluster_time = None + # Stable operation id for the operation currently using this connection. + self.op_id: Optional[int] = None def set_conn_timeout(self, timeout: Optional[float]) -> None: """Cache last timeout to avoid duplicate calls to conn.settimeout.""" @@ -416,6 +418,7 @@ def command( user_fields=user_fields, exhaust_allowed=exhaust_allowed, write_concern=write_concern, + op_id=self.op_id, ) except (OperationFailure, NotPrimaryError): raise @@ -1315,6 +1318,7 @@ def checkin(self, conn: Connection) -> None: txn = conn.pinned_txn cursor = conn.pinned_cursor conn.active = False + conn.op_id = None conn.pinned_txn = False conn.pinned_cursor = False self.__pinned_sockets.discard(conn) diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index 09d8fb75e1..5ce6683f8c 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -185,6 +185,7 @@ def run_operation( more_to_come=more_to_come, unpack_res=unpack_res, cursor_id=operation.cursor_id, + op_id=conn.op_id, ) assert reply is not None diff --git a/test/__init__.py b/test/__init__.py index f4ae7fe948..5d40b8dbac 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -699,7 +699,7 @@ def require_failCommand_fail_point(self, func: Any) -> Any: func=func, ) - def require_failCommand_appName(self, func): + def require_failCommand_appName(self, func: Any) -> Any: """Run a test only if the server supports the failCommand appName.""" # SERVER-47195 and SERVER-49336. return self._require( diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 0699451d7e..bee0a2b8ce 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -699,7 +699,7 @@ def require_failCommand_fail_point(self, func: Any) -> Any: func=func, ) - def require_failCommand_appName(self, func): + def require_failCommand_appName(self, func: Any) -> Any: """Run a test only if the server supports the failCommand appName.""" # SERVER-47195 and SERVER-49336. return self._require( diff --git a/test/asynchronous/test_operation_id_retry.py b/test/asynchronous/test_operation_id_retry.py new file mode 100644 index 0000000000..5941a2e7ad --- /dev/null +++ b/test/asynchronous/test_operation_id_retry.py @@ -0,0 +1,160 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test that retry attempts reuse a single stable CommandStartedEvent.operation_id.""" + +from __future__ import annotations + +import sys + +sys.path[0:0] = [""] + +import pymongo +from pymongo.errors import ConnectionFailure +from pymongo.operations import InsertOne +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.utils_shared import AllowListEventListener + +_IS_SYNC = False + +_APP_NAME = "operationIdRetryTest" + +# Each operation, paired with the wire command it issues and an awaitable action. +# These are all retryable; a stable operation_id must span every retry attempt. +_RETRYABLE_WRITES = [ + ("insert", lambda c: c.insert_one({"_id": 100})), + ("update", lambda c: c.update_one({"_id": 1}, {"$set": {"y": 1}})), + ("update", lambda c: c.replace_one({"_id": 2}, {"x": 9})), + ("delete", lambda c: c.delete_one({"_id": 3})), + ("findAndModify", lambda c: c.find_one_and_update({"_id": 4}, {"$set": {"y": 2}})), + ("insert", lambda c: c.bulk_write([InsertOne({"_id": 200}), InsertOne({"_id": 201})])), +] + + +_RETRYABLE_READS = [ + ("find", lambda c: c.find({"x": 1}).to_list()), + ("find", lambda c: c.find_one({"_id": 1})), + ("aggregate", lambda c: _agg(c)), + ("aggregate", lambda c: c.count_documents({"x": 1})), + ("distinct", lambda c: c.distinct("x")), + ("listIndexes", lambda c: _list_indexes(c)), +] + + +async def _agg(coll): + cursor = await coll.aggregate([{"$match": {"x": 1}}]) + return await cursor.to_list() + + +async def _list_indexes(coll): + cursor = await coll.list_indexes() + return await cursor.to_list() + + +class TestOperationIdRetry(AsyncIntegrationTest): + RETRIES = 5 # fail this many attempts; the (RETRIES + 1)th succeeds. + + @async_client_context.require_failCommand_fail_point + @async_client_context.require_failCommand_appName + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + + async def _seed(self, coll): + await coll.drop() + await coll.insert_many([{"_id": i, "x": i % 3} for i in range(5)]) + await coll.create_index("x") + + async def _check_stable_operation_id(self, command_name, action, retries): + """Force ``retries`` retries of ``action`` and assert every command + event for ``command_name`` shares one integer operation_id.""" + listener = AllowListEventListener(command_name) + client = await self.async_rs_or_single_client(event_listeners=[listener], appname=_APP_NAME) + coll = client.pymongo_test.test_operation_id_retry + await self._seed(coll) + listener.reset() + + fail_point = { + "mode": {"times": retries}, + "data": { + "failCommands": [command_name], + "closeConnection": True, + "appName": _APP_NAME, + }, + } + async with self.fail_point(fail_point): + # A CSOT timeout lets a single operation retry more than once. + with pymongo.timeout(60): + await action(coll) + + started = listener.started_events + failed = listener.failed_events + succeeded = listener.succeeded_events + op_ids = [e.operation_id for e in started + failed + succeeded] + + self.assertEqual(len(started), retries + 1, "expected one started event per attempt") + self.assertEqual(len(failed), retries) + self.assertEqual(len(succeeded), 1) + self.assertTrue(all(isinstance(op, int) for op in op_ids)) + self.assertEqual( + len(set(op_ids)), + 1, + f"operation_id not stable across retries for {command_name}: {op_ids}", + ) + + @async_client_context.require_no_standalone + async def test_retryable_writes_reuse_operation_id(self): + for command_name, action in _RETRYABLE_WRITES: + with self.subTest(command=command_name): + await self._check_stable_operation_id(command_name, action, self.RETRIES) + + async def test_retryable_reads_reuse_operation_id(self): + for command_name, action in _RETRYABLE_READS: + with self.subTest(command=command_name): + await self._check_stable_operation_id(command_name, action, self.RETRIES) + + @async_client_context.require_no_standalone + async def test_non_retryable_write_is_not_retried(self): + # Multi-document writes are not retryable: a single network error must + # surface immediately, with exactly one attempt. + for command_name, action in [ + ("update", lambda c: c.update_many({"x": 1}, {"$set": {"z": 1}})), + ("delete", lambda c: c.delete_many({"x": 2})), + ]: + with self.subTest(command=command_name): + listener = AllowListEventListener(command_name) + client = await self.async_rs_or_single_client( + event_listeners=[listener], appname=_APP_NAME + ) + coll = client.pymongo_test.test_operation_id_retry + await self._seed(coll) + listener.reset() + + fail_point = { + "mode": {"times": 1}, + "data": { + "failCommands": [command_name], + "closeConnection": True, + "appName": _APP_NAME, + }, + } + async with self.fail_point(fail_point): + with self.assertRaises(ConnectionFailure): + await action(coll) + + self.assertEqual(len(listener.started_events), 1, "must not retry") + self.assertIsInstance(listener.started_events[0].operation_id, int) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_operation_id_retry.py b/test/test_operation_id_retry.py new file mode 100644 index 0000000000..313416af1a --- /dev/null +++ b/test/test_operation_id_retry.py @@ -0,0 +1,158 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test that retry attempts reuse a single stable CommandStartedEvent.operation_id.""" + +from __future__ import annotations + +import sys + +sys.path[0:0] = [""] + +import pymongo +from pymongo.errors import ConnectionFailure +from pymongo.operations import InsertOne +from test import IntegrationTest, client_context, unittest +from test.utils_shared import AllowListEventListener + +_IS_SYNC = True + +_APP_NAME = "operationIdRetryTest" + +# Each operation, paired with the wire command it issues and an awaitable action. +# These are all retryable; a stable operation_id must span every retry attempt. +_RETRYABLE_WRITES = [ + ("insert", lambda c: c.insert_one({"_id": 100})), + ("update", lambda c: c.update_one({"_id": 1}, {"$set": {"y": 1}})), + ("update", lambda c: c.replace_one({"_id": 2}, {"x": 9})), + ("delete", lambda c: c.delete_one({"_id": 3})), + ("findAndModify", lambda c: c.find_one_and_update({"_id": 4}, {"$set": {"y": 2}})), + ("insert", lambda c: c.bulk_write([InsertOne({"_id": 200}), InsertOne({"_id": 201})])), +] + + +_RETRYABLE_READS = [ + ("find", lambda c: c.find({"x": 1}).to_list()), + ("find", lambda c: c.find_one({"_id": 1})), + ("aggregate", lambda c: _agg(c)), + ("aggregate", lambda c: c.count_documents({"x": 1})), + ("distinct", lambda c: c.distinct("x")), + ("listIndexes", lambda c: _list_indexes(c)), +] + + +def _agg(coll): + cursor = coll.aggregate([{"$match": {"x": 1}}]) + return cursor.to_list() + + +def _list_indexes(coll): + cursor = coll.list_indexes() + return cursor.to_list() + + +class TestOperationIdRetry(IntegrationTest): + RETRIES = 5 # fail this many attempts; the (RETRIES + 1)th succeeds. + + @client_context.require_failCommand_fail_point + @client_context.require_failCommand_appName + def setUp(self) -> None: + super().setUp() + + def _seed(self, coll): + coll.drop() + coll.insert_many([{"_id": i, "x": i % 3} for i in range(5)]) + coll.create_index("x") + + def _check_stable_operation_id(self, command_name, action, retries): + """Force ``retries`` retries of ``action`` and assert every command + event for ``command_name`` shares one integer operation_id.""" + listener = AllowListEventListener(command_name) + client = self.rs_or_single_client(event_listeners=[listener], appname=_APP_NAME) + coll = client.pymongo_test.test_operation_id_retry + self._seed(coll) + listener.reset() + + fail_point = { + "mode": {"times": retries}, + "data": { + "failCommands": [command_name], + "closeConnection": True, + "appName": _APP_NAME, + }, + } + with self.fail_point(fail_point): + # A CSOT timeout lets a single operation retry more than once. + with pymongo.timeout(60): + action(coll) + + started = listener.started_events + failed = listener.failed_events + succeeded = listener.succeeded_events + op_ids = [e.operation_id for e in started + failed + succeeded] + + self.assertEqual(len(started), retries + 1, "expected one started event per attempt") + self.assertEqual(len(failed), retries) + self.assertEqual(len(succeeded), 1) + self.assertTrue(all(isinstance(op, int) for op in op_ids)) + self.assertEqual( + len(set(op_ids)), + 1, + f"operation_id not stable across retries for {command_name}: {op_ids}", + ) + + @client_context.require_no_standalone + def test_retryable_writes_reuse_operation_id(self): + for command_name, action in _RETRYABLE_WRITES: + with self.subTest(command=command_name): + self._check_stable_operation_id(command_name, action, self.RETRIES) + + def test_retryable_reads_reuse_operation_id(self): + for command_name, action in _RETRYABLE_READS: + with self.subTest(command=command_name): + self._check_stable_operation_id(command_name, action, self.RETRIES) + + @client_context.require_no_standalone + def test_non_retryable_write_is_not_retried(self): + # Multi-document writes are not retryable: a single network error must + # surface immediately, with exactly one attempt. + for command_name, action in [ + ("update", lambda c: c.update_many({"x": 1}, {"$set": {"z": 1}})), + ("delete", lambda c: c.delete_many({"x": 2})), + ]: + with self.subTest(command=command_name): + listener = AllowListEventListener(command_name) + client = self.rs_or_single_client(event_listeners=[listener], appname=_APP_NAME) + coll = client.pymongo_test.test_operation_id_retry + self._seed(coll) + listener.reset() + + fail_point = { + "mode": {"times": 1}, + "data": { + "failCommands": [command_name], + "closeConnection": True, + "appName": _APP_NAME, + }, + } + with self.fail_point(fail_point): + with self.assertRaises(ConnectionFailure): + action(coll) + + self.assertEqual(len(listener.started_events), 1, "must not retry") + self.assertIsInstance(listener.started_events[0].operation_id, int) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/synchro.py b/tools/synchro.py index 8132167e71..9fa60f1120 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -253,6 +253,7 @@ def async_only_test(f: Path) -> bool: "test_monitoring.py", "test_mongos_load_balancing.py", "test_on_demand_csfle.py", + "test_operation_id_retry.py", "test_periodic_executor.py", "test_pooling.py", "test_raw_bson.py",