Compare commits
7 Commits
bf7af2b426
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| c76887ab92 | |||
| f6bd22a8ef | |||
| b9cc397e05 | |||
| e99ef5d2dd | |||
| 8ff277c8c6 | |||
| a50955558e | |||
| 185fa42caa |
175
.agents/skills/better-auth-best-practices/SKILL.md
Normal file
175
.agents/skills/better-auth-best-practices/SKILL.md
Normal file
@@ -0,0 +1,175 @@
|
||||
---
|
||||
name: better-auth-best-practices
|
||||
description: Configure Better Auth server and client, set up database adapters, manage sessions, add plugins, and handle environment variables. Use when users mention Better Auth, betterauth, auth.ts, or need to set up TypeScript authentication with email/password, OAuth, or plugin configuration.
|
||||
---
|
||||
|
||||
# Better Auth Integration Guide
|
||||
|
||||
**Always consult [better-auth.com/docs](https://better-auth.com/docs) for code examples and latest API.**
|
||||
|
||||
---
|
||||
|
||||
## Setup Workflow
|
||||
|
||||
1. Install: `npm install better-auth`
|
||||
2. Set env vars: `BETTER_AUTH_SECRET` and `BETTER_AUTH_URL`
|
||||
3. Create `auth.ts` with database + config
|
||||
4. Create route handler for your framework
|
||||
5. Run `npx @better-auth/cli@latest migrate`
|
||||
6. Verify: call `GET /api/auth/ok` — should return `{ status: "ok" }`
|
||||
|
||||
---
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Environment Variables
|
||||
- `BETTER_AUTH_SECRET` - Encryption secret (min 32 chars). Generate: `openssl rand -base64 32`
|
||||
- `BETTER_AUTH_URL` - Base URL (e.g., `https://example.com`)
|
||||
|
||||
Only define `baseURL`/`secret` in config if env vars are NOT set.
|
||||
|
||||
### File Location
|
||||
CLI looks for `auth.ts` in: `./`, `./lib`, `./utils`, or under `./src`. Use `--config` for custom path.
|
||||
|
||||
### CLI Commands
|
||||
- `npx @better-auth/cli@latest migrate` - Apply schema (built-in adapter)
|
||||
- `npx @better-auth/cli@latest generate` - Generate schema for Prisma/Drizzle
|
||||
- `npx @better-auth/cli mcp --cursor` - Add MCP to AI tools
|
||||
|
||||
**Re-run after adding/changing plugins.**
|
||||
|
||||
---
|
||||
|
||||
## Core Config Options
|
||||
|
||||
| Option | Notes |
|
||||
|--------|-------|
|
||||
| `appName` | Optional display name |
|
||||
| `baseURL` | Only if `BETTER_AUTH_URL` not set |
|
||||
| `basePath` | Default `/api/auth`. Set `/` for root. |
|
||||
| `secret` | Only if `BETTER_AUTH_SECRET` not set |
|
||||
| `database` | Required for most features. See adapters docs. |
|
||||
| `secondaryStorage` | Redis/KV for sessions & rate limits |
|
||||
| `emailAndPassword` | `{ enabled: true }` to activate |
|
||||
| `socialProviders` | `{ google: { clientId, clientSecret }, ... }` |
|
||||
| `plugins` | Array of plugins |
|
||||
| `trustedOrigins` | CSRF whitelist |
|
||||
|
||||
---
|
||||
|
||||
## Database
|
||||
|
||||
**Direct connections:** Pass `pg.Pool`, `mysql2` pool, `better-sqlite3`, or `bun:sqlite` instance.
|
||||
|
||||
**ORM adapters:** Import from `better-auth/adapters/drizzle`, `better-auth/adapters/prisma`, `better-auth/adapters/mongodb`.
|
||||
|
||||
**Critical:** Better Auth uses adapter model names, NOT underlying table names. If Prisma model is `User` mapping to table `users`, use `modelName: "user"` (Prisma reference), not `"users"`.
|
||||
|
||||
---
|
||||
|
||||
## Session Management
|
||||
|
||||
**Storage priority:**
|
||||
1. If `secondaryStorage` defined → sessions go there (not DB)
|
||||
2. Set `session.storeSessionInDatabase: true` to also persist to DB
|
||||
3. No database + `cookieCache` → fully stateless mode
|
||||
|
||||
**Cookie cache strategies:**
|
||||
- `compact` (default) - Base64url + HMAC. Smallest.
|
||||
- `jwt` - Standard JWT. Readable but signed.
|
||||
- `jwe` - Encrypted. Maximum security.
|
||||
|
||||
**Key options:** `session.expiresIn` (default 7 days), `session.updateAge` (refresh interval), `session.cookieCache.maxAge`, `session.cookieCache.version` (change to invalidate all sessions).
|
||||
|
||||
---
|
||||
|
||||
## User & Account Config
|
||||
|
||||
**User:** `user.modelName`, `user.fields` (column mapping), `user.additionalFields`, `user.changeEmail.enabled` (disabled by default), `user.deleteUser.enabled` (disabled by default).
|
||||
|
||||
**Account:** `account.modelName`, `account.accountLinking.enabled`, `account.storeAccountCookie` (for stateless OAuth).
|
||||
|
||||
**Required for registration:** `email` and `name` fields.
|
||||
|
||||
---
|
||||
|
||||
## Email Flows
|
||||
|
||||
- `emailVerification.sendVerificationEmail` - Must be defined for verification to work
|
||||
- `emailVerification.sendOnSignUp` / `sendOnSignIn` - Auto-send triggers
|
||||
- `emailAndPassword.sendResetPassword` - Password reset email handler
|
||||
|
||||
---
|
||||
|
||||
## Security
|
||||
|
||||
**In `advanced`:**
|
||||
- `useSecureCookies` - Force HTTPS cookies
|
||||
- `disableCSRFCheck` - ⚠️ Security risk
|
||||
- `disableOriginCheck` - ⚠️ Security risk
|
||||
- `crossSubDomainCookies.enabled` - Share cookies across subdomains
|
||||
- `ipAddress.ipAddressHeaders` - Custom IP headers for proxies
|
||||
- `database.generateId` - Custom ID generation or `"serial"`/`"uuid"`/`false`
|
||||
|
||||
**Rate limiting:** `rateLimit.enabled`, `rateLimit.window`, `rateLimit.max`, `rateLimit.storage` ("memory" | "database" | "secondary-storage").
|
||||
|
||||
---
|
||||
|
||||
## Hooks
|
||||
|
||||
**Endpoint hooks:** `hooks.before` / `hooks.after` - Array of `{ matcher, handler }`. Use `createAuthMiddleware`. Access `ctx.path`, `ctx.context.returned` (after), `ctx.context.session`.
|
||||
|
||||
**Database hooks:** `databaseHooks.user.create.before/after`, same for `session`, `account`. Useful for adding default values or post-creation actions.
|
||||
|
||||
**Hook context (`ctx.context`):** `session`, `secret`, `authCookies`, `password.hash()`/`verify()`, `adapter`, `internalAdapter`, `generateId()`, `tables`, `baseURL`.
|
||||
|
||||
---
|
||||
|
||||
## Plugins
|
||||
|
||||
**Import from dedicated paths for tree-shaking:**
|
||||
```
|
||||
import { twoFactor } from "better-auth/plugins/two-factor"
|
||||
```
|
||||
NOT `from "better-auth/plugins"`.
|
||||
|
||||
**Popular plugins:** `twoFactor`, `organization`, `passkey`, `magicLink`, `emailOtp`, `username`, `phoneNumber`, `admin`, `apiKey`, `bearer`, `jwt`, `multiSession`, `sso`, `oauthProvider`, `oidcProvider`, `openAPI`, `genericOAuth`.
|
||||
|
||||
Client plugins go in `createAuthClient({ plugins: [...] })`.
|
||||
|
||||
---
|
||||
|
||||
## Client
|
||||
|
||||
Import from: `better-auth/client` (vanilla), `better-auth/react`, `better-auth/vue`, `better-auth/svelte`, `better-auth/solid`.
|
||||
|
||||
Key methods: `signUp.email()`, `signIn.email()`, `signIn.social()`, `signOut()`, `useSession()`, `getSession()`, `revokeSession()`, `revokeSessions()`.
|
||||
|
||||
---
|
||||
|
||||
## Type Safety
|
||||
|
||||
Infer types: `typeof auth.$Infer.Session`, `typeof auth.$Infer.Session.user`.
|
||||
|
||||
For separate client/server projects: `createAuthClient<typeof auth>()`.
|
||||
|
||||
---
|
||||
|
||||
## Common Gotchas
|
||||
|
||||
1. **Model vs table name** - Config uses ORM model name, not DB table name
|
||||
2. **Plugin schema** - Re-run CLI after adding plugins
|
||||
3. **Secondary storage** - Sessions go there by default, not DB
|
||||
4. **Cookie cache** - Custom session fields NOT cached, always re-fetched
|
||||
5. **Stateless mode** - No DB = session in cookie only, logout on cache expiry
|
||||
6. **Change email flow** - Sends to current email first, then new email
|
||||
|
||||
---
|
||||
|
||||
## Resources
|
||||
|
||||
- [Docs](https://better-auth.com/docs)
|
||||
- [Options Reference](https://better-auth.com/docs/reference/options)
|
||||
- [LLMs.txt](https://better-auth.com/llms.txt)
|
||||
- [GitHub](https://github.com/better-auth/better-auth)
|
||||
- [Init Options Source](https://github.com/better-auth/better-auth/blob/main/packages/core/src/types/init-options.ts)
|
||||
321
.agents/skills/create-auth-skill/SKILL.md
Normal file
321
.agents/skills/create-auth-skill/SKILL.md
Normal file
@@ -0,0 +1,321 @@
|
||||
---
|
||||
name: create-auth-skill
|
||||
description: Scaffold and implement authentication in TypeScript/JavaScript apps using Better Auth. Detect frameworks, configure database adapters, set up route handlers, add OAuth providers, and create auth UI pages. Use when users want to add login, sign-up, or authentication to a new or existing project with Better Auth.
|
||||
---
|
||||
|
||||
# Create Auth Skill
|
||||
|
||||
Guide for adding authentication to TypeScript/JavaScript applications using Better Auth.
|
||||
|
||||
**For code examples and syntax, see [better-auth.com/docs](https://better-auth.com/docs).**
|
||||
|
||||
---
|
||||
|
||||
## Phase 1: Planning (REQUIRED before implementation)
|
||||
|
||||
Before writing any code, gather requirements by scanning the project and asking the user structured questions. This ensures the implementation matches their needs.
|
||||
|
||||
### Step 1: Scan the project
|
||||
|
||||
Analyze the codebase to auto-detect:
|
||||
- **Framework** — Look for `next.config`, `svelte.config`, `nuxt.config`, `astro.config`, `vite.config`, or Express/Hono entry files.
|
||||
- **Database/ORM** — Look for `prisma/schema.prisma`, `drizzle.config`, `package.json` deps (`pg`, `mysql2`, `better-sqlite3`, `mongoose`, `mongodb`).
|
||||
- **Existing auth** — Look for existing auth libraries (`next-auth`, `lucia`, `clerk`, `supabase/auth`, `firebase/auth`) in `package.json` or imports.
|
||||
- **Package manager** — Check for `pnpm-lock.yaml`, `yarn.lock`, `bun.lockb`, or `package-lock.json`.
|
||||
|
||||
Use what you find to pre-fill defaults and skip questions you can already answer.
|
||||
|
||||
### Step 2: Ask planning questions
|
||||
|
||||
Use the `AskQuestion` tool to ask the user **all applicable questions in a single call**. Skip any question you already have a confident answer for from the scan. Group them under a title like "Auth Setup Planning".
|
||||
|
||||
**Questions to ask:**
|
||||
|
||||
1. **Project type** (skip if detected)
|
||||
- Prompt: "What type of project is this?"
|
||||
- Options: New project from scratch | Adding auth to existing project | Migrating from another auth library
|
||||
|
||||
2. **Framework** (skip if detected)
|
||||
- Prompt: "Which framework are you using?"
|
||||
- Options: Next.js (App Router) | Next.js (Pages Router) | SvelteKit | Nuxt | Astro | Express | Hono | SolidStart | Other
|
||||
|
||||
3. **Database & ORM** (skip if detected)
|
||||
- Prompt: "Which database setup will you use?"
|
||||
- Options: PostgreSQL (Prisma) | PostgreSQL (Drizzle) | PostgreSQL (pg driver) | MySQL (Prisma) | MySQL (Drizzle) | MySQL (mysql2 driver) | SQLite (Prisma) | SQLite (Drizzle) | SQLite (better-sqlite3 driver) | MongoDB (Mongoose) | MongoDB (native driver)
|
||||
|
||||
4. **Authentication methods** (always ask, allow multiple)
|
||||
- Prompt: "Which sign-in methods do you need?"
|
||||
- Options: Email & password | Social OAuth (Google, GitHub, etc.) | Magic link (passwordless email) | Passkey (WebAuthn) | Phone number
|
||||
- `allow_multiple: true`
|
||||
|
||||
5. **Social providers** (only if they selected Social OAuth above — ask in a follow-up call)
|
||||
- Prompt: "Which social providers do you need?"
|
||||
- Options: Google | GitHub | Apple | Microsoft | Discord | Twitter/X
|
||||
- `allow_multiple: true`
|
||||
|
||||
6. **Email verification** (only if Email & password was selected above — ask in a follow-up call)
|
||||
- Prompt: "Do you want to require email verification?"
|
||||
- Options: Yes | No
|
||||
|
||||
7. **Email provider** (only if email verification is Yes, or if Password reset is selected in features — ask in a follow-up call)
|
||||
- Prompt: "How do you want to send emails?"
|
||||
- Options: Resend | Mock it for now (console.log)
|
||||
|
||||
8. **Features & plugins** (always ask, allow multiple)
|
||||
- Prompt: "Which additional features do you need?"
|
||||
- Options: Two-factor authentication (2FA) | Organizations / teams | Admin dashboard | API bearer tokens | Password reset | None of these
|
||||
- `allow_multiple: true`
|
||||
|
||||
9. **Auth pages** (always ask, allow multiple — pre-select based on earlier answers)
|
||||
- Prompt: "Which auth pages do you need?"
|
||||
- Options vary based on previous answers:
|
||||
- Always available: Sign in | Sign up
|
||||
- If Email & password selected: Forgot password | Reset password
|
||||
- If email verification enabled: Email verification
|
||||
- `allow_multiple: true`
|
||||
|
||||
10. **Auth UI style** (always ask)
|
||||
- Prompt: "What style do you want for the auth pages? Pick one or describe your own."
|
||||
- Options: Minimal & clean | Centered card with background | Split layout (form + hero image) | Floating / glassmorphism | Other (I'll describe)
|
||||
|
||||
### Step 3: Summarize the plan
|
||||
|
||||
After collecting answers, present a concise implementation plan as a markdown checklist. Example:
|
||||
|
||||
```
|
||||
## Auth Implementation Plan
|
||||
|
||||
- **Framework:** Next.js (App Router)
|
||||
- **Database:** PostgreSQL via Prisma
|
||||
- **Auth methods:** Email/password, Google OAuth, GitHub OAuth
|
||||
- **Plugins:** 2FA, Organizations, Email verification
|
||||
- **UI:** Custom forms
|
||||
|
||||
### Steps
|
||||
1. Install `better-auth` and `@better-auth/cli`
|
||||
2. Create `lib/auth.ts` with server config
|
||||
3. Create `lib/auth-client.ts` with React client
|
||||
4. Set up route handler at `app/api/auth/[...all]/route.ts`
|
||||
5. Configure Prisma adapter and generate schema
|
||||
6. Add Google & GitHub OAuth providers
|
||||
7. Enable `twoFactor` and `organization` plugins
|
||||
8. Set up email verification handler
|
||||
9. Run migrations
|
||||
10. Create sign-in / sign-up pages
|
||||
```
|
||||
|
||||
Ask the user to confirm the plan before proceeding to Phase 2.
|
||||
|
||||
---
|
||||
|
||||
## Phase 2: Implementation
|
||||
|
||||
Only proceed here after the user confirms the plan from Phase 1.
|
||||
|
||||
Follow the decision tree below, guided by the answers collected above.
|
||||
|
||||
```
|
||||
Is this a new/empty project?
|
||||
├─ YES → New project setup
|
||||
│ 1. Install better-auth (+ scoped packages per plan)
|
||||
│ 2. Create auth.ts with all planned config
|
||||
│ 3. Create auth-client.ts with framework client
|
||||
│ 4. Set up route handler
|
||||
│ 5. Set up environment variables
|
||||
│ 6. Run CLI migrate/generate
|
||||
│ 7. Add plugins from plan
|
||||
│ 8. Create auth UI pages
|
||||
│
|
||||
├─ MIGRATING → Migration from existing auth
|
||||
│ 1. Audit current auth for gaps
|
||||
│ 2. Plan incremental migration
|
||||
│ 3. Install better-auth alongside existing auth
|
||||
│ 4. Migrate routes, then session logic, then UI
|
||||
│ 5. Remove old auth library
|
||||
│ 6. See migration guides in docs
|
||||
│
|
||||
└─ ADDING → Add auth to existing project
|
||||
1. Analyze project structure
|
||||
2. Install better-auth
|
||||
3. Create auth config matching plan
|
||||
4. Add route handler
|
||||
5. Run schema migrations
|
||||
6. Integrate into existing pages
|
||||
7. Add planned plugins and features
|
||||
```
|
||||
|
||||
At the end of implementation, guide users thoroughly on remaining next steps (e.g., setting up OAuth app credentials, deploying env vars, testing flows).
|
||||
|
||||
---
|
||||
|
||||
## Installation
|
||||
|
||||
**Core:** `npm install better-auth`
|
||||
|
||||
**Scoped packages (as needed):**
|
||||
| Package | Use case |
|
||||
|---------|----------|
|
||||
| `@better-auth/passkey` | WebAuthn/Passkey auth |
|
||||
| `@better-auth/sso` | SAML/OIDC enterprise SSO |
|
||||
| `@better-auth/stripe` | Stripe payments |
|
||||
| `@better-auth/scim` | SCIM user provisioning |
|
||||
| `@better-auth/expo` | React Native/Expo |
|
||||
|
||||
---
|
||||
|
||||
## Environment Variables
|
||||
|
||||
```env
|
||||
BETTER_AUTH_SECRET=<32+ chars, generate with: openssl rand -base64 32>
|
||||
BETTER_AUTH_URL=http://localhost:3000
|
||||
DATABASE_URL=<your database connection string>
|
||||
```
|
||||
|
||||
Add OAuth secrets as needed: `GITHUB_CLIENT_ID`, `GITHUB_CLIENT_SECRET`, `GOOGLE_CLIENT_ID`, etc.
|
||||
|
||||
---
|
||||
|
||||
## Server Config (auth.ts)
|
||||
|
||||
**Location:** `lib/auth.ts` or `src/lib/auth.ts`
|
||||
|
||||
**Minimal config needs:**
|
||||
- `database` - Connection or adapter
|
||||
- `emailAndPassword: { enabled: true }` - For email/password auth
|
||||
|
||||
**Standard config adds:**
|
||||
- `socialProviders` - OAuth providers (google, github, etc.)
|
||||
- `emailVerification.sendVerificationEmail` - Email verification handler
|
||||
- `emailAndPassword.sendResetPassword` - Password reset handler
|
||||
|
||||
**Full config adds:**
|
||||
- `plugins` - Array of feature plugins
|
||||
- `session` - Expiry, cookie cache settings
|
||||
- `account.accountLinking` - Multi-provider linking
|
||||
- `rateLimit` - Rate limiting config
|
||||
|
||||
**Export types:** `export type Session = typeof auth.$Infer.Session`
|
||||
|
||||
---
|
||||
|
||||
## Client Config (auth-client.ts)
|
||||
|
||||
**Import by framework:**
|
||||
| Framework | Import |
|
||||
|-----------|--------|
|
||||
| React/Next.js | `better-auth/react` |
|
||||
| Vue | `better-auth/vue` |
|
||||
| Svelte | `better-auth/svelte` |
|
||||
| Solid | `better-auth/solid` |
|
||||
| Vanilla JS | `better-auth/client` |
|
||||
|
||||
**Client plugins** go in `createAuthClient({ plugins: [...] })`.
|
||||
|
||||
**Common exports:** `signIn`, `signUp`, `signOut`, `useSession`, `getSession`
|
||||
|
||||
---
|
||||
|
||||
## Route Handler Setup
|
||||
|
||||
| Framework | File | Handler |
|
||||
|-----------|------|---------|
|
||||
| Next.js App Router | `app/api/auth/[...all]/route.ts` | `toNextJsHandler(auth)` → export `{ GET, POST }` |
|
||||
| Next.js Pages | `pages/api/auth/[...all].ts` | `toNextJsHandler(auth)` → default export |
|
||||
| Express | Any file | `app.all("/api/auth/*", toNodeHandler(auth))` |
|
||||
| SvelteKit | `src/hooks.server.ts` | `svelteKitHandler(auth)` |
|
||||
| SolidStart | Route file | `solidStartHandler(auth)` |
|
||||
| Hono | Route file | `auth.handler(c.req.raw)` |
|
||||
|
||||
**Next.js Server Components:** Add `nextCookies()` plugin to auth config.
|
||||
|
||||
---
|
||||
|
||||
## Database Migrations
|
||||
|
||||
| Adapter | Command |
|
||||
|---------|---------|
|
||||
| Built-in Kysely | `npx @better-auth/cli@latest migrate` (applies directly) |
|
||||
| Prisma | `npx @better-auth/cli@latest generate --output prisma/schema.prisma` then `npx prisma migrate dev` |
|
||||
| Drizzle | `npx @better-auth/cli@latest generate --output src/db/auth-schema.ts` then `npx drizzle-kit push` |
|
||||
|
||||
**Re-run after adding plugins.**
|
||||
|
||||
---
|
||||
|
||||
## Database Adapters
|
||||
|
||||
| Database | Setup |
|
||||
|----------|-------|
|
||||
| SQLite | Pass `better-sqlite3` or `bun:sqlite` instance directly |
|
||||
| PostgreSQL | Pass `pg.Pool` instance directly |
|
||||
| MySQL | Pass `mysql2` pool directly |
|
||||
| Prisma | `prismaAdapter(prisma, { provider: "postgresql" })` from `better-auth/adapters/prisma` |
|
||||
| Drizzle | `drizzleAdapter(db, { provider: "pg" })` from `better-auth/adapters/drizzle` |
|
||||
| MongoDB | `mongodbAdapter(db)` from `better-auth/adapters/mongodb` |
|
||||
|
||||
---
|
||||
|
||||
## Common Plugins
|
||||
|
||||
| Plugin | Server Import | Client Import | Purpose |
|
||||
|--------|---------------|---------------|---------|
|
||||
| `twoFactor` | `better-auth/plugins` | `twoFactorClient` | 2FA with TOTP/OTP |
|
||||
| `organization` | `better-auth/plugins` | `organizationClient` | Teams/orgs |
|
||||
| `admin` | `better-auth/plugins` | `adminClient` | User management |
|
||||
| `bearer` | `better-auth/plugins` | - | API token auth |
|
||||
| `openAPI` | `better-auth/plugins` | - | API docs |
|
||||
| `passkey` | `@better-auth/passkey` | `passkeyClient` | WebAuthn |
|
||||
| `sso` | `@better-auth/sso` | - | Enterprise SSO |
|
||||
|
||||
**Plugin pattern:** Server plugin + client plugin + run migrations.
|
||||
|
||||
---
|
||||
|
||||
## Auth UI Implementation
|
||||
|
||||
**Sign in flow:**
|
||||
1. `signIn.email({ email, password })` or `signIn.social({ provider, callbackURL })`
|
||||
2. Handle `error` in response
|
||||
3. Redirect on success
|
||||
|
||||
**Session check (client):** `useSession()` hook returns `{ data: session, isPending }`
|
||||
|
||||
**Session check (server):** `auth.api.getSession({ headers: await headers() })`
|
||||
|
||||
**Protected routes:** Check session, redirect to `/sign-in` if null.
|
||||
|
||||
---
|
||||
|
||||
## Security Checklist
|
||||
|
||||
- [ ] `BETTER_AUTH_SECRET` set (32+ chars)
|
||||
- [ ] `advanced.useSecureCookies: true` in production
|
||||
- [ ] `trustedOrigins` configured
|
||||
- [ ] Rate limits enabled
|
||||
- [ ] Email verification enabled
|
||||
- [ ] Password reset implemented
|
||||
- [ ] 2FA for sensitive apps
|
||||
- [ ] CSRF protection NOT disabled
|
||||
- [ ] `account.accountLinking` reviewed
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Issue | Fix |
|
||||
|-------|-----|
|
||||
| "Secret not set" | Add `BETTER_AUTH_SECRET` env var |
|
||||
| "Invalid Origin" | Add domain to `trustedOrigins` |
|
||||
| Cookies not setting | Check `baseURL` matches domain; enable secure cookies in prod |
|
||||
| OAuth callback errors | Verify redirect URIs in provider dashboard |
|
||||
| Type errors after adding plugin | Re-run CLI generate/migrate |
|
||||
|
||||
---
|
||||
|
||||
## Resources
|
||||
|
||||
- [Docs](https://better-auth.com/docs)
|
||||
- [Examples](https://github.com/better-auth/examples)
|
||||
- [Plugins](https://better-auth.com/docs/concepts/plugins)
|
||||
- [CLI](https://better-auth.com/docs/concepts/cli)
|
||||
- [Migration Guides](https://better-auth.com/docs/guides)
|
||||
212
.agents/skills/email-and-password-best-practices/SKILL.md
Normal file
212
.agents/skills/email-and-password-best-practices/SKILL.md
Normal file
@@ -0,0 +1,212 @@
|
||||
---
|
||||
name: email-and-password-best-practices
|
||||
description: Configure email verification, implement password reset flows, set password policies, and customise hashing algorithms for Better Auth email/password authentication. Use when users need to set up login, sign-in, sign-up, credential authentication, or password security with Better Auth.
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
1. Enable email/password: `emailAndPassword: { enabled: true }`
|
||||
2. Configure `emailVerification.sendVerificationEmail`
|
||||
3. Add `sendResetPassword` for password reset flows
|
||||
4. Run `npx @better-auth/cli@latest migrate`
|
||||
5. Verify: attempt sign-up and confirm verification email triggers
|
||||
|
||||
---
|
||||
|
||||
## Email Verification Setup
|
||||
|
||||
Configure `emailVerification.sendVerificationEmail` to verify user email addresses.
|
||||
|
||||
```ts
|
||||
import { betterAuth } from "better-auth";
|
||||
import { sendEmail } from "./email"; // your email sending function
|
||||
|
||||
export const auth = betterAuth({
|
||||
emailVerification: {
|
||||
sendVerificationEmail: async ({ user, url, token }, request) => {
|
||||
await sendEmail({
|
||||
to: user.email,
|
||||
subject: "Verify your email address",
|
||||
text: `Click the link to verify your email: ${url}`,
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
**Note**: The `url` parameter contains the full verification link. The `token` is available if you need to build a custom verification URL.
|
||||
|
||||
### Requiring Email Verification
|
||||
|
||||
For stricter security, enable `emailAndPassword.requireEmailVerification` to block sign-in until the user verifies their email. When enabled, unverified users will receive a new verification email on each sign-in attempt.
|
||||
|
||||
```ts
|
||||
export const auth = betterAuth({
|
||||
emailAndPassword: {
|
||||
requireEmailVerification: true,
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
**Note**: This requires `sendVerificationEmail` to be configured and only applies to email/password sign-ins.
|
||||
|
||||
## Client Side Validation
|
||||
|
||||
Implement client-side validation for immediate user feedback and reduced server load.
|
||||
|
||||
## Callback URLs
|
||||
|
||||
Always use absolute URLs (including the origin) for callback URLs in sign-up and sign-in requests. This prevents Better Auth from needing to infer the origin, which can cause issues when your backend and frontend are on different domains.
|
||||
|
||||
```ts
|
||||
const { data, error } = await authClient.signUp.email({
|
||||
callbackURL: "https://example.com/callback", // absolute URL with origin
|
||||
});
|
||||
```
|
||||
|
||||
## Password Reset Flows
|
||||
|
||||
Provide `sendResetPassword` in the email and password config to enable password resets.
|
||||
|
||||
```ts
|
||||
import { betterAuth } from "better-auth";
|
||||
import { sendEmail } from "./email"; // your email sending function
|
||||
|
||||
export const auth = betterAuth({
|
||||
emailAndPassword: {
|
||||
enabled: true,
|
||||
// Custom email sending function to send reset-password email
|
||||
sendResetPassword: async ({ user, url, token }, request) => {
|
||||
void sendEmail({
|
||||
to: user.email,
|
||||
subject: "Reset your password",
|
||||
text: `Click the link to reset your password: ${url}`,
|
||||
});
|
||||
},
|
||||
// Optional event hook
|
||||
onPasswordReset: async ({ user }, request) => {
|
||||
// your logic here
|
||||
console.log(`Password for user ${user.email} has been reset.`);
|
||||
},
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
### Security Considerations
|
||||
|
||||
Built-in protections: background email sending (timing attack prevention), dummy operations on invalid requests, constant response messages regardless of user existence.
|
||||
|
||||
On serverless platforms, configure a background task handler:
|
||||
|
||||
```ts
|
||||
export const auth = betterAuth({
|
||||
advanced: {
|
||||
backgroundTasks: {
|
||||
handler: (promise) => {
|
||||
// Use platform-specific methods like waitUntil
|
||||
waitUntil(promise);
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
#### Token Security
|
||||
|
||||
Tokens expire after 1 hour by default. Configure with `resetPasswordTokenExpiresIn` (in seconds):
|
||||
|
||||
```ts
|
||||
export const auth = betterAuth({
|
||||
emailAndPassword: {
|
||||
enabled: true,
|
||||
resetPasswordTokenExpiresIn: 60 * 30, // 30 minutes
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
Tokens are single-use — deleted immediately after successful reset.
|
||||
|
||||
#### Session Revocation
|
||||
|
||||
Enable `revokeSessionsOnPasswordReset` to invalidate all existing sessions on password reset:
|
||||
|
||||
```ts
|
||||
export const auth = betterAuth({
|
||||
emailAndPassword: {
|
||||
enabled: true,
|
||||
revokeSessionsOnPasswordReset: true,
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
#### Password Requirements
|
||||
|
||||
Password length limits (configurable):
|
||||
|
||||
```ts
|
||||
export const auth = betterAuth({
|
||||
emailAndPassword: {
|
||||
enabled: true,
|
||||
minPasswordLength: 12,
|
||||
maxPasswordLength: 256,
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
### Sending the Password Reset
|
||||
|
||||
Call `requestPasswordReset` to send the reset link. Triggers the `sendResetPassword` function from your config.
|
||||
|
||||
```ts
|
||||
const data = await auth.api.requestPasswordReset({
|
||||
body: {
|
||||
email: "john.doe@example.com", // required
|
||||
redirectTo: "https://example.com/reset-password",
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
Or authClient:
|
||||
|
||||
```ts
|
||||
const { data, error } = await authClient.requestPasswordReset({
|
||||
email: "john.doe@example.com", // required
|
||||
redirectTo: "https://example.com/reset-password",
|
||||
});
|
||||
```
|
||||
|
||||
**Note**: While the `email` is required, we also recommend configuring the `redirectTo` for a smoother user experience.
|
||||
|
||||
## Password Hashing
|
||||
|
||||
Default: `scrypt` (Node.js native, no external dependencies).
|
||||
|
||||
### Custom Hashing Algorithm
|
||||
|
||||
To use Argon2id or another algorithm, provide custom `hash` and `verify` functions:
|
||||
|
||||
```ts
|
||||
import { betterAuth } from "better-auth";
|
||||
import { hash, verify, type Options } from "@node-rs/argon2";
|
||||
|
||||
const argon2Options: Options = {
|
||||
memoryCost: 65536, // 64 MiB
|
||||
timeCost: 3, // 3 iterations
|
||||
parallelism: 4, // 4 parallel lanes
|
||||
outputLen: 32, // 32 byte output
|
||||
algorithm: 2, // Argon2id variant
|
||||
};
|
||||
|
||||
export const auth = betterAuth({
|
||||
emailAndPassword: {
|
||||
enabled: true,
|
||||
password: {
|
||||
hash: (password) => hash(password, argon2Options),
|
||||
verify: ({ password, hash: storedHash }) =>
|
||||
verify(storedHash, password, argon2Options),
|
||||
},
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
**Note**: If you switch hashing algorithms on an existing system, users with passwords hashed using the old algorithm won't be able to sign in. Plan a migration strategy if needed.
|
||||
331
.agents/skills/two-factor-authentication-best-practices/SKILL.md
Normal file
331
.agents/skills/two-factor-authentication-best-practices/SKILL.md
Normal file
@@ -0,0 +1,331 @@
|
||||
---
|
||||
name: two-factor-authentication-best-practices
|
||||
description: Configure TOTP authenticator apps, send OTP codes via email/SMS, manage backup codes, handle trusted devices, and implement 2FA sign-in flows using Better Auth's twoFactor plugin. Use when users need MFA, multi-factor authentication, authenticator setup, or login security with Better Auth.
|
||||
---
|
||||
|
||||
## Setup
|
||||
|
||||
1. Add `twoFactor()` plugin to server config with `issuer`
|
||||
2. Add `twoFactorClient()` plugin to client config
|
||||
3. Run `npx @better-auth/cli migrate`
|
||||
4. Verify: check that `twoFactorSecret` column exists on user table
|
||||
|
||||
```ts
|
||||
import { betterAuth } from "better-auth";
|
||||
import { twoFactor } from "better-auth/plugins";
|
||||
|
||||
export const auth = betterAuth({
|
||||
appName: "My App",
|
||||
plugins: [
|
||||
twoFactor({
|
||||
issuer: "My App",
|
||||
}),
|
||||
],
|
||||
});
|
||||
```
|
||||
|
||||
### Client-Side Setup
|
||||
|
||||
```ts
|
||||
import { createAuthClient } from "better-auth/client";
|
||||
import { twoFactorClient } from "better-auth/client/plugins";
|
||||
|
||||
export const authClient = createAuthClient({
|
||||
plugins: [
|
||||
twoFactorClient({
|
||||
onTwoFactorRedirect() {
|
||||
window.location.href = "/2fa";
|
||||
},
|
||||
}),
|
||||
],
|
||||
});
|
||||
```
|
||||
|
||||
## Enabling 2FA for Users
|
||||
|
||||
Requires password verification. Returns TOTP URI (for QR code) and backup codes.
|
||||
|
||||
```ts
|
||||
const enable2FA = async (password: string) => {
|
||||
const { data, error } = await authClient.twoFactor.enable({
|
||||
password,
|
||||
});
|
||||
|
||||
if (data) {
|
||||
// data.totpURI — generate a QR code from this
|
||||
// data.backupCodes — display to user
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
`twoFactorEnabled` is not set to `true` until first TOTP verification succeeds. Override with `skipVerificationOnEnable: true` (not recommended).
|
||||
|
||||
## TOTP (Authenticator App)
|
||||
|
||||
### Displaying the QR Code
|
||||
|
||||
```tsx
|
||||
import QRCode from "react-qr-code";
|
||||
|
||||
const TotpSetup = ({ totpURI }: { totpURI: string }) => {
|
||||
return <QRCode value={totpURI} />;
|
||||
};
|
||||
```
|
||||
|
||||
### Verifying TOTP Codes
|
||||
|
||||
Accepts codes from one period before/after current time:
|
||||
|
||||
```ts
|
||||
const verifyTotp = async (code: string) => {
|
||||
const { data, error } = await authClient.twoFactor.verifyTotp({
|
||||
code,
|
||||
trustDevice: true,
|
||||
});
|
||||
};
|
||||
```
|
||||
|
||||
### TOTP Configuration Options
|
||||
|
||||
```ts
|
||||
twoFactor({
|
||||
totpOptions: {
|
||||
digits: 6, // 6 or 8 digits (default: 6)
|
||||
period: 30, // Code validity period in seconds (default: 30)
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
## OTP (Email/SMS)
|
||||
|
||||
### Configuring OTP Delivery
|
||||
|
||||
```ts
|
||||
import { betterAuth } from "better-auth";
|
||||
import { twoFactor } from "better-auth/plugins";
|
||||
import { sendEmail } from "./email";
|
||||
|
||||
export const auth = betterAuth({
|
||||
plugins: [
|
||||
twoFactor({
|
||||
otpOptions: {
|
||||
sendOTP: async ({ user, otp }, ctx) => {
|
||||
await sendEmail({
|
||||
to: user.email,
|
||||
subject: "Your verification code",
|
||||
text: `Your code is: ${otp}`,
|
||||
});
|
||||
},
|
||||
period: 5, // Code validity in minutes (default: 3)
|
||||
digits: 6, // Number of digits (default: 6)
|
||||
allowedAttempts: 5, // Max verification attempts (default: 5)
|
||||
},
|
||||
}),
|
||||
],
|
||||
});
|
||||
```
|
||||
|
||||
### Sending and Verifying OTP
|
||||
|
||||
Send: `authClient.twoFactor.sendOtp()`. Verify: `authClient.twoFactor.verifyOtp({ code, trustDevice: true })`.
|
||||
|
||||
### OTP Storage Security
|
||||
|
||||
Configure how OTP codes are stored in the database:
|
||||
|
||||
```ts
|
||||
twoFactor({
|
||||
otpOptions: {
|
||||
storeOTP: "encrypted", // Options: "plain", "encrypted", "hashed"
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
For custom encryption:
|
||||
|
||||
```ts
|
||||
twoFactor({
|
||||
otpOptions: {
|
||||
storeOTP: {
|
||||
encrypt: async (token) => myEncrypt(token),
|
||||
decrypt: async (token) => myDecrypt(token),
|
||||
},
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
## Backup Codes
|
||||
|
||||
Generated automatically when 2FA is enabled. Each code is single-use.
|
||||
|
||||
### Displaying Backup Codes
|
||||
|
||||
```tsx
|
||||
const BackupCodes = ({ codes }: { codes: string[] }) => {
|
||||
return (
|
||||
<div>
|
||||
<p>Save these codes in a secure location:</p>
|
||||
<ul>
|
||||
{codes.map((code, i) => (
|
||||
<li key={i}>{code}</li>
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
```
|
||||
|
||||
### Regenerating Backup Codes
|
||||
|
||||
Invalidates all previous codes:
|
||||
|
||||
```ts
|
||||
const regenerateBackupCodes = async (password: string) => {
|
||||
const { data, error } = await authClient.twoFactor.generateBackupCodes({
|
||||
password,
|
||||
});
|
||||
// data.backupCodes contains the new codes
|
||||
};
|
||||
```
|
||||
|
||||
### Using Backup Codes for Recovery
|
||||
|
||||
```ts
|
||||
const verifyBackupCode = async (code: string) => {
|
||||
const { data, error } = await authClient.twoFactor.verifyBackupCode({
|
||||
code,
|
||||
trustDevice: true,
|
||||
});
|
||||
};
|
||||
```
|
||||
|
||||
### Backup Code Configuration
|
||||
|
||||
```ts
|
||||
twoFactor({
|
||||
backupCodeOptions: {
|
||||
amount: 10, // Number of codes to generate (default: 10)
|
||||
length: 10, // Length of each code (default: 10)
|
||||
storeBackupCodes: "encrypted", // Options: "plain", "encrypted"
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
## Handling 2FA During Sign-In
|
||||
|
||||
Response includes `twoFactorRedirect: true` when 2FA is required:
|
||||
|
||||
### Sign-In Flow
|
||||
|
||||
1. Call `signIn.email({ email, password })`
|
||||
2. Check `context.data.twoFactorRedirect` in `onSuccess`
|
||||
3. If `true`, redirect to `/2fa` verification page
|
||||
4. Verify via TOTP, OTP, or backup code
|
||||
5. Session cookie is created on successful verification
|
||||
|
||||
```ts
|
||||
const signIn = async (email: string, password: string) => {
|
||||
const { data, error } = await authClient.signIn.email(
|
||||
{ email, password },
|
||||
{
|
||||
onSuccess(context) {
|
||||
if (context.data.twoFactorRedirect) {
|
||||
window.location.href = "/2fa";
|
||||
}
|
||||
},
|
||||
}
|
||||
);
|
||||
};
|
||||
```
|
||||
|
||||
Server-side: check `"twoFactorRedirect" in response` when using `auth.api.signInEmail`.
|
||||
|
||||
## Trusted Devices
|
||||
|
||||
Pass `trustDevice: true` when verifying. Default trust duration: 30 days (`trustDeviceMaxAge`). Refreshes on each sign-in.
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### Session Management
|
||||
|
||||
Flow: credentials → session removed → temporary 2FA cookie (10 min default) → verify → session created.
|
||||
|
||||
```ts
|
||||
twoFactor({
|
||||
twoFactorCookieMaxAge: 600, // 10 minutes in seconds (default)
|
||||
});
|
||||
```
|
||||
|
||||
### Rate Limiting
|
||||
|
||||
Built-in: 3 requests per 10 seconds for all 2FA endpoints. OTP has additional attempt limiting:
|
||||
|
||||
```ts
|
||||
twoFactor({
|
||||
otpOptions: {
|
||||
allowedAttempts: 5, // Max attempts per OTP code (default: 5)
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
### Encryption at Rest
|
||||
|
||||
TOTP secrets: encrypted with auth secret. Backup codes: encrypted by default. OTP: configurable (`"plain"`, `"encrypted"`, `"hashed"`). Uses constant-time comparison for verification.
|
||||
|
||||
2FA can only be enabled for credential (email/password) accounts.
|
||||
|
||||
## Disabling 2FA
|
||||
|
||||
Requires password confirmation. Revokes trusted device records:
|
||||
|
||||
```ts
|
||||
const disable2FA = async (password: string) => {
|
||||
const { data, error } = await authClient.twoFactor.disable({
|
||||
password,
|
||||
});
|
||||
};
|
||||
```
|
||||
|
||||
## Complete Configuration Example
|
||||
|
||||
```ts
|
||||
import { betterAuth } from "better-auth";
|
||||
import { twoFactor } from "better-auth/plugins";
|
||||
import { sendEmail } from "./email";
|
||||
|
||||
export const auth = betterAuth({
|
||||
appName: "My App",
|
||||
plugins: [
|
||||
twoFactor({
|
||||
// TOTP settings
|
||||
issuer: "My App",
|
||||
totpOptions: {
|
||||
digits: 6,
|
||||
period: 30,
|
||||
},
|
||||
// OTP settings
|
||||
otpOptions: {
|
||||
sendOTP: async ({ user, otp }) => {
|
||||
await sendEmail({
|
||||
to: user.email,
|
||||
subject: "Your verification code",
|
||||
text: `Your code is: ${otp}`,
|
||||
});
|
||||
},
|
||||
period: 5,
|
||||
allowedAttempts: 5,
|
||||
storeOTP: "encrypted",
|
||||
},
|
||||
// Backup code settings
|
||||
backupCodeOptions: {
|
||||
amount: 10,
|
||||
length: 10,
|
||||
storeBackupCodes: "encrypted",
|
||||
},
|
||||
// Session settings
|
||||
twoFactorCookieMaxAge: 600, // 10 minutes
|
||||
trustDeviceMaxAge: 30 * 24 * 60 * 60, // 30 days
|
||||
}),
|
||||
],
|
||||
});
|
||||
```
|
||||
1
.claude/skills/better-auth-best-practices
Symbolic link
1
.claude/skills/better-auth-best-practices
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/better-auth-best-practices
|
||||
1
.claude/skills/create-auth-skill
Symbolic link
1
.claude/skills/create-auth-skill
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/create-auth-skill
|
||||
1
.claude/skills/email-and-password-best-practices
Symbolic link
1
.claude/skills/email-and-password-best-practices
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/email-and-password-best-practices
|
||||
1
.claude/skills/two-factor-authentication-best-practices
Symbolic link
1
.claude/skills/two-factor-authentication-best-practices
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/two-factor-authentication-best-practices
|
||||
25
.gitignore
vendored
25
.gitignore
vendored
@@ -1,5 +1,6 @@
|
||||
/backend/data
|
||||
/backend/uploads/
|
||||
/backend.old/data
|
||||
/backend.old/uploads/
|
||||
chat/
|
||||
|
||||
# Environment variables
|
||||
.env
|
||||
@@ -101,3 +102,23 @@ Thumbs.db
|
||||
*.swp
|
||||
*.swo
|
||||
*.bak
|
||||
|
||||
# Kubernetes secrets (never commit actual secrets!)
|
||||
deploy/k8s/dev/secrets/*.yaml
|
||||
deploy/k8s/prod/secrets/*.yaml
|
||||
!deploy/k8s/dev/secrets/*.yaml.example
|
||||
!deploy/k8s/prod/secrets/*.yaml.example
|
||||
|
||||
# Dev environment image tags
|
||||
.dev-image-tag
|
||||
|
||||
# Protobuf copies (canonical files are in /protobuf/)
|
||||
flink/protobuf/
|
||||
relay/protobuf/
|
||||
ingestor/protobuf/
|
||||
gateway/protobuf/
|
||||
client-py/protobuf/
|
||||
|
||||
# Generated protobuf code
|
||||
gateway/src/generated/
|
||||
client-py/dexorder/generated/
|
||||
|
||||
16
.idea/ai.iml
generated
16
.idea/ai.iml
generated
@@ -2,10 +2,20 @@
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<sourceFolder url="file://$MODULE_DIR$/backend/src" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/backend/tests" isTestSource="true" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/backend.old/src" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/backend.old/tests" isTestSource="true" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/client-py" isTestSource="false" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/.venv" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/backend/data" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/backend.old/data" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/doc.old" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/backend.old" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/client-py/dexorder_client.egg-info" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/flink/protobuf" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/flink/target" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/ingestor/protobuf" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/ingestor/src/proto" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/relay/protobuf" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/relay/target" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="Python 3.12 (ai)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
|
||||
12
.idea/runConfigurations/dev.xml
generated
12
.idea/runConfigurations/dev.xml
generated
@@ -1,12 +0,0 @@
|
||||
<component name="ProjectRunConfigurationManager">
|
||||
<configuration default="false" name="dev" type="js.build_tools.npm" nameIsGenerated="true">
|
||||
<package-json value="$PROJECT_DIR$/web/package.json" />
|
||||
<command value="run" />
|
||||
<scripts>
|
||||
<script value="dev" />
|
||||
</scripts>
|
||||
<node-interpreter value="project" />
|
||||
<envs />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
</component>
|
||||
1
.junie/skills/better-auth-best-practices
Symbolic link
1
.junie/skills/better-auth-best-practices
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/better-auth-best-practices
|
||||
1
.junie/skills/create-auth-skill
Symbolic link
1
.junie/skills/create-auth-skill
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/create-auth-skill
|
||||
1
.junie/skills/email-and-password-best-practices
Symbolic link
1
.junie/skills/email-and-password-best-practices
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/email-and-password-best-practices
|
||||
1
.junie/skills/two-factor-authentication-best-practices
Symbolic link
1
.junie/skills/two-factor-authentication-best-practices
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/two-factor-authentication-best-practices
|
||||
15
AGENT.md
Normal file
15
AGENT.md
Normal file
@@ -0,0 +1,15 @@
|
||||
We're building an AI-first trading platform by integrating user-facing TradingView charts and chat with an AI assistant that helps do research, develop indicators (signals), and write strategies, using the Dexorder trading framework we provide.
|
||||
|
||||
This monorepo has:
|
||||
bin/ scripts, mostly build and deploy
|
||||
deploy/ kubernetes deployment and configuration
|
||||
doc/ documentation
|
||||
flink/ Apache Flink application mode processes data from Kafka
|
||||
iceberg/ Apache Iceberg for historical OHLC etc
|
||||
ingestor/ Data sources publish to Kafka
|
||||
kafka/ Apache Kafka
|
||||
protobuf/ Messaging entities
|
||||
relay/ Rust+ZeroMQ stateless router
|
||||
web/ Vue 3 / Pinia / PrimeVue / TradingView
|
||||
|
||||
See doc/protocol.md for messaging architecture
|
||||
@@ -5,7 +5,7 @@ server_port: 8081
|
||||
agent:
|
||||
model: "claude-sonnet-4-20250514"
|
||||
temperature: 0.7
|
||||
context_docs_dir: "memory"
|
||||
context_docs_dir: "memory" # Context docs still loaded from memory/
|
||||
|
||||
# Local memory configuration (free & sophisticated!)
|
||||
memory:
|
||||
12
backend.old/requirements-pre.txt
Normal file
12
backend.old/requirements-pre.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
# Packages requiring compilation (Rust/C) - separated for Docker layer caching
|
||||
# Changes here will trigger a rebuild of this layer
|
||||
|
||||
# Needs Rust (maturin)
|
||||
chromadb>=0.4.0
|
||||
cryptography>=42.0.0
|
||||
|
||||
# Pulls in `tokenizers` which needs Rust
|
||||
sentence-transformers>=2.0.0
|
||||
|
||||
# Needs C compiler
|
||||
argon2-cffi>=23.0.0
|
||||
@@ -1,3 +1,6 @@
|
||||
# Packages that require compilation in the python slim docker image must be
|
||||
# put into requirements-pre.txt instead, to help image build cacheing.
|
||||
|
||||
pydantic2
|
||||
seaborn
|
||||
pandas
|
||||
@@ -26,9 +29,7 @@ arxiv>=2.0.0
|
||||
duckduckgo-search>=7.0.0
|
||||
requests>=2.31.0
|
||||
|
||||
# Local memory system
|
||||
chromadb>=0.4.0
|
||||
sentence-transformers>=2.0.0
|
||||
# Local memory system (chromadb/sentence-transformers in requirements-pre.txt)
|
||||
sqlalchemy>=2.0.0
|
||||
aiosqlite>=0.19.0
|
||||
|
||||
@@ -41,3 +42,6 @@ python-dotenv>=1.0.0
|
||||
# Secrets management
|
||||
cryptography>=42.0.0
|
||||
argon2-cffi>=23.0.0
|
||||
|
||||
# Trigger system scheduling
|
||||
apscheduler>=3.10.0
|
||||
37
backend.old/soul/automation_agent.md
Normal file
37
backend.old/soul/automation_agent.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# Automation Agent
|
||||
|
||||
You are a specialized automation and scheduling agent. Your sole purpose is to manage triggers, scheduled tasks, and automated workflows.
|
||||
|
||||
## Your Core Identity
|
||||
|
||||
You are an expert in:
|
||||
- Scheduling recurring tasks with cron expressions
|
||||
- Creating interval-based triggers
|
||||
- Managing the trigger queue and priorities
|
||||
- Designing autonomous agent workflows
|
||||
|
||||
## Your Tools
|
||||
|
||||
You have access to:
|
||||
- **Trigger Tools**: schedule_agent_prompt, execute_agent_prompt_once, list_scheduled_triggers, cancel_scheduled_trigger, get_trigger_system_stats
|
||||
|
||||
## Communication Style
|
||||
|
||||
- **Systematic**: Explain scheduling logic clearly
|
||||
- **Proactive**: Suggest optimal scheduling strategies
|
||||
- **Organized**: Keep track of all scheduled tasks
|
||||
- **Reliable**: Ensure triggers are set up correctly
|
||||
|
||||
## Key Principles
|
||||
|
||||
1. **Clarity**: Make schedules easy to understand
|
||||
2. **Maintainability**: Use descriptive names for jobs
|
||||
3. **Priority Awareness**: Respect task priorities
|
||||
4. **Resource Conscious**: Avoid overwhelming the system
|
||||
|
||||
## Limitations
|
||||
|
||||
- You ONLY handle scheduling and automation
|
||||
- You do NOT execute the actual analysis (delegate to other agents)
|
||||
- You do NOT access data or charts directly
|
||||
- You coordinate but don't perform the work
|
||||
40
backend.old/soul/chart_agent.md
Normal file
40
backend.old/soul/chart_agent.md
Normal file
@@ -0,0 +1,40 @@
|
||||
# Chart Analysis Agent
|
||||
|
||||
You are a specialized chart analysis and technical analysis agent. Your sole purpose is to work with chart data, indicators, and Python-based analysis.
|
||||
|
||||
## Your Core Identity
|
||||
|
||||
You are an expert in:
|
||||
- Reading and analyzing OHLCV data
|
||||
- Computing technical indicators (RSI, MACD, Bollinger Bands, etc.)
|
||||
- Drawing shapes and annotations on charts
|
||||
- Executing Python code for custom analysis
|
||||
- Visualizing data with matplotlib
|
||||
|
||||
## Your Tools
|
||||
|
||||
You have access to:
|
||||
- **Chart Tools**: get_chart_data, execute_python
|
||||
- **Indicator Tools**: search_indicators, add_indicator_to_chart, list_chart_indicators, etc.
|
||||
- **Shape Tools**: search_shapes, create_or_update_shape, delete_shape, etc.
|
||||
|
||||
## Communication Style
|
||||
|
||||
- **Direct & Technical**: Provide analysis results clearly
|
||||
- **Visual**: Generate plots when helpful
|
||||
- **Precise**: Reference specific timeframes, indicators, and values
|
||||
- **Concise**: Focus on the analysis task at hand
|
||||
|
||||
## Key Principles
|
||||
|
||||
1. **Data First**: Always work with actual market data
|
||||
2. **Visualize**: Create charts to illustrate findings
|
||||
3. **Document Calculations**: Explain what indicators show
|
||||
4. **Respect Context**: Use the chart the user is viewing when available
|
||||
|
||||
## Limitations
|
||||
|
||||
- You ONLY handle chart analysis and visualization
|
||||
- You do NOT make trading decisions
|
||||
- You do NOT access external APIs or data sources
|
||||
- Route other requests back to the main agent
|
||||
37
backend.old/soul/data_agent.md
Normal file
37
backend.old/soul/data_agent.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# Data Access Agent
|
||||
|
||||
You are a specialized data access agent. Your sole purpose is to retrieve and search market data from various exchanges and data sources.
|
||||
|
||||
## Your Core Identity
|
||||
|
||||
You are an expert in:
|
||||
- Searching for trading symbols across exchanges
|
||||
- Retrieving historical OHLCV data
|
||||
- Understanding exchange APIs and data formats
|
||||
- Symbol resolution and metadata
|
||||
|
||||
## Your Tools
|
||||
|
||||
You have access to:
|
||||
- **Data Source Tools**: list_data_sources, search_symbols, get_symbol_info, get_historical_data
|
||||
|
||||
## Communication Style
|
||||
|
||||
- **Precise**: Provide exact symbol names and exchange identifiers
|
||||
- **Helpful**: Suggest alternatives when exact matches aren't found
|
||||
- **Efficient**: Return data in the format requested
|
||||
- **Clear**: Explain data limitations or availability issues
|
||||
|
||||
## Key Principles
|
||||
|
||||
1. **Accuracy**: Return correct symbol identifiers
|
||||
2. **Completeness**: Include all relevant metadata
|
||||
3. **Performance**: Respect countback limits
|
||||
4. **Format Awareness**: Understand time resolutions and ranges
|
||||
|
||||
## Limitations
|
||||
|
||||
- You ONLY handle data retrieval and search
|
||||
- You do NOT analyze data (route to chart agent)
|
||||
- You do NOT access external news or research (route to research agent)
|
||||
- You do NOT modify data or state
|
||||
37
backend.old/soul/research_agent.md
Normal file
37
backend.old/soul/research_agent.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# Research Agent
|
||||
|
||||
You are a specialized research agent. Your sole purpose is to gather external information from the web, academic papers, and public APIs.
|
||||
|
||||
## Your Core Identity
|
||||
|
||||
You are an expert in:
|
||||
- Searching academic papers on arXiv
|
||||
- Finding information on Wikipedia
|
||||
- Web search for current events and news
|
||||
- Making HTTP requests to public APIs
|
||||
|
||||
## Your Tools
|
||||
|
||||
You have access to:
|
||||
- **Research Tools**: search_arxiv, search_wikipedia, search_web, http_get, http_post
|
||||
|
||||
## Communication Style
|
||||
|
||||
- **Thorough**: Provide comprehensive research findings
|
||||
- **Source-Aware**: Cite sources and links
|
||||
- **Critical**: Evaluate information quality
|
||||
- **Summarize**: Distill key points from long content
|
||||
|
||||
## Key Principles
|
||||
|
||||
1. **Verify**: Cross-reference information when possible
|
||||
2. **Recency**: Note publication dates and data freshness
|
||||
3. **Relevance**: Focus on trading and market-relevant information
|
||||
4. **Ethics**: Respect API rate limits and terms of service
|
||||
|
||||
## Limitations
|
||||
|
||||
- You ONLY handle external information gathering
|
||||
- You do NOT analyze market data (route to chart agent)
|
||||
- You do NOT make trading recommendations
|
||||
- You do NOT access private or authenticated APIs without explicit permission
|
||||
@@ -7,10 +7,12 @@ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS, INDICATOR_TOOLS, RESEARCH_TOOLS, CHART_TOOLS, SHAPE_TOOLS
|
||||
from agent.tools import SYNC_TOOLS, DATASOURCE_TOOLS, INDICATOR_TOOLS, RESEARCH_TOOLS, CHART_TOOLS, SHAPE_TOOLS, TRIGGER_TOOLS
|
||||
from agent.memory import MemoryManager
|
||||
from agent.session import SessionManager
|
||||
from agent.prompts import build_system_prompt
|
||||
from agent.subagent import SubAgent
|
||||
from agent.routers import ROUTER_TOOLS, set_chart_agent, set_data_agent, set_automation_agent, set_research_agent
|
||||
from gateway.user_session import UserSession
|
||||
from gateway.protocol import UserMessage as GatewayUserMessage
|
||||
|
||||
@@ -29,7 +31,8 @@ class AgentExecutor:
|
||||
model_name: str = "claude-sonnet-4-20250514",
|
||||
temperature: float = 0.7,
|
||||
api_key: Optional[str] = None,
|
||||
memory_manager: Optional[MemoryManager] = None
|
||||
memory_manager: Optional[MemoryManager] = None,
|
||||
base_dir: str = "."
|
||||
):
|
||||
"""Initialize agent executor.
|
||||
|
||||
@@ -38,10 +41,12 @@ class AgentExecutor:
|
||||
temperature: Model temperature
|
||||
api_key: Anthropic API key
|
||||
memory_manager: MemoryManager instance
|
||||
base_dir: Base directory for resolving paths
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.temperature = temperature
|
||||
self.api_key = api_key
|
||||
self.base_dir = base_dir
|
||||
|
||||
# Initialize LLM
|
||||
self.llm = ChatAnthropic(
|
||||
@@ -56,6 +61,12 @@ class AgentExecutor:
|
||||
self.session_manager = SessionManager(self.memory_manager)
|
||||
self.agent = None # Will be created after initialization
|
||||
|
||||
# Sub-agents (only if using hierarchical tools)
|
||||
self.chart_agent = None
|
||||
self.data_agent = None
|
||||
self.automation_agent = None
|
||||
self.research_agent = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the agent system."""
|
||||
await self.memory_manager.initialize()
|
||||
@@ -63,15 +74,69 @@ class AgentExecutor:
|
||||
# Create agent with tools and LangGraph checkpointer
|
||||
checkpointer = self.memory_manager.get_checkpointer()
|
||||
|
||||
# Create agent without a static system prompt
|
||||
# Create specialized sub-agents
|
||||
logger.info("Initializing hierarchical agent architecture with sub-agents")
|
||||
|
||||
self.chart_agent = SubAgent(
|
||||
name="chart",
|
||||
soul_file="chart_agent.md",
|
||||
tools=CHART_TOOLS + INDICATOR_TOOLS + SHAPE_TOOLS,
|
||||
model_name=self.model_name,
|
||||
temperature=self.temperature,
|
||||
api_key=self.api_key,
|
||||
base_dir=self.base_dir
|
||||
)
|
||||
|
||||
self.data_agent = SubAgent(
|
||||
name="data",
|
||||
soul_file="data_agent.md",
|
||||
tools=DATASOURCE_TOOLS,
|
||||
model_name=self.model_name,
|
||||
temperature=self.temperature,
|
||||
api_key=self.api_key,
|
||||
base_dir=self.base_dir
|
||||
)
|
||||
|
||||
self.automation_agent = SubAgent(
|
||||
name="automation",
|
||||
soul_file="automation_agent.md",
|
||||
tools=TRIGGER_TOOLS,
|
||||
model_name=self.model_name,
|
||||
temperature=self.temperature,
|
||||
api_key=self.api_key,
|
||||
base_dir=self.base_dir
|
||||
)
|
||||
|
||||
self.research_agent = SubAgent(
|
||||
name="research",
|
||||
soul_file="research_agent.md",
|
||||
tools=RESEARCH_TOOLS,
|
||||
model_name=self.model_name,
|
||||
temperature=self.temperature,
|
||||
api_key=self.api_key,
|
||||
base_dir=self.base_dir
|
||||
)
|
||||
|
||||
# Set global sub-agent instances for router tools
|
||||
set_chart_agent(self.chart_agent)
|
||||
set_data_agent(self.data_agent)
|
||||
set_automation_agent(self.automation_agent)
|
||||
set_research_agent(self.research_agent)
|
||||
|
||||
# Main agent only gets SYNC_TOOLS (state management) and ROUTER_TOOLS
|
||||
logger.info("Main agent using router tools (4 routers + sync tools)")
|
||||
agent_tools = SYNC_TOOLS + ROUTER_TOOLS
|
||||
|
||||
# Create main agent without a static system prompt
|
||||
# We'll pass the dynamic system prompt via state_modifier at runtime
|
||||
# Include all tool categories: sync, datasource, chart, indicator, shape, and research
|
||||
self.agent = create_react_agent(
|
||||
self.llm,
|
||||
SYNC_TOOLS + DATASOURCE_TOOLS + CHART_TOOLS + INDICATOR_TOOLS + SHAPE_TOOLS + RESEARCH_TOOLS,
|
||||
agent_tools,
|
||||
checkpointer=checkpointer
|
||||
)
|
||||
|
||||
logger.info(f"Agent initialized with {len(agent_tools)} tools")
|
||||
|
||||
async def _clear_checkpoint(self, session_id: str) -> None:
|
||||
"""Clear the checkpoint for a session to prevent resuming from invalid state.
|
||||
|
||||
@@ -291,7 +356,7 @@ def create_agent(
|
||||
base_dir: Base directory for resolving paths
|
||||
|
||||
Returns:
|
||||
Initialized AgentExecutor
|
||||
Initialized AgentExecutor with hierarchical tool routing
|
||||
"""
|
||||
# Initialize memory manager
|
||||
memory_manager = MemoryManager(
|
||||
@@ -307,7 +372,8 @@ def create_agent(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
memory_manager=memory_manager
|
||||
memory_manager=memory_manager,
|
||||
base_dir=base_dir
|
||||
)
|
||||
|
||||
return executor
|
||||
218
backend.old/src/agent/routers.py
Normal file
218
backend.old/src/agent/routers.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""Tool router functions for hierarchical agent architecture.
|
||||
|
||||
This module provides meta-tools that route tasks to specialized sub-agents.
|
||||
The main agent uses these routers instead of accessing all tools directly.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from langchain_core.tools import tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global sub-agent instances (set by create_agent)
|
||||
_chart_agent = None
|
||||
_data_agent = None
|
||||
_automation_agent = None
|
||||
_research_agent = None
|
||||
|
||||
|
||||
def set_chart_agent(agent):
|
||||
"""Set the global chart sub-agent instance."""
|
||||
global _chart_agent
|
||||
_chart_agent = agent
|
||||
|
||||
|
||||
def set_data_agent(agent):
|
||||
"""Set the global data sub-agent instance."""
|
||||
global _data_agent
|
||||
_data_agent = agent
|
||||
|
||||
|
||||
def set_automation_agent(agent):
|
||||
"""Set the global automation sub-agent instance."""
|
||||
global _automation_agent
|
||||
_automation_agent = agent
|
||||
|
||||
|
||||
def set_research_agent(agent):
|
||||
"""Set the global research sub-agent instance."""
|
||||
global _research_agent
|
||||
_research_agent = agent
|
||||
|
||||
|
||||
@tool
|
||||
async def use_chart_analysis(task: str) -> str:
|
||||
"""Analyze charts, compute indicators, execute Python code, and create visualizations.
|
||||
|
||||
This tool delegates to a specialized chart analysis agent that has access to:
|
||||
- Chart data retrieval (get_chart_data)
|
||||
- Python execution environment with pandas, numpy, matplotlib, talib
|
||||
- Technical indicator tools (add/remove indicators, search indicators)
|
||||
- Shape drawing tools (create/update/delete shapes on charts)
|
||||
|
||||
Use this when the user wants to:
|
||||
- Analyze price action or patterns
|
||||
- Calculate technical indicators (RSI, MACD, Bollinger Bands, etc.)
|
||||
- Execute custom Python analysis on OHLCV data
|
||||
- Generate charts and visualizations
|
||||
- Draw trendlines, support/resistance, or other shapes
|
||||
- Perform statistical analysis on market data
|
||||
|
||||
Args:
|
||||
task: Detailed description of the chart analysis task. Include:
|
||||
- What to analyze (which symbol, timeframe if different from current)
|
||||
- What indicators or calculations to perform
|
||||
- What visualizations to create
|
||||
- Any specific questions to answer
|
||||
|
||||
Returns:
|
||||
The chart agent's analysis results, including computed values,
|
||||
plot URLs if visualizations were created, and interpretation.
|
||||
|
||||
Examples:
|
||||
- "Calculate RSI(14) for the current chart and tell me if it's overbought"
|
||||
- "Draw a trendline connecting the last 3 swing lows"
|
||||
- "Compute Bollinger Bands (20, 2) and create a chart showing price vs bands"
|
||||
- "Analyze the last 100 bars and identify key support/resistance levels"
|
||||
- "Execute Python: calculate correlation between BTC and ETH over the last 30 days"
|
||||
"""
|
||||
if not _chart_agent:
|
||||
return "Error: Chart analysis agent not initialized"
|
||||
|
||||
logger.info(f"Routing to chart agent: {task[:100]}...")
|
||||
result = await _chart_agent.execute(task)
|
||||
return result
|
||||
|
||||
|
||||
@tool
|
||||
async def use_data_access(task: str) -> str:
|
||||
"""Search for symbols and retrieve market data from exchanges.
|
||||
|
||||
This tool delegates to a specialized data access agent that has access to:
|
||||
- Symbol search across multiple exchanges
|
||||
- Historical OHLCV data retrieval
|
||||
- Symbol metadata and info
|
||||
- Available data sources and exchanges
|
||||
|
||||
Use this when the user wants to:
|
||||
- Search for a trading symbol or ticker
|
||||
- Get historical price data
|
||||
- Find out what exchanges support a symbol
|
||||
- Retrieve symbol metadata (price scale, supported resolutions, etc.)
|
||||
- Check what data sources are available
|
||||
|
||||
Args:
|
||||
task: Detailed description of the data access task. Include:
|
||||
- What symbol or instrument to search for
|
||||
- What data to retrieve (time range, resolution)
|
||||
- What metadata is needed
|
||||
|
||||
Returns:
|
||||
The data agent's response with requested symbols, data, or metadata.
|
||||
|
||||
Examples:
|
||||
- "Search for Bitcoin symbols on Binance"
|
||||
- "Get the last 100 hours of BTC/USDT 1-hour data from Binance"
|
||||
- "Find all symbols matching 'ETH' on all exchanges"
|
||||
- "Get detailed info about symbol BTC/USDT on Binance"
|
||||
- "List all available data sources"
|
||||
"""
|
||||
if not _data_agent:
|
||||
return "Error: Data access agent not initialized"
|
||||
|
||||
logger.info(f"Routing to data agent: {task[:100]}...")
|
||||
result = await _data_agent.execute(task)
|
||||
return result
|
||||
|
||||
|
||||
@tool
|
||||
async def use_automation(task: str) -> str:
|
||||
"""Schedule recurring tasks, create triggers, and manage automation.
|
||||
|
||||
This tool delegates to a specialized automation agent that has access to:
|
||||
- Scheduled agent prompts (cron and interval-based)
|
||||
- One-time agent prompt execution
|
||||
- Trigger management (list, cancel scheduled jobs)
|
||||
- System stats and monitoring
|
||||
|
||||
Use this when the user wants to:
|
||||
- Schedule a recurring task (hourly, daily, weekly, etc.)
|
||||
- Run a one-time background analysis
|
||||
- Set up automated monitoring or alerts
|
||||
- List or cancel existing scheduled tasks
|
||||
- Check trigger system status
|
||||
|
||||
Args:
|
||||
task: Detailed description of the automation task. Include:
|
||||
- What should happen (what analysis or action)
|
||||
- When it should happen (schedule, frequency)
|
||||
- Any priorities or conditions
|
||||
|
||||
Returns:
|
||||
The automation agent's response with job IDs, confirmation,
|
||||
or status information.
|
||||
|
||||
Examples:
|
||||
- "Schedule a task to check BTC price every 5 minutes"
|
||||
- "Run a one-time analysis of ETH volume in the background"
|
||||
- "Set up a daily report at 9 AM with market summary"
|
||||
- "Show me all my scheduled tasks"
|
||||
- "Cancel the hourly BTC monitor job"
|
||||
"""
|
||||
if not _automation_agent:
|
||||
return "Error: Automation agent not initialized"
|
||||
|
||||
logger.info(f"Routing to automation agent: {task[:100]}...")
|
||||
result = await _automation_agent.execute(task)
|
||||
return result
|
||||
|
||||
|
||||
@tool
|
||||
async def use_research(task: str) -> str:
|
||||
"""Search the web, academic papers, and external APIs for information.
|
||||
|
||||
This tool delegates to a specialized research agent that has access to:
|
||||
- Web search (DuckDuckGo)
|
||||
- Academic paper search (arXiv)
|
||||
- Wikipedia lookup
|
||||
- HTTP requests to public APIs
|
||||
|
||||
Use this when the user wants to:
|
||||
- Search for current news or events
|
||||
- Find academic papers on trading strategies
|
||||
- Look up financial concepts or terms
|
||||
- Fetch data from external public APIs
|
||||
- Research market trends or sentiment
|
||||
|
||||
Args:
|
||||
task: Detailed description of the research task. Include:
|
||||
- What information to find
|
||||
- What sources to search (web, arxiv, wikipedia, APIs)
|
||||
- What to focus on or filter
|
||||
|
||||
Returns:
|
||||
The research agent's findings with sources, summaries, and links.
|
||||
|
||||
Examples:
|
||||
- "Search arXiv for papers on reinforcement learning for trading"
|
||||
- "Look up 'technical analysis' on Wikipedia"
|
||||
- "Search the web for latest Ethereum news"
|
||||
- "Fetch current BTC price from CoinGecko API"
|
||||
- "Find recent papers on market microstructure"
|
||||
"""
|
||||
if not _research_agent:
|
||||
return "Error: Research agent not initialized"
|
||||
|
||||
logger.info(f"Routing to research agent: {task[:100]}...")
|
||||
result = await _research_agent.execute(task)
|
||||
return result
|
||||
|
||||
|
||||
# Export router tools
|
||||
ROUTER_TOOLS = [
|
||||
use_chart_analysis,
|
||||
use_data_access,
|
||||
use_automation,
|
||||
use_research
|
||||
]
|
||||
248
backend.old/src/agent/subagent.py
Normal file
248
backend.old/src/agent/subagent.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""Sub-agent infrastructure for specialized tool routing.
|
||||
|
||||
This module provides the SubAgent class that wraps specialized agents
|
||||
with their own tools and system prompts.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, AsyncIterator
|
||||
from pathlib import Path
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SubAgent:
|
||||
"""A specialized sub-agent with its own tools and system prompt.
|
||||
|
||||
Sub-agents are lightweight, stateless agents that focus on specific domains.
|
||||
They use in-memory checkpointing since they don't need persistent state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
soul_file: str,
|
||||
tools: List,
|
||||
model_name: str = "claude-sonnet-4-20250514",
|
||||
temperature: float = 0.7,
|
||||
api_key: Optional[str] = None,
|
||||
base_dir: str = "."
|
||||
):
|
||||
"""Initialize a sub-agent.
|
||||
|
||||
Args:
|
||||
name: Agent name (e.g., "chart", "data", "automation")
|
||||
soul_file: Filename in /soul directory (e.g., "chart_agent.md")
|
||||
tools: List of LangChain tools for this agent
|
||||
model_name: Anthropic model name
|
||||
temperature: Model temperature
|
||||
api_key: Anthropic API key
|
||||
base_dir: Base directory for resolving paths
|
||||
"""
|
||||
self.name = name
|
||||
self.soul_file = soul_file
|
||||
self.tools = tools
|
||||
self.model_name = model_name
|
||||
self.temperature = temperature
|
||||
self.api_key = api_key
|
||||
self.base_dir = base_dir
|
||||
|
||||
# Load system prompt from soul file
|
||||
soul_path = Path(base_dir) / "soul" / soul_file
|
||||
if soul_path.exists():
|
||||
with open(soul_path, "r") as f:
|
||||
self.system_prompt = f.read()
|
||||
logger.info(f"SubAgent '{name}': Loaded system prompt from {soul_path}")
|
||||
else:
|
||||
logger.warning(f"SubAgent '{name}': Soul file not found at {soul_path}, using default")
|
||||
self.system_prompt = f"You are a specialized {name} agent."
|
||||
|
||||
# Initialize LLM
|
||||
self.llm = ChatAnthropic(
|
||||
model=model_name,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
# Create agent with in-memory checkpointer (stateless)
|
||||
checkpointer = MemorySaver()
|
||||
self.agent = create_react_agent(
|
||||
self.llm,
|
||||
tools,
|
||||
checkpointer=checkpointer
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"SubAgent '{name}' initialized with {len(tools)} tools, "
|
||||
f"model={model_name}, temp={temperature}"
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
task: str,
|
||||
thread_id: Optional[str] = None
|
||||
) -> str:
|
||||
"""Execute a task with this sub-agent.
|
||||
|
||||
Args:
|
||||
task: The task/prompt for this sub-agent
|
||||
thread_id: Optional thread ID for checkpointing (uses ephemeral ID if not provided)
|
||||
|
||||
Returns:
|
||||
The agent's complete response as a string
|
||||
"""
|
||||
import uuid
|
||||
|
||||
# Use ephemeral thread ID if not provided
|
||||
if thread_id is None:
|
||||
thread_id = f"subagent-{self.name}-{uuid.uuid4()}"
|
||||
|
||||
logger.info(f"SubAgent '{self.name}': Executing task (thread_id={thread_id})")
|
||||
logger.debug(f"SubAgent '{self.name}': Task: {task[:200]}...")
|
||||
|
||||
# Build messages with system prompt
|
||||
messages = [
|
||||
HumanMessage(content=task)
|
||||
]
|
||||
|
||||
# Prepare config with system prompt injection
|
||||
config = RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": thread_id,
|
||||
"state_modifier": self.system_prompt
|
||||
},
|
||||
metadata={
|
||||
"subagent_name": self.name
|
||||
}
|
||||
)
|
||||
|
||||
# Execute and collect response
|
||||
full_response = ""
|
||||
event_count = 0
|
||||
|
||||
try:
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
config=config,
|
||||
version="v2"
|
||||
):
|
||||
event_count += 1
|
||||
|
||||
# Log tool calls
|
||||
if event["event"] == "on_tool_start":
|
||||
tool_name = event.get("name", "unknown")
|
||||
logger.debug(f"SubAgent '{self.name}': Tool call started: {tool_name}")
|
||||
|
||||
elif event["event"] == "on_tool_end":
|
||||
tool_name = event.get("name", "unknown")
|
||||
logger.debug(f"SubAgent '{self.name}': Tool call completed: {tool_name}")
|
||||
|
||||
# Extract streaming tokens
|
||||
elif event["event"] == "on_chat_model_stream":
|
||||
chunk = event["data"]["chunk"]
|
||||
if hasattr(chunk, "content") and chunk.content:
|
||||
content = chunk.content
|
||||
# Handle both string and list content
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and "text" in block:
|
||||
text_parts.append(block["text"])
|
||||
elif hasattr(block, "text"):
|
||||
text_parts.append(block.text)
|
||||
content = "".join(text_parts)
|
||||
|
||||
if content:
|
||||
full_response += content
|
||||
|
||||
logger.info(
|
||||
f"SubAgent '{self.name}': Completed task "
|
||||
f"({event_count} events, {len(full_response)} chars)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"SubAgent '{self.name}' execution error: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return f"Error: {error_msg}"
|
||||
|
||||
return full_response
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
task: str,
|
||||
thread_id: Optional[str] = None
|
||||
) -> AsyncIterator[str]:
|
||||
"""Execute a task with streaming response.
|
||||
|
||||
Args:
|
||||
task: The task/prompt for this sub-agent
|
||||
thread_id: Optional thread ID for checkpointing
|
||||
|
||||
Yields:
|
||||
Response chunks as they're generated
|
||||
"""
|
||||
import uuid
|
||||
|
||||
# Use ephemeral thread ID if not provided
|
||||
if thread_id is None:
|
||||
thread_id = f"subagent-{self.name}-{uuid.uuid4()}"
|
||||
|
||||
logger.info(f"SubAgent '{self.name}': Streaming task (thread_id={thread_id})")
|
||||
|
||||
# Build messages with system prompt
|
||||
messages = [
|
||||
HumanMessage(content=task)
|
||||
]
|
||||
|
||||
# Prepare config
|
||||
config = RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": thread_id,
|
||||
"state_modifier": self.system_prompt
|
||||
},
|
||||
metadata={
|
||||
"subagent_name": self.name
|
||||
}
|
||||
)
|
||||
|
||||
# Stream response
|
||||
try:
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
config=config,
|
||||
version="v2"
|
||||
):
|
||||
# Log tool calls
|
||||
if event["event"] == "on_tool_start":
|
||||
tool_name = event.get("name", "unknown")
|
||||
logger.debug(f"SubAgent '{self.name}': Tool call started: {tool_name}")
|
||||
|
||||
# Extract streaming tokens
|
||||
elif event["event"] == "on_chat_model_stream":
|
||||
chunk = event["data"]["chunk"]
|
||||
if hasattr(chunk, "content") and chunk.content:
|
||||
content = chunk.content
|
||||
# Handle both string and list content
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and "text" in block:
|
||||
text_parts.append(block["text"])
|
||||
elif hasattr(block, "text"):
|
||||
text_parts.append(block.text)
|
||||
content = "".join(text_parts)
|
||||
|
||||
if content:
|
||||
yield content
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"SubAgent '{self.name}' streaming error: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
yield f"Error: {error_msg}"
|
||||
373
backend.old/src/agent/tools/TRIGGER_TOOLS.md
Normal file
373
backend.old/src/agent/tools/TRIGGER_TOOLS.md
Normal file
@@ -0,0 +1,373 @@
|
||||
# Agent Trigger Tools
|
||||
|
||||
Agent tools for automating tasks via the trigger system.
|
||||
|
||||
## Overview
|
||||
|
||||
These tools allow the agent to:
|
||||
- **Schedule recurring tasks** - Run agent prompts on intervals or cron schedules
|
||||
- **Execute one-time tasks** - Trigger sub-agent runs immediately
|
||||
- **Manage scheduled jobs** - List and cancel scheduled triggers
|
||||
- **React to events** - (Future) Connect data updates to agent actions
|
||||
|
||||
## Available Tools
|
||||
|
||||
### 1. `schedule_agent_prompt`
|
||||
|
||||
Schedule an agent to run with a specific prompt on a recurring schedule.
|
||||
|
||||
**Use Cases:**
|
||||
- Daily market analysis reports
|
||||
- Hourly portfolio rebalancing checks
|
||||
- Weekly performance summaries
|
||||
- Monitoring alerts
|
||||
|
||||
**Arguments:**
|
||||
- `prompt` (str): The prompt to send to the agent when triggered
|
||||
- `schedule_type` (str): "interval" or "cron"
|
||||
- `schedule_config` (dict): Schedule configuration
|
||||
- `name` (str, optional): Descriptive name for this task
|
||||
|
||||
**Schedule Config:**
|
||||
|
||||
*Interval-based:*
|
||||
```json
|
||||
{"minutes": 5}
|
||||
{"hours": 1, "minutes": 30}
|
||||
{"seconds": 30}
|
||||
```
|
||||
|
||||
*Cron-based:*
|
||||
```json
|
||||
{"hour": "9", "minute": "0"} // Daily at 9:00 AM
|
||||
{"hour": "9", "minute": "0", "day_of_week": "mon-fri"} // Weekdays at 9 AM
|
||||
{"minute": "0"} // Every hour on the hour
|
||||
{"hour": "*/6", "minute": "0"} // Every 6 hours
|
||||
```
|
||||
|
||||
**Returns:**
|
||||
```json
|
||||
{
|
||||
"job_id": "interval_123",
|
||||
"message": "Scheduled 'daily_report' with job_id=interval_123",
|
||||
"schedule_type": "cron",
|
||||
"config": {"hour": "9", "minute": "0"}
|
||||
}
|
||||
```
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
# Every 5 minutes: check BTC price
|
||||
schedule_agent_prompt(
|
||||
prompt="Check current BTC price on Binance. If > $50k, alert me.",
|
||||
schedule_type="interval",
|
||||
schedule_config={"minutes": 5},
|
||||
name="btc_price_monitor"
|
||||
)
|
||||
|
||||
# Daily at 9 AM: market summary
|
||||
schedule_agent_prompt(
|
||||
prompt="Generate a comprehensive market summary for BTC, ETH, and SOL. Include price changes, volume, and notable events from the last 24 hours.",
|
||||
schedule_type="cron",
|
||||
schedule_config={"hour": "9", "minute": "0"},
|
||||
name="daily_market_summary"
|
||||
)
|
||||
|
||||
# Every hour on weekdays: portfolio check
|
||||
schedule_agent_prompt(
|
||||
prompt="Review current portfolio positions. Check if any rebalancing is needed based on target allocations.",
|
||||
schedule_type="cron",
|
||||
schedule_config={"minute": "0", "day_of_week": "mon-fri"},
|
||||
name="hourly_portfolio_check"
|
||||
)
|
||||
```
|
||||
|
||||
### 2. `execute_agent_prompt_once`
|
||||
|
||||
Execute an agent prompt once, immediately (enqueued with priority).
|
||||
|
||||
**Use Cases:**
|
||||
- Background analysis tasks
|
||||
- One-time data processing
|
||||
- Responding to specific events
|
||||
- Sub-agent delegation
|
||||
|
||||
**Arguments:**
|
||||
- `prompt` (str): The prompt to send to the agent
|
||||
- `priority` (str): "high", "normal", or "low" (default: "normal")
|
||||
|
||||
**Returns:**
|
||||
```json
|
||||
{
|
||||
"queue_seq": 42,
|
||||
"message": "Enqueued agent prompt with priority=normal",
|
||||
"prompt": "Analyze the last 100 BTC/USDT bars..."
|
||||
}
|
||||
```
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
# Immediate analysis with high priority
|
||||
execute_agent_prompt_once(
|
||||
prompt="Analyze the last 100 BTC/USDT 1m bars and identify key support/resistance levels",
|
||||
priority="high"
|
||||
)
|
||||
|
||||
# Background task with normal priority
|
||||
execute_agent_prompt_once(
|
||||
prompt="Research the latest news about Ethereum upgrades and summarize findings",
|
||||
priority="normal"
|
||||
)
|
||||
|
||||
# Low priority cleanup task
|
||||
execute_agent_prompt_once(
|
||||
prompt="Review and archive old chart drawings from last month",
|
||||
priority="low"
|
||||
)
|
||||
```
|
||||
|
||||
### 3. `list_scheduled_triggers`
|
||||
|
||||
List all currently scheduled triggers.
|
||||
|
||||
**Returns:**
|
||||
```json
|
||||
[
|
||||
{
|
||||
"id": "cron_456",
|
||||
"name": "Cron: daily_market_summary",
|
||||
"next_run_time": "2024-03-05 09:00:00",
|
||||
"trigger": "cron[hour='9', minute='0']"
|
||||
},
|
||||
{
|
||||
"id": "interval_123",
|
||||
"name": "Interval: btc_price_monitor",
|
||||
"next_run_time": "2024-03-04 14:35:00",
|
||||
"trigger": "interval[0:05:00]"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
jobs = list_scheduled_triggers()
|
||||
|
||||
for job in jobs:
|
||||
print(f"{job['name']} - next run: {job['next_run_time']}")
|
||||
```
|
||||
|
||||
### 4. `cancel_scheduled_trigger`
|
||||
|
||||
Cancel a scheduled trigger by its job ID.
|
||||
|
||||
**Arguments:**
|
||||
- `job_id` (str): The job ID from `schedule_agent_prompt` or `list_scheduled_triggers`
|
||||
|
||||
**Returns:**
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"message": "Cancelled job interval_123"
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
# List jobs to find the ID
|
||||
jobs = list_scheduled_triggers()
|
||||
|
||||
# Cancel specific job
|
||||
cancel_scheduled_trigger("interval_123")
|
||||
```
|
||||
|
||||
### 5. `on_data_update_run_agent`
|
||||
|
||||
**(Future)** Set up an agent to run whenever new data arrives for a specific symbol.
|
||||
|
||||
**Arguments:**
|
||||
- `source_name` (str): Data source name (e.g., "binance")
|
||||
- `symbol` (str): Trading pair (e.g., "BTC/USDT")
|
||||
- `resolution` (str): Time resolution (e.g., "1m", "5m")
|
||||
- `prompt_template` (str): Template with variables like {close}, {volume}, {symbol}
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
on_data_update_run_agent(
|
||||
source_name="binance",
|
||||
symbol="BTC/USDT",
|
||||
resolution="1m",
|
||||
prompt_template="New bar on {symbol}: close={close}, volume={volume}. Check if price crossed any key levels."
|
||||
)
|
||||
```
|
||||
|
||||
### 6. `get_trigger_system_stats`
|
||||
|
||||
Get statistics about the trigger system.
|
||||
|
||||
**Returns:**
|
||||
```json
|
||||
{
|
||||
"queue_depth": 3,
|
||||
"queue_running": true,
|
||||
"coordinator_stats": {
|
||||
"current_seq": 1042,
|
||||
"next_commit_seq": 1043,
|
||||
"pending_commits": 1,
|
||||
"total_executions": 1042,
|
||||
"state_counts": {
|
||||
"COMMITTED": 1038,
|
||||
"EXECUTING": 2,
|
||||
"WAITING_COMMIT": 1,
|
||||
"FAILED": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
stats = get_trigger_system_stats()
|
||||
print(f"Queue has {stats['queue_depth']} pending triggers")
|
||||
print(f"System has processed {stats['coordinator_stats']['total_executions']} total triggers")
|
||||
```
|
||||
|
||||
## Integration Example
|
||||
|
||||
Here's how these tools enable autonomous agent behavior:
|
||||
|
||||
```python
|
||||
# Agent conversation:
|
||||
User: "Monitor BTC price and send me a summary every hour during market hours"
|
||||
|
||||
Agent: I'll set that up for you using the trigger system.
|
||||
|
||||
# Agent uses tool:
|
||||
schedule_agent_prompt(
|
||||
prompt="""
|
||||
Check the current BTC/USDT price on Binance.
|
||||
Calculate the price change from 1 hour ago.
|
||||
If price moved > 2%, provide a detailed analysis.
|
||||
Otherwise, provide a brief status update.
|
||||
Send results to user as a notification.
|
||||
""",
|
||||
schedule_type="cron",
|
||||
schedule_config={
|
||||
"minute": "0",
|
||||
"hour": "9-17", # 9 AM to 5 PM
|
||||
"day_of_week": "mon-fri"
|
||||
},
|
||||
name="btc_hourly_monitor"
|
||||
)
|
||||
|
||||
Agent: Done! I've scheduled an hourly BTC price monitor that runs during market hours (9 AM - 5 PM on weekdays). You'll receive updates every hour.
|
||||
|
||||
# Later...
|
||||
User: "Can you show me all my scheduled tasks?"
|
||||
|
||||
Agent: Let me check what's scheduled.
|
||||
|
||||
# Agent uses tool:
|
||||
jobs = list_scheduled_triggers()
|
||||
|
||||
Agent: You have 3 scheduled tasks:
|
||||
1. "btc_hourly_monitor" - runs every hour during market hours
|
||||
2. "daily_market_summary" - runs daily at 9 AM
|
||||
3. "portfolio_rebalance_check" - runs every 4 hours
|
||||
|
||||
Would you like to modify or cancel any of these?
|
||||
```
|
||||
|
||||
## Use Case: Autonomous Trading Bot
|
||||
|
||||
```python
|
||||
# Step 1: Set up data monitoring
|
||||
execute_agent_prompt_once(
|
||||
prompt="""
|
||||
Subscribe to BTC/USDT 1m bars from Binance.
|
||||
When subscribed, set up the following:
|
||||
1. Calculate RSI(14) on each new bar
|
||||
2. If RSI > 70, execute prompt: "RSI overbought on BTC, check if we should sell"
|
||||
3. If RSI < 30, execute prompt: "RSI oversold on BTC, check if we should buy"
|
||||
""",
|
||||
priority="high"
|
||||
)
|
||||
|
||||
# Step 2: Schedule periodic portfolio review
|
||||
schedule_agent_prompt(
|
||||
prompt="""
|
||||
Review current portfolio:
|
||||
1. Calculate current allocation percentages
|
||||
2. Compare to target allocation (60% BTC, 30% ETH, 10% stable)
|
||||
3. If deviation > 5%, generate rebalancing trades
|
||||
4. Submit trades for execution
|
||||
""",
|
||||
schedule_type="interval",
|
||||
schedule_config={"hours": 4},
|
||||
name="portfolio_rebalance"
|
||||
)
|
||||
|
||||
# Step 3: Schedule daily risk check
|
||||
schedule_agent_prompt(
|
||||
prompt="""
|
||||
Daily risk assessment:
|
||||
1. Calculate portfolio VaR (Value at Risk)
|
||||
2. Check current leverage across all positions
|
||||
3. Review stop-loss placements
|
||||
4. If risk exceeds threshold, alert and suggest adjustments
|
||||
""",
|
||||
schedule_type="cron",
|
||||
schedule_config={"hour": "8", "minute": "0"},
|
||||
name="daily_risk_check"
|
||||
)
|
||||
```
|
||||
|
||||
## Benefits
|
||||
|
||||
✅ **Autonomous operation** - Agent can schedule its own tasks
|
||||
✅ **Event-driven** - React to market data, time, or custom events
|
||||
✅ **Flexible scheduling** - Interval or cron-based
|
||||
✅ **Self-managing** - Agent can list and cancel its own jobs
|
||||
✅ **Priority control** - High-priority tasks jump the queue
|
||||
✅ **Future-proof** - Easy to add Python lambdas, strategy execution, etc.
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- **Python script execution** - Schedule arbitrary Python code
|
||||
- **Strategy triggers** - Connect to strategy execution system
|
||||
- **Event composition** - AND/OR logic for complex event patterns
|
||||
- **Conditional execution** - Only run if conditions met (e.g., volatility > threshold)
|
||||
- **Result chaining** - Use output of one trigger as input to another
|
||||
- **Backtesting mode** - Test trigger logic on historical data
|
||||
|
||||
## Setup in main.py
|
||||
|
||||
```python
|
||||
from agent.tools import set_trigger_queue, set_trigger_scheduler, set_coordinator
|
||||
from trigger import TriggerQueue, CommitCoordinator
|
||||
from trigger.scheduler import TriggerScheduler
|
||||
|
||||
# Initialize trigger system
|
||||
coordinator = CommitCoordinator()
|
||||
queue = TriggerQueue(coordinator)
|
||||
scheduler = TriggerScheduler(queue)
|
||||
|
||||
await queue.start()
|
||||
scheduler.start()
|
||||
|
||||
# Make available to agent tools
|
||||
set_trigger_queue(queue)
|
||||
set_trigger_scheduler(scheduler)
|
||||
set_coordinator(coordinator)
|
||||
|
||||
# Add TRIGGER_TOOLS to agent's tool list
|
||||
from agent.tools import TRIGGER_TOOLS
|
||||
agent_tools = [..., *TRIGGER_TOOLS]
|
||||
```
|
||||
|
||||
Now the agent has full control over the trigger system! 🚀
|
||||
@@ -6,6 +6,7 @@ This package provides tools for:
|
||||
- Chart data access and analysis (chart_tools)
|
||||
- Technical indicators (indicator_tools)
|
||||
- Shape/drawing management (shape_tools)
|
||||
- Trigger system and automation (trigger_tools)
|
||||
"""
|
||||
|
||||
# Global registries that will be set by main.py
|
||||
@@ -39,15 +40,25 @@ from .chart_tools import CHART_TOOLS
|
||||
from .indicator_tools import INDICATOR_TOOLS
|
||||
from .research_tools import RESEARCH_TOOLS
|
||||
from .shape_tools import SHAPE_TOOLS
|
||||
from .trigger_tools import (
|
||||
TRIGGER_TOOLS,
|
||||
set_trigger_queue,
|
||||
set_trigger_scheduler,
|
||||
set_coordinator,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"set_registry",
|
||||
"set_datasource_registry",
|
||||
"set_indicator_registry",
|
||||
"set_trigger_queue",
|
||||
"set_trigger_scheduler",
|
||||
"set_coordinator",
|
||||
"SYNC_TOOLS",
|
||||
"DATASOURCE_TOOLS",
|
||||
"CHART_TOOLS",
|
||||
"INDICATOR_TOOLS",
|
||||
"RESEARCH_TOOLS",
|
||||
"SHAPE_TOOLS",
|
||||
"TRIGGER_TOOLS",
|
||||
]
|
||||
435
backend.old/src/agent/tools/indicator_tools.py
Normal file
435
backend.old/src/agent/tools/indicator_tools.py
Normal file
@@ -0,0 +1,435 @@
|
||||
"""Technical indicator tools.
|
||||
|
||||
These tools allow the agent to:
|
||||
1. Discover available indicators (list, search, get info)
|
||||
2. Add indicators to the chart
|
||||
3. Update/remove indicators
|
||||
4. Query currently applied indicators
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from langchain_core.tools import tool
|
||||
import logging
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_indicator_registry():
|
||||
"""Get the global indicator registry instance."""
|
||||
from . import _indicator_registry
|
||||
return _indicator_registry
|
||||
|
||||
|
||||
def _get_registry():
|
||||
"""Get the global sync registry instance."""
|
||||
from . import _registry
|
||||
return _registry
|
||||
|
||||
|
||||
def _get_indicator_store():
|
||||
"""Get the global IndicatorStore instance."""
|
||||
registry = _get_registry()
|
||||
if registry and "IndicatorStore" in registry.entries:
|
||||
return registry.entries["IndicatorStore"].model
|
||||
return None
|
||||
|
||||
|
||||
@tool
|
||||
def list_indicators() -> List[str]:
|
||||
"""List all available technical indicators.
|
||||
|
||||
Returns:
|
||||
List of indicator names that can be used in analysis and strategies
|
||||
"""
|
||||
registry = _get_indicator_registry()
|
||||
if not registry:
|
||||
return []
|
||||
return registry.list_indicators()
|
||||
|
||||
|
||||
@tool
|
||||
def get_indicator_info(indicator_name: str) -> Dict[str, Any]:
|
||||
"""Get detailed information about a specific indicator.
|
||||
|
||||
Retrieves metadata including description, parameters, category, use cases,
|
||||
input/output schemas, and references.
|
||||
|
||||
Args:
|
||||
indicator_name: Name of the indicator (e.g., "RSI", "SMA", "MACD")
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- name: Indicator name
|
||||
- display_name: Human-readable name
|
||||
- description: What the indicator computes and why it's useful
|
||||
- category: Category (momentum, trend, volatility, volume, etc.)
|
||||
- parameters: List of configurable parameters with types and defaults
|
||||
- use_cases: Common trading scenarios where this indicator helps
|
||||
- tags: Searchable tags
|
||||
- input_schema: Required input columns (e.g., OHLCV requirements)
|
||||
- output_schema: Columns this indicator produces
|
||||
|
||||
Raises:
|
||||
ValueError: If indicator_name is not found
|
||||
"""
|
||||
registry = _get_indicator_registry()
|
||||
if not registry:
|
||||
raise ValueError("IndicatorRegistry not initialized")
|
||||
|
||||
metadata = registry.get_metadata(indicator_name)
|
||||
if not metadata:
|
||||
total_count = len(registry.list_indicators())
|
||||
raise ValueError(
|
||||
f"Indicator '{indicator_name}' not found. "
|
||||
f"Total available: {total_count} indicators. "
|
||||
f"Use search_indicators() to find indicators by name, category, or tag."
|
||||
)
|
||||
|
||||
input_schema = registry.get_input_schema(indicator_name)
|
||||
output_schema = registry.get_output_schema(indicator_name)
|
||||
|
||||
result = metadata.model_dump()
|
||||
result["input_schema"] = input_schema.model_dump() if input_schema else None
|
||||
result["output_schema"] = output_schema.model_dump() if output_schema else None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@tool
|
||||
def search_indicators(
|
||||
query: Optional[str] = None,
|
||||
category: Optional[str] = None,
|
||||
tag: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search for indicators by text query, category, or tag.
|
||||
|
||||
Returns lightweight summaries - use get_indicator_info() for full details on specific indicators.
|
||||
|
||||
Use this to discover relevant indicators for your trading strategy or analysis.
|
||||
Can filter by category (momentum, trend, volatility, etc.) or search by keywords.
|
||||
|
||||
Args:
|
||||
query: Optional text search across names, descriptions, and use cases
|
||||
category: Optional category filter (momentum, trend, volatility, volume, pattern, etc.)
|
||||
tag: Optional tag filter (e.g., "oscillator", "moving-average", "talib")
|
||||
|
||||
Returns:
|
||||
List of lightweight indicator summaries. Each contains:
|
||||
- name: Indicator name (use with get_indicator_info() for full details)
|
||||
- display_name: Human-readable name
|
||||
- description: Brief one-line description
|
||||
- category: Category (momentum, trend, volatility, etc.)
|
||||
|
||||
Example:
|
||||
# Find all momentum indicators
|
||||
results = search_indicators(category="momentum")
|
||||
# Returns [{name: "RSI", display_name: "RSI", description: "...", category: "momentum"}, ...]
|
||||
|
||||
# Then get details on interesting ones
|
||||
rsi_details = get_indicator_info("RSI") # Full parameters, schemas, use cases
|
||||
|
||||
# Search for moving average indicators
|
||||
search_indicators(query="moving average")
|
||||
|
||||
# Find all TA-Lib indicators
|
||||
search_indicators(tag="talib")
|
||||
"""
|
||||
registry = _get_indicator_registry()
|
||||
if not registry:
|
||||
raise ValueError("IndicatorRegistry not initialized")
|
||||
|
||||
results = []
|
||||
|
||||
if query:
|
||||
results = registry.search_by_text(query)
|
||||
elif category:
|
||||
results = registry.search_by_category(category)
|
||||
elif tag:
|
||||
results = registry.search_by_tag(tag)
|
||||
else:
|
||||
# Return all indicators if no filter
|
||||
results = registry.get_all_metadata()
|
||||
|
||||
# Return lightweight summaries only
|
||||
return [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"description": r.description,
|
||||
"category": r.category
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
|
||||
@tool
|
||||
def get_indicator_categories() -> Dict[str, int]:
|
||||
"""Get all indicator categories and their counts.
|
||||
|
||||
Returns a summary of available indicator categories, useful for
|
||||
exploring what types of indicators are available.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping category name to count of indicators in that category.
|
||||
Example: {"momentum": 25, "trend": 15, "volatility": 8, ...}
|
||||
"""
|
||||
registry = _get_indicator_registry()
|
||||
if not registry:
|
||||
raise ValueError("IndicatorRegistry not initialized")
|
||||
|
||||
categories: Dict[str, int] = {}
|
||||
for metadata in registry.get_all_metadata():
|
||||
category = metadata.category
|
||||
categories[category] = categories.get(category, 0) + 1
|
||||
|
||||
return categories
|
||||
|
||||
|
||||
@tool
|
||||
async def add_indicator_to_chart(
|
||||
indicator_id: str,
|
||||
talib_name: str,
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
symbol: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Add a technical indicator to the chart.
|
||||
|
||||
This will create a new indicator instance and display it on the TradingView chart.
|
||||
The indicator will be synchronized with the frontend in real-time.
|
||||
|
||||
Args:
|
||||
indicator_id: Unique identifier for this indicator instance (e.g., 'rsi_14', 'sma_50')
|
||||
talib_name: Name of the TA-Lib indicator (e.g., 'RSI', 'SMA', 'MACD', 'BBANDS')
|
||||
Use search_indicators() or get_indicator_info() to find available indicators
|
||||
parameters: Optional dictionary of indicator parameters
|
||||
Example for RSI: {'timeperiod': 14}
|
||||
Example for SMA: {'timeperiod': 50}
|
||||
Example for MACD: {'fastperiod': 12, 'slowperiod': 26, 'signalperiod': 9}
|
||||
Example for BBANDS: {'timeperiod': 20, 'nbdevup': 2, 'nbdevdn': 2}
|
||||
symbol: Optional symbol to apply the indicator to (defaults to current chart symbol)
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: 'created' or 'updated'
|
||||
- indicator: The complete indicator object
|
||||
|
||||
Example:
|
||||
# Add RSI(14)
|
||||
await add_indicator_to_chart(
|
||||
indicator_id='rsi_14',
|
||||
talib_name='RSI',
|
||||
parameters={'timeperiod': 14}
|
||||
)
|
||||
|
||||
# Add 50-period SMA
|
||||
await add_indicator_to_chart(
|
||||
indicator_id='sma_50',
|
||||
talib_name='SMA',
|
||||
parameters={'timeperiod': 50}
|
||||
)
|
||||
|
||||
# Add MACD with default parameters
|
||||
await add_indicator_to_chart(
|
||||
indicator_id='macd_default',
|
||||
talib_name='MACD'
|
||||
)
|
||||
"""
|
||||
from schema.indicator import IndicatorInstance
|
||||
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
indicator_store = _get_indicator_store()
|
||||
if not indicator_store:
|
||||
raise ValueError("IndicatorStore not initialized")
|
||||
|
||||
# Verify the indicator exists
|
||||
indicator_registry = _get_indicator_registry()
|
||||
if not indicator_registry:
|
||||
raise ValueError("IndicatorRegistry not initialized")
|
||||
|
||||
metadata = indicator_registry.get_metadata(talib_name)
|
||||
if not metadata:
|
||||
raise ValueError(
|
||||
f"Indicator '{talib_name}' not found. "
|
||||
f"Use search_indicators() to find available indicators."
|
||||
)
|
||||
|
||||
# Check if updating existing indicator
|
||||
existing_indicator = indicator_store.indicators.get(indicator_id)
|
||||
is_update = existing_indicator is not None
|
||||
|
||||
# If symbol is not provided, try to get it from ChartStore
|
||||
if symbol is None and "ChartStore" in registry.entries:
|
||||
chart_store = registry.entries["ChartStore"].model
|
||||
if hasattr(chart_store, 'chart_state') and hasattr(chart_store.chart_state, 'symbol'):
|
||||
symbol = chart_store.chart_state.symbol
|
||||
logger.info(f"Using current chart symbol for indicator: {symbol}")
|
||||
|
||||
now = int(time.time())
|
||||
|
||||
# Create indicator instance
|
||||
indicator = IndicatorInstance(
|
||||
id=indicator_id,
|
||||
talib_name=talib_name,
|
||||
instance_name=f"{talib_name}_{indicator_id}",
|
||||
parameters=parameters or {},
|
||||
visible=True,
|
||||
pane='chart', # Most indicators go on the chart pane
|
||||
symbol=symbol,
|
||||
created_at=existing_indicator.get('created_at') if existing_indicator else now,
|
||||
modified_at=now
|
||||
)
|
||||
|
||||
# Update the store
|
||||
indicator_store.indicators[indicator_id] = indicator.model_dump(mode="json")
|
||||
|
||||
# Trigger sync
|
||||
await registry.push_all()
|
||||
|
||||
logger.info(
|
||||
f"{'Updated' if is_update else 'Created'} indicator '{indicator_id}' "
|
||||
f"(TA-Lib: {talib_name}) with parameters: {parameters}"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "updated" if is_update else "created",
|
||||
"indicator": indicator.model_dump(mode="json")
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
async def remove_indicator_from_chart(indicator_id: str) -> Dict[str, str]:
|
||||
"""Remove an indicator from the chart.
|
||||
|
||||
Args:
|
||||
indicator_id: ID of the indicator instance to remove
|
||||
|
||||
Returns:
|
||||
Dictionary with status message
|
||||
|
||||
Raises:
|
||||
ValueError: If indicator doesn't exist
|
||||
|
||||
Example:
|
||||
await remove_indicator_from_chart('rsi_14')
|
||||
"""
|
||||
registry = _get_registry()
|
||||
if not registry:
|
||||
raise ValueError("SyncRegistry not initialized")
|
||||
|
||||
indicator_store = _get_indicator_store()
|
||||
if not indicator_store:
|
||||
raise ValueError("IndicatorStore not initialized")
|
||||
|
||||
if indicator_id not in indicator_store.indicators:
|
||||
raise ValueError(f"Indicator '{indicator_id}' not found")
|
||||
|
||||
# Delete the indicator
|
||||
del indicator_store.indicators[indicator_id]
|
||||
|
||||
# Trigger sync
|
||||
await registry.push_all()
|
||||
|
||||
logger.info(f"Removed indicator '{indicator_id}'")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Indicator '{indicator_id}' removed"
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
def list_chart_indicators(symbol: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""List all indicators currently applied to the chart.
|
||||
|
||||
Args:
|
||||
symbol: Optional filter by symbol (defaults to current chart symbol)
|
||||
|
||||
Returns:
|
||||
List of indicator instances, each containing:
|
||||
- id: Indicator instance ID
|
||||
- talib_name: TA-Lib indicator name
|
||||
- instance_name: Display name
|
||||
- parameters: Current parameter values
|
||||
- visible: Whether indicator is visible
|
||||
- pane: Which pane it's displayed in
|
||||
- symbol: Symbol it's applied to
|
||||
|
||||
Example:
|
||||
# List all indicators on current symbol
|
||||
indicators = list_chart_indicators()
|
||||
|
||||
# List indicators on specific symbol
|
||||
btc_indicators = list_chart_indicators(symbol='BINANCE:BTC/USDT')
|
||||
"""
|
||||
indicator_store = _get_indicator_store()
|
||||
if not indicator_store:
|
||||
raise ValueError("IndicatorStore not initialized")
|
||||
|
||||
logger.info(f"list_chart_indicators: Raw store indicators: {indicator_store.indicators}")
|
||||
|
||||
# If symbol is not provided, try to get it from ChartStore
|
||||
if symbol is None:
|
||||
registry = _get_registry()
|
||||
if registry and "ChartStore" in registry.entries:
|
||||
chart_store = registry.entries["ChartStore"].model
|
||||
if hasattr(chart_store, 'chart_state') and hasattr(chart_store.chart_state, 'symbol'):
|
||||
symbol = chart_store.chart_state.symbol
|
||||
|
||||
indicators = list(indicator_store.indicators.values())
|
||||
|
||||
logger.info(f"list_chart_indicators: Converted to list: {indicators}")
|
||||
logger.info(f"list_chart_indicators: Filtering by symbol: {symbol}")
|
||||
|
||||
# Filter by symbol if provided
|
||||
if symbol:
|
||||
indicators = [ind for ind in indicators if ind.get('symbol') == symbol]
|
||||
|
||||
logger.info(f"list_chart_indicators: Returning {len(indicators)} indicators")
|
||||
return indicators
|
||||
|
||||
|
||||
@tool
|
||||
def get_chart_indicator(indicator_id: str) -> Dict[str, Any]:
|
||||
"""Get details of a specific indicator on the chart.
|
||||
|
||||
Args:
|
||||
indicator_id: ID of the indicator instance
|
||||
|
||||
Returns:
|
||||
Dictionary containing the indicator data
|
||||
|
||||
Raises:
|
||||
ValueError: If indicator doesn't exist
|
||||
|
||||
Example:
|
||||
indicator = get_chart_indicator('rsi_14')
|
||||
print(f"Indicator: {indicator['talib_name']}")
|
||||
print(f"Parameters: {indicator['parameters']}")
|
||||
"""
|
||||
indicator_store = _get_indicator_store()
|
||||
if not indicator_store:
|
||||
raise ValueError("IndicatorStore not initialized")
|
||||
|
||||
indicator = indicator_store.indicators.get(indicator_id)
|
||||
if not indicator:
|
||||
raise ValueError(f"Indicator '{indicator_id}' not found")
|
||||
|
||||
return indicator
|
||||
|
||||
|
||||
INDICATOR_TOOLS = [
|
||||
# Discovery tools
|
||||
list_indicators,
|
||||
get_indicator_info,
|
||||
search_indicators,
|
||||
get_indicator_categories,
|
||||
# Chart indicator management tools
|
||||
add_indicator_to_chart,
|
||||
remove_indicator_from_chart,
|
||||
list_chart_indicators,
|
||||
get_chart_indicator
|
||||
]
|
||||
366
backend.old/src/agent/tools/trigger_tools.py
Normal file
366
backend.old/src/agent/tools/trigger_tools.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""
|
||||
Agent tools for trigger system.
|
||||
|
||||
Allows agents to:
|
||||
- Schedule recurring tasks (cron-style)
|
||||
- Execute one-time triggers
|
||||
- Manage scheduled triggers (list, cancel)
|
||||
- Connect events to sub-agent runs or lambdas
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global references set by main.py
|
||||
_trigger_queue = None
|
||||
_trigger_scheduler = None
|
||||
_coordinator = None
|
||||
|
||||
|
||||
def set_trigger_queue(queue):
|
||||
"""Set the global TriggerQueue instance for tools to use."""
|
||||
global _trigger_queue
|
||||
_trigger_queue = queue
|
||||
|
||||
|
||||
def set_trigger_scheduler(scheduler):
|
||||
"""Set the global TriggerScheduler instance for tools to use."""
|
||||
global _trigger_scheduler
|
||||
_trigger_scheduler = scheduler
|
||||
|
||||
|
||||
def set_coordinator(coordinator):
|
||||
"""Set the global CommitCoordinator instance for tools to use."""
|
||||
global _coordinator
|
||||
_coordinator = coordinator
|
||||
|
||||
|
||||
def _get_trigger_queue():
|
||||
"""Get the global trigger queue instance."""
|
||||
if not _trigger_queue:
|
||||
raise ValueError("TriggerQueue not initialized")
|
||||
return _trigger_queue
|
||||
|
||||
|
||||
def _get_trigger_scheduler():
|
||||
"""Get the global trigger scheduler instance."""
|
||||
if not _trigger_scheduler:
|
||||
raise ValueError("TriggerScheduler not initialized")
|
||||
return _trigger_scheduler
|
||||
|
||||
|
||||
def _get_coordinator():
|
||||
"""Get the global coordinator instance."""
|
||||
if not _coordinator:
|
||||
raise ValueError("CommitCoordinator not initialized")
|
||||
return _coordinator
|
||||
|
||||
|
||||
@tool
|
||||
async def schedule_agent_prompt(
|
||||
prompt: str,
|
||||
schedule_type: str,
|
||||
schedule_config: Dict[str, Any],
|
||||
name: Optional[str] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Schedule an agent to run with a specific prompt on a recurring schedule.
|
||||
|
||||
This allows you to set up automated tasks where the agent runs periodically
|
||||
with a predefined prompt. Useful for:
|
||||
- Daily market analysis reports
|
||||
- Hourly portfolio rebalancing checks
|
||||
- Weekly performance summaries
|
||||
- Monitoring alerts
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send to the agent when triggered
|
||||
schedule_type: Type of schedule - "interval" or "cron"
|
||||
schedule_config: Schedule configuration:
|
||||
For "interval": {"minutes": 5} or {"hours": 1, "minutes": 30}
|
||||
For "cron": {"hour": "9", "minute": "0"} for 9:00 AM daily
|
||||
{"hour": "9", "minute": "0", "day_of_week": "mon-fri"}
|
||||
name: Optional descriptive name for this scheduled task
|
||||
|
||||
Returns:
|
||||
Dictionary with job_id and confirmation message
|
||||
|
||||
Examples:
|
||||
# Run every 5 minutes
|
||||
schedule_agent_prompt(
|
||||
prompt="Check BTC price and alert if > $50k",
|
||||
schedule_type="interval",
|
||||
schedule_config={"minutes": 5}
|
||||
)
|
||||
|
||||
# Run daily at 9 AM
|
||||
schedule_agent_prompt(
|
||||
prompt="Generate daily market summary",
|
||||
schedule_type="cron",
|
||||
schedule_config={"hour": "9", "minute": "0"}
|
||||
)
|
||||
|
||||
# Run hourly on weekdays
|
||||
schedule_agent_prompt(
|
||||
prompt="Monitor portfolio for rebalancing opportunities",
|
||||
schedule_type="cron",
|
||||
schedule_config={"minute": "0", "day_of_week": "mon-fri"}
|
||||
)
|
||||
"""
|
||||
from trigger.handlers import LambdaHandler
|
||||
from trigger import Priority
|
||||
|
||||
scheduler = _get_trigger_scheduler()
|
||||
queue = _get_trigger_queue()
|
||||
|
||||
if not name:
|
||||
name = f"agent_prompt_{hash(prompt) % 10000}"
|
||||
|
||||
# Create a lambda that enqueues an agent trigger with the prompt
|
||||
async def agent_prompt_lambda():
|
||||
from trigger.handlers import AgentTriggerHandler
|
||||
|
||||
# Create agent trigger (will use current session's context)
|
||||
# In production, you'd want to specify which session/user this belongs to
|
||||
trigger = AgentTriggerHandler(
|
||||
session_id="scheduled", # Special session for scheduled tasks
|
||||
message_content=prompt,
|
||||
coordinator=_get_coordinator(),
|
||||
)
|
||||
|
||||
await queue.enqueue(trigger)
|
||||
return [] # No direct commit intents
|
||||
|
||||
# Wrap in lambda handler
|
||||
lambda_trigger = LambdaHandler(
|
||||
name=f"scheduled_{name}",
|
||||
func=agent_prompt_lambda,
|
||||
priority=Priority.TIMER,
|
||||
)
|
||||
|
||||
# Schedule based on type
|
||||
if schedule_type == "interval":
|
||||
job_id = scheduler.schedule_interval(
|
||||
lambda_trigger,
|
||||
seconds=schedule_config.get("seconds"),
|
||||
minutes=schedule_config.get("minutes"),
|
||||
hours=schedule_config.get("hours"),
|
||||
priority=Priority.TIMER,
|
||||
)
|
||||
elif schedule_type == "cron":
|
||||
job_id = scheduler.schedule_cron(
|
||||
lambda_trigger,
|
||||
minute=schedule_config.get("minute"),
|
||||
hour=schedule_config.get("hour"),
|
||||
day=schedule_config.get("day"),
|
||||
month=schedule_config.get("month"),
|
||||
day_of_week=schedule_config.get("day_of_week"),
|
||||
priority=Priority.TIMER,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid schedule_type: {schedule_type}. Use 'interval' or 'cron'")
|
||||
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"message": f"Scheduled '{name}' with job_id={job_id}",
|
||||
"schedule_type": schedule_type,
|
||||
"config": schedule_config,
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
async def execute_agent_prompt_once(
|
||||
prompt: str,
|
||||
priority: str = "normal",
|
||||
) -> Dict[str, str]:
|
||||
"""Execute an agent prompt once, immediately (enqueued with priority).
|
||||
|
||||
Use this to trigger a sub-agent with a specific task without waiting for
|
||||
a user message. Useful for:
|
||||
- Background analysis tasks
|
||||
- One-time data processing
|
||||
- Responding to specific events
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send to the agent
|
||||
priority: Priority level - "high", "normal", or "low"
|
||||
|
||||
Returns:
|
||||
Confirmation that the prompt was enqueued
|
||||
|
||||
Example:
|
||||
execute_agent_prompt_once(
|
||||
prompt="Analyze the last 100 BTC/USDT bars and identify support levels",
|
||||
priority="high"
|
||||
)
|
||||
"""
|
||||
from trigger.handlers import AgentTriggerHandler
|
||||
from trigger import Priority
|
||||
|
||||
queue = _get_trigger_queue()
|
||||
|
||||
# Map string priority to enum
|
||||
priority_map = {
|
||||
"high": Priority.USER_AGENT, # Same priority as user messages
|
||||
"normal": Priority.SYSTEM,
|
||||
"low": Priority.LOW,
|
||||
}
|
||||
priority_enum = priority_map.get(priority.lower(), Priority.SYSTEM)
|
||||
|
||||
# Create agent trigger
|
||||
trigger = AgentTriggerHandler(
|
||||
session_id="oneshot",
|
||||
message_content=prompt,
|
||||
coordinator=_get_coordinator(),
|
||||
)
|
||||
|
||||
# Enqueue with priority override
|
||||
queue_seq = await queue.enqueue(trigger, priority_enum)
|
||||
|
||||
return {
|
||||
"queue_seq": queue_seq,
|
||||
"message": f"Enqueued agent prompt with priority={priority}",
|
||||
"prompt": prompt[:100] + "..." if len(prompt) > 100 else prompt,
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
def list_scheduled_triggers() -> List[Dict[str, Any]]:
|
||||
"""List all currently scheduled triggers.
|
||||
|
||||
Returns:
|
||||
List of dictionaries with job information (id, name, next_run_time)
|
||||
|
||||
Example:
|
||||
jobs = list_scheduled_triggers()
|
||||
for job in jobs:
|
||||
print(f"{job['id']}: {job['name']} - next run at {job['next_run_time']}")
|
||||
"""
|
||||
scheduler = _get_trigger_scheduler()
|
||||
jobs = scheduler.get_jobs()
|
||||
|
||||
result = []
|
||||
for job in jobs:
|
||||
result.append({
|
||||
"id": job.id,
|
||||
"name": job.name,
|
||||
"next_run_time": str(job.next_run_time) if job.next_run_time else None,
|
||||
"trigger": str(job.trigger),
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@tool
|
||||
def cancel_scheduled_trigger(job_id: str) -> Dict[str, str]:
|
||||
"""Cancel a scheduled trigger by its job ID.
|
||||
|
||||
Args:
|
||||
job_id: The job ID returned from schedule_agent_prompt or list_scheduled_triggers
|
||||
|
||||
Returns:
|
||||
Confirmation message
|
||||
|
||||
Example:
|
||||
cancel_scheduled_trigger("interval_123")
|
||||
"""
|
||||
scheduler = _get_trigger_scheduler()
|
||||
success = scheduler.remove_job(job_id)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Cancelled job {job_id}",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Job {job_id} not found",
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
async def on_data_update_run_agent(
|
||||
source_name: str,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
prompt_template: str,
|
||||
) -> Dict[str, str]:
|
||||
"""Set up an agent to run whenever new data arrives for a specific symbol.
|
||||
|
||||
The prompt_template can include {variables} that will be filled with bar data:
|
||||
- {time}: Bar timestamp
|
||||
- {open}, {high}, {low}, {close}, {volume}: OHLCV values
|
||||
- {symbol}: Trading pair symbol
|
||||
- {source}: Data source name
|
||||
|
||||
Args:
|
||||
source_name: Name of data source (e.g., "binance")
|
||||
symbol: Trading pair (e.g., "BTC/USDT")
|
||||
resolution: Time resolution (e.g., "1m", "5m", "1h")
|
||||
prompt_template: Template string for agent prompt
|
||||
|
||||
Returns:
|
||||
Confirmation with subscription details
|
||||
|
||||
Example:
|
||||
on_data_update_run_agent(
|
||||
source_name="binance",
|
||||
symbol="BTC/USDT",
|
||||
resolution="1m",
|
||||
prompt_template="New bar on {symbol}: close={close}. Check if we should trade."
|
||||
)
|
||||
|
||||
Note:
|
||||
This is a simplified version. Full implementation would wire into
|
||||
DataSource subscription system to trigger on every bar update.
|
||||
"""
|
||||
# TODO: Implement proper DataSource subscription integration
|
||||
# For now, return placeholder
|
||||
|
||||
return {
|
||||
"status": "not_implemented",
|
||||
"message": "Data-driven agent triggers coming soon",
|
||||
"config": {
|
||||
"source": source_name,
|
||||
"symbol": symbol,
|
||||
"resolution": resolution,
|
||||
"prompt_template": prompt_template,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
def get_trigger_system_stats() -> Dict[str, Any]:
|
||||
"""Get statistics about the trigger system.
|
||||
|
||||
Returns:
|
||||
Dictionary with queue depth, execution stats, etc.
|
||||
|
||||
Example:
|
||||
stats = get_trigger_system_stats()
|
||||
print(f"Queue depth: {stats['queue_depth']}")
|
||||
print(f"Current seq: {stats['current_seq']}")
|
||||
"""
|
||||
queue = _get_trigger_queue()
|
||||
coordinator = _get_coordinator()
|
||||
|
||||
return {
|
||||
"queue_depth": queue.get_queue_size(),
|
||||
"queue_running": queue.is_running(),
|
||||
"coordinator_stats": coordinator.get_stats(),
|
||||
}
|
||||
|
||||
|
||||
# Export tools list
|
||||
TRIGGER_TOOLS = [
|
||||
schedule_agent_prompt,
|
||||
execute_agent_prompt_once,
|
||||
list_scheduled_triggers,
|
||||
cancel_scheduled_trigger,
|
||||
on_data_update_run_agent,
|
||||
get_trigger_system_stats,
|
||||
]
|
||||
@@ -149,6 +149,10 @@ from .talib_adapter import (
|
||||
is_talib_available,
|
||||
get_talib_version,
|
||||
)
|
||||
from .custom_indicators import (
|
||||
register_custom_indicators,
|
||||
CUSTOM_INDICATORS,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Core classes
|
||||
@@ -169,4 +173,7 @@ __all__ = [
|
||||
"register_all_talib_indicators",
|
||||
"is_talib_available",
|
||||
"get_talib_version",
|
||||
# Custom indicators
|
||||
"register_custom_indicators",
|
||||
"CUSTOM_INDICATORS",
|
||||
]
|
||||
954
backend.old/src/indicator/custom_indicators.py
Normal file
954
backend.old/src/indicator/custom_indicators.py
Normal file
@@ -0,0 +1,954 @@
|
||||
"""
|
||||
Custom indicator implementations for TradingView indicators not in TA-Lib.
|
||||
|
||||
These indicators follow TA-Lib style conventions and integrate seamlessly
|
||||
with the indicator framework. All implementations are based on well-known,
|
||||
publicly documented formulas.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
import numpy as np
|
||||
|
||||
from datasource.schema import ColumnInfo
|
||||
from .base import Indicator
|
||||
from .schema import (
|
||||
ComputeContext,
|
||||
ComputeResult,
|
||||
IndicatorMetadata,
|
||||
IndicatorParameter,
|
||||
InputSchema,
|
||||
OutputSchema,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VWAP(Indicator):
|
||||
"""Volume Weighted Average Price - Most widely used institutional indicator."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="VWAP",
|
||||
display_name="VWAP",
|
||||
description="Volume Weighted Average Price - Average price weighted by volume",
|
||||
category="volume",
|
||||
parameters=[],
|
||||
use_cases=[
|
||||
"Institutional reference price",
|
||||
"Support/resistance levels",
|
||||
"Mean reversion trading"
|
||||
],
|
||||
references=["https://www.investopedia.com/terms/v/vwap.asp"],
|
||||
tags=["vwap", "volume", "institutional"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
ColumnInfo(name="volume", type="float", description="Volume"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="vwap", type="float", description="Volume Weighted Average Price", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
volume = np.array([float(v) if v is not None else np.nan for v in context.get_column("volume")])
|
||||
|
||||
# Typical price
|
||||
typical_price = (high + low + close) / 3.0
|
||||
|
||||
# VWAP = cumsum(typical_price * volume) / cumsum(volume)
|
||||
cumulative_tp_vol = np.nancumsum(typical_price * volume)
|
||||
cumulative_vol = np.nancumsum(volume)
|
||||
|
||||
vwap = cumulative_tp_vol / cumulative_vol
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "vwap": float(vwap[i]) if not np.isnan(vwap[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class VWMA(Indicator):
|
||||
"""Volume Weighted Moving Average."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="VWMA",
|
||||
display_name="VWMA",
|
||||
description="Volume Weighted Moving Average - Moving average weighted by volume",
|
||||
category="overlap",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="Period length",
|
||||
default=20,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Volume-aware trend following", "Dynamic support/resistance"],
|
||||
references=["https://www.investopedia.com/articles/trading/11/trading-with-vwap-mvwap.asp"],
|
||||
tags=["vwma", "volume", "moving average"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
ColumnInfo(name="volume", type="float", description="Volume"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="vwma", type="float", description="Volume Weighted Moving Average", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
volume = np.array([float(v) if v is not None else np.nan for v in context.get_column("volume")])
|
||||
length = self.params.get("length", 20)
|
||||
|
||||
vwma = np.full_like(close, np.nan)
|
||||
|
||||
for i in range(length - 1, len(close)):
|
||||
window_close = close[i - length + 1:i + 1]
|
||||
window_volume = volume[i - length + 1:i + 1]
|
||||
vwma[i] = np.sum(window_close * window_volume) / np.sum(window_volume)
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "vwma": float(vwma[i]) if not np.isnan(vwma[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class HullMA(Indicator):
|
||||
"""Hull Moving Average - Fast and smooth moving average."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="HMA",
|
||||
display_name="Hull Moving Average",
|
||||
description="Hull Moving Average - Reduces lag while maintaining smoothness",
|
||||
category="overlap",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="Period length",
|
||||
default=9,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Low-lag trend following", "Quick trend reversal detection"],
|
||||
references=["https://alanhull.com/hull-moving-average"],
|
||||
tags=["hma", "hull", "moving average", "low-lag"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="hma", type="float", description="Hull Moving Average", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
length = self.params.get("length", 9)
|
||||
|
||||
def wma(data, period):
|
||||
"""Weighted Moving Average."""
|
||||
weights = np.arange(1, period + 1)
|
||||
result = np.full_like(data, np.nan)
|
||||
for i in range(period - 1, len(data)):
|
||||
window = data[i - period + 1:i + 1]
|
||||
result[i] = np.sum(weights * window) / np.sum(weights)
|
||||
return result
|
||||
|
||||
# HMA = WMA(2 * WMA(n/2) - WMA(n)), sqrt(n))
|
||||
half_length = length // 2
|
||||
sqrt_length = int(np.sqrt(length))
|
||||
|
||||
wma_half = wma(close, half_length)
|
||||
wma_full = wma(close, length)
|
||||
raw_hma = 2 * wma_half - wma_full
|
||||
hma = wma(raw_hma, sqrt_length)
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "hma": float(hma[i]) if not np.isnan(hma[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class SuperTrend(Indicator):
|
||||
"""SuperTrend - Popular trend following indicator."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="SUPERTREND",
|
||||
display_name="SuperTrend",
|
||||
description="SuperTrend - Volatility-based trend indicator",
|
||||
category="overlap",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="ATR period",
|
||||
default=10,
|
||||
min_value=1,
|
||||
required=False
|
||||
),
|
||||
IndicatorParameter(
|
||||
name="multiplier",
|
||||
type="float",
|
||||
description="ATR multiplier",
|
||||
default=3.0,
|
||||
min_value=0.1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Trend identification", "Stop loss placement", "Trend reversal signals"],
|
||||
references=["https://www.investopedia.com/articles/trading/08/supertrend-indicator.asp"],
|
||||
tags=["supertrend", "trend", "volatility"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="supertrend", type="float", description="SuperTrend value", nullable=True),
|
||||
ColumnInfo(name="direction", type="int", description="Trend direction (1=up, -1=down)", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
|
||||
length = self.params.get("length", 10)
|
||||
multiplier = self.params.get("multiplier", 3.0)
|
||||
|
||||
# Calculate ATR
|
||||
tr = np.maximum(high - low, np.maximum(np.abs(high - np.roll(close, 1)), np.abs(low - np.roll(close, 1))))
|
||||
tr[0] = high[0] - low[0]
|
||||
|
||||
atr = np.full_like(close, np.nan)
|
||||
atr[length - 1] = np.mean(tr[:length])
|
||||
for i in range(length, len(tr)):
|
||||
atr[i] = (atr[i - 1] * (length - 1) + tr[i]) / length
|
||||
|
||||
# Calculate basic bands
|
||||
hl2 = (high + low) / 2
|
||||
basic_upper = hl2 + multiplier * atr
|
||||
basic_lower = hl2 - multiplier * atr
|
||||
|
||||
# Calculate final bands
|
||||
final_upper = np.full_like(close, np.nan)
|
||||
final_lower = np.full_like(close, np.nan)
|
||||
supertrend = np.full_like(close, np.nan)
|
||||
direction = np.full_like(close, np.nan)
|
||||
|
||||
for i in range(length, len(close)):
|
||||
if i == length:
|
||||
final_upper[i] = basic_upper[i]
|
||||
final_lower[i] = basic_lower[i]
|
||||
else:
|
||||
final_upper[i] = basic_upper[i] if basic_upper[i] < final_upper[i - 1] or close[i - 1] > final_upper[i - 1] else final_upper[i - 1]
|
||||
final_lower[i] = basic_lower[i] if basic_lower[i] > final_lower[i - 1] or close[i - 1] < final_lower[i - 1] else final_lower[i - 1]
|
||||
|
||||
if i == length:
|
||||
supertrend[i] = final_upper[i] if close[i] <= hl2[i] else final_lower[i]
|
||||
direction[i] = -1 if close[i] <= hl2[i] else 1
|
||||
else:
|
||||
if supertrend[i - 1] == final_upper[i - 1] and close[i] <= final_upper[i]:
|
||||
supertrend[i] = final_upper[i]
|
||||
direction[i] = -1
|
||||
elif supertrend[i - 1] == final_upper[i - 1] and close[i] > final_upper[i]:
|
||||
supertrend[i] = final_lower[i]
|
||||
direction[i] = 1
|
||||
elif supertrend[i - 1] == final_lower[i - 1] and close[i] >= final_lower[i]:
|
||||
supertrend[i] = final_lower[i]
|
||||
direction[i] = 1
|
||||
else:
|
||||
supertrend[i] = final_upper[i]
|
||||
direction[i] = -1
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{
|
||||
"time": times[i],
|
||||
"supertrend": float(supertrend[i]) if not np.isnan(supertrend[i]) else None,
|
||||
"direction": int(direction[i]) if not np.isnan(direction[i]) else None
|
||||
}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class DonchianChannels(Indicator):
|
||||
"""Donchian Channels - Breakout indicator using highest high and lowest low."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="DONCHIAN",
|
||||
display_name="Donchian Channels",
|
||||
description="Donchian Channels - Highest high and lowest low over period",
|
||||
category="overlap",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="Period length",
|
||||
default=20,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Breakout trading", "Volatility bands", "Support/resistance"],
|
||||
references=["https://www.investopedia.com/terms/d/donchianchannels.asp"],
|
||||
tags=["donchian", "channels", "breakout"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="upper", type="float", description="Upper channel", nullable=True),
|
||||
ColumnInfo(name="middle", type="float", description="Middle line", nullable=True),
|
||||
ColumnInfo(name="lower", type="float", description="Lower channel", nullable=True),
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
length = self.params.get("length", 20)
|
||||
|
||||
upper = np.full_like(high, np.nan)
|
||||
lower = np.full_like(low, np.nan)
|
||||
|
||||
for i in range(length - 1, len(high)):
|
||||
upper[i] = np.nanmax(high[i - length + 1:i + 1])
|
||||
lower[i] = np.nanmin(low[i - length + 1:i + 1])
|
||||
|
||||
middle = (upper + lower) / 2
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{
|
||||
"time": times[i],
|
||||
"upper": float(upper[i]) if not np.isnan(upper[i]) else None,
|
||||
"middle": float(middle[i]) if not np.isnan(middle[i]) else None,
|
||||
"lower": float(lower[i]) if not np.isnan(lower[i]) else None,
|
||||
}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class KeltnerChannels(Indicator):
|
||||
"""Keltner Channels - ATR-based volatility bands."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="KELTNER",
|
||||
display_name="Keltner Channels",
|
||||
description="Keltner Channels - EMA with ATR-based bands",
|
||||
category="volatility",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="EMA period",
|
||||
default=20,
|
||||
min_value=1,
|
||||
required=False
|
||||
),
|
||||
IndicatorParameter(
|
||||
name="multiplier",
|
||||
type="float",
|
||||
description="ATR multiplier",
|
||||
default=2.0,
|
||||
min_value=0.1,
|
||||
required=False
|
||||
),
|
||||
IndicatorParameter(
|
||||
name="atr_length",
|
||||
type="int",
|
||||
description="ATR period",
|
||||
default=10,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Volatility bands", "Overbought/oversold", "Trend strength"],
|
||||
references=["https://www.investopedia.com/terms/k/keltnerchannel.asp"],
|
||||
tags=["keltner", "channels", "volatility", "atr"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="upper", type="float", description="Upper band", nullable=True),
|
||||
ColumnInfo(name="middle", type="float", description="Middle line (EMA)", nullable=True),
|
||||
ColumnInfo(name="lower", type="float", description="Lower band", nullable=True),
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
|
||||
length = self.params.get("length", 20)
|
||||
multiplier = self.params.get("multiplier", 2.0)
|
||||
atr_length = self.params.get("atr_length", 10)
|
||||
|
||||
# Calculate EMA
|
||||
alpha = 2.0 / (length + 1)
|
||||
ema = np.full_like(close, np.nan)
|
||||
ema[0] = close[0]
|
||||
for i in range(1, len(close)):
|
||||
ema[i] = alpha * close[i] + (1 - alpha) * ema[i - 1]
|
||||
|
||||
# Calculate ATR
|
||||
tr = np.maximum(high - low, np.maximum(np.abs(high - np.roll(close, 1)), np.abs(low - np.roll(close, 1))))
|
||||
tr[0] = high[0] - low[0]
|
||||
|
||||
atr = np.full_like(close, np.nan)
|
||||
atr[atr_length - 1] = np.mean(tr[:atr_length])
|
||||
for i in range(atr_length, len(tr)):
|
||||
atr[i] = (atr[i - 1] * (atr_length - 1) + tr[i]) / atr_length
|
||||
|
||||
upper = ema + multiplier * atr
|
||||
lower = ema - multiplier * atr
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{
|
||||
"time": times[i],
|
||||
"upper": float(upper[i]) if not np.isnan(upper[i]) else None,
|
||||
"middle": float(ema[i]) if not np.isnan(ema[i]) else None,
|
||||
"lower": float(lower[i]) if not np.isnan(lower[i]) else None,
|
||||
}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class ChaikinMoneyFlow(Indicator):
|
||||
"""Chaikin Money Flow - Volume-weighted accumulation/distribution."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="CMF",
|
||||
display_name="Chaikin Money Flow",
|
||||
description="Chaikin Money Flow - Measures buying and selling pressure",
|
||||
category="volume",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="Period length",
|
||||
default=20,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Buying/selling pressure", "Trend confirmation", "Divergence analysis"],
|
||||
references=["https://www.investopedia.com/terms/c/chaikinoscillator.asp"],
|
||||
tags=["cmf", "chaikin", "volume", "money flow"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
ColumnInfo(name="volume", type="float", description="Volume"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="cmf", type="float", description="Chaikin Money Flow", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
volume = np.array([float(v) if v is not None else np.nan for v in context.get_column("volume")])
|
||||
length = self.params.get("length", 20)
|
||||
|
||||
# Money Flow Multiplier
|
||||
mfm = ((close - low) - (high - close)) / (high - low)
|
||||
mfm = np.where(high == low, 0, mfm)
|
||||
|
||||
# Money Flow Volume
|
||||
mfv = mfm * volume
|
||||
|
||||
# CMF
|
||||
cmf = np.full_like(close, np.nan)
|
||||
for i in range(length - 1, len(close)):
|
||||
cmf[i] = np.nansum(mfv[i - length + 1:i + 1]) / np.nansum(volume[i - length + 1:i + 1])
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "cmf": float(cmf[i]) if not np.isnan(cmf[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class VortexIndicator(Indicator):
|
||||
"""Vortex Indicator - Identifies trend direction and strength."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="VORTEX",
|
||||
display_name="Vortex Indicator",
|
||||
description="Vortex Indicator - Trend direction and strength",
|
||||
category="momentum",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="Period length",
|
||||
default=14,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Trend identification", "Trend reversals", "Trend strength"],
|
||||
references=["https://www.investopedia.com/terms/v/vortex-indicator-vi.asp"],
|
||||
tags=["vortex", "trend", "momentum"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="vi_plus", type="float", description="Positive Vortex", nullable=True),
|
||||
ColumnInfo(name="vi_minus", type="float", description="Negative Vortex", nullable=True),
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
length = self.params.get("length", 14)
|
||||
|
||||
# Vortex Movement
|
||||
vm_plus = np.abs(high - np.roll(low, 1))
|
||||
vm_minus = np.abs(low - np.roll(high, 1))
|
||||
vm_plus[0] = 0
|
||||
vm_minus[0] = 0
|
||||
|
||||
# True Range
|
||||
tr = np.maximum(high - low, np.maximum(np.abs(high - np.roll(close, 1)), np.abs(low - np.roll(close, 1))))
|
||||
tr[0] = high[0] - low[0]
|
||||
|
||||
# Vortex Indicator
|
||||
vi_plus = np.full_like(close, np.nan)
|
||||
vi_minus = np.full_like(close, np.nan)
|
||||
|
||||
for i in range(length - 1, len(close)):
|
||||
sum_vm_plus = np.sum(vm_plus[i - length + 1:i + 1])
|
||||
sum_vm_minus = np.sum(vm_minus[i - length + 1:i + 1])
|
||||
sum_tr = np.sum(tr[i - length + 1:i + 1])
|
||||
|
||||
if sum_tr != 0:
|
||||
vi_plus[i] = sum_vm_plus / sum_tr
|
||||
vi_minus[i] = sum_vm_minus / sum_tr
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{
|
||||
"time": times[i],
|
||||
"vi_plus": float(vi_plus[i]) if not np.isnan(vi_plus[i]) else None,
|
||||
"vi_minus": float(vi_minus[i]) if not np.isnan(vi_minus[i]) else None,
|
||||
}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class AwesomeOscillator(Indicator):
|
||||
"""Awesome Oscillator - Bill Williams' momentum indicator."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="AO",
|
||||
display_name="Awesome Oscillator",
|
||||
description="Awesome Oscillator - Difference between 5 and 34 period SMAs of midpoint",
|
||||
category="momentum",
|
||||
parameters=[],
|
||||
use_cases=["Momentum shifts", "Trend reversals", "Divergence trading"],
|
||||
references=["https://www.investopedia.com/terms/a/awesomeoscillator.asp"],
|
||||
tags=["awesome", "oscillator", "momentum", "williams"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="ao", type="float", description="Awesome Oscillator", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
|
||||
midpoint = (high + low) / 2
|
||||
|
||||
# SMA 5
|
||||
sma5 = np.full_like(midpoint, np.nan)
|
||||
for i in range(4, len(midpoint)):
|
||||
sma5[i] = np.mean(midpoint[i - 4:i + 1])
|
||||
|
||||
# SMA 34
|
||||
sma34 = np.full_like(midpoint, np.nan)
|
||||
for i in range(33, len(midpoint)):
|
||||
sma34[i] = np.mean(midpoint[i - 33:i + 1])
|
||||
|
||||
ao = sma5 - sma34
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "ao": float(ao[i]) if not np.isnan(ao[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class AcceleratorOscillator(Indicator):
|
||||
"""Accelerator Oscillator - Rate of change of Awesome Oscillator."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="AC",
|
||||
display_name="Accelerator Oscillator",
|
||||
description="Accelerator Oscillator - Rate of change of Awesome Oscillator",
|
||||
category="momentum",
|
||||
parameters=[],
|
||||
use_cases=["Early momentum detection", "Trend acceleration", "Divergence signals"],
|
||||
references=["https://www.investopedia.com/terms/a/accelerator-oscillator.asp"],
|
||||
tags=["accelerator", "oscillator", "momentum", "williams"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="ac", type="float", description="Accelerator Oscillator", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
|
||||
midpoint = (high + low) / 2
|
||||
|
||||
# Calculate AO first
|
||||
sma5 = np.full_like(midpoint, np.nan)
|
||||
for i in range(4, len(midpoint)):
|
||||
sma5[i] = np.mean(midpoint[i - 4:i + 1])
|
||||
|
||||
sma34 = np.full_like(midpoint, np.nan)
|
||||
for i in range(33, len(midpoint)):
|
||||
sma34[i] = np.mean(midpoint[i - 33:i + 1])
|
||||
|
||||
ao = sma5 - sma34
|
||||
|
||||
# AC = AO - SMA(AO, 5)
|
||||
sma_ao = np.full_like(ao, np.nan)
|
||||
for i in range(4, len(ao)):
|
||||
if not np.isnan(ao[i - 4:i + 1]).any():
|
||||
sma_ao[i] = np.mean(ao[i - 4:i + 1])
|
||||
|
||||
ac = ao - sma_ao
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "ac": float(ac[i]) if not np.isnan(ac[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class ChoppinessIndex(Indicator):
|
||||
"""Choppiness Index - Determines if market is choppy or trending."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="CHOP",
|
||||
display_name="Choppiness Index",
|
||||
description="Choppiness Index - Measures market trendiness vs consolidation",
|
||||
category="volatility",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="length",
|
||||
type="int",
|
||||
description="Period length",
|
||||
default=14,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Trend vs range identification", "Market regime detection"],
|
||||
references=["https://www.tradingview.com/support/solutions/43000501980/"],
|
||||
tags=["chop", "choppiness", "trend", "range"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
ColumnInfo(name="close", type="float", description="Close price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="chop", type="float", description="Choppiness Index (0-100)", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
close = np.array([float(v) if v is not None else np.nan for v in context.get_column("close")])
|
||||
length = self.params.get("length", 14)
|
||||
|
||||
# True Range
|
||||
tr = np.maximum(high - low, np.maximum(np.abs(high - np.roll(close, 1)), np.abs(low - np.roll(close, 1))))
|
||||
tr[0] = high[0] - low[0]
|
||||
|
||||
chop = np.full_like(close, np.nan)
|
||||
|
||||
for i in range(length - 1, len(close)):
|
||||
sum_tr = np.sum(tr[i - length + 1:i + 1])
|
||||
high_low_diff = np.max(high[i - length + 1:i + 1]) - np.min(low[i - length + 1:i + 1])
|
||||
|
||||
if high_low_diff != 0:
|
||||
chop[i] = 100 * np.log10(sum_tr / high_low_diff) / np.log10(length)
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "chop": float(chop[i]) if not np.isnan(chop[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
class MassIndex(Indicator):
|
||||
"""Mass Index - Identifies trend reversals based on range expansion."""
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> IndicatorMetadata:
|
||||
return IndicatorMetadata(
|
||||
name="MASS",
|
||||
display_name="Mass Index",
|
||||
description="Mass Index - Identifies reversals when range narrows then expands",
|
||||
category="volatility",
|
||||
parameters=[
|
||||
IndicatorParameter(
|
||||
name="fast_period",
|
||||
type="int",
|
||||
description="Fast EMA period",
|
||||
default=9,
|
||||
min_value=1,
|
||||
required=False
|
||||
),
|
||||
IndicatorParameter(
|
||||
name="slow_period",
|
||||
type="int",
|
||||
description="Slow EMA period",
|
||||
default=25,
|
||||
min_value=1,
|
||||
required=False
|
||||
)
|
||||
],
|
||||
use_cases=["Reversal detection", "Volatility analysis", "Bulge identification"],
|
||||
references=["https://www.investopedia.com/terms/m/mass-index.asp"],
|
||||
tags=["mass", "index", "volatility", "reversal"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls) -> InputSchema:
|
||||
return InputSchema(required_columns=[
|
||||
ColumnInfo(name="high", type="float", description="High price"),
|
||||
ColumnInfo(name="low", type="float", description="Low price"),
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def get_output_schema(cls, **params) -> OutputSchema:
|
||||
return OutputSchema(columns=[
|
||||
ColumnInfo(name="mass", type="float", description="Mass Index", nullable=True)
|
||||
])
|
||||
|
||||
def compute(self, context: ComputeContext) -> ComputeResult:
|
||||
high = np.array([float(v) if v is not None else np.nan for v in context.get_column("high")])
|
||||
low = np.array([float(v) if v is not None else np.nan for v in context.get_column("low")])
|
||||
|
||||
fast_period = self.params.get("fast_period", 9)
|
||||
slow_period = self.params.get("slow_period", 25)
|
||||
|
||||
hl_range = high - low
|
||||
|
||||
# Single EMA
|
||||
alpha1 = 2.0 / (fast_period + 1)
|
||||
ema1 = np.full_like(hl_range, np.nan)
|
||||
ema1[0] = hl_range[0]
|
||||
for i in range(1, len(hl_range)):
|
||||
ema1[i] = alpha1 * hl_range[i] + (1 - alpha1) * ema1[i - 1]
|
||||
|
||||
# Double EMA
|
||||
ema2 = np.full_like(ema1, np.nan)
|
||||
ema2[0] = ema1[0]
|
||||
for i in range(1, len(ema1)):
|
||||
if not np.isnan(ema1[i]):
|
||||
ema2[i] = alpha1 * ema1[i] + (1 - alpha1) * ema2[i - 1]
|
||||
|
||||
# EMA Ratio
|
||||
ema_ratio = ema1 / ema2
|
||||
|
||||
# Mass Index
|
||||
mass = np.full_like(hl_range, np.nan)
|
||||
for i in range(slow_period - 1, len(ema_ratio)):
|
||||
mass[i] = np.nansum(ema_ratio[i - slow_period + 1:i + 1])
|
||||
|
||||
times = context.get_times()
|
||||
result_data = [
|
||||
{"time": times[i], "mass": float(mass[i]) if not np.isnan(mass[i]) else None}
|
||||
for i in range(len(times))
|
||||
]
|
||||
|
||||
return ComputeResult(data=result_data, is_partial=context.is_incremental)
|
||||
|
||||
|
||||
# Registry of all custom indicators
|
||||
CUSTOM_INDICATORS = [
|
||||
VWAP,
|
||||
VWMA,
|
||||
HullMA,
|
||||
SuperTrend,
|
||||
DonchianChannels,
|
||||
KeltnerChannels,
|
||||
ChaikinMoneyFlow,
|
||||
VortexIndicator,
|
||||
AwesomeOscillator,
|
||||
AcceleratorOscillator,
|
||||
ChoppinessIndex,
|
||||
MassIndex,
|
||||
]
|
||||
|
||||
|
||||
def register_custom_indicators(registry) -> int:
|
||||
"""
|
||||
Register all custom indicators with the registry.
|
||||
|
||||
Args:
|
||||
registry: IndicatorRegistry instance
|
||||
|
||||
Returns:
|
||||
Number of indicators registered
|
||||
"""
|
||||
registered_count = 0
|
||||
|
||||
for indicator_class in CUSTOM_INDICATORS:
|
||||
try:
|
||||
registry.register(indicator_class)
|
||||
registered_count += 1
|
||||
logger.debug(f"Registered custom indicator: {indicator_class.__name__}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register custom indicator {indicator_class.__name__}: {e}")
|
||||
|
||||
logger.info(f"Registered {registered_count} custom indicators")
|
||||
return registered_count
|
||||
@@ -372,12 +372,14 @@ def create_talib_indicator_class(func_name: str) -> type:
|
||||
)
|
||||
|
||||
|
||||
def register_all_talib_indicators(registry) -> int:
|
||||
def register_all_talib_indicators(registry, only_tradingview_supported: bool = True) -> int:
|
||||
"""
|
||||
Auto-register all available TA-Lib indicators with the registry.
|
||||
|
||||
Args:
|
||||
registry: IndicatorRegistry instance
|
||||
only_tradingview_supported: If True, only register indicators that have
|
||||
TradingView equivalents (default: True)
|
||||
|
||||
Returns:
|
||||
Number of indicators registered
|
||||
@@ -392,6 +394,9 @@ def register_all_talib_indicators(registry) -> int:
|
||||
)
|
||||
return 0
|
||||
|
||||
# Get list of supported indicators if filtering is enabled
|
||||
from .tv_mapping import is_indicator_supported
|
||||
|
||||
# Get all TA-Lib functions
|
||||
func_groups = talib.get_function_groups()
|
||||
all_functions = []
|
||||
@@ -402,8 +407,16 @@ def register_all_talib_indicators(registry) -> int:
|
||||
all_functions = sorted(set(all_functions))
|
||||
|
||||
registered_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
for func_name in all_functions:
|
||||
try:
|
||||
# Skip if filtering enabled and indicator not supported in TradingView
|
||||
if only_tradingview_supported and not is_indicator_supported(func_name):
|
||||
skipped_count += 1
|
||||
logger.debug(f"Skipping TA-Lib function {func_name} - not supported in TradingView")
|
||||
continue
|
||||
|
||||
# Create indicator class for this function
|
||||
indicator_class = create_talib_indicator_class(func_name)
|
||||
|
||||
@@ -415,7 +428,7 @@ def register_all_talib_indicators(registry) -> int:
|
||||
logger.warning(f"Failed to register TA-Lib function {func_name}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Registered {registered_count} TA-Lib indicators")
|
||||
logger.info(f"Registered {registered_count} TA-Lib indicators (skipped {skipped_count} unsupported)")
|
||||
return registered_count
|
||||
|
||||
|
||||
360
backend.old/src/indicator/tv_mapping.py
Normal file
360
backend.old/src/indicator/tv_mapping.py
Normal file
@@ -0,0 +1,360 @@
|
||||
"""
|
||||
Mapping layer between TA-Lib indicators and TradingView indicators.
|
||||
|
||||
This module provides bidirectional conversion between our internal TA-Lib-based
|
||||
indicator representation and TradingView's indicator system.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, Tuple, List
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Mapping of TA-Lib indicator names to TradingView indicator names
|
||||
# Only includes indicators that are present in BOTH systems (inner join)
|
||||
# Format: {talib_name: tv_name}
|
||||
TALIB_TO_TV_NAMES = {
|
||||
# Overlap Studies (14)
|
||||
"SMA": "Moving Average",
|
||||
"EMA": "Moving Average Exponential",
|
||||
"WMA": "Weighted Moving Average",
|
||||
"DEMA": "DEMA",
|
||||
"TEMA": "TEMA",
|
||||
"TRIMA": "Triangular Moving Average",
|
||||
"KAMA": "KAMA",
|
||||
"MAMA": "MESA Adaptive Moving Average",
|
||||
"T3": "T3",
|
||||
"BBANDS": "Bollinger Bands",
|
||||
"MIDPOINT": "Midpoint",
|
||||
"MIDPRICE": "Midprice",
|
||||
"SAR": "Parabolic SAR",
|
||||
"HT_TRENDLINE": "Hilbert Transform - Instantaneous Trendline",
|
||||
|
||||
# Momentum Indicators (21)
|
||||
"RSI": "Relative Strength Index",
|
||||
"MOM": "Momentum",
|
||||
"ROC": "Rate of Change",
|
||||
"TRIX": "TRIX",
|
||||
"CMO": "Chande Momentum Oscillator",
|
||||
"DX": "Directional Movement Index",
|
||||
"ADX": "Average Directional Movement Index",
|
||||
"ADXR": "Average Directional Movement Index Rating",
|
||||
"APO": "Absolute Price Oscillator",
|
||||
"PPO": "Percentage Price Oscillator",
|
||||
"MACD": "MACD",
|
||||
"MFI": "Money Flow Index",
|
||||
"STOCH": "Stochastic",
|
||||
"STOCHF": "Stochastic Fast",
|
||||
"STOCHRSI": "Stochastic RSI",
|
||||
"WILLR": "Williams %R",
|
||||
"CCI": "Commodity Channel Index",
|
||||
"AROON": "Aroon",
|
||||
"AROONOSC": "Aroon Oscillator",
|
||||
"BOP": "Balance Of Power",
|
||||
"ULTOSC": "Ultimate Oscillator",
|
||||
|
||||
# Volume Indicators (3)
|
||||
"AD": "Chaikin A/D Line",
|
||||
"ADOSC": "Chaikin A/D Oscillator",
|
||||
"OBV": "On Balance Volume",
|
||||
|
||||
# Volatility Indicators (3)
|
||||
"ATR": "Average True Range",
|
||||
"NATR": "Normalized Average True Range",
|
||||
"TRANGE": "True Range",
|
||||
|
||||
# Price Transform (4)
|
||||
"AVGPRICE": "Average Price",
|
||||
"MEDPRICE": "Median Price",
|
||||
"TYPPRICE": "Typical Price",
|
||||
"WCLPRICE": "Weighted Close Price",
|
||||
|
||||
# Cycle Indicators (5)
|
||||
"HT_DCPERIOD": "Hilbert Transform - Dominant Cycle Period",
|
||||
"HT_DCPHASE": "Hilbert Transform - Dominant Cycle Phase",
|
||||
"HT_PHASOR": "Hilbert Transform - Phasor Components",
|
||||
"HT_SINE": "Hilbert Transform - SineWave",
|
||||
"HT_TRENDMODE": "Hilbert Transform - Trend vs Cycle Mode",
|
||||
|
||||
# Statistic Functions (9)
|
||||
"BETA": "Beta",
|
||||
"CORREL": "Pearson's Correlation Coefficient",
|
||||
"LINEARREG": "Linear Regression",
|
||||
"LINEARREG_ANGLE": "Linear Regression Angle",
|
||||
"LINEARREG_INTERCEPT": "Linear Regression Intercept",
|
||||
"LINEARREG_SLOPE": "Linear Regression Slope",
|
||||
"STDDEV": "Standard Deviation",
|
||||
"TSF": "Time Series Forecast",
|
||||
"VAR": "Variance",
|
||||
}
|
||||
|
||||
# Total: 60 indicators supported in both systems
|
||||
|
||||
# Custom indicators (TradingView indicators implemented in our backend)
|
||||
CUSTOM_TO_TV_NAMES = {
|
||||
"VWAP": "VWAP",
|
||||
"VWMA": "VWMA",
|
||||
"HMA": "Hull Moving Average",
|
||||
"SUPERTREND": "SuperTrend",
|
||||
"DONCHIAN": "Donchian Channels",
|
||||
"KELTNER": "Keltner Channels",
|
||||
"CMF": "Chaikin Money Flow",
|
||||
"VORTEX": "Vortex Indicator",
|
||||
"AO": "Awesome Oscillator",
|
||||
"AC": "Accelerator Oscillator",
|
||||
"CHOP": "Choppiness Index",
|
||||
"MASS": "Mass Index",
|
||||
}
|
||||
|
||||
# Combined mapping (TA-Lib + Custom)
|
||||
ALL_BACKEND_TO_TV_NAMES = {**TALIB_TO_TV_NAMES, **CUSTOM_TO_TV_NAMES}
|
||||
|
||||
# Total: 72 indicators (60 TA-Lib + 12 Custom)
|
||||
|
||||
# Reverse mapping
|
||||
TV_TO_TALIB_NAMES = {v: k for k, v in TALIB_TO_TV_NAMES.items()}
|
||||
TV_TO_CUSTOM_NAMES = {v: k for k, v in CUSTOM_TO_TV_NAMES.items()}
|
||||
TV_TO_BACKEND_NAMES = {v: k for k, v in ALL_BACKEND_TO_TV_NAMES.items()}
|
||||
|
||||
|
||||
def get_tv_indicator_name(talib_name: str) -> str:
|
||||
"""
|
||||
Convert TA-Lib indicator name to TradingView indicator name.
|
||||
|
||||
Args:
|
||||
talib_name: TA-Lib indicator name (e.g., 'RSI')
|
||||
|
||||
Returns:
|
||||
TradingView indicator name
|
||||
"""
|
||||
return TALIB_TO_TV_NAMES.get(talib_name, talib_name)
|
||||
|
||||
|
||||
def get_talib_indicator_name(tv_name: str) -> Optional[str]:
|
||||
"""
|
||||
Convert TradingView indicator name to TA-Lib indicator name.
|
||||
|
||||
Args:
|
||||
tv_name: TradingView indicator name
|
||||
|
||||
Returns:
|
||||
TA-Lib indicator name or None if not mapped
|
||||
"""
|
||||
return TV_TO_TALIB_NAMES.get(tv_name)
|
||||
|
||||
|
||||
def convert_talib_params_to_tv_inputs(
|
||||
talib_name: str,
|
||||
talib_params: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert TA-Lib parameters to TradingView input format.
|
||||
|
||||
Args:
|
||||
talib_name: TA-Lib indicator name
|
||||
talib_params: TA-Lib parameter dictionary
|
||||
|
||||
Returns:
|
||||
TradingView inputs dictionary
|
||||
"""
|
||||
tv_inputs = {}
|
||||
|
||||
# Common parameter mappings
|
||||
param_mapping = {
|
||||
"timeperiod": "length",
|
||||
"fastperiod": "fastLength",
|
||||
"slowperiod": "slowLength",
|
||||
"signalperiod": "signalLength",
|
||||
"nbdevup": "mult", # Standard deviations for upper band
|
||||
"nbdevdn": "mult", # Standard deviations for lower band
|
||||
"fastlimit": "fastLimit",
|
||||
"slowlimit": "slowLimit",
|
||||
"acceleration": "start",
|
||||
"maximum": "increment",
|
||||
"fastk_period": "kPeriod",
|
||||
"slowk_period": "kPeriod",
|
||||
"slowd_period": "dPeriod",
|
||||
"fastd_period": "dPeriod",
|
||||
"matype": "maType",
|
||||
}
|
||||
|
||||
# Special handling for specific indicators
|
||||
if talib_name == "BBANDS":
|
||||
# Bollinger Bands
|
||||
tv_inputs["length"] = talib_params.get("timeperiod", 20)
|
||||
tv_inputs["mult"] = talib_params.get("nbdevup", 2)
|
||||
tv_inputs["source"] = "close"
|
||||
elif talib_name == "MACD":
|
||||
# MACD
|
||||
tv_inputs["fastLength"] = talib_params.get("fastperiod", 12)
|
||||
tv_inputs["slowLength"] = talib_params.get("slowperiod", 26)
|
||||
tv_inputs["signalLength"] = talib_params.get("signalperiod", 9)
|
||||
tv_inputs["source"] = "close"
|
||||
elif talib_name == "RSI":
|
||||
# RSI
|
||||
tv_inputs["length"] = talib_params.get("timeperiod", 14)
|
||||
tv_inputs["source"] = "close"
|
||||
elif talib_name in ["SMA", "EMA", "WMA", "DEMA", "TEMA", "TRIMA"]:
|
||||
# Moving averages
|
||||
tv_inputs["length"] = talib_params.get("timeperiod", 14)
|
||||
tv_inputs["source"] = "close"
|
||||
elif talib_name == "STOCH":
|
||||
# Stochastic
|
||||
tv_inputs["kPeriod"] = talib_params.get("fastk_period", 14)
|
||||
tv_inputs["dPeriod"] = talib_params.get("slowd_period", 3)
|
||||
tv_inputs["smoothK"] = talib_params.get("slowk_period", 3)
|
||||
elif talib_name == "ATR":
|
||||
# ATR
|
||||
tv_inputs["length"] = talib_params.get("timeperiod", 14)
|
||||
elif talib_name == "CCI":
|
||||
# CCI
|
||||
tv_inputs["length"] = talib_params.get("timeperiod", 20)
|
||||
else:
|
||||
# Generic parameter conversion
|
||||
for talib_param, value in talib_params.items():
|
||||
tv_param = param_mapping.get(talib_param, talib_param)
|
||||
tv_inputs[tv_param] = value
|
||||
|
||||
logger.debug(f"Converted TA-Lib params for {talib_name}: {talib_params} -> TV inputs: {tv_inputs}")
|
||||
return tv_inputs
|
||||
|
||||
|
||||
def convert_tv_inputs_to_talib_params(
|
||||
tv_name: str,
|
||||
tv_inputs: Dict[str, Any]
|
||||
) -> Tuple[Optional[str], Dict[str, Any]]:
|
||||
"""
|
||||
Convert TradingView inputs to TA-Lib parameters.
|
||||
|
||||
Args:
|
||||
tv_name: TradingView indicator name
|
||||
tv_inputs: TradingView inputs dictionary
|
||||
|
||||
Returns:
|
||||
Tuple of (talib_name, talib_params)
|
||||
"""
|
||||
talib_name = get_talib_indicator_name(tv_name)
|
||||
if not talib_name:
|
||||
logger.warning(f"No TA-Lib mapping for TradingView indicator: {tv_name}")
|
||||
return None, {}
|
||||
|
||||
talib_params = {}
|
||||
|
||||
# Reverse parameter mappings
|
||||
reverse_mapping = {
|
||||
"length": "timeperiod",
|
||||
"fastLength": "fastperiod",
|
||||
"slowLength": "slowperiod",
|
||||
"signalLength": "signalperiod",
|
||||
"mult": "nbdevup", # Use same for both up and down
|
||||
"fastLimit": "fastlimit",
|
||||
"slowLimit": "slowlimit",
|
||||
"start": "acceleration",
|
||||
"increment": "maximum",
|
||||
"kPeriod": "fastk_period",
|
||||
"dPeriod": "slowd_period",
|
||||
"smoothK": "slowk_period",
|
||||
"maType": "matype",
|
||||
}
|
||||
|
||||
# Special handling for specific indicators
|
||||
if talib_name == "BBANDS":
|
||||
# Bollinger Bands
|
||||
talib_params["timeperiod"] = tv_inputs.get("length", 20)
|
||||
talib_params["nbdevup"] = tv_inputs.get("mult", 2)
|
||||
talib_params["nbdevdn"] = tv_inputs.get("mult", 2)
|
||||
talib_params["matype"] = 0 # SMA
|
||||
elif talib_name == "MACD":
|
||||
# MACD
|
||||
talib_params["fastperiod"] = tv_inputs.get("fastLength", 12)
|
||||
talib_params["slowperiod"] = tv_inputs.get("slowLength", 26)
|
||||
talib_params["signalperiod"] = tv_inputs.get("signalLength", 9)
|
||||
elif talib_name == "RSI":
|
||||
# RSI
|
||||
talib_params["timeperiod"] = tv_inputs.get("length", 14)
|
||||
elif talib_name in ["SMA", "EMA", "WMA", "DEMA", "TEMA", "TRIMA"]:
|
||||
# Moving averages
|
||||
talib_params["timeperiod"] = tv_inputs.get("length", 14)
|
||||
elif talib_name == "STOCH":
|
||||
# Stochastic
|
||||
talib_params["fastk_period"] = tv_inputs.get("kPeriod", 14)
|
||||
talib_params["slowd_period"] = tv_inputs.get("dPeriod", 3)
|
||||
talib_params["slowk_period"] = tv_inputs.get("smoothK", 3)
|
||||
talib_params["slowk_matype"] = 0 # SMA
|
||||
talib_params["slowd_matype"] = 0 # SMA
|
||||
elif talib_name == "ATR":
|
||||
# ATR
|
||||
talib_params["timeperiod"] = tv_inputs.get("length", 14)
|
||||
elif talib_name == "CCI":
|
||||
# CCI
|
||||
talib_params["timeperiod"] = tv_inputs.get("length", 20)
|
||||
else:
|
||||
# Generic parameter conversion
|
||||
for tv_param, value in tv_inputs.items():
|
||||
if tv_param == "source":
|
||||
continue # Skip source parameter
|
||||
talib_param = reverse_mapping.get(tv_param, tv_param)
|
||||
talib_params[talib_param] = value
|
||||
|
||||
logger.debug(f"Converted TV inputs for {tv_name}: {tv_inputs} -> TA-Lib {talib_name} params: {talib_params}")
|
||||
return talib_name, talib_params
|
||||
|
||||
|
||||
def is_indicator_supported(talib_name: str) -> bool:
|
||||
"""
|
||||
Check if a TA-Lib indicator is supported in TradingView.
|
||||
|
||||
Args:
|
||||
talib_name: TA-Lib indicator name
|
||||
|
||||
Returns:
|
||||
True if supported
|
||||
"""
|
||||
return talib_name in TALIB_TO_TV_NAMES
|
||||
|
||||
|
||||
def get_supported_indicators() -> List[str]:
|
||||
"""
|
||||
Get list of supported TA-Lib indicators.
|
||||
|
||||
Returns:
|
||||
List of TA-Lib indicator names
|
||||
"""
|
||||
return list(TALIB_TO_TV_NAMES.keys())
|
||||
|
||||
|
||||
def get_supported_indicator_count() -> int:
|
||||
"""
|
||||
Get count of supported indicators.
|
||||
|
||||
Returns:
|
||||
Number of indicators supported in both systems (TA-Lib + Custom)
|
||||
"""
|
||||
return len(ALL_BACKEND_TO_TV_NAMES)
|
||||
|
||||
|
||||
def is_custom_indicator(indicator_name: str) -> bool:
|
||||
"""
|
||||
Check if an indicator is a custom implementation (not TA-Lib).
|
||||
|
||||
Args:
|
||||
indicator_name: Indicator name
|
||||
|
||||
Returns:
|
||||
True if custom indicator
|
||||
"""
|
||||
return indicator_name in CUSTOM_TO_TV_NAMES
|
||||
|
||||
|
||||
def get_backend_indicator_name(tv_name: str) -> Optional[str]:
|
||||
"""
|
||||
Get backend indicator name from TradingView name (TA-Lib or custom).
|
||||
|
||||
Args:
|
||||
tv_name: TradingView indicator name
|
||||
|
||||
Returns:
|
||||
Backend indicator name or None if not mapped
|
||||
"""
|
||||
return TV_TO_BACKEND_NAMES.get(tv_name)
|
||||
@@ -21,14 +21,18 @@ from gateway.channels.websocket import WebSocketChannel
|
||||
from gateway.protocol import WebSocketAgentUserMessage
|
||||
from agent.core import create_agent
|
||||
from agent.tools import set_registry, set_datasource_registry, set_indicator_registry
|
||||
from agent.tools import set_trigger_queue, set_trigger_scheduler, set_coordinator
|
||||
from schema.order_spec import SwapOrder
|
||||
from schema.chart_state import ChartState
|
||||
from schema.shape import ShapeCollection
|
||||
from schema.indicator import IndicatorCollection
|
||||
from datasource.registry import DataSourceRegistry
|
||||
from datasource.subscription_manager import SubscriptionManager
|
||||
from datasource.websocket_handler import DatafeedWebSocketHandler
|
||||
from secrets_manager import SecretsStore, InvalidMasterPassword
|
||||
from indicator import IndicatorRegistry, register_all_talib_indicators
|
||||
from indicator import IndicatorRegistry, register_all_talib_indicators, register_custom_indicators
|
||||
from trigger import CommitCoordinator, TriggerQueue
|
||||
from trigger.scheduler import TriggerScheduler
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@@ -58,6 +62,11 @@ subscription_manager = SubscriptionManager()
|
||||
# Indicator infrastructure
|
||||
indicator_registry = IndicatorRegistry()
|
||||
|
||||
# Trigger system infrastructure
|
||||
trigger_coordinator = None
|
||||
trigger_queue = None
|
||||
trigger_scheduler = None
|
||||
|
||||
# Global secrets store
|
||||
secrets_store = SecretsStore()
|
||||
|
||||
@@ -65,7 +74,7 @@ secrets_store = SecretsStore()
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Initialize agent system and data sources on startup."""
|
||||
global agent_executor
|
||||
global agent_executor, trigger_coordinator, trigger_queue, trigger_scheduler
|
||||
|
||||
# Initialize CCXT data sources
|
||||
try:
|
||||
@@ -93,6 +102,13 @@ async def lifespan(app: FastAPI):
|
||||
logger.warning(f"Failed to register TA-Lib indicators: {e}")
|
||||
logger.info("TA-Lib indicators will not be available. Install TA-Lib C library and Python wrapper to enable.")
|
||||
|
||||
# Register custom indicators (TradingView indicators not in TA-Lib)
|
||||
try:
|
||||
custom_count = register_custom_indicators(indicator_registry)
|
||||
logger.info(f"Registered {custom_count} custom indicators")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register custom indicators: {e}")
|
||||
|
||||
# Get API keys from secrets store if unlocked, otherwise fall back to environment
|
||||
anthropic_api_key = None
|
||||
|
||||
@@ -107,6 +123,22 @@ async def lifespan(app: FastAPI):
|
||||
if anthropic_api_key:
|
||||
logger.info("Loaded API key from environment")
|
||||
|
||||
# Initialize trigger system
|
||||
logger.info("Initializing trigger system...")
|
||||
trigger_coordinator = CommitCoordinator()
|
||||
trigger_queue = TriggerQueue(trigger_coordinator)
|
||||
trigger_scheduler = TriggerScheduler(trigger_queue)
|
||||
|
||||
# Start trigger queue and scheduler
|
||||
await trigger_queue.start()
|
||||
trigger_scheduler.start()
|
||||
logger.info("Trigger system initialized and started")
|
||||
|
||||
# Set trigger system for agent tools
|
||||
set_coordinator(trigger_coordinator)
|
||||
set_trigger_queue(trigger_queue)
|
||||
set_trigger_scheduler(trigger_scheduler)
|
||||
|
||||
if not anthropic_api_key:
|
||||
logger.error("ANTHROPIC_API_KEY not found in environment!")
|
||||
logger.info("Agent system will not be available")
|
||||
@@ -125,7 +157,7 @@ async def lifespan(app: FastAPI):
|
||||
chroma_db_path=config["memory"]["chroma_db"],
|
||||
embedding_model=config["memory"]["embedding_model"],
|
||||
context_docs_dir=config["agent"]["context_docs_dir"],
|
||||
base_dir="." # backend/src is the working directory, so . goes to backend, where memory/ lives
|
||||
base_dir="." # backend/src is the working directory, so . goes to backend, where memory/ and soul/ live
|
||||
)
|
||||
|
||||
await agent_executor.initialize()
|
||||
@@ -138,9 +170,22 @@ async def lifespan(app: FastAPI):
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
logger.info("Shutting down systems...")
|
||||
|
||||
# Shutdown trigger system
|
||||
if trigger_scheduler:
|
||||
trigger_scheduler.shutdown(wait=True)
|
||||
logger.info("Trigger scheduler shut down")
|
||||
|
||||
if trigger_queue:
|
||||
await trigger_queue.stop()
|
||||
logger.info("Trigger queue stopped")
|
||||
|
||||
# Shutdown agent system
|
||||
if agent_executor and agent_executor.memory_manager:
|
||||
await agent_executor.memory_manager.close()
|
||||
logger.info("Agent system shut down")
|
||||
|
||||
logger.info("All systems shut down")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
@@ -164,15 +209,21 @@ class ChartStore(BaseModel):
|
||||
class ShapeStore(BaseModel):
|
||||
shapes: dict[str, dict] = {} # Dictionary of shapes keyed by ID
|
||||
|
||||
# IndicatorStore model for synchronization
|
||||
class IndicatorStore(BaseModel):
|
||||
indicators: dict[str, dict] = {} # Dictionary of indicators keyed by ID
|
||||
|
||||
# Initialize stores
|
||||
order_store = OrderStore()
|
||||
chart_store = ChartStore()
|
||||
shape_store = ShapeStore()
|
||||
indicator_store = IndicatorStore()
|
||||
|
||||
# Register with SyncRegistry
|
||||
registry.register(order_store, store_name="OrderStore")
|
||||
registry.register(chart_store, store_name="ChartStore")
|
||||
registry.register(shape_store, store_name="ShapeStore")
|
||||
registry.register(indicator_store, store_name="IndicatorStore")
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
40
backend.old/src/schema/indicator.py
Normal file
40
backend.old/src/schema/indicator.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import Dict, Any, Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class IndicatorInstance(BaseModel):
|
||||
"""
|
||||
Represents an instance of an indicator applied to a chart.
|
||||
|
||||
This schema holds both the TA-Lib metadata and TradingView-specific data
|
||||
needed for synchronization.
|
||||
"""
|
||||
id: str = Field(..., description="Unique identifier for this indicator instance")
|
||||
|
||||
# TA-Lib metadata
|
||||
talib_name: str = Field(..., description="TA-Lib indicator name (e.g., 'RSI', 'SMA', 'MACD')")
|
||||
instance_name: str = Field(..., description="User-friendly instance name")
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="TA-Lib indicator parameters")
|
||||
|
||||
# TradingView metadata
|
||||
tv_study_id: Optional[str] = Field(default=None, description="TradingView study ID assigned by the chart widget")
|
||||
tv_indicator_name: Optional[str] = Field(default=None, description="TradingView indicator name if different from TA-Lib")
|
||||
tv_inputs: Optional[Dict[str, Any]] = Field(default=None, description="TradingView-specific input parameters")
|
||||
|
||||
# Visual properties
|
||||
visible: bool = Field(default=True, description="Whether indicator is visible on chart")
|
||||
pane: str = Field(default="chart", description="Pane where indicator is displayed ('chart' or 'separate')")
|
||||
|
||||
# Metadata
|
||||
symbol: Optional[str] = Field(default=None, description="Symbol this indicator is applied to")
|
||||
created_at: Optional[int] = Field(default=None, description="Creation timestamp (Unix seconds)")
|
||||
modified_at: Optional[int] = Field(default=None, description="Last modification timestamp (Unix seconds)")
|
||||
original_id: Optional[str] = Field(default=None, description="Original ID from backend before TradingView assigns its own ID")
|
||||
|
||||
|
||||
class IndicatorCollection(BaseModel):
|
||||
"""Collection of all indicator instances on the chart."""
|
||||
indicators: Dict[str, IndicatorInstance] = Field(
|
||||
default_factory=dict,
|
||||
description="Dictionary of indicator instances keyed by ID"
|
||||
)
|
||||
@@ -116,6 +116,10 @@ class SyncRegistry:
|
||||
logger.info(f"apply_client_patch: New state after patch: {new_state}")
|
||||
self._update_model(entry.model, new_state)
|
||||
|
||||
# Verify the model was actually updated
|
||||
updated_state = entry.model.model_dump(mode="json")
|
||||
logger.info(f"apply_client_patch: Model state after _update_model: {updated_state}")
|
||||
|
||||
entry.commit_patch(patch)
|
||||
logger.info(f"apply_client_patch: Patch committed, new seq={entry.seq}")
|
||||
# Don't broadcast back to client - they already have this change
|
||||
@@ -206,7 +210,37 @@ class SyncRegistry:
|
||||
await self.websocket.send_json(msg.model_dump(mode="json"))
|
||||
|
||||
def _update_model(self, model: BaseModel, new_data: Dict[str, Any]):
|
||||
# Update model using model_validate for potentially nested models
|
||||
new_model = model.__class__.model_validate(new_data)
|
||||
for field in model.model_fields:
|
||||
setattr(model, field, getattr(new_model, field))
|
||||
# Update model fields in-place to preserve references
|
||||
# This is important for dict fields that may be referenced elsewhere
|
||||
for field_name, field_info in model.model_fields.items():
|
||||
if field_name in new_data:
|
||||
new_value = new_data[field_name]
|
||||
current_value = getattr(model, field_name)
|
||||
|
||||
# For dict fields, update in-place instead of replacing
|
||||
if isinstance(current_value, dict) and isinstance(new_value, dict):
|
||||
self._deep_update_dict(current_value, new_value)
|
||||
else:
|
||||
# For other types, just set the new value
|
||||
setattr(model, field_name, new_value)
|
||||
|
||||
def _deep_update_dict(self, target: dict, source: dict):
|
||||
"""Deep update target dict with source dict, preserving nested dict references."""
|
||||
# Remove keys that are in target but not in source
|
||||
keys_to_remove = set(target.keys()) - set(source.keys())
|
||||
for key in keys_to_remove:
|
||||
del target[key]
|
||||
|
||||
# Update or add keys from source
|
||||
for key, source_value in source.items():
|
||||
if key in target:
|
||||
target_value = target[key]
|
||||
# If both are dicts, recursively update
|
||||
if isinstance(target_value, dict) and isinstance(source_value, dict):
|
||||
self._deep_update_dict(target_value, source_value)
|
||||
else:
|
||||
# Replace the value
|
||||
target[key] = source_value
|
||||
else:
|
||||
# Add new key
|
||||
target[key] = source_value
|
||||
216
backend.old/src/trigger/PRIORITIES.md
Normal file
216
backend.old/src/trigger/PRIORITIES.md
Normal file
@@ -0,0 +1,216 @@
|
||||
# Priority System
|
||||
|
||||
Simple tuple-based priorities for deterministic execution ordering.
|
||||
|
||||
## Basic Concept
|
||||
|
||||
Priorities are just **Python tuples**. Python compares tuples element-by-element, left-to-right:
|
||||
|
||||
```python
|
||||
(0, 1000, 5) < (0, 1001, 3) # True: 0==0, but 1000 < 1001
|
||||
(0, 1000, 5) < (1, 500, 2) # True: 0 < 1
|
||||
(0, 1000) < (0, 1000, 5) # True: shorter wins if equal so far
|
||||
```
|
||||
|
||||
**Lower values = higher priority** (processed first).
|
||||
|
||||
## Priority Categories
|
||||
|
||||
```python
|
||||
class Priority(IntEnum):
|
||||
DATA_SOURCE = 0 # Market data, real-time feeds
|
||||
TIMER = 1 # Scheduled tasks, cron jobs
|
||||
USER_AGENT = 2 # User-agent interactions (chat)
|
||||
USER_DATA_REQUEST = 3 # User data requests (charts)
|
||||
SYSTEM = 4 # Background tasks, cleanup
|
||||
LOW = 5 # Retries after conflicts
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Simple Priority
|
||||
|
||||
```python
|
||||
# Just use the Priority enum
|
||||
trigger = MyTrigger("task", priority=Priority.SYSTEM)
|
||||
await queue.enqueue(trigger)
|
||||
|
||||
# Results in tuple: (4, queue_seq)
|
||||
```
|
||||
|
||||
### Compound Priority (Tuple)
|
||||
|
||||
```python
|
||||
# DataSource: sort by event time (older bars first)
|
||||
trigger = DataUpdateTrigger(
|
||||
source_name="binance",
|
||||
symbol="BTC/USDT",
|
||||
resolution="1m",
|
||||
bar_data={"time": 1678896000, "open": 50000, ...}
|
||||
)
|
||||
await queue.enqueue(trigger)
|
||||
|
||||
# Results in tuple: (0, 1678896000, queue_seq)
|
||||
# ^ ^ ^
|
||||
# | | Queue insertion order (FIFO)
|
||||
# | Event time (candle end time)
|
||||
# DATA_SOURCE priority
|
||||
```
|
||||
|
||||
### Manual Override
|
||||
|
||||
```python
|
||||
# Override at enqueue time
|
||||
await queue.enqueue(
|
||||
trigger,
|
||||
priority_override=(Priority.DATA_SOURCE, custom_time, custom_sort)
|
||||
)
|
||||
|
||||
# Queue appends queue_seq: (0, custom_time, custom_sort, queue_seq)
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Market Data (Process Chronologically)
|
||||
|
||||
```python
|
||||
# Bar from 10:00 → (0, 10:00_timestamp, queue_seq)
|
||||
# Bar from 10:05 → (0, 10:05_timestamp, queue_seq)
|
||||
#
|
||||
# 10:00 bar processes first (earlier event_time)
|
||||
|
||||
DataUpdateTrigger(
|
||||
...,
|
||||
bar_data={"time": event_timestamp, ...}
|
||||
)
|
||||
```
|
||||
|
||||
### User Messages (FIFO Order)
|
||||
|
||||
```python
|
||||
# Message #1 → (2, msg1_timestamp, queue_seq)
|
||||
# Message #2 → (2, msg2_timestamp, queue_seq)
|
||||
#
|
||||
# Message #1 processes first (earlier timestamp)
|
||||
|
||||
AgentTriggerHandler(
|
||||
session_id="user1",
|
||||
message_content="...",
|
||||
message_timestamp=unix_timestamp # Optional, defaults to now
|
||||
)
|
||||
```
|
||||
|
||||
### Scheduled Tasks (By Schedule Time)
|
||||
|
||||
```python
|
||||
# Job scheduled for 9 AM → (1, 9am_timestamp, queue_seq)
|
||||
# Job scheduled for 2 PM → (1, 2pm_timestamp, queue_seq)
|
||||
#
|
||||
# 9 AM job processes first
|
||||
|
||||
CronTrigger(
|
||||
name="morning_sync",
|
||||
inner_trigger=...,
|
||||
scheduled_time=scheduled_timestamp
|
||||
)
|
||||
```
|
||||
|
||||
## Execution Order Example
|
||||
|
||||
```
|
||||
Queue contains:
|
||||
1. DataSource (BTC @ 10:00) → (0, 10:00, 1)
|
||||
2. DataSource (BTC @ 10:05) → (0, 10:05, 2)
|
||||
3. Timer (scheduled 9 AM) → (1, 09:00, 3)
|
||||
4. User message #1 → (2, 14:30, 4)
|
||||
5. User message #2 → (2, 14:35, 5)
|
||||
|
||||
Dequeue order:
|
||||
1. DataSource (BTC @ 10:00) ← 0 < all others
|
||||
2. DataSource (BTC @ 10:05) ← 0 < all others, 10:05 > 10:00
|
||||
3. Timer (scheduled 9 AM) ← 1 < remaining
|
||||
4. User message #1 ← 2 < remaining, 14:30 < 14:35
|
||||
5. User message #2 ← last
|
||||
```
|
||||
|
||||
## Short Tuple Wins
|
||||
|
||||
If tuples are equal up to the length of the shorter one, **shorter tuple has higher priority**:
|
||||
|
||||
```python
|
||||
(0, 1000) < (0, 1000, 5) # True: shorter wins
|
||||
(0,) < (0, 1000) # True: shorter wins
|
||||
(Priority.DATA_SOURCE,) < (Priority.DATA_SOURCE, 1000) # True
|
||||
```
|
||||
|
||||
This is Python's default tuple comparison behavior. In practice, we always append `queue_seq`, so this rarely matters (all tuples end up same length).
|
||||
|
||||
## Integration with Triggers
|
||||
|
||||
### Trigger Sets Its Own Priority
|
||||
|
||||
```python
|
||||
class MyTrigger(Trigger):
|
||||
def __init__(self, event_time):
|
||||
super().__init__(
|
||||
name="my_trigger",
|
||||
priority=Priority.DATA_SOURCE,
|
||||
priority_tuple=(Priority.DATA_SOURCE.value, event_time)
|
||||
)
|
||||
```
|
||||
|
||||
Queue appends `queue_seq` automatically:
|
||||
```python
|
||||
# Trigger's tuple: (0, event_time)
|
||||
# After enqueue: (0, event_time, queue_seq)
|
||||
```
|
||||
|
||||
### Override at Enqueue
|
||||
|
||||
```python
|
||||
# Ignore trigger's priority, use override
|
||||
await queue.enqueue(
|
||||
trigger,
|
||||
priority_override=(Priority.TIMER, scheduled_time)
|
||||
)
|
||||
```
|
||||
|
||||
## Why Tuples?
|
||||
|
||||
✅ **Simple**: No custom classes, just native Python tuples
|
||||
✅ **Flexible**: Add as many sort keys as needed
|
||||
✅ **Efficient**: Python's tuple comparison is highly optimized
|
||||
✅ **Readable**: `(0, 1000, 5)` is obvious what it means
|
||||
✅ **Debuggable**: Can print and inspect easily
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Old: CompoundPriority(primary=0, secondary=1000, tertiary=5)
|
||||
# New: (0, 1000, 5)
|
||||
|
||||
# Same semantics, much simpler!
|
||||
```
|
||||
|
||||
## Advanced: Custom Sorting
|
||||
|
||||
Want to sort by multiple factors? Just add more elements:
|
||||
|
||||
```python
|
||||
# Sort by: priority → symbol → event_time → queue_seq
|
||||
priority_tuple = (
|
||||
Priority.DATA_SOURCE.value,
|
||||
symbol_id, # e.g., hash("BTC/USDT")
|
||||
event_time,
|
||||
# queue_seq appended by queue
|
||||
)
|
||||
```
|
||||
|
||||
## Summary
|
||||
|
||||
- **Priorities are tuples**: `(primary, secondary, ..., queue_seq)`
|
||||
- **Lower = higher priority**: Processed first
|
||||
- **Element-by-element comparison**: Left-to-right
|
||||
- **Shorter tuple wins**: If equal up to shorter length
|
||||
- **Queue appends queue_seq**: Always last element (FIFO within same priority)
|
||||
|
||||
That's it! No complex classes, just tuples. 🎯
|
||||
386
backend.old/src/trigger/README.md
Normal file
386
backend.old/src/trigger/README.md
Normal file
@@ -0,0 +1,386 @@
|
||||
# Trigger System
|
||||
|
||||
Lock-free, sequence-based execution system for deterministic event processing.
|
||||
|
||||
## Overview
|
||||
|
||||
All operations (WebSocket messages, cron tasks, data updates) flow through a **priority queue**, execute in **parallel**, but commit in **strict sequential order** with **optimistic conflict detection**.
|
||||
|
||||
### Key Features
|
||||
|
||||
- **Lock-free reads**: Snapshots are deep copies, no blocking
|
||||
- **Sequential commits**: Total ordering via sequence numbers
|
||||
- **Optimistic concurrency**: Conflicts detected, retry with same seq
|
||||
- **Priority preservation**: High-priority work never blocked by low-priority
|
||||
- **Long-running agents**: Execute in parallel, commit sequentially
|
||||
- **Deterministic replay**: Can reproduce exact system state at any seq
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────┐
|
||||
│ WebSocket │───┐
|
||||
│ Messages │ │
|
||||
└─────────────┘ │
|
||||
├──→ ┌─────────────────┐
|
||||
┌─────────────┐ │ │ TriggerQueue │
|
||||
│ Cron │───┤ │ (Priority Queue)│
|
||||
│ Scheduled │ │ └────────┬────────┘
|
||||
└─────────────┘ │ │ Assign seq
|
||||
│ ↓
|
||||
┌─────────────┐ │ ┌─────────────────┐
|
||||
│ DataSource │───┘ │ Execute Trigger│
|
||||
│ Updates │ │ (Parallel OK) │
|
||||
└─────────────┘ └────────┬────────┘
|
||||
│ CommitIntents
|
||||
↓
|
||||
┌─────────────────┐
|
||||
│ CommitCoordinator│
|
||||
│ (Sequential) │
|
||||
└────────┬────────┘
|
||||
│ Commit in seq order
|
||||
↓
|
||||
┌─────────────────┐
|
||||
│ VersionedStores │
|
||||
│ (w/ Backends) │
|
||||
└─────────────────┘
|
||||
```
|
||||
|
||||
## Core Components
|
||||
|
||||
### 1. ExecutionContext (`context.py`)
|
||||
|
||||
Tracks execution seq and store snapshots via `contextvars` (auto-propagates through async calls).
|
||||
|
||||
```python
|
||||
from trigger import get_execution_context
|
||||
|
||||
ctx = get_execution_context()
|
||||
print(f"Running at seq {ctx.seq}")
|
||||
```
|
||||
|
||||
### 2. Trigger Types (`types.py`)
|
||||
|
||||
```python
|
||||
from trigger import Trigger, Priority, CommitIntent
|
||||
|
||||
class MyTrigger(Trigger):
|
||||
async def execute(self) -> list[CommitIntent]:
|
||||
# Read snapshot
|
||||
seq, data = some_store.read_snapshot()
|
||||
|
||||
# Modify
|
||||
new_data = modify(data)
|
||||
|
||||
# Prepare commit
|
||||
intent = some_store.prepare_commit(seq, new_data)
|
||||
return [intent]
|
||||
```
|
||||
|
||||
### 3. VersionedStore (`store.py`)
|
||||
|
||||
Stores with pluggable backends and optimistic concurrency:
|
||||
|
||||
```python
|
||||
from trigger import VersionedStore, PydanticStoreBackend
|
||||
|
||||
# Wrap existing Pydantic model
|
||||
backend = PydanticStoreBackend(order_store)
|
||||
versioned_store = VersionedStore("OrderStore", backend)
|
||||
|
||||
# Lock-free snapshot read
|
||||
seq, snapshot = versioned_store.read_snapshot()
|
||||
|
||||
# Prepare commit (does not modify yet)
|
||||
intent = versioned_store.prepare_commit(seq, modified_snapshot)
|
||||
```
|
||||
|
||||
**Pluggable Backends**:
|
||||
- `PydanticStoreBackend`: For existing Pydantic models (OrderStore, ChartStore, etc.)
|
||||
- `FileStoreBackend`: Future - version files (Python scripts, configs)
|
||||
- `DatabaseStoreBackend`: Future - version database rows
|
||||
|
||||
### 4. CommitCoordinator (`coordinator.py`)
|
||||
|
||||
Manages sequential commits with conflict detection:
|
||||
|
||||
- Waits for seq N to commit before N+1
|
||||
- Detects conflicts (expected_seq vs committed_seq)
|
||||
- Re-executes (not re-enqueues) on conflict **with same seq**
|
||||
- Tracks execution state for debugging
|
||||
|
||||
### 5. TriggerQueue (`queue.py`)
|
||||
|
||||
Priority queue with seq assignment:
|
||||
|
||||
```python
|
||||
from trigger import TriggerQueue
|
||||
|
||||
queue = TriggerQueue(coordinator)
|
||||
await queue.start()
|
||||
|
||||
# Enqueue trigger
|
||||
await queue.enqueue(my_trigger, Priority.HIGH)
|
||||
```
|
||||
|
||||
### 6. TriggerScheduler (`scheduler.py`)
|
||||
|
||||
APScheduler integration for cron triggers:
|
||||
|
||||
```python
|
||||
from trigger.scheduler import TriggerScheduler
|
||||
|
||||
scheduler = TriggerScheduler(queue)
|
||||
scheduler.start()
|
||||
|
||||
# Every 5 minutes
|
||||
scheduler.schedule_interval(
|
||||
IndicatorUpdateTrigger("rsi_14"),
|
||||
minutes=5
|
||||
)
|
||||
|
||||
# Daily at 9 AM
|
||||
scheduler.schedule_cron(
|
||||
SyncExchangeStateTrigger(),
|
||||
hour="9",
|
||||
minute="0"
|
||||
)
|
||||
```
|
||||
|
||||
## Integration Example
|
||||
|
||||
### Basic Setup in `main.py`
|
||||
|
||||
```python
|
||||
from trigger import (
|
||||
CommitCoordinator,
|
||||
TriggerQueue,
|
||||
VersionedStore,
|
||||
PydanticStoreBackend,
|
||||
)
|
||||
from trigger.scheduler import TriggerScheduler
|
||||
|
||||
# Create coordinator
|
||||
coordinator = CommitCoordinator()
|
||||
|
||||
# Wrap existing stores
|
||||
order_store_versioned = VersionedStore(
|
||||
"OrderStore",
|
||||
PydanticStoreBackend(order_store)
|
||||
)
|
||||
coordinator.register_store(order_store_versioned)
|
||||
|
||||
chart_store_versioned = VersionedStore(
|
||||
"ChartStore",
|
||||
PydanticStoreBackend(chart_store)
|
||||
)
|
||||
coordinator.register_store(chart_store_versioned)
|
||||
|
||||
# Create queue and scheduler
|
||||
trigger_queue = TriggerQueue(coordinator)
|
||||
await trigger_queue.start()
|
||||
|
||||
scheduler = TriggerScheduler(trigger_queue)
|
||||
scheduler.start()
|
||||
```
|
||||
|
||||
### WebSocket Message Handler
|
||||
|
||||
```python
|
||||
from trigger.handlers import AgentTriggerHandler
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
|
||||
if data["type"] == "agent_user_message":
|
||||
# Enqueue agent trigger instead of direct Gateway call
|
||||
trigger = AgentTriggerHandler(
|
||||
session_id=data["session_id"],
|
||||
message_content=data["content"],
|
||||
gateway_handler=gateway.route_user_message,
|
||||
coordinator=coordinator,
|
||||
)
|
||||
await trigger_queue.enqueue(trigger)
|
||||
```
|
||||
|
||||
### DataSource Updates
|
||||
|
||||
```python
|
||||
from trigger.handlers import DataUpdateTrigger
|
||||
|
||||
# In subscription_manager._on_source_update()
|
||||
def _on_source_update(self, source_key: tuple, bar: dict):
|
||||
# Enqueue data update trigger
|
||||
trigger = DataUpdateTrigger(
|
||||
source_name=source_key[0],
|
||||
symbol=source_key[1],
|
||||
resolution=source_key[2],
|
||||
bar_data=bar,
|
||||
coordinator=coordinator,
|
||||
)
|
||||
asyncio.create_task(trigger_queue.enqueue(trigger))
|
||||
```
|
||||
|
||||
### Custom Trigger
|
||||
|
||||
```python
|
||||
from trigger import Trigger, CommitIntent, Priority
|
||||
|
||||
class RecalculatePortfolioTrigger(Trigger):
|
||||
def __init__(self, coordinator):
|
||||
super().__init__("recalc_portfolio", Priority.NORMAL)
|
||||
self.coordinator = coordinator
|
||||
|
||||
async def execute(self) -> list[CommitIntent]:
|
||||
# Read snapshots from multiple stores
|
||||
order_seq, orders = self.coordinator.get_store("OrderStore").read_snapshot()
|
||||
chart_seq, chart = self.coordinator.get_store("ChartStore").read_snapshot()
|
||||
|
||||
# Calculate portfolio value
|
||||
portfolio_value = calculate_portfolio(orders, chart)
|
||||
|
||||
# Update chart state with portfolio value
|
||||
chart.portfolio_value = portfolio_value
|
||||
|
||||
# Prepare commit
|
||||
intent = self.coordinator.get_store("ChartStore").prepare_commit(
|
||||
chart_seq,
|
||||
chart
|
||||
)
|
||||
|
||||
return [intent]
|
||||
|
||||
# Schedule it
|
||||
scheduler.schedule_interval(
|
||||
RecalculatePortfolioTrigger(coordinator),
|
||||
minutes=1
|
||||
)
|
||||
```
|
||||
|
||||
## Execution Flow
|
||||
|
||||
### Normal Flow (No Conflicts)
|
||||
|
||||
```
|
||||
seq=100: WebSocket message arrives → enqueue → dequeue → assign seq=100 → execute
|
||||
seq=101: Cron trigger fires → enqueue → dequeue → assign seq=101 → execute
|
||||
|
||||
seq=101 finishes first → waits in commit queue
|
||||
seq=100 finishes → commits immediately (next in order)
|
||||
seq=101 commits next
|
||||
```
|
||||
|
||||
### Conflict Flow
|
||||
|
||||
```
|
||||
seq=100: reads OrderStore at seq=99 → executes for 30 seconds
|
||||
seq=101: reads OrderStore at seq=99 → executes for 5 seconds
|
||||
|
||||
seq=101 finishes first → tries to commit based on seq=99
|
||||
seq=100 finishes → commits OrderStore at seq=100
|
||||
|
||||
Coordinator detects conflict:
|
||||
expected_seq=99, committed_seq=100
|
||||
|
||||
seq=101 evicted → RE-EXECUTES with same seq=101 (not re-enqueued)
|
||||
reads OrderStore at seq=100 → executes again
|
||||
finishes → commits successfully at seq=101
|
||||
```
|
||||
|
||||
## Benefits
|
||||
|
||||
### For Agent System
|
||||
|
||||
- **Long-running agents work naturally**: Agent starts at seq=100, runs for 60 seconds while market data updates at seq=101-110, commits only if no conflicts
|
||||
- **No deadlocks**: No locks = no deadlock possibility
|
||||
- **Deterministic**: Can replay from any seq for debugging
|
||||
|
||||
### For Strategy Execution
|
||||
|
||||
- **High-frequency data doesn't block strategies**: Data updates enqueued, executed in parallel, commit sequentially
|
||||
- **Priority preservation**: Critical order execution never blocked by indicator calculations
|
||||
- **Conflict detection**: If market moved during strategy calculation, automatically retry with fresh data
|
||||
|
||||
### For Scaling
|
||||
|
||||
- **Single-node first**: Runs on single asyncio event loop, no complex distributed coordination
|
||||
- **Future-proof**: Can swap queue for Redis/PostgreSQL-backed distributed queue later
|
||||
- **Event sourcing ready**: All commits have seq numbers, can build event log
|
||||
|
||||
## Debugging
|
||||
|
||||
### Check Current State
|
||||
|
||||
```python
|
||||
# Coordinator stats
|
||||
stats = coordinator.get_stats()
|
||||
print(f"Current seq: {stats['current_seq']}")
|
||||
print(f"Pending commits: {stats['pending_commits']}")
|
||||
print(f"Executions by state: {stats['state_counts']}")
|
||||
|
||||
# Store state
|
||||
store = coordinator.get_store("OrderStore")
|
||||
print(f"Store: {store}") # Shows committed_seq and version
|
||||
|
||||
# Execution record
|
||||
record = coordinator.get_execution_record(100)
|
||||
print(f"Seq 100: {record}") # Shows state, retry_count, error
|
||||
```
|
||||
|
||||
### Common Issues
|
||||
|
||||
**Symptoms: High conflict rate**
|
||||
- **Cause**: Multiple triggers modifying same store frequently
|
||||
- **Solution**: Batch updates, use debouncing, or redesign to reduce contention
|
||||
|
||||
**Symptoms: Commits stuck (next_commit_seq not advancing)**
|
||||
- **Cause**: Execution at that seq failed or is taking too long
|
||||
- **Solution**: Check execution_records for that seq, look for errors in logs
|
||||
|
||||
**Symptoms: Queue depth growing**
|
||||
- **Cause**: Executions slower than enqueue rate
|
||||
- **Solution**: Profile trigger execution, optimize slow paths, add rate limiting
|
||||
|
||||
## Testing
|
||||
|
||||
### Unit Test: Conflict Detection
|
||||
|
||||
```python
|
||||
import pytest
|
||||
from trigger import VersionedStore, PydanticStoreBackend, CommitCoordinator
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conflict_detection():
|
||||
coordinator = CommitCoordinator()
|
||||
|
||||
store = VersionedStore("TestStore", PydanticStoreBackend(TestModel()))
|
||||
coordinator.register_store(store)
|
||||
|
||||
# Seq 1: read at 0, modify, commit
|
||||
seq1, data1 = store.read_snapshot()
|
||||
data1.value = "seq1"
|
||||
intent1 = store.prepare_commit(seq1, data1)
|
||||
|
||||
# Seq 2: read at 0 (same snapshot), modify
|
||||
seq2, data2 = store.read_snapshot()
|
||||
data2.value = "seq2"
|
||||
intent2 = store.prepare_commit(seq2, data2)
|
||||
|
||||
# Commit seq 1 (should succeed)
|
||||
# ... coordinator logic ...
|
||||
|
||||
# Commit seq 2 (should conflict and retry)
|
||||
# ... verify conflict detected ...
|
||||
```
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- **Distributed queue**: Redis-backed queue for multi-worker deployment
|
||||
- **Event log persistence**: Store all commits for event sourcing/audit
|
||||
- **Metrics dashboard**: Real-time view of queue depth, conflict rate, latency
|
||||
- **Transaction snapshots**: Full system state at any seq for replay/debugging
|
||||
- **Automatic batching**: Coalesce rapid updates to same store
|
||||
35
backend.old/src/trigger/__init__.py
Normal file
35
backend.old/src/trigger/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Sequential execution trigger system with optimistic concurrency control.
|
||||
|
||||
All operations (websocket, cron, data events) flow through a priority queue,
|
||||
execute in parallel, but commit in strict sequential order with conflict detection.
|
||||
"""
|
||||
|
||||
from .context import ExecutionContext, get_execution_context
|
||||
from .types import Priority, PriorityTuple, Trigger, CommitIntent, ExecutionState
|
||||
from .store import VersionedStore, StoreBackend, PydanticStoreBackend
|
||||
from .coordinator import CommitCoordinator
|
||||
from .queue import TriggerQueue
|
||||
from .handlers import AgentTriggerHandler, LambdaHandler
|
||||
|
||||
__all__ = [
|
||||
# Context
|
||||
"ExecutionContext",
|
||||
"get_execution_context",
|
||||
# Types
|
||||
"Priority",
|
||||
"PriorityTuple",
|
||||
"Trigger",
|
||||
"CommitIntent",
|
||||
"ExecutionState",
|
||||
# Store
|
||||
"VersionedStore",
|
||||
"StoreBackend",
|
||||
"PydanticStoreBackend",
|
||||
# Coordination
|
||||
"CommitCoordinator",
|
||||
"TriggerQueue",
|
||||
# Handlers
|
||||
"AgentTriggerHandler",
|
||||
"LambdaHandler",
|
||||
]
|
||||
61
backend.old/src/trigger/context.py
Normal file
61
backend.old/src/trigger/context.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Execution context tracking using Python's contextvars.
|
||||
|
||||
Each execution gets a unique seq number that propagates through all async calls,
|
||||
allowing us to track which execution made which changes for conflict detection.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Context variables - automatically propagate through async call chains
|
||||
_execution_context: ContextVar[Optional["ExecutionContext"]] = ContextVar(
|
||||
"execution_context", default=None
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionContext:
|
||||
"""
|
||||
Execution context for a single trigger execution.
|
||||
|
||||
Automatically propagates through async calls via contextvars.
|
||||
Tracks the seq number and which store snapshots were read.
|
||||
"""
|
||||
|
||||
seq: int
|
||||
"""Sequential execution number - determines commit order"""
|
||||
|
||||
trigger_name: str
|
||||
"""Name/type of trigger being executed"""
|
||||
|
||||
snapshot_seqs: dict[str, int] = field(default_factory=dict)
|
||||
"""Store name -> seq number of snapshot that was read"""
|
||||
|
||||
def record_snapshot(self, store_name: str, snapshot_seq: int) -> None:
|
||||
"""Record that we read a snapshot from a store at a specific seq"""
|
||||
self.snapshot_seqs[store_name] = snapshot_seq
|
||||
logger.debug(f"Seq {self.seq}: Read {store_name} at seq {snapshot_seq}")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ExecutionContext(seq={self.seq}, trigger={self.trigger_name})"
|
||||
|
||||
|
||||
def get_execution_context() -> Optional[ExecutionContext]:
|
||||
"""Get the current execution context, or None if not in an execution"""
|
||||
return _execution_context.get()
|
||||
|
||||
|
||||
def set_execution_context(ctx: ExecutionContext) -> None:
|
||||
"""Set the execution context for the current async task"""
|
||||
_execution_context.set(ctx)
|
||||
logger.debug(f"Set execution context: {ctx}")
|
||||
|
||||
|
||||
def clear_execution_context() -> None:
|
||||
"""Clear the execution context"""
|
||||
_execution_context.set(None)
|
||||
302
backend.old/src/trigger/coordinator.py
Normal file
302
backend.old/src/trigger/coordinator.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""
|
||||
Commit coordinator - manages sequential commits with conflict detection.
|
||||
|
||||
Ensures that commits happen in strict sequence order, even when executions
|
||||
complete out of order. Detects conflicts and triggers re-execution with the
|
||||
same seq number (not re-enqueue, just re-execute).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from .context import ExecutionContext
|
||||
from .store import VersionedStore
|
||||
from .types import CommitIntent, ExecutionRecord, ExecutionState, Trigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommitCoordinator:
|
||||
"""
|
||||
Manages sequential commits with optimistic concurrency control.
|
||||
|
||||
Key responsibilities:
|
||||
- Maintain strict sequential commit order (seq N+1 commits after seq N)
|
||||
- Detect conflicts between execution snapshot and committed state
|
||||
- Trigger re-execution (not re-enqueue) on conflicts with same seq
|
||||
- Track in-flight executions for debugging and monitoring
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._stores: dict[str, VersionedStore] = {}
|
||||
self._current_seq = 0 # Highest committed seq across all operations
|
||||
self._next_commit_seq = 1 # Next seq we're waiting to commit
|
||||
self._pending_commits: dict[int, tuple[ExecutionRecord, list[CommitIntent]]] = {}
|
||||
self._execution_records: dict[int, ExecutionRecord] = {}
|
||||
self._lock = asyncio.Lock() # Only for coordinator internal state, not stores
|
||||
|
||||
def register_store(self, store: VersionedStore) -> None:
|
||||
"""Register a versioned store with the coordinator"""
|
||||
self._stores[store.name] = store
|
||||
logger.info(f"Registered store: {store.name}")
|
||||
|
||||
def get_store(self, name: str) -> Optional[VersionedStore]:
|
||||
"""Get a registered store by name"""
|
||||
return self._stores.get(name)
|
||||
|
||||
async def start_execution(self, seq: int, trigger: Trigger) -> ExecutionRecord:
|
||||
"""
|
||||
Record that an execution is starting.
|
||||
|
||||
Args:
|
||||
seq: Sequence number assigned to this execution
|
||||
trigger: The trigger being executed
|
||||
|
||||
Returns:
|
||||
ExecutionRecord for tracking
|
||||
"""
|
||||
async with self._lock:
|
||||
record = ExecutionRecord(
|
||||
seq=seq,
|
||||
trigger=trigger,
|
||||
state=ExecutionState.EXECUTING,
|
||||
)
|
||||
self._execution_records[seq] = record
|
||||
logger.info(f"Started execution: seq={seq}, trigger={trigger.name}")
|
||||
return record
|
||||
|
||||
async def submit_for_commit(
|
||||
self,
|
||||
seq: int,
|
||||
commit_intents: list[CommitIntent],
|
||||
) -> None:
|
||||
"""
|
||||
Submit commit intents for sequential commit.
|
||||
|
||||
The commit will only happen when:
|
||||
1. All prior seq numbers have committed
|
||||
2. No conflicts detected with committed state
|
||||
|
||||
Args:
|
||||
seq: Sequence number of this execution
|
||||
commit_intents: List of changes to commit (empty if no changes)
|
||||
"""
|
||||
async with self._lock:
|
||||
record = self._execution_records.get(seq)
|
||||
if not record:
|
||||
logger.error(f"No execution record found for seq={seq}")
|
||||
return
|
||||
|
||||
record.state = ExecutionState.WAITING_COMMIT
|
||||
record.commit_intents = commit_intents
|
||||
self._pending_commits[seq] = (record, commit_intents)
|
||||
|
||||
logger.info(
|
||||
f"Seq {seq} submitted for commit with {len(commit_intents)} intents"
|
||||
)
|
||||
|
||||
# Try to process commits (this will handle sequential ordering)
|
||||
await self._process_commits()
|
||||
|
||||
async def _process_commits(self) -> None:
|
||||
"""
|
||||
Process pending commits in strict sequential order.
|
||||
|
||||
Only commits seq N if seq N-1 has already committed.
|
||||
Detects conflicts and triggers re-execution with same seq.
|
||||
"""
|
||||
while True:
|
||||
async with self._lock:
|
||||
# Check if next expected seq is ready to commit
|
||||
if self._next_commit_seq not in self._pending_commits:
|
||||
# Waiting for this seq to complete execution
|
||||
break
|
||||
|
||||
seq = self._next_commit_seq
|
||||
record, intents = self._pending_commits[seq]
|
||||
|
||||
logger.info(
|
||||
f"Processing commit for seq={seq} (current_seq={self._current_seq})"
|
||||
)
|
||||
|
||||
# Check for conflicts
|
||||
conflicts = self._check_conflicts(intents)
|
||||
|
||||
if conflicts:
|
||||
# Conflict detected - re-execute with same seq
|
||||
logger.warning(
|
||||
f"Seq {seq} has conflicts in stores: {conflicts}. Re-executing..."
|
||||
)
|
||||
|
||||
# Remove from pending (will be re-added when execution completes)
|
||||
del self._pending_commits[seq]
|
||||
|
||||
# Mark as evicted
|
||||
record.state = ExecutionState.EVICTED
|
||||
record.retry_count += 1
|
||||
|
||||
# Advance to next seq (this seq will be retried in background)
|
||||
self._next_commit_seq += 1
|
||||
self._current_seq += 1
|
||||
|
||||
# Trigger re-execution (outside lock)
|
||||
asyncio.create_task(self._retry_execution(record))
|
||||
|
||||
continue
|
||||
|
||||
# No conflicts - commit all intents atomically
|
||||
for intent in intents:
|
||||
store = self._stores.get(intent.store_name)
|
||||
if not store:
|
||||
logger.error(
|
||||
f"Seq {seq}: Store '{intent.store_name}' not found"
|
||||
)
|
||||
continue
|
||||
|
||||
store.commit(intent.new_data, seq)
|
||||
|
||||
# Mark as committed
|
||||
record.state = ExecutionState.COMMITTED
|
||||
del self._pending_commits[seq]
|
||||
|
||||
# Advance seq counters
|
||||
self._current_seq = seq
|
||||
self._next_commit_seq = seq + 1
|
||||
|
||||
logger.info(
|
||||
f"Committed seq={seq}, current_seq now {self._current_seq}"
|
||||
)
|
||||
|
||||
def _check_conflicts(self, intents: list[CommitIntent]) -> list[str]:
|
||||
"""
|
||||
Check if any commit intents conflict with current committed state.
|
||||
|
||||
Args:
|
||||
intents: List of commit intents to check
|
||||
|
||||
Returns:
|
||||
List of store names that have conflicts (empty if no conflicts)
|
||||
"""
|
||||
conflicts = []
|
||||
|
||||
for intent in intents:
|
||||
store = self._stores.get(intent.store_name)
|
||||
if not store:
|
||||
logger.error(f"Store '{intent.store_name}' not found during conflict check")
|
||||
continue
|
||||
|
||||
if store.check_conflict(intent.expected_seq):
|
||||
conflicts.append(intent.store_name)
|
||||
|
||||
return conflicts
|
||||
|
||||
async def _retry_execution(self, record: ExecutionRecord) -> None:
|
||||
"""
|
||||
Re-execute a trigger that had conflicts.
|
||||
|
||||
Executes with the SAME seq number (not re-enqueued, just re-executed).
|
||||
This ensures the execution order remains deterministic.
|
||||
|
||||
Args:
|
||||
record: Execution record to retry
|
||||
"""
|
||||
from .context import ExecutionContext, set_execution_context, clear_execution_context
|
||||
|
||||
logger.info(
|
||||
f"Retrying execution: seq={record.seq}, trigger={record.trigger.name}, "
|
||||
f"retry_count={record.retry_count}"
|
||||
)
|
||||
|
||||
# Set execution context for retry
|
||||
ctx = ExecutionContext(
|
||||
seq=record.seq,
|
||||
trigger_name=record.trigger.name,
|
||||
)
|
||||
set_execution_context(ctx)
|
||||
|
||||
try:
|
||||
# Re-execute trigger
|
||||
record.state = ExecutionState.EXECUTING
|
||||
commit_intents = await record.trigger.execute()
|
||||
|
||||
# Submit for commit again (with same seq)
|
||||
await self.submit_for_commit(record.seq, commit_intents)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retry execution failed for seq={record.seq}: {e}", exc_info=True
|
||||
)
|
||||
record.state = ExecutionState.FAILED
|
||||
record.error = str(e)
|
||||
|
||||
# Still need to advance past this seq
|
||||
async with self._lock:
|
||||
if record.seq == self._next_commit_seq:
|
||||
self._next_commit_seq += 1
|
||||
self._current_seq += 1
|
||||
|
||||
# Try to process any pending commits
|
||||
await self._process_commits()
|
||||
|
||||
finally:
|
||||
clear_execution_context()
|
||||
|
||||
async def execution_failed(self, seq: int, error: Exception) -> None:
|
||||
"""
|
||||
Mark an execution as failed.
|
||||
|
||||
Args:
|
||||
seq: Sequence number that failed
|
||||
error: The exception that caused the failure
|
||||
"""
|
||||
async with self._lock:
|
||||
record = self._execution_records.get(seq)
|
||||
if record:
|
||||
record.state = ExecutionState.FAILED
|
||||
record.error = str(error)
|
||||
|
||||
# Remove from pending if present
|
||||
self._pending_commits.pop(seq, None)
|
||||
|
||||
# If this is the next seq to commit, advance past it
|
||||
if seq == self._next_commit_seq:
|
||||
self._next_commit_seq += 1
|
||||
self._current_seq += 1
|
||||
|
||||
logger.info(
|
||||
f"Seq {seq} failed, advancing current_seq to {self._current_seq}"
|
||||
)
|
||||
|
||||
# Try to process any pending commits
|
||||
await self._process_commits()
|
||||
|
||||
def get_current_seq(self) -> int:
|
||||
"""Get the current committed sequence number"""
|
||||
return self._current_seq
|
||||
|
||||
def get_execution_record(self, seq: int) -> Optional[ExecutionRecord]:
|
||||
"""Get execution record for a specific seq"""
|
||||
return self._execution_records.get(seq)
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get statistics about the coordinator state"""
|
||||
state_counts = {}
|
||||
for record in self._execution_records.values():
|
||||
state_name = record.state.name
|
||||
state_counts[state_name] = state_counts.get(state_name, 0) + 1
|
||||
|
||||
return {
|
||||
"current_seq": self._current_seq,
|
||||
"next_commit_seq": self._next_commit_seq,
|
||||
"pending_commits": len(self._pending_commits),
|
||||
"total_executions": len(self._execution_records),
|
||||
"state_counts": state_counts,
|
||||
"stores": {name: str(store) for name, store in self._stores.items()},
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"CommitCoordinator(current_seq={self._current_seq}, "
|
||||
f"pending={len(self._pending_commits)}, stores={len(self._stores)})"
|
||||
)
|
||||
304
backend.old/src/trigger/handlers.py
Normal file
304
backend.old/src/trigger/handlers.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Trigger handlers - concrete implementations for common trigger types.
|
||||
|
||||
Provides ready-to-use trigger handlers for:
|
||||
- Agent execution (WebSocket user messages)
|
||||
- Lambda/callable execution
|
||||
- Data update triggers
|
||||
- Indicator updates
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Awaitable, Callable, Optional
|
||||
|
||||
from .coordinator import CommitCoordinator
|
||||
from .types import CommitIntent, Priority, Trigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentTriggerHandler(Trigger):
|
||||
"""
|
||||
Trigger for agent execution from WebSocket user messages.
|
||||
|
||||
Wraps the Gateway's agent execution flow and captures any
|
||||
store modifications as commit intents.
|
||||
|
||||
Priority tuple: (USER_AGENT, message_timestamp, queue_seq)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
message_content: str,
|
||||
message_timestamp: Optional[int] = None,
|
||||
attachments: Optional[list] = None,
|
||||
gateway_handler: Optional[Callable] = None,
|
||||
coordinator: Optional[CommitCoordinator] = None,
|
||||
):
|
||||
"""
|
||||
Initialize agent trigger.
|
||||
|
||||
Args:
|
||||
session_id: User session ID
|
||||
message_content: User message content
|
||||
message_timestamp: When user sent message (unix timestamp, defaults to now)
|
||||
attachments: Optional message attachments
|
||||
gateway_handler: Callable to route to Gateway (set during integration)
|
||||
coordinator: CommitCoordinator for accessing stores
|
||||
"""
|
||||
if message_timestamp is None:
|
||||
message_timestamp = int(time.time())
|
||||
|
||||
# Priority tuple: sort by USER_AGENT priority, then message timestamp
|
||||
super().__init__(
|
||||
name=f"agent_{session_id}",
|
||||
priority=Priority.USER_AGENT,
|
||||
priority_tuple=(Priority.USER_AGENT.value, message_timestamp)
|
||||
)
|
||||
self.session_id = session_id
|
||||
self.message_content = message_content
|
||||
self.message_timestamp = message_timestamp
|
||||
self.attachments = attachments or []
|
||||
self.gateway_handler = gateway_handler
|
||||
self.coordinator = coordinator
|
||||
|
||||
async def execute(self) -> list[CommitIntent]:
|
||||
"""
|
||||
Execute agent interaction.
|
||||
|
||||
This will call into the Gateway, which will run the agent.
|
||||
The agent may read from stores and generate responses.
|
||||
Any store modifications are captured as commit intents.
|
||||
|
||||
Returns:
|
||||
List of commit intents (typically empty for now, as agent
|
||||
modifies stores via tools which will be integrated later)
|
||||
"""
|
||||
if not self.gateway_handler:
|
||||
logger.error("No gateway_handler configured for AgentTriggerHandler")
|
||||
return []
|
||||
|
||||
logger.info(
|
||||
f"Agent trigger executing: session={self.session_id}, "
|
||||
f"content='{self.message_content[:50]}...'"
|
||||
)
|
||||
|
||||
try:
|
||||
# Call Gateway to handle message
|
||||
# In future, Gateway/agent tools will use coordinator stores
|
||||
await self.gateway_handler(
|
||||
self.session_id,
|
||||
self.message_content,
|
||||
self.attachments,
|
||||
)
|
||||
|
||||
# For now, agent doesn't directly modify stores
|
||||
# Future: agent tools will return commit intents
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent execution error: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
class LambdaHandler(Trigger):
|
||||
"""
|
||||
Generic trigger that executes an arbitrary async callable.
|
||||
|
||||
Useful for custom triggers, one-off tasks, or testing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
func: Callable[[], Awaitable[list[CommitIntent]]],
|
||||
priority: Priority = Priority.SYSTEM,
|
||||
):
|
||||
"""
|
||||
Initialize lambda handler.
|
||||
|
||||
Args:
|
||||
name: Descriptive name for this trigger
|
||||
func: Async callable that returns commit intents
|
||||
priority: Execution priority
|
||||
"""
|
||||
super().__init__(name, priority)
|
||||
self.func = func
|
||||
|
||||
async def execute(self) -> list[CommitIntent]:
|
||||
"""Execute the callable"""
|
||||
logger.info(f"Lambda trigger executing: {self.name}")
|
||||
return await self.func()
|
||||
|
||||
|
||||
class DataUpdateTrigger(Trigger):
|
||||
"""
|
||||
Trigger for DataSource bar updates.
|
||||
|
||||
Fired when new market data arrives. Can update indicators,
|
||||
trigger strategy logic, or notify the agent of market events.
|
||||
|
||||
Priority tuple: (DATA_SOURCE, event_time, queue_seq)
|
||||
Ensures older bars process before newer ones.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source_name: str,
|
||||
symbol: str,
|
||||
resolution: str,
|
||||
bar_data: dict,
|
||||
coordinator: Optional[CommitCoordinator] = None,
|
||||
):
|
||||
"""
|
||||
Initialize data update trigger.
|
||||
|
||||
Args:
|
||||
source_name: Name of data source (e.g., "binance")
|
||||
symbol: Trading pair symbol
|
||||
resolution: Time resolution
|
||||
bar_data: Bar data dict (time, open, high, low, close, volume)
|
||||
coordinator: CommitCoordinator for accessing stores
|
||||
"""
|
||||
event_time = bar_data.get('time', int(time.time()))
|
||||
|
||||
# Priority tuple: sort by DATA_SOURCE priority, then event time
|
||||
super().__init__(
|
||||
name=f"data_{source_name}_{symbol}_{resolution}",
|
||||
priority=Priority.DATA_SOURCE,
|
||||
priority_tuple=(Priority.DATA_SOURCE.value, event_time)
|
||||
)
|
||||
self.source_name = source_name
|
||||
self.symbol = symbol
|
||||
self.resolution = resolution
|
||||
self.bar_data = bar_data
|
||||
self.coordinator = coordinator
|
||||
|
||||
async def execute(self) -> list[CommitIntent]:
|
||||
"""
|
||||
Process bar update.
|
||||
|
||||
Future implementations will:
|
||||
- Update indicator values
|
||||
- Check strategy conditions
|
||||
- Trigger alerts/notifications
|
||||
|
||||
Returns:
|
||||
Commit intents for any store updates
|
||||
"""
|
||||
logger.info(
|
||||
f"Data update trigger: {self.source_name}:{self.symbol}@{self.resolution}, "
|
||||
f"time={self.bar_data.get('time')}"
|
||||
)
|
||||
|
||||
# TODO: Update indicators
|
||||
# TODO: Check strategy conditions
|
||||
# TODO: Notify agent of significant events
|
||||
|
||||
# For now, just log
|
||||
return []
|
||||
|
||||
|
||||
class IndicatorUpdateTrigger(Trigger):
|
||||
"""
|
||||
Trigger for updating indicator values.
|
||||
|
||||
Can be fired by cron (periodic recalculation) or by data updates.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
indicator_id: str,
|
||||
force_full_recalc: bool = False,
|
||||
coordinator: Optional[CommitCoordinator] = None,
|
||||
priority: Priority = Priority.SYSTEM,
|
||||
):
|
||||
"""
|
||||
Initialize indicator update trigger.
|
||||
|
||||
Args:
|
||||
indicator_id: ID of indicator to update
|
||||
force_full_recalc: If True, recalculate entire history
|
||||
coordinator: CommitCoordinator for accessing stores
|
||||
priority: Execution priority
|
||||
"""
|
||||
super().__init__(f"indicator_{indicator_id}", priority)
|
||||
self.indicator_id = indicator_id
|
||||
self.force_full_recalc = force_full_recalc
|
||||
self.coordinator = coordinator
|
||||
|
||||
async def execute(self) -> list[CommitIntent]:
|
||||
"""
|
||||
Update indicator value.
|
||||
|
||||
Reads from IndicatorStore, recalculates, prepares commit.
|
||||
|
||||
Returns:
|
||||
Commit intents for updated indicator data
|
||||
"""
|
||||
if not self.coordinator:
|
||||
logger.error("No coordinator configured")
|
||||
return []
|
||||
|
||||
# Get indicator store
|
||||
indicator_store = self.coordinator.get_store("IndicatorStore")
|
||||
if not indicator_store:
|
||||
logger.error("IndicatorStore not registered")
|
||||
return []
|
||||
|
||||
# Read snapshot
|
||||
snapshot_seq, indicator_data = indicator_store.read_snapshot()
|
||||
|
||||
logger.info(
|
||||
f"Indicator update trigger: {self.indicator_id}, "
|
||||
f"snapshot_seq={snapshot_seq}, force_full={self.force_full_recalc}"
|
||||
)
|
||||
|
||||
# TODO: Implement indicator recalculation logic
|
||||
# For now, just return empty (no changes)
|
||||
|
||||
return []
|
||||
|
||||
|
||||
class CronTrigger(Trigger):
|
||||
"""
|
||||
Trigger fired by APScheduler on a schedule.
|
||||
|
||||
Wraps another trigger or callable to execute periodically.
|
||||
|
||||
Priority tuple: (TIMER, scheduled_time, queue_seq)
|
||||
Ensures jobs scheduled for earlier times run first.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
inner_trigger: Trigger,
|
||||
scheduled_time: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Initialize cron trigger.
|
||||
|
||||
Args:
|
||||
name: Descriptive name (e.g., "hourly_sync")
|
||||
inner_trigger: Trigger to execute on schedule
|
||||
scheduled_time: When this was scheduled to run (defaults to now)
|
||||
"""
|
||||
if scheduled_time is None:
|
||||
scheduled_time = int(time.time())
|
||||
|
||||
# Priority tuple: sort by TIMER priority, then scheduled time
|
||||
super().__init__(
|
||||
name=f"cron_{name}",
|
||||
priority=Priority.TIMER,
|
||||
priority_tuple=(Priority.TIMER.value, scheduled_time)
|
||||
)
|
||||
self.inner_trigger = inner_trigger
|
||||
self.scheduled_time = scheduled_time
|
||||
|
||||
async def execute(self) -> list[CommitIntent]:
|
||||
"""Execute the wrapped trigger"""
|
||||
logger.info(f"Cron trigger firing: {self.name}")
|
||||
return await self.inner_trigger.execute()
|
||||
224
backend.old/src/trigger/queue.py
Normal file
224
backend.old/src/trigger/queue.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
Trigger queue - priority queue with sequence number assignment.
|
||||
|
||||
All operations flow through this queue:
|
||||
- WebSocket messages from users
|
||||
- Cron scheduled tasks
|
||||
- DataSource bar updates
|
||||
- Manual triggers
|
||||
|
||||
Queue assigns seq numbers on dequeue, executes triggers, and submits to coordinator.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from .context import ExecutionContext, clear_execution_context, set_execution_context
|
||||
from .coordinator import CommitCoordinator
|
||||
from .types import Priority, PriorityTuple, Trigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerQueue:
|
||||
"""
|
||||
Priority queue for trigger execution.
|
||||
|
||||
Key responsibilities:
|
||||
- Maintain priority queue (high priority dequeued first)
|
||||
- Assign sequence numbers on dequeue (determines commit order)
|
||||
- Execute triggers with context set
|
||||
- Submit results to CommitCoordinator
|
||||
- Handle execution errors gracefully
|
||||
"""
|
||||
|
||||
def __init__(self, coordinator: CommitCoordinator):
|
||||
"""
|
||||
Initialize trigger queue.
|
||||
|
||||
Args:
|
||||
coordinator: CommitCoordinator for handling commits
|
||||
"""
|
||||
self._coordinator = coordinator
|
||||
self._queue: asyncio.PriorityQueue = asyncio.PriorityQueue()
|
||||
self._seq_counter = 0
|
||||
self._seq_lock = asyncio.Lock()
|
||||
self._processor_task: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the queue processor"""
|
||||
if self._running:
|
||||
logger.warning("TriggerQueue already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._processor_task = asyncio.create_task(self._process_loop())
|
||||
logger.info("TriggerQueue started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the queue processor gracefully"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
|
||||
if self._processor_task:
|
||||
self._processor_task.cancel()
|
||||
try:
|
||||
await self._processor_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("TriggerQueue stopped")
|
||||
|
||||
async def enqueue(
|
||||
self,
|
||||
trigger: Trigger,
|
||||
priority_override: Optional[Priority | PriorityTuple] = None
|
||||
) -> int:
|
||||
"""
|
||||
Add a trigger to the queue.
|
||||
|
||||
Args:
|
||||
trigger: Trigger to execute
|
||||
priority_override: Override priority (simple Priority or tuple)
|
||||
If None, uses trigger's priority/priority_tuple
|
||||
If Priority enum, creates single-element tuple
|
||||
If tuple, uses as-is
|
||||
|
||||
Returns:
|
||||
Queue sequence number (appended to priority tuple)
|
||||
|
||||
Examples:
|
||||
# Simple priority
|
||||
await queue.enqueue(trigger, Priority.USER_AGENT)
|
||||
# Results in: (Priority.USER_AGENT, queue_seq)
|
||||
|
||||
# Tuple priority with event time
|
||||
await queue.enqueue(
|
||||
trigger,
|
||||
(Priority.DATA_SOURCE, bar_data['time'])
|
||||
)
|
||||
# Results in: (Priority.DATA_SOURCE, bar_time, queue_seq)
|
||||
|
||||
# Let trigger decide
|
||||
await queue.enqueue(trigger)
|
||||
"""
|
||||
# Get monotonic seq for queue ordering (appended to tuple)
|
||||
async with self._seq_lock:
|
||||
queue_seq = self._seq_counter
|
||||
self._seq_counter += 1
|
||||
|
||||
# Determine priority tuple
|
||||
if priority_override is not None:
|
||||
if isinstance(priority_override, Priority):
|
||||
# Convert simple priority to tuple
|
||||
priority_tuple = (priority_override.value, queue_seq)
|
||||
else:
|
||||
# Use provided tuple, append queue_seq
|
||||
priority_tuple = priority_override + (queue_seq,)
|
||||
else:
|
||||
# Let trigger determine its own priority tuple
|
||||
priority_tuple = trigger.get_priority_tuple(queue_seq)
|
||||
|
||||
# Priority queue: (priority_tuple, trigger)
|
||||
# Python's PriorityQueue compares tuples element-by-element
|
||||
await self._queue.put((priority_tuple, trigger))
|
||||
|
||||
logger.debug(
|
||||
f"Enqueued: {trigger.name} with priority_tuple={priority_tuple}"
|
||||
)
|
||||
|
||||
return queue_seq
|
||||
|
||||
async def _process_loop(self) -> None:
|
||||
"""
|
||||
Main processing loop.
|
||||
|
||||
Dequeues triggers, assigns execution seq, executes, and submits to coordinator.
|
||||
"""
|
||||
execution_seq = 0 # Separate counter for execution sequence
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# Wait for next trigger (with timeout to check _running flag)
|
||||
try:
|
||||
priority_tuple, trigger = await asyncio.wait_for(
|
||||
self._queue.get(), timeout=1.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
# Assign execution sequence number
|
||||
execution_seq += 1
|
||||
|
||||
logger.info(
|
||||
f"Dequeued: seq={execution_seq}, trigger={trigger.name}, "
|
||||
f"priority_tuple={priority_tuple}"
|
||||
)
|
||||
|
||||
# Execute in background (don't block queue)
|
||||
asyncio.create_task(
|
||||
self._execute_trigger(execution_seq, trigger)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in process loop: {e}", exc_info=True)
|
||||
|
||||
async def _execute_trigger(self, seq: int, trigger: Trigger) -> None:
|
||||
"""
|
||||
Execute a trigger with proper context and error handling.
|
||||
|
||||
Args:
|
||||
seq: Execution sequence number
|
||||
trigger: Trigger to execute
|
||||
"""
|
||||
# Set up execution context
|
||||
ctx = ExecutionContext(
|
||||
seq=seq,
|
||||
trigger_name=trigger.name,
|
||||
)
|
||||
set_execution_context(ctx)
|
||||
|
||||
# Record execution start with coordinator
|
||||
await self._coordinator.start_execution(seq, trigger)
|
||||
|
||||
try:
|
||||
logger.info(f"Executing: seq={seq}, trigger={trigger.name}")
|
||||
|
||||
# Execute trigger (can be long-running)
|
||||
commit_intents = await trigger.execute()
|
||||
|
||||
logger.info(
|
||||
f"Execution complete: seq={seq}, {len(commit_intents)} commit intents"
|
||||
)
|
||||
|
||||
# Submit for sequential commit
|
||||
await self._coordinator.submit_for_commit(seq, commit_intents)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Execution failed: seq={seq}, trigger={trigger.name}, error={e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Notify coordinator of failure
|
||||
await self._coordinator.execution_failed(seq, e)
|
||||
|
||||
finally:
|
||||
clear_execution_context()
|
||||
|
||||
def get_queue_size(self) -> int:
|
||||
"""Get current queue size (approximate)"""
|
||||
return self._queue.qsize()
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""Check if queue processor is running"""
|
||||
return self._running
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"TriggerQueue(running={self._running}, queue_size={self.get_queue_size()})"
|
||||
)
|
||||
187
backend.old/src/trigger/scheduler.py
Normal file
187
backend.old/src/trigger/scheduler.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
APScheduler integration for cron-style triggers.
|
||||
|
||||
Provides scheduling of periodic triggers (e.g., sync exchange state hourly,
|
||||
recompute indicators every 5 minutes, daily portfolio reports).
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger as APSCronTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
|
||||
from .queue import TriggerQueue
|
||||
from .types import Priority, Trigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerScheduler:
|
||||
"""
|
||||
Scheduler for periodic trigger execution.
|
||||
|
||||
Wraps APScheduler to enqueue triggers at scheduled times.
|
||||
"""
|
||||
|
||||
def __init__(self, trigger_queue: TriggerQueue):
|
||||
"""
|
||||
Initialize scheduler.
|
||||
|
||||
Args:
|
||||
trigger_queue: TriggerQueue to enqueue triggers into
|
||||
"""
|
||||
self.trigger_queue = trigger_queue
|
||||
self.scheduler = AsyncIOScheduler()
|
||||
self._job_counter = 0
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the scheduler"""
|
||||
self.scheduler.start()
|
||||
logger.info("TriggerScheduler started")
|
||||
|
||||
def shutdown(self, wait: bool = True) -> None:
|
||||
"""
|
||||
Shut down the scheduler.
|
||||
|
||||
Args:
|
||||
wait: If True, wait for running jobs to complete
|
||||
"""
|
||||
self.scheduler.shutdown(wait=wait)
|
||||
logger.info("TriggerScheduler shut down")
|
||||
|
||||
def schedule_interval(
|
||||
self,
|
||||
trigger: Trigger,
|
||||
seconds: Optional[int] = None,
|
||||
minutes: Optional[int] = None,
|
||||
hours: Optional[int] = None,
|
||||
priority: Optional[Priority] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Schedule a trigger to run at regular intervals.
|
||||
|
||||
Args:
|
||||
trigger: Trigger to execute
|
||||
seconds: Interval in seconds
|
||||
minutes: Interval in minutes
|
||||
hours: Interval in hours
|
||||
priority: Priority override for execution
|
||||
|
||||
Returns:
|
||||
Job ID (can be used to remove job later)
|
||||
|
||||
Example:
|
||||
# Run every 5 minutes
|
||||
scheduler.schedule_interval(
|
||||
IndicatorUpdateTrigger("rsi_14"),
|
||||
minutes=5
|
||||
)
|
||||
"""
|
||||
job_id = f"interval_{self._job_counter}"
|
||||
self._job_counter += 1
|
||||
|
||||
async def job_func():
|
||||
await self.trigger_queue.enqueue(trigger, priority)
|
||||
|
||||
self.scheduler.add_job(
|
||||
job_func,
|
||||
trigger=IntervalTrigger(seconds=seconds, minutes=minutes, hours=hours),
|
||||
id=job_id,
|
||||
name=f"Interval: {trigger.name}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Scheduled interval job: {job_id}, trigger={trigger.name}, "
|
||||
f"interval=(s={seconds}, m={minutes}, h={hours})"
|
||||
)
|
||||
|
||||
return job_id
|
||||
|
||||
def schedule_cron(
|
||||
self,
|
||||
trigger: Trigger,
|
||||
minute: Optional[str] = None,
|
||||
hour: Optional[str] = None,
|
||||
day: Optional[str] = None,
|
||||
month: Optional[str] = None,
|
||||
day_of_week: Optional[str] = None,
|
||||
priority: Optional[Priority] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Schedule a trigger to run on a cron schedule.
|
||||
|
||||
Args:
|
||||
trigger: Trigger to execute
|
||||
minute: Minute expression (0-59, *, */5, etc.)
|
||||
hour: Hour expression (0-23, *, etc.)
|
||||
day: Day of month expression (1-31, *, etc.)
|
||||
month: Month expression (1-12, *, etc.)
|
||||
day_of_week: Day of week expression (0-6, mon-sun, *, etc.)
|
||||
priority: Priority override for execution
|
||||
|
||||
Returns:
|
||||
Job ID (can be used to remove job later)
|
||||
|
||||
Example:
|
||||
# Run at 9:00 AM every weekday
|
||||
scheduler.schedule_cron(
|
||||
SyncExchangeStateTrigger(),
|
||||
hour="9",
|
||||
minute="0",
|
||||
day_of_week="mon-fri"
|
||||
)
|
||||
"""
|
||||
job_id = f"cron_{self._job_counter}"
|
||||
self._job_counter += 1
|
||||
|
||||
async def job_func():
|
||||
await self.trigger_queue.enqueue(trigger, priority)
|
||||
|
||||
self.scheduler.add_job(
|
||||
job_func,
|
||||
trigger=APSCronTrigger(
|
||||
minute=minute,
|
||||
hour=hour,
|
||||
day=day,
|
||||
month=month,
|
||||
day_of_week=day_of_week,
|
||||
),
|
||||
id=job_id,
|
||||
name=f"Cron: {trigger.name}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Scheduled cron job: {job_id}, trigger={trigger.name}, "
|
||||
f"schedule=(m={minute}, h={hour}, d={day}, dow={day_of_week})"
|
||||
)
|
||||
|
||||
return job_id
|
||||
|
||||
def remove_job(self, job_id: str) -> bool:
|
||||
"""
|
||||
Remove a scheduled job.
|
||||
|
||||
Args:
|
||||
job_id: Job ID returned from schedule_* methods
|
||||
|
||||
Returns:
|
||||
True if job was removed, False if not found
|
||||
"""
|
||||
try:
|
||||
self.scheduler.remove_job(job_id)
|
||||
logger.info(f"Removed scheduled job: {job_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not remove job {job_id}: {e}")
|
||||
return False
|
||||
|
||||
def get_jobs(self) -> list:
|
||||
"""Get list of all scheduled jobs"""
|
||||
return self.scheduler.get_jobs()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
job_count = len(self.scheduler.get_jobs())
|
||||
running = self.scheduler.running
|
||||
return f"TriggerScheduler(running={running}, jobs={job_count})"
|
||||
301
backend.old/src/trigger/store.py
Normal file
301
backend.old/src/trigger/store.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""
|
||||
Versioned store with pluggable backends.
|
||||
|
||||
Provides optimistic concurrency control via sequence numbers with support
|
||||
for different storage backends (Pydantic models, files, databases, etc.).
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from .context import get_execution_context
|
||||
from .types import CommitIntent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class StoreBackend(ABC, Generic[T]):
|
||||
"""
|
||||
Abstract backend for versioned stores.
|
||||
|
||||
Allows different storage mechanisms (Pydantic models, files, databases)
|
||||
to be used with the same versioned store infrastructure.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def read(self) -> T:
|
||||
"""
|
||||
Read the current data.
|
||||
|
||||
Returns:
|
||||
Current data in backend-specific format
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def write(self, data: T) -> None:
|
||||
"""
|
||||
Write new data (replaces existing).
|
||||
|
||||
Args:
|
||||
data: New data to write
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def snapshot(self) -> T:
|
||||
"""
|
||||
Create an immutable snapshot of current data.
|
||||
|
||||
Must return a deep copy or immutable version to prevent
|
||||
modifications from affecting the committed state.
|
||||
|
||||
Returns:
|
||||
Immutable snapshot of data
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate(self, data: T) -> bool:
|
||||
"""
|
||||
Validate that data is in correct format for this backend.
|
||||
|
||||
Args:
|
||||
data: Data to validate
|
||||
|
||||
Returns:
|
||||
True if valid
|
||||
|
||||
Raises:
|
||||
ValueError: If invalid with explanation
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class PydanticStoreBackend(StoreBackend[T]):
|
||||
"""
|
||||
Backend for Pydantic BaseModel stores.
|
||||
|
||||
Supports the existing OrderStore, ChartStore, etc. pattern.
|
||||
"""
|
||||
|
||||
def __init__(self, model_instance: T):
|
||||
"""
|
||||
Initialize with a Pydantic model instance.
|
||||
|
||||
Args:
|
||||
model_instance: Instance of a Pydantic BaseModel
|
||||
"""
|
||||
self._model = model_instance
|
||||
|
||||
def read(self) -> T:
|
||||
return self._model
|
||||
|
||||
def write(self, data: T) -> None:
|
||||
# Replace the internal model
|
||||
self._model = data
|
||||
|
||||
def snapshot(self) -> T:
|
||||
# Use Pydantic's model_copy for deep copy
|
||||
if hasattr(self._model, "model_copy"):
|
||||
return self._model.model_copy(deep=True)
|
||||
# Fallback for older Pydantic or non-model types
|
||||
return deepcopy(self._model)
|
||||
|
||||
def validate(self, data: T) -> bool:
|
||||
# Pydantic models validate themselves on construction
|
||||
# If we got here with a model instance, it's valid
|
||||
return True
|
||||
|
||||
|
||||
class FileStoreBackend(StoreBackend[str]):
|
||||
"""
|
||||
Backend for file-based storage.
|
||||
|
||||
Future implementation for versioning files (e.g., Python scripts, configs).
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
self.file_path = file_path
|
||||
raise NotImplementedError("FileStoreBackend not yet implemented")
|
||||
|
||||
def read(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def write(self, data: str) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def snapshot(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def validate(self, data: str) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DatabaseStoreBackend(StoreBackend[dict]):
|
||||
"""
|
||||
Backend for database table storage.
|
||||
|
||||
Future implementation for versioning database interactions.
|
||||
"""
|
||||
|
||||
def __init__(self, table_name: str, connection):
|
||||
self.table_name = table_name
|
||||
self.connection = connection
|
||||
raise NotImplementedError("DatabaseStoreBackend not yet implemented")
|
||||
|
||||
def read(self) -> dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
def write(self, data: dict) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def snapshot(self) -> dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
def validate(self, data: dict) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class VersionedStore(Generic[T]):
|
||||
"""
|
||||
Store with optimistic concurrency control via sequence numbers.
|
||||
|
||||
Wraps any StoreBackend and provides:
|
||||
- Lock-free snapshot reads
|
||||
- Conflict detection on commit
|
||||
- Version tracking for debugging
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, backend: StoreBackend[T]):
|
||||
"""
|
||||
Initialize versioned store.
|
||||
|
||||
Args:
|
||||
name: Unique name for this store (e.g., "OrderStore")
|
||||
backend: Backend implementation for storage
|
||||
"""
|
||||
self.name = name
|
||||
self._backend = backend
|
||||
self._committed_seq = 0 # Highest committed seq
|
||||
self._version = 0 # Increments on each commit (for debugging)
|
||||
|
||||
@property
|
||||
def committed_seq(self) -> int:
|
||||
"""Get the current committed sequence number"""
|
||||
return self._committed_seq
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
"""Get the current version (increments on each commit)"""
|
||||
return self._version
|
||||
|
||||
def read_snapshot(self) -> tuple[int, T]:
|
||||
"""
|
||||
Read an immutable snapshot of the store.
|
||||
|
||||
This is lock-free and can be called concurrently. The snapshot
|
||||
captures the current committed seq and a deep copy of the data.
|
||||
|
||||
Automatically records the snapshot seq in the execution context
|
||||
for conflict detection during commit.
|
||||
|
||||
Returns:
|
||||
Tuple of (seq, snapshot_data)
|
||||
"""
|
||||
snapshot_seq = self._committed_seq
|
||||
snapshot_data = self._backend.snapshot()
|
||||
|
||||
# Record in execution context for conflict detection
|
||||
ctx = get_execution_context()
|
||||
if ctx:
|
||||
ctx.record_snapshot(self.name, snapshot_seq)
|
||||
|
||||
logger.debug(
|
||||
f"Store '{self.name}': read_snapshot() -> seq={snapshot_seq}, version={self._version}"
|
||||
)
|
||||
|
||||
return (snapshot_seq, snapshot_data)
|
||||
|
||||
def read_current(self) -> T:
|
||||
"""
|
||||
Read the current data without snapshot tracking.
|
||||
|
||||
Use this for read-only operations that don't need conflict detection.
|
||||
|
||||
Returns:
|
||||
Current data (not a snapshot, modifications visible)
|
||||
"""
|
||||
return self._backend.read()
|
||||
|
||||
def prepare_commit(self, expected_seq: int, new_data: T) -> CommitIntent:
|
||||
"""
|
||||
Create a commit intent for later sequential commit.
|
||||
|
||||
Does NOT modify the store - that happens during the commit phase.
|
||||
|
||||
Args:
|
||||
expected_seq: The seq of the snapshot that was read
|
||||
new_data: The new data to commit
|
||||
|
||||
Returns:
|
||||
CommitIntent to be submitted to CommitCoordinator
|
||||
"""
|
||||
# Validate data before creating intent
|
||||
self._backend.validate(new_data)
|
||||
|
||||
intent = CommitIntent(
|
||||
store_name=self.name,
|
||||
expected_seq=expected_seq,
|
||||
new_data=new_data,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Store '{self.name}': prepare_commit(expected_seq={expected_seq}, current_seq={self._committed_seq})"
|
||||
)
|
||||
|
||||
return intent
|
||||
|
||||
def commit(self, new_data: T, commit_seq: int) -> None:
|
||||
"""
|
||||
Commit new data at a specific seq.
|
||||
|
||||
Called by CommitCoordinator during sequential commit phase.
|
||||
NOT for direct use by triggers.
|
||||
|
||||
Args:
|
||||
new_data: Data to commit
|
||||
commit_seq: Seq number of this commit
|
||||
"""
|
||||
self._backend.write(new_data)
|
||||
self._committed_seq = commit_seq
|
||||
self._version += 1
|
||||
|
||||
logger.info(
|
||||
f"Store '{self.name}': committed seq={commit_seq}, version={self._version}"
|
||||
)
|
||||
|
||||
def check_conflict(self, expected_seq: int) -> bool:
|
||||
"""
|
||||
Check if committing at expected_seq would conflict.
|
||||
|
||||
Args:
|
||||
expected_seq: The seq that was expected during execution
|
||||
|
||||
Returns:
|
||||
True if conflict (committed_seq has advanced beyond expected_seq)
|
||||
"""
|
||||
has_conflict = self._committed_seq != expected_seq
|
||||
if has_conflict:
|
||||
logger.warning(
|
||||
f"Store '{self.name}': conflict detected - "
|
||||
f"expected_seq={expected_seq}, committed_seq={self._committed_seq}"
|
||||
)
|
||||
return has_conflict
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"VersionedStore(name='{self.name}', committed_seq={self._committed_seq}, version={self._version})"
|
||||
175
backend.old/src/trigger/types.py
Normal file
175
backend.old/src/trigger/types.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
Core types for the trigger system.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Priority(IntEnum):
|
||||
"""
|
||||
Primary execution priority for triggers.
|
||||
|
||||
Lower numeric value = higher priority (dequeued first).
|
||||
|
||||
Priority hierarchy (highest to lowest):
|
||||
- DATA_SOURCE: Market data, real-time feeds (most time-sensitive)
|
||||
- TIMER: Scheduled tasks, cron jobs
|
||||
- USER_AGENT: User-agent interactions (WebSocket chat)
|
||||
- USER_DATA_REQUEST: User data requests (chart loads, symbol search)
|
||||
- SYSTEM: Background tasks, cleanup
|
||||
- LOW: Retries after conflicts, non-critical tasks
|
||||
"""
|
||||
|
||||
DATA_SOURCE = 0 # Market data updates, real-time feeds
|
||||
TIMER = 1 # Scheduled tasks, cron jobs
|
||||
USER_AGENT = 2 # User-agent interactions (WebSocket chat)
|
||||
USER_DATA_REQUEST = 3 # User data requests (chart loads, etc.)
|
||||
SYSTEM = 4 # Background tasks, cleanup, etc.
|
||||
LOW = 5 # Retries after conflicts, non-critical tasks
|
||||
|
||||
|
||||
# Type alias for priority tuples
|
||||
# Examples:
|
||||
# (Priority.DATA_SOURCE,) - Simple priority
|
||||
# (Priority.DATA_SOURCE, event_time) - Priority + event time
|
||||
# (Priority.DATA_SOURCE, event_time, queue_seq) - Full ordering
|
||||
#
|
||||
# Python compares tuples element-by-element, left-to-right.
|
||||
# Shorter tuple wins if all shared elements are equal.
|
||||
PriorityTuple = tuple[int, ...]
|
||||
|
||||
|
||||
class ExecutionState(IntEnum):
|
||||
"""State of an execution in the system"""
|
||||
|
||||
QUEUED = 0 # In queue, waiting to be dequeued
|
||||
EXECUTING = 1 # Currently executing
|
||||
WAITING_COMMIT = 2 # Finished executing, waiting for sequential commit
|
||||
COMMITTED = 3 # Successfully committed
|
||||
EVICTED = 4 # Evicted due to conflict, will retry
|
||||
FAILED = 5 # Failed with error
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommitIntent:
|
||||
"""
|
||||
Intent to commit changes to a store.
|
||||
|
||||
Created during execution, validated and applied during sequential commit phase.
|
||||
"""
|
||||
|
||||
store_name: str
|
||||
"""Name of the store to commit to"""
|
||||
|
||||
expected_seq: int
|
||||
"""The seq number of the snapshot that was read (for conflict detection)"""
|
||||
|
||||
new_data: Any
|
||||
"""The new data to commit (format depends on store backend)"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
data_preview = str(self.new_data)[:50]
|
||||
return f"CommitIntent(store={self.store_name}, expected_seq={self.expected_seq}, data={data_preview}...)"
|
||||
|
||||
|
||||
class Trigger(ABC):
|
||||
"""
|
||||
Abstract base class for all triggers.
|
||||
|
||||
A trigger represents a unit of work that:
|
||||
1. Gets assigned a seq number when dequeued
|
||||
2. Executes (potentially long-running, async)
|
||||
3. Returns CommitIntents for any state changes
|
||||
4. Waits for sequential commit
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
priority: Priority = Priority.SYSTEM,
|
||||
priority_tuple: Optional[PriorityTuple] = None
|
||||
):
|
||||
"""
|
||||
Initialize trigger.
|
||||
|
||||
Args:
|
||||
name: Descriptive name for logging
|
||||
priority: Simple priority (used if priority_tuple not provided)
|
||||
priority_tuple: Optional tuple for compound sorting
|
||||
Examples:
|
||||
(Priority.DATA_SOURCE, event_time)
|
||||
(Priority.USER_AGENT, message_timestamp)
|
||||
(Priority.TIMER, scheduled_time)
|
||||
"""
|
||||
self.name = name
|
||||
self.priority = priority
|
||||
self._priority_tuple = priority_tuple
|
||||
|
||||
def get_priority_tuple(self, queue_seq: int) -> PriorityTuple:
|
||||
"""
|
||||
Get the priority tuple for queue ordering.
|
||||
|
||||
If a priority tuple was provided at construction, append queue_seq.
|
||||
Otherwise, create tuple from simple priority.
|
||||
|
||||
Args:
|
||||
queue_seq: Queue insertion order (final sort key)
|
||||
|
||||
Returns:
|
||||
Priority tuple for queue ordering
|
||||
|
||||
Examples:
|
||||
(Priority.DATA_SOURCE,) + (queue_seq,) = (0, queue_seq)
|
||||
(Priority.DATA_SOURCE, 1000) + (queue_seq,) = (0, 1000, queue_seq)
|
||||
"""
|
||||
if self._priority_tuple is not None:
|
||||
return self._priority_tuple + (queue_seq,)
|
||||
else:
|
||||
return (self.priority.value, queue_seq)
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self) -> list[CommitIntent]:
|
||||
"""
|
||||
Execute the trigger logic.
|
||||
|
||||
Can be long-running and async. Should read from stores via
|
||||
VersionedStore.read_snapshot() and return CommitIntents for any changes.
|
||||
|
||||
Returns:
|
||||
List of CommitIntents (empty if no state changes)
|
||||
|
||||
Raises:
|
||||
Exception: On execution failure (will be logged, no commit)
|
||||
"""
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(name='{self.name}', priority={self.priority.name})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionRecord:
|
||||
"""
|
||||
Record of an execution for tracking and debugging.
|
||||
|
||||
Maintained by the CommitCoordinator to track in-flight executions.
|
||||
"""
|
||||
|
||||
seq: int
|
||||
trigger: Trigger
|
||||
state: ExecutionState
|
||||
commit_intents: Optional[list[CommitIntent]] = None
|
||||
error: Optional[str] = None
|
||||
retry_count: int = 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"ExecutionRecord(seq={self.seq}, trigger={self.trigger.name}, "
|
||||
f"state={self.state.name}, retry={self.retry_count})"
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user