Skip to content

Commit

Permalink
Merge pull request #6 from maren-ha/muon
Browse files Browse the repository at this point in the history
merge muon with main
  • Loading branch information
maren-ha authored Jun 8, 2023
2 parents 1093ea8 + 8297ea9 commit 07cfcb9
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 18 deletions.
16 changes: 4 additions & 12 deletions docs/src/DataProcessing.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
# Data processing

## `AnnData` object
## `AnnData` object ind I/O

```@docs
AnnData
```
The `AnnData` struct is imported from [`Muon.jl`](https://github.com/scverse/Muon.jl). The package provides read and write functions for `.h5ad` and `.h5mu` files, the typical H5-based format for storing Python `anndata` objects. The `AnnData` object stores datasets together with metadata, such as information on the variables (genes in scRNA-seq data) and observations (cells), as well as different kinds of annotations and transformations of the original count matrix, such as PCA or UMAP embeddings, or graphs of observations or variables.

## I/O
For details on the Julia implementation in `Muon.jl`, see the [documentation](https://scverse.github.io/Muon.jl/dev/).

```@docs
read_h5ad
```

```@docs
write_h5ad
```
For more details on the original Python implementation of the `anndata` object, see the [documentation](https://anndata.readthedocs.io/en/latest/) and [preprint](https://doi.org/10.1101/2021.12.16.473007).

## Library size and normalization

Expand Down
16 changes: 10 additions & 6 deletions src/tSNEPenalty.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ function loss(m::scVAE, x::AbstractMatrix{S}, P::AbstractMatrix{S}, batch_indice
px_scale, px_r, px_rate, px_dropout = scVI.generative(m, z, library)
kl_divergence_z = -0.5f0 .* sum(1.0f0 .+ log.(qz_v) - qz_m.^2 .- qz_v, dims=1) # 2

kl_divergence_l = get_kl_divergence_l(m, ql_m, ql_v, batch_indices)
kl_divergence_l = scVI.get_kl_divergence_l(m, ql_m, ql_v, batch_indices)

reconst_loss = get_reconstruction_loss(m, x, px_rate, px_r, px_dropout)
reconst_loss = scVI.get_reconstruction_loss(m, x, px_rate, px_r, px_dropout)
kl_local_for_warmup = kl_divergence_z
kl_local_no_warmup = kl_divergence_l
weighted_kl_local = kl_weight .* kl_local_for_warmup .+ kl_local_no_warmup
Expand All @@ -129,11 +129,15 @@ function loss(m::scVAE, x::AbstractMatrix{S}, P::AbstractMatrix{S}, batch_indice
P = P .* cheat_scale/sum(P) # normalize + early exaggeration
sum_P = cheat_scale
end
tsne_penalty = compute_kldiv(z, P, sum_P)
#println(tsne_penalty)
tsne_penalty = scVI.compute_kldiv(z, P, sum_P)
println(tsne_penalty)

#graph_loss = sum(z*(Diagonal(vec(sum(P, dims=1))) .- P)*z')
#0.5.*sum(P[i,j].*(z[:,i] .- z[:,j]).^2 for i in 1:size(P,1), j in 1:size(P,2))
#println(graph_loss)

lossval = mean(reconst_loss + weighted_kl_local)
return lossval + 100.0f0*tsne_penalty
lossval = mean(reconst_loss + weighted_kl_local)
return lossval + 150.0f0*tsne_penalty #+ 1.0f0 * graph_loss
end

function register_losses!(m::scVAE, x::AbstractMatrix{S}, P::AbstractMatrix{S}, batch_indices::Vector{Int};
Expand Down

0 comments on commit 07cfcb9

Please sign in to comment.