pulse2percept.models.granley2021¶
BiphasicAxonMapModel
, BiphasicAxonMapSpatial
, [Granley2021]
Functions
cond_jit (fn[, static_argnums]) |
Conditional decorator for jax jit |
Classes
BiphasicAxonMapModel (**params) |
BiphasicAxonMapModel of [Granley2021] (standalone model) |
BiphasicAxonMapSpatial (**params) |
BiphasicAxonMapModel of [Granley2021] (spatial model) |
DefaultBrightModel (**params) |
Default model to be used for brightness scaling in BiphasicAxonMapModel Implements Eq 4 from [Granley2021] Fit using data from [Nanduri2012] and [Weitz2015] |
DefaultSizeModel (rho[, engine]) |
Default model to be used for size (rho) scaling in BiphasicAxonMapModel Implements Eq 5 from [Granley2021] Fit using data from [Nanduri2012] and [Weitz2015] |
DefaultStreakModel (axlambda[, engine]) |
Default model to be used for streak length (lambda) scaling in BiphasicAxonMapModel Implements Eq 6 from [Granley2021] Fit using data from [Weitz2015] |
-
class
pulse2percept.models.granley2021.
BiphasicAxonMapModel
(**params)[source]¶ BiphasicAxonMapModel of [Granley2021] (standalone model)
An AxonMapModel where phosphene brightness, size, and streak length scale according to amplitude, frequency, and pulse duration
All stimuli must be BiphasicPulseTrains.
This model is different than other spatial models in that it calculates one representative percept from all time steps of the stimulus.
Brightness, size, and streak length scaling are controlled by the parameters bright_model, size_model, and streak model respectively. By default, these are set to classes that implement Eqs 3-6 from Granley 2021. These models can be individually customized by setting the bright_model, size_model, or streak_model to any python callable with signature f(freq, amp, pdur)
Note
Using this model in combination with a temporal model is not currently supported and will give unexpected results
Parameters: - bright_model (callable, optional) – Model used to modulate percept brightness with amplitude, frequency, and pulse duration
- size_model (callable, optional) – Model used to modulate percept size with amplitude, frequency, and pulse duration
- streak_model (callable, optional) – Model used to modulate percept streak length with amplitude, frequency, and pulse duration
- do_thresholding (boolean) – Use probabilistic sigmoid thresholding, default: False
- **params (dict, optional) –
Arguments to be passed to AxonMapSpatial
- axlambda: double, optional
- Exponential decay constant along the axon(microns).
- rho: double, optional
- Exponential decay constant away from the axon(microns).
- eye: {‘RE’, LE’}, optional
- Eye for which to generate the axon map.
- xrange : (x_min, x_max), optional
- A tuple indicating the range of x values to simulate (in degrees of visual angle). In a right eye, negative x values correspond to the temporal retina, and positive x values to the nasal retina. In a left eye, the opposite is true.
- yrange : tuple, (y_min, y_max)
- A tuple indicating the range of y values to simulate (in degrees of visual angle). Negative y values correspond to the superior retina, and positive y values to the inferior retina.
- xystep : int, double, tuple
- Step size for the range of (x,y) values to simulate (in degrees of
visual angle). For example, to create a grid with x values [0, 0.5, 1]
use
x_range=(0, 1)
andxystep=0.5
. - grid_type : {‘rectangular’, ‘hexagonal’}
- Whether to simulate points on a rectangular or hexagonal grid
- vfmap :
VisualFieldMap
, optional - An instance of a
VisualFieldMap
object that provides retinotopic mappings. By default,Watson2014Map
is used. - n_gray : int, optional
- The number of gray levels to use. If an integer is given, k-means
clustering is used to compress the color space of the percept into
n_gray
bins. If None, no compression is performed. - noise : float or int, optional
- Adds salt-and-pepper noise to each percept frame. An integer will be interpreted as the number of pixels to subject to noise in each frame. A float between 0 and 1 will be interpreted as a ratio of pixels to subject to noise in each frame.
- loc_od, loc_od: (x,y), optional
- Location of the optic disc in degrees of visual angle. Note that the optic disc in a left eye will be corrected to have a negative x coordinate.
- n_axons: int, optional
- Number of axons to generate.
- axons_range: (min, max), optional
- The range of angles(in degrees) at which axons exit the optic disc. This corresponds to the range of $phi_0$ values used in [Jansonius2009].
- n_ax_segments: int, optional
- Number of segments an axon is made of.
- ax_segments_range: (min, max), optional
- Lower and upper bounds for the radial position values(polar coords) for each axon.
- min_ax_sensitivity: float, optional
- Axon segments whose contribution to brightness is smaller than this value will be pruned to improve computational efficiency. Set to a value between 0 and 1.
- axon_pickle: str, optional
- File name in which to store precomputed axon maps.
- ignore_pickle: bool, optional
- A flag whether to ignore the pickle file in future calls to
model.build()
. - n_threads: int, optional
- Number of CPU threads to use during parallelization using OpenMP. Defaults to max number of user CPU cores.
-
build
(**build_params)[source]¶ Build the model
Performs expensive one-time calculations, such as building the spatial grid used to predict a percept.
Parameters: build_params (additional parameters to set) – You can overwrite parameters that are listed in get_default_params
. Trying to add new class attributes outside of that will cause aFreezeError
. Example:model.build(param1=val)
Returns: self
-
find_threshold
(implant, bright_th, amp_range=(0, 999), amp_tol=1, bright_tol=0.1, max_iter=100, t_percept=None)[source]¶ Find the threshold current for a certain stimulus
Estimates
amp_th
such that the output ofmodel.predict_percept(stim(amp_th))
is approximatelybright_th
.Parameters: - implant (
ProsthesisSystem
) – The implant and its stimulus to use. Stimulus amplitude will be up and down regulated untilamp_th
is found. - bright_th (float) – Model output (brightness) that’s considered “at threshold”.
- amp_range ((amp_lo, amp_hi), optional) – Range of amplitudes to search (uA).
- amp_tol (float, optional) – Search will stop if candidate range of amplitudes is within
amp_tol
- bright_tol (float, optional) – Search will stop if model brightness is within
bright_tol
ofbright_th
- max_iter (int, optional) – Search will stop after
max_iter
iterations - t_percept (float or list of floats, optional) – The time points at which to output a percept (ms).
If None,
implant.stim.time
is used.
Returns: amp_th (float) – Threshold current (uA), estimated so that the output of
model.predict_percept(stim(amp_th))
is withinbright_tol
ofbright_th
.- implant (
-
has_space
¶ Returns True if the model has a spatial component
-
has_time
¶ Returns True if the model has a temporal component
-
is_built
¶ Returns True if the
build
model has been called
-
predict_percept
(implant, t_percept=None)[source]¶ Predict a percept.
Overrides base predict percept to keep desired time axes
Important
You must call
build
before callingpredict_percept
.Note: The stimuli should use amplitude as a factor of threshold, NOT raw amplitude in microamps
Parameters: - implant (
ProsthesisSystem
) – A valid prosthesis system. A stimulus can be passed viastim
. - t_percept (float or list of floats, optional) – The time points at which to output a percept (ms).
If None,
implant.stim.time
is used.
Returns: percept (
Percept
) – A Percept object whosedata
container has dimensions Y x X x T. Will return None ifimplant.stim
is None.- implant (
-
set_params
(params)[source]¶ Set model parameters
This is a convenience function to set parameters that might be part of the spatial model, the temporal model, or both.
Alternatively, you can set the parameter directly, e.g.
model.spatial.verbose = True
.Note
If a parameter exists in both spatial and temporal models(e.g.,
verbose
), both models will be updated.Parameters: params (dict) – A dictionary of parameters to set.
-
class
pulse2percept.models.granley2021.
BiphasicAxonMapSpatial
(**params)[source]¶ BiphasicAxonMapModel of [Granley2021] (spatial model)
An AxonMapModel where phosphene brightness, size, and streak length scale according to amplitude, frequency, and pulse duration
All stimuli must be BiphasicPulseTrains.
This model is different than other spatial models in that it calculates one representative percept from all time steps of the stimulus.
Brightness, size, and streak length scaling are controlled by the effects models bright_model, size_model, and streak model respectively. By default, these are set to classes that implement Eqs 3-6 from Granley 2021. These models can be individually customized by setting the bright_model, size_model, or streak_model to any python callable with signature f(freq, amp, pdur)
Note
Using this model in combination with a temporal model is not currently supported and will give unexpected results
Parameters: - bright_model (callable, optional) – Model used to modulate percept brightness with amplitude, frequency, and pulse duration
- size_model (callable, optional) – Model used to modulate percept size with amplitude, frequency, and pulse duration
- streak_model (callable, optional) – Model used to modulate percept streak length with amplitude, frequency, and pulse duration
- **params (optional) –
Additional params for AxonMapModel.
- axlambda: double, optional
- Exponential decay constant along the axon(microns).
- rho: double, optional
- Exponential decay constant away from the axon(microns).
- eye: {‘RE’, LE’}, optional
- Eye for which to generate the axon map.
- xrange : (x_min, x_max), optional
- A tuple indicating the range of x values to simulate (in degrees of visual angle). In a right eye, negative x values correspond to the temporal retina, and positive x values to the nasal retina. In a left eye, the opposite is true.
- yrange : tuple, (y_min, y_max)
- A tuple indicating the range of y values to simulate (in degrees of visual angle). Negative y values correspond to the superior retina, and positive y values to the inferior retina.
- xystep : int, double, tuple
- Step size for the range of (x,y) values to simulate (in degrees of
visual angle). For example, to create a grid with x values [0, 0.5, 1]
use
x_range=(0, 1)
andxystep=0.5
. - grid_type : {‘rectangular’, ‘hexagonal’}
- Whether to simulate points on a rectangular or hexagonal grid
- vfmap :
VisualFieldMap
, optional - An instance of a
VisualFieldMap
object that provides retinotopic mappings. By default,Watson2014Map
is used. - n_gray : int, optional
- The number of gray levels to use. If an integer is given, k-means
clustering is used to compress the color space of the percept into
n_gray
bins. If None, no compression is performed. - noise : float or int, optional
- Adds salt-and-pepper noise to each percept frame. An integer will be interpreted as the number of pixels to subject to noise in each frame. A float between 0 and 1 will be interpreted as a ratio of pixels to subject to noise in each frame.
- loc_od, loc_od: (x,y), optional
- Location of the optic disc in degrees of visual angle. Note that the optic disc in a left eye will be corrected to have a negative x coordinate.
- n_axons: int, optional
- Number of axons to generate.
- axons_range: (min, max), optional
- The range of angles(in degrees) at which axons exit the optic disc. This corresponds to the range of $phi_0$ values used in [Jansonius2009].
- n_ax_segments: int, optional
- Number of segments an axon is made of.
- ax_segments_range: (min, max), optional
- Lower and upper bounds for the radial position values(polar coords) for each axon.
- min_ax_sensitivity: float, optional
- Axon segments whose contribution to brightness is smaller than this value will be pruned to improve computational efficiency. Set to a value between 0 and 1.
- axon_pickle: str, optional
- File name in which to store precomputed axon maps.
- ignore_pickle: bool, optional
- A flag whether to ignore the pickle file in future calls to
model.build()
. - n_threads: int, optional
- Number of CPU threads to use during parallelization using OpenMP. Defaults to max number of user CPU cores.
-
biphasic_axon_map_jax
[source]¶ Predicts the spatial response of BiphasicAxonMapModel using Jax
- eparams : jnp.array with shape (n_elecs, 3)
- Brightness, size, and streak length effect on each electrode
- x, y : jnp.array with shape (n_elecs)
- x and y coordinate of each electrode
- axon_segments : jnp.array with shape (n_points, n_ax_segments, 3)
- Closest axon segment to each simulated point, as returned by calc_axon_sensitivities
- rho : float
- The rho parameter of the axon map model: exponential decay constant (microns) away from the axon.
- axlambda : float
- The lambda parameter of the axon map model: exponential decay constant (microns) away from the cell body along the axon
- thresh_percept : float
- Spatial responses smaller than
thresh_percept
will be set to zero
-
build
(**build_params)[source]¶ Build the model
Performs expensive one-time calculations, such as building the spatial grid used to predict a percept. You must call
build
before callingpredict_percept
.Important
Don’t override this method if you are building your own model. Customize
_build
instead.Parameters: build_params (additional parameters to set) – You can overwrite parameters that are listed in get_default_params
. Trying to add new class attributes outside of that will cause aFreezeError
. Example:model.build(param1=val)
-
calc_axon_sensitivity
(bundles, pad=False)[source]¶ Calculate the sensitivity of each axon segment to electrical current
This function combines the x,y coordinates of each bundle segment with a sensitivity value that depends on the distance of the segment to the cell body and
self.axlambda
.The number of
bundles
must equal the number of points onself.grid`
. The function will then assume that the i-th bundle passes through the i-th point on the grid. This is used to determine the bundle segment that is closest to the i-th point on the grid, and to cut off all segments that extend beyond the soma. This effectively transforms a bundle into an axon, where the first axon segment now corresponds with the i-th location of the grid.After that, each axon segment gets a sensitivity value that depends on the distance of the segment to the soma (with decay rate
self.axlambda
). This is typically done during the build process, so that the only work left to do during run time is to multiply the sensitivity value with the current applied to each segment.If pad is True (set when engine is ‘jax’), axons are padded to all have the same length as the longest axon
Parameters: bundles (list of Nx2 arrays) – A list of bundles, where every bundle is an Nx2 array consisting of the x,y coordinates of each axon segment (retinal coords, microns). Note that each bundle will most likely have a different N Returns: axon_contrib (numpy array with shape (n_points, axon_length, 3)) – An array of axon segments and sensitivity values. Each entry in the array is a Nx3 array, where the first two columns contain the retinal coordinates of each axon segment (microns), and the third column contains the sensitivity of the segment to electrical current. The latter depends on self.axlambda
. axon_length is set to the maximum length of any axon after being trimmed due to min_sensitivity
-
calc_bundle_tangent
(xc, yc)[source]¶ Calculates orientation of fiber bundle tangent at (xc, yc)
Parameters: yc (xc,) – (x, y) retinal location of point at which to calculate bundle orientation in microns. Returns: tangent (scalar) – An angle in radians
-
calc_bundle_tangent_fast
(xc, yc, bundles=None)[source]¶ Calculates orientation of fiber bundle tangent at (xc, yc) This function supports multiple queries (xc and yc can be arrays), without requiring growing the axon bundles again for each point (like calc_bundle_tangent). It uses a ckdtree, which will be slower for single points, but significantly faster for multiple points.
Parameters: yc (xc,) – (x, y) retinal location of point at which to calculate bundle orientation in microns. Returns: tangent (array of floats) – Angles in radians
-
find_closest_axon
(bundles, xret=None, yret=None, return_index=False)[source]¶ Finds the closest axon segment for a point on the retina
This function will search a number of nerve fiber bundles (
bundles
) and return the bundle that is closest to a particular point (or list of points) on the retinal surface (xret
,yret
).Parameters: - bundles (list of Nx2 arrays) – A list of bundles, where every bundle is an Nx2 array consisting of the x,y coordinates of each axon segment (retinal coords, microns). Note that each bundle will most likely have a different N
- yret (xret,) – The x,y location on the retina (in microns, where the fovea is the origin) for which to find the closests axon.
- return_index (bool, optional) – If True, the function will also return the index into
bundles
that represents the closest axon
Returns: - axon (Nx2 array or list of Nx2 arrays) – For each point in (xret, yret), returns an Nx2 array that represents the closest axon to that point. Each row in the array contains the x,y retinal coordinates (microns) of a particular axon segment.
- idx_axon (scalar or list of scalars, optional) – If
return_index
is True, also returns the index inbundles
of the closest axon (or list of closest axons).
-
find_threshold
(implant, bright_th, amp_range=(0, 999), amp_tol=1, bright_tol=0.1, max_iter=100)[source]¶ Find the threshold current for a certain stimulus
Estimates
amp_th
such that the output ofmodel.predict_percept(stim(amp_th))
is approximatelybright_th
.Parameters: - implant (
ProsthesisSystem
) – The implant and its stimulus to use. Stimulus amplitude will be up and down regulated untilamp_th
is found. - bright_th (float) – Model output (brightness) that’s considered “at threshold”.
- amp_range ((amp_lo, amp_hi), optional) – Range of amplitudes to search (uA).
- amp_tol (float, optional) – Search will stop if candidate range of amplitudes is within
amp_tol
- bright_tol (float, optional) – Search will stop if model brightness is within
bright_tol
ofbright_th
- max_iter (int, optional) – Search will stop after
max_iter
iterations
Returns: amp_th (float) – Threshold current (uA), estimated so that the output of
model.predict_percept(stim(amp_th))
is withinbright_tol
ofbright_th
.- implant (
-
grow_axon_bundles
(n_bundles=None, prune=True)[source]¶ Grow a number of axon bundles
This method generates the trajectory of a number of nerve fiber bundles based on the mathematical model described in [Beyeler2019], which is based on [Jansonius2009].
Bundles originate at the optic nerve head with initial angle
phi0
. The method generatesn_bundles
axon bundles whosephi0
values are linearly sampled fromself.axons_range
(polar coords). Each axon will consist ofself.n_ax_segments
segments that spanself.ax_segments_range
distance from the optic nerve head (polar coords).Parameters: Returns: bundles (list of Nx2 arrays) – A list of bundles, where every bundle is an Nx2 array consisting of the x,y coordinates of each axon segment (retinal coords, microns). Note that each bundle will most likely have a different N
-
is_built
¶ A flag indicating whether the model has been built
-
plot
(use_dva=False, style='hull', annotate=True, autoscale=True, ax=None, figsize=None)[source]¶ Plot the axon map
Parameters: - use_dva (bool, optional) – Uses degrees of visual angle (dva) if True, else retinal coordinates (microns)
- style ({'hull', 'scatter', 'cell'}, optional) –
Grid plotting style:
- ’hull’: Show the convex hull of the grid (that is, the outline of the smallest convex set that contains all grid points).
- ’scatter’: Scatter plot all grid points
- ’cell’: Show the outline of each grid cell as a polygon. Note that this can be costly for a high-resolution grid.
- annotate (bool, optional) – Flag whether to label the four retinal quadrants
- autoscale (bool, optional) – Whether to adjust the x,y limits of the plot
- ax (matplotlib.axes._subplots.AxesSubplot, optional) – A Matplotlib axes object. If None, will either use the current axes (if exists) or create a new Axes object
- figsize ((float, float), optional) – Desired (width, height) of the figure in inches
-
predict_one_point_jax
(axon, eparams, x, y, rho)[source]¶ Predicts the brightness contribution from each axon segment for each pixel
-
predict_percept
(implant, t_percept=None)[source]¶ Predicts the spatial response Override base predict percept to have desired timesteps and remove unneccesary computation
Parameters: - implant (
ProsthesisSystem
) – A valid prosthesis system. A stimulus can be passed viastim
. - t_percept (float or list of floats, optional) – The time points at which to output a percept (ms).
If None,
implant.stim.time
is used.
Returns: percept (
Percept
) – A Percept object whosedata
container has dimensions Y x X x 1. Will return None ifimplant.stim
is None.- implant (
-
predict_percept_batched
(implant, stims, t_percept=None)[source]¶ Batched version of predict_percept Only supported with jax engine
This is significantly faster if you do not batch ALL of your percepts, but rather, split them into chunks (128 - 256 percepts each) and repeatedly call that. This is because jax has to compile on the first call, so repeated calls is much faster.
Parameters: - implant (
ProsthesisSystem
) – A valid prosthesis system. - stims (list of stimuli) – A percept will be predicted for each stimulus. Each stimulus must be a collection of :py:class:`~pulse2percept.stimuli.BiphasicPulseTrains’
- t_percept (float or list of floats, optional) – The time points at which to output a percept (ms).
If None,
implant.stim.time
is used.
Returns: percepts (list of
Percept
) – A list of Percept objects whosedata
container has dimensions Y x X x 1.- implant (
-
class
pulse2percept.models.granley2021.
DefaultBrightModel
(**params)[source]¶ Default model to be used for brightness scaling in BiphasicAxonMapModel Implements Eq 4 from [Granley2021] Fit using data from [Nanduri2012] and [Weitz2015]
- do_thresholding : bool, optional
- Set to true to enable probabilistic phosphene appearance at near-threshold amplitudes
- a0, a1 : float, optional
- Linear regression coefficients (slope and intercept) of pulse_duration vs threshold curve (Eq 3). Amplitude factor will be scaled by a0*pdur + a1.
- a2, a3, a4: float, optional
- Linear regression coefficients for brightness vs amplitude and frequency (Eq 4) F_bright = a2*scaled_amp + a3*freq + a4
-
build
(**build_params)[source]¶ Build the model
Every model must have a
`build
method, which is meant to perform all expensive one-time calculations. You must callbuild
before callingpredict_percept
.Important
Don’t override this method if you are building your own model. Customize
_build
instead.Parameters: build_params (additional parameters to set) – You can overwrite parameters that are listed in get_default_params
. Trying to add new class attributes outside of that will cause aFreezeError
. Example:model.build(param1=val)
-
is_built
¶ A flag indicating whether the model has been built
-
scale_threshold
(pdur)[source]¶ Based on eq 3 in paper, this function produces the factor that amplitude will be scaled by to produce a_tilde. Computes A_0 * t + A_1 (1/threshold) .. note:
This equation has been updated from the original paper, and has been refit to data from Argus II users from Horsager et al. 2009.
-
class
pulse2percept.models.granley2021.
DefaultSizeModel
(rho, engine='serial', **params)[source]¶ Default model to be used for size (rho) scaling in BiphasicAxonMapModel Implements Eq 5 from [Granley2021] Fit using data from [Nanduri2012] and [Weitz2015]
- rho : float32
- Rho parameter of BiphasicAxonMapModel (spatial decay rate)
- a0, a1 : float, optional
- Linear regression coefficients (slope and intercept) of pulse_duration vs threshold curve (Eq 3). Amplitude factor will be scaled by a0*pdur + a1.
- a5, a6 : float, optional
- Linear regression coefficients for size vs amplitude (Eq 5) F_size = a5*scaled_amp + a6
-
build
(**build_params)[source]¶ Build the model
Every model must have a
`build
method, which is meant to perform all expensive one-time calculations. You must callbuild
before callingpredict_percept
.Important
Don’t override this method if you are building your own model. Customize
_build
instead.Parameters: build_params (additional parameters to set) – You can overwrite parameters that are listed in get_default_params
. Trying to add new class attributes outside of that will cause aFreezeError
. Example:model.build(param1=val)
-
is_built
¶ A flag indicating whether the model has been built
-
scale_threshold
(pdur)[source]¶ Based on eq 3 in paper, this function produces the factor that amplitude will be scaled by to produce a_tilde. Computes A_0 * t + A_1 (1/threshold) .. note:
This equation has been updated from the original paper, and has been refit to data from Argus II users from Horsager et al. 2009.
-
class
pulse2percept.models.granley2021.
DefaultStreakModel
(axlambda, engine='serial', **params)[source]¶ Default model to be used for streak length (lambda) scaling in BiphasicAxonMapModel Implements Eq 6 from [Granley2021] Fit using data from [Weitz2015]
- axlambda : float32
- Axlambda parameter of BiphasicAxonMapModel (axonal decay rate)
- a7, a8, a9: float, optional
- Regression coefficients for streak length vs pulse duration (Eq 6) F_streak = -a7*pdur^a8 + a9
-
build
(**build_params)[source]¶ Build the model
Every model must have a
`build
method, which is meant to perform all expensive one-time calculations. You must callbuild
before callingpredict_percept
.Important
Don’t override this method if you are building your own model. Customize
_build
instead.Parameters: build_params (additional parameters to set) – You can overwrite parameters that are listed in get_default_params
. Trying to add new class attributes outside of that will cause aFreezeError
. Example:model.build(param1=val)
-
is_built
¶ A flag indicating whether the model has been built