diff --git a/src/config.ts b/src/config.ts index 8318522..7d4a5a1 100644 --- a/src/config.ts +++ b/src/config.ts @@ -428,31 +428,10 @@ export const config: Config = { ["MAX_OUTPUT_TOKENS_ANTHROPIC", "MAX_OUTPUT_TOKENS"], 400 ), - allowedModelFamilies: getEnvWithDefault("ALLOWED_MODEL_FAMILIES", [ - "turbo", - "gpt4", - "gpt4-32k", - "gpt4-turbo", - "gpt4o", - "claude", - "claude-opus", - "gemini-flash", - "gemini-pro", - "gemini-ultra", - "mistral-tiny", - "mistral-small", - "mistral-medium", - "mistral-large", - "aws-claude", - "aws-claude-opus", - "gcp-claude", - "gcp-claude-opus", - "azure-turbo", - "azure-gpt4", - "azure-gpt4-32k", - "azure-gpt4-turbo", - "azure-gpt4o", - ]), + allowedModelFamilies: getEnvWithDefault( + "ALLOWED_MODEL_FAMILIES", + getDefaultModelFamilies() + ), rejectPhrases: parseCsv(getEnvWithDefault("REJECT_PHRASES", "")), rejectMessage: getEnvWithDefault( "REJECT_MESSAGE", @@ -801,3 +780,7 @@ function parseCsv(val: string): string[] { const matches = val.match(regex) || []; return matches.map((item) => item.replace(/^"|"$/g, "").trim()); } + +function getDefaultModelFamilies(): ModelFamily[] { + return MODEL_FAMILIES.filter((f) => !f.includes("dall-e")) as ModelFamily[]; +} diff --git a/src/proxy/aws.ts b/src/proxy/aws.ts index 2a54a13..6b81375 100644 --- a/src/proxy/aws.ts +++ b/src/proxy/aws.ts @@ -8,17 +8,25 @@ import { awsMistral } from "./aws-mistral"; import { AwsBedrockKey, keyPool } from "../shared/key-management"; const awsRouter = Router(); +awsRouter.get("/:vendor?/models", addV1, handleModelsRequest); awsRouter.use("/claude", addV1, awsClaude); awsRouter.use("/mistral", addV1, awsMistral); -awsRouter.get("/:vendor?/models", handleModelsRequest); const MODELS_CACHE_TTL = 10000; -let modelsCache: any = null; -let modelsCacheTime = 0; +let modelsCache: Record = {}; +let modelsCacheTime: Record = {}; function handleModelsRequest(req: Request, res: Response) { if (!config.awsCredentials) return { object: "list", data: [] }; - if (new Date().getTime() - modelsCacheTime < MODELS_CACHE_TTL) { - return res.json(modelsCache); + + const vendor = req.params.vendor?.length + ? req.params.vendor === "claude" + ? "anthropic" + : req.params.vendor + : "all"; + + const cacheTime = modelsCacheTime[vendor] || 0; + if (new Date().getTime() - cacheTime < MODELS_CACHE_TTL) { + return res.json(modelsCache[vendor]); } const availableModelIds = new Set(); @@ -43,27 +51,25 @@ function handleModelsRequest(req: Request, res: Response) { ] .filter((id) => availableModelIds.has(id)) .map((id) => { - const vendor = id.match(/^(.*)\./)?.[1]; - return { - id, - object: "model", - created: new Date().getTime(), - owned_by: vendor, - permission: [], - root: vendor, - parent: null, - }; - }); + const vendor = id.match(/^(.*)\./)?.[1]; + return { + id, + object: "model", + created: new Date().getTime(), + owned_by: vendor, + permission: [], + root: vendor, + parent: null, + }; + }); - const requestedVendor = req.params.vendor; - const vendor = requestedVendor === "claude" ? "anthropic" : requestedVendor; - modelsCache = { + modelsCache[vendor] = { object: "list", - data: models.filter((m) => m.root === vendor), + data: models.filter((m) => vendor === "all" || m.root === vendor), }; - modelsCacheTime = new Date().getTime(); + modelsCacheTime[vendor] = new Date().getTime(); - return res.json(modelsCache); + return res.json(modelsCache[vendor]); } export const aws = awsRouter; diff --git a/src/proxy/middleware/response/handle-streamed-response.ts b/src/proxy/middleware/response/handle-streamed-response.ts index f84e263..439766b 100644 --- a/src/proxy/middleware/response/handle-streamed-response.ts +++ b/src/proxy/middleware/response/handle-streamed-response.ts @@ -22,18 +22,19 @@ import { SSEStreamAdapter } from "./streaming/sse-stream-adapter"; const pipelineAsync = promisify(pipeline); /** - * `handleStreamedResponse` consumes and transforms a streamed response from the - * upstream service, forwarding events to the client in their requested format. + * `handleStreamedResponse` consumes a streamed response from the upstream API, + * decodes chunk-by-chunk into a stream of events, transforms those events into + * the client's requested format, and forwards the result to the client. + * * After the entire stream has been consumed, it resolves with the full response * body so that subsequent middleware in the chain can process it as if it were - * a non-streaming response. + * a non-streaming response (to count output tokens, track usage, etc). * - * In the event of an error, the request's streaming flag is unset and the non- - * streaming response handler is called instead. - * - * If the error is retryable, that handler will re-enqueue the request and also - * reset the streaming flag. Unfortunately the streaming flag is set and unset - * in multiple places, so it's hard to keep track of. + * In the event of an error, the request's streaming flag is unset and the + * request is bounced back to the non-streaming response handler. If the error + * is retryable, that handler will re-enqueue the request and also reset the + * streaming flag. Unfortunately the streaming flag is set and unset in multiple + * places, so it's hard to keep track of. */ export const handleStreamedResponse: RawResponseBodyHandler = async ( proxyRes, @@ -70,13 +71,21 @@ export const handleStreamedResponse: RawResponseBodyHandler = async ( logger: req.log, }; - // Decoder turns the raw response stream into a stream of events in some - // format (text/event-stream, vnd.amazon.event-stream, streaming JSON, etc). - const decoder = getDecoder({ ...streamOptions, input: proxyRes }); - // Adapter transforms the decoded events into server-sent events. - const adapter = new SSEStreamAdapter(streamOptions); - // Aggregator compiles all events into a single response object. + // While the request is streaming, aggregator collects all events so that we + // can compile them into a single response object and publish that to the + // remaining middleware. Because we have an OpenAI transformer for every + // supported format, EventAggregator always consumes OpenAI events so that we + // only have to write one aggregator (OpenAI input) for each output format. const aggregator = new EventAggregator({ format: req.outboundApi }); + + // Decoder reads from the raw response buffer and produces a stream of + // discrete events in some format (text/event-stream, vnd.amazon.event-stream, + // streaming JSON, etc). + const decoder = getDecoder({ ...streamOptions, input: proxyRes }); + // Adapter consumes the decoded events and produces server-sent events so we + // have a standard transport for the client and to translate between API + // message formats. + const adapter = new SSEStreamAdapter(streamOptions); // Transformer converts server-sent events from one vendor's API message // format to another. const transformer = new SSEMessageTransformer({ diff --git a/src/proxy/middleware/response/streaming/aws-event-stream-decoder.ts b/src/proxy/middleware/response/streaming/aws-event-stream-decoder.ts index b394142..fe06289 100644 --- a/src/proxy/middleware/response/streaming/aws-event-stream-decoder.ts +++ b/src/proxy/middleware/response/streaming/aws-event-stream-decoder.ts @@ -24,7 +24,7 @@ export function getAwsEventStreamDecoder(params: { if (eventType === "chunk") { result = input[eventType]; } else { - // AWS unmarshaller treats non-chunk (errors and exceptions) oddly. + // AWS unmarshaller treats non-chunk events (errors and exceptions) oddly. result = { [eventType]: input[eventType] } as any; } return result; diff --git a/src/service-info.ts b/src/service-info.ts index b720d5d..98f71b4 100644 --- a/src/service-info.ts +++ b/src/service-info.ts @@ -149,8 +149,8 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record } = { "mistral-ai": `%BASE%/mistral-ai`, }, aws: { - claude: `%BASE%/aws/claude`, - mistral: `%BASE%/aws/mistral`, + "aws-claude": `%BASE%/aws/claude`, + "aws-mistral": `%BASE%/aws/mistral`, }, gcp: { gcp: `%BASE%/gcp/claude`,