composable cli config

This commit is contained in:
tim
2025-02-27 17:51:07 -04:00
parent c132f40164
commit e868ea5a4b
7 changed files with 44 additions and 10 deletions

View File

@@ -74,4 +74,4 @@ async def main():
if __name__ == '__main__': if __name__ == '__main__':
execute(main()) execute(main)

View File

@@ -37,4 +37,4 @@ if __name__ == '__main__':
time = parse_date(sys.argv[1], ignoretz=True).replace(tzinfo=timezone.utc) time = parse_date(sys.argv[1], ignoretz=True).replace(tzinfo=timezone.utc)
seconds_per_block = float(sys.argv[2]) seconds_per_block = float(sys.argv[2])
sys.argv = [sys.argv[0], *sys.argv[3:]] sys.argv = [sys.argv[0], *sys.argv[3:]]
execute(main()) execute(main)

View File

@@ -7,10 +7,11 @@ import tomllib
from asyncio import CancelledError from asyncio import CancelledError
from signal import Signals from signal import Signals
from traceback import print_exception from traceback import print_exception
from typing import Coroutine from typing import Coroutine, Callable, Union, Any
from dexorder import configuration, config from dexorder import configuration, config
from dexorder.alert import init_alerts from dexorder.alert import init_alerts
from dexorder.configuration.schema import Config
from dexorder.metric.metric_startup import start_metrics_server from dexorder.metric.metric_startup import start_metrics_server
if __name__ == '__main__': if __name__ == '__main__':
@@ -25,7 +26,25 @@ async def _shutdown_coro(_sig, _loop):
if task is not this_task: if task is not this_task:
task.cancel() task.cancel()
def execute(main:Coroutine, shutdown=None, *, parse_logging=True, parse_args=True):
def split_args():
omegaconf_args = []
regular_args = []
for arg in sys.argv[1:]:
if '=' in arg:
key, value = arg.split('=', 1)
if hasattr(Config, key):
omegaconf_args.append(arg)
else:
regular_args.append(arg)
return omegaconf_args, regular_args
def execute(main:Callable[...,Coroutine[Any,Any,Any]], shutdown=None, *, parse_logging=True, parse_args: Union[Callable[[list[str]],Any],bool]=True):
"""
if parse_args is a function, then the command-line arguments are given to OmegaConf first, and any args parsed by
OmegaConf are stripped from the args list. The remaining args are then passed to parse_args(args)
"""
# config # config
configured = False configured = False
if parse_logging: if parse_logging:
@@ -42,10 +61,18 @@ def execute(main:Coroutine, shutdown=None, *, parse_logging=True, parse_args=Tru
logging.basicConfig(level=logging.INFO, stream=sys.stdout) logging.basicConfig(level=logging.INFO, stream=sys.stdout)
log.setLevel(logging.DEBUG) log.setLevel(logging.DEBUG)
log.info('Logging configured to default') log.info('Logging configured to default')
xconf = None
if parse_args: if parse_args:
if callable(parse_args):
omegaconf_args, regular_args = split_args()
else:
omegaconf_args = None
# NOTE: there is special command-line argument handling in config/load.py to get a config filename. # NOTE: there is special command-line argument handling in config/load.py to get a config filename.
# The -c/--config flag MUST BE FIRST if present. # The -c/--config flag MUST BE FIRST if present.
configuration.parse_args() configuration.parse_args(omegaconf_args)
if callable(parse_args):
# noinspection PyUnboundLocalVariable
xconf = parse_args(regular_args)
init_alerts() init_alerts()
@@ -59,7 +86,14 @@ def execute(main:Coroutine, shutdown=None, *, parse_logging=True, parse_args=Tru
loop.add_signal_handler(s, lambda sig=s: asyncio.create_task(_shutdown_coro(sig, loop), name=f'{s.name} handler')) loop.add_signal_handler(s, lambda sig=s: asyncio.create_task(_shutdown_coro(sig, loop), name=f'{s.name} handler'))
# main # main
task = loop.create_task(main, name='main') num_args = len(inspect.signature(main).parameters)
if num_args == 0:
coro = main()
elif num_args == 1:
coro = main(xconf)
else:
raise Exception(f'main() must accept 0 or 1 arguments, not {num_args}')
task = loop.create_task(coro, name='main')
try: try:
loop.run_until_complete(task) loop.run_until_complete(task)
except CancelledError: except CancelledError:

View File

@@ -62,4 +62,4 @@ async def main():
if __name__ == '__main__': if __name__ == '__main__':
execute(main()) execute(main)

View File

@@ -138,4 +138,4 @@ async def main():
if __name__ == '__main__': if __name__ == '__main__':
execute(main()) execute(main)

View File

@@ -216,4 +216,4 @@ async def main():
if __name__ == '__main__': if __name__ == '__main__':
execute(main()) execute(main)

View File

@@ -28,5 +28,5 @@ async def main():
if __name__ == '__main__': if __name__ == '__main__':
execute(main()) execute(main)