Skip to content

fix: Image Feature in Datasets Library Fails to Handle bytearray Objects from Spark DataFrames (#7517) #7521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/datasets/features/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class Audio:
def __call__(self):
return self.pa_type

def encode_example(self, value: Union[str, bytes, dict]) -> dict:
def encode_example(self, value: Union[str, bytes, bytearray, dict]) -> dict:
"""Encode example into a format for Arrow.

Args:
Expand All @@ -90,7 +90,7 @@ def encode_example(self, value: Union[str, bytes, dict]) -> dict:
raise ImportError("To support encoding audio data, please install 'soundfile'.") from err
if isinstance(value, str):
return {"bytes": None, "path": value}
elif isinstance(value, bytes):
elif isinstance(value, (bytes, bytearray)):
return {"bytes": value, "path": None}
elif "array" in value:
# convert the audio array to wav bytes
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/features/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class Image:
def __call__(self):
return self.pa_type

def encode_example(self, value: Union[str, bytes, dict, np.ndarray, "PIL.Image.Image"]) -> dict:
def encode_example(self, value: Union[str, bytes, bytearray, dict, np.ndarray, "PIL.Image.Image"]) -> dict:
"""Encode example into a format for Arrow.

Args:
Expand All @@ -111,7 +111,7 @@ def encode_example(self, value: Union[str, bytes, dict, np.ndarray, "PIL.Image.I

if isinstance(value, str):
return {"path": value, "bytes": None}
elif isinstance(value, bytes):
elif isinstance(value, (bytes, bytearray)):
return {"path": None, "bytes": value}
elif isinstance(value, np.ndarray):
# convert the image array to PNG/TIFF bytes
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/features/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class Pdf:
def __call__(self):
return self.pa_type

def encode_example(self, value: Union[str, bytes, dict, "pdfplumber.pdf.PDF"]) -> dict:
def encode_example(self, value: Union[str, bytes, bytearray, dict, "pdfplumber.pdf.PDF"]) -> dict:
"""Encode example into a format for Arrow.

Args:
Expand All @@ -92,7 +92,7 @@ def encode_example(self, value: Union[str, bytes, dict, "pdfplumber.pdf.PDF"]) -

if isinstance(value, str):
return {"path": value, "bytes": None}
elif isinstance(value, bytes):
elif isinstance(value, (bytes, bytearray)):
return {"path": None, "bytes": value}
elif pdfplumber is not None and isinstance(value, pdfplumber.pdf.PDF):
# convert the pdfplumber.pdf.PDF to bytes
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/features/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class Video:
def __call__(self):
return self.pa_type

def encode_example(self, value: Union[str, bytes, Example, np.ndarray, "VideoReader"]) -> Example:
def encode_example(self, value: Union[str, bytes, bytearray, Example, np.ndarray, "VideoReader"]) -> Example:
"""Encode example into a format for Arrow.

Args:
Expand All @@ -92,7 +92,7 @@ def encode_example(self, value: Union[str, bytes, Example, np.ndarray, "VideoRea

if isinstance(value, str):
return {"path": value, "bytes": None}
elif isinstance(value, bytes):
elif isinstance(value, (bytes, bytearray)):
return {"path": None, "bytes": value}
elif isinstance(value, np.ndarray):
# convert the video array to bytes
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/keyhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@
from huggingface_hub.utils import insecure_hashlib


def _as_bytes(hash_data: Union[str, int, bytes]) -> bytes:
def _as_bytes(hash_data: Union[str, int, bytes, bytearray]) -> bytes:
"""
Returns the input hash_data in its bytes form

Args:
hash_data: the hash salt/key to be converted to bytes
"""
if isinstance(hash_data, bytes):
if isinstance(hash_data, (bytes, bytearray)):
# Data already in bytes, returns as it as
return hash_data
elif isinstance(hash_data, str):
Expand Down
37 changes: 37 additions & 0 deletions tests/packaged_modules/test_spark.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from unittest.mock import patch

import numpy as np
import pyspark
import pytest

from datasets import Features, Image, IterableDataset
from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesList
from datasets.packaged_modules.spark.spark import (
Expand Down Expand Up @@ -131,3 +133,38 @@ def test_repartition_df_if_needed_max_num_df_rows():
spark_builder._repartition_df_if_needed(max_shard_size=1)
# The new number of partitions should not be greater than the number of rows.
assert spark_builder.df.rdd.getNumPartitions() == 100


@require_not_windows
@require_dill_gt_0_3_2
def test_iterable_image_features():
spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate()
img_bytes = np.zeros((10, 10, 3), dtype=np.uint8).tobytes()
data = [(img_bytes,)]
df = spark.createDataFrame(data, "image: binary")
features = Features({"image": Image(decode=False)})
dset = IterableDataset.from_spark(df, features=features)
item = next(iter(dset))
assert item.keys() == {"image"}
assert item == {"image": {"path": None, "bytes": img_bytes}}


@require_not_windows
@require_dill_gt_0_3_2
def test_iterable_image_features_decode():
from io import BytesIO

import PIL.Image

spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate()
img = PIL.Image.fromarray(np.zeros((10, 10, 3), dtype=np.uint8), "RGB")
buffer = BytesIO()
img.save(buffer, format="PNG")
img_bytes = bytes(buffer.getvalue())
data = [(img_bytes,)]
df = spark.createDataFrame(data, "image: binary")
features = Features({"image": Image()})
dset = IterableDataset.from_spark(df, features=features)
item = next(iter(dset))
assert item.keys() == {"image"}
assert isinstance(item["image"], PIL.Image.Image)