lodum.pickle
1# SPDX-FileCopyrightText: 2025-present Michael R. Bernstein <zopemaven@gmail.com> 2# 3# SPDX-License-Identifier: Apache-2.0 4import pickle 5import builtins 6import io 7from typing import Any, Type, TypeVar 8 9from .core import Dumper 10from .internal import ( 11 dump as validate_lodum_structure, 12 DEFAULT_MAX_SIZE, 13) 14from .exception import DeserializationError 15 16T = TypeVar("T") 17 18# --- Safe Encoding --- 19 20 21class ValidationDumper(Dumper): 22 """A no-op dumper used only for validation.""" 23 24 def dump_int(self, value: int) -> None: 25 pass 26 27 def dump_str(self, value: str) -> None: 28 pass 29 30 def dump_float(self, value: float) -> None: 31 pass 32 33 def dump_bool(self, value: bool) -> None: 34 pass 35 36 def dump_bytes(self, value: bytes) -> None: 37 pass 38 39 def dump_list(self, value: list[Any]) -> None: 40 pass 41 42 def dump_dict(self, value: dict[str, Any]) -> None: 43 pass 44 45 def begin_struct(self, cls: Type[Any]) -> dict[str, Any]: 46 return {} # Return a dummy dict 47 48 def end_struct(self) -> None: 49 pass 50 51 52def dumps(obj: Any) -> bytes: 53 """ 54 Encodes a Python object to a pickle byte string, ensuring it is safe. 55 """ 56 validator = ValidationDumper() 57 validate_lodum_structure(obj, validator) 58 return pickle.dumps(obj) 59 60 61# --- Safe Decoding --- 62 63 64class SafeUnpickler(pickle.Unpickler): 65 """ 66 A custom unpickler that only allows safe, lodum-enabled classes to be loaded. 67 """ 68 69 def find_class(self, module_name: str, class_name: str) -> Type: 70 if "os" in module_name or "sys" in module_name or "subprocess" in module_name: 71 raise pickle.UnpicklingError(f"Unsafe module '{module_name}' is forbidden.") 72 73 SAFE_BUILTINS = { 74 "int", 75 "float", 76 "str", 77 "bool", 78 "bytes", 79 "bytearray", 80 "list", 81 "tuple", 82 "dict", 83 "set", 84 "frozenset", 85 "complex", 86 "NoneType", 87 "type", 88 } 89 90 if module_name == "builtins": 91 if class_name in SAFE_BUILTINS and hasattr(builtins, class_name): 92 return getattr(builtins, class_name) 93 raise pickle.UnpicklingError(f"Unsafe builtin '{class_name}' is forbidden.") 94 95 if module_name == "collections" and class_name in ( 96 "defaultdict", 97 "OrderedDict", 98 "Counter", 99 ): 100 import collections 101 102 return getattr(collections, class_name) 103 104 if module_name == "array" and class_name in ("array", "_array_reconstructor"): 105 import array 106 107 return getattr(array, class_name) 108 109 cls = super().find_class(module_name, class_name) 110 111 if getattr(cls, "_lodum_enabled", False): 112 return cls 113 114 raise pickle.UnpicklingError( 115 f"Attempted to unpickle a non-lodum type: {module_name}.{class_name}" 116 ) 117 118 119def loads(cls: Type[T], data: bytes, max_size: int = DEFAULT_MAX_SIZE) -> T: 120 """ 121 Decodes a pickle byte string to a Python object, ensuring it is safe. 122 """ 123 if len(data) > max_size: 124 raise DeserializationError( 125 f"Input size ({len(data)}) exceeds maximum allowed ({max_size})" 126 ) 127 128 with io.BytesIO(data) as f: 129 unpickler = SafeUnpickler(f) 130 try: 131 obj = unpickler.load() 132 except ( 133 pickle.UnpicklingError, 134 AttributeError, 135 ImportError, 136 IndexError, 137 TypeError, 138 ) as e: 139 raise DeserializationError(f"Failed to unpickle data: {e}") 140 141 if not isinstance(obj, cls): 142 raise DeserializationError( 143 f"Deserialized object is of type {type(obj).__name__}, but expected {cls.__name__}" 144 ) 145 146 return obj
class
ValidationDumper(lodum.core.Dumper):
22class ValidationDumper(Dumper): 23 """A no-op dumper used only for validation.""" 24 25 def dump_int(self, value: int) -> None: 26 pass 27 28 def dump_str(self, value: str) -> None: 29 pass 30 31 def dump_float(self, value: float) -> None: 32 pass 33 34 def dump_bool(self, value: bool) -> None: 35 pass 36 37 def dump_bytes(self, value: bytes) -> None: 38 pass 39 40 def dump_list(self, value: list[Any]) -> None: 41 pass 42 43 def dump_dict(self, value: dict[str, Any]) -> None: 44 pass 45 46 def begin_struct(self, cls: Type[Any]) -> dict[str, Any]: 47 return {} # Return a dummy dict 48 49 def end_struct(self) -> None: 50 pass
A no-op dumper used only for validation.
def
dumps(obj: Any) -> bytes:
53def dumps(obj: Any) -> bytes: 54 """ 55 Encodes a Python object to a pickle byte string, ensuring it is safe. 56 """ 57 validator = ValidationDumper() 58 validate_lodum_structure(obj, validator) 59 return pickle.dumps(obj)
Encodes a Python object to a pickle byte string, ensuring it is safe.
class
SafeUnpickler(_pickle.Unpickler):
65class SafeUnpickler(pickle.Unpickler): 66 """ 67 A custom unpickler that only allows safe, lodum-enabled classes to be loaded. 68 """ 69 70 def find_class(self, module_name: str, class_name: str) -> Type: 71 if "os" in module_name or "sys" in module_name or "subprocess" in module_name: 72 raise pickle.UnpicklingError(f"Unsafe module '{module_name}' is forbidden.") 73 74 SAFE_BUILTINS = { 75 "int", 76 "float", 77 "str", 78 "bool", 79 "bytes", 80 "bytearray", 81 "list", 82 "tuple", 83 "dict", 84 "set", 85 "frozenset", 86 "complex", 87 "NoneType", 88 "type", 89 } 90 91 if module_name == "builtins": 92 if class_name in SAFE_BUILTINS and hasattr(builtins, class_name): 93 return getattr(builtins, class_name) 94 raise pickle.UnpicklingError(f"Unsafe builtin '{class_name}' is forbidden.") 95 96 if module_name == "collections" and class_name in ( 97 "defaultdict", 98 "OrderedDict", 99 "Counter", 100 ): 101 import collections 102 103 return getattr(collections, class_name) 104 105 if module_name == "array" and class_name in ("array", "_array_reconstructor"): 106 import array 107 108 return getattr(array, class_name) 109 110 cls = super().find_class(module_name, class_name) 111 112 if getattr(cls, "_lodum_enabled", False): 113 return cls 114 115 raise pickle.UnpicklingError( 116 f"Attempted to unpickle a non-lodum type: {module_name}.{class_name}" 117 )
A custom unpickler that only allows safe, lodum-enabled classes to be loaded.
def
find_class(self, module_name: str, class_name: str) -> Type:
70 def find_class(self, module_name: str, class_name: str) -> Type: 71 if "os" in module_name or "sys" in module_name or "subprocess" in module_name: 72 raise pickle.UnpicklingError(f"Unsafe module '{module_name}' is forbidden.") 73 74 SAFE_BUILTINS = { 75 "int", 76 "float", 77 "str", 78 "bool", 79 "bytes", 80 "bytearray", 81 "list", 82 "tuple", 83 "dict", 84 "set", 85 "frozenset", 86 "complex", 87 "NoneType", 88 "type", 89 } 90 91 if module_name == "builtins": 92 if class_name in SAFE_BUILTINS and hasattr(builtins, class_name): 93 return getattr(builtins, class_name) 94 raise pickle.UnpicklingError(f"Unsafe builtin '{class_name}' is forbidden.") 95 96 if module_name == "collections" and class_name in ( 97 "defaultdict", 98 "OrderedDict", 99 "Counter", 100 ): 101 import collections 102 103 return getattr(collections, class_name) 104 105 if module_name == "array" and class_name in ("array", "_array_reconstructor"): 106 import array 107 108 return getattr(array, class_name) 109 110 cls = super().find_class(module_name, class_name) 111 112 if getattr(cls, "_lodum_enabled", False): 113 return cls 114 115 raise pickle.UnpicklingError( 116 f"Attempted to unpickle a non-lodum type: {module_name}.{class_name}" 117 )
Return an object from a specified module.
If necessary, the module will be imported. Subclasses may override this method (e.g. to restrict unpickling of arbitrary classes and functions).
This method is called whenever a class or a function object is needed. Both arguments passed are str objects.
def
loads(cls: Type[~T], data: bytes, max_size: int = 10485760) -> ~T:
120def loads(cls: Type[T], data: bytes, max_size: int = DEFAULT_MAX_SIZE) -> T: 121 """ 122 Decodes a pickle byte string to a Python object, ensuring it is safe. 123 """ 124 if len(data) > max_size: 125 raise DeserializationError( 126 f"Input size ({len(data)}) exceeds maximum allowed ({max_size})" 127 ) 128 129 with io.BytesIO(data) as f: 130 unpickler = SafeUnpickler(f) 131 try: 132 obj = unpickler.load() 133 except ( 134 pickle.UnpicklingError, 135 AttributeError, 136 ImportError, 137 IndexError, 138 TypeError, 139 ) as e: 140 raise DeserializationError(f"Failed to unpickle data: {e}") 141 142 if not isinstance(obj, cls): 143 raise DeserializationError( 144 f"Deserialized object is of type {type(obj).__name__}, but expected {cls.__name__}" 145 ) 146 147 return obj
Decodes a pickle byte string to a Python object, ensuring it is safe.