Skip to content

Commit

Permalink
Merge branch 'master' into parler_tts
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Oct 11, 2024
2 parents ec50204 + 9dfbab1 commit 19445dc
Show file tree
Hide file tree
Showing 14 changed files with 61 additions and 41 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/analysis.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
run: tokei . > tokei_output.txt

- name: Comment or Update PR
uses: actions/github-script@v6
uses: actions/github-script@v7
with:
script: |
const fs = require('fs');
Expand Down
12 changes: 6 additions & 6 deletions .github/workflows/build_cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,27 @@ jobs:
security-events: write
steps:
- name: Checkout repository
uses: actions/checkout@v3
uses: actions/checkout@v4

- name: Initialize Docker Buildx
uses: docker/setup-buildx-action@v2.0.0
uses: docker/setup-buildx-action@v3
with:
install: true

- name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1
uses: rlespinasse/github-slug-action@v4.5.0

- name: Login to GitHub Container Registry
if: github.event_name != 'pull_request'
uses: docker/login-action@v2
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}

- name: Extract metadata (tags, labels) for Docker
id: meta-cpu
uses: docker/metadata-action@v4.3.0
uses: docker/metadata-action@v5
with:
images: |
ghcr.io/${{env.GITHUB_REPOSITORY_OWNER_PART}}/${{env.GITHUB_REPOSITORY_NAME_PART}}
Expand All @@ -55,7 +55,7 @@ jobs:
type=raw,value=cpu-sha-${{ env.GITHUB_SHA_SHORT }}
- name: Build and push Docker image
id: build-and-push-cpu
uses: docker/build-push-action@v4
uses: docker/build-push-action@v6
with:
context: .
file: Dockerfile
Expand Down
10 changes: 5 additions & 5 deletions .github/workflows/build_cuda_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
uses: actions/checkout@v3

- name: Initialize Docker Buildx
uses: docker/setup-buildx-action@v2.0.0
uses: docker/setup-buildx-action@v3
with:
install: true

Expand All @@ -43,19 +43,19 @@ jobs:
${{ runner.os }}-buildx-
- name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1
uses: rlespinasse/github-slug-action@v4.5.0

- name: Login to GitHub Container Registry
if: github.event_name != 'pull_request'
uses: docker/login-action@v2
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}

- name: Extract metadata (tags, labels) for Docker
id: meta-cuda
uses: docker/metadata-action@v4.3.0
uses: docker/metadata-action@v5
with:
images: |
ghcr.io/${{env.GITHUB_REPOSITORY_OWNER_PART}}/${{env.GITHUB_REPOSITORY_NAME_PART}}
Expand All @@ -68,7 +68,7 @@ jobs:
type=raw,value=cuda-${{matrix.compute_capability}}-sha-${{ env.GITHUB_SHA_SHORT }}
- name: Build and push Docker image
id: build-and-push-cuda
uses: docker/build-push-action@v4
uses: docker/build-push-action@v6
with:
context: .
file: Dockerfile.cuda-all
Expand Down
14 changes: 7 additions & 7 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
os: [ubuntu-latest, windows-latest, macOS-latest]
rust: [stable]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
profile: minimal
Expand All @@ -38,7 +38,7 @@ jobs:
os: [ubuntu-latest, windows-latest, macOS-latest]
rust: [stable]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
profile: minimal
Expand All @@ -55,7 +55,7 @@ jobs:
name: Rustfmt
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
profile: minimal
Expand All @@ -71,7 +71,7 @@ jobs:
name: Clippy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
profile: minimal
Expand All @@ -87,7 +87,7 @@ jobs:
name: Docs
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
profile: minimal
Expand All @@ -102,7 +102,7 @@ jobs:
name: Typos
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
profile: minimal
Expand All @@ -117,5 +117,5 @@ jobs:
# markdown-link-check:
# runs-on: ubuntu-latest
# steps:
# - uses: actions/checkout@master
# - uses: actions/checkout@v4
# - uses: gaurav-nelson/github-action-markdown-link-check@v1
4 changes: 2 additions & 2 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ jobs:
rust: [stable]
steps:
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: ${{ matrix.rust }}
override: true
- name: Setup Pages
uses: actions/configure-pages@v3
uses: actions/configure-pages@v5
- uses: actions-rs/cargo@v1
with:
command: doc
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release_python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:

steps:
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v4

- uses: actions-rs/toolchain@v1
with:
Expand Down
10 changes: 5 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ license = "MIT"

