diff --git a/src/cleanvision/utils/viz_manager.py b/src/cleanvision/utils/viz_manager.py index 5c2ca665..204df2a2 100644 --- a/src/cleanvision/utils/viz_manager.py +++ b/src/cleanvision/utils/viz_manager.py @@ -42,23 +42,25 @@ def plot_image_grid( ) -> None: nrows = math.ceil(len(images) / ncols) ncols = min(ncols, len(images)) + + # Convert list of images to a 4D Numpy array + arr = np.array([np.array(image) for image in images]) + fig, axes = plt.subplots( nrows, ncols, figsize=(cell_size[0] * ncols, cell_size[1] * nrows) ) - if nrows > 1: - idx = 0 - for i in range(nrows): - for j in range(ncols): - idx = i * ncols + j - if idx >= len(images): - axes[i, j].axis("off") - continue - set_image_on_axes(images[idx], axes[i, j], titles[idx]) - if idx >= len(images): - break - elif ncols > 1: - for i in range(min(ncols, len(images))): - set_image_on_axes(images[i], axes[i], titles[i]) - else: - set_image_on_axes(images[0], axes, titles[0]) + + # Create a 2D array of indices + idxs = np.arange(nrows * ncols).reshape(nrows, ncols) + + # Set axes properties + for ax in axes.flatten(): + ax.get_xaxis().set_visible(False) + ax.get_yaxis().set_visible(False) + + # Set images on axes using advanced indexing + axes[idxs[:len(images) // ncols + 1, :len(images) % ncols]] = arr + for i, title in enumerate(titles): + axes.flat[i].set_title(title, fontsize=7) + plt.show()