lodum.pickle

  1# SPDX-FileCopyrightText: 2025-present Michael R. Bernstein <zopemaven@gmail.com>
  2#
  3# SPDX-License-Identifier: Apache-2.0
  4import pickle
  5import builtins
  6import io
  7from typing import Any, Type, TypeVar
  8
  9from .core import Dumper
 10from .internal import (
 11    dump as validate_lodum_structure,
 12    DEFAULT_MAX_SIZE,
 13)
 14from .exception import DeserializationError
 15
 16T = TypeVar("T")
 17
 18# --- Safe Encoding ---
 19
 20
 21class ValidationDumper(Dumper):
 22    """A no-op dumper used only for validation."""
 23
 24    def dump_int(self, value: int) -> None:
 25        pass
 26
 27    def dump_str(self, value: str) -> None:
 28        pass
 29
 30    def dump_float(self, value: float) -> None:
 31        pass
 32
 33    def dump_bool(self, value: bool) -> None:
 34        pass
 35
 36    def dump_bytes(self, value: bytes) -> None:
 37        pass
 38
 39    def dump_list(self, value: list[Any]) -> None:
 40        pass
 41
 42    def dump_dict(self, value: dict[str, Any]) -> None:
 43        pass
 44
 45    def begin_struct(self, cls: Type[Any]) -> dict[str, Any]:
 46        return {}  # Return a dummy dict
 47
 48    def end_struct(self) -> None:
 49        pass
 50
 51
 52def dumps(obj: Any) -> bytes:
 53    """
 54    Encodes a Python object to a pickle byte string, ensuring it is safe.
 55    """
 56    validator = ValidationDumper()
 57    validate_lodum_structure(obj, validator)
 58    return pickle.dumps(obj)
 59
 60
 61# --- Safe Decoding ---
 62
 63
 64class SafeUnpickler(pickle.Unpickler):
 65    """
 66    A custom unpickler that only allows safe, lodum-enabled classes to be loaded.
 67    """
 68
 69    def find_class(self, module_name: str, class_name: str) -> Type:
 70        if "os" in module_name or "sys" in module_name or "subprocess" in module_name:
 71            raise pickle.UnpicklingError(f"Unsafe module '{module_name}' is forbidden.")
 72
 73        SAFE_BUILTINS = {
 74            "int",
 75            "float",
 76            "str",
 77            "bool",
 78            "bytes",
 79            "bytearray",
 80            "list",
 81            "tuple",
 82            "dict",
 83            "set",
 84            "frozenset",
 85            "complex",
 86            "NoneType",
 87            "type",
 88        }
 89
 90        if module_name == "builtins":
 91            if class_name in SAFE_BUILTINS and hasattr(builtins, class_name):
 92                return getattr(builtins, class_name)
 93            raise pickle.UnpicklingError(f"Unsafe builtin '{class_name}' is forbidden.")
 94
 95        if module_name == "collections" and class_name in (
 96            "defaultdict",
 97            "OrderedDict",
 98            "Counter",
 99        ):
100            import collections
101
102            return getattr(collections, class_name)
103
104        if module_name == "array" and class_name in ("array", "_array_reconstructor"):
105            import array
106
107            return getattr(array, class_name)
108
109        cls = super().find_class(module_name, class_name)
110
111        if getattr(cls, "_lodum_enabled", False):
112            return cls
113
114        raise pickle.UnpicklingError(
115            f"Attempted to unpickle a non-lodum type: {module_name}.{class_name}"
116        )
117
118
119def loads(cls: Type[T], data: bytes, max_size: int = DEFAULT_MAX_SIZE) -> T:
120    """
121    Decodes a pickle byte string to a Python object, ensuring it is safe.
122    """
123    if len(data) > max_size:
124        raise DeserializationError(
125            f"Input size ({len(data)}) exceeds maximum allowed ({max_size})"
126        )
127
128    with io.BytesIO(data) as f:
129        unpickler = SafeUnpickler(f)
130        try:
131            obj = unpickler.load()
132        except (
133            pickle.UnpicklingError,
134            AttributeError,
135            ImportError,
136            IndexError,
137            TypeError,
138        ) as e:
139            raise DeserializationError(f"Failed to unpickle data: {e}")
140
141    if not isinstance(obj, cls):
142        raise DeserializationError(
143            f"Deserialized object is of type {type(obj).__name__}, but expected {cls.__name__}"
144        )
145
146    return obj
class ValidationDumper(lodum.core.Dumper):
22class ValidationDumper(Dumper):
23    """A no-op dumper used only for validation."""
24
25    def dump_int(self, value: int) -> None:
26        pass
27
28    def dump_str(self, value: str) -> None:
29        pass
30
31    def dump_float(self, value: float) -> None:
32        pass
33
34    def dump_bool(self, value: bool) -> None:
35        pass
36
37    def dump_bytes(self, value: bytes) -> None:
38        pass
39
40    def dump_list(self, value: list[Any]) -> None:
41        pass
42
43    def dump_dict(self, value: dict[str, Any]) -> None:
44        pass
45
46    def begin_struct(self, cls: Type[Any]) -> dict[str, Any]:
47        return {}  # Return a dummy dict
48
49    def end_struct(self) -> None:
50        pass

A no-op dumper used only for validation.

