Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update APIs service #17

Merged
merged 3 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Aloha!

[![License](https://img.shields.io/github/license/QPod/aloha)](https://github.com/QPod/aloha/blob/main/LICENSE)
[![GitHub Workflow Status](https://img.shields.io/github/workflow/status/QPod/aloha/build)](https://github.com/QPod/aloha/actions)
[![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/QPod/aloha-python/pip.yml?branch=main)](https://github.com/QPod/aloha-python/actions)
[![Join the Gitter Chat](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/QPod/)
[![PyPI version](https://img.shields.io/pypi/v/aloha)](https://pypi.python.org/pypi/aloha/)
[![PyPI Downloads](https://img.shields.io/pypi/dm/aloha)](https://pepy.tech/badge/aloha/)
Expand All @@ -21,6 +21,6 @@ Please generously STAR★ our project or donate to us! [![GitHub Starts](https:

## Getting started

```py
```shell
pip install aloha[all]
```
Empty file.
86 changes: 86 additions & 0 deletions demo/app_common/ainlp/model_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import List

import torch
from transformers import AutoTokenizer, AutoModel

from aloha.service.streamer import ManagedModel

SEED = 0
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)


class TextUnmaskModel:
def __init__(self, max_sent_len=16, model_path="bert-base-uncased"):
self.model_path = model_path
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
self.transformer = AutoModel.from_pretrained(self.model_path)
self.transformer.eval()
self.transformer.to(device="cuda")
self.max_sent_len = max_sent_len

def predict(self, batch: List[str]) -> List[str]:
"""predict masked word"""
batch_inputs = []
masked_indexes = []

for text in batch:
tokenized_text = self.tokenizer.tokenize(text)
if len(tokenized_text) > self.max_sent_len - 2:
tokenized_text = tokenized_text[: self.max_sent_len - 2]

tokenized_text = ['[CLS]'] + tokenized_text + ['[SEP]']
tokenized_text += ['[PAD]'] * (self.max_sent_len - len(tokenized_text))

indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
batch_inputs.append(indexed_tokens)
masked_indexes.append(tokenized_text.index('[MASK]'))

tokens_tensor = torch.tensor(batch_inputs).to("cuda")

with torch.no_grad():
# prediction_scores: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
prediction_scores = self.transformer(tokens_tensor)[0]

batch_outputs = []
for i in range(len(batch_inputs)):
predicted_index = torch.argmax(prediction_scores[i, masked_indexes[i]]).item()
predicted_token = self.tokenizer.convert_ids_to_tokens(predicted_index)
batch_outputs.append(predicted_token)

return batch_outputs


class ManagedBertModel(ManagedModel):
def init_model(self):
self.model = TextUnmaskModel()

def predict(self, batch):
return self.model.predict(batch)


def test_simple():
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")
inputs = tokenizer("Hello! My name is [MASK]!", return_tensors="pt")
outputs = model(**inputs)
print(outputs)

predicted_index = torch.argmax(outputs[1]).item()
predicted_token = tokenizer.convert_ids_to_tokens(predicted_index)
print(predicted_token)


def test_batch():
batch_text = [
"twinkle twinkle [MASK] star.",
"Happy birthday to [MASK].",
'the answer to life, the [MASK], and everything.'
]
model = TextUnmaskModel()
outputs = model.predict(batch_text)
print(outputs)


if __name__ == "__main__":
test_simple()
Empty file.
14 changes: 14 additions & 0 deletions demo/app_common/api/api_multipart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from aloha.logger import LOG
from aloha.service.api.v0 import APIHandler


class MultipartHandler(APIHandler):
def response(self, params=None, *args, **kwargs):
LOG.debug(params)
return params


default_handlers = [
# internal API: QueryDB Postgres with sql directly
(r"/api_internal/multipart", MultipartHandler),
]
1 change: 1 addition & 0 deletions demo/app_common/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ def main():
modules_to_load = [
"app_common.api.api_common_sys_info",
"app_common.api.api_common_query_postgres",
"app_common.api.api_multipart",
]

if 'service' not in SETTINGS.config:
Expand Down
8 changes: 6 additions & 2 deletions src/aloha/config/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,19 @@ def get_config_files() -> list:

files = files_config.split(',')
ret = []
msgs = []
for f in files:
file = get_config_dir(f)
if not os.path.exists(file):
warnings.warn('Expecting config file [%s] but it does not exists!' % file)
msgs.append('Expecting config file [%s] but it does not exists!' % file)
else:
print(' ---> Loading config file [%s]' % file)
ret.append(os.path.expandvars(f))
if len(ret) == 0:
warnings.warn('No config files set properly, EMPTY config will be used!')
msgs.append('No config files set properly, EMPTY config will be used!')

if len(msgs) > 0:
warnings.warn('\n'.join(msgs))
return ret


Expand Down
3 changes: 2 additions & 1 deletion src/aloha/encrypt/vault/cyberark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from ...logger import LOG

requests.packages.urllib3.disable_warnings(InsecureRequestWarning)
requests.packages.urllib3.util.ssl_.DEFAULT_CIPHERS += ':HIGHT:!DH:!aNULL'
if hasattr(requests.packages.urllib3.util.ssl_, 'DEFAULT_CIPHERS'):
requests.packages.urllib3.util.ssl_.DEFAULT_CIPHERS += ':HIGHT:!DH:!aNULL'


class CyberArkVault(BaseVault, AesEncryptor):
Expand Down
6 changes: 4 additions & 2 deletions src/aloha/service/api/v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ class APIHandler(AbstractApiHandler, ABC):
}

async def post(self, *args, **kwargs):
body_arguments = self.request_body
kwargs.update(body_arguments)
req_body = self.request_body

if req_body is not None: # body_arguments
kwargs.update(req_body)

resp = dict(code=5200, message=['success'])
try:
Expand Down
14 changes: 11 additions & 3 deletions src/aloha/service/http/base_api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def request_body(self) -> dict:
body_arguments: dict = Optional[None]

if content_type.startswith('multipart/form-data'): # only parse files when 'Content-Type' starts with 'multipart/form-data'
body_arguments = self.request.body_arguments
body_arguments = self.request_param # self.request.body_arguments
else:
try:
body = self.request.body.decode('utf-8')
Expand All @@ -62,8 +62,16 @@ def request_body(self) -> dict:

@property
def request_param(self) -> dict:
url_arguments: dict = {k: v[0].decode('utf-8') for k, v in self.request.arguments.items()}
return url_arguments
ret: dict = {}
for k, v in self.request.arguments.items():
val = v[0].decode('utf-8')
try:
value = json.loads(val)
except json.JSONDecodeError:
value = val
ret[k] = value

return ret


class DefaultHandler404(AbstractApiHandler):
Expand Down
33 changes: 33 additions & 0 deletions src/aloha/service/http/files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import time

import requests

from ...logger import LOG


def iter_over_request_files(request, url_files):
for file_key, files in request.files.items(): # iter over files uploaded by multipart
for f in files:
file_name, content_type = f["filename"], f["content_type"]
body = f.get('body', b"")
LOG.info(f"File {file_name} from multipart has content type {content_type} and length bytes={len(body)}")
yield file_key, file_name, content_type, body

for file_key, list_url in {'url_files': url_files or []}.items(): # iter over files specified by `url_files`
for url in sorted(set(list_url)):
try:
t_start = time.time()
resp = requests.get(url, stream=True) # download the file from given url
if resp.status_code == 200:
body = resp.content
content_type = resp.headers.get("Content-Type", "UNKNOWN")
else:
raise RuntimeError("Failed to download file after %s seconds with code=%s from URL %s" % (
time.time() - t_start, resp.status_code, url
))
del resp
except Exception as e:
raise e
t_cost = time.time() - t_start
LOG.info(f"File {url} has content type {content_type} and length bytes={len(body)}, downloaded in {t_cost} seconds")
yield 'url_files', url, content_type, body
8 changes: 6 additions & 2 deletions src/aloha/service/streamer/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
import threading
import time

from redis import Redis

from .base import BaseStreamer, BaseWorker, TIMEOUT, TIME_SLEEP, logger
from ...logger import LOG

try:
from redis import Redis
except ImportError:
LOG.warn('redis not installed, service.streamer.RedisStreamer will no be available!')


class RedisWorker(BaseWorker):
Expand Down
Loading
Loading