Skip to content

Commit be5fee4

Browse files
committed
switch to using dtype validation for datetime timezones
1 parent 830fff8 commit be5fee4

File tree

3 files changed

+37
-15
lines changed

3 files changed

+37
-15
lines changed

dataframely/columns/datetime.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def __init__(
273273
max: dt.datetime | None = None,
274274
max_exclusive: dt.datetime | None = None,
275275
resolution: str | None = None,
276-
time_zone: ZoneInfo | str | dt.timezone | None = None,
276+
time_zone: str | dt.tzinfo | None = None,
277277
check: Callable[[pl.Expr], pl.Expr] | None = None,
278278
alias: str | None = None,
279279
metadata: dict[str, Any] | None = None,
@@ -339,16 +339,6 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]:
339339
result = super().validation_rules(expr)
340340
if self.resolution is not None:
341341
result["resolution"] = expr.dt.truncate(self.resolution) == expr
342-
if self.time_zone is not None:
343-
time_zone = (
344-
self.time_zone.key
345-
if isinstance(self.time_zone, ZoneInfo)
346-
else self.time_zone
347-
)
348-
result["time_zone"] = pl.coalesce(
349-
expr == pl.selectors.datetime(time_unit="us", time_zone=time_zone),
350-
False,
351-
)
352342
return result
353343

354344
def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:

dataframely/random.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import datetime as dt
55
from collections.abc import Sequence
66
from typing import TypeVar
7-
from zoneinfo import ZoneInfo
87

98
import numpy as np
109
import polars as pl
@@ -294,7 +293,7 @@ def sample_datetime(
294293
min: dt.datetime,
295294
max: dt.datetime | None,
296295
resolution: str | None = None,
297-
time_zone: ZoneInfo | str | dt.timezone | None = None,
296+
time_zone: str | dt.tzinfo | None = None,
298297
null_probability: float = 0.0,
299298
) -> pl.Series:
300299
"""Sample a list of datetimes in the provided range.

tests/column_types/test_datetime.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
import datetime as dt
5+
import re
56
from typing import Any
67
from zoneinfo import ZoneInfo
78

@@ -11,6 +12,7 @@
1112

1213
import dataframely as dy
1314
from dataframely.columns import Column
15+
from dataframely.exc import DtypeValidationError
1416
from dataframely.random import Generator
1517
from dataframely.testing import evaluate_rules, rules_from_exprs
1618
from dataframely.testing.factory import create_schema
@@ -519,11 +521,42 @@ def test_validate_datetime_timezone(
519521
[
520522
dy.Datetime(
521523
min=dt.datetime(2020, 1, 1), max=dt.datetime(2021, 1, 1), resolution="1h"
522-
)
524+
),
525+
dy.Datetime(time_zone="Etc/UTC"),
523526
],
524527
)
525-
def test_sample_resolution(column: dy.Column) -> None:
528+
def test_sample(column: dy.Column) -> None:
526529
generator = Generator(seed=42)
527530
samples = column.sample(generator, n=10_000)
528531
schema = create_schema("test", {"a": column})
529532
schema.validate(samples.to_frame("a"))
533+
534+
535+
@pytest.mark.parametrize(
536+
("dtype", "column", "error"),
537+
[
538+
(
539+
pl.Datetime(time_zone="America/New_York"),
540+
dy.Datetime(time_zone="Etc/UTC"),
541+
r"1 columns have an invalid dtype.*\n.*got dtype 'Datetime\(time_unit='us', time_zone='America/New_York'\)' but expected 'Datetime\(time_unit='us', time_zone='Etc/UTC'\)'",
542+
),
543+
(
544+
pl.Datetime(time_zone="Etc/UTC"),
545+
dy.Datetime(time_zone="Etc/UTC"),
546+
None,
547+
),
548+
],
549+
)
550+
def test_dtype_validation(
551+
dtype: pl.DataType,
552+
column: dy.Column,
553+
error: str | None,
554+
) -> None:
555+
df = pl.DataFrame(schema={"a": dtype})
556+
schema = create_schema("test", {"a": column})
557+
if error is None:
558+
schema.validate(df)
559+
else:
560+
with pytest.raises(DtypeValidationError) as exc:
561+
schema.validate(df)
562+
assert re.match(error, str(exc.value))

0 commit comments

Comments
 (0)