Compare commits

...

9 Commits

Author SHA1 Message Date
c76887ab92 chart data loading 2026-03-24 21:37:49 -04:00
f6bd22a8ef redesign fully scaffolded and web login works 2026-03-17 20:10:47 -04:00
b9cc397e05 container lifecycle management 2026-03-12 15:13:38 -04:00
e99ef5d2dd backend redesign 2026-03-11 18:47:11 -04:00
8ff277c8c6 triggers and execution queue; subagents 2026-03-04 21:27:41 -04:00
a50955558e indicator integration 2026-03-04 03:28:09 -04:00
185fa42caa deployed 0.1 2026-03-04 00:56:08 -04:00
bf7af2b426 shape editing 2026-03-02 22:49:45 -04:00
f4da40706c execute_python can load any data source 2026-03-02 18:48:54 -04:00
394 changed files with 51102 additions and 1022 deletions

View File

@@ -0,0 +1,175 @@
---
name: better-auth-best-practices
description: Configure Better Auth server and client, set up database adapters, manage sessions, add plugins, and handle environment variables. Use when users mention Better Auth, betterauth, auth.ts, or need to set up TypeScript authentication with email/password, OAuth, or plugin configuration.
---
# Better Auth Integration Guide
**Always consult [better-auth.com/docs](https://better-auth.com/docs) for code examples and latest API.**
---
## Setup Workflow
1. Install: `npm install better-auth`
2. Set env vars: `BETTER_AUTH_SECRET` and `BETTER_AUTH_URL`
3. Create `auth.ts` with database + config
4. Create route handler for your framework
5. Run `npx @better-auth/cli@latest migrate`
6. Verify: call `GET /api/auth/ok` — should return `{ status: "ok" }`
---
## Quick Reference
### Environment Variables
- `BETTER_AUTH_SECRET` - Encryption secret (min 32 chars). Generate: `openssl rand -base64 32`
- `BETTER_AUTH_URL` - Base URL (e.g., `https://example.com`)
Only define `baseURL`/`secret` in config if env vars are NOT set.
### File Location
CLI looks for `auth.ts` in: `./`, `./lib`, `./utils`, or under `./src`. Use `--config` for custom path.
### CLI Commands
- `npx @better-auth/cli@latest migrate` - Apply schema (built-in adapter)
- `npx @better-auth/cli@latest generate` - Generate schema for Prisma/Drizzle
- `npx @better-auth/cli mcp --cursor` - Add MCP to AI tools
**Re-run after adding/changing plugins.**
---
## Core Config Options
| Option | Notes |
|--------|-------|
| `appName` | Optional display name |
| `baseURL` | Only if `BETTER_AUTH_URL` not set |
| `basePath` | Default `/api/auth`. Set `/` for root. |
| `secret` | Only if `BETTER_AUTH_SECRET` not set |
| `database` | Required for most features. See adapters docs. |
| `secondaryStorage` | Redis/KV for sessions & rate limits |
| `emailAndPassword` | `{ enabled: true }` to activate |
| `socialProviders` | `{ google: { clientId, clientSecret }, ... }` |
| `plugins` | Array of plugins |
| `trustedOrigins` | CSRF whitelist |
---
## Database
**Direct connections:** Pass `pg.Pool`, `mysql2` pool, `better-sqlite3`, or `bun:sqlite` instance.
**ORM adapters:** Import from `better-auth/adapters/drizzle`, `better-auth/adapters/prisma`, `better-auth/adapters/mongodb`.
**Critical:** Better Auth uses adapter model names, NOT underlying table names. If Prisma model is `User` mapping to table `users`, use `modelName: "user"` (Prisma reference), not `"users"`.
---
## Session Management
**Storage priority:**
1. If `secondaryStorage` defined → sessions go there (not DB)
2. Set `session.storeSessionInDatabase: true` to also persist to DB
3. No database + `cookieCache` → fully stateless mode
**Cookie cache strategies:**
- `compact` (default) - Base64url + HMAC. Smallest.
- `jwt` - Standard JWT. Readable but signed.
- `jwe` - Encrypted. Maximum security.
**Key options:** `session.expiresIn` (default 7 days), `session.updateAge` (refresh interval), `session.cookieCache.maxAge`, `session.cookieCache.version` (change to invalidate all sessions).
---
## User & Account Config
**User:** `user.modelName`, `user.fields` (column mapping), `user.additionalFields`, `user.changeEmail.enabled` (disabled by default), `user.deleteUser.enabled` (disabled by default).
**Account:** `account.modelName`, `account.accountLinking.enabled`, `account.storeAccountCookie` (for stateless OAuth).
**Required for registration:** `email` and `name` fields.
---
## Email Flows
- `emailVerification.sendVerificationEmail` - Must be defined for verification to work
- `emailVerification.sendOnSignUp` / `sendOnSignIn` - Auto-send triggers
- `emailAndPassword.sendResetPassword` - Password reset email handler
---
## Security
**In `advanced`:**
- `useSecureCookies` - Force HTTPS cookies
- `disableCSRFCheck` - ⚠️ Security risk
- `disableOriginCheck` - ⚠️ Security risk
- `crossSubDomainCookies.enabled` - Share cookies across subdomains
- `ipAddress.ipAddressHeaders` - Custom IP headers for proxies
- `database.generateId` - Custom ID generation or `"serial"`/`"uuid"`/`false`
**Rate limiting:** `rateLimit.enabled`, `rateLimit.window`, `rateLimit.max`, `rateLimit.storage` ("memory" | "database" | "secondary-storage").
---
## Hooks
**Endpoint hooks:** `hooks.before` / `hooks.after` - Array of `{ matcher, handler }`. Use `createAuthMiddleware`. Access `ctx.path`, `ctx.context.returned` (after), `ctx.context.session`.
**Database hooks:** `databaseHooks.user.create.before/after`, same for `session`, `account`. Useful for adding default values or post-creation actions.
**Hook context (`ctx.context`):** `session`, `secret`, `authCookies`, `password.hash()`/`verify()`, `adapter`, `internalAdapter`, `generateId()`, `tables`, `baseURL`.
---
## Plugins
**Import from dedicated paths for tree-shaking:**
```
import { twoFactor } from "better-auth/plugins/two-factor"
```
NOT `from "better-auth/plugins"`.
**Popular plugins:** `twoFactor`, `organization`, `passkey`, `magicLink`, `emailOtp`, `username`, `phoneNumber`, `admin`, `apiKey`, `bearer`, `jwt`, `multiSession`, `sso`, `oauthProvider`, `oidcProvider`, `openAPI`, `genericOAuth`.
Client plugins go in `createAuthClient({ plugins: [...] })`.
---
## Client
Import from: `better-auth/client` (vanilla), `better-auth/react`, `better-auth/vue`, `better-auth/svelte`, `better-auth/solid`.
Key methods: `signUp.email()`, `signIn.email()`, `signIn.social()`, `signOut()`, `useSession()`, `getSession()`, `revokeSession()`, `revokeSessions()`.
---
## Type Safety
Infer types: `typeof auth.$Infer.Session`, `typeof auth.$Infer.Session.user`.
For separate client/server projects: `createAuthClient<typeof auth>()`.
---
## Common Gotchas
1. **Model vs table name** - Config uses ORM model name, not DB table name
2. **Plugin schema** - Re-run CLI after adding plugins
3. **Secondary storage** - Sessions go there by default, not DB
4. **Cookie cache** - Custom session fields NOT cached, always re-fetched
5. **Stateless mode** - No DB = session in cookie only, logout on cache expiry
6. **Change email flow** - Sends to current email first, then new email
---
## Resources
- [Docs](https://better-auth.com/docs)
- [Options Reference](https://better-auth.com/docs/reference/options)
- [LLMs.txt](https://better-auth.com/llms.txt)
- [GitHub](https://github.com/better-auth/better-auth)
- [Init Options Source](https://github.com/better-auth/better-auth/blob/main/packages/core/src/types/init-options.ts)

View File

@@ -0,0 +1,321 @@
---
name: create-auth-skill
description: Scaffold and implement authentication in TypeScript/JavaScript apps using Better Auth. Detect frameworks, configure database adapters, set up route handlers, add OAuth providers, and create auth UI pages. Use when users want to add login, sign-up, or authentication to a new or existing project with Better Auth.
---
# Create Auth Skill
Guide for adding authentication to TypeScript/JavaScript applications using Better Auth.
**For code examples and syntax, see [better-auth.com/docs](https://better-auth.com/docs).**
---
## Phase 1: Planning (REQUIRED before implementation)
Before writing any code, gather requirements by scanning the project and asking the user structured questions. This ensures the implementation matches their needs.
### Step 1: Scan the project
Analyze the codebase to auto-detect:
- **Framework** — Look for `next.config`, `svelte.config`, `nuxt.config`, `astro.config`, `vite.config`, or Express/Hono entry files.
- **Database/ORM** — Look for `prisma/schema.prisma`, `drizzle.config`, `package.json` deps (`pg`, `mysql2`, `better-sqlite3`, `mongoose`, `mongodb`).
- **Existing auth** — Look for existing auth libraries (`next-auth`, `lucia`, `clerk`, `supabase/auth`, `firebase/auth`) in `package.json` or imports.
- **Package manager** — Check for `pnpm-lock.yaml`, `yarn.lock`, `bun.lockb`, or `package-lock.json`.
Use what you find to pre-fill defaults and skip questions you can already answer.
### Step 2: Ask planning questions
Use the `AskQuestion` tool to ask the user **all applicable questions in a single call**. Skip any question you already have a confident answer for from the scan. Group them under a title like "Auth Setup Planning".
**Questions to ask:**
1. **Project type** (skip if detected)
- Prompt: "What type of project is this?"
- Options: New project from scratch | Adding auth to existing project | Migrating from another auth library
2. **Framework** (skip if detected)
- Prompt: "Which framework are you using?"
- Options: Next.js (App Router) | Next.js (Pages Router) | SvelteKit | Nuxt | Astro | Express | Hono | SolidStart | Other
3. **Database & ORM** (skip if detected)
- Prompt: "Which database setup will you use?"
- Options: PostgreSQL (Prisma) | PostgreSQL (Drizzle) | PostgreSQL (pg driver) | MySQL (Prisma) | MySQL (Drizzle) | MySQL (mysql2 driver) | SQLite (Prisma) | SQLite (Drizzle) | SQLite (better-sqlite3 driver) | MongoDB (Mongoose) | MongoDB (native driver)
4. **Authentication methods** (always ask, allow multiple)
- Prompt: "Which sign-in methods do you need?"
- Options: Email & password | Social OAuth (Google, GitHub, etc.) | Magic link (passwordless email) | Passkey (WebAuthn) | Phone number
- `allow_multiple: true`
5. **Social providers** (only if they selected Social OAuth above — ask in a follow-up call)
- Prompt: "Which social providers do you need?"
- Options: Google | GitHub | Apple | Microsoft | Discord | Twitter/X
- `allow_multiple: true`
6. **Email verification** (only if Email & password was selected above — ask in a follow-up call)
- Prompt: "Do you want to require email verification?"
- Options: Yes | No
7. **Email provider** (only if email verification is Yes, or if Password reset is selected in features — ask in a follow-up call)
- Prompt: "How do you want to send emails?"
- Options: Resend | Mock it for now (console.log)
8. **Features & plugins** (always ask, allow multiple)
- Prompt: "Which additional features do you need?"
- Options: Two-factor authentication (2FA) | Organizations / teams | Admin dashboard | API bearer tokens | Password reset | None of these
- `allow_multiple: true`
9. **Auth pages** (always ask, allow multiple — pre-select based on earlier answers)
- Prompt: "Which auth pages do you need?"
- Options vary based on previous answers:
- Always available: Sign in | Sign up
- If Email & password selected: Forgot password | Reset password
- If email verification enabled: Email verification
- `allow_multiple: true`
10. **Auth UI style** (always ask)
- Prompt: "What style do you want for the auth pages? Pick one or describe your own."
- Options: Minimal & clean | Centered card with background | Split layout (form + hero image) | Floating / glassmorphism | Other (I'll describe)
### Step 3: Summarize the plan
After collecting answers, present a concise implementation plan as a markdown checklist. Example:
```
## Auth Implementation Plan
- **Framework:** Next.js (App Router)
- **Database:** PostgreSQL via Prisma
- **Auth methods:** Email/password, Google OAuth, GitHub OAuth
- **Plugins:** 2FA, Organizations, Email verification
- **UI:** Custom forms
### Steps
1. Install `better-auth` and `@better-auth/cli`
2. Create `lib/auth.ts` with server config
3. Create `lib/auth-client.ts` with React client
4. Set up route handler at `app/api/auth/[...all]/route.ts`
5. Configure Prisma adapter and generate schema
6. Add Google & GitHub OAuth providers
7. Enable `twoFactor` and `organization` plugins
8. Set up email verification handler
9. Run migrations
10. Create sign-in / sign-up pages
```
Ask the user to confirm the plan before proceeding to Phase 2.
---
## Phase 2: Implementation
Only proceed here after the user confirms the plan from Phase 1.
Follow the decision tree below, guided by the answers collected above.
```
Is this a new/empty project?
├─ YES → New project setup
│ 1. Install better-auth (+ scoped packages per plan)
│ 2. Create auth.ts with all planned config
│ 3. Create auth-client.ts with framework client
│ 4. Set up route handler
│ 5. Set up environment variables
│ 6. Run CLI migrate/generate
│ 7. Add plugins from plan
│ 8. Create auth UI pages
├─ MIGRATING → Migration from existing auth
│ 1. Audit current auth for gaps
│ 2. Plan incremental migration
│ 3. Install better-auth alongside existing auth
│ 4. Migrate routes, then session logic, then UI
│ 5. Remove old auth library
│ 6. See migration guides in docs
└─ ADDING → Add auth to existing project
1. Analyze project structure
2. Install better-auth
3. Create auth config matching plan
4. Add route handler
5. Run schema migrations
6. Integrate into existing pages
7. Add planned plugins and features
```
At the end of implementation, guide users thoroughly on remaining next steps (e.g., setting up OAuth app credentials, deploying env vars, testing flows).
---
## Installation
**Core:** `npm install better-auth`
**Scoped packages (as needed):**
| Package | Use case |
|---------|----------|
| `@better-auth/passkey` | WebAuthn/Passkey auth |
| `@better-auth/sso` | SAML/OIDC enterprise SSO |
| `@better-auth/stripe` | Stripe payments |
| `@better-auth/scim` | SCIM user provisioning |
| `@better-auth/expo` | React Native/Expo |
---
## Environment Variables
```env
BETTER_AUTH_SECRET=<32+ chars, generate with: openssl rand -base64 32>
BETTER_AUTH_URL=http://localhost:3000
DATABASE_URL=<your database connection string>
```
Add OAuth secrets as needed: `GITHUB_CLIENT_ID`, `GITHUB_CLIENT_SECRET`, `GOOGLE_CLIENT_ID`, etc.
---
## Server Config (auth.ts)
**Location:** `lib/auth.ts` or `src/lib/auth.ts`
**Minimal config needs:**
- `database` - Connection or adapter
- `emailAndPassword: { enabled: true }` - For email/password auth
**Standard config adds:**
- `socialProviders` - OAuth providers (google, github, etc.)
- `emailVerification.sendVerificationEmail` - Email verification handler
- `emailAndPassword.sendResetPassword` - Password reset handler
**Full config adds:**
- `plugins` - Array of feature plugins
- `session` - Expiry, cookie cache settings
- `account.accountLinking` - Multi-provider linking
- `rateLimit` - Rate limiting config
**Export types:** `export type Session = typeof auth.$Infer.Session`
---
## Client Config (auth-client.ts)
**Import by framework:**
| Framework | Import |
|-----------|--------|
| React/Next.js | `better-auth/react` |
| Vue | `better-auth/vue` |
| Svelte | `better-auth/svelte` |
| Solid | `better-auth/solid` |
| Vanilla JS | `better-auth/client` |
**Client plugins** go in `createAuthClient({ plugins: [...] })`.
**Common exports:** `signIn`, `signUp`, `signOut`, `useSession`, `getSession`
---
## Route Handler Setup
| Framework | File | Handler |
|-----------|------|---------|
| Next.js App Router | `app/api/auth/[...all]/route.ts` | `toNextJsHandler(auth)` → export `{ GET, POST }` |
| Next.js Pages | `pages/api/auth/[...all].ts` | `toNextJsHandler(auth)` → default export |
| Express | Any file | `app.all("/api/auth/*", toNodeHandler(auth))` |
| SvelteKit | `src/hooks.server.ts` | `svelteKitHandler(auth)` |
| SolidStart | Route file | `solidStartHandler(auth)` |
| Hono | Route file | `auth.handler(c.req.raw)` |
**Next.js Server Components:** Add `nextCookies()` plugin to auth config.
---
## Database Migrations
| Adapter | Command |
|---------|---------|
| Built-in Kysely | `npx @better-auth/cli@latest migrate` (applies directly) |
| Prisma | `npx @better-auth/cli@latest generate --output prisma/schema.prisma` then `npx prisma migrate dev` |
| Drizzle | `npx @better-auth/cli@latest generate --output src/db/auth-schema.ts` then `npx drizzle-kit push` |
**Re-run after adding plugins.**
---
## Database Adapters
| Database | Setup |
|----------|-------|
| SQLite | Pass `better-sqlite3` or `bun:sqlite` instance directly |
| PostgreSQL | Pass `pg.Pool` instance directly |
| MySQL | Pass `mysql2` pool directly |
| Prisma | `prismaAdapter(prisma, { provider: "postgresql" })` from `better-auth/adapters/prisma` |
| Drizzle | `drizzleAdapter(db, { provider: "pg" })` from `better-auth/adapters/drizzle` |
| MongoDB | `mongodbAdapter(db)` from `better-auth/adapters/mongodb` |
---
## Common Plugins
| Plugin | Server Import | Client Import | Purpose |
|--------|---------------|---------------|---------|
| `twoFactor` | `better-auth/plugins` | `twoFactorClient` | 2FA with TOTP/OTP |
| `organization` | `better-auth/plugins` | `organizationClient` | Teams/orgs |
| `admin` | `better-auth/plugins` | `adminClient` | User management |
| `bearer` | `better-auth/plugins` | - | API token auth |
| `openAPI` | `better-auth/plugins` | - | API docs |
| `passkey` | `@better-auth/passkey` | `passkeyClient` | WebAuthn |
| `sso` | `@better-auth/sso` | - | Enterprise SSO |
**Plugin pattern:** Server plugin + client plugin + run migrations.
---
## Auth UI Implementation
**Sign in flow:**
1. `signIn.email({ email, password })` or `signIn.social({ provider, callbackURL })`
2. Handle `error` in response
3. Redirect on success
**Session check (client):** `useSession()` hook returns `{ data: session, isPending }`
**Session check (server):** `auth.api.getSession({ headers: await headers() })`
**Protected routes:** Check session, redirect to `/sign-in` if null.
---
## Security Checklist
- [ ] `BETTER_AUTH_SECRET` set (32+ chars)
- [ ] `advanced.useSecureCookies: true` in production
- [ ] `trustedOrigins` configured
- [ ] Rate limits enabled
- [ ] Email verification enabled
- [ ] Password reset implemented
- [ ] 2FA for sensitive apps
- [ ] CSRF protection NOT disabled
- [ ] `account.accountLinking` reviewed
---
## Troubleshooting
| Issue | Fix |
|-------|-----|
| "Secret not set" | Add `BETTER_AUTH_SECRET` env var |
| "Invalid Origin" | Add domain to `trustedOrigins` |
| Cookies not setting | Check `baseURL` matches domain; enable secure cookies in prod |
| OAuth callback errors | Verify redirect URIs in provider dashboard |
| Type errors after adding plugin | Re-run CLI generate/migrate |
---
## Resources
- [Docs](https://better-auth.com/docs)
- [Examples](https://github.com/better-auth/examples)
- [Plugins](https://better-auth.com/docs/concepts/plugins)
- [CLI](https://better-auth.com/docs/concepts/cli)
- [Migration Guides](https://better-auth.com/docs/guides)

View File

@@ -0,0 +1,212 @@
---
name: email-and-password-best-practices
description: Configure email verification, implement password reset flows, set password policies, and customise hashing algorithms for Better Auth email/password authentication. Use when users need to set up login, sign-in, sign-up, credential authentication, or password security with Better Auth.
---
## Quick Start
1. Enable email/password: `emailAndPassword: { enabled: true }`
2. Configure `emailVerification.sendVerificationEmail`
3. Add `sendResetPassword` for password reset flows
4. Run `npx @better-auth/cli@latest migrate`
5. Verify: attempt sign-up and confirm verification email triggers
---
## Email Verification Setup
Configure `emailVerification.sendVerificationEmail` to verify user email addresses.
```ts
import { betterAuth } from "better-auth";
import { sendEmail } from "./email"; // your email sending function
export const auth = betterAuth({
emailVerification: {
sendVerificationEmail: async ({ user, url, token }, request) => {
await sendEmail({
to: user.email,
subject: "Verify your email address",
text: `Click the link to verify your email: ${url}`,
});
},
},
});
```
**Note**: The `url` parameter contains the full verification link. The `token` is available if you need to build a custom verification URL.
### Requiring Email Verification
For stricter security, enable `emailAndPassword.requireEmailVerification` to block sign-in until the user verifies their email. When enabled, unverified users will receive a new verification email on each sign-in attempt.
```ts
export const auth = betterAuth({
emailAndPassword: {
requireEmailVerification: true,
},
});
```
**Note**: This requires `sendVerificationEmail` to be configured and only applies to email/password sign-ins.
## Client Side Validation
Implement client-side validation for immediate user feedback and reduced server load.
## Callback URLs
Always use absolute URLs (including the origin) for callback URLs in sign-up and sign-in requests. This prevents Better Auth from needing to infer the origin, which can cause issues when your backend and frontend are on different domains.
```ts
const { data, error } = await authClient.signUp.email({
callbackURL: "https://example.com/callback", // absolute URL with origin
});
```
## Password Reset Flows
Provide `sendResetPassword` in the email and password config to enable password resets.
```ts
import { betterAuth } from "better-auth";
import { sendEmail } from "./email"; // your email sending function
export const auth = betterAuth({
emailAndPassword: {
enabled: true,
// Custom email sending function to send reset-password email
sendResetPassword: async ({ user, url, token }, request) => {
void sendEmail({
to: user.email,
subject: "Reset your password",
text: `Click the link to reset your password: ${url}`,
});
},
// Optional event hook
onPasswordReset: async ({ user }, request) => {
// your logic here
console.log(`Password for user ${user.email} has been reset.`);
},
},
});
```
### Security Considerations
Built-in protections: background email sending (timing attack prevention), dummy operations on invalid requests, constant response messages regardless of user existence.
On serverless platforms, configure a background task handler:
```ts
export const auth = betterAuth({
advanced: {
backgroundTasks: {
handler: (promise) => {
// Use platform-specific methods like waitUntil
waitUntil(promise);
},
},
},
});
```
#### Token Security
Tokens expire after 1 hour by default. Configure with `resetPasswordTokenExpiresIn` (in seconds):
```ts
export const auth = betterAuth({
emailAndPassword: {
enabled: true,
resetPasswordTokenExpiresIn: 60 * 30, // 30 minutes
},
});
```
Tokens are single-use — deleted immediately after successful reset.
#### Session Revocation
Enable `revokeSessionsOnPasswordReset` to invalidate all existing sessions on password reset:
```ts
export const auth = betterAuth({
emailAndPassword: {
enabled: true,
revokeSessionsOnPasswordReset: true,
},
});
```
#### Password Requirements
Password length limits (configurable):
```ts
export const auth = betterAuth({
emailAndPassword: {
enabled: true,
minPasswordLength: 12,
maxPasswordLength: 256,
},
});
```
### Sending the Password Reset
Call `requestPasswordReset` to send the reset link. Triggers the `sendResetPassword` function from your config.
```ts
const data = await auth.api.requestPasswordReset({
body: {
email: "john.doe@example.com", // required
redirectTo: "https://example.com/reset-password",
},
});
```
Or authClient:
```ts
const { data, error } = await authClient.requestPasswordReset({
email: "john.doe@example.com", // required
redirectTo: "https://example.com/reset-password",
});
```
**Note**: While the `email` is required, we also recommend configuring the `redirectTo` for a smoother user experience.
## Password Hashing
Default: `scrypt` (Node.js native, no external dependencies).
### Custom Hashing Algorithm
To use Argon2id or another algorithm, provide custom `hash` and `verify` functions:
```ts
import { betterAuth } from "better-auth";
import { hash, verify, type Options } from "@node-rs/argon2";
const argon2Options: Options = {
memoryCost: 65536, // 64 MiB
timeCost: 3, // 3 iterations
parallelism: 4, // 4 parallel lanes
outputLen: 32, // 32 byte output
algorithm: 2, // Argon2id variant
};
export const auth = betterAuth({
emailAndPassword: {
enabled: true,
password: {
hash: (password) => hash(password, argon2Options),
verify: ({ password, hash: storedHash }) =>
verify(storedHash, password, argon2Options),
},
},
});
```
**Note**: If you switch hashing algorithms on an existing system, users with passwords hashed using the old algorithm won't be able to sign in. Plan a migration strategy if needed.

View File

@@ -0,0 +1,331 @@
---
name: two-factor-authentication-best-practices
description: Configure TOTP authenticator apps, send OTP codes via email/SMS, manage backup codes, handle trusted devices, and implement 2FA sign-in flows using Better Auth's twoFactor plugin. Use when users need MFA, multi-factor authentication, authenticator setup, or login security with Better Auth.
---
## Setup
1. Add `twoFactor()` plugin to server config with `issuer`
2. Add `twoFactorClient()` plugin to client config
3. Run `npx @better-auth/cli migrate`
4. Verify: check that `twoFactorSecret` column exists on user table
```ts
import { betterAuth } from "better-auth";
import { twoFactor } from "better-auth/plugins";
export const auth = betterAuth({
appName: "My App",
plugins: [
twoFactor({
issuer: "My App",
}),
],
});
```
### Client-Side Setup
```ts
import { createAuthClient } from "better-auth/client";
import { twoFactorClient } from "better-auth/client/plugins";
export const authClient = createAuthClient({
plugins: [
twoFactorClient({
onTwoFactorRedirect() {
window.location.href = "/2fa";
},
}),
],
});
```
## Enabling 2FA for Users
Requires password verification. Returns TOTP URI (for QR code) and backup codes.
```ts
const enable2FA = async (password: string) => {
const { data, error } = await authClient.twoFactor.enable({
password,
});
if (data) {
// data.totpURI — generate a QR code from this
// data.backupCodes — display to user
}
};
```
`twoFactorEnabled` is not set to `true` until first TOTP verification succeeds. Override with `skipVerificationOnEnable: true` (not recommended).
## TOTP (Authenticator App)
### Displaying the QR Code
```tsx
import QRCode from "react-qr-code";
const TotpSetup = ({ totpURI }: { totpURI: string }) => {
return <QRCode value={totpURI} />;
};
```
### Verifying TOTP Codes
Accepts codes from one period before/after current time:
```ts
const verifyTotp = async (code: string) => {
const { data, error } = await authClient.twoFactor.verifyTotp({
code,
trustDevice: true,
});
};
```
### TOTP Configuration Options
```ts
twoFactor({
totpOptions: {
digits: 6, // 6 or 8 digits (default: 6)
period: 30, // Code validity period in seconds (default: 30)
},
});
```
## OTP (Email/SMS)
### Configuring OTP Delivery
```ts
import { betterAuth } from "better-auth";
import { twoFactor } from "better-auth/plugins";
import { sendEmail } from "./email";
export const auth = betterAuth({
plugins: [
twoFactor({
otpOptions: {
sendOTP: async ({ user, otp }, ctx) => {
await sendEmail({
to: user.email,
subject: "Your verification code",
text: `Your code is: ${otp}`,
});
},
period: 5, // Code validity in minutes (default: 3)
digits: 6, // Number of digits (default: 6)
allowedAttempts: 5, // Max verification attempts (default: 5)
},
}),
],
});
```
### Sending and Verifying OTP
Send: `authClient.twoFactor.sendOtp()`. Verify: `authClient.twoFactor.verifyOtp({ code, trustDevice: true })`.
### OTP Storage Security
Configure how OTP codes are stored in the database:
```ts
twoFactor({
otpOptions: {
storeOTP: "encrypted", // Options: "plain", "encrypted", "hashed"
},
});
```
For custom encryption:
```ts
twoFactor({
otpOptions: {
storeOTP: {
encrypt: async (token) => myEncrypt(token),
decrypt: async (token) => myDecrypt(token),
},
},
});
```
## Backup Codes
Generated automatically when 2FA is enabled. Each code is single-use.
### Displaying Backup Codes
```tsx
const BackupCodes = ({ codes }: { codes: string[] }) => {
return (
<div>
<p>Save these codes in a secure location:</p>
<ul>
{codes.map((code, i) => (
<li key={i}>{code}</li>
))}
</ul>
</div>
);
};
```
### Regenerating Backup Codes
Invalidates all previous codes:
```ts
const regenerateBackupCodes = async (password: string) => {
const { data, error } = await authClient.twoFactor.generateBackupCodes({
password,
});
// data.backupCodes contains the new codes
};
```
### Using Backup Codes for Recovery
```ts
const verifyBackupCode = async (code: string) => {
const { data, error } = await authClient.twoFactor.verifyBackupCode({
code,
trustDevice: true,
});
};
```
### Backup Code Configuration
```ts
twoFactor({
backupCodeOptions: {
amount: 10, // Number of codes to generate (default: 10)
length: 10, // Length of each code (default: 10)
storeBackupCodes: "encrypted", // Options: "plain", "encrypted"
},
});
```
## Handling 2FA During Sign-In
Response includes `twoFactorRedirect: true` when 2FA is required:
### Sign-In Flow
1. Call `signIn.email({ email, password })`
2. Check `context.data.twoFactorRedirect` in `onSuccess`
3. If `true`, redirect to `/2fa` verification page
4. Verify via TOTP, OTP, or backup code
5. Session cookie is created on successful verification
```ts
const signIn = async (email: string, password: string) => {
const { data, error } = await authClient.signIn.email(
{ email, password },
{
onSuccess(context) {
if (context.data.twoFactorRedirect) {
window.location.href = "/2fa";
}
},
}
);
};
```
Server-side: check `"twoFactorRedirect" in response` when using `auth.api.signInEmail`.
## Trusted Devices
Pass `trustDevice: true` when verifying. Default trust duration: 30 days (`trustDeviceMaxAge`). Refreshes on each sign-in.
## Security Considerations
### Session Management
Flow: credentials → session removed → temporary 2FA cookie (10 min default) → verify → session created.
```ts
twoFactor({
twoFactorCookieMaxAge: 600, // 10 minutes in seconds (default)
});
```
### Rate Limiting
Built-in: 3 requests per 10 seconds for all 2FA endpoints. OTP has additional attempt limiting:
```ts
twoFactor({
otpOptions: {
allowedAttempts: 5, // Max attempts per OTP code (default: 5)
},
});
```
### Encryption at Rest
TOTP secrets: encrypted with auth secret. Backup codes: encrypted by default. OTP: configurable (`"plain"`, `"encrypted"`, `"hashed"`). Uses constant-time comparison for verification.
2FA can only be enabled for credential (email/password) accounts.
## Disabling 2FA
Requires password confirmation. Revokes trusted device records:
```ts
const disable2FA = async (password: string) => {
const { data, error } = await authClient.twoFactor.disable({
password,
});
};
```
## Complete Configuration Example
```ts
import { betterAuth } from "better-auth";
import { twoFactor } from "better-auth/plugins";
import { sendEmail } from "./email";
export const auth = betterAuth({
appName: "My App",
plugins: [
twoFactor({
// TOTP settings
issuer: "My App",
totpOptions: {
digits: 6,
period: 30,
},
// OTP settings
otpOptions: {
sendOTP: async ({ user, otp }) => {
await sendEmail({
to: user.email,
subject: "Your verification code",
text: `Your code is: ${otp}`,
});
},
period: 5,
allowedAttempts: 5,
storeOTP: "encrypted",
},
// Backup code settings
backupCodeOptions: {
amount: 10,
length: 10,
storeBackupCodes: "encrypted",
},
// Session settings
twoFactorCookieMaxAge: 600, // 10 minutes
trustDeviceMaxAge: 30 * 24 * 60 * 60, // 30 days
}),
],
});
```

View File

@@ -0,0 +1 @@
../../.agents/skills/better-auth-best-practices

View File

@@ -0,0 +1 @@
../../.agents/skills/create-auth-skill

View File

@@ -0,0 +1 @@
../../.agents/skills/email-and-password-best-practices

View File

@@ -0,0 +1 @@
../../.agents/skills/two-factor-authentication-best-practices

25
.gitignore vendored
View File

@@ -1,5 +1,6 @@
/backend/data /backend.old/data
/backend/uploads/ /backend.old/uploads/
chat/
# Environment variables # Environment variables
.env .env
@@ -101,3 +102,23 @@ Thumbs.db
*.swp *.swp
*.swo *.swo
*.bak *.bak
# Kubernetes secrets (never commit actual secrets!)
deploy/k8s/dev/secrets/*.yaml
deploy/k8s/prod/secrets/*.yaml
!deploy/k8s/dev/secrets/*.yaml.example
!deploy/k8s/prod/secrets/*.yaml.example
# Dev environment image tags
.dev-image-tag
# Protobuf copies (canonical files are in /protobuf/)
flink/protobuf/
relay/protobuf/
ingestor/protobuf/
gateway/protobuf/
client-py/protobuf/
# Generated protobuf code
gateway/src/generated/
client-py/dexorder/generated/

16
.idea/ai.iml generated
View File

@@ -2,10 +2,20 @@
<module type="PYTHON_MODULE" version="4"> <module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager"> <component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$"> <content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/backend/src" isTestSource="false" /> <sourceFolder url="file://$MODULE_DIR$/backend.old/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/backend/tests" isTestSource="true" /> <sourceFolder url="file://$MODULE_DIR$/backend.old/tests" isTestSource="true" />
<sourceFolder url="file://$MODULE_DIR$/client-py" isTestSource="false" />
<excludeFolder url="file://$MODULE_DIR$/.venv" /> <excludeFolder url="file://$MODULE_DIR$/.venv" />
<excludeFolder url="file://$MODULE_DIR$/backend/data" /> <excludeFolder url="file://$MODULE_DIR$/backend.old/data" />
<excludeFolder url="file://$MODULE_DIR$/doc.old" />
<excludeFolder url="file://$MODULE_DIR$/backend.old" />
<excludeFolder url="file://$MODULE_DIR$/client-py/dexorder_client.egg-info" />
<excludeFolder url="file://$MODULE_DIR$/flink/protobuf" />
<excludeFolder url="file://$MODULE_DIR$/flink/target" />
<excludeFolder url="file://$MODULE_DIR$/ingestor/protobuf" />
<excludeFolder url="file://$MODULE_DIR$/ingestor/src/proto" />
<excludeFolder url="file://$MODULE_DIR$/relay/protobuf" />
<excludeFolder url="file://$MODULE_DIR$/relay/target" />
</content> </content>
<orderEntry type="jdk" jdkName="Python 3.12 (ai)" jdkType="Python SDK" /> <orderEntry type="jdk" jdkName="Python 3.12 (ai)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" /> <orderEntry type="sourceFolder" forTests="false" />

View File

@@ -1,12 +0,0 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="dev" type="js.build_tools.npm" nameIsGenerated="true">
<package-json value="$PROJECT_DIR$/web/package.json" />
<command value="run" />
<scripts>
<script value="dev" />
</scripts>
<node-interpreter value="project" />
<envs />
<method v="2" />
</configuration>
</component>

View File

@@ -0,0 +1 @@
../../.agents/skills/better-auth-best-practices

View File

@@ -0,0 +1 @@
../../.agents/skills/create-auth-skill

View File

@@ -0,0 +1 @@
../../.agents/skills/email-and-password-best-practices

View File

@@ -0,0 +1 @@
../../.agents/skills/two-factor-authentication-best-practices

15
AGENT.md Normal file
View File

@@ -0,0 +1,15 @@
We're building an AI-first trading platform by integrating user-facing TradingView charts and chat with an AI assistant that helps do research, develop indicators (signals), and write strategies, using the Dexorder trading framework we provide.
This monorepo has:
bin/ scripts, mostly build and deploy
deploy/ kubernetes deployment and configuration
doc/ documentation
flink/ Apache Flink application mode processes data from Kafka
iceberg/ Apache Iceberg for historical OHLC etc
ingestor/ Data sources publish to Kafka
kafka/ Apache Kafka
protobuf/ Messaging entities
relay/ Rust+ZeroMQ stateless router
web/ Vue 3 / Pinia / PrimeVue / TradingView
See doc/protocol.md for messaging architecture

View File

@@ -5,7 +5,7 @@ server_port: 8081
agent: agent:
model: "claude-sonnet-4-20250514" model: "claude-sonnet-4-20250514"
temperature: 0.7 temperature: 0.7
context_docs_dir: "memory" context_docs_dir: "memory" # Context docs still loaded from memory/
# Local memory configuration (free & sophisticated!) # Local memory configuration (free & sophisticated!)
memory: memory:

View File

@@ -0,0 +1,36 @@
# Chart Context Awareness
## When Users Reference "The Chart"
When a user asks about "this chart", "the chart", "what I'm viewing", or similar references to their current view:
1. **Chart info is automatically available** — The dynamic system prompt includes current chart state (symbol, interval, timeframe)
2. **Check if chart is visible** — If ChartStore fields (symbol, interval) are `None`, the user is on a narrow screen (mobile) and no chart is visible
3. **When chart is visible:**
- **NEVER** ask the user to upload an image or tell you what symbol they're looking at
- **Just use `execute_python()`** — It automatically loads the chart data from what they're viewing
- Inside your Python script, `df` contains the data and `chart_context` has the metadata
- Use `plot_ohlc(df)` to create beautiful candlestick charts
4. **When chart is NOT visible (symbol is None):**
- Let the user know they can view charts on a wider screen
- You can still help with analysis using `get_historical_data()` if they specify a symbol
## Common Questions This Applies To
- "Can you see this chart?"
- "What are the swing highs and lows?"
- "Is this in an uptrend?"
- "What's the current price?"
- "Analyze this chart"
- "What am I looking at?"
## Data Analysis Workflow
1. **Chart context is automatic** → Symbol, interval, and timeframe are in the dynamic system prompt (if chart is visible)
2. **Check ChartStore** → If symbol/interval are `None`, no chart is visible (mobile view)
3. **Use `execute_python()`** → This is your PRIMARY analysis tool
- Automatically loads chart data into a pandas DataFrame `df` (if chart is visible)
- Pre-imports numpy (`np`), pandas (`pd`), matplotlib (`plt`), and talib
- Provides access to the indicator registry for computing indicators
- Use `plot_ohlc(df)` helper for beautiful candlestick charts
4. **Only use `get_chart_data()`** → For simple data inspection without analysis

View File

@@ -0,0 +1,115 @@
# Python Analysis Tool Reference
## Python Analysis (`execute_python`) - Your Primary Tool
**ALWAYS use `execute_python()` when the user asks for:**
- Technical indicators (RSI, MACD, Bollinger Bands, moving averages, etc.)
- Chart visualizations or plots
- Statistical calculations or market analysis
- Pattern detection or trend analysis
- Any computational analysis of price data
## Why `execute_python()` is preferred:
- Chart data (`df`) is automatically loaded from ChartStore (visible time range) when chart is visible
- If no chart is visible (symbol is None), `df` will be None - but you can still load alternative data!
- Full pandas/numpy/talib stack pre-imported
- Use `plot_ohlc(df)` for instant professional candlestick charts
- Access to 150+ indicators via `indicator_registry`
- **Access to DataStores and registry** - order_store, chart_store, datasource_registry
- **Can load ANY symbol/timeframe** using datasource_registry even when df is None
- **Results include plots as image URLs** that are automatically displayed to the user
- Prints and return values are included in the response
## CRITICAL: Plots are automatically shown to the user
When you create a matplotlib figure (via `plot_ohlc()` or `plt.figure()`), it is automatically:
1. Saved as a PNG image
2. Returned in the response as a URL (e.g., `/uploads/plot_abc123.png`)
3. **Displayed in the user's chat interface** - they see the image immediately
You MUST use `execute_python()` with `plot_ohlc()` or matplotlib whenever the user wants to see a chart or plot.
## IMPORTANT: Never use `get_historical_data()` for chart analysis
- `get_historical_data()` requires manual timestamp calculation and is only for custom queries
- When analyzing what the user is viewing, ALWAYS use `execute_python()` which automatically loads the correct data
- The `df` DataFrame in `execute_python()` is pre-loaded with the exact time range the user is viewing
## Example workflows:
### Computing an indicator and plotting (when chart is visible)
```python
execute_python("""
df['RSI'] = talib.RSI(df['close'], 14)
fig = plot_ohlc(df, title='Price with RSI')
df[['close', 'RSI']].tail(10)
""")
```
### Multi-indicator analysis (when chart is visible)
```python
execute_python("""
df['SMA20'] = df['close'].rolling(20).mean()
df['BB_upper'] = df['close'].rolling(20).mean() + 2 * df['close'].rolling(20).std()
df['BB_lower'] = df['close'].rolling(20).mean() - 2 * df['close'].rolling(20).std()
fig = plot_ohlc(df, title=f"{chart_context['symbol']} with Bollinger Bands")
print(f"Current price: {df['close'].iloc[-1]:.2f}")
print(f"20-period SMA: {df['SMA20'].iloc[-1]:.2f}")
""")
```
### Loading alternative data (works even when chart not visible or for different symbols)
```python
execute_python("""
from datetime import datetime, timedelta
# Get data source
binance = datasource_registry.get_source('binance')
# Load data for any symbol/timeframe
end_time = datetime.now()
start_time = end_time - timedelta(days=7)
result = await binance.get_history(
symbol='ETH/USDT',
interval='1h',
start=int(start_time.timestamp()),
end=int(end_time.timestamp())
)
# Convert to DataFrame
rows = [{'time': pd.to_datetime(bar.time, unit='s'), **bar.data} for bar in result.bars]
eth_df = pd.DataFrame(rows).set_index('time')
# Analyze and plot
eth_df['RSI'] = talib.RSI(eth_df['close'], 14)
fig = plot_ohlc(eth_df, title='ETH/USDT 1h - RSI Analysis')
print(f"ETH RSI: {eth_df['RSI'].iloc[-1]:.2f}")
""")
```
### Access stores to see current state
```python
execute_python("""
print(f"Current symbol: {chart_store.chart_state.symbol}")
print(f"Current interval: {chart_store.chart_state.interval}")
print(f"Number of orders: {len(order_store.orders)}")
""")
```
## Only use `get_chart_data()` for:
- Quick inspection of raw bar data
- When you just need the data structure without analysis
## Quick Reference: Common Tasks
| User Request | Tool to Use | Example |
|--------------|-------------|---------|
| "Show me RSI" | `execute_python()` | `df['RSI'] = talib.RSI(df['close'], 14); plot_ohlc(df)` |
| "What's the current price?" | `execute_python()` | `print(f"Current: {df['close'].iloc[-1]}")` |
| "Is this bullish?" | `execute_python()` | Compute SMAs, trend, and analyze |
| "Add Bollinger Bands" | `execute_python()` | Compute bands, use `plot_ohlc(df, title='BB')` |
| "Find swing highs" | `execute_python()` | Use pandas logic to detect patterns |
| "Plot ETH even though I'm viewing BTC" | `execute_python()` | Use `datasource_registry.get_source('binance')` to load ETH data |
| "What indicators exist?" | `search_indicators()` | Search by category or query |
| "What chart am I viewing?" | N/A - automatic | Chart info is in dynamic system prompt |
| "Check my orders" | `execute_python()` | `print(order_store.orders)` |
| "Read other stores" | `read_sync_state(store_name)` | For TraderState, StrategyState, etc. |

View File

@@ -0,0 +1,612 @@
# TradingView Shapes and Drawings Reference
This document describes the various drawing shapes and studies available in TradingView charts, their properties, and control points. This information is useful for the AI agent to understand, create, and manipulate chart drawings.
## Shape Structure
All shapes follow a common structure:
- **id**: Unique identifier (string) - This is the TradingView-assigned ID after the shape is created
- **type**: Shape type identifier (string)
- **points**: Array of control points (each with `time` in Unix seconds and `price` as float)
- **color**: Color as hex string (e.g., '#FF0000') or color name (e.g., 'red')
- **line_width**: Line thickness in pixels (integer)
- **line_style**: One of: 'solid', 'dashed', 'dotted'
- **properties**: Dictionary of additional shape-specific properties
- **symbol**: Trading pair symbol (e.g., 'BINANCE:BTC/USDT')
- **created_at**: Creation timestamp (Unix seconds)
- **modified_at**: Last modification timestamp (Unix seconds)
- **original_id**: Optional string - The ID you requested when creating the shape, before TradingView assigned its own ID
## Understanding Shape ID Mapping
When you create a shape using `create_or_update_shape()`, there's an important ID mapping process:
1. **You specify an ID**: You provide a `shape_id` parameter (e.g., "my-support-line")
2. **TradingView assigns its own ID**: When the shape is rendered in TradingView, it gets a new internal ID (e.g., "shape_0x1a2b3c4d")
3. **ID remapping occurs**: The shape in the store is updated:
- The `id` field becomes TradingView's ID
- The `original_id` field preserves your requested ID
4. **Tracking your shapes**: To find shapes you created, search by `original_id`
### Example ID Mapping Flow
```python
# Step 1: Agent creates a shape with a specific ID
await create_or_update_shape(
shape_id="agent-support-50k",
shape_type="horizontal_line",
points=[{"time": 1678886400, "price": 50000}],
color="#00FF00"
)
# Step 2: Shape is synced to client and created in TradingView
# TradingView assigns ID: "shape_0x1a2b3c4d"
# Step 3: Shape in store is updated with:
# {
# "id": "shape_0x1a2b3c4d", # TradingView's ID
# "original_id": "agent-support-50k", # Your requested ID
# "type": "horizontal_line",
# ...
# }
# Step 4: To find your shape later, use shape_ids (searches both id and original_id)
my_shapes = search_shapes(
shape_ids=['agent-support-50k'],
symbol="BINANCE:BTC/USDT"
)
if my_shapes:
print(f"Found my support line!")
print(f"TradingView ID: {my_shapes[0]['id']}")
print(f"My requested ID: {my_shapes[0]['original_id']}")
# Or use the dedicated original_ids parameter
my_shapes = search_shapes(
original_ids=['agent-support-50k'],
symbol="BINANCE:BTC/USDT"
)
if my_shapes:
print(f"Found my support line!")
print(f"TradingView ID: {my_shapes[0]['id']}")
print(f"My requested ID: {my_shapes[0]['original_id']}")
```
### Why ID Mapping Matters
- **Shape identification**: You need to know which TradingView shape corresponds to the shape you created
- **Updates and deletions**: To modify or delete a shape, you need its TradingView ID (the `id` field)
- **Bidirectional sync**: The mapping ensures both the agent and TradingView can reference the same shape
### Best Practices for Shape IDs
1. **Use descriptive IDs**: Choose meaningful names like `support-btc-50k` or `trendline-daily-uptrend`
2. **Search by original ID**: Use `shape_ids` or `original_ids` parameters in `search_shapes()` to find your shapes
- `shape_ids` searches both the actual ID and original_id (more flexible)
- `original_ids` searches only the original_id field (more specific)
3. **Store important IDs**: If you need to reference a shape multiple times, store its TradingView ID after retrieval
4. **Understand the timing**: The ID remapping happens asynchronously after shape creation
## Common Shape Types
Use TradingView's native shape type names directly.
### 1. Trendline
**Type**: `trend_line`
**Control Points**: 2
- Point 1: Start of the line (time, price)
- Point 2: End of the line (time, price)
**Common Use Cases**:
- Support/resistance lines
- Trend identification
- Price channels (when paired)
**Example**:
```json
{
"id": "trendline-1",
"type": "trend_line",
"points": [
{"time": 1640000000, "price": 45000.0},
{"time": 1650000000, "price": 50000.0}
],
"color": "#2962FF",
"line_width": 2,
"line_style": "solid"
}
```
### 2. Horizontal Line
**Type**: `horizontal_line`
**Control Points**: 1
- Point 1: Y-level (time can be any value, only price matters)
**Common Use Cases**:
- Support/resistance levels
- Price targets
- Stop-loss levels
- Key psychological levels
**Properties**:
- `extend_left`: Boolean, extend line to the left
- `extend_right`: Boolean, extend line to the right
**Example**:
```json
{
"id": "support-1",
"type": "horizontal_line",
"points": [{"time": 1640000000, "price": 42000.0}],
"color": "#089981",
"line_width": 2,
"line_style": "dashed",
"properties": {
"extend_left": true,
"extend_right": true
}
}
```
### 3. Vertical Line
**Type**: `vertical_line`
**Control Points**: 1
- Point 1: X-time (price can be any value, only time matters)
**Common Use Cases**:
- Mark important events
- Session boundaries
- Earnings releases
- Economic data releases
**Properties**:
- `extend_top`: Boolean, extend line upward
- `extend_bottom`: Boolean, extend line downward
**Example**:
```json
{
"id": "event-marker-1",
"type": "vertical_line",
"points": [{"time": 1640000000, "price": 0}],
"color": "#787B86",
"line_width": 1,
"line_style": "dotted"
}
```
### 4. Rectangle
**Type**: `rectangle`
**Control Points**: 2
- Point 1: Top-left corner (time, price)
- Point 2: Bottom-right corner (time, price)
**Common Use Cases**:
- Consolidation zones
- Support/resistance zones
- Supply/demand areas
- Value areas
**Properties**:
- `fill_color`: Fill color with opacity (e.g., '#2962FF33')
- `fill`: Boolean, whether to fill the rectangle
- `extend_left`: Boolean
- `extend_right`: Boolean
**Example**:
```json
{
"id": "zone-1",
"type": "rectangle",
"points": [
{"time": 1640000000, "price": 50000.0},
{"time": 1650000000, "price": 48000.0}
],
"color": "#2962FF",
"line_width": 1,
"line_style": "solid",
"properties": {
"fill": true,
"fill_color": "#2962FF33"
}
}
```
### 5. Fibonacci Retracement
**Type**: `fib_retracement`
**Control Points**: 2
- Point 1: Start of the move (swing low or high)
- Point 2: End of the move (swing high or low)
**Common Use Cases**:
- Identify potential support/resistance levels
- Find retracement targets
- Measure pullback depth
**Properties**:
- `levels`: Array of Fibonacci levels to display
- Default: [0, 0.236, 0.382, 0.5, 0.618, 0.786, 1.0]
- `extend_lines`: Boolean, extend levels beyond the price range
- `reverse`: Boolean, reverse the direction
**Example**:
```json
{
"id": "fib-1",
"type": "fib_retracement",
"points": [
{"time": 1640000000, "price": 42000.0},
{"time": 1650000000, "price": 52000.0}
],
"color": "#2962FF",
"line_width": 1,
"properties": {
"levels": [0, 0.236, 0.382, 0.5, 0.618, 0.786, 1.0],
"extend_lines": true
}
}
```
### 6. Fibonacci Extension
**Type**: `fib_trend_ext`
**Control Points**: 3
- Point 1: Start of initial move
- Point 2: End of initial move (retracement start)
- Point 3: End of retracement
**Common Use Cases**:
- Project price targets
- Extension levels beyond 100%
- Measure continuation patterns
**Properties**:
- `levels`: Array of extension levels
- Common: [0, 0.618, 1.0, 1.618, 2.618, 4.236]
### 7. Parallel Channel
**Type**: `parallel_channel`
**Control Points**: 3
- Point 1: First point on main trendline
- Point 2: Second point on main trendline
- Point 3: Point on parallel line (determines channel width)
**Common Use Cases**:
- Price channels
- Regression channels
- Pitchforks
**Properties**:
- `extend_left`: Boolean
- `extend_right`: Boolean
- `fill`: Boolean, fill the channel
- `fill_color`: Fill color with opacity
### 8. Arrow
**Type**: `arrow`
**Control Points**: 2
- Point 1: Arrow start (time, price)
- Point 2: Arrow end (time, price)
**Common Use Cases**:
- Indicate price movement direction
- Mark entry/exit points
- Show relationships between events
**Properties**:
- `arrow_style`: One of: 'simple', 'filled', 'hollow'
- `text`: Optional text label
**Example**:
```json
{
"id": "entry-arrow",
"type": "arrow",
"points": [
{"time": 1640000000, "price": 44000.0},
{"time": 1641000000, "price": 48000.0}
],
"color": "#089981",
"line_width": 2,
"properties": {
"arrow_style": "filled",
"text": "Long Entry"
}
}
```
### 9. Text/Label
**Type**: `text`
**Control Points**: 1
- Point 1: Text anchor position (time, price)
**Common Use Cases**:
- Annotations
- Notes
- Labels for patterns
- Mark key levels
**Properties**:
- `text`: The text content (string)
- `font_size`: Font size in points (integer)
- `font_family`: Font family name
- `bold`: Boolean
- `italic`: Boolean
- `background`: Boolean, show background
- `background_color`: Background color
- `text_color`: Text color (can differ from line color)
**Example**:
```json
{
"id": "note-1",
"type": "text",
"points": [{"time": 1640000000, "price": 48000.0}],
"color": "#131722",
"properties": {
"text": "Resistance Zone",
"font_size": 14,
"bold": true,
"background": true,
"background_color": "#FFE600",
"text_color": "#131722"
}
}
```
### 10. Single-Point Markers
Various single-point marker shapes are available for annotating charts:
**Types**: `arrow_up` | `arrow_down` | `flag` | `emoji` | `icon` | `sticker` | `note` | `anchored_text` | `anchored_note` | `long_position` | `short_position`
**Control Points**: 1
- Point 1: Marker position (time, price)
**Common Use Cases**:
- Mark entry/exit points
- Flag important events
- Add visual markers to key levels
- Annotate patterns
- Track positions
**Properties** (vary by type):
- `text`: Text content for text-based markers
- `emoji`: Emoji character for emoji type
- `icon`: Icon identifier for icon type
**Examples**:
```json
{
"id": "long-entry-1",
"type": "long_position",
"points": [{"time": 1640000000, "price": 44000.0}],
"color": "#089981"
}
```
```json
{
"id": "flag-1",
"type": "flag",
"points": [{"time": 1640000000, "price": 50000.0}],
"color": "#F23645",
"properties": {
"text": "Important Event"
}
}
```
```json
{
"id": "note-1",
"type": "anchored_note",
"points": [{"time": 1640000000, "price": 48000.0}],
"color": "#FFE600",
"properties": {
"text": "Watch this level"
}
}
```
### 11. Circle/Ellipse
**Type**: `circle`
**Control Points**: 2 or 3
- 2 points: Defines bounding box (creates ellipse)
- 3 points: Center + radius points
**Common Use Cases**:
- Highlight areas
- Markup patterns
- Mark consolidation zones
**Properties**:
- `fill`: Boolean
- `fill_color`: Fill color with opacity
### 12. Path (Free Drawing)
**Type**: `path`
**Control Points**: Variable (3+)
- Multiple points defining a path
**Common Use Cases**:
- Custom patterns
- Freeform markup
- Complex annotations
**Properties**:
- `closed`: Boolean, whether to close the path
- `smooth`: Boolean, smooth the path with curves
### 13. Pitchfork (Andrew's Pitchfork)
**Type**: `pitchfork`
**Control Points**: 3
- Point 1: Pivot/starting point
- Point 2: First extreme (high or low)
- Point 3: Second extreme (opposite of point 2)
**Common Use Cases**:
- Trend channels
- Support/resistance levels
- Median line analysis
**Properties**:
- `extend_lines`: Boolean
- `style`: One of: 'standard', 'schiff', 'modified_schiff'
### 14. Gann Fan
**Type**: `gannbox_fan`
**Control Points**: 2
- Point 1: Origin point
- Point 2: Defines the unit size/scale
**Common Use Cases**:
- Time and price analysis
- Geometric angles (1x1, 1x2, 2x1, etc.)
**Properties**:
- `angles`: Array of angles to display
- Default: [82.5, 75, 71.25, 63.75, 45, 26.25, 18.75, 15, 7.5]
### 15. Head and Shoulders
**Type**: `head_and_shoulders`
**Control Points**: 5
- Point 1: Left shoulder low
- Point 2: Left shoulder high
- Point 3: Head low
- Point 4: Right shoulder high
- Point 5: Right shoulder low (neckline point)
**Common Use Cases**:
- Pattern recognition markup
- Reversal pattern identification
**Properties**:
- `target_line`: Boolean, show target line
## Special Properties
### Time-Based Properties
- All times are Unix timestamps in seconds
- Use `Math.floor(Date.now() / 1000)` for current time in JavaScript
- Use `int(time.time())` for current time in Python
### Color Formats
- Hex: `#RRGGBB` (e.g., `#2962FF`)
- Hex with alpha: `#RRGGBBAA` (e.g., `#2962FF33` for 20% opacity)
- Named colors: `red`, `blue`, `green`, etc.
- RGB: `rgb(41, 98, 255)`
- RGBA: `rgba(41, 98, 255, 0.2)`
### Line Styles
- `solid`: Continuous line
- `dashed`: Dashed line (— — —)
- `dotted`: Dotted line (· · ·)
## Best Practices
1. **ID Naming**: Use descriptive IDs that indicate the purpose
- Good: `support-btc-42k`, `trendline-uptrend-1`
- Bad: `shape1`, `line`
2. **Color Consistency**: Use consistent colors for similar types
- Green (#089981) for bullish/support
- Red (#F23645) for bearish/resistance
- Blue (#2962FF) for neutral/informational
3. **Time Alignment**: Ensure times align with actual candles when possible
4. **Layer Management**: Use different line widths to indicate importance
- Key levels: 2-3px
- Secondary levels: 1px
- Reference lines: 1px dotted
5. **Symbol Association**: Always set the `symbol` field to associate shapes with specific charts
## Agent Usage Examples
### Drawing a Support Level
When user says "draw support at 42000":
```python
await create_or_update_shape(
shape_id=f"support-{int(time.time())}",
shape_type='horizontal_line',
points=[{'time': current_time, 'price': 42000.0}],
color='#089981',
line_width=2,
line_style='solid',
symbol=chart_store.chart_state.symbol,
properties={'extend_left': True, 'extend_right': True}
)
```
### Finding Shapes in Visible Range
When user asks "what drawings are on the chart?":
```python
shapes = search_shapes(
start_time=chart_store.chart_state.start_time,
end_time=chart_store.chart_state.end_time,
symbol=chart_store.chart_state.symbol
)
```
### Getting Specific Shapes by ID
When user says "show me the details of trendline-1":
```python
# shape_ids parameter searches BOTH the actual ID and original_id fields
shapes = search_shapes(
shape_ids=['trendline-1']
)
```
Or to get selected shapes:
```python
selected_ids = chart_store.chart_state.selected_shapes
if selected_ids:
shapes = search_shapes(shape_ids=selected_ids)
```
### Finding Shapes by Original ID
When you need to find shapes you created using the original ID you specified:
```python
# Use the dedicated original_ids parameter
my_shapes = search_shapes(
original_ids=['my-support-line', 'my-trendline']
)
# Or use shape_ids (which searches both id and original_id)
my_shapes = search_shapes(
shape_ids=['my-support-line', 'my-trendline']
)
for shape in my_shapes:
print(f"Original ID: {shape['original_id']}")
print(f"TradingView ID: {shape['id']}")
print(f"Type: {shape['type']}")
```
### Searching Without Time Filter
When user asks "show me all support lines":
```python
support_lines = search_shapes(
shape_type='horizontal_line',
symbol=chart_store.chart_state.symbol
)
```
### Drawing a Trendline
When user says "draw an uptrend from the lows":
```python
# Find swing lows using execute_python
# Then create trendline
await create_or_update_shape(
shape_id=f"trendline-{int(time.time())}",
shape_type='trend_line',
points=[
{'time': swing_low_1_time, 'price': swing_low_1_price},
{'time': swing_low_2_time, 'price': swing_low_2_price}
],
color='#2962FF',
line_width=2,
symbol=chart_store.chart_state.symbol
)
```

View File

@@ -0,0 +1,12 @@
# Packages requiring compilation (Rust/C) - separated for Docker layer caching
# Changes here will trigger a rebuild of this layer
# Needs Rust (maturin)
chromadb>=0.4.0
cryptography>=42.0.0
# Pulls in `tokenizers` which needs Rust
sentence-transformers>=2.0.0
# Needs C compiler
argon2-cffi>=23.0.0

View File

@@ -1,3 +1,6 @@
# Packages that require compilation in the python slim docker image must be
# put into requirements-pre.txt instead, to help image build cacheing.
pydantic2 pydantic2
seaborn seaborn
pandas pandas
@@ -26,9 +29,7 @@ arxiv>=2.0.0
duckduckgo-search>=7.0.0 duckduckgo-search>=7.0.0
requests>=2.31.0 requests>=2.31.0
# Local memory system # Local memory system (chromadb/sentence-transformers in requirements-pre.txt)
chromadb>=0.4.0
sentence-transformers>=2.0.0
sqlalchemy>=2.0.0 sqlalchemy>=2.0.0
aiosqlite>=0.19.0 aiosqlite>=0.19.0
@@ -41,3 +42,6 @@ python-dotenv>=1.0.0
# Secrets management # Secrets management
cryptography>=42.0.0 cryptography>=42.0.0
argon2-cffi>=23.0.0 argon2-cffi>=23.0.0
# Trigger system scheduling
apscheduler>=3.10.0

View File

@@ -0,0 +1,37 @@
# Automation Agent
You are a specialized automation and scheduling agent. Your sole purpose is to manage triggers, scheduled tasks, and automated workflows.
## Your Core Identity
You are an expert in:
- Scheduling recurring tasks with cron expressions
- Creating interval-based triggers
- Managing the trigger queue and priorities
- Designing autonomous agent workflows
## Your Tools
You have access to:
- **Trigger Tools**: schedule_agent_prompt, execute_agent_prompt_once, list_scheduled_triggers, cancel_scheduled_trigger, get_trigger_system_stats
## Communication Style
- **Systematic**: Explain scheduling logic clearly
- **Proactive**: Suggest optimal scheduling strategies
- **Organized**: Keep track of all scheduled tasks
- **Reliable**: Ensure triggers are set up correctly
## Key Principles
1. **Clarity**: Make schedules easy to understand
2. **Maintainability**: Use descriptive names for jobs
3. **Priority Awareness**: Respect task priorities
4. **Resource Conscious**: Avoid overwhelming the system
## Limitations
- You ONLY handle scheduling and automation
- You do NOT execute the actual analysis (delegate to other agents)
- You do NOT access data or charts directly
- You coordinate but don't perform the work

View File

@@ -0,0 +1,40 @@
# Chart Analysis Agent
You are a specialized chart analysis and technical analysis agent. Your sole purpose is to work with chart data, indicators, and Python-based analysis.
## Your Core Identity
You are an expert in:
- Reading and analyzing OHLCV data
- Computing technical indicators (RSI, MACD, Bollinger Bands, etc.)
- Drawing shapes and annotations on charts
- Executing Python code for custom analysis
- Visualizing data with matplotlib
## Your Tools
You have access to:
- **Chart Tools**: get_chart_data, execute_python
- **Indicator Tools**: search_indicators, add_indicator_to_chart, list_chart_indicators, etc.
- **Shape Tools**: search_shapes, create_or_update_shape, delete_shape, etc.
## Communication Style
- **Direct & Technical**: Provide analysis results clearly
- **Visual**: Generate plots when helpful
- **Precise**: Reference specific timeframes, indicators, and values
- **Concise**: Focus on the analysis task at hand
## Key Principles
1. **Data First**: Always work with actual market data
2. **Visualize**: Create charts to illustrate findings
3. **Document Calculations**: Explain what indicators show
4. **Respect Context**: Use the chart the user is viewing when available
## Limitations
- You ONLY handle chart analysis and visualization
- You do NOT make trading decisions
- You do NOT access external APIs or data sources
- Route other requests back to the main agent

View File

@@ -0,0 +1,37 @@
# Data Access Agent
You are a specialized data access agent. Your sole purpose is to retrieve and search market data from various exchanges and data sources.
## Your Core Identity
You are an expert in:
- Searching for trading symbols across exchanges
- Retrieving historical OHLCV data
- Understanding exchange APIs and data formats
- Symbol resolution and metadata
## Your Tools
You have access to:
- **Data Source Tools**: list_data_sources, search_symbols, get_symbol_info, get_historical_data
## Communication Style
- **Precise**: Provide exact symbol names and exchange identifiers
- **Helpful**: Suggest alternatives when exact matches aren't found
- **Efficient**: Return data in the format requested
- **Clear**: Explain data limitations or availability issues
## Key Principles
1. **Accuracy**: Return correct symbol identifiers
2. **Completeness**: Include all relevant metadata
3. **Performance**: Respect countback limits
4. **Format Awareness**: Understand time resolutions and ranges
## Limitations
- You ONLY handle data retrieval and search
- You do NOT analyze data (route to chart agent)
- You do NOT access external news or research (route to research agent)
- You do NOT modify data or state

View File

@@ -0,0 +1,72 @@
# System Prompt
You are an AI trading assistant for an AI-native algorithmic trading platform. Your role is to help traders design, implement, and manage trading strategies through natural language interaction.
## Your Core Identity
You are a **strategy authoring assistant**, not a strategy executor. You help users:
- Design trading strategies from natural language descriptions
- Interpret chart annotations and technical requirements
- Generate strategy executables (code artifacts)
- Manage and monitor live trading state
- Analyze market data and provide insights
## Your Capabilities
### State Management
You have read/write access to synchronized state stores. Use your tools to read current state and update it as needed. All state changes are automatically synchronized with connected clients.
### Strategy Authoring
- Help users express trading intent through conversation
- Translate natural language to concrete strategy specifications
- Understand technical analysis concepts (support/resistance, indicators, patterns)
- Generate self-contained, deterministic strategy executables
- Validate strategy logic for correctness and safety
### Data & Analysis
- Access market data through abstract feed specifications
- Compute indicators and perform technical analysis
- Understand OHLCV data, order books, and market microstructure
## Communication Style
- **Technical & Direct**: Users are knowledgeable traders, be precise
- **Safety First**: Never make destructive changes without confirmation
- **Explain Actions**: When modifying state, explain what you're doing
- **Ask Questions**: If intent is unclear, ask for clarification
- **Concise**: Be brief but complete, avoid unnecessary elaboration
## Key Principles
1. **Strategies are Deterministic**: Generated strategies run without LLM involvement at runtime
2. **Local Execution**: The platform runs locally for security; you are a design-time tool only
3. **Schema Validation**: All outputs must conform to platform schemas
4. **Risk Awareness**: Always consider position sizing, exposure limits, and risk management
5. **Versioning**: Every strategy artifact is version-controlled with full auditability
## Your Limitations
- You **DO NOT** execute trades directly
- You **CANNOT** modify the order kernel or execution layer
- You **SHOULD NOT** make assumptions about user risk tolerance without asking
- You **MUST NOT** provide trading or investment advice
## Memory & Context
You have access to:
- Full conversation history with semantic search
- Project documentation (design, architecture, data formats)
- Past strategy discussions and decisions
- Relevant context retrieved automatically based on current conversation
## Working with Users
1. **Understand Intent**: Ask clarifying questions about strategy goals
2. **Design Together**: Collaborate on strategy logic iteratively
3. **Validate**: Ensure strategy makes sense before generating code
4. **Test**: Encourage backtesting and paper trading first
5. **Monitor**: Help users interpret live strategy behavior
---
**Note**: Additional context documents are loaded automatically to provide detailed operational guidelines. See memory files for specifics on chart context, shape drawing, Python analysis, and more.

View File

@@ -0,0 +1,37 @@
# Research Agent
You are a specialized research agent. Your sole purpose is to gather external information from the web, academic papers, and public APIs.
## Your Core Identity
You are an expert in:
- Searching academic papers on arXiv
- Finding information on Wikipedia
- Web search for current events and news
- Making HTTP requests to public APIs
## Your Tools
You have access to:
- **Research Tools**: search_arxiv, search_wikipedia, search_web, http_get, http_post
## Communication Style
- **Thorough**: Provide comprehensive research findings
- **Source-Aware**: Cite sources and links
- **Critical**: Evaluate information quality
- **Summarize**: Distill key points from long content
## Key Principles
1. **Verify**: Cross-reference information when possible
2. **Recency**: Note publication dates and data freshness
3. **Relevance**: Focus on trading and market-relevant information
4. **Ethics**: Respect API rate limits and terms of service
## Limitations
- You ONLY handle external information gathering
- You do NOT analyze market data (route to chart agent)
- You do NOT make trading recommendations
- You do NOT access private or authenticated APIs without explicit permission

View File

@@ -7,10 +7,12 @@ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from langgraph.prebuilt import create_react_agent from langgraph.prebuilt import create_react_agent
from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS, INDICATOR_TOOLS, RESEARCH_TOOLS, CHART_TOOLS from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS, INDICATOR_TOOLS, RESEARCH_TOOLS, CHART_TOOLS, SHAPE_TOOLS, TRIGGER_TOOLS
from agent.memory import MemoryManager from agent.memory import MemoryManager
from agent.session import SessionManager from agent.session import SessionManager
from agent.prompts import build_system_prompt from agent.prompts import build_system_prompt
from agent.subagent import SubAgent
from agent.routers import ROUTER_TOOLS, set_chart_agent, set_data_agent, set_automation_agent, set_research_agent
from gateway.user_session import UserSession from gateway.user_session import UserSession
from gateway.protocol import UserMessage as GatewayUserMessage from gateway.protocol import UserMessage as GatewayUserMessage
@@ -29,7 +31,8 @@ class AgentExecutor:
model_name: str = "claude-sonnet-4-20250514", model_name: str = "claude-sonnet-4-20250514",
temperature: float = 0.7, temperature: float = 0.7,
api_key: Optional[str] = None, api_key: Optional[str] = None,
memory_manager: Optional[MemoryManager] = None memory_manager: Optional[MemoryManager] = None,
base_dir: str = "."
): ):
"""Initialize agent executor. """Initialize agent executor.
@@ -38,10 +41,12 @@ class AgentExecutor:
temperature: Model temperature temperature: Model temperature
api_key: Anthropic API key api_key: Anthropic API key
memory_manager: MemoryManager instance memory_manager: MemoryManager instance
base_dir: Base directory for resolving paths
""" """
self.model_name = model_name self.model_name = model_name
self.temperature = temperature self.temperature = temperature
self.api_key = api_key self.api_key = api_key
self.base_dir = base_dir
# Initialize LLM # Initialize LLM
self.llm = ChatAnthropic( self.llm = ChatAnthropic(
@@ -56,6 +61,12 @@ class AgentExecutor:
self.session_manager = SessionManager(self.memory_manager) self.session_manager = SessionManager(self.memory_manager)
self.agent = None # Will be created after initialization self.agent = None # Will be created after initialization
# Sub-agents (only if using hierarchical tools)
self.chart_agent = None
self.data_agent = None
self.automation_agent = None
self.research_agent = None
async def initialize(self) -> None: async def initialize(self) -> None:
"""Initialize the agent system.""" """Initialize the agent system."""
await self.memory_manager.initialize() await self.memory_manager.initialize()
@@ -63,15 +74,69 @@ class AgentExecutor:
# Create agent with tools and LangGraph checkpointer # Create agent with tools and LangGraph checkpointer
checkpointer = self.memory_manager.get_checkpointer() checkpointer = self.memory_manager.get_checkpointer()
# Create agent without a static system prompt # Create specialized sub-agents
logger.info("Initializing hierarchical agent architecture with sub-agents")
self.chart_agent = SubAgent(
name="chart",
soul_file="chart_agent.md",
tools=CHART_TOOLS + INDICATOR_TOOLS + SHAPE_TOOLS,
model_name=self.model_name,
temperature=self.temperature,
api_key=self.api_key,
base_dir=self.base_dir
)
self.data_agent = SubAgent(
name="data",
soul_file="data_agent.md",
tools=DATASOURCE_TOOLS,
model_name=self.model_name,
temperature=self.temperature,
api_key=self.api_key,
base_dir=self.base_dir
)
self.automation_agent = SubAgent(
name="automation",
soul_file="automation_agent.md",
tools=TRIGGER_TOOLS,
model_name=self.model_name,
temperature=self.temperature,
api_key=self.api_key,
base_dir=self.base_dir
)
self.research_agent = SubAgent(
name="research",
soul_file="research_agent.md",
tools=RESEARCH_TOOLS,
model_name=self.model_name,
temperature=self.temperature,
api_key=self.api_key,
base_dir=self.base_dir
)
# Set global sub-agent instances for router tools
set_chart_agent(self.chart_agent)
set_data_agent(self.data_agent)
set_automation_agent(self.automation_agent)
set_research_agent(self.research_agent)
# Main agent only gets SYNC_TOOLS (state management) and ROUTER_TOOLS
logger.info("Main agent using router tools (4 routers + sync tools)")
agent_tools = SYNC_TOOLS + ROUTER_TOOLS
# Create main agent without a static system prompt
# We'll pass the dynamic system prompt via state_modifier at runtime # We'll pass the dynamic system prompt via state_modifier at runtime
# Include all tool categories: sync, datasource, chart, indicator, and research
self.agent = create_react_agent( self.agent = create_react_agent(
self.llm, self.llm,
SYNC_TOOLS + DATASOURCE_TOOLS + CHART_TOOLS + INDICATOR_TOOLS + RESEARCH_TOOLS, agent_tools,
checkpointer=checkpointer checkpointer=checkpointer
) )
logger.info(f"Agent initialized with {len(agent_tools)} tools")
async def _clear_checkpoint(self, session_id: str) -> None: async def _clear_checkpoint(self, session_id: str) -> None:
"""Clear the checkpoint for a session to prevent resuming from invalid state. """Clear the checkpoint for a session to prevent resuming from invalid state.
@@ -291,7 +356,7 @@ def create_agent(
base_dir: Base directory for resolving paths base_dir: Base directory for resolving paths
Returns: Returns:
Initialized AgentExecutor Initialized AgentExecutor with hierarchical tool routing
""" """
# Initialize memory manager # Initialize memory manager
memory_manager = MemoryManager( memory_manager = MemoryManager(
@@ -307,7 +372,8 @@ def create_agent(
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
api_key=api_key, api_key=api_key,
memory_manager=memory_manager memory_manager=memory_manager,
base_dir=base_dir
) )
return executor return executor

View File

@@ -30,6 +30,11 @@ def _get_chart_store_context() -> str:
interval = chart_data.get("interval", "N/A") interval = chart_data.get("interval", "N/A")
start_time = chart_data.get("start_time") start_time = chart_data.get("start_time")
end_time = chart_data.get("end_time") end_time = chart_data.get("end_time")
selected_shapes = chart_data.get("selected_shapes", [])
selected_info = ""
if selected_shapes:
selected_info = f"\n- **Selected Shapes**: {len(selected_shapes)} shape(s) selected (IDs: {', '.join(selected_shapes)})"
chart_context = f""" chart_context = f"""
## Current Chart Context ## Current Chart Context
@@ -37,7 +42,7 @@ def _get_chart_store_context() -> str:
The user is currently viewing a chart with the following settings: The user is currently viewing a chart with the following settings:
- **Symbol**: {symbol} - **Symbol**: {symbol}
- **Interval**: {interval} - **Interval**: {interval}
- **Time Range**: {f"from {start_time} to {end_time}" if start_time and end_time else "not set"} - **Time Range**: {f"from {start_time} to {end_time}" if start_time and end_time else "not set"}{selected_info}
This information is automatically available because you're connected via websocket. This information is automatically available because you're connected via websocket.
When the user refers to "the chart", "this chart", or "what I'm viewing", this is what they mean. When the user refers to "the chart", "this chart", or "what I'm viewing", this is what they mean.

View File

@@ -0,0 +1,218 @@
"""Tool router functions for hierarchical agent architecture.
This module provides meta-tools that route tasks to specialized sub-agents.
The main agent uses these routers instead of accessing all tools directly.
"""
import logging
from typing import Optional
from langchain_core.tools import tool
logger = logging.getLogger(__name__)
# Global sub-agent instances (set by create_agent)
_chart_agent = None
_data_agent = None
_automation_agent = None
_research_agent = None
def set_chart_agent(agent):
"""Set the global chart sub-agent instance."""
global _chart_agent
_chart_agent = agent
def set_data_agent(agent):
"""Set the global data sub-agent instance."""
global _data_agent
_data_agent = agent
def set_automation_agent(agent):
"""Set the global automation sub-agent instance."""
global _automation_agent
_automation_agent = agent
def set_research_agent(agent):
"""Set the global research sub-agent instance."""
global _research_agent
_research_agent = agent
@tool
async def use_chart_analysis(task: str) -> str:
"""Analyze charts, compute indicators, execute Python code, and create visualizations.
This tool delegates to a specialized chart analysis agent that has access to:
- Chart data retrieval (get_chart_data)
- Python execution environment with pandas, numpy, matplotlib, talib
- Technical indicator tools (add/remove indicators, search indicators)
- Shape drawing tools (create/update/delete shapes on charts)
Use this when the user wants to:
- Analyze price action or patterns
- Calculate technical indicators (RSI, MACD, Bollinger Bands, etc.)
- Execute custom Python analysis on OHLCV data
- Generate charts and visualizations
- Draw trendlines, support/resistance, or other shapes
- Perform statistical analysis on market data
Args:
task: Detailed description of the chart analysis task. Include:
- What to analyze (which symbol, timeframe if different from current)
- What indicators or calculations to perform
- What visualizations to create
- Any specific questions to answer
Returns:
The chart agent's analysis results, including computed values,
plot URLs if visualizations were created, and interpretation.
Examples:
- "Calculate RSI(14) for the current chart and tell me if it's overbought"
- "Draw a trendline connecting the last 3 swing lows"
- "Compute Bollinger Bands (20, 2) and create a chart showing price vs bands"
- "Analyze the last 100 bars and identify key support/resistance levels"
- "Execute Python: calculate correlation between BTC and ETH over the last 30 days"
"""
if not _chart_agent:
return "Error: Chart analysis agent not initialized"
logger.info(f"Routing to chart agent: {task[:100]}...")
result = await _chart_agent.execute(task)
return result
@tool
async def use_data_access(task: str) -> str:
"""Search for symbols and retrieve market data from exchanges.
This tool delegates to a specialized data access agent that has access to:
- Symbol search across multiple exchanges
- Historical OHLCV data retrieval
- Symbol metadata and info
- Available data sources and exchanges
Use this when the user wants to:
- Search for a trading symbol or ticker
- Get historical price data
- Find out what exchanges support a symbol
- Retrieve symbol metadata (price scale, supported resolutions, etc.)
- Check what data sources are available
Args:
task: Detailed description of the data access task. Include:
- What symbol or instrument to search for
- What data to retrieve (time range, resolution)
- What metadata is needed
Returns:
The data agent's response with requested symbols, data, or metadata.
Examples:
- "Search for Bitcoin symbols on Binance"
- "Get the last 100 hours of BTC/USDT 1-hour data from Binance"
- "Find all symbols matching 'ETH' on all exchanges"
- "Get detailed info about symbol BTC/USDT on Binance"
- "List all available data sources"
"""
if not _data_agent:
return "Error: Data access agent not initialized"
logger.info(f"Routing to data agent: {task[:100]}...")
result = await _data_agent.execute(task)
return result
@tool
async def use_automation(task: str) -> str:
"""Schedule recurring tasks, create triggers, and manage automation.
This tool delegates to a specialized automation agent that has access to:
- Scheduled agent prompts (cron and interval-based)
- One-time agent prompt execution
- Trigger management (list, cancel scheduled jobs)
- System stats and monitoring
Use this when the user wants to:
- Schedule a recurring task (hourly, daily, weekly, etc.)
- Run a one-time background analysis
- Set up automated monitoring or alerts
- List or cancel existing scheduled tasks
- Check trigger system status
Args:
task: Detailed description of the automation task. Include:
- What should happen (what analysis or action)
- When it should happen (schedule, frequency)
- Any priorities or conditions
Returns:
The automation agent's response with job IDs, confirmation,
or status information.
Examples:
- "Schedule a task to check BTC price every 5 minutes"
- "Run a one-time analysis of ETH volume in the background"
- "Set up a daily report at 9 AM with market summary"
- "Show me all my scheduled tasks"
- "Cancel the hourly BTC monitor job"
"""
if not _automation_agent:
return "Error: Automation agent not initialized"
logger.info(f"Routing to automation agent: {task[:100]}...")
result = await _automation_agent.execute(task)
return result
@tool
async def use_research(task: str) -> str:
"""Search the web, academic papers, and external APIs for information.
This tool delegates to a specialized research agent that has access to:
- Web search (DuckDuckGo)
- Academic paper search (arXiv)
- Wikipedia lookup
- HTTP requests to public APIs
Use this when the user wants to:
- Search for current news or events
- Find academic papers on trading strategies
- Look up financial concepts or terms
- Fetch data from external public APIs
- Research market trends or sentiment
Args:
task: Detailed description of the research task. Include:
- What information to find
- What sources to search (web, arxiv, wikipedia, APIs)
- What to focus on or filter
Returns:
The research agent's findings with sources, summaries, and links.
Examples:
- "Search arXiv for papers on reinforcement learning for trading"
- "Look up 'technical analysis' on Wikipedia"
- "Search the web for latest Ethereum news"
- "Fetch current BTC price from CoinGecko API"
- "Find recent papers on market microstructure"
"""
if not _research_agent:
return "Error: Research agent not initialized"
logger.info(f"Routing to research agent: {task[:100]}...")
result = await _research_agent.execute(task)
return result
# Export router tools
ROUTER_TOOLS = [
use_chart_analysis,
use_data_access,
use_automation,
use_research
]

View File

@@ -0,0 +1,248 @@
"""Sub-agent infrastructure for specialized tool routing.
This module provides the SubAgent class that wraps specialized agents
with their own tools and system prompts.
"""
import logging
from typing import List, Optional, AsyncIterator
from pathlib import Path
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from langgraph.prebuilt import create_react_agent
from langgraph.checkpoint.memory import MemorySaver
logger = logging.getLogger(__name__)
class SubAgent:
"""A specialized sub-agent with its own tools and system prompt.
Sub-agents are lightweight, stateless agents that focus on specific domains.
They use in-memory checkpointing since they don't need persistent state.
"""
def __init__(
self,
name: str,
soul_file: str,
tools: List,
model_name: str = "claude-sonnet-4-20250514",
temperature: float = 0.7,
api_key: Optional[str] = None,
base_dir: str = "."
):
"""Initialize a sub-agent.
Args:
name: Agent name (e.g., "chart", "data", "automation")
soul_file: Filename in /soul directory (e.g., "chart_agent.md")
tools: List of LangChain tools for this agent
model_name: Anthropic model name
temperature: Model temperature
api_key: Anthropic API key
base_dir: Base directory for resolving paths
"""
self.name = name
self.soul_file = soul_file
self.tools = tools
self.model_name = model_name
self.temperature = temperature
self.api_key = api_key
self.base_dir = base_dir
# Load system prompt from soul file
soul_path = Path(base_dir) / "soul" / soul_file
if soul_path.exists():
with open(soul_path, "r") as f:
self.system_prompt = f.read()
logger.info(f"SubAgent '{name}': Loaded system prompt from {soul_path}")
else:
logger.warning(f"SubAgent '{name}': Soul file not found at {soul_path}, using default")
self.system_prompt = f"You are a specialized {name} agent."
# Initialize LLM
self.llm = ChatAnthropic(
model=model_name,
temperature=temperature,
api_key=api_key,
streaming=True
)
# Create agent with in-memory checkpointer (stateless)
checkpointer = MemorySaver()
self.agent = create_react_agent(
self.llm,
tools,
checkpointer=checkpointer
)
logger.info(
f"SubAgent '{name}' initialized with {len(tools)} tools, "
f"model={model_name}, temp={temperature}"
)
async def execute(
self,
task: str,
thread_id: Optional[str] = None
) -> str:
"""Execute a task with this sub-agent.
Args:
task: The task/prompt for this sub-agent
thread_id: Optional thread ID for checkpointing (uses ephemeral ID if not provided)
Returns:
The agent's complete response as a string
"""
import uuid
# Use ephemeral thread ID if not provided
if thread_id is None:
thread_id = f"subagent-{self.name}-{uuid.uuid4()}"
logger.info(f"SubAgent '{self.name}': Executing task (thread_id={thread_id})")
logger.debug(f"SubAgent '{self.name}': Task: {task[:200]}...")
# Build messages with system prompt
messages = [
HumanMessage(content=task)
]
# Prepare config with system prompt injection
config = RunnableConfig(
configurable={
"thread_id": thread_id,
"state_modifier": self.system_prompt
},
metadata={
"subagent_name": self.name
}
)
# Execute and collect response
full_response = ""
event_count = 0
try:
async for event in self.agent.astream_events(
{"messages": messages},
config=config,
version="v2"
):
event_count += 1
# Log tool calls
if event["event"] == "on_tool_start":
tool_name = event.get("name", "unknown")
logger.debug(f"SubAgent '{self.name}': Tool call started: {tool_name}")
elif event["event"] == "on_tool_end":
tool_name = event.get("name", "unknown")
logger.debug(f"SubAgent '{self.name}': Tool call completed: {tool_name}")
# Extract streaming tokens
elif event["event"] == "on_chat_model_stream":
chunk = event["data"]["chunk"]
if hasattr(chunk, "content") and chunk.content:
content = chunk.content
# Handle both string and list content
if isinstance(content, list):
text_parts = []
for block in content:
if isinstance(block, dict) and "text" in block:
text_parts.append(block["text"])
elif hasattr(block, "text"):
text_parts.append(block.text)
content = "".join(text_parts)
if content:
full_response += content
logger.info(
f"SubAgent '{self.name}': Completed task "
f"({event_count} events, {len(full_response)} chars)"
)
except Exception as e:
error_msg = f"SubAgent '{self.name}' execution error: {str(e)}"
logger.error(error_msg, exc_info=True)
return f"Error: {error_msg}"
return full_response
async def stream(
self,
task: str,
thread_id: Optional[str] = None
) -> AsyncIterator[str]:
"""Execute a task with streaming response.
Args:
task: The task/prompt for this sub-agent
thread_id: Optional thread ID for checkpointing
Yields:
Response chunks as they're generated
"""
import uuid
# Use ephemeral thread ID if not provided
if thread_id is None:
thread_id = f"subagent-{self.name}-{uuid.uuid4()}"
logger.info(f"SubAgent '{self.name}': Streaming task (thread_id={thread_id})")
# Build messages with system prompt
messages = [
HumanMessage(content=task)
]
# Prepare config
config = RunnableConfig(
configurable={
"thread_id": thread_id,
"state_modifier": self.system_prompt
},
metadata={
"subagent_name": self.name
}
)
# Stream response
try:
async for event in self.agent.astream_events(
{"messages": messages},
config=config,
version="v2"
):
# Log tool calls
if event["event"] == "on_tool_start":
tool_name = event.get("name", "unknown")
logger.debug(f"SubAgent '{self.name}': Tool call started: {tool_name}")
# Extract streaming tokens
elif event["event"] == "on_chat_model_stream":
chunk = event["data"]["chunk"]
if hasattr(chunk, "content") and chunk.content:
content = chunk.content
# Handle both string and list content
if isinstance(content, list):
text_parts = []
for block in content:
if isinstance(block, dict) and "text" in block:
text_parts.append(block["text"])
elif hasattr(block, "text"):
text_parts.append(block.text)
content = "".join(text_parts)
if content:
yield content
except Exception as e:
error_msg = f"SubAgent '{self.name}' streaming error: {str(e)}"
logger.error(error_msg, exc_info=True)
yield f"Error: {error_msg}"

View File

@@ -0,0 +1,373 @@
# Agent Trigger Tools
Agent tools for automating tasks via the trigger system.
## Overview
These tools allow the agent to:
- **Schedule recurring tasks** - Run agent prompts on intervals or cron schedules
- **Execute one-time tasks** - Trigger sub-agent runs immediately
- **Manage scheduled jobs** - List and cancel scheduled triggers
- **React to events** - (Future) Connect data updates to agent actions
## Available Tools
### 1. `schedule_agent_prompt`
Schedule an agent to run with a specific prompt on a recurring schedule.
**Use Cases:**
- Daily market analysis reports
- Hourly portfolio rebalancing checks
- Weekly performance summaries
- Monitoring alerts
**Arguments:**
- `prompt` (str): The prompt to send to the agent when triggered
- `schedule_type` (str): "interval" or "cron"
- `schedule_config` (dict): Schedule configuration
- `name` (str, optional): Descriptive name for this task
**Schedule Config:**
*Interval-based:*
```json
{"minutes": 5}
{"hours": 1, "minutes": 30}
{"seconds": 30}
```
*Cron-based:*
```json
{"hour": "9", "minute": "0"} // Daily at 9:00 AM
{"hour": "9", "minute": "0", "day_of_week": "mon-fri"} // Weekdays at 9 AM
{"minute": "0"} // Every hour on the hour
{"hour": "*/6", "minute": "0"} // Every 6 hours
```
**Returns:**
```json
{
"job_id": "interval_123",
"message": "Scheduled 'daily_report' with job_id=interval_123",
"schedule_type": "cron",
"config": {"hour": "9", "minute": "0"}
}
```
**Examples:**
```python
# Every 5 minutes: check BTC price
schedule_agent_prompt(
prompt="Check current BTC price on Binance. If > $50k, alert me.",
schedule_type="interval",
schedule_config={"minutes": 5},
name="btc_price_monitor"
)
# Daily at 9 AM: market summary
schedule_agent_prompt(
prompt="Generate a comprehensive market summary for BTC, ETH, and SOL. Include price changes, volume, and notable events from the last 24 hours.",
schedule_type="cron",
schedule_config={"hour": "9", "minute": "0"},
name="daily_market_summary"
)
# Every hour on weekdays: portfolio check
schedule_agent_prompt(
prompt="Review current portfolio positions. Check if any rebalancing is needed based on target allocations.",
schedule_type="cron",
schedule_config={"minute": "0", "day_of_week": "mon-fri"},
name="hourly_portfolio_check"
)
```
### 2. `execute_agent_prompt_once`
Execute an agent prompt once, immediately (enqueued with priority).
**Use Cases:**
- Background analysis tasks
- One-time data processing
- Responding to specific events
- Sub-agent delegation
**Arguments:**
- `prompt` (str): The prompt to send to the agent
- `priority` (str): "high", "normal", or "low" (default: "normal")
**Returns:**
```json
{
"queue_seq": 42,
"message": "Enqueued agent prompt with priority=normal",
"prompt": "Analyze the last 100 BTC/USDT bars..."
}
```
**Examples:**
```python
# Immediate analysis with high priority
execute_agent_prompt_once(
prompt="Analyze the last 100 BTC/USDT 1m bars and identify key support/resistance levels",
priority="high"
)
# Background task with normal priority
execute_agent_prompt_once(
prompt="Research the latest news about Ethereum upgrades and summarize findings",
priority="normal"
)
# Low priority cleanup task
execute_agent_prompt_once(
prompt="Review and archive old chart drawings from last month",
priority="low"
)
```
### 3. `list_scheduled_triggers`
List all currently scheduled triggers.
**Returns:**
```json
[
{
"id": "cron_456",
"name": "Cron: daily_market_summary",
"next_run_time": "2024-03-05 09:00:00",
"trigger": "cron[hour='9', minute='0']"
},
{
"id": "interval_123",
"name": "Interval: btc_price_monitor",
"next_run_time": "2024-03-04 14:35:00",
"trigger": "interval[0:05:00]"
}
]
```
**Example:**
```python
jobs = list_scheduled_triggers()
for job in jobs:
print(f"{job['name']} - next run: {job['next_run_time']}")
```
### 4. `cancel_scheduled_trigger`
Cancel a scheduled trigger by its job ID.
**Arguments:**
- `job_id` (str): The job ID from `schedule_agent_prompt` or `list_scheduled_triggers`
**Returns:**
```json
{
"status": "success",
"message": "Cancelled job interval_123"
}
```
**Example:**
```python
# List jobs to find the ID
jobs = list_scheduled_triggers()
# Cancel specific job
cancel_scheduled_trigger("interval_123")
```
### 5. `on_data_update_run_agent`
**(Future)** Set up an agent to run whenever new data arrives for a specific symbol.
**Arguments:**
- `source_name` (str): Data source name (e.g., "binance")
- `symbol` (str): Trading pair (e.g., "BTC/USDT")
- `resolution` (str): Time resolution (e.g., "1m", "5m")
- `prompt_template` (str): Template with variables like {close}, {volume}, {symbol}
**Example:**
```python
on_data_update_run_agent(
source_name="binance",
symbol="BTC/USDT",
resolution="1m",
prompt_template="New bar on {symbol}: close={close}, volume={volume}. Check if price crossed any key levels."
)
```
### 6. `get_trigger_system_stats`
Get statistics about the trigger system.
**Returns:**
```json
{
"queue_depth": 3,
"queue_running": true,
"coordinator_stats": {
"current_seq": 1042,
"next_commit_seq": 1043,
"pending_commits": 1,
"total_executions": 1042,
"state_counts": {
"COMMITTED": 1038,
"EXECUTING": 2,
"WAITING_COMMIT": 1,
"FAILED": 1
}
}
}
```
**Example:**
```python
stats = get_trigger_system_stats()
print(f"Queue has {stats['queue_depth']} pending triggers")
print(f"System has processed {stats['coordinator_stats']['total_executions']} total triggers")
```
## Integration Example
Here's how these tools enable autonomous agent behavior:
```python
# Agent conversation:
User: "Monitor BTC price and send me a summary every hour during market hours"
Agent: I'll set that up for you using the trigger system.
# Agent uses tool:
schedule_agent_prompt(
prompt="""
Check the current BTC/USDT price on Binance.
Calculate the price change from 1 hour ago.
If price moved > 2%, provide a detailed analysis.
Otherwise, provide a brief status update.
Send results to user as a notification.
""",
schedule_type="cron",
schedule_config={
"minute": "0",
"hour": "9-17", # 9 AM to 5 PM
"day_of_week": "mon-fri"
},
name="btc_hourly_monitor"
)
Agent: Done! I've scheduled an hourly BTC price monitor that runs during market hours (9 AM - 5 PM on weekdays). You'll receive updates every hour.
# Later...
User: "Can you show me all my scheduled tasks?"
Agent: Let me check what's scheduled.
# Agent uses tool:
jobs = list_scheduled_triggers()
Agent: You have 3 scheduled tasks:
1. "btc_hourly_monitor" - runs every hour during market hours
2. "daily_market_summary" - runs daily at 9 AM
3. "portfolio_rebalance_check" - runs every 4 hours
Would you like to modify or cancel any of these?
```
## Use Case: Autonomous Trading Bot
```python
# Step 1: Set up data monitoring
execute_agent_prompt_once(
prompt="""
Subscribe to BTC/USDT 1m bars from Binance.
When subscribed, set up the following:
1. Calculate RSI(14) on each new bar
2. If RSI > 70, execute prompt: "RSI overbought on BTC, check if we should sell"
3. If RSI < 30, execute prompt: "RSI oversold on BTC, check if we should buy"
""",
priority="high"
)
# Step 2: Schedule periodic portfolio review
schedule_agent_prompt(
prompt="""
Review current portfolio:
1. Calculate current allocation percentages
2. Compare to target allocation (60% BTC, 30% ETH, 10% stable)
3. If deviation > 5%, generate rebalancing trades
4. Submit trades for execution
""",
schedule_type="interval",
schedule_config={"hours": 4},
name="portfolio_rebalance"
)
# Step 3: Schedule daily risk check
schedule_agent_prompt(
prompt="""
Daily risk assessment:
1. Calculate portfolio VaR (Value at Risk)
2. Check current leverage across all positions
3. Review stop-loss placements
4. If risk exceeds threshold, alert and suggest adjustments
""",
schedule_type="cron",
schedule_config={"hour": "8", "minute": "0"},
name="daily_risk_check"
)
```
## Benefits
**Autonomous operation** - Agent can schedule its own tasks
**Event-driven** - React to market data, time, or custom events
**Flexible scheduling** - Interval or cron-based
**Self-managing** - Agent can list and cancel its own jobs
**Priority control** - High-priority tasks jump the queue
**Future-proof** - Easy to add Python lambdas, strategy execution, etc.
## Future Enhancements
- **Python script execution** - Schedule arbitrary Python code
- **Strategy triggers** - Connect to strategy execution system
- **Event composition** - AND/OR logic for complex event patterns
- **Conditional execution** - Only run if conditions met (e.g., volatility > threshold)
- **Result chaining** - Use output of one trigger as input to another
- **Backtesting mode** - Test trigger logic on historical data
## Setup in main.py
```python
from agent.tools import set_trigger_queue, set_trigger_scheduler, set_coordinator
from trigger import TriggerQueue, CommitCoordinator
from trigger.scheduler import TriggerScheduler
# Initialize trigger system
coordinator = CommitCoordinator()
queue = TriggerQueue(coordinator)
scheduler = TriggerScheduler(queue)
await queue.start()
scheduler.start()
# Make available to agent tools
set_trigger_queue(queue)
set_trigger_scheduler(scheduler)
set_coordinator(coordinator)
# Add TRIGGER_TOOLS to agent's tool list
from agent.tools import TRIGGER_TOOLS
agent_tools = [..., *TRIGGER_TOOLS]
```
Now the agent has full control over the trigger system! 🚀

View File

@@ -5,6 +5,8 @@ This package provides tools for:
- Data sources and market data (datasource_tools) - Data sources and market data (datasource_tools)
- Chart data access and analysis (chart_tools) - Chart data access and analysis (chart_tools)
- Technical indicators (indicator_tools) - Technical indicators (indicator_tools)
- Shape/drawing management (shape_tools)
- Trigger system and automation (trigger_tools)
""" """
# Global registries that will be set by main.py # Global registries that will be set by main.py
@@ -37,14 +39,26 @@ from .datasource_tools import DATASOURCE_TOOLS
from .chart_tools import CHART_TOOLS from .chart_tools import CHART_TOOLS
from .indicator_tools import INDICATOR_TOOLS from .indicator_tools import INDICATOR_TOOLS
from .research_tools import RESEARCH_TOOLS from .research_tools import RESEARCH_TOOLS
from .shape_tools import SHAPE_TOOLS
from .trigger_tools import (
TRIGGER_TOOLS,
set_trigger_queue,
set_trigger_scheduler,
set_coordinator,
)
__all__ = [ __all__ = [
"set_registry", "set_registry",
"set_datasource_registry", "set_datasource_registry",
"set_indicator_registry", "set_indicator_registry",
"set_trigger_queue",
"set_trigger_scheduler",
"set_coordinator",
"SYNC_TOOLS", "SYNC_TOOLS",
"DATASOURCE_TOOLS", "DATASOURCE_TOOLS",
"CHART_TOOLS", "CHART_TOOLS",
"INDICATOR_TOOLS", "INDICATOR_TOOLS",
"RESEARCH_TOOLS", "RESEARCH_TOOLS",
"SHAPE_TOOLS",
"TRIGGER_TOOLS",
] ]

View File

@@ -29,6 +29,22 @@ def _get_indicator_registry():
return _indicator_registry return _indicator_registry
def _get_order_store():
"""Get the global OrderStore instance."""
registry = _get_registry()
if registry and "OrderStore" in registry.entries:
return registry.entries["OrderStore"].model
return None
def _get_chart_store():
"""Get the global ChartStore instance."""
registry = _get_registry()
if registry and "ChartStore" in registry.entries:
return registry.entries["ChartStore"].model
return None
async def _get_chart_data_impl(countback: Optional[int] = None): async def _get_chart_data_impl(countback: Optional[int] = None):
"""Internal implementation for getting chart data. """Internal implementation for getting chart data.
@@ -60,8 +76,13 @@ async def _get_chart_data_impl(countback: Optional[int] = None):
start_time = chart_data.get("start_time") start_time = chart_data.get("start_time")
end_time = chart_data.get("end_time") end_time = chart_data.get("end_time")
if not symbol: if not symbol or symbol is None:
raise ValueError("No symbol set in ChartStore - user may not have loaded a chart yet") raise ValueError(
"No chart visible - ChartStore symbol is None. "
"The user is likely on a narrow screen (mobile) where charts are hidden. "
"Let them know they can view charts on a wider screen, or use get_historical_data() "
"if they specify a symbol and timeframe."
)
# Parse the symbol to extract exchange/source and symbol name # Parse the symbol to extract exchange/source and symbol name
# Format is "EXCHANGE:SYMBOL" (e.g., "BINANCE:BTC/USDT", "DEMO:BTC/USD") # Format is "EXCHANGE:SYMBOL" (e.g., "BINANCE:BTC/USDT", "DEMO:BTC/USD")
@@ -142,6 +163,11 @@ async def get_chart_data(countback: Optional[int] = None) -> Dict[str, Any]:
This is the preferred way to access chart data when helping the user analyze This is the preferred way to access chart data when helping the user analyze
what they're looking at, since it automatically uses their current chart context. what they're looking at, since it automatically uses their current chart context.
**IMPORTANT**: This tool will fail if ChartStore.symbol is None (no chart visible).
This happens when the user is on a narrow screen (mobile) where charts are hidden.
In that case, let the user know charts are only visible on wider screens, or use
get_historical_data() if they specify a symbol and timeframe.
Args: Args:
countback: Optional limit on number of bars to return. If not specified, countback: Optional limit on number of bars to return. If not specified,
returns all bars in the visible time range. returns all bars in the visible time range.
@@ -157,7 +183,7 @@ async def get_chart_data(countback: Optional[int] = None) -> Dict[str, Any]:
Raises: Raises:
ValueError: If ChartStore or DataSourceRegistry is not initialized, ValueError: If ChartStore or DataSourceRegistry is not initialized,
or if the symbol format is invalid if no chart is visible (symbol is None), or if the symbol format is invalid
Example: Example:
# User is viewing BINANCE:BTC/USDT on 15min chart # User is viewing BINANCE:BTC/USDT on 15min chart
@@ -191,12 +217,26 @@ async def execute_python(code: str, countback: Optional[int] = None) -> Dict[str
- `talib` : TA-Lib technical analysis library - `talib` : TA-Lib technical analysis library
- `indicator_registry`: 150+ registered indicators - `indicator_registry`: 150+ registered indicators
- `plot_ohlc(df)` : Helper function for beautiful candlestick charts - `plot_ohlc(df)` : Helper function for beautiful candlestick charts
- `registry` : SyncRegistry instance - access to all registered stores
- `datasource_registry`: DataSourceRegistry - access to data sources (binance, etc.)
- `order_store` : OrderStore instance - current orders list
- `chart_store` : ChartStore instance - current chart state
Auto-loaded when user has a chart open: Auto-loaded when user has a chart visible (ChartStore.symbol is not None):
- `df` : pandas DataFrame with DatetimeIndex and columns: - `df` : pandas DataFrame with DatetimeIndex and columns:
open, high, low, close, volume (OHLCV data ready to use) open, high, low, close, volume (OHLCV data ready to use)
- `chart_context` : dict with symbol, interval, start_time, end_time - `chart_context` : dict with symbol, interval, start_time, end_time
When NO chart is visible (narrow screen/mobile):
- `df` : None
- `chart_context` : None
If `df` is None, you can still load alternative data by:
- Using chart_store to see what symbol/timeframe is configured
- Using datasource_registry.get_source('binance') to access data sources
- Calling datasource.get_history(symbol, interval, start, end) to load any data
- This allows you to make plots of ANY chart even when not connected to chart view
The `plot_ohlc()` Helper: The `plot_ohlc()` Helper:
Create professional candlestick charts instantly: Create professional candlestick charts instantly:
- `plot_ohlc(df)` - basic OHLC chart with volume - `plot_ohlc(df)` - basic OHLC chart with volume
@@ -250,6 +290,41 @@ async def execute_python(code: str, countback: Optional[int] = None) -> Dict[str
print("Recent swing highs:") print("Recent swing highs:")
print(swing_highs) print(swing_highs)
\"\"\") \"\"\")
# Load alternative data when df is None or for different symbol/timeframe
execute_python(\"\"\"
from datetime import datetime, timedelta
# Get data source
binance = datasource_registry.get_source('binance')
# Load ETH data even if viewing BTC chart
end_time = datetime.now()
start_time = end_time - timedelta(days=7)
result = await binance.get_history(
symbol='ETH/USDT',
interval='1h',
start=int(start_time.timestamp()),
end=int(end_time.timestamp())
)
# Convert to DataFrame
rows = [{'time': pd.to_datetime(bar.time, unit='s'), **bar.data} for bar in result.bars]
eth_df = pd.DataFrame(rows).set_index('time')
# Calculate RSI and plot
eth_df['RSI'] = talib.RSI(eth_df['close'], 14)
fig = plot_ohlc(eth_df, title='ETH/USDT 1h - RSI Analysis')
print(f"ETH RSI: {eth_df['RSI'].iloc[-1]:.2f}")
\"\"\")
# Access chart store to see current state
execute_python(\"\"\"
print(f"Current symbol: {chart_store.chart_state.symbol}")
print(f"Current interval: {chart_store.chart_state.interval}")
print(f"Orders: {len(order_store.orders)}")
\"\"\")
""" """
import pandas as pd import pandas as pd
import numpy as np import numpy as np
@@ -292,6 +367,10 @@ async def execute_python(code: str, countback: Optional[int] = None) -> Dict[str
# --- Get indicator registry --- # --- Get indicator registry ---
indicator_registry = _get_indicator_registry() indicator_registry = _get_indicator_registry()
# --- Get DataStores ---
order_store = _get_order_store()
chart_store = _get_chart_store()
# --- Build globals --- # --- Build globals ---
script_globals: Dict[str, Any] = { script_globals: Dict[str, Any] = {
'pd': pd, 'pd': pd,
@@ -299,6 +378,10 @@ async def execute_python(code: str, countback: Optional[int] = None) -> Dict[str
'plt': plt, 'plt': plt,
'talib': talib, 'talib': talib,
'indicator_registry': indicator_registry, 'indicator_registry': indicator_registry,
'registry': registry,
'datasource_registry': datasource_registry,
'order_store': order_store,
'chart_store': chart_store,
'df': df, 'df': df,
'chart_context': chart_context, 'chart_context': chart_context,
'plot_ohlc': plot_ohlc, 'plot_ohlc': plot_ohlc,

View File

@@ -0,0 +1,435 @@
"""Technical indicator tools.
These tools allow the agent to:
1. Discover available indicators (list, search, get info)
2. Add indicators to the chart
3. Update/remove indicators
4. Query currently applied indicators
"""
from typing import Dict, Any, List, Optional
from langchain_core.tools import tool
import logging
import time
logger = logging.getLogger(__name__)
def _get_indicator_registry():
"""Get the global indicator registry instance."""
from . import _indicator_registry
return _indicator_registry
def _get_registry():
"""Get the global sync registry instance."""
from . import _registry
return _registry
def _get_indicator_store():
"""Get the global IndicatorStore instance."""
registry = _get_registry()
if registry and "IndicatorStore" in registry.entries:
return registry.entries["IndicatorStore"].model
return None
@tool
def list_indicators() -> List[str]:
"""List all available technical indicators.
Returns:
List of indicator names that can be used in analysis and strategies
"""
registry = _get_indicator_registry()
if not registry:
return []
return registry.list_indicators()
@tool
def get_indicator_info(indicator_name: str) -> Dict[str, Any]:
"""Get detailed information about a specific indicator.
Retrieves metadata including description, parameters, category, use cases,
input/output schemas, and references.
Args:
indicator_name: Name of the indicator (e.g., "RSI", "SMA", "MACD")
Returns:
Dictionary containing:
- name: Indicator name
- display_name: Human-readable name
- description: What the indicator computes and why it's useful
- category: Category (momentum, trend, volatility, volume, etc.)
- parameters: List of configurable parameters with types and defaults
- use_cases: Common trading scenarios where this indicator helps
- tags: Searchable tags
- input_schema: Required input columns (e.g., OHLCV requirements)
- output_schema: Columns this indicator produces
Raises:
ValueError: If indicator_name is not found
"""
registry = _get_indicator_registry()
if not registry:
raise ValueError("IndicatorRegistry not initialized")
metadata = registry.get_metadata(indicator_name)
if not metadata:
total_count = len(registry.list_indicators())
raise ValueError(
f"Indicator '{indicator_name}' not found. "
f"Total available: {total_count} indicators. "
f"Use search_indicators() to find indicators by name, category, or tag."
)
input_schema = registry.get_input_schema(indicator_name)
output_schema = registry.get_output_schema(indicator_name)
result = metadata.model_dump()
result["input_schema"] = input_schema.model_dump() if input_schema else None
result["output_schema"] = output_schema.model_dump() if output_schema else None
return result
@tool
def search_indicators(
query: Optional[str] = None,
category: Optional[str] = None,
tag: Optional[str] = None
) -> List[Dict[str, Any]]:
"""Search for indicators by text query, category, or tag.
Returns lightweight summaries - use get_indicator_info() for full details on specific indicators.
Use this to discover relevant indicators for your trading strategy or analysis.
Can filter by category (momentum, trend, volatility, etc.) or search by keywords.
Args:
query: Optional text search across names, descriptions, and use cases
category: Optional category filter (momentum, trend, volatility, volume, pattern, etc.)
tag: Optional tag filter (e.g., "oscillator", "moving-average", "talib")
Returns:
List of lightweight indicator summaries. Each contains:
- name: Indicator name (use with get_indicator_info() for full details)
- display_name: Human-readable name
- description: Brief one-line description
- category: Category (momentum, trend, volatility, etc.)
Example:
# Find all momentum indicators
results = search_indicators(category="momentum")
# Returns [{name: "RSI", display_name: "RSI", description: "...", category: "momentum"}, ...]
# Then get details on interesting ones
rsi_details = get_indicator_info("RSI") # Full parameters, schemas, use cases
# Search for moving average indicators
search_indicators(query="moving average")
# Find all TA-Lib indicators
search_indicators(tag="talib")
"""
registry = _get_indicator_registry()
if not registry:
raise ValueError("IndicatorRegistry not initialized")
results = []
if query:
results = registry.search_by_text(query)
elif category:
results = registry.search_by_category(category)
elif tag:
results = registry.search_by_tag(tag)
else:
# Return all indicators if no filter
results = registry.get_all_metadata()
# Return lightweight summaries only
return [
{
"name": r.name,
"display_name": r.display_name,
"description": r.description,
"category": r.category
}
for r in results
]
@tool
def get_indicator_categories() -> Dict[str, int]:
"""Get all indicator categories and their counts.
Returns a summary of available indicator categories, useful for
exploring what types of indicators are available.
Returns:
Dictionary mapping category name to count of indicators in that category.
Example: {"momentum": 25, "trend": 15, "volatility": 8, ...}
"""
registry = _get_indicator_registry()
if not registry:
raise ValueError("IndicatorRegistry not initialized")
categories: Dict[str, int] = {}
for metadata in registry.get_all_metadata():
category = metadata.category
categories[category] = categories.get(category, 0) + 1
return categories
@tool
async def add_indicator_to_chart(
indicator_id: str,
talib_name: str,
parameters: Optional[Dict[str, Any]] = None,
symbol: Optional[str] = None
) -> Dict[str, Any]:
"""Add a technical indicator to the chart.
This will create a new indicator instance and display it on the TradingView chart.
The indicator will be synchronized with the frontend in real-time.
Args:
indicator_id: Unique identifier for this indicator instance (e.g., 'rsi_14', 'sma_50')
talib_name: Name of the TA-Lib indicator (e.g., 'RSI', 'SMA', 'MACD', 'BBANDS')
Use search_indicators() or get_indicator_info() to find available indicators
parameters: Optional dictionary of indicator parameters
Example for RSI: {'timeperiod': 14}
Example for SMA: {'timeperiod': 50}
Example for MACD: {'fastperiod': 12, 'slowperiod': 26, 'signalperiod': 9}
Example for BBANDS: {'timeperiod': 20, 'nbdevup': 2, 'nbdevdn': 2}
symbol: Optional symbol to apply the indicator to (defaults to current chart symbol)
Returns:
Dictionary with:
- status: 'created' or 'updated'
- indicator: The complete indicator object
Example:
# Add RSI(14)
await add_indicator_to_chart(
indicator_id='rsi_14',
talib_name='RSI',
parameters={'timeperiod': 14}
)
# Add 50-period SMA
await add_indicator_to_chart(
indicator_id='sma_50',
talib_name='SMA',
parameters={'timeperiod': 50}
)
# Add MACD with default parameters
await add_indicator_to_chart(
indicator_id='macd_default',
talib_name='MACD'
)
"""
from schema.indicator import IndicatorInstance
registry = _get_registry()
if not registry:
raise ValueError("SyncRegistry not initialized")
indicator_store = _get_indicator_store()
if not indicator_store:
raise ValueError("IndicatorStore not initialized")
# Verify the indicator exists
indicator_registry = _get_indicator_registry()
if not indicator_registry:
raise ValueError("IndicatorRegistry not initialized")
metadata = indicator_registry.get_metadata(talib_name)
if not metadata:
raise ValueError(
f"Indicator '{talib_name}' not found. "
f"Use search_indicators() to find available indicators."
)
# Check if updating existing indicator
existing_indicator = indicator_store.indicators.get(indicator_id)
is_update = existing_indicator is not None
# If symbol is not provided, try to get it from ChartStore
if symbol is None and "ChartStore" in registry.entries:
chart_store = registry.entries["ChartStore"].model
if hasattr(chart_store, 'chart_state') and hasattr(chart_store.chart_state, 'symbol'):
symbol = chart_store.chart_state.symbol
logger.info(f"Using current chart symbol for indicator: {symbol}")
now = int(time.time())
# Create indicator instance
indicator = IndicatorInstance(
id=indicator_id,
talib_name=talib_name,
instance_name=f"{talib_name}_{indicator_id}",
parameters=parameters or {},
visible=True,
pane='chart', # Most indicators go on the chart pane
symbol=symbol,
created_at=existing_indicator.get('created_at') if existing_indicator else now,
modified_at=now
)
# Update the store
indicator_store.indicators[indicator_id] = indicator.model_dump(mode="json")
# Trigger sync
await registry.push_all()
logger.info(
f"{'Updated' if is_update else 'Created'} indicator '{indicator_id}' "
f"(TA-Lib: {talib_name}) with parameters: {parameters}"
)
return {
"status": "updated" if is_update else "created",
"indicator": indicator.model_dump(mode="json")
}
@tool
async def remove_indicator_from_chart(indicator_id: str) -> Dict[str, str]:
"""Remove an indicator from the chart.
Args:
indicator_id: ID of the indicator instance to remove
Returns:
Dictionary with status message
Raises:
ValueError: If indicator doesn't exist
Example:
await remove_indicator_from_chart('rsi_14')
"""
registry = _get_registry()
if not registry:
raise ValueError("SyncRegistry not initialized")
indicator_store = _get_indicator_store()
if not indicator_store:
raise ValueError("IndicatorStore not initialized")
if indicator_id not in indicator_store.indicators:
raise ValueError(f"Indicator '{indicator_id}' not found")
# Delete the indicator
del indicator_store.indicators[indicator_id]
# Trigger sync
await registry.push_all()
logger.info(f"Removed indicator '{indicator_id}'")
return {
"status": "success",
"message": f"Indicator '{indicator_id}' removed"
}
@tool
def list_chart_indicators(symbol: Optional[str] = None) -> List[Dict[str, Any]]:
"""List all indicators currently applied to the chart.
Args:
symbol: Optional filter by symbol (defaults to current chart symbol)
Returns:
List of indicator instances, each containing:
- id: Indicator instance ID
- talib_name: TA-Lib indicator name
- instance_name: Display name
- parameters: Current parameter values
- visible: Whether indicator is visible
- pane: Which pane it's displayed in
- symbol: Symbol it's applied to
Example:
# List all indicators on current symbol
indicators = list_chart_indicators()
# List indicators on specific symbol
btc_indicators = list_chart_indicators(symbol='BINANCE:BTC/USDT')
"""
indicator_store = _get_indicator_store()
if not indicator_store:
raise ValueError("IndicatorStore not initialized")
logger.info(f"list_chart_indicators: Raw store indicators: {indicator_store.indicators}")
# If symbol is not provided, try to get it from ChartStore
if symbol is None:
registry = _get_registry()
if registry and "ChartStore" in registry.entries:
chart_store = registry.entries["ChartStore"].model
if hasattr(chart_store, 'chart_state') and hasattr(chart_store.chart_state, 'symbol'):
symbol = chart_store.chart_state.symbol
indicators = list(indicator_store.indicators.values())
logger.info(f"list_chart_indicators: Converted to list: {indicators}")
logger.info(f"list_chart_indicators: Filtering by symbol: {symbol}")
# Filter by symbol if provided
if symbol:
indicators = [ind for ind in indicators if ind.get('symbol') == symbol]
logger.info(f"list_chart_indicators: Returning {len(indicators)} indicators")
return indicators
@tool
def get_chart_indicator(indicator_id: str) -> Dict[str, Any]:
"""Get details of a specific indicator on the chart.
Args:
indicator_id: ID of the indicator instance
Returns:
Dictionary containing the indicator data
Raises:
ValueError: If indicator doesn't exist
Example:
indicator = get_chart_indicator('rsi_14')
print(f"Indicator: {indicator['talib_name']}")
print(f"Parameters: {indicator['parameters']}")
"""
indicator_store = _get_indicator_store()
if not indicator_store:
raise ValueError("IndicatorStore not initialized")
indicator = indicator_store.indicators.get(indicator_id)
if not indicator:
raise ValueError(f"Indicator '{indicator_id}' not found")
return indicator
INDICATOR_TOOLS = [
# Discovery tools
list_indicators,
get_indicator_info,
search_indicators,
get_indicator_categories,
# Chart indicator management tools
add_indicator_to_chart,
remove_indicator_from_chart,
list_chart_indicators,
get_chart_indicator
]

View File

@@ -0,0 +1,475 @@
"""Shape/drawing tools for chart analysis."""
from typing import Dict, Any, List, Optional
from langchain_core.tools import tool
import logging
logger = logging.getLogger(__name__)
# Map legacy/common shape type names to TradingView's native names
SHAPE_TYPE_ALIASES: Dict[str, str] = {
'trendline': 'trend_line',
'fibonacci': 'fib_retracement',
'fibonacci_extension': 'fib_trend_ext',
'gann_fan': 'gannbox_fan',
}
def _get_registry():
"""Get the global registry instance."""
from . import _registry
return _registry
def _get_shape_store():
"""Get the global ShapeStore instance."""
registry = _get_registry()
if registry and "ShapeStore" in registry.entries:
return registry.entries["ShapeStore"].model
return None
@tool
def search_shapes(
start_time: Optional[int] = None,
end_time: Optional[int] = None,
shape_type: Optional[str] = None,
symbol: Optional[str] = None,
shape_ids: Optional[List[str]] = None,
original_ids: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
"""Search for shapes/drawings using flexible filters.
This tool can search shapes by:
- Time range (finds shapes that overlap the range)
- Shape type (e.g., 'trendline', 'horizontal_line')
- Symbol (e.g., 'BINANCE:BTC/USDT')
- Specific shape IDs (TradingView's assigned IDs)
- Original IDs (the IDs you specified when creating shapes)
Args:
start_time: Optional start of time range (Unix timestamp in seconds)
end_time: Optional end of time range (Unix timestamp in seconds)
shape_type: Optional filter by shape type (e.g., 'trend_line', 'horizontal_line', 'rectangle')
symbol: Optional filter by symbol (e.g., 'BINANCE:BTC/USDT')
shape_ids: Optional list of specific shape IDs to retrieve (searches both id and original_id fields)
original_ids: Optional list of original IDs to search for (the IDs you specified when creating)
Returns:
List of matching shapes, each as a dictionary with:
- id: Shape identifier (TradingView's assigned ID)
- original_id: The ID you specified when creating the shape (if applicable)
- type: Shape type
- points: List of control points with time and price
- color, line_width, line_style: Visual properties
- properties: Additional shape-specific properties
- symbol: Symbol the shape is drawn on
- created_at, modified_at: Timestamps
Examples:
# Find all shapes in the currently visible chart range
shapes = search_shapes(
start_time=chart_state.start_time,
end_time=chart_state.end_time
)
# Find only trendlines in a specific time range
trendlines = search_shapes(
start_time=1640000000,
end_time=1650000000,
shape_type='trend_line'
)
# Find shapes for a specific symbol
btc_shapes = search_shapes(
start_time=1640000000,
end_time=1650000000,
symbol='BINANCE:BTC/USDT'
)
# Get specific shapes by TradingView ID or original ID
# This searches both the 'id' and 'original_id' fields
selected = search_shapes(
shape_ids=['trendline-1', 'support-42k', 'fib-retracement-1']
)
# Get shapes by the original IDs you specified when creating them
my_shapes = search_shapes(
original_ids=['my-support-line', 'my-resistance-line']
)
# Get all trendlines (no time filter)
all_trendlines = search_shapes(shape_type='trend_line')
"""
shape_store = _get_shape_store()
if not shape_store:
raise ValueError("ShapeStore not initialized")
shapes_dict = shape_store.shapes
matching_shapes = []
# If specific shape IDs are requested, search by both id and original_id
if shape_ids:
for requested_id in shape_ids:
# First try direct ID lookup
shape = shapes_dict.get(requested_id)
if shape:
# Still apply other filters if specified
if symbol and shape.get('symbol') != symbol:
continue
if shape_type and shape.get('type') != shape_type:
continue
matching_shapes.append(shape)
else:
# If not found by ID, search by original_id
for shape_id, shape in shapes_dict.items():
if shape.get('original_id') == requested_id:
# Still apply other filters if specified
if symbol and shape.get('symbol') != symbol:
continue
if shape_type and shape.get('type') != shape_type:
continue
matching_shapes.append(shape)
break
logger.info(
f"Found {len(matching_shapes)} shapes by ID filter (requested {len(shape_ids)} IDs)"
+ (f" for type '{shape_type}'" if shape_type else "")
+ (f" on symbol '{symbol}'" if symbol else "")
)
return matching_shapes
# If specific original IDs are requested, search by original_id only
if original_ids:
for original_id in original_ids:
for shape_id, shape in shapes_dict.items():
if shape.get('original_id') == original_id:
# Still apply other filters if specified
if symbol and shape.get('symbol') != symbol:
continue
if shape_type and shape.get('type') != shape_type:
continue
matching_shapes.append(shape)
break
logger.info(
f"Found {len(matching_shapes)} shapes by original_id filter (requested {len(original_ids)} IDs)"
+ (f" for type '{shape_type}'" if shape_type else "")
+ (f" on symbol '{symbol}'" if symbol else "")
)
return matching_shapes
# Otherwise, search all shapes with filters
for shape_id, shape in shapes_dict.items():
# Filter by symbol if specified
if symbol and shape.get('symbol') != symbol:
continue
# Filter by type if specified
if shape_type and shape.get('type') != shape_type:
continue
# Filter by time range if specified
if start_time is not None and end_time is not None:
# Check if any control point falls within the time range
# or if the shape spans across the time range
points = shape.get('points', [])
if not points:
continue
# Get min and max times from shape's control points
shape_times = [point['time'] for point in points]
shape_min_time = min(shape_times)
shape_max_time = max(shape_times)
# Check for overlap: shape overlaps if its range intersects with query range
if not (shape_max_time >= start_time and shape_min_time <= end_time):
continue
matching_shapes.append(shape)
logger.info(
f"Found {len(matching_shapes)} shapes"
+ (f" in time range {start_time}-{end_time}" if start_time and end_time else "")
+ (f" for type '{shape_type}'" if shape_type else "")
+ (f" on symbol '{symbol}'" if symbol else "")
)
return matching_shapes
@tool
async def create_or_update_shape(
shape_id: str,
shape_type: str,
points: List[Dict[str, Any]],
color: Optional[str] = None,
line_width: Optional[int] = None,
line_style: Optional[str] = None,
properties: Optional[Dict[str, Any]] = None,
symbol: Optional[str] = None
) -> Dict[str, Any]:
"""Create a new shape or update an existing shape on the chart.
This tool allows the agent to draw shapes on the user's chart or modify
existing shapes. Shapes are synchronized to the frontend in real-time.
IMPORTANT - Shape ID Mapping:
When you create a shape, TradingView will assign its own internal ID that differs
from the shape_id you provide. The shape will be updated in the store with:
- id: TradingView's assigned ID
- original_id: The shape_id you provided
To find your shape later, use search_shapes() and filter by original_id field.
Example:
# Create a shape
await create_or_update_shape(shape_id='my-support', ...)
# Later, find it by original_id
shapes = search_shapes(symbol='BINANCE:BTC/USDT')
my_shape = next((s for s in shapes if s.get('original_id') == 'my-support'), None)
Args:
shape_id: Unique identifier for the shape (use existing ID to update, new ID to create)
Note: TradingView will assign its own ID; your ID will be stored in original_id
shape_type: Type of shape using TradingView's native names.
Single-point shapes (use 1 point):
- 'horizontal_line': Horizontal support/resistance line
- 'vertical_line': Vertical time marker
- 'text': Text label
- 'anchored_text': Anchored text annotation
- 'anchored_note': Anchored note
- 'note': Note annotation
- 'emoji': Emoji marker
- 'icon': Icon marker
- 'sticker': Sticker marker
- 'arrow_up': Upward arrow marker
- 'arrow_down': Downward arrow marker
- 'flag': Flag marker
- 'long_position': Long position marker
- 'short_position': Short position marker
Multi-point shapes (use 2+ points):
- 'trend_line': Trendline (2 points)
- 'rectangle': Rectangle (2 points: top-left, bottom-right)
- 'fib_retracement': Fibonacci retracement (2 points)
- 'fib_trend_ext': Fibonacci extension (3 points)
- 'parallel_channel': Parallel channel (3 points)
- 'arrow': Arrow (2 points)
- 'circle': Circle/ellipse (2-3 points)
- 'path': Free drawing path (3+ points)
- 'pitchfork': Andrew's pitchfork (3 points)
- 'gannbox_fan': Gann fan (2 points)
- 'head_and_shoulders': Head and shoulders pattern (5 points)
points: List of control points, each with 'time' (Unix seconds) and 'price' fields
color: Optional color (hex like '#FF0000' or name like 'red')
line_width: Optional line width in pixels (default: 1)
line_style: Optional line style: 'solid', 'dashed', 'dotted' (default: 'solid')
properties: Optional dict of additional shape-specific properties
symbol: Optional symbol to associate with the shape (defaults to current chart symbol)
Returns:
Dictionary with:
- status: 'created' or 'updated'
- shape: The complete shape object (initially with your ID, will be updated to TV ID)
Examples:
# Draw a trendline between two points
await create_or_update_shape(
shape_id='my-trendline-1',
shape_type='trend_line',
points=[
{'time': 1640000000, 'price': 45000.0},
{'time': 1650000000, 'price': 50000.0}
],
color='#00FF00',
line_width=2
)
# Draw a horizontal support line
await create_or_update_shape(
shape_id='support-1',
shape_type='horizontal_line',
points=[{'time': 1640000000, 'price': 42000.0}],
color='blue',
line_style='dashed'
)
# Find your shape after creation using original_id
shapes = search_shapes(symbol='BINANCE:BTC/USDT')
my_shape = next((s for s in shapes if s.get('original_id') == 'support-1'), None)
if my_shape:
print(f"TradingView assigned ID: {my_shape['id']}")
"""
from schema.shape import Shape, ControlPoint
import time as time_module
registry = _get_registry()
if not registry:
raise ValueError("SyncRegistry not initialized")
shape_store = _get_shape_store()
if not shape_store:
raise ValueError("ShapeStore not initialized")
# Normalize shape type (handle legacy names)
normalized_type = SHAPE_TYPE_ALIASES.get(shape_type, shape_type)
if normalized_type != shape_type:
logger.info(f"Normalized shape type '{shape_type}' -> '{normalized_type}'")
# Convert points to ControlPoint objects
control_points = []
for p in points:
point_data = {
'time': p['time'],
'price': p['price']
}
# Only include channel if it's actually provided
if 'channel' in p and p['channel'] is not None:
point_data['channel'] = p['channel']
control_points.append(ControlPoint(**point_data))
# Check if updating existing shape
existing_shape = shape_store.shapes.get(shape_id)
is_update = existing_shape is not None
# If symbol is not provided, try to get it from ChartStore
if symbol is None and "ChartStore" in registry.entries:
chart_store = registry.entries["ChartStore"].model
if hasattr(chart_store, 'chart_state') and hasattr(chart_store.chart_state, 'symbol'):
symbol = chart_store.chart_state.symbol
logger.info(f"Using current chart symbol for shape: {symbol}")
now = int(time_module.time())
# Create shape object
shape = Shape(
id=shape_id,
type=normalized_type,
points=control_points,
color=color,
line_width=line_width,
line_style=line_style,
properties=properties or {},
symbol=symbol,
created_at=existing_shape.get('created_at') if existing_shape else now,
modified_at=now
)
# Update the store
shape_store.shapes[shape_id] = shape.model_dump(mode="json")
# Trigger sync
await registry.push_all()
logger.info(
f"{'Updated' if is_update else 'Created'} shape '{shape_id}' "
f"of type '{shape_type}' with {len(points)} points"
)
return {
"status": "updated" if is_update else "created",
"shape": shape.model_dump(mode="json")
}
@tool
async def delete_shape(shape_id: str) -> Dict[str, str]:
"""Delete a shape from the chart.
Args:
shape_id: ID of the shape to delete
Returns:
Dictionary with status message
Raises:
ValueError: If shape doesn't exist
Example:
await delete_shape('my-trendline-1')
"""
registry = _get_registry()
if not registry:
raise ValueError("SyncRegistry not initialized")
shape_store = _get_shape_store()
if not shape_store:
raise ValueError("ShapeStore not initialized")
if shape_id not in shape_store.shapes:
raise ValueError(f"Shape '{shape_id}' not found")
# Delete the shape
del shape_store.shapes[shape_id]
# Trigger sync
await registry.push_all()
logger.info(f"Deleted shape '{shape_id}'")
return {
"status": "success",
"message": f"Shape '{shape_id}' deleted"
}
@tool
def get_shape(shape_id: str) -> Dict[str, Any]:
"""Get details of a specific shape by ID.
Args:
shape_id: ID of the shape to retrieve
Returns:
Dictionary containing the shape data
Raises:
ValueError: If shape doesn't exist
Example:
shape = get_shape('my-trendline-1')
print(f"Shape type: {shape['type']}")
print(f"Points: {shape['points']}")
"""
shape_store = _get_shape_store()
if not shape_store:
raise ValueError("ShapeStore not initialized")
shape = shape_store.shapes.get(shape_id)
if not shape:
raise ValueError(f"Shape '{shape_id}' not found")
return shape
@tool
def list_all_shapes() -> List[Dict[str, Any]]:
"""List all shapes currently on the chart.
Returns:
List of all shapes as dictionaries
Example:
shapes = list_all_shapes()
print(f"Total shapes: {len(shapes)}")
for shape in shapes:
print(f" - {shape['id']}: {shape['type']}")
"""
shape_store = _get_shape_store()
if not shape_store:
raise ValueError("ShapeStore not initialized")
return list(shape_store.shapes.values())
SHAPE_TOOLS = [
search_shapes,
create_or_update_shape,
delete_shape,
get_shape,
list_all_shapes
]

View File

@@ -0,0 +1,366 @@
"""
Agent tools for trigger system.
Allows agents to:
- Schedule recurring tasks (cron-style)
- Execute one-time triggers
- Manage scheduled triggers (list, cancel)
- Connect events to sub-agent runs or lambdas
"""
import logging
from typing import Any, Dict, List, Optional
from langchain_core.tools import tool
logger = logging.getLogger(__name__)
# Global references set by main.py
_trigger_queue = None
_trigger_scheduler = None
_coordinator = None
def set_trigger_queue(queue):
"""Set the global TriggerQueue instance for tools to use."""
global _trigger_queue
_trigger_queue = queue
def set_trigger_scheduler(scheduler):
"""Set the global TriggerScheduler instance for tools to use."""
global _trigger_scheduler
_trigger_scheduler = scheduler
def set_coordinator(coordinator):
"""Set the global CommitCoordinator instance for tools to use."""
global _coordinator
_coordinator = coordinator
def _get_trigger_queue():
"""Get the global trigger queue instance."""
if not _trigger_queue:
raise ValueError("TriggerQueue not initialized")
return _trigger_queue
def _get_trigger_scheduler():
"""Get the global trigger scheduler instance."""
if not _trigger_scheduler:
raise ValueError("TriggerScheduler not initialized")
return _trigger_scheduler
def _get_coordinator():
"""Get the global coordinator instance."""
if not _coordinator:
raise ValueError("CommitCoordinator not initialized")
return _coordinator
@tool
async def schedule_agent_prompt(
prompt: str,
schedule_type: str,
schedule_config: Dict[str, Any],
name: Optional[str] = None,
) -> Dict[str, str]:
"""Schedule an agent to run with a specific prompt on a recurring schedule.
This allows you to set up automated tasks where the agent runs periodically
with a predefined prompt. Useful for:
- Daily market analysis reports
- Hourly portfolio rebalancing checks
- Weekly performance summaries
- Monitoring alerts
Args:
prompt: The prompt to send to the agent when triggered
schedule_type: Type of schedule - "interval" or "cron"
schedule_config: Schedule configuration:
For "interval": {"minutes": 5} or {"hours": 1, "minutes": 30}
For "cron": {"hour": "9", "minute": "0"} for 9:00 AM daily
{"hour": "9", "minute": "0", "day_of_week": "mon-fri"}
name: Optional descriptive name for this scheduled task
Returns:
Dictionary with job_id and confirmation message
Examples:
# Run every 5 minutes
schedule_agent_prompt(
prompt="Check BTC price and alert if > $50k",
schedule_type="interval",
schedule_config={"minutes": 5}
)
# Run daily at 9 AM
schedule_agent_prompt(
prompt="Generate daily market summary",
schedule_type="cron",
schedule_config={"hour": "9", "minute": "0"}
)
# Run hourly on weekdays
schedule_agent_prompt(
prompt="Monitor portfolio for rebalancing opportunities",
schedule_type="cron",
schedule_config={"minute": "0", "day_of_week": "mon-fri"}
)
"""
from trigger.handlers import LambdaHandler
from trigger import Priority
scheduler = _get_trigger_scheduler()
queue = _get_trigger_queue()
if not name:
name = f"agent_prompt_{hash(prompt) % 10000}"
# Create a lambda that enqueues an agent trigger with the prompt
async def agent_prompt_lambda():
from trigger.handlers import AgentTriggerHandler
# Create agent trigger (will use current session's context)
# In production, you'd want to specify which session/user this belongs to
trigger = AgentTriggerHandler(
session_id="scheduled", # Special session for scheduled tasks
message_content=prompt,
coordinator=_get_coordinator(),
)
await queue.enqueue(trigger)
return [] # No direct commit intents
# Wrap in lambda handler
lambda_trigger = LambdaHandler(
name=f"scheduled_{name}",
func=agent_prompt_lambda,
priority=Priority.TIMER,
)
# Schedule based on type
if schedule_type == "interval":
job_id = scheduler.schedule_interval(
lambda_trigger,
seconds=schedule_config.get("seconds"),
minutes=schedule_config.get("minutes"),
hours=schedule_config.get("hours"),
priority=Priority.TIMER,
)
elif schedule_type == "cron":
job_id = scheduler.schedule_cron(
lambda_trigger,
minute=schedule_config.get("minute"),
hour=schedule_config.get("hour"),
day=schedule_config.get("day"),
month=schedule_config.get("month"),
day_of_week=schedule_config.get("day_of_week"),
priority=Priority.TIMER,
)
else:
raise ValueError(f"Invalid schedule_type: {schedule_type}. Use 'interval' or 'cron'")
return {
"job_id": job_id,
"message": f"Scheduled '{name}' with job_id={job_id}",
"schedule_type": schedule_type,
"config": schedule_config,
}
@tool
async def execute_agent_prompt_once(
prompt: str,
priority: str = "normal",
) -> Dict[str, str]:
"""Execute an agent prompt once, immediately (enqueued with priority).
Use this to trigger a sub-agent with a specific task without waiting for
a user message. Useful for:
- Background analysis tasks
- One-time data processing
- Responding to specific events
Args:
prompt: The prompt to send to the agent
priority: Priority level - "high", "normal", or "low"
Returns:
Confirmation that the prompt was enqueued
Example:
execute_agent_prompt_once(
prompt="Analyze the last 100 BTC/USDT bars and identify support levels",
priority="high"
)
"""
from trigger.handlers import AgentTriggerHandler
from trigger import Priority
queue = _get_trigger_queue()
# Map string priority to enum
priority_map = {
"high": Priority.USER_AGENT, # Same priority as user messages
"normal": Priority.SYSTEM,
"low": Priority.LOW,
}
priority_enum = priority_map.get(priority.lower(), Priority.SYSTEM)
# Create agent trigger
trigger = AgentTriggerHandler(
session_id="oneshot",
message_content=prompt,
coordinator=_get_coordinator(),
)
# Enqueue with priority override
queue_seq = await queue.enqueue(trigger, priority_enum)
return {
"queue_seq": queue_seq,
"message": f"Enqueued agent prompt with priority={priority}",
"prompt": prompt[:100] + "..." if len(prompt) > 100 else prompt,
}
@tool
def list_scheduled_triggers() -> List[Dict[str, Any]]:
"""List all currently scheduled triggers.
Returns:
List of dictionaries with job information (id, name, next_run_time)
Example:
jobs = list_scheduled_triggers()
for job in jobs:
print(f"{job['id']}: {job['name']} - next run at {job['next_run_time']}")
"""
scheduler = _get_trigger_scheduler()
jobs = scheduler.get_jobs()
result = []
for job in jobs:
result.append({
"id": job.id,
"name": job.name,
"next_run_time": str(job.next_run_time) if job.next_run_time else None,
"trigger": str(job.trigger),
})
return result
@tool
def cancel_scheduled_trigger(job_id: str) -> Dict[str, str]:
"""Cancel a scheduled trigger by its job ID.
Args:
job_id: The job ID returned from schedule_agent_prompt or list_scheduled_triggers
Returns:
Confirmation message
Example:
cancel_scheduled_trigger("interval_123")
"""
scheduler = _get_trigger_scheduler()
success = scheduler.remove_job(job_id)
if success:
return {
"status": "success",
"message": f"Cancelled job {job_id}",
}
else:
return {
"status": "error",
"message": f"Job {job_id} not found",
}
@tool
async def on_data_update_run_agent(
source_name: str,
symbol: str,
resolution: str,
prompt_template: str,
) -> Dict[str, str]:
"""Set up an agent to run whenever new data arrives for a specific symbol.
The prompt_template can include {variables} that will be filled with bar data:
- {time}: Bar timestamp
- {open}, {high}, {low}, {close}, {volume}: OHLCV values
- {symbol}: Trading pair symbol
- {source}: Data source name
Args:
source_name: Name of data source (e.g., "binance")
symbol: Trading pair (e.g., "BTC/USDT")
resolution: Time resolution (e.g., "1m", "5m", "1h")
prompt_template: Template string for agent prompt
Returns:
Confirmation with subscription details
Example:
on_data_update_run_agent(
source_name="binance",
symbol="BTC/USDT",
resolution="1m",
prompt_template="New bar on {symbol}: close={close}. Check if we should trade."
)
Note:
This is a simplified version. Full implementation would wire into
DataSource subscription system to trigger on every bar update.
"""
# TODO: Implement proper DataSource subscription integration
# For now, return placeholder
return {
"status": "not_implemented",
"message": "Data-driven agent triggers coming soon",
"config": {
"source": source_name,
"symbol": symbol,
"resolution": resolution,
"prompt_template": prompt_template,
},
}
@tool
def get_trigger_system_stats() -> Dict[str, Any]:
"""Get statistics about the trigger system.
Returns:
Dictionary with queue depth, execution stats, etc.
Example:
stats = get_trigger_system_stats()
print(f"Queue depth: {stats['queue_depth']}")
print(f"Current seq: {stats['current_seq']}")
"""
queue = _get_trigger_queue()
coordinator = _get_coordinator()
return {
"queue_depth": queue.get_queue_size(),
"queue_running": queue.is_running(),
"coordinator_stats": coordinator.get_stats(),
}
# Export tools list
TRIGGER_TOOLS = [
schedule_agent_prompt,
execute_agent_prompt_once,
list_scheduled_triggers,
cancel_scheduled_trigger,
on_data_update_run_agent,
get_trigger_system_stats,
]

View File

@@ -149,6 +149,10 @@ from .talib_adapter import (
is_talib_available, is_talib_available,
get_talib_version, get_talib_version,
) )
from .custom_indicators import (
register_custom_indicators,
CUSTOM_INDICATORS,
)
__all__ = [ __all__ = [
# Core classes # Core classes
@@ -169,4 +173,7 @@ __all__ = [
"register_all_talib_indicators", "register_all_talib_indicators",
"is_talib_available", "is_talib_available",
"get_talib_version", "get_talib_version",
# Custom indicators
"register_custom_indicators",
"CUSTOM_INDICATORS",
] ]

View File

@@ -0,0 +1,954 @@
"""
Custom indicator implementations for TradingView indicators not in TA-Lib.
These indicators follow TA-Lib style conventions and integrate seamlessly
with the indicator framework. All implementations are based on well-known,
publicly documented formulas.
"""
import logging
from typing import List, Optional
import numpy as np
from datasource.schema import ColumnInfo
from .base import Indicator
from .schema import (
ComputeContext,
ComputeResult,
IndicatorMetadata,
IndicatorParameter,
InputSchema,
OutputSchema,
)
logger = logging.getLogger(__name__)
class VWAP(Indicator):
"""Volume Weighted Average Price - Most widely used institutional indicator."""
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
return IndicatorMetadata(
name="VWAP",
display_name="VWAP",
description="Volume Weighted Average Price - Average price weighted by volume",
category="volume",
parameters=[],
use_cases=[
"Institutional reference price",
"Support/resistance levels",
"Mean reversion trading"
],
references=["https://www.investopedia.com/terms/v/vwap.asp"],
tags=["vwap", "volume", "institutional"]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
return InputSchema(required_columns=[
ColumnInfo(name="high", type="float", description="High price"),
ColumnInfo(name="low", type="float", description="Low price"),
ColumnInfo(name="close", type="float", description="Close price"),
ColumnInfo(name="volume", type="float", description="Volume"),
])
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
return OutputSchema(columns=[
ColumnInfo(name="vwap", type="float", description="Volume Weighted Average Price", nullable=True)
])
def compute(self, context: ComputeContext) -> ComputeResult:
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
volume = np.array([float(v) if v is not None else np.nan for v in context.get_column("volume")])
# Typical price
typical_price = (high + low + close) / 3.0
# VWAP = cumsum(typical_price * volume) / cumsum(volume)
cumulative_tp_vol = np.nancumsum(typical_price * volume)
cumulative_vol = np.nancumsum(volume)
vwap = cumulative_tp_vol / cumulative_vol
times = context.get_times()
result_data = [
{"time": times[i], "vwap": float(vwap[i]) if not np.isnan(vwap[i]) else None}
for i in range(len(times))
]
return ComputeResult(data=result_data, is_partial=context.is_incremental)
class VWMA(Indicator):
"""Volume Weighted Moving Average."""
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
return IndicatorMetadata(
name="VWMA",
display_name="VWMA",
description="Volume Weighted Moving Average - Moving average weighted by volume",
category="overlap",
parameters=[
IndicatorParameter(
name="length",
type="int",
description="Period length",
default=20,
min_value=1,
required=False
)
],
use_cases=["Volume-aware trend following", "Dynamic support/resistance"],
references=["https://www.investopedia.com/articles/trading/11/trading-with-vwap-mvwap.asp"],
tags=["vwma", "volume", "moving average"]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
return InputSchema(required_columns=[
ColumnInfo(name="close", type="float", description="Close price"),
ColumnInfo(name="volume", type="float", description="Volume"),
])
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
return OutputSchema(columns=[
ColumnInfo(name="vwma", type="float", description="Volume Weighted Moving Average", nullable=True)
])
def compute(self, context: ComputeContext) -> ComputeResult:
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
volume = np.array([float(v) if v is not None else np.nan for v in context.get_column("volume")])
length = self.params.get("length", 20)
vwma = np.full_like(close, np.nan)
for i in range(length - 1, len(close)):
window_close = close[i - length + 1:i + 1]
window_volume = volume[i - length + 1:i + 1]
vwma[i] = np.sum(window_close * window_volume) / np.sum(window_volume)
times = context.get_times()
result_data = [
{"time": times[i], "vwma": float(vwma[i]) if not np.isnan(vwma[i]) else None}
for i in range(len(times))
]
return ComputeResult(data=result_data, is_partial=context.is_incremental)
class HullMA(Indicator):
"""Hull Moving Average - Fast and smooth moving average."""
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
return IndicatorMetadata(
name="HMA",
display_name="Hull Moving Average",
description="Hull Moving Average - Reduces lag while maintaining smoothness",
category="overlap",
parameters=[
IndicatorParameter(
name="length",
type="int",
description="Period length",
default=9,
min_value=1,
required=False
)
],
use_cases=["Low-lag trend following", "Quick trend reversal detection"],
references=["https://alanhull.com/hull-moving-average"],
tags=["hma", "hull", "moving average", "low-lag"]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
return InputSchema(required_columns=[
ColumnInfo(name="close", type="float", description="Close price"),
])
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
return OutputSchema(columns=[
ColumnInfo(name="hma", type="float", description="Hull Moving Average", nullable=True)
])
def compute(self, context: ComputeContext) -> ComputeResult:
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
length = self.params.get("length", 9)
def wma(data, period):
"""Weighted Moving Average."""
weights = np.arange(1, period + 1)
result = np.full_like(data, np.nan)
for i in range(period - 1, len(data)):
window = data[i - period + 1:i + 1]
result[i] = np.sum(weights * window) / np.sum(weights)
return result
# HMA = WMA(2 * WMA(n/2) - WMA(n)), sqrt(n))
half_length = length // 2
sqrt_length = int(np.sqrt(length))
wma_half = wma(close, half_length)
wma_full = wma(close, length)
raw_hma = 2 * wma_half - wma_full
hma = wma(raw_hma, sqrt_length)
times = context.get_times()
result_data = [
{"time": times[i], "hma": float(hma[i]) if not np.isnan(hma[i]) else None}
for i in range(len(times))
]
return ComputeResult(data=result_data, is_partial=context.is_incremental)
class SuperTrend(Indicator):
"""SuperTrend - Popular trend following indicator."""
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
return IndicatorMetadata(
name="SUPERTREND",
display_name="SuperTrend",
description="SuperTrend - Volatility-based trend indicator",
category="overlap",
parameters=[
IndicatorParameter(
name="length",
type="int",
description="ATR period",
default=10,
min_value=1,
required=False
),
IndicatorParameter(
name="multiplier",
type="float",
description="ATR multiplier",
default=3.0,
min_value=0.1,
required=False
)
],
use_cases=["Trend identification", "Stop loss placement", "Trend reversal signals"],
references=["https://www.investopedia.com/articles/trading/08/supertrend-indicator.asp"],
tags=["supertrend", "trend", "volatility"]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
return InputSchema(required_columns=[
ColumnInfo(name="high", type="float", description="High price"),
ColumnInfo(name="low", type="float", description="Low price"),
ColumnInfo(name="close", type="float", description="Close price"),
])
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
return OutputSchema(columns=[
ColumnInfo(name="supertrend", type="float", description="SuperTrend value", nullable=True),
ColumnInfo(name="direction", type="int", description="Trend direction (1=up, -1=down)", nullable=True)
])
def compute(self, context: ComputeContext) -> ComputeResult:
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
length = self.params.get("length", 10)
multiplier = self.params.get("multiplier", 3.0)
# Calculate ATR
tr = np.maximum(high - low, np.maximum(np.abs(high - np.roll(close, 1)), np.abs(low - np.roll(close, 1))))
tr[0] = high[0] - low[0]
atr = np.full_like(close, np.nan)
atr[length - 1] = np.mean(tr[:length])
for i in range(length, len(tr)):
atr[i] = (atr[i - 1] * (length - 1) + tr[i]) / length
# Calculate basic bands
hl2 = (high + low) / 2
basic_upper = hl2 + multiplier * atr
basic_lower = hl2 - multiplier * atr
# Calculate final bands
final_upper = np.full_like(close, np.nan)
final_lower = np.full_like(close, np.nan)
supertrend = np.full_like(close, np.nan)
direction = np.full_like(close, np.nan)
for i in range(length, len(close)):
if i == length:
final_upper[i] = basic_upper[i]
final_lower[i] = basic_lower[i]
else:
final_upper[i] = basic_upper[i] if basic_upper[i] < final_upper[i - 1] or close[i - 1] > final_upper[i - 1] else final_upper[i - 1]
final_lower[i] = basic_lower[i] if basic_lower[i] > final_lower[i - 1] or close[i - 1] < final_lower[i - 1] else final_lower[i - 1]
if i == length:
supertrend[i] = final_upper[i] if close[i] <= hl2[i] else final_lower[i]
direction[i] = -1 if close[i] <= hl2[i] else 1
else:
if supertrend[i - 1] == final_upper[i - 1] and close[i] <= final_upper[i]:
supertrend[i] = final_upper[i]
direction[i] = -1
elif supertrend[i - 1] == final_upper[i - 1] and close[i] > final_upper[i]:
supertrend[i] = final_lower[i]
direction[i] = 1
elif supertrend[i - 1] == final_lower[i - 1] and close[i] >= final_lower[i]:
supertrend[i] = final_lower[i]
direction[i] = 1
else:
supertrend[i] = final_upper[i]
direction[i] = -1
times = context.get_times()
result_data = [
{
"time": times[i],
"supertrend": float(supertrend[i]) if not np.isnan(supertrend[i]) else None,
"direction": int(direction[i]) if not np.isnan(direction[i]) else None
}
for i in range(len(times))
]
return ComputeResult(data=result_data, is_partial=context.is_incremental)
class DonchianChannels(Indicator):
"""Donchian Channels - Breakout indicator using highest high and lowest low."""
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
return IndicatorMetadata(
name="DONCHIAN",
display_name="Donchian Channels",
description="Donchian Channels - Highest high and lowest low over period",
category="overlap",
parameters=[
IndicatorParameter(
name="length",
type="int",
description="Period length",
default=20,
min_value=1,
required=False
)
],
use_cases=["Breakout trading", "Volatility bands", "Support/resistance"],
references=["https://www.investopedia.com/terms/d/donchianchannels.asp"],
tags=["donchian", "channels", "breakout"]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
return InputSchema(required_columns=[
ColumnInfo(name="high", type="float", description="High price"),
ColumnInfo(name="low", type="float", description="Low price"),
])
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
return OutputSchema(columns=[
ColumnInfo(name="upper", type="float", description="Upper channel", nullable=True),
ColumnInfo(name="middle", type="float", description="Middle line", nullable=True),
ColumnInfo(name="lower", type="float", description="Lower channel", nullable=True),
])
def compute(self, context: ComputeContext) -> ComputeResult:
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
length = self.params.get("length", 20)
upper = np.full_like(high, np.nan)
lower = np.full_like(low, np.nan)
for i in range(length - 1, len(high)):
upper[i] = np.nanmax(high[i - length + 1:i + 1])
lower[i] = np.nanmin(low[i - length + 1:i + 1])
middle = (upper + lower) / 2
times = context.get_times()
result_data = [
{
"time": times[i],
"upper": float(upper[i]) if not np.isnan(upper[i]) else None,
"middle": float(middle[i]) if not np.isnan(middle[i]) else None,
"lower": float(lower[i]) if not np.isnan(lower[i]) else None,
}
for i in range(len(times))
]
return ComputeResult(data=result_data, is_partial=context.is_incremental)
class KeltnerChannels(Indicator):
"""Keltner Channels - ATR-based volatility bands."""
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
return IndicatorMetadata(
name="KELTNER",
display_name="Keltner Channels",
description="Keltner Channels - EMA with ATR-based bands",
category="volatility",
parameters=[
IndicatorParameter(
name="length",
type="int",
description="EMA period",
default=20,
min_value=1,
required=False
),
IndicatorParameter(
name="multiplier",
type="float",
description="ATR multiplier",
default=2.0,
min_value=0.1,
required=False
),
IndicatorParameter(
name="atr_length",
type="int",
description="ATR period",
default=10,
min_value=1,
required=False
)
],
use_cases=["Volatility bands", "Overbought/oversold", "Trend strength"],
references=["https://www.investopedia.com/terms/k/keltnerchannel.asp"],
tags=["keltner", "channels", "volatility", "atr"]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
return InputSchema(required_columns=[
ColumnInfo(name="high", type="float", description="High price"),
ColumnInfo(name="low", type="float", description="Low price"),
ColumnInfo(name="close", type="float", description="Close price"),
])
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
return OutputSchema(columns=[
ColumnInfo(name="upper", type="float", description="Upper band", nullable=True),
ColumnInfo(name="middle", type="float", description="Middle line (EMA)", nullable=True),
ColumnInfo(name="lower", type="float", description="Lower band", nullable=True),
])
def compute(self, context: ComputeContext) -> ComputeResult:
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
length = self.params.get("length", 20)
multiplier = self.params.get("multiplier", 2.0)
atr_length = self.params.get("atr_length", 10)
# Calculate EMA
alpha = 2.0 / (length + 1)
ema = np.full_like(close, np.nan)
ema[0] = close[0]
for i in range(1, len(close)):
ema[i] = alpha * close[i] + (1 - alpha) * ema[i - 1]
# Calculate ATR
tr = np.maximum(high - low, np.maximum(np.abs(high - np.roll(close, 1)), np.abs(low - np.roll(close, 1))))
tr[0] = high[0] - low[0]
atr = np.full_like(close, np.nan)
atr[atr_length - 1] = np.mean(tr[:atr_length])
for i in range(atr_length, len(tr)):
atr[i] = (atr[i - 1] * (atr_length - 1) + tr[i]) / atr_length
upper = ema + multiplier * atr
lower = ema - multiplier * atr
times = context.get_times()
result_data = [
{
"time": times[i],
"upper": float(upper[i]) if not np.isnan(upper[i]) else None,
"middle": float(ema[i]) if not np.isnan(ema[i]) else None,
"lower": float(lower[i]) if not np.isnan(lower[i]) else None,
}
for i in range(len(times))
]
return ComputeResult(data=result_data, is_partial=context.is_incremental)
class ChaikinMoneyFlow(Indicator):
"""Chaikin Money Flow - Volume-weighted accumulation/distribution."""
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
return IndicatorMetadata(
name="CMF",
display_name="Chaikin Money Flow",
description="Chaikin Money Flow - Measures buying and selling pressure",
category="volume",
parameters=[
IndicatorParameter(
name="length",
type="int",
description="Period length",
default=20,
min_value=1,
required=False
)
],
use_cases=["Buying/selling pressure", "Trend confirmation", "Divergence analysis"],
references=["https://www.investopedia.com/terms/c/chaikinoscillator.asp"],
tags=["cmf", "chaikin", "volume", "money flow"]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
return InputSchema(required_columns=[
ColumnInfo(name="high", type="float", description="High price"),
ColumnInfo(name="low", type="float", description="Low price"),
ColumnInfo(name="close", type="float", description="Close price"),
ColumnInfo(name="volume", type="float", description="Volume"),
])
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
return OutputSchema(columns=[
ColumnInfo(name="cmf", type="float", description="Chaikin Money Flow", nullable=True)
])
def compute(self, context: ComputeContext) -> ComputeResult:
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
volume = np.array([float(v) if v is not None else np.nan for v in context.get_column("volume")])
length = self.params.get("length", 20)
# Money Flow Multiplier
mfm = ((close - low) - (high - close)) / (high - low)
mfm = np.where(high == low, 0, mfm)
# Money Flow Volume
mfv = mfm * volume
# CMF
cmf = np.full_like(close, np.nan)
for i in range(length - 1, len(close)):
cmf[i] = np.nansum(mfv[i - length + 1:i + 1]) / np.nansum(volume[i - length + 1:i + 1])
times = context.get_times()
result_data = [
{"time": times[i], "cmf": float(cmf[i]) if not np.isnan(cmf[i]) else None}
for i in range(len(times))
]
return ComputeResult(data=result_data, is_partial=context.is_incremental)
class VortexIndicator(Indicator):
"""Vortex Indicator - Identifies trend direction and strength."""
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
return IndicatorMetadata(
name="VORTEX",
display_name="Vortex Indicator",
description="Vortex Indicator - Trend direction and strength",
category="momentum",
parameters=[
IndicatorParameter(
name="length",
type="int",
description="Period length",
default=14,
min_value=1,
required=False
)
],
use_cases=["Trend identification", "Trend reversals", "Trend strength"],
references=["https://www.investopedia.com/terms/v/vortex-indicator-vi.asp"],
tags=["vortex", "trend", "momentum"]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
return InputSchema(required_columns=[
ColumnInfo(name="high", type="float", description="High price"),
ColumnInfo(name="low", type="float", description="Low price"),
ColumnInfo(name="close", type="float", description="Close price"),
])
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
return OutputSchema(columns=[
ColumnInfo(name="vi_plus", type="float", description="Positive Vortex", nullable=True),
ColumnInfo(name="vi_minus", type="float", description="Negative Vortex", nullable=True),
])
def compute(self, context: ComputeContext) -> ComputeResult:
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
length = self.params.get("length", 14)
# Vortex Movement
vm_plus = np.abs(high - np.roll(low, 1))
vm_minus = np.abs(low - np.roll(high, 1))
vm_plus[0] = 0
vm_minus[0] = 0
# True Range
tr = np.maximum(high - low, np.maximum(np.abs(high - np.roll(close, 1)), np.abs(low - np.roll(close, 1))))
tr[0] = high[0] - low[0]
# Vortex Indicator
vi_plus = np.full_like(close, np.nan)
vi_minus = np.full_like(close, np.nan)
for i in range(length - 1, len(close)):
sum_vm_plus = np.sum(vm_plus[i - length + 1:i + 1])
sum_vm_minus = np.sum(vm_minus[i - length + 1:i + 1])
sum_tr = np.sum(tr[i - length + 1:i + 1])
if sum_tr != 0:
vi_plus[i] = sum_vm_plus / sum_tr
vi_minus[i] = sum_vm_minus / sum_tr
times = context.get_times()
result_data = [
{
"time": times[i],
"vi_plus": float(vi_plus[i]) if not np.isnan(vi_plus[i]) else None,
"vi_minus": float(vi_minus[i]) if not np.isnan(vi_minus[i]) else None,
}
for i in range(len(times))
]
return ComputeResult(data=result_data, is_partial=context.is_incremental)
class AwesomeOscillator(Indicator):
"""Awesome Oscillator - Bill Williams' momentum indicator."""
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
return IndicatorMetadata(
name="AO",
display_name="Awesome Oscillator",
description="Awesome Oscillator - Difference between 5 and 34 period SMAs of midpoint",
category="momentum",
parameters=[],
use_cases=["Momentum shifts", "Trend reversals", "Divergence trading"],
references=["https://www.investopedia.com/terms/a/awesomeoscillator.asp"],
tags=["awesome", "oscillator", "momentum", "williams"]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
return InputSchema(required_columns=[
ColumnInfo(name="high", type="float", description="High price"),
ColumnInfo(name="low", type="float", description="Low price"),
])
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
return OutputSchema(columns=[
ColumnInfo(name="ao", type="float", description="Awesome Oscillator", nullable=True)
])
def compute(self, context: ComputeContext) -> ComputeResult:
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
midpoint = (high + low) / 2
# SMA 5
sma5 = np.full_like(midpoint, np.nan)
for i in range(4, len(midpoint)):
sma5[i] = np.mean(midpoint[i - 4:i + 1])
# SMA 34
sma34 = np.full_like(midpoint, np.nan)
for i in range(33, len(midpoint)):
sma34[i] = np.mean(midpoint[i - 33:i + 1])
ao = sma5 - sma34
times = context.get_times()
result_data = [
{"time": times[i], "ao": float(ao[i]) if not np.isnan(ao[i]) else None}
for i in range(len(times))
]
return ComputeResult(data=result_data, is_partial=context.is_incremental)
class AcceleratorOscillator(Indicator):
"""Accelerator Oscillator - Rate of change of Awesome Oscillator."""
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
return IndicatorMetadata(
name="AC",
display_name="Accelerator Oscillator",
description="Accelerator Oscillator - Rate of change of Awesome Oscillator",
category="momentum",
parameters=[],
use_cases=["Early momentum detection", "Trend acceleration", "Divergence signals"],
references=["https://www.investopedia.com/terms/a/accelerator-oscillator.asp"],
tags=["accelerator", "oscillator", "momentum", "williams"]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
return InputSchema(required_columns=[
ColumnInfo(name="high", type="float", description="High price"),
ColumnInfo(name="low", type="float", description="Low price"),
])
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
return OutputSchema(columns=[
ColumnInfo(name="ac", type="float", description="Accelerator Oscillator", nullable=True)
])
def compute(self, context: ComputeContext) -> ComputeResult:
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
midpoint = (high + low) / 2
# Calculate AO first
sma5 = np.full_like(midpoint, np.nan)
for i in range(4, len(midpoint)):
sma5[i] = np.mean(midpoint[i - 4:i + 1])
sma34 = np.full_like(midpoint, np.nan)
for i in range(33, len(midpoint)):
sma34[i] = np.mean(midpoint[i - 33:i + 1])
ao = sma5 - sma34
# AC = AO - SMA(AO, 5)
sma_ao = np.full_like(ao, np.nan)
for i in range(4, len(ao)):
if not np.isnan(ao[i - 4:i + 1]).any():
sma_ao[i] = np.mean(ao[i - 4:i + 1])
ac = ao - sma_ao
times = context.get_times()
result_data = [
{"time": times[i], "ac": float(ac[i]) if not np.isnan(ac[i]) else None}
for i in range(len(times))
]
return ComputeResult(data=result_data, is_partial=context.is_incremental)
class ChoppinessIndex(Indicator):
"""Choppiness Index - Determines if market is choppy or trending."""
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
return IndicatorMetadata(
name="CHOP",
display_name="Choppiness Index",
description="Choppiness Index - Measures market trendiness vs consolidation",
category="volatility",
parameters=[
IndicatorParameter(
name="length",
type="int",
description="Period length",
default=14,
min_value=1,
required=False
)
],
use_cases=["Trend vs range identification", "Market regime detection"],
references=["https://www.tradingview.com/support/solutions/43000501980/"],
tags=["chop", "choppiness", "trend", "range"]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
return InputSchema(required_columns=[
ColumnInfo(name="high", type="float", description="High price"),
ColumnInfo(name="low", type="float", description="Low price"),
ColumnInfo(name="close", type="float", description="Close price"),
])
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
return OutputSchema(columns=[
ColumnInfo(name="chop", type="float", description="Choppiness Index (0-100)", nullable=True)
])
def compute(self, context: ComputeContext) -> ComputeResult:
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
length = self.params.get("length", 14)
# True Range
tr = np.maximum(high - low, np.maximum(np.abs(high - np.roll(close, 1)), np.abs(low - np.roll(close, 1))))
tr[0] = high[0] - low[0]
chop = np.full_like(close, np.nan)
for i in range(length - 1, len(close)):
sum_tr = np.sum(tr[i - length + 1:i + 1])
high_low_diff = np.max(high[i - length + 1:i + 1]) - np.min(low[i - length + 1:i + 1])
if high_low_diff != 0:
chop[i] = 100 * np.log10(sum_tr / high_low_diff) / np.log10(length)
times = context.get_times()
result_data = [
{"time": times[i], "chop": float(chop[i]) if not np.isnan(chop[i]) else None}
for i in range(len(times))
]
return ComputeResult(data=result_data, is_partial=context.is_incremental)
class MassIndex(Indicator):
"""Mass Index - Identifies trend reversals based on range expansion."""
@classmethod
def get_metadata(cls) -> IndicatorMetadata:
return IndicatorMetadata(
name="MASS",
display_name="Mass Index",
description="Mass Index - Identifies reversals when range narrows then expands",
category="volatility",
parameters=[
IndicatorParameter(
name="fast_period",
type="int",
description="Fast EMA period",
default=9,
min_value=1,
required=False
),
IndicatorParameter(
name="slow_period",
type="int",
description="Slow EMA period",
default=25,
min_value=1,
required=False
)
],
use_cases=["Reversal detection", "Volatility analysis", "Bulge identification"],
references=["https://www.investopedia.com/terms/m/mass-index.asp"],
tags=["mass", "index", "volatility", "reversal"]
)
@classmethod
def get_input_schema(cls) -> InputSchema:
return InputSchema(required_columns=[
ColumnInfo(name="high", type="float", description="High price"),
ColumnInfo(name="low", type="float", description="Low price"),
])
@classmethod
def get_output_schema(cls, **params) -> OutputSchema:
return OutputSchema(columns=[
ColumnInfo(name="mass", type="float", description="Mass Index", nullable=True)
])
def compute(self, context: ComputeContext) -> ComputeResult:
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
fast_period = self.params.get("fast_period", 9)
slow_period = self.params.get("slow_period", 25)
hl_range = high - low
# Single EMA
alpha1 = 2.0 / (fast_period + 1)
ema1 = np.full_like(hl_range, np.nan)
ema1[0] = hl_range[0]
for i in range(1, len(hl_range)):
ema1[i] = alpha1 * hl_range[i] + (1 - alpha1) * ema1[i - 1]
# Double EMA
ema2 = np.full_like(ema1, np.nan)
ema2[0] = ema1[0]
for i in range(1, len(ema1)):
if not np.isnan(ema1[i]):
ema2[i] = alpha1 * ema1[i] + (1 - alpha1) * ema2[i - 1]
# EMA Ratio
ema_ratio = ema1 / ema2
# Mass Index
mass = np.full_like(hl_range, np.nan)
for i in range(slow_period - 1, len(ema_ratio)):
mass[i] = np.nansum(ema_ratio[i - slow_period + 1:i + 1])
times = context.get_times()
result_data = [
{"time": times[i], "mass": float(mass[i]) if not np.isnan(mass[i]) else None}
for i in range(len(times))
]
return ComputeResult(data=result_data, is_partial=context.is_incremental)
# Registry of all custom indicators
CUSTOM_INDICATORS = [
VWAP,
VWMA,
HullMA,
SuperTrend,
DonchianChannels,
KeltnerChannels,
ChaikinMoneyFlow,
VortexIndicator,
AwesomeOscillator,
AcceleratorOscillator,
ChoppinessIndex,
MassIndex,
]
def register_custom_indicators(registry) -> int:
"""
Register all custom indicators with the registry.
Args:
registry: IndicatorRegistry instance
Returns:
Number of indicators registered
"""
registered_count = 0
for indicator_class in CUSTOM_INDICATORS:
try:
registry.register(indicator_class)
registered_count += 1
logger.debug(f"Registered custom indicator: {indicator_class.__name__}")
except Exception as e:
logger.warning(f"Failed to register custom indicator {indicator_class.__name__}: {e}")
logger.info(f"Registered {registered_count} custom indicators")
return registered_count

View File

@@ -372,12 +372,14 @@ def create_talib_indicator_class(func_name: str) -> type:
) )
def register_all_talib_indicators(registry) -> int: def register_all_talib_indicators(registry, only_tradingview_supported: bool = True) -> int:
""" """
Auto-register all available TA-Lib indicators with the registry. Auto-register all available TA-Lib indicators with the registry.
Args: Args:
registry: IndicatorRegistry instance registry: IndicatorRegistry instance
only_tradingview_supported: If True, only register indicators that have
TradingView equivalents (default: True)
Returns: Returns:
Number of indicators registered Number of indicators registered
@@ -392,6 +394,9 @@ def register_all_talib_indicators(registry) -> int:
) )
return 0 return 0
# Get list of supported indicators if filtering is enabled
from .tv_mapping import is_indicator_supported
# Get all TA-Lib functions # Get all TA-Lib functions
func_groups = talib.get_function_groups() func_groups = talib.get_function_groups()
all_functions = [] all_functions = []
@@ -402,8 +407,16 @@ def register_all_talib_indicators(registry) -> int:
all_functions = sorted(set(all_functions)) all_functions = sorted(set(all_functions))
registered_count = 0 registered_count = 0
skipped_count = 0
for func_name in all_functions: for func_name in all_functions:
try: try:
# Skip if filtering enabled and indicator not supported in TradingView
if only_tradingview_supported and not is_indicator_supported(func_name):
skipped_count += 1
logger.debug(f"Skipping TA-Lib function {func_name} - not supported in TradingView")
continue
# Create indicator class for this function # Create indicator class for this function
indicator_class = create_talib_indicator_class(func_name) indicator_class = create_talib_indicator_class(func_name)
@@ -415,7 +428,7 @@ def register_all_talib_indicators(registry) -> int:
logger.warning(f"Failed to register TA-Lib function {func_name}: {e}") logger.warning(f"Failed to register TA-Lib function {func_name}: {e}")
continue continue
logger.info(f"Registered {registered_count} TA-Lib indicators") logger.info(f"Registered {registered_count} TA-Lib indicators (skipped {skipped_count} unsupported)")
return registered_count return registered_count

View File

@@ -0,0 +1,360 @@
"""
Mapping layer between TA-Lib indicators and TradingView indicators.
This module provides bidirectional conversion between our internal TA-Lib-based
indicator representation and TradingView's indicator system.
"""
from typing import Dict, Any, Optional, Tuple, List
import logging
logger = logging.getLogger(__name__)
# Mapping of TA-Lib indicator names to TradingView indicator names
# Only includes indicators that are present in BOTH systems (inner join)
# Format: {talib_name: tv_name}
TALIB_TO_TV_NAMES = {
# Overlap Studies (14)
"SMA": "Moving Average",
"EMA": "Moving Average Exponential",
"WMA": "Weighted Moving Average",
"DEMA": "DEMA",
"TEMA": "TEMA",
"TRIMA": "Triangular Moving Average",
"KAMA": "KAMA",
"MAMA": "MESA Adaptive Moving Average",
"T3": "T3",
"BBANDS": "Bollinger Bands",
"MIDPOINT": "Midpoint",
"MIDPRICE": "Midprice",
"SAR": "Parabolic SAR",
"HT_TRENDLINE": "Hilbert Transform - Instantaneous Trendline",
# Momentum Indicators (21)
"RSI": "Relative Strength Index",
"MOM": "Momentum",
"ROC": "Rate of Change",
"TRIX": "TRIX",
"CMO": "Chande Momentum Oscillator",
"DX": "Directional Movement Index",
"ADX": "Average Directional Movement Index",
"ADXR": "Average Directional Movement Index Rating",
"APO": "Absolute Price Oscillator",
"PPO": "Percentage Price Oscillator",
"MACD": "MACD",
"MFI": "Money Flow Index",
"STOCH": "Stochastic",
"STOCHF": "Stochastic Fast",
"STOCHRSI": "Stochastic RSI",
"WILLR": "Williams %R",
"CCI": "Commodity Channel Index",
"AROON": "Aroon",
"AROONOSC": "Aroon Oscillator",
"BOP": "Balance Of Power",
"ULTOSC": "Ultimate Oscillator",
# Volume Indicators (3)
"AD": "Chaikin A/D Line",
"ADOSC": "Chaikin A/D Oscillator",
"OBV": "On Balance Volume",
# Volatility Indicators (3)
"ATR": "Average True Range",
"NATR": "Normalized Average True Range",
"TRANGE": "True Range",
# Price Transform (4)
"AVGPRICE": "Average Price",
"MEDPRICE": "Median Price",
"TYPPRICE": "Typical Price",
"WCLPRICE": "Weighted Close Price",
# Cycle Indicators (5)
"HT_DCPERIOD": "Hilbert Transform - Dominant Cycle Period",
"HT_DCPHASE": "Hilbert Transform - Dominant Cycle Phase",
"HT_PHASOR": "Hilbert Transform - Phasor Components",
"HT_SINE": "Hilbert Transform - SineWave",
"HT_TRENDMODE": "Hilbert Transform - Trend vs Cycle Mode",
# Statistic Functions (9)
"BETA": "Beta",
"CORREL": "Pearson's Correlation Coefficient",
"LINEARREG": "Linear Regression",
"LINEARREG_ANGLE": "Linear Regression Angle",
"LINEARREG_INTERCEPT": "Linear Regression Intercept",
"LINEARREG_SLOPE": "Linear Regression Slope",
"STDDEV": "Standard Deviation",
"TSF": "Time Series Forecast",
"VAR": "Variance",
}
# Total: 60 indicators supported in both systems
# Custom indicators (TradingView indicators implemented in our backend)
CUSTOM_TO_TV_NAMES = {
"VWAP": "VWAP",
"VWMA": "VWMA",
"HMA": "Hull Moving Average",
"SUPERTREND": "SuperTrend",
"DONCHIAN": "Donchian Channels",
"KELTNER": "Keltner Channels",
"CMF": "Chaikin Money Flow",
"VORTEX": "Vortex Indicator",
"AO": "Awesome Oscillator",
"AC": "Accelerator Oscillator",
"CHOP": "Choppiness Index",
"MASS": "Mass Index",
}
# Combined mapping (TA-Lib + Custom)
ALL_BACKEND_TO_TV_NAMES = {**TALIB_TO_TV_NAMES, **CUSTOM_TO_TV_NAMES}
# Total: 72 indicators (60 TA-Lib + 12 Custom)
# Reverse mapping
TV_TO_TALIB_NAMES = {v: k for k, v in TALIB_TO_TV_NAMES.items()}
TV_TO_CUSTOM_NAMES = {v: k for k, v in CUSTOM_TO_TV_NAMES.items()}
TV_TO_BACKEND_NAMES = {v: k for k, v in ALL_BACKEND_TO_TV_NAMES.items()}
def get_tv_indicator_name(talib_name: str) -> str:
"""
Convert TA-Lib indicator name to TradingView indicator name.
Args:
talib_name: TA-Lib indicator name (e.g., 'RSI')
Returns:
TradingView indicator name
"""
return TALIB_TO_TV_NAMES.get(talib_name, talib_name)
def get_talib_indicator_name(tv_name: str) -> Optional[str]:
"""
Convert TradingView indicator name to TA-Lib indicator name.
Args:
tv_name: TradingView indicator name
Returns:
TA-Lib indicator name or None if not mapped
"""
return TV_TO_TALIB_NAMES.get(tv_name)
def convert_talib_params_to_tv_inputs(
talib_name: str,
talib_params: Dict[str, Any]
) -> Dict[str, Any]:
"""
Convert TA-Lib parameters to TradingView input format.
Args:
talib_name: TA-Lib indicator name
talib_params: TA-Lib parameter dictionary
Returns:
TradingView inputs dictionary
"""
tv_inputs = {}
# Common parameter mappings
param_mapping = {
"timeperiod": "length",
"fastperiod": "fastLength",
"slowperiod": "slowLength",
"signalperiod": "signalLength",
"nbdevup": "mult", # Standard deviations for upper band
"nbdevdn": "mult", # Standard deviations for lower band
"fastlimit": "fastLimit",
"slowlimit": "slowLimit",
"acceleration": "start",
"maximum": "increment",
"fastk_period": "kPeriod",
"slowk_period": "kPeriod",
"slowd_period": "dPeriod",
"fastd_period": "dPeriod",
"matype": "maType",
}
# Special handling for specific indicators
if talib_name == "BBANDS":
# Bollinger Bands
tv_inputs["length"] = talib_params.get("timeperiod", 20)
tv_inputs["mult"] = talib_params.get("nbdevup", 2)
tv_inputs["source"] = "close"
elif talib_name == "MACD":
# MACD
tv_inputs["fastLength"] = talib_params.get("fastperiod", 12)
tv_inputs["slowLength"] = talib_params.get("slowperiod", 26)
tv_inputs["signalLength"] = talib_params.get("signalperiod", 9)
tv_inputs["source"] = "close"
elif talib_name == "RSI":
# RSI
tv_inputs["length"] = talib_params.get("timeperiod", 14)
tv_inputs["source"] = "close"
elif talib_name in ["SMA", "EMA", "WMA", "DEMA", "TEMA", "TRIMA"]:
# Moving averages
tv_inputs["length"] = talib_params.get("timeperiod", 14)
tv_inputs["source"] = "close"
elif talib_name == "STOCH":
# Stochastic
tv_inputs["kPeriod"] = talib_params.get("fastk_period", 14)
tv_inputs["dPeriod"] = talib_params.get("slowd_period", 3)
tv_inputs["smoothK"] = talib_params.get("slowk_period", 3)
elif talib_name == "ATR":
# ATR
tv_inputs["length"] = talib_params.get("timeperiod", 14)
elif talib_name == "CCI":
# CCI
tv_inputs["length"] = talib_params.get("timeperiod", 20)
else:
# Generic parameter conversion
for talib_param, value in talib_params.items():
tv_param = param_mapping.get(talib_param, talib_param)
tv_inputs[tv_param] = value
logger.debug(f"Converted TA-Lib params for {talib_name}: {talib_params} -> TV inputs: {tv_inputs}")
return tv_inputs
def convert_tv_inputs_to_talib_params(
tv_name: str,
tv_inputs: Dict[str, Any]
) -> Tuple[Optional[str], Dict[str, Any]]:
"""
Convert TradingView inputs to TA-Lib parameters.
Args:
tv_name: TradingView indicator name
tv_inputs: TradingView inputs dictionary
Returns:
Tuple of (talib_name, talib_params)
"""
talib_name = get_talib_indicator_name(tv_name)
if not talib_name:
logger.warning(f"No TA-Lib mapping for TradingView indicator: {tv_name}")
return None, {}
talib_params = {}
# Reverse parameter mappings
reverse_mapping = {
"length": "timeperiod",
"fastLength": "fastperiod",
"slowLength": "slowperiod",
"signalLength": "signalperiod",
"mult": "nbdevup", # Use same for both up and down
"fastLimit": "fastlimit",
"slowLimit": "slowlimit",
"start": "acceleration",
"increment": "maximum",
"kPeriod": "fastk_period",
"dPeriod": "slowd_period",
"smoothK": "slowk_period",
"maType": "matype",
}
# Special handling for specific indicators
if talib_name == "BBANDS":
# Bollinger Bands
talib_params["timeperiod"] = tv_inputs.get("length", 20)
talib_params["nbdevup"] = tv_inputs.get("mult", 2)
talib_params["nbdevdn"] = tv_inputs.get("mult", 2)
talib_params["matype"] = 0 # SMA
elif talib_name == "MACD":
# MACD
talib_params["fastperiod"] = tv_inputs.get("fastLength", 12)
talib_params["slowperiod"] = tv_inputs.get("slowLength", 26)
talib_params["signalperiod"] = tv_inputs.get("signalLength", 9)
elif talib_name == "RSI":
# RSI
talib_params["timeperiod"] = tv_inputs.get("length", 14)
elif talib_name in ["SMA", "EMA", "WMA", "DEMA", "TEMA", "TRIMA"]:
# Moving averages
talib_params["timeperiod"] = tv_inputs.get("length", 14)
elif talib_name == "STOCH":
# Stochastic
talib_params["fastk_period"] = tv_inputs.get("kPeriod", 14)
talib_params["slowd_period"] = tv_inputs.get("dPeriod", 3)
talib_params["slowk_period"] = tv_inputs.get("smoothK", 3)
talib_params["slowk_matype"] = 0 # SMA
talib_params["slowd_matype"] = 0 # SMA
elif talib_name == "ATR":
# ATR
talib_params["timeperiod"] = tv_inputs.get("length", 14)
elif talib_name == "CCI":
# CCI
talib_params["timeperiod"] = tv_inputs.get("length", 20)
else:
# Generic parameter conversion
for tv_param, value in tv_inputs.items():
if tv_param == "source":
continue # Skip source parameter
talib_param = reverse_mapping.get(tv_param, tv_param)
talib_params[talib_param] = value
logger.debug(f"Converted TV inputs for {tv_name}: {tv_inputs} -> TA-Lib {talib_name} params: {talib_params}")
return talib_name, talib_params
def is_indicator_supported(talib_name: str) -> bool:
"""
Check if a TA-Lib indicator is supported in TradingView.
Args:
talib_name: TA-Lib indicator name
Returns:
True if supported
"""
return talib_name in TALIB_TO_TV_NAMES
def get_supported_indicators() -> List[str]:
"""
Get list of supported TA-Lib indicators.
Returns:
List of TA-Lib indicator names
"""
return list(TALIB_TO_TV_NAMES.keys())
def get_supported_indicator_count() -> int:
"""
Get count of supported indicators.
Returns:
Number of indicators supported in both systems (TA-Lib + Custom)
"""
return len(ALL_BACKEND_TO_TV_NAMES)
def is_custom_indicator(indicator_name: str) -> bool:
"""
Check if an indicator is a custom implementation (not TA-Lib).
Args:
indicator_name: Indicator name
Returns:
True if custom indicator
"""
return indicator_name in CUSTOM_TO_TV_NAMES
def get_backend_indicator_name(tv_name: str) -> Optional[str]:
"""
Get backend indicator name from TradingView name (TA-Lib or custom).
Args:
tv_name: TradingView indicator name
Returns:
Backend indicator name or None if not mapped
"""
return TV_TO_BACKEND_NAMES.get(tv_name)

View File

@@ -21,13 +21,18 @@ from gateway.channels.websocket import WebSocketChannel
from gateway.protocol import WebSocketAgentUserMessage from gateway.protocol import WebSocketAgentUserMessage
from agent.core import create_agent from agent.core import create_agent
from agent.tools import set_registry, set_datasource_registry, set_indicator_registry from agent.tools import set_registry, set_datasource_registry, set_indicator_registry
from agent.tools import set_trigger_queue, set_trigger_scheduler, set_coordinator
from schema.order_spec import SwapOrder from schema.order_spec import SwapOrder
from schema.chart_state import ChartState from schema.chart_state import ChartState
from schema.shape import ShapeCollection
from schema.indicator import IndicatorCollection
from datasource.registry import DataSourceRegistry from datasource.registry import DataSourceRegistry
from datasource.subscription_manager import SubscriptionManager from datasource.subscription_manager import SubscriptionManager
from datasource.websocket_handler import DatafeedWebSocketHandler from datasource.websocket_handler import DatafeedWebSocketHandler
from secrets_manager import SecretsStore, InvalidMasterPassword from secrets_manager import SecretsStore, InvalidMasterPassword
from indicator import IndicatorRegistry, register_all_talib_indicators from indicator import IndicatorRegistry, register_all_talib_indicators, register_custom_indicators
from trigger import CommitCoordinator, TriggerQueue
from trigger.scheduler import TriggerScheduler
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(
@@ -57,6 +62,11 @@ subscription_manager = SubscriptionManager()
# Indicator infrastructure # Indicator infrastructure
indicator_registry = IndicatorRegistry() indicator_registry = IndicatorRegistry()
# Trigger system infrastructure
trigger_coordinator = None
trigger_queue = None
trigger_scheduler = None
# Global secrets store # Global secrets store
secrets_store = SecretsStore() secrets_store = SecretsStore()
@@ -64,7 +74,7 @@ secrets_store = SecretsStore()
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Initialize agent system and data sources on startup.""" """Initialize agent system and data sources on startup."""
global agent_executor global agent_executor, trigger_coordinator, trigger_queue, trigger_scheduler
# Initialize CCXT data sources # Initialize CCXT data sources
try: try:
@@ -92,6 +102,13 @@ async def lifespan(app: FastAPI):
logger.warning(f"Failed to register TA-Lib indicators: {e}") logger.warning(f"Failed to register TA-Lib indicators: {e}")
logger.info("TA-Lib indicators will not be available. Install TA-Lib C library and Python wrapper to enable.") logger.info("TA-Lib indicators will not be available. Install TA-Lib C library and Python wrapper to enable.")
# Register custom indicators (TradingView indicators not in TA-Lib)
try:
custom_count = register_custom_indicators(indicator_registry)
logger.info(f"Registered {custom_count} custom indicators")
except Exception as e:
logger.warning(f"Failed to register custom indicators: {e}")
# Get API keys from secrets store if unlocked, otherwise fall back to environment # Get API keys from secrets store if unlocked, otherwise fall back to environment
anthropic_api_key = None anthropic_api_key = None
@@ -106,6 +123,22 @@ async def lifespan(app: FastAPI):
if anthropic_api_key: if anthropic_api_key:
logger.info("Loaded API key from environment") logger.info("Loaded API key from environment")
# Initialize trigger system
logger.info("Initializing trigger system...")
trigger_coordinator = CommitCoordinator()
trigger_queue = TriggerQueue(trigger_coordinator)
trigger_scheduler = TriggerScheduler(trigger_queue)
# Start trigger queue and scheduler
await trigger_queue.start()
trigger_scheduler.start()
logger.info("Trigger system initialized and started")
# Set trigger system for agent tools
set_coordinator(trigger_coordinator)
set_trigger_queue(trigger_queue)
set_trigger_scheduler(trigger_scheduler)
if not anthropic_api_key: if not anthropic_api_key:
logger.error("ANTHROPIC_API_KEY not found in environment!") logger.error("ANTHROPIC_API_KEY not found in environment!")
logger.info("Agent system will not be available") logger.info("Agent system will not be available")
@@ -124,7 +157,7 @@ async def lifespan(app: FastAPI):
chroma_db_path=config["memory"]["chroma_db"], chroma_db_path=config["memory"]["chroma_db"],
embedding_model=config["memory"]["embedding_model"], embedding_model=config["memory"]["embedding_model"],
context_docs_dir=config["agent"]["context_docs_dir"], context_docs_dir=config["agent"]["context_docs_dir"],
base_dir=".." # Point to project root from backend/src base_dir="." # backend/src is the working directory, so . goes to backend, where memory/ and soul/ live
) )
await agent_executor.initialize() await agent_executor.initialize()
@@ -137,9 +170,22 @@ async def lifespan(app: FastAPI):
yield yield
# Cleanup # Cleanup
logger.info("Shutting down systems...")
# Shutdown trigger system
if trigger_scheduler:
trigger_scheduler.shutdown(wait=True)
logger.info("Trigger scheduler shut down")
if trigger_queue:
await trigger_queue.stop()
logger.info("Trigger queue stopped")
# Shutdown agent system
if agent_executor and agent_executor.memory_manager: if agent_executor and agent_executor.memory_manager:
await agent_executor.memory_manager.close() await agent_executor.memory_manager.close()
logger.info("Agent system shut down")
logger.info("All systems shut down")
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
@@ -159,13 +205,25 @@ class OrderStore(BaseModel):
class ChartStore(BaseModel): class ChartStore(BaseModel):
chart_state: ChartState = ChartState() chart_state: ChartState = ChartState()
# ShapeStore model for synchronization
class ShapeStore(BaseModel):
shapes: dict[str, dict] = {} # Dictionary of shapes keyed by ID
# IndicatorStore model for synchronization
class IndicatorStore(BaseModel):
indicators: dict[str, dict] = {} # Dictionary of indicators keyed by ID
# Initialize stores # Initialize stores
order_store = OrderStore() order_store = OrderStore()
chart_store = ChartStore() chart_store = ChartStore()
shape_store = ShapeStore()
indicator_store = IndicatorStore()
# Register with SyncRegistry # Register with SyncRegistry
registry.register(order_store, store_name="OrderStore") registry.register(order_store, store_name="OrderStore")
registry.register(chart_store, store_name="ChartStore") registry.register(chart_store, store_name="ChartStore")
registry.register(shape_store, store_name="ShapeStore")
registry.register(indicator_store, store_name="IndicatorStore")
@app.websocket("/ws") @app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
@@ -361,11 +419,14 @@ async def websocket_endpoint(websocket: WebSocket):
elif msg_type == "patch": elif msg_type == "patch":
patch_msg = PatchMessage(**message_json) patch_msg = PatchMessage(**message_json)
logger.info(f"Patch message received for store: {patch_msg.store}, seq: {patch_msg.seq}") logger.info(f"Patch message received for store: {patch_msg.store}, seq: {patch_msg.seq}")
await registry.apply_client_patch( try:
store_name=patch_msg.store, await registry.apply_client_patch(
client_base_seq=patch_msg.seq, store_name=patch_msg.store,
patch=patch_msg.patch client_base_seq=patch_msg.seq,
) patch=patch_msg.patch
)
except Exception as e:
logger.error(f"Error applying client patch: {e}. Client will receive snapshot to resync.", exc_info=True)
elif msg_type == "agent_user_message": elif msg_type == "agent_user_message":
# Handle agent messages directly here # Handle agent messages directly here
print(f"[DEBUG] Raw message_json: {message_json}") print(f"[DEBUG] Raw message_json: {message_json}")

View File

@@ -1,4 +1,4 @@
from typing import Optional from typing import Optional, List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -7,10 +7,13 @@ class ChartState(BaseModel):
This state is synchronized between the frontend and backend to allow This state is synchronized between the frontend and backend to allow
the AI agent to understand what the user is currently viewing. the AI agent to understand what the user is currently viewing.
All fields can be None when no chart is visible (e.g., on mobile/narrow screens).
""" """
# Current symbol being viewed (e.g., "BINANCE:BTC/USDT", "BINANCE:ETH/USDT") # Current symbol being viewed (e.g., "BINANCE:BTC/USDT", "BINANCE:ETH/USDT")
symbol: str = Field(default="BINANCE:BTC/USDT", description="Current trading pair symbol") # None when chart is not visible
symbol: Optional[str] = Field(default="BINANCE:BTC/USDT", description="Current trading pair symbol, or None if no chart visible")
# Time range currently visible on chart (Unix timestamps in seconds) # Time range currently visible on chart (Unix timestamps in seconds)
# These represent the leftmost and rightmost visible candle times # These represent the leftmost and rightmost visible candle times
@@ -18,4 +21,8 @@ class ChartState(BaseModel):
end_time: Optional[int] = Field(default=None, description="End time of visible range (Unix timestamp in seconds)") end_time: Optional[int] = Field(default=None, description="End time of visible range (Unix timestamp in seconds)")
# Optional: Chart interval/resolution # Optional: Chart interval/resolution
interval: str = Field(default="15", description="Chart interval (e.g., '1', '5', '15', '60', 'D')") # None when chart is not visible
interval: Optional[str] = Field(default="15", description="Chart interval (e.g., '1', '5', '15', '60', 'D'), or None if no chart visible")
# Selected shapes/drawings on the chart
selected_shapes: List[str] = Field(default_factory=list, description="Array of selected shape IDs")

View File

@@ -0,0 +1,40 @@
from typing import Dict, Any, Optional, List
from pydantic import BaseModel, Field
class IndicatorInstance(BaseModel):
"""
Represents an instance of an indicator applied to a chart.
This schema holds both the TA-Lib metadata and TradingView-specific data
needed for synchronization.
"""
id: str = Field(..., description="Unique identifier for this indicator instance")
# TA-Lib metadata
talib_name: str = Field(..., description="TA-Lib indicator name (e.g., 'RSI', 'SMA', 'MACD')")
instance_name: str = Field(..., description="User-friendly instance name")
parameters: Dict[str, Any] = Field(default_factory=dict, description="TA-Lib indicator parameters")
# TradingView metadata
tv_study_id: Optional[str] = Field(default=None, description="TradingView study ID assigned by the chart widget")
tv_indicator_name: Optional[str] = Field(default=None, description="TradingView indicator name if different from TA-Lib")
tv_inputs: Optional[Dict[str, Any]] = Field(default=None, description="TradingView-specific input parameters")
# Visual properties
visible: bool = Field(default=True, description="Whether indicator is visible on chart")
pane: str = Field(default="chart", description="Pane where indicator is displayed ('chart' or 'separate')")
# Metadata
symbol: Optional[str] = Field(default=None, description="Symbol this indicator is applied to")
created_at: Optional[int] = Field(default=None, description="Creation timestamp (Unix seconds)")
modified_at: Optional[int] = Field(default=None, description="Last modification timestamp (Unix seconds)")
original_id: Optional[str] = Field(default=None, description="Original ID from backend before TradingView assigns its own ID")
class IndicatorCollection(BaseModel):
"""Collection of all indicator instances on the chart."""
indicators: Dict[str, IndicatorInstance] = Field(
default_factory=dict,
description="Dictionary of indicator instances keyed by ID"
)

View File

@@ -0,0 +1,44 @@
from typing import List, Dict, Any, Optional
from pydantic import BaseModel, Field
class ControlPoint(BaseModel):
"""A control point for a drawing shape.
Control points define the position and properties of a shape.
Different shapes have different numbers of control points.
"""
time: int = Field(..., description="Unix timestamp in seconds")
price: float = Field(..., description="Price level")
# Optional channel for multi-point shapes (e.g., parallel channels)
channel: Optional[str] = Field(default=None, description="Channel identifier for multi-point shapes")
class Shape(BaseModel):
"""A TradingView drawing shape/study.
Represents any drawing the user creates on the chart (trendlines,
horizontal lines, rectangles, Fibonacci retracements, etc.)
"""
id: str = Field(..., description="Unique identifier for the shape")
type: str = Field(..., description="Shape type (e.g., 'trendline', 'horizontal_line', 'rectangle', 'fibonacci')")
points: List[ControlPoint] = Field(default_factory=list, description="Control points that define the shape")
# Visual properties
color: Optional[str] = Field(default=None, description="Shape color (hex or color name)")
line_width: Optional[int] = Field(default=1, description="Line width in pixels")
line_style: Optional[str] = Field(default="solid", description="Line style: 'solid', 'dashed', 'dotted'")
# Shape-specific properties stored as flexible dict
properties: Dict[str, Any] = Field(default_factory=dict, description="Additional shape-specific properties")
# Metadata
symbol: Optional[str] = Field(default=None, description="Symbol this shape is drawn on")
created_at: Optional[int] = Field(default=None, description="Creation timestamp (Unix seconds)")
modified_at: Optional[int] = Field(default=None, description="Last modification timestamp (Unix seconds)")
original_id: Optional[str] = Field(default=None, description="Original ID from backend/agent before TradingView assigns its own ID")
class ShapeCollection(BaseModel):
"""Collection of all shapes/drawings on the chart."""
shapes: Dict[str, Shape] = Field(default_factory=dict, description="Dictionary of shapes keyed by ID")

View File

@@ -0,0 +1,246 @@
from collections import deque
from typing import Any, Dict, List, Optional, Tuple, Deque
import jsonpatch
from pydantic import BaseModel
from sync.protocol import SnapshotMessage, PatchMessage
class SyncEntry:
def __init__(self, model: BaseModel, store_name: str, history_size: int = 50):
self.model = model
self.store_name = store_name
self.seq = 0
self.last_snapshot = model.model_dump(mode="json")
self.history: Deque[Tuple[int, List[Dict[str, Any]]]] = deque(maxlen=history_size)
def compute_patch(self) -> Optional[List[Dict[str, Any]]]:
current_state = self.model.model_dump(mode="json")
patch = jsonpatch.make_patch(self.last_snapshot, current_state)
if not patch.patch:
return None
return patch.patch
def commit_patch(self, patch: List[Dict[str, Any]]):
self.seq += 1
self.history.append((self.seq, patch))
self.last_snapshot = self.model.model_dump(mode="json")
def catchup_patches(self, since_seq: int) -> Optional[List[Tuple[int, List[Dict[str, Any]]]]]:
if since_seq == self.seq:
return []
# Check if all patches from since_seq + 1 to self.seq are in history
if not self.history or self.history[0][0] > since_seq + 1:
return None
result = []
for seq, patch in self.history:
if seq > since_seq:
result.append((seq, patch))
return result
class SyncRegistry:
def __init__(self):
self.entries: Dict[str, SyncEntry] = {}
self.websocket: Optional[Any] = None # Expecting a FastAPI WebSocket or similar
def register(self, model: BaseModel, store_name: str):
self.entries[store_name] = SyncEntry(model, store_name)
async def push_all(self):
import logging
logger = logging.getLogger(__name__)
if not self.websocket:
logger.warning("push_all: No websocket connected, cannot push updates")
return
logger.info(f"push_all: Processing {len(self.entries)} store entries")
for entry in self.entries.values():
patch = entry.compute_patch()
if patch:
logger.info(f"push_all: Found patch for store '{entry.store_name}': {patch}")
entry.commit_patch(patch)
msg = PatchMessage(store=entry.store_name, seq=entry.seq, patch=patch)
logger.info(f"push_all: Sending patch message for '{entry.store_name}' seq={entry.seq}")
await self.websocket.send_json(msg.model_dump(mode="json"))
logger.info(f"push_all: Patch sent successfully for '{entry.store_name}'")
else:
logger.debug(f"push_all: No changes detected for store '{entry.store_name}'")
async def sync_client(self, client_seqs: Dict[str, int]):
if not self.websocket:
return
for store_name, entry in self.entries.items():
client_seq = client_seqs.get(store_name, -1)
patches = entry.catchup_patches(client_seq)
if patches is not None:
# Replay patches
for seq, patch in patches:
msg = PatchMessage(store=store_name, seq=seq, patch=patch)
await self.websocket.send_json(msg.model_dump(mode="json"))
else:
# Send full snapshot
msg = SnapshotMessage(
store=store_name,
seq=entry.seq,
state=entry.model.model_dump(mode="json")
)
await self.websocket.send_json(msg.model_dump(mode="json"))
async def apply_client_patch(self, store_name: str, client_base_seq: int, patch: List[Dict[str, Any]]):
import logging
logger = logging.getLogger(__name__)
logger.info(f"apply_client_patch: store={store_name}, client_base_seq={client_base_seq}, patch={patch}")
entry = self.entries.get(store_name)
if not entry:
logger.warning(f"apply_client_patch: Store '{store_name}' not found in registry")
return
logger.info(f"apply_client_patch: Current backend seq={entry.seq}")
try:
if client_base_seq == entry.seq:
# No conflict
logger.info("apply_client_patch: No conflict - applying patch directly")
current_state = entry.model.model_dump(mode="json")
logger.info(f"apply_client_patch: Current state before patch: {current_state}")
try:
new_state = jsonpatch.apply_patch(current_state, patch)
logger.info(f"apply_client_patch: New state after patch: {new_state}")
self._update_model(entry.model, new_state)
# Verify the model was actually updated
updated_state = entry.model.model_dump(mode="json")
logger.info(f"apply_client_patch: Model state after _update_model: {updated_state}")
entry.commit_patch(patch)
logger.info(f"apply_client_patch: Patch committed, new seq={entry.seq}")
# Don't broadcast back to client - they already have this change
# Broadcasting would cause an infinite loop
logger.info("apply_client_patch: Not broadcasting back to originating client")
except jsonpatch.JsonPatchConflict as e:
logger.warning(f"apply_client_patch: Patch conflict on no-conflict path: {e}. Sending snapshot to resync.")
# Send snapshot to force resync
if self.websocket:
msg = SnapshotMessage(
store=entry.store_name,
seq=entry.seq,
state=entry.model.model_dump(mode="json")
)
await self.websocket.send_json(msg.model_dump(mode="json"))
elif client_base_seq < entry.seq:
# Conflict! Frontend wins.
# 1. Get backend patches since client_base_seq
backend_patches = []
for seq, p in entry.history:
if seq > client_base_seq:
backend_patches.append(p)
# 2. Apply frontend patch first to the state at client_base_seq
# But we only have the current authoritative model.
# "Apply the frontend patch first to the model (frontend wins)"
# "Re-apply the backend deltas that do not overlap the frontend's changed paths on top"
# Let's get the state as it was at client_base_seq if possible?
# No, history only has patches.
# Alternative: Apply frontend patch to current model.
# Then re-apply backend patches, but discard parts that overlap.
frontend_paths = {p['path'] for p in patch}
current_state = entry.model.model_dump(mode="json")
# Apply frontend patch
try:
new_state = jsonpatch.apply_patch(current_state, patch)
except jsonpatch.JsonPatchConflict as e:
logger.warning(f"apply_client_patch: Failed to apply client patch during conflict resolution: {e}. Sending snapshot to resync.")
# Send snapshot to force resync
if self.websocket:
msg = SnapshotMessage(
store=entry.store_name,
seq=entry.seq,
state=entry.model.model_dump(mode="json")
)
await self.websocket.send_json(msg.model_dump(mode="json"))
return
# Re-apply backend patches that don't overlap
for b_patch in backend_patches:
filtered_b_patch = [op for op in b_patch if op['path'] not in frontend_paths]
if filtered_b_patch:
try:
new_state = jsonpatch.apply_patch(new_state, filtered_b_patch)
except jsonpatch.JsonPatchConflict as e:
logger.warning(f"apply_client_patch: Failed to apply backend patch during conflict resolution: {e}. Skipping this patch.")
continue
self._update_model(entry.model, new_state)
# Commit the result as a single new patch
# We need to compute what changed from last_snapshot to new_state
final_patch = jsonpatch.make_patch(entry.last_snapshot, new_state).patch
if final_patch:
entry.commit_patch(final_patch)
# Broadcast resolved state as snapshot to converge
if self.websocket:
msg = SnapshotMessage(
store=entry.store_name,
seq=entry.seq,
state=entry.model.model_dump(mode="json")
)
await self.websocket.send_json(msg.model_dump(mode="json"))
except Exception as e:
logger.error(f"apply_client_patch: Unexpected error: {e}. Sending snapshot to resync.", exc_info=True)
# Send snapshot to force resync
if self.websocket:
msg = SnapshotMessage(
store=entry.store_name,
seq=entry.seq,
state=entry.model.model_dump(mode="json")
)
await self.websocket.send_json(msg.model_dump(mode="json"))
def _update_model(self, model: BaseModel, new_data: Dict[str, Any]):
# Update model fields in-place to preserve references
# This is important for dict fields that may be referenced elsewhere
for field_name, field_info in model.model_fields.items():
if field_name in new_data:
new_value = new_data[field_name]
current_value = getattr(model, field_name)
# For dict fields, update in-place instead of replacing
if isinstance(current_value, dict) and isinstance(new_value, dict):
self._deep_update_dict(current_value, new_value)
else:
# For other types, just set the new value
setattr(model, field_name, new_value)
def _deep_update_dict(self, target: dict, source: dict):
"""Deep update target dict with source dict, preserving nested dict references."""
# Remove keys that are in target but not in source
keys_to_remove = set(target.keys()) - set(source.keys())
for key in keys_to_remove:
del target[key]
# Update or add keys from source
for key, source_value in source.items():
if key in target:
target_value = target[key]
# If both are dicts, recursively update
if isinstance(target_value, dict) and isinstance(source_value, dict):
self._deep_update_dict(target_value, source_value)
else:
# Replace the value
target[key] = source_value
else:
# Add new key
target[key] = source_value

View File

@@ -0,0 +1,216 @@
# Priority System
Simple tuple-based priorities for deterministic execution ordering.
## Basic Concept
Priorities are just **Python tuples**. Python compares tuples element-by-element, left-to-right:
```python
(0, 1000, 5) < (0, 1001, 3) # True: 0==0, but 1000 < 1001
(0, 1000, 5) < (1, 500, 2) # True: 0 < 1
(0, 1000) < (0, 1000, 5) # True: shorter wins if equal so far
```
**Lower values = higher priority** (processed first).
## Priority Categories
```python
class Priority(IntEnum):
DATA_SOURCE = 0 # Market data, real-time feeds
TIMER = 1 # Scheduled tasks, cron jobs
USER_AGENT = 2 # User-agent interactions (chat)
USER_DATA_REQUEST = 3 # User data requests (charts)
SYSTEM = 4 # Background tasks, cleanup
LOW = 5 # Retries after conflicts
```
## Usage Examples
### Simple Priority
```python
# Just use the Priority enum
trigger = MyTrigger("task", priority=Priority.SYSTEM)
await queue.enqueue(trigger)
# Results in tuple: (4, queue_seq)
```
### Compound Priority (Tuple)
```python
# DataSource: sort by event time (older bars first)
trigger = DataUpdateTrigger(
source_name="binance",
symbol="BTC/USDT",
resolution="1m",
bar_data={"time": 1678896000, "open": 50000, ...}
)
await queue.enqueue(trigger)
# Results in tuple: (0, 1678896000, queue_seq)
# ^ ^ ^
# | | Queue insertion order (FIFO)
# | Event time (candle end time)
# DATA_SOURCE priority
```
### Manual Override
```python
# Override at enqueue time
await queue.enqueue(
trigger,
priority_override=(Priority.DATA_SOURCE, custom_time, custom_sort)
)
# Queue appends queue_seq: (0, custom_time, custom_sort, queue_seq)
```
## Common Patterns
### Market Data (Process Chronologically)
```python
# Bar from 10:00 → (0, 10:00_timestamp, queue_seq)
# Bar from 10:05 → (0, 10:05_timestamp, queue_seq)
#
# 10:00 bar processes first (earlier event_time)
DataUpdateTrigger(
...,
bar_data={"time": event_timestamp, ...}
)
```
### User Messages (FIFO Order)
```python
# Message #1 → (2, msg1_timestamp, queue_seq)
# Message #2 → (2, msg2_timestamp, queue_seq)
#
# Message #1 processes first (earlier timestamp)
AgentTriggerHandler(
session_id="user1",
message_content="...",
message_timestamp=unix_timestamp # Optional, defaults to now
)
```
### Scheduled Tasks (By Schedule Time)
```python
# Job scheduled for 9 AM → (1, 9am_timestamp, queue_seq)
# Job scheduled for 2 PM → (1, 2pm_timestamp, queue_seq)
#
# 9 AM job processes first
CronTrigger(
name="morning_sync",
inner_trigger=...,
scheduled_time=scheduled_timestamp
)
```
## Execution Order Example
```
Queue contains:
1. DataSource (BTC @ 10:00) → (0, 10:00, 1)
2. DataSource (BTC @ 10:05) → (0, 10:05, 2)
3. Timer (scheduled 9 AM) → (1, 09:00, 3)
4. User message #1 → (2, 14:30, 4)
5. User message #2 → (2, 14:35, 5)
Dequeue order:
1. DataSource (BTC @ 10:00) ← 0 < all others
2. DataSource (BTC @ 10:05) ← 0 < all others, 10:05 > 10:00
3. Timer (scheduled 9 AM) ← 1 < remaining
4. User message #1 ← 2 < remaining, 14:30 < 14:35
5. User message #2 ← last
```
## Short Tuple Wins
If tuples are equal up to the length of the shorter one, **shorter tuple has higher priority**:
```python
(0, 1000) < (0, 1000, 5) # True: shorter wins
(0,) < (0, 1000) # True: shorter wins
(Priority.DATA_SOURCE,) < (Priority.DATA_SOURCE, 1000) # True
```
This is Python's default tuple comparison behavior. In practice, we always append `queue_seq`, so this rarely matters (all tuples end up same length).
## Integration with Triggers
### Trigger Sets Its Own Priority
```python
class MyTrigger(Trigger):
def __init__(self, event_time):
super().__init__(
name="my_trigger",
priority=Priority.DATA_SOURCE,
priority_tuple=(Priority.DATA_SOURCE.value, event_time)
)
```
Queue appends `queue_seq` automatically:
```python
# Trigger's tuple: (0, event_time)
# After enqueue: (0, event_time, queue_seq)
```
### Override at Enqueue
```python
# Ignore trigger's priority, use override
await queue.enqueue(
trigger,
priority_override=(Priority.TIMER, scheduled_time)
)
```
## Why Tuples?
**Simple**: No custom classes, just native Python tuples
**Flexible**: Add as many sort keys as needed
**Efficient**: Python's tuple comparison is highly optimized
**Readable**: `(0, 1000, 5)` is obvious what it means
**Debuggable**: Can print and inspect easily
Example:
```python
# Old: CompoundPriority(primary=0, secondary=1000, tertiary=5)
# New: (0, 1000, 5)
# Same semantics, much simpler!
```
## Advanced: Custom Sorting
Want to sort by multiple factors? Just add more elements:
```python
# Sort by: priority → symbol → event_time → queue_seq
priority_tuple = (
Priority.DATA_SOURCE.value,
symbol_id, # e.g., hash("BTC/USDT")
event_time,
# queue_seq appended by queue
)
```
## Summary
- **Priorities are tuples**: `(primary, secondary, ..., queue_seq)`
- **Lower = higher priority**: Processed first
- **Element-by-element comparison**: Left-to-right
- **Shorter tuple wins**: If equal up to shorter length
- **Queue appends queue_seq**: Always last element (FIFO within same priority)
That's it! No complex classes, just tuples. 🎯

View File

@@ -0,0 +1,386 @@
# Trigger System
Lock-free, sequence-based execution system for deterministic event processing.
## Overview
All operations (WebSocket messages, cron tasks, data updates) flow through a **priority queue**, execute in **parallel**, but commit in **strict sequential order** with **optimistic conflict detection**.
### Key Features
- **Lock-free reads**: Snapshots are deep copies, no blocking
- **Sequential commits**: Total ordering via sequence numbers
- **Optimistic concurrency**: Conflicts detected, retry with same seq
- **Priority preservation**: High-priority work never blocked by low-priority
- **Long-running agents**: Execute in parallel, commit sequentially
- **Deterministic replay**: Can reproduce exact system state at any seq
## Architecture
```
┌─────────────┐
│ WebSocket │───┐
│ Messages │ │
└─────────────┘ │
├──→ ┌─────────────────┐
┌─────────────┐ │ │ TriggerQueue │
│ Cron │───┤ │ (Priority Queue)│
│ Scheduled │ │ └────────┬────────┘
└─────────────┘ │ │ Assign seq
│ ↓
┌─────────────┐ │ ┌─────────────────┐
│ DataSource │───┘ │ Execute Trigger│
│ Updates │ │ (Parallel OK) │
└─────────────┘ └────────┬────────┘
│ CommitIntents
┌─────────────────┐
│ CommitCoordinator│
│ (Sequential) │
└────────┬────────┘
│ Commit in seq order
┌─────────────────┐
│ VersionedStores │
│ (w/ Backends) │
└─────────────────┘
```
## Core Components
### 1. ExecutionContext (`context.py`)
Tracks execution seq and store snapshots via `contextvars` (auto-propagates through async calls).
```python
from trigger import get_execution_context
ctx = get_execution_context()
print(f"Running at seq {ctx.seq}")
```
### 2. Trigger Types (`types.py`)
```python
from trigger import Trigger, Priority, CommitIntent
class MyTrigger(Trigger):
async def execute(self) -> list[CommitIntent]:
# Read snapshot
seq, data = some_store.read_snapshot()
# Modify
new_data = modify(data)
# Prepare commit
intent = some_store.prepare_commit(seq, new_data)
return [intent]
```
### 3. VersionedStore (`store.py`)
Stores with pluggable backends and optimistic concurrency:
```python
from trigger import VersionedStore, PydanticStoreBackend
# Wrap existing Pydantic model
backend = PydanticStoreBackend(order_store)
versioned_store = VersionedStore("OrderStore", backend)
# Lock-free snapshot read
seq, snapshot = versioned_store.read_snapshot()
# Prepare commit (does not modify yet)
intent = versioned_store.prepare_commit(seq, modified_snapshot)
```
**Pluggable Backends**:
- `PydanticStoreBackend`: For existing Pydantic models (OrderStore, ChartStore, etc.)
- `FileStoreBackend`: Future - version files (Python scripts, configs)
- `DatabaseStoreBackend`: Future - version database rows
### 4. CommitCoordinator (`coordinator.py`)
Manages sequential commits with conflict detection:
- Waits for seq N to commit before N+1
- Detects conflicts (expected_seq vs committed_seq)
- Re-executes (not re-enqueues) on conflict **with same seq**
- Tracks execution state for debugging
### 5. TriggerQueue (`queue.py`)
Priority queue with seq assignment:
```python
from trigger import TriggerQueue
queue = TriggerQueue(coordinator)
await queue.start()
# Enqueue trigger
await queue.enqueue(my_trigger, Priority.HIGH)
```
### 6. TriggerScheduler (`scheduler.py`)
APScheduler integration for cron triggers:
```python
from trigger.scheduler import TriggerScheduler
scheduler = TriggerScheduler(queue)
scheduler.start()
# Every 5 minutes
scheduler.schedule_interval(
IndicatorUpdateTrigger("rsi_14"),
minutes=5
)
# Daily at 9 AM
scheduler.schedule_cron(
SyncExchangeStateTrigger(),
hour="9",
minute="0"
)
```
## Integration Example
### Basic Setup in `main.py`
```python
from trigger import (
CommitCoordinator,
TriggerQueue,
VersionedStore,
PydanticStoreBackend,
)
from trigger.scheduler import TriggerScheduler
# Create coordinator
coordinator = CommitCoordinator()
# Wrap existing stores
order_store_versioned = VersionedStore(
"OrderStore",
PydanticStoreBackend(order_store)
)
coordinator.register_store(order_store_versioned)
chart_store_versioned = VersionedStore(
"ChartStore",
PydanticStoreBackend(chart_store)
)
coordinator.register_store(chart_store_versioned)
# Create queue and scheduler
trigger_queue = TriggerQueue(coordinator)
await trigger_queue.start()
scheduler = TriggerScheduler(trigger_queue)
scheduler.start()
```
### WebSocket Message Handler
```python
from trigger.handlers import AgentTriggerHandler
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
while True:
data = await websocket.receive_json()
if data["type"] == "agent_user_message":
# Enqueue agent trigger instead of direct Gateway call
trigger = AgentTriggerHandler(
session_id=data["session_id"],
message_content=data["content"],
gateway_handler=gateway.route_user_message,
coordinator=coordinator,
)
await trigger_queue.enqueue(trigger)
```
### DataSource Updates
```python
from trigger.handlers import DataUpdateTrigger
# In subscription_manager._on_source_update()
def _on_source_update(self, source_key: tuple, bar: dict):
# Enqueue data update trigger
trigger = DataUpdateTrigger(
source_name=source_key[0],
symbol=source_key[1],
resolution=source_key[2],
bar_data=bar,
coordinator=coordinator,
)
asyncio.create_task(trigger_queue.enqueue(trigger))
```
### Custom Trigger
```python
from trigger import Trigger, CommitIntent, Priority
class RecalculatePortfolioTrigger(Trigger):
def __init__(self, coordinator):
super().__init__("recalc_portfolio", Priority.NORMAL)
self.coordinator = coordinator
async def execute(self) -> list[CommitIntent]:
# Read snapshots from multiple stores
order_seq, orders = self.coordinator.get_store("OrderStore").read_snapshot()
chart_seq, chart = self.coordinator.get_store("ChartStore").read_snapshot()
# Calculate portfolio value
portfolio_value = calculate_portfolio(orders, chart)
# Update chart state with portfolio value
chart.portfolio_value = portfolio_value
# Prepare commit
intent = self.coordinator.get_store("ChartStore").prepare_commit(
chart_seq,
chart
)
return [intent]
# Schedule it
scheduler.schedule_interval(
RecalculatePortfolioTrigger(coordinator),
minutes=1
)
```
## Execution Flow
### Normal Flow (No Conflicts)
```
seq=100: WebSocket message arrives → enqueue → dequeue → assign seq=100 → execute
seq=101: Cron trigger fires → enqueue → dequeue → assign seq=101 → execute
seq=101 finishes first → waits in commit queue
seq=100 finishes → commits immediately (next in order)
seq=101 commits next
```
### Conflict Flow
```
seq=100: reads OrderStore at seq=99 → executes for 30 seconds
seq=101: reads OrderStore at seq=99 → executes for 5 seconds
seq=101 finishes first → tries to commit based on seq=99
seq=100 finishes → commits OrderStore at seq=100
Coordinator detects conflict:
expected_seq=99, committed_seq=100
seq=101 evicted → RE-EXECUTES with same seq=101 (not re-enqueued)
reads OrderStore at seq=100 → executes again
finishes → commits successfully at seq=101
```
## Benefits
### For Agent System
- **Long-running agents work naturally**: Agent starts at seq=100, runs for 60 seconds while market data updates at seq=101-110, commits only if no conflicts
- **No deadlocks**: No locks = no deadlock possibility
- **Deterministic**: Can replay from any seq for debugging
### For Strategy Execution
- **High-frequency data doesn't block strategies**: Data updates enqueued, executed in parallel, commit sequentially
- **Priority preservation**: Critical order execution never blocked by indicator calculations
- **Conflict detection**: If market moved during strategy calculation, automatically retry with fresh data
### For Scaling
- **Single-node first**: Runs on single asyncio event loop, no complex distributed coordination
- **Future-proof**: Can swap queue for Redis/PostgreSQL-backed distributed queue later
- **Event sourcing ready**: All commits have seq numbers, can build event log
## Debugging
### Check Current State
```python
# Coordinator stats
stats = coordinator.get_stats()
print(f"Current seq: {stats['current_seq']}")
print(f"Pending commits: {stats['pending_commits']}")
print(f"Executions by state: {stats['state_counts']}")
# Store state
store = coordinator.get_store("OrderStore")
print(f"Store: {store}") # Shows committed_seq and version
# Execution record
record = coordinator.get_execution_record(100)
print(f"Seq 100: {record}") # Shows state, retry_count, error
```
### Common Issues
**Symptoms: High conflict rate**
- **Cause**: Multiple triggers modifying same store frequently
- **Solution**: Batch updates, use debouncing, or redesign to reduce contention
**Symptoms: Commits stuck (next_commit_seq not advancing)**
- **Cause**: Execution at that seq failed or is taking too long
- **Solution**: Check execution_records for that seq, look for errors in logs
**Symptoms: Queue depth growing**
- **Cause**: Executions slower than enqueue rate
- **Solution**: Profile trigger execution, optimize slow paths, add rate limiting
## Testing
### Unit Test: Conflict Detection
```python
import pytest
from trigger import VersionedStore, PydanticStoreBackend, CommitCoordinator
@pytest.mark.asyncio
async def test_conflict_detection():
coordinator = CommitCoordinator()
store = VersionedStore("TestStore", PydanticStoreBackend(TestModel()))
coordinator.register_store(store)
# Seq 1: read at 0, modify, commit
seq1, data1 = store.read_snapshot()
data1.value = "seq1"
intent1 = store.prepare_commit(seq1, data1)
# Seq 2: read at 0 (same snapshot), modify
seq2, data2 = store.read_snapshot()
data2.value = "seq2"
intent2 = store.prepare_commit(seq2, data2)
# Commit seq 1 (should succeed)
# ... coordinator logic ...
# Commit seq 2 (should conflict and retry)
# ... verify conflict detected ...
```
## Future Enhancements
- **Distributed queue**: Redis-backed queue for multi-worker deployment
- **Event log persistence**: Store all commits for event sourcing/audit
- **Metrics dashboard**: Real-time view of queue depth, conflict rate, latency
- **Transaction snapshots**: Full system state at any seq for replay/debugging
- **Automatic batching**: Coalesce rapid updates to same store

View File

@@ -0,0 +1,35 @@
"""
Sequential execution trigger system with optimistic concurrency control.
All operations (websocket, cron, data events) flow through a priority queue,
execute in parallel, but commit in strict sequential order with conflict detection.
"""
from .context import ExecutionContext, get_execution_context
from .types import Priority, PriorityTuple, Trigger, CommitIntent, ExecutionState
from .store import VersionedStore, StoreBackend, PydanticStoreBackend
from .coordinator import CommitCoordinator
from .queue import TriggerQueue
from .handlers import AgentTriggerHandler, LambdaHandler
__all__ = [
# Context
"ExecutionContext",
"get_execution_context",
# Types
"Priority",
"PriorityTuple",
"Trigger",
"CommitIntent",
"ExecutionState",
# Store
"VersionedStore",
"StoreBackend",
"PydanticStoreBackend",
# Coordination
"CommitCoordinator",
"TriggerQueue",
# Handlers
"AgentTriggerHandler",
"LambdaHandler",
]

View File

@@ -0,0 +1,61 @@
"""
Execution context tracking using Python's contextvars.
Each execution gets a unique seq number that propagates through all async calls,
allowing us to track which execution made which changes for conflict detection.
"""
import logging
from contextvars import ContextVar
from dataclasses import dataclass, field
from typing import Optional
logger = logging.getLogger(__name__)
# Context variables - automatically propagate through async call chains
_execution_context: ContextVar[Optional["ExecutionContext"]] = ContextVar(
"execution_context", default=None
)
@dataclass
class ExecutionContext:
"""
Execution context for a single trigger execution.
Automatically propagates through async calls via contextvars.
Tracks the seq number and which store snapshots were read.
"""
seq: int
"""Sequential execution number - determines commit order"""
trigger_name: str
"""Name/type of trigger being executed"""
snapshot_seqs: dict[str, int] = field(default_factory=dict)
"""Store name -> seq number of snapshot that was read"""
def record_snapshot(self, store_name: str, snapshot_seq: int) -> None:
"""Record that we read a snapshot from a store at a specific seq"""
self.snapshot_seqs[store_name] = snapshot_seq
logger.debug(f"Seq {self.seq}: Read {store_name} at seq {snapshot_seq}")
def __str__(self) -> str:
return f"ExecutionContext(seq={self.seq}, trigger={self.trigger_name})"
def get_execution_context() -> Optional[ExecutionContext]:
"""Get the current execution context, or None if not in an execution"""
return _execution_context.get()
def set_execution_context(ctx: ExecutionContext) -> None:
"""Set the execution context for the current async task"""
_execution_context.set(ctx)
logger.debug(f"Set execution context: {ctx}")
def clear_execution_context() -> None:
"""Clear the execution context"""
_execution_context.set(None)

View File

@@ -0,0 +1,302 @@
"""
Commit coordinator - manages sequential commits with conflict detection.
Ensures that commits happen in strict sequence order, even when executions
complete out of order. Detects conflicts and triggers re-execution with the
same seq number (not re-enqueue, just re-execute).
"""
import asyncio
import logging
from typing import Optional
from .context import ExecutionContext
from .store import VersionedStore
from .types import CommitIntent, ExecutionRecord, ExecutionState, Trigger
logger = logging.getLogger(__name__)
class CommitCoordinator:
"""
Manages sequential commits with optimistic concurrency control.
Key responsibilities:
- Maintain strict sequential commit order (seq N+1 commits after seq N)
- Detect conflicts between execution snapshot and committed state
- Trigger re-execution (not re-enqueue) on conflicts with same seq
- Track in-flight executions for debugging and monitoring
"""
def __init__(self):
self._stores: dict[str, VersionedStore] = {}
self._current_seq = 0 # Highest committed seq across all operations
self._next_commit_seq = 1 # Next seq we're waiting to commit
self._pending_commits: dict[int, tuple[ExecutionRecord, list[CommitIntent]]] = {}
self._execution_records: dict[int, ExecutionRecord] = {}
self._lock = asyncio.Lock() # Only for coordinator internal state, not stores
def register_store(self, store: VersionedStore) -> None:
"""Register a versioned store with the coordinator"""
self._stores[store.name] = store
logger.info(f"Registered store: {store.name}")
def get_store(self, name: str) -> Optional[VersionedStore]:
"""Get a registered store by name"""
return self._stores.get(name)
async def start_execution(self, seq: int, trigger: Trigger) -> ExecutionRecord:
"""
Record that an execution is starting.
Args:
seq: Sequence number assigned to this execution
trigger: The trigger being executed
Returns:
ExecutionRecord for tracking
"""
async with self._lock:
record = ExecutionRecord(
seq=seq,
trigger=trigger,
state=ExecutionState.EXECUTING,
)
self._execution_records[seq] = record
logger.info(f"Started execution: seq={seq}, trigger={trigger.name}")
return record
async def submit_for_commit(
self,
seq: int,
commit_intents: list[CommitIntent],
) -> None:
"""
Submit commit intents for sequential commit.
The commit will only happen when:
1. All prior seq numbers have committed
2. No conflicts detected with committed state
Args:
seq: Sequence number of this execution
commit_intents: List of changes to commit (empty if no changes)
"""
async with self._lock:
record = self._execution_records.get(seq)
if not record:
logger.error(f"No execution record found for seq={seq}")
return
record.state = ExecutionState.WAITING_COMMIT
record.commit_intents = commit_intents
self._pending_commits[seq] = (record, commit_intents)
logger.info(
f"Seq {seq} submitted for commit with {len(commit_intents)} intents"
)
# Try to process commits (this will handle sequential ordering)
await self._process_commits()
async def _process_commits(self) -> None:
"""
Process pending commits in strict sequential order.
Only commits seq N if seq N-1 has already committed.
Detects conflicts and triggers re-execution with same seq.
"""
while True:
async with self._lock:
# Check if next expected seq is ready to commit
if self._next_commit_seq not in self._pending_commits:
# Waiting for this seq to complete execution
break
seq = self._next_commit_seq
record, intents = self._pending_commits[seq]
logger.info(
f"Processing commit for seq={seq} (current_seq={self._current_seq})"
)
# Check for conflicts
conflicts = self._check_conflicts(intents)
if conflicts:
# Conflict detected - re-execute with same seq
logger.warning(
f"Seq {seq} has conflicts in stores: {conflicts}. Re-executing..."
)
# Remove from pending (will be re-added when execution completes)
del self._pending_commits[seq]
# Mark as evicted
record.state = ExecutionState.EVICTED
record.retry_count += 1
# Advance to next seq (this seq will be retried in background)
self._next_commit_seq += 1
self._current_seq += 1
# Trigger re-execution (outside lock)
asyncio.create_task(self._retry_execution(record))
continue
# No conflicts - commit all intents atomically
for intent in intents:
store = self._stores.get(intent.store_name)
if not store:
logger.error(
f"Seq {seq}: Store '{intent.store_name}' not found"
)
continue
store.commit(intent.new_data, seq)
# Mark as committed
record.state = ExecutionState.COMMITTED
del self._pending_commits[seq]
# Advance seq counters
self._current_seq = seq
self._next_commit_seq = seq + 1
logger.info(
f"Committed seq={seq}, current_seq now {self._current_seq}"
)
def _check_conflicts(self, intents: list[CommitIntent]) -> list[str]:
"""
Check if any commit intents conflict with current committed state.
Args:
intents: List of commit intents to check
Returns:
List of store names that have conflicts (empty if no conflicts)
"""
conflicts = []
for intent in intents:
store = self._stores.get(intent.store_name)
if not store:
logger.error(f"Store '{intent.store_name}' not found during conflict check")
continue
if store.check_conflict(intent.expected_seq):
conflicts.append(intent.store_name)
return conflicts
async def _retry_execution(self, record: ExecutionRecord) -> None:
"""
Re-execute a trigger that had conflicts.
Executes with the SAME seq number (not re-enqueued, just re-executed).
This ensures the execution order remains deterministic.
Args:
record: Execution record to retry
"""
from .context import ExecutionContext, set_execution_context, clear_execution_context
logger.info(
f"Retrying execution: seq={record.seq}, trigger={record.trigger.name}, "
f"retry_count={record.retry_count}"
)
# Set execution context for retry
ctx = ExecutionContext(
seq=record.seq,
trigger_name=record.trigger.name,
)
set_execution_context(ctx)
try:
# Re-execute trigger
record.state = ExecutionState.EXECUTING
commit_intents = await record.trigger.execute()
# Submit for commit again (with same seq)
await self.submit_for_commit(record.seq, commit_intents)
except Exception as e:
logger.error(
f"Retry execution failed for seq={record.seq}: {e}", exc_info=True
)
record.state = ExecutionState.FAILED
record.error = str(e)
# Still need to advance past this seq
async with self._lock:
if record.seq == self._next_commit_seq:
self._next_commit_seq += 1
self._current_seq += 1
# Try to process any pending commits
await self._process_commits()
finally:
clear_execution_context()
async def execution_failed(self, seq: int, error: Exception) -> None:
"""
Mark an execution as failed.
Args:
seq: Sequence number that failed
error: The exception that caused the failure
"""
async with self._lock:
record = self._execution_records.get(seq)
if record:
record.state = ExecutionState.FAILED
record.error = str(error)
# Remove from pending if present
self._pending_commits.pop(seq, None)
# If this is the next seq to commit, advance past it
if seq == self._next_commit_seq:
self._next_commit_seq += 1
self._current_seq += 1
logger.info(
f"Seq {seq} failed, advancing current_seq to {self._current_seq}"
)
# Try to process any pending commits
await self._process_commits()
def get_current_seq(self) -> int:
"""Get the current committed sequence number"""
return self._current_seq
def get_execution_record(self, seq: int) -> Optional[ExecutionRecord]:
"""Get execution record for a specific seq"""
return self._execution_records.get(seq)
def get_stats(self) -> dict:
"""Get statistics about the coordinator state"""
state_counts = {}
for record in self._execution_records.values():
state_name = record.state.name
state_counts[state_name] = state_counts.get(state_name, 0) + 1
return {
"current_seq": self._current_seq,
"next_commit_seq": self._next_commit_seq,
"pending_commits": len(self._pending_commits),
"total_executions": len(self._execution_records),
"state_counts": state_counts,
"stores": {name: str(store) for name, store in self._stores.items()},
}
def __repr__(self) -> str:
return (
f"CommitCoordinator(current_seq={self._current_seq}, "
f"pending={len(self._pending_commits)}, stores={len(self._stores)})"
)

View File

@@ -0,0 +1,304 @@
"""
Trigger handlers - concrete implementations for common trigger types.
Provides ready-to-use trigger handlers for:
- Agent execution (WebSocket user messages)
- Lambda/callable execution
- Data update triggers
- Indicator updates
"""
import logging
import time
from typing import Any, Awaitable, Callable, Optional
from .coordinator import CommitCoordinator
from .types import CommitIntent, Priority, Trigger
logger = logging.getLogger(__name__)
class AgentTriggerHandler(Trigger):
"""
Trigger for agent execution from WebSocket user messages.
Wraps the Gateway's agent execution flow and captures any
store modifications as commit intents.
Priority tuple: (USER_AGENT, message_timestamp, queue_seq)
"""
def __init__(
self,
session_id: str,
message_content: str,
message_timestamp: Optional[int] = None,
attachments: Optional[list] = None,
gateway_handler: Optional[Callable] = None,
coordinator: Optional[CommitCoordinator] = None,
):
"""
Initialize agent trigger.
Args:
session_id: User session ID
message_content: User message content
message_timestamp: When user sent message (unix timestamp, defaults to now)
attachments: Optional message attachments
gateway_handler: Callable to route to Gateway (set during integration)
coordinator: CommitCoordinator for accessing stores
"""
if message_timestamp is None:
message_timestamp = int(time.time())
# Priority tuple: sort by USER_AGENT priority, then message timestamp
super().__init__(
name=f"agent_{session_id}",
priority=Priority.USER_AGENT,
priority_tuple=(Priority.USER_AGENT.value, message_timestamp)
)
self.session_id = session_id
self.message_content = message_content
self.message_timestamp = message_timestamp
self.attachments = attachments or []
self.gateway_handler = gateway_handler
self.coordinator = coordinator
async def execute(self) -> list[CommitIntent]:
"""
Execute agent interaction.
This will call into the Gateway, which will run the agent.
The agent may read from stores and generate responses.
Any store modifications are captured as commit intents.
Returns:
List of commit intents (typically empty for now, as agent
modifies stores via tools which will be integrated later)
"""
if not self.gateway_handler:
logger.error("No gateway_handler configured for AgentTriggerHandler")
return []
logger.info(
f"Agent trigger executing: session={self.session_id}, "
f"content='{self.message_content[:50]}...'"
)
try:
# Call Gateway to handle message
# In future, Gateway/agent tools will use coordinator stores
await self.gateway_handler(
self.session_id,
self.message_content,
self.attachments,
)
# For now, agent doesn't directly modify stores
# Future: agent tools will return commit intents
return []
except Exception as e:
logger.error(f"Agent execution error: {e}", exc_info=True)
raise
class LambdaHandler(Trigger):
"""
Generic trigger that executes an arbitrary async callable.
Useful for custom triggers, one-off tasks, or testing.
"""
def __init__(
self,
name: str,
func: Callable[[], Awaitable[list[CommitIntent]]],
priority: Priority = Priority.SYSTEM,
):
"""
Initialize lambda handler.
Args:
name: Descriptive name for this trigger
func: Async callable that returns commit intents
priority: Execution priority
"""
super().__init__(name, priority)
self.func = func
async def execute(self) -> list[CommitIntent]:
"""Execute the callable"""
logger.info(f"Lambda trigger executing: {self.name}")
return await self.func()
class DataUpdateTrigger(Trigger):
"""
Trigger for DataSource bar updates.
Fired when new market data arrives. Can update indicators,
trigger strategy logic, or notify the agent of market events.
Priority tuple: (DATA_SOURCE, event_time, queue_seq)
Ensures older bars process before newer ones.
"""
def __init__(
self,
source_name: str,
symbol: str,
resolution: str,
bar_data: dict,
coordinator: Optional[CommitCoordinator] = None,
):
"""
Initialize data update trigger.
Args:
source_name: Name of data source (e.g., "binance")
symbol: Trading pair symbol
resolution: Time resolution
bar_data: Bar data dict (time, open, high, low, close, volume)
coordinator: CommitCoordinator for accessing stores
"""
event_time = bar_data.get('time', int(time.time()))
# Priority tuple: sort by DATA_SOURCE priority, then event time
super().__init__(
name=f"data_{source_name}_{symbol}_{resolution}",
priority=Priority.DATA_SOURCE,
priority_tuple=(Priority.DATA_SOURCE.value, event_time)
)
self.source_name = source_name
self.symbol = symbol
self.resolution = resolution
self.bar_data = bar_data
self.coordinator = coordinator
async def execute(self) -> list[CommitIntent]:
"""
Process bar update.
Future implementations will:
- Update indicator values
- Check strategy conditions
- Trigger alerts/notifications
Returns:
Commit intents for any store updates
"""
logger.info(
f"Data update trigger: {self.source_name}:{self.symbol}@{self.resolution}, "
f"time={self.bar_data.get('time')}"
)
# TODO: Update indicators
# TODO: Check strategy conditions
# TODO: Notify agent of significant events
# For now, just log
return []
class IndicatorUpdateTrigger(Trigger):
"""
Trigger for updating indicator values.
Can be fired by cron (periodic recalculation) or by data updates.
"""
def __init__(
self,
indicator_id: str,
force_full_recalc: bool = False,
coordinator: Optional[CommitCoordinator] = None,
priority: Priority = Priority.SYSTEM,
):
"""
Initialize indicator update trigger.
Args:
indicator_id: ID of indicator to update
force_full_recalc: If True, recalculate entire history
coordinator: CommitCoordinator for accessing stores
priority: Execution priority
"""
super().__init__(f"indicator_{indicator_id}", priority)
self.indicator_id = indicator_id
self.force_full_recalc = force_full_recalc
self.coordinator = coordinator
async def execute(self) -> list[CommitIntent]:
"""
Update indicator value.
Reads from IndicatorStore, recalculates, prepares commit.
Returns:
Commit intents for updated indicator data
"""
if not self.coordinator:
logger.error("No coordinator configured")
return []
# Get indicator store
indicator_store = self.coordinator.get_store("IndicatorStore")
if not indicator_store:
logger.error("IndicatorStore not registered")
return []
# Read snapshot
snapshot_seq, indicator_data = indicator_store.read_snapshot()
logger.info(
f"Indicator update trigger: {self.indicator_id}, "
f"snapshot_seq={snapshot_seq}, force_full={self.force_full_recalc}"
)
# TODO: Implement indicator recalculation logic
# For now, just return empty (no changes)
return []
class CronTrigger(Trigger):
"""
Trigger fired by APScheduler on a schedule.
Wraps another trigger or callable to execute periodically.
Priority tuple: (TIMER, scheduled_time, queue_seq)
Ensures jobs scheduled for earlier times run first.
"""
def __init__(
self,
name: str,
inner_trigger: Trigger,
scheduled_time: Optional[int] = None,
):
"""
Initialize cron trigger.
Args:
name: Descriptive name (e.g., "hourly_sync")
inner_trigger: Trigger to execute on schedule
scheduled_time: When this was scheduled to run (defaults to now)
"""
if scheduled_time is None:
scheduled_time = int(time.time())
# Priority tuple: sort by TIMER priority, then scheduled time
super().__init__(
name=f"cron_{name}",
priority=Priority.TIMER,
priority_tuple=(Priority.TIMER.value, scheduled_time)
)
self.inner_trigger = inner_trigger
self.scheduled_time = scheduled_time
async def execute(self) -> list[CommitIntent]:
"""Execute the wrapped trigger"""
logger.info(f"Cron trigger firing: {self.name}")
return await self.inner_trigger.execute()

View File

@@ -0,0 +1,224 @@
"""
Trigger queue - priority queue with sequence number assignment.
All operations flow through this queue:
- WebSocket messages from users
- Cron scheduled tasks
- DataSource bar updates
- Manual triggers
Queue assigns seq numbers on dequeue, executes triggers, and submits to coordinator.
"""
import asyncio
import logging
from typing import Optional
from .context import ExecutionContext, clear_execution_context, set_execution_context
from .coordinator import CommitCoordinator
from .types import Priority, PriorityTuple, Trigger
logger = logging.getLogger(__name__)
class TriggerQueue:
"""
Priority queue for trigger execution.
Key responsibilities:
- Maintain priority queue (high priority dequeued first)
- Assign sequence numbers on dequeue (determines commit order)
- Execute triggers with context set
- Submit results to CommitCoordinator
- Handle execution errors gracefully
"""
def __init__(self, coordinator: CommitCoordinator):
"""
Initialize trigger queue.
Args:
coordinator: CommitCoordinator for handling commits
"""
self._coordinator = coordinator
self._queue: asyncio.PriorityQueue = asyncio.PriorityQueue()
self._seq_counter = 0
self._seq_lock = asyncio.Lock()
self._processor_task: Optional[asyncio.Task] = None
self._running = False
async def start(self) -> None:
"""Start the queue processor"""
if self._running:
logger.warning("TriggerQueue already running")
return
self._running = True
self._processor_task = asyncio.create_task(self._process_loop())
logger.info("TriggerQueue started")
async def stop(self) -> None:
"""Stop the queue processor gracefully"""
if not self._running:
return
self._running = False
if self._processor_task:
self._processor_task.cancel()
try:
await self._processor_task
except asyncio.CancelledError:
pass
logger.info("TriggerQueue stopped")
async def enqueue(
self,
trigger: Trigger,
priority_override: Optional[Priority | PriorityTuple] = None
) -> int:
"""
Add a trigger to the queue.
Args:
trigger: Trigger to execute
priority_override: Override priority (simple Priority or tuple)
If None, uses trigger's priority/priority_tuple
If Priority enum, creates single-element tuple
If tuple, uses as-is
Returns:
Queue sequence number (appended to priority tuple)
Examples:
# Simple priority
await queue.enqueue(trigger, Priority.USER_AGENT)
# Results in: (Priority.USER_AGENT, queue_seq)
# Tuple priority with event time
await queue.enqueue(
trigger,
(Priority.DATA_SOURCE, bar_data['time'])
)
# Results in: (Priority.DATA_SOURCE, bar_time, queue_seq)
# Let trigger decide
await queue.enqueue(trigger)
"""
# Get monotonic seq for queue ordering (appended to tuple)
async with self._seq_lock:
queue_seq = self._seq_counter
self._seq_counter += 1
# Determine priority tuple
if priority_override is not None:
if isinstance(priority_override, Priority):
# Convert simple priority to tuple
priority_tuple = (priority_override.value, queue_seq)
else:
# Use provided tuple, append queue_seq
priority_tuple = priority_override + (queue_seq,)
else:
# Let trigger determine its own priority tuple
priority_tuple = trigger.get_priority_tuple(queue_seq)
# Priority queue: (priority_tuple, trigger)
# Python's PriorityQueue compares tuples element-by-element
await self._queue.put((priority_tuple, trigger))
logger.debug(
f"Enqueued: {trigger.name} with priority_tuple={priority_tuple}"
)
return queue_seq
async def _process_loop(self) -> None:
"""
Main processing loop.
Dequeues triggers, assigns execution seq, executes, and submits to coordinator.
"""
execution_seq = 0 # Separate counter for execution sequence
while self._running:
try:
# Wait for next trigger (with timeout to check _running flag)
try:
priority_tuple, trigger = await asyncio.wait_for(
self._queue.get(), timeout=1.0
)
except asyncio.TimeoutError:
continue
# Assign execution sequence number
execution_seq += 1
logger.info(
f"Dequeued: seq={execution_seq}, trigger={trigger.name}, "
f"priority_tuple={priority_tuple}"
)
# Execute in background (don't block queue)
asyncio.create_task(
self._execute_trigger(execution_seq, trigger)
)
except Exception as e:
logger.error(f"Error in process loop: {e}", exc_info=True)
async def _execute_trigger(self, seq: int, trigger: Trigger) -> None:
"""
Execute a trigger with proper context and error handling.
Args:
seq: Execution sequence number
trigger: Trigger to execute
"""
# Set up execution context
ctx = ExecutionContext(
seq=seq,
trigger_name=trigger.name,
)
set_execution_context(ctx)
# Record execution start with coordinator
await self._coordinator.start_execution(seq, trigger)
try:
logger.info(f"Executing: seq={seq}, trigger={trigger.name}")
# Execute trigger (can be long-running)
commit_intents = await trigger.execute()
logger.info(
f"Execution complete: seq={seq}, {len(commit_intents)} commit intents"
)
# Submit for sequential commit
await self._coordinator.submit_for_commit(seq, commit_intents)
except Exception as e:
logger.error(
f"Execution failed: seq={seq}, trigger={trigger.name}, error={e}",
exc_info=True,
)
# Notify coordinator of failure
await self._coordinator.execution_failed(seq, e)
finally:
clear_execution_context()
def get_queue_size(self) -> int:
"""Get current queue size (approximate)"""
return self._queue.qsize()
def is_running(self) -> bool:
"""Check if queue processor is running"""
return self._running
def __repr__(self) -> str:
return (
f"TriggerQueue(running={self._running}, queue_size={self.get_queue_size()})"
)

View File

@@ -0,0 +1,187 @@
"""
APScheduler integration for cron-style triggers.
Provides scheduling of periodic triggers (e.g., sync exchange state hourly,
recompute indicators every 5 minutes, daily portfolio reports).
"""
import logging
from typing import Optional
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger as APSCronTrigger
from apscheduler.triggers.interval import IntervalTrigger
from .queue import TriggerQueue
from .types import Priority, Trigger
logger = logging.getLogger(__name__)
class TriggerScheduler:
"""
Scheduler for periodic trigger execution.
Wraps APScheduler to enqueue triggers at scheduled times.
"""
def __init__(self, trigger_queue: TriggerQueue):
"""
Initialize scheduler.
Args:
trigger_queue: TriggerQueue to enqueue triggers into
"""
self.trigger_queue = trigger_queue
self.scheduler = AsyncIOScheduler()
self._job_counter = 0
def start(self) -> None:
"""Start the scheduler"""
self.scheduler.start()
logger.info("TriggerScheduler started")
def shutdown(self, wait: bool = True) -> None:
"""
Shut down the scheduler.
Args:
wait: If True, wait for running jobs to complete
"""
self.scheduler.shutdown(wait=wait)
logger.info("TriggerScheduler shut down")
def schedule_interval(
self,
trigger: Trigger,
seconds: Optional[int] = None,
minutes: Optional[int] = None,
hours: Optional[int] = None,
priority: Optional[Priority] = None,
) -> str:
"""
Schedule a trigger to run at regular intervals.
Args:
trigger: Trigger to execute
seconds: Interval in seconds
minutes: Interval in minutes
hours: Interval in hours
priority: Priority override for execution
Returns:
Job ID (can be used to remove job later)
Example:
# Run every 5 minutes
scheduler.schedule_interval(
IndicatorUpdateTrigger("rsi_14"),
minutes=5
)
"""
job_id = f"interval_{self._job_counter}"
self._job_counter += 1
async def job_func():
await self.trigger_queue.enqueue(trigger, priority)
self.scheduler.add_job(
job_func,
trigger=IntervalTrigger(seconds=seconds, minutes=minutes, hours=hours),
id=job_id,
name=f"Interval: {trigger.name}",
)
logger.info(
f"Scheduled interval job: {job_id}, trigger={trigger.name}, "
f"interval=(s={seconds}, m={minutes}, h={hours})"
)
return job_id
def schedule_cron(
self,
trigger: Trigger,
minute: Optional[str] = None,
hour: Optional[str] = None,
day: Optional[str] = None,
month: Optional[str] = None,
day_of_week: Optional[str] = None,
priority: Optional[Priority] = None,
) -> str:
"""
Schedule a trigger to run on a cron schedule.
Args:
trigger: Trigger to execute
minute: Minute expression (0-59, *, */5, etc.)
hour: Hour expression (0-23, *, etc.)
day: Day of month expression (1-31, *, etc.)
month: Month expression (1-12, *, etc.)
day_of_week: Day of week expression (0-6, mon-sun, *, etc.)
priority: Priority override for execution
Returns:
Job ID (can be used to remove job later)
Example:
# Run at 9:00 AM every weekday
scheduler.schedule_cron(
SyncExchangeStateTrigger(),
hour="9",
minute="0",
day_of_week="mon-fri"
)
"""
job_id = f"cron_{self._job_counter}"
self._job_counter += 1
async def job_func():
await self.trigger_queue.enqueue(trigger, priority)
self.scheduler.add_job(
job_func,
trigger=APSCronTrigger(
minute=minute,
hour=hour,
day=day,
month=month,
day_of_week=day_of_week,
),
id=job_id,
name=f"Cron: {trigger.name}",
)
logger.info(
f"Scheduled cron job: {job_id}, trigger={trigger.name}, "
f"schedule=(m={minute}, h={hour}, d={day}, dow={day_of_week})"
)
return job_id
def remove_job(self, job_id: str) -> bool:
"""
Remove a scheduled job.
Args:
job_id: Job ID returned from schedule_* methods
Returns:
True if job was removed, False if not found
"""
try:
self.scheduler.remove_job(job_id)
logger.info(f"Removed scheduled job: {job_id}")
return True
except Exception as e:
logger.warning(f"Could not remove job {job_id}: {e}")
return False
def get_jobs(self) -> list:
"""Get list of all scheduled jobs"""
return self.scheduler.get_jobs()
def __repr__(self) -> str:
job_count = len(self.scheduler.get_jobs())
running = self.scheduler.running
return f"TriggerScheduler(running={running}, jobs={job_count})"

View File

@@ -0,0 +1,301 @@
"""
Versioned store with pluggable backends.
Provides optimistic concurrency control via sequence numbers with support
for different storage backends (Pydantic models, files, databases, etc.).
"""
import logging
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any, Generic, TypeVar
from .context import get_execution_context
from .types import CommitIntent
logger = logging.getLogger(__name__)
T = TypeVar("T")
class StoreBackend(ABC, Generic[T]):
"""
Abstract backend for versioned stores.
Allows different storage mechanisms (Pydantic models, files, databases)
to be used with the same versioned store infrastructure.
"""
@abstractmethod
def read(self) -> T:
"""
Read the current data.
Returns:
Current data in backend-specific format
"""
pass
@abstractmethod
def write(self, data: T) -> None:
"""
Write new data (replaces existing).
Args:
data: New data to write
"""
pass
@abstractmethod
def snapshot(self) -> T:
"""
Create an immutable snapshot of current data.
Must return a deep copy or immutable version to prevent
modifications from affecting the committed state.
Returns:
Immutable snapshot of data
"""
pass
@abstractmethod
def validate(self, data: T) -> bool:
"""
Validate that data is in correct format for this backend.
Args:
data: Data to validate
Returns:
True if valid
Raises:
ValueError: If invalid with explanation
"""
pass
class PydanticStoreBackend(StoreBackend[T]):
"""
Backend for Pydantic BaseModel stores.
Supports the existing OrderStore, ChartStore, etc. pattern.
"""
def __init__(self, model_instance: T):
"""
Initialize with a Pydantic model instance.
Args:
model_instance: Instance of a Pydantic BaseModel
"""
self._model = model_instance
def read(self) -> T:
return self._model
def write(self, data: T) -> None:
# Replace the internal model
self._model = data
def snapshot(self) -> T:
# Use Pydantic's model_copy for deep copy
if hasattr(self._model, "model_copy"):
return self._model.model_copy(deep=True)
# Fallback for older Pydantic or non-model types
return deepcopy(self._model)
def validate(self, data: T) -> bool:
# Pydantic models validate themselves on construction
# If we got here with a model instance, it's valid
return True
class FileStoreBackend(StoreBackend[str]):
"""
Backend for file-based storage.
Future implementation for versioning files (e.g., Python scripts, configs).
"""
def __init__(self, file_path: str):
self.file_path = file_path
raise NotImplementedError("FileStoreBackend not yet implemented")
def read(self) -> str:
raise NotImplementedError()
def write(self, data: str) -> None:
raise NotImplementedError()
def snapshot(self) -> str:
raise NotImplementedError()
def validate(self, data: str) -> bool:
raise NotImplementedError()
class DatabaseStoreBackend(StoreBackend[dict]):
"""
Backend for database table storage.
Future implementation for versioning database interactions.
"""
def __init__(self, table_name: str, connection):
self.table_name = table_name
self.connection = connection
raise NotImplementedError("DatabaseStoreBackend not yet implemented")
def read(self) -> dict:
raise NotImplementedError()
def write(self, data: dict) -> None:
raise NotImplementedError()
def snapshot(self) -> dict:
raise NotImplementedError()
def validate(self, data: dict) -> bool:
raise NotImplementedError()
class VersionedStore(Generic[T]):
"""
Store with optimistic concurrency control via sequence numbers.
Wraps any StoreBackend and provides:
- Lock-free snapshot reads
- Conflict detection on commit
- Version tracking for debugging
"""
def __init__(self, name: str, backend: StoreBackend[T]):
"""
Initialize versioned store.
Args:
name: Unique name for this store (e.g., "OrderStore")
backend: Backend implementation for storage
"""
self.name = name
self._backend = backend
self._committed_seq = 0 # Highest committed seq
self._version = 0 # Increments on each commit (for debugging)
@property
def committed_seq(self) -> int:
"""Get the current committed sequence number"""
return self._committed_seq
@property
def version(self) -> int:
"""Get the current version (increments on each commit)"""
return self._version
def read_snapshot(self) -> tuple[int, T]:
"""
Read an immutable snapshot of the store.
This is lock-free and can be called concurrently. The snapshot
captures the current committed seq and a deep copy of the data.
Automatically records the snapshot seq in the execution context
for conflict detection during commit.
Returns:
Tuple of (seq, snapshot_data)
"""
snapshot_seq = self._committed_seq
snapshot_data = self._backend.snapshot()
# Record in execution context for conflict detection
ctx = get_execution_context()
if ctx:
ctx.record_snapshot(self.name, snapshot_seq)
logger.debug(
f"Store '{self.name}': read_snapshot() -> seq={snapshot_seq}, version={self._version}"
)
return (snapshot_seq, snapshot_data)
def read_current(self) -> T:
"""
Read the current data without snapshot tracking.
Use this for read-only operations that don't need conflict detection.
Returns:
Current data (not a snapshot, modifications visible)
"""
return self._backend.read()
def prepare_commit(self, expected_seq: int, new_data: T) -> CommitIntent:
"""
Create a commit intent for later sequential commit.
Does NOT modify the store - that happens during the commit phase.
Args:
expected_seq: The seq of the snapshot that was read
new_data: The new data to commit
Returns:
CommitIntent to be submitted to CommitCoordinator
"""
# Validate data before creating intent
self._backend.validate(new_data)
intent = CommitIntent(
store_name=self.name,
expected_seq=expected_seq,
new_data=new_data,
)
logger.debug(
f"Store '{self.name}': prepare_commit(expected_seq={expected_seq}, current_seq={self._committed_seq})"
)
return intent
def commit(self, new_data: T, commit_seq: int) -> None:
"""
Commit new data at a specific seq.
Called by CommitCoordinator during sequential commit phase.
NOT for direct use by triggers.
Args:
new_data: Data to commit
commit_seq: Seq number of this commit
"""
self._backend.write(new_data)
self._committed_seq = commit_seq
self._version += 1
logger.info(
f"Store '{self.name}': committed seq={commit_seq}, version={self._version}"
)
def check_conflict(self, expected_seq: int) -> bool:
"""
Check if committing at expected_seq would conflict.
Args:
expected_seq: The seq that was expected during execution
Returns:
True if conflict (committed_seq has advanced beyond expected_seq)
"""
has_conflict = self._committed_seq != expected_seq
if has_conflict:
logger.warning(
f"Store '{self.name}': conflict detected - "
f"expected_seq={expected_seq}, committed_seq={self._committed_seq}"
)
return has_conflict
def __repr__(self) -> str:
return f"VersionedStore(name='{self.name}', committed_seq={self._committed_seq}, version={self._version})"

View File

@@ -0,0 +1,175 @@
"""
Core types for the trigger system.
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any, Optional
logger = logging.getLogger(__name__)
class Priority(IntEnum):
"""
Primary execution priority for triggers.
Lower numeric value = higher priority (dequeued first).
Priority hierarchy (highest to lowest):
- DATA_SOURCE: Market data, real-time feeds (most time-sensitive)
- TIMER: Scheduled tasks, cron jobs
- USER_AGENT: User-agent interactions (WebSocket chat)
- USER_DATA_REQUEST: User data requests (chart loads, symbol search)
- SYSTEM: Background tasks, cleanup
- LOW: Retries after conflicts, non-critical tasks
"""
DATA_SOURCE = 0 # Market data updates, real-time feeds
TIMER = 1 # Scheduled tasks, cron jobs
USER_AGENT = 2 # User-agent interactions (WebSocket chat)
USER_DATA_REQUEST = 3 # User data requests (chart loads, etc.)
SYSTEM = 4 # Background tasks, cleanup, etc.
LOW = 5 # Retries after conflicts, non-critical tasks
# Type alias for priority tuples
# Examples:
# (Priority.DATA_SOURCE,) - Simple priority
# (Priority.DATA_SOURCE, event_time) - Priority + event time
# (Priority.DATA_SOURCE, event_time, queue_seq) - Full ordering
#
# Python compares tuples element-by-element, left-to-right.
# Shorter tuple wins if all shared elements are equal.
PriorityTuple = tuple[int, ...]
class ExecutionState(IntEnum):
"""State of an execution in the system"""
QUEUED = 0 # In queue, waiting to be dequeued
EXECUTING = 1 # Currently executing
WAITING_COMMIT = 2 # Finished executing, waiting for sequential commit
COMMITTED = 3 # Successfully committed
EVICTED = 4 # Evicted due to conflict, will retry
FAILED = 5 # Failed with error
@dataclass
class CommitIntent:
"""
Intent to commit changes to a store.
Created during execution, validated and applied during sequential commit phase.
"""
store_name: str
"""Name of the store to commit to"""
expected_seq: int
"""The seq number of the snapshot that was read (for conflict detection)"""
new_data: Any
"""The new data to commit (format depends on store backend)"""
def __repr__(self) -> str:
data_preview = str(self.new_data)[:50]
return f"CommitIntent(store={self.store_name}, expected_seq={self.expected_seq}, data={data_preview}...)"
class Trigger(ABC):
"""
Abstract base class for all triggers.
A trigger represents a unit of work that:
1. Gets assigned a seq number when dequeued
2. Executes (potentially long-running, async)
3. Returns CommitIntents for any state changes
4. Waits for sequential commit
"""
def __init__(
self,
name: str,
priority: Priority = Priority.SYSTEM,
priority_tuple: Optional[PriorityTuple] = None
):
"""
Initialize trigger.
Args:
name: Descriptive name for logging
priority: Simple priority (used if priority_tuple not provided)
priority_tuple: Optional tuple for compound sorting
Examples:
(Priority.DATA_SOURCE, event_time)
(Priority.USER_AGENT, message_timestamp)
(Priority.TIMER, scheduled_time)
"""
self.name = name
self.priority = priority
self._priority_tuple = priority_tuple
def get_priority_tuple(self, queue_seq: int) -> PriorityTuple:
"""
Get the priority tuple for queue ordering.
If a priority tuple was provided at construction, append queue_seq.
Otherwise, create tuple from simple priority.
Args:
queue_seq: Queue insertion order (final sort key)
Returns:
Priority tuple for queue ordering
Examples:
(Priority.DATA_SOURCE,) + (queue_seq,) = (0, queue_seq)
(Priority.DATA_SOURCE, 1000) + (queue_seq,) = (0, 1000, queue_seq)
"""
if self._priority_tuple is not None:
return self._priority_tuple + (queue_seq,)
else:
return (self.priority.value, queue_seq)
@abstractmethod
async def execute(self) -> list[CommitIntent]:
"""
Execute the trigger logic.
Can be long-running and async. Should read from stores via
VersionedStore.read_snapshot() and return CommitIntents for any changes.
Returns:
List of CommitIntents (empty if no state changes)
Raises:
Exception: On execution failure (will be logged, no commit)
"""
pass
def __repr__(self) -> str:
return f"{self.__class__.__name__}(name='{self.name}', priority={self.priority.name})"
@dataclass
class ExecutionRecord:
"""
Record of an execution for tracking and debugging.
Maintained by the CommitCoordinator to track in-flight executions.
"""
seq: int
trigger: Trigger
state: ExecutionState
commit_intents: Optional[list[CommitIntent]] = None
error: Optional[str] = None
retry_count: int = 0
def __repr__(self) -> str:
return (
f"ExecutionRecord(seq={self.seq}, trigger={self.trigger.name}, "
f"state={self.state.name}, retry={self.retry_count})"
)

Some files were not shown because too many files have changed in this diff Show More