fixes aws models endpoint

This commit is contained in:
nai-degen
2024-08-12 19:26:55 -05:00
parent 2d8e1dac13
commit 49a89122f5
5 changed files with 63 additions and 65 deletions
+8 -25
View File
@@ -428,31 +428,10 @@ export const config: Config = {
["MAX_OUTPUT_TOKENS_ANTHROPIC", "MAX_OUTPUT_TOKENS"],
400
),
allowedModelFamilies: getEnvWithDefault("ALLOWED_MODEL_FAMILIES", [
"turbo",
"gpt4",
"gpt4-32k",
"gpt4-turbo",
"gpt4o",
"claude",
"claude-opus",
"gemini-flash",
"gemini-pro",
"gemini-ultra",
"mistral-tiny",
"mistral-small",
"mistral-medium",
"mistral-large",
"aws-claude",
"aws-claude-opus",
"gcp-claude",
"gcp-claude-opus",
"azure-turbo",
"azure-gpt4",
"azure-gpt4-32k",
"azure-gpt4-turbo",
"azure-gpt4o",
]),
allowedModelFamilies: getEnvWithDefault(
"ALLOWED_MODEL_FAMILIES",
getDefaultModelFamilies()
),
rejectPhrases: parseCsv(getEnvWithDefault("REJECT_PHRASES", "")),
rejectMessage: getEnvWithDefault(
"REJECT_MESSAGE",
@@ -801,3 +780,7 @@ function parseCsv(val: string): string[] {
const matches = val.match(regex) || [];
return matches.map((item) => item.replace(/^"|"$/g, "").trim());
}
function getDefaultModelFamilies(): ModelFamily[] {
return MODEL_FAMILIES.filter((f) => !f.includes("dall-e")) as ModelFamily[];
}
+28 -22
View File
@@ -8,17 +8,25 @@ import { awsMistral } from "./aws-mistral";
import { AwsBedrockKey, keyPool } from "../shared/key-management";
const awsRouter = Router();
awsRouter.get("/:vendor?/models", addV1, handleModelsRequest);
awsRouter.use("/claude", addV1, awsClaude);
awsRouter.use("/mistral", addV1, awsMistral);
awsRouter.get("/:vendor?/models", handleModelsRequest);
const MODELS_CACHE_TTL = 10000;
let modelsCache: any = null;
let modelsCacheTime = 0;
let modelsCache: Record<string, any> = {};
let modelsCacheTime: Record<string, number> = {};
function handleModelsRequest(req: Request, res: Response) {
if (!config.awsCredentials) return { object: "list", data: [] };
if (new Date().getTime() - modelsCacheTime < MODELS_CACHE_TTL) {
return res.json(modelsCache);
const vendor = req.params.vendor?.length
? req.params.vendor === "claude"
? "anthropic"
: req.params.vendor
: "all";
const cacheTime = modelsCacheTime[vendor] || 0;
if (new Date().getTime() - cacheTime < MODELS_CACHE_TTL) {
return res.json(modelsCache[vendor]);
}
const availableModelIds = new Set<string>();
@@ -43,27 +51,25 @@ function handleModelsRequest(req: Request, res: Response) {
]
.filter((id) => availableModelIds.has(id))
.map((id) => {
const vendor = id.match(/^(.*)\./)?.[1];
return {
id,
object: "model",
created: new Date().getTime(),
owned_by: vendor,
permission: [],
root: vendor,
parent: null,
};
});
const vendor = id.match(/^(.*)\./)?.[1];
return {
id,
object: "model",
created: new Date().getTime(),
owned_by: vendor,
permission: [],
root: vendor,
parent: null,
};
});
const requestedVendor = req.params.vendor;
const vendor = requestedVendor === "claude" ? "anthropic" : requestedVendor;
modelsCache = {
modelsCache[vendor] = {
object: "list",
data: models.filter((m) => m.root === vendor),
data: models.filter((m) => vendor === "all" || m.root === vendor),
};
modelsCacheTime = new Date().getTime();
modelsCacheTime[vendor] = new Date().getTime();
return res.json(modelsCache);
return res.json(modelsCache[vendor]);
}
export const aws = awsRouter;
@@ -22,18 +22,19 @@ import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
const pipelineAsync = promisify(pipeline);
/**
* `handleStreamedResponse` consumes and transforms a streamed response from the
* upstream service, forwarding events to the client in their requested format.
* `handleStreamedResponse` consumes a streamed response from the upstream API,
* decodes chunk-by-chunk into a stream of events, transforms those events into
* the client's requested format, and forwards the result to the client.
*
* After the entire stream has been consumed, it resolves with the full response
* body so that subsequent middleware in the chain can process it as if it were
* a non-streaming response.
* a non-streaming response (to count output tokens, track usage, etc).
*
* In the event of an error, the request's streaming flag is unset and the non-
* streaming response handler is called instead.
*
* If the error is retryable, that handler will re-enqueue the request and also
* reset the streaming flag. Unfortunately the streaming flag is set and unset
* in multiple places, so it's hard to keep track of.
* In the event of an error, the request's streaming flag is unset and the
* request is bounced back to the non-streaming response handler. If the error
* is retryable, that handler will re-enqueue the request and also reset the
* streaming flag. Unfortunately the streaming flag is set and unset in multiple
* places, so it's hard to keep track of.
*/
export const handleStreamedResponse: RawResponseBodyHandler = async (
proxyRes,
@@ -70,13 +71,21 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
logger: req.log,
};
// Decoder turns the raw response stream into a stream of events in some
// format (text/event-stream, vnd.amazon.event-stream, streaming JSON, etc).
const decoder = getDecoder({ ...streamOptions, input: proxyRes });
// Adapter transforms the decoded events into server-sent events.
const adapter = new SSEStreamAdapter(streamOptions);
// Aggregator compiles all events into a single response object.
// While the request is streaming, aggregator collects all events so that we
// can compile them into a single response object and publish that to the
// remaining middleware. Because we have an OpenAI transformer for every
// supported format, EventAggregator always consumes OpenAI events so that we
// only have to write one aggregator (OpenAI input) for each output format.
const aggregator = new EventAggregator({ format: req.outboundApi });
// Decoder reads from the raw response buffer and produces a stream of
// discrete events in some format (text/event-stream, vnd.amazon.event-stream,
// streaming JSON, etc).
const decoder = getDecoder({ ...streamOptions, input: proxyRes });
// Adapter consumes the decoded events and produces server-sent events so we
// have a standard transport for the client and to translate between API
// message formats.
const adapter = new SSEStreamAdapter(streamOptions);
// Transformer converts server-sent events from one vendor's API message
// format to another.
const transformer = new SSEMessageTransformer({
@@ -24,7 +24,7 @@ export function getAwsEventStreamDecoder(params: {
if (eventType === "chunk") {
result = input[eventType];
} else {
// AWS unmarshaller treats non-chunk (errors and exceptions) oddly.
// AWS unmarshaller treats non-chunk events (errors and exceptions) oddly.
result = { [eventType]: input[eventType] } as any;
}
return result;
+2 -2
View File
@@ -149,8 +149,8 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
"mistral-ai": `%BASE%/mistral-ai`,
},
aws: {
claude: `%BASE%/aws/claude`,
mistral: `%BASE%/aws/mistral`,
"aws-claude": `%BASE%/aws/claude`,
"aws-mistral": `%BASE%/aws/mistral`,
},
gcp: {
gcp: `%BASE%/gcp/claude`,