Skip to content

Commit

Permalink
KEP-2170: Create model and dataset initializers
Browse files Browse the repository at this point in the history
Signed-off-by: Andrey Velichkevich <[email protected]>
  • Loading branch information
andreyvelich committed Oct 23, 2024
1 parent fd4d102 commit fe481b7
Show file tree
Hide file tree
Showing 14 changed files with 225 additions and 2 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/publish-core-images.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ jobs:
dockerfile: cmd/training-operator.v2alpha1/Dockerfile
platforms: linux/amd64,linux/arm64,linux/ppc64le
tag-prefix: v2alpha1
- component-name: model-initiailizer-v2
dockerfile: cmd/initiailizer_v2/model/Dockerfile
platforms: linux/amd64,linux/arm64
tag-prefix: v2
- component-name: dataset-initiailizer-v2
dockerfile: cmd/initiailizer_v2/dataset/Dockerfile
platforms: linux/amd64,linux/arm64
tag-prefix: v2
- component-name: kubectl-delivery
dockerfile: build/images/kubectl-delivery/Dockerfile
platforms: linux/amd64,linux/arm64,linux/ppc64le
Expand Down
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ cover.out
.vscode/
__debug_bin

# Compiled python files.
*.pyc
# Python chache files
__pycache__/

# Emacs temporary files
*~
Expand Down
13 changes: 13 additions & 0 deletions cmd/initiailizer_v2/dataset/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
FROM python:3.11-alpine

WORKDIR /workspace

# Copy the required Python modules.
COPY cmd/initiailizer_v2/dataset/requirements.txt .
COPY sdk/python/kubeflow sdk/python/kubeflow
COPY pkg/initiailizer_v2 pkg/initiailizer_v2

# Install the needed packages.
RUN pip install -r requirements.txt

ENTRYPOINT ["python", "-m", "pkg.initiailizer_v2.dataset"]
1 change: 1 addition & 0 deletions cmd/initiailizer_v2/dataset/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
huggingface_hub==0.23.4
13 changes: 13 additions & 0 deletions cmd/initiailizer_v2/model/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
FROM python:3.11-alpine

WORKDIR /workspace

# Copy the required Python modules.
COPY cmd/initiailizer_v2/model/requirements.txt .
COPY sdk/python/kubeflow sdk/python/kubeflow
COPY pkg/initiailizer_v2 pkg/initiailizer_v2

# Install the needed packages.
RUN pip install -r requirements.txt

ENTRYPOINT ["python", "-m", "pkg.initiailizer_v2.model"]
1 change: 1 addition & 0 deletions cmd/initiailizer_v2/model/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
huggingface_hub==0.23.4
35 changes: 35 additions & 0 deletions pkg/initiailizer_v2/dataset/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import logging
import os
from urllib.parse import urlparse

import pkg.initiailizer_v2.utils.utils as utils
from pkg.initiailizer_v2.dataset.huggingface import HuggingFace

logging.basicConfig(
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%dT%H:%M:%SZ",
level=logging.INFO,
)

if __name__ == "__main__":
logging.info("Starting dataset initialization")

try:
storage_uri = os.environ[utils.STORAGE_URI_ENV]
except Exception as e:
logging.error("STORAGE_URI env variable must be set.")
raise e

logging.info(f"Storage URI: {storage_uri}")

storage_uri_parsed = urlparse(storage_uri)

match storage_uri_parsed.scheme:
# TODO (andreyvelich): Implement more dataset providers.
case utils.HF_SCHEME:
hf = HuggingFace()
hf.load_config()
hf.download_dataset(storage_uri_parsed)
case _:
logging.error("STORAGE_URI must have the valid dataset provider")
raise Exception
8 changes: 8 additions & 0 deletions pkg/initiailizer_v2/dataset/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from dataclasses import dataclass
from typing import Optional


# TODO (andreyvelich): This should be moved under Training V2 SDK.
@dataclass
class HuggingFaceDatasetConfig:
access_token: Optional[str] = None
40 changes: 40 additions & 0 deletions pkg/initiailizer_v2/dataset/huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import logging
from urllib.parse import ParseResult

import huggingface_hub

import pkg.initiailizer_v2.utils.utils as utils

# TODO (andreyvelich): This should be moved to SDK V2 constants.
import sdk.python.kubeflow.storage_initializer.constants as constants
from pkg.initiailizer_v2.dataset.config import HuggingFaceDatasetConfig

