From 942f6d7709dbe70556fea8df80ae4a59f2edb9bd Mon Sep 17 00:00:00 2001 From: Adrien Brault Date: Fri, 31 May 2024 23:52:26 +0200 Subject: [PATCH] feat: DSN configuration --- README.md | 47 +++++++++++++ src/Instructrice.php | 8 +-- src/LLM/Cost.php | 4 +- src/LLM/DSNParser.php | 122 +++++++++++++++++++++++++++++++++ src/LLM/LLMConfig.php | 2 + src/LLM/LLMFactory.php | 15 +++- src/LLM/Provider/Anthropic.php | 4 +- src/LLM/Provider/Google.php | 4 +- tests/LLM/DSNParserTest.php | 85 +++++++++++++++++++++++ 9 files changed, 280 insertions(+), 11 deletions(-) create mode 100644 src/LLM/DSNParser.php create mode 100644 tests/LLM/DSNParserTest.php diff --git a/README.md b/README.md index 66996a8..02b2920 100644 --- a/README.md +++ b/README.md @@ -276,6 +276,53 @@ $instructrice->get( ); ``` +#### DSN + +You may configure the LLM using a DSN: +- the scheme is the provider: `openai`, `openai-http`, `anthropic`, `google` +- the password is the api key +- the host, port and path are the api endpoints without the scheme +- the query string: + - `model` is the model name + - `context` is the context window + - `strategy` is the strategy to use: + - `json` for json mode with the schema in the prompt only + - `json_with_schema` for json mode with probably the completion perfectly constrained to the schema + - `tool_any` + - `tool_auto` + - `tool_function` + +Examples: +```php +use AdrienBrault\Instructrice\InstructriceFactory; + +$instructrice = InstructriceFactory::create( + defaultLlm: 'openai://:api_key@api.openai.com/v1/chat/completions?model=gpt-3.5-turbo&strategy=tool_auto&context=16000' +); + +$instructrice->get( + ..., + llm: 'openai-http://localhost:11434?model=adrienbrault/nous-hermes2theta-llama3-8b&strategy=json&context=8000' +); + +$instructrice->get( + ..., + llm: 'openai://:api_key@api.fireworks.ai/inference/v1/chat/completions?model=accounts/fireworks/models/llama-v3-70b-instruct&context=8000&strategy=json_with_schema' +); + +$instructrice->get( + ..., + llm: 'google://:api_key@generativelanguage.googleapis.com/v1beta/models?model=gemini-1.5-flash&context=1000000' +); + +$instructrice->get( + ..., + llm: 'anthropic://:api_key@api.anthropic.com?model=claude-3-haiku-20240307&context=200000' +); +``` + +#### LLMInterface + You may also implement [LLMInterface](src/LLM/LLMInterface.php). ## Acknowledgements diff --git a/src/Instructrice.php b/src/Instructrice.php index d696982..d33f619 100644 --- a/src/Instructrice.php +++ b/src/Instructrice.php @@ -27,7 +27,7 @@ class Instructrice { public function __construct( - private readonly ProviderModel|LLMConfig $defaultLlm, + private readonly ProviderModel|LLMConfig|string $defaultLlm, private readonly LLMFactory $llmFactory, private readonly LoggerInterface $logger, private readonly SchemaFactory $schemaFactory, @@ -50,7 +50,7 @@ public function get( ?string $prompt = null, array $options = [], ?callable $onChunk = null, - LLMInterface|LLMConfig|ProviderModel|null $llm = null, + LLMInterface|LLMConfig|ProviderModel|string|null $llm = null, ) { $denormalize = fn (mixed $data) => $data; $schema = $type; @@ -89,7 +89,7 @@ public function list( ?string $prompt = null, array $options = [], ?callable $onChunk = null, - LLMInterface|LLMConfig|ProviderModel|null $llm = null, + LLMInterface|LLMConfig|ProviderModel|string|null $llm = null, ): array { $wrappedWithProperty = 'list'; $schema = [ @@ -145,7 +145,7 @@ private function getAndDenormalize( string $prompt, bool $truncateAutomatically = false, ?callable $onChunk = null, - LLMInterface|LLMConfig|ProviderModel|null $llm = null, + LLMInterface|LLMConfig|ProviderModel|string|null $llm = null, ): mixed { if (($schema['type'] ?? null) !== 'object') { $wrappedWithProperty = 'inner'; diff --git a/src/LLM/Cost.php b/src/LLM/Cost.php index 569a942..5f5cb4c 100644 --- a/src/LLM/Cost.php +++ b/src/LLM/Cost.php @@ -7,8 +7,8 @@ class Cost { public function __construct( - public readonly float $millionPromptTokensPrice, - public readonly float $millionCompletionTokensPrice, + public readonly float $millionPromptTokensPrice = 0, + public readonly float $millionCompletionTokensPrice = 0, ) { } diff --git a/src/LLM/DSNParser.php b/src/LLM/DSNParser.php new file mode 100644 index 0000000..7651da6 --- /dev/null +++ b/src/LLM/DSNParser.php @@ -0,0 +1,122 @@ + string(), + 'pass' => optional(string()), + 'host' => string(), + 'port' => optional(int()), + 'path' => optional(string()), + 'query' => string(), + ], true)->coerce($parsedUrl); + + $apiKey = $parsedUrl['pass'] ?? null; + $host = $parsedUrl['host']; + $port = $parsedUrl['port'] ?? null; + $path = $parsedUrl['path'] ?? null; + $query = $parsedUrl['query']; + + $hostWithPort = $host . ($port === null ? '' : ':' . $port); + + $client = union( + literal_scalar('openai'), + literal_scalar('openai-http'), + literal_scalar('anthropic'), + literal_scalar('google') + )->coerce($parsedUrl['scheme']); + + parse_str($query, $parsedQuery); + $model = $parsedQuery['model']; + $strategyName = $parsedQuery['strategy'] ?? null; + $context = (int) ($parsedQuery['context'] ?? null); + + if (! \is_string($model)) { + throw new InvalidArgumentException('The DSN "model" query string must be a string'); + } + + if ($context <= 0) { + throw new InvalidArgumentException('The DSN "context" query string must be a positive integer'); + } + + $scheme = 'https'; + + $strategy = null; + if ($strategyName === 'json') { + $strategy = OpenAiJsonStrategy::JSON; + } elseif ($strategyName === 'json_with_schema') { + $strategy = OpenAiJsonStrategy::JSON_WITH_SCHEMA; + } elseif ($strategyName === 'tool_any') { + $strategy = OpenAiToolStrategy::ANY; + } elseif ($strategyName === 'tool_auto') { + $strategy = OpenAiToolStrategy::AUTO; + } elseif ($strategyName === 'tool_function') { + $strategy = OpenAiToolStrategy::FUNCTION; + } + + if ($client === 'anthropic') { + $headers = [ + 'x-api-key' => $apiKey, + ]; + $llmClass = AnthropicLLM::class; + $path ??= '/v1/messages'; + } elseif ($client === 'google') { + $headers = [ + 'x-api-key' => $apiKey, + ]; + $llmClass = GoogleLLM::class; + $path ??= '/v1beta/models'; + } elseif ($client === 'openai' || $client === 'openai-http') { + $path ??= '/v1/chat/completions'; + $headers = $apiKey === null ? [] : [ + 'Authorization' => 'Bearer ' . $apiKey, + ]; + + $llmClass = OpenAiLLM::class; + + if ($client === 'openai-http') { + $scheme = 'http'; + } + } else { + throw new InvalidArgumentException(sprintf('Unknown client "%s", use one of %s', $client, implode(', ', ['openai', 'anthropic', 'google']))); + } + + $uri = $scheme . '://' . $hostWithPort . $path; + + return new LLMConfig( + $uri, + $model, + $context, + $model, + $hostWithPort, + new Cost(), + $strategy, + headers: $headers, + llmClass: $llmClass + ); + } +} diff --git a/src/LLM/LLMConfig.php b/src/LLM/LLMConfig.php index 401311e..4797a72 100644 --- a/src/LLM/LLMConfig.php +++ b/src/LLM/LLMConfig.php @@ -15,6 +15,7 @@ class LLMConfig * @param callable(mixed, string): string $systemPrompt * @param array $headers * @param list $stopTokens + * @param class-string $llmClass */ public function __construct( public readonly string $uri, @@ -29,6 +30,7 @@ public function __construct( public readonly ?int $maxTokens = null, public readonly ?string $docUrl = null, array|false|null $stopTokens = null, + public readonly ?string $llmClass = null, ) { if ($stopTokens !== false) { $this->stopTokens = $stopTokens ?? ["```\n\n", '<|im_end|>', "\n\n\n", "\t\n\t\n"]; diff --git a/src/LLM/LLMFactory.php b/src/LLM/LLMFactory.php index b8a4270..c471792 100644 --- a/src/LLM/LLMFactory.php +++ b/src/LLM/LLMFactory.php @@ -42,18 +42,23 @@ public function __construct( private readonly array $apiKeys = [], private readonly Gpt3Tokenizer $tokenizer = new Gpt3Tokenizer(new Gpt3TokenizerConfig()), private readonly ParserInterface $parser = new JsonParser(), + private readonly DSNParser $dsnParser = new DSNParser(), ) { } - public function create(LLMConfig|ProviderModel $config): LLMInterface + public function create(LLMConfig|ProviderModel|string $config): LLMInterface { + if (\is_string($config)) { + $config = $this->dsnParser->parse($config); + } + if ($config instanceof ProviderModel) { $apiKey = $this->apiKeys[$config::class] ?? null; $apiKey ??= self::getProviderModelApiKey($config, true) ?? 'sk-xxx'; $config = $config->createConfig($apiKey); } - if (str_contains($config->uri, 'api.anthropic.com')) { + if ($config->llmClass === AnthropicLLM::class) { return new AnthropicLLM( $this->client, $this->logger, @@ -62,7 +67,7 @@ public function create(LLMConfig|ProviderModel $config): LLMInterface ); } - if (str_contains($config->uri, 'googleapis.com')) { + if ($config->llmClass === GoogleLLM::class) { return new GoogleLLM( $config, $this->client, @@ -72,6 +77,10 @@ public function create(LLMConfig|ProviderModel $config): LLMInterface ); } + if ($config->llmClass !== OpenAiLLM::class && $config->llmClass !== null) { + throw new InvalidArgumentException(sprintf('Unknown LLM class %s', $config->llmClass)); + } + return new OpenAiLLM( $config, $this->client, diff --git a/src/LLM/Provider/Anthropic.php b/src/LLM/Provider/Anthropic.php index 85676db..4609cc9 100644 --- a/src/LLM/Provider/Anthropic.php +++ b/src/LLM/Provider/Anthropic.php @@ -4,6 +4,7 @@ namespace AdrienBrault\Instructrice\LLM\Provider; +use AdrienBrault\Instructrice\LLM\Client\AnthropicLLM; use AdrienBrault\Instructrice\LLM\Cost; use AdrienBrault\Instructrice\LLM\LLMConfig; @@ -63,7 +64,8 @@ public function createConfig(string $apiKey): LLMConfig headers: [ 'x-api-key' => $apiKey, ], - docUrl: 'https://docs.anthropic.com/claude/docs/models-overview' + docUrl: 'https://docs.anthropic.com/claude/docs/models-overview', + llmClass: AnthropicLLM::class, ); } } diff --git a/src/LLM/Provider/Google.php b/src/LLM/Provider/Google.php index db0ee3c..580de5f 100644 --- a/src/LLM/Provider/Google.php +++ b/src/LLM/Provider/Google.php @@ -4,6 +4,7 @@ namespace AdrienBrault\Instructrice\LLM\Provider; +use AdrienBrault\Instructrice\LLM\Client\GoogleLLM; use AdrienBrault\Instructrice\LLM\Cost; use AdrienBrault\Instructrice\LLM\LLMConfig; @@ -36,7 +37,8 @@ public function createConfig(string $apiKey): LLMConfig headers: [ 'x-api-key' => $apiKey, ], - docUrl: 'https://ai.google.dev/gemini-api/docs/models/gemini' + docUrl: 'https://ai.google.dev/gemini-api/docs/models/gemini', + llmClass: GoogleLLM::class, ); } } diff --git a/tests/LLM/DSNParserTest.php b/tests/LLM/DSNParserTest.php new file mode 100644 index 0000000..d865582 --- /dev/null +++ b/tests/LLM/DSNParserTest.php @@ -0,0 +1,85 @@ +parse('openai://:api_key@api.openai.com/v1/chat/completions?model=gpt-3.5-turbo&strategy=tool_auto&context=16000'); + + $this->assertInstanceOf(LLMConfig::class, $config); + $this->assertSame(OpenAiLLM::class, $config->llmClass); + $this->assertSame('https://api.openai.com/v1/chat/completions', $config->uri); + $this->assertSame('gpt-3.5-turbo', $config->model); + $this->assertSame(16000, $config->contextWindow); + $this->assertSame('gpt-3.5-turbo', $config->label); + $this->assertSame('api.openai.com', $config->provider); + $this->assertSame(OpenAiToolStrategy::AUTO, $config->strategy); + $this->assertSame([ + 'Authorization' => 'Bearer api_key', + ], $config->headers); + $this->assertNull($config->docUrl); + } + + public function testOpenAiPathLess(): void + { + $config = (new DSNParser())->parse('openai://:api_key@api.openai.com?model=gpt-3.5-turbo&strategy=tool_auto&context=16000'); + + $this->assertInstanceOf(LLMConfig::class, $config); + $this->assertSame(OpenAiLLM::class, $config->llmClass); + $this->assertSame('https://api.openai.com/v1/chat/completions', $config->uri); + } + + public function testAnthropic(): void + { + $config = (new DSNParser())->parse('anthropic://:api_key@api.anthropic.com/v1/messages?model=claude-3-haiku-20240307&context=200000'); + + $this->assertInstanceOf(LLMConfig::class, $config); + $this->assertSame(AnthropicLLM::class, $config->llmClass); + $this->assertSame('https://api.anthropic.com/v1/messages', $config->uri); + $this->assertSame('claude-3-haiku-20240307', $config->model); + $this->assertSame(200000, $config->contextWindow); + $this->assertSame('claude-3-haiku-20240307', $config->label); + $this->assertSame('api.anthropic.com', $config->provider); + $this->assertNull($config->strategy); + $this->assertSame([ + 'x-api-key' => 'api_key', + ], $config->headers); + } + + public function testAnthropicPathLess(): void + { + $config = (new DSNParser())->parse('anthropic://:api_key@api.anthropic.com?model=claude-3-haiku-20240307&context=200000'); + + $this->assertInstanceOf(LLMConfig::class, $config); + $this->assertSame(AnthropicLLM::class, $config->llmClass); + $this->assertSame('https://api.anthropic.com/v1/messages', $config->uri); + } + + public function testOllama(): void + { + $config = (new DSNParser())->parse('openai-http://localhost:11434?model=adrienbrault/nous-hermes2theta-llama3-8b&strategy=json&context=8000'); + + $this->assertInstanceOf(LLMConfig::class, $config); + $this->assertSame(OpenAiLLM::class, $config->llmClass); + $this->assertSame('http://localhost:11434/v1/chat/completions', $config->uri); + $this->assertSame('adrienbrault/nous-hermes2theta-llama3-8b', $config->model); + $this->assertSame(8000, $config->contextWindow); + $this->assertSame('adrienbrault/nous-hermes2theta-llama3-8b', $config->label); + $this->assertSame('localhost:11434', $config->provider); + $this->assertSame(OpenAiJsonStrategy::JSON, $config->strategy); + } +}