Skip to content

Commit

Permalink
rename dens heads
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque committed Nov 8, 2024
1 parent 6ac0683 commit fcbeb0a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 30 deletions.
68 changes: 39 additions & 29 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2_dens.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,18 @@ def forward(self, data) -> dict[str, torch.Tensor]:
return {"node_embedding": x, "graph": graph}


@registry.register_model("equiformer_v2_dens_energy_head")
class DeNSEnergyHead(torch.nn.Module, HeadInterface):
def __init__(self, backbone: BackboneInterface, reduce: str = "sum"):
@registry.register_model("eqV2_DeNS_scalar_head")
class DeNSScalarHead(torch.nn.Module, HeadInterface):
def __init__(
self,
backbone: BackboneInterface,
output_name: str = "energy",
reduce: str = "sum",
):
super().__init__()
self.reduce = reduce
self.avg_num_nodes = backbone.avg_num_nodes
self.energy_block = FeedForwardNetwork(
self.scalar_block = FeedForwardNetwork(
backbone.sphere_channels,
backbone.ffn_hidden_channels,
1,
Expand All @@ -399,49 +404,51 @@ def __init__(self, backbone: BackboneInterface, reduce: str = "sum"):
backbone.use_grid_mlp,
backbone.use_sep_s2_act,
)
self.output_name = output_name
self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init))
self.use_denoising_energy = backbone.use_denoising_energy

def forward(
self, data: Batch, emb: dict[str, torch.Tensor | GraphData]
) -> dict[str, torch.Tensor]:
node_energy = self.energy_block(emb["node_embedding"])
node_energy = node_energy.embedding.narrow(1, 0, 1)
node_out = self.scalar_block(emb["node_embedding"])
node_out = node_out.embedding.narrow(1, 0, 1)
if gp_utils.initialized():
node_energy = gp_utils.gather_from_model_parallel_region(node_energy, dim=0)
energy = torch.zeros(
node_out = gp_utils.gather_from_model_parallel_region(node_out, dim=0)
output_scalar = torch.zeros(
len(data.natoms),
device=node_energy.device,
dtype=node_energy.dtype,
device=node_out.device,
dtype=node_out.dtype,
)

energy.index_add_(0, data.batch, node_energy.view(-1))
output_scalar.index_add_(0, data.batch, node_out.view(-1))

if (
hasattr(data, "denoising_pos_forward")
and data.denoising_pos_forward
and not self.use_denoising_energy
):
energy = energy * 0.0
output_scalar = output_scalar * 0.0

if self.reduce == "sum":
return {"energy": energy / self.avg_num_nodes}
return {self.output_name: output_scalar / self.avg_num_nodes}
elif self.reduce == "mean":
return {"energy": energy / data.natoms}
return {self.output_name: output_scalar / data.natoms}
else:
raise ValueError(
f"reduce can only be sum or mean, user provided: {self.reduce}"
)


@registry.register_model("equiformer_v2_dens_force_head")
class DeNSForceHead(torch.nn.Module, HeadInterface):
def __init__(self, backbone):
@registry.register_model("eqV2_DeNS_vector_head")
class DeNSVectorHead(torch.nn.Module, HeadInterface):
def __init__(self, backbone: BackboneInterface, output_name: str = "forces"):
super().__init__()

self.output_name = output_name
self.activation_checkpoint = backbone.activation_checkpoint

self.force_block = SO2EquivariantGraphAttention(
self.vector_block = SO2EquivariantGraphAttention(
backbone.sphere_channels,
backbone.attn_hidden_channels,
backbone.num_heads,
Expand Down Expand Up @@ -494,8 +501,8 @@ def forward(
self, data: Batch, emb: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
if self.activation_checkpoint:
forces = torch.utils.checkpoint.checkpoint(
self.force_block,
output_vector = torch.utils.checkpoint.checkpoint(
self.vector_block,
emb["node_embedding"],
emb["graph"].atomic_numbers_full,
emb["graph"].edge_distance,
Expand All @@ -513,7 +520,7 @@ def forward(
use_reentrant=not self.training,
)
else:
forces = self.force_block(
output_vector = self.vector_block(
emb["node_embedding"],
emb["graph"].atomic_numbers_full,
emb["graph"].edge_distance,
Expand All @@ -527,28 +534,31 @@ def forward(
emb["graph"].edge_index,
node_offset=emb["graph"].node_offset,
)
forces = forces.embedding.narrow(1, 1, 3)
forces = forces.view(-1, 3).contiguous()
output_vector = output_vector.embedding.narrow(1, 1, 3)
output_vector = output_vector.view(-1, 3).contiguous()
denoising_pos_vec = denoising_pos_vec.embedding.narrow(1, 1, 3)
denoising_pos_vec = denoising_pos_vec.view(-1, 3)
if gp_utils.initialized():
forces = gp_utils.gather_from_model_parallel_region(forces, dim=0)
output_vector = gp_utils.gather_from_model_parallel_region(
output_vector, dim=0
)
denoising_pos_vec = gp_utils.gather_from_model_parallel_region(
denoising_pos_vec, dim=0
)

if hasattr(data, "denoising_pos_forward") and data.denoising_pos_forward:
if hasattr(data, "noise_mask"):
noise_mask_tensor = data.noise_mask.view(-1, 1)
forces = denoising_pos_vec * noise_mask_tensor + forces * (
~noise_mask_tensor
output_vector = (
denoising_pos_vec * noise_mask_tensor
+ output_vector * (~noise_mask_tensor)
)
else:
forces = denoising_pos_vec + 0 * forces
output_vector = denoising_pos_vec + 0 * output_vector
else:
forces = 0 * denoising_pos_vec + forces
output_vector = 0 * denoising_pos_vec + output_vector

return {"forces": forces}
return {self.output_name: output_vector}


@registry.register_model("dens_rank2_symmetric_head")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def train(self, disable_eval_tqdm=False):
all_atoms=self.denoising_pos_params.all_atoms,
)

# Forward, loss, backward.
# Forward, loss, backward. #TODO update this with new signatures
with torch.cuda.amp.autocast(enabled=self.scaler is not None):
out = self._forward(batch)
loss = self._compute_loss(out, batch)
Expand Down

0 comments on commit fcbeb0a

Please sign in to comment.