Skip to content

Commit fa62831

Browse files
committed
switch to using dtype validation for datetime timezones
1 parent 830fff8 commit fa62831

File tree

3 files changed

+36
-122
lines changed

3 files changed

+36
-122
lines changed

dataframely/columns/datetime.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def __init__(
273273
max: dt.datetime | None = None,
274274
max_exclusive: dt.datetime | None = None,
275275
resolution: str | None = None,
276-
time_zone: ZoneInfo | str | dt.timezone | None = None,
276+
time_zone: str | dt.tzinfo | None = None,
277277
check: Callable[[pl.Expr], pl.Expr] | None = None,
278278
alias: str | None = None,
279279
metadata: dict[str, Any] | None = None,
@@ -339,16 +339,6 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]:
339339
result = super().validation_rules(expr)
340340
if self.resolution is not None:
341341
result["resolution"] = expr.dt.truncate(self.resolution) == expr
342-
if self.time_zone is not None:
343-
time_zone = (
344-
self.time_zone.key
345-
if isinstance(self.time_zone, ZoneInfo)
346-
else self.time_zone
347-
)
348-
result["time_zone"] = pl.coalesce(
349-
expr == pl.selectors.datetime(time_unit="us", time_zone=time_zone),
350-
False,
351-
)
352342
return result
353343

354344
def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:

dataframely/random.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import datetime as dt
55
from collections.abc import Sequence
66
from typing import TypeVar
7-
from zoneinfo import ZoneInfo
87

