Skip to content

Commit

Permalink
[WIP] llama-70b
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Feb 6, 2024
1 parent 6eeeaa0 commit 3d99397
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 35 deletions.
6 changes: 3 additions & 3 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
patch_within_function,
patched_finfo,
)
from ..utils.misc import args_and_kwargs_to_kwargs_only
from ..utils.misc import args_and_kwargs_to_kwargs_only, is_main_worker
from ..utils.require_utils import requires_neuronx_distributed, requires_torch_xla
from .optimizer import NeuronAcceleratedOptimizer
from .scheduler import NeuronAcceleratedScheduler
Expand Down Expand Up @@ -219,14 +219,14 @@ def prepare_data_loader(self, data_loader: DataLoader, device_placement: Optiona
num_replicas = parallel_layers.parallel_state.get_data_parallel_size()
rank = parallel_layers.parallel_state.get_data_parallel_rank()
force_drop_last = parallel_layers.parallel_state.get_pipeline_model_parallel_size() > 1
if force_drop_last and xm.get_ordinal() == 0:
if is_main_worker() and force_drop_last:
logger.warning(
"Pipeline parallelsim: forcing the dataloader to drop the last incomplete batch because it can "
"cause failure if the last batch size is not divisible by the number of microbatches for the pipeline."
)
else:
num_replicas = xm.xrt_world_size()
rank = xm.get_local_ordinal()
rank = xm.get_ordinal()
if self.state.num_processes > 1:
data_loader = self._prepare_data_loader_for_distributed(
data_loader, num_replicas=num_replicas, rank=rank, force_drop_last=force_drop_last
Expand Down
8 changes: 4 additions & 4 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def predicate_func(layer):
for n, p in layer.named_parameters():
if p not in parameter_to_name:
xm.master_print(n)
return False
return True
names = {parameter_to_name[p] for p in layer.parameters()}
return names < names_of_the_parameters_to_consider

Expand All @@ -357,7 +357,7 @@ def predicate_func(layer):
sequence_parallel_enabled=sequence_parallel_enabled,
# should_parallelize_predicate_func=predicate_func,
)
xm.rendezvous("End of tensor parallelism")
# xm.rendezvous("End of tensor parallelism")

# Preparing the model for sequence parallelism:
sp_specs_cls = cls.SEQUENCE_PARALLELSIM_SPECS_CLS
Expand Down Expand Up @@ -507,7 +507,7 @@ def predicate_func(layer):
if left_uninitialized and hasattr(mod, "reset_parameters"):
initialize_torch_nn_module(mod, parameter_names)

xm.rendezvous("End of initalization")
# xm.rendezvous("End of initalization")

pp_size = get_pipeline_model_parallel_size()
if pp_size > 1:
Expand Down Expand Up @@ -535,7 +535,7 @@ def predicate_func(layer):
if gradient_checkpointing:
apply_checkpoint(model)

xm.rendezvous("End of pipeline paralellism")
# xxm.rendezvous("End of pipeline paralellism")

if checkpoint_dir is not None:
cls.load_model_checkpoint(model, checkpoint_dir)
Expand Down
1 change: 1 addition & 0 deletions optimum/neuron/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def on_train_begin(self, args: "TrainingArguments", state: TrainerState, control
"""
Event called at the beginning of training.
"""
return
if is_precompilation() or self.neuron_cache_path is None:
return
if self.push:
Expand Down
45 changes: 24 additions & 21 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
_ORIGINAL_NEURON_CACHE_PATH: Optional[Path] = None
_TMP_NEURON_CACHE_DIR: Optional[TemporaryDirectory] = None
_TMP_NEURON_CACHE_PATH: Optional[Path] = None
_TCP_STORE_ADDRESS = "127.0.0.1"
_TCP_STORE_ADDRESS = os.environ.get("MASTER_ADDR", "127.0.0.1")
_TCP_STORE_PORT = 5000


Expand All @@ -130,22 +130,22 @@
_ORIGINAL_NEURON_CACHE_PATH = get_neuron_cache_path()

# _ORIGINAL_NEURON_CACHE_PATH is `None` when the `--no-cache` flag is set.
if _ORIGINAL_NEURON_CACHE_PATH is not None:
if is_precompilation():
# During precompilation, we make sure to set the cache path to the defined compile cache path by the
# user. If nothing is specified, it is set to the default compile cache used by the Neuron compiler:
# /var/tmp/neuron-compile-cache
set_neuron_cache_path(_ORIGINAL_NEURON_CACHE_PATH)
else:
if os.environ["LOCAL_RANK"] == "0":
_TMP_NEURON_CACHE_DIR = NeuronCacheCallback.create_temporary_neuron_cache(get_neuron_cache_path())
store = torch.distributed.TCPStore(_TCP_STORE_ADDRESS, _TCP_STORE_PORT, is_master=True)
store.set("tmp_neuron_cache_path", _TMP_NEURON_CACHE_DIR.name)
_TMP_NEURON_CACHE_PATH = Path(_TMP_NEURON_CACHE_DIR.name)
else:
store = torch.distributed.TCPStore(_TCP_STORE_ADDRESS, _TCP_STORE_PORT, is_master=False)
_TMP_NEURON_CACHE_PATH = Path(store.get("tmp_neuron_cache_path").decode("utf-8"))
set_neuron_cache_path(_TMP_NEURON_CACHE_PATH)
# if _ORIGINAL_NEURON_CACHE_PATH is not None:
# if is_precompilation():
# # During precompilation, we make sure to set the cache path to the defined compile cache path by the
# # user. If nothing is specified, it is set to the default compile cache used by the Neuron compiler:
# # /var/tmp/neuron-compile-cache
# set_neuron_cache_path(_ORIGINAL_NEURON_CACHE_PATH)
# else:
# if os.environ["RANK"] == "0":
# _TMP_NEURON_CACHE_DIR = NeuronCacheCallback.create_temporary_neuron_cache(get_neuron_cache_path())
# store = torch.distributed.TCPStore(_TCP_STORE_ADDRESS, _TCP_STORE_PORT, is_master=True)
# store.set("tmp_neuron_cache_path", _TMP_NEURON_CACHE_DIR.name)
# _TMP_NEURON_CACHE_PATH = Path(_TMP_NEURON_CACHE_DIR.name)
# else:
# store = torch.distributed.TCPStore(_TCP_STORE_ADDRESS, _TCP_STORE_PORT, is_master=False)
# _TMP_NEURON_CACHE_PATH = Path(store.get("tmp_neuron_cache_path").decode("utf-8"))
# set_neuron_cache_path(_TMP_NEURON_CACHE_PATH)

torch.distributed.init_process_group(backend="xla")
if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla):
Expand Down Expand Up @@ -196,8 +196,9 @@ def __init__(self, *args, **kwargs):
if self.args.local_rank <= 0:
logger.setLevel(logging.INFO)

push = self.args.local_rank <= 0 and not is_precompilation() and not self.args.skip_cache_push
fetch = self.args.local_rank <= 0 or self.args.mp_plugin.should_parallelize
rank = xm.get_ordinal()
push = rank <= 0 and not is_precompilation() and not self.args.skip_cache_push
fetch = rank <= 0 or self.args.mp_plugin.should_parallelize

callback = NeuronCacheCallback(
tmp_neuron_cache=_TMP_NEURON_CACHE_PATH,
Expand All @@ -207,7 +208,7 @@ def __init__(self, *args, **kwargs):
wait_for_everyone_on_fetch=True,
wait_for_everyone_on_push=True,
)
self.add_callback(callback)
# self.add_callback(callback)

# Make the model Neuron-compatible for generation.
patch_generation_mixin_to_neuron_generation_mixin(self.model)
Expand Down Expand Up @@ -422,6 +423,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for
self._globalstep_last_logged = self.state.global_step
self.store_flos()

# if is_main_worker():
self.log(logs)

metrics = None
Expand Down Expand Up @@ -785,7 +787,8 @@ def _inner_training_loop(
# FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.

# Train!
parameter_count = get_model_param_count(model, trainable_only=True)
# parameter_count = get_model_param_count(model, trainable_only=True)
parameter_count = 10
if is_main_worker():
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples:,}")
Expand Down
11 changes: 7 additions & 4 deletions optimum/neuron/utils/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,10 +1043,13 @@ def push_to_cache_on_hub(
success = False

# Adding the model to the registry if the upload was successful.
# TODO: it slows down training since it pushes a lot of stuff to the registry.
# It is needed to find a better way. Disabling it for now since it's not used at all.
if success:
try:
add_in_registry(cache_repo_id, neuron_hash)
except HfHubHTTPError:
pass
pass
# try:
# add_in_registry(cache_repo_id, neuron_hash)
# except HfHubHTTPError:
# pass

return CachedModelOnTheHub(cache_repo_id, path_in_repo)
6 changes: 3 additions & 3 deletions optimum/neuron/utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,8 @@ def numel(parameter_name, parameter) -> int:
param_count = sum(numel(n, p) for n, p in named_parameters if not trainable_only or p.requires_grad)

if get_pipeline_model_parallel_size() > 1:
param_count = torch.tensor(param_count, dtype=torch.double).to(xm.xla_device())
param_count = xm.all_reduce(xm.REDUCE_SUM, param_count, groups=get_pipeline_model_parallel_group(as_list=True))
param_count = param_count.detach().item()
param_count = torch.tensor(param_count, dtype=torch.float32).to(xm.xla_device())
# param_count = xm.all_reduce(xm.REDUCE_SUM, param_count, groups=get_pipeline_model_parallel_group(as_list=True))
param_count = int(param_count.detach().item())

return param_count

0 comments on commit 3d99397

Please sign in to comment.