Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Azure ? #2

Open
jays0606 opened this issue Jun 22, 2023 · 2 comments
Open

Support for Azure ? #2

jays0606 opened this issue Jun 22, 2023 · 2 comments

Comments

@jays0606
Copy link

I tried the Azure OpenAI for the embedding, but i don't think it's working.

Any ideas on supporting ? The Azure website says it doesn't support batch yet.

https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#embeddings

@micycle1
Copy link

micycle1 commented Feb 8, 2024

Here's an updated version of the library that works with v1.0+ of the OpenAI library and supports AzureOpenAI. I've modified the library slightly to take in a user-supplied client, rather than creating one internally.

Example Usage

from open_ai_multiclient import OpenAIMultiClient

async_client = AsyncAzureOpenAI(
    api_key = key,
    api_version="2023-12-01-preview",
    azure_endpoint=url,
)

params = dict(       
    model="gpt-4-turbo",
    temperature=0.35,
    top_p=0.25
)

api = OpenAIMultiClient(async_client, endpoint="chat.completions", data_template=params, max_retries=0)

openai_multi_client.py:

import logging
import asyncio
from dataclasses import dataclass
from threading import Thread
from typing import Any, Optional

from aioprocessing import AioJoinableQueue, AioQueue
from tenacity import wait_random_exponential, stop_after_attempt, AsyncRetrying, RetryError
from openai import AsyncOpenAI, AsyncAzureOpenAI

logger = logging.getLogger(__name__)

@dataclass
class Payload:
    endpoint: str
    data: dict
    metadata: Optional[dict]
    max_retries: int
    retry_multiplier: float
    retry_max: float
    attempt: int = 0
    failed: bool = False
    response: Any = None
    callback: Any = None

    def call_callback(self):
        if self.callback:
            self.callback(self)


class OpenAIMultiClient:
    def __init__(self,
                 aclient,
                 concurrency: int = 10,
                 max_retries: int = 10,
                 wait_interval: float = 0,
                 retry_multiplier: float = 1,
                 retry_max: float = 60,
                 endpoint: Optional[str] = None,
                 data_template: Optional[dict] = None,
                 metadata_template: Optional[dict] = None,
                 custom_api=None):
        self._aclient = aclient
        self._endpoint = endpoint
        self._wait_interval = wait_interval
        self._data_template = data_template or {}
        self._metadata_template = metadata_template or {}
        self._max_retries = max_retries
        self._retry_multiplier = retry_multiplier
        self._retry_max = retry_max
        self._concurrency = concurrency
        self._loop = asyncio.new_event_loop()
        self._in_queue = AioJoinableQueue(maxsize=concurrency)
        self._out_queue = AioQueue(maxsize=concurrency)
        self._event_loop_thread = Thread(target=self._run_event_loop)
        self._event_loop_thread.start()
        self._mock_api = custom_api
        for i in range(concurrency):
            asyncio.run_coroutine_threadsafe(self._worker(i), self._loop)

    def run_request_function(self, input_function, *args, stop_at_end=True, **kwargs):
        if stop_at_end:
            def f(*args, **kwargs):
                input_function(*args, **kwargs)
                self.close()
        else:
            f = input_function
        input_thread = Thread(target=f, args=args, kwargs=kwargs)
        input_thread.start()

    def _run_event_loop(self):
        asyncio.set_event_loop(self._loop)
        self._loop.run_forever()

    async def _process_payload(self, payload: Payload) -> Payload:
        logger.debug(f"Processing {payload}")
        if self._mock_api:
            payload.response = await self._mock_api(payload)
        elif payload.endpoint == "completions":
            payload.response = await self._aclient.completions.create(**payload.data)
        elif payload.endpoint == "chat.completions" or payload.endpoint == "chats":
            payload.response = await self._aclient.chat.completions.create(**payload.data)
        elif payload.endpoint == "embeddings":
            payload.response = await self._aclient.embeddings.create(**payload.data)
        elif payload.endpoint == "edits":
            payload.response = await self._aclient.edits.create(**payload.data)
        elif payload.endpoint == "images":
            payload.response = await self._aclient.images.generate(**payload.data)
        elif payload.endpoint == "fine-tunes":
            payload.response = await self._aclient.fine_tunes.create(**payload.data)
        else:
            raise ValueError(f"Unknown endpoint {payload.endpoint}")
        logger.debug(f"Processed {payload}")
        return payload

    async def _worker(self, i):
        while True:
            payload = await self._in_queue.coro_get()

            if payload is None:
                logger.debug(f"Exiting worker {i}")
                self._in_queue.task_done()
                break

            try:
                async for attempt in AsyncRetrying(
                        wait=wait_random_exponential(multiplier=payload.retry_multiplier, max=payload.retry_max),
                        stop=stop_after_attempt(payload.max_retries)):
                    with attempt:
                        try:
                            payload.attempt = attempt.retry_state.attempt_number
                            payload = await self._process_payload(payload)
                            await self._out_queue.coro_put(payload)
                            self._in_queue.task_done()
                        except Exception:
                            logger.exception(f"Error processing {payload}")
                            raise
            except RetryError:
                payload.failed = True
                logger.error(f"Failed to process {payload}")
                await self._out_queue.coro_put(payload)
                self._in_queue.task_done()
            await asyncio.sleep(self._wait_interval)

    def close(self):
        try:
            for i in range(self._concurrency):
                self._in_queue.put(None)
            self._in_queue.join()
            self._out_queue.put(None)
            self._loop.call_soon_threadsafe(self._loop.stop)
            self._event_loop_thread.join()
        except Exception as e:
            logger.error(f"Error closing: {e}")

    def __iter__(self):
        return self

    def __next__(self):
        out = self._out_queue.get()
        if out is None:
            raise StopIteration
        out.call_callback()
        return out

    def request(self,
                data: dict,
                endpoint: Optional[str] = None,
                metadata: Optional[dict] = None,
                callback: Any = None,
                max_retries: Optional[int] = None,
                retry_multiplier: Optional[float] = None,
                retry_max: Optional[float] = None):
        payload = Payload(
            endpoint=endpoint or self._endpoint,
            data={**self._data_template, **data},
            metadata={**self._metadata_template, **(metadata or {})},
            callback=callback,
            max_retries=max_retries or self._max_retries,
            retry_multiplier=retry_multiplier or self._retry_multiplier,
            retry_max=retry_max or self._retry_max
        )
        self._in_queue.put(payload)

    def pull_all(self):
        for _ in self:
            pass


