-
Notifications
You must be signed in to change notification settings - Fork 373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature(zc): add MetaDiffuser and prompt-dt #771
base: main
Are you sure you want to change the base?
Changes from 15 commits
b05c856
a459fd0
1c08ede
e97725c
32ccf3f
94648d1
16e8144
3524c72
b0e7274
6be5920
7519400
3bafbf1
2b1bdaa
c8d9c7f
fd2896c
35e8e77
9b611db
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
from typing import Union, Optional, List, Any, Tuple | ||
import os | ||
import torch | ||
from functools import partial | ||
from tensorboardX import SummaryWriter | ||
from copy import deepcopy | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data.distributed import DistributedSampler | ||
|
||
from ding.envs import get_vec_env_setting, create_env_manager | ||
from ding.worker import BaseLearner, InteractionSerialMetaEvaluator | ||
from ding.config import read_config, compile_config | ||
from ding.policy import create_policy | ||
from ding.utils import set_pkg_seed, get_world_size, get_rank | ||
from ding.utils.data import create_dataset | ||
|
||
def serial_pipeline_meta_offline( | ||
input_cfg: Union[str, Tuple[dict, dict]], | ||
seed: int = 0, | ||
env_setting: Optional[List[Any]] = None, | ||
model: Optional[torch.nn.Module] = None, | ||
max_train_iter: Optional[int] = int(1e10), | ||
) -> 'Policy': # noqa | ||
""" | ||
Overview: | ||
Serial pipeline entry. | ||
Arguments: | ||
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ | ||
``str`` type means config file path. \ | ||
``Tuple[dict, dict]`` type means [user_config, create_cfg]. | ||
- seed (:obj:`int`): Random seed. | ||
- env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ | ||
``BaseEnv`` subclass, collector env config, and evaluator env config. | ||
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. | ||
- max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. | ||
Returns: | ||
- policy (:obj:`Policy`): Converged policy. | ||
""" | ||
if isinstance(input_cfg, str): | ||
cfg, create_cfg = read_config(input_cfg) | ||
else: | ||
cfg, create_cfg = deepcopy(input_cfg) | ||
create_cfg.policy.type = create_cfg.policy.type + '_command' | ||
cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg) | ||
|
||
cfg.env['seed'] = seed | ||
|
||
# Dataset | ||
dataset = create_dataset(cfg) | ||
|
||
sampler, shuffle = None, True | ||
if get_world_size() > 1: | ||
sampler, shuffle = DistributedSampler(dataset), False | ||
dataloader = DataLoader( | ||
dataset, | ||
# Dividing by get_world_size() here simply to make multigpu | ||
# settings mathmatically equivalent to the singlegpu setting. | ||
# If the training efficiency is the bottleneck, feel free to | ||
# use the original batch size per gpu and increase learning rate | ||
# correspondingly. | ||
cfg.policy.learn.batch_size // get_world_size(), | ||
# cfg.policy.learn.batch_size | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this line. |
||
shuffle=shuffle, | ||
sampler=sampler, | ||
collate_fn=lambda x: x, | ||
pin_memory=cfg.policy.cuda, | ||
) | ||
|
||
# Env, policy | ||
env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env, collect=False) | ||
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) | ||
|
||
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) | ||
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'eval']) | ||
|
||
if hasattr(policy, 'set_statistic'): | ||
# useful for setting action bounds for ibc | ||
policy.set_statistic(dataset.statistics) | ||
|
||
if cfg.policy.need_init_dataprocess: | ||
policy.init_dataprocess_func(dataset) | ||
|
||
if get_rank() == 0: | ||
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) | ||
else: | ||
tb_logger = None | ||
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) | ||
evaluator = InteractionSerialMetaEvaluator( | ||
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name | ||
) | ||
evaluator.init_params(dataset.params) | ||
|
||
learner.call_hook('before_run') | ||
stop = False | ||
|
||
for epoch in range(cfg.policy.learn.train_epoch): | ||
if get_world_size() > 1: | ||
dataloader.sampler.set_epoch(epoch) | ||
for i in range(cfg.policy.train_num): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "train_num"->"batch_size"? |
||
dataset.set_task_id(i) | ||
for train_data in dataloader: | ||
learner.train(train_data) | ||
|
||
# Evaluate policy at most once per epoch. | ||
if evaluator.should_eval(learner.train_iter): | ||
if hasattr(policy, 'warm_train'): | ||
# if algorithm need warm train | ||
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, | ||
policy_warm_func=policy.warm_train, need_reward=cfg.policy.need_reward) | ||
else: | ||
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, | ||
need_reward=cfg.policy.need_reward) | ||
|
||
if stop or learner.train_iter >= max_train_iter: | ||
stop = True | ||
break | ||
|
||
learner.call_hook('after_run') | ||
print('final reward is: {}'.format(reward)) | ||
return policy, stop |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from .base_env_manager import BaseEnvManager, BaseEnvManagerV2, create_env_manager, get_env_manager_cls | ||
from .subprocess_env_manager import AsyncSubprocessEnvManager, SyncSubprocessEnvManager, SubprocessEnvManagerV2 | ||
from .subprocess_env_manager import AsyncSubprocessEnvManager, SyncSubprocessEnvManager, SubprocessEnvManagerV2,\ | ||
MetaSyncSubprocessEnvManager | ||
from .gym_vector_env_manager import GymVectorEnvManager | ||
# Do not import PoolEnvManager here, because it depends on installation of `envpool` | ||
from .env_supervisor import EnvSupervisor |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ | |
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from ding.utils import SequenceType | ||
from ding.utils import SequenceType, MODEL_REGISTRY | ||
|
||
|
||
class MaskedCausalAttention(nn.Module): | ||
|
@@ -156,7 +156,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# x = x + self.mlp(self.ln2(x)) | ||
return x | ||
|
||
|
||
@MODEL_REGISTRY.register('dt') | ||
class DecisionTransformer(nn.Module): | ||
""" | ||
Overview: | ||
|
@@ -176,7 +176,8 @@ def __init__( | |
drop_p: float, | ||
max_timestep: int = 4096, | ||
state_encoder: Optional[nn.Module] = None, | ||
continuous: bool = False | ||
continuous: bool = False, | ||
use_prompt: bool = False, | ||
): | ||
""" | ||
Overview: | ||
|
@@ -206,6 +207,9 @@ def __init__( | |
# projection heads (project to embedding) | ||
self.embed_ln = nn.LayerNorm(h_dim) | ||
self.embed_timestep = nn.Embedding(max_timestep, h_dim) | ||
if use_prompt: | ||
self.prompt_embed_timestep = nn.Embedding(max_timestep, h_dim) | ||
input_seq_len *= 2 | ||
self.drop = nn.Dropout(drop_p) | ||
|
||
self.pos_emb = nn.Parameter(torch.zeros(1, input_seq_len + 1, self.h_dim)) | ||
|
@@ -218,14 +222,21 @@ def __init__( | |
self.embed_state = torch.nn.Linear(state_dim, h_dim) | ||
self.predict_rtg = torch.nn.Linear(h_dim, 1) | ||
self.predict_state = torch.nn.Linear(h_dim, state_dim) | ||
if use_prompt: | ||
self.prompt_embed_state = torch.nn.Linear(state_dim, h_dim) | ||
self.prompt_embed_rtg = torch.nn.Linear(1, h_dim) | ||
if continuous: | ||
# continuous actions | ||
self.embed_action = torch.nn.Linear(act_dim, h_dim) | ||
use_action_tanh = True # True for continuous actions | ||
if use_prompt: | ||
self.prompt_embed_action = torch.nn.Linear(act_dim, h_dim) | ||
else: | ||
# discrete actions | ||
self.embed_action = torch.nn.Embedding(act_dim, h_dim) | ||
use_action_tanh = False # False for discrete actions | ||
if use_prompt: | ||
self.prompt_embed_action = torch.nn.Embedding(act_dim, h_dim) | ||
self.predict_action = nn.Sequential( | ||
*([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if use_action_tanh else [])) | ||
) | ||
|
@@ -243,7 +254,8 @@ def forward( | |
states: torch.Tensor, | ||
actions: torch.Tensor, | ||
returns_to_go: torch.Tensor, | ||
tar: Optional[int] = None | ||
tar: Optional[int] = None, | ||
prompt: dict = None, | ||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
""" | ||
Overview: | ||
|
@@ -299,7 +311,34 @@ def forward( | |
t_p = torch.stack((returns_embeddings, state_embeddings, action_embeddings), | ||
dim=1).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim) | ||
h = self.embed_ln(t_p) | ||
|
||
if prompt is not None: | ||
prompt_states, prompt_actions, prompt_returns_to_go,\ | ||
prompt_timesteps, prompt_attention_mask = prompt | ||
prompt_seq_length = prompt_states.shape[1] | ||
prompt_state_embeddings = self.prompt_embed_state(prompt_states) | ||
prompt_action_embeddings = self.prompt_embed_action(prompt_actions) | ||
if prompt_returns_to_go.shape[1] % 10 == 1: | ||
prompt_returns_embeddings = self.prompt_embed_rtg(prompt_returns_to_go[:,:-1]) | ||
else: | ||
prompt_returns_embeddings = self.prompt_embed_rtg(prompt_returns_to_go) | ||
prompt_time_embeddings = self.prompt_embed_timestep(prompt_timesteps) | ||
|
||
prompt_state_embeddings = prompt_state_embeddings + prompt_time_embeddings | ||
prompt_action_embeddings = prompt_action_embeddings + prompt_time_embeddings | ||
prompt_returns_embeddings = prompt_returns_embeddings + prompt_time_embeddings | ||
prompt_stacked_inputs = torch.stack( | ||
(prompt_returns_embeddings, prompt_state_embeddings, prompt_action_embeddings), dim=1 | ||
).permute(0, 2, 1, 3).reshape(prompt_states.shape[0], 3 * prompt_seq_length, self.h_dim) | ||
|
||
# prompt_stacked_attention_mask = torch.stack( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove these unused lines? |
||
# (prompt_attention_mask, prompt_attention_mask, prompt_attention_mask), dim=1 | ||
# ).permute(0, 2, 1).reshape(prompt_states.shape[0], 3 * prompt_seq_length | ||
h = torch.cat((prompt_stacked_inputs, h), dim=1) | ||
# stacked_attention_mask = torch.cat((prompt_stacked_attention_mask, stacked_attention_mask), dim=1) | ||
|
||
# transformer and prediction | ||
|
||
h = self.transformer(h) | ||
# get h reshaped such that its size = (B x 3 x T x h_dim) and | ||
# h[:, 0, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t | ||
|
@@ -308,11 +347,15 @@ def forward( | |
# that is, for each timestep (t) we have 3 output embeddings from the transformer, | ||
# each conditioned on all previous timesteps plus | ||
# the 3 input variables at that timestep (r_t, s_t, a_t) in sequence. | ||
h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3) | ||
if prompt is None: | ||
h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3) | ||
else: | ||
h = h.reshape(B, -1, 3, self.h_dim).permute(0, 2, 1, 3) | ||
|
||
return_preds = self.predict_rtg(h[:, 2])[:, -T:, :] # predict next rtg given r, s, a | ||
state_preds = self.predict_state(h[:, 2])[:, -T:, :] # predict next state given r, s, a | ||
action_preds = self.predict_action(h[:, 1])[:, -T:, :] # predict action given r, s | ||
|
||
return_preds = self.predict_rtg(h[:, 2]) # predict next rtg given r, s, a | ||
state_preds = self.predict_state(h[:, 2]) # predict next state given r, s, a | ||
action_preds = self.predict_action(h[:, 1]) # predict action given r, s | ||
else: | ||
state_embeddings = self.state_encoder( | ||
states.reshape(-1, *self.state_dim).type(torch.float32).contiguous() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add more details?