98
import numpy as np
109
import polars as pl
@@ -294,7 +293,7 @@ def sample_datetime(
294293
min: dt.datetime,
295294
max: dt.datetime | None,
296295
resolution: str | None = None,
297-
time_zone: ZoneInfo | str | dt.timezone | None = None,
296+
time_zone: str | dt.tzinfo | None = None,
298297
null_probability: float = 0.0,
299298
) -> pl.Series:
300299
"""Sample a list of datetimes in the provided range.

tests/column_types/test_datetime.py

+34-109
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
import datetime as dt
5+
import re
56
from typing import Any
6-
from zoneinfo import ZoneInfo
77

88
import polars as pl
99
import pytest
1010
from polars.testing import assert_frame_equal
1111

1212
import dataframely as dy
1313
from dataframely.columns import Column
14+
from dataframely.exc import DtypeValidationError
1415
from dataframely.random import Generator
1516
from dataframely.testing import evaluate_rules, rules_from_exprs
1617
from dataframely.testing.factory import create_schema
@@ -408,122 +409,46 @@ def test_validate_resolution(
408409

409410

410411
@pytest.mark.parametrize(
411-
("column", "values", "schema", "valid"),
412+
"column",
412413
[
413-
(
414-
dy.Datetime(),
415-
[
416-
dt.datetime(2020, 4, 5, tzinfo=dt.UTC),
417-
dt.datetime(2021, 1, 1, 12, tzinfo=dt.UTC),
418-
dt.datetime(2022, 7, 10, 0, 0, 1, tzinfo=dt.UTC),
419-
],
420-
None,
421-
{},
422-
),
423-
(
424-
dy.Datetime(time_zone="UTC"),
425-
[
426-
dt.datetime(2020, 4, 5),
427-
dt.datetime(2021, 1, 1, 12),
428-
dt.datetime(2022, 7, 10, 0, 0, 1),
429-
],
430-
None,
431-
{"time_zone": [False, False, False]},
432-
),
433-
(
434-
dy.Datetime(time_zone="UTC"),
435-
[
436-
dt.datetime(2020, 4, 5, tzinfo=dt.UTC),
437-
dt.datetime(2021, 1, 1, 12, tzinfo=dt.UTC),
438-
dt.datetime(2022, 7, 10, 0, 0, 1, tzinfo=dt.UTC),
439-
],
440-
None,
441-
{"time_zone": [True, True, True]},
442-
),
443-
(
444-
dy.Datetime(
445-
time_zone=dt.timezone(
446-
dt.timedelta(hours=-7), name="America/Los_Angeles"
447-
)
448-
),
449-
[
450-
dt.datetime(2020, 4, 5),
451-
dt.datetime(2021, 1, 1, 12),
452-
dt.datetime(2022, 7, 10, 0, 0, 1),
453-
],
454-
pl.Datetime(time_zone="America/Los_Angeles"),
455-
{"time_zone": [True, True, True]},
456-
),
457-
(
458-
dy.Datetime(
459-
time_zone=dt.timezone(
460-
dt.timedelta(hours=-7), name="America/Los_Angeles"
461-
)
462-
),
463-
[
464-
dt.datetime(2020, 4, 5),
465-
dt.datetime(2021, 1, 1, 12),
466-
dt.datetime(2022, 7, 10, 0, 0, 1),
467-
],
468-
None,
469-
{"time_zone": [False, False, False]},
470-
),
471-
(
472-
dy.Datetime(time_zone=dt.timezone(dt.timedelta(hours=-7))),
473-
[
474-
dt.datetime(2020, 4, 5),
475-
dt.datetime(2021, 1, 1, 12),
476-
dt.datetime(2022, 7, 10, 0, 0, 1),
477-
],
478-
pl.Datetime(time_zone="America/Los_Angeles"),
479-
{"time_zone": [False, False, False]},
480-
),
481-
(
482-
dy.Datetime(time_zone=ZoneInfo("Etc/UTC")),
483-
[
484-
dt.datetime(2020, 4, 5, tzinfo=ZoneInfo("Etc/UTC")),
485-
dt.datetime(2021, 1, 1, 12, tzinfo=ZoneInfo("Etc/UTC")),
486-
dt.datetime(2022, 7, 10, 0, 0, 1, tzinfo=ZoneInfo("Etc/UTC")),
487-
],
488-
None,
489-
{"time_zone": [True, True, True]},
490-
),
491-
(
492-
dy.Datetime(time_zone=ZoneInfo("Etc/UTC")),
493-
[
494-
dt.datetime(2020, 4, 5, tzinfo=ZoneInfo("America/New_York")),
495-
dt.datetime(2021, 1, 1, 12, tzinfo=ZoneInfo("America/New_York")),
496-
dt.datetime(2022, 7, 10, 0, 0, 1, tzinfo=ZoneInfo("America/New_York")),
497-
],
498-
None,
499-
{"time_zone": [False, False, False]},
414+
dy.Datetime(
415+
min=dt.datetime(2020, 1, 1), max=dt.datetime(2021, 1, 1), resolution="1h"
500416
),
417+
dy.Datetime(time_zone="Etc/UTC"),
501418
],
502419
)
503-
def test_validate_datetime_timezone(
504-
column: Column,
505-
values: list[Any],
506-
schema: pl.Datetime | None,
507-
valid: dict[str, list[bool]],
508-
) -> None:
509-
lf = pl.LazyFrame(
510-
{"a": values}, schema={"a": schema} if schema is not None else None
511-
)
512-
actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a"))))
513-
expected = pl.LazyFrame(valid)
514-
assert_frame_equal(actual, expected)
420+
def test_sample(column: dy.Column) -> None:
421+
generator = Generator(seed=42)
422+
samples = column.sample(generator, n=10_000)
423+
schema = create_schema("test", {"a": column})
424+
schema.validate(samples.to_frame("a"))
515425

516426

517427
@pytest.mark.parametrize(
518-
"column",
428+
("dtype", "column", "error"),
519429
[
520-
dy.Datetime(
521-
min=dt.datetime(2020, 1, 1), max=dt.datetime(2021, 1, 1), resolution="1h"
522-
)
430+
(
431+
pl.Datetime(time_zone="America/New_York"),
432+
dy.Datetime(time_zone="Etc/UTC"),
433+
r"1 columns have an invalid dtype.*\n.*got dtype 'Datetime\(time_unit='us', time_zone='America/New_York'\)' but expected 'Datetime\(time_unit='us', time_zone='Etc/UTC'\)'",
434+
),
435+
(
436+
pl.Datetime(time_zone="Etc/UTC"),
437+
dy.Datetime(time_zone="Etc/UTC"),
438+
None,
439+
),
523440
],
524441
)
525-
def test_sample_resolution(column: dy.Column) -> None:
526-
generator = Generator(seed=42)
527-
samples = column.sample(generator, n=10_000)
442+
def test_dtype_validation(
443+
dtype: pl.DataType,
444+
column: dy.Column,
445+
error: str | None,
446+
) -> None:
447+
df = pl.DataFrame(schema={"a": dtype})
528448
schema = create_schema("test", {"a": column})
529-
schema.validate(samples.to_frame("a"))
449+
if error is None:
450+
schema.validate(df)
451+
else:
452+
with pytest.raises(DtypeValidationError) as exc:
453+
schema.validate(df)
454+
assert re.match(error, str(exc.value))

0 commit comments

Comments
 (0)