logging.basicConfig(
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%dT%H:%M:%SZ",
level=logging.INFO,
)


class HuggingFace:

def load_config(self):
config_dict = utils.get_config_from_env(HuggingFaceDatasetConfig)
logging.info(f"Config for HuggingFace dataset initiailizer: {config_dict}")
self.config = HuggingFaceDatasetConfig(**config_dict)

def download_dataset(self, storage_uri_parsed: ParseResult):
dataset_uri = storage_uri_parsed.netloc + storage_uri_parsed.path
logging.info(f"Downloading dataset: {dataset_uri}")
logging.info("-" * 40)

if self.config.access_token:
huggingface_hub.login(self.config.access_token)

huggingface_hub.snapshot_download(
repo_id=dataset_uri,
repo_type="dataset",
local_dir=constants.VOLUME_PATH_DATASET,
)

logging.info("Dataset has been downloaded")
37 changes: 37 additions & 0 deletions pkg/initiailizer_v2/model/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import logging
import os
from urllib.parse import urlparse

import pkg.initiailizer_v2.utils.utils as utils
from pkg.initiailizer_v2.model.huggingface import HuggingFace

logging.basicConfig(
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%dT%H:%M:%SZ",
level=logging.INFO,
)

if __name__ == "__main__":
logging.info("Starting pre-trained model initialization")

try:
storage_uri = os.environ[utils.STORAGE_URI_ENV]
except Exception as e:
logging.error("STORAGE_URI env variable must be set.")
raise e

logging.info(f"Storage URI: {storage_uri}")

storage_uri_parsed = urlparse(storage_uri)

match storage_uri_parsed.scheme:
# TODO (andreyvelich): Implement more model providers.
case utils.HF_SCHEME:
hf = HuggingFace()
hf.load_config()
hf.download_model(storage_uri_parsed)
case _:
logging.error(
f"STORAGE_URI must have the valid model provider. STORAGE_URI: {storage_uri}"
)
raise Exception
9 changes: 9 additions & 0 deletions pkg/initiailizer_v2/model/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from dataclasses import dataclass
from typing import Optional


# TODO (andreyvelich): This should be moved under Training V2 SDK.
@dataclass
class HuggingFaceModelInputConfig:
invalid: str
access_token: Optional[str] = None
42 changes: 42 additions & 0 deletions pkg/initiailizer_v2/model/huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import logging
from urllib.parse import ParseResult

import huggingface_hub

import pkg.initiailizer_v2.utils.utils as utils

# TODO (andreyvelich): This should be moved to SDK V2 constants.
import sdk.python.kubeflow.storage_initializer.constants as constants
from pkg.initiailizer_v2.model.config import HuggingFaceModelInputConfig

logging.basicConfig(
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%dT%H:%M:%SZ",
level=logging.INFO,
)


class HuggingFace:

def load_config(self):
config_dict = utils.get_config_from_env(HuggingFaceModelInputConfig)
logging.info(f"Config for HuggingFace model initiailizer: {config_dict}")
self.config = HuggingFaceModelInputConfig(**config_dict)

def download_model(self, storage_uri_parsed: ParseResult):
model_uri = storage_uri_parsed.netloc + storage_uri_parsed.path
logging.info(f"Downloading model: {model_uri}")
logging.info("-" * 40)

if self.config.access_token:
huggingface_hub.login(self.config.access_token)

# TODO (andreyvelich): We should verify these patterns for different models.
huggingface_hub.snapshot_download(
repo_id=model_uri,
local_dir=constants.VOLUME_PATH_MODEL,
allow_patterns=["*.json", "*.safetensors", "*.model"],
ignore_patterns=["*.msgpack", "*.h5", "*.bin"],
)

logging.info("Model has been downloaded")
Empty file.
16 changes: 16 additions & 0 deletions pkg/initiailizer_v2/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os
from dataclasses import fields
from typing import Dict

STORAGE_URI_ENV = "STORAGE_URI"
HF_SCHEME = "hf"


# Get DataClass config from the environment variables.
# Env names must be equal to the DataClass parameters.
def get_config_from_env(config) -> Dict[str, str]:
config_from_env = {}
for field in fields(config):
config_from_env[field.name] = os.getenv(field.name.upper())

return config_from_env

0 comments on commit fe481b7

Please sign in to comment.