21
21
from matplotlib .collections import LineCollection
22
22
23
23
24
- def colored_line (x , y , c , ax , ** lc_kwargs ):
24
+ def colored_line (x , y , c , ax = None , ** lc_kwargs ):
25
25
"""
26
26
Plot a line with a color specified along the line by a third value.
27
27
@@ -36,8 +36,8 @@ def colored_line(x, y, c, ax, **lc_kwargs):
36
36
The horizontal and vertical coordinates of the data points.
37
37
c : array-like
38
38
The color values, which should be the same size as x and y.
39
- ax : Axes
40
- Axis object on which to plot the colored line .
39
+ ax : matplotlib.axes. Axes, optional
40
+ The axes to plot on. If not provided, the current axes will be used .
41
41
**lc_kwargs
42
42
Any additional arguments to pass to matplotlib.collections.LineCollection
43
43
constructor. This should not include the array keyword argument because
@@ -49,36 +49,32 @@ def colored_line(x, y, c, ax, **lc_kwargs):
49
49
The generated line collection representing the colored line.
50
50
"""
51
51
if "array" in lc_kwargs :
52
- warnings .warn ('The provided "array" keyword argument will be overridden' )
52
+ warnings .warn (
53
+ 'The provided "array" keyword argument will be overridden' ,
54
+ UserWarning ,
55
+ stacklevel = 2 ,
56
+ )
57
+
58
+ xy = np .stack ((x , y ), axis = - 1 )
59
+ xy_mid = np .concat (
60
+ (xy [0 , :][None , :], (xy [:- 1 , :] + xy [1 :, :]) / 2 , xy [- 1 , :][None , :]), axis = 0
61
+ )
62
+ segments = np .stack ((xy_mid [:- 1 , :], xy , xy_mid [1 :, :]), axis = - 2 )
63
+ # Note that
64
+ # segments[0, :, :] is [xy[0, :], xy[0, :], (xy[0, :] + xy[1, :]) / 2]
65
+ # segments[i, :, :] is [(xy[i - 1, :] + xy[i, :]) / 2, xy[i, :],
66
+ # (xy[i, :] + xy[i + 1, :]) / 2] if i not in {0, len(x) - 1}
67
+ # segments[-1, :, :] is [(xy[-2, :] + xy[-1, :]) / 2, xy[-1, :], xy[-1, :]]
68
+
69
+ lc_kwargs ["array" ] = c
70
+ lc = LineCollection (segments , ** lc_kwargs )
53
71
54
- # Default the capstyle to butt so that the line segments smoothly line up
55
- default_kwargs = {"capstyle" : "butt" }
56
- default_kwargs .update (lc_kwargs )
57
-
58
- # Compute the midpoints of the line segments. Include the first and last points
59
- # twice so we don't need any special syntax later to handle them.
60
- x = np .asarray (x )
61
- y = np .asarray (y )
62
- x_midpts = np .hstack ((x [0 ], 0.5 * (x [1 :] + x [:- 1 ]), x [- 1 ]))
63
- y_midpts = np .hstack ((y [0 ], 0.5 * (y [1 :] + y [:- 1 ]), y [- 1 ]))
64
-
65
- # Determine the start, middle, and end coordinate pair of each line segment.
66
- # Use the reshape to add an extra dimension so each pair of points is in its
67
- # own list. Then concatenate them to create:
68
- # [
69
- # [(x1_start, y1_start), (x1_mid, y1_mid), (x1_end, y1_end)],
70
- # [(x2_start, y2_start), (x2_mid, y2_mid), (x2_end, y2_end)],
71
- # ...
72
- # ]
73
- coord_start = np .column_stack ((x_midpts [:- 1 ], y_midpts [:- 1 ]))[:, np .newaxis , :]
74
- coord_mid = np .column_stack ((x , y ))[:, np .newaxis , :]
75
- coord_end = np .column_stack ((x_midpts [1 :], y_midpts [1 :]))[:, np .newaxis , :]
76
- segments = np .concatenate ((coord_start , coord_mid , coord_end ), axis = 1 )
77
-
78
- lc = LineCollection (segments , ** default_kwargs )
79
- lc .set_array (c ) # set the colors of each segment
72
+ # Plot the line collection to the axes
73
+ ax = ax or plt .gca ()
74
+ ax .add_collection (lc )
75
+ ax .autoscale_view ()
80
76
81
- return ax . add_collection ( lc )
77
+ return lc
82
78
83
79
84
80
# -------------- Create and show plot --------------
@@ -93,11 +89,6 @@ def colored_line(x, y, c, ax, **lc_kwargs):
93
89
lines = colored_line (x , y , color , ax1 , linewidth = 10 , cmap = "plasma" )
94
90
fig1 .colorbar (lines ) # add a color legend
95
91
96
- # Set the axis limits and tick positions
97
- ax1 .set_xlim (- 1 , 1 )
98
- ax1 .set_ylim (- 1 , 1 )
99
- ax1 .set_xticks ((- 1 , 0 , 1 ))
100
- ax1 .set_yticks ((- 1 , 0 , 1 ))
101
92
ax1 .set_title ("Color at each point" )
102
93
103
94
plt .show ()
0 commit comments