From 2ed7e242f8919d2cc54c65f20170640fd878c142 Mon Sep 17 00:00:00 2001 From: "suzushi.tomori" Date: Tue, 19 Dec 2023 21:37:13 +0900 Subject: [PATCH] Fix bug related to TypedDict handling within Tuple --- mypy/checkexpr.py | 17 ++++++++++++ test-data/unit/check-inference.test | 3 ++- test-data/unit/check-literal.test | 4 +-- test-data/unit/check-unions.test | 40 +++++++++++++++++++++++++++++ 4 files changed, 61 insertions(+), 3 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 626584bc3a20..c203f281dc75 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -4918,6 +4918,19 @@ def tuple_context_matches(self, expr: TupleExpr, ctx: TupleType) -> bool: expr_star_index = next(i for i, lv in enumerate(expr.items) if isinstance(lv, StarExpr)) return len(expr.items) == len(ctx.items) and ctx_unpack_index == expr_star_index + def union_with_tuple_context_matches( + self, expr: TupleExpr, ctx: UnionType + ) -> TupleType | None: + for item in ctx.items: + item = get_proper_type(item) + if isinstance(item, TupleType) and self.tuple_context_matches(expr, item): + return item + elif isinstance(item, UnionType): + inner = self.union_with_tuple_context_matches(expr, item) + if inner: + return inner + return None + def visit_tuple_expr(self, e: TupleExpr) -> Type: """Type check a tuple expression.""" # Try to determine type context for type inference. @@ -4943,6 +4956,10 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type: assert isinstance(type_context, Instance) if type_context.args: type_context_items = [type_context.args[0]] * len(e.items) + elif isinstance(type_context, UnionType): + inner_tuple = self.union_with_tuple_context_matches(e, type_context) + if inner_tuple: + type_context_items = inner_tuple.items # NOTE: it's possible for the context to have a different # number of items than e. In that case we use those context # items that match a position in e, and we'll worry about type diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index 953855e502d6..4bf79932e457 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -3440,7 +3440,8 @@ x: Iterable[List[Union[int, str]]] = (foo([1]), foo(["a"])) from typing import Dict, Iterable, Tuple, Union def foo(x: Union[Tuple[str, Dict[str, int], str], Iterable[object]]) -> None: ... -foo(("a", {"a": "b"}, "b")) +foo(("a", {"a": 1}, "b")) +foo(("a", {"a": "b"}, "b")) # E: Dict entry 0 has incompatible type "str": "str"; expected "str": "int" [builtins fixtures/dict.pyi] [case testUseSupertypeAsInferenceContext] diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index d9ad68385ad1..a77dc3a48b9d 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -2891,9 +2891,9 @@ def invalid_literal_type() -> Tuple[Literal[1]]: def incorrect_return1() -> Union[Tuple[Literal[True], int], Tuple[Literal[False], str]]: if x: - return (False, 5) # E: Incompatible return value type (got "Tuple[bool, int]", expected "Union[Tuple[Literal[True], int], Tuple[Literal[False], str]]") + return (False, 5) # E: Incompatible return value type (got "Tuple[Literal[False], int]", expected "Union[Tuple[Literal[True], int], Tuple[Literal[False], str]]") else: - return (True, 'oops') # E: Incompatible return value type (got "Tuple[bool, str]", expected "Union[Tuple[Literal[True], int], Tuple[Literal[False], str]]") + return (True, 'oops') # E: Incompatible return value type (got "Tuple[Literal[True], str]", expected "Union[Tuple[Literal[True], int], Tuple[Literal[False], str]]") def incorrect_return2() -> Union[Tuple[Literal[True], int], Tuple[Literal[False], str]]: if x: diff --git a/test-data/unit/check-unions.test b/test-data/unit/check-unions.test index d79ab14184c6..fd7b5bc1de50 100644 --- a/test-data/unit/check-unions.test +++ b/test-data/unit/check-unions.test @@ -1220,3 +1220,43 @@ nc: Union[Container[str], int] 'x' in nc # E: Unsupported right operand type for in ("Union[Container[str], int]") [builtins fixtures/tuple.pyi] [typing fixtures/typing-full.pyi] + +[case testUnionInnerTypedDict] +from typing import TypedDict, Tuple, Union +Point = TypedDict('Point', {'x': int, 'y': int}) +def f1(z: bool) -> Union[Tuple[Point, int], Tuple[Point, str]]: + if z: + return {'x': 1, 'y': 2}, 3 # OK + return {'x': 1, 'y': 2}, "a" # OK +def f2(z: bool) -> Union[Tuple[Point, int], Tuple[int, str]]: + if z: + return {'x': 1, 'y': 2}, 3 # OK + return 3, "a" # OK +def f3(z: bool) -> Union[Union[Tuple[Point, int], Tuple[Point, str]], int]: + if z: + return {'x': 1, 'y': 2}, 3 # OK + return 1 # OK +def f4(z: Union[Tuple[Point, int], Tuple[Point, str]]) -> None: ... +def f5(z: Union[Tuple[Point, int], Tuple[Point, str]]) -> Union[Tuple[Point, int], Tuple[Point, str]]: + return z +def f6(z: Union[Union[Tuple[Point, int], Tuple[Point, str], int]]) -> None: ... + +p = Point(x=42, y=1337) +f4((p, 1)) +f4(({'x': 42, 'y': 1337}, 1)) +f4((p, "a")) +f4(({'x': 42, 'y': 1337}, "a")) +f4(({'x': 42, 'y': 1337}, True)) +f4((p, 1.0)) # E: Argument 1 to "f4" has incompatible type "Tuple[Point, float]"; expected "Union[Tuple[Point, int], Tuple[Point, str]]" +f5((p, 1)) +f5(({'x': 42, 'y': 1337}, 1)) +f5((p, "a")) +f5(({'x': 42, 'y': 1337}, "a")) +f5(({'x': 42, 'y': 1337}, True)) +f5((p, 1.0)) # E: Argument 1 to "f5" has incompatible type "Tuple[Point, float]"; expected "Union[Tuple[Point, int], Tuple[Point, str]]" +f6((p, 1)) +f6(({'x': 42, 'y': 1337}, 1)) +f6(1) +f6((p, 1.0)) # E: Argument 1 to "f6" has incompatible type "Tuple[Point, float]"; expected "Union[Tuple[Point, int], Tuple[Point, str], int]" +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi]