Skip to content

Commit 697c240

Browse files
authored
Merge pull request matplotlib#29954 from 34j/perf/better-multicolored-line
Simplify `colored_line()` implementation in Multicolored lines example
2 parents 76a47d9 + 05c622b commit 697c240

File tree

1 file changed

+27
-36
lines changed

1 file changed

+27
-36
lines changed

galleries/examples/lines_bars_and_markers/multicolored_line.py

+27-36
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from matplotlib.collections import LineCollection
2222

2323

24-
def colored_line(x, y, c, ax, **lc_kwargs):
24+
def colored_line(x, y, c, ax=None, **lc_kwargs):
2525
"""
2626
Plot a line with a color specified along the line by a third value.
2727
@@ -36,8 +36,8 @@ def colored_line(x, y, c, ax, **lc_kwargs):
3636
The horizontal and vertical coordinates of the data points.
3737
c : array-like
3838
The color values, which should be the same size as x and y.
39-
ax : Axes
40-
Axis object on which to plot the colored line.
39+
ax : matplotlib.axes.Axes, optional
40+
The axes to plot on. If not provided, the current axes will be used.
4141
**lc_kwargs
4242
Any additional arguments to pass to matplotlib.collections.LineCollection
4343
constructor. This should not include the array keyword argument because
@@ -49,36 +49,32 @@ def colored_line(x, y, c, ax, **lc_kwargs):
4949
The generated line collection representing the colored line.
5050
"""
5151
if "array" in lc_kwargs:
52-
warnings.warn('The provided "array" keyword argument will be overridden')
52+
warnings.warn(
53+
'The provided "array" keyword argument will be overridden',
54+
UserWarning,
55+
stacklevel=2,
56+
)
57+
58+
xy = np.stack((x, y), axis=-1)
59+
xy_mid = np.concat(
60+
(xy[0, :][None, :], (xy[:-1, :] + xy[1:, :]) / 2, xy[-1, :][None, :]), axis=0
61+
)
62+
segments = np.stack((xy_mid[:-1, :], xy, xy_mid[1:, :]), axis=-2)
63+
# Note that
64+
# segments[0, :, :] is [xy[0, :], xy[0, :], (xy[0, :] + xy[1, :]) / 2]
65+
# segments[i, :, :] is [(xy[i - 1, :] + xy[i, :]) / 2, xy[i, :],
66+
# (xy[i, :] + xy[i + 1, :]) / 2] if i not in {0, len(x) - 1}
67+
# segments[-1, :, :] is [(xy[-2, :] + xy[-1, :]) / 2, xy[-1, :], xy[-1, :]]
68+
69+
lc_kwargs["array"] = c
70+
lc = LineCollection(segments, **lc_kwargs)
5371

54-
# Default the capstyle to butt so that the line segments smoothly line up
55-
default_kwargs = {"capstyle": "butt"}
56-
default_kwargs.update(lc_kwargs)
57-
58-
# Compute the midpoints of the line segments. Include the first and last points
59-
# twice so we don't need any special syntax later to handle them.
60-
x = np.asarray(x)
61-
y = np.asarray(y)
62-
x_midpts = np.hstack((x[0], 0.5 * (x[1:] + x[:-1]), x[-1]))
63-
y_midpts = np.hstack((y[0], 0.5 * (y[1:] + y[:-1]), y[-1]))
64-
65-
# Determine the start, middle, and end coordinate pair of each line segment.
66-
# Use the reshape to add an extra dimension so each pair of points is in its
67-
# own list. Then concatenate them to create:
68-
# [
69-
# [(x1_start, y1_start), (x1_mid, y1_mid), (x1_end, y1_end)],
70-
# [(x2_start, y2_start), (x2_mid, y2_mid), (x2_end, y2_end)],
71-
# ...
72-
# ]
73-
coord_start = np.column_stack((x_midpts[:-1], y_midpts[:-1]))[:, np.newaxis, :]
74-
coord_mid = np.column_stack((x, y))[:, np.newaxis, :]
75-
coord_end = np.column_stack((x_midpts[1:], y_midpts[1:]))[:, np.newaxis, :]
76-
segments = np.concatenate((coord_start, coord_mid, coord_end), axis=1)
77-
78-
lc = LineCollection(segments, **default_kwargs)
79-
lc.set_array(c) # set the colors of each segment
72+
# Plot the line collection to the axes
73+
ax = ax or plt.gca()
74+
ax.add_collection(lc)
75+
ax.autoscale_view()
8076

81-
return ax.add_collection(lc)
77+
return lc
8278

8379

8480
# -------------- Create and show plot --------------
@@ -93,11 +89,6 @@ def colored_line(x, y, c, ax, **lc_kwargs):
9389
lines = colored_line(x, y, color, ax1, linewidth=10, cmap="plasma")
9490
fig1.colorbar(lines) # add a color legend
9591

96-
# Set the axis limits and tick positions
97-
ax1.set_xlim(-1, 1)
98-
ax1.set_ylim(-1, 1)
99-
ax1.set_xticks((-1, 0, 1))
100-
ax1.set_yticks((-1, 0, 1))
10192
ax1.set_title("Color at each point")
10293

10394
plt.show()

0 commit comments

Comments
 (0)