Skip to content

Commit

Permalink
Combined merge PR, safetensors load & SigLIP i18n SO400m (#959)
Browse files Browse the repository at this point in the history
* safetensors support when loading from hf_hub --> check .bin file first, .safetensors only if .bin not found

* Add model defs & weights for new so400m i18n variant. Add a 378x378 config for the original 384x348 so400m because the patch size doesn't divide 384 properly.

* pin webdataset <= 0.2.86 due to breaks, keep timm at 1.0.9 for test/train until 1.0.10 is released

* Add webdataset max version to pyproject.toml as well

* Cleanup safetensors load support

* Update timm deps

* Attempt to fix broken collect in tests

* Remove collect from tests, should remove duration tracking too as it's overcomplicated

---------

Co-authored-by: Mehmet Deniz Birlikci <[email protected]>
  • Loading branch information
rwightman and deniz-birlikci authored Oct 15, 2024
1 parent fc5a37b commit 921b27c
Show file tree
Hide file tree
Showing 14 changed files with 187 additions and 65 deletions.
20 changes: 0 additions & 20 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,23 +119,3 @@ jobs:
with:
name: pytest_durations_${{ matrix.os }}-${{ matrix.python }}-${{ matrix.job }}
path: .test_durations

Collect:
needs: Tests
runs-on: ubuntu-latest
steps:
- name: Cache
uses: actions/cache@v3
with:
path: .test_durations
key: test_durations-0-${{ github.run_id }}
- name: Collect
uses: actions/download-artifact@v3
with:
path: artifacts
- name: Consolidate
run: |
jq -n -S \
'reduce (inputs | to_entries[]) as {$key, $value} ({}; .[$key] += $value)' \
artifacts/pytest_durations_*/.test_durations > .test_durations
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,18 @@ dependencies = [
'ftfy',
'tqdm',
'huggingface-hub',
'safetensors',
'timm',
]
dynamic = ["version"]

[project.optional-dependencies]
training = [
'torch>=2.0',
'webdataset>=0.2.5',
'webdataset>=0.2.5,<=0.2.86',
'pandas',
'transformers[sentencepiece]',
'timm>=1.0.7',
'timm>=1.0.10',
'fsspec',
]
test = [
Expand Down
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pytest-split==0.8.0
pytest==7.2.0
transformers[sentencepiece]
timm>=1.0.7
timm>=1.0.10
5 changes: 3 additions & 2 deletions requirements-training.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
torch>=1.9.0
torchvision
webdataset>=0.2.5
webdataset>=0.2.5,<=0.2.86
regex
ftfy
tqdm
pandas
braceexpand
huggingface_hub
safetensors
transformers[sentencepiece]
timm>=1.0.7
timm>=1.0.10
fsspec
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ regex
ftfy
tqdm
huggingface_hub
safetensors
timm
5 changes: 5 additions & 0 deletions src/open_clip/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@
IMAGENET_STD = (0.229, 0.224, 0.225)
INCEPTION_MEAN = (0.5, 0.5, 0.5)
INCEPTION_STD = (0.5, 0.5, 0.5)

# Default name for a weights file hosted on the Huggingface Hub.
HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl
HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version
HF_CONFIG_NAME = 'open_clip_config.json'
48 changes: 29 additions & 19 deletions src/open_clip/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str):
"""
from timm.layers import resample_patch_embed, resample_abs_pos_embed

def _n2p(w, t=True):
def _n2p(w, t=True, idx=None):
if idx is not None:
w = w[idx]
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
w = w.flatten()
if t:
Expand Down Expand Up @@ -66,21 +68,28 @@ def _convert_timm_img(module, prefix):

mha_sub, b_sub, ln1_sub = (0, 0, 1)
for i, block in enumerate(module.blocks.children()):
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w:
block_prefix = f'{prefix}Transformer/encoderblock/'
idx = i
else:
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
idx = None
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx))
block.attn.qkv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')]))
block.attn.qkv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
_n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx))
for r in range(2):
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel']))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))
getattr(block.mlp, f'fc{r + 1}').weight.copy_(
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx))

module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
Expand Down Expand Up @@ -129,13 +138,14 @@ def _convert_openclip_txt(module: TextTransformer, prefix):
_convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/')
module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale']))
module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias']))
module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias']))

_convert_timm_img(model.visual.trunk, 'params/img/')
_convert_openclip_txt(model.text, 'params/txt/')
model.logit_bias.copy_(_n2p(w['params/b'])[0])
model.logit_scale.copy_(_n2p(w['params/t'])[0])
if module.text_projection is not None:
module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias']))

_convert_timm_img(model.visual.trunk, 'img/')
_convert_openclip_txt(model.text, 'txt/')
model.logit_bias.copy_(_n2p(w['b'])[0])
model.logit_scale.copy_(_n2p(w['t'])[0])


@torch.no_grad()
Expand Down
21 changes: 18 additions & 3 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,21 @@ def get_tokenizer(
return tokenizer


def load_state_dict(checkpoint_path: str, map_location='cpu'):
checkpoint = torch.load(checkpoint_path, map_location=map_location)
def load_state_dict(
checkpoint_path: str,
device='cpu',
weights_only=True,
):
# Check if safetensors or not and load weights accordingly
if str(checkpoint_path).endswith(".safetensors"):
from safetensors.torch import load_file
checkpoint = load_file(checkpoint_path, device=device)
else:
try:
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only)
except TypeError:
checkpoint = torch.load(checkpoint_path, map_location=device)

if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif isinstance(checkpoint, torch.jit.ScriptModule):
Expand All @@ -144,14 +157,16 @@ def load_checkpoint(
model: Union[CLIP, CustomTextCLIP],
checkpoint_path: str,
strict: bool = True,
weights_only: bool = True,
device='cpu',
):
if Path(checkpoint_path).suffix in ('.npz', '.npy'):
# Separate path loading numpy big_vision (SigLIP) weights
from open_clip.convert import load_big_vision_weights
load_big_vision_weights(model, checkpoint_path)
return {}

state_dict = load_state_dict(checkpoint_path)
state_dict = load_state_dict(checkpoint_path, device=device, weights_only=weights_only)

# Detect & convert 3rd party state_dicts -> open_clip
state_dict = convert_state_dict(model, state_dict)
Expand Down
2 changes: 2 additions & 0 deletions src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class CLIPTextCfg:
final_ln_after_pool: bool = False # apply final LayerNorm after pooling
pool_type: str = 'argmax'
proj_bias: bool = False
proj_type: str = 'linear' # control final text projection, 'none' forces no projection
output_tokens: bool = False
act_kwargs: dict = None
norm_kwargs: dict = None
Expand Down Expand Up @@ -209,6 +210,7 @@ def _build_text_tower(
no_causal_mask=text_cfg.no_causal_mask,
pad_id=text_cfg.pad_id,
pool_type=text_cfg.pool_type,
proj_type=text_cfg.proj_type,
proj_bias=text_cfg.proj_bias,
output_tokens=text_cfg.output_tokens,
act_layer=act_layer,
Expand Down
30 changes: 30 additions & 0 deletions src/open_clip/model_configs/ViT-SO400M-14-SigLIP-378.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"embed_dim": 1152,
"init_logit_bias": -10,
"custom_text": true,
"vision_cfg": {
"image_size": 378,
"timm_model_name": "vit_so400m_patch14_siglip_378",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"context_length": 64,
"vocab_size": 32000,
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
"tokenizer_kwargs": {
"clean": "canonicalize"
},
"width": 1152,
"heads": 16,
"layers": 27,
"mlp_ratio": 3.7362,
"no_causal_mask": true,
"proj_bias": true,
"pool_type": "last",
"norm_kwargs":{
"eps": 1e-6
}
}
}
30 changes: 30 additions & 0 deletions src/open_clip/model_configs/ViT-SO400M-16-SigLIP-i18n-256.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"embed_dim": 1152,
"init_logit_bias": -10,
"custom_text": true,
"vision_cfg": {
"image_size": 256,
"timm_model_name": "vit_so400m_patch16_siglip_256",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"context_length": 64,
"vocab_size": 250000,
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP-i18n-256",
"tokenizer_kwargs": {
"clean": "canonicalize"
},
"width": 1152,
"heads": 16,
"layers": 27,
"mlp_ratio": 3.7362,
"no_causal_mask": true,
"pool_type": "last",
"proj_type": "none",
"norm_kwargs":{
"eps": 1e-6
}
}
}
61 changes: 54 additions & 7 deletions src/open_clip/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,27 @@
import urllib
import warnings
from functools import partial
from typing import Dict, Union
from typing import Dict, Iterable, Optional, Union

from tqdm import tqdm


try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False


from .constants import (
IMAGENET_MEAN,
IMAGENET_STD,
INCEPTION_MEAN,
INCEPTION_STD,
OPENAI_DATASET_MEAN,
OPENAI_DATASET_STD,
HF_WEIGHTS_NAME,
HF_SAFE_WEIGHTS_NAME,
)
from .version import __version__

Expand Down Expand Up @@ -414,6 +424,12 @@ def _mccfg(url='', hf_hub='', **kwargs):
"ViT-SO400M-14-SigLIP": dict(
webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP/'),
),
"ViT-SO400M-16-SigLIP-i18n-256": dict(
webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP-i18n-256/'),
),
"ViT-SO400M-14-SigLIP-378": dict(
webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'), # NOTE using 384 weights, but diff img_size used
),
"ViT-SO400M-14-SigLIP-384": dict(
webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'),
),
Expand Down Expand Up @@ -613,21 +629,52 @@ def has_hf_hub(necessary=False):
return _has_hf_hub


def _get_safe_alternatives(filename: str) -> Iterable[str]:
"""Returns potential safetensors alternatives for a given filename.
Use case:
When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it.
"""
if filename == HF_WEIGHTS_NAME:
yield HF_SAFE_WEIGHTS_NAME

if filename not in (HF_WEIGHTS_NAME,) and filename.endswith(".bin") or filename.endswith(".pth"):
yield filename[:-4] + ".safetensors"


def download_pretrained_from_hf(
model_id: str,
filename: str = 'open_clip_pytorch_model.bin',
revision=None,
cache_dir: Union[str, None] = None,
filename: Optional[str] = None,
revision: Optional[str] = None,
cache_dir: Optional[str] = None,
):
has_hf_hub(True)
cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
return cached_file

filename = filename or HF_WEIGHTS_NAME

# Look for .safetensors alternatives and load from it if it exists
if _has_safetensors:
for safe_filename in _get_safe_alternatives(filename):
try:
cached_file = hf_hub_download(
repo_id=model_id, filename=safe_filename, revision=revision, cache_dir=cache_dir)
return cached_file
except Exception:
pass

try:
# Attempt to download the file
cached_file = hf_hub_download(
repo_id=model_id, filename=filename, revision=revision, cache_dir=cache_dir)
return cached_file # Return the path to the downloaded file if successful
except Exception as e:
raise FileNotFoundError(f"Failed to download any files for {model_id}. Last error: {e}")


def download_pretrained(
cfg: Dict,
force_hf_hub: bool = False,
cache_dir: Union[str, None] = None,
cache_dir: Optional[str] = None,
):
target = ''
if not cfg:
Expand Down
Loading

0 comments on commit 921b27c

Please sign in to comment.