shutdown fixes
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import logging.config
|
import logging.config
|
||||||
import tomllib
|
import tomllib
|
||||||
@@ -15,25 +16,16 @@ if __name__ == '__main__':
|
|||||||
raise Exception('this file is meant to be imported not executed')
|
raise Exception('this file is meant to be imported not executed')
|
||||||
|
|
||||||
|
|
||||||
ignorable_exceptions = [CancelledError]
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
async def _shutdown_coro(_sig, _loop, extra_shutdown):
|
async def _shutdown_coro(_sig, _loop):
|
||||||
log.info('shutting down')
|
this_task = asyncio.current_task()
|
||||||
if extra_shutdown is not None:
|
for task in asyncio.all_tasks():
|
||||||
extra_shutdown()
|
if task is not this_task:
|
||||||
tasks = [t for t in asyncio.all_tasks() if t is not
|
task.cancel()
|
||||||
asyncio.current_task()]
|
|
||||||
for task in tasks:
|
|
||||||
task.cancel()
|
|
||||||
exceptions = await asyncio.gather(*tasks, return_exceptions=True)
|
|
||||||
for x in exceptions:
|
|
||||||
if x is not None and x.__class__ not in ignorable_exceptions:
|
|
||||||
print_exception(x)
|
|
||||||
|
|
||||||
|
|
||||||
def execute(main:Coroutine, shutdown=None, *, parse_logging=True, parse_args=True):
|
def execute(main:Coroutine, shutdown=None, *, parse_logging=True, parse_args=True):
|
||||||
|
# config
|
||||||
configured = False
|
configured = False
|
||||||
if parse_logging:
|
if parse_logging:
|
||||||
try:
|
try:
|
||||||
@@ -51,10 +43,14 @@ def execute(main:Coroutine, shutdown=None, *, parse_logging=True, parse_args=Tru
|
|||||||
log.info('Logging configured to default')
|
log.info('Logging configured to default')
|
||||||
if parse_args:
|
if parse_args:
|
||||||
configuration.parse_args()
|
configuration.parse_args()
|
||||||
|
|
||||||
|
# loop setup
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
signals = Signals.SIGQUIT, Signals.SIGTERM, Signals.SIGINT
|
signals = Signals.SIGQUIT, Signals.SIGTERM, Signals.SIGINT
|
||||||
for s in signals:
|
for s in signals:
|
||||||
loop.add_signal_handler(s, lambda sig=s: asyncio.create_task(_shutdown_coro(sig, loop, shutdown), name=f'{s.name} handler'))
|
loop.add_signal_handler(s, lambda sig=s: asyncio.create_task(_shutdown_coro(sig, loop), name=f'{s.name} handler'))
|
||||||
|
|
||||||
|
# main
|
||||||
task = loop.create_task(main, name='main')
|
task = loop.create_task(main, name='main')
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(task)
|
loop.run_until_complete(task)
|
||||||
@@ -62,13 +58,31 @@ def execute(main:Coroutine, shutdown=None, *, parse_logging=True, parse_args=Tru
|
|||||||
pass
|
pass
|
||||||
except Exception as x:
|
except Exception as x:
|
||||||
print_exception(x)
|
print_exception(x)
|
||||||
|
|
||||||
|
# shutdown tasks
|
||||||
|
log.info('shutdown')
|
||||||
|
if shutdown is not None:
|
||||||
|
sd = shutdown()
|
||||||
|
if inspect.isawaitable(sd):
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(sd)
|
||||||
|
except Exception as x:
|
||||||
|
log.error(f'Exception during shutdown: {x}')
|
||||||
|
print_exception(x)
|
||||||
|
|
||||||
|
# cancel anything remaining
|
||||||
try:
|
try:
|
||||||
remaining_tasks = asyncio.all_tasks()
|
remaining_tasks = asyncio.all_tasks()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
for t in remaining_tasks:
|
for task in remaining_tasks:
|
||||||
t.cancel()
|
task.cancel()
|
||||||
loop.run_until_complete(asyncio.gather(*remaining_tasks))
|
results = loop.run_until_complete(asyncio.gather(*remaining_tasks, return_exceptions=True))
|
||||||
|
for x in results:
|
||||||
|
if isinstance(x, Exception):
|
||||||
|
print_exception(x)
|
||||||
|
|
||||||
|
# end
|
||||||
loop.stop()
|
loop.stop()
|
||||||
loop.close()
|
loop.close()
|
||||||
|
|||||||
Reference in New Issue
Block a user