diff --git a/ntp.py b/ntp.py index cf95549..147cc69 100755 --- a/ntp.py +++ b/ntp.py @@ -22,16 +22,20 @@ # https://datatracker.ietf.org/doc/html/rfc5905 # +# Type annotations, refactoring and docstrings by Tamas Nepusz import logging +from collections.abc import Callable, Iterable from datetime import datetime, timezone from ipaddress import IPv6Address, ip_address from random import random -from socket import AF_INET, AF_INET6, SOCK_DGRAM -from socket import socket, getaddrinfo, gaierror +from socket import AF_INET, AF_INET6, SOCK_DGRAM, gaierror, getaddrinfo, socket from struct import pack, unpack -from time import time, sleep +from time import sleep, time +from typing import TYPE_CHECKING, TypeAlias +if TYPE_CHECKING: + from argparse import ArgumentParser VERSION = 4 TOLERANCE = 15e-6 # 15 us/s (clock drift assumed in RFC) @@ -50,7 +54,17 @@ log = logging.getLogger("ntp") -def ntptime(t=None): +def ntptime(t: float | None = None) -> tuple[int, int]: + """Convert a Unix timestamp to an NTP timestamp tuple. + + Args: + t: Unix time in seconds. If omitted, the current system time is used. + + Returns: + A ``(seconds, fraction)`` tuple in NTP 64-bit fixed-point format, + where ``seconds`` are counted from the NTP epoch (1900-01-01) and + ``fraction`` contains the 32-bit fractional part. + """ if t is None: t = time() secs = int(t) @@ -60,45 +74,63 @@ def ntptime(t=None): class NtpError(Exception): + """Base class for NTP-related errors.""" + pass class NtpUnsynchronizedError(NtpError): + """Raised when synchronization data is unavailable or unusable.""" + pass class NtpDeniedError(NtpError): + """Raised when an NTP server denies the client's request.""" + pass class NtpThrottledError(NtpError): + """Raised when an NTP server asks the client to reduce its query rate.""" + pass class NtpPacketError(NtpError): + """Raised when an NTP packet is malformed or inconsistent.""" + pass class NtpMessage: + """Representation of an NTP packet and its associated local timing data.""" + def __init__( self, *, - delay=MAXDISP, - dispersion=MAXDISP, - leap=NOSYNC, - mode=3, - poll=MINPOLL, - precision=PRECISION, - reference=b"", - stratum=MAXSTRAT, - t=None, - t_dst=(0, 0), - t_org=(0, 0), - t_rec=(0, 0), - t_ref=(0, 0), - t_xmt=(0, 0), - version=VERSION + delay: float = MAXDISP, + dispersion: float = MAXDISP, + leap: int = NOSYNC, + mode: int = 3, + poll: int = MINPOLL, + precision: int = PRECISION, + reference: bytes = b"", + stratum: int = MAXSTRAT, + t: float | None = None, + t_dst: tuple[int, int] = (0, 0), + t_org: tuple[int, int] = (0, 0), + t_rec: tuple[int, int] = (0, 0), + t_ref: tuple[int, int] = (0, 0), + t_xmt: tuple[int, int] = (0, 0), + version: int = VERSION, ): + """Initialize an NTP message. + + The constructor arguments map directly to the NTP message fields used + by this implementation, including the timestamp tuples carried by the + packet and the local receive timestamp recorded in ``t``. + """ if t is None: t = time() @@ -119,41 +151,86 @@ def __init__( self.version = version @staticmethod - def to_short(x): + def to_short(x: int | float) -> tuple[int, int]: + """Convert a value to the NTP short fixed-point format. + + Args: + x: Delay or dispersion value in seconds, expressed as an integer or + floating-point number. + + Returns: + A ``(seconds, fraction)`` tuple representing the 16.16 fixed-point + NTP short format. + + Raises: + NtpError: If ``x`` is not an ``int`` or ``float``. + """ # Page 13, short format is 32bit, unsigned, fixed point. if isinstance(x, int): - return x & 0xffff, 0 + return x & 0xFFFF, 0 if isinstance(x, float): secs = int(x) frac = int(65536 * (x - secs)) - return secs & 0xffff, frac & 0xffff - raise NtpError("Invalid NTP shot format value") + return secs & 0xFFFF, frac & 0xFFFF + raise NtpError("Invalid NTP short format value") @staticmethod - def to_timestamp(x): + def to_timestamp(x: int | float) -> tuple[int, int]: + """Convert a value to the full NTP timestamp format. + + Args: + x: Timestamp value in seconds, expressed as an integer or + floating-point number. + + Returns: + A ``(seconds, fraction)`` tuple representing the 32.32 fixed-point + NTP timestamp format. + + Raises: + NtpError: If ``x`` is not an ``int`` or ``float``. + """ # Page 13, timestamp is 64bit, unsigned, fixed point. if isinstance(x, int): - return x & 0xffffffff, 0 + return x & 0xFFFFFFFF, 0 if isinstance(x, float): secs = int(x) frac = int(4294967296 * (x - secs)) - return secs & 0xffffffff, frac & 0xffffffff + return secs & 0xFFFFFFFF, frac & 0xFFFFFFFF raise NtpError("Invalid NTP timestamp value") @staticmethod - def from_short(secs, frac): + def from_short(secs: int, frac: int) -> float: + """Convert an NTP short fixed-point value to seconds. + + Args: + secs: Integer part of the 16.16 fixed-point value. + frac: Fractional part of the 16.16 fixed-point value. + + Returns: + The decoded value in seconds as a floating-point number. + """ return secs + frac / 65536 @staticmethod - def from_timestamp(secs, frac): + def from_timestamp(secs: int, frac: int) -> float: + """Convert an NTP timestamp value to seconds. + + Args: + secs: Integer part of the 32.32 fixed-point timestamp. + frac: Fractional part of the 32.32 fixed-point timestamp. + + Returns: + The decoded timestamp in seconds as a floating-point number. + """ return secs + frac / 4294967296 - def serialize(self): - b1 = ( - ((self.leap & 3) << 6) - | ((self.version & 7) << 3) - | ((self.mode & 7) << 0) - ) + def serialize(self) -> bytes: + """Serialize the message to the wire format used by NTP. + + Returns: + The 48-byte binary representation of the NTP packet. + """ + b1 = ((self.leap & 3) << 6) | ((self.version & 7) << 3) | ((self.mode & 7) << 0) delay_secs, delay_frac = self.to_short(self.delay) dispersion_secs, dispersion_frac = self.to_short(self.dispersion) t_ref_secs, t_ref_frac = self.t_ref @@ -184,7 +261,26 @@ def serialize(self): ) @staticmethod - def deserialize(b, t=None): + def deserialize(b: bytes, t: float | None = None) -> "NtpMessage": + """Parse an NTP server response from its wire representation. + + Args: + b: Raw packet payload received from the network. + t: Local Unix timestamp recorded when the packet was received. If + omitted, the current system time is used. + + Returns: + An ``NtpMessage`` populated from the received packet, with the local + destination timestamp stored in ``t_dst``. + + Raises: + NtpPacketError: If the packet length, version, mode, or required + timestamps are invalid. + NtpDeniedError: If the server returned a kiss-o'-death response that + denies service. + NtpThrottledError: If the server returned a kiss-o'-death response + requesting a lower query rate. + """ if t is None: t = time() @@ -255,15 +351,28 @@ def deserialize(b, t=None): class NtpState: + """Snapshot of the time-quality metrics derived from NTP samples.""" + def __init__( self, *, - delay=MAXDISP, - dispersion=MAXDISP, - jitter=0, - offset=0, - t=None + delay: float = MAXDISP, + dispersion: float = MAXDISP, + jitter: int = 0, + offset: float = 0, + t: float | None = None, ): + """Initialize a state snapshot. + + Args: + delay: Estimated round-trip network delay in seconds. + dispersion: Estimated maximum error of the sample in seconds. + jitter: Variation between recent offset samples, in seconds. + offset: Estimated correction to apply to the local clock, in + seconds. + t: Local Unix timestamp at which the state was computed. If + omitted, the current system time is used. + """ if t is None: t = time() @@ -283,16 +392,31 @@ def __str__(self): class NtpAssociation: + """Maintain polling state and statistics for a single NTP peer.""" + def __init__( self, *, - address, - port=123, - precision=PRECISION, - tolerance=TOLERANCE, - start_randomization=None, - max_poll=None + address: str, + port: int = 123, + precision: int = PRECISION, + tolerance: float = TOLERANCE, + start_randomization: float | None = None, + max_poll: int | None = None, ): + """Initialize an association with a single NTP server. + + Args: + address: IPv4 or IPv6 address of the remote NTP server. + port: UDP port of the remote NTP server. + precision: Local clock precision encoded as a base-2 exponent. + tolerance: Assumed maximum local clock drift, in seconds per + second. + start_randomization: Optional maximum initial random delay, in + seconds, used to stagger the first poll. + max_poll: Optional upper bound for the polling interval, in + seconds. + """ ip = ip_address(address) self.ipv6 = isinstance(ip, IPv6Address) self.address = (address, port) @@ -312,27 +436,36 @@ def __init__( self.poll = MINPOLL self.poll_t = t if start_randomization is not None: - self.poll_t += random() * start_randomization + self.poll_t += random() * start_randomization # noqa: S311 log.info("NTP association %s initialized", self) log.debug("%s Scheduled at %s", self, self.poll_t) - def __hash__(self): + def __hash__(self) -> int: return hash(self.address) - def __eq__(self, other): + def __eq__(self, other: object): if isinstance(other, NtpAssociation): return self.address == other.address if isinstance(other, tuple): return self.address == other return False - def __str__(self): + def __str__(self) -> str: return "%s" % self.address[0] - def __repr__(self): + def __repr__(self) -> str: return self.__str__() - def schedule_poll(self, t=None): + def schedule_poll(self, t: float | None = None) -> None: + """Schedule the next poll time for this peer. + + The next poll is jittered slightly around the current polling interval + to avoid synchronized bursts against multiple servers. + + Args: + t: Local Unix timestamp to use as the scheduling base. If omitted, + the current system time is used. + """ if t is None: t = time() @@ -341,12 +474,21 @@ def schedule_poll(self, t=None): self.poll = min(self.max_poll, self.poll) self.poll = max(MINPOLL, self.poll) - interval = self.poll + random() * self.poll / 2 - self.poll / 4 + interval = self.poll + random() * self.poll / 2 - self.poll / 4 # noqa: S311 self.poll_t = t + interval self.poll *= 1.5 log.debug("%s Scheduled in %s secs at %s", self, interval, self.poll_t) - def calculate_state(self, t=None): + def calculate_state(self, t: float | None = None) -> None: + """Recompute the aggregate peer state from the sample register. + + Args: + t: Local Unix timestamp to associate with the newly computed state. + If omitted, the current system time is used. + + Returns: + None. + """ if t is None: t = time() @@ -355,12 +497,10 @@ def calculate_state(self, t=None): offset = register[0].offset delay = register[0].delay - dispersion = sum( - r.dispersion / (2 ** i) for i, r in enumerate(register, 1) + dispersion = sum(r.dispersion / (2**i) for i, r in enumerate(register, 1)) + jitter = ( + sum((r.offset - offset) ** 2 for r in register) / (len(register) - 1) ** 0.5 ) - jitter = sum( - (r.offset - offset) ** 2 for r in register - ) / (len(register) - 1) ** 0.5 self.state = NtpState( offset=offset, @@ -370,7 +510,20 @@ def calculate_state(self, t=None): t=t, ) - def root_distance(self, t=None): + def root_distance(self, t: float | None = None) -> float: + """Estimate the peer's root distance. + + Root distance is the synchronization error bound used by NTP selection. + It combines network delay, dispersion, jitter, and accumulated drift + since the last update. + + Args: + t: Current local Unix timestamp. If omitted, the current system + time is used. + + Returns: + The estimated root distance in seconds. + """ if t is None: t = time() @@ -385,10 +538,29 @@ def root_distance(self, t=None): + self.tolerance * abs(t - incoming.t) ) - def merit_factor(self, t=None): + def merit_factor(self, t: float | None = None) -> float: + """Compute the peer's ranking score for clock selection. + + Lower scores are better. The score primarily favors lower stratum + servers and then uses root distance as a tie-breaker. + + Args: + t: Current local Unix timestamp. If omitted, the current system + time is used. + + Returns: + The merit factor used to sort candidate peers. + """ return self.incoming.stratum * MAXDIST + self.root_distance(t) - def is_synchronized(self): + def is_synchronized(self) -> bool: + """Check whether the peer reports itself as synchronized. + + Returns: + ``True`` if the peer's last response indicates a synchronized + server with an acceptable stratum and dispersion, otherwise + ``False``. + """ incoming = self.incoming return ( incoming.leap != NOSYNC @@ -396,10 +568,29 @@ def is_synchronized(self): and incoming.delay / 2 + incoming.dispersion < MAXDISP ) - def is_fit(self, t=None): + def is_fit(self, t: float | None = None) -> bool: + """Check whether the peer is suitable for clock selection. + + Args: + t: Current local Unix timestamp. If omitted, the current system + time is used. + + Returns: + ``True`` if the peer is synchronized and its root distance is below + the implementation's acceptance threshold, otherwise ``False``. + """ return self.is_synchronized() and self.root_distance(t) < MAXDISP - def prepare_request(self, t=None): + def prepare_request(self, t: float | None = None) -> bytes: + """Build the next client request packet for this peer. + + Args: + t: Local Unix timestamp to use for the transmit time. If omitted, + the current system time is used. + + Returns: + The serialized NTP client request packet. + """ if t is None: t = time() @@ -412,7 +603,18 @@ def prepare_request(self, t=None): ) return self.outgoing.serialize() - def response_error(self, error, t=None): + def response_error(self, error: Exception, t: float | None = None) -> None: + """Record a communication failure for this peer. + + Communication failures degrade the peer's sample history so that peer + selection naturally stops favoring it until valid replies arrive again. + + Args: + error: The communication error that occurred while sending or + receiving a packet. + t: Local Unix timestamp at which the error was observed. If + omitted, the current system time is used. + """ # Communication errors, including timeouts, cause degradation # of samples in the register and render the peer unfit. if t is None: @@ -429,11 +631,27 @@ def response_error(self, error, t=None): self.calculate_state(t) self.schedule_poll(t) - def process_response(self, payload, t=None): + def process_response(self, payload: bytes, t: float | None = None) -> None: + """Process a received NTP server response. + + Valid replies update the association state and sample register. Invalid + or unsynchronized replies are converted into degraded samples so the + peer becomes less likely to be selected. + + Args: + payload: Raw UDP payload received from the peer. + t: Local Unix timestamp recorded when the payload was received. If + omitted, the current system time is used. + + Raises: + AssertionError: If called before a request has been prepared and no + outgoing packet state is available. + """ if t is None: t = time() log.debug("%s Got a packet", self) + assert self.outgoing is not None try: r = NtpMessage.deserialize(payload, t) @@ -468,12 +686,8 @@ def process_response(self, payload, t=None): # in order, to be in sync with the peer. Therefore, # negative offset means our clock is running fast. offset = (t2 - t1 + t3 - t4) / 2 - delay = max(t4 - t1 - t3 + t2, 2 ** self.precision) - dispersion = ( - 2 ** r.precision - + 2 ** self.precision - + (t4 - t1) * self.tolerance - ) + delay = max(t4 - t1 - t3 + t2, 2**self.precision) + dispersion = 2**r.precision + 2**self.precision + (t4 - t1) * self.tolerance state = NtpState( offset=offset, delay=delay, @@ -483,35 +697,66 @@ def process_response(self, payload, t=None): self.register.append(state) self.calculate_state(t) log.debug("%s Update with %s", self, state) + except NtpUnsynchronizedError: self.register.append(NtpState(t=t)) self.calculate_state(t) + except NtpError as e: log.info("%s %s", self, e.args[0]) + finally: self.schedule_poll(t) +Edge: TypeAlias = tuple[float, int, int, NtpAssociation] +State: TypeAlias = tuple[int, float, float] + + class NtpArena: + """Manage a set of NTP peers and derive a combined clock estimate.""" + + peers: dict[tuple[str, int], NtpAssociation] + sockv4: socket | None + sockv6: socket | None + def __init__( self, *, - addresses, - socket_timeout=5.0, - precision=PRECISION, - tolerance=TOLERANCE, - start_randomization=None, - max_poll=None + addresses: Iterable[str], + socket_timeout: float = 5.0, + precision: int = PRECISION, + tolerance: float = TOLERANCE, + start_randomization: float | None = None, + max_poll: int | None = None, ): + """Initialize the arena and its peer sockets. + + Args: + addresses: IP addresses of the NTP servers to poll. + socket_timeout: Timeout for socket receive operations, in seconds. + precision: Local clock precision encoded as a base-2 exponent. + tolerance: Assumed maximum local clock drift, in seconds per + second. + start_randomization: Optional maximum initial random delay, in + seconds, used to stagger the first poll for each peer. + max_poll: Optional upper bound for each peer's polling interval, in + seconds. + + Raises: + ValueError: If no usable IPv4 or IPv6 addresses are provided. + """ needs_ipv4 = False needs_ipv6 = False self.peers = {} for i in set(addresses): - p = NtpAssociation(address=i, - precision=precision, - tolerance=tolerance, - start_randomization=start_randomization, - max_poll=max_poll) + p = NtpAssociation( + address=i, + precision=precision, + tolerance=tolerance, + start_randomization=start_randomization, + max_poll=max_poll, + ) self.peers[p.address] = p if p.ipv6: needs_ipv6 = True @@ -525,7 +770,7 @@ def __init__( if needs_ipv4: self.sockv4 = socket(AF_INET, SOCK_DGRAM) self.sockv4.settimeout(socket_timeout) - self.sockv4.bind(("0.0.0.0", 0)) + self.sockv4.bind(("0.0.0.0", 0)) # noqa: S104 log.debug("Created IPv4 socket") self.sockv6 = None @@ -538,10 +783,30 @@ def __init__( def query_peers( self, *, - query_limit=None, - time_limit=None, - response_callback=None - ): + query_limit: int | None = None, + time_limit: float | None = None, + response_callback: Callable[[], None] | None = None, + ) -> float: + """Poll due peers until the next wait period begins. + + The method sends requests to peers whose scheduled poll time has + arrived, waits for a matching response from each, and updates the peer + state accordingly. + + Args: + query_limit: Optional maximum number of peer queries to perform in + this call. + time_limit: Optional maximum wall-clock time, in seconds, to spend + inside this call. + response_callback: Optional callback invoked after each successful + response is processed. + + Returns: + The number of seconds until the next peer should be queried. + + Raises: + ValueError: If the arena has no peers to query. + """ log.debug("Query peers") i = 0 start = time() @@ -572,7 +837,10 @@ def query_peers( if time_limit is not None and t - start > time_limit: log.debug("Time limit reached") return diff + s = self.sockv6 if p.ipv6 else self.sockv4 + assert s is not None + try: s.sendto(p.prepare_request(), p.address) while True: @@ -586,7 +854,26 @@ def query_peers( except OSError as e: p.response_error(e) - def filter_clocks(self, edges, low, high): + def filter_clocks(self, edges: Iterable[Edge], low: float, high: float) -> State: + """Select survivor peers and combine them into a system clock state. + + Args: + edges: Interval boundary records derived from candidate peers, + containing lower bounds, midpoints, upper bounds, and peer + references. + low: Lower bound of the consensus interval. + high: Upper bound of the consensus interval. + + Returns: + A ``(leap, offset, jitter)`` tuple representing the combined system + state, where ``leap`` is the leap-second correction indicator, + ``offset`` is the estimated clock correction in seconds, and + ``jitter`` is the aggregate jitter in seconds. + + Raises: + NtpUnsynchronizedError: If no peers remain within the consensus + interval. + """ # Truechimers have their midpoint in the found interval. truechimers = set() for e in edges: @@ -602,6 +889,7 @@ def filter_clocks(self, edges, low, high): min_jitter = None max_jitter = None max_jitter_peer = None + for t in truechimers: offset = t.state.offset jitter = ( @@ -613,8 +901,14 @@ def filter_clocks(self, edges, low, high): if max_jitter is None or max_jitter < jitter: max_jitter = jitter max_jitter_peer = t + + assert min_jitter is not None + assert max_jitter is not None + if max_jitter < min_jitter: break + + assert max_jitter_peer is not None truechimers.remove(max_jitter_peer) t = time() @@ -645,7 +939,20 @@ def filter_clocks(self, edges, low, high): log.debug("offset=%g jitter=%g leap=%d", offset, jitter, leap) return leap, offset, jitter - def calculate_state(self): + def calculate_state(self) -> State: + """Compute the current combined clock state from all fit peers. + + This runs the NTP clock-selection and clustering steps over the peers + that are currently considered fit. + + Returns: + A ``(leap, offset, jitter)`` tuple representing the selected system + clock state. + + Raises: + NtpUnsynchronizedError: If there are no fit peers or if the peers + do not reach consensus. + """ t = time() fit = [p for p in self.peers.values() if p.is_fit(t)] @@ -653,7 +960,7 @@ def calculate_state(self): raise NtpUnsynchronizedError("No fit peers found") log.debug("Fit peers: %s", fit) - edges = [] + edges: list[Edge] = [] for p in fit: offset = p.state.offset distance = p.root_distance(t) @@ -686,18 +993,22 @@ def calculate_state(self): break midpoints += e[2] - if ( - midpoints <= i - and low is not None - and high is not None - and low < high - ): + if midpoints <= i and low is not None and high is not None and low < high: return self.filter_clocks(edges, low, high) raise NtpUnsynchronizedError("No consensus found") -def argv_parser(progname=None): +def argv_parser(progname: str | None = None) -> "ArgumentParser": + """Create the command-line argument parser for the NTP client. + + Args: + progname: Program name to display in help output. If omitted, ``ntp`` + is used. + + Returns: + The configured argument parser. + """ import argparse if progname is None: @@ -727,7 +1038,7 @@ def argv_parser(progname=None): help=( "how many times the output should be produced." " It defaults to zero, which means 'run forever'." - ) + ), ) parser.add_argument( "--output-interval", @@ -738,7 +1049,7 @@ def argv_parser(progname=None): " Defaults to after each batch of queries." " Zero means after every reply from an NTP server." " Subject to availability of synchronization data." - ) + ), ) parser.add_argument( "--output-format", @@ -748,7 +1059,7 @@ def argv_parser(progname=None): "defaults to '{Y:04}-{M:02}-{D:02}T{h:02}:{m:02}:{s:02}.{u:06}Z'" " Other variables available: count, offset, jitter, leap and time." " For example: 'offset={offset}'" - ) + ), ) parser.add_argument( "--socket-timeout", @@ -763,12 +1074,20 @@ def argv_parser(progname=None): help=( "max interval between queries to each NTP server (in seconds)." " By default it is capped at 1h +/- 15m." - ) + ), ) return parser -def main(): +def main() -> None: + """Run the command-line NTP client. + + The client resolves the requested servers, polls them repeatedly, and + prints synchronized time output according to the configured format. + + Returns: + None. + """ args = argv_parser().parse_args() log_level = getattr(logging, args.log_level.upper()) output_format = args.output_format