Skip to content

Support TypedDict unpacking in ParamSpec specifications #16120

Open
@tmke8

Description

@tmke8

Feature

Related to

TypedDict unpacking in ParamSpec would work just like it works now in Callables.

P = ParamSpec("P")
class C(Generic[P]):
    def __init__(self, f: Callable[P, None]): ...

class Args(TypedDict):
    x: int
    y: str

def f(*, x: int, y: str) -> None: ...
c: C[[Unpack[Args]]] = C(f)  # OK
d: C[[int, str]] =  = C(f)  # error because `f` expects keyword arguments

Pitch

In order to express a callable type with keyword arguments, you can use a call protocol, but this doesn't work for other classes that are generic in ParamSpec. For example, in pytorch, network layers have to inherit from Module which should be typed approximately like this (using Python 3.12 generic syntax):

class Module[T, **P]:
    @abstractmethod
    def forward(self, *args: P.args, **kwargs: P.kwargs) -> T: ...

    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
        # do other stuff
        return self.forward(*args, **kwargs)

But what to do if I want to override forward with an optional argument?

class Dense(Module[Tensor, [Tensor, bool]]):
    def forward(self, x: Tensor, *, with_dropout: bool = False):
        # my implementation
        return x

With TypedDict unpacking in ParamSpec:

class ExtraArgs(TypedDict):
    with_dropout: bool

class Dense(Module[Tensor, [Tensor, Unpack[ExtraArgs]]]):
    def forward(self, x: Tensor, *, with_dropout: bool = False):
        # my implementation
        return x

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions