Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(zc): add MetaDiffuser and prompt-dt #771

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions ding/entry/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@
from .serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream, serial_pipeline_dreamer
from .serial_entry_bco import serial_pipeline_bco
from .serial_entry_pc import serial_pipeline_pc
from .serial_entry_meta_offline import serial_pipeline_meta_offline
120 changes: 120 additions & 0 deletions ding/entry/serial_entry_meta_offline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from typing import Union, Optional, List, Any, Tuple
import os
import torch
from functools import partial
from tensorboardX import SummaryWriter
from copy import deepcopy
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from ding.envs import get_vec_env_setting, create_env_manager
from ding.worker import BaseLearner, InteractionSerialMetaEvaluator
from ding.config import read_config, compile_config
from ding.policy import create_policy
from ding.utils import set_pkg_seed, get_world_size, get_rank
from ding.utils.data import create_dataset

def serial_pipeline_meta_offline(
input_cfg: Union[str, Tuple[dict, dict]],
seed: int = 0,
env_setting: Optional[List[Any]] = None,
model: Optional[torch.nn.Module] = None,
max_train_iter: Optional[int] = int(1e10),
) -> 'Policy': # noqa
"""
Overview:
Serial pipeline entry.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add more details?

Arguments:
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
``str`` type means config file path. \
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
``BaseEnv`` subclass, collector env config, and evaluator env config.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""
if isinstance(input_cfg, str):
cfg, create_cfg = read_config(input_cfg)
else:
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg)

cfg.env['seed'] = seed

# Dataset
dataset = create_dataset(cfg)

sampler, shuffle = None, True
if get_world_size() > 1:
sampler, shuffle = DistributedSampler(dataset), False
dataloader = DataLoader(
dataset,
# Dividing by get_world_size() here simply to make multigpu
# settings mathmatically equivalent to the singlegpu setting.
# If the training efficiency is the bottleneck, feel free to
# use the original batch size per gpu and increase learning rate
# correspondingly.
cfg.policy.learn.batch_size // get_world_size(),
# cfg.policy.learn.batch_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this line.

shuffle=shuffle,
sampler=sampler,
collate_fn=lambda x: x,
pin_memory=cfg.policy.cuda,
)

# Env, policy
env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env, collect=False)
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'eval'])

if hasattr(policy, 'set_statistic'):
# useful for setting action bounds for ibc
policy.set_statistic(dataset.statistics)

if cfg.policy.need_init_dataprocess:
policy.init_dataprocess_func(dataset)

if get_rank() == 0:
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
else:
tb_logger = None
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
evaluator = InteractionSerialMetaEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator.init_params(dataset.params)

learner.call_hook('before_run')
stop = False

for epoch in range(cfg.policy.learn.train_epoch):
if get_world_size() > 1:
dataloader.sampler.set_epoch(epoch)
for i in range(cfg.policy.train_num):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"train_num"->"batch_size"?

dataset.set_task_id(i)
for train_data in dataloader:
learner.train(train_data)

# Evaluate policy at most once per epoch.
if evaluator.should_eval(learner.train_iter):
if hasattr(policy, 'warm_train'):
# if algorithm need warm train
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter,
policy_warm_func=policy.warm_train, need_reward=cfg.policy.need_reward)
else:
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter,
need_reward=cfg.policy.need_reward)

if stop or learner.train_iter >= max_train_iter:
stop = True
break

learner.call_hook('after_run')
print('final reward is: {}'.format(reward))
return policy, stop
3 changes: 2 additions & 1 deletion ding/envs/env_manager/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base_env_manager import BaseEnvManager, BaseEnvManagerV2, create_env_manager, get_env_manager_cls
from .subprocess_env_manager import AsyncSubprocessEnvManager, SyncSubprocessEnvManager, SubprocessEnvManagerV2
from .subprocess_env_manager import AsyncSubprocessEnvManager, SyncSubprocessEnvManager, SubprocessEnvManagerV2,\
MetaSyncSubprocessEnvManager
from .gym_vector_env_manager import GymVectorEnvManager
# Do not import PoolEnvManager here, because it depends on installation of `envpool`
from .env_supervisor import EnvSupervisor
20 changes: 20 additions & 0 deletions ding/envs/env_manager/subprocess_env_manager.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -832,3 +832,23 @@ def step(self, actions: Union[List[tnp.ndarray], tnp.ndarray]) -> List[tnp.ndarr
info = remove_illegal_item(info)
new_data.append(tnp.array({'obs': obs, 'reward': reward, 'done': done, 'info': info, 'env_id': env_id}))
return new_data

@ENV_MANAGER_REGISTRY.register('meta_subprocess')
class MetaSyncSubprocessEnvManager(SyncSubprocessEnvManager):

@property
def method_name_list(self) -> list:
return [
'reset', 'step', 'seed', 'close', 'enable_save_replay', 'render', 'reward_shaping', 'enable_save_figure',
'set_all_goals', 'reset_task'
]

def set_all_goals(self, params):
for p in self._pipe_parents.values():
p.send(['set_all_goals', [params], {}])
data = {i: p.recv() for i, p in self._pipe_parents.items()}

def reset_task(self, id):
for p in self._pipe_parents.values():
p.send(['reset_task', [id], {}])
data = {i: p.recv() for i, p in self._pipe_parents.items()}
59 changes: 51 additions & 8 deletions ding/model/template/decision_transformer.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ding.utils import SequenceType
from ding.utils import SequenceType, MODEL_REGISTRY


class MaskedCausalAttention(nn.Module):
Expand Down Expand Up @@ -156,7 +156,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# x = x + self.mlp(self.ln2(x))
return x


@MODEL_REGISTRY.register('dt')
class DecisionTransformer(nn.Module):
"""
Overview:
Expand All @@ -176,7 +176,8 @@ def __init__(
drop_p: float,
max_timestep: int = 4096,
state_encoder: Optional[nn.Module] = None,
continuous: bool = False
continuous: bool = False,
use_prompt: bool = False,
):
"""
Overview:
Expand Down Expand Up @@ -206,6 +207,9 @@ def __init__(
# projection heads (project to embedding)
self.embed_ln = nn.LayerNorm(h_dim)
self.embed_timestep = nn.Embedding(max_timestep, h_dim)
if use_prompt:
self.prompt_embed_timestep = nn.Embedding(max_timestep, h_dim)
input_seq_len *= 2
self.drop = nn.Dropout(drop_p)

self.pos_emb = nn.Parameter(torch.zeros(1, input_seq_len + 1, self.h_dim))
Expand All @@ -218,14 +222,21 @@ def __init__(
self.embed_state = torch.nn.Linear(state_dim, h_dim)
self.predict_rtg = torch.nn.Linear(h_dim, 1)
self.predict_state = torch.nn.Linear(h_dim, state_dim)
if use_prompt:
self.prompt_embed_state = torch.nn.Linear(state_dim, h_dim)
self.prompt_embed_rtg = torch.nn.Linear(1, h_dim)
if continuous:
# continuous actions
self.embed_action = torch.nn.Linear(act_dim, h_dim)
use_action_tanh = True # True for continuous actions
if use_prompt:
self.prompt_embed_action = torch.nn.Linear(act_dim, h_dim)
else:
# discrete actions
self.embed_action = torch.nn.Embedding(act_dim, h_dim)
use_action_tanh = False # False for discrete actions
if use_prompt:
self.prompt_embed_action = torch.nn.Embedding(act_dim, h_dim)
self.predict_action = nn.Sequential(
*([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if use_action_tanh else []))
)
Expand All @@ -243,7 +254,8 @@ def forward(
states: torch.Tensor,
actions: torch.Tensor,
returns_to_go: torch.Tensor,
tar: Optional[int] = None
tar: Optional[int] = None,
prompt: dict = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Overview:
Expand Down Expand Up @@ -299,7 +311,34 @@ def forward(
t_p = torch.stack((returns_embeddings, state_embeddings, action_embeddings),
dim=1).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim)
h = self.embed_ln(t_p)

if prompt is not None:
prompt_states, prompt_actions, prompt_returns_to_go,\
prompt_timesteps, prompt_attention_mask = prompt
prompt_seq_length = prompt_states.shape[1]
prompt_state_embeddings = self.prompt_embed_state(prompt_states)
prompt_action_embeddings = self.prompt_embed_action(prompt_actions)
if prompt_returns_to_go.shape[1] % 10 == 1:
prompt_returns_embeddings = self.prompt_embed_rtg(prompt_returns_to_go[:,:-1])
else:
prompt_returns_embeddings = self.prompt_embed_rtg(prompt_returns_to_go)
prompt_time_embeddings = self.prompt_embed_timestep(prompt_timesteps)

prompt_state_embeddings = prompt_state_embeddings + prompt_time_embeddings
prompt_action_embeddings = prompt_action_embeddings + prompt_time_embeddings
prompt_returns_embeddings = prompt_returns_embeddings + prompt_time_embeddings
prompt_stacked_inputs = torch.stack(
(prompt_returns_embeddings, prompt_state_embeddings, prompt_action_embeddings), dim=1
).permute(0, 2, 1, 3).reshape(prompt_states.shape[0], 3 * prompt_seq_length, self.h_dim)

# prompt_stacked_attention_mask = torch.stack(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove these unused lines?

# (prompt_attention_mask, prompt_attention_mask, prompt_attention_mask), dim=1
# ).permute(0, 2, 1).reshape(prompt_states.shape[0], 3 * prompt_seq_length
h = torch.cat((prompt_stacked_inputs, h), dim=1)
# stacked_attention_mask = torch.cat((prompt_stacked_attention_mask, stacked_attention_mask), dim=1)

# transformer and prediction

h = self.transformer(h)
# get h reshaped such that its size = (B x 3 x T x h_dim) and
# h[:, 0, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t
Expand All @@ -308,11 +347,15 @@ def forward(
# that is, for each timestep (t) we have 3 output embeddings from the transformer,
# each conditioned on all previous timesteps plus
# the 3 input variables at that timestep (r_t, s_t, a_t) in sequence.
h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3)
if prompt is None:
h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3)
else:
h = h.reshape(B, -1, 3, self.h_dim).permute(0, 2, 1, 3)

return_preds = self.predict_rtg(h[:, 2])[:, -T:, :] # predict next rtg given r, s, a
state_preds = self.predict_state(h[:, 2])[:, -T:, :] # predict next state given r, s, a
action_preds = self.predict_action(h[:, 1])[:, -T:, :] # predict action given r, s

return_preds = self.predict_rtg(h[:, 2]) # predict next rtg given r, s, a
state_preds = self.predict_state(h[:, 2]) # predict next state given r, s, a
action_preds = self.predict_action(h[:, 1]) # predict action given r, s
else:
state_embeddings = self.state_encoder(
states.reshape(-1, *self.state_dim).type(torch.float32).contiguous()
Expand Down
Loading