fixes aws models endpoint
This commit is contained in:
+8
-25
@@ -428,31 +428,10 @@ export const config: Config = {
|
||||
["MAX_OUTPUT_TOKENS_ANTHROPIC", "MAX_OUTPUT_TOKENS"],
|
||||
400
|
||||
),
|
||||
allowedModelFamilies: getEnvWithDefault("ALLOWED_MODEL_FAMILIES", [
|
||||
"turbo",
|
||||
"gpt4",
|
||||
"gpt4-32k",
|
||||
"gpt4-turbo",
|
||||
"gpt4o",
|
||||
"claude",
|
||||
"claude-opus",
|
||||
"gemini-flash",
|
||||
"gemini-pro",
|
||||
"gemini-ultra",
|
||||
"mistral-tiny",
|
||||
"mistral-small",
|
||||
"mistral-medium",
|
||||
"mistral-large",
|
||||
"aws-claude",
|
||||
"aws-claude-opus",
|
||||
"gcp-claude",
|
||||
"gcp-claude-opus",
|
||||
"azure-turbo",
|
||||
"azure-gpt4",
|
||||
"azure-gpt4-32k",
|
||||
"azure-gpt4-turbo",
|
||||
"azure-gpt4o",
|
||||
]),
|
||||
allowedModelFamilies: getEnvWithDefault(
|
||||
"ALLOWED_MODEL_FAMILIES",
|
||||
getDefaultModelFamilies()
|
||||
),
|
||||
rejectPhrases: parseCsv(getEnvWithDefault("REJECT_PHRASES", "")),
|
||||
rejectMessage: getEnvWithDefault(
|
||||
"REJECT_MESSAGE",
|
||||
@@ -801,3 +780,7 @@ function parseCsv(val: string): string[] {
|
||||
const matches = val.match(regex) || [];
|
||||
return matches.map((item) => item.replace(/^"|"$/g, "").trim());
|
||||
}
|
||||
|
||||
function getDefaultModelFamilies(): ModelFamily[] {
|
||||
return MODEL_FAMILIES.filter((f) => !f.includes("dall-e")) as ModelFamily[];
|
||||
}
|
||||
|
||||
+28
-22
@@ -8,17 +8,25 @@ import { awsMistral } from "./aws-mistral";
|
||||
import { AwsBedrockKey, keyPool } from "../shared/key-management";
|
||||
|
||||
const awsRouter = Router();
|
||||
awsRouter.get("/:vendor?/models", addV1, handleModelsRequest);
|
||||
awsRouter.use("/claude", addV1, awsClaude);
|
||||
awsRouter.use("/mistral", addV1, awsMistral);
|
||||
awsRouter.get("/:vendor?/models", handleModelsRequest);
|
||||
|
||||
const MODELS_CACHE_TTL = 10000;
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
let modelsCache: Record<string, any> = {};
|
||||
let modelsCacheTime: Record<string, number> = {};
|
||||
function handleModelsRequest(req: Request, res: Response) {
|
||||
if (!config.awsCredentials) return { object: "list", data: [] };
|
||||
if (new Date().getTime() - modelsCacheTime < MODELS_CACHE_TTL) {
|
||||
return res.json(modelsCache);
|
||||
|
||||
const vendor = req.params.vendor?.length
|
||||
? req.params.vendor === "claude"
|
||||
? "anthropic"
|
||||
: req.params.vendor
|
||||
: "all";
|
||||
|
||||
const cacheTime = modelsCacheTime[vendor] || 0;
|
||||
if (new Date().getTime() - cacheTime < MODELS_CACHE_TTL) {
|
||||
return res.json(modelsCache[vendor]);
|
||||
}
|
||||
|
||||
const availableModelIds = new Set<string>();
|
||||
@@ -43,27 +51,25 @@ function handleModelsRequest(req: Request, res: Response) {
|
||||
]
|
||||
.filter((id) => availableModelIds.has(id))
|
||||
.map((id) => {
|
||||
const vendor = id.match(/^(.*)\./)?.[1];
|
||||
return {
|
||||
id,
|
||||
object: "model",
|
||||
created: new Date().getTime(),
|
||||
owned_by: vendor,
|
||||
permission: [],
|
||||
root: vendor,
|
||||
parent: null,
|
||||
};
|
||||
});
|
||||
const vendor = id.match(/^(.*)\./)?.[1];
|
||||
return {
|
||||
id,
|
||||
object: "model",
|
||||
created: new Date().getTime(),
|
||||
owned_by: vendor,
|
||||
permission: [],
|
||||
root: vendor,
|
||||
parent: null,
|
||||
};
|
||||
});
|
||||
|
||||
const requestedVendor = req.params.vendor;
|
||||
const vendor = requestedVendor === "claude" ? "anthropic" : requestedVendor;
|
||||
modelsCache = {
|
||||
modelsCache[vendor] = {
|
||||
object: "list",
|
||||
data: models.filter((m) => m.root === vendor),
|
||||
data: models.filter((m) => vendor === "all" || m.root === vendor),
|
||||
};
|
||||
modelsCacheTime = new Date().getTime();
|
||||
modelsCacheTime[vendor] = new Date().getTime();
|
||||
|
||||
return res.json(modelsCache);
|
||||
return res.json(modelsCache[vendor]);
|
||||
}
|
||||
|
||||
export const aws = awsRouter;
|
||||
|
||||
@@ -22,18 +22,19 @@ import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
|
||||
const pipelineAsync = promisify(pipeline);
|
||||
|
||||
/**
|
||||
* `handleStreamedResponse` consumes and transforms a streamed response from the
|
||||
* upstream service, forwarding events to the client in their requested format.
|
||||
* `handleStreamedResponse` consumes a streamed response from the upstream API,
|
||||
* decodes chunk-by-chunk into a stream of events, transforms those events into
|
||||
* the client's requested format, and forwards the result to the client.
|
||||
*
|
||||
* After the entire stream has been consumed, it resolves with the full response
|
||||
* body so that subsequent middleware in the chain can process it as if it were
|
||||
* a non-streaming response.
|
||||
* a non-streaming response (to count output tokens, track usage, etc).
|
||||
*
|
||||
* In the event of an error, the request's streaming flag is unset and the non-
|
||||
* streaming response handler is called instead.
|
||||
*
|
||||
* If the error is retryable, that handler will re-enqueue the request and also
|
||||
* reset the streaming flag. Unfortunately the streaming flag is set and unset
|
||||
* in multiple places, so it's hard to keep track of.
|
||||
* In the event of an error, the request's streaming flag is unset and the
|
||||
* request is bounced back to the non-streaming response handler. If the error
|
||||
* is retryable, that handler will re-enqueue the request and also reset the
|
||||
* streaming flag. Unfortunately the streaming flag is set and unset in multiple
|
||||
* places, so it's hard to keep track of.
|
||||
*/
|
||||
export const handleStreamedResponse: RawResponseBodyHandler = async (
|
||||
proxyRes,
|
||||
@@ -70,13 +71,21 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
||||
logger: req.log,
|
||||
};
|
||||
|
||||
// Decoder turns the raw response stream into a stream of events in some
|
||||
// format (text/event-stream, vnd.amazon.event-stream, streaming JSON, etc).
|
||||
const decoder = getDecoder({ ...streamOptions, input: proxyRes });
|
||||
// Adapter transforms the decoded events into server-sent events.
|
||||
const adapter = new SSEStreamAdapter(streamOptions);
|
||||
// Aggregator compiles all events into a single response object.
|
||||
// While the request is streaming, aggregator collects all events so that we
|
||||
// can compile them into a single response object and publish that to the
|
||||
// remaining middleware. Because we have an OpenAI transformer for every
|
||||
// supported format, EventAggregator always consumes OpenAI events so that we
|
||||
// only have to write one aggregator (OpenAI input) for each output format.
|
||||
const aggregator = new EventAggregator({ format: req.outboundApi });
|
||||
|
||||
// Decoder reads from the raw response buffer and produces a stream of
|
||||
// discrete events in some format (text/event-stream, vnd.amazon.event-stream,
|
||||
// streaming JSON, etc).
|
||||
const decoder = getDecoder({ ...streamOptions, input: proxyRes });
|
||||
// Adapter consumes the decoded events and produces server-sent events so we
|
||||
// have a standard transport for the client and to translate between API
|
||||
// message formats.
|
||||
const adapter = new SSEStreamAdapter(streamOptions);
|
||||
// Transformer converts server-sent events from one vendor's API message
|
||||
// format to another.
|
||||
const transformer = new SSEMessageTransformer({
|
||||
|
||||
@@ -24,7 +24,7 @@ export function getAwsEventStreamDecoder(params: {
|
||||
if (eventType === "chunk") {
|
||||
result = input[eventType];
|
||||
} else {
|
||||
// AWS unmarshaller treats non-chunk (errors and exceptions) oddly.
|
||||
// AWS unmarshaller treats non-chunk events (errors and exceptions) oddly.
|
||||
result = { [eventType]: input[eventType] } as any;
|
||||
}
|
||||
return result;
|
||||
|
||||
+2
-2
@@ -149,8 +149,8 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
|
||||
"mistral-ai": `%BASE%/mistral-ai`,
|
||||
},
|
||||
aws: {
|
||||
claude: `%BASE%/aws/claude`,
|
||||
mistral: `%BASE%/aws/mistral`,
|
||||
"aws-claude": `%BASE%/aws/claude`,
|
||||
"aws-mistral": `%BASE%/aws/mistral`,
|
||||
},
|
||||
gcp: {
|
||||
gcp: `%BASE%/gcp/claude`,
|
||||
|
||||
Reference in New Issue
Block a user