"""
Membrane potential plots. Default rendering is a heatmap (consistent
with the raster view for spikes); plot_potential_lines is available
for traditional per-neuron line plots.
"""
from __future__ import annotations
from typing import Any, Optional, Sequence, Tuple
import matplotlib.pyplot as plt
import numpy as np
from sanafe.data import potentials_to_dataframe
from sanafe.viz.styles import (
SANAFEStyle,
DEFAULT_COLORS,
create_figure,
get_default_style,
style_axis,
)
[docs]
def plot_potential(
source: Any,
neuron_ids: Optional[Sequence[str]] = None,
time_range: Optional[Tuple[int, int]] = None,
cmap: str = "viridis",
vmin: Optional[float] = None,
vmax: Optional[float] = None,
show_colorbar: bool = True,
ax: Optional[plt.Axes] = None,
style: Optional[SANAFEStyle] = None,
figsize: Optional[Tuple[float, float]] = None,
title: Optional[str] = None,
xlabel: str = "Time-step",
ylabel: str = "Neuron",
**imshow_kwargs,
) -> Tuple[plt.Figure, plt.Axes]:
style = style or get_default_style()
df = potentials_to_dataframe(source, neuron_ids=neuron_ids)
if time_range is not None:
df = df.iloc[time_range[0]:time_range[1]]
n_timesteps, n_neurons = df.shape
timesteps = np.asarray(df.index)
t_start = int(timesteps[0]) if len(timesteps) else 0
t_end = int(timesteps[-1]) + 1 if len(timesteps) else 1
if ax is None:
fig, ax = create_figure(figsize=figsize, style=style)
else:
fig = ax.get_figure()
imshow_defaults = {
"aspect": "auto",
"origin": "lower",
"cmap": cmap,
"extent": [t_start - 0.5, t_end - 0.5, -0.5, n_neurons - 0.5],
}
if vmin is not None:
imshow_defaults["vmin"] = vmin
if vmax is not None:
imshow_defaults["vmax"] = vmax
imshow_defaults.update(imshow_kwargs)
im = ax.imshow(df.values.T, **imshow_defaults)
if show_colorbar:
cbar = fig.colorbar(im, ax=ax)
cbar.set_label("Membrane Potential")
labels = list(df.columns)
if n_neurons <= 20:
ax.set_yticks(range(n_neurons))
ax.set_yticklabels(labels)
else:
step = max(1, n_neurons // 10)
tick_positions = list(range(0, n_neurons, step))
ax.set_yticks(tick_positions)
ax.set_yticklabels([labels[i] for i in tick_positions])
style_axis(ax, style, xlabel=xlabel, ylabel=ylabel, title=title)
if style.tight_layout:
fig.tight_layout()
return fig, ax
[docs]
def plot_potential_lines(
source: Any,
neuron_ids: Optional[Sequence[str]] = None,
time_range: Optional[Tuple[int, int]] = None,
colors: Optional[Sequence[str]] = None,
show_threshold: Optional[float] = None,
threshold_color: str = "#d62728",
threshold_linestyle: str = "--",
show_legend: bool = True,
ax: Optional[plt.Axes] = None,
style: Optional[SANAFEStyle] = None,
figsize: Optional[Tuple[float, float]] = None,
title: Optional[str] = None,
xlabel: str = "Time-step",
ylabel: str = "Membrane Potential",
**plot_kwargs,
) -> Tuple[plt.Figure, plt.Axes]:
style = style or get_default_style()
df = potentials_to_dataframe(source, neuron_ids=neuron_ids)
if time_range is not None:
df = df.iloc[time_range[0]:time_range[1]]
n_neurons = df.shape[1]
if colors is None:
colors = [DEFAULT_COLORS[i % len(DEFAULT_COLORS)] for i in range(n_neurons)]
if ax is None:
fig, ax = create_figure(figsize=figsize, style=style)
else:
fig = ax.get_figure()
plot_defaults = {"linewidth": style.potential_line_width}
if style.potential_marker:
plot_defaults["marker"] = style.potential_marker
plot_defaults["markersize"] = style.potential_marker_size
plot_defaults.update(plot_kwargs)
timesteps = np.asarray(df.index)
for i, col in enumerate(df.columns):
ax.plot(timesteps, df[col].values, color=colors[i],
label=str(col), **plot_defaults)
if show_threshold is not None:
ax.axhline(y=show_threshold, color=threshold_color,
linestyle=threshold_linestyle,
linewidth=style.line_width * 0.8, label="Threshold", zorder=0)
if len(timesteps):
ax.set_xlim(timesteps[0] - 0.5, timesteps[-1] + 0.5)
style_axis(ax, style, xlabel=xlabel, ylabel=ylabel, title=title)
if show_legend:
ax.legend(loc="upper right", framealpha=0.9)
if style.tight_layout:
fig.tight_layout()
return fig, ax