class OrderedPayload(Payload):
    put_counter: int

    def __init__(self, *args, put_counter, **kwargs):
        super().__init__(*args, **kwargs)
        self.put_counter = put_counter


class OpenAIMultiOrderedClient(OpenAIMultiClient):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._put_counter = 0
        self._get_counter = 0
        self._get_cache = {}
        self._stopped = False

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            if self._stopped:
                out = None
            else:
                out = self._out_queue.get()
            if out is None:
                self._stopped = True
                if self._get_counter == self._put_counter:
                    raise StopIteration
                else:
                    out = self._get_cache[self._get_counter]
                    del self._get_cache[self._get_counter]
                    self._get_counter += 1
                    out.call_callback()
                    return out

            data_counter = out.put_counter
            if data_counter == self._get_counter:
                self._get_counter += 1
                out.call_callback()
                return out
            self._get_cache[data_counter] = out
            if self._get_counter in self._get_cache:
                out = self._get_cache[self._get_counter]
                del self._get_cache[self._get_counter]
                self._get_counter += 1
                out.call_callback()
                return out

    def request(self,
                data: dict,
                endpoint: Optional[str] = None,
                metadata: Optional[dict] = None,
                callback: Any = None,
                max_retries: Optional[int] = None,
                retry_multiplier: Optional[float] = None,
                retry_max: Optional[float] = None):
        payload = OrderedPayload(
            endpoint=endpoint or self._endpoint,
            data={**self._data_template, **data},
            metadata={**self._metadata_template, **(metadata or {})},
            callback=callback,
            max_retries=max_retries or self._max_retries,
            retry_multiplier=retry_multiplier or self._retry_multiplier,
            retry_max=retry_max or self._retry_max,
            put_counter=self._put_counter
        )
        self._put_counter += 1
        self._in_queue.put(payload)

@gigberg
Copy link

gigberg commented Apr 28, 2024

thanks a lot, it works for me

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants