Files
ai/gateway/src/tools/tool-registry.ts

365 lines
12 KiB
TypeScript

import type { DynamicStructuredTool } from '@langchain/core/tools';
import type { FastifyBaseLogger } from 'fastify';
import type { MCPClientConnector } from '../harness/mcp-client.js';
import type { OHLCService } from '../services/ohlc-service.js';
import type { SymbolIndexService } from '../services/symbol-index-service.js';
import type { WorkspaceManager } from '../workspace/workspace-manager.js';
import type { Ticker24hSnapshot } from '../clients/zmq-relay-client.js';
import { createSymbolLookupTool } from './platform/symbol-lookup.tool.js';
import { createGetChartDataTool } from './platform/get-chart-data.tool.js';
import { createGetTicker24hTool } from './platform/get-ticker24h.tool.js';
import { createWebSearchTool } from './platform/web-search.tool.js';
import { createFetchPageTool } from './platform/fetch-page.tool.js';
import { createArxivSearchTool } from './platform/arxiv-search.tool.js';
import { createMCPToolWrappers, type MCPToolInfo } from './mcp/mcp-tool-wrapper.js';
/**
* Agent tool configuration
* Specifies which tools are available to which agent
*/
export interface AgentToolConfig {
/** Agent name (e.g., 'main', 'research', 'web-explore') */
agentName: string;
/** Platform tool names to include */
platformTools: string[];
/** MCP tool patterns/names to include (supports wildcards like 'Python*') */
mcpTools: string[];
}
/**
* Platform services required for creating platform tools
* Can be provided as direct references or getter functions (for lazy initialization)
*/
export interface PlatformServices {
ohlcService?: OHLCService | (() => OHLCService | undefined);
symbolIndexService?: SymbolIndexService | (() => SymbolIndexService | undefined);
workspaceManager?: WorkspaceManager | (() => WorkspaceManager | undefined);
ticker24hGetter?: (exchange: string) => Ticker24hSnapshot | undefined;
tavilyApiKey?: string;
}
/**
* Tool Registry
*
* Manages tool creation and agent-to-tool mappings.
* Supports:
* - Platform tools (local services like symbol lookup, chart data)
* - Remote MCP tools (per-user, session-scoped)
* - Configurable tool routing (which tools for which agents)
*/
export class ToolRegistry {
private logger: FastifyBaseLogger;
private platformServices: PlatformServices;
private agentToolConfigs: Map<string, AgentToolConfig> = new Map();
constructor(logger: FastifyBaseLogger, platformServices: PlatformServices) {
this.logger = logger;
this.platformServices = platformServices;
}
/**
* Register agent tool configuration
*/
registerAgentTools(config: AgentToolConfig): void {
this.agentToolConfigs.set(config.agentName, config);
this.logger.debug(
{
agent: config.agentName,
platformTools: config.platformTools,
mcpTools: config.mcpTools,
},
'Registered agent tool configuration'
);
}
/**
* Get tools for a specific agent
*
* @param agentName - Name of the agent ('main', 'research', etc.)
* @param mcpClient - MCP client for remote tools (optional)
* @param availableMCPTools - List of available MCP tools from user's server (optional)
* @param workspaceManager - Workspace manager for this session (optional, used by some platform tools)
* @returns Array of tools for this agent
*/
async getToolsForAgent(
agentName: string,
mcpClient?: MCPClientConnector,
availableMCPTools?: MCPToolInfo[],
workspaceManager?: WorkspaceManager,
onImage?: (image: { data: string; mimeType: string }) => void,
onWorkspaceMutation?: (storeName: string, newState: unknown) => void
): Promise<DynamicStructuredTool[]> {
const config = this.agentToolConfigs.get(agentName);
if (!config) {
this.logger.warn({ agent: agentName }, 'No tool configuration found for agent');
return [];
}
const tools: DynamicStructuredTool[] = [];
// Add platform tools
for (const toolName of config.platformTools) {
const tool = await this.getPlatformTool(toolName, workspaceManager);
if (tool) {
tools.push(tool);
} else {
this.logger.warn({ agent: agentName, tool: toolName }, 'Platform tool not found');
}
}
// Add MCP tools (if MCP client and tools are available)
if (mcpClient && availableMCPTools && availableMCPTools.length > 0) {
const filteredMCPTools = this.filterMCPTools(availableMCPTools, config.mcpTools);
const mcpToolInstances = createMCPToolWrappers(filteredMCPTools, mcpClient, this.logger, onImage, onWorkspaceMutation);
tools.push(...mcpToolInstances);
this.logger.debug(
{
agent: agentName,
mcpToolCount: mcpToolInstances.length,
mcpToolNames: mcpToolInstances.map(t => t.name),
},
'Added MCP tools for agent'
);
}
this.logger.info(
{
agent: agentName,
toolCount: tools.length,
toolNames: tools.map(t => t.name),
},
'Retrieved tools for agent'
);
return tools;
}
/**
* Get a platform tool by name
*
* @param toolName - Name of the tool to create
* @param sessionWorkspaceManager - Optional session-specific workspace manager
*/
private async getPlatformTool(
toolName: string,
sessionWorkspaceManager?: WorkspaceManager
): Promise<DynamicStructuredTool | null> {
// Don't cache tools - recreate each time to get latest services
// (services might be initialized asynchronously after registry creation)
// Create tool based on name
let tool: DynamicStructuredTool | null = null;
switch (toolName) {
case 'SymbolLookup': {
const symbolIndexService = this.resolveService(this.platformServices.symbolIndexService);
if (symbolIndexService) {
tool = createSymbolLookupTool({
symbolIndexService,
logger: this.logger,
});
} else {
this.logger.warn('SymbolIndexService not available for SymbolLookup tool');
}
break;
}
case 'GetChartData': {
const ohlcService = this.resolveService(this.platformServices.ohlcService);
// Use session workspace manager if provided, otherwise try global
const workspaceManager = sessionWorkspaceManager ||
this.resolveService(this.platformServices.workspaceManager);
if (ohlcService && workspaceManager) {
tool = createGetChartDataTool({
ohlcService,
workspaceManager,
logger: this.logger,
});
} else {
this.logger.warn(
{ hasOHLC: !!ohlcService, hasWorkspace: !!workspaceManager },
'OHLCService or WorkspaceManager not available for GetChartData tool'
);
}
break;
}
case 'GetTicker24h': {
if (this.platformServices.ticker24hGetter) {
tool = createGetTicker24hTool({
getTicker24h: this.platformServices.ticker24hGetter,
logger: this.logger,
});
} else {
this.logger.warn('ticker24hGetter not configured — GetTicker24h tool unavailable');
}
break;
}
case 'WebSearch': {
if (this.platformServices.tavilyApiKey) {
tool = createWebSearchTool({ apiKey: this.platformServices.tavilyApiKey, logger: this.logger });
} else {
this.logger.warn('TAVILY_API_KEY not configured — WebSearch tool unavailable');
}
break;
}
case 'FetchPage': {
tool = createFetchPageTool({ logger: this.logger });
break;
}
case 'ArxivSearch': {
tool = createArxivSearchTool({ logger: this.logger });
break;
}
default:
this.logger.warn({ tool: toolName }, 'Unknown platform tool');
return null;
}
return tool;
}
/**
* Resolve a service (handle both direct references and getter functions)
*/
private resolveService<T>(service: T | (() => T | undefined) | undefined): T | undefined {
// Check if it's a function by checking the type more carefully
if (service && typeof (service as any) === 'function' && !(service as any).prototype) {
// It's a getter function (arrow function or function expression, not a class)
return (service as () => T | undefined)();
}
return service as T | undefined;
}
/**
* Filter MCP tools based on patterns/names
* Supports wildcards like 'Python*' or exact names like 'ExecuteResearch'
*/
private filterMCPTools(availableTools: MCPToolInfo[], patterns: string[]): MCPToolInfo[] {
if (patterns.length === 0) {
return [];
}
return availableTools.filter(tool => {
for (const pattern of patterns) {
if (this.matchesPattern(tool.name, pattern)) {
return true;
}
}
return false;
});
}
/**
* Check if a tool name matches a pattern
* Supports wildcards: 'Python*' matches 'PythonWrite', 'PythonRead', etc.
*/
private matchesPattern(toolName: string, pattern: string): boolean {
if (pattern === toolName) {
return true; // Exact match
}
if (pattern.includes('*')) {
// Convert wildcard pattern to regex
const regexPattern = pattern
.replace(/\*/g, '.*')
.replace(/\?/g, '.');
const regex = new RegExp(`^${regexPattern}$`);
return regex.test(toolName);
}
return false;
}
/**
* Resolve tools directly from explicit platform tool names and MCP patterns,
* without requiring a pre-registered agent config.
* Used by SpawnService to build tool lists from wiki frontmatter at spawn time.
*/
async resolveTools(
platformTools: string[],
mcpPatterns: string[],
mcpClient?: MCPClientConnector,
availableMCPTools?: MCPToolInfo[],
workspaceManager?: WorkspaceManager,
onImage?: (image: { data: string; mimeType: string }) => void,
onWorkspaceMutation?: (storeName: string, newState: unknown) => void
): Promise<DynamicStructuredTool[]> {
const tools: DynamicStructuredTool[] = [];
for (const toolName of platformTools) {
const tool = await this.getPlatformTool(toolName, workspaceManager);
if (tool) {
tools.push(tool);
} else {
this.logger.warn({ tool: toolName }, 'resolveTools: platform tool not found');
}
}
if (mcpClient && availableMCPTools && availableMCPTools.length > 0 && mcpPatterns.length > 0) {
const filteredMCPTools = this.filterMCPTools(availableMCPTools, mcpPatterns);
const mcpToolInstances = createMCPToolWrappers(filteredMCPTools, mcpClient, this.logger, onImage, onWorkspaceMutation);
tools.push(...mcpToolInstances);
}
return tools;
}
/**
* Get all registered agent names
*/
getRegisteredAgents(): string[] {
return Array.from(this.agentToolConfigs.keys());
}
/**
* Get tool configuration for an agent
*/
getAgentToolConfig(agentName: string): AgentToolConfig | null {
return this.agentToolConfigs.get(agentName) || null;
}
}
/**
* Global registry instance (initialized at gateway startup)
*/
let globalToolRegistry: ToolRegistry | null = null;
/**
* Initialize the global tool registry
*/
export function initializeToolRegistry(
logger: FastifyBaseLogger,
platformServices: PlatformServices
): ToolRegistry {
if (globalToolRegistry) {
logger.warn('Global tool registry already initialized');
return globalToolRegistry;
}
globalToolRegistry = new ToolRegistry(logger, platformServices);
logger.info('Tool registry initialized');
return globalToolRegistry;
}
/**
* Get the global tool registry
*/
export function getToolRegistry(): ToolRegistry {
if (!globalToolRegistry) {
throw new Error('Tool registry not initialized. Call initializeToolRegistry() first.');
}
return globalToolRegistry;
}