Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(whl): add AWR algorithm. #828

Merged
merged 7 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ It provides **python-first** and **asynchronous-native** task and middleware abs
- Offline RL algorithms: BCQ, CQL, TD3BC, Decision Transformer, EDAC, Diffuser, Decision Diffuser, SO2
- Model-based RL algorithms: SVG, STEVE, MBPO, DDPPO, DreamerV3
- Exploration algorithms: HER, RND, ICM, NGU
- LLM + RL Algorithms: PPO-max, DPO, PromptPG
- LLM + RL Algorithms: PPO-max, DPO, PromptPG, PromptAWR
- Other algorithms: such as PER, PLR, PCGrad
- MCTS + RL algorithms: AlphaZero, MuZero, please refer to [LightZero](https://github.com/opendilab/LightZero)
- Generative Model + RL algorithms: Diffusion-QL, QGPO, SRPO, please refer to [GenerativeRL](https://github.com/opendilab/GenerativeRL)
Expand Down Expand Up @@ -283,6 +283,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 54 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 |
| 55 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)<br>[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 |
| 56 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 |
| 57 | [AWR](https://arxiv.org/pdf/1910.00177) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [policy/ibc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/prompt_awr.py) | python3 -u tabmwp_awr_config.py |

</details>

Expand Down
43 changes: 33 additions & 10 deletions ding/model/template/language_transformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Dict
from typing import List, Dict, Optional
import torch
from torch import nn

Expand All @@ -18,13 +18,16 @@ class LanguageTransformer(nn.Module):
Interfaces:
``__init__``, ``forward``
"""
mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']

def __init__(
self,
model_name: str = "bert-base-uncased",
add_linear: bool = False,
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
embedding_size: int = 128,
freeze_encoder: bool = True
freeze_encoder: bool = True,
hidden_dim: int = 768,
norm_embedding: bool = False
) -> None:
"""
Overview:
Expand All @@ -36,10 +39,16 @@ def __init__(
- embedding_size (:obj:`int`): The embedding size of the added linear layer, such as 128.
- freeze_encoder (:obj:`bool`): Whether to freeze the encoder language model while training, \
defaults to be ``True``.
- hidden_dim (:obj:`int`): The embedding dimension of the encoding model (e.g. BERT). This value should \
correspond to the model you use. For bert-base-uncased, this value is 768.
- norm_embedding (:obj:`bool`): Whether to normalize the embedding vectors. Default to be ``False``.
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
"""
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForTokenClassification.from_pretrained(model_name)
in_channel = hidden_dim if not add_linear else embedding_size
self.value_head = nn.Linear(in_channel, 1)
self.norm = nn.Identity() if not norm_embedding else nn.LayerNorm(normalized_shape=in_channel)
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved

# Freeze transformer encoder and only train the linear layer
if freeze_encoder:
Expand All @@ -49,9 +58,7 @@ def __init__(
if add_linear:
# Add a small, adjustable linear layer on top of language model tuned through RL
self.embedding_size = embedding_size
self.linear = nn.Linear(
self.model.config.hidden_size, embedding_size
) # 768 for bert-base-uncased, distilbert-base-uncased
self.linear = nn.Linear(self.model.config.hidden_size, embedding_size)
else:
self.linear = None

Expand All @@ -66,19 +73,27 @@ def _calc_embedding(self, x: list) -> torch.Tensor:
last_hidden_states = output.hidden_states[-1]
# Get [CLS] hidden states
sentence_embedding = last_hidden_states[:, 0, :] # len(input_list) x hidden_size
sentence_embedding = self.norm(sentence_embedding)

PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
if self.linear:
sentence_embedding = self.linear(sentence_embedding) # len(input_list) x embedding_size

return sentence_embedding

def forward(self, train_samples: List[str], candidate_samples: List[str]) -> Dict:
def forward(
self,
train_samples: List[str],
candidate_samples: Optional[List[str]] = None,
mode: str = 'compute_actor'
) -> Dict:
"""
Overview:
LanguageTransformer forward computation graph, input two lists of strings and predict their matching scores.
Different ``mode`` will forward with different network modules to get different outputs.
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
Arguments:
- train_samples (:obj:`List[str]`): One list of strings.
- candidate_samples (:obj:`List[str]`): The other list of strings to calculate the matching scores.
- candidate_samples (:obj:`Optional[List[str]]`): The other list of strings to calculate matching scores.
- - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class.
Returns:
- output (:obj:`Dict`): Output dict data, including the logit of matching scores and the \
corresponding ``torch.distributions.Categorical`` object.
Expand All @@ -96,7 +111,15 @@ def forward(self, train_samples: List[str], candidate_samples: List[str]) -> Dic
>>> scores = model(ctxt_list, cands_list)
>>> assert scores.shape == (1, 3)
"""
assert mode in self.mode
prompt_embedding = self._calc_embedding(train_samples)
cands_embedding = self._calc_embedding(candidate_samples)
scores = torch.mm(prompt_embedding, cands_embedding.t())
return {'dist': torch.distributions.Categorical(logits=scores), 'logit': scores}

res_dict = {}
if mode in ['compute_actor', 'compute_actor_critic']:
cands_embedding = self._calc_embedding(candidate_samples)
scores = torch.mm(prompt_embedding, cands_embedding.t())
res_dict.update({'dist': torch.distributions.Categorical(logits=scores), 'logit': scores})
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
if mode in ['compute_critic', 'compute_actor_critic']:
value = self.value_head(prompt_embedding)
res_dict.update({'value': value})
return res_dict
34 changes: 29 additions & 5 deletions ding/model/template/tests/test_language_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,33 @@ def check_model(self):
cands_list = [problems[pid] for pid in cand_pids]

model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256)
scores = model(ctxt_list, cands_list)
assert scores.shape == (1, 3)
output = model(ctxt_list, cands_list, mode='compute_actor')
assert 'dist' in output.keys() and 'logit' in output.keys() and len(output.keys()) == 2
assert output['logit'].shape == (1, 3)

model = LanguageTransformer(model_name="bert-base-uncased", add_linear=False, embedding_size=256)
scores = model(ctxt_list, cands_list)
assert scores.shape == (1, 3)
output = model(ctxt_list, cands_list, mode='compute_critic')
assert 'value' in output.keys() and len(output.keys()) == 1
assert output['value'].shape == (1, )

output = model(ctxt_list, cands_list, mode='compute_critic')
assert 'value' in output.keys() and 'dist' in output.keys() and 'logit' in output.keys() and len(
output.keys()
) == 3
assert output['value'].shape == (1, )
assert output['logit'].shape == (1, 3)

model = LanguageTransformer(model_name="bert-base-uncased", add_linear=False, norm_embedding=True)
output = model(ctxt_list, cands_list, mode='compute_actor')
assert 'dist' in output.keys() and 'logit' in output.keys() and len(output.keys()) == 2
assert output['logit'].shape == (1, 3)

output = model(ctxt_list, cands_list, mode='compute_critic')
assert 'value' in output.keys() and len(output.keys()) == 1
assert output['value'].shape == (1, )

output = model(ctxt_list, cands_list, mode='compute_critic')
assert 'value' in output.keys() and 'dist' in output.keys() and 'logit' in output.keys() and len(
output.keys()
) == 3
assert output['value'].shape == (1, )
assert output['logit'].shape == (1, 3)
1 change: 1 addition & 0 deletions ding/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,5 @@
# new-type policy
from .ppof import PPOFPolicy
from .prompt_pg import PromptPGPolicy
from .prompt_awr import PromptAWRPolicy
from .happo import HAPPOPolicy
6 changes: 6 additions & 0 deletions ding/policy/command_mode_policy_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from .prompt_pg import PromptPGPolicy
from .plan_diffuser import PDPolicy
from .happo import HAPPOPolicy
from .prompt_awr import PromptAWRPolicy


class EpsCommandModePolicy(CommandModePolicy):
Expand Down Expand Up @@ -455,3 +456,8 @@ def _get_setting_eval(self, command_info: dict) -> dict:
@POLICY_REGISTRY.register('prompt_pg_command')
class PromptPGCommandModePolicy(PromptPGPolicy, DummyCommandModePolicy):
pass


@POLICY_REGISTRY.register('prompt_awr_command')
class PromptAWRCommandModePolicy(PromptAWRPolicy, DummyCommandModePolicy):
pass
Loading
Loading