Module catmaid_publish.utils

Expand source code
import logging
from collections.abc import Iterable
from copy import copy, deepcopy
from functools import lru_cache, wraps
from typing import Any, Callable, Hashable, Optional, TypeVar

import networkx as nx
import pymaid

logger = logging.getLogger(__name__)


def setup_logging(level=logging.DEBUG):
    """Sane default logging setup.

    Should only be called once, and only by a script.
    """
    logging.basicConfig(level=level)

    # these are packages with particularly noisy logging
    logging.getLogger("urllib3").setLevel(logging.INFO)
    logging.getLogger("matplotlib").setLevel(logging.WARNING)


T = TypeVar("T")


def fill_in_dict(to_fill: dict[T, T], fill_with: Iterable[T], inplace=False):
    if not inplace:
        to_fill = dict(to_fill)

    for f in fill_with:
        if f not in to_fill:
            to_fill[f] = f

    return to_fill


def copy_cache(deep: bool = True, maxsize: Optional[int] = 128, typed: bool = False):
    """Decorator factory wrapping functools.lru_cache, which copies return values.

    Use this for functions with mutable return values.
    N.B. must be called with parentheses.

    Parameters
    ----------
    deep : bool, optional
        Whether to deepcopy return values, by default True
    maxsize : Optional[int], optional
        See lru_cache, by default 128.
        Use None for an unbounded cache (which is also faster).
    typed : bool, optional
        See lru_cache, by default False

    Returns
    -------
    Callable[[Callable], Callable]
        Returns a decorator.

    Examples
    --------
    >>> @copy_cache()  # must include parentheses, unlike some zero-arg decorators
    ... def my_function(key: str, value: int) -> dict:
    ...     return {key: value}
    ...
    >>> d1 = my_function("a", 1)
    >>> d1["b"] = 2
    >>> d2 = my_function("a", 1)
    >>> "b" in d2  # would be True with raw functools.lru_cache
    False
    """
    copier = deepcopy if deep else copy

    def wrapper(fn):
        wrapped = lru_cache(maxsize, typed)(fn)

        @wraps(fn)
        def copy_wrapped(*args, **kwargs):
            out = wrapped(*args, **kwargs)
            return copier(out)

        copy_wrapped.cache_info = wrapped.cache_info
        copy_wrapped.cache_clear = wrapped.cache_clear

        return copy_wrapped

    return wrapper


def accept(node_id, node_data: dict[str, Any]) -> bool:
    return True


def descendants(
    g: nx.DiGraph,
    roots: list[str],
    max_depth=None,
    select_fn: Optional[Callable[[Hashable, dict[str, Any]], bool]] = None,
) -> set[str]:
    """Find all descendant nodes from given roots, to a given depth.

    Output includes given roots.
    """
    if select_fn is None:
        select_fn = accept

    out = set()
    visited = set()
    to_visit: list[tuple[str, int]] = [(r, 0) for r in roots]
    while to_visit:
        node, depth = to_visit.pop()

        if node in visited:
            continue

        visited.add(node)

        if select_fn(node, g.nodes[node]):
            out.add(node)

        if max_depth is not None and depth >= max_depth:
            continue

        to_visit.extend((n, depth + 1) for n in g.successors(node))

    return out


def remove_nodes(
    g: nx.Graph,
    node_fn: Callable[[Any, dict], bool],
) -> dict[Hashable, dict]:
    """Remove nodes according to a function.

    Parameters
    ----------
    g : nx.Graph
    node_fn : Callable[[Any, dict], bool]
        Callable which takes a node ID and attributes dict,
        and returns True if it should be deleted.

    Returns
    -------
    dict[Hashable, dict]
        Dict of removed node IDs to their attributes dict.
    """
    out = dict()
    for idx, data in g.nodes(data=True):
        if node_fn(idx, data):
            out[idx] = data
    g.remove_nodes_from(out)
    return out


def join_markdown(*strings: Optional[str]) -> str:
    """Join stripped markdown strings with thematic breaks.

    Returns
    -------
    str
    """
    out = []
    for s in strings:
        if s is None:
            continue
        s = s.strip()
        if s:
            out.append(s)

    if out:
        out.append(out.pop() + "\n")

    return "\n\n---\n\n".join(out)


@copy_cache(maxsize=None)
def entity_graph() -> nx.DiGraph:
    return pymaid.get_entity_graph(["annotation", "neuron", "volume"])

Functions

def accept(node_id, node_data: dict[str, typing.Any]) ‑> bool
Expand source code
def accept(node_id, node_data: dict[str, Any]) -> bool:
    return True
def copy_cache(deep: bool = True, maxsize: Optional[int] = 128, typed: bool = False)

Decorator factory wrapping functools.lru_cache, which copies return values.

Use this for functions with mutable return values. N.B. must be called with parentheses.

Parameters

deep : bool, optional
Whether to deepcopy return values, by default True
maxsize : Optional[int], optional
See lru_cache, by default 128. Use None for an unbounded cache (which is also faster).
typed : bool, optional
See lru_cache, by default False

Returns

Callable[[Callable], Callable]
Returns a decorator.

Examples

