Files
ai/sandbox/dexorder/tools/research_harness.py

151 lines
5.1 KiB
Python

#!/usr/bin/env python3
"""
Research script harness - runs implementation.py in a subprocess with API
initialization, stdout/stderr capture, and matplotlib figure capture.
This file is written to disk and invoked by python_tools.py rather than
being passed inline via `python -c`, so the harness code is inspectable and
not regenerated on every call.
Usage:
python -m dexorder.tools.research_harness <implementation_path>
Output (JSON to stdout):
{
"stdout": "captured user stdout",
"stderr": "captured user stderr",
"images": [{"format": "png", "data": "<base64>"}],
"error": false
}
"""
import sys
import io
import os
import base64
import json
from pathlib import Path
# Non-interactive matplotlib backend (must be set before importing pyplot)
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
# Ensure dexorder package is importable
sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
# ---------------------------------------------------------------------------
# Initialize API from config files so research scripts can call get_api()
# ---------------------------------------------------------------------------
try:
import yaml
config_path = os.environ.get("CONFIG_PATH", "/app/config/config.yaml")
secrets_path = os.environ.get("SECRETS_PATH", "/app/config/secrets.yaml")
config_data = {}
secrets_data = {}
if Path(config_path).exists():
with open(config_path) as f:
config_data = yaml.safe_load(f) or {}
if Path(secrets_path).exists():
with open(secrets_path) as f:
secrets_data = yaml.safe_load(f) or {}
data_cfg = config_data.get("data", {})
iceberg_cfg = data_cfg.get("iceberg", {})
relay_cfg = data_cfg.get("relay", {})
from dexorder.api import set_api, API
from dexorder.impl.charting_api_impl import ChartingAPIImpl
from dexorder.impl.data_api_impl import DataAPIImpl
_data_api = DataAPIImpl(
iceberg_catalog_uri=iceberg_cfg.get("catalog_uri", "http://iceberg-catalog:8181"),
relay_endpoint=relay_cfg.get("endpoint", "tcp://relay:5559"),
notification_endpoint=relay_cfg.get("notification_endpoint", "tcp://relay:5558"),
namespace=iceberg_cfg.get("namespace", "trading"),
s3_endpoint=iceberg_cfg.get("s3_endpoint") or secrets_data.get("s3_endpoint"),
s3_access_key=iceberg_cfg.get("s3_access_key") or secrets_data.get("s3_access_key"),
s3_secret_key=iceberg_cfg.get("s3_secret_key") or secrets_data.get("s3_secret_key"),
)
# NOTE: We intentionally do NOT call asyncio.run(_data_api.start()) here.
# DataAPIImpl.historical_ohlc() auto-starts on first use, which ensures the
# ZMQ context and notification listener are created inside the user's own
# asyncio.run() event loop — avoiding cross-loop lifecycle issues.
set_api(API(charting=ChartingAPIImpl(), data=_data_api))
except Exception as e:
print(f"WARNING: API initialization failed: {e}", file=sys.stderr)
# ---------------------------------------------------------------------------
# Register custom indicators so research scripts can use df.ta.my_indicator()
# ---------------------------------------------------------------------------
try:
from dexorder.tools.python_tools import setup_custom_indicators
_data_dir = Path(os.environ.get("DATA_DIR", "/app/data"))
setup_custom_indicators(_data_dir)
except Exception as e:
print(f"WARNING: Custom indicator registration failed: {e}", file=sys.stderr)
def main():
if len(sys.argv) < 2:
print("Usage: research_harness.py <implementation_path>", file=sys.stderr)
sys.exit(2)
impl_path = Path(sys.argv[1])
if not impl_path.exists():
print(json.dumps({
"stdout": "",
"stderr": f"Implementation file not found: {impl_path}",
"images": [],
"error": True,
}))
sys.exit(0)
# Capture stdout and stderr
old_stdout = sys.stdout
old_stderr = sys.stderr
sys.stdout = io.StringIO()
sys.stderr = io.StringIO()
error_occurred = False
try:
exec(compile(impl_path.read_text(), str(impl_path), 'exec'), {})
except Exception as e:
print(f"ERROR: {e}", file=sys.stderr)
import traceback
traceback.print_exc(file=sys.stderr)
error_occurred = True
# Restore stdout/stderr
stdout_output = sys.stdout.getvalue()
stderr_output = sys.stderr.getvalue()
sys.stdout = old_stdout
sys.stderr = old_stderr
# Capture all matplotlib figures as base64 PNGs
images = []
for fig_num in plt.get_fignums():
fig = plt.figure(fig_num)
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=100, bbox_inches='tight')
buf.seek(0)
img_b64 = base64.b64encode(buf.read()).decode('utf-8')
images.append({"format": "png", "data": img_b64})
buf.close()
plt.close('all')
# Output results as JSON to real stdout
result = {
"stdout": stdout_output,
"stderr": stderr_output,
"images": images,
"error": error_occurred,
}
print(json.dumps(result))
if __name__ == "__main__":
main()