[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

branch: models_igr: graph display + record selection #6

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
6fe2e9a
Updated drawing of network (network_status_graph.py; vertically cente…
igorpodolak Mar 20, 2020
d5df424
"selection of output in record.py"
igorpodolak Mar 22, 2020
2ddc41c
Merge branch 'master' of https://github.com/ziemowit-s/neuronpp into …
igorpodolak Mar 22, 2020
e5fe31f
"centered display of graph in network_status_graph.py; selection of o…
igorpodolak Mar 23, 2020
4fb520f
Display of true/predicted classes in record.py; true_labels added as …
igorpodolak Mar 24, 2020
7758ed8
"class labels passed to record:plot() --> plot_animate()"
igorpodolak Mar 25, 2020
06fc25d
Merge remote-tracking branch 'origin/master' into models_igr
igorpodolak Mar 26, 2020
3ed9930
Plotting of outputs: record.py: true_labels added; names changed
igorpodolak Mar 26, 2020
c8fb1e2
record.py: namedtuple as parameter method
igorpodolak Mar 27, 2020
597409e
record.py: simulation parameters to _plot_animate passed with a named…
igorpodolak Mar 27, 2020
cf742d9
minor errors (due to tragic namedtuple
igorpodolak Mar 27, 2020
ef7ae7f
Merge remote-tracking branch 'origin/models_igr'
ziemowit-s Mar 27, 2020
79f38d8
fixed Record to work with markers, not work yet
ziemowit-s Mar 27, 2020
34a7336
Corrected errors in record.py; changed its parameters a bit
igorpodolak Mar 27, 2020
387e21a
Merge remote-tracking branch 'origin/models_igr' into models_igr
igorpodolak Mar 27, 2020
f2d346e
minor errors in record.py, plot_animate()
igorpodolak Mar 27, 2020
8411e81
minor errors in inheritance in record.py
igorpodolak Mar 27, 2020
2ca9dec
Merge remote-tracking branch 'origin/master' into models_igr
ziemowit-s Apr 2, 2020
be385e8
changes to Record
ziemowit-s Apr 2, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions neuronpp/utils/graphs/network_status_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(self, cells, weight_name='w', plot_fixed_weight_edges=True):

self.correct_position = (0.1, 0.1)

self.population_sizes = dict()
self.population_names = self._get_population_names()
self.edges = self._get_edges(weight_name)

Expand Down Expand Up @@ -69,6 +70,10 @@ def update_spikes(self, sim_time):

def _get_edges(self, weight_name):
result = []
# compute the number of cells in each layer
for c in self.cells:
pop_name = c.name.split('[')[0]
self.population_sizes[pop_name] = self.population_sizes.get(pop_name, 0) + 1
for c in self.cells:
soma = c.filter_secs('soma')
if c._spike_detector is None:
Expand All @@ -81,11 +86,13 @@ def _get_edges(self, weight_name):
except ValueError:
x_pos = self.population_names.index(pop_name)

y_pos = int(split_name[-1][:-1])
# shift down the layer by half its size to vertically center graph
y_pos = int(split_name[-1][:-1]) - self.population_sizes[pop_name] // 2

if 'inh' in c.name:
self.colors.append('red')
y_pos -= 5
# todo center vertically by half width of the hid layer
y_pos -= 6
elif 'hid' in c.name:
self.colors.append('blue')
else:
Expand All @@ -110,7 +117,8 @@ def _find_target(self, c, x_pos, y_pos, weight_name):
except ValueError:
x_trg = self.population_names.index(pop_name)

y_trg = int(split_target[-1][:-1])
# center veritically
y_trg = int(split_target[-1][:-1]) - self.population_sizes[pop_name] // 2

weight = None
if self.plot_constant_connections and hasattr(nc.target.hoc, weight_name):
Expand All @@ -125,7 +133,8 @@ def _find_weights(c, weight_name):
for nc in c.ncs:
if "SpikeDetector" in nc.name:
continue
elif isinstance(nc.source, Seg) and isinstance(nc.target, PointProcess) and hasattr(nc.target.hoc, weight_name):
elif isinstance(nc.source, Seg) and isinstance(nc.target, PointProcess) and hasattr(nc.target.hoc,
weight_name):
weight = getattr(nc.target.hoc, weight_name)
targets.append(weight)
return targets
Expand Down
119 changes: 108 additions & 11 deletions neuronpp/utils/record.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from nrn import Segment, Section
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this concept of recording with labels should be moved to separated class which inherits from Record, eg. RecordWithLabels. In this case - dt and true_labels params from the _plot_animate() metod can be moved into consructor.

from collections import defaultdict
from collections import defaultdict, namedtuple

import numpy as np
import pandas as pd
from neuron import h
import matplotlib.pyplot as plt
from neuronpp.core.hocwrappers.sec import Sec

from neuronpp.core.hocwrappers.point_process import PointProcess

from neuronpp.core.hocwrappers.seg import Seg

MarkerParams = namedtuple("Simulation_params",
"agent_class agent_stepsize dt input_cell_num output_cell_num output_labels")


class Record:
def __init__(self, elements, variables='v'):
Expand Down Expand Up @@ -36,13 +41,20 @@ def __init__(self, elements, variables='v'):
for elem in elements:
for var in variables:
if isinstance(elem, Seg):
name = elem.parent.name
cell_name = elem.parent.parent.name
name = "%s_%s" % (cell_name, elem.name)
elif isinstance(elem, PointProcess):
cell_name = elem.cell.name
name = "%s_%s" % (cell_name, elem.name)
elif isinstance(elem, Sec):
raise TypeError("Record element cannot be of type Sec, however you can specify Seg eg. soma(0.5) and pass as element.")
else:
name = elem.name
try:
s = getattr(elem.hoc, "_ref_%s" % var)
except AttributeError:
raise AttributeError("there is no attribute of %s. Maybe you forgot to append loc param for sections?" % var)
raise AttributeError(
"there is no attribute of %s. Maybe you forgot to append loc param for sections?" % var)

rec = h.Vector().record(s)
self.recs[var].append((name, rec))
Expand Down Expand Up @@ -87,16 +99,19 @@ def _plot_static(self, position=None):
for i, (name, rec) in enumerate(section_recs):
rec_np = rec.as_numpy()
if np.max(np.isnan(rec_np)):
raise ValueError("Vector recorded for variable: '%s' and segment: '%s' contains nan values." % (var_name, name))
raise ValueError(
"Vector recorded for variable: '%s' and segment: '%s' contains nan values." % (var_name, name))

if position is not "merge":
ax = self._get_subplot(fig=fig, var_name=var_name, position=position, row_len=len(section_recs), index=i + 1)
ax = self._get_subplot(fig=fig, var_name=var_name, position=position, row_len=len(section_recs),
index=i + 1)
ax.set_title("Variable: %s" % var_name)
ax.plot(self.t, rec, label=name)
ax.set(xlabel='t (ms)', ylabel=var_name)
ax.legend()

def _plot_animate(self, steps=10000, y_lim=None, position=None):
def _plot_animate(self, steps=10000, y_lim=None, position=None, true_class=None, pred_class=None,
show_true_predicted=False, marker_params: MarkerParams = None):
"""
Call each time you want to redraw plot.

Expand All @@ -109,8 +124,20 @@ def _plot_animate(self, steps=10000, y_lim=None, position=None):
* position=(3,3) -> if you have 9 neurons and want to display 'v' on 3x3 matrix
* position='merge' -> it will display all figures on the same graph.
* position=None -> Default, each neuron has separated axis (row) on the figure.
:param true_class: list of true class labels in this window
:param pred_class: list of predicted class labels in window
:param show_true_predicted: whther to print true/predicted class' marks on the plot
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no point of using show_true_predicted param, because you can just check whether true_class and pred_class is not None.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a matter of liking. Lots of code is written so that switch some behaviour on/off using a single switch.
But this is an important comment: I shall move out the binary show_true_predicted out of the run_params tuple. Even so as this is not actually a simulation parameter.

:param marker_params: MarkerParams namedtuple contains inner params:
:param agent_stepsize: agent readout time step
:param dt: agent integration time step
:param input_cell_num: number of input cells
:param output_cell_num: number of output cells
:param output_labels: list of true labels for the consecutive plots
:return:
"""
if show_true_predicted and marker_params is None:
raise ValueError(
"Running parameters run_params need to be passed if true/predicted markers are to be shown")
create_fig = False
for var_name, section_recs in self.recs.items():
if var_name not in self.figs:
Expand All @@ -119,21 +146,25 @@ def _plot_animate(self, steps=10000, y_lim=None, position=None):
fig = self.figs[var_name]
if fig is None:
create_fig = True
fig = plt.figure()
fig = plt.figure(figsize=(16.5, 5.5))
fig.canvas.draw()
self.figs[var_name] = fig

if show_true_predicted:
if len(marker_params.output_labels) != len(section_recs):
raise ValueError(
"show_predicted is true but the number of true labels given is not equal to actual number of elemens to plot.")
for i, (name, rec) in enumerate(section_recs):
if create_fig:
if position == 'merge':
ax = fig.add_subplot(1, 1, 1)
else:
ax = self._get_subplot(fig=fig, var_name=var_name, position=position, row_len=len(section_recs), index=i + 1)
ax = self._get_subplot(fig=fig, var_name=var_name, position=position, row_len=len(section_recs),
index=i + 1)

if y_lim:
ax.set_ylim(y_lim[0], y_lim[1])
line, = ax.plot([], lw=1, label=name)
ax.set_title("Variable: %s" % var_name)
ax.set_ylabel(var_name)
ax.set_xlabel("t (ms)")
ax.legend()
Expand All @@ -146,17 +177,83 @@ def _plot_animate(self, steps=10000, y_lim=None, position=None):

ax.set_xlim(t.min(), t.max())
if y_lim is None:
ax.set_ylim(r.min()-(np.abs(r.min()*0.05)), r.max()+(np.abs(r.max()*0.05)))
# compute per-plot OY limits if global are not given
current_y_lim = (r.min() - (np.abs(r.min() * 0.05)), r.max() + (np.abs(r.max() * 0.05)))
ax.set_ylim(current_y_lim)
else:
current_y_lim = y_lim

# update data
line.set_data(t, r)
if show_true_predicted:
# info draw markers for true and predicted classes
self._show_true_predicted_marks(ax=ax, label=marker_params.output_labels[i], true_class=true_class,
pred_class=pred_class,
t=t, y_limits=current_y_lim, marker_params=marker_params)
if create_fig and i == 0:
# draw legend only the first time and only on the uppermost graph
ax.legend()

# info join plots by removing labels and ticks from subplots that are not on the edge
if create_fig:
igorpodolak marked this conversation as resolved.
Show resolved Hide resolved
fig.subplots_adjust(left=0.09, bottom=0.075, right=0.99, top=0.98, wspace=None, hspace=0.00)
fig.canvas.draw()
fig.canvas.flush_events()

if create_fig:
plt.show(block=False)

def _show_true_predicted_marks(self, ax, label, true_class, pred_class, t, y_limits, marker_params):
"""
draw triangles for true and predicted classes
:param ax: the canvas
:param true_class: list of true class labels in this window
:param pred_class: list of predicted class labels in window
:param y_limits: this canvas OY limits for y axis. Default is (-80, 50)
:param run_params: a namedtuple containing
:param agent_stepsize: agent readout time step
:param dt: agent integration time step
:param input_cell_num: number of input cells
:param output_cell_num: number of output cells
:param output_labels: list of true labels for the consecutive plots
:return:
"""
if marker_params.output_labels is not None:
true_x, pred_x = self._get_labels_timestamps(label=label,
true_class=true_class,
pred_class=pred_class, t=t,
marker_params=marker_params)
else:
raise ValueError("True_labels parameter need to be given if show_true_prediction is True")
true_y = [y_limits[0] + np.abs(y_limits[0]) * 0.09] * len(true_x)
pred_y = [y_limits[1] - np.abs(y_limits[1] * 0.12)] * len(pred_x)
ax.scatter(true_x, true_y, c="orange", marker="^", alpha=0.95, label="true")
ax.scatter(pred_x, pred_y, c="magenta", marker="v", alpha=0.95, label="predicted")

@staticmethod
def _get_labels_timestamps(label, true_class, pred_class, t, marker_params):
"""
find and return lists of time steps for true and predicted labels
:param label: the label id (an int)
:param true_class: list of true classes for the whole time region
:param pred_class: list of predicted labels (class ids) for the whole time region
:param t: the region time steps
:param marker_params:
:return: lists of marks for true_x: true classes, pred_x: predicted classes
"""
n = len(true_class)
x = t[::int(2 * marker_params.agent_stepsize / marker_params.dt)][-n:]
true_x = []
pred_x = []
# todo change lists into numpy arrays for speed
for k in range(n):
# get the true classes for the current label
if true_class[k] == label:
true_x.append(x[k])
if pred_class[k] == label:
pred_x.append(x[k])
return true_x, pred_x

def to_csv(self, filename):
cols = ['time']
data = [self.t.as_numpy().tolist()]
Expand Down