diff --git a/library/train_util.py b/library/train_util.py index bdf7774e4..c4845c54b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4141,22 +4141,7 @@ def get_optimizer(args, trainable_params): raise AttributeError( "No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" ) - elif optimizer_type == "Ademamix8bit".lower(): - logger.info(f"use 8-bit Ademamix optimizer | {optimizer_kwargs}") - try: - optimizer_class = bnb.optim.AdEMAMix8bit - except AttributeError: - raise AttributeError( - "No Ademamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / Ademamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" - ) - elif optimizer_type == "PagedAdemamix8bit".lower(): - logger.info(f"use 8-bit PagedAdemamix optimizer | {optimizer_kwargs}") - try: - optimizer_class = bnb.optim.PagedAdEMAMix8bit - except AttributeError: - raise AttributeError( - "No PagedAdemamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / PagedAdemamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" - ) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "PagedAdamW".lower():