140 lines
4.6 KiB
Python
140 lines
4.6 KiB
Python
import asyncio
|
|
import inspect
|
|
import logging
|
|
import logging.config
|
|
import sys
|
|
import tomllib
|
|
from asyncio import CancelledError
|
|
from signal import Signals
|
|
from traceback import print_exception
|
|
from typing import Coroutine, Callable, Union, Any
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
from dexorder import configuration, config
|
|
from dexorder.alert import init_alerts
|
|
from dexorder.configuration.schema import Config
|
|
from dexorder.metric.metric_startup import start_metrics_server
|
|
|
|
if __name__ == '__main__':
|
|
raise Exception('this file is meant to be imported not executed')
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
async def _shutdown_coro(_sig, _loop):
|
|
this_task = asyncio.current_task()
|
|
for task in asyncio.all_tasks():
|
|
if task is not this_task:
|
|
task.cancel()
|
|
|
|
|
|
def split_args():
|
|
omegaconf_args = []
|
|
regular_args = []
|
|
for arg in sys.argv[1:]:
|
|
if '=' in arg and not arg.startswith('--'):
|
|
key, value = arg.split('=', 1)
|
|
if hasattr(Config, key):
|
|
omegaconf_args.append(arg)
|
|
continue
|
|
regular_args.append(arg)
|
|
return omegaconf_args, regular_args
|
|
|
|
|
|
def execute(main:Callable[...,Coroutine[Any,Any,Any]], shutdown=None, *, parse_logging=True,
|
|
parse_args: Union[Callable[[list[str]],Any], type, bool]=True):
|
|
"""
|
|
if parse_args is a function, then the command-line arguments are given to OmegaConf first, and any args parsed by
|
|
OmegaConf are stripped from the args list. The remaining args are then passed to parse_args(args)
|
|
if parse_args is a type, then the type is used to parse the extra command-line arguments using OmegaConf.
|
|
"""
|
|
# config
|
|
configured = False
|
|
if parse_logging:
|
|
try:
|
|
with open('logging.toml', 'rb') as file:
|
|
dictconf = tomllib.load(file)
|
|
except FileNotFoundError:
|
|
pass
|
|
else:
|
|
logging.config.dictConfig(dictconf)
|
|
log.info('Logging configured from logging.toml')
|
|
configured = True
|
|
if not configured:
|
|
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
|
|
log.setLevel(logging.DEBUG)
|
|
log.info('Logging configured to default')
|
|
xconf = None
|
|
if parse_args:
|
|
# NOTE: there is special command-line argument handling in config/load.py to get a config filename.
|
|
# The -c/--config flag MUST BE FIRST if present.
|
|
# The rest of the arguments are split by format into key=value for omegaconf and anything else is "regular args"
|
|
omegaconf_args, regular_args = split_args()
|
|
configuration.parse_args(omegaconf_args)
|
|
# must check for `type` before `callable`, because types are also callables
|
|
if isinstance(parse_args, type):
|
|
# noinspection PyUnboundLocalVariable
|
|
xconf = OmegaConf.merge(OmegaConf.structured(parse_args), OmegaConf.from_cli(regular_args))
|
|
elif callable(parse_args):
|
|
# noinspection PyUnboundLocalVariable
|
|
xconf = parse_args(regular_args)
|
|
else:
|
|
# just pass the regular args to main
|
|
xconf = regular_args
|
|
|
|
init_alerts()
|
|
|
|
if config.metrics_port:
|
|
start_metrics_server()
|
|
|
|
# loop setup
|
|
loop = asyncio.get_event_loop()
|
|
signals = Signals.SIGQUIT, Signals.SIGTERM, Signals.SIGINT
|
|
for s in signals:
|
|
loop.add_signal_handler(s, lambda sig=s: asyncio.create_task(_shutdown_coro(sig, loop), name=f'{s.name} handler'))
|
|
|
|
# main
|
|
num_args = len(inspect.signature(main).parameters)
|
|
if num_args == 0:
|
|
coro = main()
|
|
elif num_args == 1:
|
|
coro = main(xconf)
|
|
else:
|
|
raise Exception(f'main() must accept 0 or 1 arguments, not {num_args}')
|
|
task = loop.create_task(coro, name='main')
|
|
try:
|
|
loop.run_until_complete(task)
|
|
except CancelledError:
|
|
pass
|
|
except Exception as x:
|
|
print_exception(x)
|
|
|
|
# shutdown tasks
|
|
log.info('shutdown')
|
|
if shutdown is not None:
|
|
sd = shutdown()
|
|
if inspect.isawaitable(sd):
|
|
try:
|
|
loop.run_until_complete(sd)
|
|
except Exception as x:
|
|
log.error(f'Exception during shutdown: {x}')
|
|
print_exception(x)
|
|
|
|
# cancel anything remaining
|
|
try:
|
|
remaining_tasks = asyncio.all_tasks()
|
|
except RuntimeError:
|
|
pass
|
|
else:
|
|
for task in remaining_tasks:
|
|
task.cancel()
|
|
results = loop.run_until_complete(asyncio.gather(*remaining_tasks, return_exceptions=True))
|
|
for x in results:
|
|
if isinstance(x, Exception):
|
|
print_exception(x)
|
|
|
|
# end
|
|
loop.stop()
|
|
loop.close()
|