Files
ai/gateway/src/tools/tool-registry.ts

292 lines
8.9 KiB
TypeScript

import type { DynamicStructuredTool } from '@langchain/core/tools';
import type { FastifyBaseLogger } from 'fastify';
import type { MCPClientConnector } from '../harness/mcp-client.js';
import type { OHLCService } from '../services/ohlc-service.js';
import type { SymbolIndexService } from '../services/symbol-index-service.js';
import type { WorkspaceManager } from '../workspace/workspace-manager.js';
import { createSymbolLookupTool } from './platform/symbol-lookup.tool.js';
import { createGetChartDataTool } from './platform/get-chart-data.tool.js';
import { createMCPToolWrappers, type MCPToolInfo } from './mcp/mcp-tool-wrapper.js';
/**
* Agent tool configuration
* Specifies which tools are available to which agent
*/
export interface AgentToolConfig {
/** Agent name (e.g., 'main', 'research', 'code-reviewer') */
agentName: string;
/** Platform tool names to include */
platformTools: string[];
/** MCP tool patterns/names to include (supports wildcards like 'category_*') */
mcpTools: string[];
}
/**
* Platform services required for creating platform tools
* Can be provided as direct references or getter functions (for lazy initialization)
*/
export interface PlatformServices {
ohlcService?: OHLCService | (() => OHLCService | undefined);
symbolIndexService?: SymbolIndexService | (() => SymbolIndexService | undefined);
workspaceManager?: WorkspaceManager | (() => WorkspaceManager | undefined);
}
/**
* Tool Registry
*
* Manages tool creation and agent-to-tool mappings.
* Supports:
* - Platform tools (local services like symbol lookup, chart data)
* - Remote MCP tools (per-user, session-scoped)
* - Configurable tool routing (which tools for which agents)
*/
export class ToolRegistry {
private logger: FastifyBaseLogger;
private platformServices: PlatformServices;
private agentToolConfigs: Map<string, AgentToolConfig> = new Map();
constructor(logger: FastifyBaseLogger, platformServices: PlatformServices) {
this.logger = logger;
this.platformServices = platformServices;
}
/**
* Register agent tool configuration
*/
registerAgentTools(config: AgentToolConfig): void {
this.agentToolConfigs.set(config.agentName, config);
this.logger.debug(
{
agent: config.agentName,
platformTools: config.platformTools,
mcpTools: config.mcpTools,
},
'Registered agent tool configuration'
);
}
/**
* Get tools for a specific agent
*
* @param agentName - Name of the agent ('main', 'research', etc.)
* @param mcpClient - MCP client for remote tools (optional)
* @param availableMCPTools - List of available MCP tools from user's server (optional)
* @param workspaceManager - Workspace manager for this session (optional, used by some platform tools)
* @returns Array of tools for this agent
*/
async getToolsForAgent(
agentName: string,
mcpClient?: MCPClientConnector,
availableMCPTools?: MCPToolInfo[],
workspaceManager?: WorkspaceManager,
onImage?: (image: { data: string; mimeType: string }) => void
): Promise<DynamicStructuredTool[]> {
const config = this.agentToolConfigs.get(agentName);
if (!config) {
this.logger.warn({ agent: agentName }, 'No tool configuration found for agent');
return [];
}
const tools: DynamicStructuredTool[] = [];
// Add platform tools
for (const toolName of config.platformTools) {
const tool = await this.getPlatformTool(toolName, workspaceManager);
if (tool) {
tools.push(tool);
} else {
this.logger.warn({ agent: agentName, tool: toolName }, 'Platform tool not found');
}
}
// Add MCP tools (if MCP client and tools are available)
if (mcpClient && availableMCPTools && availableMCPTools.length > 0) {
const filteredMCPTools = this.filterMCPTools(availableMCPTools, config.mcpTools);
const mcpToolInstances = createMCPToolWrappers(filteredMCPTools, mcpClient, this.logger, onImage);
tools.push(...mcpToolInstances);
this.logger.debug(
{
agent: agentName,
mcpToolCount: mcpToolInstances.length,
mcpToolNames: mcpToolInstances.map(t => t.name),
},
'Added MCP tools for agent'
);
}
this.logger.info(
{
agent: agentName,
toolCount: tools.length,
toolNames: tools.map(t => t.name),
},
'Retrieved tools for agent'
);
return tools;
}
/**
* Get a platform tool by name
*
* @param toolName - Name of the tool to create
* @param sessionWorkspaceManager - Optional session-specific workspace manager
*/
private async getPlatformTool(
toolName: string,
sessionWorkspaceManager?: WorkspaceManager
): Promise<DynamicStructuredTool | null> {
// Don't cache tools - recreate each time to get latest services
// (services might be initialized asynchronously after registry creation)
// Create tool based on name
let tool: DynamicStructuredTool | null = null;
switch (toolName) {
case 'symbol_lookup': {
const symbolIndexService = this.resolveService(this.platformServices.symbolIndexService);
if (symbolIndexService) {
tool = createSymbolLookupTool({
symbolIndexService,
logger: this.logger,
});
} else {
this.logger.warn('SymbolIndexService not available for symbol_lookup tool');
}
break;
}
case 'get_chart_data': {
const ohlcService = this.resolveService(this.platformServices.ohlcService);
// Use session workspace manager if provided, otherwise try global
const workspaceManager = sessionWorkspaceManager ||
this.resolveService(this.platformServices.workspaceManager);
if (ohlcService && workspaceManager) {
tool = createGetChartDataTool({
ohlcService,
workspaceManager,
logger: this.logger,
});
} else {
this.logger.warn(
{ hasOHLC: !!ohlcService, hasWorkspace: !!workspaceManager },
'OHLCService or WorkspaceManager not available for get_chart_data tool'
);
}
break;
}
default:
this.logger.warn({ tool: toolName }, 'Unknown platform tool');
return null;
}
return tool;
}
/**
* Resolve a service (handle both direct references and getter functions)
*/
private resolveService<T>(service: T | (() => T | undefined) | undefined): T | undefined {
// Check if it's a function by checking the type more carefully
if (service && typeof (service as any) === 'function' && !(service as any).prototype) {
// It's a getter function (arrow function or function expression, not a class)
return (service as () => T | undefined)();
}
return service as T | undefined;
}
/**
* Filter MCP tools based on patterns/names
* Supports wildcards like 'category_*' or exact names like 'execute_research'
*/
private filterMCPTools(availableTools: MCPToolInfo[], patterns: string[]): MCPToolInfo[] {
if (patterns.length === 0) {
return [];
}
return availableTools.filter(tool => {
for (const pattern of patterns) {
if (this.matchesPattern(tool.name, pattern)) {
return true;
}
}
return false;
});
}
/**
* Check if a tool name matches a pattern
* Supports wildcards: 'category_*' matches 'category_write', 'category_read', etc.
*/
private matchesPattern(toolName: string, pattern: string): boolean {
if (pattern === toolName) {
return true; // Exact match
}
if (pattern.includes('*')) {
// Convert wildcard pattern to regex
const regexPattern = pattern
.replace(/\*/g, '.*')
.replace(/\?/g, '.');
const regex = new RegExp(`^${regexPattern}$`);
return regex.test(toolName);
}
return false;
}
/**
* Get all registered agent names
*/
getRegisteredAgents(): string[] {
return Array.from(this.agentToolConfigs.keys());
}
/**
* Get tool configuration for an agent
*/
getAgentToolConfig(agentName: string): AgentToolConfig | null {
return this.agentToolConfigs.get(agentName) || null;
}
}
/**
* Global registry instance (initialized at gateway startup)
*/
let globalToolRegistry: ToolRegistry | null = null;
/**
* Initialize the global tool registry
*/
export function initializeToolRegistry(
logger: FastifyBaseLogger,
platformServices: PlatformServices
): ToolRegistry {
if (globalToolRegistry) {
logger.warn('Global tool registry already initialized');
return globalToolRegistry;
}
globalToolRegistry = new ToolRegistry(logger, platformServices);
logger.info('Tool registry initialized');
return globalToolRegistry;
}
/**
* Get the global tool registry
*/
export function getToolRegistry(): ToolRegistry {
if (!globalToolRegistry) {
throw new Error('Tool registry not initialized. Call initializeToolRegistry() first.');
}
return globalToolRegistry;
}