nnconvexity

This package contains the documentation for the nnconvexity Python package. See also the code for the paper containing a demo for using this package.

If you use our code, please cite the paper Tětková, L., Brüsch, T., Dorszewski, T. et al. On convex decision regions in deep network representations. Nat Commun 16, 5419 (2025). https://doi.org/10.1038/s41467-025-60809-y.

We support two ways to compute convexity:

  • Euclidean using [nnconvexity.euclidean][]
  • graph using [nnconvexity.graph][]

nnconvexity.euclidean.compute_euclidean_convexity(representations, labels, predict_from_middle, layer, n_pairs=5000, n_sampled=10)

Compute Euclidean convexity (in %) for given representations and labels.

Parameters:
  • representations (ndarray) –

    (n_data, n_tokens, n_features) Latent representations.

  • labels (ndarray) –

    (n_data,) Labels.

  • predict_from_middle (function) –

    Function that takes features as input and returns predictions. Inputs: features (shape (n_interpolated data, n_tokens, n_features)), layer (int). Output: predictions (shape n_interpolated data).

  • layer (int) –

    Layer to compute Euclidean convexity for.

  • n_pairs (int, default: 5000 ) –

    Maximum number of pairs within a concept used for evaluation.

  • n_sampled (int, default: 10 ) –

    Number of points sampled on each segment.

Returns:
  • float( float ) –

    Mean Euclidean convexity (in %) over all pairs.

  • Dict[int, float]

    Dict[int, float]: Dictionary mapping concept label to Euclidean convexity (in %) of that concept.

nnconvexity.graph.compute_graph_convexity(representations, labels, n_neighbors=10, max_n_paths=5000, n_parallel=1)

Compute graph convexity (in %) for given representations and labels.

Parameters:
  • representations (ndarray) –

    (n_data, n_features) latent representations

  • labels (ndarray) –

    (n_data,) labels

  • n_neighbors (int, default: 10 ) –

    number of nearest neighbors

  • max_n_paths (int, default: 5000 ) –

    maximum number of paths used for evaluation

  • n_parallel (int, default: 1 ) –

    number of parallel jobs

Returns:
  • float( float ) –

    mean graph convexity (in %) over all paths

  • Dict[int, float]

    Dict[int, float]: dictionary mapping concept label to graph convexity (in %) of that concept