11# Adapted with permission from the EdgeDB project;
22# license: PSFL.
33
4-
4+ import sys
5+ import gc
56import asyncio
67import contextvars
78import contextlib
1112
1213from test .test_asyncio .utils import await_without_task
1314
14-
1515# To prevent a warning "test altered the execution environment"
1616def tearDownModule ():
1717 asyncio .set_event_loop_policy (None )
@@ -29,6 +29,15 @@ def get_error_types(eg):
2929 return {type (exc ) for exc in eg .exceptions }
3030
3131
32+ def no_other_refs ():
33+ # due to gh-124392 coroutines now refer to their locals
34+ coro = asyncio .current_task ().get_coro ()
35+ frame = sys ._getframe (1 )
36+ while coro .cr_frame != frame :
37+ coro = coro .cr_await
38+ return [coro ]
39+
40+
3241class TestTaskGroup (unittest .IsolatedAsyncioTestCase ):
3342
3443 async def test_taskgroup_01 (self ):
@@ -899,6 +908,95 @@ async def outer():
899908
900909 await outer ()
901910
911+ async def test_exception_refcycles_direct (self ):
912+ """Test that TaskGroup doesn't keep a reference to the raised ExceptionGroup"""
913+ tg = asyncio .TaskGroup ()
914+ exc = None
915+
916+ class _Done (Exception ):
917+ pass
918+
919+ try :
920+ async with tg :
921+ raise _Done
922+ except ExceptionGroup as e :
923+ exc = e
924+
925+ self .assertIsNotNone (exc )
926+ self .assertListEqual (gc .get_referrers (exc ), no_other_refs ())
927+
928+
929+ async def test_exception_refcycles_errors (self ):
930+ """Test that TaskGroup deletes self._errors, and __aexit__ args"""
931+ tg = asyncio .TaskGroup ()
932+ exc = None
933+
934+ class _Done (Exception ):
935+ pass
936+
937+ try :
938+ async with tg :
939+ raise _Done
940+ except* _Done as excs :
941+ exc = excs .exceptions [0 ]
942+
943+ self .assertIsInstance (exc , _Done )
944+ self .assertListEqual (gc .get_referrers (exc ), no_other_refs ())
945+
946+
947+ async def test_exception_refcycles_parent_task (self ):
948+ """Test that TaskGroup deletes self._parent_task"""
949+ tg = asyncio .TaskGroup ()
950+ exc = None
951+
952+ class _Done (Exception ):
953+ pass
954+
955+ async def coro_fn ():
956+ async with tg :
957+ raise _Done
958+
959+ try :
960+ async with asyncio .TaskGroup () as tg2 :
961+ tg2 .create_task (coro_fn ())
962+ except* _Done as excs :
963+ exc = excs .exceptions [0 ].exceptions [0 ]
964+
965+ self .assertIsInstance (exc , _Done )
966+ self .assertListEqual (gc .get_referrers (exc ), no_other_refs ())
967+
968+ async def test_exception_refcycles_propagate_cancellation_error (self ):
969+ """Test that TaskGroup deletes propagate_cancellation_error"""
970+ tg = asyncio .TaskGroup ()
971+ exc = None
972+
973+ try :
974+ async with asyncio .timeout (- 1 ):
975+ async with tg :
976+ await asyncio .sleep (0 )
977+ except TimeoutError as e :
978+ exc = e .__cause__
979+
980+ self .assertIsInstance (exc , asyncio .CancelledError )
981+ self .assertListEqual (gc .get_referrers (exc ), no_other_refs ())
982+
983+ async def test_exception_refcycles_base_error (self ):
984+ """Test that TaskGroup deletes self._base_error"""
985+ class MyKeyboardInterrupt (KeyboardInterrupt ):
986+ pass
987+
988+ tg = asyncio .TaskGroup ()
989+ exc = None
990+
991+ try :
992+ async with tg :
993+ raise MyKeyboardInterrupt
994+ except MyKeyboardInterrupt as e :
995+ exc = e
996+
997+ self .assertIsNotNone (exc )
998+ self .assertListEqual (gc .get_referrers (exc ), no_other_refs ())
999+
9021000
9031001if __name__ == "__main__" :
9041002 unittest .main ()
0 commit comments