You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
WARNING: currently NOT compatible with keras 3.x, if using tensorflow>=2.16.0, needs to install pip install tf-keras~=$(pip show tensorflow | awk -F ': ' '/Version/{print $2}') manually. While importing, import this package ahead of Tensorflow, or set export TF_USE_LEGACY_KERAS=1.
It's not recommended downloading and loading model from h5 file directly, better building model and loading weights like import kecam; mm = kecam.models.LCNet050().
coco_train_script.py for TF is still under testing...
Install as pip package. kecam is a short alias name of this package. Note: the pip package kecam doesn't set any backend requirement, make sure either Tensorflow or PyTorch installed before hand. For PyTorch backend usage, refer Keras PyTorch Backend.
pip install -U kecam
# Or
pip install -U keras-cv-attention-models
# Or
pip install -U git+https://github.com/leondgarse/keras_cv_attention_models
Refer to each sub directory for detail usage.
Basic model prediction
fromkeras_cv_attention_modelsimportvolomm=volo.VOLO_d1(pretrained="imagenet")
""" Run predict """importtensorflowastffromtensorflowimportkerasfromkeras_cv_attention_models.test_imagesimportcatimg=cat()
imm=keras.applications.imagenet_utils.preprocess_input(img, mode='torch')
pred=mm(tf.expand_dims(tf.image.resize(imm, mm.input_shape[1:3]), 0)).numpy()
pred=tf.nn.softmax(pred).numpy() # If classifier activation is not softmaxprint(keras.applications.imagenet_utils.decode_predictions(pred)[0])
# [('n02124075', 'Egyptian_cat', 0.99664897),# ('n02123045', 'tabby', 0.0007249644),# ('n02123159', 'tiger_cat', 0.00020345),# ('n02127052', 'lynx', 5.4973923e-05),# ('n02123597', 'Siamese_cat', 2.675306e-05)]
Or just use model preset preprocess_input and decode_predictions
num_classes={custom output classes} others than 1000 or 0 will just skip loading the header Dense layer weights. As model.load_weights(weight_file, by_name=True, skip_mismatch=True) is used for loading weights.
fromkeras_cv_attention_modelsimportswin_transformer_v2mm=swin_transformer_v2.SwinTransformerV2Tiny_window8(num_classes=64)
# >>>> Load pretrained from: ~/.keras/models/swin_transformer_v2_tiny_window8_256_imagenet.h5# WARNING:tensorflow:Skipping loading weights for layer #601 (named predictions) due to mismatch in shape for weight predictions/kernel:0. Weight expects shape (768, 64). Received saved weight with shape (768, 1000)# WARNING:tensorflow:Skipping loading weights for layer #601 (named predictions) due to mismatch in shape for weight predictions/bias:0. Weight expects shape (64,). Received saved weight with shape (1000,)
Reload own model weights by set pretrained="xxx.h5". Better than calling model.load_weights directly, if reloading model with different input_shape and with weights shape not matching.
importosfromkeras_cv_attention_modelsimportcoatnetpretrained=os.path.expanduser('~/.keras/models/coatnet0_224_imagenet.h5')
mm=coatnet.CoAtNet1(input_shape=(384, 384, 3), pretrained=pretrained) # No sense, just showing usage
Alias name kecam can be used instead of keras_cv_attention_models. It's __init__.py only with from keras_cv_attention_models import *.
[Deprecated] tensorflow_addons is not imported by default. While reloading model depending on GroupNormalization like MobileViTV2 from h5 directly, needs to import tensorflow_addons manually first.
Export TF model to onnx. Needs tf2onnx for TF, pip install onnx tf2onnx onnxsim onnxruntime. For using PyTorch backend, exporting onnx is supported by PyTorch.
T4 Inference in the model tables are tested using trtexec on Tesla T4 with CUDA=12.0.1-1, Driver=525.60.13. All models are exported as ONNX using PyTorch backend, using batch_szie=1 only. Note: this data is for reference only, and vary in different batch sizes or benchmark tools or platforms or implementations.
All results are tested using colab trtexec.ipynb. Thus reproducible by any others.
attention_layers is __init__.py only, which imports core layers defined in model architectures. Like RelativePositionalEmbedding from botnet, outlook_attention from volo, and many other Positional Embedding Layers / Attention Blocks.
For custom dataset, custom_dataset_script.py can be used creating a json format file, which can be used as --data_name xxx.json for training, detail usage can be found in Custom recognition dataset.
# `antialias` is default enabled for resize, can be turned off be set `--disable_antialias`.
CUDA_VISIBLE_DEVICES='0' TF_XLA_FLAGS="--tf_xla_auto_jit=2" python3 train_script.py --seed 0 -s aotnet50
# Evaluation using input_shape (224, 224).# `antialias` usage should be same with training.
CUDA_VISIBLE_DEVICES='1' python3 eval_script.py -m aotnet50_epoch_103_val_acc_0.7674.h5 -i 224 --central_crop 0.95
# >>>> Accuracy top1: 0.78466 top5: 0.94088
Restore from break point by setting --restore_path and --initial_epoch, and keep other parameters same. restore_path is higher priority than model and additional_model_kwargs, also restore optimizer and loss. initial_epoch is mainly for learning rate scheduler. If not sure where it stopped, check checkpoints/{save_name}_hist.json.
importjsonwithopen("checkpoints/aotnet50_hist.json", "r") asff:
aa=json.load(ff)
len(aa['lr'])
# 41 ==> 41 epochs are finished, initial_epoch is 41 then, restart from epoch 42
custom_dataset_script.py can be used creating a json format file, which can be used as --data_name xxx.json for training, detail usage can be found in Custom detection dataset.
Default parameters for coco_train_script.py is EfficientDetD0 with input_shape=(256, 256, 3), batch_size=64, mosaic_mix_prob=0.5, freeze_backbone_epochs=32, total_epochs=105. Technically, it's any pyramid structure backbone + EfficientDet / YOLOX header / YOLOR header + anchor_free / yolor / efficientdet anchors combination supported.
Currently 4 types anchors supported, parameter anchors_mode controls which anchor to use, value in ["efficientdet", "anchor_free", "yolor", "yolov8"]. Default None for det_header presets.
NOTE: YOLOV8 has a default regression_len=64 for bbox output length. Typically it's 4 for other detection models, for yolov8 it's reg_max=16 -> regression_len = 16 * 4 == 64.
Note: COCO training still under testing, may change parameters and default behaviors. Take the risk if would like help developing.
coco_eval_script.py is used for evaluating model AP / AR on COCO validation set. It has a dependency pip install pycocotools which is not in package requirements. More usage can be found in COCO Evaluation.
custom_dataset_script.py can be used creating a tsv / json format file, which can be used as --data_name xxx.tsv for training, detail usage can be found in Custom caption dataset.
Train using clip_train_script.py on COCO captions Default --data_path is a testing one datasets/coco_dog_cat/captions.tsv.
Note: Works better with PyTorch backend, Tensorflow one seems overfitted if training logger like --epochs 200, and evaluation runs ~5 times slower. [???]
Dataset can be a directory containing images for basic DDPM training using images only, or a recognition json file created following Custom recognition dataset, which will train using labels as instruction.
Train using ddpm_train_script.py on cifar10 with labels Default --data_path is builtin cifar10.
# Set --eval_interval 50 as TF evaluation is rather slow [???]TF_XLA_FLAGS="--tf_xla_auto_jit=2"CUDA_VISIBLE_DEVICES=1pythonddpm_train_script.py--eval_interval50
Train Using PyTorch backend by setting KECAM_BACKEND='torch'
Currently TFLite not supporting tf.image.extract_patches / tf.transpose with len(perm) > 4. Some operations could be supported in latest or tf-nightly version, like previously not supported gelu / Conv2D with groups>1 are working now. May try if encountering issue.
Functions like model_surgery.convert_groups_conv2d_2_split_conv2d and model_surgery.convert_gelu_to_approximate are not needed using up-to-date TF version.
Not supporting VOLO / HaloNet models converting, cause they need a longer tf.transposeperm.
model_surgery.convert_dense_to_conv converts all Dense layer with 3D / 4D inputs to Conv1D / Conv2D, as currently TFLite xnnpack not supporting it.
Detection models including efficinetdet / yolox / yolor, model can be converted a TFLite format directly. If need DecodePredictions also included in TFLite model, need to set use_static_output=True for DecodePredictions, as TFLite requires a more static output shape. Model output shape will be fixed as [batch, max_output_size, 6]. The last dimension 6 means [bbox_top, bbox_left, bbox_bottom, bbox_right, label_index, confidence], and those valid ones are where confidence > 0.
""" Init model """fromkeras_cv_attention_modelsimportefficientdetmodel=efficientdet.EfficientDetD0(pretrained="coco")
""" Create a model with DecodePredictions using `use_static_output=True` """model.decode_predictions.use_static_output=True# parameters like score_threshold / iou_or_sigma can be set another value if needed.nn=model.decode_predictions(model.outputs[0], score_threshold=0.5)
bb=keras.models.Model(model.inputs[0], nn)
""" Convert TFLite """converter=tf.lite.TFLiteConverter.from_keras_model(bb)
open(bb.name+".tflite", "wb").write(converter.convert())
""" Inference test """fromkeras_cv_attention_models.imagenetimporteval_funcfromkeras_cv_attention_modelsimporttest_imagesdd=eval_func.TFLiteModelInterf(bb.name+".tflite")
imm=test_images.cat()
inputs=tf.expand_dims(tf.image.resize(imm, dd.input_shape[1:-1]), 0)
inputs=keras.applications.imagenet_utils.preprocess_input(inputs, mode='torch')
preds=dd(inputs)[0]
print(f"{preds.shape=}")
# preds.shape = (100, 6)pred=preds[preds[:, -1] >0]
bboxes, labels, confidences=pred[:, :4], pred[:, 4], pred[:, -1]
print(f"{bboxes=}, {labels=}, {confidences=}")
# bboxes = array([[0.22825494, 0.47238672, 0.816262 , 0.8700745 ]], dtype=float32),# labels = array([16.], dtype=float32),# confidences = array([0.8309707], dtype=float32)""" Show result """fromkeras_cv_attention_models.cocoimportdatadata.show_image_with_bboxes(imm, bboxes, labels, confidences, num_classes=90)
Set os environment export KECAM_BACKEND='torch' to enable this PyTorch backend.
Currently supports most recognition and detection models except hornet*gf / nfnets / volo. For detection models, using torchvision.ops.nms while running prediction.
Basic model build and prediction.
Will load same h5 weights as TF one if available.
Note: input_shape will auto fit image data format. Given input_shape=(224, 224, 3) or input_shape=(3, 224, 224), will both set to (3, 224, 224) if channels_first.
[Experimental] Set os environment export KECAM_BACKEND='keras_core' to enable this keras_core backend. Not using keras>3.0, as still not compiling with TensorFlow==2.15.0
keras-core has its own backends, supporting tensorflow / torch / jax, by editting ~/.keras/keras.json"backend" value.
Currently most recognition models except HaloNet / BotNet supported, also GPT2 / LLaMA2 supported.
Basic model build and prediction.
!pipinstallsentencepiece# required for llama2 tokenizeros.environ['KECAM_BACKEND'] ='keras_core'os.environ['KERAS_BACKEND'] ='jax'importkecamprint(f"{kecam.backend.backend() =}")
# kecam.backend.backend() = 'jax'mm=kecam.llama2.LLaMA2_42M()
# >>>> Load pretrained from: ~/.keras/models/llama2_42m_tiny_stories.h5mm.run_prediction('As evening fell, a maiden stood at the edge of a wood. In her hands,')
# >>>> Load tokenizer from file: ~/.keras/datasets/llama_tokenizer.model# <s># As evening fell, a maiden stood at the edge of a wood. In her hands, she held a beautiful diamond. Everyone was surprised to see it.# "What is it?" one of the kids asked.# "It's a diamond," the maiden said.# ...
Recognition Models
AotNet
Keras AotNet is just a ResNet / ResNetV2 like framework, that set parameters like attn_types and se_ratio and others, which is used to apply different types attention layer. Works like byoanet / byobnet from timm.
Default parameters set is a typical ResNet architecture with Conv2D use_bias=False and padding like PyTorch.
fromkeras_cv_attention_modelsimportaotnet# Mixing se and outlook and halo and mhsa and cot_attention, 21M parameters.# 50 is just a picked number that larger than the relative `num_block`.attn_types= [None, "outlook", ["bot", "halo"] *50, "cot"],
se_ratio= [0.25, 0, 0, 0],
model=aotnet.AotNet50V2(attn_types=attn_types, se_ratio=se_ratio, stem_type="deep", strides=1)
model.summary()
Code. The code here is licensed MIT. It is your responsibility to ensure you comply with licenses here and conditions of any dependent licenses. Where applicable, I've linked the sources/references for various components in docstrings. If you think I've missed anything please create an issue. So far all of the pretrained weights available here are pretrained on ImageNet and COCO with a select few that have some additional pretraining.
ImageNet Pretrained Weights. ImageNet was released for non-commercial research purposes only (https://image-net.org/download). It's not clear what the implications of that are for the use of pretrained weights from that dataset. Any models I have trained with ImageNet are done for research purposes and one should assume that the original dataset license applies to the weights. It's best to seek legal advice if you intend to use the pretrained weights in a commercial product.
COCO Pretrained Weights. Should follow cocodataset termsofuse. The annotations in COCO dataset belong to the COCO Consortium and are licensed under a Creative Commons Attribution 4.0 License. The COCO Consortium does not own the copyright of the images. Use of the images must abide by the Flickr Terms of Use. The users of the images accept full responsibility for the use of the dataset, including but not limited to the use of any copies of copyrighted images that they may create from the dataset.
Pretrained on more than ImageNet and COCO. Several weights included or references here were pretrained with proprietary datasets that I do not have access to. These include the Facebook WSL, SSL, SWSL ResNe(Xt) and the Google Noisy Student EfficientNet models. The Facebook models have an explicit non-commercial license (CC-BY-NC 4.0, https://github.com/facebookresearch/semi-supervised-ImageNet1K-models, https://github.com/facebookresearch/WSL-Images). The Google models do not appear to have any restriction beyond the Apache 2.0 license (and ImageNet concerns). In either case, you should contact Facebook or Google with any questions.
Citing
BibTeX
@misc{leondgarse,
author = {Leondgarse},
title = {Keras CV Attention Models},
year = {2022},
publisher = {GitHub},
journal = {GitHub repository},
doi = {10.5281/zenodo.6506947},
howpublished = {\url{https://github.com/leondgarse/keras_cv_attention_models}}
}