>>> @copy_cache()  # must include parentheses, unlike some zero-arg decorators
... def my_function(key: str, value: int) -> dict:
...     return {key: value}
...
>>> d1 = my_function("a", 1)
>>> d1["b"] = 2
>>> d2 = my_function("a", 1)
>>> "b" in d2  # would be True with raw functools.lru_cache
False
Expand source code
def copy_cache(deep: bool = True, maxsize: Optional[int] = 128, typed: bool = False):
    """Decorator factory wrapping functools.lru_cache, which copies return values.

    Use this for functions with mutable return values.
    N.B. must be called with parentheses.

    Parameters
    ----------
    deep : bool, optional
        Whether to deepcopy return values, by default True
    maxsize : Optional[int], optional
        See lru_cache, by default 128.
        Use None for an unbounded cache (which is also faster).
    typed : bool, optional
        See lru_cache, by default False

    Returns
    -------
    Callable[[Callable], Callable]
        Returns a decorator.

    Examples
    --------
    >>> @copy_cache()  # must include parentheses, unlike some zero-arg decorators
    ... def my_function(key: str, value: int) -> dict:
    ...     return {key: value}
    ...
    >>> d1 = my_function("a", 1)
    >>> d1["b"] = 2
    >>> d2 = my_function("a", 1)
    >>> "b" in d2  # would be True with raw functools.lru_cache
    False
    """
    copier = deepcopy if deep else copy

    def wrapper(fn):
        wrapped = lru_cache(maxsize, typed)(fn)

        @wraps(fn)
        def copy_wrapped(*args, **kwargs):
            out = wrapped(*args, **kwargs)
            return copier(out)

        copy_wrapped.cache_info = wrapped.cache_info
        copy_wrapped.cache_clear = wrapped.cache_clear

        return copy_wrapped

    return wrapper
def descendants(g: networkx.classes.digraph.DiGraph, roots: list[str], max_depth=None, select_fn: Optional[Callable[[Hashable, dict[str, Any]], bool]] = None) ‑> set[str]

Find all descendant nodes from given roots, to a given depth.

Output includes given roots.

Expand source code
def descendants(
    g: nx.DiGraph,
    roots: list[str],
    max_depth=None,
    select_fn: Optional[Callable[[Hashable, dict[str, Any]], bool]] = None,
) -> set[str]:
    """Find all descendant nodes from given roots, to a given depth.

    Output includes given roots.
    """
    if select_fn is None:
        select_fn = accept

    out = set()
    visited = set()
    to_visit: list[tuple[str, int]] = [(r, 0) for r in roots]
    while to_visit:
        node, depth = to_visit.pop()

        if node in visited:
            continue

        visited.add(node)

        if select_fn(node, g.nodes[node]):
            out.add(node)

        if max_depth is not None and depth >= max_depth:
            continue

        to_visit.extend((n, depth + 1) for n in g.successors(node))

    return out
def entity_graph() ‑> networkx.classes.digraph.DiGraph
Expand source code
@copy_cache(maxsize=None)
def entity_graph() -> nx.DiGraph:
    return pymaid.get_entity_graph(["annotation", "neuron", "volume"])
def fill_in_dict(to_fill: dict[~T, ~T], fill_with: collections.abc.Iterable[~T], inplace=False)
Expand source code
def fill_in_dict(to_fill: dict[T, T], fill_with: Iterable[T], inplace=False):
    if not inplace:
        to_fill = dict(to_fill)

    for f in fill_with:
        if f not in to_fill:
            to_fill[f] = f

    return to_fill
def join_markdown(*strings: Optional[str]) ‑> str

Join stripped markdown strings with thematic breaks.

Returns

str
 
Expand source code
def join_markdown(*strings: Optional[str]) -> str:
    """Join stripped markdown strings with thematic breaks.

    Returns
    -------
    str
    """
    out = []
    for s in strings:
        if s is None:
            continue
        s = s.strip()
        if s:
            out.append(s)

    if out:
        out.append(out.pop() + "\n")

    return "\n\n---\n\n".join(out)
def remove_nodes(g: networkx.classes.graph.Graph, node_fn: Callable[[Any, dict], bool]) ‑> dict[typing.Hashable, dict]

Remove nodes according to a function.

Parameters

g : nx.Graph
 
node_fn : Callable[[Any, dict], bool]
Callable which takes a node ID and attributes dict, and returns True if it should be deleted.

Returns

dict[Hashable, dict]
Dict of removed node IDs to their attributes dict.
Expand source code
def remove_nodes(
    g: nx.Graph,
    node_fn: Callable[[Any, dict], bool],
) -> dict[Hashable, dict]:
    """Remove nodes according to a function.

    Parameters
    ----------
    g : nx.Graph
    node_fn : Callable[[Any, dict], bool]
        Callable which takes a node ID and attributes dict,
        and returns True if it should be deleted.

    Returns
    -------
    dict[Hashable, dict]
        Dict of removed node IDs to their attributes dict.
    """
    out = dict()
    for idx, data in g.nodes(data=True):
        if node_fn(idx, data):
            out[idx] = data
    g.remove_nodes_from(out)
    return out
def setup_logging(level=10)

Sane default logging setup.

Should only be called once, and only by a script.

Expand source code
def setup_logging(level=logging.DEBUG):
    """Sane default logging setup.

    Should only be called once, and only by a script.
    """
    logging.basicConfig(level=level)

    # these are packages with particularly noisy logging
    logging.getLogger("urllib3").setLevel(logging.INFO)
    logging.getLogger("matplotlib").setLevel(logging.WARNING)