Skip to content

Commit

Permalink
add keras_core_functional.py
Browse files Browse the repository at this point in the history
  • Loading branch information
leondgarse committed Aug 4, 2023
1 parent 02771fa commit d4e5b53
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 1 deletion.
34 changes: 34 additions & 0 deletions keras_cv_attention_models/keras_core_functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import keras_core
from keras_core.ops import *
from keras_core.ops import concatenate as concat
from keras_core.ops import mean as reduce_mean
from keras_core.ops import sum as reduce_sum
from keras_core.ops import max as reduce_max
from keras_core.ops import min as reduce_min
from keras_core.ops import power as pow
from keras_core.ops import clip as clip_by_value
from keras_core.ops.image import extract_patches


def resize(images, size, method="bilinear", preserve_aspect_ratio=False, antialias=False, name=None):
return keras_core.ops.image.resize(images, size, interpolation=method, antialias=antialias, data_format=keras_core.backend.image_data_format())


def split(inputs, num_or_size_splits, axis=0, num=None, name="split"):
if isinstance(num_or_size_splits, int):
return keras_core.ops.split(inputs, num_or_size_splits, axis=axis)

axis = (len(inputs.shape) + axis) if axis < 0 else axis
split_axis_shape = inputs.shape[axis]
assert split_axis_shape is not None

size_splits = num_or_size_splits
size_splits = [0 if ii is None or ii == -1 else ii for ii in size_splits]
num_unknown_dim = sum([ii == 0 for ii in size_splits])
assert num_unknown_dim < 2, "At most one unknown dimension in num_or_size_splits: {}".format(num_or_size_splits)

if num_unknown_dim == 1:
size_splits = [(split_axis_shape - sum(size_splits)) if ii == 0 else ii for ii in size_splits]

cum_split = [sum(num_or_size_splits[: id + 1]) for id, _ in enumerate(size_splits[:-1])]
return keras_core.ops.split(inputs, cum_split, axis=axis)
41 changes: 41 additions & 0 deletions keras_cv_attention_models/llama2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from keras_cv_attention_models.llama2.llama2 import Llama2, Llama2_7B, RunPrediction, PositionalEncodingFourierRot1D, RMSNorm

__head_doc__ = """
Keras implementation of [Github openai/gpt-2](https://github.com/openai/gpt-2).
Paper [Language Models are Unsupervised Multitask Learners](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf).
"""

__tail_doc__ = """ vocab_size: model vocab size.
max_block_size: number of tokens generated in each sample.
include_top: boolena value if including output Dense head layer. Set false to exclude the head layer.
dropout: float value for drop out rate for Embedding layer and attention blocks.
activation: activation used in whole model, default `gelu/app`.
pretrained: None or one of ["webtext", "huggingface"].
- if "webtext", will try to download and load ported weights if available.
- if "huggingface", will try converting and loading weights from huggingface `transformers` pacakge.
- if None, will initialize model with ranbdom weights.
Returns:
A `keras.Model` instance.
"""

Llama2.__doc__ = __head_doc__ + """
Args:
num_blocks: .
embedding_size: .
num_heads: .
block_use_bias: .
model_name: string, model name.
""" + __tail_doc__ + """
Model architectures:
| Model | Params | FLOPs | vocab_size | LAMBADA PPL |
| ------------| ------- | ------- | ---------- | ----------- |
| GPT2_Base | 163.04M | 146.42G | 50257 | 35.13 |
| GPT2_Medium | 406.29M | 415.07G | 50257 | 15.60 |
| GPT2_Large | 838.36M | 890.28G | 50257 | 10.87 |
| GPT2_XLarge | 1.638B | 1758.3G | 50257 | 8.63 |
"""

Llama2_7B.__doc__ = __head_doc__ + """
Args:
""" + __tail_doc__
5 changes: 4 additions & 1 deletion keras_cv_attention_models/pytorch_backend/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def __init__(self, value=0):
super().__init__(seed=None)

def __call__(self, shape, dtype=None, **kwargs):
return torch.nn.init.constant_(torch.empty(shape), val=self.value)
if hasattr(self.value, "shape") and tuple(self.value.shape) == tuple(shape):
return self.value
else:
return torch.nn.init.constant_(torch.empty(shape), val=self.value)

def get_config(self):
return {"value": self.value}
Expand Down

0 comments on commit d4e5b53

Please sign in to comment.