244 lines
7.0 KiB
TypeScript
244 lines
7.0 KiB
TypeScript
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
|
import type { FastifyBaseLogger } from 'fastify';
|
|
import { LLMProviderFactory, type ModelConfig, LLMProvider, type LicenseModelsConfig } from './provider.js';
|
|
import type { ModelMiddleware } from './middleware.js';
|
|
import type { License } from '../types/user.js';
|
|
|
|
/**
|
|
* Model routing strategies
|
|
*/
|
|
export enum RoutingStrategy {
|
|
/** Use user's preferred model from license */
|
|
USER_PREFERENCE = 'user_preference',
|
|
/** Route based on query complexity */
|
|
COMPLEXITY = 'complexity',
|
|
/** Route based on license tier */
|
|
LICENSE_TIER = 'license_tier',
|
|
/** Use cheapest available model */
|
|
COST_OPTIMIZED = 'cost_optimized',
|
|
}
|
|
|
|
/**
|
|
* Model router
|
|
* Intelligently selects which model to use based on various factors
|
|
*/
|
|
export class ModelRouter {
|
|
private factory: LLMProviderFactory;
|
|
private logger: FastifyBaseLogger;
|
|
private defaultModel: ModelConfig;
|
|
private licenseModels?: LicenseModelsConfig;
|
|
|
|
constructor(factory: LLMProviderFactory, logger: FastifyBaseLogger) {
|
|
this.factory = factory;
|
|
this.logger = logger;
|
|
this.defaultModel = factory.getDefaultModel();
|
|
this.licenseModels = factory.getLicenseModelsConfig();
|
|
}
|
|
|
|
/**
|
|
* Route to appropriate model based on context
|
|
*/
|
|
async route(
|
|
message: string,
|
|
license: License,
|
|
strategy: RoutingStrategy = RoutingStrategy.USER_PREFERENCE,
|
|
userId?: string
|
|
): Promise<{ model: BaseChatModel; middleware: ModelMiddleware }> {
|
|
let modelConfig: ModelConfig;
|
|
|
|
switch (strategy) {
|
|
case RoutingStrategy.USER_PREFERENCE:
|
|
modelConfig = this.routeByUserPreference(license);
|
|
break;
|
|
|
|
case RoutingStrategy.COMPLEXITY:
|
|
modelConfig = this.routeByComplexity(message, license);
|
|
break;
|
|
|
|
case RoutingStrategy.LICENSE_TIER:
|
|
modelConfig = this.routeByLicenseTier(license);
|
|
break;
|
|
|
|
case RoutingStrategy.COST_OPTIMIZED:
|
|
modelConfig = this.routeByCost(license);
|
|
break;
|
|
|
|
default:
|
|
modelConfig = this.defaultModel;
|
|
}
|
|
|
|
this.logger.info(
|
|
{
|
|
userId,
|
|
strategy,
|
|
provider: modelConfig.provider,
|
|
model: modelConfig.model,
|
|
},
|
|
'Routing to model'
|
|
);
|
|
|
|
return this.factory.createModel(modelConfig);
|
|
}
|
|
|
|
/**
|
|
* Route based on user's preferred model (if set in license)
|
|
*/
|
|
private routeByUserPreference(license: License): ModelConfig {
|
|
// Check if user has custom model preference
|
|
const preferredModel = license.preferredModel as ModelConfig | undefined;
|
|
|
|
if (preferredModel && this.isModelAllowed(preferredModel, license)) {
|
|
return preferredModel;
|
|
}
|
|
|
|
// Fall back to license tier default
|
|
return this.routeByLicenseTier(license);
|
|
}
|
|
|
|
/**
|
|
* Route based on query complexity
|
|
*/
|
|
private routeByComplexity(message: string, license: License): ModelConfig {
|
|
const isComplex = this.isComplexQuery(message);
|
|
|
|
// Use configuration if available
|
|
if (this.licenseModels) {
|
|
const tierConfig = this.licenseModels[license.licenseType];
|
|
if (tierConfig) {
|
|
const model = isComplex ? tierConfig.complex : tierConfig.default;
|
|
return { provider: this.defaultModel.provider as LLMProvider, model };
|
|
}
|
|
}
|
|
|
|
// Fallback to hardcoded defaults
|
|
if (license.licenseType === 'enterprise') {
|
|
return isComplex
|
|
? { provider: LLMProvider.ANTHROPIC, model: 'claude-opus-4-6' }
|
|
: { provider: LLMProvider.ANTHROPIC, model: 'claude-sonnet-4-6' };
|
|
}
|
|
|
|
if (license.licenseType === 'pro') {
|
|
return isComplex
|
|
? { provider: LLMProvider.ANTHROPIC, model: 'claude-sonnet-4-6' }
|
|
: { provider: LLMProvider.ANTHROPIC, model: 'claude-haiku-4-5-20251001' };
|
|
}
|
|
|
|
return { provider: LLMProvider.ANTHROPIC, model: 'claude-haiku-4-5-20251001' };
|
|
}
|
|
|
|
/**
|
|
* Route based on license tier
|
|
*/
|
|
private routeByLicenseTier(license: License): ModelConfig {
|
|
// Use configuration if available
|
|
if (this.licenseModels) {
|
|
const tierConfig = this.licenseModels[license.licenseType];
|
|
if (tierConfig) {
|
|
return { provider: this.defaultModel.provider as LLMProvider, model: tierConfig.default };
|
|
}
|
|
}
|
|
|
|
// Fallback to hardcoded defaults
|
|
switch (license.licenseType) {
|
|
case 'enterprise':
|
|
return { provider: LLMProvider.ANTHROPIC, model: 'claude-sonnet-4-6' };
|
|
|
|
case 'pro':
|
|
return { provider: LLMProvider.ANTHROPIC, model: 'claude-sonnet-4-6' };
|
|
|
|
case 'free':
|
|
return { provider: LLMProvider.ANTHROPIC, model: 'claude-haiku-4-5-20251001' };
|
|
|
|
default:
|
|
return this.defaultModel;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Route to cheapest available model
|
|
*/
|
|
private routeByCost(license: License): ModelConfig {
|
|
// Use configuration if available
|
|
if (this.licenseModels) {
|
|
const tierConfig = this.licenseModels[license.licenseType];
|
|
if (tierConfig) {
|
|
return { provider: this.defaultModel.provider as LLMProvider, model: tierConfig.cost_optimized };
|
|
}
|
|
}
|
|
|
|
// Fallback: use Haiku for cost efficiency
|
|
return { provider: LLMProvider.ANTHROPIC, model: 'claude-haiku-4-5-20251001' };
|
|
}
|
|
|
|
/**
|
|
* Check if model is allowed for user's license
|
|
*/
|
|
private isModelAllowed(model: ModelConfig, license: License): boolean {
|
|
// Use configuration if available
|
|
if (this.licenseModels) {
|
|
const tierConfig = this.licenseModels[license.licenseType];
|
|
if (tierConfig) {
|
|
// Check allowed_models list if defined
|
|
if (tierConfig.allowed_models && tierConfig.allowed_models.length > 0) {
|
|
return tierConfig.allowed_models.includes(model.model);
|
|
}
|
|
|
|
// Check blocked_models list if defined
|
|
if (tierConfig.blocked_models && tierConfig.blocked_models.length > 0) {
|
|
return !tierConfig.blocked_models.includes(model.model);
|
|
}
|
|
|
|
// No restrictions if neither list is defined
|
|
return true;
|
|
}
|
|
}
|
|
|
|
// Fallback to hardcoded defaults
|
|
if (license.licenseType === 'free') {
|
|
const allowedModels = ['claude-haiku-4-5-20251001'];
|
|
return allowedModels.includes(model.model);
|
|
}
|
|
|
|
if (license.licenseType === 'pro') {
|
|
const blockedModels = ['claude-opus-4-6'];
|
|
return !blockedModels.includes(model.model);
|
|
}
|
|
|
|
// Enterprise: all models allowed
|
|
return true;
|
|
}
|
|
|
|
/**
|
|
* Determine if query is complex
|
|
*/
|
|
private isComplexQuery(message: string): boolean {
|
|
const complexityIndicators = [
|
|
// Multi-step analysis
|
|
'backtest',
|
|
'analyze',
|
|
'compare',
|
|
'optimize',
|
|
|
|
// Code generation
|
|
'write',
|
|
'create',
|
|
'implement',
|
|
'build',
|
|
|
|
// Deep reasoning
|
|
'explain why',
|
|
'what if',
|
|
'how would',
|
|
|
|
// Long messages (> 200 chars likely complex)
|
|
message.length > 200,
|
|
];
|
|
|
|
const messageLower = message.toLowerCase();
|
|
|
|
return complexityIndicators.some((indicator) =>
|
|
typeof indicator === 'string' ? messageLower.includes(indicator) : indicator
|
|
);
|
|
}
|
|
}
|