[workspace.dependencies]
anyhow = "1.0.80"
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "156ebd1" }
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "156ebd1" }
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4" }
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4" }
serde = "1.0.197"
serde_json = "1.0.114"
indexmap = { version = "2.2.5", features = ["serde"] }
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
<a name="top"></a>
<h1 align="center">
mistral.rs
</h1>
Expand Down Expand Up @@ -586,3 +587,7 @@ If you want to add a new model, please contact us via an issue and we can coordi
## Credits
This project would not be possible without the excellent work at [`candle`](https://github.com/huggingface/candle). Additionally, thank you to all contributors! Contributing can range from raising an issue or suggesting a feature to adding some new functionality.
<p align="right">
<a href="#top">⬆️ Back to Top</a>
</p>
2 changes: 1 addition & 1 deletion mistralrs-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ candle-core.workspace = true
candle-nn.workspace = true
serde.workspace = true
serde_json.workspace = true
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "156ebd1", optional = true }
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4", optional = true }
dirs = "5.0.1"
hf-hub = "0.3.2"
thiserror = "1.0.57"
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-core/src/pipeline/loaders/vision_loaders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl FromStr for VisionLoaderType {
"llava_next" => Ok(Self::LLaVANext),
"llava" => Ok(Self::LLaVA),
"vllama" => Ok(Self::VLlama),
a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vsllama`.")),
a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`.")),
}
}
}
Expand Down
29 changes: 22 additions & 7 deletions mistralrs-core/src/vision_models/mllama/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,15 @@ fn prepare_cross_attention_mask(
let bs = cross_attention_mask.dim(0)?;
let text_total_length = cross_attention_mask.dim(1)?;
let mut cross_attn_mask = repeat_interleave(
&cross_attention_mask.to_dtype(DType::F32)?.to_dtype(dtype)?,
&cross_attention_mask.to_dtype(DType::F32)?,
num_vision_tokens,
3,
)?;
cross_attn_mask = cross_attn_mask.reshape((bs, text_total_length, ()))?;
cross_attn_mask = cross_attn_mask.unsqueeze(1)?;

// Invert the mask
let inverted_cross_attn_mask = (1. - cross_attn_mask.to_dtype(DType::F32)?.to_dtype(dtype)?)?;
let inverted_cross_attn_mask = (1. - cross_attn_mask)?;
const NEG_INF_VALUE: f32 = -1e15;
cross_attn_mask = masked_fill(
&inverted_cross_attn_mask,
Expand All @@ -75,7 +75,9 @@ fn prepare_cross_attention_mask(
.unsqueeze(D::Minus1)?;

cross_attn_mask = cross_attn_mask
.broadcast_mul(&full_text_row_masked_out_mask.to_dtype(cross_attn_mask.dtype())?)?;
.broadcast_mul(&full_text_row_masked_out_mask.to_dtype(cross_attn_mask.dtype())?)?
.to_dtype(DType::F32)?
.to_dtype(dtype)?;

Ok((cross_attn_mask, full_text_row_masked_out_mask))
}
Expand All @@ -85,6 +87,7 @@ pub(crate) struct MLlamaModel {
language_model: MLlamaTextModel,
multi_modal_projector: Linear,
hidden_size: usize,
dtype: DType,
}

impl MLlamaModel {
Expand All @@ -96,10 +99,18 @@ impl MLlamaModel {
attention_mechanism: AttentionImplementation,
) -> Result<Self> {
let real_dev = normal_loading_metadata.real_device.clone();
// This vision model is very sensitive.
let vision_model_dtype = if vb.dtype() == DType::F16 {
DType::F32
} else {
vb.dtype()
};
Ok(Self {
vision_model: MLlamaVisionModel::new(
&cfg.vision_config,
vb.pp("vision_model").set_device(real_dev.clone()),
vb.pp("vision_model")
.set_device(real_dev.clone())
.set_dtype(vision_model_dtype),
)?,
language_model: MLlamaTextModel::new(
&cfg.text_config,
Expand All @@ -111,9 +122,12 @@ impl MLlamaModel {
multi_modal_projector: linear(
cfg.vision_config.vision_output_dim,
cfg.text_config.hidden_size,
vb.pp("multi_modal_projector").set_device(real_dev.clone()),
vb.pp("multi_modal_projector")
.set_device(real_dev.clone())
.set_dtype(vision_model_dtype),
)?,
hidden_size: cfg.text_config.hidden_size,
dtype: vb.dtype(),
})
}

Expand Down Expand Up @@ -142,7 +156,8 @@ impl MLlamaModel {
let cross_attention_states = self
.multi_modal_projector
.forward(&vision_outputs.flatten(0, 1)?)?
.reshape(((), vision_outputs.dim(D::Minus2)?, self.hidden_size))?;
.reshape(((), vision_outputs.dim(D::Minus2)?, self.hidden_size))?
.to_dtype(self.dtype)?;
Some(cross_attention_states)
} else {
None
Expand All @@ -153,7 +168,7 @@ impl MLlamaModel {
let (cmask, fmask) = prepare_cross_attention_mask(
cross_attn_mask,
self.vision_model.num_patches,
self.multi_modal_projector.weight().dtype(),
self.dtype,
)?;
(Some(cmask), Some(fmask))
} else {
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-paged-attn/src/pagedattention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ __device__ void paged_attention_kernel(

// Load the query to registers.
// Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in the group
// For example, if the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
// th vectors of the query, and so on.
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
Expand Down Expand Up @@ -205,7 +205,7 @@ __device__ void paged_attention_kernel(

// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in the group
// For example, if the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
// vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-pyo3/Cargo_template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pyo3.workspace = true
mistralrs-core = { version = "0.3.1", path = "../mistralrs-core", features=["pyo3_macros","$feature_name"] }
serde.workspace = true
serde_json.workspace = true
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "156ebd1", features=["$feature_name"] }
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4", features=["$feature_name"] }
indexmap.workspace = true
accelerate-src = { workspace = true, optional = true }
intel-mkl-src = { workspace = true, optional = true }
Expand Down

0 comments on commit 19445dc

Please sign in to comment.