Skip to content

Commit

Permalink
update transformers imports for deepspeed and `is_torch_xla_availab…
Browse files Browse the repository at this point in the history
…le` (#2012)

* change deepspeed to integrations.deepspeed

* add version check and change tpu to xla

* add version check
  • Loading branch information
Rohan138 authored Sep 7, 2024
1 parent 29f23f1 commit 2335ec2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
24 changes: 20 additions & 4 deletions optimum/onnxruntime/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
from torch.utils.data import Dataset, RandomSampler
from transformers.data.data_collator import DataCollator
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled
from transformers.modeling_utils import PreTrainedModel, unwrap_model
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer import Trainer
Expand All @@ -81,10 +80,10 @@
is_apex_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_tpu_available,
)

from ..utils import logging
from ..utils.import_utils import check_if_transformers_greater
from .training_args import ORTOptimizerNames, ORTTrainingArguments
from .utils import (
is_onnxruntime_training_available,
Expand All @@ -94,8 +93,25 @@
if is_apex_available():
from apex import amp

if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
if check_if_transformers_greater("4.33"):
from transformers.integrations.deepspeed import (
deepspeed_init,
deepspeed_load_checkpoint,
is_deepspeed_zero3_enabled,
)
else:
from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled

if check_if_transformers_greater("4.39"):
from transformers.utils import is_torch_xla_available

if is_torch_xla_available():
import torch_xla.core.xla_model as xm
else:
from transformers.utils import is_torch_tpu_available

if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm

if TYPE_CHECKING:
import optuna
Expand Down
7 changes: 6 additions & 1 deletion optimum/onnxruntime/trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
import torch
from torch import nn
from torch.utils.data import Dataset
from transformers.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer_utils import PredictionOutput
from transformers.utils import is_accelerate_available, logging

from ..utils.import_utils import check_if_transformers_greater
from .trainer import ORTTrainer


Expand All @@ -33,6 +33,11 @@
"The package `accelerate` is required to use the ORTTrainer. Please install it following https://huggingface.co/docs/accelerate/basic_tutorials/install."
)

if check_if_transformers_greater("4.33"):
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
else:
from transformers.deepspeed import is_deepspeed_zero3_enabled

logger = logging.get_logger(__name__)


Expand Down

0 comments on commit 2335ec2

Please sign in to comment.