Compare commits
11 Commits
311df9aeda
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| c76887ab92 | |||
| f6bd22a8ef | |||
| b9cc397e05 | |||
| e99ef5d2dd | |||
| 8ff277c8c6 | |||
| a50955558e | |||
| 185fa42caa | |||
| bf7af2b426 | |||
| f4da40706c | |||
| 3ffce97b3e | |||
| 3b29096dab |
175
.agents/skills/better-auth-best-practices/SKILL.md
Normal file
175
.agents/skills/better-auth-best-practices/SKILL.md
Normal file
@@ -0,0 +1,175 @@
|
||||
---
|
||||
name: better-auth-best-practices
|
||||
description: Configure Better Auth server and client, set up database adapters, manage sessions, add plugins, and handle environment variables. Use when users mention Better Auth, betterauth, auth.ts, or need to set up TypeScript authentication with email/password, OAuth, or plugin configuration.
|
||||
---
|
||||
|
||||
# Better Auth Integration Guide
|
||||
|
||||
**Always consult [better-auth.com/docs](https://better-auth.com/docs) for code examples and latest API.**
|
||||
|
||||
---
|
||||
|
||||
## Setup Workflow
|
||||
|
||||
1. Install: `npm install better-auth`
|
||||
2. Set env vars: `BETTER_AUTH_SECRET` and `BETTER_AUTH_URL`
|
||||
3. Create `auth.ts` with database + config
|
||||
4. Create route handler for your framework
|
||||
5. Run `npx @better-auth/cli@latest migrate`
|
||||
6. Verify: call `GET /api/auth/ok` — should return `{ status: "ok" }`
|
||||
|
||||
---
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Environment Variables
|
||||
- `BETTER_AUTH_SECRET` - Encryption secret (min 32 chars). Generate: `openssl rand -base64 32`
|
||||
- `BETTER_AUTH_URL` - Base URL (e.g., `https://example.com`)
|
||||
|
||||
Only define `baseURL`/`secret` in config if env vars are NOT set.
|
||||
|
||||
### File Location
|
||||
CLI looks for `auth.ts` in: `./`, `./lib`, `./utils`, or under `./src`. Use `--config` for custom path.
|
||||
|
||||
### CLI Commands
|
||||
- `npx @better-auth/cli@latest migrate` - Apply schema (built-in adapter)
|
||||
- `npx @better-auth/cli@latest generate` - Generate schema for Prisma/Drizzle
|
||||
- `npx @better-auth/cli mcp --cursor` - Add MCP to AI tools
|
||||
|
||||
**Re-run after adding/changing plugins.**
|
||||
|
||||
---
|
||||
|
||||
## Core Config Options
|
||||
|
||||
| Option | Notes |
|
||||
|--------|-------|
|
||||
| `appName` | Optional display name |
|
||||
| `baseURL` | Only if `BETTER_AUTH_URL` not set |
|
||||
| `basePath` | Default `/api/auth`. Set `/` for root. |
|
||||
| `secret` | Only if `BETTER_AUTH_SECRET` not set |
|
||||
| `database` | Required for most features. See adapters docs. |
|
||||
| `secondaryStorage` | Redis/KV for sessions & rate limits |
|
||||
| `emailAndPassword` | `{ enabled: true }` to activate |
|
||||
| `socialProviders` | `{ google: { clientId, clientSecret }, ... }` |
|
||||
| `plugins` | Array of plugins |
|
||||
| `trustedOrigins` | CSRF whitelist |
|
||||
|
||||
---
|
||||
|
||||
## Database
|
||||
|
||||
**Direct connections:** Pass `pg.Pool`, `mysql2` pool, `better-sqlite3`, or `bun:sqlite` instance.
|
||||
|
||||
**ORM adapters:** Import from `better-auth/adapters/drizzle`, `better-auth/adapters/prisma`, `better-auth/adapters/mongodb`.
|
||||
|
||||
**Critical:** Better Auth uses adapter model names, NOT underlying table names. If Prisma model is `User` mapping to table `users`, use `modelName: "user"` (Prisma reference), not `"users"`.
|
||||
|
||||
---
|
||||
|
||||
## Session Management
|
||||
|
||||
**Storage priority:**
|
||||
1. If `secondaryStorage` defined → sessions go there (not DB)
|
||||
2. Set `session.storeSessionInDatabase: true` to also persist to DB
|
||||
3. No database + `cookieCache` → fully stateless mode
|
||||
|
||||
**Cookie cache strategies:**
|
||||
- `compact` (default) - Base64url + HMAC. Smallest.
|
||||
- `jwt` - Standard JWT. Readable but signed.
|
||||
- `jwe` - Encrypted. Maximum security.
|
||||
|
||||
**Key options:** `session.expiresIn` (default 7 days), `session.updateAge` (refresh interval), `session.cookieCache.maxAge`, `session.cookieCache.version` (change to invalidate all sessions).
|
||||
|
||||
---
|
||||
|
||||
## User & Account Config
|
||||
|
||||
**User:** `user.modelName`, `user.fields` (column mapping), `user.additionalFields`, `user.changeEmail.enabled` (disabled by default), `user.deleteUser.enabled` (disabled by default).
|
||||
|
||||
**Account:** `account.modelName`, `account.accountLinking.enabled`, `account.storeAccountCookie` (for stateless OAuth).
|
||||
|
||||
**Required for registration:** `email` and `name` fields.
|
||||
|
||||
---
|
||||
|
||||
## Email Flows
|
||||
|
||||
- `emailVerification.sendVerificationEmail` - Must be defined for verification to work
|
||||
- `emailVerification.sendOnSignUp` / `sendOnSignIn` - Auto-send triggers
|
||||
- `emailAndPassword.sendResetPassword` - Password reset email handler
|
||||
|
||||
---
|
||||
|
||||
## Security
|
||||
|
||||
**In `advanced`:**
|
||||
- `useSecureCookies` - Force HTTPS cookies
|
||||
- `disableCSRFCheck` - ⚠️ Security risk
|
||||
- `disableOriginCheck` - ⚠️ Security risk
|
||||
- `crossSubDomainCookies.enabled` - Share cookies across subdomains
|
||||
- `ipAddress.ipAddressHeaders` - Custom IP headers for proxies
|
||||
- `database.generateId` - Custom ID generation or `"serial"`/`"uuid"`/`false`
|
||||
|
||||
**Rate limiting:** `rateLimit.enabled`, `rateLimit.window`, `rateLimit.max`, `rateLimit.storage` ("memory" | "database" | "secondary-storage").
|
||||
|
||||
---
|
||||
|
||||
## Hooks
|
||||
|
||||
**Endpoint hooks:** `hooks.before` / `hooks.after` - Array of `{ matcher, handler }`. Use `createAuthMiddleware`. Access `ctx.path`, `ctx.context.returned` (after), `ctx.context.session`.
|
||||
|
||||
**Database hooks:** `databaseHooks.user.create.before/after`, same for `session`, `account`. Useful for adding default values or post-creation actions.
|
||||
|
||||
**Hook context (`ctx.context`):** `session`, `secret`, `authCookies`, `password.hash()`/`verify()`, `adapter`, `internalAdapter`, `generateId()`, `tables`, `baseURL`.
|
||||
|
||||
---
|
||||
|
||||
## Plugins
|
||||
|
||||
**Import from dedicated paths for tree-shaking:**
|
||||
```
|
||||
import { twoFactor } from "better-auth/plugins/two-factor"
|
||||
```
|
||||
NOT `from "better-auth/plugins"`.
|
||||
|
||||
**Popular plugins:** `twoFactor`, `organization`, `passkey`, `magicLink`, `emailOtp`, `username`, `phoneNumber`, `admin`, `apiKey`, `bearer`, `jwt`, `multiSession`, `sso`, `oauthProvider`, `oidcProvider`, `openAPI`, `genericOAuth`.
|
||||
|
||||
Client plugins go in `createAuthClient({ plugins: [...] })`.
|
||||
|
||||
---
|
||||
|
||||
## Client
|
||||
|
||||
Import from: `better-auth/client` (vanilla), `better-auth/react`, `better-auth/vue`, `better-auth/svelte`, `better-auth/solid`.
|
||||
|
||||
Key methods: `signUp.email()`, `signIn.email()`, `signIn.social()`, `signOut()`, `useSession()`, `getSession()`, `revokeSession()`, `revokeSessions()`.
|
||||
|
||||
---
|
||||
|
||||
## Type Safety
|
||||
|
||||
Infer types: `typeof auth.$Infer.Session`, `typeof auth.$Infer.Session.user`.
|
||||
|
||||
For separate client/server projects: `createAuthClient<typeof auth>()`.
|
||||
|
||||
---
|
||||
|
||||
## Common Gotchas
|
||||
|
||||
1. **Model vs table name** - Config uses ORM model name, not DB table name
|
||||
2. **Plugin schema** - Re-run CLI after adding plugins
|
||||
3. **Secondary storage** - Sessions go there by default, not DB
|
||||
4. **Cookie cache** - Custom session fields NOT cached, always re-fetched
|
||||
5. **Stateless mode** - No DB = session in cookie only, logout on cache expiry
|
||||
6. **Change email flow** - Sends to current email first, then new email
|
||||
|
||||
---
|
||||
|
||||
## Resources
|
||||
|
||||
- [Docs](https://better-auth.com/docs)
|
||||
- [Options Reference](https://better-auth.com/docs/reference/options)
|
||||
- [LLMs.txt](https://better-auth.com/llms.txt)
|
||||
- [GitHub](https://github.com/better-auth/better-auth)
|
||||
- [Init Options Source](https://github.com/better-auth/better-auth/blob/main/packages/core/src/types/init-options.ts)
|
||||
321
.agents/skills/create-auth-skill/SKILL.md
Normal file
321
.agents/skills/create-auth-skill/SKILL.md
Normal file
@@ -0,0 +1,321 @@
|
||||
---
|
||||
name: create-auth-skill
|
||||
description: Scaffold and implement authentication in TypeScript/JavaScript apps using Better Auth. Detect frameworks, configure database adapters, set up route handlers, add OAuth providers, and create auth UI pages. Use when users want to add login, sign-up, or authentication to a new or existing project with Better Auth.
|
||||
---
|
||||
|
||||
# Create Auth Skill
|
||||
|
||||
Guide for adding authentication to TypeScript/JavaScript applications using Better Auth.
|
||||
|
||||
**For code examples and syntax, see [better-auth.com/docs](https://better-auth.com/docs).**
|
||||
|
||||
---
|
||||
|
||||
## Phase 1: Planning (REQUIRED before implementation)
|
||||
|
||||
Before writing any code, gather requirements by scanning the project and asking the user structured questions. This ensures the implementation matches their needs.
|
||||
|
||||
### Step 1: Scan the project
|
||||
|
||||
Analyze the codebase to auto-detect:
|
||||
- **Framework** — Look for `next.config`, `svelte.config`, `nuxt.config`, `astro.config`, `vite.config`, or Express/Hono entry files.
|
||||
- **Database/ORM** — Look for `prisma/schema.prisma`, `drizzle.config`, `package.json` deps (`pg`, `mysql2`, `better-sqlite3`, `mongoose`, `mongodb`).
|
||||
- **Existing auth** — Look for existing auth libraries (`next-auth`, `lucia`, `clerk`, `supabase/auth`, `firebase/auth`) in `package.json` or imports.
|
||||
- **Package manager** — Check for `pnpm-lock.yaml`, `yarn.lock`, `bun.lockb`, or `package-lock.json`.
|
||||
|
||||
Use what you find to pre-fill defaults and skip questions you can already answer.
|
||||
|
||||
### Step 2: Ask planning questions
|
||||
|
||||
Use the `AskQuestion` tool to ask the user **all applicable questions in a single call**. Skip any question you already have a confident answer for from the scan. Group them under a title like "Auth Setup Planning".
|
||||
|
||||
**Questions to ask:**
|
||||
|
||||
1. **Project type** (skip if detected)
|
||||
- Prompt: "What type of project is this?"
|
||||
- Options: New project from scratch | Adding auth to existing project | Migrating from another auth library
|
||||
|
||||
2. **Framework** (skip if detected)
|
||||
- Prompt: "Which framework are you using?"
|
||||
- Options: Next.js (App Router) | Next.js (Pages Router) | SvelteKit | Nuxt | Astro | Express | Hono | SolidStart | Other
|
||||
|
||||
3. **Database & ORM** (skip if detected)
|
||||
- Prompt: "Which database setup will you use?"
|
||||
- Options: PostgreSQL (Prisma) | PostgreSQL (Drizzle) | PostgreSQL (pg driver) | MySQL (Prisma) | MySQL (Drizzle) | MySQL (mysql2 driver) | SQLite (Prisma) | SQLite (Drizzle) | SQLite (better-sqlite3 driver) | MongoDB (Mongoose) | MongoDB (native driver)
|
||||
|
||||
4. **Authentication methods** (always ask, allow multiple)
|
||||
- Prompt: "Which sign-in methods do you need?"
|
||||
- Options: Email & password | Social OAuth (Google, GitHub, etc.) | Magic link (passwordless email) | Passkey (WebAuthn) | Phone number
|
||||
- `allow_multiple: true`
|
||||
|
||||
5. **Social providers** (only if they selected Social OAuth above — ask in a follow-up call)
|
||||
- Prompt: "Which social providers do you need?"
|
||||
- Options: Google | GitHub | Apple | Microsoft | Discord | Twitter/X
|
||||
- `allow_multiple: true`
|
||||
|
||||
6. **Email verification** (only if Email & password was selected above — ask in a follow-up call)
|
||||
- Prompt: "Do you want to require email verification?"
|
||||
- Options: Yes | No
|
||||
|
||||
7. **Email provider** (only if email verification is Yes, or if Password reset is selected in features — ask in a follow-up call)
|
||||
- Prompt: "How do you want to send emails?"
|
||||
- Options: Resend | Mock it for now (console.log)
|
||||
|
||||
8. **Features & plugins** (always ask, allow multiple)
|
||||
- Prompt: "Which additional features do you need?"
|
||||
- Options: Two-factor authentication (2FA) | Organizations / teams | Admin dashboard | API bearer tokens | Password reset | None of these
|
||||
- `allow_multiple: true`
|
||||
|
||||
9. **Auth pages** (always ask, allow multiple — pre-select based on earlier answers)
|
||||
- Prompt: "Which auth pages do you need?"
|
||||
- Options vary based on previous answers:
|
||||
- Always available: Sign in | Sign up
|
||||
- If Email & password selected: Forgot password | Reset password
|
||||
- If email verification enabled: Email verification
|
||||
- `allow_multiple: true`
|
||||
|
||||
10. **Auth UI style** (always ask)
|
||||
- Prompt: "What style do you want for the auth pages? Pick one or describe your own."
|
||||
- Options: Minimal & clean | Centered card with background | Split layout (form + hero image) | Floating / glassmorphism | Other (I'll describe)
|
||||
|
||||
### Step 3: Summarize the plan
|
||||
|
||||
After collecting answers, present a concise implementation plan as a markdown checklist. Example:
|
||||
|
||||
```
|
||||
## Auth Implementation Plan
|
||||
|
||||
- **Framework:** Next.js (App Router)
|
||||
- **Database:** PostgreSQL via Prisma
|
||||
- **Auth methods:** Email/password, Google OAuth, GitHub OAuth
|
||||
- **Plugins:** 2FA, Organizations, Email verification
|
||||
- **UI:** Custom forms
|
||||
|
||||
### Steps
|
||||
1. Install `better-auth` and `@better-auth/cli`
|
||||
2. Create `lib/auth.ts` with server config
|
||||
3. Create `lib/auth-client.ts` with React client
|
||||
4. Set up route handler at `app/api/auth/[...all]/route.ts`
|
||||
5. Configure Prisma adapter and generate schema
|
||||
6. Add Google & GitHub OAuth providers
|
||||
7. Enable `twoFactor` and `organization` plugins
|
||||
8. Set up email verification handler
|
||||
9. Run migrations
|
||||
10. Create sign-in / sign-up pages
|
||||
```
|
||||
|
||||
Ask the user to confirm the plan before proceeding to Phase 2.
|
||||
|
||||
---
|
||||
|
||||
## Phase 2: Implementation
|
||||
|
||||
Only proceed here after the user confirms the plan from Phase 1.
|
||||
|
||||
Follow the decision tree below, guided by the answers collected above.
|
||||
|
||||
```
|
||||
Is this a new/empty project?
|
||||
├─ YES → New project setup
|
||||
│ 1. Install better-auth (+ scoped packages per plan)
|
||||
│ 2. Create auth.ts with all planned config
|
||||
│ 3. Create auth-client.ts with framework client
|
||||
│ 4. Set up route handler
|
||||
│ 5. Set up environment variables
|
||||
│ 6. Run CLI migrate/generate
|
||||
│ 7. Add plugins from plan
|
||||
│ 8. Create auth UI pages
|
||||
│
|
||||
├─ MIGRATING → Migration from existing auth
|
||||
│ 1. Audit current auth for gaps
|
||||
│ 2. Plan incremental migration
|
||||
│ 3. Install better-auth alongside existing auth
|
||||
│ 4. Migrate routes, then session logic, then UI
|
||||
│ 5. Remove old auth library
|
||||
│ 6. See migration guides in docs
|
||||
│
|
||||
└─ ADDING → Add auth to existing project
|
||||
1. Analyze project structure
|
||||
2. Install better-auth
|
||||
3. Create auth config matching plan
|
||||
4. Add route handler
|
||||
5. Run schema migrations
|
||||
6. Integrate into existing pages
|
||||
7. Add planned plugins and features
|
||||
```
|
||||
|
||||
At the end of implementation, guide users thoroughly on remaining next steps (e.g., setting up OAuth app credentials, deploying env vars, testing flows).
|
||||
|
||||
---
|
||||
|
||||
## Installation
|
||||
|
||||
**Core:** `npm install better-auth`
|
||||
|
||||
**Scoped packages (as needed):**
|
||||
| Package | Use case |
|
||||
|---------|----------|
|
||||
| `@better-auth/passkey` | WebAuthn/Passkey auth |
|
||||
| `@better-auth/sso` | SAML/OIDC enterprise SSO |
|
||||
| `@better-auth/stripe` | Stripe payments |
|
||||
| `@better-auth/scim` | SCIM user provisioning |
|
||||
| `@better-auth/expo` | React Native/Expo |
|
||||
|
||||
---
|
||||
|
||||
## Environment Variables
|
||||
|
||||
```env
|
||||
BETTER_AUTH_SECRET=<32+ chars, generate with: openssl rand -base64 32>
|
||||
BETTER_AUTH_URL=http://localhost:3000
|
||||
DATABASE_URL=<your database connection string>
|
||||
```
|
||||
|
||||
Add OAuth secrets as needed: `GITHUB_CLIENT_ID`, `GITHUB_CLIENT_SECRET`, `GOOGLE_CLIENT_ID`, etc.
|
||||
|
||||
---
|
||||
|
||||
## Server Config (auth.ts)
|
||||
|
||||
**Location:** `lib/auth.ts` or `src/lib/auth.ts`
|
||||
|
||||
**Minimal config needs:**
|
||||
- `database` - Connection or adapter
|
||||
- `emailAndPassword: { enabled: true }` - For email/password auth
|
||||
|
||||
**Standard config adds:**
|
||||
- `socialProviders` - OAuth providers (google, github, etc.)
|
||||
- `emailVerification.sendVerificationEmail` - Email verification handler
|
||||
- `emailAndPassword.sendResetPassword` - Password reset handler
|
||||
|
||||
**Full config adds:**
|
||||
- `plugins` - Array of feature plugins
|
||||
- `session` - Expiry, cookie cache settings
|
||||
- `account.accountLinking` - Multi-provider linking
|
||||
- `rateLimit` - Rate limiting config
|
||||
|
||||
**Export types:** `export type Session = typeof auth.$Infer.Session`
|
||||
|
||||
---
|
||||
|
||||
## Client Config (auth-client.ts)
|
||||
|
||||
**Import by framework:**
|
||||
| Framework | Import |
|
||||
|-----------|--------|
|
||||
| React/Next.js | `better-auth/react` |
|
||||
| Vue | `better-auth/vue` |
|
||||
| Svelte | `better-auth/svelte` |
|
||||
| Solid | `better-auth/solid` |
|
||||
| Vanilla JS | `better-auth/client` |
|
||||
|
||||
**Client plugins** go in `createAuthClient({ plugins: [...] })`.
|
||||
|
||||
**Common exports:** `signIn`, `signUp`, `signOut`, `useSession`, `getSession`
|
||||
|
||||
---
|
||||
|
||||
## Route Handler Setup
|
||||
|
||||
| Framework | File | Handler |
|
||||
|-----------|------|---------|
|
||||
| Next.js App Router | `app/api/auth/[...all]/route.ts` | `toNextJsHandler(auth)` → export `{ GET, POST }` |
|
||||
| Next.js Pages | `pages/api/auth/[...all].ts` | `toNextJsHandler(auth)` → default export |
|
||||
| Express | Any file | `app.all("/api/auth/*", toNodeHandler(auth))` |
|
||||
| SvelteKit | `src/hooks.server.ts` | `svelteKitHandler(auth)` |
|
||||
| SolidStart | Route file | `solidStartHandler(auth)` |
|
||||
| Hono | Route file | `auth.handler(c.req.raw)` |
|
||||
|
||||
**Next.js Server Components:** Add `nextCookies()` plugin to auth config.
|
||||
|
||||
---
|
||||
|
||||
## Database Migrations
|
||||
|
||||
| Adapter | Command |
|
||||
|---------|---------|
|
||||
| Built-in Kysely | `npx @better-auth/cli@latest migrate` (applies directly) |
|
||||
| Prisma | `npx @better-auth/cli@latest generate --output prisma/schema.prisma` then `npx prisma migrate dev` |
|
||||
| Drizzle | `npx @better-auth/cli@latest generate --output src/db/auth-schema.ts` then `npx drizzle-kit push` |
|
||||
|
||||
**Re-run after adding plugins.**
|
||||
|
||||
---
|
||||
|
||||
## Database Adapters
|
||||
|
||||
| Database | Setup |
|
||||
|----------|-------|
|
||||
| SQLite | Pass `better-sqlite3` or `bun:sqlite` instance directly |
|
||||
| PostgreSQL | Pass `pg.Pool` instance directly |
|
||||
| MySQL | Pass `mysql2` pool directly |
|
||||
| Prisma | `prismaAdapter(prisma, { provider: "postgresql" })` from `better-auth/adapters/prisma` |
|
||||
| Drizzle | `drizzleAdapter(db, { provider: "pg" })` from `better-auth/adapters/drizzle` |
|
||||
| MongoDB | `mongodbAdapter(db)` from `better-auth/adapters/mongodb` |
|
||||
|
||||
---
|
||||
|
||||
## Common Plugins
|
||||
|
||||
| Plugin | Server Import | Client Import | Purpose |
|
||||
|--------|---------------|---------------|---------|
|
||||
| `twoFactor` | `better-auth/plugins` | `twoFactorClient` | 2FA with TOTP/OTP |
|
||||
| `organization` | `better-auth/plugins` | `organizationClient` | Teams/orgs |
|
||||
| `admin` | `better-auth/plugins` | `adminClient` | User management |
|
||||
| `bearer` | `better-auth/plugins` | - | API token auth |
|
||||
| `openAPI` | `better-auth/plugins` | - | API docs |
|
||||
| `passkey` | `@better-auth/passkey` | `passkeyClient` | WebAuthn |
|
||||
| `sso` | `@better-auth/sso` | - | Enterprise SSO |
|
||||
|
||||
**Plugin pattern:** Server plugin + client plugin + run migrations.
|
||||
|
||||
---
|
||||
|
||||
## Auth UI Implementation
|
||||
|
||||
**Sign in flow:**
|
||||
1. `signIn.email({ email, password })` or `signIn.social({ provider, callbackURL })`
|
||||
2. Handle `error` in response
|
||||
3. Redirect on success
|
||||
|
||||
**Session check (client):** `useSession()` hook returns `{ data: session, isPending }`
|
||||
|
||||
**Session check (server):** `auth.api.getSession({ headers: await headers() })`
|
||||
|
||||
**Protected routes:** Check session, redirect to `/sign-in` if null.
|
||||
|
||||
---
|
||||
|
||||
## Security Checklist
|
||||
|
||||
- [ ] `BETTER_AUTH_SECRET` set (32+ chars)
|
||||
- [ ] `advanced.useSecureCookies: true` in production
|
||||
- [ ] `trustedOrigins` configured
|
||||
- [ ] Rate limits enabled
|
||||
- [ ] Email verification enabled
|
||||
- [ ] Password reset implemented
|
||||
- [ ] 2FA for sensitive apps
|
||||
- [ ] CSRF protection NOT disabled
|
||||
- [ ] `account.accountLinking` reviewed
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Issue | Fix |
|
||||
|-------|-----|
|
||||
| "Secret not set" | Add `BETTER_AUTH_SECRET` env var |
|
||||
| "Invalid Origin" | Add domain to `trustedOrigins` |
|
||||
| Cookies not setting | Check `baseURL` matches domain; enable secure cookies in prod |
|
||||
| OAuth callback errors | Verify redirect URIs in provider dashboard |
|
||||
| Type errors after adding plugin | Re-run CLI generate/migrate |
|
||||
|
||||
---
|
||||
|
||||
## Resources
|
||||
|
||||
- [Docs](https://better-auth.com/docs)
|
||||
- [Examples](https://github.com/better-auth/examples)
|
||||
- [Plugins](https://better-auth.com/docs/concepts/plugins)
|
||||
- [CLI](https://better-auth.com/docs/concepts/cli)
|
||||
- [Migration Guides](https://better-auth.com/docs/guides)
|
||||
212
.agents/skills/email-and-password-best-practices/SKILL.md
Normal file
212
.agents/skills/email-and-password-best-practices/SKILL.md
Normal file
@@ -0,0 +1,212 @@
|
||||
---
|
||||
name: email-and-password-best-practices
|
||||
description: Configure email verification, implement password reset flows, set password policies, and customise hashing algorithms for Better Auth email/password authentication. Use when users need to set up login, sign-in, sign-up, credential authentication, or password security with Better Auth.
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
1. Enable email/password: `emailAndPassword: { enabled: true }`
|
||||
2. Configure `emailVerification.sendVerificationEmail`
|
||||
3. Add `sendResetPassword` for password reset flows
|
||||
4. Run `npx @better-auth/cli@latest migrate`
|
||||
5. Verify: attempt sign-up and confirm verification email triggers
|
||||
|
||||
---
|
||||
|
||||
## Email Verification Setup
|
||||
|
||||
Configure `emailVerification.sendVerificationEmail` to verify user email addresses.
|
||||
|
||||
```ts
|
||||
import { betterAuth } from "better-auth";
|
||||
import { sendEmail } from "./email"; // your email sending function
|
||||
|
||||
export const auth = betterAuth({
|
||||
emailVerification: {
|
||||
sendVerificationEmail: async ({ user, url, token }, request) => {
|
||||
await sendEmail({
|
||||
to: user.email,
|
||||
subject: "Verify your email address",
|
||||
text: `Click the link to verify your email: ${url}`,
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
**Note**: The `url` parameter contains the full verification link. The `token` is available if you need to build a custom verification URL.
|
||||
|
||||
### Requiring Email Verification
|
||||
|
||||
For stricter security, enable `emailAndPassword.requireEmailVerification` to block sign-in until the user verifies their email. When enabled, unverified users will receive a new verification email on each sign-in attempt.
|
||||
|
||||
```ts
|
||||
export const auth = betterAuth({
|
||||
emailAndPassword: {
|
||||
requireEmailVerification: true,
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
**Note**: This requires `sendVerificationEmail` to be configured and only applies to email/password sign-ins.
|
||||
|
||||
## Client Side Validation
|
||||
|
||||
Implement client-side validation for immediate user feedback and reduced server load.
|
||||
|
||||
## Callback URLs
|
||||
|
||||
Always use absolute URLs (including the origin) for callback URLs in sign-up and sign-in requests. This prevents Better Auth from needing to infer the origin, which can cause issues when your backend and frontend are on different domains.
|
||||
|
||||
```ts
|
||||
const { data, error } = await authClient.signUp.email({
|
||||
callbackURL: "https://example.com/callback", // absolute URL with origin
|
||||
});
|
||||
```
|
||||
|
||||
## Password Reset Flows
|
||||
|
||||
Provide `sendResetPassword` in the email and password config to enable password resets.
|
||||
|
||||
```ts
|
||||
import { betterAuth } from "better-auth";
|
||||
import { sendEmail } from "./email"; // your email sending function
|
||||
|
||||
export const auth = betterAuth({
|
||||
emailAndPassword: {
|
||||
enabled: true,
|
||||
// Custom email sending function to send reset-password email
|
||||
sendResetPassword: async ({ user, url, token }, request) => {
|
||||
void sendEmail({
|
||||
to: user.email,
|
||||
subject: "Reset your password",
|
||||
text: `Click the link to reset your password: ${url}`,
|
||||
});
|
||||
},
|
||||
// Optional event hook
|
||||
onPasswordReset: async ({ user }, request) => {
|
||||
// your logic here
|
||||
console.log(`Password for user ${user.email} has been reset.`);
|
||||
},
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
### Security Considerations
|
||||
|
||||
Built-in protections: background email sending (timing attack prevention), dummy operations on invalid requests, constant response messages regardless of user existence.
|
||||
|
||||
On serverless platforms, configure a background task handler:
|
||||
|
||||
```ts
|
||||
export const auth = betterAuth({
|
||||
advanced: {
|
||||
backgroundTasks: {
|
||||
handler: (promise) => {
|
||||
// Use platform-specific methods like waitUntil
|
||||
waitUntil(promise);
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
#### Token Security
|
||||
|
||||
Tokens expire after 1 hour by default. Configure with `resetPasswordTokenExpiresIn` (in seconds):
|
||||
|
||||
```ts
|
||||
export const auth = betterAuth({
|
||||
emailAndPassword: {
|
||||
enabled: true,
|
||||
resetPasswordTokenExpiresIn: 60 * 30, // 30 minutes
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
Tokens are single-use — deleted immediately after successful reset.
|
||||
|
||||
#### Session Revocation
|
||||
|
||||
Enable `revokeSessionsOnPasswordReset` to invalidate all existing sessions on password reset:
|
||||
|
||||
```ts
|
||||
export const auth = betterAuth({
|
||||
emailAndPassword: {
|
||||
enabled: true,
|
||||
revokeSessionsOnPasswordReset: true,
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
#### Password Requirements
|
||||
|
||||
Password length limits (configurable):
|
||||
|
||||
```ts
|
||||
export const auth = betterAuth({
|
||||
emailAndPassword: {
|
||||
enabled: true,
|
||||
minPasswordLength: 12,
|
||||
maxPasswordLength: 256,
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
### Sending the Password Reset
|
||||
|
||||
Call `requestPasswordReset` to send the reset link. Triggers the `sendResetPassword` function from your config.
|
||||
|
||||
```ts
|
||||
const data = await auth.api.requestPasswordReset({
|
||||
body: {
|
||||
email: "john.doe@example.com", // required
|
||||
redirectTo: "https://example.com/reset-password",
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
Or authClient:
|
||||
|
||||
```ts
|
||||
const { data, error } = await authClient.requestPasswordReset({
|
||||
email: "john.doe@example.com", // required
|
||||
redirectTo: "https://example.com/reset-password",
|
||||
});
|
||||
```
|
||||
|
||||
**Note**: While the `email` is required, we also recommend configuring the `redirectTo` for a smoother user experience.
|
||||
|
||||
## Password Hashing
|
||||
|
||||
Default: `scrypt` (Node.js native, no external dependencies).
|
||||
|
||||
### Custom Hashing Algorithm
|
||||
|
||||
To use Argon2id or another algorithm, provide custom `hash` and `verify` functions:
|
||||
|
||||
```ts
|
||||
import { betterAuth } from "better-auth";
|
||||
import { hash, verify, type Options } from "@node-rs/argon2";
|
||||
|
||||
const argon2Options: Options = {
|
||||
memoryCost: 65536, // 64 MiB
|
||||
timeCost: 3, // 3 iterations
|
||||
parallelism: 4, // 4 parallel lanes
|
||||
outputLen: 32, // 32 byte output
|
||||
algorithm: 2, // Argon2id variant
|
||||
};
|
||||
|
||||
export const auth = betterAuth({
|
||||
emailAndPassword: {
|
||||
enabled: true,
|
||||
password: {
|
||||
hash: (password) => hash(password, argon2Options),
|
||||
verify: ({ password, hash: storedHash }) =>
|
||||
verify(storedHash, password, argon2Options),
|
||||
},
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
**Note**: If you switch hashing algorithms on an existing system, users with passwords hashed using the old algorithm won't be able to sign in. Plan a migration strategy if needed.
|
||||
331
.agents/skills/two-factor-authentication-best-practices/SKILL.md
Normal file
331
.agents/skills/two-factor-authentication-best-practices/SKILL.md
Normal file
@@ -0,0 +1,331 @@
|
||||
---
|
||||
name: two-factor-authentication-best-practices
|
||||
description: Configure TOTP authenticator apps, send OTP codes via email/SMS, manage backup codes, handle trusted devices, and implement 2FA sign-in flows using Better Auth's twoFactor plugin. Use when users need MFA, multi-factor authentication, authenticator setup, or login security with Better Auth.
|
||||
---
|
||||
|
||||
## Setup
|
||||
|
||||
1. Add `twoFactor()` plugin to server config with `issuer`
|
||||
2. Add `twoFactorClient()` plugin to client config
|
||||
3. Run `npx @better-auth/cli migrate`
|
||||
4. Verify: check that `twoFactorSecret` column exists on user table
|
||||
|
||||
```ts
|
||||
import { betterAuth } from "better-auth";
|
||||
import { twoFactor } from "better-auth/plugins";
|
||||
|
||||
export const auth = betterAuth({
|
||||
appName: "My App",
|
||||
plugins: [
|
||||
twoFactor({
|
||||
issuer: "My App",
|
||||
}),
|
||||
],
|
||||
});
|
||||
```
|
||||
|
||||
### Client-Side Setup
|
||||
|
||||
```ts
|
||||
import { createAuthClient } from "better-auth/client";
|
||||
import { twoFactorClient } from "better-auth/client/plugins";
|
||||
|
||||
export const authClient = createAuthClient({
|
||||
plugins: [
|
||||
twoFactorClient({
|
||||
onTwoFactorRedirect() {
|
||||
window.location.href = "/2fa";
|
||||
},
|
||||
}),
|
||||
],
|
||||
});
|
||||
```
|
||||
|
||||
## Enabling 2FA for Users
|
||||
|
||||
Requires password verification. Returns TOTP URI (for QR code) and backup codes.
|
||||
|
||||
```ts
|
||||
const enable2FA = async (password: string) => {
|
||||
const { data, error } = await authClient.twoFactor.enable({
|
||||
password,
|
||||
});
|
||||
|
||||
if (data) {
|
||||
// data.totpURI — generate a QR code from this
|
||||
// data.backupCodes — display to user
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
`twoFactorEnabled` is not set to `true` until first TOTP verification succeeds. Override with `skipVerificationOnEnable: true` (not recommended).
|
||||
|
||||
## TOTP (Authenticator App)
|
||||
|
||||
### Displaying the QR Code
|
||||
|
||||
```tsx
|
||||
import QRCode from "react-qr-code";
|
||||
|
||||
const TotpSetup = ({ totpURI }: { totpURI: string }) => {
|
||||
return <QRCode value={totpURI} />;
|
||||
};
|
||||
```
|
||||
|
||||
### Verifying TOTP Codes
|
||||
|
||||
Accepts codes from one period before/after current time:
|
||||
|
||||
```ts
|
||||
const verifyTotp = async (code: string) => {
|
||||
const { data, error } = await authClient.twoFactor.verifyTotp({
|
||||
code,
|
||||
trustDevice: true,
|
||||
});
|
||||
};
|
||||
```
|
||||
|
||||
### TOTP Configuration Options
|
||||
|
||||
```ts
|
||||
twoFactor({
|
||||
totpOptions: {
|
||||
digits: 6, // 6 or 8 digits (default: 6)
|
||||
period: 30, // Code validity period in seconds (default: 30)
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
## OTP (Email/SMS)
|
||||
|
||||
### Configuring OTP Delivery
|
||||
|
||||
```ts
|
||||
import { betterAuth } from "better-auth";
|
||||
import { twoFactor } from "better-auth/plugins";
|
||||
import { sendEmail } from "./email";
|
||||
|
||||
export const auth = betterAuth({
|
||||
plugins: [
|
||||
twoFactor({
|
||||
otpOptions: {
|
||||
sendOTP: async ({ user, otp }, ctx) => {
|
||||
await sendEmail({
|
||||
to: user.email,
|
||||
subject: "Your verification code",
|
||||
text: `Your code is: ${otp}`,
|
||||
});
|
||||
},
|
||||
period: 5, // Code validity in minutes (default: 3)
|
||||
digits: 6, // Number of digits (default: 6)
|
||||
allowedAttempts: 5, // Max verification attempts (default: 5)
|
||||
},
|
||||
}),
|
||||
],
|
||||
});
|
||||
```
|
||||
|
||||
### Sending and Verifying OTP
|
||||
|
||||
Send: `authClient.twoFactor.sendOtp()`. Verify: `authClient.twoFactor.verifyOtp({ code, trustDevice: true })`.
|
||||
|
||||
### OTP Storage Security
|
||||
|
||||
Configure how OTP codes are stored in the database:
|
||||
|
||||
```ts
|
||||
twoFactor({
|
||||
otpOptions: {
|
||||
storeOTP: "encrypted", // Options: "plain", "encrypted", "hashed"
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
For custom encryption:
|
||||
|
||||
```ts
|
||||
twoFactor({
|
||||
otpOptions: {
|
||||
storeOTP: {
|
||||
encrypt: async (token) => myEncrypt(token),
|
||||
decrypt: async (token) => myDecrypt(token),
|
||||
},
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
## Backup Codes
|
||||
|
||||
Generated automatically when 2FA is enabled. Each code is single-use.
|
||||
|
||||
### Displaying Backup Codes
|
||||
|
||||
```tsx
|
||||
const BackupCodes = ({ codes }: { codes: string[] }) => {
|
||||
return (
|
||||
<div>
|
||||
<p>Save these codes in a secure location:</p>
|
||||
<ul>
|
||||
{codes.map((code, i) => (
|
||||
<li key={i}>{code}</li>
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
```
|
||||
|
||||
### Regenerating Backup Codes
|
||||
|
||||
Invalidates all previous codes:
|
||||
|
||||
```ts
|
||||
const regenerateBackupCodes = async (password: string) => {
|
||||
const { data, error } = await authClient.twoFactor.generateBackupCodes({
|
||||
password,
|
||||
});
|
||||
// data.backupCodes contains the new codes
|
||||
};
|
||||
```
|
||||
|
||||
### Using Backup Codes for Recovery
|
||||
|
||||
```ts
|
||||
const verifyBackupCode = async (code: string) => {
|
||||
const { data, error } = await authClient.twoFactor.verifyBackupCode({
|
||||
code,
|
||||
trustDevice: true,
|
||||
});
|
||||
};
|
||||
```
|
||||
|
||||
### Backup Code Configuration
|
||||
|
||||
```ts
|
||||
twoFactor({
|
||||
backupCodeOptions: {
|
||||
amount: 10, // Number of codes to generate (default: 10)
|
||||
length: 10, // Length of each code (default: 10)
|
||||
storeBackupCodes: "encrypted", // Options: "plain", "encrypted"
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
## Handling 2FA During Sign-In
|
||||
|
||||
Response includes `twoFactorRedirect: true` when 2FA is required:
|
||||
|
||||
### Sign-In Flow
|
||||
|
||||
1. Call `signIn.email({ email, password })`
|
||||
2. Check `context.data.twoFactorRedirect` in `onSuccess`
|
||||
3. If `true`, redirect to `/2fa` verification page
|
||||
4. Verify via TOTP, OTP, or backup code
|
||||
5. Session cookie is created on successful verification
|
||||
|
||||
```ts
|
||||
const signIn = async (email: string, password: string) => {
|
||||
const { data, error } = await authClient.signIn.email(
|
||||
{ email, password },
|
||||
{
|
||||
onSuccess(context) {
|
||||
if (context.data.twoFactorRedirect) {
|
||||
window.location.href = "/2fa";
|
||||
}
|
||||
},
|
||||
}
|
||||
);
|
||||
};
|
||||
```
|
||||
|
||||
Server-side: check `"twoFactorRedirect" in response` when using `auth.api.signInEmail`.
|
||||
|
||||
## Trusted Devices
|
||||
|
||||
Pass `trustDevice: true` when verifying. Default trust duration: 30 days (`trustDeviceMaxAge`). Refreshes on each sign-in.
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### Session Management
|
||||
|
||||
Flow: credentials → session removed → temporary 2FA cookie (10 min default) → verify → session created.
|
||||
|
||||
```ts
|
||||
twoFactor({
|
||||
twoFactorCookieMaxAge: 600, // 10 minutes in seconds (default)
|
||||
});
|
||||
```
|
||||
|
||||
### Rate Limiting
|
||||
|
||||
Built-in: 3 requests per 10 seconds for all 2FA endpoints. OTP has additional attempt limiting:
|
||||
|
||||
```ts
|
||||
twoFactor({
|
||||
otpOptions: {
|
||||
allowedAttempts: 5, // Max attempts per OTP code (default: 5)
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
### Encryption at Rest
|
||||
|
||||
TOTP secrets: encrypted with auth secret. Backup codes: encrypted by default. OTP: configurable (`"plain"`, `"encrypted"`, `"hashed"`). Uses constant-time comparison for verification.
|
||||
|
||||
2FA can only be enabled for credential (email/password) accounts.
|
||||
|
||||
## Disabling 2FA
|
||||
|
||||
Requires password confirmation. Revokes trusted device records:
|
||||
|
||||
```ts
|
||||
const disable2FA = async (password: string) => {
|
||||
const { data, error } = await authClient.twoFactor.disable({
|
||||
password,
|
||||
});
|
||||
};
|
||||
```
|
||||
|
||||
## Complete Configuration Example
|
||||
|
||||
```ts
|
||||
import { betterAuth } from "better-auth";
|
||||
import { twoFactor } from "better-auth/plugins";
|
||||
import { sendEmail } from "./email";
|
||||
|
||||
export const auth = betterAuth({
|
||||
appName: "My App",
|
||||
plugins: [
|
||||
twoFactor({
|
||||
// TOTP settings
|
||||
issuer: "My App",
|
||||
totpOptions: {
|
||||
digits: 6,
|
||||
period: 30,
|
||||
},
|
||||
// OTP settings
|
||||
otpOptions: {
|
||||
sendOTP: async ({ user, otp }) => {
|
||||
await sendEmail({
|
||||
to: user.email,
|
||||
subject: "Your verification code",
|
||||
text: `Your code is: ${otp}`,
|
||||
});
|
||||
},
|
||||
period: 5,
|
||||
allowedAttempts: 5,
|
||||
storeOTP: "encrypted",
|
||||
},
|
||||
// Backup code settings
|
||||
backupCodeOptions: {
|
||||
amount: 10,
|
||||
length: 10,
|
||||
storeBackupCodes: "encrypted",
|
||||
},
|
||||
// Session settings
|
||||
twoFactorCookieMaxAge: 600, // 10 minutes
|
||||
trustDeviceMaxAge: 30 * 24 * 60 * 60, // 30 days
|
||||
}),
|
||||
],
|
||||
});
|
||||
```
|
||||
1
.claude/skills/better-auth-best-practices
Symbolic link
1
.claude/skills/better-auth-best-practices
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/better-auth-best-practices
|
||||
1
.claude/skills/create-auth-skill
Symbolic link
1
.claude/skills/create-auth-skill
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/create-auth-skill
|
||||
1
.claude/skills/email-and-password-best-practices
Symbolic link
1
.claude/skills/email-and-password-best-practices
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/email-and-password-best-practices
|
||||
1
.claude/skills/two-factor-authentication-best-practices
Symbolic link
1
.claude/skills/two-factor-authentication-best-practices
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/two-factor-authentication-best-practices
|
||||
25
.gitignore
vendored
25
.gitignore
vendored
@@ -1,5 +1,6 @@
|
||||
/backend/data
|
||||
/backend/uploads/
|
||||
/backend.old/data
|
||||
/backend.old/uploads/
|
||||
chat/
|
||||
|
||||
# Environment variables
|
||||
.env
|
||||
@@ -101,3 +102,23 @@ Thumbs.db
|
||||
*.swp
|
||||
*.swo
|
||||
*.bak
|
||||
|
||||
# Kubernetes secrets (never commit actual secrets!)
|
||||
deploy/k8s/dev/secrets/*.yaml
|
||||
deploy/k8s/prod/secrets/*.yaml
|
||||
!deploy/k8s/dev/secrets/*.yaml.example
|
||||
!deploy/k8s/prod/secrets/*.yaml.example
|
||||
|
||||
# Dev environment image tags
|
||||
.dev-image-tag
|
||||
|
||||
# Protobuf copies (canonical files are in /protobuf/)
|
||||
flink/protobuf/
|
||||
relay/protobuf/
|
||||
ingestor/protobuf/
|
||||
gateway/protobuf/
|
||||
client-py/protobuf/
|
||||
|
||||
# Generated protobuf code
|
||||
gateway/src/generated/
|
||||
client-py/dexorder/generated/
|
||||
|
||||
16
.idea/ai.iml
generated
16
.idea/ai.iml
generated
@@ -2,10 +2,20 @@
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<sourceFolder url="file://$MODULE_DIR$/backend/src" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/backend/tests" isTestSource="true" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/backend.old/src" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/backend.old/tests" isTestSource="true" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/client-py" isTestSource="false" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/.venv" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/backend/data" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/backend.old/data" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/doc.old" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/backend.old" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/client-py/dexorder_client.egg-info" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/flink/protobuf" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/flink/target" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/ingestor/protobuf" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/ingestor/src/proto" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/relay/protobuf" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/relay/target" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="Python 3.12 (ai)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
|
||||
12
.idea/runConfigurations/dev.xml
generated
12
.idea/runConfigurations/dev.xml
generated
@@ -1,12 +0,0 @@
|
||||
<component name="ProjectRunConfigurationManager">
|
||||
<configuration default="false" name="dev" type="js.build_tools.npm" nameIsGenerated="true">
|
||||
<package-json value="$PROJECT_DIR$/web/package.json" />
|
||||
<command value="run" />
|
||||
<scripts>
|
||||
<script value="dev" />
|
||||
</scripts>
|
||||
<node-interpreter value="project" />
|
||||
<envs />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
</component>
|
||||
1
.junie/skills/better-auth-best-practices
Symbolic link
1
.junie/skills/better-auth-best-practices
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/better-auth-best-practices
|
||||
1
.junie/skills/create-auth-skill
Symbolic link
1
.junie/skills/create-auth-skill
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/create-auth-skill
|
||||
1
.junie/skills/email-and-password-best-practices
Symbolic link
1
.junie/skills/email-and-password-best-practices
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/email-and-password-best-practices
|
||||
1
.junie/skills/two-factor-authentication-best-practices
Symbolic link
1
.junie/skills/two-factor-authentication-best-practices
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/two-factor-authentication-best-practices
|
||||
15
AGENT.md
Normal file
15
AGENT.md
Normal file
@@ -0,0 +1,15 @@
|
||||
We're building an AI-first trading platform by integrating user-facing TradingView charts and chat with an AI assistant that helps do research, develop indicators (signals), and write strategies, using the Dexorder trading framework we provide.
|
||||
|
||||
This monorepo has:
|
||||
bin/ scripts, mostly build and deploy
|
||||
deploy/ kubernetes deployment and configuration
|
||||
doc/ documentation
|
||||
flink/ Apache Flink application mode processes data from Kafka
|
||||
iceberg/ Apache Iceberg for historical OHLC etc
|
||||
ingestor/ Data sources publish to Kafka
|
||||
kafka/ Apache Kafka
|
||||
protobuf/ Messaging entities
|
||||
relay/ Rust+ZeroMQ stateless router
|
||||
web/ Vue 3 / Pinia / PrimeVue / TradingView
|
||||
|
||||
See doc/protocol.md for messaging architecture
|
||||
@@ -5,7 +5,7 @@ server_port: 8081
|
||||
agent:
|
||||
model: "claude-sonnet-4-20250514"
|
||||
temperature: 0.7
|
||||
context_docs_dir: "doc"
|
||||
context_docs_dir: "memory" # Context docs still loaded from memory/
|
||||
|
||||
# Local memory configuration (free & sophisticated!)
|
||||
memory:
|
||||
36
backend.old/memory/chart_context.md
Normal file
36
backend.old/memory/chart_context.md
Normal file
@@ -0,0 +1,36 @@
|
||||
# Chart Context Awareness
|
||||
|
||||
## When Users Reference "The Chart"
|
||||
|
||||
When a user asks about "this chart", "the chart", "what I'm viewing", or similar references to their current view:
|
||||
|
||||
1. **Chart info is automatically available** — The dynamic system prompt includes current chart state (symbol, interval, timeframe)
|
||||
2. **Check if chart is visible** — If ChartStore fields (symbol, interval) are `None`, the user is on a narrow screen (mobile) and no chart is visible
|
||||
3. **When chart is visible:**
|
||||
- **NEVER** ask the user to upload an image or tell you what symbol they're looking at
|
||||
- **Just use `execute_python()`** — It automatically loads the chart data from what they're viewing
|
||||
- Inside your Python script, `df` contains the data and `chart_context` has the metadata
|
||||
- Use `plot_ohlc(df)` to create beautiful candlestick charts
|
||||
4. **When chart is NOT visible (symbol is None):**
|
||||
- Let the user know they can view charts on a wider screen
|
||||
- You can still help with analysis using `get_historical_data()` if they specify a symbol
|
||||
|
||||
## Common Questions This Applies To
|
||||
|
||||
- "Can you see this chart?"
|
||||
- "What are the swing highs and lows?"
|
||||
- "Is this in an uptrend?"
|
||||
- "What's the current price?"
|
||||
- "Analyze this chart"
|
||||
- "What am I looking at?"
|
||||
|
||||
## Data Analysis Workflow
|
||||
|
||||
1. **Chart context is automatic** → Symbol, interval, and timeframe are in the dynamic system prompt (if chart is visible)
|
||||
2. **Check ChartStore** → If symbol/interval are `None`, no chart is visible (mobile view)
|
||||
3. **Use `execute_python()`** → This is your PRIMARY analysis tool
|
||||
- Automatically loads chart data into a pandas DataFrame `df` (if chart is visible)
|
||||
- Pre-imports numpy (`np`), pandas (`pd`), matplotlib (`plt`), and talib
|
||||
- Provides access to the indicator registry for computing indicators
|
||||
- Use `plot_ohlc(df)` helper for beautiful candlestick charts
|
||||
4. **Only use `get_chart_data()`** → For simple data inspection without analysis
|
||||
115
backend.old/memory/python_analysis.md
Normal file
115
backend.old/memory/python_analysis.md
Normal file
@@ -0,0 +1,115 @@
|
||||
# Python Analysis Tool Reference
|
||||
|
||||
## Python Analysis (`execute_python`) - Your Primary Tool
|
||||
|
||||
**ALWAYS use `execute_python()` when the user asks for:**
|
||||
- Technical indicators (RSI, MACD, Bollinger Bands, moving averages, etc.)
|
||||
- Chart visualizations or plots
|
||||
- Statistical calculations or market analysis
|
||||
- Pattern detection or trend analysis
|
||||
- Any computational analysis of price data
|
||||
|
||||
## Why `execute_python()` is preferred:
|
||||
- Chart data (`df`) is automatically loaded from ChartStore (visible time range) when chart is visible
|
||||
- If no chart is visible (symbol is None), `df` will be None - but you can still load alternative data!
|
||||
- Full pandas/numpy/talib stack pre-imported
|
||||
- Use `plot_ohlc(df)` for instant professional candlestick charts
|
||||
- Access to 150+ indicators via `indicator_registry`
|
||||
- **Access to DataStores and registry** - order_store, chart_store, datasource_registry
|
||||
- **Can load ANY symbol/timeframe** using datasource_registry even when df is None
|
||||
- **Results include plots as image URLs** that are automatically displayed to the user
|
||||
- Prints and return values are included in the response
|
||||
|
||||
## CRITICAL: Plots are automatically shown to the user
|
||||
When you create a matplotlib figure (via `plot_ohlc()` or `plt.figure()`), it is automatically:
|
||||
1. Saved as a PNG image
|
||||
2. Returned in the response as a URL (e.g., `/uploads/plot_abc123.png`)
|
||||
3. **Displayed in the user's chat interface** - they see the image immediately
|
||||
|
||||
You MUST use `execute_python()` with `plot_ohlc()` or matplotlib whenever the user wants to see a chart or plot.
|
||||
|
||||
## IMPORTANT: Never use `get_historical_data()` for chart analysis
|
||||
- `get_historical_data()` requires manual timestamp calculation and is only for custom queries
|
||||
- When analyzing what the user is viewing, ALWAYS use `execute_python()` which automatically loads the correct data
|
||||
- The `df` DataFrame in `execute_python()` is pre-loaded with the exact time range the user is viewing
|
||||
|
||||
## Example workflows:
|
||||
|
||||
### Computing an indicator and plotting (when chart is visible)
|
||||
```python
|
||||
execute_python("""
|
||||
df['RSI'] = talib.RSI(df['close'], 14)
|
||||
fig = plot_ohlc(df, title='Price with RSI')
|
||||
df[['close', 'RSI']].tail(10)
|
||||
""")
|
||||
```
|
||||
|
||||
### Multi-indicator analysis (when chart is visible)
|
||||
```python
|
||||
execute_python("""
|
||||
df['SMA20'] = df['close'].rolling(20).mean()
|
||||
df['BB_upper'] = df['close'].rolling(20).mean() + 2 * df['close'].rolling(20).std()
|
||||
df['BB_lower'] = df['close'].rolling(20).mean() - 2 * df['close'].rolling(20).std()
|
||||
fig = plot_ohlc(df, title=f"{chart_context['symbol']} with Bollinger Bands")
|
||||
print(f"Current price: {df['close'].iloc[-1]:.2f}")
|
||||
print(f"20-period SMA: {df['SMA20'].iloc[-1]:.2f}")
|
||||
""")
|
||||
```
|
||||
|
||||
### Loading alternative data (works even when chart not visible or for different symbols)
|
||||
```python
|
||||
execute_python("""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Get data source
|
||||
binance = datasource_registry.get_source('binance')
|
||||
|
||||
# Load data for any symbol/timeframe
|
||||
end_time = datetime.now()
|
||||
start_time = end_time - timedelta(days=7)
|
||||
|
||||
result = await binance.get_history(
|
||||
symbol='ETH/USDT',
|
||||
interval='1h',
|
||||
start=int(start_time.timestamp()),
|
||||
end=int(end_time.timestamp())
|
||||
)
|
||||
|
||||
# Convert to DataFrame
|
||||
rows = [{'time': pd.to_datetime(bar.time, unit='s'), **bar.data} for bar in result.bars]
|
||||
eth_df = pd.DataFrame(rows).set_index('time')
|
||||
|
||||
# Analyze and plot
|
||||
eth_df['RSI'] = talib.RSI(eth_df['close'], 14)
|
||||
fig = plot_ohlc(eth_df, title='ETH/USDT 1h - RSI Analysis')
|
||||
print(f"ETH RSI: {eth_df['RSI'].iloc[-1]:.2f}")
|
||||
""")
|
||||
```
|
||||
|
||||
### Access stores to see current state
|
||||
```python
|
||||
execute_python("""
|
||||
print(f"Current symbol: {chart_store.chart_state.symbol}")
|
||||
print(f"Current interval: {chart_store.chart_state.interval}")
|
||||
print(f"Number of orders: {len(order_store.orders)}")
|
||||
""")
|
||||
```
|
||||
|
||||
## Only use `get_chart_data()` for:
|
||||
- Quick inspection of raw bar data
|
||||
- When you just need the data structure without analysis
|
||||
|
||||
## Quick Reference: Common Tasks
|
||||
|
||||
| User Request | Tool to Use | Example |
|
||||
|--------------|-------------|---------|
|
||||
| "Show me RSI" | `execute_python()` | `df['RSI'] = talib.RSI(df['close'], 14); plot_ohlc(df)` |
|
||||
| "What's the current price?" | `execute_python()` | `print(f"Current: {df['close'].iloc[-1]}")` |
|
||||
| "Is this bullish?" | `execute_python()` | Compute SMAs, trend, and analyze |
|
||||
| "Add Bollinger Bands" | `execute_python()` | Compute bands, use `plot_ohlc(df, title='BB')` |
|
||||
| "Find swing highs" | `execute_python()` | Use pandas logic to detect patterns |
|
||||
| "Plot ETH even though I'm viewing BTC" | `execute_python()` | Use `datasource_registry.get_source('binance')` to load ETH data |
|
||||
| "What indicators exist?" | `search_indicators()` | Search by category or query |
|
||||
| "What chart am I viewing?" | N/A - automatic | Chart info is in dynamic system prompt |
|
||||
| "Check my orders" | `execute_python()` | `print(order_store.orders)` |
|
||||
| "Read other stores" | `read_sync_state(store_name)` | For TraderState, StrategyState, etc. |
|
||||
612
backend.old/memory/tradingview_shapes.md
Normal file
612
backend.old/memory/tradingview_shapes.md
Normal file
@@ -0,0 +1,612 @@
|
||||
# TradingView Shapes and Drawings Reference
|
||||
|
||||
This document describes the various drawing shapes and studies available in TradingView charts, their properties, and control points. This information is useful for the AI agent to understand, create, and manipulate chart drawings.
|
||||
|
||||
## Shape Structure
|
||||
|
||||
All shapes follow a common structure:
|
||||
- **id**: Unique identifier (string) - This is the TradingView-assigned ID after the shape is created
|
||||
- **type**: Shape type identifier (string)
|
||||
- **points**: Array of control points (each with `time` in Unix seconds and `price` as float)
|
||||
- **color**: Color as hex string (e.g., '#FF0000') or color name (e.g., 'red')
|
||||
- **line_width**: Line thickness in pixels (integer)
|
||||
- **line_style**: One of: 'solid', 'dashed', 'dotted'
|
||||
- **properties**: Dictionary of additional shape-specific properties
|
||||
- **symbol**: Trading pair symbol (e.g., 'BINANCE:BTC/USDT')
|
||||
- **created_at**: Creation timestamp (Unix seconds)
|
||||
- **modified_at**: Last modification timestamp (Unix seconds)
|
||||
- **original_id**: Optional string - The ID you requested when creating the shape, before TradingView assigned its own ID
|
||||
|
||||
## Understanding Shape ID Mapping
|
||||
|
||||
When you create a shape using `create_or_update_shape()`, there's an important ID mapping process:
|
||||
|
||||
1. **You specify an ID**: You provide a `shape_id` parameter (e.g., "my-support-line")
|
||||
2. **TradingView assigns its own ID**: When the shape is rendered in TradingView, it gets a new internal ID (e.g., "shape_0x1a2b3c4d")
|
||||
3. **ID remapping occurs**: The shape in the store is updated:
|
||||
- The `id` field becomes TradingView's ID
|
||||
- The `original_id` field preserves your requested ID
|
||||
4. **Tracking your shapes**: To find shapes you created, search by `original_id`
|
||||
|
||||
### Example ID Mapping Flow
|
||||
|
||||
```python
|
||||
# Step 1: Agent creates a shape with a specific ID
|
||||
await create_or_update_shape(
|
||||
shape_id="agent-support-50k",
|
||||
shape_type="horizontal_line",
|
||||
points=[{"time": 1678886400, "price": 50000}],
|
||||
color="#00FF00"
|
||||
)
|
||||
|
||||
# Step 2: Shape is synced to client and created in TradingView
|
||||
# TradingView assigns ID: "shape_0x1a2b3c4d"
|
||||
|
||||
# Step 3: Shape in store is updated with:
|
||||
# {
|
||||
# "id": "shape_0x1a2b3c4d", # TradingView's ID
|
||||
# "original_id": "agent-support-50k", # Your requested ID
|
||||
# "type": "horizontal_line",
|
||||
# ...
|
||||
# }
|
||||
|
||||
# Step 4: To find your shape later, use shape_ids (searches both id and original_id)
|
||||
my_shapes = search_shapes(
|
||||
shape_ids=['agent-support-50k'],
|
||||
symbol="BINANCE:BTC/USDT"
|
||||
)
|
||||
|
||||
if my_shapes:
|
||||
print(f"Found my support line!")
|
||||
print(f"TradingView ID: {my_shapes[0]['id']}")
|
||||
print(f"My requested ID: {my_shapes[0]['original_id']}")
|
||||
|
||||
# Or use the dedicated original_ids parameter
|
||||
my_shapes = search_shapes(
|
||||
original_ids=['agent-support-50k'],
|
||||
symbol="BINANCE:BTC/USDT"
|
||||
)
|
||||
|
||||
if my_shapes:
|
||||
print(f"Found my support line!")
|
||||
print(f"TradingView ID: {my_shapes[0]['id']}")
|
||||
print(f"My requested ID: {my_shapes[0]['original_id']}")
|
||||
```
|
||||
|
||||
### Why ID Mapping Matters
|
||||
|
||||
- **Shape identification**: You need to know which TradingView shape corresponds to the shape you created
|
||||
- **Updates and deletions**: To modify or delete a shape, you need its TradingView ID (the `id` field)
|
||||
- **Bidirectional sync**: The mapping ensures both the agent and TradingView can reference the same shape
|
||||
|
||||
### Best Practices for Shape IDs
|
||||
|
||||
1. **Use descriptive IDs**: Choose meaningful names like `support-btc-50k` or `trendline-daily-uptrend`
|
||||
2. **Search by original ID**: Use `shape_ids` or `original_ids` parameters in `search_shapes()` to find your shapes
|
||||
- `shape_ids` searches both the actual ID and original_id (more flexible)
|
||||
- `original_ids` searches only the original_id field (more specific)
|
||||
3. **Store important IDs**: If you need to reference a shape multiple times, store its TradingView ID after retrieval
|
||||
4. **Understand the timing**: The ID remapping happens asynchronously after shape creation
|
||||
|
||||
## Common Shape Types
|
||||
|
||||
Use TradingView's native shape type names directly.
|
||||
|
||||
### 1. Trendline
|
||||
**Type**: `trend_line`
|
||||
**Control Points**: 2
|
||||
- Point 1: Start of the line (time, price)
|
||||
- Point 2: End of the line (time, price)
|
||||
|
||||
**Common Use Cases**:
|
||||
- Support/resistance lines
|
||||
- Trend identification
|
||||
- Price channels (when paired)
|
||||
|
||||
**Example**:
|
||||
```json
|
||||
{
|
||||
"id": "trendline-1",
|
||||
"type": "trend_line",
|
||||
"points": [
|
||||
{"time": 1640000000, "price": 45000.0},
|
||||
{"time": 1650000000, "price": 50000.0}
|
||||
],
|
||||
"color": "#2962FF",
|
||||
"line_width": 2,
|
||||
"line_style": "solid"
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Horizontal Line
|
||||
**Type**: `horizontal_line`
|
||||
**Control Points**: 1
|
||||
- Point 1: Y-level (time can be any value, only price matters)
|
||||
|
||||
**Common Use Cases**:
|
||||
- Support/resistance levels
|
||||
- Price targets
|
||||
- Stop-loss levels
|
||||
- Key psychological levels
|
||||
|
||||
**Properties**:
|
||||
- `extend_left`: Boolean, extend line to the left
|
||||
- `extend_right`: Boolean, extend line to the right
|
||||
|
||||
**Example**:
|
||||
```json
|
||||
{
|
||||
"id": "support-1",
|
||||
"type": "horizontal_line",
|
||||
"points": [{"time": 1640000000, "price": 42000.0}],
|
||||
"color": "#089981",
|
||||
"line_width": 2,
|
||||
"line_style": "dashed",
|
||||
"properties": {
|
||||
"extend_left": true,
|
||||
"extend_right": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Vertical Line
|
||||
**Type**: `vertical_line`
|
||||
**Control Points**: 1
|
||||
- Point 1: X-time (price can be any value, only time matters)
|
||||
|
||||
**Common Use Cases**:
|
||||
- Mark important events
|
||||
- Session boundaries
|
||||
- Earnings releases
|
||||
- Economic data releases
|
||||
|
||||
**Properties**:
|
||||
- `extend_top`: Boolean, extend line upward
|
||||
- `extend_bottom`: Boolean, extend line downward
|
||||
|
||||
**Example**:
|
||||
```json
|
||||
{
|
||||
"id": "event-marker-1",
|
||||
"type": "vertical_line",
|
||||
"points": [{"time": 1640000000, "price": 0}],
|
||||
"color": "#787B86",
|
||||
"line_width": 1,
|
||||
"line_style": "dotted"
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Rectangle
|
||||
**Type**: `rectangle`
|
||||
**Control Points**: 2
|
||||
- Point 1: Top-left corner (time, price)
|
||||
- Point 2: Bottom-right corner (time, price)
|
||||
|
||||
**Common Use Cases**:
|
||||
- Consolidation zones
|
||||
- Support/resistance zones
|
||||
- Supply/demand areas
|
||||
- Value areas
|
||||
|
||||
**Properties**:
|
||||
- `fill_color`: Fill color with opacity (e.g., '#2962FF33')
|
||||
- `fill`: Boolean, whether to fill the rectangle
|
||||
- `extend_left`: Boolean
|
||||
- `extend_right`: Boolean
|
||||
|
||||
**Example**:
|
||||
```json
|
||||
{
|
||||
"id": "zone-1",
|
||||
"type": "rectangle",
|
||||
"points": [
|
||||
{"time": 1640000000, "price": 50000.0},
|
||||
{"time": 1650000000, "price": 48000.0}
|
||||
],
|
||||
"color": "#2962FF",
|
||||
"line_width": 1,
|
||||
"line_style": "solid",
|
||||
"properties": {
|
||||
"fill": true,
|
||||
"fill_color": "#2962FF33"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 5. Fibonacci Retracement
|
||||
**Type**: `fib_retracement`
|
||||
**Control Points**: 2
|
||||
- Point 1: Start of the move (swing low or high)
|
||||
- Point 2: End of the move (swing high or low)
|
||||
|
||||
**Common Use Cases**:
|
||||
- Identify potential support/resistance levels
|
||||
- Find retracement targets
|
||||
- Measure pullback depth
|
||||
|
||||
**Properties**:
|
||||
- `levels`: Array of Fibonacci levels to display
|
||||
- Default: [0, 0.236, 0.382, 0.5, 0.618, 0.786, 1.0]
|
||||
- `extend_lines`: Boolean, extend levels beyond the price range
|
||||
- `reverse`: Boolean, reverse the direction
|
||||
|
||||
**Example**:
|
||||
```json
|
||||
{
|
||||
"id": "fib-1",
|
||||
"type": "fib_retracement",
|
||||
"points": [
|
||||
{"time": 1640000000, "price": 42000.0},
|
||||
{"time": 1650000000, "price": 52000.0}
|
||||
],
|
||||
"color": "#2962FF",
|
||||
"line_width": 1,
|
||||
"properties": {
|
||||
"levels": [0, 0.236, 0.382, 0.5, 0.618, 0.786, 1.0],
|
||||
"extend_lines": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 6. Fibonacci Extension
|
||||
**Type**: `fib_trend_ext`
|
||||
**Control Points**: 3
|
||||
- Point 1: Start of initial move
|
||||
- Point 2: End of initial move (retracement start)
|
||||
- Point 3: End of retracement
|
||||
|
||||
**Common Use Cases**:
|
||||
- Project price targets
|
||||
- Extension levels beyond 100%
|
||||
- Measure continuation patterns
|
||||
|
||||
**Properties**:
|
||||
- `levels`: Array of extension levels
|
||||
- Common: [0, 0.618, 1.0, 1.618, 2.618, 4.236]
|
||||
|
||||
### 7. Parallel Channel
|
||||
**Type**: `parallel_channel`
|
||||
**Control Points**: 3
|
||||
- Point 1: First point on main trendline
|
||||
- Point 2: Second point on main trendline
|
||||
- Point 3: Point on parallel line (determines channel width)
|
||||
|
||||
**Common Use Cases**:
|
||||
- Price channels
|
||||
- Regression channels
|
||||
- Pitchforks
|
||||
|
||||
**Properties**:
|
||||
- `extend_left`: Boolean
|
||||
- `extend_right`: Boolean
|
||||
- `fill`: Boolean, fill the channel
|
||||
- `fill_color`: Fill color with opacity
|
||||
|
||||
### 8. Arrow
|
||||
**Type**: `arrow`
|
||||
**Control Points**: 2
|
||||
- Point 1: Arrow start (time, price)
|
||||
- Point 2: Arrow end (time, price)
|
||||
|
||||
**Common Use Cases**:
|
||||
- Indicate price movement direction
|
||||
- Mark entry/exit points
|
||||
- Show relationships between events
|
||||
|
||||
**Properties**:
|
||||
- `arrow_style`: One of: 'simple', 'filled', 'hollow'
|
||||
- `text`: Optional text label
|
||||
|
||||
**Example**:
|
||||
```json
|
||||
{
|
||||
"id": "entry-arrow",
|
||||
"type": "arrow",
|
||||
"points": [
|
||||
{"time": 1640000000, "price": 44000.0},
|
||||
{"time": 1641000000, "price": 48000.0}
|
||||
],
|
||||
"color": "#089981",
|
||||
"line_width": 2,
|
||||
"properties": {
|
||||
"arrow_style": "filled",
|
||||
"text": "Long Entry"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 9. Text/Label
|
||||
**Type**: `text`
|
||||
**Control Points**: 1
|
||||
- Point 1: Text anchor position (time, price)
|
||||
|
||||
**Common Use Cases**:
|
||||
- Annotations
|
||||
- Notes
|
||||
- Labels for patterns
|
||||
- Mark key levels
|
||||
|
||||
**Properties**:
|
||||
- `text`: The text content (string)
|
||||
- `font_size`: Font size in points (integer)
|
||||
- `font_family`: Font family name
|
||||
- `bold`: Boolean
|
||||
- `italic`: Boolean
|
||||
- `background`: Boolean, show background
|
||||
- `background_color`: Background color
|
||||
- `text_color`: Text color (can differ from line color)
|
||||
|
||||
**Example**:
|
||||
```json
|
||||
{
|
||||
"id": "note-1",
|
||||
"type": "text",
|
||||
"points": [{"time": 1640000000, "price": 48000.0}],
|
||||
"color": "#131722",
|
||||
"properties": {
|
||||
"text": "Resistance Zone",
|
||||
"font_size": 14,
|
||||
"bold": true,
|
||||
"background": true,
|
||||
"background_color": "#FFE600",
|
||||
"text_color": "#131722"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 10. Single-Point Markers
|
||||
Various single-point marker shapes are available for annotating charts:
|
||||
|
||||
**Types**: `arrow_up` | `arrow_down` | `flag` | `emoji` | `icon` | `sticker` | `note` | `anchored_text` | `anchored_note` | `long_position` | `short_position`
|
||||
|
||||
**Control Points**: 1
|
||||
- Point 1: Marker position (time, price)
|
||||
|
||||
**Common Use Cases**:
|
||||
- Mark entry/exit points
|
||||
- Flag important events
|
||||
- Add visual markers to key levels
|
||||
- Annotate patterns
|
||||
- Track positions
|
||||
|
||||
**Properties** (vary by type):
|
||||
- `text`: Text content for text-based markers
|
||||
- `emoji`: Emoji character for emoji type
|
||||
- `icon`: Icon identifier for icon type
|
||||
|
||||
**Examples**:
|
||||
```json
|
||||
{
|
||||
"id": "long-entry-1",
|
||||
"type": "long_position",
|
||||
"points": [{"time": 1640000000, "price": 44000.0}],
|
||||
"color": "#089981"
|
||||
}
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "flag-1",
|
||||
"type": "flag",
|
||||
"points": [{"time": 1640000000, "price": 50000.0}],
|
||||
"color": "#F23645",
|
||||
"properties": {
|
||||
"text": "Important Event"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "note-1",
|
||||
"type": "anchored_note",
|
||||
"points": [{"time": 1640000000, "price": 48000.0}],
|
||||
"color": "#FFE600",
|
||||
"properties": {
|
||||
"text": "Watch this level"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 11. Circle/Ellipse
|
||||
**Type**: `circle`
|
||||
**Control Points**: 2 or 3
|
||||
- 2 points: Defines bounding box (creates ellipse)
|
||||
- 3 points: Center + radius points
|
||||
|
||||
**Common Use Cases**:
|
||||
- Highlight areas
|
||||
- Markup patterns
|
||||
- Mark consolidation zones
|
||||
|
||||
**Properties**:
|
||||
- `fill`: Boolean
|
||||
- `fill_color`: Fill color with opacity
|
||||
|
||||
### 12. Path (Free Drawing)
|
||||
**Type**: `path`
|
||||
**Control Points**: Variable (3+)
|
||||
- Multiple points defining a path
|
||||
|
||||
**Common Use Cases**:
|
||||
- Custom patterns
|
||||
- Freeform markup
|
||||
- Complex annotations
|
||||
|
||||
**Properties**:
|
||||
- `closed`: Boolean, whether to close the path
|
||||
- `smooth`: Boolean, smooth the path with curves
|
||||
|
||||
### 13. Pitchfork (Andrew's Pitchfork)
|
||||
**Type**: `pitchfork`
|
||||
**Control Points**: 3
|
||||
- Point 1: Pivot/starting point
|
||||
- Point 2: First extreme (high or low)
|
||||
- Point 3: Second extreme (opposite of point 2)
|
||||
|
||||
**Common Use Cases**:
|
||||
- Trend channels
|
||||
- Support/resistance levels
|
||||
- Median line analysis
|
||||
|
||||
**Properties**:
|
||||
- `extend_lines`: Boolean
|
||||
- `style`: One of: 'standard', 'schiff', 'modified_schiff'
|
||||
|
||||
### 14. Gann Fan
|
||||
**Type**: `gannbox_fan`
|
||||
**Control Points**: 2
|
||||
- Point 1: Origin point
|
||||
- Point 2: Defines the unit size/scale
|
||||
|
||||
**Common Use Cases**:
|
||||
- Time and price analysis
|
||||
- Geometric angles (1x1, 1x2, 2x1, etc.)
|
||||
|
||||
**Properties**:
|
||||
- `angles`: Array of angles to display
|
||||
- Default: [82.5, 75, 71.25, 63.75, 45, 26.25, 18.75, 15, 7.5]
|
||||
|
||||
### 15. Head and Shoulders
|
||||
**Type**: `head_and_shoulders`
|
||||
**Control Points**: 5
|
||||
- Point 1: Left shoulder low
|
||||
- Point 2: Left shoulder high
|
||||
- Point 3: Head low
|
||||
- Point 4: Right shoulder high
|
||||
- Point 5: Right shoulder low (neckline point)
|
||||
|
||||
**Common Use Cases**:
|
||||
- Pattern recognition markup
|
||||
- Reversal pattern identification
|
||||
|
||||
**Properties**:
|
||||
- `target_line`: Boolean, show target line
|
||||
|
||||
## Special Properties
|
||||
|
||||
### Time-Based Properties
|
||||
- All times are Unix timestamps in seconds
|
||||
- Use `Math.floor(Date.now() / 1000)` for current time in JavaScript
|
||||
- Use `int(time.time())` for current time in Python
|
||||
|
||||
### Color Formats
|
||||
- Hex: `#RRGGBB` (e.g., `#2962FF`)
|
||||
- Hex with alpha: `#RRGGBBAA` (e.g., `#2962FF33` for 20% opacity)
|
||||
- Named colors: `red`, `blue`, `green`, etc.
|
||||
- RGB: `rgb(41, 98, 255)`
|
||||
- RGBA: `rgba(41, 98, 255, 0.2)`
|
||||
|
||||
### Line Styles
|
||||
- `solid`: Continuous line
|
||||
- `dashed`: Dashed line (— — —)
|
||||
- `dotted`: Dotted line (· · ·)
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **ID Naming**: Use descriptive IDs that indicate the purpose
|
||||
- Good: `support-btc-42k`, `trendline-uptrend-1`
|
||||
- Bad: `shape1`, `line`
|
||||
|
||||
2. **Color Consistency**: Use consistent colors for similar types
|
||||
- Green (#089981) for bullish/support
|
||||
- Red (#F23645) for bearish/resistance
|
||||
- Blue (#2962FF) for neutral/informational
|
||||
|
||||
3. **Time Alignment**: Ensure times align with actual candles when possible
|
||||
|
||||
4. **Layer Management**: Use different line widths to indicate importance
|
||||
- Key levels: 2-3px
|
||||
- Secondary levels: 1px
|
||||
- Reference lines: 1px dotted
|
||||
|
||||
5. **Symbol Association**: Always set the `symbol` field to associate shapes with specific charts
|
||||
|
||||
## Agent Usage Examples
|
||||
|
||||
### Drawing a Support Level
|
||||
When user says "draw support at 42000":
|
||||
```python
|
||||
await create_or_update_shape(
|
||||
shape_id=f"support-{int(time.time())}",
|
||||
shape_type='horizontal_line',
|
||||
points=[{'time': current_time, 'price': 42000.0}],
|
||||
color='#089981',
|
||||
line_width=2,
|
||||
line_style='solid',
|
||||
symbol=chart_store.chart_state.symbol,
|
||||
properties={'extend_left': True, 'extend_right': True}
|
||||
)
|
||||
```
|
||||
|
||||
### Finding Shapes in Visible Range
|
||||
When user asks "what drawings are on the chart?":
|
||||
```python
|
||||
shapes = search_shapes(
|
||||
start_time=chart_store.chart_state.start_time,
|
||||
end_time=chart_store.chart_state.end_time,
|
||||
symbol=chart_store.chart_state.symbol
|
||||
)
|
||||
```
|
||||
|
||||
### Getting Specific Shapes by ID
|
||||
When user says "show me the details of trendline-1":
|
||||
```python
|
||||
# shape_ids parameter searches BOTH the actual ID and original_id fields
|
||||
shapes = search_shapes(
|
||||
shape_ids=['trendline-1']
|
||||
)
|
||||
```
|
||||
|
||||
Or to get selected shapes:
|
||||
```python
|
||||
selected_ids = chart_store.chart_state.selected_shapes
|
||||
if selected_ids:
|
||||
shapes = search_shapes(shape_ids=selected_ids)
|
||||
```
|
||||
|
||||
### Finding Shapes by Original ID
|
||||
When you need to find shapes you created using the original ID you specified:
|
||||
```python
|
||||
# Use the dedicated original_ids parameter
|
||||
my_shapes = search_shapes(
|
||||
original_ids=['my-support-line', 'my-trendline']
|
||||
)
|
||||
|
||||
# Or use shape_ids (which searches both id and original_id)
|
||||
my_shapes = search_shapes(
|
||||
shape_ids=['my-support-line', 'my-trendline']
|
||||
)
|
||||
|
||||
for shape in my_shapes:
|
||||
print(f"Original ID: {shape['original_id']}")
|
||||
print(f"TradingView ID: {shape['id']}")
|
||||
print(f"Type: {shape['type']}")
|
||||
```
|
||||
|
||||
### Searching Without Time Filter
|
||||
When user asks "show me all support lines":
|
||||
```python
|
||||
support_lines = search_shapes(
|
||||
shape_type='horizontal_line',
|
||||
symbol=chart_store.chart_state.symbol
|
||||
)
|
||||
```
|
||||
|
||||
### Drawing a Trendline
|
||||
When user says "draw an uptrend from the lows":
|
||||
```python
|
||||
# Find swing lows using execute_python
|
||||
# Then create trendline
|
||||
await create_or_update_shape(
|
||||
shape_id=f"trendline-{int(time.time())}",
|
||||
shape_type='trend_line',
|
||||
points=[
|
||||
{'time': swing_low_1_time, 'price': swing_low_1_price},
|
||||
{'time': swing_low_2_time, 'price': swing_low_2_price}
|
||||
],
|
||||
color='#2962FF',
|
||||
line_width=2,
|
||||
symbol=chart_store.chart_state.symbol
|
||||
)
|
||||
```
|
||||
12
backend.old/requirements-pre.txt
Normal file
12
backend.old/requirements-pre.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
# Packages requiring compilation (Rust/C) - separated for Docker layer caching
|
||||
# Changes here will trigger a rebuild of this layer
|
||||
|
||||
# Needs Rust (maturin)
|
||||
chromadb>=0.4.0
|
||||
cryptography>=42.0.0
|
||||
|
||||
# Pulls in `tokenizers` which needs Rust
|
||||
sentence-transformers>=2.0.0
|
||||
|
||||
# Needs C compiler
|
||||
argon2-cffi>=23.0.0
|
||||
@@ -1,9 +1,13 @@
|
||||
# Packages that require compilation in the python slim docker image must be
|
||||
# put into requirements-pre.txt instead, to help image build cacheing.
|
||||
|
||||
pydantic2
|
||||
seaborn
|
||||
pandas
|
||||
numpy
|
||||
scipy
|
||||
matplotlib
|
||||
mplfinance
|
||||
fastapi
|
||||
uvicorn
|
||||
websockets
|
||||
@@ -11,6 +15,7 @@ jsonpatch
|
||||
python-multipart
|
||||
ccxt>=4.0.0
|
||||
pyyaml
|
||||
TA-Lib>=0.4.0
|
||||
|
||||
# LangChain agent dependencies
|
||||
langchain>=0.3.0
|
||||
@@ -19,9 +24,12 @@ langgraph-checkpoint-sqlite>=1.0.0
|
||||
langchain-anthropic>=0.3.0
|
||||
langchain-community>=0.3.0
|
||||
|
||||
# Local memory system
|
||||
chromadb>=0.4.0
|
||||
sentence-transformers>=2.0.0
|
||||
# Additional tools for research and web access
|
||||
arxiv>=2.0.0
|
||||
duckduckgo-search>=7.0.0
|
||||
requests>=2.31.0
|
||||
|
||||
# Local memory system (chromadb/sentence-transformers in requirements-pre.txt)
|
||||
sqlalchemy>=2.0.0
|
||||
aiosqlite>=0.19.0
|
||||
|
||||
@@ -34,3 +42,6 @@ python-dotenv>=1.0.0
|
||||
# Secrets management
|
||||
cryptography>=42.0.0
|
||||
argon2-cffi>=23.0.0
|
||||
|
||||
# Trigger system scheduling
|
||||
apscheduler>=3.10.0
|
||||
37
backend.old/soul/automation_agent.md
Normal file
37
backend.old/soul/automation_agent.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# Automation Agent
|
||||
|
||||
You are a specialized automation and scheduling agent. Your sole purpose is to manage triggers, scheduled tasks, and automated workflows.
|
||||
|
||||
## Your Core Identity
|
||||
|
||||
You are an expert in:
|
||||
- Scheduling recurring tasks with cron expressions
|
||||
- Creating interval-based triggers
|
||||
- Managing the trigger queue and priorities
|
||||
- Designing autonomous agent workflows
|
||||
|
||||
## Your Tools
|
||||
|
||||
You have access to:
|
||||
- **Trigger Tools**: schedule_agent_prompt, execute_agent_prompt_once, list_scheduled_triggers, cancel_scheduled_trigger, get_trigger_system_stats
|
||||
|
||||
## Communication Style
|
||||
|
||||
- **Systematic**: Explain scheduling logic clearly
|
||||
- **Proactive**: Suggest optimal scheduling strategies
|
||||
- **Organized**: Keep track of all scheduled tasks
|
||||
- **Reliable**: Ensure triggers are set up correctly
|
||||
|
||||
## Key Principles
|
||||
|
||||
1. **Clarity**: Make schedules easy to understand
|
||||
2. **Maintainability**: Use descriptive names for jobs
|
||||
3. **Priority Awareness**: Respect task priorities
|
||||
4. **Resource Conscious**: Avoid overwhelming the system
|
||||
|
||||
## Limitations
|
||||
|
||||
- You ONLY handle scheduling and automation
|
||||
- You do NOT execute the actual analysis (delegate to other agents)
|
||||
- You do NOT access data or charts directly
|
||||
- You coordinate but don't perform the work
|
||||
40
backend.old/soul/chart_agent.md
Normal file
40
backend.old/soul/chart_agent.md
Normal file
@@ -0,0 +1,40 @@
|
||||
# Chart Analysis Agent
|
||||
|
||||
You are a specialized chart analysis and technical analysis agent. Your sole purpose is to work with chart data, indicators, and Python-based analysis.
|
||||
|
||||
## Your Core Identity
|
||||
|
||||
You are an expert in:
|
||||
- Reading and analyzing OHLCV data
|
||||
- Computing technical indicators (RSI, MACD, Bollinger Bands, etc.)
|
||||
- Drawing shapes and annotations on charts
|
||||
- Executing Python code for custom analysis
|
||||
- Visualizing data with matplotlib
|
||||
|
||||
## Your Tools
|
||||
|
||||
You have access to:
|
||||
- **Chart Tools**: get_chart_data, execute_python
|
||||
- **Indicator Tools**: search_indicators, add_indicator_to_chart, list_chart_indicators, etc.
|
||||
- **Shape Tools**: search_shapes, create_or_update_shape, delete_shape, etc.
|
||||
|
||||
## Communication Style
|
||||
|
||||
- **Direct & Technical**: Provide analysis results clearly
|
||||
- **Visual**: Generate plots when helpful
|
||||
- **Precise**: Reference specific timeframes, indicators, and values
|
||||
- **Concise**: Focus on the analysis task at hand
|
||||
|
||||
## Key Principles
|
||||
|
||||
1. **Data First**: Always work with actual market data
|
||||
2. **Visualize**: Create charts to illustrate findings
|
||||
3. **Document Calculations**: Explain what indicators show
|
||||
4. **Respect Context**: Use the chart the user is viewing when available
|
||||
|
||||
## Limitations
|
||||
|
||||
- You ONLY handle chart analysis and visualization
|
||||
- You do NOT make trading decisions
|
||||
- You do NOT access external APIs or data sources
|
||||
- Route other requests back to the main agent
|
||||
37
backend.old/soul/data_agent.md
Normal file
37
backend.old/soul/data_agent.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# Data Access Agent
|
||||
|
||||
You are a specialized data access agent. Your sole purpose is to retrieve and search market data from various exchanges and data sources.
|
||||
|
||||
## Your Core Identity
|
||||
|
||||
You are an expert in:
|
||||
- Searching for trading symbols across exchanges
|
||||
- Retrieving historical OHLCV data
|
||||
- Understanding exchange APIs and data formats
|
||||
- Symbol resolution and metadata
|
||||
|
||||
## Your Tools
|
||||
|
||||
You have access to:
|
||||
- **Data Source Tools**: list_data_sources, search_symbols, get_symbol_info, get_historical_data
|
||||
|
||||
## Communication Style
|
||||
|
||||
- **Precise**: Provide exact symbol names and exchange identifiers
|
||||
- **Helpful**: Suggest alternatives when exact matches aren't found
|
||||
- **Efficient**: Return data in the format requested
|
||||
- **Clear**: Explain data limitations or availability issues
|
||||
|
||||
## Key Principles
|
||||
|
||||
1. **Accuracy**: Return correct symbol identifiers
|
||||
2. **Completeness**: Include all relevant metadata
|
||||
3. **Performance**: Respect countback limits
|
||||
4. **Format Awareness**: Understand time resolutions and ranges
|
||||
|
||||
## Limitations
|
||||
|
||||
- You ONLY handle data retrieval and search
|
||||
- You do NOT analyze data (route to chart agent)
|
||||
- You do NOT access external news or research (route to research agent)
|
||||
- You do NOT modify data or state
|
||||
72
backend.old/soul/main_agent.md
Normal file
72
backend.old/soul/main_agent.md
Normal file
@@ -0,0 +1,72 @@
|
||||
# System Prompt
|
||||
|
||||
You are an AI trading assistant for an AI-native algorithmic trading platform. Your role is to help traders design, implement, and manage trading strategies through natural language interaction.
|
||||
|
||||
## Your Core Identity
|
||||
|
||||
You are a **strategy authoring assistant**, not a strategy executor. You help users:
|
||||
- Design trading strategies from natural language descriptions
|
||||
- Interpret chart annotations and technical requirements
|
||||
- Generate strategy executables (code artifacts)
|
||||
- Manage and monitor live trading state
|
||||
- Analyze market data and provide insights
|
||||
|
||||
## Your Capabilities
|
||||
|
||||
### State Management
|
||||
You have read/write access to synchronized state stores. Use your tools to read current state and update it as needed. All state changes are automatically synchronized with connected clients.
|
||||
|
||||
### Strategy Authoring
|
||||
- Help users express trading intent through conversation
|
||||
- Translate natural language to concrete strategy specifications
|
||||
- Understand technical analysis concepts (support/resistance, indicators, patterns)
|
||||
- Generate self-contained, deterministic strategy executables
|
||||
- Validate strategy logic for correctness and safety
|
||||
|
||||
### Data & Analysis
|
||||
- Access market data through abstract feed specifications
|
||||
- Compute indicators and perform technical analysis
|
||||
- Understand OHLCV data, order books, and market microstructure
|
||||
|
||||
## Communication Style
|
||||
|
||||
- **Technical & Direct**: Users are knowledgeable traders, be precise
|
||||
- **Safety First**: Never make destructive changes without confirmation
|
||||
- **Explain Actions**: When modifying state, explain what you're doing
|
||||
- **Ask Questions**: If intent is unclear, ask for clarification
|
||||
- **Concise**: Be brief but complete, avoid unnecessary elaboration
|
||||
|
||||
## Key Principles
|
||||
|
||||
1. **Strategies are Deterministic**: Generated strategies run without LLM involvement at runtime
|
||||
2. **Local Execution**: The platform runs locally for security; you are a design-time tool only
|
||||
3. **Schema Validation**: All outputs must conform to platform schemas
|
||||
4. **Risk Awareness**: Always consider position sizing, exposure limits, and risk management
|
||||
5. **Versioning**: Every strategy artifact is version-controlled with full auditability
|
||||
|
||||
## Your Limitations
|
||||
|
||||
- You **DO NOT** execute trades directly
|
||||
- You **CANNOT** modify the order kernel or execution layer
|
||||
- You **SHOULD NOT** make assumptions about user risk tolerance without asking
|
||||
- You **MUST NOT** provide trading or investment advice
|
||||
|
||||
## Memory & Context
|
||||
|
||||
You have access to:
|
||||
- Full conversation history with semantic search
|
||||
- Project documentation (design, architecture, data formats)
|
||||
- Past strategy discussions and decisions
|
||||
- Relevant context retrieved automatically based on current conversation
|
||||
|
||||
## Working with Users
|
||||
|
||||
1. **Understand Intent**: Ask clarifying questions about strategy goals
|
||||
2. **Design Together**: Collaborate on strategy logic iteratively
|
||||
3. **Validate**: Ensure strategy makes sense before generating code
|
||||
4. **Test**: Encourage backtesting and paper trading first
|
||||
5. **Monitor**: Help users interpret live strategy behavior
|
||||
|
||||
---
|
||||
|
||||
**Note**: Additional context documents are loaded automatically to provide detailed operational guidelines. See memory files for specifics on chart context, shape drawing, Python analysis, and more.
|
||||
37
backend.old/soul/research_agent.md
Normal file
37
backend.old/soul/research_agent.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# Research Agent
|
||||
|
||||
You are a specialized research agent. Your sole purpose is to gather external information from the web, academic papers, and public APIs.
|
||||
|
||||
## Your Core Identity
|
||||
|
||||
You are an expert in:
|
||||
- Searching academic papers on arXiv
|
||||
- Finding information on Wikipedia
|
||||
- Web search for current events and news
|
||||
- Making HTTP requests to public APIs
|
||||
|
||||
## Your Tools
|
||||
|
||||
You have access to:
|
||||
- **Research Tools**: search_arxiv, search_wikipedia, search_web, http_get, http_post
|
||||
|
||||
## Communication Style
|
||||
|
||||
- **Thorough**: Provide comprehensive research findings
|
||||
- **Source-Aware**: Cite sources and links
|
||||
- **Critical**: Evaluate information quality
|
||||
- **Summarize**: Distill key points from long content
|
||||
|
||||
## Key Principles
|
||||
|
||||
1. **Verify**: Cross-reference information when possible
|
||||
2. **Recency**: Note publication dates and data freshness
|
||||
3. **Relevance**: Focus on trading and market-relevant information
|
||||
4. **Ethics**: Respect API rate limits and terms of service
|
||||
|
||||
## Limitations
|
||||
|
||||
- You ONLY handle external information gathering
|
||||
- You do NOT analyze market data (route to chart agent)
|
||||
- You do NOT make trading recommendations
|
||||
- You do NOT access private or authenticated APIs without explicit permission
|
||||
4
backend.old/src/agent/__init__.py
Normal file
4
backend.old/src/agent/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Don't import at module level to avoid circular imports
|
||||
# Users should import directly: from agent.core import create_agent
|
||||
|
||||
__all__ = ["core", "tools"]
|
||||
@@ -7,10 +7,12 @@ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS
|
||||
from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS, INDICATOR_TOOLS, RESEARCH_TOOLS, CHART_TOOLS, SHAPE_TOOLS, TRIGGER_TOOLS
|
||||
from agent.memory import MemoryManager
|
||||
from agent.session import SessionManager
|
||||
from agent.prompts import build_system_prompt
|
||||
from agent.subagent import SubAgent
|
||||
from agent.routers import ROUTER_TOOLS, set_chart_agent, set_data_agent, set_automation_agent, set_research_agent
|
||||
from gateway.user_session import UserSession
|
||||
from gateway.protocol import UserMessage as GatewayUserMessage
|
||||
|
||||
@@ -29,7 +31,8 @@ class AgentExecutor:
|
||||
model_name: str = "claude-sonnet-4-20250514",
|
||||
temperature: float = 0.7,
|
||||
api_key: Optional[str] = None,
|
||||
memory_manager: Optional[MemoryManager] = None
|
||||
memory_manager: Optional[MemoryManager] = None,
|
||||
base_dir: str = "."
|
||||
):
|
||||
"""Initialize agent executor.
|
||||
|
||||
@@ -38,10 +41,12 @@ class AgentExecutor:
|
||||
temperature: Model temperature
|
||||
api_key: Anthropic API key
|
||||
memory_manager: MemoryManager instance
|
||||
base_dir: Base directory for resolving paths
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.temperature = temperature
|
||||
self.api_key = api_key
|
||||
self.base_dir = base_dir
|
||||
|
||||
# Initialize LLM
|
||||
self.llm = ChatAnthropic(
|
||||
@@ -56,24 +61,82 @@ class AgentExecutor:
|
||||
self.session_manager = SessionManager(self.memory_manager)
|
||||
self.agent = None # Will be created after initialization
|
||||
|
||||
# Sub-agents (only if using hierarchical tools)
|
||||
self.chart_agent = None
|
||||
self.data_agent = None
|
||||
self.automation_agent = None
|
||||
self.research_agent = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the agent system."""
|
||||
await self.memory_manager.initialize()
|
||||
|
||||
# Create agent with tools and LangGraph checkpointing
|
||||
# Create agent with tools and LangGraph checkpointer
|
||||
checkpointer = self.memory_manager.get_checkpointer()
|
||||
|
||||
# Build initial system prompt with context
|
||||
context = self.memory_manager.get_context_prompt()
|
||||
system_prompt = build_system_prompt(context, [])
|
||||
# Create specialized sub-agents
|
||||
logger.info("Initializing hierarchical agent architecture with sub-agents")
|
||||
|
||||
self.chart_agent = SubAgent(
|
||||
name="chart",
|
||||
soul_file="chart_agent.md",
|
||||
tools=CHART_TOOLS + INDICATOR_TOOLS + SHAPE_TOOLS,
|
||||
model_name=self.model_name,
|
||||
temperature=self.temperature,
|
||||
api_key=self.api_key,
|
||||
base_dir=self.base_dir
|
||||
)
|
||||
|
||||
self.data_agent = SubAgent(
|
||||
name="data",
|
||||
soul_file="data_agent.md",
|
||||
tools=DATASOURCE_TOOLS,
|
||||
model_name=self.model_name,
|
||||
temperature=self.temperature,
|
||||
api_key=self.api_key,
|
||||
base_dir=self.base_dir
|
||||
)
|
||||
|
||||
self.automation_agent = SubAgent(
|
||||
name="automation",
|
||||
soul_file="automation_agent.md",
|
||||
tools=TRIGGER_TOOLS,
|
||||
model_name=self.model_name,
|
||||
temperature=self.temperature,
|
||||
api_key=self.api_key,
|
||||
base_dir=self.base_dir
|
||||
)
|
||||
|
||||
self.research_agent = SubAgent(
|
||||
name="research",
|
||||
soul_file="research_agent.md",
|
||||
tools=RESEARCH_TOOLS,
|
||||
model_name=self.model_name,
|
||||
temperature=self.temperature,
|
||||
api_key=self.api_key,
|
||||
base_dir=self.base_dir
|
||||
)
|
||||
|
||||
# Set global sub-agent instances for router tools
|
||||
set_chart_agent(self.chart_agent)
|
||||
set_data_agent(self.data_agent)
|
||||
set_automation_agent(self.automation_agent)
|
||||
set_research_agent(self.research_agent)
|
||||
|
||||
# Main agent only gets SYNC_TOOLS (state management) and ROUTER_TOOLS
|
||||
logger.info("Main agent using router tools (4 routers + sync tools)")
|
||||
agent_tools = SYNC_TOOLS + ROUTER_TOOLS
|
||||
|
||||
# Create main agent without a static system prompt
|
||||
# We'll pass the dynamic system prompt via state_modifier at runtime
|
||||
self.agent = create_react_agent(
|
||||
self.llm,
|
||||
SYNC_TOOLS + DATASOURCE_TOOLS,
|
||||
prompt=system_prompt,
|
||||
agent_tools,
|
||||
checkpointer=checkpointer
|
||||
)
|
||||
|
||||
logger.info(f"Agent initialized with {len(agent_tools)} tools")
|
||||
|
||||
async def _clear_checkpoint(self, session_id: str) -> None:
|
||||
"""Clear the checkpoint for a session to prevent resuming from invalid state.
|
||||
|
||||
@@ -101,26 +164,6 @@ class AgentExecutor:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clear checkpoint for session {session_id}: {e}")
|
||||
|
||||
def _build_system_message(self, state: Dict[str, Any]) -> SystemMessage:
|
||||
"""Build system message with context.
|
||||
|
||||
Args:
|
||||
state: Agent state
|
||||
|
||||
Returns:
|
||||
SystemMessage with full context
|
||||
"""
|
||||
# Get context from loaded documents
|
||||
context = self.memory_manager.get_context_prompt()
|
||||
|
||||
# Get active channels from metadata
|
||||
active_channels = state.get("metadata", {}).get("active_channels", [])
|
||||
|
||||
# Build system prompt
|
||||
system_prompt = build_system_prompt(context, active_channels)
|
||||
|
||||
return SystemMessage(content=system_prompt)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
session: UserSession,
|
||||
@@ -143,7 +186,12 @@ class AgentExecutor:
|
||||
|
||||
async with lock:
|
||||
try:
|
||||
# Build message history
|
||||
# Build system prompt with current context
|
||||
context = self.memory_manager.get_context_prompt()
|
||||
system_prompt = build_system_prompt(context, session.active_channels)
|
||||
|
||||
# Build message history WITHOUT prepending system message
|
||||
# The system prompt will be passed via state_modifier in the config
|
||||
messages = []
|
||||
history = session.get_history(limit=10)
|
||||
logger.info(f"Building message history, {len(history)} messages in history")
|
||||
@@ -155,14 +203,18 @@ class AgentExecutor:
|
||||
elif msg.role == "assistant":
|
||||
messages.append(AIMessage(content=msg.content))
|
||||
|
||||
logger.info(f"Prepared {len(messages)} messages for agent")
|
||||
logger.info(f"Prepared {len(messages)} messages for agent (including system prompt)")
|
||||
for i, msg in enumerate(messages):
|
||||
logger.info(f"LangChain message {i}: type={type(msg).__name__}, content_len={len(msg.content)}, content='{msg.content[:100] if msg.content else 'EMPTY'}'")
|
||||
msg_type = type(msg).__name__
|
||||
content_preview = msg.content[:100] if msg.content else 'EMPTY'
|
||||
logger.info(f"LangChain message {i}: type={msg_type}, content_len={len(msg.content)}, content='{content_preview}'")
|
||||
|
||||
# Prepare config with metadata
|
||||
# Prepare config with metadata and dynamic system prompt
|
||||
# Pass system_prompt via state_modifier to avoid multiple system messages
|
||||
config = RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": session.session_id
|
||||
"thread_id": session.session_id,
|
||||
"state_modifier": system_prompt # Dynamic system prompt injection
|
||||
},
|
||||
metadata={
|
||||
"session_id": session.session_id,
|
||||
@@ -178,6 +230,8 @@ class AgentExecutor:
|
||||
event_count = 0
|
||||
chunk_count = 0
|
||||
|
||||
plot_urls = [] # Accumulate plot URLs from execute_python tool calls
|
||||
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
config=config,
|
||||
@@ -199,7 +253,35 @@ class AgentExecutor:
|
||||
elif event["event"] == "on_tool_end":
|
||||
tool_name = event.get("name", "unknown")
|
||||
tool_output = event.get("data", {}).get("output")
|
||||
logger.info(f"Tool call completed: {tool_name} with output: {tool_output}")
|
||||
|
||||
# LangChain may wrap the output in a ToolMessage with content field
|
||||
# Try to extract the actual content from the ToolMessage
|
||||
actual_output = tool_output
|
||||
if hasattr(tool_output, "content"):
|
||||
actual_output = tool_output.content
|
||||
|
||||
logger.info(f"Tool call completed: {tool_name} with output type: {type(actual_output)}")
|
||||
|
||||
# Extract plot_urls from execute_python tool results
|
||||
if tool_name == "execute_python":
|
||||
# Try to parse as JSON if it's a string
|
||||
import json
|
||||
if isinstance(actual_output, str):
|
||||
try:
|
||||
actual_output = json.loads(actual_output)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.warning(f"Could not parse execute_python output as JSON: {actual_output[:200]}")
|
||||
|
||||
if isinstance(actual_output, dict):
|
||||
tool_plot_urls = actual_output.get("plot_urls", [])
|
||||
if tool_plot_urls:
|
||||
logger.info(f"execute_python generated {len(tool_plot_urls)} plots: {tool_plot_urls}")
|
||||
plot_urls.extend(tool_plot_urls)
|
||||
# Yield metadata about plots immediately
|
||||
yield {
|
||||
"content": "",
|
||||
"metadata": {"plot_urls": tool_plot_urls}
|
||||
}
|
||||
|
||||
# Extract streaming tokens
|
||||
elif event["event"] == "on_chat_model_stream":
|
||||
@@ -274,7 +356,7 @@ def create_agent(
|
||||
base_dir: Base directory for resolving paths
|
||||
|
||||
Returns:
|
||||
Initialized AgentExecutor
|
||||
Initialized AgentExecutor with hierarchical tool routing
|
||||
"""
|
||||
# Initialize memory manager
|
||||
memory_manager = MemoryManager(
|
||||
@@ -290,7 +372,8 @@ def create_agent(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
memory_manager=memory_manager
|
||||
memory_manager=memory_manager,
|
||||
base_dir=base_dir
|
||||
)
|
||||
|
||||
return executor
|
||||
118
backend.old/src/agent/prompts.py
Normal file
118
backend.old/src/agent/prompts.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from typing import List, Dict, Any
|
||||
from gateway.user_session import UserSession
|
||||
|
||||
|
||||
def _get_chart_store_context() -> str:
|
||||
"""Get current ChartStore state for context injection.
|
||||
|
||||
Returns:
|
||||
Formatted string with ChartStore contents, or empty string if unavailable
|
||||
"""
|
||||
try:
|
||||
from agent.tools import _registry
|
||||
|
||||
if not _registry:
|
||||
return ""
|
||||
|
||||
chart_store = _registry.entries.get("ChartStore")
|
||||
if not chart_store:
|
||||
return ""
|
||||
|
||||
chart_state = chart_store.model.model_dump(mode="json")
|
||||
chart_data = chart_state.get("chart_state", {})
|
||||
|
||||
# Only include if there's actual chart data
|
||||
if not chart_data or not chart_data.get("symbol"):
|
||||
return ""
|
||||
|
||||
# Format the chart information
|
||||
symbol = chart_data.get("symbol", "N/A")
|
||||
interval = chart_data.get("interval", "N/A")
|
||||
start_time = chart_data.get("start_time")
|
||||
end_time = chart_data.get("end_time")
|
||||
selected_shapes = chart_data.get("selected_shapes", [])
|
||||
|
||||
selected_info = ""
|
||||
if selected_shapes:
|
||||
selected_info = f"\n- **Selected Shapes**: {len(selected_shapes)} shape(s) selected (IDs: {', '.join(selected_shapes)})"
|
||||
|
||||
chart_context = f"""
|
||||
## Current Chart Context
|
||||
|
||||
The user is currently viewing a chart with the following settings:
|
||||
- **Symbol**: {symbol}
|
||||
- **Interval**: {interval}
|
||||
- **Time Range**: {f"from {start_time} to {end_time}" if start_time and end_time else "not set"}{selected_info}
|
||||
|
||||
This information is automatically available because you're connected via websocket.
|
||||
When the user refers to "the chart", "this chart", or "what I'm viewing", this is what they mean.
|
||||
"""
|
||||
return chart_context
|
||||
|
||||
except Exception:
|
||||
# Silently fail - chart context is optional enhancement
|
||||
return ""
|
||||
|
||||
|
||||
def build_system_prompt(context: str, active_channels: List[str]) -> str:
|
||||
"""Build the system prompt for the agent.
|
||||
|
||||
The main system prompt comes from system_prompt.md (loaded in context).
|
||||
This function adds dynamic session information.
|
||||
|
||||
Args:
|
||||
context: Context from loaded markdown documents (includes system_prompt.md)
|
||||
active_channels: List of active channel IDs for this session
|
||||
|
||||
Returns:
|
||||
Formatted system prompt
|
||||
"""
|
||||
channels_str = ", ".join(active_channels) if active_channels else "none"
|
||||
|
||||
# Check if user is connected via websocket - if so, inject chart context
|
||||
# Note: We check for websocket by looking for "websocket" in channel IDs
|
||||
# since WebSocketChannel uses channel_id like "websocket-{uuid}"
|
||||
has_websocket = any("websocket" in channel_id.lower() for channel_id in active_channels)
|
||||
|
||||
chart_context = ""
|
||||
if has_websocket:
|
||||
chart_context = _get_chart_store_context()
|
||||
|
||||
# Context already includes system_prompt.md and other docs
|
||||
# Just add current session information
|
||||
prompt = f"""{context}
|
||||
|
||||
## Current Session Information
|
||||
|
||||
**Active Channels**: {channels_str}
|
||||
|
||||
Your responses will be sent to all active channels. Your responses are streamed back in real-time.
|
||||
If the user sends a new message while you're responding, your current response will be interrupted
|
||||
and you'll be re-invoked with the updated context.
|
||||
{chart_context}"""
|
||||
return prompt
|
||||
|
||||
|
||||
def build_user_prompt_with_history(session: UserSession, current_message: str) -> str:
|
||||
"""Build a user prompt including conversation history.
|
||||
|
||||
Args:
|
||||
session: User session with conversation history
|
||||
current_message: Current user message
|
||||
|
||||
Returns:
|
||||
Formatted prompt with history
|
||||
"""
|
||||
messages = []
|
||||
|
||||
# Get recent history (last 10 messages)
|
||||
history = session.get_history(limit=10)
|
||||
|
||||
for msg in history:
|
||||
role_label = "User" if msg.role == "user" else "Assistant"
|
||||
messages.append(f"{role_label}: {msg.content}")
|
||||
|
||||
# Add current message
|
||||
messages.append(f"User: {current_message}")
|
||||
|
||||
return "\n\n".join(messages)
|
||||
218
backend.old/src/agent/routers.py
Normal file
218
backend.old/src/agent/routers.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""Tool router functions for hierarchical agent architecture.
|
||||
|
||||
This module provides meta-tools that route tasks to specialized sub-agents.
|
||||
The main agent uses these routers instead of accessing all tools directly.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from langchain_core.tools import tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global sub-agent instances (set by create_agent)
|
||||
_chart_agent = None
|
||||
_data_agent = None
|
||||
_automation_agent = None
|
||||
_research_agent = None
|
||||
|
||||
|
||||
def set_chart_agent(agent):
|
||||
"""Set the global chart sub-agent instance."""
|
||||
global _chart_agent
|
||||
_chart_agent = agent
|
||||
|
||||
|
||||
def set_data_agent(agent):
|
||||
"""Set the global data sub-agent instance."""
|
||||
global _data_agent
|
||||
_data_agent = agent
|
||||
|
||||
|
||||
def set_automation_agent(agent):
|
||||
"""Set the global automation sub-agent instance."""
|
||||
global _automation_agent
|
||||
_automation_agent = agent
|
||||
|
||||
|
||||
def set_research_agent(agent):
|
||||
"""Set the global research sub-agent instance."""
|
||||
global _research_agent
|
||||
_research_agent = agent
|
||||
|
||||
|
||||
@tool
|
||||
async def use_chart_analysis(task: str) -> str:
|
||||
"""Analyze charts, compute indicators, execute Python code, and create visualizations.
|
||||
|
||||
This tool delegates to a specialized chart analysis agent that has access to:
|
||||
- Chart data retrieval (get_chart_data)
|
||||
- Python execution environment with pandas, numpy, matplotlib, talib
|
||||
- Technical indicator tools (add/remove indicators, search indicators)
|
||||
- Shape drawing tools (create/update/delete shapes on charts)
|
||||
|
||||
Use this when the user wants to:
|
||||
- Analyze price action or patterns
|
||||
- Calculate technical indicators (RSI, MACD, Bollinger Bands, etc.)
|
||||
- Execute custom Python analysis on OHLCV data
|
||||
- Generate charts and visualizations
|
||||
- Draw trendlines, support/resistance, or other shapes
|
||||
- Perform statistical analysis on market data
|
||||
|
||||
Args:
|
||||
task: Detailed description of the chart analysis task. Include:
|
||||
- What to analyze (which symbol, timeframe if different from current)
|
||||
- What indicators or calculations to perform
|
||||
- What visualizations to create
|
||||
- Any specific questions to answer
|
||||
|
||||
Returns:
|
||||
The chart agent's analysis results, including computed values,
|
||||
plot URLs if visualizations were created, and interpretation.
|
||||
|
||||
Examples:
|
||||
- "Calculate RSI(14) for the current chart and tell me if it's overbought"
|
||||
- "Draw a trendline connecting the last 3 swing lows"
|
||||
- "Compute Bollinger Bands (20, 2) and create a chart showing price vs bands"
|
||||
- "Analyze the last 100 bars and identify key support/resistance levels"
|
||||
- "Execute Python: calculate correlation between BTC and ETH over the last 30 days"
|
||||
"""
|
||||
if not _chart_agent:
|
||||
return "Error: Chart analysis agent not initialized"
|
||||
|
||||
logger.info(f"Routing to chart agent: {task[:100]}...")
|
||||
result = await _chart_agent.execute(task)
|
||||
return result
|
||||
|
||||
|
||||
@tool
|
||||
async def use_data_access(task: str) -> str:
|
||||
"""Search for symbols and retrieve market data from exchanges.
|
||||
|
||||
This tool delegates to a specialized data access agent that has access to:
|
||||
- Symbol search across multiple exchanges
|
||||
- Historical OHLCV data retrieval
|
||||
- Symbol metadata and info
|
||||
- Available data sources and exchanges
|
||||
|
||||
Use this when the user wants to:
|
||||
- Search for a trading symbol or ticker
|
||||
- Get historical price data
|
||||
- Find out what exchanges support a symbol
|
||||
- Retrieve symbol metadata (price scale, supported resolutions, etc.)
|
||||
- Check what data sources are available
|
||||
|
||||
Args:
|
||||
task: Detailed description of the data access task. Include:
|
||||
- What symbol or instrument to search for
|
||||
- What data to retrieve (time range, resolution)
|
||||
- What metadata is needed
|
||||
|
||||
Returns:
|
||||
The data agent's response with requested symbols, data, or metadata.
|
||||
|
||||
Examples:
|
||||
- "Search for Bitcoin symbols on Binance"
|
||||
- "Get the last 100 hours of BTC/USDT 1-hour data from Binance"
|
||||
- "Find all symbols matching 'ETH' on all exchanges"
|
||||
- "Get detailed info about symbol BTC/USDT on Binance"
|
||||
- "List all available data sources"
|
||||
"""
|
||||
if not _data_agent:
|
||||
return "Error: Data access agent not initialized"
|
||||
|
||||
logger.info(f"Routing to data agent: {task[:100]}...")
|
||||
result = await _data_agent.execute(task)
|
||||
return result
|
||||
|
||||
|
||||
@tool
|
||||
async def use_automation(task: str) -> str:
|
||||
"""Schedule recurring tasks, create triggers, and manage automation.
|
||||
|
||||
This tool delegates to a specialized automation agent that has access to:
|
||||
- Scheduled agent prompts (cron and interval-based)
|
||||
- One-time agent prompt execution
|
||||
- Trigger management (list, cancel scheduled jobs)
|
||||
- System stats and monitoring
|
||||
|
||||
Use this when the user wants to:
|
||||
- Schedule a recurring task (hourly, daily, weekly, etc.)
|
||||
- Run a one-time background analysis
|
||||
- Set up automated monitoring or alerts
|
||||
- List or cancel existing scheduled tasks
|
||||
- Check trigger system status
|
||||
|
||||
Args:
|
||||
task: Detailed description of the automation task. Include:
|
||||
- What should happen (what analysis or action)
|
||||
- When it should happen (schedule, frequency)
|
||||
- Any priorities or conditions
|
||||
|
||||
Returns:
|
||||
The automation agent's response with job IDs, confirmation,
|
||||
or status information.
|
||||
|
||||
Examples:
|
||||
- "Schedule a task to check BTC price every 5 minutes"
|
||||
- "Run a one-time analysis of ETH volume in the background"
|
||||
- "Set up a daily report at 9 AM with market summary"
|
||||
- "Show me all my scheduled tasks"
|
||||
- "Cancel the hourly BTC monitor job"
|
||||
"""
|
||||
if not _automation_agent:
|
||||
return "Error: Automation agent not initialized"
|
||||
|
||||
logger.info(f"Routing to automation agent: {task[:100]}...")
|
||||
result = await _automation_agent.execute(task)
|
||||
return result
|
||||
|
||||
|
||||
@tool
|
||||
async def use_research(task: str) -> str:
|
||||
"""Search the web, academic papers, and external APIs for information.
|
||||
|
||||
This tool delegates to a specialized research agent that has access to:
|
||||
- Web search (DuckDuckGo)
|
||||
- Academic paper search (arXiv)
|
||||
- Wikipedia lookup
|
||||
- HTTP requests to public APIs
|
||||
|
||||
Use this when the user wants to:
|
||||
- Search for current news or events
|
||||
- Find academic papers on trading strategies
|
||||
- Look up financial concepts or terms
|
||||
- Fetch data from external public APIs
|
||||
- Research market trends or sentiment
|
||||
|
||||
Args:
|
||||
task: Detailed description of the research task. Include:
|
||||
- What information to find
|
||||
- What sources to search (web, arxiv, wikipedia, APIs)
|
||||
- What to focus on or filter
|
||||
|
||||
Returns:
|
||||
The research agent's findings with sources, summaries, and links.
|
||||
|
||||
Examples:
|
||||
- "Search arXiv for papers on reinforcement learning for trading"
|
||||
- "Look up 'technical analysis' on Wikipedia"
|
||||
- "Search the web for latest Ethereum news"
|
||||
- "Fetch current BTC price from CoinGecko API"
|
||||
- "Find recent papers on market microstructure"
|
||||
"""
|
||||
if not _research_agent:
|
||||
return "Error: Research agent not initialized"
|
||||
|
||||
logger.info(f"Routing to research agent: {task[:100]}...")
|
||||
result = await _research_agent.execute(task)
|
||||
return result
|
||||
|
||||
|
||||
# Export router tools
|
||||
ROUTER_TOOLS = [
|
||||
use_chart_analysis,
|
||||
use_data_access,
|
||||
use_automation,
|
||||
use_research
|
||||
]
|
||||
248
backend.old/src/agent/subagent.py
Normal file
248
backend.old/src/agent/subagent.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""Sub-agent infrastructure for specialized tool routing.
|
||||
|
||||
This module provides the SubAgent class that wraps specialized agents
|
||||
with their own tools and system prompts.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, AsyncIterator
|
||||
from pathlib import Path
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SubAgent:
|
||||
"""A specialized sub-agent with its own tools and system prompt.
|
||||
|
||||
Sub-agents are lightweight, stateless agents that focus on specific domains.
|
||||
They use in-memory checkpointing since they don't need persistent state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
soul_file: str,
|
||||
tools: List,
|
||||
model_name: str = "claude-sonnet-4-20250514",
|
||||
temperature: float = 0.7,
|
||||
api_key: Optional[str] = None,
|
||||
base_dir: str = "."
|
||||
):
|
||||
"""Initialize a sub-agent.
|
||||
|
||||
Args:
|
||||
name: Agent name (e.g., "chart", "data", "automation")
|
||||
soul_file: Filename in /soul directory (e.g., "chart_agent.md")
|
||||
tools: List of LangChain tools for this agent
|
||||
model_name: Anthropic model name
|
||||
temperature: Model temperature
|
||||
api_key: Anthropic API key
|
||||
base_dir: Base directory for resolving paths
|
||||
"""
|
||||
self.name = name
|
||||
self.soul_file = soul_file
|
||||
self.tools = tools
|
||||
self.model_name = model_name
|
||||
self.temperature = temperature
|
||||
self.api_key = api_key
|
||||
self.base_dir = base_dir
|
||||
|
||||
# Load system prompt from soul file
|
||||
soul_path = Path(base_dir) / "soul" / soul_file
|
||||
if soul_path.exists():
|
||||
with open(soul_path, "r") as f:
|
||||
self.system_prompt = f.read()
|
||||
logger.info(f"SubAgent '{name}': Loaded system prompt from {soul_path}")
|
||||
else:
|
||||
logger.warning(f"SubAgent '{name}': Soul file not found at {soul_path}, using default")
|
||||
self.system_prompt = f"You are a specialized {name} agent."
|
||||
|
||||
# Initialize LLM
|
||||
self.llm = ChatAnthropic(
|
||||
model=model_name,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
# Create agent with in-memory checkpointer (stateless)
|
||||
checkpointer = MemorySaver()
|
||||
self.agent = create_react_agent(
|
||||
self.llm,
|
||||
tools,
|
||||
checkpointer=checkpointer
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"SubAgent '{name}' initialized with {len(tools)} tools, "
|
||||
f"model={model_name}, temp={temperature}"
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
task: str,
|
||||
thread_id: Optional[str] = None
|
||||
) -> str:
|
||||
"""Execute a task with this sub-agent.
|
||||
|
||||
Args:
|
||||
task: The task/prompt for this sub-agent
|
||||
thread_id: Optional thread ID for checkpointing (uses ephemeral ID if not provided)
|
||||
|
||||
Returns:
|
||||
The agent's complete response as a string
|
||||
"""
|
||||
import uuid
|
||||
|
||||
# Use ephemeral thread ID if not provided
|
||||
if thread_id is None:
|
||||
thread_id = f"subagent-{self.name}-{uuid.uuid4()}"
|
||||
|
||||
logger.info(f"SubAgent '{self.name}': Executing task (thread_id={thread_id})")
|
||||
logger.debug(f"SubAgent '{self.name}': Task: {task[:200]}...")
|
||||
|
||||
# Build messages with system prompt
|
||||
messages = [
|
||||
HumanMessage(content=task)
|
||||
]
|
||||
|
||||
# Prepare config with system prompt injection
|
||||
config = RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": thread_id,
|
||||
"state_modifier": self.system_prompt
|
||||
},
|
||||
metadata={
|
||||
"subagent_name": self.name
|
||||
}
|
||||
)
|
||||
|
||||
# Execute and collect response
|
||||
full_response = ""
|
||||
event_count = 0
|
||||
|
||||
try:
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
config=config,
|
||||
version="v2"
|
||||
):
|
||||
event_count += 1
|
||||
|
||||
# Log tool calls
|
||||
if event["event"] == "on_tool_start":
|
||||
tool_name = event.get("name", "unknown")
|
||||
logger.debug(f"SubAgent '{self.name}': Tool call started: {tool_name}")
|
||||
|
||||
elif event["event"] == "on_tool_end":
|
||||
tool_name = event.get("name", "unknown")
|
||||
logger.debug(f"SubAgent '{self.name}': Tool call completed: {tool_name}")
|
||||
|
||||
# Extract streaming tokens
|
||||
elif event["event"] == "on_chat_model_stream":
|
||||
chunk = event["data"]["chunk"]
|
||||
if hasattr(chunk, "content") and chunk.content:
|
||||
content = chunk.content
|
||||
# Handle both string and list content
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and "text" in block:
|
||||
text_parts.append(block["text"])
|
||||
elif hasattr(block, "text"):
|
||||
text_parts.append(block.text)
|
||||
content = "".join(text_parts)
|
||||
|
||||
if content:
|
||||
full_response += content
|
||||
|
||||
logger.info(
|
||||
f"SubAgent '{self.name}': Completed task "
|
||||
f"({event_count} events, {len(full_response)} chars)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"SubAgent '{self.name}' execution error: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return f"Error: {error_msg}"
|
||||
|
||||
return full_response
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
task: str,
|
||||
thread_id: Optional[str] = None
|
||||
) -> AsyncIterator[str]:
|
||||
"""Execute a task with streaming response.
|
||||
|
||||
Args:
|
||||
task: The task/prompt for this sub-agent
|
||||
thread_id: Optional thread ID for checkpointing
|
||||
|
||||
Yields:
|
||||
Response chunks as they're generated
|
||||
"""
|
||||
import uuid
|
||||
|
||||
# Use ephemeral thread ID if not provided
|
||||
if thread_id is None:
|
||||
thread_id = f"subagent-{self.name}-{uuid.uuid4()}"
|
||||
|
||||
logger.info(f"SubAgent '{self.name}': Streaming task (thread_id={thread_id})")
|
||||
|
||||
# Build messages with system prompt
|
||||
messages = [
|
||||
HumanMessage(content=task)
|
||||
]
|
||||
|
||||
# Prepare config
|
||||
config = RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": thread_id,
|
||||
"state_modifier": self.system_prompt
|
||||
},
|
||||
metadata={
|
||||
"subagent_name": self.name
|
||||
}
|
||||
)
|
||||
|
||||
# Stream response
|
||||
try:
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
config=config,
|
||||
version="v2"
|
||||
):
|
||||
# Log tool calls
|
||||
if event["event"] == "on_tool_start":
|
||||
tool_name = event.get("name", "unknown")
|
||||
logger.debug(f"SubAgent '{self.name}': Tool call started: {tool_name}")
|
||||
|
||||
# Extract streaming tokens
|
||||
elif event["event"] == "on_chat_model_stream":
|
||||
chunk = event["data"]["chunk"]
|
||||
if hasattr(chunk, "content") and chunk.content:
|
||||
content = chunk.content
|
||||
# Handle both string and list content
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and "text" in block:
|
||||
text_parts.append(block["text"])
|
||||
elif hasattr(block, "text"):
|
||||
text_parts.append(block.text)
|
||||
content = "".join(text_parts)
|
||||
|
||||
if content:
|
||||
yield content
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"SubAgent '{self.name}' streaming error: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
yield f"Error: {error_msg}"
|
||||
139
backend.old/src/agent/tools/CHART_UTILS_README.md
Normal file
139
backend.old/src/agent/tools/CHART_UTILS_README.md
Normal file
@@ -0,0 +1,139 @@
|
||||
# Chart Utilities - Standard OHLC Plotting
|
||||
|
||||
## Overview
|
||||
|
||||
The `chart_utils.py` module provides convenience functions for creating beautiful, professional OHLC candlestick charts with a consistent look and feel. This is designed to be used by the LLM in `analyze_chart_data` scripts, eliminating the need to write custom matplotlib code for every chart.
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Beautiful by default**: Uses mplfinance with seaborn-inspired aesthetics
|
||||
- **Consistent styling**: Professional color scheme (teal green up, coral red down)
|
||||
- **Easy to use**: Simple function calls instead of complex matplotlib code
|
||||
- **Customizable**: Supports all mplfinance options via kwargs
|
||||
- **Volume integration**: Optional volume subplot
|
||||
|
||||
## Installation
|
||||
|
||||
The required package `mplfinance` has been added to `requirements.txt`:
|
||||
|
||||
```bash
|
||||
pip install mplfinance
|
||||
```
|
||||
|
||||
## Available Functions
|
||||
|
||||
### 1. `plot_ohlc(df, title=None, volume=True, figsize=(14, 8), **kwargs)`
|
||||
|
||||
Main function for creating standard OHLC candlestick charts.
|
||||
|
||||
**Parameters:**
|
||||
- `df`: pandas DataFrame with DatetimeIndex and OHLCV columns
|
||||
- `title`: Optional chart title
|
||||
- `volume`: Whether to include volume subplot (default: True)
|
||||
- `figsize`: Figure size in inches (default: (14, 8))
|
||||
- `**kwargs`: Additional mplfinance.plot() arguments
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
fig = plot_ohlc(df, title='BTC/USDT 15min', volume=True)
|
||||
```
|
||||
|
||||
### 2. `add_indicators_to_plot(df, indicators, **plot_kwargs)`
|
||||
|
||||
Creates OHLC chart with technical indicators overlaid.
|
||||
|
||||
**Parameters:**
|
||||
- `df`: DataFrame with OHLCV data and indicator columns
|
||||
- `indicators`: Dict mapping indicator column names to display parameters
|
||||
- `**plot_kwargs`: Additional arguments for plot_ohlc()
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
df['SMA_20'] = df['close'].rolling(20).mean()
|
||||
df['SMA_50'] = df['close'].rolling(50).mean()
|
||||
|
||||
fig = add_indicators_to_plot(
|
||||
df,
|
||||
indicators={
|
||||
'SMA_20': {'color': 'blue', 'width': 1.5},
|
||||
'SMA_50': {'color': 'red', 'width': 1.5}
|
||||
},
|
||||
title='Price with Moving Averages'
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Preset Functions
|
||||
|
||||
- `plot_price_volume(df, title=None)` - Standard price + volume chart
|
||||
- `plot_price_only(df, title=None)` - Candlesticks without volume
|
||||
|
||||
## Integration with analyze_chart_data
|
||||
|
||||
These functions are automatically available in the `analyze_chart_data` tool's script environment:
|
||||
|
||||
```python
|
||||
# In an analyze_chart_data script:
|
||||
# df is already provided
|
||||
|
||||
# Simple usage
|
||||
fig = plot_ohlc(df, title='Price Action')
|
||||
|
||||
# With indicators
|
||||
df['SMA'] = df['close'].rolling(20).mean()
|
||||
fig = add_indicators_to_plot(
|
||||
df,
|
||||
indicators={'SMA': {'color': 'blue', 'width': 1.5}},
|
||||
title='Price with SMA'
|
||||
)
|
||||
|
||||
# Return data for the assistant
|
||||
df[['close', 'SMA']].tail(10)
|
||||
```
|
||||
|
||||
## Styling
|
||||
|
||||
The default style includes:
|
||||
- **Up candles**: Teal green (#26a69a)
|
||||
- **Down candles**: Coral red (#ef5350)
|
||||
- **Background**: Light gray with white axes
|
||||
- **Grid**: Subtle dashed lines with 30% alpha
|
||||
- **Professional fonts**: Clean, readable sizes
|
||||
|
||||
## Why This Matters
|
||||
|
||||
**Before:**
|
||||
```python
|
||||
# LLM had to write this every time
|
||||
import matplotlib.pyplot as plt
|
||||
fig, ax = plt.subplots(figsize=(12, 6))
|
||||
ax.plot(df.index, df['close'], label='Close')
|
||||
# ... lots more code for styling, colors, etc.
|
||||
```
|
||||
|
||||
**After:**
|
||||
```python
|
||||
# LLM can now just do this
|
||||
fig = plot_ohlc(df, title='BTC/USDT')
|
||||
```
|
||||
|
||||
Benefits:
|
||||
- ✅ Less code to generate → faster response
|
||||
- ✅ Consistent appearance across all charts
|
||||
- ✅ Professional look out of the box
|
||||
- ✅ Easier to maintain and customize
|
||||
- ✅ Better use of mplfinance's candlestick rendering
|
||||
|
||||
## Example Output
|
||||
|
||||
See `chart_utils_example.py` for runnable examples demonstrating:
|
||||
1. Basic OHLC chart with volume
|
||||
2. OHLC chart with multiple indicators
|
||||
3. Price-only chart
|
||||
4. Custom styling options
|
||||
|
||||
## File Locations
|
||||
|
||||
- **Main module**: `backend/src/agent/tools/chart_utils.py`
|
||||
- **Integration**: `backend/src/agent/tools/chart_tools.py` (lines 306-328)
|
||||
- **Examples**: `backend/src/agent/tools/chart_utils_example.py`
|
||||
- **Dependency**: `backend/requirements.txt` (mplfinance added)
|
||||
373
backend.old/src/agent/tools/TRIGGER_TOOLS.md
Normal file
373
backend.old/src/agent/tools/TRIGGER_TOOLS.md
Normal file
@@ -0,0 +1,373 @@
|
||||
# Agent Trigger Tools
|
||||
|
||||
Agent tools for automating tasks via the trigger system.
|
||||
|
||||
## Overview
|
||||
|
||||
These tools allow the agent to:
|
||||
- **Schedule recurring tasks** - Run agent prompts on intervals or cron schedules
|
||||
- **Execute one-time tasks** - Trigger sub-agent runs immediately
|
||||
- **Manage scheduled jobs** - List and cancel scheduled triggers
|
||||
- **React to events** - (Future) Connect data updates to agent actions
|
||||
|
||||
## Available Tools
|
||||
|
||||
### 1. `schedule_agent_prompt`
|
||||
|
||||
Schedule an agent to run with a specific prompt on a recurring schedule.
|
||||
|
||||
**Use Cases:**
|
||||
- Daily market analysis reports
|
||||
- Hourly portfolio rebalancing checks
|
||||
- Weekly performance summaries
|
||||
- Monitoring alerts
|
||||
|
||||
**Arguments:**
|
||||
- `prompt` (str): The prompt to send to the agent when triggered
|
||||
- `schedule_type` (str): "interval" or "cron"
|
||||
- `schedule_config` (dict): Schedule configuration
|
||||
- `name` (str, optional): Descriptive name for this task
|
||||
|
||||
**Schedule Config:**
|
||||
|
||||
*Interval-based:*
|
||||
```json
|
||||
{"minutes": 5}
|
||||
{"hours": 1, "minutes": 30}
|
||||
{"seconds": 30}
|
||||
```
|
||||
|
||||
*Cron-based:*
|
||||
```json
|
||||
{"hour": "9", "minute": "0"} // Daily at 9:00 AM
|
||||
{"hour": "9", "minute": "0", "day_of_week": "mon-fri"} // Weekdays at 9 AM
|
||||
{"minute": "0"} // Every hour on the hour
|
||||
{"hour": "*/6", "minute": "0"} // Every 6 hours
|
||||
```
|
||||
|
||||
**Returns:**
|
||||
```json
|
||||
{
|
||||
"job_id": "interval_123",
|
||||
"message": "Scheduled 'daily_report' with job_id=interval_123",
|
||||
"schedule_type": "cron",
|
||||
"config": {"hour": "9", "minute": "0"}
|
||||
}
|
||||
```
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
# Every 5 minutes: check BTC price
|
||||
schedule_agent_prompt(
|
||||
prompt="Check current BTC price on Binance. If > $50k, alert me.",
|
||||
schedule_type="interval",
|
||||
schedule_config={"minutes": 5},
|
||||
name="btc_price_monitor"
|
||||
)
|
||||
|
||||
# Daily at 9 AM: market summary
|
||||
schedule_agent_prompt(
|
||||
prompt="Generate a comprehensive market summary for BTC, ETH, and SOL. Include price changes, volume, and notable events from the last 24 hours.",
|
||||
schedule_type="cron",
|
||||
schedule_config={"hour": "9", "minute": "0"},
|
||||
name="daily_market_summary"
|
||||
)
|
||||
|
||||
# Every hour on weekdays: portfolio check
|
||||
schedule_agent_prompt(
|
||||
prompt="Review current portfolio positions. Check if any rebalancing is needed based on target allocations.",
|
||||
schedule_type="cron",
|
||||
schedule_config={"minute": "0", "day_of_week": "mon-fri"},
|
||||
name="hourly_portfolio_check"
|
||||
)
|
||||
```
|
||||
|
||||
### 2. `execute_agent_prompt_once`
|
||||
|
||||
Execute an agent prompt once, immediately (enqueued with priority).
|
||||
|
||||
**Use Cases:**
|
||||
- Background analysis tasks
|
||||
- One-time data processing
|
||||
- Responding to specific events
|
||||
- Sub-agent delegation
|
||||
|
||||
**Arguments:**
|
||||
- `prompt` (str): The prompt to send to the agent
|
||||
- `priority` (str): "high", "normal", or "low" (default: "normal")
|
||||
|
||||
**Returns:**
|
||||
```json
|
||||
{
|
||||
"queue_seq": 42,
|
||||
"message": "Enqueued agent prompt with priority=normal",
|
||||
"prompt": "Analyze the last 100 BTC/USDT bars..."
|
||||
}
|
||||
```
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
# Immediate analysis with high priority
|
||||
execute_agent_prompt_once(
|
||||
prompt="Analyze the last 100 BTC/USDT 1m bars and identify key support/resistance levels",
|
||||
priority="high"
|
||||
)
|
||||
|
||||
# Background task with normal priority
|
||||
execute_agent_prompt_once(
|
||||
prompt="Research the latest news about Ethereum upgrades and summarize findings",
|
||||
priority="normal"
|
||||
)
|
||||
|
||||
# Low priority cleanup task
|
||||
execute_agent_prompt_once(
|
||||
prompt="Review and archive old chart drawings from last month",
|
||||
priority="low"
|
||||
)
|
||||
```
|
||||
|
||||
### 3. `list_scheduled_triggers`
|
||||
|
||||
List all currently scheduled triggers.
|
||||
|
||||
**Returns:**
|
||||
```json
|
||||
[
|
||||
{
|
||||
"id": "cron_456",
|
||||
"name": "Cron: daily_market_summary",
|
||||
"next_run_time": "2024-03-05 09:00:00",
|
||||
"trigger": "cron[hour='9', minute='0']"
|
||||
},
|
||||
{
|
||||
"id": "interval_123",
|
||||
"name": "Interval: btc_price_monitor",
|
||||
"next_run_time": "2024-03-04 14:35:00",
|
||||
"trigger": "interval[0:05:00]"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
jobs = list_scheduled_triggers()
|
||||
|
||||
for job in jobs:
|
||||
print(f"{job['name']} - next run: {job['next_run_time']}")
|
||||
```
|
||||
|
||||
### 4. `cancel_scheduled_trigger`
|
||||
|
||||
Cancel a scheduled trigger by its job ID.
|
||||
|
||||
**Arguments:**
|
||||
- `job_id` (str): The job ID from `schedule_agent_prompt` or `list_scheduled_triggers`
|
||||
|
||||
**Returns:**
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"message": "Cancelled job interval_123"
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
# List jobs to find the ID
|
||||
jobs = list_scheduled_triggers()
|
||||
|
||||
# Cancel specific job
|
||||
cancel_scheduled_trigger("interval_123")
|
||||
```
|
||||
|
||||
### 5. `on_data_update_run_agent`
|
||||
|
||||
**(Future)** Set up an agent to run whenever new data arrives for a specific symbol.
|
||||
|
||||
**Arguments:**
|
||||
- `source_name` (str): Data source name (e.g., "binance")
|
||||
- `symbol` (str): Trading pair (e.g., "BTC/USDT")
|
||||
- `resolution` (str): Time resolution (e.g., "1m", "5m")
|
||||
- `prompt_template` (str): Template with variables like {close}, {volume}, {symbol}
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
on_data_update_run_agent(
|
||||
source_name="binance",
|
||||
symbol="BTC/USDT",
|
||||
resolution="1m",
|
||||
prompt_template="New bar on {symbol}: close={close}, volume={volume}. Check if price crossed any key levels."
|
||||
)
|
||||
```
|
||||
|
||||
### 6. `get_trigger_system_stats`
|
||||
|
||||
Get statistics about the trigger system.
|
||||
|
||||
**Returns:**
|
||||
```json
|
||||
{
|
||||
"queue_depth": 3,
|
||||
"queue_running": true,
|
||||
"coordinator_stats": {
|
||||
"current_seq": 1042,
|
||||
"next_commit_seq": 1043,
|
||||
"pending_commits": 1,
|
||||
"total_executions": 1042,
|
||||
"state_counts": {
|
||||
"COMMITTED": 1038,
|
||||
"EXECUTING": 2,
|
||||
"WAITING_COMMIT": 1,
|
||||
"FAILED": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
stats = get_trigger_system_stats()
|
||||
print(f"Queue has {stats['queue_depth']} pending triggers")
|
||||
print(f"System has processed {stats['coordinator_stats']['total_executions']} total triggers")
|
||||
```
|
||||
|
||||
## Integration Example
|
||||
|
||||
Here's how these tools enable autonomous agent behavior:
|
||||
|
||||
```python
|
||||
# Agent conversation:
|
||||
User: "Monitor BTC price and send me a summary every hour during market hours"
|
||||
|
||||
Agent: I'll set that up for you using the trigger system.
|
||||
|
||||
# Agent uses tool:
|
||||
schedule_agent_prompt(
|
||||
prompt="""
|
||||
Check the current BTC/USDT price on Binance.
|
||||
Calculate the price change from 1 hour ago.
|
||||
If price moved > 2%, provide a detailed analysis.
|
||||
Otherwise, provide a brief status update.
|
||||
Send results to user as a notification.
|
||||
""",
|
||||
schedule_type="cron",
|
||||
schedule_config={
|
||||
"minute": "0",
|
||||
"hour": "9-17", # 9 AM to 5 PM
|
||||
"day_of_week": "mon-fri"
|
||||
},
|
||||
name="btc_hourly_monitor"
|
||||
)
|
||||
|
||||
Agent: Done! I've scheduled an hourly BTC price monitor that runs during market hours (9 AM - 5 PM on weekdays). You'll receive updates every hour.
|
||||
|
||||
# Later...
|
||||
User: "Can you show me all my scheduled tasks?"
|
||||
|
||||
Agent: Let me check what's scheduled.
|
||||
|
||||
# Agent uses tool:
|
||||
jobs = list_scheduled_triggers()
|
||||
|
||||
Agent: You have 3 scheduled tasks:
|
||||
1. "btc_hourly_monitor" - runs every hour during market hours
|
||||
2. "daily_market_summary" - runs daily at 9 AM
|
||||
3. "portfolio_rebalance_check" - runs every 4 hours
|
||||
|
||||
Would you like to modify or cancel any of these?
|
||||
```
|
||||
|
||||
## Use Case: Autonomous Trading Bot
|
||||
|
||||
```python
|
||||
# Step 1: Set up data monitoring
|
||||
execute_agent_prompt_once(
|
||||
prompt="""
|
||||
Subscribe to BTC/USDT 1m bars from Binance.
|
||||
When subscribed, set up the following:
|
||||
1. Calculate RSI(14) on each new bar
|
||||
2. If RSI > 70, execute prompt: "RSI overbought on BTC, check if we should sell"
|
||||
3. If RSI < 30, execute prompt: "RSI oversold on BTC, check if we should buy"
|
||||
""",
|
||||
priority="high"
|
||||
)
|
||||
|
||||
# Step 2: Schedule periodic portfolio review
|
||||
schedule_agent_prompt(
|
||||
prompt="""
|
||||
Review current portfolio:
|
||||
1. Calculate current allocation percentages
|
||||
2. Compare to target allocation (60% BTC, 30% ETH, 10% stable)
|
||||
3. If deviation > 5%, generate rebalancing trades
|
||||
4. Submit trades for execution
|
||||
""",
|
||||
schedule_type="interval",
|
||||
schedule_config={"hours": 4},
|
||||
name="portfolio_rebalance"
|
||||
)
|
||||
|
||||
# Step 3: Schedule daily risk check
|
||||
schedule_agent_prompt(
|
||||
prompt="""
|
||||
Daily risk assessment:
|
||||
1. Calculate portfolio VaR (Value at Risk)
|
||||
2. Check current leverage across all positions
|
||||
3. Review stop-loss placements
|
||||
4. If risk exceeds threshold, alert and suggest adjustments
|
||||
""",
|
||||
schedule_type="cron",
|
||||
schedule_config={"hour": "8", "minute": "0"},
|
||||
name="daily_risk_check"
|
||||
)
|
||||
```
|
||||
|
||||
## Benefits
|
||||
|
||||
✅ **Autonomous operation** - Agent can schedule its own tasks
|
||||
✅ **Event-driven** - React to market data, time, or custom events
|
||||
✅ **Flexible scheduling** - Interval or cron-based
|
||||
✅ **Self-managing** - Agent can list and cancel its own jobs
|
||||
✅ **Priority control** - High-priority tasks jump the queue
|
||||
✅ **Future-proof** - Easy to add Python lambdas, strategy execution, etc.
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- **Python script execution** - Schedule arbitrary Python code
|
||||
- **Strategy triggers** - Connect to strategy execution system
|
||||
- **Event composition** - AND/OR logic for complex event patterns
|
||||
- **Conditional execution** - Only run if conditions met (e.g., volatility > threshold)
|
||||
- **Result chaining** - Use output of one trigger as input to another
|
||||
- **Backtesting mode** - Test trigger logic on historical data
|
||||
|
||||
## Setup in main.py
|
||||
|
||||
```python
|
||||
from agent.tools import set_trigger_queue, set_trigger_scheduler, set_coordinator
|
||||
from trigger import TriggerQueue, CommitCoordinator
|
||||
from trigger.scheduler import TriggerScheduler
|
||||
|
||||
# Initialize trigger system
|
||||
coordinator = CommitCoordinator()
|
||||
queue = TriggerQueue(coordinator)
|
||||
scheduler = TriggerScheduler(queue)
|
||||
|
||||
await queue.start()
|
||||
scheduler.start()
|
||||
|
||||
# Make available to agent tools
|
||||
set_trigger_queue(queue)
|
||||
set_trigger_scheduler(scheduler)
|
||||
set_coordinator(coordinator)
|
||||
|
||||
# Add TRIGGER_TOOLS to agent's tool list
|
||||
from agent.tools import TRIGGER_TOOLS
|
||||
agent_tools = [..., *TRIGGER_TOOLS]
|
||||
```
|
||||
|
||||
Now the agent has full control over the trigger system! 🚀
|
||||
64
backend.old/src/agent/tools/__init__.py
Normal file
64
backend.old/src/agent/tools/__init__.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Agent tools for trading operations.
|
||||
|
||||
This package provides tools for:
|
||||
- Synchronization stores (sync_tools)
|
||||
- Data sources and market data (datasource_tools)
|
||||
- Chart data access and analysis (chart_tools)
|
||||
- Technical indicators (indicator_tools)
|
||||
- Shape/drawing management (shape_tools)
|
||||
- Trigger system and automation (trigger_tools)
|
||||
"""
|
||||
|
||||
# Global registries that will be set by main.py
|
||||
_registry = None
|
||||
_datasource_registry = None
|
||||
_indicator_registry = None
|
||||
|
||||
|
||||
def set_registry(registry):
|
||||
"""Set the global SyncRegistry instance for tools to use."""
|
||||
global _registry
|
||||
_registry = registry
|
||||
|
||||
|
||||
def set_datasource_registry(datasource_registry):
|
||||
"""Set the global DataSourceRegistry instance for tools to use."""
|
||||
global _datasource_registry
|
||||
_datasource_registry = datasource_registry
|
||||
|
||||
|
||||
def set_indicator_registry(indicator_registry):
|
||||
"""Set the global IndicatorRegistry instance for tools to use."""
|
||||
global _indicator_registry
|
||||
_indicator_registry = indicator_registry
|
||||
|
||||
|
||||
# Import all tools from submodules
|
||||
from .sync_tools import SYNC_TOOLS
|
||||
from .datasource_tools import DATASOURCE_TOOLS
|
||||
from .chart_tools import CHART_TOOLS
|
||||
from .indicator_tools import INDICATOR_TOOLS
|
||||
from .research_tools import RESEARCH_TOOLS
|
||||
from .shape_tools import SHAPE_TOOLS
|
||||
from .trigger_tools import (
|
||||
TRIGGER_TOOLS,
|
||||
set_trigger_queue,
|
||||
set_trigger_scheduler,
|
||||
set_coordinator,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"set_registry",
|
||||
"set_datasource_registry",
|
||||
"set_indicator_registry",
|
||||
"set_trigger_queue",
|
||||
"set_trigger_scheduler",
|
||||
"set_coordinator",
|
||||
"SYNC_TOOLS",
|
||||
"DATASOURCE_TOOLS",
|
||||
"CHART_TOOLS",
|
||||
"INDICATOR_TOOLS",
|
||||
"RESEARCH_TOOLS",
|
||||
"SHAPE_TOOLS",
|
||||
"TRIGGER_TOOLS",
|
||||
]
|
||||
454
backend.old/src/agent/tools/chart_tools.py
Normal file
454
backend.old/src/agent/tools/chart_tools.py
Normal file
@@ -0,0 +1,454 @@
|
||||
"""Chart data access and analysis tools."""
|
||||
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
import io
|
||||
import uuid
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from contextlib import redirect_stdout, redirect_stderr
|
||||
from langchain_core.tools import tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_registry():
|
||||
"""Get the global registry instance."""
|
||||
from . import _registry
|
||||
return _registry
|
||||
|
||||
|
||||
def _get_datasource_registry():
|
||||
"""Get the global datasource registry instance."""
|
||||
from . import _datasource_registry
|
||||
return _datasource_registry
|
||||
|
||||
|
||||
def _get_indicator_registry():
|
||||
"""Get the global indicator registry instance."""
|
||||
from . import _indicator_registry
|
||||
return _indicator_registry
|
||||
|
||||
|
||||
def _get_order_store():
|
||||
"""Get the global OrderStore instance."""
|
||||
registry = _get_registry()
|
||||
if registry and "OrderStore" in registry.entries:
|
||||
return registry.entries["OrderStore"].model
|
||||
return None
|
||||
|
||||
|
||||
def _get_chart_store():
|
||||
"""Get the global ChartStore instance."""
|
||||
registry = _get_registry()
|
||||
if registry and "ChartStore" in registry.entries:
|
||||
return registry.entries["ChartStore"].model
|
||||
return None
|
||||
|
||||
|
||||
async def _get_chart_data_impl(countback: Optional[int] = None):
|
||||
"""Internal implementation for getting chart data.
|
||||
|
||||
This is a helper function that can be called by both get_chart_data tool
|
||||
and analyze_chart_data tool.
|
||||
|
||||
Returns:
|
||||
Tuple of (HistoryResult, chart_context dict, source_name)
|
||||
"""
|
||||
registry = _get_registry()
|
||||
datasource_registry = _get_datasource_registry()
|
||||
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized - cannot read ChartStore")
|
||||
|
||||
if not datasource_registry:
|
||||
raise ValueError("DataSourceRegistry not initialized - cannot query data")
|
||||
|
||||
# Read current chart state
|
||||
chart_store = registry.entries.get("ChartStore")
|
||||
if not chart_store:
|
||||
raise ValueError("ChartStore not found in registry")
|
||||
|
||||
chart_state = chart_store.model.model_dump(mode="json")
|
||||
chart_data = chart_state.get("chart_state", {})
|
||||
|
||||
symbol = chart_data.get("symbol", "")
|
||||
interval = chart_data.get("interval", "15")
|
||||
start_time = chart_data.get("start_time")
|
||||
end_time = chart_data.get("end_time")
|
||||
|
||||
if not symbol or symbol is None:
|
||||
raise ValueError(
|
||||
"No chart visible - ChartStore symbol is None. "
|
||||
"The user is likely on a narrow screen (mobile) where charts are hidden. "
|
||||
"Let them know they can view charts on a wider screen, or use get_historical_data() "
|
||||
"if they specify a symbol and timeframe."
|
||||
)
|
||||
|
||||
# Parse the symbol to extract exchange/source and symbol name
|
||||
# Format is "EXCHANGE:SYMBOL" (e.g., "BINANCE:BTC/USDT", "DEMO:BTC/USD")
|
||||
if ":" not in symbol:
|
||||
raise ValueError(
|
||||
f"Invalid symbol format: '{symbol}'. Expected format is 'EXCHANGE:SYMBOL' "
|
||||
f"(e.g., 'BINANCE:BTC/USDT' or 'DEMO:BTC/USD')"
|
||||
)
|
||||
|
||||
exchange_prefix, symbol_name = symbol.split(":", 1)
|
||||
source_name = exchange_prefix.lower()
|
||||
|
||||
# Get the data source
|
||||
source = datasource_registry.get(source_name)
|
||||
if not source:
|
||||
available = datasource_registry.list_sources()
|
||||
raise ValueError(
|
||||
f"Data source '{source_name}' not found. Available sources: {available}. "
|
||||
f"Make sure the exchange in the symbol '{symbol}' matches an available source."
|
||||
)
|
||||
|
||||
# Determine time range - REQUIRE it to be set, no defaults
|
||||
if start_time is None or end_time is None:
|
||||
raise ValueError(
|
||||
f"Chart time range not set in ChartStore. start_time={start_time}, end_time={end_time}. "
|
||||
f"The user needs to load the chart first, or the frontend may not be sending the visible range. "
|
||||
f"Wait for the chart to fully load before analyzing data."
|
||||
)
|
||||
|
||||
from_time = int(start_time)
|
||||
end_time = int(end_time)
|
||||
logger.info(
|
||||
f"Using ChartStore time range: from_time={from_time}, end_time={end_time}, "
|
||||
f"countback={countback}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Querying data source '{source_name}' for symbol '{symbol_name}', "
|
||||
f"resolution '{interval}'"
|
||||
)
|
||||
|
||||
# Query the data source
|
||||
result = await source.get_bars(
|
||||
symbol=symbol_name,
|
||||
resolution=interval,
|
||||
from_time=from_time,
|
||||
to_time=end_time,
|
||||
countback=countback
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Received {len(result.bars)} bars from data source. "
|
||||
f"First bar time: {result.bars[0].time if result.bars else 'N/A'}, "
|
||||
f"Last bar time: {result.bars[-1].time if result.bars else 'N/A'}"
|
||||
)
|
||||
|
||||
# Build chart context to return along with result
|
||||
chart_context = {
|
||||
"symbol": symbol,
|
||||
"interval": interval,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time
|
||||
}
|
||||
|
||||
return result, chart_context, source_name
|
||||
|
||||
|
||||
@tool
|
||||
async def get_chart_data(countback: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Get the candle/bar data for what the user is currently viewing on their chart.
|
||||
|
||||
This is a convenience tool that automatically:
|
||||
1. Reads the ChartStore to see what chart the user is viewing
|
||||
2. Parses the symbol to determine the data source (exchange prefix)
|
||||
3. Queries the appropriate data source for that symbol's data
|
||||
4. Returns the data for the visible time range and interval
|
||||
|
||||
This is the preferred way to access chart data when helping the user analyze
|
||||
what they're looking at, since it automatically uses their current chart context.
|
||||
|
||||
**IMPORTANT**: This tool will fail if ChartStore.symbol is None (no chart visible).
|
||||
This happens when the user is on a narrow screen (mobile) where charts are hidden.
|
||||
In that case, let the user know charts are only visible on wider screens, or use
|
||||
get_historical_data() if they specify a symbol and timeframe.
|
||||
|
||||
Args:
|
||||
countback: Optional limit on number of bars to return. If not specified,
|
||||
returns all bars in the visible time range.
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- chart_context: Current chart state (symbol, interval, time range)
|
||||
- symbol: The trading pair being viewed
|
||||
- resolution: The chart interval
|
||||
- bars: List of bar data with 'time' and 'data' fields
|
||||
- columns: Schema describing available data columns
|
||||
- source: Which data source was used
|
||||
|
||||
Raises:
|
||||
ValueError: If ChartStore or DataSourceRegistry is not initialized,
|
||||
if no chart is visible (symbol is None), or if the symbol format is invalid
|
||||
|
||||
Example:
|
||||
# User is viewing BINANCE:BTC/USDT on 15min chart
|
||||
data = get_chart_data()
|
||||
# Returns BTC/USDT data from binance source at 15min resolution
|
||||
# for the currently visible time range
|
||||
"""
|
||||
result, chart_context, source_name = await _get_chart_data_impl(countback)
|
||||
|
||||
# Return enriched result with chart context
|
||||
response = result.model_dump()
|
||||
response["chart_context"] = chart_context
|
||||
response["source"] = source_name
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@tool
|
||||
async def execute_python(code: str, countback: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Execute Python code for technical analysis with automatic chart data loading.
|
||||
|
||||
**PRIMARY TOOL for all technical analysis, indicator computation, and chart generation.**
|
||||
|
||||
This is your go-to tool whenever the user asks about indicators, wants to see
|
||||
a chart, or needs any computational analysis of market data.
|
||||
|
||||
Pre-loaded Environment:
|
||||
- `pd` : pandas
|
||||
- `np` : numpy
|
||||
- `plt` : matplotlib.pyplot (figures auto-saved to plot_urls)
|
||||
- `talib` : TA-Lib technical analysis library
|
||||
- `indicator_registry`: 150+ registered indicators
|
||||
- `plot_ohlc(df)` : Helper function for beautiful candlestick charts
|
||||
- `registry` : SyncRegistry instance - access to all registered stores
|
||||
- `datasource_registry`: DataSourceRegistry - access to data sources (binance, etc.)
|
||||
- `order_store` : OrderStore instance - current orders list
|
||||
- `chart_store` : ChartStore instance - current chart state
|
||||
|
||||
Auto-loaded when user has a chart visible (ChartStore.symbol is not None):
|
||||
- `df` : pandas DataFrame with DatetimeIndex and columns:
|
||||
open, high, low, close, volume (OHLCV data ready to use)
|
||||
- `chart_context` : dict with symbol, interval, start_time, end_time
|
||||
|
||||
When NO chart is visible (narrow screen/mobile):
|
||||
- `df` : None
|
||||
- `chart_context` : None
|
||||
|
||||
If `df` is None, you can still load alternative data by:
|
||||
- Using chart_store to see what symbol/timeframe is configured
|
||||
- Using datasource_registry.get_source('binance') to access data sources
|
||||
- Calling datasource.get_history(symbol, interval, start, end) to load any data
|
||||
- This allows you to make plots of ANY chart even when not connected to chart view
|
||||
|
||||
The `plot_ohlc()` Helper:
|
||||
Create professional candlestick charts instantly:
|
||||
- `plot_ohlc(df)` - basic OHLC chart with volume
|
||||
- `plot_ohlc(df, title='BTC 15min')` - with custom title
|
||||
- `plot_ohlc(df, volume=False)` - price only, no volume
|
||||
- Returns a matplotlib Figure that's automatically saved to plot_urls
|
||||
|
||||
Args:
|
||||
code: Python code to execute
|
||||
countback: Optional limit on number of bars to load (default: all visible bars)
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- script_output : printed output + last expression result
|
||||
- result_dataframe : serialized DataFrame if last expression is a DataFrame
|
||||
- plot_urls : list of image URLs (e.g., ["/uploads/plot_abc123.png"])
|
||||
- chart_context : {symbol, interval, start_time, end_time} or None
|
||||
- error : traceback if execution failed
|
||||
|
||||
Examples:
|
||||
# RSI indicator with chart
|
||||
execute_python(\"\"\"
|
||||
df['RSI'] = talib.RSI(df['close'], 14)
|
||||
fig = plot_ohlc(df, title='BTC/USDT with RSI')
|
||||
print(f"Current RSI: {df['RSI'].iloc[-1]:.2f}")
|
||||
df[['close', 'RSI']].tail(5)
|
||||
\"\"\")
|
||||
|
||||
# Multiple indicators
|
||||
execute_python(\"\"\"
|
||||
df['SMA_20'] = df['close'].rolling(20).mean()
|
||||
df['SMA_50'] = df['close'].rolling(50).mean()
|
||||
df['BB_upper'] = df['close'].rolling(20).mean() + 2*df['close'].rolling(20).std()
|
||||
df['BB_lower'] = df['close'].rolling(20).mean() - 2*df['close'].rolling(20).std()
|
||||
|
||||
fig = plot_ohlc(df, title=f"{chart_context['symbol']} - Bollinger Bands")
|
||||
|
||||
current_price = df['close'].iloc[-1]
|
||||
sma20 = df['SMA_20'].iloc[-1]
|
||||
print(f"Price: {current_price:.2f}, SMA20: {sma20:.2f}")
|
||||
df[['close', 'SMA_20', 'BB_upper', 'BB_lower']].tail(10)
|
||||
\"\"\")
|
||||
|
||||
# Pattern detection
|
||||
execute_python(\"\"\"
|
||||
# Find swing highs
|
||||
df['swing_high'] = (df['high'] > df['high'].shift(1)) & (df['high'] > df['high'].shift(-1))
|
||||
swing_highs = df[df['swing_high']][['high']].tail(5)
|
||||
|
||||
fig = plot_ohlc(df, title='Swing High Detection')
|
||||
print("Recent swing highs:")
|
||||
print(swing_highs)
|
||||
\"\"\")
|
||||
|
||||
# Load alternative data when df is None or for different symbol/timeframe
|
||||
execute_python(\"\"\"
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Get data source
|
||||
binance = datasource_registry.get_source('binance')
|
||||
|
||||
# Load ETH data even if viewing BTC chart
|
||||
end_time = datetime.now()
|
||||
start_time = end_time - timedelta(days=7)
|
||||
|
||||
result = await binance.get_history(
|
||||
symbol='ETH/USDT',
|
||||
interval='1h',
|
||||
start=int(start_time.timestamp()),
|
||||
end=int(end_time.timestamp())
|
||||
)
|
||||
|
||||
# Convert to DataFrame
|
||||
rows = [{'time': pd.to_datetime(bar.time, unit='s'), **bar.data} for bar in result.bars]
|
||||
eth_df = pd.DataFrame(rows).set_index('time')
|
||||
|
||||
# Calculate RSI and plot
|
||||
eth_df['RSI'] = talib.RSI(eth_df['close'], 14)
|
||||
fig = plot_ohlc(eth_df, title='ETH/USDT 1h - RSI Analysis')
|
||||
print(f"ETH RSI: {eth_df['RSI'].iloc[-1]:.2f}")
|
||||
\"\"\")
|
||||
|
||||
# Access chart store to see current state
|
||||
execute_python(\"\"\"
|
||||
print(f"Current symbol: {chart_store.chart_state.symbol}")
|
||||
print(f"Current interval: {chart_store.chart_state.interval}")
|
||||
print(f"Orders: {len(order_store.orders)}")
|
||||
\"\"\")
|
||||
"""
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
try:
|
||||
import talib
|
||||
except ImportError:
|
||||
talib = None
|
||||
logger.warning("TA-Lib not available in execute_python environment")
|
||||
|
||||
# --- Attempt to load chart data ---
|
||||
df = None
|
||||
chart_context = None
|
||||
|
||||
registry = _get_registry()
|
||||
datasource_registry = _get_datasource_registry()
|
||||
|
||||
if registry and datasource_registry:
|
||||
try:
|
||||
result, chart_context, source_name = await _get_chart_data_impl(countback)
|
||||
bars = result.bars
|
||||
if bars:
|
||||
rows = []
|
||||
for bar in bars:
|
||||
rows.append({'time': pd.to_datetime(bar.time, unit='s'), **bar.data})
|
||||
df = pd.DataFrame(rows).set_index('time')
|
||||
for col in ['open', 'high', 'low', 'close', 'volume']:
|
||||
if col in df.columns:
|
||||
df[col] = pd.to_numeric(df[col], errors='coerce')
|
||||
logger.info(f"execute_python: loaded {len(df)} bars for {chart_context['symbol']}")
|
||||
except Exception as e:
|
||||
logger.info(f"execute_python: no chart data loaded ({e})")
|
||||
|
||||
# --- Import chart utilities ---
|
||||
from .chart_utils import plot_ohlc
|
||||
|
||||
# --- Get indicator registry ---
|
||||
indicator_registry = _get_indicator_registry()
|
||||
|
||||
# --- Get DataStores ---
|
||||
order_store = _get_order_store()
|
||||
chart_store = _get_chart_store()
|
||||
|
||||
# --- Build globals ---
|
||||
script_globals: Dict[str, Any] = {
|
||||
'pd': pd,
|
||||
'np': np,
|
||||
'plt': plt,
|
||||
'talib': talib,
|
||||
'indicator_registry': indicator_registry,
|
||||
'registry': registry,
|
||||
'datasource_registry': datasource_registry,
|
||||
'order_store': order_store,
|
||||
'chart_store': chart_store,
|
||||
'df': df,
|
||||
'chart_context': chart_context,
|
||||
'plot_ohlc': plot_ohlc,
|
||||
}
|
||||
|
||||
# --- Execute ---
|
||||
uploads_dir = Path(__file__).parent.parent.parent.parent / "data" / "uploads"
|
||||
uploads_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
stdout_capture = io.StringIO()
|
||||
result_df = None
|
||||
error_msg = None
|
||||
plot_urls = []
|
||||
|
||||
try:
|
||||
with redirect_stdout(stdout_capture), redirect_stderr(stdout_capture):
|
||||
exec(code, script_globals)
|
||||
|
||||
# Capture last expression
|
||||
lines = code.strip().splitlines()
|
||||
if lines:
|
||||
last = lines[-1].strip()
|
||||
if last and not any(last.startswith(kw) for kw in (
|
||||
'if', 'for', 'while', 'def', 'class', 'import',
|
||||
'from', 'with', 'try', 'return', '#'
|
||||
)):
|
||||
try:
|
||||
last_val = eval(last, script_globals)
|
||||
if isinstance(last_val, pd.DataFrame):
|
||||
result_df = last_val
|
||||
elif last_val is not None:
|
||||
stdout_capture.write(str(last_val))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Save plots
|
||||
for fig_num in plt.get_fignums():
|
||||
fig = plt.figure(fig_num)
|
||||
filename = f"plot_{uuid.uuid4()}.png"
|
||||
fig.savefig(uploads_dir / filename, format='png', bbox_inches='tight', dpi=100)
|
||||
plot_urls.append(f"/uploads/{filename}")
|
||||
plt.close(fig)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_msg = f"{type(e).__name__}: {e}\n{traceback.format_exc()}"
|
||||
|
||||
# --- Build response ---
|
||||
response: Dict[str, Any] = {
|
||||
'script_output': stdout_capture.getvalue(),
|
||||
'chart_context': chart_context,
|
||||
'plot_urls': plot_urls,
|
||||
}
|
||||
if result_df is not None:
|
||||
response['result_dataframe'] = {
|
||||
'columns': result_df.columns.tolist(),
|
||||
'index': result_df.index.astype(str).tolist(),
|
||||
'data': result_df.values.tolist(),
|
||||
'shape': result_df.shape,
|
||||
}
|
||||
if error_msg:
|
||||
response['error'] = error_msg
|
||||
|
||||
return response
|
||||
|
||||
|
||||
CHART_TOOLS = [
|
||||
get_chart_data,
|
||||
execute_python
|
||||
]
|
||||
224
backend.old/src/agent/tools/chart_utils.py
Normal file
224
backend.old/src/agent/tools/chart_utils.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""Chart plotting utilities for creating standard, beautiful OHLC charts."""
|
||||
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import Optional, Tuple
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def plot_ohlc(
|
||||
df: pd.DataFrame,
|
||||
title: Optional[str] = None,
|
||||
volume: bool = True,
|
||||
figsize: Tuple[int, int] = (14, 8),
|
||||
style: str = 'seaborn-v0_8-darkgrid',
|
||||
**kwargs
|
||||
) -> plt.Figure:
|
||||
"""Create a beautiful standard OHLC candlestick chart.
|
||||
|
||||
This is a convenience function that generates a professional-looking candlestick
|
||||
chart with consistent styling across all generated charts. It uses mplfinance
|
||||
with seaborn aesthetics for a polished appearance.
|
||||
|
||||
Args:
|
||||
df: pandas DataFrame with DatetimeIndex and columns: open, high, low, close, volume
|
||||
title: Optional chart title. If None, uses symbol from chart context
|
||||
volume: Whether to include volume subplot (default: True)
|
||||
figsize: Figure size as (width, height) in inches (default: (14, 8))
|
||||
style: Base matplotlib style to use (default: 'seaborn-v0_8-darkgrid')
|
||||
**kwargs: Additional arguments to pass to mplfinance.plot()
|
||||
|
||||
Returns:
|
||||
matplotlib.figure.Figure: The created figure object
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Basic usage in analyze_chart_data script
|
||||
fig = plot_ohlc(df, title='BTC/USDT 15min')
|
||||
|
||||
# Customize with additional indicators
|
||||
fig = plot_ohlc(df, volume=True, title='Price Action')
|
||||
|
||||
# Add custom overlays after calling plot_ohlc
|
||||
df['SMA20'] = df['close'].rolling(20).mean()
|
||||
fig = plot_ohlc(df, title='With SMA')
|
||||
# Note: For mplfinance overlays, use the mav or addplot parameters
|
||||
```
|
||||
|
||||
Note:
|
||||
The DataFrame must have a DatetimeIndex and the standard OHLCV columns.
|
||||
Column names should be lowercase: open, high, low, close, volume
|
||||
"""
|
||||
try:
|
||||
import mplfinance as mpf
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"mplfinance is required for plot_ohlc(). "
|
||||
"Install it with: pip install mplfinance"
|
||||
)
|
||||
|
||||
# Validate DataFrame structure
|
||||
required_cols = ['open', 'high', 'low', 'close']
|
||||
missing_cols = [col for col in required_cols if col not in df.columns]
|
||||
if missing_cols:
|
||||
raise ValueError(
|
||||
f"DataFrame missing required columns: {missing_cols}. "
|
||||
f"Required: {required_cols}"
|
||||
)
|
||||
|
||||
if not isinstance(df.index, pd.DatetimeIndex):
|
||||
raise ValueError(
|
||||
"DataFrame must have a DatetimeIndex. "
|
||||
"Convert with: df.index = pd.to_datetime(df.index)"
|
||||
)
|
||||
|
||||
# Ensure volume column exists for volume plot
|
||||
if volume and 'volume' not in df.columns:
|
||||
logger.warning("volume=True but 'volume' column not found in DataFrame. Disabling volume.")
|
||||
volume = False
|
||||
|
||||
# Create custom style with seaborn aesthetics
|
||||
# Using a professional color scheme: green for up candles, red for down candles
|
||||
mc = mpf.make_marketcolors(
|
||||
up='#26a69a', # Teal green (calmer than bright green)
|
||||
down='#ef5350', # Coral red (softer than pure red)
|
||||
edge='inherit', # Match candle color for edges
|
||||
wick='inherit', # Match candle color for wicks
|
||||
volume='in', # Volume bars colored by price direction
|
||||
alpha=0.9 # Slight transparency for elegance
|
||||
)
|
||||
|
||||
s = mpf.make_mpf_style(
|
||||
base_mpf_style='charles', # Clean base style
|
||||
marketcolors=mc,
|
||||
rc={
|
||||
'font.size': 10,
|
||||
'axes.labelsize': 11,
|
||||
'axes.titlesize': 12,
|
||||
'xtick.labelsize': 9,
|
||||
'ytick.labelsize': 9,
|
||||
'legend.fontsize': 10,
|
||||
'figure.facecolor': '#f0f0f0',
|
||||
'axes.facecolor': '#ffffff',
|
||||
'axes.grid': True,
|
||||
'grid.alpha': 0.3,
|
||||
'grid.linestyle': '--',
|
||||
}
|
||||
)
|
||||
|
||||
# Prepare plot parameters
|
||||
plot_params = {
|
||||
'type': 'candle',
|
||||
'style': s,
|
||||
'volume': volume,
|
||||
'figsize': figsize,
|
||||
'tight_layout': True,
|
||||
'returnfig': True,
|
||||
'warn_too_much_data': 1000, # Warn if > 1000 candles for performance
|
||||
}
|
||||
|
||||
# Add title if provided
|
||||
if title:
|
||||
plot_params['title'] = title
|
||||
|
||||
# Merge any additional kwargs
|
||||
plot_params.update(kwargs)
|
||||
|
||||
# Create the plot
|
||||
logger.info(
|
||||
f"Creating OHLC chart with {len(df)} candles, "
|
||||
f"date range: {df.index.min()} to {df.index.max()}, "
|
||||
f"volume: {volume}"
|
||||
)
|
||||
|
||||
fig, axes = mpf.plot(df, **plot_params)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def add_indicators_to_plot(
|
||||
df: pd.DataFrame,
|
||||
indicators: dict,
|
||||
**plot_kwargs
|
||||
) -> plt.Figure:
|
||||
"""Create an OHLC chart with technical indicators overlaid.
|
||||
|
||||
This extends plot_ohlc() to include common technical indicators using
|
||||
mplfinance's addplot functionality for proper overlay on candlestick charts.
|
||||
|
||||
Args:
|
||||
df: pandas DataFrame with OHLCV data and indicator columns
|
||||
indicators: Dictionary mapping indicator names to parameters
|
||||
Example: {
|
||||
'SMA_20': {'color': 'blue', 'width': 1.5},
|
||||
'EMA_50': {'color': 'orange', 'width': 1.5}
|
||||
}
|
||||
**plot_kwargs: Additional arguments for plot_ohlc()
|
||||
|
||||
Returns:
|
||||
matplotlib.figure.Figure: The created figure object
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Calculate indicators
|
||||
df['SMA_20'] = df['close'].rolling(20).mean()
|
||||
df['SMA_50'] = df['close'].rolling(50).mean()
|
||||
|
||||
# Plot with indicators
|
||||
fig = add_indicators_to_plot(
|
||||
df,
|
||||
indicators={
|
||||
'SMA_20': {'color': 'blue', 'width': 1.5, 'label': '20 SMA'},
|
||||
'SMA_50': {'color': 'red', 'width': 1.5, 'label': '50 SMA'}
|
||||
},
|
||||
title='BTC/USDT with Moving Averages'
|
||||
)
|
||||
```
|
||||
"""
|
||||
try:
|
||||
import mplfinance as mpf
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"mplfinance is required. Install it with: pip install mplfinance"
|
||||
)
|
||||
|
||||
# Build addplot list for indicators
|
||||
addplots = []
|
||||
for indicator_col, params in indicators.items():
|
||||
if indicator_col not in df.columns:
|
||||
logger.warning(f"Indicator column '{indicator_col}' not found in DataFrame. Skipping.")
|
||||
continue
|
||||
|
||||
color = params.get('color', 'blue')
|
||||
width = params.get('width', 1.0)
|
||||
panel = params.get('panel', 0) # 0 = main panel with candles
|
||||
ylabel = params.get('ylabel', '')
|
||||
|
||||
addplots.append(
|
||||
mpf.make_addplot(
|
||||
df[indicator_col],
|
||||
color=color,
|
||||
width=width,
|
||||
panel=panel,
|
||||
ylabel=ylabel
|
||||
)
|
||||
)
|
||||
|
||||
# Pass addplot to plot_ohlc via kwargs
|
||||
if addplots:
|
||||
plot_kwargs['addplot'] = addplots
|
||||
|
||||
return plot_ohlc(df, **plot_kwargs)
|
||||
|
||||
|
||||
# Convenience presets for common chart types
|
||||
def plot_price_volume(df: pd.DataFrame, title: Optional[str] = None) -> plt.Figure:
|
||||
"""Create a standard price + volume chart."""
|
||||
return plot_ohlc(df, title=title, volume=True, figsize=(14, 8))
|
||||
|
||||
|
||||
def plot_price_only(df: pd.DataFrame, title: Optional[str] = None) -> plt.Figure:
|
||||
"""Create a price-only candlestick chart without volume."""
|
||||
return plot_ohlc(df, title=title, volume=False, figsize=(14, 6))
|
||||
154
backend.old/src/agent/tools/chart_utils_example.py
Normal file
154
backend.old/src/agent/tools/chart_utils_example.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
Example usage of chart_utils.py plotting functions.
|
||||
|
||||
This demonstrates how the LLM can use the plot_ohlc() convenience function
|
||||
in analyze_chart_data scripts to create beautiful, standard OHLC charts.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
def create_sample_data(days=30):
|
||||
"""Create sample OHLCV data for testing."""
|
||||
dates = pd.date_range(end=datetime.now(), periods=days * 24, freq='1H')
|
||||
|
||||
# Simulate price movement
|
||||
np.random.seed(42)
|
||||
close = 50000 + np.cumsum(np.random.randn(len(dates)) * 100)
|
||||
|
||||
data = {
|
||||
'open': close + np.random.randn(len(dates)) * 50,
|
||||
'high': close + np.abs(np.random.randn(len(dates))) * 100,
|
||||
'low': close - np.abs(np.random.randn(len(dates))) * 100,
|
||||
'close': close,
|
||||
'volume': np.abs(np.random.randn(len(dates))) * 1000000
|
||||
}
|
||||
|
||||
df = pd.DataFrame(data, index=dates)
|
||||
|
||||
# Ensure high is highest and low is lowest
|
||||
df['high'] = df[['open', 'high', 'low', 'close']].max(axis=1)
|
||||
df['low'] = df[['open', 'high', 'low', 'close']].min(axis=1)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from chart_utils import plot_ohlc, add_indicators_to_plot, plot_price_volume
|
||||
|
||||
# Create sample data
|
||||
df = create_sample_data(days=30)
|
||||
|
||||
print("=" * 60)
|
||||
print("Example 1: Basic OHLC chart with volume")
|
||||
print("=" * 60)
|
||||
print("\nScript the LLM would generate:")
|
||||
print("""
|
||||
fig = plot_ohlc(df, title='BTC/USDT 1H', volume=True)
|
||||
df.tail(5)
|
||||
""")
|
||||
|
||||
# Execute it
|
||||
fig = plot_ohlc(df, title='BTC/USDT 1H', volume=True)
|
||||
print("\n✓ Chart created successfully!")
|
||||
print(f" Figure size: {fig.get_size_inches()}")
|
||||
print(f" Number of axes: {len(fig.axes)}")
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Example 2: OHLC chart with indicators")
|
||||
print("=" * 60)
|
||||
print("\nScript the LLM would generate:")
|
||||
print("""
|
||||
# Calculate indicators
|
||||
df['SMA_20'] = df['close'].rolling(20).mean()
|
||||
df['SMA_50'] = df['close'].rolling(50).mean()
|
||||
df['EMA_12'] = df['close'].ewm(span=12, adjust=False).mean()
|
||||
|
||||
# Plot with indicators
|
||||
fig = add_indicators_to_plot(
|
||||
df,
|
||||
indicators={
|
||||
'SMA_20': {'color': 'blue', 'width': 1.5},
|
||||
'SMA_50': {'color': 'red', 'width': 1.5},
|
||||
'EMA_12': {'color': 'green', 'width': 1.0}
|
||||
},
|
||||
title='BTC/USDT with Moving Averages',
|
||||
volume=True
|
||||
)
|
||||
|
||||
df[['close', 'SMA_20', 'SMA_50', 'EMA_12']].tail(5)
|
||||
""")
|
||||
|
||||
# Execute it
|
||||
df['SMA_20'] = df['close'].rolling(20).mean()
|
||||
df['SMA_50'] = df['close'].rolling(50).mean()
|
||||
df['EMA_12'] = df['close'].ewm(span=12, adjust=False).mean()
|
||||
|
||||
fig = add_indicators_to_plot(
|
||||
df,
|
||||
indicators={
|
||||
'SMA_20': {'color': 'blue', 'width': 1.5},
|
||||
'SMA_50': {'color': 'red', 'width': 1.5},
|
||||
'EMA_12': {'color': 'green', 'width': 1.0}
|
||||
},
|
||||
title='BTC/USDT with Moving Averages',
|
||||
volume=True
|
||||
)
|
||||
|
||||
print("\n✓ Chart with indicators created successfully!")
|
||||
print(f" Last close: ${df['close'].iloc[-1]:,.2f}")
|
||||
print(f" SMA 20: ${df['SMA_20'].iloc[-1]:,.2f}")
|
||||
print(f" SMA 50: ${df['SMA_50'].iloc[-1]:,.2f}")
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Example 3: Price-only chart (no volume)")
|
||||
print("=" * 60)
|
||||
print("\nScript the LLM would generate:")
|
||||
print("""
|
||||
from chart_utils import plot_price_only
|
||||
|
||||
fig = plot_price_only(df, title='Clean Price Action')
|
||||
""")
|
||||
|
||||
# Execute it
|
||||
from chart_utils import plot_price_only
|
||||
fig = plot_price_only(df, title='Clean Price Action')
|
||||
|
||||
print("\n✓ Price-only chart created successfully!")
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Summary")
|
||||
print("=" * 60)
|
||||
print("""
|
||||
The chart_utils module provides:
|
||||
|
||||
1. plot_ohlc() - Main function for beautiful candlestick charts
|
||||
- Professional seaborn-inspired styling
|
||||
- Consistent color scheme (teal up, coral down)
|
||||
- Optional volume subplot
|
||||
- Customizable figure size
|
||||
|
||||
2. add_indicators_to_plot() - OHLC charts with technical indicators
|
||||
- Overlay multiple indicators
|
||||
- Customizable colors and line widths
|
||||
- Proper integration with mplfinance
|
||||
|
||||
3. Preset functions for common chart types:
|
||||
- plot_price_volume() - Standard price + volume
|
||||
- plot_price_only() - Candlesticks without volume
|
||||
|
||||
Benefits:
|
||||
✓ Consistent look and feel across all charts
|
||||
✓ Less code for the LLM to generate
|
||||
✓ Professional appearance out of the box
|
||||
✓ Easy to customize when needed
|
||||
✓ Works seamlessly with analyze_chart_data tool
|
||||
|
||||
The LLM can now simply call plot_ohlc(df) instead of writing
|
||||
custom matplotlib code for every chart request!
|
||||
""")
|
||||
158
backend.old/src/agent/tools/datasource_tools.py
Normal file
158
backend.old/src/agent/tools/datasource_tools.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Data source and market data tools."""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
def _get_datasource_registry():
|
||||
"""Get the global datasource registry instance."""
|
||||
from . import _datasource_registry
|
||||
return _datasource_registry
|
||||
|
||||
|
||||
@tool
|
||||
def list_data_sources() -> List[str]:
|
||||
"""List all available data sources.
|
||||
|
||||
Returns:
|
||||
List of data source names that can be queried for market data
|
||||
"""
|
||||
registry = _get_datasource_registry()
|
||||
if not registry:
|
||||
return []
|
||||
return registry.list_sources()
|
||||
|
||||
|
||||
@tool
|
||||
async def search_symbols(
|
||||
query: str,
|
||||
type: Optional[str] = None,
|
||||
exchange: Optional[str] = None,
|
||||
limit: int = 30,
|
||||
) -> Dict[str, Any]:
|
||||
"""Search for trading symbols across all data sources.
|
||||
|
||||
Automatically searches all available data sources and returns aggregated results.
|
||||
Use this to find symbols before calling get_symbol_info or get_historical_data.
|
||||
|
||||
Args:
|
||||
query: Search query (e.g., "BTC", "AAPL", "EUR")
|
||||
type: Optional filter by instrument type (e.g., "crypto", "stock", "forex")
|
||||
exchange: Optional filter by exchange (e.g., "binance", "nasdaq")
|
||||
limit: Maximum number of results per source (default: 30)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping source names to lists of matching symbols.
|
||||
Each symbol includes: symbol, full_name, description, exchange, type.
|
||||
Use the source name and symbol from results with get_symbol_info or get_historical_data.
|
||||
|
||||
Example response:
|
||||
{
|
||||
"demo": [
|
||||
{
|
||||
"symbol": "BTC/USDT",
|
||||
"full_name": "Bitcoin / Tether USD",
|
||||
"description": "Bitcoin perpetual futures",
|
||||
"exchange": "demo",
|
||||
"type": "crypto"
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
registry = _get_datasource_registry()
|
||||
if not registry:
|
||||
raise ValueError("DataSourceRegistry not initialized")
|
||||
|
||||
# Always search all sources
|
||||
results = await registry.search_all(query, type, exchange, limit)
|
||||
return {name: [r.model_dump() for r in matches] for name, matches in results.items()}
|
||||
|
||||
|
||||
@tool
|
||||
async def get_symbol_info(source_name: str, symbol: str) -> Dict[str, Any]:
|
||||
"""Get complete metadata for a trading symbol.
|
||||
|
||||
This retrieves full information about a symbol including:
|
||||
- Description and type
|
||||
- Supported time resolutions
|
||||
- Available data columns (OHLCV, volume, funding rates, etc.)
|
||||
- Trading session information
|
||||
- Price scale and precision
|
||||
|
||||
Args:
|
||||
source_name: Name of the data source (use list_data_sources to see available)
|
||||
symbol: Symbol identifier (e.g., "BTC/USDT", "AAPL", "EUR/USD")
|
||||
|
||||
Returns:
|
||||
Dictionary containing complete symbol metadata including column schema
|
||||
|
||||
Raises:
|
||||
ValueError: If source_name or symbol is not found
|
||||
"""
|
||||
registry = _get_datasource_registry()
|
||||
if not registry:
|
||||
raise ValueError("DataSourceRegistry not initialized")
|
||||
|
||||
symbol_info = await registry.resolve_symbol(source_name, symbol)
|
||||
return symbol_info.model_dump()
|
||||
|
||||
|
||||
@tool
|
||||
async def get_historical_data(
|
||||
source_name: str,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
from_time: int,
|
||||
to_time: int,
|
||||
countback: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get historical bar/candle data for a symbol.
|
||||
|
||||
Retrieves time-series data between the specified timestamps. The data
|
||||
includes all columns defined for the symbol (OHLCV + any custom columns).
|
||||
|
||||
Args:
|
||||
source_name: Name of the data source
|
||||
symbol: Symbol identifier
|
||||
resolution: Time resolution (e.g., "1" = 1min, "5" = 5min, "60" = 1hour, "1D" = 1day)
|
||||
from_time: Start time as Unix timestamp in seconds
|
||||
to_time: End time as Unix timestamp in seconds
|
||||
countback: Optional limit on number of bars to return
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- symbol: The requested symbol
|
||||
- resolution: The time resolution
|
||||
- bars: List of bar data with 'time' and 'data' fields
|
||||
- columns: Schema describing available data columns
|
||||
- nextTime: If present, indicates more data is available for pagination
|
||||
|
||||
Raises:
|
||||
ValueError: If source, symbol, or resolution is invalid
|
||||
|
||||
Example:
|
||||
# Get 1-hour BTC data for the last 24 hours
|
||||
import time
|
||||
to_time = int(time.time())
|
||||
from_time = to_time - 86400 # 24 hours ago
|
||||
data = get_historical_data("demo", "BTC/USDT", "60", from_time, to_time)
|
||||
"""
|
||||
registry = _get_datasource_registry()
|
||||
if not registry:
|
||||
raise ValueError("DataSourceRegistry not initialized")
|
||||
|
||||
source = registry.get(source_name)
|
||||
if not source:
|
||||
available = registry.list_sources()
|
||||
raise ValueError(f"Data source '{source_name}' not found. Available sources: {available}")
|
||||
|
||||
result = await source.get_bars(symbol, resolution, from_time, to_time, countback)
|
||||
return result.model_dump()
|
||||
|
||||
|
||||
DATASOURCE_TOOLS = [
|
||||
list_data_sources,
|
||||
search_symbols,
|
||||
get_symbol_info,
|
||||
get_historical_data,
|
||||
]
|
||||
435
backend.old/src/agent/tools/indicator_tools.py
Normal file
435
backend.old/src/agent/tools/indicator_tools.py
Normal file
@@ -0,0 +1,435 @@
|
||||
"""Technical indicator tools.
|
||||
|
||||
These tools allow the agent to:
|
||||
1. Discover available indicators (list, search, get info)
|
||||
2. Add indicators to the chart
|
||||
3. Update/remove indicators
|
||||
4. Query currently applied indicators
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from langchain_core.tools import tool
|
||||
import logging
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_indicator_registry():
|
||||
"""Get the global indicator registry instance."""
|
||||
from . import _indicator_registry
|
||||
return _indicator_registry
|
||||
|
||||
|
||||
def _get_registry():
|
||||
"""Get the global sync registry instance."""
|
||||
from . import _registry
|
||||
return _registry
|
||||
|
||||
|
||||
def _get_indicator_store():
|
||||
"""Get the global IndicatorStore instance."""
|
||||
registry = _get_registry()
|
||||
if registry and "IndicatorStore" in registry.entries:
|
||||
return registry.entries["IndicatorStore"].model
|
||||
return None
|
||||
|
||||
|
||||
@tool
|
||||
def list_indicators() -> List[str]:
|
||||
"""List all available technical indicators.
|
||||
|
||||
Returns:
|
||||
List of indicator names that can be used in analysis and strategies
|
||||
"""
|
||||
registry = _get_indicator_registry()
|
||||
if not registry:
|
||||
return []
|
||||
return registry.list_indicators()
|
||||
|
||||
|
||||
@tool
|
||||
def get_indicator_info(indicator_name: str) -> Dict[str, Any]:
|
||||
"""Get detailed information about a specific indicator.
|
||||
|
||||
Retrieves metadata including description, parameters, category, use cases,
|
||||
input/output schemas, and references.
|
||||
|
||||
Args:
|
||||
indicator_name: Name of the indicator (e.g., "RSI", "SMA", "MACD")
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- name: Indicator name
|
||||
- display_name: Human-readable name
|
||||
- description: What the indicator computes and why it's useful
|
||||
- category: Category (momentum, trend, volatility, volume, etc.)
|
||||
- parameters: List of configurable parameters with types and defaults
|
||||
- use_cases: Common trading scenarios where this indicator helps
|
||||
- tags: Searchable tags
|
||||
- input_schema: Required input columns (e.g., OHLCV requirements)
|
||||
- output_schema: Columns this indicator produces
|
||||
|
||||
Raises:
|
||||
ValueError: If indicator_name is not found
|
||||
"""
|
||||
registry = _get_indicator_registry()
|
||||
if not registry:
|
||||
raise ValueError("IndicatorRegistry not initialized")
|
||||
|
||||
metadata = registry.get_metadata(indicator_name)
|
||||
if not metadata:
|
||||
total_count = len(registry.list_indicators())
|
||||
raise ValueError(
|
||||
f"Indicator '{indicator_name}' not found. "
|
||||
f"Total available: {total_count} indicators. "
|
||||
f"Use search_indicators() to find indicators by name, category, or tag."
|
||||
)
|
||||
|
||||
input_schema = registry.get_input_schema(indicator_name)
|
||||
output_schema = registry.get_output_schema(indicator_name)
|
||||
|
||||
result = metadata.model_dump()
|
||||
result["input_schema"] = input_schema.model_dump() if input_schema else None
|
||||
result["output_schema"] = output_schema.model_dump() if output_schema else None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@tool
|
||||
def search_indicators(
|
||||
query: Optional[str] = None,
|
||||
category: Optional[str] = None,
|
||||
tag: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search for indicators by text query, category, or tag.
|
||||
|
||||
Returns lightweight summaries - use get_indicator_info() for full details on specific indicators.
|
||||
|
||||
Use this to discover relevant indicators for your trading strategy or analysis.
|
||||
Can filter by category (momentum, trend, volatility, etc.) or search by keywords.
|
||||
|
||||
Args:
|
||||
query: Optional text search across names, descriptions, and use cases
|
||||
category: Optional category filter (momentum, trend, volatility, volume, pattern, etc.)
|
||||
tag: Optional tag filter (e.g., "oscillator", "moving-average", "talib")
|
||||
|
||||
Returns:
|
||||
List of lightweight indicator summaries. Each contains:
|
||||
- name: Indicator name (use with get_indicator_info() for full details)
|
||||
- display_name: Human-readable name
|
||||
- description: Brief one-line description
|
||||
- category: Category (momentum, trend, volatility, etc.)
|
||||
|
||||
Example:
|
||||
# Find all momentum indicators
|
||||
results = search_indicators(category="momentum")
|
||||
# Returns [{name: "RSI", display_name: "RSI", description: "...", category: "momentum"}, ...]
|
||||
|
||||
# Then get details on interesting ones
|
||||
rsi_details = get_indicator_info("RSI") # Full parameters, schemas, use cases
|
||||
|
||||
# Search for moving average indicators
|
||||
search_indicators(query="moving average")
|
||||
|
||||
# Find all TA-Lib indicators
|
||||
search_indicators(tag="talib")
|
||||
"""
|
||||
registry = _get_indicator_registry()
|
||||
if not registry:
|
||||
raise ValueError("IndicatorRegistry not initialized")
|
||||
|
||||
results = []
|
||||
|
||||
if query:
|
||||
results = registry.search_by_text(query)
|
||||
elif category:
|
||||
results = registry.search_by_category(category)
|
||||
elif tag:
|
||||
results = registry.search_by_tag(tag)
|
||||
else:
|
||||
# Return all indicators if no filter
|
||||
results = registry.get_all_metadata()
|
||||
|
||||
# Return lightweight summaries only
|
||||
return [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"description": r.description,
|
||||
"category": r.category
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
|
||||
@tool
|
||||
def get_indicator_categories() -> Dict[str, int]:
|
||||
"""Get all indicator categories and their counts.
|
||||
|
||||
Returns a summary of available indicator categories, useful for
|
||||
exploring what types of indicators are available.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping category name to count of indicators in that category.
|
||||
Example: {"momentum": 25, "trend": 15, "volatility": 8, ...}
|
||||
"""
|
||||
registry = _get_indicator_registry()
|
||||
if not registry:
|
||||
raise ValueError("IndicatorRegistry not initialized")
|
||||
|
||||
categories: Dict[str, int] = {}
|
||||
for metadata in registry.get_all_metadata():
|
||||
category = metadata.category
|
||||
categories[category] = categories.get(category, 0) + 1
|
||||
|
||||
return categories
|
||||
|
||||
|
||||
@tool
|
||||
async def add_indicator_to_chart(
|
||||
indicator_id: str,
|
||||
talib_name: str,
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
symbol: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Add a technical indicator to the chart.
|
||||
|
||||
This will create a new indicator instance and display it on the TradingView chart.
|
||||
The indicator will be synchronized with the frontend in real-time.
|
||||
|
||||
Args:
|
||||
indicator_id: Unique identifier for this indicator instance (e.g., 'rsi_14', 'sma_50')
|
||||
talib_name: Name of the TA-Lib indicator (e.g., 'RSI', 'SMA', 'MACD', 'BBANDS')
|
||||
Use search_indicators() or get_indicator_info() to find available indicators
|
||||
parameters: Optional dictionary of indicator parameters
|
||||
Example for RSI: {'timeperiod': 14}
|
||||
Example for SMA: {'timeperiod': 50}
|
||||
Example for MACD: {'fastperiod': 12, 'slowperiod': 26, 'signalperiod': 9}
|
||||
Example for BBANDS: {'timeperiod': 20, 'nbdevup': 2, 'nbdevdn': 2}
|
||||
symbol: Optional symbol to apply the indicator to (defaults to current chart symbol)
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: 'created' or 'updated'
|
||||
- indicator: The complete indicator object
|
||||
|
||||
Example:
|
||||
# Add RSI(14)
|
||||
await add_indicator_to_chart(
|
||||
indicator_id='rsi_14',
|
||||
talib_name='RSI',
|
||||
parameters={'timeperiod': 14}
|
||||
)
|
||||
|
||||
# Add 50-period SMA
|
||||
await add_indicator_to_chart(
|
||||
indicator_id='sma_50',
|
||||
talib_name='SMA',
|
||||
parameters={'timeperiod': 50}
|
||||
)
|
||||
|
||||
# Add MACD with default parameters
|
||||
await add_indicator_to_chart(
|
||||
indicator_id='macd_default',
|
||||
talib_name='MACD'
|
||||
)
|
||||
"""
|
||||
from schema.indicator import IndicatorInstance
|
||||
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
indicator_store = _get_indicator_store()
|
||||
if not indicator_store:
|
||||
raise ValueError("IndicatorStore not initialized")
|
||||
|
||||
# Verify the indicator exists
|
||||
indicator_registry = _get_indicator_registry()
|
||||
if not indicator_registry:
|
||||
raise ValueError("IndicatorRegistry not initialized")
|
||||
|
||||
metadata = indicator_registry.get_metadata(talib_name)
|
||||
if not metadata:
|
||||
raise ValueError(
|
||||
f"Indicator '{talib_name}' not found. "
|
||||
f"Use search_indicators() to find available indicators."
|
||||
)
|
||||
|
||||
# Check if updating existing indicator
|
||||
existing_indicator = indicator_store.indicators.get(indicator_id)
|
||||
is_update = existing_indicator is not None
|
||||
|
||||
# If symbol is not provided, try to get it from ChartStore
|
||||
if symbol is None and "ChartStore" in registry.entries:
|
||||
chart_store = registry.entries["ChartStore"].model
|
||||
if hasattr(chart_store, 'chart_state') and hasattr(chart_store.chart_state, 'symbol'):
|
||||
symbol = chart_store.chart_state.symbol
|
||||
logger.info(f"Using current chart symbol for indicator: {symbol}")
|
||||
|
||||
now = int(time.time())
|
||||
|
||||
# Create indicator instance
|
||||
indicator = IndicatorInstance(
|
||||
id=indicator_id,
|
||||
talib_name=talib_name,
|
||||
instance_name=f"{talib_name}_{indicator_id}",
|
||||
parameters=parameters or {},
|
||||
visible=True,
|
||||
pane='chart', # Most indicators go on the chart pane
|
||||
symbol=symbol,
|
||||
created_at=existing_indicator.get('created_at') if existing_indicator else now,
|
||||
modified_at=now
|
||||
)
|
||||
|
||||
# Update the store
|
||||
indicator_store.indicators[indicator_id] = indicator.model_dump(mode="json")
|
||||
|
||||
# Trigger sync
|
||||
await registry.push_all()
|
||||
|
||||
logger.info(
|
||||
f"{'Updated' if is_update else 'Created'} indicator '{indicator_id}' "
|
||||
f"(TA-Lib: {talib_name}) with parameters: {parameters}"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "updated" if is_update else "created",
|
||||
"indicator": indicator.model_dump(mode="json")
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
async def remove_indicator_from_chart(indicator_id: str) -> Dict[str, str]:
|
||||
"""Remove an indicator from the chart.
|
||||
|
||||
Args:
|
||||
indicator_id: ID of the indicator instance to remove
|
||||
|
||||
Returns:
|
||||
Dictionary with status message
|
||||
|
||||
Raises:
|
||||
ValueError: If indicator doesn't exist
|
||||
|
||||
Example:
|
||||
await remove_indicator_from_chart('rsi_14')
|
||||
"""
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
indicator_store = _get_indicator_store()
|
||||
if not indicator_store:
|
||||
raise ValueError("IndicatorStore not initialized")
|
||||
|
||||
if indicator_id not in indicator_store.indicators:
|
||||
raise ValueError(f"Indicator '{indicator_id}' not found")
|
||||
|
||||
# Delete the indicator
|
||||
del indicator_store.indicators[indicator_id]
|
||||
|
||||
# Trigger sync
|
||||
await registry.push_all()
|
||||
|
||||
logger.info(f"Removed indicator '{indicator_id}'")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Indicator '{indicator_id}' removed"
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
def list_chart_indicators(symbol: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""List all indicators currently applied to the chart.
|
||||
|
||||
Args:
|
||||
symbol: Optional filter by symbol (defaults to current chart symbol)
|
||||
|
||||
Returns:
|
||||
List of indicator instances, each containing:
|
||||
- id: Indicator instance ID
|
||||
- talib_name: TA-Lib indicator name
|
||||
- instance_name: Display name
|
||||
- parameters: Current parameter values
|
||||
- visible: Whether indicator is visible
|
||||
- pane: Which pane it's displayed in
|
||||
- symbol: Symbol it's applied to
|
||||
|
||||
Example:
|
||||
# List all indicators on current symbol
|
||||
indicators = list_chart_indicators()
|
||||
|
||||
# List indicators on specific symbol
|
||||
btc_indicators = list_chart_indicators(symbol='BINANCE:BTC/USDT')
|
||||
"""
|
||||
indicator_store = _get_indicator_store()
|
||||
if not indicator_store:
|
||||
raise ValueError("IndicatorStore not initialized")
|
||||
|
||||
logger.info(f"list_chart_indicators: Raw store indicators: {indicator_store.indicators}")
|
||||
|
||||
# If symbol is not provided, try to get it from ChartStore
|
||||
if symbol is None:
|
||||
registry = _get_registry()
|
||||
if registry and "ChartStore" in registry.entries:
|
||||
chart_store = registry.entries["ChartStore"].model
|
||||
if hasattr(chart_store, 'chart_state') and hasattr(chart_store.chart_state, 'symbol'):
|
||||
symbol = chart_store.chart_state.symbol
|
||||
|
||||
indicators = list(indicator_store.indicators.values())
|
||||
|
||||
logger.info(f"list_chart_indicators: Converted to list: {indicators}")
|
||||
logger.info(f"list_chart_indicators: Filtering by symbol: {symbol}")
|
||||
|
||||
# Filter by symbol if provided
|
||||
if symbol:
|
||||
indicators = [ind for ind in indicators if ind.get('symbol') == symbol]
|
||||
|
||||
logger.info(f"list_chart_indicators: Returning {len(indicators)} indicators")
|
||||
return indicators
|
||||
|
||||
|
||||
@tool
|
||||
def get_chart_indicator(indicator_id: str) -> Dict[str, Any]:
|
||||
"""Get details of a specific indicator on the chart.
|
||||
|
||||
Args:
|
||||
indicator_id: ID of the indicator instance
|
||||
|
||||
Returns:
|
||||
Dictionary containing the indicator data
|
||||
|
||||
Raises:
|
||||
ValueError: If indicator doesn't exist
|
||||
|
||||
Example:
|
||||
indicator = get_chart_indicator('rsi_14')
|
||||
print(f"Indicator: {indicator['talib_name']}")
|
||||
print(f"Parameters: {indicator['parameters']}")
|
||||
"""
|
||||
indicator_store = _get_indicator_store()
|
||||
if not indicator_store:
|
||||
raise ValueError("IndicatorStore not initialized")
|
||||
|
||||
indicator = indicator_store.indicators.get(indicator_id)
|
||||
if not indicator:
|
||||
raise ValueError(f"Indicator '{indicator_id}' not found")
|
||||
|
||||
return indicator
|
||||
|
||||
|
||||
INDICATOR_TOOLS = [
|
||||
# Discovery tools
|
||||
list_indicators,
|
||||
get_indicator_info,
|
||||
search_indicators,
|
||||
get_indicator_categories,
|
||||
# Chart indicator management tools
|
||||
add_indicator_to_chart,
|
||||
remove_indicator_from_chart,
|
||||
list_chart_indicators,
|
||||
get_chart_indicator
|
||||
]
|
||||
171
backend.old/src/agent/tools/research_tools.py
Normal file
171
backend.old/src/agent/tools/research_tools.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""Research and external data tools for trading analysis."""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from langchain_core.tools import tool
|
||||
from langchain_community.tools import (
|
||||
ArxivQueryRun,
|
||||
WikipediaQueryRun,
|
||||
DuckDuckGoSearchRun
|
||||
)
|
||||
from langchain_community.utilities import (
|
||||
ArxivAPIWrapper,
|
||||
WikipediaAPIWrapper,
|
||||
DuckDuckGoSearchAPIWrapper
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
def search_arxiv(query: str, max_results: int = 5) -> str:
|
||||
"""Search arXiv for academic papers on quantitative finance, trading strategies, and machine learning.
|
||||
|
||||
Use this to find research papers on topics like:
|
||||
- Market microstructure and order flow
|
||||
- Algorithmic trading strategies
|
||||
- Machine learning for finance
|
||||
- Time series forecasting
|
||||
- Risk management
|
||||
- Portfolio optimization
|
||||
|
||||
Args:
|
||||
query: Search query (e.g., "machine learning algorithmic trading", "deep learning stock prediction")
|
||||
max_results: Maximum number of results to return (default: 5)
|
||||
|
||||
Returns:
|
||||
Summary of papers including titles, authors, abstracts, and links
|
||||
|
||||
Example:
|
||||
search_arxiv("reinforcement learning trading", max_results=3)
|
||||
"""
|
||||
arxiv = ArxivQueryRun(api_wrapper=ArxivAPIWrapper(top_k_results=max_results))
|
||||
return arxiv.run(query)
|
||||
|
||||
|
||||
@tool
|
||||
def search_wikipedia(query: str) -> str:
|
||||
"""Search Wikipedia for information on finance, trading, and economics concepts.
|
||||
|
||||
Use this to get background information on:
|
||||
- Financial instruments and markets
|
||||
- Economic indicators
|
||||
- Trading terminology
|
||||
- Technical analysis concepts
|
||||
- Historical market events
|
||||
|
||||
Args:
|
||||
query: Search query (e.g., "Black-Scholes model", "technical analysis", "options trading")
|
||||
|
||||
Returns:
|
||||
Wikipedia article summary with key information
|
||||
|
||||
Example:
|
||||
search_wikipedia("Bollinger Bands")
|
||||
"""
|
||||
wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
|
||||
return wikipedia.run(query)
|
||||
|
||||
|
||||
@tool
|
||||
def search_web(query: str, max_results: int = 5) -> str:
|
||||
"""Search the web for current information on markets, news, and trading.
|
||||
|
||||
Use this to find:
|
||||
- Latest market news and analysis
|
||||
- Company announcements and earnings
|
||||
- Economic events and indicators
|
||||
- Cryptocurrency updates
|
||||
- Exchange status and updates
|
||||
- Trading strategy discussions
|
||||
|
||||
Args:
|
||||
query: Search query (e.g., "Bitcoin price news", "Fed interest rate decision")
|
||||
max_results: Maximum number of results to return (default: 5)
|
||||
|
||||
Returns:
|
||||
Search results with titles, snippets, and links
|
||||
|
||||
Example:
|
||||
search_web("Ethereum merge update", max_results=3)
|
||||
"""
|
||||
# Lazy initialization to avoid hanging during import
|
||||
search = DuckDuckGoSearchRun(api_wrapper=DuckDuckGoSearchAPIWrapper())
|
||||
# Note: max_results parameter doesn't work properly with current wrapper
|
||||
return search.run(query)
|
||||
|
||||
|
||||
@tool
|
||||
def http_get(url: str, params: Optional[Dict[str, str]] = None) -> str:
|
||||
"""Make HTTP GET request to fetch data from APIs or web pages.
|
||||
|
||||
Use this to retrieve:
|
||||
- Exchange API data (if public endpoints)
|
||||
- Market data from external APIs
|
||||
- Documentation and specifications
|
||||
- News articles and blog posts
|
||||
- JSON/XML data from web services
|
||||
|
||||
Args:
|
||||
url: The URL to fetch
|
||||
params: Optional query parameters as a dictionary
|
||||
|
||||
Returns:
|
||||
Response text from the URL
|
||||
|
||||
Raises:
|
||||
ValueError: If the request fails
|
||||
|
||||
Example:
|
||||
http_get("https://api.coingecko.com/api/v3/simple/price",
|
||||
params={"ids": "bitcoin", "vs_currencies": "usd"})
|
||||
"""
|
||||
import requests
|
||||
|
||||
try:
|
||||
response = requests.get(url, params=params, timeout=10)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except requests.RequestException as e:
|
||||
raise ValueError(f"HTTP GET request failed: {str(e)}")
|
||||
|
||||
|
||||
@tool
|
||||
def http_post(url: str, data: Dict[str, Any]) -> str:
|
||||
"""Make HTTP POST request to send data to APIs.
|
||||
|
||||
Use this to:
|
||||
- Submit data to external APIs
|
||||
- Trigger webhooks
|
||||
- Post analysis results
|
||||
- Interact with exchange APIs (if authenticated)
|
||||
|
||||
Args:
|
||||
url: The URL to post to
|
||||
data: Dictionary of data to send in the request body
|
||||
|
||||
Returns:
|
||||
Response text from the server
|
||||
|
||||
Raises:
|
||||
ValueError: If the request fails
|
||||
|
||||
Example:
|
||||
http_post("https://webhook.site/xxx", {"message": "Trade executed"})
|
||||
"""
|
||||
import requests
|
||||
import json
|
||||
|
||||
try:
|
||||
response = requests.post(url, json=data, timeout=10)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except requests.RequestException as e:
|
||||
raise ValueError(f"HTTP POST request failed: {str(e)}")
|
||||
|
||||
|
||||
# Export tools list
|
||||
RESEARCH_TOOLS = [
|
||||
search_arxiv,
|
||||
search_wikipedia,
|
||||
search_web,
|
||||
http_get,
|
||||
http_post
|
||||
]
|
||||
475
backend.old/src/agent/tools/shape_tools.py
Normal file
475
backend.old/src/agent/tools/shape_tools.py
Normal file
@@ -0,0 +1,475 @@
|
||||
"""Shape/drawing tools for chart analysis."""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from langchain_core.tools import tool
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Map legacy/common shape type names to TradingView's native names
|
||||
SHAPE_TYPE_ALIASES: Dict[str, str] = {
|
||||
'trendline': 'trend_line',
|
||||
'fibonacci': 'fib_retracement',
|
||||
'fibonacci_extension': 'fib_trend_ext',
|
||||
'gann_fan': 'gannbox_fan',
|
||||
}
|
||||
|
||||
|
||||
def _get_registry():
|
||||
"""Get the global registry instance."""
|
||||
from . import _registry
|
||||
return _registry
|
||||
|
||||
|
||||
def _get_shape_store():
|
||||
"""Get the global ShapeStore instance."""
|
||||
registry = _get_registry()
|
||||
if registry and "ShapeStore" in registry.entries:
|
||||
return registry.entries["ShapeStore"].model
|
||||
return None
|
||||
|
||||
|
||||
@tool
|
||||
def search_shapes(
|
||||
start_time: Optional[int] = None,
|
||||
end_time: Optional[int] = None,
|
||||
shape_type: Optional[str] = None,
|
||||
symbol: Optional[str] = None,
|
||||
shape_ids: Optional[List[str]] = None,
|
||||
original_ids: Optional[List[str]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search for shapes/drawings using flexible filters.
|
||||
|
||||
This tool can search shapes by:
|
||||
- Time range (finds shapes that overlap the range)
|
||||
- Shape type (e.g., 'trendline', 'horizontal_line')
|
||||
- Symbol (e.g., 'BINANCE:BTC/USDT')
|
||||
- Specific shape IDs (TradingView's assigned IDs)
|
||||
- Original IDs (the IDs you specified when creating shapes)
|
||||
|
||||
Args:
|
||||
start_time: Optional start of time range (Unix timestamp in seconds)
|
||||
end_time: Optional end of time range (Unix timestamp in seconds)
|
||||
shape_type: Optional filter by shape type (e.g., 'trend_line', 'horizontal_line', 'rectangle')
|
||||
symbol: Optional filter by symbol (e.g., 'BINANCE:BTC/USDT')
|
||||
shape_ids: Optional list of specific shape IDs to retrieve (searches both id and original_id fields)
|
||||
original_ids: Optional list of original IDs to search for (the IDs you specified when creating)
|
||||
|
||||
Returns:
|
||||
List of matching shapes, each as a dictionary with:
|
||||
- id: Shape identifier (TradingView's assigned ID)
|
||||
- original_id: The ID you specified when creating the shape (if applicable)
|
||||
- type: Shape type
|
||||
- points: List of control points with time and price
|
||||
- color, line_width, line_style: Visual properties
|
||||
- properties: Additional shape-specific properties
|
||||
- symbol: Symbol the shape is drawn on
|
||||
- created_at, modified_at: Timestamps
|
||||
|
||||
Examples:
|
||||
# Find all shapes in the currently visible chart range
|
||||
shapes = search_shapes(
|
||||
start_time=chart_state.start_time,
|
||||
end_time=chart_state.end_time
|
||||
)
|
||||
|
||||
# Find only trendlines in a specific time range
|
||||
trendlines = search_shapes(
|
||||
start_time=1640000000,
|
||||
end_time=1650000000,
|
||||
shape_type='trend_line'
|
||||
)
|
||||
|
||||
# Find shapes for a specific symbol
|
||||
btc_shapes = search_shapes(
|
||||
start_time=1640000000,
|
||||
end_time=1650000000,
|
||||
symbol='BINANCE:BTC/USDT'
|
||||
)
|
||||
|
||||
# Get specific shapes by TradingView ID or original ID
|
||||
# This searches both the 'id' and 'original_id' fields
|
||||
selected = search_shapes(
|
||||
shape_ids=['trendline-1', 'support-42k', 'fib-retracement-1']
|
||||
)
|
||||
|
||||
# Get shapes by the original IDs you specified when creating them
|
||||
my_shapes = search_shapes(
|
||||
original_ids=['my-support-line', 'my-resistance-line']
|
||||
)
|
||||
|
||||
# Get all trendlines (no time filter)
|
||||
all_trendlines = search_shapes(shape_type='trend_line')
|
||||
"""
|
||||
shape_store = _get_shape_store()
|
||||
if not shape_store:
|
||||
raise ValueError("ShapeStore not initialized")
|
||||
|
||||
shapes_dict = shape_store.shapes
|
||||
matching_shapes = []
|
||||
|
||||
# If specific shape IDs are requested, search by both id and original_id
|
||||
if shape_ids:
|
||||
for requested_id in shape_ids:
|
||||
# First try direct ID lookup
|
||||
shape = shapes_dict.get(requested_id)
|
||||
if shape:
|
||||
# Still apply other filters if specified
|
||||
if symbol and shape.get('symbol') != symbol:
|
||||
continue
|
||||
if shape_type and shape.get('type') != shape_type:
|
||||
continue
|
||||
matching_shapes.append(shape)
|
||||
else:
|
||||
# If not found by ID, search by original_id
|
||||
for shape_id, shape in shapes_dict.items():
|
||||
if shape.get('original_id') == requested_id:
|
||||
# Still apply other filters if specified
|
||||
if symbol and shape.get('symbol') != symbol:
|
||||
continue
|
||||
if shape_type and shape.get('type') != shape_type:
|
||||
continue
|
||||
matching_shapes.append(shape)
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Found {len(matching_shapes)} shapes by ID filter (requested {len(shape_ids)} IDs)"
|
||||
+ (f" for type '{shape_type}'" if shape_type else "")
|
||||
+ (f" on symbol '{symbol}'" if symbol else "")
|
||||
)
|
||||
return matching_shapes
|
||||
|
||||
# If specific original IDs are requested, search by original_id only
|
||||
if original_ids:
|
||||
for original_id in original_ids:
|
||||
for shape_id, shape in shapes_dict.items():
|
||||
if shape.get('original_id') == original_id:
|
||||
# Still apply other filters if specified
|
||||
if symbol and shape.get('symbol') != symbol:
|
||||
continue
|
||||
if shape_type and shape.get('type') != shape_type:
|
||||
continue
|
||||
matching_shapes.append(shape)
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Found {len(matching_shapes)} shapes by original_id filter (requested {len(original_ids)} IDs)"
|
||||
+ (f" for type '{shape_type}'" if shape_type else "")
|
||||
+ (f" on symbol '{symbol}'" if symbol else "")
|
||||
)
|
||||
return matching_shapes
|
||||
|
||||
# Otherwise, search all shapes with filters
|
||||
for shape_id, shape in shapes_dict.items():
|
||||
# Filter by symbol if specified
|
||||
if symbol and shape.get('symbol') != symbol:
|
||||
continue
|
||||
|
||||
# Filter by type if specified
|
||||
if shape_type and shape.get('type') != shape_type:
|
||||
continue
|
||||
|
||||
# Filter by time range if specified
|
||||
if start_time is not None and end_time is not None:
|
||||
# Check if any control point falls within the time range
|
||||
# or if the shape spans across the time range
|
||||
points = shape.get('points', [])
|
||||
if not points:
|
||||
continue
|
||||
|
||||
# Get min and max times from shape's control points
|
||||
shape_times = [point['time'] for point in points]
|
||||
shape_min_time = min(shape_times)
|
||||
shape_max_time = max(shape_times)
|
||||
|
||||
# Check for overlap: shape overlaps if its range intersects with query range
|
||||
if not (shape_max_time >= start_time and shape_min_time <= end_time):
|
||||
continue
|
||||
|
||||
matching_shapes.append(shape)
|
||||
|
||||
logger.info(
|
||||
f"Found {len(matching_shapes)} shapes"
|
||||
+ (f" in time range {start_time}-{end_time}" if start_time and end_time else "")
|
||||
+ (f" for type '{shape_type}'" if shape_type else "")
|
||||
+ (f" on symbol '{symbol}'" if symbol else "")
|
||||
)
|
||||
|
||||
return matching_shapes
|
||||
|
||||
|
||||
@tool
|
||||
async def create_or_update_shape(
|
||||
shape_id: str,
|
||||
shape_type: str,
|
||||
points: List[Dict[str, Any]],
|
||||
color: Optional[str] = None,
|
||||
line_width: Optional[int] = None,
|
||||
line_style: Optional[str] = None,
|
||||
properties: Optional[Dict[str, Any]] = None,
|
||||
symbol: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a new shape or update an existing shape on the chart.
|
||||
|
||||
This tool allows the agent to draw shapes on the user's chart or modify
|
||||
existing shapes. Shapes are synchronized to the frontend in real-time.
|
||||
|
||||
IMPORTANT - Shape ID Mapping:
|
||||
When you create a shape, TradingView will assign its own internal ID that differs
|
||||
from the shape_id you provide. The shape will be updated in the store with:
|
||||
- id: TradingView's assigned ID
|
||||
- original_id: The shape_id you provided
|
||||
|
||||
To find your shape later, use search_shapes() and filter by original_id field.
|
||||
|
||||
Example:
|
||||
# Create a shape
|
||||
await create_or_update_shape(shape_id='my-support', ...)
|
||||
|
||||
# Later, find it by original_id
|
||||
shapes = search_shapes(symbol='BINANCE:BTC/USDT')
|
||||
my_shape = next((s for s in shapes if s.get('original_id') == 'my-support'), None)
|
||||
|
||||
Args:
|
||||
shape_id: Unique identifier for the shape (use existing ID to update, new ID to create)
|
||||
Note: TradingView will assign its own ID; your ID will be stored in original_id
|
||||
shape_type: Type of shape using TradingView's native names.
|
||||
|
||||
Single-point shapes (use 1 point):
|
||||
- 'horizontal_line': Horizontal support/resistance line
|
||||
- 'vertical_line': Vertical time marker
|
||||
- 'text': Text label
|
||||
- 'anchored_text': Anchored text annotation
|
||||
- 'anchored_note': Anchored note
|
||||
- 'note': Note annotation
|
||||
- 'emoji': Emoji marker
|
||||
- 'icon': Icon marker
|
||||
- 'sticker': Sticker marker
|
||||
- 'arrow_up': Upward arrow marker
|
||||
- 'arrow_down': Downward arrow marker
|
||||
- 'flag': Flag marker
|
||||
- 'long_position': Long position marker
|
||||
- 'short_position': Short position marker
|
||||
|
||||
Multi-point shapes (use 2+ points):
|
||||
- 'trend_line': Trendline (2 points)
|
||||
- 'rectangle': Rectangle (2 points: top-left, bottom-right)
|
||||
- 'fib_retracement': Fibonacci retracement (2 points)
|
||||
- 'fib_trend_ext': Fibonacci extension (3 points)
|
||||
- 'parallel_channel': Parallel channel (3 points)
|
||||
- 'arrow': Arrow (2 points)
|
||||
- 'circle': Circle/ellipse (2-3 points)
|
||||
- 'path': Free drawing path (3+ points)
|
||||
- 'pitchfork': Andrew's pitchfork (3 points)
|
||||
- 'gannbox_fan': Gann fan (2 points)
|
||||
- 'head_and_shoulders': Head and shoulders pattern (5 points)
|
||||
|
||||
points: List of control points, each with 'time' (Unix seconds) and 'price' fields
|
||||
color: Optional color (hex like '#FF0000' or name like 'red')
|
||||
line_width: Optional line width in pixels (default: 1)
|
||||
line_style: Optional line style: 'solid', 'dashed', 'dotted' (default: 'solid')
|
||||
properties: Optional dict of additional shape-specific properties
|
||||
symbol: Optional symbol to associate with the shape (defaults to current chart symbol)
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: 'created' or 'updated'
|
||||
- shape: The complete shape object (initially with your ID, will be updated to TV ID)
|
||||
|
||||
Examples:
|
||||
# Draw a trendline between two points
|
||||
await create_or_update_shape(
|
||||
shape_id='my-trendline-1',
|
||||
shape_type='trend_line',
|
||||
points=[
|
||||
{'time': 1640000000, 'price': 45000.0},
|
||||
{'time': 1650000000, 'price': 50000.0}
|
||||
],
|
||||
color='#00FF00',
|
||||
line_width=2
|
||||
)
|
||||
|
||||
# Draw a horizontal support line
|
||||
await create_or_update_shape(
|
||||
shape_id='support-1',
|
||||
shape_type='horizontal_line',
|
||||
points=[{'time': 1640000000, 'price': 42000.0}],
|
||||
color='blue',
|
||||
line_style='dashed'
|
||||
)
|
||||
|
||||
# Find your shape after creation using original_id
|
||||
shapes = search_shapes(symbol='BINANCE:BTC/USDT')
|
||||
my_shape = next((s for s in shapes if s.get('original_id') == 'support-1'), None)
|
||||
if my_shape:
|
||||
print(f"TradingView assigned ID: {my_shape['id']}")
|
||||
"""
|
||||
from schema.shape import Shape, ControlPoint
|
||||
import time as time_module
|
||||
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
shape_store = _get_shape_store()
|
||||
if not shape_store:
|
||||
raise ValueError("ShapeStore not initialized")
|
||||
|
||||
# Normalize shape type (handle legacy names)
|
||||
normalized_type = SHAPE_TYPE_ALIASES.get(shape_type, shape_type)
|
||||
if normalized_type != shape_type:
|
||||
logger.info(f"Normalized shape type '{shape_type}' -> '{normalized_type}'")
|
||||
|
||||
# Convert points to ControlPoint objects
|
||||
control_points = []
|
||||
for p in points:
|
||||
point_data = {
|
||||
'time': p['time'],
|
||||
'price': p['price']
|
||||
}
|
||||
# Only include channel if it's actually provided
|
||||
if 'channel' in p and p['channel'] is not None:
|
||||
point_data['channel'] = p['channel']
|
||||
control_points.append(ControlPoint(**point_data))
|
||||
|
||||
# Check if updating existing shape
|
||||
existing_shape = shape_store.shapes.get(shape_id)
|
||||
is_update = existing_shape is not None
|
||||
|
||||
# If symbol is not provided, try to get it from ChartStore
|
||||
if symbol is None and "ChartStore" in registry.entries:
|
||||
chart_store = registry.entries["ChartStore"].model
|
||||
if hasattr(chart_store, 'chart_state') and hasattr(chart_store.chart_state, 'symbol'):
|
||||
symbol = chart_store.chart_state.symbol
|
||||
logger.info(f"Using current chart symbol for shape: {symbol}")
|
||||
|
||||
now = int(time_module.time())
|
||||
|
||||
# Create shape object
|
||||
shape = Shape(
|
||||
id=shape_id,
|
||||
type=normalized_type,
|
||||
points=control_points,
|
||||
color=color,
|
||||
line_width=line_width,
|
||||
line_style=line_style,
|
||||
properties=properties or {},
|
||||
symbol=symbol,
|
||||
created_at=existing_shape.get('created_at') if existing_shape else now,
|
||||
modified_at=now
|
||||
)
|
||||
|
||||
# Update the store
|
||||
shape_store.shapes[shape_id] = shape.model_dump(mode="json")
|
||||
|
||||
# Trigger sync
|
||||
await registry.push_all()
|
||||
|
||||
logger.info(
|
||||
f"{'Updated' if is_update else 'Created'} shape '{shape_id}' "
|
||||
f"of type '{shape_type}' with {len(points)} points"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "updated" if is_update else "created",
|
||||
"shape": shape.model_dump(mode="json")
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_shape(shape_id: str) -> Dict[str, str]:
|
||||
"""Delete a shape from the chart.
|
||||
|
||||
Args:
|
||||
shape_id: ID of the shape to delete
|
||||
|
||||
Returns:
|
||||
Dictionary with status message
|
||||
|
||||
Raises:
|
||||
ValueError: If shape doesn't exist
|
||||
|
||||
Example:
|
||||
await delete_shape('my-trendline-1')
|
||||
"""
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
shape_store = _get_shape_store()
|
||||
if not shape_store:
|
||||
raise ValueError("ShapeStore not initialized")
|
||||
|
||||
if shape_id not in shape_store.shapes:
|
||||
raise ValueError(f"Shape '{shape_id}' not found")
|
||||
|
||||
# Delete the shape
|
||||
del shape_store.shapes[shape_id]
|
||||
|
||||
# Trigger sync
|
||||
await registry.push_all()
|
||||
|
||||
logger.info(f"Deleted shape '{shape_id}'")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Shape '{shape_id}' deleted"
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
def get_shape(shape_id: str) -> Dict[str, Any]:
|
||||
"""Get details of a specific shape by ID.
|
||||
|
||||
Args:
|
||||
shape_id: ID of the shape to retrieve
|
||||
|
||||
Returns:
|
||||
Dictionary containing the shape data
|
||||
|
||||
Raises:
|
||||
ValueError: If shape doesn't exist
|
||||
|
||||
Example:
|
||||
shape = get_shape('my-trendline-1')
|
||||
print(f"Shape type: {shape['type']}")
|
||||
print(f"Points: {shape['points']}")
|
||||
"""
|
||||
shape_store = _get_shape_store()
|
||||
if not shape_store:
|
||||
raise ValueError("ShapeStore not initialized")
|
||||
|
||||
shape = shape_store.shapes.get(shape_id)
|
||||
if not shape:
|
||||
raise ValueError(f"Shape '{shape_id}' not found")
|
||||
|
||||
return shape
|
||||
|
||||
|
||||
@tool
|
||||
def list_all_shapes() -> List[Dict[str, Any]]:
|
||||
"""List all shapes currently on the chart.
|
||||
|
||||
Returns:
|
||||
List of all shapes as dictionaries
|
||||
|
||||
Example:
|
||||
shapes = list_all_shapes()
|
||||
print(f"Total shapes: {len(shapes)}")
|
||||
for shape in shapes:
|
||||
print(f" - {shape['id']}: {shape['type']}")
|
||||
"""
|
||||
shape_store = _get_shape_store()
|
||||
if not shape_store:
|
||||
raise ValueError("ShapeStore not initialized")
|
||||
|
||||
return list(shape_store.shapes.values())
|
||||
|
||||
|
||||
SHAPE_TOOLS = [
|
||||
search_shapes,
|
||||
create_or_update_shape,
|
||||
delete_shape,
|
||||
get_shape,
|
||||
list_all_shapes
|
||||
]
|
||||
138
backend.old/src/agent/tools/sync_tools.py
Normal file
138
backend.old/src/agent/tools/sync_tools.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Synchronization store tools."""
|
||||
|
||||
from typing import Dict, Any, List
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
def _get_registry():
|
||||
"""Get the global registry instance."""
|
||||
from . import _registry
|
||||
return _registry
|
||||
|
||||
|
||||
@tool
|
||||
def list_sync_stores() -> List[str]:
|
||||
"""List all available synchronization stores.
|
||||
|
||||
Returns:
|
||||
List of store names that can be read/written
|
||||
"""
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
return []
|
||||
return list(registry.entries.keys())
|
||||
|
||||
|
||||
@tool
|
||||
def read_sync_state(store_name: str) -> Dict[str, Any]:
|
||||
"""Read the current state of a synchronization store.
|
||||
|
||||
Args:
|
||||
store_name: Name of the store to read (e.g., "TraderState", "StrategyState")
|
||||
|
||||
Returns:
|
||||
Dictionary containing the current state of the store
|
||||
|
||||
Raises:
|
||||
ValueError: If store_name doesn't exist
|
||||
"""
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
entry = registry.entries.get(store_name)
|
||||
if not entry:
|
||||
available = list(registry.entries.keys())
|
||||
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
|
||||
|
||||
return entry.model.model_dump(mode="json")
|
||||
|
||||
|
||||
@tool
|
||||
async def write_sync_state(store_name: str, updates: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Update the state of a synchronization store.
|
||||
|
||||
This will apply the updates to the store and trigger synchronization
|
||||
with all connected clients.
|
||||
|
||||
Args:
|
||||
store_name: Name of the store to update
|
||||
updates: Dictionary of field updates (field_name: new_value)
|
||||
|
||||
Returns:
|
||||
Dictionary with status and updated fields
|
||||
|
||||
Raises:
|
||||
ValueError: If store_name doesn't exist or updates are invalid
|
||||
"""
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
entry = registry.entries.get(store_name)
|
||||
if not entry:
|
||||
available = list(registry.entries.keys())
|
||||
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
|
||||
|
||||
try:
|
||||
# Get current state
|
||||
current_state = entry.model.model_dump(mode="json")
|
||||
|
||||
# Apply updates
|
||||
new_state = {**current_state, **updates}
|
||||
|
||||
# Update the model
|
||||
registry._update_model(entry.model, new_state)
|
||||
|
||||
# Trigger sync
|
||||
await registry.push_all()
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"store": store_name,
|
||||
"updated_fields": list(updates.keys())
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to update store '{store_name}': {str(e)}")
|
||||
|
||||
|
||||
@tool
|
||||
def get_store_schema(store_name: str) -> Dict[str, Any]:
|
||||
"""Get the schema/structure of a synchronization store.
|
||||
|
||||
This shows what fields are available and their types.
|
||||
|
||||
Args:
|
||||
store_name: Name of the store
|
||||
|
||||
Returns:
|
||||
Dictionary describing the store's schema
|
||||
|
||||
Raises:
|
||||
ValueError: If store_name doesn't exist
|
||||
"""
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
entry = registry.entries.get(store_name)
|
||||
if not entry:
|
||||
available = list(registry.entries.keys())
|
||||
raise ValueError(f"Store '{store_name}' not found. Available stores: {available}")
|
||||
|
||||
# Get model schema
|
||||
schema = entry.model.model_json_schema()
|
||||
|
||||
return {
|
||||
"store_name": store_name,
|
||||
"schema": schema
|
||||
}
|
||||
|
||||
|
||||
SYNC_TOOLS = [
|
||||
list_sync_stores,
|
||||
read_sync_state,
|
||||
write_sync_state,
|
||||
get_store_schema
|
||||
]
|
||||
366
backend.old/src/agent/tools/trigger_tools.py
Normal file
366
backend.old/src/agent/tools/trigger_tools.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""
|
||||
Agent tools for trigger system.
|
||||
|
||||
Allows agents to:
|
||||
- Schedule recurring tasks (cron-style)
|
||||
- Execute one-time triggers
|
||||
- Manage scheduled triggers (list, cancel)
|
||||
- Connect events to sub-agent runs or lambdas
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global references set by main.py
|
||||
_trigger_queue = None
|
||||
_trigger_scheduler = None
|
||||
_coordinator = None
|
||||
|
||||
|
||||
def set_trigger_queue(queue):
|
||||
"""Set the global TriggerQueue instance for tools to use."""
|
||||
global _trigger_queue
|
||||
_trigger_queue = queue
|
||||
|
||||
|
||||
def set_trigger_scheduler(scheduler):
|
||||
"""Set the global TriggerScheduler instance for tools to use."""
|
||||
global _trigger_scheduler
|
||||
_trigger_scheduler = scheduler
|
||||
|
||||
|
||||
def set_coordinator(coordinator):
|
||||
"""Set the global CommitCoordinator instance for tools to use."""
|
||||
global _coordinator
|
||||
_coordinator = coordinator
|
||||
|
||||
|
||||
def _get_trigger_queue():
|
||||
"""Get the global trigger queue instance."""
|
||||
if not _trigger_queue:
|
||||
raise ValueError("TriggerQueue not initialized")
|
||||
return _trigger_queue
|
||||
|
||||
|
||||
def _get_trigger_scheduler():
|
||||
"""Get the global trigger scheduler instance."""
|
||||
if not _trigger_scheduler:
|
||||
raise ValueError("TriggerScheduler not initialized")
|
||||
return _trigger_scheduler
|
||||
|
||||
|
||||
def _get_coordinator():
|
||||
"""Get the global coordinator instance."""
|
||||
if not _coordinator:
|
||||
raise ValueError("CommitCoordinator not initialized")
|
||||
return _coordinator
|
||||
|
||||
|
||||
@tool
|
||||
async def schedule_agent_prompt(
|
||||
prompt: str,
|
||||
schedule_type: str,
|
||||
schedule_config: Dict[str, Any],
|
||||
name: Optional[str] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Schedule an agent to run with a specific prompt on a recurring schedule.
|
||||
|
||||
This allows you to set up automated tasks where the agent runs periodically
|
||||
with a predefined prompt. Useful for:
|
||||
- Daily market analysis reports
|
||||
- Hourly portfolio rebalancing checks
|
||||
- Weekly performance summaries
|
||||
- Monitoring alerts
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send to the agent when triggered
|
||||
schedule_type: Type of schedule - "interval" or "cron"
|
||||
schedule_config: Schedule configuration:
|
||||
For "interval": {"minutes": 5} or {"hours": 1, "minutes": 30}
|
||||
For "cron": {"hour": "9", "minute": "0"} for 9:00 AM daily
|
||||
{"hour": "9", "minute": "0", "day_of_week": "mon-fri"}
|
||||
name: Optional descriptive name for this scheduled task
|
||||
|
||||
Returns:
|
||||
Dictionary with job_id and confirmation message
|
||||
|
||||
Examples:
|
||||
# Run every 5 minutes
|
||||
schedule_agent_prompt(
|
||||
prompt="Check BTC price and alert if > $50k",
|
||||
schedule_type="interval",
|
||||
schedule_config={"minutes": 5}
|
||||
)
|
||||
|
||||
# Run daily at 9 AM
|
||||
schedule_agent_prompt(
|
||||
prompt="Generate daily market summary",
|
||||
schedule_type="cron",
|
||||
schedule_config={"hour": "9", "minute": "0"}
|
||||
)
|
||||
|
||||
# Run hourly on weekdays
|
||||
schedule_agent_prompt(
|
||||
prompt="Monitor portfolio for rebalancing opportunities",
|
||||
schedule_type="cron",
|
||||
schedule_config={"minute": "0", "day_of_week": "mon-fri"}
|
||||
)
|
||||
"""
|
||||
from trigger.handlers import LambdaHandler
|
||||
from trigger import Priority
|
||||
|
||||
scheduler = _get_trigger_scheduler()
|
||||
queue = _get_trigger_queue()
|
||||
|
||||
if not name:
|
||||
name = f"agent_prompt_{hash(prompt) % 10000}"
|
||||
|
||||
# Create a lambda that enqueues an agent trigger with the prompt
|
||||
async def agent_prompt_lambda():
|
||||
from trigger.handlers import AgentTriggerHandler
|
||||
|
||||
# Create agent trigger (will use current session's context)
|
||||
# In production, you'd want to specify which session/user this belongs to
|
||||
trigger = AgentTriggerHandler(
|
||||
session_id="scheduled", # Special session for scheduled tasks
|
||||
message_content=prompt,
|
||||
coordinator=_get_coordinator(),
|
||||
)
|
||||
|
||||
await queue.enqueue(trigger)
|
||||
return [] # No direct commit intents
|
||||
|
||||
# Wrap in lambda handler
|
||||
lambda_trigger = LambdaHandler(
|
||||
name=f"scheduled_{name}",
|
||||
func=agent_prompt_lambda,
|
||||
priority=Priority.TIMER,
|
||||
)
|
||||
|
||||
# Schedule based on type
|
||||
if schedule_type == "interval":
|
||||
job_id = scheduler.schedule_interval(
|
||||
lambda_trigger,
|
||||
seconds=schedule_config.get("seconds"),
|
||||
minutes=schedule_config.get("minutes"),
|
||||
hours=schedule_config.get("hours"),
|
||||
priority=Priority.TIMER,
|
||||
)
|
||||
elif schedule_type == "cron":
|
||||
job_id = scheduler.schedule_cron(
|
||||
lambda_trigger,
|
||||
minute=schedule_config.get("minute"),
|
||||
hour=schedule_config.get("hour"),
|
||||
day=schedule_config.get("day"),
|
||||
month=schedule_config.get("month"),
|
||||
day_of_week=schedule_config.get("day_of_week"),
|
||||
priority=Priority.TIMER,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid schedule_type: {schedule_type}. Use 'interval' or 'cron'")
|
||||
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"message": f"Scheduled '{name}' with job_id={job_id}",
|
||||
"schedule_type": schedule_type,
|
||||
"config": schedule_config,
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
async def execute_agent_prompt_once(
|
||||
prompt: str,
|
||||
priority: str = "normal",
|
||||
) -> Dict[str, str]:
|
||||
"""Execute an agent prompt once, immediately (enqueued with priority).
|
||||
|
||||
Use this to trigger a sub-agent with a specific task without waiting for
|
||||
a user message. Useful for:
|
||||
- Background analysis tasks
|
||||
- One-time data processing
|
||||
- Responding to specific events
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send to the agent
|
||||
priority: Priority level - "high", "normal", or "low"
|
||||
|
||||
Returns:
|
||||
Confirmation that the prompt was enqueued
|
||||
|
||||
Example:
|
||||
execute_agent_prompt_once(
|
||||
prompt="Analyze the last 100 BTC/USDT bars and identify support levels",
|
||||
priority="high"
|
||||
)
|
||||
"""
|
||||
from trigger.handlers import AgentTriggerHandler
|
||||
from trigger import Priority
|
||||
|
||||
queue = _get_trigger_queue()
|
||||
|
||||
# Map string priority to enum
|
||||
priority_map = {
|
||||
"high": Priority.USER_AGENT, # Same priority as user messages
|
||||
"normal": Priority.SYSTEM,
|
||||
"low": Priority.LOW,
|
||||
}
|
||||
priority_enum = priority_map.get(priority.lower(), Priority.SYSTEM)
|
||||
|
||||
# Create agent trigger
|
||||
trigger = AgentTriggerHandler(
|
||||
session_id="oneshot",
|
||||
message_content=prompt,
|
||||
coordinator=_get_coordinator(),
|
||||
)
|
||||
|
||||
# Enqueue with priority override
|
||||
queue_seq = await queue.enqueue(trigger, priority_enum)
|
||||
|
||||
return {
|
||||
"queue_seq": queue_seq,
|
||||
"message": f"Enqueued agent prompt with priority={priority}",
|
||||
"prompt": prompt[:100] + "..." if len(prompt) > 100 else prompt,
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
def list_scheduled_triggers() -> List[Dict[str, Any]]:
|
||||
"""List all currently scheduled triggers.
|
||||
|
||||
Returns:
|
||||
List of dictionaries with job information (id, name, next_run_time)
|
||||
|
||||
Example:
|
||||
jobs = list_scheduled_triggers()
|
||||
for job in jobs:
|
||||
print(f"{job['id']}: {job['name']} - next run at {job['next_run_time']}")
|
||||
"""
|
||||
scheduler = _get_trigger_scheduler()
|
||||
jobs = scheduler.get_jobs()
|
||||
|
||||
result = []
|
||||
for job in jobs:
|
||||
result.append({
|
||||
"id": job.id,
|
||||
"name": job.name,
|
||||
"next_run_time": str(job.next_run_time) if job.next_run_time else None,
|
||||
"trigger": str(job.trigger),
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@tool
|
||||
def cancel_scheduled_trigger(job_id: str) -> Dict[str, str]:
|
||||
"""Cancel a scheduled trigger by its job ID.
|
||||
|
||||
Args:
|
||||
job_id: The job ID returned from schedule_agent_prompt or list_scheduled_triggers
|
||||
|
||||
Returns:
|
||||
Confirmation message
|
||||
|
||||
Example:
|
||||
cancel_scheduled_trigger("interval_123")
|
||||
"""
|
||||
scheduler = _get_trigger_scheduler()
|
||||
success = scheduler.remove_job(job_id)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Cancelled job {job_id}",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Job {job_id} not found",
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
async def on_data_update_run_agent(
|
||||
source_name: str,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
prompt_template: str,
|
||||
) -> Dict[str, str]:
|
||||
"""Set up an agent to run whenever new data arrives for a specific symbol.
|
||||
|
||||
The prompt_template can include {variables} that will be filled with bar data:
|
||||
- {time}: Bar timestamp
|
||||
- {open}, {high}, {low}, {close}, {volume}: OHLCV values
|
||||
- {symbol}: Trading pair symbol
|
||||
- {source}: Data source name
|
||||
|
||||
Args:
|
||||
source_name: Name of data source (e.g., "binance")
|
||||
symbol: Trading pair (e.g., "BTC/USDT")
|
||||
resolution: Time resolution (e.g., "1m", "5m", "1h")
|
||||
prompt_template: Template string for agent prompt
|
||||
|
||||
Returns:
|
||||
Confirmation with subscription details
|
||||
|
||||
Example:
|
||||
on_data_update_run_agent(
|
||||
source_name="binance",
|
||||
symbol="BTC/USDT",
|
||||
resolution="1m",
|
||||
prompt_template="New bar on {symbol}: close={close}. Check if we should trade."
|
||||
)
|
||||
|
||||
Note:
|
||||
This is a simplified version. Full implementation would wire into
|
||||
DataSource subscription system to trigger on every bar update.
|
||||
"""
|
||||
# TODO: Implement proper DataSource subscription integration
|
||||
# For now, return placeholder
|
||||
|
||||
return {
|
||||
"status": "not_implemented",
|
||||
"message": "Data-driven agent triggers coming soon",
|
||||
"config": {
|
||||
"source": source_name,
|
||||
"symbol": symbol,
|
||||
"resolution": resolution,
|
||||
"prompt_template": prompt_template,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
def get_trigger_system_stats() -> Dict[str, Any]:
|
||||
"""Get statistics about the trigger system.
|
||||
|
||||
Returns:
|
||||
Dictionary with queue depth, execution stats, etc.
|
||||
|
||||
Example:
|
||||
stats = get_trigger_system_stats()
|
||||
print(f"Queue depth: {stats['queue_depth']}")
|
||||
print(f"Current seq: {stats['current_seq']}")
|
||||
"""
|
||||
queue = _get_trigger_queue()
|
||||
coordinator = _get_coordinator()
|
||||
|
||||
return {
|
||||
"queue_depth": queue.get_queue_size(),
|
||||
"queue_running": queue.is_running(),
|
||||
"coordinator_stats": coordinator.get_stats(),
|
||||
}
|
||||
|
||||
|
||||
# Export tools list
|
||||
TRIGGER_TOOLS = [
|
||||
schedule_agent_prompt,
|
||||
execute_agent_prompt_once,
|
||||
list_scheduled_triggers,
|
||||
cancel_scheduled_trigger,
|
||||
on_data_update_run_agent,
|
||||
get_trigger_system_stats,
|
||||
]
|
||||
@@ -6,9 +6,10 @@ the free CCXT library (not ccxt.pro), supporting both historical data and
|
||||
polling-based subscriptions.
|
||||
|
||||
Numerical Precision:
|
||||
- Uses Decimal for all monetary values (prices, volumes) to avoid floating-point errors
|
||||
- OHLCV data uses native floats for optimal DataFrame/analysis performance
|
||||
- Account balances and order data should use Decimal (via _to_decimal method)
|
||||
- CCXT returns numeric values as strings or floats depending on configuration
|
||||
- All financial values are converted to Decimal to maintain precision
|
||||
- Price data converted to float (_to_float), financial data to Decimal (_to_decimal)
|
||||
|
||||
Real-time Updates:
|
||||
- Uses polling instead of WebSocket (free CCXT doesn't have WebSocket support)
|
||||
@@ -72,6 +73,20 @@ class CCXTDataSource(DataSource):
|
||||
exchange_class = getattr(ccxt, exchange_id)
|
||||
self.exchange = exchange_class(self._config)
|
||||
|
||||
# Configure CCXT to use Decimal mode for precise financial calculations
|
||||
# This ensures all numeric values from the exchange use Decimal internally
|
||||
# We then convert OHLCV to float for DataFrame performance, but keep
|
||||
# Decimal precision for account balances, order sizes, etc.
|
||||
from decimal import Decimal as PythonDecimal
|
||||
self.exchange.number = PythonDecimal
|
||||
|
||||
# Log the precision mode being used by this exchange
|
||||
precision_mode = getattr(self.exchange, 'precisionMode', 'UNKNOWN')
|
||||
logger.info(
|
||||
f"CCXT {exchange_id}: Configured with Decimal mode. "
|
||||
f"Exchange precision mode: {precision_mode}"
|
||||
)
|
||||
|
||||
if sandbox and hasattr(self.exchange, 'set_sandbox_mode'):
|
||||
self.exchange.set_sandbox_mode(True)
|
||||
|
||||
@@ -103,6 +118,33 @@ class CCXTDataSource(DataSource):
|
||||
return Decimal(str(value))
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _to_float(value: Union[str, int, float, Decimal, None]) -> Optional[float]:
|
||||
"""
|
||||
Convert a value to float for OHLCV data.
|
||||
|
||||
OHLCV data is used for charting and DataFrame analysis, where native
|
||||
floats provide better performance and compatibility with pandas/numpy.
|
||||
For financial precision (balances, order sizes), use _to_decimal() instead.
|
||||
|
||||
When CCXT is in Decimal mode (exchange.number = Decimal), it returns
|
||||
Decimal objects. This method converts them to float for performance.
|
||||
|
||||
Handles CCXT's output in both modes:
|
||||
- Decimal mode: receives Decimal objects
|
||||
- Default mode: receives strings, floats, or ints
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, float):
|
||||
return value
|
||||
if isinstance(value, Decimal):
|
||||
# CCXT in Decimal mode - convert to float for OHLCV
|
||||
return float(value)
|
||||
if isinstance(value, (str, int)):
|
||||
return float(value)
|
||||
return None
|
||||
|
||||
async def _ensure_markets_loaded(self):
|
||||
"""Ensure markets are loaded from exchange"""
|
||||
if not self._markets_loaded:
|
||||
@@ -241,31 +283,31 @@ class CCXTDataSource(DataSource):
|
||||
columns=[
|
||||
ColumnInfo(
|
||||
name="open",
|
||||
type="decimal",
|
||||
type="float",
|
||||
description=f"Opening price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="high",
|
||||
type="decimal",
|
||||
type="float",
|
||||
description=f"Highest price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="low",
|
||||
type="decimal",
|
||||
type="float",
|
||||
description=f"Lowest price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="close",
|
||||
type="decimal",
|
||||
type="float",
|
||||
description=f"Closing price in {quote}",
|
||||
unit=quote,
|
||||
),
|
||||
ColumnInfo(
|
||||
name="volume",
|
||||
type="decimal",
|
||||
type="float",
|
||||
description=f"Trading volume in {base}",
|
||||
unit=base,
|
||||
),
|
||||
@@ -370,7 +412,7 @@ class CCXTDataSource(DataSource):
|
||||
all_ohlcv = all_ohlcv[:countback]
|
||||
break
|
||||
|
||||
# Convert to our Bar format with Decimal precision
|
||||
# Convert to our Bar format with float for OHLCV (used in DataFrames)
|
||||
bars = []
|
||||
for candle in all_ohlcv:
|
||||
timestamp_ms, open_price, high, low, close, volume = candle
|
||||
@@ -384,11 +426,11 @@ class CCXTDataSource(DataSource):
|
||||
Bar(
|
||||
time=timestamp,
|
||||
data={
|
||||
"open": self._to_decimal(open_price),
|
||||
"high": self._to_decimal(high),
|
||||
"low": self._to_decimal(low),
|
||||
"close": self._to_decimal(close),
|
||||
"volume": self._to_decimal(volume),
|
||||
"open": self._to_float(open_price),
|
||||
"high": self._to_float(high),
|
||||
"low": self._to_float(low),
|
||||
"close": self._to_float(close),
|
||||
"volume": self._to_float(volume),
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -476,14 +518,14 @@ class CCXTDataSource(DataSource):
|
||||
if timestamp > last_timestamp:
|
||||
self._last_bars[subscription_id] = timestamp
|
||||
|
||||
# Convert to our format with Decimal precision
|
||||
# Convert to our format with float for OHLCV (used in DataFrames)
|
||||
tick_data = {
|
||||
"time": timestamp,
|
||||
"open": self._to_decimal(open_price),
|
||||
"high": self._to_decimal(high),
|
||||
"low": self._to_decimal(low),
|
||||
"close": self._to_decimal(close),
|
||||
"volume": self._to_decimal(volume),
|
||||
"open": self._to_float(open_price),
|
||||
"high": self._to_float(high),
|
||||
"low": self._to_float(low),
|
||||
"close": self._to_float(close),
|
||||
"volume": self._to_float(volume),
|
||||
}
|
||||
|
||||
# Call the callback
|
||||
179
backend.old/src/exchange_kernel/README.md
Normal file
179
backend.old/src/exchange_kernel/README.md
Normal file
@@ -0,0 +1,179 @@
|
||||
# Exchange Kernel API
|
||||
|
||||
A Kubernetes-style declarative API for managing orders across different exchanges.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
The Exchange Kernel maintains two separate views of order state:
|
||||
|
||||
1. **Desired State (Intent)**: What the strategy kernel wants
|
||||
2. **Actual State (Reality)**: What currently exists on the exchange
|
||||
|
||||
A reconciliation loop continuously works to bring actual state into alignment with desired state, handling errors, retries, and edge cases automatically.
|
||||
|
||||
## Core Components
|
||||
|
||||
### Models (`models.py`)
|
||||
|
||||
- **OrderIntent**: Desired order state from strategy kernel
|
||||
- **OrderState**: Actual current order state on exchange
|
||||
- **Position**: Current position (spot, margin, perp, futures, options)
|
||||
- **Asset**: Asset holdings with metadata
|
||||
- **AccountState**: Complete account snapshot (balances, positions, margin)
|
||||
- **AssetMetadata**: Asset type descriptions and trading parameters
|
||||
|
||||
### Events (`events.py`)
|
||||
|
||||
Order lifecycle events:
|
||||
- `OrderSubmitted`, `OrderAccepted`, `OrderRejected`
|
||||
- `OrderPartiallyFilled`, `OrderFilled`, `OrderCanceled`
|
||||
- `OrderModified`, `OrderExpired`
|
||||
|
||||
Position events:
|
||||
- `PositionOpened`, `PositionModified`, `PositionClosed`
|
||||
|
||||
Account events:
|
||||
- `AccountBalanceUpdated`, `MarginCallWarning`
|
||||
|
||||
### Base Interface (`base.py`)
|
||||
|
||||
Abstract `ExchangeKernel` class defining:
|
||||
|
||||
**Command API**:
|
||||
- `place_order()`, `place_order_group()` - Create order intents
|
||||
- `cancel_order()`, `modify_order()` - Update intents
|
||||
- `cancel_all_orders()` - Bulk cancellation
|
||||
|
||||
**Query API**:
|
||||
- `get_order_intent()`, `get_order_state()` - Query single order
|
||||
- `get_all_intents()`, `get_all_orders()` - Query all orders
|
||||
- `get_positions()`, `get_account_state()` - Query positions/balances
|
||||
- `get_symbol_metadata()`, `get_asset_metadata()` - Query market info
|
||||
|
||||
**Event API**:
|
||||
- `subscribe_events()`, `unsubscribe_events()` - Event notifications
|
||||
|
||||
**Lifecycle**:
|
||||
- `start()`, `stop()` - Kernel lifecycle
|
||||
- `health_check()` - Connection status
|
||||
- `force_reconciliation()` - Manual reconciliation trigger
|
||||
|
||||
### State Management (`state.py`)
|
||||
|
||||
- **IntentStateStore**: Storage for desired state (durable, survives restarts)
|
||||
- **ActualStateStore**: Storage for actual exchange state (ephemeral cache)
|
||||
- **ReconciliationEngine**: Framework for intent→reality reconciliation
|
||||
- **InMemory implementations**: For testing/prototyping
|
||||
|
||||
## Standard Order Model
|
||||
|
||||
Defined in `schema/order_spec.py`:
|
||||
|
||||
```python
|
||||
StandardOrder(
|
||||
symbol_id="BTC/USD",
|
||||
side=Side.BUY,
|
||||
amount=1.0,
|
||||
amount_type=AmountType.BASE, # or QUOTE for exact-out
|
||||
limit_price=50000.0, # None for market orders
|
||||
time_in_force=TimeInForce.GTC,
|
||||
conditional_trigger=ConditionalTrigger(...), # Optional stop-loss/take-profit
|
||||
conditional_mode=ConditionalOrderMode.UNIFIED_ADJUSTING,
|
||||
reduce_only=False,
|
||||
post_only=False,
|
||||
iceberg_qty=None,
|
||||
)
|
||||
```
|
||||
|
||||
## Symbol Metadata
|
||||
|
||||
Markets describe their capabilities via `SymbolMetadata`:
|
||||
|
||||
- **AmountConstraints**: Min/max order size, step size
|
||||
- **PriceConstraints**: Tick size, tick spacing mode (fixed/dynamic/continuous)
|
||||
- **MarketCapabilities**:
|
||||
- Supported sides (BUY, SELL)
|
||||
- Supported amount types (BASE, QUOTE, or both)
|
||||
- Market vs limit order support
|
||||
- Time-in-force options (GTC, IOC, FOK, DAY, GTD)
|
||||
- Conditional order support (stop-loss, take-profit, trailing stops)
|
||||
- Advanced features (post-only, reduce-only, iceberg)
|
||||
|
||||
## Asset Types
|
||||
|
||||
Comprehensive asset type system supporting:
|
||||
- **SPOT**: Cash markets
|
||||
- **MARGIN**: Margin trading
|
||||
- **PERP**: Perpetual futures
|
||||
- **FUTURE**: Dated futures
|
||||
- **OPTION**: Options contracts
|
||||
- **SYNTHETIC**: Derived instruments
|
||||
|
||||
Each asset has metadata describing contract specs, settlement, margin requirements, etc.
|
||||
|
||||
## Usage Pattern
|
||||
|
||||
```python
|
||||
# Create exchange kernel for specific exchange
|
||||
kernel = SomeExchangeKernel(exchange_id="binance_main")
|
||||
|
||||
# Subscribe to events
|
||||
kernel.subscribe_events(my_event_handler)
|
||||
|
||||
# Start kernel
|
||||
await kernel.start()
|
||||
|
||||
# Place order (creates intent, kernel handles execution)
|
||||
intent_id = await kernel.place_order(
|
||||
StandardOrder(
|
||||
symbol_id="BTC/USD",
|
||||
side=Side.BUY,
|
||||
amount=1.0,
|
||||
amount_type=AmountType.BASE,
|
||||
limit_price=50000.0,
|
||||
)
|
||||
)
|
||||
|
||||
# Query desired state
|
||||
intent = await kernel.get_order_intent(intent_id)
|
||||
|
||||
# Query actual state
|
||||
state = await kernel.get_order_state(intent_id)
|
||||
|
||||
# Modify order (updates intent, kernel reconciles)
|
||||
await kernel.modify_order(intent_id, new_order)
|
||||
|
||||
# Cancel order
|
||||
await kernel.cancel_order(intent_id)
|
||||
|
||||
# Query positions
|
||||
positions = await kernel.get_positions()
|
||||
|
||||
# Query account state
|
||||
account = await kernel.get_account_state()
|
||||
```
|
||||
|
||||
## Implementation Status
|
||||
|
||||
✅ **Complete**:
|
||||
- Data models and type definitions
|
||||
- Event definitions
|
||||
- Abstract interface
|
||||
- State store framework
|
||||
- In-memory stores for testing
|
||||
|
||||
⏳ **TODO** (Exchange-specific implementations):
|
||||
- Concrete ExchangeKernel implementations per exchange
|
||||
- Reconciliation engine implementation
|
||||
- Exchange API adapters
|
||||
- Persistent state storage (database)
|
||||
- Error handling and retry logic
|
||||
- Monitoring and observability
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. Create concrete implementations for specific exchanges (Binance, Uniswap, etc.)
|
||||
2. Implement reconciliation engine with proper error handling
|
||||
3. Add persistent storage backend for intents
|
||||
4. Build integration tests
|
||||
5. Add monitoring/metrics collection
|
||||
75
backend.old/src/exchange_kernel/__init__.py
Normal file
75
backend.old/src/exchange_kernel/__init__.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
Exchange Kernel API
|
||||
|
||||
The exchange kernel provides a Kubernetes-style declarative API for managing orders
|
||||
across different exchanges. It maintains both desired state (intent) and actual state
|
||||
(current orders on exchange) and reconciles them continuously.
|
||||
|
||||
Key concepts:
|
||||
- OrderIntent: What the strategy kernel wants
|
||||
- OrderState: What actually exists on the exchange
|
||||
- Reconciliation: Bringing actual state into alignment with desired state
|
||||
"""
|
||||
|
||||
from .base import ExchangeKernel
|
||||
from .events import (
|
||||
OrderEvent,
|
||||
OrderSubmitted,
|
||||
OrderAccepted,
|
||||
OrderRejected,
|
||||
OrderPartiallyFilled,
|
||||
OrderFilled,
|
||||
OrderCanceled,
|
||||
OrderModified,
|
||||
OrderExpired,
|
||||
PositionEvent,
|
||||
PositionOpened,
|
||||
PositionModified,
|
||||
PositionClosed,
|
||||
AccountEvent,
|
||||
AccountBalanceUpdated,
|
||||
MarginCallWarning,
|
||||
)
|
||||
from .models import (
|
||||
OrderIntent,
|
||||
OrderState,
|
||||
Position,
|
||||
Asset,
|
||||
AssetMetadata,
|
||||
AccountState,
|
||||
Balance,
|
||||
)
|
||||
from .state import IntentStateStore, ActualStateStore
|
||||
|
||||
__all__ = [
|
||||
# Core interface
|
||||
"ExchangeKernel",
|
||||
# Events
|
||||
"OrderEvent",
|
||||
"OrderSubmitted",
|
||||
"OrderAccepted",
|
||||
"OrderRejected",
|
||||
"OrderPartiallyFilled",
|
||||
"OrderFilled",
|
||||
"OrderCanceled",
|
||||
"OrderModified",
|
||||
"OrderExpired",
|
||||
"PositionEvent",
|
||||
"PositionOpened",
|
||||
"PositionModified",
|
||||
"PositionClosed",
|
||||
"AccountEvent",
|
||||
"AccountBalanceUpdated",
|
||||
"MarginCallWarning",
|
||||
# Models
|
||||
"OrderIntent",
|
||||
"OrderState",
|
||||
"Position",
|
||||
"Asset",
|
||||
"AssetMetadata",
|
||||
"AccountState",
|
||||
"Balance",
|
||||
# State management
|
||||
"IntentStateStore",
|
||||
"ActualStateStore",
|
||||
]
|
||||
361
backend.old/src/exchange_kernel/base.py
Normal file
361
backend.old/src/exchange_kernel/base.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
Base interface for Exchange Kernels.
|
||||
|
||||
Defines the abstract API that all exchange kernel implementations must support.
|
||||
Each exchange (or exchange type) will have its own kernel implementation.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Any
|
||||
|
||||
from .models import (
|
||||
OrderIntent,
|
||||
OrderState,
|
||||
Position,
|
||||
AccountState,
|
||||
AssetMetadata,
|
||||
)
|
||||
from .events import BaseEvent
|
||||
from ..schema.order_spec import (
|
||||
StandardOrder,
|
||||
StandardOrderGroup,
|
||||
SymbolMetadata,
|
||||
)
|
||||
|
||||
|
||||
class ExchangeKernel(ABC):
|
||||
"""
|
||||
Abstract base class for exchange kernels.
|
||||
|
||||
An exchange kernel manages the lifecycle of orders on a specific exchange,
|
||||
maintaining both desired state (intents from strategy kernel) and actual
|
||||
state (current orders on exchange), and continuously reconciling them.
|
||||
|
||||
Think of it as a Kubernetes-style controller for trading orders.
|
||||
"""
|
||||
|
||||
def __init__(self, exchange_id: str):
|
||||
"""
|
||||
Initialize the exchange kernel.
|
||||
|
||||
Args:
|
||||
exchange_id: Unique identifier for this exchange instance
|
||||
"""
|
||||
self.exchange_id = exchange_id
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Command API - Strategy kernel sends intents
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@abstractmethod
|
||||
async def place_order(self, order: StandardOrder, metadata: dict[str, Any] | None = None) -> str:
|
||||
"""
|
||||
Place a single order on the exchange.
|
||||
|
||||
This creates an OrderIntent and begins the reconciliation process to
|
||||
get the order onto the exchange.
|
||||
|
||||
Args:
|
||||
order: The order specification
|
||||
metadata: Optional strategy-specific metadata
|
||||
|
||||
Returns:
|
||||
intent_id: Unique identifier for this order intent
|
||||
|
||||
Raises:
|
||||
ValidationError: If order violates market constraints
|
||||
ExchangeError: If exchange rejects the order
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def place_order_group(
|
||||
self,
|
||||
group: StandardOrderGroup,
|
||||
metadata: dict[str, Any] | None = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Place a group of orders with OCO (One-Cancels-Other) relationship.
|
||||
|
||||
Args:
|
||||
group: Group of orders with OCO mode
|
||||
metadata: Optional strategy-specific metadata
|
||||
|
||||
Returns:
|
||||
intent_ids: List of intent IDs for each order in the group
|
||||
|
||||
Raises:
|
||||
ValidationError: If any order violates market constraints
|
||||
ExchangeError: If exchange rejects the group
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def cancel_order(self, intent_id: str) -> None:
|
||||
"""
|
||||
Cancel an order by intent ID.
|
||||
|
||||
Updates the intent to indicate cancellation is desired, and the
|
||||
reconciliation loop will handle the actual exchange cancellation.
|
||||
|
||||
Args:
|
||||
intent_id: Intent ID of the order to cancel
|
||||
|
||||
Raises:
|
||||
NotFoundError: If intent_id doesn't exist
|
||||
ExchangeError: If exchange rejects cancellation
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def modify_order(
|
||||
self,
|
||||
intent_id: str,
|
||||
new_order: StandardOrder,
|
||||
) -> None:
|
||||
"""
|
||||
Modify an existing order.
|
||||
|
||||
Updates the order intent, and the reconciliation loop will update
|
||||
the exchange order (via modify API if available, or cancel+replace).
|
||||
|
||||
Args:
|
||||
intent_id: Intent ID of the order to modify
|
||||
new_order: New order specification
|
||||
|
||||
Raises:
|
||||
NotFoundError: If intent_id doesn't exist
|
||||
ValidationError: If new order violates market constraints
|
||||
ExchangeError: If exchange rejects modification
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def cancel_all_orders(self, symbol_id: str | None = None) -> int:
|
||||
"""
|
||||
Cancel all orders, optionally filtered by symbol.
|
||||
|
||||
Args:
|
||||
symbol_id: If provided, only cancel orders for this symbol
|
||||
|
||||
Returns:
|
||||
count: Number of orders canceled
|
||||
"""
|
||||
pass
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Query API - Read desired and actual state
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@abstractmethod
|
||||
async def get_order_intent(self, intent_id: str) -> OrderIntent:
|
||||
"""
|
||||
Get the desired order state (what strategy kernel wants).
|
||||
|
||||
Args:
|
||||
intent_id: Intent ID to query
|
||||
|
||||
Returns:
|
||||
The order intent
|
||||
|
||||
Raises:
|
||||
NotFoundError: If intent_id doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_order_state(self, intent_id: str) -> OrderState:
|
||||
"""
|
||||
Get the actual order state (what's currently on exchange).
|
||||
|
||||
Args:
|
||||
intent_id: Intent ID to query
|
||||
|
||||
Returns:
|
||||
The current order state
|
||||
|
||||
Raises:
|
||||
NotFoundError: If intent_id doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_all_intents(self, symbol_id: str | None = None) -> list[OrderIntent]:
|
||||
"""
|
||||
Get all order intents, optionally filtered by symbol.
|
||||
|
||||
Args:
|
||||
symbol_id: If provided, only return intents for this symbol
|
||||
|
||||
Returns:
|
||||
List of order intents
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_all_orders(self, symbol_id: str | None = None) -> list[OrderState]:
|
||||
"""
|
||||
Get all actual order states, optionally filtered by symbol.
|
||||
|
||||
Args:
|
||||
symbol_id: If provided, only return orders for this symbol
|
||||
|
||||
Returns:
|
||||
List of order states
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_positions(self, symbol_id: str | None = None) -> list[Position]:
|
||||
"""
|
||||
Get current positions, optionally filtered by symbol.
|
||||
|
||||
Args:
|
||||
symbol_id: If provided, only return positions for this symbol
|
||||
|
||||
Returns:
|
||||
List of positions
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_account_state(self) -> AccountState:
|
||||
"""
|
||||
Get current account state (balances, margin, etc.).
|
||||
|
||||
Returns:
|
||||
Current account state
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_symbol_metadata(self, symbol_id: str) -> SymbolMetadata:
|
||||
"""
|
||||
Get metadata for a symbol (constraints, capabilities, etc.).
|
||||
|
||||
Args:
|
||||
symbol_id: Symbol to query
|
||||
|
||||
Returns:
|
||||
Symbol metadata
|
||||
|
||||
Raises:
|
||||
NotFoundError: If symbol doesn't exist on this exchange
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_asset_metadata(self, asset_id: str) -> AssetMetadata:
|
||||
"""
|
||||
Get metadata for an asset.
|
||||
|
||||
Args:
|
||||
asset_id: Asset to query
|
||||
|
||||
Returns:
|
||||
Asset metadata
|
||||
|
||||
Raises:
|
||||
NotFoundError: If asset doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def list_symbols(self) -> list[str]:
|
||||
"""
|
||||
List all available symbols on this exchange.
|
||||
|
||||
Returns:
|
||||
List of symbol IDs
|
||||
"""
|
||||
pass
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Event Subscription API
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@abstractmethod
|
||||
def subscribe_events(
|
||||
self,
|
||||
callback: Callable[[BaseEvent], None],
|
||||
event_filter: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Subscribe to events from this exchange kernel.
|
||||
|
||||
Args:
|
||||
callback: Function to call when events occur
|
||||
event_filter: Optional filter criteria (event_type, symbol_id, etc.)
|
||||
|
||||
Returns:
|
||||
subscription_id: Unique ID for this subscription (for unsubscribe)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unsubscribe_events(self, subscription_id: str) -> None:
|
||||
"""
|
||||
Unsubscribe from events.
|
||||
|
||||
Args:
|
||||
subscription_id: Subscription ID returned from subscribe_events
|
||||
"""
|
||||
pass
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Lifecycle Management
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
Start the exchange kernel.
|
||||
|
||||
Initializes connections, starts reconciliation loops, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
"""
|
||||
Stop the exchange kernel.
|
||||
|
||||
Closes connections, stops reconciliation loops, etc.
|
||||
Does NOT cancel open orders - call cancel_all_orders() first if desired.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def health_check(self) -> dict[str, Any]:
|
||||
"""
|
||||
Check health status of the exchange kernel.
|
||||
|
||||
Returns:
|
||||
Health status dict with connection state, latency, error counts, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Reconciliation Control (advanced)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@abstractmethod
|
||||
async def force_reconciliation(self, intent_id: str | None = None) -> None:
|
||||
"""
|
||||
Force immediate reconciliation.
|
||||
|
||||
Args:
|
||||
intent_id: If provided, only reconcile this specific intent.
|
||||
If None, reconcile all intents.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_reconciliation_metrics(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get metrics about the reconciliation process.
|
||||
|
||||
Returns:
|
||||
Metrics dict with reconciliation lag, error rates, retry counts, etc.
|
||||
"""
|
||||
pass
|
||||
250
backend.old/src/exchange_kernel/events.py
Normal file
250
backend.old/src/exchange_kernel/events.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
Event definitions for the Exchange Kernel.
|
||||
|
||||
All events that can occur during the order lifecycle, position management,
|
||||
and account updates.
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..schema.order_spec import Float, Uint64
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Base Event Classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class EventType(StrEnum):
|
||||
"""Types of events emitted by the exchange kernel"""
|
||||
# Order lifecycle
|
||||
ORDER_SUBMITTED = "ORDER_SUBMITTED"
|
||||
ORDER_ACCEPTED = "ORDER_ACCEPTED"
|
||||
ORDER_REJECTED = "ORDER_REJECTED"
|
||||
ORDER_PARTIALLY_FILLED = "ORDER_PARTIALLY_FILLED"
|
||||
ORDER_FILLED = "ORDER_FILLED"
|
||||
ORDER_CANCELED = "ORDER_CANCELED"
|
||||
ORDER_MODIFIED = "ORDER_MODIFIED"
|
||||
ORDER_EXPIRED = "ORDER_EXPIRED"
|
||||
|
||||
# Position events
|
||||
POSITION_OPENED = "POSITION_OPENED"
|
||||
POSITION_MODIFIED = "POSITION_MODIFIED"
|
||||
POSITION_CLOSED = "POSITION_CLOSED"
|
||||
|
||||
# Account events
|
||||
ACCOUNT_BALANCE_UPDATED = "ACCOUNT_BALANCE_UPDATED"
|
||||
MARGIN_CALL_WARNING = "MARGIN_CALL_WARNING"
|
||||
|
||||
# System events
|
||||
RECONCILIATION_FAILED = "RECONCILIATION_FAILED"
|
||||
CONNECTION_LOST = "CONNECTION_LOST"
|
||||
CONNECTION_RESTORED = "CONNECTION_RESTORED"
|
||||
|
||||
|
||||
class BaseEvent(BaseModel):
|
||||
"""Base class for all exchange kernel events"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
event_type: EventType = Field(description="Type of event")
|
||||
timestamp: Uint64 = Field(description="Event timestamp (Unix milliseconds)")
|
||||
exchange: str = Field(description="Exchange identifier")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Additional event data")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Order Events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class OrderEvent(BaseEvent):
|
||||
"""Base class for order-related events"""
|
||||
|
||||
intent_id: str = Field(description="Order intent ID")
|
||||
order_id: str | None = Field(default=None, description="Exchange order ID (if assigned)")
|
||||
symbol_id: str = Field(description="Symbol being traded")
|
||||
|
||||
|
||||
class OrderSubmitted(OrderEvent):
|
||||
"""Order has been submitted to the exchange"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ORDER_SUBMITTED)
|
||||
client_order_id: str | None = Field(default=None, description="Client-assigned order ID")
|
||||
|
||||
|
||||
class OrderAccepted(OrderEvent):
|
||||
"""Order has been accepted by the exchange"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ORDER_ACCEPTED)
|
||||
order_id: str = Field(description="Exchange-assigned order ID")
|
||||
accepted_at: Uint64 = Field(description="Exchange acceptance timestamp")
|
||||
|
||||
|
||||
class OrderRejected(OrderEvent):
|
||||
"""Order was rejected by the exchange"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ORDER_REJECTED)
|
||||
reason: str = Field(description="Rejection reason")
|
||||
error_code: str | None = Field(default=None, description="Exchange error code")
|
||||
|
||||
|
||||
class OrderPartiallyFilled(OrderEvent):
|
||||
"""Order was partially filled"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ORDER_PARTIALLY_FILLED)
|
||||
order_id: str = Field(description="Exchange order ID")
|
||||
fill_price: Float = Field(description="Fill price for this execution")
|
||||
fill_quantity: Float = Field(description="Quantity filled in this execution")
|
||||
total_filled: Float = Field(description="Total quantity filled so far")
|
||||
remaining_quantity: Float = Field(description="Remaining quantity to fill")
|
||||
commission: Float = Field(default=0.0, description="Commission/fee for this fill")
|
||||
commission_asset: str | None = Field(default=None, description="Asset used for commission")
|
||||
trade_id: str | None = Field(default=None, description="Exchange trade ID")
|
||||
|
||||
|
||||
class OrderFilled(OrderEvent):
|
||||
"""Order was completely filled"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ORDER_FILLED)
|
||||
order_id: str = Field(description="Exchange order ID")
|
||||
average_fill_price: Float = Field(description="Average execution price")
|
||||
total_quantity: Float = Field(description="Total quantity filled")
|
||||
total_commission: Float = Field(default=0.0, description="Total commission/fees")
|
||||
commission_asset: str | None = Field(default=None, description="Asset used for commission")
|
||||
completed_at: Uint64 = Field(description="Completion timestamp")
|
||||
|
||||
|
||||
class OrderCanceled(OrderEvent):
|
||||
"""Order was canceled"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ORDER_CANCELED)
|
||||
order_id: str = Field(description="Exchange order ID")
|
||||
reason: str = Field(description="Cancellation reason")
|
||||
filled_quantity: Float = Field(default=0.0, description="Quantity filled before cancellation")
|
||||
canceled_at: Uint64 = Field(description="Cancellation timestamp")
|
||||
|
||||
|
||||
class OrderModified(OrderEvent):
|
||||
"""Order was modified (price, quantity, etc.)"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ORDER_MODIFIED)
|
||||
order_id: str = Field(description="Exchange order ID")
|
||||
old_price: Float | None = Field(default=None, description="Previous price")
|
||||
new_price: Float | None = Field(default=None, description="New price")
|
||||
old_quantity: Float | None = Field(default=None, description="Previous quantity")
|
||||
new_quantity: Float | None = Field(default=None, description="New quantity")
|
||||
modified_at: Uint64 = Field(description="Modification timestamp")
|
||||
|
||||
|
||||
class OrderExpired(OrderEvent):
|
||||
"""Order expired (GTD, DAY orders)"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ORDER_EXPIRED)
|
||||
order_id: str = Field(description="Exchange order ID")
|
||||
filled_quantity: Float = Field(default=0.0, description="Quantity filled before expiration")
|
||||
expired_at: Uint64 = Field(description="Expiration timestamp")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Position Events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class PositionEvent(BaseEvent):
|
||||
"""Base class for position-related events"""
|
||||
|
||||
position_id: str = Field(description="Position identifier")
|
||||
symbol_id: str = Field(description="Symbol identifier")
|
||||
asset_id: str = Field(description="Asset identifier")
|
||||
|
||||
|
||||
class PositionOpened(PositionEvent):
|
||||
"""New position was opened"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.POSITION_OPENED)
|
||||
quantity: Float = Field(description="Position quantity")
|
||||
entry_price: Float = Field(description="Entry price")
|
||||
side: str = Field(description="LONG or SHORT")
|
||||
leverage: Float | None = Field(default=None, description="Leverage")
|
||||
|
||||
|
||||
class PositionModified(PositionEvent):
|
||||
"""Existing position was modified (size change, etc.)"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.POSITION_MODIFIED)
|
||||
old_quantity: Float = Field(description="Previous quantity")
|
||||
new_quantity: Float = Field(description="New quantity")
|
||||
average_entry_price: Float = Field(description="Updated average entry price")
|
||||
unrealized_pnl: Float | None = Field(default=None, description="Current unrealized P&L")
|
||||
|
||||
|
||||
class PositionClosed(PositionEvent):
|
||||
"""Position was closed"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.POSITION_CLOSED)
|
||||
exit_price: Float = Field(description="Exit price")
|
||||
realized_pnl: Float = Field(description="Realized profit/loss")
|
||||
closed_at: Uint64 = Field(description="Closure timestamp")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Account Events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class AccountEvent(BaseEvent):
|
||||
"""Base class for account-related events"""
|
||||
|
||||
account_id: str = Field(description="Account identifier")
|
||||
|
||||
|
||||
class AccountBalanceUpdated(AccountEvent):
|
||||
"""Account balance was updated"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.ACCOUNT_BALANCE_UPDATED)
|
||||
asset_id: str = Field(description="Asset that changed")
|
||||
old_balance: Float = Field(description="Previous balance")
|
||||
new_balance: Float = Field(description="New balance")
|
||||
old_available: Float = Field(description="Previous available")
|
||||
new_available: Float = Field(description="New available")
|
||||
change_reason: str = Field(description="Why balance changed (TRADE, DEPOSIT, WITHDRAWAL, etc.)")
|
||||
|
||||
|
||||
class MarginCallWarning(AccountEvent):
|
||||
"""Margin level is approaching liquidation threshold"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.MARGIN_CALL_WARNING)
|
||||
margin_level: Float = Field(description="Current margin level")
|
||||
liquidation_threshold: Float = Field(description="Liquidation threshold")
|
||||
required_action: str = Field(description="Required action to avoid liquidation")
|
||||
estimated_liquidation_price: Float | None = Field(
|
||||
default=None,
|
||||
description="Estimated liquidation price for positions"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# System Events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ReconciliationFailed(BaseEvent):
|
||||
"""Failed to reconcile intent with actual state"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.RECONCILIATION_FAILED)
|
||||
intent_id: str = Field(description="Order intent ID")
|
||||
error_message: str = Field(description="Error details")
|
||||
retry_count: int = Field(description="Number of retry attempts")
|
||||
|
||||
|
||||
class ConnectionLost(BaseEvent):
|
||||
"""Connection to exchange was lost"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.CONNECTION_LOST)
|
||||
reason: str = Field(description="Disconnection reason")
|
||||
|
||||
|
||||
class ConnectionRestored(BaseEvent):
|
||||
"""Connection to exchange was restored"""
|
||||
|
||||
event_type: EventType = Field(default=EventType.CONNECTION_RESTORED)
|
||||
downtime_duration: int = Field(description="Duration of downtime in milliseconds")
|
||||
194
backend.old/src/exchange_kernel/models.py
Normal file
194
backend.old/src/exchange_kernel/models.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
Data models for the Exchange Kernel.
|
||||
|
||||
Defines order intents, order state, positions, assets, and account state.
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..schema.order_spec import (
|
||||
StandardOrder,
|
||||
StandardOrderStatus,
|
||||
AssetType,
|
||||
Float,
|
||||
Uint64,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Order Intent and State
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class OrderIntent(BaseModel):
|
||||
"""
|
||||
Desired order state from the strategy kernel.
|
||||
|
||||
This represents what the strategy wants, not what currently exists.
|
||||
The exchange kernel will work to reconcile actual state with this intent.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
intent_id: str = Field(description="Unique identifier for this intent (client-assigned)")
|
||||
order: StandardOrder = Field(description="The desired order specification")
|
||||
group_id: str | None = Field(default=None, description="Group ID for OCO relationships")
|
||||
created_at: Uint64 = Field(description="When this intent was created")
|
||||
updated_at: Uint64 = Field(description="When this intent was last modified")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Strategy-specific metadata")
|
||||
|
||||
|
||||
class ReconciliationStatus(StrEnum):
|
||||
"""Status of reconciliation between intent and actual state"""
|
||||
PENDING = "PENDING" # Not yet submitted to exchange
|
||||
SUBMITTING = "SUBMITTING" # Currently being submitted
|
||||
ACTIVE = "ACTIVE" # Successfully placed on exchange
|
||||
RECONCILING = "RECONCILING" # Intent changed, updating exchange order
|
||||
FAILED = "FAILED" # Failed to submit or reconcile
|
||||
COMPLETED = "COMPLETED" # Order fully filled
|
||||
CANCELED = "CANCELED" # Order canceled
|
||||
|
||||
|
||||
class OrderState(BaseModel):
|
||||
"""
|
||||
Actual current state of an order on the exchange.
|
||||
|
||||
This represents reality - what the exchange reports about the order.
|
||||
May differ from OrderIntent during reconciliation.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
intent_id: str = Field(description="Links back to the OrderIntent")
|
||||
exchange_order_id: str = Field(description="Exchange-assigned order ID")
|
||||
status: StandardOrderStatus = Field(description="Current order status from exchange")
|
||||
reconciliation_status: ReconciliationStatus = Field(description="Reconciliation state")
|
||||
last_sync_at: Uint64 = Field(description="Last time we synced with exchange")
|
||||
error_message: str | None = Field(default=None, description="Error details if FAILED")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Position and Asset Models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class AssetMetadata(BaseModel):
|
||||
"""
|
||||
Metadata describing an asset type.
|
||||
|
||||
Provides context for positions, balances, and trading.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
asset_id: str = Field(description="Unique asset identifier")
|
||||
symbol: str = Field(description="Asset symbol (e.g., 'BTC', 'ETH', 'USD')")
|
||||
asset_type: AssetType = Field(description="Type of asset")
|
||||
name: str = Field(description="Full name")
|
||||
|
||||
# Contract specifications (for derivatives)
|
||||
contract_size: Float | None = Field(default=None, description="Contract multiplier")
|
||||
settlement_asset: str | None = Field(default=None, description="Settlement currency")
|
||||
expiry_timestamp: Uint64 | None = Field(default=None, description="Expiration timestamp")
|
||||
|
||||
# Trading parameters
|
||||
tick_size: Float | None = Field(default=None, description="Minimum price increment")
|
||||
lot_size: Float | None = Field(default=None, description="Minimum quantity increment")
|
||||
|
||||
# Margin requirements (for leveraged products)
|
||||
initial_margin_rate: Float | None = Field(default=None, description="Initial margin requirement")
|
||||
maintenance_margin_rate: Float | None = Field(default=None, description="Maintenance margin requirement")
|
||||
|
||||
# Additional metadata
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Exchange-specific metadata")
|
||||
|
||||
|
||||
class Asset(BaseModel):
|
||||
"""
|
||||
An asset holding (spot, margin, derivative position, etc.)
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
asset_id: str = Field(description="References AssetMetadata")
|
||||
quantity: Float = Field(description="Amount held (positive or negative for short positions)")
|
||||
available: Float = Field(description="Amount available for trading (not locked in orders)")
|
||||
locked: Float = Field(description="Amount locked in open orders")
|
||||
|
||||
# For derivative positions
|
||||
entry_price: Float | None = Field(default=None, description="Average entry price")
|
||||
mark_price: Float | None = Field(default=None, description="Current mark price")
|
||||
liquidation_price: Float | None = Field(default=None, description="Estimated liquidation price")
|
||||
unrealized_pnl: Float | None = Field(default=None, description="Unrealized profit/loss")
|
||||
realized_pnl: Float | None = Field(default=None, description="Realized profit/loss")
|
||||
|
||||
# Margin info
|
||||
margin_used: Float | None = Field(default=None, description="Margin allocated to this position")
|
||||
|
||||
updated_at: Uint64 = Field(description="Last update timestamp")
|
||||
|
||||
|
||||
class Position(BaseModel):
|
||||
"""
|
||||
A trading position (spot, margin, perpetual, futures, etc.)
|
||||
|
||||
Tracks both the asset holdings and associated metadata.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
position_id: str = Field(description="Unique position identifier")
|
||||
symbol_id: str = Field(description="Trading symbol")
|
||||
asset: Asset = Field(description="Asset holding details")
|
||||
metadata: AssetMetadata = Field(description="Asset metadata")
|
||||
|
||||
# Position-level info
|
||||
leverage: Float | None = Field(default=None, description="Current leverage")
|
||||
side: str | None = Field(default=None, description="LONG or SHORT (for derivatives)")
|
||||
|
||||
updated_at: Uint64 = Field(description="Last update timestamp")
|
||||
|
||||
|
||||
class Balance(BaseModel):
|
||||
"""Account balance for a single currency/asset"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
asset_id: str = Field(description="Asset identifier")
|
||||
total: Float = Field(description="Total balance")
|
||||
available: Float = Field(description="Available for trading")
|
||||
locked: Float = Field(description="Locked in orders/positions")
|
||||
|
||||
# For margin accounts
|
||||
borrowed: Float = Field(default=0.0, description="Borrowed amount (margin)")
|
||||
interest: Float = Field(default=0.0, description="Accrued interest")
|
||||
|
||||
updated_at: Uint64 = Field(description="Last update timestamp")
|
||||
|
||||
|
||||
class AccountState(BaseModel):
|
||||
"""
|
||||
Complete account state including balances, positions, and margin info.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
account_id: str = Field(description="Account identifier")
|
||||
exchange: str = Field(description="Exchange identifier")
|
||||
|
||||
balances: list[Balance] = Field(default_factory=list, description="All asset balances")
|
||||
positions: list[Position] = Field(default_factory=list, description="All open positions")
|
||||
|
||||
# Margin account info
|
||||
total_equity: Float | None = Field(default=None, description="Total account equity")
|
||||
total_margin_used: Float | None = Field(default=None, description="Total margin in use")
|
||||
total_available_margin: Float | None = Field(default=None, description="Available margin")
|
||||
margin_level: Float | None = Field(default=None, description="Margin level (equity/margin_used)")
|
||||
|
||||
# Risk metrics
|
||||
total_unrealized_pnl: Float | None = Field(default=None, description="Total unrealized P&L")
|
||||
total_realized_pnl: Float | None = Field(default=None, description="Total realized P&L")
|
||||
|
||||
updated_at: Uint64 = Field(description="Last update timestamp")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Exchange-specific data")
|
||||
472
backend.old/src/exchange_kernel/state.py
Normal file
472
backend.old/src/exchange_kernel/state.py
Normal file
@@ -0,0 +1,472 @@
|
||||
"""
|
||||
State management for the Exchange Kernel.
|
||||
|
||||
Implements the storage and reconciliation logic for desired vs actual state.
|
||||
This is the "Kubernetes for orders" concept - maintaining intent and continuously
|
||||
reconciling reality to match intent.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from collections import defaultdict
|
||||
|
||||
from .models import OrderIntent, OrderState, ReconciliationStatus
|
||||
from ..schema.order_spec import Uint64
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Intent State Store - Desired State
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class IntentStateStore(ABC):
|
||||
"""
|
||||
Storage for order intents (desired state).
|
||||
|
||||
This represents what the strategy kernel wants. Intents are durable and
|
||||
persist across restarts. The reconciliation loop continuously works to
|
||||
make actual state match these intents.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def create_intent(self, intent: OrderIntent) -> None:
|
||||
"""
|
||||
Store a new order intent.
|
||||
|
||||
Args:
|
||||
intent: The order intent to store
|
||||
|
||||
Raises:
|
||||
AlreadyExistsError: If intent_id already exists
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_intent(self, intent_id: str) -> OrderIntent:
|
||||
"""
|
||||
Retrieve an order intent.
|
||||
|
||||
Args:
|
||||
intent_id: Intent ID to retrieve
|
||||
|
||||
Returns:
|
||||
The order intent
|
||||
|
||||
Raises:
|
||||
NotFoundError: If intent_id doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def update_intent(self, intent: OrderIntent) -> None:
|
||||
"""
|
||||
Update an existing order intent.
|
||||
|
||||
Args:
|
||||
intent: Updated intent (intent_id must match existing)
|
||||
|
||||
Raises:
|
||||
NotFoundError: If intent_id doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete_intent(self, intent_id: str) -> None:
|
||||
"""
|
||||
Delete an order intent.
|
||||
|
||||
Args:
|
||||
intent_id: Intent ID to delete
|
||||
|
||||
Raises:
|
||||
NotFoundError: If intent_id doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def list_intents(
|
||||
self,
|
||||
symbol_id: str | None = None,
|
||||
group_id: str | None = None,
|
||||
) -> list[OrderIntent]:
|
||||
"""
|
||||
List all order intents, optionally filtered.
|
||||
|
||||
Args:
|
||||
symbol_id: Filter by symbol
|
||||
group_id: Filter by OCO group
|
||||
|
||||
Returns:
|
||||
List of matching intents
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_intents_by_group(self, group_id: str) -> list[OrderIntent]:
|
||||
"""
|
||||
Get all intents in an OCO group.
|
||||
|
||||
Args:
|
||||
group_id: Group ID to query
|
||||
|
||||
Returns:
|
||||
List of intents in the group
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Actual State Store - Current Reality
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ActualStateStore(ABC):
|
||||
"""
|
||||
Storage for actual order state (reality on exchange).
|
||||
|
||||
This represents what actually exists on the exchange right now.
|
||||
Updated frequently from exchange feeds and order status queries.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def create_order_state(self, state: OrderState) -> None:
|
||||
"""
|
||||
Store a new order state.
|
||||
|
||||
Args:
|
||||
state: The order state to store
|
||||
|
||||
Raises:
|
||||
AlreadyExistsError: If order state for this intent_id already exists
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_order_state(self, intent_id: str) -> OrderState:
|
||||
"""
|
||||
Retrieve order state for an intent.
|
||||
|
||||
Args:
|
||||
intent_id: Intent ID to query
|
||||
|
||||
Returns:
|
||||
The current order state
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no state exists for this intent
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_order_state_by_exchange_id(self, exchange_order_id: str) -> OrderState:
|
||||
"""
|
||||
Retrieve order state by exchange order ID.
|
||||
|
||||
Useful for processing exchange callbacks that only provide exchange_order_id.
|
||||
|
||||
Args:
|
||||
exchange_order_id: Exchange's order ID
|
||||
|
||||
Returns:
|
||||
The order state
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no state exists for this exchange order ID
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def update_order_state(self, state: OrderState) -> None:
|
||||
"""
|
||||
Update an existing order state.
|
||||
|
||||
Args:
|
||||
state: Updated state (intent_id must match existing)
|
||||
|
||||
Raises:
|
||||
NotFoundError: If state doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete_order_state(self, intent_id: str) -> None:
|
||||
"""
|
||||
Delete an order state.
|
||||
|
||||
Args:
|
||||
intent_id: Intent ID whose state to delete
|
||||
|
||||
Raises:
|
||||
NotFoundError: If state doesn't exist
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def list_order_states(
|
||||
self,
|
||||
symbol_id: str | None = None,
|
||||
reconciliation_status: ReconciliationStatus | None = None,
|
||||
) -> list[OrderState]:
|
||||
"""
|
||||
List all order states, optionally filtered.
|
||||
|
||||
Args:
|
||||
symbol_id: Filter by symbol
|
||||
reconciliation_status: Filter by reconciliation status
|
||||
|
||||
Returns:
|
||||
List of matching order states
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_stale_orders(self, max_age_seconds: int) -> list[OrderState]:
|
||||
"""
|
||||
Find orders that haven't been synced recently.
|
||||
|
||||
Used to identify orders that need status updates from exchange.
|
||||
|
||||
Args:
|
||||
max_age_seconds: Maximum age since last sync
|
||||
|
||||
Returns:
|
||||
List of order states that need refresh
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# In-Memory Implementations (for testing/prototyping)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class InMemoryIntentStore(IntentStateStore):
|
||||
"""Simple in-memory implementation of IntentStateStore"""
|
||||
|
||||
def __init__(self):
|
||||
self._intents: dict[str, OrderIntent] = {}
|
||||
self._by_symbol: dict[str, set[str]] = defaultdict(set)
|
||||
self._by_group: dict[str, set[str]] = defaultdict(set)
|
||||
|
||||
async def create_intent(self, intent: OrderIntent) -> None:
|
||||
if intent.intent_id in self._intents:
|
||||
raise ValueError(f"Intent {intent.intent_id} already exists")
|
||||
self._intents[intent.intent_id] = intent
|
||||
self._by_symbol[intent.order.symbol_id].add(intent.intent_id)
|
||||
if intent.group_id:
|
||||
self._by_group[intent.group_id].add(intent.intent_id)
|
||||
|
||||
async def get_intent(self, intent_id: str) -> OrderIntent:
|
||||
if intent_id not in self._intents:
|
||||
raise KeyError(f"Intent {intent_id} not found")
|
||||
return self._intents[intent_id]
|
||||
|
||||
async def update_intent(self, intent: OrderIntent) -> None:
|
||||
if intent.intent_id not in self._intents:
|
||||
raise KeyError(f"Intent {intent.intent_id} not found")
|
||||
old_intent = self._intents[intent.intent_id]
|
||||
|
||||
# Update indices if symbol or group changed
|
||||
if old_intent.order.symbol_id != intent.order.symbol_id:
|
||||
self._by_symbol[old_intent.order.symbol_id].discard(intent.intent_id)
|
||||
self._by_symbol[intent.order.symbol_id].add(intent.intent_id)
|
||||
|
||||
if old_intent.group_id != intent.group_id:
|
||||
if old_intent.group_id:
|
||||
self._by_group[old_intent.group_id].discard(intent.intent_id)
|
||||
if intent.group_id:
|
||||
self._by_group[intent.group_id].add(intent.intent_id)
|
||||
|
||||
self._intents[intent.intent_id] = intent
|
||||
|
||||
async def delete_intent(self, intent_id: str) -> None:
|
||||
if intent_id not in self._intents:
|
||||
raise KeyError(f"Intent {intent_id} not found")
|
||||
intent = self._intents[intent_id]
|
||||
self._by_symbol[intent.order.symbol_id].discard(intent_id)
|
||||
if intent.group_id:
|
||||
self._by_group[intent.group_id].discard(intent_id)
|
||||
del self._intents[intent_id]
|
||||
|
||||
async def list_intents(
|
||||
self,
|
||||
symbol_id: str | None = None,
|
||||
group_id: str | None = None,
|
||||
) -> list[OrderIntent]:
|
||||
if symbol_id and group_id:
|
||||
# Intersection of both filters
|
||||
symbol_ids = self._by_symbol.get(symbol_id, set())
|
||||
group_ids = self._by_group.get(group_id, set())
|
||||
intent_ids = symbol_ids & group_ids
|
||||
elif symbol_id:
|
||||
intent_ids = self._by_symbol.get(symbol_id, set())
|
||||
elif group_id:
|
||||
intent_ids = self._by_group.get(group_id, set())
|
||||
else:
|
||||
intent_ids = self._intents.keys()
|
||||
|
||||
return [self._intents[iid] for iid in intent_ids]
|
||||
|
||||
async def get_intents_by_group(self, group_id: str) -> list[OrderIntent]:
|
||||
intent_ids = self._by_group.get(group_id, set())
|
||||
return [self._intents[iid] for iid in intent_ids]
|
||||
|
||||
|
||||
class InMemoryActualStateStore(ActualStateStore):
|
||||
"""Simple in-memory implementation of ActualStateStore"""
|
||||
|
||||
def __init__(self):
|
||||
self._states: dict[str, OrderState] = {}
|
||||
self._by_exchange_id: dict[str, str] = {} # exchange_order_id -> intent_id
|
||||
self._by_symbol: dict[str, set[str]] = defaultdict(set)
|
||||
|
||||
async def create_order_state(self, state: OrderState) -> None:
|
||||
if state.intent_id in self._states:
|
||||
raise ValueError(f"Order state for intent {state.intent_id} already exists")
|
||||
self._states[state.intent_id] = state
|
||||
self._by_exchange_id[state.exchange_order_id] = state.intent_id
|
||||
self._by_symbol[state.status.order.symbol_id].add(state.intent_id)
|
||||
|
||||
async def get_order_state(self, intent_id: str) -> OrderState:
|
||||
if intent_id not in self._states:
|
||||
raise KeyError(f"Order state for intent {intent_id} not found")
|
||||
return self._states[intent_id]
|
||||
|
||||
async def get_order_state_by_exchange_id(self, exchange_order_id: str) -> OrderState:
|
||||
if exchange_order_id not in self._by_exchange_id:
|
||||
raise KeyError(f"Order state for exchange order {exchange_order_id} not found")
|
||||
intent_id = self._by_exchange_id[exchange_order_id]
|
||||
return self._states[intent_id]
|
||||
|
||||
async def update_order_state(self, state: OrderState) -> None:
|
||||
if state.intent_id not in self._states:
|
||||
raise KeyError(f"Order state for intent {state.intent_id} not found")
|
||||
old_state = self._states[state.intent_id]
|
||||
|
||||
# Update exchange_id index if it changed
|
||||
if old_state.exchange_order_id != state.exchange_order_id:
|
||||
del self._by_exchange_id[old_state.exchange_order_id]
|
||||
self._by_exchange_id[state.exchange_order_id] = state.intent_id
|
||||
|
||||
# Update symbol index if it changed
|
||||
old_symbol = old_state.status.order.symbol_id
|
||||
new_symbol = state.status.order.symbol_id
|
||||
if old_symbol != new_symbol:
|
||||
self._by_symbol[old_symbol].discard(state.intent_id)
|
||||
self._by_symbol[new_symbol].add(state.intent_id)
|
||||
|
||||
self._states[state.intent_id] = state
|
||||
|
||||
async def delete_order_state(self, intent_id: str) -> None:
|
||||
if intent_id not in self._states:
|
||||
raise KeyError(f"Order state for intent {intent_id} not found")
|
||||
state = self._states[intent_id]
|
||||
del self._by_exchange_id[state.exchange_order_id]
|
||||
self._by_symbol[state.status.order.symbol_id].discard(intent_id)
|
||||
del self._states[intent_id]
|
||||
|
||||
async def list_order_states(
|
||||
self,
|
||||
symbol_id: str | None = None,
|
||||
reconciliation_status: ReconciliationStatus | None = None,
|
||||
) -> list[OrderState]:
|
||||
if symbol_id:
|
||||
intent_ids = self._by_symbol.get(symbol_id, set())
|
||||
states = [self._states[iid] for iid in intent_ids]
|
||||
else:
|
||||
states = list(self._states.values())
|
||||
|
||||
if reconciliation_status:
|
||||
states = [s for s in states if s.reconciliation_status == reconciliation_status]
|
||||
|
||||
return states
|
||||
|
||||
async def get_stale_orders(self, max_age_seconds: int) -> list[OrderState]:
|
||||
import time
|
||||
current_time = int(time.time())
|
||||
threshold = current_time - max_age_seconds
|
||||
|
||||
return [
|
||||
state
|
||||
for state in self._states.values()
|
||||
if state.last_sync_at < threshold
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reconciliation Engine (framework only, no implementation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ReconciliationEngine:
|
||||
"""
|
||||
Reconciliation engine that continuously works to make actual state match intent.
|
||||
|
||||
This is the heart of the "Kubernetes for orders" concept. It:
|
||||
1. Compares desired state (intents) with actual state (exchange orders)
|
||||
2. Computes necessary actions (place, modify, cancel)
|
||||
3. Executes those actions via the exchange API
|
||||
4. Handles retries, errors, and edge cases
|
||||
|
||||
This is a framework class - concrete implementations will be exchange-specific.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
intent_store: IntentStateStore,
|
||||
actual_store: ActualStateStore,
|
||||
):
|
||||
"""
|
||||
Initialize the reconciliation engine.
|
||||
|
||||
Args:
|
||||
intent_store: Store for desired state
|
||||
actual_store: Store for actual state
|
||||
"""
|
||||
self.intent_store = intent_store
|
||||
self.actual_store = actual_store
|
||||
self._running = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the reconciliation loop"""
|
||||
self._running = True
|
||||
# Implementation would start async reconciliation loop here
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the reconciliation loop"""
|
||||
self._running = False
|
||||
# Implementation would stop reconciliation loop here
|
||||
pass
|
||||
|
||||
async def reconcile_intent(self, intent_id: str) -> None:
|
||||
"""
|
||||
Reconcile a specific intent.
|
||||
|
||||
Compares the intent with actual state and takes necessary actions.
|
||||
|
||||
Args:
|
||||
intent_id: Intent to reconcile
|
||||
"""
|
||||
# Framework only - concrete implementation needed
|
||||
pass
|
||||
|
||||
async def reconcile_all(self) -> None:
|
||||
"""
|
||||
Reconcile all intents.
|
||||
|
||||
Full reconciliation pass over all orders.
|
||||
"""
|
||||
# Framework only - concrete implementation needed
|
||||
pass
|
||||
|
||||
def get_metrics(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get reconciliation metrics.
|
||||
|
||||
Returns:
|
||||
Metrics about reconciliation performance, errors, etc.
|
||||
"""
|
||||
return {
|
||||
"running": self._running,
|
||||
"reconciliation_lag_ms": 0, # Framework only
|
||||
"pending_reconciliations": 0, # Framework only
|
||||
"error_count": 0, # Framework only
|
||||
"retry_count": 0, # Framework only
|
||||
}
|
||||
@@ -94,6 +94,11 @@ class Gateway:
|
||||
logger.info(f"Session is busy, interrupting existing task")
|
||||
await session.interrupt()
|
||||
|
||||
# Check if this is a stop interrupt (empty message)
|
||||
if not message.content.strip() and not message.attachments:
|
||||
logger.info("Received stop interrupt (empty message), not starting new agent round")
|
||||
return
|
||||
|
||||
# Add user message to history
|
||||
session.add_message("user", message.content, message.channel_id)
|
||||
logger.info(f"User message added to history, history length: {len(session.get_history())}")
|
||||
@@ -134,33 +139,55 @@ class Gateway:
|
||||
# Stream chunks back to active channels
|
||||
full_response = ""
|
||||
chunk_count = 0
|
||||
async for chunk in response_stream:
|
||||
chunk_count += 1
|
||||
full_response += chunk
|
||||
logger.debug(f"Received chunk #{chunk_count}, length: {len(chunk)}")
|
||||
accumulated_metadata = {}
|
||||
|
||||
# Send chunk to all active channels
|
||||
agent_msg = AgentMessage(
|
||||
session_id=session.session_id,
|
||||
target_channels=session.active_channels,
|
||||
content=chunk,
|
||||
stream_chunk=True,
|
||||
done=False
|
||||
)
|
||||
await self._send_to_channels(agent_msg)
|
||||
async for chunk in response_stream:
|
||||
# Handle dict response with metadata (from agent executor)
|
||||
if isinstance(chunk, dict):
|
||||
content = chunk.get("content", "")
|
||||
metadata = chunk.get("metadata", {})
|
||||
# Accumulate metadata (e.g., plot_urls)
|
||||
for key, value in metadata.items():
|
||||
if key == "plot_urls" and value:
|
||||
# Append to existing plot_urls
|
||||
if "plot_urls" not in accumulated_metadata:
|
||||
accumulated_metadata["plot_urls"] = []
|
||||
accumulated_metadata["plot_urls"].extend(value)
|
||||
logger.info(f"Accumulated plot_urls: {accumulated_metadata['plot_urls']}")
|
||||
else:
|
||||
accumulated_metadata[key] = value
|
||||
chunk = content
|
||||
|
||||
# Only send non-empty chunks
|
||||
if chunk:
|
||||
chunk_count += 1
|
||||
full_response += chunk
|
||||
logger.debug(f"Received chunk #{chunk_count}, length: {len(chunk)}")
|
||||
|
||||
# Send chunk to all active channels with accumulated metadata
|
||||
agent_msg = AgentMessage(
|
||||
session_id=session.session_id,
|
||||
target_channels=session.active_channels,
|
||||
content=chunk,
|
||||
stream_chunk=True,
|
||||
done=False,
|
||||
metadata=accumulated_metadata.copy()
|
||||
)
|
||||
await self._send_to_channels(agent_msg)
|
||||
|
||||
logger.info(f"Agent streaming completed, total chunks: {chunk_count}, response length: {len(full_response)}")
|
||||
|
||||
# Send final done message
|
||||
# Send final done message with all accumulated metadata
|
||||
agent_msg = AgentMessage(
|
||||
session_id=session.session_id,
|
||||
target_channels=session.active_channels,
|
||||
content="",
|
||||
stream_chunk=True,
|
||||
done=True
|
||||
done=True,
|
||||
metadata=accumulated_metadata
|
||||
)
|
||||
await self._send_to_channels(agent_msg)
|
||||
logger.info("Sent final done message to channels")
|
||||
logger.info(f"Sent final done message to channels with metadata: {accumulated_metadata}")
|
||||
|
||||
# Add to history
|
||||
session.add_message("assistant", full_response)
|
||||
179
backend.old/src/indicator/__init__.py
Normal file
179
backend.old/src/indicator/__init__.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
Composable Indicator System.
|
||||
|
||||
Provides a framework for building DAGs of data transformation pipelines
|
||||
that process time-series data incrementally. Indicators can consume
|
||||
DataSources or other Indicators as inputs, composing into arbitrarily
|
||||
complex processing graphs.
|
||||
|
||||
Key Components:
|
||||
---------------
|
||||
|
||||
Indicator (base.py):
|
||||
Abstract base class for all indicator implementations.
|
||||
Declares input/output schemas and implements synchronous compute().
|
||||
|
||||
IndicatorRegistry (registry.py):
|
||||
Central catalog of available indicators with rich metadata
|
||||
for AI agent discovery and tool generation.
|
||||
|
||||
Pipeline (pipeline.py):
|
||||
Execution engine that builds DAGs, resolves dependencies,
|
||||
and orchestrates incremental data flow through indicator chains.
|
||||
|
||||
Schema Types (schema.py):
|
||||
Type definitions for input/output schemas, computation context,
|
||||
and metadata for AI-native documentation.
|
||||
|
||||
Usage Example:
|
||||
--------------
|
||||
|
||||
from indicator import Indicator, IndicatorRegistry, Pipeline
|
||||
from indicator.schema import (
|
||||
InputSchema, OutputSchema, ComputeContext, ComputeResult,
|
||||
IndicatorMetadata, IndicatorParameter
|
||||
)
|
||||
|
||||
# Define an indicator
|
||||
class SimpleMovingAverage(Indicator):
|
||||
@classmethod
|
||||
def get_metadata(cls):
|
||||
return IndicatorMetadata(
|
||||
name="SMA",
|
||||
display_name="Simple Moving Average",
|
||||
description="Arithmetic mean of prices over N periods",
|
||||
category="trend",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="period",
|
||||
type="int",
|
||||
description="Number of periods to average",
|
||||
default=20,
|
||||
min_value=1
|
||||
)
|
||||
],
|
||||
tags=["moving-average", "trend-following"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls):
|
||||
return InputSchema(
|
||||
required_columns=[
|
||||
ColumnInfo(name="close", type="float", description="Closing price")
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params):
|
||||
return OutputSchema(
|
||||
columns=[
|
||||
ColumnInfo(
|
||||
name="sma",
|
||||
type="float",
|
||||
description=f"Simple moving average over {params.get('period', 20)} periods"
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
period = self.params["period"]
|
||||
closes = context.get_column("close")
|
||||
times = context.get_times()
|
||||
|
||||
sma_values = []
|
||||
for i in range(len(closes)):
|
||||
if i < period - 1:
|
||||
sma_values.append(None)
|
||||
else:
|
||||
window = closes[i - period + 1 : i + 1]
|
||||
sma_values.append(sum(window) / period)
|
||||
|
||||
return ComputeResult(
|
||||
data=[
|
||||
{"time": times[i], "sma": sma_values[i]}
|
||||
for i in range(len(times))
|
||||
]
|
||||
)
|
||||
|
||||
# Register the indicator
|
||||
registry = IndicatorRegistry()
|
||||
registry.register(SimpleMovingAverage)
|
||||
|
||||
# Create a pipeline
|
||||
pipeline = Pipeline(datasource_registry)
|
||||
pipeline.add_datasource("price_data", "ccxt", "BTC/USD", "1D")
|
||||
|
||||
sma_indicator = registry.create_instance("SMA", "sma_20", period=20)
|
||||
pipeline.add_indicator("sma_20", sma_indicator, input_node_ids=["price_data"])
|
||||
|
||||
# Execute
|
||||
results = pipeline.execute(datasource_data={"price_data": price_bars})
|
||||
sma_output = results["sma_20"] # Contains columns: time, close, sma_20_sma
|
||||
|
||||
Design Philosophy:
|
||||
------------------
|
||||
|
||||
1. **Schema-based composition**: Indicators declare inputs/outputs via schemas,
|
||||
enabling automatic validation and flexible composition.
|
||||
|
||||
2. **Synchronous execution**: All computation is synchronous for simplicity.
|
||||
Async handling happens at the event/strategy layer.
|
||||
|
||||
3. **Incremental updates**: Indicators receive context about what changed,
|
||||
allowing optimized recomputation of only affected values.
|
||||
|
||||
4. **AI-native metadata**: Rich descriptions, use cases, and parameter specs
|
||||
make indicators discoverable and usable by AI agents.
|
||||
|
||||
5. **Generic data flow**: Indicators work with any data source that matches
|
||||
their input schema, not specific DataSource instances.
|
||||
|
||||
6. **Event-driven**: Designed to react to DataSource updates and propagate
|
||||
changes through the DAG efficiently.
|
||||
"""
|
||||
|
||||
from .base import DataSourceAdapter, Indicator
|
||||
from .pipeline import Pipeline, PipelineNode
|
||||
from .registry import IndicatorRegistry
|
||||
from .schema import (
|
||||
ComputeContext,
|
||||
ComputeResult,
|
||||
IndicatorMetadata,
|
||||
IndicatorParameter,
|
||||
InputSchema,
|
||||
OutputSchema,
|
||||
)
|
||||
from .talib_adapter import (
|
||||
TALibIndicator,
|
||||
register_all_talib_indicators,
|
||||
is_talib_available,
|
||||
get_talib_version,
|
||||
)
|
||||
from .custom_indicators import (
|
||||
register_custom_indicators,
|
||||
CUSTOM_INDICATORS,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Core classes
|
||||
"Indicator",
|
||||
"IndicatorRegistry",
|
||||
"Pipeline",
|
||||
"PipelineNode",
|
||||
"DataSourceAdapter",
|
||||
# Schema types
|
||||
"InputSchema",
|
||||
"OutputSchema",
|
||||
"ComputeContext",
|
||||
"ComputeResult",
|
||||
"IndicatorMetadata",
|
||||
"IndicatorParameter",
|
||||
# TA-Lib integration
|
||||
"TALibIndicator",
|
||||
"register_all_talib_indicators",
|
||||
"is_talib_available",
|
||||
"get_talib_version",
|
||||
# Custom indicators
|
||||
"register_custom_indicators",
|
||||
"CUSTOM_INDICATORS",
|
||||
]
|
||||
230
backend.old/src/indicator/base.py
Normal file
230
backend.old/src/indicator/base.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
Abstract Indicator interface.
|
||||
|
||||
Provides the base class for all technical indicators and derived data transformations.
|
||||
Indicators compose into DAGs, processing data incrementally as updates arrive.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .schema import (
|
||||
ComputeContext,
|
||||
ComputeResult,
|
||||
IndicatorMetadata,
|
||||
InputSchema,
|
||||
OutputSchema,
|
||||
)
|
||||
|
||||
|
||||
class Indicator(ABC):
|
||||
"""
|
||||
Abstract base class for all indicators.
|
||||
|
||||
Indicators are composable transformation nodes that:
|
||||
- Declare input schema (columns they need)
|
||||
- Declare output schema (columns they produce)
|
||||
- Compute outputs synchronously from inputs
|
||||
- Support incremental updates (process only what changed)
|
||||
- Provide rich metadata for AI agent discovery
|
||||
|
||||
Indicators are stateless at the instance level - all state is managed
|
||||
by the pipeline execution engine. This allows the same indicator class
|
||||
to be reused with different parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, instance_name: str, **params):
|
||||
"""
|
||||
Initialize an indicator instance.
|
||||
|
||||
Args:
|
||||
instance_name: Unique name for this instance (used for output column prefixing)
|
||||
**params: Configuration parameters (validated against metadata.parameters)
|
||||
"""
|
||||
self.instance_name = instance_name
|
||||
self.params = params
|
||||
self._validate_params()
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
"""
|
||||
Get metadata for this indicator class.
|
||||
|
||||
Called by the registry for AI agent discovery and documentation.
|
||||
Should return comprehensive information about the indicator's purpose,
|
||||
parameters, and use cases.
|
||||
|
||||
Returns:
|
||||
IndicatorMetadata describing this indicator class
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
"""
|
||||
Get the input schema required by this indicator.
|
||||
|
||||
Declares what columns must be present in the input data.
|
||||
The pipeline will match this against available data sources.
|
||||
|
||||
Returns:
|
||||
InputSchema describing required and optional input columns
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
"""
|
||||
Get the output schema produced by this indicator.
|
||||
|
||||
Output column names will be automatically prefixed with the instance name
|
||||
by the pipeline engine.
|
||||
|
||||
Args:
|
||||
**params: Configuration parameters (may affect output schema)
|
||||
|
||||
Returns:
|
||||
OutputSchema describing the columns this indicator produces
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
"""
|
||||
Compute indicator values from input data.
|
||||
|
||||
This method is called synchronously by the pipeline engine whenever
|
||||
input data changes. Implementations should:
|
||||
|
||||
1. Extract needed columns from context.data
|
||||
2. Perform calculations
|
||||
3. Return results with proper time alignment
|
||||
|
||||
For incremental updates (context.is_incremental == True):
|
||||
- context.data contains only new/updated rows
|
||||
- Implementations MAY optimize by computing only these rows
|
||||
- OR implementations MAY recompute everything (simpler but slower)
|
||||
|
||||
Args:
|
||||
context: Input data and update metadata
|
||||
|
||||
Returns:
|
||||
ComputeResult with calculated indicator values
|
||||
|
||||
Raises:
|
||||
ValueError: If input data doesn't match expected schema
|
||||
"""
|
||||
pass
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
"""
|
||||
Validate that provided parameters match the metadata specification.
|
||||
|
||||
Raises:
|
||||
ValueError: If required parameters are missing or invalid
|
||||
"""
|
||||
metadata = self.get_metadata()
|
||||
|
||||
# Check for required parameters
|
||||
for param_def in metadata.parameters:
|
||||
if param_def.required and param_def.name not in self.params:
|
||||
raise ValueError(
|
||||
f"Indicator '{metadata.name}' requires parameter '{param_def.name}'"
|
||||
)
|
||||
|
||||
# Validate parameter types and ranges
|
||||
for name, value in self.params.items():
|
||||
# Find parameter definition
|
||||
param_def = next(
|
||||
(p for p in metadata.parameters if p.name == name),
|
||||
None
|
||||
)
|
||||
|
||||
if param_def is None:
|
||||
raise ValueError(
|
||||
f"Unknown parameter '{name}' for indicator '{metadata.name}'"
|
||||
)
|
||||
|
||||
# Type checking
|
||||
if param_def.type == "int" and not isinstance(value, int):
|
||||
raise ValueError(
|
||||
f"Parameter '{name}' must be int, got {type(value).__name__}"
|
||||
)
|
||||
elif param_def.type == "float" and not isinstance(value, (int, float)):
|
||||
raise ValueError(
|
||||
f"Parameter '{name}' must be float, got {type(value).__name__}"
|
||||
)
|
||||
elif param_def.type == "bool" and not isinstance(value, bool):
|
||||
raise ValueError(
|
||||
f"Parameter '{name}' must be bool, got {type(value).__name__}"
|
||||
)
|
||||
elif param_def.type == "string" and not isinstance(value, str):
|
||||
raise ValueError(
|
||||
f"Parameter '{name}' must be string, got {type(value).__name__}"
|
||||
)
|
||||
|
||||
# Range checking for numeric types
|
||||
if param_def.type in ("int", "float"):
|
||||
if param_def.min_value is not None and value < param_def.min_value:
|
||||
raise ValueError(
|
||||
f"Parameter '{name}' must be >= {param_def.min_value}, got {value}"
|
||||
)
|
||||
if param_def.max_value is not None and value > param_def.max_value:
|
||||
raise ValueError(
|
||||
f"Parameter '{name}' must be <= {param_def.max_value}, got {value}"
|
||||
)
|
||||
|
||||
def get_output_columns(self) -> List[str]:
|
||||
"""
|
||||
Get the output column names with instance name prefix.
|
||||
|
||||
Returns:
|
||||
List of prefixed output column names
|
||||
"""
|
||||
output_schema = self.get_output_schema(**self.params)
|
||||
prefixed = output_schema.with_prefix(self.instance_name)
|
||||
return [col.name for col in prefixed.columns if col.name != output_schema.time_column]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(instance_name='{self.instance_name}', params={self.params})"
|
||||
|
||||
|
||||
class DataSourceAdapter:
|
||||
"""
|
||||
Adapter to make a DataSource look like an Indicator for pipeline composition.
|
||||
|
||||
This allows DataSources to be inputs to indicators in a unified way.
|
||||
"""
|
||||
|
||||
def __init__(self, datasource_id: str, symbol: str, resolution: str):
|
||||
"""
|
||||
Create a DataSource adapter.
|
||||
|
||||
Args:
|
||||
datasource_id: Identifier for the datasource (e.g., 'ccxt', 'demo')
|
||||
symbol: Symbol to query (e.g., 'BTC/USD')
|
||||
resolution: Time resolution (e.g., '1', '5', '1D')
|
||||
"""
|
||||
self.datasource_id = datasource_id
|
||||
self.symbol = symbol
|
||||
self.resolution = resolution
|
||||
self.instance_name = f"ds_{datasource_id}_{symbol}_{resolution}".replace("/", "_").replace(":", "_")
|
||||
|
||||
def get_output_columns(self) -> List[str]:
|
||||
"""
|
||||
Get the columns provided by this datasource.
|
||||
|
||||
Note: This requires runtime resolution - the pipeline engine
|
||||
will need to query the actual DataSource to get the schema.
|
||||
|
||||
Returns:
|
||||
List of column names (placeholder - needs runtime resolution)
|
||||
"""
|
||||
# This will be resolved at runtime by the pipeline engine
|
||||
return []
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DataSourceAdapter(datasource='{self.datasource_id}', symbol='{self.symbol}', resolution='{self.resolution}')"
|
||||
954
backend.old/src/indicator/custom_indicators.py
Normal file
954
backend.old/src/indicator/custom_indicators.py
Normal file
@@ -0,0 +1,954 @@
|
||||
"""
|
||||
Custom indicator implementations for TradingView indicators not in TA-Lib.
|
||||
|
||||
These indicators follow TA-Lib style conventions and integrate seamlessly
|
||||
with the indicator framework. All implementations are based on well-known,
|
||||
publicly documented formulas.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
import numpy as np
|
||||
|
||||
from datasource.schema import ColumnInfo
|
||||
from .base import Indicator
|
||||
from .schema import (
|
||||
ComputeContext,
|
||||
ComputeResult,
|
||||
IndicatorMetadata,
|
||||
IndicatorParameter,
|
||||
InputSchema,
|
||||
OutputSchema,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VWAP(Indicator):
|
||||
"""Volume Weighted Average Price - Most widely used institutional indicator."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="VWAP",
|
||||
display_name="VWAP",
|
||||
description="Volume Weighted Average Price - Average price weighted by volume",
|
||||
category="volume",
|
||||
parameters=[],
|
||||
use_cases=[
|
||||
"Institutional reference price",
|
||||
"Support/resistance levels",
|
||||
"Mean reversion trading"
|
||||
],
|
||||
references=["https://www.investopedia.com/terms/v/vwap.asp"],
|
||||
tags=["vwap", "volume", "institutional"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
ColumnInfo(name="volume", type="float", description="Volume"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="vwap", type="float", description="Volume Weighted Average Price", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
volume = np.array([float(v) if v is not None else np.nan for v in context.get_column("volume")])
|
||||
|
||||
# Typical price
|
||||
typical_price = (high + low + close) / 3.0
|
||||
|
||||
# VWAP = cumsum(typical_price * volume) / cumsum(volume)
|
||||
cumulative_tp_vol = np.nancumsum(typical_price * volume)
|
||||
cumulative_vol = np.nancumsum(volume)
|
||||
|
||||
vwap = cumulative_tp_vol / cumulative_vol
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "vwap": float(vwap[i]) if not np.isnan(vwap[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class VWMA(Indicator):
|
||||
"""Volume Weighted Moving Average."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="VWMA",
|
||||
display_name="VWMA",
|
||||
description="Volume Weighted Moving Average - Moving average weighted by volume",
|
||||
category="overlap",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="Period length",
|
||||
default=20,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Volume-aware trend following", "Dynamic support/resistance"],
|
||||
references=["https://www.investopedia.com/articles/trading/11/trading-with-vwap-mvwap.asp"],
|
||||
tags=["vwma", "volume", "moving average"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
ColumnInfo(name="volume", type="float", description="Volume"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="vwma", type="float", description="Volume Weighted Moving Average", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
volume = np.array([float(v) if v is not None else np.nan for v in context.get_column("volume")])
|
||||
length = self.params.get("length", 20)
|
||||
|
||||
vwma = np.full_like(close, np.nan)
|
||||
|
||||
for i in range(length - 1, len(close)):
|
||||
window_close = close[i - length + 1:i + 1]
|
||||
window_volume = volume[i - length + 1:i + 1]
|
||||
vwma[i] = np.sum(window_close * window_volume) / np.sum(window_volume)
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "vwma": float(vwma[i]) if not np.isnan(vwma[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class HullMA(Indicator):
|
||||
"""Hull Moving Average - Fast and smooth moving average."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="HMA",
|
||||
display_name="Hull Moving Average",
|
||||
description="Hull Moving Average - Reduces lag while maintaining smoothness",
|
||||
category="overlap",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="Period length",
|
||||
default=9,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Low-lag trend following", "Quick trend reversal detection"],
|
||||
references=["https://alanhull.com/hull-moving-average"],
|
||||
tags=["hma", "hull", "moving average", "low-lag"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="hma", type="float", description="Hull Moving Average", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
length = self.params.get("length", 9)
|
||||
|
||||
def wma(data, period):
|
||||
"""Weighted Moving Average."""
|
||||
weights = np.arange(1, period + 1)
|
||||
result = np.full_like(data, np.nan)
|
||||
for i in range(period - 1, len(data)):
|
||||
window = data[i - period + 1:i + 1]
|
||||
result[i] = np.sum(weights * window) / np.sum(weights)
|
||||
return result
|
||||
|
||||
# HMA = WMA(2 * WMA(n/2) - WMA(n)), sqrt(n))
|
||||
half_length = length // 2
|
||||
sqrt_length = int(np.sqrt(length))
|
||||
|
||||
wma_half = wma(close, half_length)
|
||||
wma_full = wma(close, length)
|
||||
raw_hma = 2 * wma_half - wma_full
|
||||
hma = wma(raw_hma, sqrt_length)
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "hma": float(hma[i]) if not np.isnan(hma[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class SuperTrend(Indicator):
|
||||
"""SuperTrend - Popular trend following indicator."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="SUPERTREND",
|
||||
display_name="SuperTrend",
|
||||
description="SuperTrend - Volatility-based trend indicator",
|
||||
category="overlap",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="ATR period",
|
||||
default=10,
|
||||
min_value=1,
|
||||
required=False
|
||||
),
|
||||
IndicatorParameter(
|
||||
name="multiplier",
|
||||
type="float",
|
||||
description="ATR multiplier",
|
||||
default=3.0,
|
||||
min_value=0.1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Trend identification", "Stop loss placement", "Trend reversal signals"],
|
||||
references=["https://www.investopedia.com/articles/trading/08/supertrend-indicator.asp"],
|
||||
tags=["supertrend", "trend", "volatility"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="supertrend", type="float", description="SuperTrend value", nullable=True),
|
||||
ColumnInfo(name="direction", type="int", description="Trend direction (1=up, -1=down)", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
|
||||
length = self.params.get("length", 10)
|
||||
multiplier = self.params.get("multiplier", 3.0)
|
||||
|
||||
# Calculate ATR
|
||||
tr = np.maximum(high - low, np.maximum(np.abs(high - np.roll(close, 1)), np.abs(low - np.roll(close, 1))))
|
||||
tr[0] = high[0] - low[0]
|
||||
|
||||
atr = np.full_like(close, np.nan)
|
||||
atr[length - 1] = np.mean(tr[:length])
|
||||
for i in range(length, len(tr)):
|
||||
atr[i] = (atr[i - 1] * (length - 1) + tr[i]) / length
|
||||
|
||||
# Calculate basic bands
|
||||
hl2 = (high + low) / 2
|
||||
basic_upper = hl2 + multiplier * atr
|
||||
basic_lower = hl2 - multiplier * atr
|
||||
|
||||
# Calculate final bands
|
||||
final_upper = np.full_like(close, np.nan)
|
||||
final_lower = np.full_like(close, np.nan)
|
||||
supertrend = np.full_like(close, np.nan)
|
||||
direction = np.full_like(close, np.nan)
|
||||
|
||||
for i in range(length, len(close)):
|
||||
if i == length:
|
||||
final_upper[i] = basic_upper[i]
|
||||
final_lower[i] = basic_lower[i]
|
||||
else:
|
||||
final_upper[i] = basic_upper[i] if basic_upper[i] < final_upper[i - 1] or close[i - 1] > final_upper[i - 1] else final_upper[i - 1]
|
||||
final_lower[i] = basic_lower[i] if basic_lower[i] > final_lower[i - 1] or close[i - 1] < final_lower[i - 1] else final_lower[i - 1]
|
||||
|
||||
if i == length:
|
||||
supertrend[i] = final_upper[i] if close[i] <= hl2[i] else final_lower[i]
|
||||
direction[i] = -1 if close[i] <= hl2[i] else 1
|
||||
else:
|
||||
if supertrend[i - 1] == final_upper[i - 1] and close[i] <= final_upper[i]:
|
||||
supertrend[i] = final_upper[i]
|
||||
direction[i] = -1
|
||||
elif supertrend[i - 1] == final_upper[i - 1] and close[i] > final_upper[i]:
|
||||
supertrend[i] = final_lower[i]
|
||||
direction[i] = 1
|
||||
elif supertrend[i - 1] == final_lower[i - 1] and close[i] >= final_lower[i]:
|
||||
supertrend[i] = final_lower[i]
|
||||
direction[i] = 1
|
||||
else:
|
||||
supertrend[i] = final_upper[i]
|
||||
direction[i] = -1
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{
|
||||
"time": times[i],
|
||||
"supertrend": float(supertrend[i]) if not np.isnan(supertrend[i]) else None,
|
||||
"direction": int(direction[i]) if not np.isnan(direction[i]) else None
|
||||
}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class DonchianChannels(Indicator):
|
||||
"""Donchian Channels - Breakout indicator using highest high and lowest low."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="DONCHIAN",
|
||||
display_name="Donchian Channels",
|
||||
description="Donchian Channels - Highest high and lowest low over period",
|
||||
category="overlap",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="Period length",
|
||||
default=20,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Breakout trading", "Volatility bands", "Support/resistance"],
|
||||
references=["https://www.investopedia.com/terms/d/donchianchannels.asp"],
|
||||
tags=["donchian", "channels", "breakout"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="upper", type="float", description="Upper channel", nullable=True),
|
||||
ColumnInfo(name="middle", type="float", description="Middle line", nullable=True),
|
||||
ColumnInfo(name="lower", type="float", description="Lower channel", nullable=True),
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
length = self.params.get("length", 20)
|
||||
|
||||
upper = np.full_like(high, np.nan)
|
||||
lower = np.full_like(low, np.nan)
|
||||
|
||||
for i in range(length - 1, len(high)):
|
||||
upper[i] = np.nanmax(high[i - length + 1:i + 1])
|
||||
lower[i] = np.nanmin(low[i - length + 1:i + 1])
|
||||
|
||||
middle = (upper + lower) / 2
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{
|
||||
"time": times[i],
|
||||
"upper": float(upper[i]) if not np.isnan(upper[i]) else None,
|
||||
"middle": float(middle[i]) if not np.isnan(middle[i]) else None,
|
||||
"lower": float(lower[i]) if not np.isnan(lower[i]) else None,
|
||||
}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class KeltnerChannels(Indicator):
|
||||
"""Keltner Channels - ATR-based volatility bands."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="KELTNER",
|
||||
display_name="Keltner Channels",
|
||||
description="Keltner Channels - EMA with ATR-based bands",
|
||||
category="volatility",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="EMA period",
|
||||
default=20,
|
||||
min_value=1,
|
||||
required=False
|
||||
),
|
||||
IndicatorParameter(
|
||||
name="multiplier",
|
||||
type="float",
|
||||
description="ATR multiplier",
|
||||
default=2.0,
|
||||
min_value=0.1,
|
||||
required=False
|
||||
),
|
||||
IndicatorParameter(
|
||||
name="atr_length",
|
||||
type="int",
|
||||
description="ATR period",
|
||||
default=10,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Volatility bands", "Overbought/oversold", "Trend strength"],
|
||||
references=["https://www.investopedia.com/terms/k/keltnerchannel.asp"],
|
||||
tags=["keltner", "channels", "volatility", "atr"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="upper", type="float", description="Upper band", nullable=True),
|
||||
ColumnInfo(name="middle", type="float", description="Middle line (EMA)", nullable=True),
|
||||
ColumnInfo(name="lower", type="float", description="Lower band", nullable=True),
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
|
||||
length = self.params.get("length", 20)
|
||||
multiplier = self.params.get("multiplier", 2.0)
|
||||
atr_length = self.params.get("atr_length", 10)
|
||||
|
||||
# Calculate EMA
|
||||
alpha = 2.0 / (length + 1)
|
||||
ema = np.full_like(close, np.nan)
|
||||
ema[0] = close[0]
|
||||
for i in range(1, len(close)):
|
||||
ema[i] = alpha * close[i] + (1 - alpha) * ema[i - 1]
|
||||
|
||||
# Calculate ATR
|
||||
tr = np.maximum(high - low, np.maximum(np.abs(high - np.roll(close, 1)), np.abs(low - np.roll(close, 1))))
|
||||
tr[0] = high[0] - low[0]
|
||||
|
||||
atr = np.full_like(close, np.nan)
|
||||
atr[atr_length - 1] = np.mean(tr[:atr_length])
|
||||
for i in range(atr_length, len(tr)):
|
||||
atr[i] = (atr[i - 1] * (atr_length - 1) + tr[i]) / atr_length
|
||||
|
||||
upper = ema + multiplier * atr
|
||||
lower = ema - multiplier * atr
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{
|
||||
"time": times[i],
|
||||
"upper": float(upper[i]) if not np.isnan(upper[i]) else None,
|
||||
"middle": float(ema[i]) if not np.isnan(ema[i]) else None,
|
||||
"lower": float(lower[i]) if not np.isnan(lower[i]) else None,
|
||||
}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class ChaikinMoneyFlow(Indicator):
|
||||
"""Chaikin Money Flow - Volume-weighted accumulation/distribution."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="CMF",
|
||||
display_name="Chaikin Money Flow",
|
||||
description="Chaikin Money Flow - Measures buying and selling pressure",
|
||||
category="volume",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="Period length",
|
||||
default=20,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Buying/selling pressure", "Trend confirmation", "Divergence analysis"],
|
||||
references=["https://www.investopedia.com/terms/c/chaikinoscillator.asp"],
|
||||
tags=["cmf", "chaikin", "volume", "money flow"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
ColumnInfo(name="volume", type="float", description="Volume"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="cmf", type="float", description="Chaikin Money Flow", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
volume = np.array([float(v) if v is not None else np.nan for v in context.get_column("volume")])
|
||||
length = self.params.get("length", 20)
|
||||
|
||||
# Money Flow Multiplier
|
||||
mfm = ((close - low) - (high - close)) / (high - low)
|
||||
mfm = np.where(high == low, 0, mfm)
|
||||
|
||||
# Money Flow Volume
|
||||
mfv = mfm * volume
|
||||
|
||||
# CMF
|
||||
cmf = np.full_like(close, np.nan)
|
||||
for i in range(length - 1, len(close)):
|
||||
cmf[i] = np.nansum(mfv[i - length + 1:i + 1]) / np.nansum(volume[i - length + 1:i + 1])
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "cmf": float(cmf[i]) if not np.isnan(cmf[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class VortexIndicator(Indicator):
|
||||
"""Vortex Indicator - Identifies trend direction and strength."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="VORTEX",
|
||||
display_name="Vortex Indicator",
|
||||
description="Vortex Indicator - Trend direction and strength",
|
||||
category="momentum",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="Period length",
|
||||
default=14,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Trend identification", "Trend reversals", "Trend strength"],
|
||||
references=["https://www.investopedia.com/terms/v/vortex-indicator-vi.asp"],
|
||||
tags=["vortex", "trend", "momentum"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="vi_plus", type="float", description="Positive Vortex", nullable=True),
|
||||
ColumnInfo(name="vi_minus", type="float", description="Negative Vortex", nullable=True),
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
length = self.params.get("length", 14)
|
||||
|
||||
# Vortex Movement
|
||||
vm_plus = np.abs(high - np.roll(low, 1))
|
||||
vm_minus = np.abs(low - np.roll(high, 1))
|
||||
vm_plus[0] = 0
|
||||
vm_minus[0] = 0
|
||||
|
||||
# True Range
|
||||
tr = np.maximum(high - low, np.maximum(np.abs(high - np.roll(close, 1)), np.abs(low - np.roll(close, 1))))
|
||||
tr[0] = high[0] - low[0]
|
||||
|
||||
# Vortex Indicator
|
||||
vi_plus = np.full_like(close, np.nan)
|
||||
vi_minus = np.full_like(close, np.nan)
|
||||
|
||||
for i in range(length - 1, len(close)):
|
||||
sum_vm_plus = np.sum(vm_plus[i - length + 1:i + 1])
|
||||
sum_vm_minus = np.sum(vm_minus[i - length + 1:i + 1])
|
||||
sum_tr = np.sum(tr[i - length + 1:i + 1])
|
||||
|
||||
if sum_tr != 0:
|
||||
vi_plus[i] = sum_vm_plus / sum_tr
|
||||
vi_minus[i] = sum_vm_minus / sum_tr
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{
|
||||
"time": times[i],
|
||||
"vi_plus": float(vi_plus[i]) if not np.isnan(vi_plus[i]) else None,
|
||||
"vi_minus": float(vi_minus[i]) if not np.isnan(vi_minus[i]) else None,
|
||||
}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class AwesomeOscillator(Indicator):
|
||||
"""Awesome Oscillator - Bill Williams' momentum indicator."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="AO",
|
||||
display_name="Awesome Oscillator",
|
||||
description="Awesome Oscillator - Difference between 5 and 34 period SMAs of midpoint",
|
||||
category="momentum",
|
||||
parameters=[],
|
||||
use_cases=["Momentum shifts", "Trend reversals", "Divergence trading"],
|
||||
references=["https://www.investopedia.com/terms/a/awesomeoscillator.asp"],
|
||||
tags=["awesome", "oscillator", "momentum", "williams"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="ao", type="float", description="Awesome Oscillator", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
|
||||
midpoint = (high + low) / 2
|
||||
|
||||
# SMA 5
|
||||
sma5 = np.full_like(midpoint, np.nan)
|
||||
for i in range(4, len(midpoint)):
|
||||
sma5[i] = np.mean(midpoint[i - 4:i + 1])
|
||||
|
||||
# SMA 34
|
||||
sma34 = np.full_like(midpoint, np.nan)
|
||||
for i in range(33, len(midpoint)):
|
||||
sma34[i] = np.mean(midpoint[i - 33:i + 1])
|
||||
|
||||
ao = sma5 - sma34
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "ao": float(ao[i]) if not np.isnan(ao[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class AcceleratorOscillator(Indicator):
|
||||
"""Accelerator Oscillator - Rate of change of Awesome Oscillator."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="AC",
|
||||
display_name="Accelerator Oscillator",
|
||||
description="Accelerator Oscillator - Rate of change of Awesome Oscillator",
|
||||
category="momentum",
|
||||
parameters=[],
|
||||
use_cases=["Early momentum detection", "Trend acceleration", "Divergence signals"],
|
||||
references=["https://www.investopedia.com/terms/a/accelerator-oscillator.asp"],
|
||||
tags=["accelerator", "oscillator", "momentum", "williams"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="ac", type="float", description="Accelerator Oscillator", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
|
||||
midpoint = (high + low) / 2
|
||||
|
||||
# Calculate AO first
|
||||
sma5 = np.full_like(midpoint, np.nan)
|
||||
for i in range(4, len(midpoint)):
|
||||
sma5[i] = np.mean(midpoint[i - 4:i + 1])
|
||||
|
||||
sma34 = np.full_like(midpoint, np.nan)
|
||||
for i in range(33, len(midpoint)):
|
||||
sma34[i] = np.mean(midpoint[i - 33:i + 1])
|
||||
|
||||
ao = sma5 - sma34
|
||||
|
||||
# AC = AO - SMA(AO, 5)
|
||||
sma_ao = np.full_like(ao, np.nan)
|
||||
for i in range(4, len(ao)):
|
||||
if not np.isnan(ao[i - 4:i + 1]).any():
|
||||
sma_ao[i] = np.mean(ao[i - 4:i + 1])
|
||||
|
||||
ac = ao - sma_ao
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "ac": float(ac[i]) if not np.isnan(ac[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class ChoppinessIndex(Indicator):
|
||||
"""Choppiness Index - Determines if market is choppy or trending."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="CHOP",
|
||||
display_name="Choppiness Index",
|
||||
description="Choppiness Index - Measures market trendiness vs consolidation",
|
||||
category="volatility",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="Period length",
|
||||
default=14,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Trend vs range identification", "Market regime detection"],
|
||||
references=["https://www.tradingview.com/support/solutions/43000501980/"],
|
||||
tags=["chop", "choppiness", "trend", "range"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="chop", type="float", description="Choppiness Index (0-100)", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
length = self.params.get("length", 14)
|
||||
|
||||
# True Range
|
||||
tr = np.maximum(high - low, np.maximum(np.abs(high - np.roll(close, 1)), np.abs(low - np.roll(close, 1))))
|
||||
tr[0] = high[0] - low[0]
|
||||
|
||||
chop = np.full_like(close, np.nan)
|
||||
|
||||
for i in range(length - 1, len(close)):
|
||||
sum_tr = np.sum(tr[i - length + 1:i + 1])
|
||||
high_low_diff = np.max(high[i - length + 1:i + 1]) - np.min(low[i - length + 1:i + 1])
|
||||
|
||||
if high_low_diff != 0:
|
||||
chop[i] = 100 * np.log10(sum_tr / high_low_diff) / np.log10(length)
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "chop": float(chop[i]) if not np.isnan(chop[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class MassIndex(Indicator):
|
||||
"""Mass Index - Identifies trend reversals based on range expansion."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="MASS",
|
||||
display_name="Mass Index",
|
||||
description="Mass Index - Identifies reversals when range narrows then expands",
|
||||
category="volatility",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="fast_period",
|
||||
type="int",
|
||||
description="Fast EMA period",
|
||||
default=9,
|
||||
min_value=1,
|
||||
required=False
|
||||
),
|
||||
IndicatorParameter(
|
||||
name="slow_period",
|
||||
type="int",
|
||||
description="Slow EMA period",
|
||||
default=25,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Reversal detection", "Volatility analysis", "Bulge identification"],
|
||||
references=["https://www.investopedia.com/terms/m/mass-index.asp"],
|
||||
tags=["mass", "index", "volatility", "reversal"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="mass", type="float", description="Mass Index", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
|
||||
fast_period = self.params.get("fast_period", 9)
|
||||
slow_period = self.params.get("slow_period", 25)
|
||||
|
||||
hl_range = high - low
|
||||
|
||||
# Single EMA
|
||||
alpha1 = 2.0 / (fast_period + 1)
|
||||
ema1 = np.full_like(hl_range, np.nan)
|
||||
ema1[0] = hl_range[0]
|
||||
for i in range(1, len(hl_range)):
|
||||
ema1[i] = alpha1 * hl_range[i] + (1 - alpha1) * ema1[i - 1]
|
||||
|
||||
# Double EMA
|
||||
ema2 = np.full_like(ema1, np.nan)
|
||||
ema2[0] = ema1[0]
|
||||
for i in range(1, len(ema1)):
|
||||
if not np.isnan(ema1[i]):
|
||||
ema2[i] = alpha1 * ema1[i] + (1 - alpha1) * ema2[i - 1]
|
||||
|
||||
# EMA Ratio
|
||||
ema_ratio = ema1 / ema2
|
||||
|
||||
# Mass Index
|
||||
mass = np.full_like(hl_range, np.nan)
|
||||
for i in range(slow_period - 1, len(ema_ratio)):
|
||||
mass[i] = np.nansum(ema_ratio[i - slow_period + 1:i + 1])
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "mass": float(mass[i]) if not np.isnan(mass[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
# Registry of all custom indicators
|
||||
CUSTOM_INDICATORS = [
|
||||
VWAP,
|
||||
VWMA,
|
||||
HullMA,
|
||||
SuperTrend,
|
||||
DonchianChannels,
|
||||
KeltnerChannels,
|
||||
ChaikinMoneyFlow,
|
||||
VortexIndicator,
|
||||
AwesomeOscillator,
|
||||
AcceleratorOscillator,
|
||||
ChoppinessIndex,
|
||||
MassIndex,
|
||||
]
|
||||
|
||||
|
||||
def register_custom_indicators(registry) -> int:
|
||||
"""
|
||||
Register all custom indicators with the registry.
|
||||
|
||||
Args:
|
||||
registry: IndicatorRegistry instance
|
||||
|
||||
Returns:
|
||||
Number of indicators registered
|
||||
"""
|
||||
registered_count = 0
|
||||
|
||||
for indicator_class in CUSTOM_INDICATORS:
|
||||
try:
|
||||
registry.register(indicator_class)
|
||||
registered_count += 1
|
||||
logger.debug(f"Registered custom indicator: {indicator_class.__name__}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register custom indicator {indicator_class.__name__}: {e}")
|
||||
|
||||
logger.info(f"Registered {registered_count} custom indicators")
|
||||
return registered_count
|
||||
439
backend.old/src/indicator/pipeline.py
Normal file
439
backend.old/src/indicator/pipeline.py
Normal file
@@ -0,0 +1,439 @@
|
||||
"""
|
||||
Pipeline execution engine for composable indicators.
|
||||
|
||||
Manages DAG construction, dependency resolution, incremental updates,
|
||||
and efficient data flow through indicator chains.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict, deque
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from datasource.base import DataSource
|
||||
from datasource.schema import ColumnInfo
|
||||
|
||||
from .base import DataSourceAdapter, Indicator
|
||||
from .schema import ComputeContext, ComputeResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PipelineNode:
|
||||
"""
|
||||
A node in the pipeline DAG.
|
||||
|
||||
Can be either a DataSource adapter or an Indicator instance.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
node: Union[DataSourceAdapter, Indicator],
|
||||
dependencies: List[str]
|
||||
):
|
||||
"""
|
||||
Create a pipeline node.
|
||||
|
||||
Args:
|
||||
node_id: Unique identifier for this node
|
||||
node: The DataSourceAdapter or Indicator instance
|
||||
dependencies: List of node_ids this node depends on
|
||||
"""
|
||||
self.node_id = node_id
|
||||
self.node = node
|
||||
self.dependencies = dependencies
|
||||
self.output_columns: List[str] = []
|
||||
self.cached_data: List[Dict[str, Any]] = []
|
||||
|
||||
def is_datasource(self) -> bool:
|
||||
"""Check if this node is a DataSource adapter."""
|
||||
return isinstance(self.node, DataSourceAdapter)
|
||||
|
||||
def is_indicator(self) -> bool:
|
||||
"""Check if this node is an Indicator."""
|
||||
return isinstance(self.node, Indicator)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"PipelineNode(id='{self.node_id}', node={self.node}, deps={self.dependencies})"
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""
|
||||
Execution engine for indicator DAGs.
|
||||
|
||||
Manages:
|
||||
- DAG construction and validation
|
||||
- Topological sorting for execution order
|
||||
- Data flow and caching
|
||||
- Incremental updates (only recompute what changed)
|
||||
- Schema validation
|
||||
"""
|
||||
|
||||
def __init__(self, datasource_registry):
|
||||
"""
|
||||
Initialize a pipeline.
|
||||
|
||||
Args:
|
||||
datasource_registry: DataSourceRegistry for resolving data sources
|
||||
"""
|
||||
self.datasource_registry = datasource_registry
|
||||
self.nodes: Dict[str, PipelineNode] = {}
|
||||
self.execution_order: List[str] = []
|
||||
self._dirty_nodes: Set[str] = set()
|
||||
|
||||
def add_datasource(
|
||||
self,
|
||||
node_id: str,
|
||||
datasource_name: str,
|
||||
symbol: str,
|
||||
resolution: str
|
||||
) -> None:
|
||||
"""
|
||||
Add a DataSource to the pipeline.
|
||||
|
||||
Args:
|
||||
node_id: Unique identifier for this node
|
||||
datasource_name: Name of the datasource in the registry
|
||||
symbol: Symbol to query
|
||||
resolution: Time resolution
|
||||
|
||||
Raises:
|
||||
ValueError: If node_id already exists or datasource not found
|
||||
"""
|
||||
if node_id in self.nodes:
|
||||
raise ValueError(f"Node '{node_id}' already exists in pipeline")
|
||||
|
||||
datasource = self.datasource_registry.get(datasource_name)
|
||||
if not datasource:
|
||||
raise ValueError(f"DataSource '{datasource_name}' not found in registry")
|
||||
|
||||
adapter = DataSourceAdapter(datasource_name, symbol, resolution)
|
||||
node = PipelineNode(node_id, adapter, dependencies=[])
|
||||
|
||||
self.nodes[node_id] = node
|
||||
self._invalidate_execution_order()
|
||||
|
||||
logger.info(f"Added DataSource node '{node_id}': {datasource_name}/{symbol}@{resolution}")
|
||||
|
||||
def add_indicator(
|
||||
self,
|
||||
node_id: str,
|
||||
indicator: Indicator,
|
||||
input_node_ids: List[str]
|
||||
) -> None:
|
||||
"""
|
||||
Add an Indicator to the pipeline.
|
||||
|
||||
Args:
|
||||
node_id: Unique identifier for this node
|
||||
indicator: Indicator instance
|
||||
input_node_ids: List of node IDs providing input data
|
||||
|
||||
Raises:
|
||||
ValueError: If node_id already exists, dependencies not found, or schema mismatch
|
||||
"""
|
||||
if node_id in self.nodes:
|
||||
raise ValueError(f"Node '{node_id}' already exists in pipeline")
|
||||
|
||||
# Validate dependencies exist
|
||||
for dep_id in input_node_ids:
|
||||
if dep_id not in self.nodes:
|
||||
raise ValueError(f"Dependency node '{dep_id}' not found in pipeline")
|
||||
|
||||
# TODO: Validate input schema matches available columns from dependencies
|
||||
# This requires merging output schemas from all input nodes
|
||||
|
||||
node = PipelineNode(node_id, indicator, dependencies=input_node_ids)
|
||||
self.nodes[node_id] = node
|
||||
self._invalidate_execution_order()
|
||||
|
||||
logger.info(f"Added Indicator node '{node_id}': {indicator} with inputs {input_node_ids}")
|
||||
|
||||
def remove_node(self, node_id: str) -> None:
|
||||
"""
|
||||
Remove a node from the pipeline.
|
||||
|
||||
Args:
|
||||
node_id: Node to remove
|
||||
|
||||
Raises:
|
||||
ValueError: If other nodes depend on this node
|
||||
"""
|
||||
if node_id not in self.nodes:
|
||||
return
|
||||
|
||||
# Check for dependent nodes
|
||||
dependents = [
|
||||
n.node_id for n in self.nodes.values()
|
||||
if node_id in n.dependencies
|
||||
]
|
||||
|
||||
if dependents:
|
||||
raise ValueError(
|
||||
f"Cannot remove node '{node_id}': nodes {dependents} depend on it"
|
||||
)
|
||||
|
||||
del self.nodes[node_id]
|
||||
self._invalidate_execution_order()
|
||||
|
||||
logger.info(f"Removed node '{node_id}' from pipeline")
|
||||
|
||||
def _invalidate_execution_order(self) -> None:
|
||||
"""Mark execution order as needing recomputation."""
|
||||
self.execution_order = []
|
||||
|
||||
def _compute_execution_order(self) -> List[str]:
|
||||
"""
|
||||
Compute topological sort of the DAG.
|
||||
|
||||
Returns:
|
||||
List of node IDs in execution order
|
||||
|
||||
Raises:
|
||||
ValueError: If DAG contains cycles
|
||||
"""
|
||||
if self.execution_order:
|
||||
return self.execution_order
|
||||
|
||||
# Kahn's algorithm for topological sort
|
||||
in_degree = {node_id: 0 for node_id in self.nodes}
|
||||
for node in self.nodes.values():
|
||||
for dep in node.dependencies:
|
||||
in_degree[node.node_id] += 1
|
||||
|
||||
queue = deque([node_id for node_id, degree in in_degree.items() if degree == 0])
|
||||
result = []
|
||||
|
||||
while queue:
|
||||
node_id = queue.popleft()
|
||||
result.append(node_id)
|
||||
|
||||
# Find all nodes that depend on this one
|
||||
for other_node in self.nodes.values():
|
||||
if node_id in other_node.dependencies:
|
||||
in_degree[other_node.node_id] -= 1
|
||||
if in_degree[other_node.node_id] == 0:
|
||||
queue.append(other_node.node_id)
|
||||
|
||||
if len(result) != len(self.nodes):
|
||||
raise ValueError("Pipeline contains cycles")
|
||||
|
||||
self.execution_order = result
|
||||
logger.debug(f"Computed execution order: {result}")
|
||||
return result
|
||||
|
||||
def execute(
|
||||
self,
|
||||
datasource_data: Dict[str, List[Dict[str, Any]]],
|
||||
incremental: bool = False,
|
||||
updated_from_time: Optional[int] = None
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Execute the pipeline.
|
||||
|
||||
Args:
|
||||
datasource_data: Mapping of DataSource node_id to input data
|
||||
incremental: Whether this is an incremental update
|
||||
updated_from_time: Timestamp of earliest updated row (for incremental)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping node_id to output data (all nodes)
|
||||
|
||||
Raises:
|
||||
ValueError: If required datasource data is missing
|
||||
"""
|
||||
execution_order = self._compute_execution_order()
|
||||
results: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
logger.info(
|
||||
f"Executing pipeline with {len(execution_order)} nodes "
|
||||
f"(incremental={incremental})"
|
||||
)
|
||||
|
||||
for node_id in execution_order:
|
||||
node = self.nodes[node_id]
|
||||
|
||||
if node.is_datasource():
|
||||
# DataSource node - get data from input
|
||||
if node_id not in datasource_data:
|
||||
raise ValueError(
|
||||
f"DataSource node '{node_id}' has no input data"
|
||||
)
|
||||
results[node_id] = datasource_data[node_id]
|
||||
node.cached_data = results[node_id]
|
||||
logger.debug(f"DataSource node '{node_id}': {len(results[node_id])} rows")
|
||||
|
||||
elif node.is_indicator():
|
||||
# Indicator node - compute from dependencies
|
||||
indicator = node.node
|
||||
|
||||
# Merge input data from all dependencies
|
||||
input_data = self._merge_dependency_data(node.dependencies, results)
|
||||
|
||||
# Create compute context
|
||||
context = ComputeContext(
|
||||
data=input_data,
|
||||
is_incremental=incremental,
|
||||
updated_from_time=updated_from_time
|
||||
)
|
||||
|
||||
# Execute indicator
|
||||
logger.debug(
|
||||
f"Computing indicator '{node_id}' with {len(input_data)} input rows"
|
||||
)
|
||||
compute_result = indicator.compute(context)
|
||||
|
||||
# Merge result with input data (adding prefixed columns)
|
||||
output_data = compute_result.merge_with_prefix(
|
||||
indicator.instance_name,
|
||||
input_data
|
||||
)
|
||||
|
||||
results[node_id] = output_data
|
||||
node.cached_data = output_data
|
||||
logger.debug(f"Indicator node '{node_id}': {len(output_data)} rows")
|
||||
|
||||
logger.info(f"Pipeline execution complete: {len(results)} nodes processed")
|
||||
return results
|
||||
|
||||
def _merge_dependency_data(
|
||||
self,
|
||||
dependency_ids: List[str],
|
||||
results: Dict[str, List[Dict[str, Any]]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Merge data from multiple dependency nodes.
|
||||
|
||||
Data is merged by time, with later dependencies overwriting earlier ones
|
||||
for conflicting column names.
|
||||
|
||||
Args:
|
||||
dependency_ids: List of node IDs to merge
|
||||
results: Current execution results
|
||||
|
||||
Returns:
|
||||
Merged data rows
|
||||
"""
|
||||
if not dependency_ids:
|
||||
return []
|
||||
|
||||
if len(dependency_ids) == 1:
|
||||
return results[dependency_ids[0]]
|
||||
|
||||
# Build time-indexed data from first dependency
|
||||
merged: Dict[int, Dict[str, Any]] = {}
|
||||
for row in results[dependency_ids[0]]:
|
||||
merged[row["time"]] = row.copy()
|
||||
|
||||
# Merge in additional dependencies
|
||||
for dep_id in dependency_ids[1:]:
|
||||
for row in results[dep_id]:
|
||||
time_key = row["time"]
|
||||
if time_key in merged:
|
||||
# Merge columns (later dependencies win)
|
||||
merged[time_key].update(row)
|
||||
else:
|
||||
# New timestamp
|
||||
merged[time_key] = row.copy()
|
||||
|
||||
# Sort by time and return
|
||||
sorted_times = sorted(merged.keys())
|
||||
return [merged[t] for t in sorted_times]
|
||||
|
||||
def get_node_output(self, node_id: str) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
Get cached output data for a specific node.
|
||||
|
||||
Args:
|
||||
node_id: Node identifier
|
||||
|
||||
Returns:
|
||||
Cached data or None if not available
|
||||
"""
|
||||
node = self.nodes.get(node_id)
|
||||
return node.cached_data if node else None
|
||||
|
||||
def get_output_schema(self, node_id: str) -> List[ColumnInfo]:
|
||||
"""
|
||||
Get the output schema for a specific node.
|
||||
|
||||
Args:
|
||||
node_id: Node identifier
|
||||
|
||||
Returns:
|
||||
List of ColumnInfo describing output columns
|
||||
|
||||
Raises:
|
||||
ValueError: If node not found
|
||||
"""
|
||||
node = self.nodes.get(node_id)
|
||||
if not node:
|
||||
raise ValueError(f"Node '{node_id}' not found")
|
||||
|
||||
if node.is_datasource():
|
||||
# Would need to query the actual datasource at runtime
|
||||
# For now, return empty - this requires integration with DataSource
|
||||
return []
|
||||
|
||||
elif node.is_indicator():
|
||||
indicator = node.node
|
||||
output_schema = indicator.get_output_schema(**indicator.params)
|
||||
prefixed_schema = output_schema.with_prefix(indicator.instance_name)
|
||||
return prefixed_schema.columns
|
||||
|
||||
return []
|
||||
|
||||
def validate_pipeline(self) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Validate the entire pipeline for correctness.
|
||||
|
||||
Checks:
|
||||
- No cycles (already checked in execution order)
|
||||
- All dependencies exist (already checked in add_indicator)
|
||||
- Input schemas match output schemas (TODO)
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
try:
|
||||
self._compute_execution_order()
|
||||
return True, None
|
||||
except ValueError as e:
|
||||
return False, str(e)
|
||||
|
||||
def get_node_count(self) -> int:
|
||||
"""Get the number of nodes in the pipeline."""
|
||||
return len(self.nodes)
|
||||
|
||||
def get_indicator_count(self) -> int:
|
||||
"""Get the number of indicator nodes in the pipeline."""
|
||||
return sum(1 for node in self.nodes.values() if node.is_indicator())
|
||||
|
||||
def get_datasource_count(self) -> int:
|
||||
"""Get the number of datasource nodes in the pipeline."""
|
||||
return sum(1 for node in self.nodes.values() if node.is_datasource())
|
||||
|
||||
def describe(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get a detailed description of the pipeline structure.
|
||||
|
||||
Returns:
|
||||
Dictionary with pipeline metadata and structure
|
||||
"""
|
||||
return {
|
||||
"node_count": self.get_node_count(),
|
||||
"datasource_count": self.get_datasource_count(),
|
||||
"indicator_count": self.get_indicator_count(),
|
||||
"nodes": [
|
||||
{
|
||||
"id": node.node_id,
|
||||
"type": "datasource" if node.is_datasource() else "indicator",
|
||||
"node": str(node.node),
|
||||
"dependencies": node.dependencies,
|
||||
"cached_rows": len(node.cached_data)
|
||||
}
|
||||
for node in self.nodes.values()
|
||||
],
|
||||
"execution_order": self.execution_order or self._compute_execution_order(),
|
||||
"is_valid": self.validate_pipeline()[0]
|
||||
}
|
||||
349
backend.old/src/indicator/registry.py
Normal file
349
backend.old/src/indicator/registry.py
Normal file
@@ -0,0 +1,349 @@
|
||||
"""
|
||||
Indicator registry for managing and discovering indicators.
|
||||
|
||||
Provides AI agents with a queryable catalog of available indicators,
|
||||
their capabilities, and metadata.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from .base import Indicator
|
||||
from .schema import IndicatorMetadata, InputSchema, OutputSchema
|
||||
|
||||
|
||||
class IndicatorRegistry:
|
||||
"""
|
||||
Central registry for indicator classes.
|
||||
|
||||
Enables:
|
||||
- Registration of indicator implementations
|
||||
- Discovery by name, category, or tags
|
||||
- Schema validation
|
||||
- AI agent tool generation
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._indicators: Dict[str, Type[Indicator]] = {}
|
||||
|
||||
def register(self, indicator_class: Type[Indicator]) -> None:
|
||||
"""
|
||||
Register an indicator class.
|
||||
|
||||
Args:
|
||||
indicator_class: Indicator class to register
|
||||
|
||||
Raises:
|
||||
ValueError: If an indicator with this name is already registered
|
||||
"""
|
||||
metadata = indicator_class.get_metadata()
|
||||
|
||||
if metadata.name in self._indicators:
|
||||
raise ValueError(
|
||||
f"Indicator '{metadata.name}' is already registered"
|
||||
)
|
||||
|
||||
self._indicators[metadata.name] = indicator_class
|
||||
|
||||
def unregister(self, name: str) -> None:
|
||||
"""
|
||||
Unregister an indicator class.
|
||||
|
||||
Args:
|
||||
name: Indicator class name
|
||||
"""
|
||||
self._indicators.pop(name, None)
|
||||
|
||||
def get(self, name: str) -> Optional[Type[Indicator]]:
|
||||
"""
|
||||
Get an indicator class by name.
|
||||
|
||||
Args:
|
||||
name: Indicator class name
|
||||
|
||||
Returns:
|
||||
Indicator class or None if not found
|
||||
"""
|
||||
return self._indicators.get(name)
|
||||
|
||||
def list_indicators(self) -> List[str]:
|
||||
"""
|
||||
Get names of all registered indicators.
|
||||
|
||||
Returns:
|
||||
List of indicator class names
|
||||
"""
|
||||
return list(self._indicators.keys())
|
||||
|
||||
def get_metadata(self, name: str) -> Optional[IndicatorMetadata]:
|
||||
"""
|
||||
Get metadata for a specific indicator.
|
||||
|
||||
Args:
|
||||
name: Indicator class name
|
||||
|
||||
Returns:
|
||||
IndicatorMetadata or None if not found
|
||||
"""
|
||||
indicator_class = self.get(name)
|
||||
if indicator_class:
|
||||
return indicator_class.get_metadata()
|
||||
return None
|
||||
|
||||
def get_all_metadata(self) -> List[IndicatorMetadata]:
|
||||
"""
|
||||
Get metadata for all registered indicators.
|
||||
|
||||
Useful for AI agent tool generation and discovery.
|
||||
|
||||
Returns:
|
||||
List of IndicatorMetadata for all registered indicators
|
||||
"""
|
||||
return [cls.get_metadata() for cls in self._indicators.values()]
|
||||
|
||||
def search_by_category(self, category: str) -> List[IndicatorMetadata]:
|
||||
"""
|
||||
Find indicators by category.
|
||||
|
||||
Args:
|
||||
category: Category name (e.g., 'momentum', 'trend', 'volatility')
|
||||
|
||||
Returns:
|
||||
List of matching indicator metadata
|
||||
"""
|
||||
results = []
|
||||
for indicator_class in self._indicators.values():
|
||||
metadata = indicator_class.get_metadata()
|
||||
if metadata.category.lower() == category.lower():
|
||||
results.append(metadata)
|
||||
return results
|
||||
|
||||
def search_by_tag(self, tag: str) -> List[IndicatorMetadata]:
|
||||
"""
|
||||
Find indicators by tag.
|
||||
|
||||
Args:
|
||||
tag: Tag to search for (case-insensitive)
|
||||
|
||||
Returns:
|
||||
List of matching indicator metadata
|
||||
"""
|
||||
tag_lower = tag.lower()
|
||||
results = []
|
||||
for indicator_class in self._indicators.values():
|
||||
metadata = indicator_class.get_metadata()
|
||||
if any(t.lower() == tag_lower for t in metadata.tags):
|
||||
results.append(metadata)
|
||||
return results
|
||||
|
||||
def search_by_text(self, query: str) -> List[IndicatorMetadata]:
|
||||
"""
|
||||
Full-text search across indicator names, descriptions, and use cases.
|
||||
|
||||
Args:
|
||||
query: Search query (case-insensitive)
|
||||
|
||||
Returns:
|
||||
List of matching indicator metadata, ranked by relevance
|
||||
"""
|
||||
query_lower = query.lower()
|
||||
results = []
|
||||
|
||||
for indicator_class in self._indicators.values():
|
||||
metadata = indicator_class.get_metadata()
|
||||
score = 0
|
||||
|
||||
# Check name (highest weight)
|
||||
if query_lower in metadata.name.lower():
|
||||
score += 10
|
||||
if query_lower in metadata.display_name.lower():
|
||||
score += 8
|
||||
|
||||
# Check description
|
||||
if query_lower in metadata.description.lower():
|
||||
score += 5
|
||||
|
||||
# Check use cases
|
||||
for use_case in metadata.use_cases:
|
||||
if query_lower in use_case.lower():
|
||||
score += 3
|
||||
|
||||
# Check tags
|
||||
for tag in metadata.tags:
|
||||
if query_lower in tag.lower():
|
||||
score += 2
|
||||
|
||||
if score > 0:
|
||||
results.append((score, metadata))
|
||||
|
||||
# Sort by score descending
|
||||
results.sort(key=lambda x: x[0], reverse=True)
|
||||
return [metadata for _, metadata in results]
|
||||
|
||||
def find_compatible_indicators(
|
||||
self,
|
||||
available_columns: List[str],
|
||||
column_types: Dict[str, str]
|
||||
) -> List[IndicatorMetadata]:
|
||||
"""
|
||||
Find indicators that can be computed from available columns.
|
||||
|
||||
Args:
|
||||
available_columns: List of column names available
|
||||
column_types: Mapping of column name to type
|
||||
|
||||
Returns:
|
||||
List of indicators whose input schema is satisfied
|
||||
"""
|
||||
from datasource.schema import ColumnInfo
|
||||
|
||||
# Build ColumnInfo list from available data
|
||||
available_schema = [
|
||||
ColumnInfo(
|
||||
name=name,
|
||||
type=column_types.get(name, "float"),
|
||||
description=f"Column {name}"
|
||||
)
|
||||
for name in available_columns
|
||||
]
|
||||
|
||||
results = []
|
||||
for indicator_class in self._indicators.values():
|
||||
input_schema = indicator_class.get_input_schema()
|
||||
if input_schema.matches(available_schema):
|
||||
results.append(indicator_class.get_metadata())
|
||||
|
||||
return results
|
||||
|
||||
def validate_indicator_chain(
|
||||
self,
|
||||
indicator_chain: List[tuple[str, Dict]]
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Validate that a chain of indicators can be connected.
|
||||
|
||||
Args:
|
||||
indicator_chain: List of (indicator_name, params) tuples in execution order
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
if not indicator_chain:
|
||||
return True, None
|
||||
|
||||
# For now, just check that all indicators exist
|
||||
# More sophisticated DAG validation happens in the pipeline engine
|
||||
for indicator_name, params in indicator_chain:
|
||||
if indicator_name not in self._indicators:
|
||||
return False, f"Indicator '{indicator_name}' not found in registry"
|
||||
|
||||
return True, None
|
||||
|
||||
def get_input_schema(self, name: str) -> Optional[InputSchema]:
|
||||
"""
|
||||
Get input schema for a specific indicator.
|
||||
|
||||
Args:
|
||||
name: Indicator class name
|
||||
|
||||
Returns:
|
||||
InputSchema or None if not found
|
||||
"""
|
||||
indicator_class = self.get(name)
|
||||
if indicator_class:
|
||||
return indicator_class.get_input_schema()
|
||||
return None
|
||||
|
||||
def get_output_schema(self, name: str, **params) -> Optional[OutputSchema]:
|
||||
"""
|
||||
Get output schema for a specific indicator with given parameters.
|
||||
|
||||
Args:
|
||||
name: Indicator class name
|
||||
**params: Indicator parameters
|
||||
|
||||
Returns:
|
||||
OutputSchema or None if not found
|
||||
"""
|
||||
indicator_class = self.get(name)
|
||||
if indicator_class:
|
||||
return indicator_class.get_output_schema(**params)
|
||||
return None
|
||||
|
||||
def create_instance(self, name: str, instance_name: str, **params) -> Optional[Indicator]:
|
||||
"""
|
||||
Create an indicator instance with validation.
|
||||
|
||||
Args:
|
||||
name: Indicator class name
|
||||
instance_name: Unique instance name (for output column prefixing)
|
||||
**params: Indicator configuration parameters
|
||||
|
||||
Returns:
|
||||
Indicator instance or None if class not found
|
||||
|
||||
Raises:
|
||||
ValueError: If parameters are invalid
|
||||
"""
|
||||
indicator_class = self.get(name)
|
||||
if not indicator_class:
|
||||
return None
|
||||
|
||||
return indicator_class(instance_name=instance_name, **params)
|
||||
|
||||
def generate_ai_tool_spec(self) -> Dict:
|
||||
"""
|
||||
Generate a JSON specification for AI agent tools.
|
||||
|
||||
Creates a structured representation of all indicators that can be
|
||||
used to build agent tools for indicator selection and composition.
|
||||
|
||||
Returns:
|
||||
Dict suitable for AI agent tool registration
|
||||
"""
|
||||
tools = []
|
||||
|
||||
for indicator_class in self._indicators.values():
|
||||
metadata = indicator_class.get_metadata()
|
||||
|
||||
# Build parameter spec
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
|
||||
for param in metadata.parameters:
|
||||
param_spec = {
|
||||
"type": param.type,
|
||||
"description": param.description
|
||||
}
|
||||
|
||||
if param.default is not None:
|
||||
param_spec["default"] = param.default
|
||||
if param.min_value is not None:
|
||||
param_spec["minimum"] = param.min_value
|
||||
if param.max_value is not None:
|
||||
param_spec["maximum"] = param.max_value
|
||||
|
||||
parameters["properties"][param.name] = param_spec
|
||||
|
||||
if param.required:
|
||||
parameters["required"].append(param.name)
|
||||
|
||||
tool = {
|
||||
"name": f"indicator_{metadata.name.lower()}",
|
||||
"description": f"{metadata.display_name}: {metadata.description}",
|
||||
"category": metadata.category,
|
||||
"use_cases": metadata.use_cases,
|
||||
"tags": metadata.tags,
|
||||
"parameters": parameters,
|
||||
"input_schema": indicator_class.get_input_schema().model_dump(),
|
||||
"output_schema": indicator_class.get_output_schema().model_dump()
|
||||
}
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
return {
|
||||
"indicator_tools": tools,
|
||||
"total_count": len(tools)
|
||||
}
|
||||
269
backend.old/src/indicator/schema.py
Normal file
269
backend.old/src/indicator/schema.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
Data models for the Indicator system.
|
||||
|
||||
Defines schemas for input/output specifications, computation context,
|
||||
and metadata for AI agent discovery.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from datasource.schema import ColumnInfo
|
||||
|
||||
|
||||
class InputSchema(BaseModel):
|
||||
"""
|
||||
Declares the required input columns for an Indicator.
|
||||
|
||||
Indicators match against any data source (DataSource or other Indicator)
|
||||
that provides columns satisfying this schema.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
required_columns: List[ColumnInfo] = Field(
|
||||
description="Columns that must be present in the input data"
|
||||
)
|
||||
optional_columns: List[ColumnInfo] = Field(
|
||||
default_factory=list,
|
||||
description="Columns that may be used if present but are not required"
|
||||
)
|
||||
time_column: str = Field(
|
||||
default="time",
|
||||
description="Name of the timestamp column (must be present)"
|
||||
)
|
||||
|
||||
def matches(self, available_columns: List[ColumnInfo]) -> bool:
|
||||
"""
|
||||
Check if available columns satisfy this input schema.
|
||||
|
||||
Args:
|
||||
available_columns: Columns provided by a data source
|
||||
|
||||
Returns:
|
||||
True if all required columns are present with compatible types
|
||||
"""
|
||||
available_map = {col.name: col for col in available_columns}
|
||||
|
||||
# Check time column exists
|
||||
if self.time_column not in available_map:
|
||||
return False
|
||||
|
||||
# Check all required columns exist with compatible types
|
||||
for required in self.required_columns:
|
||||
if required.name not in available_map:
|
||||
return False
|
||||
available = available_map[required.name]
|
||||
if available.type != required.type:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_missing_columns(self, available_columns: List[ColumnInfo]) -> List[str]:
|
||||
"""
|
||||
Get list of missing required column names.
|
||||
|
||||
Args:
|
||||
available_columns: Columns provided by a data source
|
||||
|
||||
Returns:
|
||||
List of missing column names
|
||||
"""
|
||||
available_names = {col.name for col in available_columns}
|
||||
missing = []
|
||||
|
||||
if self.time_column not in available_names:
|
||||
missing.append(self.time_column)
|
||||
|
||||
for required in self.required_columns:
|
||||
if required.name not in available_names:
|
||||
missing.append(required.name)
|
||||
|
||||
return missing
|
||||
|
||||
|
||||
class OutputSchema(BaseModel):
|
||||
"""
|
||||
Declares the output columns produced by an Indicator.
|
||||
|
||||
Column names will be automatically prefixed with the indicator instance name
|
||||
to avoid collisions in the pipeline.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
columns: List[ColumnInfo] = Field(
|
||||
description="Output columns produced by this indicator"
|
||||
)
|
||||
time_column: str = Field(
|
||||
default="time",
|
||||
description="Name of the timestamp column (passed through from input)"
|
||||
)
|
||||
|
||||
def with_prefix(self, prefix: str) -> "OutputSchema":
|
||||
"""
|
||||
Create a new OutputSchema with all column names prefixed.
|
||||
|
||||
Args:
|
||||
prefix: Prefix to add (e.g., indicator instance name)
|
||||
|
||||
Returns:
|
||||
New OutputSchema with prefixed column names
|
||||
"""
|
||||
prefixed_columns = [
|
||||
ColumnInfo(
|
||||
name=f"{prefix}_{col.name}" if col.name != self.time_column else col.name,
|
||||
type=col.type,
|
||||
description=col.description,
|
||||
unit=col.unit,
|
||||
nullable=col.nullable
|
||||
)
|
||||
for col in self.columns
|
||||
]
|
||||
return OutputSchema(
|
||||
columns=prefixed_columns,
|
||||
time_column=self.time_column
|
||||
)
|
||||
|
||||
|
||||
class IndicatorParameter(BaseModel):
|
||||
"""
|
||||
Metadata for a configurable indicator parameter.
|
||||
|
||||
Used for AI agent discovery and dynamic indicator instantiation.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
name: str = Field(description="Parameter name")
|
||||
type: Literal["int", "float", "string", "bool"] = Field(description="Parameter type")
|
||||
description: str = Field(description="Human and LLM-readable description")
|
||||
default: Optional[Any] = Field(default=None, description="Default value if not specified")
|
||||
required: bool = Field(default=False, description="Whether this parameter is required")
|
||||
min_value: Optional[float] = Field(default=None, description="Minimum value (for numeric types)")
|
||||
max_value: Optional[float] = Field(default=None, description="Maximum value (for numeric types)")
|
||||
|
||||
|
||||
class IndicatorMetadata(BaseModel):
|
||||
"""
|
||||
Rich metadata for an Indicator class.
|
||||
|
||||
Enables AI agents to discover, understand, and instantiate indicators.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
name: str = Field(description="Unique indicator class name (e.g., 'RSI', 'SMA', 'BollingerBands')")
|
||||
display_name: str = Field(description="Human-readable display name")
|
||||
description: str = Field(description="Detailed description of what this indicator computes and why it's useful")
|
||||
category: str = Field(
|
||||
description="Indicator category (e.g., 'momentum', 'trend', 'volatility', 'volume', 'custom')"
|
||||
)
|
||||
parameters: List[IndicatorParameter] = Field(
|
||||
default_factory=list,
|
||||
description="Configurable parameters for this indicator"
|
||||
)
|
||||
use_cases: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Common use cases and trading scenarios where this indicator is helpful"
|
||||
)
|
||||
references: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="URLs or citations for indicator methodology"
|
||||
)
|
||||
tags: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Searchable tags (e.g., 'oscillator', 'mean-reversion', 'price-based')"
|
||||
)
|
||||
|
||||
|
||||
class ComputeContext(BaseModel):
|
||||
"""
|
||||
Context passed to an Indicator's compute() method.
|
||||
|
||||
Contains the input data and metadata about what changed (for incremental updates).
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
data: List[Dict[str, Any]] = Field(
|
||||
description="Input data rows (time-ordered). Each dict is {column_name: value, time: timestamp}"
|
||||
)
|
||||
is_incremental: bool = Field(
|
||||
default=False,
|
||||
description="True if this is an incremental update (only new/changed rows), False for full recompute"
|
||||
)
|
||||
updated_from_time: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Unix timestamp (ms) of the earliest updated row (for incremental updates)"
|
||||
)
|
||||
|
||||
def get_column(self, name: str) -> List[Any]:
|
||||
"""
|
||||
Extract a single column as a list of values.
|
||||
|
||||
Args:
|
||||
name: Column name
|
||||
|
||||
Returns:
|
||||
List of values in time order
|
||||
"""
|
||||
return [row.get(name) for row in self.data]
|
||||
|
||||
def get_times(self) -> List[int]:
|
||||
"""
|
||||
Get the time column as a list.
|
||||
|
||||
Returns:
|
||||
List of timestamps in order
|
||||
"""
|
||||
return [row["time"] for row in self.data]
|
||||
|
||||
|
||||
class ComputeResult(BaseModel):
|
||||
"""
|
||||
Result from an Indicator's compute() method.
|
||||
|
||||
Contains the computed output data with proper column naming.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
data: List[Dict[str, Any]] = Field(
|
||||
description="Output data rows (time-ordered). Must include time column."
|
||||
)
|
||||
is_partial: bool = Field(
|
||||
default=False,
|
||||
description="True if this result only contains updates (for incremental computation)"
|
||||
)
|
||||
|
||||
def merge_with_prefix(self, prefix: str, existing_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Merge this result into existing data with column name prefixing.
|
||||
|
||||
Args:
|
||||
prefix: Prefix to add to all column names except time
|
||||
existing_data: Existing data to merge with (matched by time)
|
||||
|
||||
Returns:
|
||||
Merged data with prefixed columns added
|
||||
"""
|
||||
# Build a time index for new data
|
||||
time_index = {row["time"]: row for row in self.data}
|
||||
|
||||
# Merge into existing data
|
||||
result = []
|
||||
for existing_row in existing_data:
|
||||
row_time = existing_row["time"]
|
||||
merged_row = existing_row.copy()
|
||||
|
||||
if row_time in time_index:
|
||||
new_row = time_index[row_time]
|
||||
for key, value in new_row.items():
|
||||
if key != "time":
|
||||
merged_row[f"{prefix}_{key}"] = value
|
||||
|
||||
result.append(merged_row)
|
||||
|
||||
return result
|
||||
449
backend.old/src/indicator/talib_adapter.py
Normal file
449
backend.old/src/indicator/talib_adapter.py
Normal file
@@ -0,0 +1,449 @@
|
||||
"""
|
||||
TA-Lib indicator adapter.
|
||||
|
||||
Provides automatic registration of all TA-Lib technical indicators
|
||||
as composable Indicator instances.
|
||||
|
||||
Installation Requirements:
|
||||
--------------------------
|
||||
TA-Lib requires both the C library and Python wrapper:
|
||||
|
||||
1. Install TA-Lib C library:
|
||||
- Ubuntu/Debian: sudo apt-get install libta-lib-dev
|
||||
- macOS: brew install ta-lib
|
||||
- From source: https://ta-lib.org/install.html
|
||||
|
||||
2. Install Python wrapper (already in requirements.txt):
|
||||
pip install TA-Lib
|
||||
|
||||
Usage:
|
||||
------
|
||||
from indicator.talib_adapter import register_all_talib_indicators
|
||||
|
||||
# Auto-register all TA-Lib indicators
|
||||
registry = IndicatorRegistry()
|
||||
register_all_talib_indicators(registry)
|
||||
|
||||
# Now you can use any TA-Lib indicator
|
||||
sma = registry.create_instance("SMA", "sma_20", period=20)
|
||||
rsi = registry.create_instance("RSI", "rsi_14", timeperiod=14)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import talib
|
||||
from talib import abstract
|
||||
TALIB_AVAILABLE = True
|
||||
except ImportError:
|
||||
TALIB_AVAILABLE = False
|
||||
talib = None
|
||||
abstract = None
|
||||
|
||||
from datasource.schema import ColumnInfo
|
||||
|
||||
from .base import Indicator
|
||||
from .schema import (
|
||||
ComputeContext,
|
||||
ComputeResult,
|
||||
IndicatorMetadata,
|
||||
IndicatorParameter,
|
||||
InputSchema,
|
||||
OutputSchema,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Mapping of TA-Lib parameter types to our schema types
|
||||
TALIB_TYPE_MAP = {
|
||||
"double": "float",
|
||||
"double[]": "float",
|
||||
"int": "int",
|
||||
"str": "string",
|
||||
}
|
||||
|
||||
# Categorization of TA-Lib functions
|
||||
TALIB_CATEGORIES = {
|
||||
"overlap": ["SMA", "EMA", "WMA", "DEMA", "TEMA", "TRIMA", "KAMA", "MAMA", "T3",
|
||||
"BBANDS", "MIDPOINT", "MIDPRICE", "SAR", "SAREXT", "HT_TRENDLINE"],
|
||||
"momentum": ["RSI", "MOM", "ROC", "ROCP", "ROCR", "ROCR100", "TRIX", "CMO", "DX",
|
||||
"ADX", "ADXR", "APO", "PPO", "MACD", "MACDEXT", "MACDFIX", "MFI",
|
||||
"STOCH", "STOCHF", "STOCHRSI", "WILLR", "CCI", "AROON", "AROONOSC",
|
||||
"BOP", "MINUS_DI", "MINUS_DM", "PLUS_DI", "PLUS_DM", "ULTOSC"],
|
||||
"volume": ["AD", "ADOSC", "OBV"],
|
||||
"volatility": ["ATR", "NATR", "TRANGE"],
|
||||
"price": ["AVGPRICE", "MEDPRICE", "TYPPRICE", "WCLPRICE"],
|
||||
"cycle": ["HT_DCPERIOD", "HT_DCPHASE", "HT_PHASOR", "HT_SINE", "HT_TRENDMODE"],
|
||||
"pattern": ["CDL2CROWS", "CDL3BLACKCROWS", "CDL3INSIDE", "CDL3LINESTRIKE",
|
||||
"CDL3OUTSIDE", "CDL3STARSINSOUTH", "CDL3WHITESOLDIERS", "CDLABANDONEDBABY",
|
||||
"CDLADVANCEBLOCK", "CDLBELTHOLD", "CDLBREAKAWAY", "CDLCLOSINGMARUBOZU",
|
||||
"CDLCONCEALBABYSWALL", "CDLCOUNTERATTACK", "CDLDARKCLOUDCOVER", "CDLDOJI",
|
||||
"CDLDOJISTAR", "CDLDRAGONFLYDOJI", "CDLENGULFING", "CDLEVENINGDOJISTAR",
|
||||
"CDLEVENINGSTAR", "CDLGAPSIDESIDEWHITE", "CDLGRAVESTONEDOJI", "CDLHAMMER",
|
||||
"CDLHANGINGMAN", "CDLHARAMI", "CDLHARAMICROSS", "CDLHIGHWAVE", "CDLHIKKAKE",
|
||||
"CDLHIKKAKEMOD", "CDLHOMINGPIGEON", "CDLIDENTICAL3CROWS", "CDLINNECK",
|
||||
"CDLINVERTEDHAMMER", "CDLKICKING", "CDLKICKINGBYLENGTH", "CDLLADDERBOTTOM",
|
||||
"CDLLONGLEGGEDDOJI", "CDLLONGLINE", "CDLMARUBOZU", "CDLMATCHINGLOW",
|
||||
"CDLMATHOLD", "CDLMORNINGDOJISTAR", "CDLMORNINGSTAR", "CDLONNECK",
|
||||
"CDLPIERCING", "CDLRICKSHAWMAN", "CDLRISEFALL3METHODS", "CDLSEPARATINGLINES",
|
||||
"CDLSHOOTINGSTAR", "CDLSHORTLINE", "CDLSPINNINGTOP", "CDLSTALLEDPATTERN",
|
||||
"CDLSTICKSANDWICH", "CDLTAKURI", "CDLTASUKIGAP", "CDLTHRUSTING", "CDLTRISTAR",
|
||||
"CDLUNIQUE3RIVER", "CDLUPSIDEGAP2CROWS", "CDLXSIDEGAP3METHODS"],
|
||||
"statistic": ["BETA", "CORREL", "LINEARREG", "LINEARREG_ANGLE", "LINEARREG_INTERCEPT",
|
||||
"LINEARREG_SLOPE", "STDDEV", "TSF", "VAR"],
|
||||
"math": ["ADD", "DIV", "MAX", "MAXINDEX", "MIN", "MININDEX", "MINMAX", "MINMAXINDEX",
|
||||
"MULT", "SUB", "SUM"],
|
||||
}
|
||||
|
||||
|
||||
def _get_function_category(func_name: str) -> str:
|
||||
"""Determine the category of a TA-Lib function."""
|
||||
for category, functions in TALIB_CATEGORIES.items():
|
||||
if func_name in functions:
|
||||
return category
|
||||
return "other"
|
||||
|
||||
|
||||
class TALibIndicator(Indicator):
|
||||
"""
|
||||
Generic adapter for TA-Lib technical indicators.
|
||||
|
||||
Wraps any TA-Lib function to work within the composable indicator framework.
|
||||
Handles parameter mapping, input validation, and output formatting.
|
||||
"""
|
||||
|
||||
# Class variable to store the TA-Lib function name
|
||||
talib_function_name: str = None
|
||||
|
||||
def __init__(self, instance_name: str, **params):
|
||||
"""
|
||||
Initialize a TA-Lib indicator.
|
||||
|
||||
Args:
|
||||
instance_name: Unique name for this instance
|
||||
**params: TA-Lib function parameters
|
||||
"""
|
||||
if not TALIB_AVAILABLE:
|
||||
raise ImportError(
|
||||
"TA-Lib is not installed. Please install the TA-Lib C library "
|
||||
"and Python wrapper. See indicator/talib_adapter.py for instructions."
|
||||
)
|
||||
|
||||
super().__init__(instance_name, **params)
|
||||
self._talib_func = abstract.Function(self.talib_function_name)
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
"""Get metadata from TA-Lib function info."""
|
||||
if not TALIB_AVAILABLE:
|
||||
raise ImportError("TA-Lib is not installed")
|
||||
|
||||
func = abstract.Function(cls.talib_function_name)
|
||||
info = func.info
|
||||
|
||||
# Build parameters list from TA-Lib function info
|
||||
parameters = []
|
||||
for param_name, param_info in info.get("parameters", {}).items():
|
||||
# Handle case where param_info is a simple value (int/float) instead of a dict
|
||||
if isinstance(param_info, dict):
|
||||
param_type = TALIB_TYPE_MAP.get(param_info.get("type", "double"), "float")
|
||||
default_value = param_info.get("default_value")
|
||||
else:
|
||||
# param_info is a simple value (default), infer type from the value
|
||||
if isinstance(param_info, int):
|
||||
param_type = "int"
|
||||
elif isinstance(param_info, float):
|
||||
param_type = "float"
|
||||
else:
|
||||
param_type = "float" # Default to float
|
||||
default_value = param_info
|
||||
|
||||
parameters.append(
|
||||
IndicatorParameter(
|
||||
name=param_name,
|
||||
type=param_type,
|
||||
description=f"TA-Lib parameter: {param_name}",
|
||||
default=default_value,
|
||||
required=False
|
||||
)
|
||||
)
|
||||
|
||||
# Get function group/category
|
||||
category = _get_function_category(cls.talib_function_name)
|
||||
|
||||
# Build display name (split camelCase or handle CDL prefix)
|
||||
display_name = cls.talib_function_name
|
||||
if display_name.startswith("CDL"):
|
||||
display_name = display_name[3:] # Remove CDL prefix for patterns
|
||||
|
||||
return IndicatorMetadata(
|
||||
name=cls.talib_function_name,
|
||||
display_name=display_name,
|
||||
description=info.get("display_name", f"TA-Lib {cls.talib_function_name} indicator"),
|
||||
category=category,
|
||||
parameters=parameters,
|
||||
use_cases=[f"Technical analysis using {cls.talib_function_name}"],
|
||||
references=["https://ta-lib.org/function.html"],
|
||||
tags=["talib", category, cls.talib_function_name.lower()]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
"""
|
||||
Get input schema from TA-Lib function requirements.
|
||||
|
||||
Most TA-Lib functions use OHLCV data, but some use subsets.
|
||||
"""
|
||||
if not TALIB_AVAILABLE:
|
||||
raise ImportError("TA-Lib is not installed")
|
||||
|
||||
func = abstract.Function(cls.talib_function_name)
|
||||
info = func.info
|
||||
input_names = info.get("input_names", {})
|
||||
|
||||
required_columns = []
|
||||
|
||||
# Map TA-Lib input names to our schema
|
||||
if "prices" in input_names:
|
||||
price_inputs = input_names["prices"]
|
||||
if "open" in price_inputs:
|
||||
required_columns.append(
|
||||
ColumnInfo(name="open", type="float", description="Opening price")
|
||||
)
|
||||
if "high" in price_inputs:
|
||||
required_columns.append(
|
||||
ColumnInfo(name="high", type="float", description="High price")
|
||||
)
|
||||
if "low" in price_inputs:
|
||||
required_columns.append(
|
||||
ColumnInfo(name="low", type="float", description="Low price")
|
||||
)
|
||||
if "close" in price_inputs:
|
||||
required_columns.append(
|
||||
ColumnInfo(name="close", type="float", description="Closing price")
|
||||
)
|
||||
if "volume" in price_inputs:
|
||||
required_columns.append(
|
||||
ColumnInfo(name="volume", type="float", description="Trading volume")
|
||||
)
|
||||
|
||||
# Handle functions that take generic price arrays
|
||||
if "price" in input_names:
|
||||
required_columns.append(
|
||||
ColumnInfo(name="close", type="float", description="Price (typically close)")
|
||||
)
|
||||
|
||||
# If no specific inputs found, assume close price
|
||||
if not required_columns:
|
||||
required_columns.append(
|
||||
ColumnInfo(name="close", type="float", description="Closing price")
|
||||
)
|
||||
|
||||
return InputSchema(required_columns=required_columns)
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
"""Get output schema from TA-Lib function outputs."""
|
||||
if not TALIB_AVAILABLE:
|
||||
raise ImportError("TA-Lib is not installed")
|
||||
|
||||
func = abstract.Function(cls.talib_function_name)
|
||||
info = func.info
|
||||
output_names = info.get("output_names", [])
|
||||
|
||||
columns = []
|
||||
|
||||
# Most TA-Lib functions output one or more float arrays
|
||||
if isinstance(output_names, list):
|
||||
for output_name in output_names:
|
||||
columns.append(
|
||||
ColumnInfo(
|
||||
name=output_name.lower(),
|
||||
type="float",
|
||||
description=f"{cls.talib_function_name} output: {output_name}",
|
||||
nullable=True # TA-Lib often has NaN for initial periods
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Single output, use function name
|
||||
columns.append(
|
||||
ColumnInfo(
|
||||
name=cls.talib_function_name.lower(),
|
||||
type="float",
|
||||
description=f"{cls.talib_function_name} indicator value",
|
||||
nullable=True
|
||||
)
|
||||
)
|
||||
|
||||
return OutputSchema(columns=columns)
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
"""Compute indicator using TA-Lib."""
|
||||
# Extract input columns
|
||||
input_data = {}
|
||||
|
||||
# Get the function's expected inputs
|
||||
info = self._talib_func.info
|
||||
input_names = info.get("input_names", {})
|
||||
|
||||
# Prepare input arrays
|
||||
if "prices" in input_names:
|
||||
price_inputs = input_names["prices"]
|
||||
for price_type in price_inputs:
|
||||
column_data = context.get_column(price_type)
|
||||
# Convert to numpy array, replacing None with NaN
|
||||
input_data[price_type] = np.array(
|
||||
[float(v) if v is not None else np.nan for v in column_data]
|
||||
)
|
||||
elif "price" in input_names:
|
||||
# Generic price input, use close
|
||||
column_data = context.get_column("close")
|
||||
input_data["price"] = np.array(
|
||||
[float(v) if v is not None else np.nan for v in column_data]
|
||||
)
|
||||
else:
|
||||
# Default to close if no inputs specified
|
||||
column_data = context.get_column("close")
|
||||
input_data["close"] = np.array(
|
||||
[float(v) if v is not None else np.nan for v in column_data]
|
||||
)
|
||||
|
||||
# Set parameters on the function
|
||||
self._talib_func.parameters = self.params
|
||||
|
||||
# Execute TA-Lib function
|
||||
try:
|
||||
output = self._talib_func(input_data)
|
||||
except Exception as e:
|
||||
logger.error(f"TA-Lib function {self.talib_function_name} failed: {e}")
|
||||
raise ValueError(f"TA-Lib computation failed: {e}")
|
||||
|
||||
# Format output
|
||||
times = context.get_times()
|
||||
output_names = info.get("output_names", [])
|
||||
|
||||
# Handle single vs multiple outputs
|
||||
if isinstance(output, np.ndarray):
|
||||
# Single output
|
||||
output_name = output_names[0].lower() if output_names else self.talib_function_name.lower()
|
||||
result_data = [
|
||||
{
|
||||
"time": times[i],
|
||||
output_name: float(output[i]) if not np.isnan(output[i]) else None
|
||||
}
|
||||
for i in range(len(times))
|
||||
]
|
||||
elif isinstance(output, tuple):
|
||||
# Multiple outputs
|
||||
result_data = []
|
||||
for i in range(len(times)):
|
||||
row = {"time": times[i]}
|
||||
for j, output_array in enumerate(output):
|
||||
output_name = output_names[j].lower() if j < len(output_names) else f"output_{j}"
|
||||
row[output_name] = float(output_array[i]) if not np.isnan(output_array[i]) else None
|
||||
result_data.append(row)
|
||||
else:
|
||||
raise ValueError(f"Unexpected TA-Lib output type: {type(output)}")
|
||||
|
||||
return ComputeResult(
|
||||
data=result_data,
|
||||
is_partial=context.is_incremental
|
||||
)
|
||||
|
||||
|
||||
def create_talib_indicator_class(func_name: str) -> type:
|
||||
"""
|
||||
Dynamically create an Indicator class for a TA-Lib function.
|
||||
|
||||
Args:
|
||||
func_name: TA-Lib function name (e.g., 'SMA', 'RSI')
|
||||
|
||||
Returns:
|
||||
Indicator class for this function
|
||||
"""
|
||||
return type(
|
||||
f"TALib_{func_name}",
|
||||
(TALibIndicator,),
|
||||
{"talib_function_name": func_name}
|
||||
)
|
||||
|
||||
|
||||
def register_all_talib_indicators(registry, only_tradingview_supported: bool = True) -> int:
|
||||
"""
|
||||
Auto-register all available TA-Lib indicators with the registry.
|
||||
|
||||
Args:
|
||||
registry: IndicatorRegistry instance
|
||||
only_tradingview_supported: If True, only register indicators that have
|
||||
TradingView equivalents (default: True)
|
||||
|
||||
Returns:
|
||||
Number of indicators registered
|
||||
|
||||
Raises:
|
||||
ImportError: If TA-Lib is not installed
|
||||
"""
|
||||
if not TALIB_AVAILABLE:
|
||||
logger.warning(
|
||||
"TA-Lib is not installed. Skipping TA-Lib indicator registration. "
|
||||
"Install TA-Lib C library and Python wrapper to enable TA-Lib indicators."
|
||||
)
|
||||
return 0
|
||||
|
||||
# Get list of supported indicators if filtering is enabled
|
||||
from .tv_mapping import is_indicator_supported
|
||||
|
||||
# Get all TA-Lib functions
|
||||
func_groups = talib.get_function_groups()
|
||||
all_functions = []
|
||||
for group, functions in func_groups.items():
|
||||
all_functions.extend(functions)
|
||||
|
||||
# Remove duplicates
|
||||
all_functions = sorted(set(all_functions))
|
||||
|
||||
registered_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
for func_name in all_functions:
|
||||
try:
|
||||
# Skip if filtering enabled and indicator not supported in TradingView
|
||||
if only_tradingview_supported and not is_indicator_supported(func_name):
|
||||
skipped_count += 1
|
||||
logger.debug(f"Skipping TA-Lib function {func_name} - not supported in TradingView")
|
||||
continue
|
||||
|
||||
# Create indicator class for this function
|
||||
indicator_class = create_talib_indicator_class(func_name)
|
||||
|
||||
# Register with the registry
|
||||
registry.register(indicator_class)
|
||||
registered_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register TA-Lib function {func_name}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Registered {registered_count} TA-Lib indicators (skipped {skipped_count} unsupported)")
|
||||
return registered_count
|
||||
|
||||
|
||||
def get_talib_version() -> Optional[str]:
|
||||
"""
|
||||
Get the installed TA-Lib version.
|
||||
|
||||
Returns:
|
||||
Version string or None if not installed
|
||||
"""
|
||||
if TALIB_AVAILABLE:
|
||||
return talib.__version__
|
||||
return None
|
||||
|
||||
|
||||
def is_talib_available() -> bool:
|
||||
"""Check if TA-Lib is available."""
|
||||
return TALIB_AVAILABLE
|
||||
360
backend.old/src/indicator/tv_mapping.py
Normal file
360
backend.old/src/indicator/tv_mapping.py
Normal file
@@ -0,0 +1,360 @@
|
||||
"""
|
||||
Mapping layer between TA-Lib indicators and TradingView indicators.
|
||||
|
||||
This module provides bidirectional conversion between our internal TA-Lib-based
|
||||
indicator representation and TradingView's indicator system.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, Tuple, List
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Mapping of TA-Lib indicator names to TradingView indicator names
|
||||
# Only includes indicators that are present in BOTH systems (inner join)
|
||||
# Format: {talib_name: tv_name}
|
||||
TALIB_TO_TV_NAMES = {
|
||||
# Overlap Studies (14)
|
||||
"SMA": "Moving Average",
|
||||
"EMA": "Moving Average Exponential",
|
||||
"WMA": "Weighted Moving Average",
|
||||
"DEMA": "DEMA",
|
||||
"TEMA": "TEMA",
|
||||
"TRIMA": "Triangular Moving Average",
|
||||
"KAMA": "KAMA",
|
||||
"MAMA": "MESA Adaptive Moving Average",
|
||||
"T3": "T3",
|
||||
"BBANDS": "Bollinger Bands",
|
||||
"MIDPOINT": "Midpoint",
|
||||
"MIDPRICE": "Midprice",
|
||||
"SAR": "Parabolic SAR",
|
||||
"HT_TRENDLINE": "Hilbert Transform - Instantaneous Trendline",
|
||||
|
||||
# Momentum Indicators (21)
|
||||
"RSI": "Relative Strength Index",
|
||||
"MOM": "Momentum",
|
||||
"ROC": "Rate of Change",
|
||||
"TRIX": "TRIX",
|
||||
"CMO": "Chande Momentum Oscillator",
|
||||
"DX": "Directional Movement Index",
|
||||
"ADX": "Average Directional Movement Index",
|
||||
"ADXR": "Average Directional Movement Index Rating",
|
||||
"APO": "Absolute Price Oscillator",
|
||||
"PPO": "Percentage Price Oscillator",
|
||||
"MACD": "MACD",
|
||||
"MFI": "Money Flow Index",
|
||||
"STOCH": "Stochastic",
|
||||
"STOCHF": "Stochastic Fast",
|
||||
"STOCHRSI": "Stochastic RSI",
|
||||
"WILLR": "Williams %R",
|
||||
"CCI": "Commodity Channel Index",
|
||||
"AROON": "Aroon",
|
||||
"AROONOSC": "Aroon Oscillator",
|
||||
"BOP": "Balance Of Power",
|
||||
"ULTOSC": "Ultimate Oscillator",
|
||||
|
||||
# Volume Indicators (3)
|
||||
"AD": "Chaikin A/D Line",
|
||||
"ADOSC": "Chaikin A/D Oscillator",
|
||||
"OBV": "On Balance Volume",
|
||||
|
||||
# Volatility Indicators (3)
|
||||
"ATR": "Average True Range",
|
||||
"NATR": "Normalized Average True Range",
|
||||
"TRANGE": "True Range",
|
||||
|
||||
# Price Transform (4)
|
||||
"AVGPRICE": "Average Price",
|
||||
"MEDPRICE": "Median Price",
|
||||
"TYPPRICE": "Typical Price",
|
||||
"WCLPRICE": "Weighted Close Price",
|
||||
|
||||
# Cycle Indicators (5)
|
||||
"HT_DCPERIOD": "Hilbert Transform - Dominant Cycle Period",
|
||||
"HT_DCPHASE": "Hilbert Transform - Dominant Cycle Phase",
|
||||
"HT_PHASOR": "Hilbert Transform - Phasor Components",
|
||||
"HT_SINE": "Hilbert Transform - SineWave",
|
||||
"HT_TRENDMODE": "Hilbert Transform - Trend vs Cycle Mode",
|
||||
|
||||
# Statistic Functions (9)
|
||||
"BETA": "Beta",
|
||||
"CORREL": "Pearson's Correlation Coefficient",
|
||||
"LINEARREG": "Linear Regression",
|
||||
"LINEARREG_ANGLE": "Linear Regression Angle",
|
||||
"LINEARREG_INTERCEPT": "Linear Regression Intercept",
|
||||
"LINEARREG_SLOPE": "Linear Regression Slope",
|
||||
"STDDEV": "Standard Deviation",
|
||||
"TSF": "Time Series Forecast",
|
||||
"VAR": "Variance",
|
||||
}
|
||||
|
||||
# Total: 60 indicators supported in both systems
|
||||
|
||||
# Custom indicators (TradingView indicators implemented in our backend)
|
||||
CUSTOM_TO_TV_NAMES = {
|
||||
"VWAP": "VWAP",
|
||||
"VWMA": "VWMA",
|
||||
"HMA": "Hull Moving Average",
|
||||
"SUPERTREND": "SuperTrend",
|
||||
"DONCHIAN": "Donchian Channels",
|
||||
"KELTNER": "Keltner Channels",
|
||||
"CMF": "Chaikin Money Flow",
|
||||
"VORTEX": "Vortex Indicator",
|
||||
"AO": "Awesome Oscillator",
|
||||
"AC": "Accelerator Oscillator",
|
||||
"CHOP": "Choppiness Index",
|
||||
"MASS": "Mass Index",
|
||||
}
|
||||
|
||||
# Combined mapping (TA-Lib + Custom)
|
||||
ALL_BACKEND_TO_TV_NAMES = {**TALIB_TO_TV_NAMES, **CUSTOM_TO_TV_NAMES}
|
||||
|
||||
# Total: 72 indicators (60 TA-Lib + 12 Custom)
|
||||
|
||||
# Reverse mapping
|
||||
TV_TO_TALIB_NAMES = {v: k for k, v in TALIB_TO_TV_NAMES.items()}
|
||||
TV_TO_CUSTOM_NAMES = {v: k for k, v in CUSTOM_TO_TV_NAMES.items()}
|
||||
TV_TO_BACKEND_NAMES = {v: k for k, v in ALL_BACKEND_TO_TV_NAMES.items()}
|
||||
|
||||
|
||||
def get_tv_indicator_name(talib_name: str) -> str:
|
||||
"""
|
||||
Convert TA-Lib indicator name to TradingView indicator name.
|
||||
|
||||
Args:
|
||||
talib_name: TA-Lib indicator name (e.g., 'RSI')
|
||||
|
||||
Returns:
|
||||
TradingView indicator name
|
||||
"""
|
||||
return TALIB_TO_TV_NAMES.get(talib_name, talib_name)
|
||||
|
||||
|
||||
def get_talib_indicator_name(tv_name: str) -> Optional[str]:
|
||||
"""
|
||||
Convert TradingView indicator name to TA-Lib indicator name.
|
||||
|
||||
Args:
|
||||
tv_name: TradingView indicator name
|
||||
|
||||
Returns:
|
||||
TA-Lib indicator name or None if not mapped
|
||||
"""
|
||||
return TV_TO_TALIB_NAMES.get(tv_name)
|
||||
|
||||
|
||||
def convert_talib_params_to_tv_inputs(
|
||||
talib_name: str,
|
||||
talib_params: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert TA-Lib parameters to TradingView input format.
|
||||
|
||||
Args:
|
||||
talib_name: TA-Lib indicator name
|
||||
talib_params: TA-Lib parameter dictionary
|
||||
|
||||
Returns:
|
||||
TradingView inputs dictionary
|
||||
"""
|
||||
tv_inputs = {}
|
||||
|
||||
# Common parameter mappings
|
||||
param_mapping = {
|
||||
"timeperiod": "length",
|
||||
"fastperiod": "fastLength",
|
||||
"slowperiod": "slowLength",
|
||||
"signalperiod": "signalLength",
|
||||
"nbdevup": "mult", # Standard deviations for upper band
|
||||
"nbdevdn": "mult", # Standard deviations for lower band
|
||||
"fastlimit": "fastLimit",
|
||||
"slowlimit": "slowLimit",
|
||||
"acceleration": "start",
|
||||
"maximum": "increment",
|
||||
"fastk_period": "kPeriod",
|
||||
"slowk_period": "kPeriod",
|
||||
"slowd_period": "dPeriod",
|
||||
"fastd_period": "dPeriod",
|
||||
"matype": "maType",
|
||||
}
|
||||
|
||||
# Special handling for specific indicators
|
||||
if talib_name == "BBANDS":
|
||||
# Bollinger Bands
|
||||
tv_inputs["length"] = talib_params.get("timeperiod", 20)
|
||||
tv_inputs["mult"] = talib_params.get("nbdevup", 2)
|
||||
tv_inputs["source"] = "close"
|
||||
elif talib_name == "MACD":
|
||||
# MACD
|
||||
tv_inputs["fastLength"] = talib_params.get("fastperiod", 12)
|
||||
tv_inputs["slowLength"] = talib_params.get("slowperiod", 26)
|
||||
tv_inputs["signalLength"] = talib_params.get("signalperiod", 9)
|
||||
tv_inputs["source"] = "close"
|
||||
elif talib_name == "RSI":
|
||||
# RSI
|
||||
tv_inputs["length"] = talib_params.get("timeperiod", 14)
|
||||
tv_inputs["source"] = "close"
|
||||
elif talib_name in ["SMA", "EMA", "WMA", "DEMA", "TEMA", "TRIMA"]:
|
||||
# Moving averages
|
||||
tv_inputs["length"] = talib_params.get("timeperiod", 14)
|
||||
tv_inputs["source"] = "close"
|
||||
elif talib_name == "STOCH":
|
||||
# Stochastic
|
||||
tv_inputs["kPeriod"] = talib_params.get("fastk_period", 14)
|
||||
tv_inputs["dPeriod"] = talib_params.get("slowd_period", 3)
|
||||
tv_inputs["smoothK"] = talib_params.get("slowk_period", 3)
|
||||
elif talib_name == "ATR":
|
||||
# ATR
|
||||
tv_inputs["length"] = talib_params.get("timeperiod", 14)
|
||||
elif talib_name == "CCI":
|
||||
# CCI
|
||||
tv_inputs["length"] = talib_params.get("timeperiod", 20)
|
||||
else:
|
||||
# Generic parameter conversion
|
||||
for talib_param, value in talib_params.items():
|
||||
tv_param = param_mapping.get(talib_param, talib_param)
|
||||
tv_inputs[tv_param] = value
|
||||
|
||||
logger.debug(f"Converted TA-Lib params for {talib_name}: {talib_params} -> TV inputs: {tv_inputs}")
|
||||
return tv_inputs
|
||||
|
||||
|
||||
def convert_tv_inputs_to_talib_params(
|
||||
tv_name: str,
|
||||
tv_inputs: Dict[str, Any]
|
||||
) -> Tuple[Optional[str], Dict[str, Any]]:
|
||||
"""
|
||||
Convert TradingView inputs to TA-Lib parameters.
|
||||
|
||||
Args:
|
||||
tv_name: TradingView indicator name
|
||||
tv_inputs: TradingView inputs dictionary
|
||||
|
||||
Returns:
|
||||
Tuple of (talib_name, talib_params)
|
||||
"""
|
||||
talib_name = get_talib_indicator_name(tv_name)
|
||||
if not talib_name:
|
||||
logger.warning(f"No TA-Lib mapping for TradingView indicator: {tv_name}")
|
||||
return None, {}
|
||||
|
||||
talib_params = {}
|
||||
|
||||
# Reverse parameter mappings
|
||||
reverse_mapping = {
|
||||
"length": "timeperiod",
|
||||
"fastLength": "fastperiod",
|
||||
"slowLength": "slowperiod",
|
||||
"signalLength": "signalperiod",
|
||||
"mult": "nbdevup", # Use same for both up and down
|
||||
"fastLimit": "fastlimit",
|
||||
"slowLimit": "slowlimit",
|
||||
"start": "acceleration",
|
||||
"increment": "maximum",
|
||||
"kPeriod": "fastk_period",
|
||||
"dPeriod": "slowd_period",
|
||||
"smoothK": "slowk_period",
|
||||
"maType": "matype",
|
||||
}
|
||||
|
||||
# Special handling for specific indicators
|
||||
if talib_name == "BBANDS":
|
||||
# Bollinger Bands
|
||||
talib_params["timeperiod"] = tv_inputs.get("length", 20)
|
||||
talib_params["nbdevup"] = tv_inputs.get("mult", 2)
|
||||
talib_params["nbdevdn"] = tv_inputs.get("mult", 2)
|
||||
talib_params["matype"] = 0 # SMA
|
||||
elif talib_name == "MACD":
|
||||
# MACD
|
||||
talib_params["fastperiod"] = tv_inputs.get("fastLength", 12)
|
||||
talib_params["slowperiod"] = tv_inputs.get("slowLength", 26)
|
||||
talib_params["signalperiod"] = tv_inputs.get("signalLength", 9)
|
||||
elif talib_name == "RSI":
|
||||
# RSI
|
||||
talib_params["timeperiod"] = tv_inputs.get("length", 14)
|
||||
elif talib_name in ["SMA", "EMA", "WMA", "DEMA", "TEMA", "TRIMA"]:
|
||||
# Moving averages
|
||||
talib_params["timeperiod"] = tv_inputs.get("length", 14)
|
||||
elif talib_name == "STOCH":
|
||||
# Stochastic
|
||||
talib_params["fastk_period"] = tv_inputs.get("kPeriod", 14)
|
||||
talib_params["slowd_period"] = tv_inputs.get("dPeriod", 3)
|
||||
talib_params["slowk_period"] = tv_inputs.get("smoothK", 3)
|
||||
talib_params["slowk_matype"] = 0 # SMA
|
||||
talib_params["slowd_matype"] = 0 # SMA
|
||||
elif talib_name == "ATR":
|
||||
# ATR
|
||||
talib_params["timeperiod"] = tv_inputs.get("length", 14)
|
||||
elif talib_name == "CCI":
|
||||
# CCI
|
||||
talib_params["timeperiod"] = tv_inputs.get("length", 20)
|
||||
else:
|
||||
# Generic parameter conversion
|
||||
for tv_param, value in tv_inputs.items():
|
||||
if tv_param == "source":
|
||||
continue # Skip source parameter
|
||||
talib_param = reverse_mapping.get(tv_param, tv_param)
|
||||
talib_params[talib_param] = value
|
||||
|
||||
logger.debug(f"Converted TV inputs for {tv_name}: {tv_inputs} -> TA-Lib {talib_name} params: {talib_params}")
|
||||
return talib_name, talib_params
|
||||
|
||||
|
||||
def is_indicator_supported(talib_name: str) -> bool:
|
||||
"""
|
||||
Check if a TA-Lib indicator is supported in TradingView.
|
||||
|
||||
Args:
|
||||
talib_name: TA-Lib indicator name
|
||||
|
||||
Returns:
|
||||
True if supported
|
||||
"""
|
||||
return talib_name in TALIB_TO_TV_NAMES
|
||||
|
||||
|
||||
def get_supported_indicators() -> List[str]:
|
||||
"""
|
||||
Get list of supported TA-Lib indicators.
|
||||
|
||||
Returns:
|
||||
List of TA-Lib indicator names
|
||||
"""
|
||||
return list(TALIB_TO_TV_NAMES.keys())
|
||||
|
||||
|
||||
def get_supported_indicator_count() -> int:
|
||||
"""
|
||||
Get count of supported indicators.
|
||||
|
||||
Returns:
|
||||
Number of indicators supported in both systems (TA-Lib + Custom)
|
||||
"""
|
||||
return len(ALL_BACKEND_TO_TV_NAMES)
|
||||
|
||||
|
||||
def is_custom_indicator(indicator_name: str) -> bool:
|
||||
"""
|
||||
Check if an indicator is a custom implementation (not TA-Lib).
|
||||
|
||||
Args:
|
||||
indicator_name: Indicator name
|
||||
|
||||
Returns:
|
||||
True if custom indicator
|
||||
"""
|
||||
return indicator_name in CUSTOM_TO_TV_NAMES
|
||||
|
||||
|
||||
def get_backend_indicator_name(tv_name: str) -> Optional[str]:
|
||||
"""
|
||||
Get backend indicator name from TradingView name (TA-Lib or custom).
|
||||
|
||||
Args:
|
||||
tv_name: TradingView indicator name
|
||||
|
||||
Returns:
|
||||
Backend indicator name or None if not mapped
|
||||
"""
|
||||
return TV_TO_BACKEND_NAMES.get(tv_name)
|
||||
@@ -20,13 +20,19 @@ from gateway.hub import Gateway
|
||||
from gateway.channels.websocket import WebSocketChannel
|
||||
from gateway.protocol import WebSocketAgentUserMessage
|
||||
from agent.core import create_agent
|
||||
from agent.tools import set_registry, set_datasource_registry
|
||||
from agent.tools import set_registry, set_datasource_registry, set_indicator_registry
|
||||
from agent.tools import set_trigger_queue, set_trigger_scheduler, set_coordinator
|
||||
from schema.order_spec import SwapOrder
|
||||
from schema.chart_state import ChartState
|
||||
from schema.shape import ShapeCollection
|
||||
from schema.indicator import IndicatorCollection
|
||||
from datasource.registry import DataSourceRegistry
|
||||
from datasource.subscription_manager import SubscriptionManager
|
||||
from datasource.websocket_handler import DatafeedWebSocketHandler
|
||||
from secrets_manager import SecretsStore, InvalidMasterPassword
|
||||
from indicator import IndicatorRegistry, register_all_talib_indicators, register_custom_indicators
|
||||
from trigger import CommitCoordinator, TriggerQueue
|
||||
from trigger.scheduler import TriggerScheduler
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@@ -53,6 +59,14 @@ agent_executor = None
|
||||
datasource_registry = DataSourceRegistry()
|
||||
subscription_manager = SubscriptionManager()
|
||||
|
||||
# Indicator infrastructure
|
||||
indicator_registry = IndicatorRegistry()
|
||||
|
||||
# Trigger system infrastructure
|
||||
trigger_coordinator = None
|
||||
trigger_queue = None
|
||||
trigger_scheduler = None
|
||||
|
||||
# Global secrets store
|
||||
secrets_store = SecretsStore()
|
||||
|
||||
@@ -60,7 +74,7 @@ secrets_store = SecretsStore()
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Initialize agent system and data sources on startup."""
|
||||
global agent_executor
|
||||
global agent_executor, trigger_coordinator, trigger_queue, trigger_scheduler
|
||||
|
||||
# Initialize CCXT data sources
|
||||
try:
|
||||
@@ -80,6 +94,21 @@ async def lifespan(app: FastAPI):
|
||||
logger.warning(f"CCXT not available: {e}. Only demo source will be available.")
|
||||
logger.info("To use real exchange data, install ccxt: pip install ccxt>=4.0.0")
|
||||
|
||||
# Initialize indicator registry with all TA-Lib indicators
|
||||
try:
|
||||
indicator_count = register_all_talib_indicators(indicator_registry)
|
||||
logger.info(f"Indicator registry initialized with {indicator_count} TA-Lib indicators")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register TA-Lib indicators: {e}")
|
||||
logger.info("TA-Lib indicators will not be available. Install TA-Lib C library and Python wrapper to enable.")
|
||||
|
||||
# Register custom indicators (TradingView indicators not in TA-Lib)
|
||||
try:
|
||||
custom_count = register_custom_indicators(indicator_registry)
|
||||
logger.info(f"Registered {custom_count} custom indicators")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register custom indicators: {e}")
|
||||
|
||||
# Get API keys from secrets store if unlocked, otherwise fall back to environment
|
||||
anthropic_api_key = None
|
||||
|
||||
@@ -94,6 +123,22 @@ async def lifespan(app: FastAPI):
|
||||
if anthropic_api_key:
|
||||
logger.info("Loaded API key from environment")
|
||||
|
||||
# Initialize trigger system
|
||||
logger.info("Initializing trigger system...")
|
||||
trigger_coordinator = CommitCoordinator()
|
||||
trigger_queue = TriggerQueue(trigger_coordinator)
|
||||
trigger_scheduler = TriggerScheduler(trigger_queue)
|
||||
|
||||
# Start trigger queue and scheduler
|
||||
await trigger_queue.start()
|
||||
trigger_scheduler.start()
|
||||
logger.info("Trigger system initialized and started")
|
||||
|
||||
# Set trigger system for agent tools
|
||||
set_coordinator(trigger_coordinator)
|
||||
set_trigger_queue(trigger_queue)
|
||||
set_trigger_scheduler(trigger_scheduler)
|
||||
|
||||
if not anthropic_api_key:
|
||||
logger.error("ANTHROPIC_API_KEY not found in environment!")
|
||||
logger.info("Agent system will not be available")
|
||||
@@ -101,6 +146,7 @@ async def lifespan(app: FastAPI):
|
||||
# Set the registries for agent tools
|
||||
set_registry(registry)
|
||||
set_datasource_registry(datasource_registry)
|
||||
set_indicator_registry(indicator_registry)
|
||||
|
||||
# Create and initialize agent
|
||||
agent_executor = create_agent(
|
||||
@@ -111,7 +157,7 @@ async def lifespan(app: FastAPI):
|
||||
chroma_db_path=config["memory"]["chroma_db"],
|
||||
embedding_model=config["memory"]["embedding_model"],
|
||||
context_docs_dir=config["agent"]["context_docs_dir"],
|
||||
base_dir=".." # Point to project root from backend/src
|
||||
base_dir="." # backend/src is the working directory, so . goes to backend, where memory/ and soul/ live
|
||||
)
|
||||
|
||||
await agent_executor.initialize()
|
||||
@@ -124,9 +170,22 @@ async def lifespan(app: FastAPI):
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
logger.info("Shutting down systems...")
|
||||
|
||||
# Shutdown trigger system
|
||||
if trigger_scheduler:
|
||||
trigger_scheduler.shutdown(wait=True)
|
||||
logger.info("Trigger scheduler shut down")
|
||||
|
||||
if trigger_queue:
|
||||
await trigger_queue.stop()
|
||||
logger.info("Trigger queue stopped")
|
||||
|
||||
# Shutdown agent system
|
||||
if agent_executor and agent_executor.memory_manager:
|
||||
await agent_executor.memory_manager.close()
|
||||
logger.info("Agent system shut down")
|
||||
|
||||
logger.info("All systems shut down")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
@@ -146,13 +205,25 @@ class OrderStore(BaseModel):
|
||||
class ChartStore(BaseModel):
|
||||
chart_state: ChartState = ChartState()
|
||||
|
||||
# ShapeStore model for synchronization
|
||||
class ShapeStore(BaseModel):
|
||||
shapes: dict[str, dict] = {} # Dictionary of shapes keyed by ID
|
||||
|
||||
# IndicatorStore model for synchronization
|
||||
class IndicatorStore(BaseModel):
|
||||
indicators: dict[str, dict] = {} # Dictionary of indicators keyed by ID
|
||||
|
||||
# Initialize stores
|
||||
order_store = OrderStore()
|
||||
chart_store = ChartStore()
|
||||
shape_store = ShapeStore()
|
||||
indicator_store = IndicatorStore()
|
||||
|
||||
# Register with SyncRegistry
|
||||
registry.register(order_store, store_name="OrderStore")
|
||||
registry.register(chart_store, store_name="ChartStore")
|
||||
registry.register(shape_store, store_name="ShapeStore")
|
||||
registry.register(indicator_store, store_name="IndicatorStore")
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
@@ -348,11 +419,14 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
elif msg_type == "patch":
|
||||
patch_msg = PatchMessage(**message_json)
|
||||
logger.info(f"Patch message received for store: {patch_msg.store}, seq: {patch_msg.seq}")
|
||||
await registry.apply_client_patch(
|
||||
store_name=patch_msg.store,
|
||||
client_base_seq=patch_msg.seq,
|
||||
patch=patch_msg.patch
|
||||
)
|
||||
try:
|
||||
await registry.apply_client_patch(
|
||||
store_name=patch_msg.store,
|
||||
client_base_seq=patch_msg.seq,
|
||||
patch=patch_msg.patch
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying client patch: {e}. Client will receive snapshot to resync.", exc_info=True)
|
||||
elif msg_type == "agent_user_message":
|
||||
# Handle agent messages directly here
|
||||
print(f"[DEBUG] Raw message_json: {message_json}")
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -7,10 +7,13 @@ class ChartState(BaseModel):
|
||||
|
||||
This state is synchronized between the frontend and backend to allow
|
||||
the AI agent to understand what the user is currently viewing.
|
||||
|
||||
All fields can be None when no chart is visible (e.g., on mobile/narrow screens).
|
||||
"""
|
||||
|
||||
# Current symbol being viewed (e.g., "BINANCE:BTC/USDT", "BINANCE:ETH/USDT")
|
||||
symbol: str = Field(default="BINANCE:BTC/USDT", description="Current trading pair symbol")
|
||||
# None when chart is not visible
|
||||
symbol: Optional[str] = Field(default="BINANCE:BTC/USDT", description="Current trading pair symbol, or None if no chart visible")
|
||||
|
||||
# Time range currently visible on chart (Unix timestamps in seconds)
|
||||
# These represent the leftmost and rightmost visible candle times
|
||||
@@ -18,4 +21,8 @@ class ChartState(BaseModel):
|
||||
end_time: Optional[int] = Field(default=None, description="End time of visible range (Unix timestamp in seconds)")
|
||||
|
||||
# Optional: Chart interval/resolution
|
||||
interval: str = Field(default="15", description="Chart interval (e.g., '1', '5', '15', '60', 'D')")
|
||||
# None when chart is not visible
|
||||
interval: Optional[str] = Field(default="15", description="Chart interval (e.g., '1', '5', '15', '60', 'D'), or None if no chart visible")
|
||||
|
||||
# Selected shapes/drawings on the chart
|
||||
selected_shapes: List[str] = Field(default_factory=list, description="Array of selected shape IDs")
|
||||
40
backend.old/src/schema/indicator.py
Normal file
40
backend.old/src/schema/indicator.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import Dict, Any, Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class IndicatorInstance(BaseModel):
|
||||
"""
|
||||
Represents an instance of an indicator applied to a chart.
|
||||
|
||||
This schema holds both the TA-Lib metadata and TradingView-specific data
|
||||
needed for synchronization.
|
||||
"""
|
||||
id: str = Field(..., description="Unique identifier for this indicator instance")
|
||||
|
||||
# TA-Lib metadata
|
||||
talib_name: str = Field(..., description="TA-Lib indicator name (e.g., 'RSI', 'SMA', 'MACD')")
|
||||
instance_name: str = Field(..., description="User-friendly instance name")
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="TA-Lib indicator parameters")
|
||||
|
||||
# TradingView metadata
|
||||
tv_study_id: Optional[str] = Field(default=None, description="TradingView study ID assigned by the chart widget")
|
||||
tv_indicator_name: Optional[str] = Field(default=None, description="TradingView indicator name if different from TA-Lib")
|
||||
tv_inputs: Optional[Dict[str, Any]] = Field(default=None, description="TradingView-specific input parameters")
|
||||
|
||||
# Visual properties
|
||||
visible: bool = Field(default=True, description="Whether indicator is visible on chart")
|
||||
pane: str = Field(default="chart", description="Pane where indicator is displayed ('chart' or 'separate')")
|
||||
|
||||
# Metadata
|
||||
symbol: Optional[str] = Field(default=None, description="Symbol this indicator is applied to")
|
||||
created_at: Optional[int] = Field(default=None, description="Creation timestamp (Unix seconds)")
|
||||
modified_at: Optional[int] = Field(default=None, description="Last modification timestamp (Unix seconds)")
|
||||
original_id: Optional[str] = Field(default=None, description="Original ID from backend before TradingView assigns its own ID")
|
||||
|
||||
|
||||
class IndicatorCollection(BaseModel):
|
||||
"""Collection of all indicator instances on the chart."""
|
||||
indicators: Dict[str, IndicatorInstance] = Field(
|
||||
default_factory=dict,
|
||||
description="Dictionary of indicator instances keyed by ID"
|
||||
)
|
||||
327
backend.old/src/schema/order_spec.py
Normal file
327
backend.old/src/schema/order_spec.py
Normal file
@@ -0,0 +1,327 @@
|
||||
from enum import StrEnum
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel, Field, BeforeValidator, PlainSerializer
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scalar coercion helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _to_int(v: int | str) -> int:
|
||||
return int(v, 0) if isinstance(v, str) else int(v)
|
||||
|
||||
|
||||
def _to_float(v: float | int | str) -> float:
|
||||
return float(v)
|
||||
|
||||
|
||||
_int_to_str = PlainSerializer(str, return_type=str, when_used="json")
|
||||
_float_to_str = PlainSerializer(str, return_type=str, when_used="json")
|
||||
|
||||
# Always stored as Python int; accepts int or string on input; serialises to string in JSON.
|
||||
type Uint8 = Annotated[int, BeforeValidator(_to_int), _int_to_str]
|
||||
type Uint16 = Annotated[int, BeforeValidator(_to_int), _int_to_str]
|
||||
type Uint24 = Annotated[int, BeforeValidator(_to_int), _int_to_str]
|
||||
type Uint32 = Annotated[int, BeforeValidator(_to_int), _int_to_str]
|
||||
type Uint64 = Annotated[int, BeforeValidator(_to_int), _int_to_str]
|
||||
type Uint256 = Annotated[int, BeforeValidator(_to_int), _int_to_str]
|
||||
type Float = Annotated[float, BeforeValidator(_to_float), _float_to_str]
|
||||
|
||||
ETH_ADDRESS_PATTERN = r"^0x[0-9a-fA-F]{40}$"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Enums
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Exchange(StrEnum):
|
||||
UNISWAP_V2 = "UniswapV2"
|
||||
UNISWAP_V3 = "UniswapV3"
|
||||
|
||||
|
||||
class Side(StrEnum):
|
||||
"""Order side: buy or sell"""
|
||||
BUY = "BUY"
|
||||
SELL = "SELL"
|
||||
|
||||
|
||||
class AmountType(StrEnum):
|
||||
"""Whether the order amount refers to base or quote currency"""
|
||||
BASE = "BASE" # Amount is in base currency (e.g., BTC in BTC/USD)
|
||||
QUOTE = "QUOTE" # Amount is in quote currency (e.g., USD in BTC/USD)
|
||||
|
||||
|
||||
class TimeInForce(StrEnum):
|
||||
"""Order lifetime specification"""
|
||||
GTC = "GTC" # Good Till Cancel
|
||||
IOC = "IOC" # Immediate or Cancel
|
||||
FOK = "FOK" # Fill or Kill
|
||||
DAY = "DAY" # Good for trading day
|
||||
GTD = "GTD" # Good Till Date
|
||||
|
||||
|
||||
class ConditionalOrderMode(StrEnum):
|
||||
"""How conditional orders behave on partial fills"""
|
||||
NEW_PER_FILL = "NEW_PER_FILL" # Create new conditional order per each fill
|
||||
UNIFIED_ADJUSTING = "UNIFIED_ADJUSTING" # Single conditional order that adjusts amount
|
||||
|
||||
|
||||
class TriggerType(StrEnum):
|
||||
"""Type of conditional trigger"""
|
||||
STOP_LOSS = "STOP_LOSS"
|
||||
TAKE_PROFIT = "TAKE_PROFIT"
|
||||
STOP_LIMIT = "STOP_LIMIT"
|
||||
TRAILING_STOP = "TRAILING_STOP"
|
||||
|
||||
|
||||
class TickSpacingMode(StrEnum):
|
||||
"""How price tick spacing is determined"""
|
||||
FIXED = "FIXED" # Fixed tick size
|
||||
DYNAMIC = "DYNAMIC" # Tick size varies by price level
|
||||
CONTINUOUS = "CONTINUOUS" # No tick restrictions
|
||||
|
||||
|
||||
class AssetType(StrEnum):
|
||||
"""Type of tradeable asset"""
|
||||
SPOT = "SPOT" # Spot/cash market
|
||||
MARGIN = "MARGIN" # Margin trading
|
||||
PERP = "PERP" # Perpetual futures
|
||||
FUTURE = "FUTURE" # Dated futures
|
||||
OPTION = "OPTION" # Options
|
||||
SYNTHETIC = "SYNTHETIC" # Synthetic/derived instruments
|
||||
|
||||
|
||||
class OcoMode(StrEnum):
|
||||
NO_OCO = "NO_OCO"
|
||||
CANCEL_ON_PARTIAL_FILL = "CANCEL_ON_PARTIAL_FILL"
|
||||
CANCEL_ON_COMPLETION = "CANCEL_ON_COMPLETION"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Supporting models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Route(BaseModel):
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
exchange: Exchange
|
||||
fee: Uint24 = Field(description="Pool fee tier; also used as maxFee on UniswapV3")
|
||||
|
||||
|
||||
class Line(BaseModel):
|
||||
"""Price line: price = intercept + slope * time. Both zero means line is disabled."""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
intercept: Float
|
||||
slope: Float
|
||||
|
||||
|
||||
class Tranche(BaseModel):
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
fraction: Uint16 = Field(description="Fraction of total order amount; MAX_FRACTION (65535) = 100%")
|
||||
startTimeIsRelative: bool
|
||||
endTimeIsRelative: bool
|
||||
minIsBarrier: bool = Field(description="Not yet supported")
|
||||
maxIsBarrier: bool = Field(description="Not yet supported")
|
||||
marketOrder: bool = Field(
|
||||
description="If true, min/max lines ignored; minLine intercept treated as max slippage"
|
||||
)
|
||||
minIsRatio: bool
|
||||
maxIsRatio: bool
|
||||
rateLimitFraction: Uint16 = Field(description="Max fraction of this tranche's amount per rate-limited execution")
|
||||
rateLimitPeriod: Uint24 = Field(description="Seconds between rate limit resets")
|
||||
startTime: Uint32 = Field(description="Unix timestamp; 0 (DISTANT_PAST) effectively disables")
|
||||
endTime: Uint32 = Field(description="Unix timestamp; 4294967295 (DISTANT_FUTURE) effectively disables")
|
||||
minLine: Line = Field(description="Traditional limit order constraint; can be diagonal")
|
||||
maxLine: Line = Field(description="Upper price boundary (too-good-a-price guard)")
|
||||
|
||||
|
||||
class TrancheStatus(BaseModel):
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
filled: Uint256 = Field(description="Amount filled by this tranche")
|
||||
activationTime: Uint32 = Field(description="Earliest time this tranche can execute; 0 = not yet concrete")
|
||||
startTime: Uint32 = Field(description="Concrete start timestamp")
|
||||
endTime: Uint32 = Field(description="Concrete end timestamp")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Standard Order Models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ConditionalTrigger(BaseModel):
|
||||
"""Conditional order trigger (stop-loss, take-profit, etc.)"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
trigger_type: TriggerType
|
||||
trigger_price: Float = Field(description="Price at which conditional order activates")
|
||||
trailing_delta: Float | None = Field(default=None, description="For trailing stops: delta from peak/trough")
|
||||
|
||||
|
||||
class AmountConstraints(BaseModel):
|
||||
"""Constraints on order amounts for a symbol"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
min_amount: Float = Field(description="Minimum order amount")
|
||||
max_amount: Float = Field(description="Maximum order amount")
|
||||
step_size: Float = Field(description="Amount increment granularity")
|
||||
|
||||
|
||||
class PriceConstraints(BaseModel):
|
||||
"""Constraints on order pricing for a symbol"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
tick_spacing_mode: TickSpacingMode
|
||||
tick_size: Float | None = Field(default=None, description="Fixed tick size (if FIXED mode)")
|
||||
min_price: Float | None = Field(default=None, description="Minimum allowed price")
|
||||
max_price: Float | None = Field(default=None, description="Maximum allowed price")
|
||||
|
||||
|
||||
class MarketCapabilities(BaseModel):
|
||||
"""Describes what order features a market supports"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
supported_sides: list[Side] = Field(description="Supported order sides (usually both)")
|
||||
supported_amount_types: list[AmountType] = Field(description="Whether BASE, QUOTE, or both amounts are supported")
|
||||
supports_market_orders: bool = Field(description="Whether market orders are supported")
|
||||
supports_limit_orders: bool = Field(description="Whether limit orders are supported")
|
||||
supported_time_in_force: list[TimeInForce] = Field(description="Supported order lifetimes")
|
||||
supports_conditional_orders: bool = Field(description="Whether stop-loss/take-profit are supported")
|
||||
supported_trigger_types: list[TriggerType] = Field(default_factory=list, description="Supported trigger types")
|
||||
supports_post_only: bool = Field(default=False, description="Whether post-only orders are supported")
|
||||
supports_reduce_only: bool = Field(default=False, description="Whether reduce-only orders are supported")
|
||||
supports_iceberg: bool = Field(default=False, description="Whether iceberg orders are supported")
|
||||
market_order_amount_type: AmountType | None = Field(
|
||||
default=None,
|
||||
description="Required amount type for market orders (some DEXs require exact-in)"
|
||||
)
|
||||
|
||||
|
||||
class SymbolMetadata(BaseModel):
|
||||
"""Complete metadata describing a tradeable symbol/market"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
symbol_id: str = Field(description="Unique symbol identifier")
|
||||
base_asset: str = Field(description="Base asset (e.g., 'BTC')")
|
||||
quote_asset: str = Field(description="Quote asset (e.g., 'USD')")
|
||||
asset_type: AssetType = Field(description="Type of market")
|
||||
exchange: str = Field(description="Exchange identifier")
|
||||
|
||||
amount_constraints: AmountConstraints
|
||||
price_constraints: PriceConstraints
|
||||
capabilities: MarketCapabilities
|
||||
|
||||
contract_size: Float | None = Field(default=None, description="For futures/options: contract multiplier")
|
||||
settlement_asset: str | None = Field(default=None, description="For derivatives: settlement currency")
|
||||
expiry_timestamp: Uint64 | None = Field(default=None, description="For dated futures/options: expiration")
|
||||
|
||||
|
||||
class StandardOrder(BaseModel):
|
||||
"""Standard order specification for exchange kernels"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
symbol_id: str = Field(description="Symbol to trade")
|
||||
side: Side = Field(description="Buy or sell")
|
||||
amount: Float = Field(description="Order amount")
|
||||
amount_type: AmountType = Field(description="Whether amount is BASE or QUOTE currency")
|
||||
|
||||
limit_price: Float | None = Field(default=None, description="Limit price (None = market order)")
|
||||
time_in_force: TimeInForce = Field(default=TimeInForce.GTC, description="Order lifetime")
|
||||
good_till_date: Uint64 | None = Field(default=None, description="Expiry timestamp for GTD orders")
|
||||
|
||||
conditional_trigger: ConditionalTrigger | None = Field(
|
||||
default=None,
|
||||
description="Stop-loss/take-profit trigger"
|
||||
)
|
||||
conditional_mode: ConditionalOrderMode | None = Field(
|
||||
default=None,
|
||||
description="How conditional orders behave on partial fills"
|
||||
)
|
||||
|
||||
reduce_only: bool = Field(default=False, description="Only reduce existing position")
|
||||
post_only: bool = Field(default=False, description="Only make, never take")
|
||||
iceberg_qty: Float | None = Field(default=None, description="Visible amount for iceberg orders")
|
||||
|
||||
client_order_id: str | None = Field(default=None, description="Client-specified order ID")
|
||||
|
||||
|
||||
class StandardOrderStatus(BaseModel):
|
||||
"""Current status of a standard order"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
order: StandardOrder
|
||||
order_id: str = Field(description="Exchange-assigned order ID")
|
||||
status: str = Field(description="Order status: NEW, PARTIALLY_FILLED, FILLED, CANCELED, REJECTED, EXPIRED")
|
||||
filled_amount: Float = Field(description="Amount filled so far")
|
||||
average_fill_price: Float = Field(description="Average execution price")
|
||||
created_at: Uint64 = Field(description="Order creation timestamp")
|
||||
updated_at: Uint64 = Field(description="Last update timestamp")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Order models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class SwapOrder(BaseModel):
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
tokenIn: str = Field(pattern=ETH_ADDRESS_PATTERN, description="ERC-20 input token address")
|
||||
tokenOut: str = Field(pattern=ETH_ADDRESS_PATTERN, description="ERC-20 output token address")
|
||||
route: Route
|
||||
amount: Uint256 = Field(description="Maximum quantity to fill")
|
||||
minFillAmount: Uint256 = Field(description="Minimum tranche amount before tranche is considered complete")
|
||||
amountIsInput: bool = Field(description="true = amount is tokenIn quantity; false = tokenOut")
|
||||
outputDirectlyToOwner: bool = Field(description="true = proceeds go to vault owner; false = vault")
|
||||
inverted: bool = Field(description="false = tokenIn/tokenOut price direction (Uniswap natural)")
|
||||
conditionalOrder: Uint64 = Field(
|
||||
description="NO_CONDITIONAL_ORDER = 2^64-1; high bit set = relative index within placement group"
|
||||
)
|
||||
tranches: list[Tranche] = Field(min_length=1)
|
||||
|
||||
|
||||
class StandardOrderGroup(BaseModel):
|
||||
"""Group of orders with OCO (One-Cancels-Other) relationship"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
mode: OcoMode
|
||||
orders: list[StandardOrder] = Field(min_length=1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Legacy swap order models (kept for backward compatibility)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class OcoGroup(BaseModel):
|
||||
"""DEPRECATED: Use StandardOrderGroup instead"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
mode: OcoMode
|
||||
orders: list[SwapOrder] = Field(min_length=1)
|
||||
|
||||
|
||||
class SwapOrderStatus(BaseModel):
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
order: SwapOrder
|
||||
fillFeeHalfBps: Uint8 = Field(description="Fill fee in half-bps (1/20000); max 255 = 1.275%")
|
||||
canceled: bool = Field(description="If true, order is canceled regardless of cancelAllIndex")
|
||||
startTime: Uint32 = Field(description="Earliest block.timestamp at which order may execute")
|
||||
ocoGroup: Uint64 = Field(description="Index into ocoGroups; NO_OCO_INDEX = 2^64-1")
|
||||
originalOrder: Uint64 = Field(description="Index of the original order in the orders array")
|
||||
startPrice: Uint256 = Field(description="Price at order start")
|
||||
filled: Uint256 = Field(description="Total amount filled so far")
|
||||
trancheStatus: list[TrancheStatus]
|
||||
|
||||
|
||||
44
backend.old/src/schema/shape.py
Normal file
44
backend.old/src/schema/shape.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ControlPoint(BaseModel):
|
||||
"""A control point for a drawing shape.
|
||||
|
||||
Control points define the position and properties of a shape.
|
||||
Different shapes have different numbers of control points.
|
||||
"""
|
||||
time: int = Field(..., description="Unix timestamp in seconds")
|
||||
price: float = Field(..., description="Price level")
|
||||
# Optional channel for multi-point shapes (e.g., parallel channels)
|
||||
channel: Optional[str] = Field(default=None, description="Channel identifier for multi-point shapes")
|
||||
|
||||
|
||||
class Shape(BaseModel):
|
||||
"""A TradingView drawing shape/study.
|
||||
|
||||
Represents any drawing the user creates on the chart (trendlines,
|
||||
horizontal lines, rectangles, Fibonacci retracements, etc.)
|
||||
"""
|
||||
id: str = Field(..., description="Unique identifier for the shape")
|
||||
type: str = Field(..., description="Shape type (e.g., 'trendline', 'horizontal_line', 'rectangle', 'fibonacci')")
|
||||
points: List[ControlPoint] = Field(default_factory=list, description="Control points that define the shape")
|
||||
|
||||
# Visual properties
|
||||
color: Optional[str] = Field(default=None, description="Shape color (hex or color name)")
|
||||
line_width: Optional[int] = Field(default=1, description="Line width in pixels")
|
||||
line_style: Optional[str] = Field(default="solid", description="Line style: 'solid', 'dashed', 'dotted'")
|
||||
|
||||
# Shape-specific properties stored as flexible dict
|
||||
properties: Dict[str, Any] = Field(default_factory=dict, description="Additional shape-specific properties")
|
||||
|
||||
# Metadata
|
||||
symbol: Optional[str] = Field(default=None, description="Symbol this shape is drawn on")
|
||||
created_at: Optional[int] = Field(default=None, description="Creation timestamp (Unix seconds)")
|
||||
modified_at: Optional[int] = Field(default=None, description="Last modification timestamp (Unix seconds)")
|
||||
original_id: Optional[str] = Field(default=None, description="Original ID from backend/agent before TradingView assigns its own ID")
|
||||
|
||||
|
||||
class ShapeCollection(BaseModel):
|
||||
"""Collection of all shapes/drawings on the chart."""
|
||||
shapes: Dict[str, Shape] = Field(default_factory=dict, description="Dictionary of shapes keyed by ID")
|
||||
40
backend.old/src/secrets_manager/__init__.py
Normal file
40
backend.old/src/secrets_manager/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Encrypted secrets management with master password protection.
|
||||
|
||||
This module provides secure storage for sensitive configuration like API keys,
|
||||
using Argon2id for password-based key derivation and Fernet (AES-256) for encryption.
|
||||
|
||||
Basic usage:
|
||||
from secrets_manager import SecretsStore
|
||||
|
||||
# First time setup
|
||||
store = SecretsStore()
|
||||
store.initialize("my-master-password")
|
||||
store.set("ANTHROPIC_API_KEY", "sk-ant-...")
|
||||
|
||||
# Later usage
|
||||
store = SecretsStore()
|
||||
store.unlock("my-master-password")
|
||||
api_key = store.get("ANTHROPIC_API_KEY")
|
||||
|
||||
Command-line interface:
|
||||
python -m secrets_manager.cli init
|
||||
python -m secrets_manager.cli set KEY VALUE
|
||||
python -m secrets_manager.cli get KEY
|
||||
python -m secrets_manager.cli list
|
||||
python -m secrets_manager.cli change-password
|
||||
"""
|
||||
|
||||
from .store import (
|
||||
SecretsStore,
|
||||
SecretsStoreError,
|
||||
SecretsStoreLocked,
|
||||
InvalidMasterPassword,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SecretsStore",
|
||||
"SecretsStoreError",
|
||||
"SecretsStoreLocked",
|
||||
"InvalidMasterPassword",
|
||||
]
|
||||
374
backend.old/src/secrets_manager/cli.py
Normal file
374
backend.old/src/secrets_manager/cli.py
Normal file
@@ -0,0 +1,374 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Command-line interface for managing the encrypted secrets store.
|
||||
|
||||
Usage:
|
||||
python -m secrets.cli init # Initialize new secrets store
|
||||
python -m secrets.cli set KEY VALUE # Set a secret
|
||||
python -m secrets.cli get KEY # Get a secret
|
||||
python -m secrets.cli delete KEY # Delete a secret
|
||||
python -m secrets.cli list # List all secret keys
|
||||
python -m secrets.cli change-password # Change master password
|
||||
python -m secrets.cli export FILE # Export encrypted backup
|
||||
python -m secrets.cli import FILE # Import encrypted backup
|
||||
python -m secrets.cli migrate-from-env # Migrate secrets from .env file
|
||||
"""
|
||||
import sys
|
||||
import argparse
|
||||
import getpass
|
||||
from pathlib import Path
|
||||
|
||||
from .store import SecretsStore, SecretsStoreError, InvalidMasterPassword
|
||||
|
||||
|
||||
def get_password(prompt: str = "Master password: ", confirm: bool = False) -> str:
|
||||
"""
|
||||
Securely get password from user.
|
||||
|
||||
Args:
|
||||
prompt: Password prompt
|
||||
confirm: If True, ask for confirmation
|
||||
|
||||
Returns:
|
||||
Password string
|
||||
"""
|
||||
password = getpass.getpass(prompt)
|
||||
|
||||
if confirm:
|
||||
confirm_password = getpass.getpass("Confirm password: ")
|
||||
if password != confirm_password:
|
||||
print("Error: Passwords do not match", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
return password
|
||||
|
||||
|
||||
def cmd_init(args):
|
||||
"""Initialize a new secrets store."""
|
||||
store = SecretsStore()
|
||||
|
||||
if store.is_initialized:
|
||||
print("Error: Secrets store is already initialized", file=sys.stderr)
|
||||
print(f"Location: {store.secrets_file}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
password = get_password("Create master password: ", confirm=True)
|
||||
|
||||
if len(password) < 8:
|
||||
print("Error: Password must be at least 8 characters", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
store.initialize(password)
|
||||
print(f"Secrets store initialized at {store.secrets_file}")
|
||||
|
||||
|
||||
def cmd_set(args):
|
||||
"""Set a secret value."""
|
||||
store = SecretsStore()
|
||||
|
||||
if not store.is_initialized:
|
||||
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
password = get_password()
|
||||
|
||||
try:
|
||||
store.unlock(password)
|
||||
except InvalidMasterPassword:
|
||||
print("Error: Invalid master password", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
store.set(args.key, args.value)
|
||||
print(f"✓ Secret '{args.key}' saved")
|
||||
|
||||
|
||||
def cmd_get(args):
|
||||
"""Get a secret value."""
|
||||
store = SecretsStore()
|
||||
|
||||
if not store.is_initialized:
|
||||
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
password = get_password()
|
||||
|
||||
try:
|
||||
store.unlock(password)
|
||||
except InvalidMasterPassword:
|
||||
print("Error: Invalid master password", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
value = store.get(args.key)
|
||||
if value is None:
|
||||
print(f"Error: Secret '{args.key}' not found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Print to stdout (can be captured)
|
||||
print(value)
|
||||
|
||||
|
||||
def cmd_delete(args):
|
||||
"""Delete a secret."""
|
||||
store = SecretsStore()
|
||||
|
||||
if not store.is_initialized:
|
||||
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
password = get_password()
|
||||
|
||||
try:
|
||||
store.unlock(password)
|
||||
except InvalidMasterPassword:
|
||||
print("Error: Invalid master password", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if store.delete(args.key):
|
||||
print(f"✓ Secret '{args.key}' deleted")
|
||||
else:
|
||||
print(f"Error: Secret '{args.key}' not found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def cmd_list(args):
|
||||
"""List all secret keys."""
|
||||
store = SecretsStore()
|
||||
|
||||
if not store.is_initialized:
|
||||
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
password = get_password()
|
||||
|
||||
try:
|
||||
store.unlock(password)
|
||||
except InvalidMasterPassword:
|
||||
print("Error: Invalid master password", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
keys = store.list_keys()
|
||||
|
||||
if not keys:
|
||||
print("No secrets stored")
|
||||
else:
|
||||
print(f"Stored secrets ({len(keys)}):")
|
||||
for key in sorted(keys):
|
||||
# Show key and value length for verification
|
||||
value = store.get(key)
|
||||
value_str = str(value)
|
||||
value_preview = value_str[:50] + "..." if len(value_str) > 50 else value_str
|
||||
print(f" {key}: {value_preview}")
|
||||
|
||||
|
||||
def cmd_change_password(args):
|
||||
"""Change the master password."""
|
||||
store = SecretsStore()
|
||||
|
||||
if not store.is_initialized:
|
||||
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
current_password = get_password("Current master password: ")
|
||||
new_password = get_password("New master password: ", confirm=True)
|
||||
|
||||
if len(new_password) < 8:
|
||||
print("Error: Password must be at least 8 characters", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
store.change_master_password(current_password, new_password)
|
||||
except InvalidMasterPassword:
|
||||
print("Error: Invalid current password", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def cmd_export(args):
|
||||
"""Export encrypted secrets to a backup file."""
|
||||
store = SecretsStore()
|
||||
|
||||
if not store.is_initialized:
|
||||
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
output_path = Path(args.file)
|
||||
|
||||
if output_path.exists() and not args.force:
|
||||
print(f"Error: File {output_path} already exists. Use --force to overwrite.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
store.export_encrypted(output_path)
|
||||
except SecretsStoreError as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def cmd_import(args):
|
||||
"""Import encrypted secrets from a backup file."""
|
||||
store = SecretsStore()
|
||||
|
||||
if not store.is_initialized:
|
||||
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
input_path = Path(args.file)
|
||||
|
||||
if not input_path.exists():
|
||||
print(f"Error: File {input_path} does not exist", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
password = get_password()
|
||||
|
||||
try:
|
||||
store.import_encrypted(input_path, password)
|
||||
except InvalidMasterPassword:
|
||||
print("Error: Invalid master password or incompatible backup", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
except SecretsStoreError as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def cmd_migrate_from_env(args):
|
||||
"""Migrate secrets from .env file to encrypted store."""
|
||||
store = SecretsStore()
|
||||
|
||||
if not store.is_initialized:
|
||||
print("Error: Secrets store is not initialized. Run 'init' first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Look for .env file
|
||||
backend_root = Path(__file__).parent.parent.parent
|
||||
env_file = backend_root / ".env"
|
||||
|
||||
if not env_file.exists():
|
||||
print(f"Error: .env file not found at {env_file}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
password = get_password()
|
||||
|
||||
try:
|
||||
store.unlock(password)
|
||||
except InvalidMasterPassword:
|
||||
print("Error: Invalid master password", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Parse .env file (simple parser - doesn't handle all edge cases)
|
||||
migrated = 0
|
||||
skipped = 0
|
||||
|
||||
with open(env_file) as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
|
||||
# Skip empty lines and comments
|
||||
if not line or line.startswith('#'):
|
||||
continue
|
||||
|
||||
# Parse KEY=VALUE format
|
||||
if '=' not in line:
|
||||
print(f"Warning: Skipping invalid line {line_num}: {line}", file=sys.stderr)
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
key, value = line.split('=', 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
# Remove quotes if present
|
||||
if value.startswith('"') and value.endswith('"'):
|
||||
value = value[1:-1]
|
||||
elif value.startswith("'") and value.endswith("'"):
|
||||
value = value[1:-1]
|
||||
|
||||
# Check if key already exists
|
||||
existing = store.get(key)
|
||||
if existing is not None:
|
||||
print(f"Warning: Secret '{key}' already exists, skipping", file=sys.stderr)
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
store.set(key, value)
|
||||
print(f"✓ Migrated: {key}")
|
||||
migrated += 1
|
||||
|
||||
print(f"\nMigration complete: {migrated} secrets migrated, {skipped} skipped")
|
||||
|
||||
if not args.keep_env:
|
||||
# Ask for confirmation before deleting .env
|
||||
confirm = input(f"\nDelete {env_file}? [y/N]: ").strip().lower()
|
||||
if confirm == 'y':
|
||||
env_file.unlink()
|
||||
print(f"✓ Deleted {env_file}")
|
||||
else:
|
||||
print(f"Kept {env_file} (consider deleting it manually)")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main CLI entry point."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Manage encrypted secrets store",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest='command', help='Command to run')
|
||||
subparsers.required = True
|
||||
|
||||
# init
|
||||
parser_init = subparsers.add_parser('init', help='Initialize new secrets store')
|
||||
parser_init.set_defaults(func=cmd_init)
|
||||
|
||||
# set
|
||||
parser_set = subparsers.add_parser('set', help='Set a secret value')
|
||||
parser_set.add_argument('key', help='Secret key name')
|
||||
parser_set.add_argument('value', help='Secret value')
|
||||
parser_set.set_defaults(func=cmd_set)
|
||||
|
||||
# get
|
||||
parser_get = subparsers.add_parser('get', help='Get a secret value')
|
||||
parser_get.add_argument('key', help='Secret key name')
|
||||
parser_get.set_defaults(func=cmd_get)
|
||||
|
||||
# delete
|
||||
parser_delete = subparsers.add_parser('delete', help='Delete a secret')
|
||||
parser_delete.add_argument('key', help='Secret key name')
|
||||
parser_delete.set_defaults(func=cmd_delete)
|
||||
|
||||
# list
|
||||
parser_list = subparsers.add_parser('list', help='List all secret keys')
|
||||
parser_list.set_defaults(func=cmd_list)
|
||||
|
||||
# change-password
|
||||
parser_change = subparsers.add_parser('change-password', help='Change master password')
|
||||
parser_change.set_defaults(func=cmd_change_password)
|
||||
|
||||
# export
|
||||
parser_export = subparsers.add_parser('export', help='Export encrypted backup')
|
||||
parser_export.add_argument('file', help='Output file path')
|
||||
parser_export.add_argument('--force', action='store_true', help='Overwrite existing file')
|
||||
parser_export.set_defaults(func=cmd_export)
|
||||
|
||||
# import
|
||||
parser_import = subparsers.add_parser('import', help='Import encrypted backup')
|
||||
parser_import.add_argument('file', help='Input file path')
|
||||
parser_import.set_defaults(func=cmd_import)
|
||||
|
||||
# migrate-from-env
|
||||
parser_migrate = subparsers.add_parser('migrate-from-env', help='Migrate from .env file')
|
||||
parser_migrate.add_argument('--keep-env', action='store_true', help='Keep .env file after migration')
|
||||
parser_migrate.set_defaults(func=cmd_migrate_from_env)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
args.func(args)
|
||||
except KeyboardInterrupt:
|
||||
print("\nAborted", file=sys.stderr)
|
||||
sys.exit(130)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
144
backend.old/src/secrets_manager/crypto.py
Normal file
144
backend.old/src/secrets_manager/crypto.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Cryptographic utilities for secrets management.
|
||||
|
||||
Uses Argon2id for password-based key derivation and Fernet for encryption.
|
||||
"""
|
||||
import os
|
||||
import secrets as secrets_module
|
||||
from typing import Tuple
|
||||
|
||||
from argon2 import PasswordHasher
|
||||
from argon2.low_level import hash_secret_raw, Type
|
||||
from cryptography.fernet import Fernet
|
||||
import base64
|
||||
|
||||
|
||||
# Argon2id parameters (OWASP recommended for password-based KDF)
|
||||
# These provide strong defense against GPU/ASIC attacks
|
||||
ARGON2_TIME_COST = 3 # iterations
|
||||
ARGON2_MEMORY_COST = 65536 # 64 MB
|
||||
ARGON2_PARALLELISM = 4 # threads
|
||||
ARGON2_HASH_LENGTH = 32 # bytes (256 bits for Fernet key)
|
||||
ARGON2_SALT_LENGTH = 16 # bytes (128 bits)
|
||||
|
||||
|
||||
def generate_salt() -> bytes:
|
||||
"""Generate a cryptographically secure random salt."""
|
||||
return secrets_module.token_bytes(ARGON2_SALT_LENGTH)
|
||||
|
||||
|
||||
def derive_key_from_password(password: str, salt: bytes) -> bytes:
|
||||
"""
|
||||
Derive an encryption key from a password using Argon2id.
|
||||
|
||||
Args:
|
||||
password: The master password
|
||||
salt: The salt (must be consistent for the same password to work)
|
||||
|
||||
Returns:
|
||||
32-byte key suitable for Fernet encryption
|
||||
"""
|
||||
password_bytes = password.encode('utf-8')
|
||||
|
||||
# Use Argon2id (hybrid mode - best of Argon2i and Argon2d)
|
||||
raw_hash = hash_secret_raw(
|
||||
secret=password_bytes,
|
||||
salt=salt,
|
||||
time_cost=ARGON2_TIME_COST,
|
||||
memory_cost=ARGON2_MEMORY_COST,
|
||||
parallelism=ARGON2_PARALLELISM,
|
||||
hash_len=ARGON2_HASH_LENGTH,
|
||||
type=Type.ID # Argon2id
|
||||
)
|
||||
|
||||
return raw_hash
|
||||
|
||||
|
||||
def create_fernet(key: bytes) -> Fernet:
|
||||
"""
|
||||
Create a Fernet cipher instance from a raw key.
|
||||
|
||||
Args:
|
||||
key: 32-byte raw key from Argon2id
|
||||
|
||||
Returns:
|
||||
Fernet instance for encryption/decryption
|
||||
"""
|
||||
# Fernet requires a URL-safe base64-encoded 32-byte key
|
||||
fernet_key = base64.urlsafe_b64encode(key)
|
||||
return Fernet(fernet_key)
|
||||
|
||||
|
||||
def encrypt_data(data: bytes, key: bytes) -> bytes:
|
||||
"""
|
||||
Encrypt data using Fernet (AES-256-CBC).
|
||||
|
||||
Args:
|
||||
data: Raw bytes to encrypt
|
||||
key: 32-byte encryption key
|
||||
|
||||
Returns:
|
||||
Encrypted data (includes IV and auth tag)
|
||||
"""
|
||||
fernet = create_fernet(key)
|
||||
return fernet.encrypt(data)
|
||||
|
||||
|
||||
def decrypt_data(encrypted_data: bytes, key: bytes) -> bytes:
|
||||
"""
|
||||
Decrypt data using Fernet.
|
||||
|
||||
Args:
|
||||
encrypted_data: Encrypted bytes from encrypt_data
|
||||
key: 32-byte encryption key (must match encryption key)
|
||||
|
||||
Returns:
|
||||
Decrypted raw bytes
|
||||
|
||||
Raises:
|
||||
cryptography.fernet.InvalidToken: If decryption fails (wrong key/corrupted data)
|
||||
"""
|
||||
fernet = create_fernet(key)
|
||||
return fernet.decrypt(encrypted_data)
|
||||
|
||||
|
||||
def create_verification_hash(password: str, salt: bytes) -> str:
|
||||
"""
|
||||
Create a verification hash to check if a password is correct.
|
||||
|
||||
This is NOT for storing the password - it's for verifying the password
|
||||
unlocks the correct key without trying to decrypt the entire secrets file.
|
||||
|
||||
Args:
|
||||
password: The master password
|
||||
salt: The salt used for key derivation
|
||||
|
||||
Returns:
|
||||
Base64-encoded hash for verification
|
||||
"""
|
||||
# Derive key and hash it again for verification
|
||||
key = derive_key_from_password(password, salt)
|
||||
|
||||
# Simple hash of the key for verification (not security critical since
|
||||
# the key itself is already derived from Argon2id)
|
||||
verification = base64.b64encode(key[:16]).decode('ascii')
|
||||
|
||||
return verification
|
||||
|
||||
|
||||
def verify_password(password: str, salt: bytes, verification_hash: str) -> bool:
|
||||
"""
|
||||
Verify a password against a verification hash.
|
||||
|
||||
Args:
|
||||
password: Password to verify
|
||||
salt: Salt used for key derivation
|
||||
verification_hash: Expected verification hash
|
||||
|
||||
Returns:
|
||||
True if password is correct, False otherwise
|
||||
"""
|
||||
computed_hash = create_verification_hash(password, salt)
|
||||
|
||||
# Constant-time comparison to prevent timing attacks
|
||||
return secrets_module.compare_digest(computed_hash, verification_hash)
|
||||
406
backend.old/src/secrets_manager/store.py
Normal file
406
backend.old/src/secrets_manager/store.py
Normal file
@@ -0,0 +1,406 @@
|
||||
"""
|
||||
Encrypted secrets store with master password protection.
|
||||
|
||||
The secrets are stored in an encrypted file, with the encryption key derived
|
||||
from a master password using Argon2id. The master password can be changed
|
||||
without re-encrypting all secrets.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import stat
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
from cryptography.fernet import InvalidToken
|
||||
|
||||
from .crypto import (
|
||||
generate_salt,
|
||||
derive_key_from_password,
|
||||
encrypt_data,
|
||||
decrypt_data,
|
||||
create_verification_hash,
|
||||
verify_password,
|
||||
)
|
||||
|
||||
|
||||
class SecretsStoreError(Exception):
|
||||
"""Base exception for secrets store errors."""
|
||||
pass
|
||||
|
||||
|
||||
class SecretsStoreLocked(SecretsStoreError):
|
||||
"""Raised when trying to access secrets while store is locked."""
|
||||
pass
|
||||
|
||||
|
||||
class InvalidMasterPassword(SecretsStoreError):
|
||||
"""Raised when master password is incorrect."""
|
||||
pass
|
||||
|
||||
|
||||
class SecretsStore:
|
||||
"""
|
||||
Encrypted secrets store with master password protection.
|
||||
|
||||
Usage:
|
||||
# Initialize (first time)
|
||||
store = SecretsStore()
|
||||
store.initialize("my-secure-password")
|
||||
|
||||
# Unlock
|
||||
store = SecretsStore()
|
||||
store.unlock("my-secure-password")
|
||||
|
||||
# Access secrets
|
||||
api_key = store.get("ANTHROPIC_API_KEY")
|
||||
store.set("NEW_SECRET", "secret-value")
|
||||
|
||||
# Change master password
|
||||
store.change_master_password("my-secure-password", "new-password")
|
||||
|
||||
# Lock when done
|
||||
store.lock()
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: Optional[Path] = None):
|
||||
"""
|
||||
Initialize secrets store.
|
||||
|
||||
Args:
|
||||
data_dir: Directory for secrets files (defaults to backend/data)
|
||||
"""
|
||||
if data_dir is None:
|
||||
# Default to backend/data
|
||||
backend_root = Path(__file__).parent.parent.parent
|
||||
data_dir = backend_root / "data"
|
||||
|
||||
self.data_dir = Path(data_dir)
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.master_key_file = self.data_dir / ".master.key"
|
||||
self.secrets_file = self.data_dir / "secrets.enc"
|
||||
|
||||
# Runtime state
|
||||
self._encryption_key: Optional[bytes] = None
|
||||
self._secrets: Optional[Dict[str, Any]] = None
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if the secrets store has been initialized."""
|
||||
return self.master_key_file.exists()
|
||||
|
||||
@property
|
||||
def is_unlocked(self) -> bool:
|
||||
"""Check if the secrets store is currently unlocked."""
|
||||
return self._encryption_key is not None
|
||||
|
||||
def initialize(self, master_password: str) -> None:
|
||||
"""
|
||||
Initialize the secrets store with a master password.
|
||||
|
||||
This should only be called once when setting up the store.
|
||||
|
||||
Args:
|
||||
master_password: The master password to protect the secrets
|
||||
|
||||
Raises:
|
||||
SecretsStoreError: If store is already initialized
|
||||
"""
|
||||
if self.is_initialized:
|
||||
raise SecretsStoreError(
|
||||
"Secrets store is already initialized. "
|
||||
"Use unlock() to access it or change_master_password() to change the password."
|
||||
)
|
||||
|
||||
# Generate a new random salt
|
||||
salt = generate_salt()
|
||||
|
||||
# Derive encryption key
|
||||
encryption_key = derive_key_from_password(master_password, salt)
|
||||
|
||||
# Create verification hash
|
||||
verification_hash = create_verification_hash(master_password, salt)
|
||||
|
||||
# Store salt and verification hash
|
||||
master_key_data = {
|
||||
"salt": salt.hex(),
|
||||
"verification": verification_hash,
|
||||
}
|
||||
|
||||
self.master_key_file.write_text(json.dumps(master_key_data, indent=2))
|
||||
|
||||
# Set restrictive permissions (owner read/write only)
|
||||
os.chmod(self.master_key_file, stat.S_IRUSR | stat.S_IWUSR)
|
||||
|
||||
# Initialize empty secrets
|
||||
self._encryption_key = encryption_key
|
||||
self._secrets = {}
|
||||
self._save_secrets()
|
||||
|
||||
print(f"✓ Secrets store initialized at {self.secrets_file}")
|
||||
|
||||
def unlock(self, master_password: str) -> None:
|
||||
"""
|
||||
Unlock the secrets store with the master password.
|
||||
|
||||
Args:
|
||||
master_password: The master password
|
||||
|
||||
Raises:
|
||||
SecretsStoreError: If store is not initialized
|
||||
InvalidMasterPassword: If password is incorrect
|
||||
"""
|
||||
if not self.is_initialized:
|
||||
raise SecretsStoreError(
|
||||
"Secrets store is not initialized. Call initialize() first."
|
||||
)
|
||||
|
||||
# Load salt and verification hash
|
||||
master_key_data = json.loads(self.master_key_file.read_text())
|
||||
salt = bytes.fromhex(master_key_data["salt"])
|
||||
verification_hash = master_key_data["verification"]
|
||||
|
||||
# Verify password
|
||||
if not verify_password(master_password, salt, verification_hash):
|
||||
raise InvalidMasterPassword("Invalid master password")
|
||||
|
||||
# Derive encryption key
|
||||
encryption_key = derive_key_from_password(master_password, salt)
|
||||
|
||||
# Load and decrypt secrets
|
||||
if self.secrets_file.exists():
|
||||
try:
|
||||
encrypted_data = self.secrets_file.read_bytes()
|
||||
decrypted_data = decrypt_data(encrypted_data, encryption_key)
|
||||
self._secrets = json.loads(decrypted_data.decode('utf-8'))
|
||||
except InvalidToken:
|
||||
raise InvalidMasterPassword("Failed to decrypt secrets (invalid password)")
|
||||
except json.JSONDecodeError as e:
|
||||
raise SecretsStoreError(f"Corrupted secrets file: {e}")
|
||||
else:
|
||||
# No secrets file yet (fresh initialization)
|
||||
self._secrets = {}
|
||||
|
||||
self._encryption_key = encryption_key
|
||||
print(f"✓ Secrets store unlocked ({len(self._secrets)} secrets)")
|
||||
|
||||
def lock(self) -> None:
|
||||
"""Lock the secrets store (clear decrypted data from memory)."""
|
||||
self._encryption_key = None
|
||||
self._secrets = None
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Get a secret value.
|
||||
|
||||
Args:
|
||||
key: Secret key name
|
||||
default: Default value if key doesn't exist
|
||||
|
||||
Returns:
|
||||
Secret value or default
|
||||
|
||||
Raises:
|
||||
SecretsStoreLocked: If store is locked
|
||||
"""
|
||||
if not self.is_unlocked:
|
||||
raise SecretsStoreLocked("Secrets store is locked. Call unlock() first.")
|
||||
|
||||
return self._secrets.get(key, default)
|
||||
|
||||
def set(self, key: str, value: Any) -> None:
|
||||
"""
|
||||
Set a secret value.
|
||||
|
||||
Args:
|
||||
key: Secret key name
|
||||
value: Secret value (must be JSON-serializable)
|
||||
|
||||
Raises:
|
||||
SecretsStoreLocked: If store is locked
|
||||
"""
|
||||
if not self.is_unlocked:
|
||||
raise SecretsStoreLocked("Secrets store is locked. Call unlock() first.")
|
||||
|
||||
self._secrets[key] = value
|
||||
self._save_secrets()
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""
|
||||
Delete a secret.
|
||||
|
||||
Args:
|
||||
key: Secret key name
|
||||
|
||||
Returns:
|
||||
True if secret existed and was deleted, False otherwise
|
||||
|
||||
Raises:
|
||||
SecretsStoreLocked: If store is locked
|
||||
"""
|
||||
if not self.is_unlocked:
|
||||
raise SecretsStoreLocked("Secrets store is locked. Call unlock() first.")
|
||||
|
||||
if key in self._secrets:
|
||||
del self._secrets[key]
|
||||
self._save_secrets()
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_keys(self) -> list[str]:
|
||||
"""
|
||||
List all secret keys.
|
||||
|
||||
Returns:
|
||||
List of secret keys
|
||||
|
||||
Raises:
|
||||
SecretsStoreLocked: If store is locked
|
||||
"""
|
||||
if not self.is_unlocked:
|
||||
raise SecretsStoreLocked("Secrets store is locked. Call unlock() first.")
|
||||
|
||||
return list(self._secrets.keys())
|
||||
|
||||
def change_master_password(self, current_password: str, new_password: str) -> None:
|
||||
"""
|
||||
Change the master password.
|
||||
|
||||
This re-encrypts the secrets with a new key derived from the new password.
|
||||
|
||||
Args:
|
||||
current_password: Current master password
|
||||
new_password: New master password
|
||||
|
||||
Raises:
|
||||
InvalidMasterPassword: If current password is incorrect
|
||||
"""
|
||||
# ALWAYS verify current password before changing
|
||||
# Load salt and verification hash
|
||||
if not self.is_initialized:
|
||||
raise SecretsStoreError(
|
||||
"Secrets store is not initialized. Call initialize() first."
|
||||
)
|
||||
|
||||
master_key_data = json.loads(self.master_key_file.read_text())
|
||||
salt = bytes.fromhex(master_key_data["salt"])
|
||||
verification_hash = master_key_data["verification"]
|
||||
|
||||
# Verify current password is correct
|
||||
if not verify_password(current_password, salt, verification_hash):
|
||||
raise InvalidMasterPassword("Invalid current password")
|
||||
|
||||
# Unlock if needed to access secrets
|
||||
was_unlocked = self.is_unlocked
|
||||
if not was_unlocked:
|
||||
# Store is locked, so unlock with current password
|
||||
# (we already verified it above, so this will succeed)
|
||||
encryption_key = derive_key_from_password(current_password, salt)
|
||||
|
||||
# Load and decrypt secrets
|
||||
if self.secrets_file.exists():
|
||||
encrypted_data = self.secrets_file.read_bytes()
|
||||
decrypted_data = decrypt_data(encrypted_data, encryption_key)
|
||||
self._secrets = json.loads(decrypted_data.decode('utf-8'))
|
||||
else:
|
||||
self._secrets = {}
|
||||
|
||||
self._encryption_key = encryption_key
|
||||
|
||||
# Generate new salt
|
||||
new_salt = generate_salt()
|
||||
|
||||
# Derive new encryption key
|
||||
new_encryption_key = derive_key_from_password(new_password, new_salt)
|
||||
|
||||
# Create new verification hash
|
||||
new_verification_hash = create_verification_hash(new_password, new_salt)
|
||||
|
||||
# Update master key file
|
||||
master_key_data = {
|
||||
"salt": new_salt.hex(),
|
||||
"verification": new_verification_hash,
|
||||
}
|
||||
self.master_key_file.write_text(json.dumps(master_key_data, indent=2))
|
||||
os.chmod(self.master_key_file, stat.S_IRUSR | stat.S_IWUSR)
|
||||
|
||||
# Re-encrypt secrets with new key
|
||||
old_key = self._encryption_key
|
||||
self._encryption_key = new_encryption_key
|
||||
self._save_secrets()
|
||||
|
||||
print("✓ Master password changed successfully")
|
||||
|
||||
# Lock if it wasn't unlocked before
|
||||
if not was_unlocked:
|
||||
self.lock()
|
||||
|
||||
def _save_secrets(self) -> None:
|
||||
"""Save secrets to encrypted file."""
|
||||
if not self.is_unlocked:
|
||||
raise SecretsStoreLocked("Cannot save while locked")
|
||||
|
||||
# Serialize secrets to JSON
|
||||
secrets_json = json.dumps(self._secrets, indent=2)
|
||||
secrets_bytes = secrets_json.encode('utf-8')
|
||||
|
||||
# Encrypt
|
||||
encrypted_data = encrypt_data(secrets_bytes, self._encryption_key)
|
||||
|
||||
# Write to file
|
||||
self.secrets_file.write_bytes(encrypted_data)
|
||||
|
||||
# Set restrictive permissions
|
||||
os.chmod(self.secrets_file, stat.S_IRUSR | stat.S_IWUSR)
|
||||
|
||||
def export_encrypted(self, output_path: Path) -> None:
|
||||
"""
|
||||
Export encrypted secrets to a file (for backup).
|
||||
|
||||
Args:
|
||||
output_path: Path to export file
|
||||
|
||||
Raises:
|
||||
SecretsStoreError: If secrets file doesn't exist
|
||||
"""
|
||||
if not self.secrets_file.exists():
|
||||
raise SecretsStoreError("No secrets to export")
|
||||
|
||||
import shutil
|
||||
shutil.copy2(self.secrets_file, output_path)
|
||||
print(f"✓ Encrypted secrets exported to {output_path}")
|
||||
|
||||
def import_encrypted(self, input_path: Path, master_password: str) -> None:
|
||||
"""
|
||||
Import encrypted secrets from a file.
|
||||
|
||||
This will verify the password can decrypt the import before replacing
|
||||
the current secrets.
|
||||
|
||||
Args:
|
||||
input_path: Path to import file
|
||||
master_password: Master password for the current store
|
||||
|
||||
Raises:
|
||||
InvalidMasterPassword: If password doesn't work with import
|
||||
"""
|
||||
if not self.is_unlocked:
|
||||
self.unlock(master_password)
|
||||
|
||||
# Try to decrypt the imported file with current key
|
||||
try:
|
||||
encrypted_data = Path(input_path).read_bytes()
|
||||
decrypted_data = decrypt_data(encrypted_data, self._encryption_key)
|
||||
imported_secrets = json.loads(decrypted_data.decode('utf-8'))
|
||||
except InvalidToken:
|
||||
raise InvalidMasterPassword(
|
||||
"Cannot decrypt imported secrets with current master password"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise SecretsStoreError(f"Corrupted import file: {e}")
|
||||
|
||||
# Replace secrets
|
||||
self._secrets = imported_secrets
|
||||
self._save_secrets()
|
||||
|
||||
print(f"✓ Imported {len(self._secrets)} secrets from {input_path}")
|
||||
246
backend.old/src/sync/registry.py
Normal file
246
backend.old/src/sync/registry.py
Normal file
@@ -0,0 +1,246 @@
|
||||
from collections import deque
|
||||
from typing import Any, Dict, List, Optional, Tuple, Deque
|
||||
|
||||
import jsonpatch
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sync.protocol import SnapshotMessage, PatchMessage
|
||||
|
||||
|
||||
class SyncEntry:
|
||||
def __init__(self, model: BaseModel, store_name: str, history_size: int = 50):
|
||||
self.model = model
|
||||
self.store_name = store_name
|
||||
self.seq = 0
|
||||
self.last_snapshot = model.model_dump(mode="json")
|
||||
self.history: Deque[Tuple[int, List[Dict[str, Any]]]] = deque(maxlen=history_size)
|
||||
|
||||
def compute_patch(self) -> Optional[List[Dict[str, Any]]]:
|
||||
current_state = self.model.model_dump(mode="json")
|
||||
patch = jsonpatch.make_patch(self.last_snapshot, current_state)
|
||||
if not patch.patch:
|
||||
return None
|
||||
return patch.patch
|
||||
|
||||
def commit_patch(self, patch: List[Dict[str, Any]]):
|
||||
self.seq += 1
|
||||
self.history.append((self.seq, patch))
|
||||
self.last_snapshot = self.model.model_dump(mode="json")
|
||||
|
||||
def catchup_patches(self, since_seq: int) -> Optional[List[Tuple[int, List[Dict[str, Any]]]]]:
|
||||
if since_seq == self.seq:
|
||||
return []
|
||||
|
||||
# Check if all patches from since_seq + 1 to self.seq are in history
|
||||
if not self.history or self.history[0][0] > since_seq + 1:
|
||||
return None
|
||||
|
||||
result = []
|
||||
for seq, patch in self.history:
|
||||
if seq > since_seq:
|
||||
result.append((seq, patch))
|
||||
return result
|
||||
|
||||
class SyncRegistry:
|
||||
def __init__(self):
|
||||
self.entries: Dict[str, SyncEntry] = {}
|
||||
self.websocket: Optional[Any] = None # Expecting a FastAPI WebSocket or similar
|
||||
|
||||
def register(self, model: BaseModel, store_name: str):
|
||||
self.entries[store_name] = SyncEntry(model, store_name)
|
||||
|
||||
async def push_all(self):
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if not self.websocket:
|
||||
logger.warning("push_all: No websocket connected, cannot push updates")
|
||||
return
|
||||
|
||||
logger.info(f"push_all: Processing {len(self.entries)} store entries")
|
||||
for entry in self.entries.values():
|
||||
patch = entry.compute_patch()
|
||||
if patch:
|
||||
logger.info(f"push_all: Found patch for store '{entry.store_name}': {patch}")
|
||||
entry.commit_patch(patch)
|
||||
msg = PatchMessage(store=entry.store_name, seq=entry.seq, patch=patch)
|
||||
logger.info(f"push_all: Sending patch message for '{entry.store_name}' seq={entry.seq}")
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
logger.info(f"push_all: Patch sent successfully for '{entry.store_name}'")
|
||||
else:
|
||||
logger.debug(f"push_all: No changes detected for store '{entry.store_name}'")
|
||||
|
||||
async def sync_client(self, client_seqs: Dict[str, int]):
|
||||
if not self.websocket:
|
||||
return
|
||||
|
||||
for store_name, entry in self.entries.items():
|
||||
client_seq = client_seqs.get(store_name, -1)
|
||||
patches = entry.catchup_patches(client_seq)
|
||||
|
||||
if patches is not None:
|
||||
# Replay patches
|
||||
for seq, patch in patches:
|
||||
msg = PatchMessage(store=store_name, seq=seq, patch=patch)
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
else:
|
||||
# Send full snapshot
|
||||
msg = SnapshotMessage(
|
||||
store=store_name,
|
||||
seq=entry.seq,
|
||||
state=entry.model.model_dump(mode="json")
|
||||
)
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
|
||||
async def apply_client_patch(self, store_name: str, client_base_seq: int, patch: List[Dict[str, Any]]):
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger.info(f"apply_client_patch: store={store_name}, client_base_seq={client_base_seq}, patch={patch}")
|
||||
|
||||
entry = self.entries.get(store_name)
|
||||
if not entry:
|
||||
logger.warning(f"apply_client_patch: Store '{store_name}' not found in registry")
|
||||
return
|
||||
|
||||
logger.info(f"apply_client_patch: Current backend seq={entry.seq}")
|
||||
|
||||
try:
|
||||
if client_base_seq == entry.seq:
|
||||
# No conflict
|
||||
logger.info("apply_client_patch: No conflict - applying patch directly")
|
||||
current_state = entry.model.model_dump(mode="json")
|
||||
logger.info(f"apply_client_patch: Current state before patch: {current_state}")
|
||||
try:
|
||||
new_state = jsonpatch.apply_patch(current_state, patch)
|
||||
logger.info(f"apply_client_patch: New state after patch: {new_state}")
|
||||
self._update_model(entry.model, new_state)
|
||||
|
||||
# Verify the model was actually updated
|
||||
updated_state = entry.model.model_dump(mode="json")
|
||||
logger.info(f"apply_client_patch: Model state after _update_model: {updated_state}")
|
||||
|
||||
entry.commit_patch(patch)
|
||||
logger.info(f"apply_client_patch: Patch committed, new seq={entry.seq}")
|
||||
# Don't broadcast back to client - they already have this change
|
||||
# Broadcasting would cause an infinite loop
|
||||
logger.info("apply_client_patch: Not broadcasting back to originating client")
|
||||
except jsonpatch.JsonPatchConflict as e:
|
||||
logger.warning(f"apply_client_patch: Patch conflict on no-conflict path: {e}. Sending snapshot to resync.")
|
||||
# Send snapshot to force resync
|
||||
if self.websocket:
|
||||
msg = SnapshotMessage(
|
||||
store=entry.store_name,
|
||||
seq=entry.seq,
|
||||
state=entry.model.model_dump(mode="json")
|
||||
)
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
|
||||
elif client_base_seq < entry.seq:
|
||||
# Conflict! Frontend wins.
|
||||
# 1. Get backend patches since client_base_seq
|
||||
backend_patches = []
|
||||
for seq, p in entry.history:
|
||||
if seq > client_base_seq:
|
||||
backend_patches.append(p)
|
||||
|
||||
# 2. Apply frontend patch first to the state at client_base_seq
|
||||
# But we only have the current authoritative model.
|
||||
# "Apply the frontend patch first to the model (frontend wins)"
|
||||
# "Re-apply the backend deltas that do not overlap the frontend's changed paths on top"
|
||||
|
||||
# Let's get the state as it was at client_base_seq if possible?
|
||||
# No, history only has patches.
|
||||
|
||||
# Alternative: Apply frontend patch to current model.
|
||||
# Then re-apply backend patches, but discard parts that overlap.
|
||||
|
||||
frontend_paths = {p['path'] for p in patch}
|
||||
|
||||
current_state = entry.model.model_dump(mode="json")
|
||||
# Apply frontend patch
|
||||
try:
|
||||
new_state = jsonpatch.apply_patch(current_state, patch)
|
||||
except jsonpatch.JsonPatchConflict as e:
|
||||
logger.warning(f"apply_client_patch: Failed to apply client patch during conflict resolution: {e}. Sending snapshot to resync.")
|
||||
# Send snapshot to force resync
|
||||
if self.websocket:
|
||||
msg = SnapshotMessage(
|
||||
store=entry.store_name,
|
||||
seq=entry.seq,
|
||||
state=entry.model.model_dump(mode="json")
|
||||
)
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
return
|
||||
|
||||
# Re-apply backend patches that don't overlap
|
||||
for b_patch in backend_patches:
|
||||
filtered_b_patch = [op for op in b_patch if op['path'] not in frontend_paths]
|
||||
if filtered_b_patch:
|
||||
try:
|
||||
new_state = jsonpatch.apply_patch(new_state, filtered_b_patch)
|
||||
except jsonpatch.JsonPatchConflict as e:
|
||||
logger.warning(f"apply_client_patch: Failed to apply backend patch during conflict resolution: {e}. Skipping this patch.")
|
||||
continue
|
||||
|
||||
self._update_model(entry.model, new_state)
|
||||
|
||||
# Commit the result as a single new patch
|
||||
# We need to compute what changed from last_snapshot to new_state
|
||||
final_patch = jsonpatch.make_patch(entry.last_snapshot, new_state).patch
|
||||
if final_patch:
|
||||
entry.commit_patch(final_patch)
|
||||
# Broadcast resolved state as snapshot to converge
|
||||
if self.websocket:
|
||||
msg = SnapshotMessage(
|
||||
store=entry.store_name,
|
||||
seq=entry.seq,
|
||||
state=entry.model.model_dump(mode="json")
|
||||
)
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
except Exception as e:
|
||||
logger.error(f"apply_client_patch: Unexpected error: {e}. Sending snapshot to resync.", exc_info=True)
|
||||
# Send snapshot to force resync
|
||||
if self.websocket:
|
||||
msg = SnapshotMessage(
|
||||
store=entry.store_name,
|
||||
seq=entry.seq,
|
||||
state=entry.model.model_dump(mode="json")
|
||||
)
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
|
||||
def _update_model(self, model: BaseModel, new_data: Dict[str, Any]):
|
||||
# Update model fields in-place to preserve references
|
||||
# This is important for dict fields that may be referenced elsewhere
|
||||
for field_name, field_info in model.model_fields.items():
|
||||
if field_name in new_data:
|
||||
new_value = new_data[field_name]
|
||||
current_value = getattr(model, field_name)
|
||||
|
||||
# For dict fields, update in-place instead of replacing
|
||||
if isinstance(current_value, dict) and isinstance(new_value, dict):
|
||||
self._deep_update_dict(current_value, new_value)
|
||||
else:
|
||||
# For other types, just set the new value
|
||||
setattr(model, field_name, new_value)
|
||||
|
||||
def _deep_update_dict(self, target: dict, source: dict):
|
||||
"""Deep update target dict with source dict, preserving nested dict references."""
|
||||
# Remove keys that are in target but not in source
|
||||
keys_to_remove = set(target.keys()) - set(source.keys())
|
||||
for key in keys_to_remove:
|
||||
del target[key]
|
||||
|
||||
# Update or add keys from source
|
||||
for key, source_value in source.items():
|
||||
if key in target:
|
||||
target_value = target[key]
|
||||
# If both are dicts, recursively update
|
||||
if isinstance(target_value, dict) and isinstance(source_value, dict):
|
||||
self._deep_update_dict(target_value, source_value)
|
||||
else:
|
||||
# Replace the value
|
||||
target[key] = source_value
|
||||
else:
|
||||
# Add new key
|
||||
target[key] = source_value
|
||||
216
backend.old/src/trigger/PRIORITIES.md
Normal file
216
backend.old/src/trigger/PRIORITIES.md
Normal file
@@ -0,0 +1,216 @@
|
||||
# Priority System
|
||||
|
||||
Simple tuple-based priorities for deterministic execution ordering.
|
||||
|
||||
## Basic Concept
|
||||
|
||||
Priorities are just **Python tuples**. Python compares tuples element-by-element, left-to-right:
|
||||
|
||||
```python
|
||||
(0, 1000, 5) < (0, 1001, 3) # True: 0==0, but 1000 < 1001
|
||||
(0, 1000, 5) < (1, 500, 2) # True: 0 < 1
|
||||
(0, 1000) < (0, 1000, 5) # True: shorter wins if equal so far
|
||||
```
|
||||
|
||||
**Lower values = higher priority** (processed first).
|
||||
|
||||
## Priority Categories
|
||||
|
||||
```python
|
||||
class Priority(IntEnum):
|
||||
DATA_SOURCE = 0 # Market data, real-time feeds
|
||||
TIMER = 1 # Scheduled tasks, cron jobs
|
||||
USER_AGENT = 2 # User-agent interactions (chat)
|
||||
USER_DATA_REQUEST = 3 # User data requests (charts)
|
||||
SYSTEM = 4 # Background tasks, cleanup
|
||||
LOW = 5 # Retries after conflicts
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Simple Priority
|
||||
|
||||
```python
|
||||
# Just use the Priority enum
|
||||
trigger = MyTrigger("task", priority=Priority.SYSTEM)
|
||||
await queue.enqueue(trigger)
|
||||
|
||||
# Results in tuple: (4, queue_seq)
|
||||
```
|
||||
|
||||
### Compound Priority (Tuple)
|
||||
|
||||
```python
|
||||
# DataSource: sort by event time (older bars first)
|
||||
trigger = DataUpdateTrigger(
|
||||
source_name="binance",
|
||||
symbol="BTC/USDT",
|
||||
resolution="1m",
|
||||
bar_data={"time": 1678896000, "open": 50000, ...}
|
||||
)
|
||||
await queue.enqueue(trigger)
|
||||
|
||||
# Results in tuple: (0, 1678896000, queue_seq)
|
||||
# ^ ^ ^
|
||||
# | | Queue insertion order (FIFO)
|
||||
# | Event time (candle end time)
|
||||
# DATA_SOURCE priority
|
||||
```
|
||||
|
||||
### Manual Override
|
||||
|
||||
```python
|
||||
# Override at enqueue time
|
||||
await queue.enqueue(
|
||||
trigger,
|
||||
priority_override=(Priority.DATA_SOURCE, custom_time, custom_sort)
|
||||
)
|
||||
|
||||
# Queue appends queue_seq: (0, custom_time, custom_sort, queue_seq)
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Market Data (Process Chronologically)
|
||||
|
||||
```python
|
||||
# Bar from 10:00 → (0, 10:00_timestamp, queue_seq)
|
||||
# Bar from 10:05 → (0, 10:05_timestamp, queue_seq)
|
||||
#
|
||||
# 10:00 bar processes first (earlier event_time)
|
||||
|
||||
DataUpdateTrigger(
|
||||
...,
|
||||
bar_data={"time": event_timestamp, ...}
|
||||
)
|
||||
```
|
||||
|
||||
### User Messages (FIFO Order)
|
||||
|
||||
```python
|
||||
# Message #1 → (2, msg1_timestamp, queue_seq)
|
||||
# Message #2 → (2, msg2_timestamp, queue_seq)
|
||||
#
|
||||
# Message #1 processes first (earlier timestamp)
|
||||
|
||||
AgentTriggerHandler(
|
||||
session_id="user1",
|
||||
message_content="...",
|
||||
message_timestamp=unix_timestamp # Optional, defaults to now
|
||||
)
|
||||
```
|
||||
|
||||
### Scheduled Tasks (By Schedule Time)
|
||||
|
||||
```python
|
||||
# Job scheduled for 9 AM → (1, 9am_timestamp, queue_seq)
|
||||
# Job scheduled for 2 PM → (1, 2pm_timestamp, queue_seq)
|
||||
#
|
||||
# 9 AM job processes first
|
||||
|
||||
CronTrigger(
|
||||
name="morning_sync",
|
||||
inner_trigger=...,
|
||||
scheduled_time=scheduled_timestamp
|
||||
)
|
||||
```
|
||||
|
||||
## Execution Order Example
|
||||
|
||||
```
|
||||
Queue contains:
|
||||
1. DataSource (BTC @ 10:00) → (0, 10:00, 1)
|
||||
2. DataSource (BTC @ 10:05) → (0, 10:05, 2)
|
||||
3. Timer (scheduled 9 AM) → (1, 09:00, 3)
|
||||
4. User message #1 → (2, 14:30, 4)
|
||||
5. User message #2 → (2, 14:35, 5)
|
||||
|
||||
Dequeue order:
|
||||
1. DataSource (BTC @ 10:00) ← 0 < all others
|
||||
2. DataSource (BTC @ 10:05) ← 0 < all others, 10:05 > 10:00
|
||||
3. Timer (scheduled 9 AM) ← 1 < remaining
|
||||
4. User message #1 ← 2 < remaining, 14:30 < 14:35
|
||||
5. User message #2 ← last
|
||||
```
|
||||
|
||||
## Short Tuple Wins
|
||||
|
||||
If tuples are equal up to the length of the shorter one, **shorter tuple has higher priority**:
|
||||
|
||||
```python
|
||||
(0, 1000) < (0, 1000, 5) # True: shorter wins
|
||||
(0,) < (0, 1000) # True: shorter wins
|
||||
(Priority.DATA_SOURCE,) < (Priority.DATA_SOURCE, 1000) # True
|
||||
```
|
||||
|
||||
This is Python's default tuple comparison behavior. In practice, we always append `queue_seq`, so this rarely matters (all tuples end up same length).
|
||||
|
||||
## Integration with Triggers
|
||||
|
||||
### Trigger Sets Its Own Priority
|
||||
|
||||
```python
|
||||
class MyTrigger(Trigger):
|
||||
def __init__(self, event_time):
|
||||
super().__init__(
|
||||
name="my_trigger",
|
||||
priority=Priority.DATA_SOURCE,
|
||||
priority_tuple=(Priority.DATA_SOURCE.value, event_time)
|
||||
)
|
||||
```
|
||||
|
||||
Queue appends `queue_seq` automatically:
|
||||
```python
|
||||
# Trigger's tuple: (0, event_time)
|
||||
# After enqueue: (0, event_time, queue_seq)
|
||||
```
|
||||
|
||||
### Override at Enqueue
|
||||
|
||||
```python
|
||||
# Ignore trigger's priority, use override
|
||||
await queue.enqueue(
|
||||
trigger,
|
||||
priority_override=(Priority.TIMER, scheduled_time)
|
||||
)
|
||||
```
|
||||
|
||||
## Why Tuples?
|
||||
|
||||
✅ **Simple**: No custom classes, just native Python tuples
|
||||
✅ **Flexible**: Add as many sort keys as needed
|
||||
✅ **Efficient**: Python's tuple comparison is highly optimized
|
||||
✅ **Readable**: `(0, 1000, 5)` is obvious what it means
|
||||
✅ **Debuggable**: Can print and inspect easily
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Old: CompoundPriority(primary=0, secondary=1000, tertiary=5)
|
||||
# New: (0, 1000, 5)
|
||||
|
||||
# Same semantics, much simpler!
|
||||
```
|
||||
|
||||
## Advanced: Custom Sorting
|
||||
|
||||
Want to sort by multiple factors? Just add more elements:
|
||||
|
||||
```python
|
||||
# Sort by: priority → symbol → event_time → queue_seq
|
||||
priority_tuple = (
|
||||
Priority.DATA_SOURCE.value,
|
||||
symbol_id, # e.g., hash("BTC/USDT")
|
||||
event_time,
|
||||
# queue_seq appended by queue
|
||||
)
|
||||
```
|
||||
|
||||
## Summary
|
||||
|
||||
- **Priorities are tuples**: `(primary, secondary, ..., queue_seq)`
|
||||
- **Lower = higher priority**: Processed first
|
||||
- **Element-by-element comparison**: Left-to-right
|
||||
- **Shorter tuple wins**: If equal up to shorter length
|
||||
- **Queue appends queue_seq**: Always last element (FIFO within same priority)
|
||||
|
||||
That's it! No complex classes, just tuples. 🎯
|
||||
386
backend.old/src/trigger/README.md
Normal file
386
backend.old/src/trigger/README.md
Normal file
@@ -0,0 +1,386 @@
|
||||
# Trigger System
|
||||
|
||||
Lock-free, sequence-based execution system for deterministic event processing.
|
||||
|
||||
## Overview
|
||||
|
||||
All operations (WebSocket messages, cron tasks, data updates) flow through a **priority queue**, execute in **parallel**, but commit in **strict sequential order** with **optimistic conflict detection**.
|
||||
|
||||
### Key Features
|
||||
|
||||
- **Lock-free reads**: Snapshots are deep copies, no blocking
|
||||
- **Sequential commits**: Total ordering via sequence numbers
|
||||
- **Optimistic concurrency**: Conflicts detected, retry with same seq
|
||||
- **Priority preservation**: High-priority work never blocked by low-priority
|
||||
- **Long-running agents**: Execute in parallel, commit sequentially
|
||||
- **Deterministic replay**: Can reproduce exact system state at any seq
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────┐
|
||||
│ WebSocket │───┐
|
||||
│ Messages │ │
|
||||
└─────────────┘ │
|
||||
├──→ ┌─────────────────┐
|
||||
┌─────────────┐ │ │ TriggerQueue │
|
||||
│ Cron │───┤ │ (Priority Queue)│
|
||||
│ Scheduled │ │ └────────┬────────┘
|
||||
└─────────────┘ │ │ Assign seq
|
||||
│ ↓
|
||||
┌─────────────┐ │ ┌─────────────────┐
|
||||
│ DataSource │───┘ │ Execute Trigger│
|
||||
│ Updates │ │ (Parallel OK) │
|
||||
└─────────────┘ └────────┬────────┘
|
||||
│ CommitIntents
|
||||
↓
|
||||
┌─────────────────┐
|
||||
│ CommitCoordinator│
|
||||
│ (Sequential) │
|
||||
└────────┬────────┘
|
||||
│ Commit in seq order
|
||||
↓
|
||||
┌─────────────────┐
|
||||
│ VersionedStores │
|
||||
│ (w/ Backends) │
|
||||
└─────────────────┘
|
||||
```
|
||||
|
||||
## Core Components
|
||||
|
||||
### 1. ExecutionContext (`context.py`)
|
||||
|
||||
Tracks execution seq and store snapshots via `contextvars` (auto-propagates through async calls).
|
||||
|
||||
```python
|
||||
from trigger import get_execution_context
|
||||
|
||||
ctx = get_execution_context()
|
||||
print(f"Running at seq {ctx.seq}")
|
||||
```
|
||||
|
||||
### 2. Trigger Types (`types.py`)
|
||||
|
||||
```python
|
||||
from trigger import Trigger, Priority, CommitIntent
|
||||
|
||||
class MyTrigger(Trigger):
|
||||
async def execute(self) -> list[CommitIntent]:
|
||||
# Read snapshot
|
||||
seq, data = some_store.read_snapshot()
|
||||
|
||||
# Modify
|
||||
new_data = modify(data)
|
||||
|
||||
# Prepare commit
|
||||
intent = some_store.prepare_commit(seq, new_data)
|
||||
return [intent]
|
||||
```
|
||||
|
||||
### 3. VersionedStore (`store.py`)
|
||||
|
||||
Stores with pluggable backends and optimistic concurrency:
|
||||
|
||||
```python
|
||||
from trigger import VersionedStore, PydanticStoreBackend
|
||||
|
||||
# Wrap existing Pydantic model
|
||||
backend = PydanticStoreBackend(order_store)
|
||||
versioned_store = VersionedStore("OrderStore", backend)
|
||||
|
||||
# Lock-free snapshot read
|
||||
seq, snapshot = versioned_store.read_snapshot()
|
||||
|
||||
# Prepare commit (does not modify yet)
|
||||
intent = versioned_store.prepare_commit(seq, modified_snapshot)
|
||||
```
|
||||
|
||||
**Pluggable Backends**:
|
||||
- `PydanticStoreBackend`: For existing Pydantic models (OrderStore, ChartStore, etc.)
|
||||
- `FileStoreBackend`: Future - version files (Python scripts, configs)
|
||||
- `DatabaseStoreBackend`: Future - version database rows
|
||||
|
||||
### 4. CommitCoordinator (`coordinator.py`)
|
||||
|
||||
Manages sequential commits with conflict detection:
|
||||
|
||||
- Waits for seq N to commit before N+1
|
||||
- Detects conflicts (expected_seq vs committed_seq)
|
||||
- Re-executes (not re-enqueues) on conflict **with same seq**
|
||||
- Tracks execution state for debugging
|
||||
|
||||
### 5. TriggerQueue (`queue.py`)
|
||||
|
||||
Priority queue with seq assignment:
|
||||
|
||||
```python
|
||||
from trigger import TriggerQueue
|
||||
|
||||
queue = TriggerQueue(coordinator)
|
||||
await queue.start()
|
||||
|
||||
# Enqueue trigger
|
||||
await queue.enqueue(my_trigger, Priority.HIGH)
|
||||
```
|
||||
|
||||
### 6. TriggerScheduler (`scheduler.py`)
|
||||
|
||||
APScheduler integration for cron triggers:
|
||||
|
||||
```python
|
||||
from trigger.scheduler import TriggerScheduler
|
||||
|
||||
scheduler = TriggerScheduler(queue)
|
||||
scheduler.start()
|
||||
|
||||
# Every 5 minutes
|
||||
scheduler.schedule_interval(
|
||||
IndicatorUpdateTrigger("rsi_14"),
|
||||
minutes=5
|
||||
)
|
||||
|
||||
# Daily at 9 AM
|
||||
scheduler.schedule_cron(
|
||||
SyncExchangeStateTrigger(),
|
||||
hour="9",
|
||||
minute="0"
|
||||
)
|
||||
```
|
||||
|
||||
## Integration Example
|
||||
|
||||
### Basic Setup in `main.py`
|
||||
|
||||
```python
|
||||
from trigger import (
|
||||
CommitCoordinator,
|
||||
TriggerQueue,
|
||||
VersionedStore,
|
||||
PydanticStoreBackend,
|
||||
)
|
||||
from trigger.scheduler import TriggerScheduler
|
||||
|
||||
# Create coordinator
|
||||
coordinator = CommitCoordinator()
|
||||
|
||||
# Wrap existing stores
|
||||
order_store_versioned = VersionedStore(
|
||||
"OrderStore",
|
||||
PydanticStoreBackend(order_store)
|
||||
)
|
||||
coordinator.register_store(order_store_versioned)
|
||||
|
||||
chart_store_versioned = VersionedStore(
|
||||
"ChartStore",
|
||||
PydanticStoreBackend(chart_store)
|
||||
)
|
||||
coordinator.register_store(chart_store_versioned)
|
||||
|
||||
# Create queue and scheduler
|
||||
trigger_queue = TriggerQueue(coordinator)
|
||||
await trigger_queue.start()
|
||||
|
||||
scheduler = TriggerScheduler(trigger_queue)
|
||||
scheduler.start()
|
||||
```
|
||||
|
||||
### WebSocket Message Handler
|
||||
|
||||
```python
|
||||
from trigger.handlers import AgentTriggerHandler
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
|
||||
if data["type"] == "agent_user_message":
|
||||
# Enqueue agent trigger instead of direct Gateway call
|
||||
trigger = AgentTriggerHandler(
|
||||
session_id=data["session_id"],
|
||||
message_content=data["content"],
|
||||
gateway_handler=gateway.route_user_message,
|
||||
coordinator=coordinator,
|
||||
)
|
||||
await trigger_queue.enqueue(trigger)
|
||||
```
|
||||
|
||||
### DataSource Updates
|
||||
|
||||
```python
|
||||
from trigger.handlers import DataUpdateTrigger
|
||||
|
||||
# In subscription_manager._on_source_update()
|
||||
def _on_source_update(self, source_key: tuple, bar: dict):
|
||||
# Enqueue data update trigger
|
||||
trigger = DataUpdateTrigger(
|
||||
source_name=source_key[0],
|
||||
symbol=source_key[1],
|
||||
resolution=source_key[2],
|
||||
bar_data=bar,
|
||||
coordinator=coordinator,
|
||||
)
|
||||
asyncio.create_task(trigger_queue.enqueue(trigger))
|
||||
```
|
||||
|
||||
### Custom Trigger
|
||||
|
||||
```python
|
||||
from trigger import Trigger, CommitIntent, Priority
|
||||
|
||||
class RecalculatePortfolioTrigger(Trigger):
|
||||
def __init__(self, coordinator):
|
||||
super().__init__("recalc_portfolio", Priority.NORMAL)
|
||||
self.coordinator = coordinator
|
||||
|
||||
async def execute(self) -> list[CommitIntent]:
|
||||
# Read snapshots from multiple stores
|
||||
order_seq, orders = self.coordinator.get_store("OrderStore").read_snapshot()
|
||||
chart_seq, chart = self.coordinator.get_store("ChartStore").read_snapshot()
|
||||
|
||||
# Calculate portfolio value
|
||||
portfolio_value = calculate_portfolio(orders, chart)
|
||||
|
||||
# Update chart state with portfolio value
|
||||
chart.portfolio_value = portfolio_value
|
||||
|
||||
# Prepare commit
|
||||
intent = self.coordinator.get_store("ChartStore").prepare_commit(
|
||||
chart_seq,
|
||||
chart
|
||||
)
|
||||
|
||||
return [intent]
|
||||
|
||||
# Schedule it
|
||||
scheduler.schedule_interval(
|
||||
RecalculatePortfolioTrigger(coordinator),
|
||||
minutes=1
|
||||
)
|
||||
```
|
||||
|
||||
## Execution Flow
|
||||
|
||||
### Normal Flow (No Conflicts)
|
||||
|
||||
```
|
||||
seq=100: WebSocket message arrives → enqueue → dequeue → assign seq=100 → execute
|
||||
seq=101: Cron trigger fires → enqueue → dequeue → assign seq=101 → execute
|
||||
|
||||
seq=101 finishes first → waits in commit queue
|
||||
seq=100 finishes → commits immediately (next in order)
|
||||
seq=101 commits next
|
||||
```
|
||||
|
||||
### Conflict Flow
|
||||
|
||||
```
|
||||
seq=100: reads OrderStore at seq=99 → executes for 30 seconds
|
||||
seq=101: reads OrderStore at seq=99 → executes for 5 seconds
|
||||
|
||||
seq=101 finishes first → tries to commit based on seq=99
|
||||
seq=100 finishes → commits OrderStore at seq=100
|
||||
|
||||
Coordinator detects conflict:
|
||||
expected_seq=99, committed_seq=100
|
||||
|
||||
seq=101 evicted → RE-EXECUTES with same seq=101 (not re-enqueued)
|
||||
reads OrderStore at seq=100 → executes again
|
||||
finishes → commits successfully at seq=101
|
||||
```
|
||||
|
||||
## Benefits
|
||||
|
||||
### For Agent System
|
||||
|
||||
- **Long-running agents work naturally**: Agent starts at seq=100, runs for 60 seconds while market data updates at seq=101-110, commits only if no conflicts
|
||||
- **No deadlocks**: No locks = no deadlock possibility
|
||||
- **Deterministic**: Can replay from any seq for debugging
|
||||
|
||||
### For Strategy Execution
|
||||
|
||||
- **High-frequency data doesn't block strategies**: Data updates enqueued, executed in parallel, commit sequentially
|
||||
- **Priority preservation**: Critical order execution never blocked by indicator calculations
|
||||
- **Conflict detection**: If market moved during strategy calculation, automatically retry with fresh data
|
||||
|
||||
### For Scaling
|
||||
|
||||
- **Single-node first**: Runs on single asyncio event loop, no complex distributed coordination
|
||||
- **Future-proof**: Can swap queue for Redis/PostgreSQL-backed distributed queue later
|
||||
- **Event sourcing ready**: All commits have seq numbers, can build event log
|
||||
|
||||
## Debugging
|
||||
|
||||
### Check Current State
|
||||
|
||||
```python
|
||||
# Coordinator stats
|
||||
stats = coordinator.get_stats()
|
||||
print(f"Current seq: {stats['current_seq']}")
|
||||
print(f"Pending commits: {stats['pending_commits']}")
|
||||
print(f"Executions by state: {stats['state_counts']}")
|
||||
|
||||
# Store state
|
||||
store = coordinator.get_store("OrderStore")
|
||||
print(f"Store: {store}") # Shows committed_seq and version
|
||||
|
||||
# Execution record
|
||||
record = coordinator.get_execution_record(100)
|
||||
print(f"Seq 100: {record}") # Shows state, retry_count, error
|
||||
```
|
||||
|
||||
### Common Issues
|
||||
|
||||
**Symptoms: High conflict rate**
|
||||
- **Cause**: Multiple triggers modifying same store frequently
|
||||
- **Solution**: Batch updates, use debouncing, or redesign to reduce contention
|
||||
|
||||
**Symptoms: Commits stuck (next_commit_seq not advancing)**
|
||||
- **Cause**: Execution at that seq failed or is taking too long
|
||||
- **Solution**: Check execution_records for that seq, look for errors in logs
|
||||
|
||||
**Symptoms: Queue depth growing**
|
||||
- **Cause**: Executions slower than enqueue rate
|
||||
- **Solution**: Profile trigger execution, optimize slow paths, add rate limiting
|
||||
|
||||
## Testing
|
||||
|
||||
### Unit Test: Conflict Detection
|
||||
|
||||
```python
|
||||
import pytest
|
||||
from trigger import VersionedStore, PydanticStoreBackend, CommitCoordinator
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conflict_detection():
|
||||
coordinator = CommitCoordinator()
|
||||
|
||||
store = VersionedStore("TestStore", PydanticStoreBackend(TestModel()))
|
||||
coordinator.register_store(store)
|
||||
|
||||
# Seq 1: read at 0, modify, commit
|
||||
seq1, data1 = store.read_snapshot()
|
||||
data1.value = "seq1"
|
||||
intent1 = store.prepare_commit(seq1, data1)
|
||||
|
||||
# Seq 2: read at 0 (same snapshot), modify
|
||||
seq2, data2 = store.read_snapshot()
|
||||
data2.value = "seq2"
|
||||
intent2 = store.prepare_commit(seq2, data2)
|
||||
|
||||
# Commit seq 1 (should succeed)
|
||||
# ... coordinator logic ...
|
||||
|
||||
# Commit seq 2 (should conflict and retry)
|
||||
# ... verify conflict detected ...
|
||||
```
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- **Distributed queue**: Redis-backed queue for multi-worker deployment
|
||||
- **Event log persistence**: Store all commits for event sourcing/audit
|
||||
- **Metrics dashboard**: Real-time view of queue depth, conflict rate, latency
|
||||
- **Transaction snapshots**: Full system state at any seq for replay/debugging
|
||||
- **Automatic batching**: Coalesce rapid updates to same store
|
||||
35
backend.old/src/trigger/__init__.py
Normal file
35
backend.old/src/trigger/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Sequential execution trigger system with optimistic concurrency control.
|
||||
|
||||
All operations (websocket, cron, data events) flow through a priority queue,
|
||||
execute in parallel, but commit in strict sequential order with conflict detection.
|
||||
"""
|
||||
|
||||
from .context import ExecutionContext, get_execution_context
|
||||
from .types import Priority, PriorityTuple, Trigger, CommitIntent, ExecutionState
|
||||
from .store import VersionedStore, StoreBackend, PydanticStoreBackend
|
||||
from .coordinator import CommitCoordinator
|
||||
from .queue import TriggerQueue
|
||||
from .handlers import AgentTriggerHandler, LambdaHandler
|
||||
|
||||
__all__ = [
|
||||
# Context
|
||||
"ExecutionContext",
|
||||
"get_execution_context",
|
||||
# Types
|
||||
"Priority",
|
||||
"PriorityTuple",
|
||||
"Trigger",
|
||||
"CommitIntent",
|
||||
"ExecutionState",
|
||||
# Store
|
||||
"VersionedStore",
|
||||
"StoreBackend",
|
||||
"PydanticStoreBackend",
|
||||
# Coordination
|
||||
"CommitCoordinator",
|
||||
"TriggerQueue",
|
||||
# Handlers
|
||||
"AgentTriggerHandler",
|
||||
"LambdaHandler",
|
||||
]
|
||||
61
backend.old/src/trigger/context.py
Normal file
61
backend.old/src/trigger/context.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Execution context tracking using Python's contextvars.
|
||||
|
||||
Each execution gets a unique seq number that propagates through all async calls,
|
||||
allowing us to track which execution made which changes for conflict detection.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Context variables - automatically propagate through async call chains
|
||||
_execution_context: ContextVar[Optional["ExecutionContext"]] = ContextVar(
|
||||
"execution_context", default=None
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionContext:
|
||||
"""
|
||||
Execution context for a single trigger execution.
|
||||
|
||||
Automatically propagates through async calls via contextvars.
|
||||
Tracks the seq number and which store snapshots were read.
|
||||
"""
|
||||
|
||||
seq: int
|
||||
"""Sequential execution number - determines commit order"""
|
||||
|
||||
trigger_name: str
|
||||
"""Name/type of trigger being executed"""
|
||||
|
||||
snapshot_seqs: dict[str, int] = field(default_factory=dict)
|
||||
"""Store name -> seq number of snapshot that was read"""
|
||||
|
||||
def record_snapshot(self, store_name: str, snapshot_seq: int) -> None:
|
||||
"""Record that we read a snapshot from a store at a specific seq"""
|
||||
self.snapshot_seqs[store_name] = snapshot_seq
|
||||
logger.debug(f"Seq {self.seq}: Read {store_name} at seq {snapshot_seq}")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ExecutionContext(seq={self.seq}, trigger={self.trigger_name})"
|
||||
|
||||
|
||||
def get_execution_context() -> Optional[ExecutionContext]:
|
||||
"""Get the current execution context, or None if not in an execution"""
|
||||
return _execution_context.get()
|
||||
|
||||
|
||||
def set_execution_context(ctx: ExecutionContext) -> None:
|
||||
"""Set the execution context for the current async task"""
|
||||
_execution_context.set(ctx)
|
||||
logger.debug(f"Set execution context: {ctx}")
|
||||
|
||||
|
||||
def clear_execution_context() -> None:
|
||||
"""Clear the execution context"""
|
||||
_execution_context.set(None)
|
||||
302
backend.old/src/trigger/coordinator.py
Normal file
302
backend.old/src/trigger/coordinator.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""
|
||||
Commit coordinator - manages sequential commits with conflict detection.
|
||||
|
||||
Ensures that commits happen in strict sequence order, even when executions
|
||||
complete out of order. Detects conflicts and triggers re-execution with the
|
||||
same seq number (not re-enqueue, just re-execute).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from .context import ExecutionContext
|
||||
from .store import VersionedStore
|
||||
from .types import CommitIntent, ExecutionRecord, ExecutionState, Trigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommitCoordinator:
|
||||
"""
|
||||
Manages sequential commits with optimistic concurrency control.
|
||||
|
||||
Key responsibilities:
|
||||
- Maintain strict sequential commit order (seq N+1 commits after seq N)
|
||||
- Detect conflicts between execution snapshot and committed state
|
||||
- Trigger re-execution (not re-enqueue) on conflicts with same seq
|
||||
- Track in-flight executions for debugging and monitoring
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._stores: dict[str, VersionedStore] = {}
|
||||
self._current_seq = 0 # Highest committed seq across all operations
|
||||
self._next_commit_seq = 1 # Next seq we're waiting to commit
|
||||
self._pending_commits: dict[int, tuple[ExecutionRecord, list[CommitIntent]]] = {}
|
||||
self._execution_records: dict[int, ExecutionRecord] = {}
|
||||
self._lock = asyncio.Lock() # Only for coordinator internal state, not stores
|
||||
|
||||
def register_store(self, store: VersionedStore) -> None:
|
||||
"""Register a versioned store with the coordinator"""
|
||||
self._stores[store.name] = store
|
||||
logger.info(f"Registered store: {store.name}")
|
||||
|
||||
def get_store(self, name: str) -> Optional[VersionedStore]:
|
||||
"""Get a registered store by name"""
|
||||
return self._stores.get(name)
|
||||
|
||||
async def start_execution(self, seq: int, trigger: Trigger) -> ExecutionRecord:
|
||||
"""
|
||||
Record that an execution is starting.
|
||||
|
||||
Args:
|
||||
seq: Sequence number assigned to this execution
|
||||
trigger: The trigger being executed
|
||||
|
||||
Returns:
|
||||
ExecutionRecord for tracking
|
||||
"""
|
||||
async with self._lock:
|
||||
record = ExecutionRecord(
|
||||
seq=seq,
|
||||
trigger=trigger,
|
||||
state=ExecutionState.EXECUTING,
|
||||
)
|
||||
self._execution_records[seq] = record
|
||||
logger.info(f"Started execution: seq={seq}, trigger={trigger.name}")
|
||||
return record
|
||||
|
||||
async def submit_for_commit(
|
||||
self,
|
||||
seq: int,
|
||||
commit_intents: list[CommitIntent],
|
||||
) -> None:
|
||||
"""
|
||||
Submit commit intents for sequential commit.
|
||||
|
||||
The commit will only happen when:
|
||||
1. All prior seq numbers have committed
|
||||
2. No conflicts detected with committed state
|
||||
|
||||
Args:
|
||||
seq: Sequence number of this execution
|
||||
commit_intents: List of changes to commit (empty if no changes)
|
||||
"""
|
||||
async with self._lock:
|
||||
record = self._execution_records.get(seq)
|
||||
if not record:
|
||||
logger.error(f"No execution record found for seq={seq}")
|
||||
return
|
||||
|
||||
record.state = ExecutionState.WAITING_COMMIT
|
||||
record.commit_intents = commit_intents
|
||||
self._pending_commits[seq] = (record, commit_intents)
|
||||
|
||||
logger.info(
|
||||
f"Seq {seq} submitted for commit with {len(commit_intents)} intents"
|
||||
)
|
||||
|
||||
# Try to process commits (this will handle sequential ordering)
|
||||
await self._process_commits()
|
||||
|
||||
async def _process_commits(self) -> None:
|
||||
"""
|
||||
Process pending commits in strict sequential order.
|
||||
|
||||
Only commits seq N if seq N-1 has already committed.
|
||||
Detects conflicts and triggers re-execution with same seq.
|
||||
"""
|
||||
while True:
|
||||
async with self._lock:
|
||||
# Check if next expected seq is ready to commit
|
||||
if self._next_commit_seq not in self._pending_commits:
|
||||
# Waiting for this seq to complete execution
|
||||
break
|
||||
|
||||
seq = self._next_commit_seq
|
||||
record, intents = self._pending_commits[seq]
|
||||
|
||||
logger.info(
|
||||
f"Processing commit for seq={seq} (current_seq={self._current_seq})"
|
||||
)
|
||||
|
||||
# Check for conflicts
|
||||
conflicts = self._check_conflicts(intents)
|
||||
|
||||
if conflicts:
|
||||
# Conflict detected - re-execute with same seq
|
||||
logger.warning(
|
||||
f"Seq {seq} has conflicts in stores: {conflicts}. Re-executing..."
|
||||
)
|
||||
|
||||
# Remove from pending (will be re-added when execution completes)
|
||||
del self._pending_commits[seq]
|
||||
|
||||
# Mark as evicted
|
||||
record.state = ExecutionState.EVICTED
|
||||
record.retry_count += 1
|
||||
|
||||
# Advance to next seq (this seq will be retried in background)
|
||||
self._next_commit_seq += 1
|
||||
self._current_seq += 1
|
||||
|
||||
# Trigger re-execution (outside lock)
|
||||
asyncio.create_task(self._retry_execution(record))
|
||||
|
||||
continue
|
||||
|
||||
# No conflicts - commit all intents atomically
|
||||
for intent in intents:
|
||||
store = self._stores.get(intent.store_name)
|
||||
if not store:
|
||||
logger.error(
|
||||
f"Seq {seq}: Store '{intent.store_name}' not found"
|
||||
)
|
||||
continue
|
||||
|
||||
store.commit(intent.new_data, seq)
|
||||
|
||||
# Mark as committed
|
||||
record.state = ExecutionState.COMMITTED
|
||||
del self._pending_commits[seq]
|
||||
|
||||
# Advance seq counters
|
||||
self._current_seq = seq
|
||||
self._next_commit_seq = seq + 1
|
||||
|
||||
logger.info(
|
||||
f"Committed seq={seq}, current_seq now {self._current_seq}"
|
||||
)
|
||||
|
||||
def _check_conflicts(self, intents: list[CommitIntent]) -> list[str]:
|
||||
"""
|
||||
Check if any commit intents conflict with current committed state.
|
||||
|
||||
Args:
|
||||
intents: List of commit intents to check
|
||||
|
||||
Returns:
|
||||
List of store names that have conflicts (empty if no conflicts)
|
||||
"""
|
||||
conflicts = []
|
||||
|
||||
for intent in intents:
|
||||
store = self._stores.get(intent.store_name)
|
||||
if not store:
|
||||
logger.error(f"Store '{intent.store_name}' not found during conflict check")
|
||||
continue
|
||||
|
||||
if store.check_conflict(intent.expected_seq):
|
||||
conflicts.append(intent.store_name)
|
||||
|
||||
return conflicts
|
||||
|
||||
async def _retry_execution(self, record: ExecutionRecord) -> None:
|
||||
"""
|
||||
Re-execute a trigger that had conflicts.
|
||||
|
||||
Executes with the SAME seq number (not re-enqueued, just re-executed).
|
||||
This ensures the execution order remains deterministic.
|
||||
|
||||
Args:
|
||||
record: Execution record to retry
|
||||
"""
|
||||
from .context import ExecutionContext, set_execution_context, clear_execution_context
|
||||
|
||||
logger.info(
|
||||
f"Retrying execution: seq={record.seq}, trigger={record.trigger.name}, "
|
||||
f"retry_count={record.retry_count}"
|
||||
)
|
||||
|
||||
# Set execution context for retry
|
||||
ctx = ExecutionContext(
|
||||
seq=record.seq,
|
||||
trigger_name=record.trigger.name,
|
||||
)
|
||||
set_execution_context(ctx)
|
||||
|
||||
try:
|
||||
# Re-execute trigger
|
||||
record.state = ExecutionState.EXECUTING
|
||||
commit_intents = await record.trigger.execute()
|
||||
|
||||
# Submit for commit again (with same seq)
|
||||
await self.submit_for_commit(record.seq, commit_intents)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retry execution failed for seq={record.seq}: {e}", exc_info=True
|
||||
)
|
||||
record.state = ExecutionState.FAILED
|
||||
record.error = str(e)
|
||||
|
||||
# Still need to advance past this seq
|
||||
async with self._lock:
|
||||
if record.seq == self._next_commit_seq:
|
||||
self._next_commit_seq += 1
|
||||
self._current_seq += 1
|
||||
|
||||
# Try to process any pending commits
|
||||
await self._process_commits()
|
||||
|
||||
finally:
|
||||
clear_execution_context()
|
||||
|
||||
async def execution_failed(self, seq: int, error: Exception) -> None:
|
||||
"""
|
||||
Mark an execution as failed.
|
||||
|
||||
Args:
|
||||
seq: Sequence number that failed
|
||||
error: The exception that caused the failure
|
||||
"""
|
||||
async with self._lock:
|
||||
record = self._execution_records.get(seq)
|
||||
if record:
|
||||
record.state = ExecutionState.FAILED
|
||||
record.error = str(error)
|
||||
|
||||
# Remove from pending if present
|
||||
self._pending_commits.pop(seq, None)
|
||||
|
||||
# If this is the next seq to commit, advance past it
|
||||
if seq == self._next_commit_seq:
|
||||
self._next_commit_seq += 1
|
||||
self._current_seq += 1
|
||||
|
||||
logger.info(
|
||||
f"Seq {seq} failed, advancing current_seq to {self._current_seq}"
|
||||
)
|
||||
|
||||
# Try to process any pending commits
|
||||
await self._process_commits()
|
||||
|
||||
def get_current_seq(self) -> int:
|
||||
"""Get the current committed sequence number"""
|
||||
return self._current_seq
|
||||
|
||||
def get_execution_record(self, seq: int) -> Optional[ExecutionRecord]:
|
||||
"""Get execution record for a specific seq"""
|
||||
return self._execution_records.get(seq)
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get statistics about the coordinator state"""
|
||||
state_counts = {}
|
||||
for record in self._execution_records.values():
|
||||
state_name = record.state.name
|
||||
state_counts[state_name] = state_counts.get(state_name, 0) + 1
|
||||
|
||||
return {
|
||||
"current_seq": self._current_seq,
|
||||
"next_commit_seq": self._next_commit_seq,
|
||||
"pending_commits": len(self._pending_commits),
|
||||
"total_executions": len(self._execution_records),
|
||||
"state_counts": state_counts,
|
||||
"stores": {name: str(store) for name, store in self._stores.items()},
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"CommitCoordinator(current_seq={self._current_seq}, "
|
||||
f"pending={len(self._pending_commits)}, stores={len(self._stores)})"
|
||||
)
|
||||
304
backend.old/src/trigger/handlers.py
Normal file
304
backend.old/src/trigger/handlers.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Trigger handlers - concrete implementations for common trigger types.
|
||||
|
||||
Provides ready-to-use trigger handlers for:
|
||||
- Agent execution (WebSocket user messages)
|
||||
- Lambda/callable execution
|
||||
- Data update triggers
|
||||
- Indicator updates
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Awaitable, Callable, Optional
|
||||
|
||||
from .coordinator import CommitCoordinator
|
||||
from .types import CommitIntent, Priority, Trigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentTriggerHandler(Trigger):
|
||||
"""
|
||||
Trigger for agent execution from WebSocket user messages.
|
||||
|
||||
Wraps the Gateway's agent execution flow and captures any
|
||||
store modifications as commit intents.
|
||||
|
||||
Priority tuple: (USER_AGENT, message_timestamp, queue_seq)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
message_content: str,
|
||||
message_timestamp: Optional[int] = None,
|
||||
attachments: Optional[list] = None,
|
||||
gateway_handler: Optional[Callable] = None,
|
||||
coordinator: Optional[CommitCoordinator] = None,
|
||||
):
|
||||
"""
|
||||
Initialize agent trigger.
|
||||
|
||||
Args:
|
||||
session_id: User session ID
|
||||
message_content: User message content
|
||||
message_timestamp: When user sent message (unix timestamp, defaults to now)
|
||||
attachments: Optional message attachments
|
||||
gateway_handler: Callable to route to Gateway (set during integration)
|
||||
coordinator: CommitCoordinator for accessing stores
|
||||
"""
|
||||
if message_timestamp is None:
|
||||
message_timestamp = int(time.time())
|
||||
|
||||
# Priority tuple: sort by USER_AGENT priority, then message timestamp
|
||||
super().__init__(
|
||||
name=f"agent_{session_id}",
|
||||
priority=Priority.USER_AGENT,
|
||||
priority_tuple=(Priority.USER_AGENT.value, message_timestamp)
|
||||
)
|
||||
self.session_id = session_id
|
||||
self.message_content = message_content
|
||||
self.message_timestamp = message_timestamp
|
||||
self.attachments = attachments or []
|
||||
self.gateway_handler = gateway_handler
|
||||
self.coordinator = coordinator
|
||||
|
||||
async def execute(self) -> list[CommitIntent]:
|
||||
"""
|
||||
Execute agent interaction.
|
||||
|
||||
This will call into the Gateway, which will run the agent.
|
||||
The agent may read from stores and generate responses.
|
||||
Any store modifications are captured as commit intents.
|
||||
|
||||
Returns:
|
||||
List of commit intents (typically empty for now, as agent
|
||||
modifies stores via tools which will be integrated later)
|
||||
"""
|
||||
if not self.gateway_handler:
|
||||
logger.error("No gateway_handler configured for AgentTriggerHandler")
|
||||
return []
|
||||
|
||||
logger.info(
|
||||
f"Agent trigger executing: session={self.session_id}, "
|
||||
f"content='{self.message_content[:50]}...'"
|
||||
)
|
||||
|
||||
try:
|
||||
# Call Gateway to handle message
|
||||
# In future, Gateway/agent tools will use coordinator stores
|
||||
await self.gateway_handler(
|
||||
self.session_id,
|
||||
self.message_content,
|
||||
self.attachments,
|
||||
)
|
||||
|
||||
# For now, agent doesn't directly modify stores
|
||||
# Future: agent tools will return commit intents
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent execution error: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
class LambdaHandler(Trigger):
|
||||
"""
|
||||
Generic trigger that executes an arbitrary async callable.
|
||||
|
||||
Useful for custom triggers, one-off tasks, or testing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
func: Callable[[], Awaitable[list[CommitIntent]]],
|
||||
priority: Priority = Priority.SYSTEM,
|
||||
):
|
||||
"""
|
||||
Initialize lambda handler.
|
||||
|
||||
Args:
|
||||
name: Descriptive name for this trigger
|
||||
func: Async callable that returns commit intents
|
||||
priority: Execution priority
|
||||
"""
|
||||
super().__init__(name, priority)
|
||||
self.func = func
|
||||
|
||||
async def execute(self) -> list[CommitIntent]:
|
||||
"""Execute the callable"""
|
||||
logger.info(f"Lambda trigger executing: {self.name}")
|
||||
return await self.func()
|
||||
|
||||
|
||||
class DataUpdateTrigger(Trigger):
|
||||
"""
|
||||
Trigger for DataSource bar updates.
|
||||
|
||||
Fired when new market data arrives. Can update indicators,
|
||||
trigger strategy logic, or notify the agent of market events.
|
||||
|
||||
Priority tuple: (DATA_SOURCE, event_time, queue_seq)
|
||||
Ensures older bars process before newer ones.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source_name: str,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
bar_data: dict,
|
||||
coordinator: Optional[CommitCoordinator] = None,
|
||||
):
|
||||
"""
|
||||
Initialize data update trigger.
|
||||
|
||||
Args:
|
||||
source_name: Name of data source (e.g., "binance")
|
||||
symbol: Trading pair symbol
|
||||
resolution: Time resolution
|
||||
bar_data: Bar data dict (time, open, high, low, close, volume)
|
||||
coordinator: CommitCoordinator for accessing stores
|
||||
"""
|
||||
event_time = bar_data.get('time', int(time.time()))
|
||||
|
||||
# Priority tuple: sort by DATA_SOURCE priority, then event time
|
||||
super().__init__(
|
||||
name=f"data_{source_name}_{symbol}_{resolution}",
|
||||
priority=Priority.DATA_SOURCE,
|
||||
priority_tuple=(Priority.DATA_SOURCE.value, event_time)
|
||||
)
|
||||
self.source_name = source_name
|
||||
self.symbol = symbol
|
||||
self.resolution = resolution
|
||||
self.bar_data = bar_data
|
||||
self.coordinator = coordinator
|
||||
|
||||
async def execute(self) -> list[CommitIntent]:
|
||||
"""
|
||||
Process bar update.
|
||||
|
||||
Future implementations will:
|
||||
- Update indicator values
|
||||
- Check strategy conditions
|
||||
- Trigger alerts/notifications
|
||||
|
||||
Returns:
|
||||
Commit intents for any store updates
|
||||
"""
|
||||
logger.info(
|
||||
f"Data update trigger: {self.source_name}:{self.symbol}@{self.resolution}, "
|
||||
f"time={self.bar_data.get('time')}"
|
||||
)
|
||||
|
||||
# TODO: Update indicators
|
||||
# TODO: Check strategy conditions
|
||||
# TODO: Notify agent of significant events
|
||||
|
||||
# For now, just log
|
||||
return []
|
||||
|
||||
|
||||
class IndicatorUpdateTrigger(Trigger):
|
||||
"""
|
||||
Trigger for updating indicator values.
|
||||
|
||||
Can be fired by cron (periodic recalculation) or by data updates.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
indicator_id: str,
|
||||
force_full_recalc: bool = False,
|
||||
coordinator: Optional[CommitCoordinator] = None,
|
||||
priority: Priority = Priority.SYSTEM,
|
||||
):
|
||||
"""
|
||||
Initialize indicator update trigger.
|
||||
|
||||
Args:
|
||||
indicator_id: ID of indicator to update
|
||||
force_full_recalc: If True, recalculate entire history
|
||||
coordinator: CommitCoordinator for accessing stores
|
||||
priority: Execution priority
|
||||
"""
|
||||
super().__init__(f"indicator_{indicator_id}", priority)
|
||||
self.indicator_id = indicator_id
|
||||
self.force_full_recalc = force_full_recalc
|
||||
self.coordinator = coordinator
|
||||
|
||||
async def execute(self) -> list[CommitIntent]:
|
||||
"""
|
||||
Update indicator value.
|
||||
|
||||
Reads from IndicatorStore, recalculates, prepares commit.
|
||||
|
||||
Returns:
|
||||
Commit intents for updated indicator data
|
||||
"""
|
||||
if not self.coordinator:
|
||||
logger.error("No coordinator configured")
|
||||
return []
|
||||
|
||||
# Get indicator store
|
||||
indicator_store = self.coordinator.get_store("IndicatorStore")
|
||||
if not indicator_store:
|
||||
logger.error("IndicatorStore not registered")
|
||||
return []
|
||||
|
||||
# Read snapshot
|
||||
snapshot_seq, indicator_data = indicator_store.read_snapshot()
|
||||
|
||||
logger.info(
|
||||
f"Indicator update trigger: {self.indicator_id}, "
|
||||
f"snapshot_seq={snapshot_seq}, force_full={self.force_full_recalc}"
|
||||
)
|
||||
|
||||
# TODO: Implement indicator recalculation logic
|
||||
# For now, just return empty (no changes)
|
||||
|
||||
return []
|
||||
|
||||
|
||||
class CronTrigger(Trigger):
|
||||
"""
|
||||
Trigger fired by APScheduler on a schedule.
|
||||
|
||||
Wraps another trigger or callable to execute periodically.
|
||||
|
||||
Priority tuple: (TIMER, scheduled_time, queue_seq)
|
||||
Ensures jobs scheduled for earlier times run first.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
inner_trigger: Trigger,
|
||||
scheduled_time: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Initialize cron trigger.
|
||||
|
||||
Args:
|
||||
name: Descriptive name (e.g., "hourly_sync")
|
||||
inner_trigger: Trigger to execute on schedule
|
||||
scheduled_time: When this was scheduled to run (defaults to now)
|
||||
"""
|
||||
if scheduled_time is None:
|
||||
scheduled_time = int(time.time())
|
||||
|
||||
# Priority tuple: sort by TIMER priority, then scheduled time
|
||||
super().__init__(
|
||||
name=f"cron_{name}",
|
||||
priority=Priority.TIMER,
|
||||
priority_tuple=(Priority.TIMER.value, scheduled_time)
|
||||
)
|
||||
self.inner_trigger = inner_trigger
|
||||
self.scheduled_time = scheduled_time
|
||||
|
||||
async def execute(self) -> list[CommitIntent]:
|
||||
"""Execute the wrapped trigger"""
|
||||
logger.info(f"Cron trigger firing: {self.name}")
|
||||
return await self.inner_trigger.execute()
|
||||
224
backend.old/src/trigger/queue.py
Normal file
224
backend.old/src/trigger/queue.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
Trigger queue - priority queue with sequence number assignment.
|
||||
|
||||
All operations flow through this queue:
|
||||
- WebSocket messages from users
|
||||
- Cron scheduled tasks
|
||||
- DataSource bar updates
|
||||
- Manual triggers
|
||||
|
||||
Queue assigns seq numbers on dequeue, executes triggers, and submits to coordinator.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from .context import ExecutionContext, clear_execution_context, set_execution_context
|
||||
from .coordinator import CommitCoordinator
|
||||
from .types import Priority, PriorityTuple, Trigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerQueue:
|
||||
"""
|
||||
Priority queue for trigger execution.
|
||||
|
||||
Key responsibilities:
|
||||
- Maintain priority queue (high priority dequeued first)
|
||||
- Assign sequence numbers on dequeue (determines commit order)
|
||||
- Execute triggers with context set
|
||||
- Submit results to CommitCoordinator
|
||||
- Handle execution errors gracefully
|
||||
"""
|
||||
|
||||
def __init__(self, coordinator: CommitCoordinator):
|
||||
"""
|
||||
Initialize trigger queue.
|
||||
|
||||
Args:
|
||||
coordinator: CommitCoordinator for handling commits
|
||||
"""
|
||||
self._coordinator = coordinator
|
||||
self._queue: asyncio.PriorityQueue = asyncio.PriorityQueue()
|
||||
self._seq_counter = 0
|
||||
self._seq_lock = asyncio.Lock()
|
||||
self._processor_task: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the queue processor"""
|
||||
if self._running:
|
||||
logger.warning("TriggerQueue already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._processor_task = asyncio.create_task(self._process_loop())
|
||||
logger.info("TriggerQueue started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the queue processor gracefully"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
|
||||
if self._processor_task:
|
||||
self._processor_task.cancel()
|
||||
try:
|
||||
await self._processor_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("TriggerQueue stopped")
|
||||
|
||||
async def enqueue(
|
||||
self,
|
||||
trigger: Trigger,
|
||||
priority_override: Optional[Priority | PriorityTuple] = None
|
||||
) -> int:
|
||||
"""
|
||||
Add a trigger to the queue.
|
||||
|
||||
Args:
|
||||
trigger: Trigger to execute
|
||||
priority_override: Override priority (simple Priority or tuple)
|
||||
If None, uses trigger's priority/priority_tuple
|
||||
If Priority enum, creates single-element tuple
|
||||
If tuple, uses as-is
|
||||
|
||||
Returns:
|
||||
Queue sequence number (appended to priority tuple)
|
||||
|
||||
Examples:
|
||||
# Simple priority
|
||||
await queue.enqueue(trigger, Priority.USER_AGENT)
|
||||
# Results in: (Priority.USER_AGENT, queue_seq)
|
||||
|
||||
# Tuple priority with event time
|
||||
await queue.enqueue(
|
||||
trigger,
|
||||
(Priority.DATA_SOURCE, bar_data['time'])
|
||||
)
|
||||
# Results in: (Priority.DATA_SOURCE, bar_time, queue_seq)
|
||||
|
||||
# Let trigger decide
|
||||
await queue.enqueue(trigger)
|
||||
"""
|
||||
# Get monotonic seq for queue ordering (appended to tuple)
|
||||
async with self._seq_lock:
|
||||
queue_seq = self._seq_counter
|
||||
self._seq_counter += 1
|
||||
|
||||
# Determine priority tuple
|
||||
if priority_override is not None:
|
||||
if isinstance(priority_override, Priority):
|
||||
# Convert simple priority to tuple
|
||||
priority_tuple = (priority_override.value, queue_seq)
|
||||
else:
|
||||
# Use provided tuple, append queue_seq
|
||||
priority_tuple = priority_override + (queue_seq,)
|
||||
else:
|
||||
# Let trigger determine its own priority tuple
|
||||
priority_tuple = trigger.get_priority_tuple(queue_seq)
|
||||
|
||||
# Priority queue: (priority_tuple, trigger)
|
||||
# Python's PriorityQueue compares tuples element-by-element
|
||||
await self._queue.put((priority_tuple, trigger))
|
||||
|
||||
logger.debug(
|
||||
f"Enqueued: {trigger.name} with priority_tuple={priority_tuple}"
|
||||
)
|
||||
|
||||
return queue_seq
|
||||
|
||||
async def _process_loop(self) -> None:
|
||||
"""
|
||||
Main processing loop.
|
||||
|
||||
Dequeues triggers, assigns execution seq, executes, and submits to coordinator.
|
||||
"""
|
||||
execution_seq = 0 # Separate counter for execution sequence
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# Wait for next trigger (with timeout to check _running flag)
|
||||
try:
|
||||
priority_tuple, trigger = await asyncio.wait_for(
|
||||
self._queue.get(), timeout=1.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
# Assign execution sequence number
|
||||
execution_seq += 1
|
||||
|
||||
logger.info(
|
||||
f"Dequeued: seq={execution_seq}, trigger={trigger.name}, "
|
||||
f"priority_tuple={priority_tuple}"
|
||||
)
|
||||
|
||||
# Execute in background (don't block queue)
|
||||
asyncio.create_task(
|
||||
self._execute_trigger(execution_seq, trigger)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in process loop: {e}", exc_info=True)
|
||||
|
||||
async def _execute_trigger(self, seq: int, trigger: Trigger) -> None:
|
||||
"""
|
||||
Execute a trigger with proper context and error handling.
|
||||
|
||||
Args:
|
||||
seq: Execution sequence number
|
||||
trigger: Trigger to execute
|
||||
"""
|
||||
# Set up execution context
|
||||
ctx = ExecutionContext(
|
||||
seq=seq,
|
||||
trigger_name=trigger.name,
|
||||
)
|
||||
set_execution_context(ctx)
|
||||
|
||||
# Record execution start with coordinator
|
||||
await self._coordinator.start_execution(seq, trigger)
|
||||
|
||||
try:
|
||||
logger.info(f"Executing: seq={seq}, trigger={trigger.name}")
|
||||
|
||||
# Execute trigger (can be long-running)
|
||||
commit_intents = await trigger.execute()
|
||||
|
||||
logger.info(
|
||||
f"Execution complete: seq={seq}, {len(commit_intents)} commit intents"
|
||||
)
|
||||
|
||||
# Submit for sequential commit
|
||||
await self._coordinator.submit_for_commit(seq, commit_intents)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Execution failed: seq={seq}, trigger={trigger.name}, error={e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Notify coordinator of failure
|
||||
await self._coordinator.execution_failed(seq, e)
|
||||
|
||||
finally:
|
||||
clear_execution_context()
|
||||
|
||||
def get_queue_size(self) -> int:
|
||||
"""Get current queue size (approximate)"""
|
||||
return self._queue.qsize()
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""Check if queue processor is running"""
|
||||
return self._running
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"TriggerQueue(running={self._running}, queue_size={self.get_queue_size()})"
|
||||
)
|
||||
187
backend.old/src/trigger/scheduler.py
Normal file
187
backend.old/src/trigger/scheduler.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
APScheduler integration for cron-style triggers.
|
||||
|
||||
Provides scheduling of periodic triggers (e.g., sync exchange state hourly,
|
||||
recompute indicators every 5 minutes, daily portfolio reports).
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger as APSCronTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
|
||||
from .queue import TriggerQueue
|
||||
from .types import Priority, Trigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerScheduler:
|
||||
"""
|
||||
Scheduler for periodic trigger execution.
|
||||
|
||||
Wraps APScheduler to enqueue triggers at scheduled times.
|
||||
"""
|
||||
|
||||
def __init__(self, trigger_queue: TriggerQueue):
|
||||
"""
|
||||
Initialize scheduler.
|
||||
|
||||
Args:
|
||||
trigger_queue: TriggerQueue to enqueue triggers into
|
||||
"""
|
||||
self.trigger_queue = trigger_queue
|
||||
self.scheduler = AsyncIOScheduler()
|
||||
self._job_counter = 0
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the scheduler"""
|
||||
self.scheduler.start()
|
||||
logger.info("TriggerScheduler started")
|
||||
|
||||
def shutdown(self, wait: bool = True) -> None:
|
||||
"""
|
||||
Shut down the scheduler.
|
||||
|
||||
Args:
|
||||
wait: If True, wait for running jobs to complete
|
||||
"""
|
||||
self.scheduler.shutdown(wait=wait)
|
||||
logger.info("TriggerScheduler shut down")
|
||||
|
||||
def schedule_interval(
|
||||
self,
|
||||
trigger: Trigger,
|
||||
seconds: Optional[int] = None,
|
||||
minutes: Optional[int] = None,
|
||||
hours: Optional[int] = None,
|
||||
priority: Optional[Priority] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Schedule a trigger to run at regular intervals.
|
||||
|
||||
Args:
|
||||
trigger: Trigger to execute
|
||||
seconds: Interval in seconds
|
||||
minutes: Interval in minutes
|
||||
hours: Interval in hours
|
||||
priority: Priority override for execution
|
||||
|
||||
Returns:
|
||||
Job ID (can be used to remove job later)
|
||||
|
||||
Example:
|
||||
# Run every 5 minutes
|
||||
scheduler.schedule_interval(
|
||||
IndicatorUpdateTrigger("rsi_14"),
|
||||
minutes=5
|
||||
)
|
||||
"""
|
||||
job_id = f"interval_{self._job_counter}"
|
||||
self._job_counter += 1
|
||||
|
||||
async def job_func():
|
||||
await self.trigger_queue.enqueue(trigger, priority)
|
||||
|
||||
self.scheduler.add_job(
|
||||
job_func,
|
||||
trigger=IntervalTrigger(seconds=seconds, minutes=minutes, hours=hours),
|
||||
id=job_id,
|
||||
name=f"Interval: {trigger.name}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Scheduled interval job: {job_id}, trigger={trigger.name}, "
|
||||
f"interval=(s={seconds}, m={minutes}, h={hours})"
|
||||
)
|
||||
|
||||
return job_id
|
||||
|
||||
def schedule_cron(
|
||||
self,
|
||||
trigger: Trigger,
|
||||
minute: Optional[str] = None,
|
||||
hour: Optional[str] = None,
|
||||
day: Optional[str] = None,
|
||||
month: Optional[str] = None,
|
||||
day_of_week: Optional[str] = None,
|
||||
priority: Optional[Priority] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Schedule a trigger to run on a cron schedule.
|
||||
|
||||
Args:
|
||||
trigger: Trigger to execute
|
||||
minute: Minute expression (0-59, *, */5, etc.)
|
||||
hour: Hour expression (0-23, *, etc.)
|
||||
day: Day of month expression (1-31, *, etc.)
|
||||
month: Month expression (1-12, *, etc.)
|
||||
day_of_week: Day of week expression (0-6, mon-sun, *, etc.)
|
||||
priority: Priority override for execution
|
||||
|
||||
Returns:
|
||||
Job ID (can be used to remove job later)
|
||||
|
||||
Example:
|
||||
# Run at 9:00 AM every weekday
|
||||
scheduler.schedule_cron(
|
||||
SyncExchangeStateTrigger(),
|
||||
hour="9",
|
||||
minute="0",
|
||||
day_of_week="mon-fri"
|
||||
)
|
||||
"""
|
||||
job_id = f"cron_{self._job_counter}"
|
||||
self._job_counter += 1
|
||||
|
||||
async def job_func():
|
||||
await self.trigger_queue.enqueue(trigger, priority)
|
||||
|
||||
self.scheduler.add_job(
|
||||
job_func,
|
||||
trigger=APSCronTrigger(
|
||||
minute=minute,
|
||||
hour=hour,
|
||||
day=day,
|
||||
month=month,
|
||||
day_of_week=day_of_week,
|
||||
),
|
||||
id=job_id,
|
||||
name=f"Cron: {trigger.name}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Scheduled cron job: {job_id}, trigger={trigger.name}, "
|
||||
f"schedule=(m={minute}, h={hour}, d={day}, dow={day_of_week})"
|
||||
)
|
||||
|
||||
return job_id
|
||||
|
||||
def remove_job(self, job_id: str) -> bool:
|
||||
"""
|
||||
Remove a scheduled job.
|
||||
|
||||
Args:
|
||||
job_id: Job ID returned from schedule_* methods
|
||||
|
||||
Returns:
|
||||
True if job was removed, False if not found
|
||||
"""
|
||||
try:
|
||||
self.scheduler.remove_job(job_id)
|
||||
logger.info(f"Removed scheduled job: {job_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not remove job {job_id}: {e}")
|
||||
return False
|
||||
|
||||
def get_jobs(self) -> list:
|
||||
"""Get list of all scheduled jobs"""
|
||||
return self.scheduler.get_jobs()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
job_count = len(self.scheduler.get_jobs())
|
||||
running = self.scheduler.running
|
||||
return f"TriggerScheduler(running={running}, jobs={job_count})"
|
||||
301
backend.old/src/trigger/store.py
Normal file
301
backend.old/src/trigger/store.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""
|
||||
Versioned store with pluggable backends.
|
||||
|
||||
Provides optimistic concurrency control via sequence numbers with support
|
||||
for different storage backends (Pydantic models, files, databases, etc.).
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from .context import get_execution_context
|
||||
from .types import CommitIntent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class StoreBackend(ABC, Generic[T]):
|
||||
"""
|
||||
Abstract backend for versioned stores.
|
||||
|
||||
Allows different storage mechanisms (Pydantic models, files, databases)
|
||||
to be used with the same versioned store infrastructure.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def read(self) -> T:
|
||||
"""
|
||||
Read the current data.
|
||||
|
||||
Returns:
|
||||
Current data in backend-specific format
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def write(self, data: T) -> None:
|
||||
"""
|
||||
Write new data (replaces existing).
|
||||
|
||||
Args:
|
||||
data: New data to write
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def snapshot(self) -> T:
|
||||
"""
|
||||
Create an immutable snapshot of current data.
|
||||
|
||||
Must return a deep copy or immutable version to prevent
|
||||
modifications from affecting the committed state.
|
||||
|
||||
Returns:
|
||||
Immutable snapshot of data
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate(self, data: T) -> bool:
|
||||
"""
|
||||
Validate that data is in correct format for this backend.
|
||||
|
||||
Args:
|
||||
data: Data to validate
|
||||
|
||||
Returns:
|
||||
True if valid
|
||||
|
||||
Raises:
|
||||
ValueError: If invalid with explanation
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class PydanticStoreBackend(StoreBackend[T]):
|
||||
"""
|
||||
Backend for Pydantic BaseModel stores.
|
||||
|
||||
Supports the existing OrderStore, ChartStore, etc. pattern.
|
||||
"""
|
||||
|
||||
def __init__(self, model_instance: T):
|
||||
"""
|
||||
Initialize with a Pydantic model instance.
|
||||
|
||||
Args:
|
||||
model_instance: Instance of a Pydantic BaseModel
|
||||
"""
|
||||
self._model = model_instance
|
||||
|
||||
def read(self) -> T:
|
||||
return self._model
|
||||
|
||||
def write(self, data: T) -> None:
|
||||
# Replace the internal model
|
||||
self._model = data
|
||||
|
||||
def snapshot(self) -> T:
|
||||
# Use Pydantic's model_copy for deep copy
|
||||
if hasattr(self._model, "model_copy"):
|
||||
return self._model.model_copy(deep=True)
|
||||
# Fallback for older Pydantic or non-model types
|
||||
return deepcopy(self._model)
|
||||
|
||||
def validate(self, data: T) -> bool:
|
||||
# Pydantic models validate themselves on construction
|
||||
# If we got here with a model instance, it's valid
|
||||
return True
|
||||
|
||||
|
||||
class FileStoreBackend(StoreBackend[str]):
|
||||
"""
|
||||
Backend for file-based storage.
|
||||
|
||||
Future implementation for versioning files (e.g., Python scripts, configs).
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
self.file_path = file_path
|
||||
raise NotImplementedError("FileStoreBackend not yet implemented")
|
||||
|
||||
def read(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def write(self, data: str) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def snapshot(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def validate(self, data: str) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DatabaseStoreBackend(StoreBackend[dict]):
|
||||
"""
|
||||
Backend for database table storage.
|
||||
|
||||
Future implementation for versioning database interactions.
|
||||
"""
|
||||
|
||||
def __init__(self, table_name: str, connection):
|
||||
self.table_name = table_name
|
||||
self.connection = connection
|
||||
raise NotImplementedError("DatabaseStoreBackend not yet implemented")
|
||||
|
||||
def read(self) -> dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
def write(self, data: dict) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def snapshot(self) -> dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
def validate(self, data: dict) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class VersionedStore(Generic[T]):
|
||||
"""
|
||||
Store with optimistic concurrency control via sequence numbers.
|
||||
|
||||
Wraps any StoreBackend and provides:
|
||||
- Lock-free snapshot reads
|
||||
- Conflict detection on commit
|
||||
- Version tracking for debugging
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, backend: StoreBackend[T]):
|
||||
"""
|
||||
Initialize versioned store.
|
||||
|
||||
Args:
|
||||
name: Unique name for this store (e.g., "OrderStore")
|
||||
backend: Backend implementation for storage
|
||||
"""
|
||||
self.name = name
|
||||
self._backend = backend
|
||||
self._committed_seq = 0 # Highest committed seq
|
||||
self._version = 0 # Increments on each commit (for debugging)
|
||||
|
||||
@property
|
||||
def committed_seq(self) -> int:
|
||||
"""Get the current committed sequence number"""
|
||||
return self._committed_seq
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
"""Get the current version (increments on each commit)"""
|
||||
return self._version
|
||||
|
||||
def read_snapshot(self) -> tuple[int, T]:
|
||||
"""
|
||||
Read an immutable snapshot of the store.
|
||||
|
||||
This is lock-free and can be called concurrently. The snapshot
|
||||
captures the current committed seq and a deep copy of the data.
|
||||
|
||||
Automatically records the snapshot seq in the execution context
|
||||
for conflict detection during commit.
|
||||
|
||||
Returns:
|
||||
Tuple of (seq, snapshot_data)
|
||||
"""
|
||||
snapshot_seq = self._committed_seq
|
||||
snapshot_data = self._backend.snapshot()
|
||||
|
||||
# Record in execution context for conflict detection
|
||||
ctx = get_execution_context()
|
||||
if ctx:
|
||||
ctx.record_snapshot(self.name, snapshot_seq)
|
||||
|
||||
logger.debug(
|
||||
f"Store '{self.name}': read_snapshot() -> seq={snapshot_seq}, version={self._version}"
|
||||
)
|
||||
|
||||
return (snapshot_seq, snapshot_data)
|
||||
|
||||
def read_current(self) -> T:
|
||||
"""
|
||||
Read the current data without snapshot tracking.
|
||||
|
||||
Use this for read-only operations that don't need conflict detection.
|
||||
|
||||
Returns:
|
||||
Current data (not a snapshot, modifications visible)
|
||||
"""
|
||||
return self._backend.read()
|
||||
|
||||
def prepare_commit(self, expected_seq: int, new_data: T) -> CommitIntent:
|
||||
"""
|
||||
Create a commit intent for later sequential commit.
|
||||
|
||||
Does NOT modify the store - that happens during the commit phase.
|
||||
|
||||
Args:
|
||||
expected_seq: The seq of the snapshot that was read
|
||||
new_data: The new data to commit
|
||||
|
||||
Returns:
|
||||
CommitIntent to be submitted to CommitCoordinator
|
||||
"""
|
||||
# Validate data before creating intent
|
||||
self._backend.validate(new_data)
|
||||
|
||||
intent = CommitIntent(
|
||||
store_name=self.name,
|
||||
expected_seq=expected_seq,
|
||||
new_data=new_data,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Store '{self.name}': prepare_commit(expected_seq={expected_seq}, current_seq={self._committed_seq})"
|
||||
)
|
||||
|
||||
return intent
|
||||
|
||||
def commit(self, new_data: T, commit_seq: int) -> None:
|
||||
"""
|
||||
Commit new data at a specific seq.
|
||||
|
||||
Called by CommitCoordinator during sequential commit phase.
|
||||
NOT for direct use by triggers.
|
||||
|
||||
Args:
|
||||
new_data: Data to commit
|
||||
commit_seq: Seq number of this commit
|
||||
"""
|
||||
self._backend.write(new_data)
|
||||
self._committed_seq = commit_seq
|
||||
self._version += 1
|
||||
|
||||
logger.info(
|
||||
f"Store '{self.name}': committed seq={commit_seq}, version={self._version}"
|
||||
)
|
||||
|
||||
def check_conflict(self, expected_seq: int) -> bool:
|
||||
"""
|
||||
Check if committing at expected_seq would conflict.
|
||||
|
||||
Args:
|
||||
expected_seq: The seq that was expected during execution
|
||||
|
||||
Returns:
|
||||
True if conflict (committed_seq has advanced beyond expected_seq)
|
||||
"""
|
||||
has_conflict = self._committed_seq != expected_seq
|
||||
if has_conflict:
|
||||
logger.warning(
|
||||
f"Store '{self.name}': conflict detected - "
|
||||
f"expected_seq={expected_seq}, committed_seq={self._committed_seq}"
|
||||
)
|
||||
return has_conflict
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"VersionedStore(name='{self.name}', committed_seq={self._committed_seq}, version={self._version})"
|
||||
175
backend.old/src/trigger/types.py
Normal file
175
backend.old/src/trigger/types.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
Core types for the trigger system.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Priority(IntEnum):
|
||||
"""
|
||||
Primary execution priority for triggers.
|
||||
|
||||
Lower numeric value = higher priority (dequeued first).
|
||||
|
||||
Priority hierarchy (highest to lowest):
|
||||
- DATA_SOURCE: Market data, real-time feeds (most time-sensitive)
|
||||
- TIMER: Scheduled tasks, cron jobs
|
||||
- USER_AGENT: User-agent interactions (WebSocket chat)
|
||||
- USER_DATA_REQUEST: User data requests (chart loads, symbol search)
|
||||
- SYSTEM: Background tasks, cleanup
|
||||
- LOW: Retries after conflicts, non-critical tasks
|
||||
"""
|
||||
|
||||
DATA_SOURCE = 0 # Market data updates, real-time feeds
|
||||
TIMER = 1 # Scheduled tasks, cron jobs
|
||||
USER_AGENT = 2 # User-agent interactions (WebSocket chat)
|
||||
USER_DATA_REQUEST = 3 # User data requests (chart loads, etc.)
|
||||
SYSTEM = 4 # Background tasks, cleanup, etc.
|
||||
LOW = 5 # Retries after conflicts, non-critical tasks
|
||||
|
||||
|
||||
# Type alias for priority tuples
|
||||
# Examples:
|
||||
# (Priority.DATA_SOURCE,) - Simple priority
|
||||
# (Priority.DATA_SOURCE, event_time) - Priority + event time
|
||||
# (Priority.DATA_SOURCE, event_time, queue_seq) - Full ordering
|
||||
#
|
||||
# Python compares tuples element-by-element, left-to-right.
|
||||
# Shorter tuple wins if all shared elements are equal.
|
||||
PriorityTuple = tuple[int, ...]
|
||||
|
||||
|
||||
class ExecutionState(IntEnum):
|
||||
"""State of an execution in the system"""
|
||||
|
||||
QUEUED = 0 # In queue, waiting to be dequeued
|
||||
EXECUTING = 1 # Currently executing
|
||||
WAITING_COMMIT = 2 # Finished executing, waiting for sequential commit
|
||||
COMMITTED = 3 # Successfully committed
|
||||
EVICTED = 4 # Evicted due to conflict, will retry
|
||||
FAILED = 5 # Failed with error
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommitIntent:
|
||||
"""
|
||||
Intent to commit changes to a store.
|
||||
|
||||
Created during execution, validated and applied during sequential commit phase.
|
||||
"""
|
||||
|
||||
store_name: str
|
||||
"""Name of the store to commit to"""
|
||||
|
||||
expected_seq: int
|
||||
"""The seq number of the snapshot that was read (for conflict detection)"""
|
||||
|
||||
new_data: Any
|
||||
"""The new data to commit (format depends on store backend)"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
data_preview = str(self.new_data)[:50]
|
||||
return f"CommitIntent(store={self.store_name}, expected_seq={self.expected_seq}, data={data_preview}...)"
|
||||
|
||||
|
||||
class Trigger(ABC):
|
||||
"""
|
||||
Abstract base class for all triggers.
|
||||
|
||||
A trigger represents a unit of work that:
|
||||
1. Gets assigned a seq number when dequeued
|
||||
2. Executes (potentially long-running, async)
|
||||
3. Returns CommitIntents for any state changes
|
||||
4. Waits for sequential commit
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
priority: Priority = Priority.SYSTEM,
|
||||
priority_tuple: Optional[PriorityTuple] = None
|
||||
):
|
||||
"""
|
||||
Initialize trigger.
|
||||
|
||||
Args:
|
||||
name: Descriptive name for logging
|
||||
priority: Simple priority (used if priority_tuple not provided)
|
||||
priority_tuple: Optional tuple for compound sorting
|
||||
Examples:
|
||||
(Priority.DATA_SOURCE, event_time)
|
||||
(Priority.USER_AGENT, message_timestamp)
|
||||
(Priority.TIMER, scheduled_time)
|
||||
"""
|
||||
self.name = name
|
||||
self.priority = priority
|
||||
self._priority_tuple = priority_tuple
|
||||
|
||||
def get_priority_tuple(self, queue_seq: int) -> PriorityTuple:
|
||||
"""
|
||||
Get the priority tuple for queue ordering.
|
||||
|
||||
If a priority tuple was provided at construction, append queue_seq.
|
||||
Otherwise, create tuple from simple priority.
|
||||
|
||||
Args:
|
||||
queue_seq: Queue insertion order (final sort key)
|
||||
|
||||
Returns:
|
||||
Priority tuple for queue ordering
|
||||
|
||||
Examples:
|
||||
(Priority.DATA_SOURCE,) + (queue_seq,) = (0, queue_seq)
|
||||
(Priority.DATA_SOURCE, 1000) + (queue_seq,) = (0, 1000, queue_seq)
|
||||
"""
|
||||
if self._priority_tuple is not None:
|
||||
return self._priority_tuple + (queue_seq,)
|
||||
else:
|
||||
return (self.priority.value, queue_seq)
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self) -> list[CommitIntent]:
|
||||
"""
|
||||
Execute the trigger logic.
|
||||
|
||||
Can be long-running and async. Should read from stores via
|
||||
VersionedStore.read_snapshot() and return CommitIntents for any changes.
|
||||
|
||||
Returns:
|
||||
List of CommitIntents (empty if no state changes)
|
||||
|
||||
Raises:
|
||||
Exception: On execution failure (will be logged, no commit)
|
||||
"""
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(name='{self.name}', priority={self.priority.name})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionRecord:
|
||||
"""
|
||||
Record of an execution for tracking and debugging.
|
||||
|
||||
Maintained by the CommitCoordinator to track in-flight executions.
|
||||
"""
|
||||
|
||||
seq: int
|
||||
trigger: Trigger
|
||||
state: ExecutionState
|
||||
commit_intents: Optional[list[CommitIntent]] = None
|
||||
error: Optional[str] = None
|
||||
retry_count: int = 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"ExecutionRecord(seq={self.seq}, trigger={self.trigger.name}, "
|
||||
f"state={self.state.name}, retry={self.retry_count})"
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user