Skip to content

Commit

Permalink
chore(format): run black on main (#510)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] and github-actions[bot] authored Jul 2, 2024
1 parent 46200b3 commit a13daa4
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
4 changes: 3 additions & 1 deletion ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,9 @@ def _load(

if gpt_config_path:
cfg = OmegaConf.load(gpt_config_path)
gpt = GPT(**cfg, use_flash_attn=use_flash_attn, device=device, logger=self.logger).eval()
gpt = GPT(
**cfg, use_flash_attn=use_flash_attn, device=device, logger=self.logger
).eval()
assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
gpt.load_state_dict(torch.load(gpt_ckpt_path, weights_only=True, mmap=True))
gpt.prepare(compile=compile and "cuda" in str(device))
Expand Down
12 changes: 9 additions & 3 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def get(self) -> bool:
return self._interrupt

def _build_llama(
self, config: omegaconf.DictConfig, device: torch.device,
self,
config: omegaconf.DictConfig,
device: torch.device,
) -> LlamaModel:

model = None
Expand All @@ -122,7 +124,9 @@ def _build_llama(
**config,
attn_implementation="flash_attention_2",
)
self.logger.warn("enabling flash_attention_2 may make gpt be even slower")
self.logger.warn(
"enabling flash_attention_2 may make gpt be even slower"
)
else:
llama_config = LlamaConfig(**config)
model = LlamaModel(llama_config)
Expand Down Expand Up @@ -439,7 +443,9 @@ def generate(
)
del_all(model_input)
attentions.append(outputs.attentions)
hidden_states = outputs.last_hidden_state.to(self.device, dtype=torch.float) # 🐻
hidden_states = outputs.last_hidden_state.to(
self.device, dtype=torch.float
) # 🐻
past_key_values = outputs.past_key_values
del_all(outputs)
if return_hidden:
Expand Down

0 comments on commit a13daa4

Please sign in to comment.