Skip to content

Commit 1c15013

Browse files
committed
Add support for static initialized defaults
1 parent d72482e commit 1c15013

6 files changed

Lines changed: 126 additions & 34 deletions

File tree

mypyc/codegen/emitclass.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
FuncIR,
3030
get_text_signature,
3131
)
32+
from mypyc.ir.module_ir import ModuleIR, get_module_top_level
3233
from mypyc.ir.rtypes import RTuple, RType, object_rprimitive
3334
from mypyc.namegen import NameGenerator
3435
from mypyc.sametype import is_same_type
@@ -193,7 +194,7 @@ def generate_class_type_decl(
193194
)
194195

195196

196-
def generate_class(cl: ClassIR, module: str, emitter: Emitter) -> None:
197+
def generate_class(cl: ClassIR, module: ModuleIR, emitter: Emitter) -> None:
197198
"""Generate C code for a class.
198199
199200
This is the main entry point to the module.
@@ -333,7 +334,7 @@ def emit_line() -> None:
333334
if cl.is_trait:
334335
generate_new_for_trait(cl, new_name, emitter)
335336

336-
generate_methods_table(cl, methods_name, emitter)
337+
generate_methods_table(cl, methods_name, emitter, module)
337338
emit_line()
338339

339340
flags = ["Py_TPFLAGS_DEFAULT", "Py_TPFLAGS_HEAPTYPE", "Py_TPFLAGS_BASETYPE"]
@@ -352,7 +353,7 @@ def emit_line() -> None:
352353
flags.append("Py_TPFLAGS_MANAGED_DICT")
353354
fields["tp_flags"] = " | ".join(flags)
354355

355-
fields["tp_doc"] = native_class_doc_initializer(cl)
356+
fields["tp_doc"] = native_class_doc_initializer(cl, get_module_top_level(module))
356357

357358
emitter.emit_line(f"static PyTypeObject {emitter.type_struct_name(cl)}_template_ = {{")
358359
emitter.emit_line("PyVarObject_HEAD_INIT(NULL, 0)")
@@ -837,7 +838,7 @@ def generate_finalize_for_class(
837838
emitter.emit_line("}")
838839

839840

840-
def generate_methods_table(cl: ClassIR, name: str, emitter: Emitter) -> None:
841+
def generate_methods_table(cl: ClassIR, name: str, emitter: Emitter, module: ModuleIR) -> None:
841842
emitter.emit_line(f"static PyMethodDef {name}[] = {{")
842843
for fn in cl.methods.values():
843844
if fn.decl.is_prop_setter or fn.decl.is_prop_getter or fn.internal:
@@ -850,7 +851,7 @@ def generate_methods_table(cl: ClassIR, name: str, emitter: Emitter) -> None:
850851
elif fn.decl.kind == FUNC_CLASSMETHOD:
851852
flags.append("METH_CLASS")
852853

853-
doc = native_function_doc_initializer(fn)
854+
doc = native_function_doc_initializer(fn, get_module_top_level(module))
854855
emitter.emit_line(" {}, {}}},".format(" | ".join(flags), doc))
855856

856857
# Provide a default __getstate__ and __setstate__
@@ -1111,10 +1112,10 @@ def has_managed_dict(cl: ClassIR, emitter: Emitter) -> bool:
11111112
)
11121113

11131114

1114-
def native_class_doc_initializer(cl: ClassIR) -> str:
1115+
def native_class_doc_initializer(cl: ClassIR, module_body: FuncIR) -> str:
11151116
init_fn = cl.get_method("__init__")
11161117
if init_fn is not None:
1117-
text_sig = get_text_signature(init_fn, bound=True)
1118+
text_sig = get_text_signature(init_fn, module_body, bound=True)
11181119
if text_sig is None:
11191120
return "NULL"
11201121
text_sig = text_sig.replace("__init__", cl.name, 1)

mypyc/codegen/emitfunc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ def native_function_header(fn: FuncDecl, emitter: Emitter) -> str:
113113
)
114114

115115

116-
def native_function_doc_initializer(func: FuncIR) -> str:
117-
text_sig = get_text_signature(func)
116+
def native_function_doc_initializer(func: FuncIR, module_body: FuncIR) -> str:
117+
text_sig = get_text_signature(func, module_body)
118118
if text_sig is None:
119119
return "NULL"
120120
docstring = f"{text_sig}\n--\n\n"

mypyc/codegen/emitmodule.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
)
5555
from mypyc.errors import Errors
5656
from mypyc.ir.func_ir import FuncIR
57-
from mypyc.ir.module_ir import ModuleIR, ModuleIRs, deserialize_modules
57+
from mypyc.ir.module_ir import ModuleIR, ModuleIRs, deserialize_modules, get_module_top_level
5858
from mypyc.ir.ops import DeserMaps, LoadLiteral
5959
from mypyc.ir.rtypes import RType
6060
from mypyc.irbuild.main import build_ir
@@ -567,7 +567,7 @@ def generate_c_for_modules(self) -> list[tuple[str, str]]:
567567

568568
for cl in module.classes:
569569
if cl.is_ext_class:
570-
generate_class(cl, module_name, emitter)
570+
generate_class(cl, module, emitter)
571571

572572
# Generate Python extension module definitions and module initialization functions.
573573
self.generate_module_def(emitter, module_name, module)
@@ -919,7 +919,7 @@ def emit_module_methods(
919919
flag = "METH_FASTCALL"
920920
else:
921921
flag = "METH_VARARGS"
922-
doc = native_function_doc_initializer(fn)
922+
doc = native_function_doc_initializer(fn, get_module_top_level(module))
923923
emitter.emit_line(
924924
(
925925
'{{"{name}", (PyCFunction){prefix}{cname}, {flag} | METH_KEYWORDS, '

mypyc/ir/func_ir.py

Lines changed: 83 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,20 @@
1313
AssignMulti,
1414
BasicBlock,
1515
Box,
16+
CallC,
1617
ControlOp,
1718
DeserMaps,
1819
Float,
20+
GetElementPtr,
21+
InitStatic,
1922
Integer,
23+
IntOp,
2024
LoadAddress,
2125
LoadLiteral,
26+
LoadMem,
27+
LoadStatic,
2228
Register,
29+
SetMem,
2330
TupleSet,
2431
Value,
2532
)
@@ -406,7 +413,7 @@ def all_values_full(args: list[Register], blocks: list[BasicBlock]) -> list[Valu
406413
_NOT_REPRESENTABLE = object()
407414

408415

409-
def get_text_signature(fn: FuncIR, *, bound: bool = False) -> str | None:
416+
def get_text_signature(fn: FuncIR, module_body: FuncIR, *, bound: bool = False) -> str | None:
410417
"""Return a text signature in CPython's internal doc format, or None
411418
if the function's signature cannot be represented.
412419
"""
@@ -430,7 +437,7 @@ def get_text_signature(fn: FuncIR, *, bound: bool = False) -> str | None:
430437
)
431438
default: object = inspect.Parameter.empty
432439
if arg.optional:
433-
default = _find_default_argument(arg.name, fn.blocks)
440+
default = _find_default_argument(arg.name, fn.blocks, module_body.blocks)
434441
if default is _NOT_REPRESENTABLE:
435442
# This default argument cannot be represented in a __text_signature__
436443
return None
@@ -444,16 +451,29 @@ def get_text_signature(fn: FuncIR, *, bound: bool = False) -> str | None:
444451
return f"{fn.name}{inspect.Signature(parameters)}"
445452

446453

447-
def _find_default_argument(name: str, blocks: list[BasicBlock]) -> object:
454+
def _find_default_argument(
455+
name: str, blocks: list[BasicBlock], static_blocks: list[BasicBlock]
456+
) -> object:
448457
# Find assignment inserted by gen_arg_defaults. Assumed to be the first assignment.
449458
for block in blocks:
450459
for op in block.ops:
451460
if isinstance(op, Assign) and op.dest.name == name:
452-
return _extract_python_literal(op.src)
461+
if isinstance(op.src, LoadStatic):
462+
return _find_init_static(op.src.identifier, static_blocks)
463+
else:
464+
return _extract_python_literal(op.src, blocks)
465+
return _NOT_REPRESENTABLE
466+
467+
468+
def _find_init_static(fullname: str, blocks: list[BasicBlock]) -> object:
469+
for block in blocks:
470+
for op in block.ops:
471+
if isinstance(op, InitStatic) and op.identifier == fullname:
472+
return _extract_python_literal(op.value, blocks)
453473
return _NOT_REPRESENTABLE
454474

455475

456-
def _extract_python_literal(value: Value) -> object:
476+
def _extract_python_literal(value: Value, blocks: list[BasicBlock]) -> object:
457477
if isinstance(value, Integer):
458478
if is_none_rprimitive(value.type):
459479
return None
@@ -466,10 +486,66 @@ def _extract_python_literal(value: Value) -> object:
466486
elif isinstance(value, LoadLiteral):
467487
return value.value
468488
elif isinstance(value, Box):
469-
return _extract_python_literal(value.src)
489+
return _extract_python_literal(value.src, blocks)
470490
elif isinstance(value, TupleSet):
471-
items = tuple(_extract_python_literal(item) for item in value.items)
491+
items = tuple(_extract_python_literal(item, blocks) for item in value.items)
472492
if any(itm is _NOT_REPRESENTABLE for itm in items):
473493
return _NOT_REPRESENTABLE
474494
return items
495+
elif isinstance(value, CallC):
496+
if value.function_name == "PyList_New":
497+
assert len(value.args) == 1
498+
size = _extract_python_literal(value.args[0], blocks)
499+
if size == 0:
500+
return []
501+
return _extract_list(value, blocks)
502+
if value.function_name == "PyDict_New":
503+
return {}
504+
if value.function_name == "CPyDict_Build":
505+
args = [_extract_python_literal(arg, blocks) for arg in value.args]
506+
if any(arg is _NOT_REPRESENTABLE for arg in args):
507+
return _NOT_REPRESENTABLE
508+
return {k: v for k, v in zip(args[1::2], args[2::2])}
509+
if value.function_name == "PySet_New":
510+
result = _extract_set(value, blocks)
511+
if not result:
512+
return _NOT_REPRESENTABLE # set() isn't valid in __text_signature__
513+
return result
514+
elif isinstance(value, LoadAddress) and value.src == "_Py_EllipsisObject":
515+
return _EllipsisLiteral()
475516
return _NOT_REPRESENTABLE
517+
518+
519+
def _extract_list(value: CallC, blocks: list[BasicBlock]) -> object:
520+
result = []
521+
for block in blocks:
522+
for op in block.ops:
523+
if isinstance(op, SetMem):
524+
dest = op.dest.lhs if isinstance(op.dest, IntOp) else op.dest
525+
if (
526+
isinstance(dest, LoadMem)
527+
and isinstance(dest.src, GetElementPtr)
528+
and dest.src.src == value
529+
):
530+
item = _extract_python_literal(op.src, blocks)
531+
if item is _NOT_REPRESENTABLE:
532+
return _NOT_REPRESENTABLE
533+
result.append(item)
534+
return result
535+
536+
537+
def _extract_set(value: CallC, blocks: list[BasicBlock]) -> object:
538+
result = set()
539+
for block in blocks:
540+
for op in block.ops:
541+
if isinstance(op, CallC) and op.function_name == "PySet_Add" and op.args[0] == value:
542+
item = _extract_python_literal(op.args[1], blocks)
543+
if item is _NOT_REPRESENTABLE:
544+
return _NOT_REPRESENTABLE
545+
result.add(item)
546+
return result
547+
548+
549+
class _EllipsisLiteral:
550+
def __repr__(self) -> str:
551+
return "..."

mypyc/ir/module_ir.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from mypyc.common import JsonDict
5+
from mypyc.common import TOP_LEVEL_NAME, JsonDict
66
from mypyc.ir.class_ir import ClassIR
77
from mypyc.ir.func_ir import FuncDecl, FuncIR
88
from mypyc.ir.ops import DeserMaps
@@ -90,3 +90,11 @@ def deserialize_modules(data: dict[str, JsonDict], ctx: DeserMaps) -> dict[str,
9090
# ModulesIRs should also always be an *OrderedDict*, but if we
9191
# declared it that way we would need to put it in quotes everywhere...
9292
ModuleIRs = dict[str, ModuleIR]
93+
94+
95+
def get_module_top_level(module: ModuleIR) -> FuncIR:
96+
# Optimization: we tend to put the top level last, so reverse iterate
97+
for fn in reversed(module.functions):
98+
if fn.name == TOP_LEVEL_NAME:
99+
return fn
100+
assert False, f"module '{module.fullname}' missing '{TOP_LEVEL_NAME}' function"

mypyc/test-data/run-signatures.test

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,16 @@ def default_tuple_empty(x=()): pass
4242
def default_tuple_literals(x=(1, "a", 1.0, False, True, None, (), (1,2,(3,4)))): pass
4343
def default_tuple_singleton(x=(1,)): pass
4444
def default_named_constant(x=A): pass
45+
def default_complex(x=1+2j): pass
46+
def default_list_empty(x=[]): pass
47+
def default_list_literals(x=[1, 2, 3]): pass
48+
def default_dict_empty(x={}): pass
49+
def default_dict_literals(x={'a': 1}): pass
50+
def default_set_literals(x={1, 2, 3}): pass
51+
def default_tuple_collections(x=([1], {2}, {3: 4})): pass
52+
def default_nested_collections(x={'list': [1], 'set': {2}, 'dict': {3: 4}}): pass
53+
def default_literal_fold(x=1+2): pass
54+
def default_ellipsis(x=...): pass
4555

4656
[file driver.py]
4757
import inspect
@@ -56,6 +66,16 @@ assert str(inspect.signature(default_none)) == "(x=None)"
5666
assert str(inspect.signature(default_tuple_empty)) == "(x=())"
5767
assert str(inspect.signature(default_tuple_literals)) == "(x=(1, 'a', 1.0, False, True, None, (), (1, 2, (3, 4))))"
5868
assert str(inspect.signature(default_named_constant)) == "(x=1)"
69+
assert str(inspect.signature(default_complex)) == "(x=(1+2j))"
70+
assert str(inspect.signature(default_list_empty)) == "(x=[])"
71+
assert str(inspect.signature(default_list_literals)) == "(x=[1, 2, 3])"
72+
assert str(inspect.signature(default_dict_empty)) == "(x={})"
73+
assert str(inspect.signature(default_dict_literals)) == "(x={'a': 1})"
74+
assert str(inspect.signature(default_set_literals)) == "(x={1, 2, 3})"
75+
assert str(inspect.signature(default_tuple_collections)) == "(x=([1], {2}, {3: 4}))"
76+
assert str(inspect.signature(default_nested_collections)) == "(x={'list': [1], 'set': {2}, 'dict': {3: 4}})"
77+
assert str(inspect.signature(default_literal_fold)) == "(x=3)"
78+
assert str(inspect.signature(default_ellipsis)) == "(x=Ellipsis)"
5979

6080
# Check __text_signature__ directly since inspect.signature produces
6181
# an incorrect signature for 1-tuple default arguments prior to
@@ -95,19 +115,6 @@ def bad_set_empty(x=set()): pass # supported by ast.literal_eval, but not by in
95115
def bad_nan(x=float("nan")): pass
96116
def bad_enum(x=Color.RED): pass
97117

98-
# TODO: Default arguments that could potentially be represented in a
99-
# __text_signature__, but which are not currently supported.
100-
# See 'inspect._signature_fromstr' for what default values are supported at runtime.
101-
def bad_complex(x=1+2j): pass
102-
def bad_list_empty(x=[]): pass
103-
def bad_list_literals(x=[1, 2, 3]): pass
104-
def bad_dict_empty(x={}): pass
105-
def bad_dict_literals(x={'a': 1}): pass
106-
def bad_set_literals(x={1, 2, 3}): pass
107-
def bad_tuple_literals(x=([1, 2, 3], {'a': 1}, {1, 2, 3})): pass
108-
def bad_ellipsis(x=...): pass
109-
def bad_literal_fold(x=1+2): pass
110-
111118
[file driver.py]
112119
import inspect
113120
from testutil import assertRaises
@@ -155,7 +162,7 @@ class HasInit:
155162
class InheritedInit(HasInit): pass
156163

157164
class HasInitBad:
158-
def __init__(self, x=[]) -> None: pass
165+
def __init__(self, x=object()) -> None: pass
159166

160167
[file driver.py]
161168
import inspect

0 commit comments

Comments
 (0)