Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ def visit_func_def(self, o: FuncDef) -> None:
if init in self.method_names:
# Can't have both an attribute and a method/property with the same name.
continue
init_code = self.get_init(init, value, annotation)
init_code = self.get_init(init, value, annotation, self_init=True)
if init_code:
self.add(init_code)

Expand Down Expand Up @@ -1288,7 +1288,11 @@ def visit_import(self, o: Import) -> None:
self.record_name(target_name)

def get_init(
self, lvalue: str, rvalue: Expression, annotation: Type | None = None
self,
lvalue: str,
rvalue: Expression,
annotation: Type | None = None,
self_init: bool = False,
) -> str | None:
"""Return initializer for a variable.

Expand Down Expand Up @@ -1320,7 +1324,10 @@ def get_init(
return f"{self._indent}{lvalue} = ...\n"
else:
typename = self.get_str_type_of_node(rvalue)
initializer = self.get_assign_initializer(rvalue)
if self_init:
initializer = ""
else:
initializer = self.get_assign_initializer(rvalue)
return f"{self._indent}{lvalue}: {typename}{initializer}\n"

def get_assign_initializer(self, rvalue: Expression) -> str:
Expand All @@ -1344,6 +1351,10 @@ def get_assign_initializer(self, rvalue: Expression) -> str:
return f" = {rvalue.accept(p)}"
if not (isinstance(rvalue, TempNode) and rvalue.no_rhs):
return " = ..."
if self.processing_enum:
return ""
if not (isinstance(rvalue, TempNode) and rvalue.no_rhs):
return " = ..."
# TODO: support other possible cases, where initializer is important

# By default, no initializer is required:
Expand Down
26 changes: 13 additions & 13 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ class C:
x = 1
[out]
class C:
x: int
x: int = ...

[case testInitTypeAnnotationPreserved]
class C:
Expand Down Expand Up @@ -267,7 +267,7 @@ class C:
x: int

class C:
x: int
x: int = ...
def __init__(self) -> None: ...

[case testEmptyClass]
Expand Down Expand Up @@ -319,7 +319,7 @@ class A:
_x: int

class A:
_y: int
_y: int = ...

[case testSpecialInternalVar]
__all__ = []
Expand Down Expand Up @@ -706,7 +706,7 @@ class A:
__all__ = ['A']

class A:
x: int
x: int = ...
def f(self) -> None: ...

[case testSkipMultiplePrivateDefs]
Expand Down Expand Up @@ -972,16 +972,16 @@ class A(NamedTuple):

class B(A):
z1: str
z2: int
z3: str
z2: int = ...
z3: str = ...

class RegularClass:
x: int
y: str
y: str = ...
class NestedNamedTuple(NamedTuple):
x: int
y: str = ...
z: str
z: str = ...


[case testNestedClassInNamedTuple_semanal-xfail]
Expand Down Expand Up @@ -1337,7 +1337,7 @@ class A:
[out]
class A:
class B:
x: int
x: int = ...
def f(self) -> None: ...
def g(self) -> None: ...

Expand Down Expand Up @@ -1388,7 +1388,7 @@ class A:
from _typeshed import Incomplete

class A:
x: Incomplete
x: Incomplete = ...
def __init__(self, a=None) -> None: ...
def method(self, a=None) -> None: ...

Expand Down Expand Up @@ -4342,21 +4342,21 @@ def create_model(cls): ...

class X:
a: int
b: str
b: str = ...

@typing_extensions.dataclass_transform(kw_only_default=True)
class ModelBase: ...

class Y(ModelBase):
a: int
b: str
b: str = ...

@typing_extensions.dataclass_transform(kw_only_default=True)
class DCMeta(type): ...

class Z(metaclass=DCMeta):
a: int
b: str
b: str = ...

[case testDataclassTransformDecorator_semanal]
import typing_extensions
Expand Down