Files
ai/gateway/src/llm/router.ts

250 lines
7.1 KiB
TypeScript

import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
import type { FastifyBaseLogger } from 'fastify';
import { LLMProviderFactory, type ModelConfig, LLMProvider, type LicenseModelsConfig } from './provider.js';
import type { ModelMiddleware } from './middleware.js';
import type { License } from '../types/user.js';
/**
* Model routing strategies
*/
export enum RoutingStrategy {
/** Use user's preferred model from license */
USER_PREFERENCE = 'user_preference',
/** Route based on query complexity */
COMPLEXITY = 'complexity',
/** Route based on license tier */
LICENSE_TIER = 'license_tier',
/** Use cheapest available model */
COST_OPTIMIZED = 'cost_optimized',
}
/**
* Model router
* Intelligently selects which model to use based on various factors
*/
export class ModelRouter {
private factory: LLMProviderFactory;
private logger: FastifyBaseLogger;
private defaultModel: ModelConfig;
private licenseModels?: LicenseModelsConfig;
constructor(factory: LLMProviderFactory, logger: FastifyBaseLogger) {
this.factory = factory;
this.logger = logger;
this.defaultModel = factory.getDefaultModel();
this.licenseModels = factory.getLicenseModelsConfig();
}
/**
* Route to appropriate model based on context
*/
async route(
message: string,
license: License,
strategy: RoutingStrategy = RoutingStrategy.USER_PREFERENCE,
userId?: string,
maxTokens?: number
): Promise<{ model: BaseChatModel; middleware: ModelMiddleware }> {
let modelConfig: ModelConfig;
switch (strategy) {
case RoutingStrategy.USER_PREFERENCE:
modelConfig = this.routeByUserPreference(license);
break;
case RoutingStrategy.COMPLEXITY:
modelConfig = this.routeByComplexity(message, license);
break;
case RoutingStrategy.LICENSE_TIER:
modelConfig = this.routeByLicenseTier(license);
break;
case RoutingStrategy.COST_OPTIMIZED:
modelConfig = this.routeByCost(license);
break;
default:
modelConfig = this.defaultModel;
}
if (maxTokens !== undefined) {
modelConfig = { ...modelConfig, maxTokens };
}
this.logger.info(
{
userId,
strategy,
provider: modelConfig.provider,
model: modelConfig.model,
maxTokens: modelConfig.maxTokens,
},
'Routing to model'
);
return this.factory.createModel(modelConfig);
}
/**
* Route based on user's preferred model (if set in license)
*/
private routeByUserPreference(license: License): ModelConfig {
// Check if user has custom model preference
const preferredModel = license.preferredModel as ModelConfig | undefined;
if (preferredModel && this.isModelAllowed(preferredModel, license)) {
return preferredModel;
}
// Fall back to license tier default
return this.routeByLicenseTier(license);
}
/**
* Route based on query complexity
*/
private routeByComplexity(message: string, license: License): ModelConfig {
const isComplex = this.isComplexQuery(message);
// Use configuration if available
if (this.licenseModels) {
const tierConfig = this.licenseModels[license.licenseType];
if (tierConfig) {
const model = isComplex ? tierConfig.complex : tierConfig.default;
return { provider: this.defaultModel.provider as LLMProvider, model };
}
}
// Fallback to hardcoded defaults
if (license.licenseType === 'enterprise') {
return isComplex
? { provider: LLMProvider.DEEP_INFRA, model: 'Qwen/Qwen3-235B-A22B-Instruct-2507' }
: { provider: LLMProvider.DEEP_INFRA, model: 'zai-org/GLM-5' };
}
if (license.licenseType === 'pro') {
return isComplex
? { provider: LLMProvider.DEEP_INFRA, model: 'zai-org/GLM-5' }
: { provider: LLMProvider.DEEP_INFRA, model: 'zai-org/GLM-5' };
}
return { provider: LLMProvider.DEEP_INFRA, model: 'zai-org/GLM-5' };
}
/**
* Route based on license tier
*/
private routeByLicenseTier(license: License): ModelConfig {
// Use configuration if available
if (this.licenseModels) {
const tierConfig = this.licenseModels[license.licenseType];
if (tierConfig) {
return { provider: this.defaultModel.provider as LLMProvider, model: tierConfig.default };
}
}
// Fallback to hardcoded defaults
switch (license.licenseType) {
case 'enterprise':
return { provider: LLMProvider.DEEP_INFRA, model: 'zai-org/GLM-5' };
case 'pro':
return { provider: LLMProvider.DEEP_INFRA, model: 'zai-org/GLM-5' };
case 'free':
return { provider: LLMProvider.DEEP_INFRA, model: 'zai-org/GLM-5' };
default:
return this.defaultModel;
}
}
/**
* Route to cheapest available model
*/
private routeByCost(license: License): ModelConfig {
// Use configuration if available
if (this.licenseModels) {
const tierConfig = this.licenseModels[license.licenseType];
if (tierConfig) {
return { provider: this.defaultModel.provider as LLMProvider, model: tierConfig.cost_optimized };
}
}
// Fallback: use GLM-5
return { provider: LLMProvider.DEEP_INFRA, model: 'zai-org/GLM-5' };
}
/**
* Check if model is allowed for user's license
*/
private isModelAllowed(model: ModelConfig, license: License): boolean {
// Use configuration if available
if (this.licenseModels) {
const tierConfig = this.licenseModels[license.licenseType];
if (tierConfig) {
// Check allowed_models list if defined
if (tierConfig.allowed_models && tierConfig.allowed_models.length > 0) {
return tierConfig.allowed_models.includes(model.model);
}
// Check blocked_models list if defined
if (tierConfig.blocked_models && tierConfig.blocked_models.length > 0) {
return !tierConfig.blocked_models.includes(model.model);
}
// No restrictions if neither list is defined
return true;
}
}
// Fallback to hardcoded defaults
if (license.licenseType === 'free') {
const allowedModels = ['zai-org/GLM-5'];
return allowedModels.includes(model.model);
}
if (license.licenseType === 'pro') {
const blockedModels = ['Qwen/Qwen3-235B-A22B-Instruct-2507'];
return !blockedModels.includes(model.model);
}
// Enterprise: all models allowed
return true;
}
/**
* Determine if query is complex
*/
private isComplexQuery(message: string): boolean {
const complexityIndicators = [
// Multi-step analysis
'backtest',
'analyze',
'compare',
'optimize',
// Code generation
'write',
'create',
'implement',
'build',
// Deep reasoning
'explain why',
'what if',
'how would',
// Long messages (> 200 chars likely complex)
message.length > 200,
];
const messageLower = message.toLowerCase();
return complexityIndicators.some((indicator) =>
typeof indicator === 'string' ? messageLower.includes(indicator) : indicator
);
}
}