Source code for pytomography.io.PET.prd._dtypes
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import datetime
from types import GenericAlias
import sys
if sys.version_info >= (3, 10):
from types import UnionType
from typing import Any, Callable, Union, cast, get_args, get_origin
import numpy as np
from . import yardl_types as yardl
[docs]def make_get_dtype_func(
dtype_map: dict[
Union[type, GenericAlias],
Union[np.dtype[Any], Callable[[tuple[type, ...]], np.dtype[Any]]],
]
) -> Callable[[Union[type, GenericAlias]], np.dtype[Any]]:
dtype_map[bool] = np.dtype(np.bool_)
dtype_map[yardl.Int8] = np.dtype(np.int8)
dtype_map[yardl.UInt8] = np.dtype(np.uint8)
dtype_map[yardl.Int16] = np.dtype(np.int16)
dtype_map[yardl.UInt16] = np.dtype(np.uint16)
dtype_map[yardl.Int32] = np.dtype(np.int32)
dtype_map[yardl.UInt32] = np.dtype(np.uint32)
dtype_map[yardl.Int64] = np.dtype(np.int64)
dtype_map[yardl.UInt64] = np.dtype(np.uint64)
dtype_map[yardl.Size] = np.dtype(np.uint64)
dtype_map[yardl.Float32] = np.dtype(np.float32)
dtype_map[yardl.Float64] = np.dtype(np.float64)
dtype_map[yardl.ComplexFloat] = np.dtype(np.complex64)
dtype_map[yardl.ComplexDouble] = np.dtype(np.complex128)
dtype_map[datetime.date] = np.dtype("datetime64[D]")
dtype_map[yardl.Time] = np.dtype("timedelta64[ns]")
dtype_map[yardl.DateTime] = np.dtype("datetime64[ns]")
dtype_map[str] = np.dtype(np.object_)
# Add the Python types to the dictionary too, but these may not be
# correct since they map to several dtypes
dtype_map[int] = np.dtype(np.int64)
dtype_map[float] = np.dtype(np.float64)
dtype_map[complex] = np.dtype(np.complex128)
def get_dtype_impl(
dtype_map: dict[
Union[type, GenericAlias],
Union[np.dtype[Any], Callable[[tuple[type, ...]], np.dtype[Any]]],
],
t: Union[type, GenericAlias],
) -> np.dtype[Any]:
# type_args = list(filter(lambda t: type(t) != TypeVar, get_args(t)))
origin = get_origin(t)
if origin == Union or (
sys.version_info >= (3, 10) and isinstance(t, UnionType)
):
return _get_union_dtype(get_args(t))
# If t is found in dtype_map here, t is either a Python type
# or t is a types.GenericAlias with missing type arguments
if (res := dtype_map.get(t, None)) is not None:
if callable(res):
raise RuntimeError(f"Generic type arguments not provided for {t}")
else:
return res
# Here, t is either invalid (no dtype registered)
# or t is a types.GenericAlias with type arguments specified
if origin is not None and (res := dtype_map.get(origin, None)) is not None:
if callable(res):
return res(get_args(t))
raise RuntimeError(f"Cannot find dtype for {t}")
def _get_union_dtype(args: tuple[type, ...]) -> np.dtype[Any]:
if len(args) == 2 and args[1] == cast(type, type(None)):
# This is an optional type
inner_type = get_dtype_impl(dtype_map, args[0])
return np.dtype(
[("has_value", np.bool_), ("value", inner_type)], align=True
)
return np.dtype(np.object_)
return lambda t: get_dtype_impl(dtype_map, t)