Skip to content

Commit

Permalink
feat(proofs): added process_with_gemini endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
TTalex committed Nov 9, 2024
1 parent 1337015 commit b40f3cf
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 0 deletions.
5 changes: 5 additions & 0 deletions config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,8 @@
# ------------------------------------------------------------------------------

GOOGLE_CLOUD_VISION_API_KEY = os.getenv("GOOGLE_CLOUD_VISION_API_KEY")

# Google Gemini API
# ------------------------------------------------------------------------------

GOOGLE_GEMINI_API_KEY = os.getenv("GOOGLE_GEMINI_API_KEY")
8 changes: 8 additions & 0 deletions open_prices/api/proofs/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,11 @@ class ProofUpdateSerializer(serializers.ModelSerializer):
class Meta:
model = Proof
fields = Proof.UPDATE_FIELDS

class ProofProcessWithGeminiSerializer(serializers.ModelSerializer):
files = serializers.ListField(child=serializers.FileField(required=True, use_url=False))
mode = serializers.CharField() # TODO: this mode param should be used to select the prompt to execute, unimplemented for now

class Meta:
model = Proof # TODO: this has nothing to do with the Proof model, fixme
fields = ["files", "mode"]
23 changes: 23 additions & 0 deletions open_prices/api/proofs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@
ProofFullSerializer,
ProofUpdateSerializer,
ProofUploadSerializer,
ProofProcessWithGeminiSerializer
)
from open_prices.api.utils import get_source_from_request
from open_prices.common.authentication import CustomAuthentication
from open_prices.common.gemini import handle_bulk_labels
from open_prices.proofs.models import Proof
from open_prices.proofs.utils import store_file

import PIL.Image


