Source code for pymod.server

"""Modbus TCP server — callback-based, datastore-free.

The server owns no data. The host application supplies callbacks that
read from and write to whatever store it already has (BACnet object
cache, MQTT mirror, in-memory dict — pymod doesn't care).

Authorization is implicit in the area: `INPUT_REGISTER` and
`DISCRETE_INPUT` are read-only because no write FC targets them.
`HOLDING_REGISTER` and `COIL` are read-write. The host decides which data
to expose through which area; if no `on_write` callback is supplied at
all, every write FC returns ILLEGAL_FUNCTION.

Callback exception semantics:

* Raise a `ModbusExceptionResponse` subclass (e.g. `IllegalDataAddress`)
  to surface that specific Modbus exception code on the wire.
* Raise anything else and the server logs the exception at ERROR and
  responds with `SLAVE_DEVICE_FAILURE` (0x04).

Callbacks may be sync or async — `inspect.isawaitable` handles both.

Usage:

    async def on_read(area, address, count):
        return host_store.read(area, address, count)

    async def on_write(area, address, values):
        host_store.write(area, address, values)

    server = pymod.Server(host="0.0.0.0", port=502,
                          on_read=on_read, on_write=on_write)
    async with server:
        await asyncio.Future()  # run forever
"""

from __future__ import annotations

import asyncio
import inspect
import logging
from collections.abc import Awaitable, Callable, Sequence
from typing import Any, Self, cast

from ._types import Area
from .errors import (
    IllegalDataValue,
    IllegalFunction,
    ModbusError,
    ModbusExceptionResponse,
    ModbusProtocolError,
    SlaveDeviceFailure,
)
from .protocol.adu_tcp import (
    MBAP_MIN_PEEK_BYTES,
    decode_mbap_frame,
    encode_mbap_frame,
    expected_frame_length,
)
from .protocol.pdu import (
    FC_READ_COILS,
    FC_READ_DISCRETE_INPUTS,
    FC_READ_HOLDING_REGISTERS,
    FC_READ_INPUT_REGISTERS,
    FC_WRITE_MULTIPLE_COILS,
    FC_WRITE_MULTIPLE_REGISTERS,
    FC_WRITE_SINGLE_COIL,
    FC_WRITE_SINGLE_REGISTER,
    decode_request_read_coils,
    decode_request_read_discrete_inputs,
    decode_request_read_holding_registers,
    decode_request_read_input_registers,
    decode_request_write_multiple_coils,
    decode_request_write_multiple_registers,
    decode_request_write_single_coil,
    decode_request_write_single_register,
    encode_exception_response,
    encode_response_read_coils,
    encode_response_read_discrete_inputs,
    encode_response_read_holding_registers,
    encode_response_read_input_registers,
    encode_response_write_multiple_coils,
    encode_response_write_multiple_registers,
    encode_response_write_single_coil,
    encode_response_write_single_register,
)

_logger = logging.getLogger("pymod.server")
_wire = logging.getLogger("pymod.wire")

ReadCallback = Callable[
    [Area, int, int],
    Awaitable[Sequence[int | bool]] | Sequence[int | bool],
]
WriteCallback = Callable[
    [Area, int, Sequence[int | bool]],
    Awaitable[None] | None,
]

# Per-spec exception codes used directly by the server when no specific
# callback exception applies.
_GATEWAY_TARGET_FAILED_TO_RESPOND = 0x0B


