Skip to content

Commit 0e0b907

Browse files
committed
update code style + docstrings to match standard usage
1 parent 4ba37c8 commit 0e0b907

File tree

1 file changed

+79
-53
lines changed

1 file changed

+79
-53
lines changed

control/phaseplot.py

Lines changed: 79 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
'phaseplot.separatrices_radius': 0.1 # initial radius for separatrices
4646
}
4747

48+
4849
def phase_plane_plot(
4950
sys, pointdata=None, timedata=None, gridtype=None, gridspec=None,
5051
plot_streamlines=None, plot_vectorfield=None, plot_streamplot=None,
@@ -131,17 +132,17 @@ def phase_plane_plot(
131132
'both' to flow both forward and backward. The amount of time to
132133
simulate in each direction is given by the `timedata` argument.
133134
plot_streamlines : bool or dict, optional
134-
If then plot streamlines based on the pointdata and gridtype. If set
135-
to a dict, pass on the key-value pairs in the dict as keywords to
136-
`streamlines`.
135+
If True then plot streamlines based on the pointdata and gridtype.
136+
If set to a dict, pass on the key-value pairs in the dict as
137+
keywords to `streamlines`.
137138
plot_vectorfield : bool or dict, optional
138-
If then plot the vector field based on the pointdata and gridtype.
139-
If set to a dict, pass on the key-value pairs in the dict as keywords
140-
to `phaseplot.vectorfield`.
139+
If True then plot the vector field based on the pointdata and
140+
gridtype. If set to a dict, pass on the key-value pairs in the
141+
dict as keywords to `phaseplot.vectorfield`.
141142
plot_streamplot : bool or dict, optional
142-
If then use :func:`matplotlib.axes.Axes.streamplot` function
143+
If True then use `matplotlib.axes.Axes.streamplot` function
143144
to plot the streamlines. If set to a dict, pass on the key-value
144-
pairs in the dict as keywords to :func:`~control.phaseplot.streamplot`.
145+
pairs in the dict as keywords to `phaseplot.streamplot`.
145146
plot_equilpoints : bool or dict, optional
146147
If True (default) then plot equilibrium points based in the phase
147148
plot boundary. If set to a dict, pass on the key-value pairs in the
@@ -158,6 +159,16 @@ def phase_plane_plot(
158159
title : str, optional
159160
Set the title of the plot. Defaults to plot type and system name(s).
160161
162+
Notes
163+
-----
164+
The default method for producing streamlines is determined based on which
165+
keywords are specified, with `plot_streamplot` serving as the generic
166+
default. If any of the `arrows`, `arrow_size`, `arrow_style`, or `dir`
167+
keywords are used and neither `plot_streamlines` nor `plot_streamplot` is
168+
set, then `plot_streamlines` will be set to True. If neither
169+
`plot_streamlines` nor `plot_vectorfield` set set to True, then
170+
`plot_streamplot` will be set to True.
171+
161172
"""
162173
# Check for legacy usage of plot_streamlines
163174
streamline_keywords = [
@@ -174,11 +185,8 @@ def phase_plane_plot(
174185
"falling back to streamlines")
175186
plot_streamlines = True
176187

177-
if (
178-
plot_streamlines is None
179-
and plot_vectorfield is None
180-
and plot_streamplot is None
181-
):
188+
if plot_streamlines is None and plot_vectorfield is None \
189+
and plot_streamplot is None:
182190
plot_streamplot = True
183191

184192
if plot_streamplot and not plot_streamlines and not plot_vectorfield:
@@ -219,9 +227,10 @@ def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
219227
out[0] += streamlines(
220228
sys, pointdata, timedata, _check_kwargs=False,
221229
suppress_warnings=suppress_warnings, **kwargs_local)
222-
230+
223231
new_zorder = max(elem.get_zorder() for elem in out[0])
224-
flow_zorder = max(flow_zorder, new_zorder) if flow_zorder else new_zorder
232+
flow_zorder = max(flow_zorder, new_zorder) if flow_zorder \
233+
else new_zorder
225234

226235
# Get rid of keyword arguments handled by streamlines
227236
for kw in ['arrows', 'arrow_size', 'arrow_style', 'color',
@@ -237,25 +246,28 @@ def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
237246
kwargs, plot_vectorfield, gridspec=gridspec, ax=ax)
238247
out[1] = vectorfield(
239248
sys, pointdata, _check_kwargs=False, **kwargs_local)
240-
249+
241250
new_zorder = out[1].get_zorder()
242-
flow_zorder = max(flow_zorder, new_zorder) if flow_zorder else new_zorder
251+
flow_zorder = max(flow_zorder, new_zorder) if flow_zorder \
252+
else new_zorder
243253

244254
# Get rid of keyword arguments handled by vectorfield
245255
for kw in ['color', 'params']:
246256
initial_kwargs.pop(kw, None)
247257

248258
if plot_streamplot:
249259
if gridtype not in [None, 'meshgrid']:
250-
raise ValueError("gridtype must be 'meshgrid' when using streamplot")
260+
raise ValueError(
261+
"gridtype must be 'meshgrid' when using streamplot")
251262

252263
kwargs_local = _create_kwargs(
253264
kwargs, plot_streamplot, gridspec=gridspec, ax=ax)
254265
out[3] = streamplot(
255266
sys, pointdata, _check_kwargs=False, **kwargs_local)
256-
267+
257268
new_zorder = max(out[3].lines.get_zorder(), out[3].arrows.get_zorder())
258-
flow_zorder = max(flow_zorder, new_zorder) if flow_zorder else new_zorder
269+
flow_zorder = max(flow_zorder, new_zorder) if flow_zorder \
270+
else new_zorder
259271

260272
# Get rid of keyword arguments handled by streamplot
261273
for kw in ['color', 'params']:
@@ -269,8 +281,9 @@ def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
269281
kwargs_local['zorder'] = kwargs_local.get('zorder', sep_zorder)
270282
out[0] += separatrices(
271283
sys, pointdata, _check_kwargs=False, **kwargs_local)
272-
273-
sep_zorder = max(elem.get_zorder() for elem in out[0]) if out[0] else None
284+
285+
sep_zorder = max(elem.get_zorder() for elem in out[0]) if out[0] \
286+
else None
274287

275288
# Get rid of keyword arguments handled by separatrices
276289
for kw in ['arrows', 'arrow_size', 'arrow_style', 'params']:
@@ -340,9 +353,6 @@ def vectorfield(
340353
dict with key 'args' and value given by a tuple (passed to callable).
341354
color : matplotlib color spec, optional
342355
Plot the vector field in the given color.
343-
zorder : float, optional
344-
Set the zorder for the separatrices. In not specified, it will be
345-
automatically chosen by `matplotlib.axes.Axes.quiver`.
346356
ax : `matplotlib.axes.Axes`, optional
347357
Use the given axes for the plot, otherwise use the current axes.
348358
@@ -357,6 +367,9 @@ def vectorfield(
357367
Default is set by `config.defaults['ctrlplot.rcParams']`.
358368
suppress_warnings : bool, optional
359369
If set to True, suppress warning messages in generating trajectories.
370+
zorder : float, optional
371+
Set the zorder for the separatrices. In not specified, it will be
372+
automatically chosen by `matplotlib.axes.Axes.quiver`.
360373
361374
"""
362375
# Process keywords
@@ -427,23 +440,14 @@ def streamplot(
427440
dict with key 'args' and value given by a tuple (passed to callable).
428441
color : matplotlib color spec, optional
429442
Plot the vector field in the given color.
430-
vary_color : bool, optional
431-
If set to True, vary the color of the streamlines based on the magnitude
432-
vary_linewidth : bool, optional.
433-
If set to True, vary the linewidth of the streamlines based on the magnitude.
434-
cmap : str or Colormap, optional
435-
Colormap to use for varying the color of the streamlines.
436-
norm : `matplotlib.colors.Normalize`, optional
437-
An instance of Normalize to use for scaling the colormap and linewidths.
438-
zorder : float, optional
439-
Set the zorder for the separatrices. In not specified, it will be
440-
automatically chosen by `matplotlib.axes.Axes.streamplot`.
441443
ax : `matplotlib.axes.Axes`, optional
442444
Use the given axes for the plot, otherwise use the current axes.
443445
444446
Returns
445447
-------
446448
out : StreamplotSet
449+
Containter object with lines and arrows contained in the
450+
streamplot. See `matplotlib.axes.Axes.streamplot` for details.
447451
448452
Other Parameters
449453
----------------
@@ -452,6 +456,19 @@ def streamplot(
452456
Default is set by `config.default['ctrlplot.rcParams']`.
453457
suppress_warnings : bool, optional
454458
If set to True, suppress warning messages in generating trajectories.
459+
vary_color : bool, optional
460+
If set to True, vary the color of the streamlines based on the
461+
magnitude of the vector field.
462+
vary_linewidth : bool, optional.
463+
If set to True, vary the linewidth of the streamlines based on the
464+
magnitude of the vector field.
465+
cmap : str or Colormap, optional
466+
Colormap to use for varying the color of the streamlines.
467+
norm : `matplotlib.colors.Normalize`, optional
468+
Normalization map to use for scaling the colormap and linewidths.
469+
zorder : float, optional
470+
Set the zorder for the separatrices. In not specified, it will be
471+
automatically chosen by `matplotlib.axes.Axes.streamplot`.
455472
456473
"""
457474
# Process keywords
@@ -466,7 +483,8 @@ def streamplot(
466483
# Determine the points on which to generate the streamplot field
467484
points, gridspec = _make_points(pointdata, gridspec, 'meshgrid')
468485
grid_arr_shape = gridspec[::-1]
469-
xs, ys = points[:, 0].reshape(grid_arr_shape), points[:, 1].reshape(grid_arr_shape)
486+
xs = points[:, 0].reshape(grid_arr_shape)
487+
ys = points[:, 1].reshape(grid_arr_shape)
470488

471489
# Create axis if needed
472490
if ax is None:
@@ -484,25 +502,29 @@ def streamplot(
484502

485503
# Generate phase plane (quiver) data
486504
sys._update_params(params)
487-
us_flat, vs_flat = np.transpose([sys._rhs(0, x, np.zeros(sys.ninputs)) for x in points])
505+
us_flat, vs_flat = np.transpose(
506+
[sys._rhs(0, x, np.zeros(sys.ninputs)) for x in points])
488507
us, vs = us_flat.reshape(grid_arr_shape), vs_flat.reshape(grid_arr_shape)
489508

490509
magnitudes = np.linalg.norm([us, vs], axis=0)
491510
norm = norm or mpl.colors.Normalize()
492511
normalized = norm(magnitudes)
493-
cmap = plt.get_cmap(cmap)
512+
cmap = plt.get_cmap(cmap)
494513

495514
with plt.rc_context(rcParams):
496515
default_lw = plt.rcParams['lines.linewidth']
497516
min_lw, max_lw = 0.25*default_lw, 2*default_lw
498-
linewidths = normalized * (max_lw - min_lw) + min_lw if vary_linewidth else None
517+
linewidths = normalized * (max_lw - min_lw) + min_lw \
518+
if vary_linewidth else None
499519
color = magnitudes if vary_color else color
500520

501-
out = ax.streamplot(xs, ys, us, vs, color=color, linewidth=linewidths,
502-
cmap=cmap, norm=norm, zorder=zorder)
521+
out = ax.streamplot(
522+
xs, ys, us, vs, color=color, linewidth=linewidths, cmap=cmap,
523+
norm=norm, zorder=zorder)
503524

504525
return out
505526

527+
506528
def streamlines(
507529
sys, pointdata, timedata=1, gridspec=None, gridtype=None, dir=None,
508530
zorder=None, ax=None, _check_kwargs=True, suppress_warnings=False,
@@ -548,9 +570,6 @@ def streamlines(
548570
dict with key 'args' and value given by a tuple (passed to callable).
549571
color : str
550572
Plot the streamlines in the given color.
551-
zorder : float, optional
552-
Set the zorder for the separatrices. In not specified, it will be
553-
automatically chosen by `matplotlib.axes.Axes.plot`.
554573
ax : `matplotlib.axes.Axes`, optional
555574
Use the given axes for the plot, otherwise use the current axes.
556575
@@ -574,6 +593,9 @@ def streamlines(
574593
Default is set by `config.defaults['ctrlplot.rcParams']`.
575594
suppress_warnings : bool, optional
576595
If set to True, suppress warning messages in generating trajectories.
596+
zorder : float, optional
597+
Set the zorder for the separatrices. In not specified, it will be
598+
automatically chosen by `matplotlib.axes.Axes.plot`.
577599
578600
"""
579601
# Process keywords
@@ -672,9 +694,6 @@ def equilpoints(
672694
dict with key 'args' and value given by a tuple (passed to callable).
673695
color : str
674696
Plot the equilibrium points in the given color.
675-
zorder : float, optional
676-
Set the zorder for the separatrices. In not specified, it will be
677-
automatically chosen by `matplotlib.axes.Axes.plot`.
678697
ax : `matplotlib.axes.Axes`, optional
679698
Use the given axes for the plot, otherwise use the current axes.
680699
@@ -687,6 +706,9 @@ def equilpoints(
687706
rcParams : dict
688707
Override the default parameters used for generating plots.
689708
Default is set by `config.defaults['ctrlplot.rcParams']`.
709+
zorder : float, optional
710+
Set the zorder for the separatrices. In not specified, it will be
711+
automatically chosen by `matplotlib.axes.Axes.plot`.
690712
691713
"""
692714
# Process keywords
@@ -720,7 +742,8 @@ def equilpoints(
720742
out = []
721743
for xeq in equilpts:
722744
with plt.rc_context(rcParams):
723-
out += ax.plot(xeq[0], xeq[1], marker='o', color=color, zorder=zorder)
745+
out += ax.plot(
746+
xeq[0], xeq[1], marker='o', color=color, zorder=zorder)
724747
return out
725748

726749

@@ -767,9 +790,6 @@ def separatrices(
767790
separatrices. If a tuple is given, the first element is used as
768791
the color specification for stable separatrices and the second
769792
element for unstable separatrices.
770-
zorder : float, optional
771-
Set the zorder for the separatrices. In not specified, it will be
772-
automatically chosen by `matplotlib.axes.Axes.plot`.
773793
ax : `matplotlib.axes.Axes`, optional
774794
Use the given axes for the plot, otherwise use the current axes.
775795
@@ -784,6 +804,9 @@ def separatrices(
784804
Default is set by `config.defaults['ctrlplot.rcParams']`.
785805
suppress_warnings : bool, optional
786806
If set to True, suppress warning messages in generating trajectories.
807+
zorder : float, optional
808+
Set the zorder for the separatrices. In not specified, it will be
809+
automatically chosen by `matplotlib.axes.Axes.plot`.
787810
788811
Notes
789812
-----
@@ -884,7 +907,8 @@ def separatrices(
884907
if traj.shape[1] > 1:
885908
with plt.rc_context(rcParams):
886909
out += ax.plot(
887-
traj[0], traj[1], color=color, linestyle=linestyle, zorder=zorder)
910+
traj[0], traj[1], color=color,
911+
linestyle=linestyle, zorder=zorder)
888912

889913
# Add arrows to the lines at specified intervals
890914
with plt.rc_context(rcParams):
@@ -984,6 +1008,7 @@ def circlegrid(centers, radius, num):
9841008
theta in np.linspace(0, 2 * math.pi, num, endpoint=False)])
9851009
return grid
9861010

1011+
9871012
#
9881013
# Internal utility functions
9891014
#
@@ -1004,6 +1029,7 @@ def _create_system(sys, params):
10041029
return NonlinearIOSystem(
10051030
_update, _output, states=2, inputs=0, outputs=0, name="_callable")
10061031

1032+
10071033
# Set axis limits for the plot
10081034
def _set_axis_limits(ax, pointdata):
10091035
# Get the current axis limits

0 commit comments

Comments
 (0)