diff --git a/stdlib/@tests/test_cases/check_math.py b/stdlib/@tests/test_cases/check_math.py new file mode 100644 index 000000000000..d637c15ff178 --- /dev/null +++ b/stdlib/@tests/test_cases/check_math.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from decimal import Decimal +from fractions import Fraction +from math import prod +from typing import Any, Literal, Union +from typing_extensions import assert_type + + +class SupportsMul: + def __mul__(self, other: Any) -> SupportsMul: + return SupportsMul() + + +class SupportsRMul: + def __rmul__(self, other: Any) -> SupportsRMul: + return SupportsRMul() + + +class SupportsMulAndRMul: + def __mul__(self, other: Any) -> SupportsMulAndRMul: + return SupportsMulAndRMul() + + def __rmul__(self, other: Any) -> SupportsMulAndRMul: + return SupportsMulAndRMul() + + +literal_list: list[Literal[0, 1]] = [0, 1, 1] + +assert_type(prod([2, 4]), int) +assert_type(prod([3, 5], start=4), int) + +assert_type(prod([True, False]), int) +assert_type(prod([True, False], start=True), int) +assert_type(prod(literal_list), int) + +assert_type(prod([SupportsMul(), SupportsMul()], start=SupportsMul()), SupportsMul) +assert_type(prod([SupportsMulAndRMul(), SupportsMulAndRMul()]), Union[SupportsMulAndRMul, Literal[1]]) + +assert_type(prod([5.6, 3.2]), Union[float, Literal[1]]) +assert_type(prod([5.6, 3.2], start=3), Union[float, int]) + +assert_type(prod([Fraction(7, 2), Fraction(3, 5)]), Union[Fraction, Literal[1]]) +assert_type(prod([Fraction(7, 2), Fraction(3, 5)], start=Fraction(1)), Fraction) +assert_type(prod([Decimal("3.14"), Decimal("2.71")]), Union[Decimal, Literal[1]]) +assert_type(prod([Decimal("3.14"), Decimal("2.71")], start=Decimal("1.00")), Decimal) +assert_type(prod([complex(7, 2), complex(3, 5)]), Union[complex, Literal[1]]) +assert_type(prod([complex(7, 2), complex(3, 5)], start=complex(1, 0)), complex) + + +# mypy and pyright infer the types differently for these, so we can't use assert_type +# Just test that no error is emitted for any of these +prod([5.6, 3.2]) # mypy: `float`; pyright: `float | Literal[0]` +prod([2.5, 5.8], start=5) # mypy: `float`; pyright: `float | int` + +# These all fail at runtime +prod([SupportsMul(), SupportsMul()]) # type: ignore +prod([SupportsRMul(), SupportsRMul()], start=SupportsRMul()) # type: ignore +prod([SupportsRMul(), SupportsRMul()]) # type: ignore + +# TODO: these pass pyright with the current stubs, but mypy erroneously emits an error: +# prod([3, Fraction(7, 22), complex(8, 0), 9.83]) +# prod([3, Decimal("0.98")]) diff --git a/stdlib/_typeshed/__init__.pyi b/stdlib/_typeshed/__init__.pyi index 7201819b25ed..2849fc642612 100644 --- a/stdlib/_typeshed/__init__.pyi +++ b/stdlib/_typeshed/__init__.pyi @@ -117,6 +117,12 @@ class SupportsSub(Protocol[_T_contra, _T_co]): class SupportsRSub(Protocol[_T_contra, _T_co]): def __rsub__(self, x: _T_contra, /) -> _T_co: ... +class SupportsMul(Protocol[_T_contra, _T_co]): + def __mul__(self, x: _T_contra, /) -> _T_co: ... + +class SupportsRMul(Protocol[_T_contra, _T_co]): + def __rmul__(self, x: _T_contra, /) -> _T_co: ... + class SupportsDivMod(Protocol[_T_contra, _T_co]): def __divmod__(self, other: _T_contra, /) -> _T_co: ... diff --git a/stdlib/math.pyi b/stdlib/math.pyi index 86f71f27580a..f73429cf6940 100644 --- a/stdlib/math.pyi +++ b/stdlib/math.pyi @@ -1,6 +1,7 @@ import sys +from _typeshed import SupportsMul, SupportsRMul from collections.abc import Iterable -from typing import Final, Protocol, SupportsFloat, SupportsIndex, TypeVar, overload +from typing import Any, Final, Literal, Protocol, SupportsFloat, SupportsIndex, TypeVar, overload from typing_extensions import TypeAlias _T = TypeVar("_T") @@ -99,10 +100,29 @@ elif sys.version_info >= (3, 9): def perm(n: SupportsIndex, k: SupportsIndex | None = None, /) -> int: ... def pow(x: _SupportsFloatOrIndex, y: _SupportsFloatOrIndex, /) -> float: ... + +_PositiveInteger: TypeAlias = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25] +_NegativeInteger: TypeAlias = Literal[-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20] +_LiteralInteger = _PositiveInteger | _NegativeInteger | Literal[0] # noqa: Y026 # TODO: Use TypeAlias once mypy bugs are fixed + +_MultiplicableT1 = TypeVar("_MultiplicableT1", bound=SupportsMul[Any, Any]) +_MultiplicableT2 = TypeVar("_MultiplicableT2", bound=SupportsMul[Any, Any]) + +class _SupportsProdWithNoDefaultGiven(SupportsMul[Any, Any], SupportsRMul[int, Any], Protocol): ... + +_SupportsProdNoDefaultT = TypeVar("_SupportsProdNoDefaultT", bound=_SupportsProdWithNoDefaultGiven) + +# This stub is based on the type stub for `builtins.sum`. +# Like `builtins.sum`, it cannot be precisely represented in a type stub +# without introducing many false positives. +# For more details on its limitations and false positives, see #13572. +# Instead, just like `builtins.sum`, we explicitly handle several useful cases. +@overload +def prod(iterable: Iterable[bool | _LiteralInteger], /, *, start: int = 1) -> int: ... # type: ignore[overload-overlap] @overload -def prod(iterable: Iterable[SupportsIndex], /, *, start: SupportsIndex = 1) -> int: ... # type: ignore[overload-overlap] +def prod(iterable: Iterable[_SupportsProdNoDefaultT], /) -> _SupportsProdNoDefaultT | Literal[1]: ... @overload -def prod(iterable: Iterable[_SupportsFloatOrIndex], /, *, start: _SupportsFloatOrIndex = 1) -> float: ... +def prod(iterable: Iterable[_MultiplicableT1], /, *, start: _MultiplicableT2) -> _MultiplicableT1 | _MultiplicableT2: ... def radians(x: _SupportsFloatOrIndex, /) -> float: ... def remainder(x: _SupportsFloatOrIndex, y: _SupportsFloatOrIndex, /) -> float: ... def sin(x: _SupportsFloatOrIndex, /) -> float: ...