def dump_int(self, value: int) -> None:
25    def dump_int(self, value: int) -> None:
26        pass
def dump_str(self, value: str) -> None:
28    def dump_str(self, value: str) -> None:
29        pass
def dump_float(self, value: float) -> None:
31    def dump_float(self, value: float) -> None:
32        pass
def dump_bool(self, value: bool) -> None:
34    def dump_bool(self, value: bool) -> None:
35        pass
def dump_bytes(self, value: bytes) -> None:
37    def dump_bytes(self, value: bytes) -> None:
38        pass
def dump_list(self, value: list[typing.Any]) -> None:
40    def dump_list(self, value: list[Any]) -> None:
41        pass
def dump_dict(self, value: dict[str, typing.Any]) -> None:
43    def dump_dict(self, value: dict[str, Any]) -> None:
44        pass
def begin_struct(self, cls: Type[Any]) -> dict[str, typing.Any]:
46    def begin_struct(self, cls: Type[Any]) -> dict[str, Any]:
47        return {}  # Return a dummy dict
def end_struct(self) -> None:
49    def end_struct(self) -> None:
50        pass
def dumps(obj: Any) -> bytes:
53def dumps(obj: Any) -> bytes:
54    """
55    Encodes a Python object to a pickle byte string, ensuring it is safe.
56    """
57    validator = ValidationDumper()
58    validate_lodum_structure(obj, validator)
59    return pickle.dumps(obj)

Encodes a Python object to a pickle byte string, ensuring it is safe.

class SafeUnpickler(_pickle.Unpickler):
 65class SafeUnpickler(pickle.Unpickler):
 66    """
 67    A custom unpickler that only allows safe, lodum-enabled classes to be loaded.
 68    """
 69
 70    def find_class(self, module_name: str, class_name: str) -> Type:
 71        if "os" in module_name or "sys" in module_name or "subprocess" in module_name:
 72            raise pickle.UnpicklingError(f"Unsafe module '{module_name}' is forbidden.")
 73
 74        SAFE_BUILTINS = {
 75            "int",
 76            "float",
 77            "str",
 78            "bool",
 79            "bytes",
 80            "bytearray",
 81            "list",
 82            "tuple",
 83            "dict",
 84            "set",
 85            "frozenset",
 86            "complex",
 87            "NoneType",
 88            "type",
 89        }
 90
 91        if module_name == "builtins":
 92            if class_name in SAFE_BUILTINS and hasattr(builtins, class_name):
 93                return getattr(builtins, class_name)
 94            raise pickle.UnpicklingError(f"Unsafe builtin '{class_name}' is forbidden.")
 95
 96        if module_name == "collections" and class_name in (
 97            "defaultdict",
 98            "OrderedDict",
 99            "Counter",
100        ):
101            import collections
102
103            return getattr(collections, class_name)
104
105        if module_name == "array" and class_name in ("array", "_array_reconstructor"):
106            import array
107
108            return getattr(array, class_name)
109
110        cls = super().find_class(module_name, class_name)
111
112        if getattr(cls, "_lodum_enabled", False):
113            return cls
114
115        raise pickle.UnpicklingError(
116            f"Attempted to unpickle a non-lodum type: {module_name}.{class_name}"
117        )

A custom unpickler that only allows safe, lodum-enabled classes to be loaded.

def find_class(self, module_name: str, class_name: str) -> Type:
 70    def find_class(self, module_name: str, class_name: str) -> Type:
 71        if "os" in module_name or "sys" in module_name or "subprocess" in module_name:
 72            raise pickle.UnpicklingError(f"Unsafe module '{module_name}' is forbidden.")
 73
 74        SAFE_BUILTINS = {
 75            "int",
 76            "float",
 77            "str",
 78            "bool",
 79            "bytes",
 80            "bytearray",
 81            "list",
 82            "tuple",
 83            "dict",
 84            "set",
 85            "frozenset",
 86            "complex",
 87            "NoneType",
 88            "type",
 89        }
 90
 91        if module_name == "builtins":
 92            if class_name in SAFE_BUILTINS and hasattr(builtins, class_name):
 93                return getattr(builtins, class_name)
 94            raise pickle.UnpicklingError(f"Unsafe builtin '{class_name}' is forbidden.")
 95
 96        if module_name == "collections" and class_name in (
 97            "defaultdict",
 98            "OrderedDict",
 99            "Counter",
100        ):
101            import collections
102
103            return getattr(collections, class_name)
104
105        if module_name == "array" and class_name in ("array", "_array_reconstructor"):
106            import array
107
108            return getattr(array, class_name)
109
110        cls = super().find_class(module_name, class_name)
111
112        if getattr(cls, "_lodum_enabled", False):
113            return cls
114
115        raise pickle.UnpicklingError(
116            f"Attempted to unpickle a non-lodum type: {module_name}.{class_name}"
117        )

Return an object from a specified module.

If necessary, the module will be imported. Subclasses may override this method (e.g. to restrict unpickling of arbitrary classes and functions).

This method is called whenever a class or a function object is needed. Both arguments passed are str objects.

def loads(cls: Type[~T], data: bytes, max_size: int = 10485760) -> ~T:
120def loads(cls: Type[T], data: bytes, max_size: int = DEFAULT_MAX_SIZE) -> T:
121    """
122    Decodes a pickle byte string to a Python object, ensuring it is safe.
123    """
124    if len(data) > max_size:
125        raise DeserializationError(
126            f"Input size ({len(data)}) exceeds maximum allowed ({max_size})"
127        )
128
129    with io.BytesIO(data) as f:
130        unpickler = SafeUnpickler(f)
131        try:
132            obj = unpickler.load()
133        except (
134            pickle.UnpicklingError,
135            AttributeError,
136            ImportError,
137            IndexError,
138            TypeError,
139        ) as e:
140            raise DeserializationError(f"Failed to unpickle data: {e}")
141
142    if not isinstance(obj, cls):
143        raise DeserializationError(
144            f"Deserialized object is of type {type(obj).__name__}, but expected {cls.__name__}"
145        )
146
147    return obj

Decodes a pickle byte string to a Python object, ensuring it is safe.