Skip to content

Commit

Permalink
FEAT: TimeLLM is faster and supports more LLMs (#1139)
Browse files Browse the repository at this point in the history
* Fix issue #950: Reduce TimeLLM setup time for training

* Restore changes on the examples

* Revert changes to nbs/models.ipynb, nbs/models.softs.ipynb and neuralforecast/_modidx.py

* Revert changes to nbs/models.ipynb, nbs/models.softs.ipynb and neuralforecast/_modidx.py

* Refactor code to dynamically load models with AutoModel, AutoTokenizer, and AutoConfig

- Updated load_model_and_tokenizer function to use AutoModel, AutoTokenizer, and AutoConfig for flexible model loading.
- Included default model(gpt2) for cases where the specified model fails to load.
- Kept llm, llm_config, and llm_tokenizer arguments to minimize changes.
- Changed llm from storing pretrained weights to accepting pretrained model path to reduce necessary modifications.

This update enhances the flexibility and reliability of model loading based on received feedback while minimizing necessary changes.

* Refactor code to dynamically load models with AutoModel, AutoTokenizer, and AutoConfig

- Updated load_model_and_tokenizer function to use AutoModel, AutoTokenizer, and AutoConfig for flexible model loading.
- Included default model(gpt2) for cases where the specified model fails to load.
- Kept llm, llm_config, and llm_tokenizer arguments to minimize changes.
- Changed llm from storing pretrained weights to accepting pretrained model path to reduce necessary modifications.

This update enhances the flexibility and reliability of model loading based on received feedback while minimizing necessary changes.

* clear output

* modify test code

* Optimize model loading and add deprecation warning

- Simplify model loading logic
- Add constant for default model name
- Improve error handling for model loading
- Add success messages for model loading
- Implement deprecation warning for 'llm_config' and 'llm_tokenizer' parameters
- Update print messages for clarity
- Remove redundant code

This commit improves code readability, maintainability, and user experience
by providing clearer feedback and warnings about deprecated parameters.

* Resolved conflict in nbs/models.timellm.ipynb

---------

Co-authored-by: ive2go <[email protected]>
Co-authored-by: Olivier Sprangers <[email protected]>
  • Loading branch information
3 people authored Sep 13, 2024
1 parent 5f35427 commit 29366af
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 89 deletions.
112 changes: 39 additions & 73 deletions nbs/models.timellm.ipynb
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@
"from neuralforecast.losses.pytorch import MAE\n",
"\n",
"try:\n",
" from transformers import GPT2Config, GPT2Model, GPT2Tokenizer\n",
" from transformers import AutoModel, AutoTokenizer, AutoConfig\n",
" IS_TRANSFORMERS_INSTALLED = True\n",
"except ImportError:\n",
" IS_TRANSFORMERS_INSTALLED = False"
" IS_TRANSFORMERS_INSTALLED = False\n",
"\n",
"import warnings"
]
},
{
Expand Down Expand Up @@ -321,14 +323,14 @@
" `stride`: int=8, stride of patch.<br>\n",
" `d_ff`: int=128, dimension of fcn.<br>\n",
" `top_k`: int=5, top tokens to consider.<br>\n",
" `d_llm`: int=768, hidden dimension of LLM.<br>\n",
" `d_llm`: int=768, hidden dimension of LLM.<br> # LLama7b:4096; GPT2-small:768; BERT-base:768\n",
" `d_model`: int=32, dimension of model.<br>\n",
" `n_heads`: int=8, number of heads in attention layer.<br>\n",
" `enc_in`: int=7, encoder input size.<br>\n",
" `dec_in`: int=7, decoder input size.<br>\n",
" `llm` = None, LLM model to use. If not specified, it will use GPT-2 from https://huggingface.co/openai-community/gpt2\"<br>\n",
" `llm_config` = None, configuration of LLM. If not specified, it will use the configuration of GPT-2 from https://huggingface.co/openai-community/gpt2\"<br>\n",
" `llm_tokenizer` = None, tokenizer of LLM. If not specified, it will use the GPT-2 tokenizer from https://huggingface.co/openai-community/gpt2\"<br>\n",
" `llm` = None, Path to pretrained LLM model to use. If not specified, it will use GPT-2 from https://huggingface.co/openai-community/gpt2\"<br>\n",
" `llm_config` = Deprecated, configuration of LLM. If not specified, it will use the configuration of GPT-2 from https://huggingface.co/openai-community/gpt2\"<br>\n",
" `llm_tokenizer` = Deprecated, tokenizer of LLM. If not specified, it will use the GPT-2 tokenizer from https://huggingface.co/openai-community/gpt2\"<br>\n",
" `llm_num_hidden_layers` = 32, hidden layers in LLM\n",
" `llm_output_attention`: bool = True, whether to output attention in encoder.<br>\n",
" `llm_output_hidden_states`: bool = True, whether to output hidden states.<br>\n",
Expand Down Expand Up @@ -456,19 +458,34 @@
" self.enc_in = enc_in\n",
" self.dec_in = dec_in\n",
"\n",
" self.llm_config = llm_config\n",
" self.llm = llm\n",
" self.llm_tokenizer = llm_tokenizer\n",
" DEFAULT_MODEL = \"openai-community/gpt2\"\n",
"\n",
" if self.llm is None:\n",
" if llm is None:\n",
" if not IS_TRANSFORMERS_INSTALLED:\n",
" raise ImportError(\"Please install `transformers` to use the default LLM\")\n",
" \n",
" print(\"Using GPT2 model as default and ignoring `llm_config` and `llm_tokenizer`\")\n",
"\n",
" self.llm_confg = GPT2Config.from_pretrained('openai-community/gpt2')\n",
" self.llm = GPT2Model.from_pretrained('openai-community/gpt2', config=self.llm_confg)\n",
" self.llm_tokenizer = GPT2Tokenizer.from_pretrained('openai-community/gpt2')\n",
" raise ImportError(\n",
" \"Please install `transformers` to use the default LLM.\"\n",
" )\n",
" \n",
" print(f\"Using {DEFAULT_MODEL} as default.\")\n",
" model_name = DEFAULT_MODEL\n",
" else:\n",
" model_name = llm\n",
"\n",
" if llm_config is not None or llm_tokenizer is not None:\n",
" warnings.warn(\"'llm_config' and 'llm_tokenizer' parameters are deprecated and will be ignored. \"\n",
" \"The config and tokenizer will be automatically loaded from the specified model.\", \n",
" DeprecationWarning)\n",
"\n",
" try:\n",
" self.llm_config = AutoConfig.from_pretrained(model_name)\n",
" self.llm = AutoModel.from_pretrained(model_name, config=self.llm_config)\n",
" self.llm_tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
" print(f\"Successfully loaded model: {model_name}\")\n",
" except EnvironmentError:\n",
" print(f\"Failed to load {model_name}. Loading the default model ({DEFAULT_MODEL})...\")\n",
" self.llm_config = AutoConfig.from_pretrained(DEFAULT_MODEL)\n",
" self.llm = AutoModel.from_pretrained(DEFAULT_MODEL, config=self.llm_config)\n",
" self.llm_tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL)\n",
"\n",
" self.llm_num_hidden_layers = llm_num_hidden_layers\n",
" self.llm_output_attention = llm_output_attention\n",
Expand Down Expand Up @@ -626,27 +643,20 @@
"from neuralforecast.models import TimeLLM\n",
"from neuralforecast.utils import AirPassengersPanel, augment_calendar_df\n",
"\n",
"from transformers import GPT2Config, GPT2Model, GPT2Tokenizer\n",
"\n",
"AirPassengersPanel, calendar_cols = augment_calendar_df(df=AirPassengersPanel, freq='M')\n",
"\n",
"Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]] # 132 train\n",
"Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n",
"\n",
"gpt2_config = GPT2Config.from_pretrained('openai-community/gpt2')\n",
"gpt2 = GPT2Model.from_pretrained('openai-community/gpt2', config=gpt2_config)\n",
"gpt2_tokenizer = GPT2Tokenizer.from_pretrained('openai-community/gpt2')\n",
"\n",
"prompt_prefix = \"The dataset contains data on monthly air passengers. There is a yearly seasonality\"\n",
"\n",
"timellm = TimeLLM(h=12,\n",
" input_size=36,\n",
" llm=gpt2,\n",
" llm_config=gpt2_config,\n",
" llm_tokenizer=gpt2_tokenizer,\n",
" llm='openai-community/gpt2',\n",
" prompt_prefix=prompt_prefix,\n",
" batch_size=24,\n",
" windows_batch_size=24)\n",
" batch_size=16,\n",
" valid_batch_size=16,\n",
" windows_batch_size=16)\n",
"\n",
"nf = NeuralForecast(\n",
" models=[timellm],\n",
Expand All @@ -662,51 +672,7 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"#| eval: false\n",
"try:\n",
" from transformers import GPT2Config, GPT2Model, GPT2Tokenizer\n",
"except ImportError:\n",
" raise ImportError('The transformers library is required for Time-LLM to work')\n",
"\n",
"from neuralforecast import NeuralForecast\n",
"from neuralforecast.models import TimeLLM\n",
"\n",
"from neuralforecast.utils import AirPassengersPanel, augment_calendar_df\n",
"\n",
"AirPassengersPanel, calendar_cols = augment_calendar_df(df=AirPassengersPanel, freq='M')\n",
"\n",
"Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]] # 132 train\n",
"Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n",
"\n",
"gpt2_config = GPT2Config.from_pretrained('openai-community/gpt2')\n",
"gpt2 = GPT2Model.from_pretrained('openai-community/gpt2', config=gpt2_config)\n",
"gpt2_tokenizer = GPT2Tokenizer.from_pretrained('openai-community/gpt2')\n",
"\n",
"prompt_prefix = \"The dataset contains data on monthly air passengers. There is a yearly seasonality\"\n",
"\n",
"timellm = TimeLLM(h=12,\n",
" input_size=36,\n",
" llm=gpt2,\n",
" llm_config=gpt2_config,\n",
" llm_tokenizer=gpt2_tokenizer,\n",
" prompt_prefix=prompt_prefix,\n",
" batch_size=24,\n",
" windows_batch_size=24)\n",
"\n",
"nf = NeuralForecast(\n",
" models=[timellm],\n",
" freq='M'\n",
")\n",
"\n",
"nf.fit(df=Y_train_df, val_size=12)\n",
"forecasts = nf.predict(futr_df=Y_test_df)\n",
"\n",
"# Asserts\n",
"assert 'TimeLLM' in forecasts.columns, \"The column TimeLLM does not exist. Something went wrong with the model\"\n",
"assert not forecasts['TimeLLM'].isnull().any(), \"Predictions contain NaN values.\""
]
"source": []
}
],
"metadata": {
Expand Down
47 changes: 31 additions & 16 deletions neuralforecast/models/timellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
from ..losses.pytorch import MAE

try:
from transformers import GPT2Config, GPT2Model, GPT2Tokenizer
from transformers import AutoModel, AutoTokenizer, AutoConfig

IS_TRANSFORMERS_INSTALLED = True
except ImportError:
IS_TRANSFORMERS_INSTALLED = False

import warnings

# %% ../../nbs/models.timellm.ipynb 9
class ReplicationPad1d(nn.Module):
"""
Expand Down Expand Up @@ -256,14 +258,14 @@ class TimeLLM(BaseWindows):
`stride`: int=8, stride of patch.<br>
`d_ff`: int=128, dimension of fcn.<br>
`top_k`: int=5, top tokens to consider.<br>
`d_llm`: int=768, hidden dimension of LLM.<br>
`d_llm`: int=768, hidden dimension of LLM.<br> # LLama7b:4096; GPT2-small:768; BERT-base:768
`d_model`: int=32, dimension of model.<br>
`n_heads`: int=8, number of heads in attention layer.<br>
`enc_in`: int=7, encoder input size.<br>
`dec_in`: int=7, decoder input size.<br>
`llm` = None, LLM model to use. If not specified, it will use GPT-2 from https://huggingface.co/openai-community/gpt2"<br>
`llm_config` = None, configuration of LLM. If not specified, it will use the configuration of GPT-2 from https://huggingface.co/openai-community/gpt2"<br>
`llm_tokenizer` = None, tokenizer of LLM. If not specified, it will use the GPT-2 tokenizer from https://huggingface.co/openai-community/gpt2"<br>
`llm` = None, Path to pretrained LLM model to use. If not specified, it will use GPT-2 from https://huggingface.co/openai-community/gpt2"<br>
`llm_config` = Deprecated, configuration of LLM. If not specified, it will use the configuration of GPT-2 from https://huggingface.co/openai-community/gpt2"<br>
`llm_tokenizer` = Deprecated, tokenizer of LLM. If not specified, it will use the GPT-2 tokenizer from https://huggingface.co/openai-community/gpt2"<br>
`llm_num_hidden_layers` = 32, hidden layers in LLM
`llm_output_attention`: bool = True, whether to output attention in encoder.<br>
`llm_output_hidden_states`: bool = True, whether to output hidden states.<br>
Expand Down Expand Up @@ -395,25 +397,38 @@ def __init__(
self.enc_in = enc_in
self.dec_in = dec_in

self.llm_config = llm_config
self.llm = llm
self.llm_tokenizer = llm_tokenizer
DEFAULT_MODEL = "openai-community/gpt2"

if self.llm is None:
if llm is None:
if not IS_TRANSFORMERS_INSTALLED:
raise ImportError(
"Please install `transformers` to use the default LLM"
"Please install `transformers` to use the default LLM."
)

print(
"Using GPT2 model as default and ignoring `llm_config` and `llm_tokenizer`"
print(f"Using {DEFAULT_MODEL} as default.")
model_name = DEFAULT_MODEL
else:
model_name = llm

if llm_config is not None or llm_tokenizer is not None:
warnings.warn(
"'llm_config' and 'llm_tokenizer' parameters are deprecated and will be ignored. "
"The config and tokenizer will be automatically loaded from the specified model.",
DeprecationWarning,
)

self.llm_confg = GPT2Config.from_pretrained("openai-community/gpt2")
self.llm = GPT2Model.from_pretrained(
"openai-community/gpt2", config=self.llm_confg
try:
self.llm_config = AutoConfig.from_pretrained(model_name)
self.llm = AutoModel.from_pretrained(model_name, config=self.llm_config)
self.llm_tokenizer = AutoTokenizer.from_pretrained(model_name)
print(f"Successfully loaded model: {model_name}")
except EnvironmentError:
print(
f"Failed to load {model_name}. Loading the default model ({DEFAULT_MODEL})..."
)
self.llm_tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
self.llm_config = AutoConfig.from_pretrained(DEFAULT_MODEL)
self.llm = AutoModel.from_pretrained(DEFAULT_MODEL, config=self.llm_config)
self.llm_tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL)

self.llm_num_hidden_layers = llm_num_hidden_layers
self.llm_output_attention = llm_output_attention
Expand Down

0 comments on commit 29366af

Please sign in to comment.