class ProofViewSet(
mixins.ListModelMixin,
Expand Down Expand Up @@ -94,3 +98,22 @@ def upload(self, request: Request) -> Response:
proof = serializer.save(owner=self.request.user.user_id, source=source)
# return full proof
return Response(ProofFullSerializer(proof).data, status=status.HTTP_201_CREATED)


@extend_schema(request=ProofProcessWithGeminiSerializer)
@action(
detail=False,
methods=["POST"],
url_path="process_with_gemini",
parser_classes=[MultiPartParser],
)
def process_with_gemini(self, request: Request) -> Response:
if not request.data.get("files"):
return Response(
{"files": ["This field is required."]},
status=status.HTTP_400_BAD_REQUEST,
)
files = request.FILES.getlist('files')
sample_files = [PIL.Image.open(file.file) for file in files]
res = handle_bulk_labels(sample_files)
return Response(res, status=status.HTTP_200_OK)
130 changes: 130 additions & 0 deletions open_prices/common/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from typing import List, Dict
import google.generativeai as genai
import typing_extensions as typing
import json
import enum
import os
from django.conf import settings


genai.configure(api_key=settings.GOOGLE_GEMINI_API_KEY)
model = genai.GenerativeModel(model_name="gemini-1.5-flash")

# TODO: what about orther categories ?
class Products(enum.Enum):
OTHER = "other"
APPLES = "en:apples"
APRICOTS = "en:apricots"
ARTICHOKES = "en:artichokes"
ASPARAGUS = "en:asparagus"
AUBERGINES = "en:aubergines"
AVOCADOS = "en:avocados"
BANANAS = "en:bananas"
BEET = "en:beet"
BERRIES = "en:berries"
BLACKBERRIES = "en:blackberries"
BLUEBERRIES = "en:blueberries"
BOK_CHOY = "en:bok-choy"
BROCCOLI = "en:broccoli"
CABBAGES = "en:cabbages"
CARROTS = "en:carrots"
CAULIFLOWERS = "en:cauliflowers"
CELERY = "en:celery"
CELERY_STALK = "en:celery-stalk"
CEP_MUSHROOMS = "en:cep-mushrooms"
CHANTERELLES = "en:chanterelles"
CHERRIES = "en:cherries"
CHERRY_TOMATOES = "en:cherry-tomatoes"
CHICKPEAS = "en:chickpeas"
CHIVES = "en:chives"
CLEMENTINES = "en:clementines"
COCONUTS = "en:coconuts"
CRANBERRIES = "en:cranberries"
CUCUMBERS = "en:cucumbers"
DATES = "en:dates"
ENDIVES = "en:endives"
FIGS = "en:figs"
GARLIC = "en:garlic"
GINGER = "en:ginger"
GRAPEFRUITS = "en:grapefruits"
GRAPES = "en:grapes"
GREEN_BEANS = "en:green-beans"
KIWIS = "en:kiwis"
KAKIS = "en:kakis"
LEEKS = "en:leeks"
LEMONS = "en:lemons"
LETTUCES = "en:lettuces"
LIMES = "en:limes"
LYCHEES = "en:lychees"
MANDARIN_ORANGES = "en:mandarin-oranges"
MANGOES = "en:mangoes"
MELONS = "en:melons"
MUSHROOMS = "en:mushrooms"
NECTARINES = "en:nectarines"
ONIONS = "en:onions"
ORANGES = "en:oranges"
PAPAYAS = "en:papayas"
PASSION_FRUITS = "en:passion-fruits"
PEACHES = "en:peaches"
PEARS = "en:pears"
PEAS = "en:peas"
PEPPERS = "en:peppers"
PINEAPPLE = "en:pineapple"
PLUMS = "en:plums"
POMEGRANATES = "en:pomegranates"
POMELOS = "en:pomelos"
POTATOES = "en:potatoes"
PUMPKINS = "en:pumpkins"
RADISHES = "en:radishes"
RASPBERRIES = "en:raspberries"
RHUBARBS = "en:rhubarbs"
SCALLIONS = "en:scallions"
SHALLOTS = "en:shallots"
SPINACHS = "en:spinachs"
SPROUTS = "en:sprouts"
STRAWBERRIES = "en:strawberries"
TOMATOES = "en:tomatoes"
TURNIP = "en:turnip"
WATERMELONS = "en:watermelons"
WALNUTS = "en:walnuts"
ZUCCHINI = "en:zucchini"

# TODO: what about other origins ?
class Origin(enum.Enum):
FRANCE = "en:france"
ITALY = "en:italy"
SPAIN = "en:spain"
POLAND = "en:poland"
CHINA = "en:china"
BELGIUM = "en:belgium"
MOROCCO = "en:morocco"
PERU = "en:peru"
PORTUGAL = "en:portugal"
MEXICO = "en:mexico"
OTHER = "other"
UNKNOWN = "unknown"

class Unit(enum.Enum):
KILOGRAM = "KILOGRAM"
UNIT = "UNIT"

class Label(typing.TypedDict):
product: Products
price: float
origin: Origin
unit: Unit
organic: bool
barcode: str

class Labels(typing.TypedDict):
labels: list[Label]

def handle_bulk_labels(images):
response = model.generate_content(
["Here are " + str(len(images)) + " pictures containing a label. For each picture of a label, please extract all the following attributes: the product category matching product name, the origin category matching country of origin, the price, is the product organic, the unit (per KILOGRAM or per UNIT) and the barcode. I expect a list of " + str(len(images)) + " labels in your reply, no more, no less. If you cannot decode an attribute, set it to an empty string"] + images,
generation_config=genai.GenerationConfig(
response_mime_type="application/json", response_schema=Labels
)
)
vals = json.loads(response.text)
return vals
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ django-cors-headers = "^4.4.0"
sentry-sdk = {extras = ["django"], version = "^2.13.0"}
django-solo = "^2.3.0"
pillow = "^10.4.0"
google-generativeai = "^0.8.3"

[tool.poetry.group.dev.dependencies]
black = "~23.12.1"
Expand Down

0 comments on commit b40f3cf

Please sign in to comment.