fixes rebase issues and adds aws key serializer

This commit is contained in:
nai-degen
2023-10-08 01:50:23 -05:00
parent 05ab8c37eb
commit 8884544b05
10 changed files with 86 additions and 57 deletions
@@ -84,11 +84,10 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
const serializedKeys = await this.store.load();
if (serializedKeys.length === 0) {
this.log.warn(
return this.log.warn(
{ via: storeName },
"No Anthropic keys found. Anthropic API will not be available."
);
return;
}
this.keys.push(...serializedKeys.map(AnthropicKeySerializer.deserialize));
@@ -12,8 +12,8 @@ export const AnthropicKeySerializer: KeySerializer<AnthropicKey> = {
key,
service: "anthropic" as const,
modelFamilies: ["claude" as const],
isTrial: false,
isDisabled: false,
isRevoked: false,
isPozzed: false,
promptCount: 0,
lastUsed: 0,
+34 -38
View File
@@ -1,9 +1,10 @@
import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { config } from "../../../config";
import { logger } from "../../../logger";
import type { AwsBedrockModelFamily } from "../../models";
import { Key, KeyProvider } from "../index";
import { KeyStore, SerializedKey } from "../stores";
import { AwsKeyChecker } from "./checker";
import { AwsBedrockKeySerializer } from "./serializer";
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
export const AWS_BEDROCK_SUPPORTED_MODELS = [
@@ -33,6 +34,15 @@ export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage {
awsLoggingStatus: "unknown" | "disabled" | "enabled";
}
const SERIALIZABLE_FIELDS: (keyof AwsBedrockKey)[] = [
"key",
"service",
"hash",
"aws-claudeTokens",
];
export type SerializedAwsBedrockKey = SerializedKey &
Partial<Pick<AwsBedrockKey, (typeof SERIALIZABLE_FIELDS)[number]>>;
/**
* Upon being rate limited, a key will be locked out for this many milliseconds
* while we wait for other concurrent requests to finish.
@@ -46,48 +56,34 @@ const RATE_LIMIT_LOCKOUT = 300;
const KEY_REUSE_DELAY = 500;
export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
readonly service = "aws";
readonly service = "aws" as const;
private keys: AwsBedrockKey[] = [];
private readonly keys: AwsBedrockKey[] = [];
private store: KeyStore<AwsBedrockKey>;
private checker?: AwsKeyChecker;
private log = logger.child({ module: "key-provider", service: this.service });
constructor() {
const keyConfig = config.awsCredentials?.trim();
if (!keyConfig) {
this.log.warn(
"AWS_CREDENTIALS is not set. AWS Bedrock API will not be available."
);
return;
}
let bareKeys: string[];
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
for (const key of bareKeys) {
const newKey: AwsBedrockKey = {
key,
service: this.service,
modelFamilies: ["aws-claude"],
isDisabled: false,
isRevoked: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,
rateLimitedUntil: 0,
awsLoggingStatus: "unknown",
hash: `aws-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
["aws-claudeTokens"]: 0,
};
this.keys.push(newKey);
}
this.log.info({ keyCount: this.keys.length }, "Loaded AWS Bedrock keys.");
constructor(store: KeyStore<AwsBedrockKey>) {
this.store = store;
}
public init() {
public async init() {
const storeName = this.store.constructor.name;
const serializedKeys = await this.store.load();
if (serializedKeys.length === 0) {
return this.log.warn(
{ via: storeName },
"No AWS credentials found. AWS Bedrock API will not be available."
);
}
this.keys.push(...serializedKeys.map(AwsBedrockKeySerializer.deserialize));
this.log.info(
{ count: this.keys.length, via: storeName },
"Loaded AWS Bedrock keys."
);
if (config.checkKeys) {
this.checker = new AwsKeyChecker(this.keys, this.update.bind(this));
this.checker.start();
@@ -0,0 +1,33 @@
import crypto from "crypto";
import { AwsBedrockKey } from "..";
import { KeySerializer } from "../stores";
import { SerializedAwsBedrockKey } from "./provider";
export const AwsBedrockKeySerializer: KeySerializer<AwsBedrockKey> = {
serialize(key: AwsBedrockKey): SerializedAwsBedrockKey {
return { key: key.key };
},
deserialize(serializedKey: SerializedAwsBedrockKey): AwsBedrockKey {
const { key, ...rest } = serializedKey;
return {
key,
service: "aws",
modelFamilies: ["aws-claude"],
isDisabled: false,
isRevoked: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,
rateLimitedUntil: 0,
awsLoggingStatus: "unknown",
hash: `plm-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
["aws-claudeTokens"]: 0,
...rest,
};
},
};
+3 -6
View File
@@ -1,12 +1,9 @@
/* Manages OpenAI API keys. Tracks usage, disables expired keys, and provides
round-robin access to keys. Keys are stored in the OPENAI_KEY environment
variable as a comma-separated list of keys. */
import crypto from "crypto";
import http from "http";
import { Key, KeyProvider, Model } from "../index";
import { IncomingHttpHeaders } from "http";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { getOpenAIModelFamily, OpenAIModelFamily } from "../../models";
import { Key, KeyProvider, Model } from "../index";
import { KeyStore, SerializedKey } from "../stores";
import { OpenAIKeyChecker } from "./checker";
import { OpenAIKeySerializer } from "./serializer";
@@ -337,7 +334,7 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
key[`${getOpenAIModelFamily(model)}Tokens`] += tokens;
}
public updateRateLimits(keyHash: string, headers: http.IncomingHttpHeaders) {
public updateRateLimits(keyHash: string, headers: IncomingHttpHeaders) {
const key = this.keys.find((k) => k.hash === keyHash)!;
const requestsReset = headers["x-ratelimit-reset-requests"];
const tokensReset = headers["x-ratelimit-reset-tokens"];
+1 -1
View File
@@ -1,7 +1,7 @@
import { Key, KeyProvider } from "..";
import { KeyStore, SerializedKey } from "../stores";
import { logger } from "../../../logger";
import type { GooglePalmModelFamily } from "../../models";
import { KeyStore, SerializedKey } from "../stores";
import { GooglePalmKeySerializer } from "./serializer";
// https://developers.generativeai.google.com/models/language
+1 -1
View File
@@ -13,8 +13,8 @@ export const GooglePalmKeySerializer: KeySerializer<GooglePalmKey> = {
key,
service: "google-palm" as const,
modelFamilies: ["bison"],
isTrial: false,
isDisabled: false,
isRevoked: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,
+5 -3
View File
@@ -1,19 +1,21 @@
import { APIFormat, Key } from ".";
import { LLMService, Key } from ".";
import { assertNever } from "../utils";
import { KeySerializer } from "./stores";
import { OpenAIKeySerializer } from "./openai/serializer";
import { AnthropicKeySerializer } from "./anthropic/serializer";
import { GooglePalmKeySerializer } from "./palm/serializer";
import { AwsBedrockKeySerializer } from "./aws/serializer";
export function getSerializer(service: APIFormat): KeySerializer<Key> {
export function getSerializer(service: LLMService): KeySerializer<Key> {
switch (service) {
case "openai":
case "openai-text":
return OpenAIKeySerializer;
case "anthropic":
return AnthropicKeySerializer;
case "google-palm":
return GooglePalmKeySerializer;
case "aws":
return AwsBedrockKeySerializer;
default:
assertNever(service);
}
+2 -2
View File
@@ -1,7 +1,7 @@
import firebase from "firebase-admin";
import { getFirebaseApp } from "../../../config";
import { logger } from "../../../logger";
import { APIFormat, Key } from "..";
import { LLMService, Key } from "..";
import { KeyStore, assertSerializableKey } from ".";
import { KeySerializer } from ".";
@@ -13,7 +13,7 @@ export class FirebaseKeyStore<K extends Key> implements KeyStore<K> {
private flushInterval: NodeJS.Timeout | null = null;
constructor(
private service: APIFormat,
private readonly service: LLMService,
private serializer: KeySerializer<K>,
app = getFirebaseApp()
) {
+5 -3
View File
@@ -1,23 +1,25 @@
import { assertNever } from "../../utils";
import { APIFormat, Key } from "..";
import { LLMService, Key } from "..";
import { KeySerializer } from ".";
import { KeyStore } from ".";
export class MemoryKeyStore<K extends Key> implements KeyStore<K> {
private env: string;
constructor(service: APIFormat, private serializer: KeySerializer<K>) {
constructor(service: LLMService, private serializer: KeySerializer<K>) {
switch (service) {
case "anthropic":
this.env = "ANTHROPIC_KEY";
break;
case "openai":
case "openai-text":
this.env = "OPENAI_KEY";
break;
case "google-palm":
this.env = "GOOGLE_PALM_KEY";
break;
case "aws":
this.env = "AWS_CREDENTIALS"; // TODO: parse AWS security credentials
break;
default:
assertNever(service);
}