[docs] class Server: """Modbus TCP server with callback-based data access.""" def __init__( self, *, host: str = "0.0.0.0", port: int = 502, on_read: ReadCallback, on_write: WriteCallback | None = None, unit_id: int | None = None, max_connections: int = 32, ) -> None: if max_connections < 1: raise ValueError(f"max_connections must be >= 1, got {max_connections}") self._host = host self._port = port self._on_read = on_read self._on_write = on_write self._unit_id = unit_id self._max_connections = max_connections self._server: asyncio.Server | None = None self._connections: set[asyncio.StreamWriter] = set() self._client_tasks: set[asyncio.Task[None]] = set() # ---------- public API ----------
[docs] async def start(self) -> None: if self._server is not None: return self._server = await asyncio.start_server( self._on_client, self._host, self._port ) sockets = self._server.sockets if sockets: sockname = sockets[0].getsockname() _logger.info("listening on %s", sockname)
[docs] async def stop(self) -> None: # Drop active client connections. writers = list(self._connections) self._connections.clear() for w in writers: try: w.close() except Exception: pass for w in writers: try: await w.wait_closed() except Exception: pass # Cancel any in-flight handler tasks. tasks = list(self._client_tasks) for t in tasks: if not t.done(): t.cancel() for t in tasks: try: await t except (asyncio.CancelledError, Exception): pass self._client_tasks.clear() # Shut down listener. if self._server is not None: self._server.close() try: await self._server.wait_closed() except Exception: pass self._server = None
@property def is_running(self) -> bool: return self._server is not None @property def active_connections(self) -> int: return len(self._connections) @property def port(self) -> int: """Bound port. Useful when port=0 was used (ephemeral port).""" if self._server is None: return self._port sockets = self._server.sockets if not sockets: return self._port return cast(int, sockets[0].getsockname()[1]) async def __aenter__(self) -> Self: await self.start() return self async def __aexit__(self, *exc_info: object) -> None: await self.stop() # ---------- per-connection handler ---------- async def _on_client( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, ) -> None: peer = writer.get_extra_info("peername") if len(self._connections) >= self._max_connections: _logger.warning( "rejecting connection from %s: max_connections=%d reached", peer, self._max_connections, ) try: writer.close() await writer.wait_closed() except Exception: pass return self._connections.add(writer) task = asyncio.current_task() if task is not None: self._client_tasks.add(task) _logger.debug( "client connected: %s (active=%d)", peer, len(self._connections) ) try: while True: header = await reader.readexactly(MBAP_MIN_PEEK_BYTES) total = expected_frame_length(header) rest = await reader.readexactly(total - MBAP_MIN_PEEK_BYTES) frame = header + rest if _wire.isEnabledFor(logging.DEBUG): _wire.debug("RX from %s: %s", peer, frame.hex()) try: tid, uid, request_pdu = decode_mbap_frame(frame) except ModbusProtocolError: _logger.exception("malformed MBAP frame from %s; dropping", peer) continue response_pdu = await self._dispatch(uid, request_pdu) response_frame = encode_mbap_frame(tid, uid, response_pdu) if _wire.isEnabledFor(logging.DEBUG): _wire.debug("TX to %s: %s", peer, response_frame.hex()) writer.write(response_frame) await writer.drain() except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError): pass except asyncio.CancelledError: raise except Exception: _logger.exception("unexpected error in handler for %s", peer) finally: self._connections.discard(writer) if task is not None: self._client_tasks.discard(task) try: writer.close() except Exception: pass _logger.debug( "client disconnected: %s (active=%d)", peer, len(self._connections), ) # ---------- request dispatch ---------- async def _dispatch(self, uid: int, pdu: bytes) -> bytes: if len(pdu) == 0: return encode_exception_response(0, IllegalFunction.code) fc = pdu[0] if self._unit_id is not None and uid != self._unit_id: return encode_exception_response(fc, _GATEWAY_TARGET_FAILED_TO_RESPOND) try: return await self._handle(fc, pdu) except ModbusExceptionResponse as e: return encode_exception_response(fc, e.code) except asyncio.CancelledError: raise except Exception: _logger.exception("callback raised unhandled exception on FC %#04x", fc) return encode_exception_response(fc, SlaveDeviceFailure.code) async def _handle(self, fc: int, pdu: bytes) -> bytes: # Read FCs. if fc == FC_READ_COILS: start, count = decode_request_read_coils(pdu) values = await self._call_read(Area.COIL, start, count) return encode_response_read_coils(_to_bool_list(values, count)) if fc == FC_READ_DISCRETE_INPUTS: start, count = decode_request_read_discrete_inputs(pdu) values = await self._call_read(Area.DISCRETE_INPUT, start, count) return encode_response_read_discrete_inputs(_to_bool_list(values, count)) if fc == FC_READ_HOLDING_REGISTERS: start, count = decode_request_read_holding_registers(pdu) values = await self._call_read(Area.HOLDING_REGISTER, start, count) return encode_response_read_holding_registers(_to_int_list(values, count)) if fc == FC_READ_INPUT_REGISTERS: start, count = decode_request_read_input_registers(pdu) values = await self._call_read(Area.INPUT_REGISTER, start, count) return encode_response_read_input_registers(_to_int_list(values, count)) # Write FCs require on_write. if fc in ( FC_WRITE_SINGLE_COIL, FC_WRITE_SINGLE_REGISTER, FC_WRITE_MULTIPLE_COILS, FC_WRITE_MULTIPLE_REGISTERS, ): if self._on_write is None: return encode_exception_response(fc, IllegalFunction.code) if fc == FC_WRITE_SINGLE_COIL: coil_addr, coil_value = decode_request_write_single_coil(pdu) await self._call_write(Area.COIL, coil_addr, [coil_value]) return encode_response_write_single_coil(coil_addr, coil_value) if fc == FC_WRITE_SINGLE_REGISTER: reg_addr, reg_value = decode_request_write_single_register(pdu) await self._call_write(Area.HOLDING_REGISTER, reg_addr, [reg_value]) return encode_response_write_single_register(reg_addr, reg_value) if fc == FC_WRITE_MULTIPLE_COILS: start, bits = decode_request_write_multiple_coils(pdu) await self._call_write(Area.COIL, start, bits) return encode_response_write_multiple_coils(start, len(bits)) if fc == FC_WRITE_MULTIPLE_REGISTERS: start, regs = decode_request_write_multiple_registers(pdu) await self._call_write(Area.HOLDING_REGISTER, start, regs) return encode_response_write_multiple_registers(start, len(regs)) # Unknown / unsupported function code. return encode_exception_response(fc, IllegalFunction.code) # ---------- callback adapters ---------- async def _call_read( self, area: Area, address: int, count: int, ) -> Sequence[int | bool]: result: Any = self._on_read(area, address, count) if inspect.isawaitable(result): result = await result return cast("Sequence[int | bool]", result) async def _call_write( self, area: Area, address: int, values: Sequence[int | bool], ) -> None: assert self._on_write is not None result: Any = self._on_write(area, address, values) if inspect.isawaitable(result): await result
# ---------- helpers ---------- def _to_bool_list(values: Sequence[int | bool], expected_count: int) -> list[bool]: if len(values) != expected_count: raise SlaveDeviceFailure( f"on_read returned {len(values)} values, expected {expected_count}" ) return [bool(v) for v in values] def _to_int_list(values: Sequence[int | bool], expected_count: int) -> list[int]: if len(values) != expected_count: raise SlaveDeviceFailure( f"on_read returned {len(values)} values, expected {expected_count}" ) out: list[int] = [] for v in values: iv = int(v) if not 0 <= iv <= 0xFFFF: raise SlaveDeviceFailure(f"on_read returned out-of-range register value {iv}") out.append(iv) return out # Re-export ModbusError for callers who want to subclass / catch. __all__ = [ "ReadCallback", "Server", "WriteCallback", "ModbusError", "IllegalDataValue", ]