Skip to content

Commit

Permalink
feat: add NeuronModel base class
Browse files Browse the repository at this point in the history
This base class will implement transformers PreTrainedModel methods
that are not implemented in optimum PreTrainedModel base class.
  • Loading branch information
dacorvo committed May 23, 2024
1 parent d4cdf77 commit 53f7ed4
Show file tree
Hide file tree
Showing 12 changed files with 673 additions and 633 deletions.
8 changes: 4 additions & 4 deletions docs/source/package_reference/modeling.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ limitations under the License.

## Generic model classes

### NeuronBaseModel
### NeuronTracedModel

The `NeuronBaseModel` class is available for instantiating a base Neuron model without a specific head.
The `NeuronTracedModel` class is available for instantiating a base Neuron model without a specific head.
It is used as the base class for all tasks but text generation.

[[autodoc]] modeling_base.NeuronBaseModel
[[autodoc]] modeling_traced.NeuronTracedModel

### NeuronDecoderModel

Expand Down Expand Up @@ -104,4 +104,4 @@ The following Neuron model classes are available for natural language processing

### NeuronStableDiffusionXLInpaintPipeline
[[autodoc]] modeling_diffusion.NeuronStableDiffusionXLInpaintPipeline
- __call__
- __call__
4 changes: 2 additions & 2 deletions optimum/neuron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"hf_argparser": ["NeuronHfArgumentParser"],
"trainers": ["NeuronTrainer", "Seq2SeqNeuronTrainer"],
"training_args": ["NeuronTrainingArguments", "Seq2SeqNeuronTrainingArguments"],
"modeling_base": ["NeuronBaseModel"],
"modeling_traced": ["NeuronTracedModel"],
"modeling": [
"NeuronModelForFeatureExtraction",
"NeuronModelForSentenceTransformers",
Expand Down Expand Up @@ -73,7 +73,6 @@
NeuronModelForSequenceClassification,
NeuronModelForTokenClassification,
)
from .modeling_base import NeuronBaseModel
from .modeling_decoder import NeuronDecoderModel
from .modeling_diffusion import (
NeuronLatentConsistencyModelPipeline,
Expand All @@ -85,6 +84,7 @@
NeuronStableDiffusionXLPipeline,
)
from .modeling_seq2seq import NeuronModelForSeq2SeqLM
from .modeling_traced import NeuronTracedModel
from .pipelines import pipeline
from .trainers import NeuronTrainer, Seq2SeqNeuronTrainer
from .training_args import NeuronTrainingArguments, Seq2SeqNeuronTrainingArguments
Expand Down
20 changes: 10 additions & 10 deletions optimum/neuron/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@
from transformers.utils import ModelOutput

from .generation import TokenSelector
from .modeling_base import NeuronBaseModel
from .modeling_decoder import NeuronDecoderModel
from .modeling_traced import NeuronTracedModel


if TYPE_CHECKING:
Expand All @@ -61,13 +61,13 @@
_TOKENIZER_FOR_DOC = "AutoTokenizer"

NEURON_MODEL_START_DOCSTRING = r"""
This model inherits from [`~neuron.modeling.NeuronBaseModel`]. Check the superclass documentation for the generic methods the
This model inherits from [`~neuron.modeling.NeuronTracedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving)
Args:
config (`transformers.PretrainedConfig`): [PretrainedConfig](https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig) is the Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`optimum.neuron.modeling.NeuronBaseModel.from_pretrained`] method to load the model weights.
configuration. Check out the [`optimum.neuron.modeling.NeuronTracedModel.from_pretrained`] method to load the model weights.
model (`torch.jit._script.ScriptModule`): [torch.jit._script.ScriptModule](https://pytorch.org/docs/stable/generated/torch.jit.ScriptModule.html) is the TorchScript module with embedded NEFF(Neuron Executable File Format) compiled by neuron(x) compiler.
"""

Expand Down Expand Up @@ -125,7 +125,7 @@
""",
NEURON_MODEL_START_DOCSTRING,
)
class NeuronModelForFeatureExtraction(NeuronBaseModel):
class NeuronModelForFeatureExtraction(NeuronTracedModel):
"""
Feature Extraction model on Neuron devices.
"""
Expand Down Expand Up @@ -198,7 +198,7 @@ def forward(
""",
NEURON_MODEL_START_DOCSTRING,
)
class NeuronModelForSentenceTransformers(NeuronBaseModel):
class NeuronModelForSentenceTransformers(NeuronTracedModel):
"""
Sentence Transformers model on Neuron devices.
"""
Expand Down Expand Up @@ -283,7 +283,7 @@ def forward(
""",
NEURON_MODEL_START_DOCSTRING,
)
class NeuronModelForMaskedLM(NeuronBaseModel):
class NeuronModelForMaskedLM(NeuronTracedModel):
"""
Masked language model for on Neuron devices.
"""
Expand Down Expand Up @@ -353,7 +353,7 @@ def forward(
""",
NEURON_MODEL_START_DOCSTRING,
)
class NeuronModelForQuestionAnswering(NeuronBaseModel):
class NeuronModelForQuestionAnswering(NeuronTracedModel):
"""
Question Answering model on Neuron devices.
"""
Expand Down Expand Up @@ -422,7 +422,7 @@ def forward(
""",
NEURON_MODEL_START_DOCSTRING,
)
class NeuronModelForSequenceClassification(NeuronBaseModel):
class NeuronModelForSequenceClassification(NeuronTracedModel):
"""
Sequence Classification model on Neuron devices.
"""
Expand Down Expand Up @@ -490,7 +490,7 @@ def forward(
""",
NEURON_MODEL_START_DOCSTRING,
)
class NeuronModelForTokenClassification(NeuronBaseModel):
class NeuronModelForTokenClassification(NeuronTracedModel):
"""
Token Classification model on Neuron devices.
"""
Expand Down Expand Up @@ -571,7 +571,7 @@ def forward(
""",
NEURON_MODEL_START_DOCSTRING,
)
class NeuronModelForMultipleChoice(NeuronBaseModel):
class NeuronModelForMultipleChoice(NeuronTracedModel):
"""
Multiple choice model on Neuron devices.
"""
Expand Down
Loading

0 comments on commit 53f7ed4

Please sign in to comment.