diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0449d4daf..16e1d905d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -119,23 +119,3 @@ jobs: with: name: pytest_durations_${{ matrix.os }}-${{ matrix.python }}-${{ matrix.job }} path: .test_durations - - Collect: - needs: Tests - runs-on: ubuntu-latest - steps: - - name: Cache - uses: actions/cache@v3 - with: - path: .test_durations - key: test_durations-0-${{ github.run_id }} - - name: Collect - uses: actions/download-artifact@v3 - with: - path: artifacts - - name: Consolidate - run: | - jq -n -S \ - 'reduce (inputs | to_entries[]) as {$key, $value} ({}; .[$key] += $value)' \ - artifacts/pytest_durations_*/.test_durations > .test_durations - diff --git a/pyproject.toml b/pyproject.toml index 4fafb6293..d468260bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ 'ftfy', 'tqdm', 'huggingface-hub', + 'safetensors', 'timm', ] dynamic = ["version"] @@ -47,10 +48,10 @@ dynamic = ["version"] [project.optional-dependencies] training = [ 'torch>=2.0', - 'webdataset>=0.2.5', + 'webdataset>=0.2.5,<=0.2.86', 'pandas', 'transformers[sentencepiece]', - 'timm>=1.0.7', + 'timm>=1.0.10', 'fsspec', ] test = [ diff --git a/requirements-test.txt b/requirements-test.txt index 77c9ebfc9..0507fcd9b 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,4 +1,4 @@ pytest-split==0.8.0 pytest==7.2.0 transformers[sentencepiece] -timm>=1.0.7 +timm>=1.0.10 diff --git a/requirements-training.txt b/requirements-training.txt index 0aa1fea8b..ab2754ba6 100644 --- a/requirements-training.txt +++ b/requirements-training.txt @@ -1,12 +1,13 @@ torch>=1.9.0 torchvision -webdataset>=0.2.5 +webdataset>=0.2.5,<=0.2.86 regex ftfy tqdm pandas braceexpand huggingface_hub +safetensors transformers[sentencepiece] -timm>=1.0.7 +timm>=1.0.10 fsspec diff --git a/requirements.txt b/requirements.txt index 46e10425e..4b1ff4a3d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ regex ftfy tqdm huggingface_hub +safetensors timm diff --git a/src/open_clip/constants.py b/src/open_clip/constants.py index 599c48c03..5bdfc2451 100644 --- a/src/open_clip/constants.py +++ b/src/open_clip/constants.py @@ -4,3 +4,8 @@ IMAGENET_STD = (0.229, 0.224, 0.225) INCEPTION_MEAN = (0.5, 0.5, 0.5) INCEPTION_STD = (0.5, 0.5, 0.5) + +# Default name for a weights file hosted on the Huggingface Hub. +HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl +HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version +HF_CONFIG_NAME = 'open_clip_config.json' diff --git a/src/open_clip/convert.py b/src/open_clip/convert.py index 84571e0f1..f0c06ffba 100644 --- a/src/open_clip/convert.py +++ b/src/open_clip/convert.py @@ -18,7 +18,9 @@ def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str): """ from timm.layers import resample_patch_embed, resample_abs_pos_embed - def _n2p(w, t=True): + def _n2p(w, t=True, idx=None): + if idx is not None: + w = w[idx] if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: w = w.flatten() if t: @@ -66,21 +68,28 @@ def _convert_timm_img(module, prefix): mha_sub, b_sub, ln1_sub = (0, 0, 1) for i, block in enumerate(module.blocks.children()): - block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w: + block_prefix = f'{prefix}Transformer/encoderblock/' + idx = i + else: + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + idx = None mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' - block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) - block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx)) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx)) block.attn.qkv.weight.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')])) block.attn.qkv.bias.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) - block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) - block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + _n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx)) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx)) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx)) for r in range(2): - getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) - getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) - block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) - block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) + getattr(block.mlp, f'fc{r + 1}').weight.copy_( + _n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx)) + getattr(block.mlp, f'fc{r + 1}').bias.copy_( + _n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx)) module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) @@ -129,13 +138,14 @@ def _convert_openclip_txt(module: TextTransformer, prefix): _convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/') module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale'])) module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias'])) - module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) - module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias'])) - - _convert_timm_img(model.visual.trunk, 'params/img/') - _convert_openclip_txt(model.text, 'params/txt/') - model.logit_bias.copy_(_n2p(w['params/b'])[0]) - model.logit_scale.copy_(_n2p(w['params/t'])[0]) + if module.text_projection is not None: + module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) + module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias'])) + + _convert_timm_img(model.visual.trunk, 'img/') + _convert_openclip_txt(model.text, 'txt/') + model.logit_bias.copy_(_n2p(w['b'])[0]) + model.logit_scale.copy_(_n2p(w['t'])[0]) @torch.no_grad() diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 86b44862f..82ebe2bb9 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -125,8 +125,21 @@ def get_tokenizer( return tokenizer -def load_state_dict(checkpoint_path: str, map_location='cpu'): - checkpoint = torch.load(checkpoint_path, map_location=map_location) +def load_state_dict( + checkpoint_path: str, + device='cpu', + weights_only=True, +): + # Check if safetensors or not and load weights accordingly + if str(checkpoint_path).endswith(".safetensors"): + from safetensors.torch import load_file + checkpoint = load_file(checkpoint_path, device=device) + else: + try: + checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only) + except TypeError: + checkpoint = torch.load(checkpoint_path, map_location=device) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif isinstance(checkpoint, torch.jit.ScriptModule): @@ -144,6 +157,8 @@ def load_checkpoint( model: Union[CLIP, CustomTextCLIP], checkpoint_path: str, strict: bool = True, + weights_only: bool = True, + device='cpu', ): if Path(checkpoint_path).suffix in ('.npz', '.npy'): # Separate path loading numpy big_vision (SigLIP) weights @@ -151,7 +166,7 @@ def load_checkpoint( load_big_vision_weights(model, checkpoint_path) return {} - state_dict = load_state_dict(checkpoint_path) + state_dict = load_state_dict(checkpoint_path, device=device, weights_only=weights_only) # Detect & convert 3rd party state_dicts -> open_clip state_dict = convert_state_dict(model, state_dict) diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 5a0fc935f..989662ebb 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -72,6 +72,7 @@ class CLIPTextCfg: final_ln_after_pool: bool = False # apply final LayerNorm after pooling pool_type: str = 'argmax' proj_bias: bool = False + proj_type: str = 'linear' # control final text projection, 'none' forces no projection output_tokens: bool = False act_kwargs: dict = None norm_kwargs: dict = None @@ -209,6 +210,7 @@ def _build_text_tower( no_causal_mask=text_cfg.no_causal_mask, pad_id=text_cfg.pad_id, pool_type=text_cfg.pool_type, + proj_type=text_cfg.proj_type, proj_bias=text_cfg.proj_bias, output_tokens=text_cfg.output_tokens, act_layer=act_layer, diff --git a/src/open_clip/model_configs/ViT-SO400M-14-SigLIP-378.json b/src/open_clip/model_configs/ViT-SO400M-14-SigLIP-378.json new file mode 100644 index 000000000..6bc14fabc --- /dev/null +++ b/src/open_clip/model_configs/ViT-SO400M-14-SigLIP-378.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 1152, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 378, + "timm_model_name": "vit_so400m_patch14_siglip_378", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 64, + "vocab_size": 32000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 1152, + "heads": 16, + "layers": 27, + "mlp_ratio": 3.7362, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-SO400M-16-SigLIP-i18n-256.json b/src/open_clip/model_configs/ViT-SO400M-16-SigLIP-i18n-256.json new file mode 100644 index 000000000..4e39b1b46 --- /dev/null +++ b/src/open_clip/model_configs/ViT-SO400M-16-SigLIP-i18n-256.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 1152, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 256, + "timm_model_name": "vit_so400m_patch16_siglip_256", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 64, + "vocab_size": 250000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP-i18n-256", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 1152, + "heads": 16, + "layers": 27, + "mlp_ratio": 3.7362, + "no_causal_mask": true, + "pool_type": "last", + "proj_type": "none", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/src/open_clip/pretrained.py b/src/open_clip/pretrained.py index 4dcbf4ae5..35de55064 100644 --- a/src/open_clip/pretrained.py +++ b/src/open_clip/pretrained.py @@ -3,10 +3,18 @@ import urllib import warnings from functools import partial -from typing import Dict, Union +from typing import Dict, Iterable, Optional, Union from tqdm import tqdm + +try: + import safetensors.torch + _has_safetensors = True +except ImportError: + _has_safetensors = False + + from .constants import ( IMAGENET_MEAN, IMAGENET_STD, @@ -14,6 +22,8 @@ INCEPTION_STD, OPENAI_DATASET_MEAN, OPENAI_DATASET_STD, + HF_WEIGHTS_NAME, + HF_SAFE_WEIGHTS_NAME, ) from .version import __version__ @@ -414,6 +424,12 @@ def _mccfg(url='', hf_hub='', **kwargs): "ViT-SO400M-14-SigLIP": dict( webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP/'), ), + "ViT-SO400M-16-SigLIP-i18n-256": dict( + webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP-i18n-256/'), + ), + "ViT-SO400M-14-SigLIP-378": dict( + webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'), # NOTE using 384 weights, but diff img_size used + ), "ViT-SO400M-14-SigLIP-384": dict( webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'), ), @@ -613,21 +629,52 @@ def has_hf_hub(necessary=False): return _has_hf_hub +def _get_safe_alternatives(filename: str) -> Iterable[str]: + """Returns potential safetensors alternatives for a given filename. + + Use case: + When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it. + """ + if filename == HF_WEIGHTS_NAME: + yield HF_SAFE_WEIGHTS_NAME + + if filename not in (HF_WEIGHTS_NAME,) and filename.endswith(".bin") or filename.endswith(".pth"): + yield filename[:-4] + ".safetensors" + + def download_pretrained_from_hf( model_id: str, - filename: str = 'open_clip_pytorch_model.bin', - revision=None, - cache_dir: Union[str, None] = None, + filename: Optional[str] = None, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, ): has_hf_hub(True) - cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) - return cached_file + + filename = filename or HF_WEIGHTS_NAME + + # Look for .safetensors alternatives and load from it if it exists + if _has_safetensors: + for safe_filename in _get_safe_alternatives(filename): + try: + cached_file = hf_hub_download( + repo_id=model_id, filename=safe_filename, revision=revision, cache_dir=cache_dir) + return cached_file + except Exception: + pass + + try: + # Attempt to download the file + cached_file = hf_hub_download( + repo_id=model_id, filename=filename, revision=revision, cache_dir=cache_dir) + return cached_file # Return the path to the downloaded file if successful + except Exception as e: + raise FileNotFoundError(f"Failed to download any files for {model_id}. Last error: {e}") def download_pretrained( cfg: Dict, force_hf_hub: bool = False, - cache_dir: Union[str, None] = None, + cache_dir: Optional[str] = None, ): target = '' if not cfg: diff --git a/src/open_clip/push_to_hf_hub.py b/src/open_clip/push_to_hf_hub.py index 26b5594e4..867a6d5f3 100644 --- a/src/open_clip/push_to_hf_hub.py +++ b/src/open_clip/push_to_hf_hub.py @@ -1,6 +1,5 @@ import argparse import json -import os from pathlib import Path from tempfile import TemporaryDirectory from typing import Optional, Tuple, Union @@ -28,14 +27,10 @@ except ImportError: _has_safetensors = False +from .constants import HF_WEIGHTS_NAME, HF_SAFE_WEIGHTS_NAME, HF_CONFIG_NAME from .factory import create_model_from_pretrained, get_model_config, get_tokenizer from .tokenizer import HFTokenizer -# Default name for a weights file hosted on the Huggingface Hub. -HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl -HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version -HF_CONFIG_NAME = 'open_clip_config.json' - def save_config_for_hf( model, @@ -193,7 +188,7 @@ def push_pretrained_to_hf_hub( tokenizer = get_tokenizer(model_name) if hf_tokenizer_self: # make hf tokenizer config in the uploaded model point to self instead of original location - model_config['text']['hf_tokenizer_name'] = repo_id + model_config['text_cfg']['hf_tokenizer_name'] = repo_id push_to_hf_hub( model=model, @@ -316,6 +311,7 @@ def generate_readme(model_card: dict, model_name: str): image_std=args.image_std, image_interpolation=args.image_interpolation, image_resize_mode=args.image_resize_mode, + hf_tokenizer_self=args.hf_tokenizer_self, ) print(f'{args.model} saved.') diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 4932abf2c..bf85dc8e7 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -677,11 +677,12 @@ def __init__( layers: int = 12, mlp_ratio: float = 4.0, ls_init_value: float = None, - output_dim: int = 512, + output_dim: Optional[int] = 512, embed_cls: bool = False, no_causal_mask: bool = False, pad_id: int = 0, pool_type: str = 'argmax', + proj_type: str = 'linear', proj_bias: bool = False, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, @@ -721,10 +722,13 @@ def __init__( else: self.register_buffer('attn_mask', self.build_causal_mask(), persistent=False) - if proj_bias: - self.text_projection = nn.Linear(width, output_dim) + if proj_type == 'none' or not output_dim: + self.text_projection = None else: - self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + if proj_bias: + self.text_projection = nn.Linear(width, output_dim) + else: + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) self.init_parameters()