diff --git a/.typos.toml b/.typos.toml index 0b4a95179..7b8092737 100644 --- a/.typos.toml +++ b/.typos.toml @@ -4,7 +4,8 @@ extend-ignore-identifiers-re = [ "mmaped", "arange", "Nd", - "nin" + "nin", + "cudaDevAttrMaxSharedMemoryPerBlockOptin" ] [files] diff --git a/Cargo.lock b/Cargo.lock index 121b6c542..fcef40396 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,9 +20,9 @@ checksum = "415ed64958754dbe991900f3940677e6a7eefb4d7367afd70d642677b0c7d19d" [[package]] name = "addr2line" -version = "0.24.1" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5fb1d8e4442bd405fdfd1dacb42792696b0cf9cb15882e5d097b742a676d375" +checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" dependencies = [ "gimli", ] @@ -354,18 +354,18 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94bbb0ad554ad961ddc5da507a12a29b14e4ae5bda06b19f575a3e6079d2e2ae" +checksum = "8334215b81e418a0a7bdb8ef0849474f40bb10c8b71f1c4ed315cff49f32494d" dependencies = [ "bytemuck_derive", ] [[package]] name = "bytemuck_derive" -version = "1.7.1" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cc8b54b395f2fcfbb3d90c47b01c7f444d94d05bdeb775811dec868ac3bbc26" +checksum = "bcfcc3cd946cb52f0bbfdbbcfa2f4e24f75ebb6c0e1002f7c25904fada18b9ec" dependencies = [ "proc-macro2", "quote", @@ -393,13 +393,14 @@ checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" [[package]] name = "candle-core" version = "0.7.2" -source = "git+https://github.com/EricLBuehler/candle.git?rev=20a57c4#20a57c4bcf300e4bc6c1e48f7f3702668ae8cb80" +source = "git+https://github.com/EricLBuehler/candle.git?rev=f2b6941#f2b6941a4856ae5d5b3413789f58d4ee3aa5a562" dependencies = [ "accelerate-src", "byteorder", "candle-kernels", "candle-metal-kernels", "cudarc", + "float8", "gemm", "half", "intel-mkl-src", @@ -420,7 +421,7 @@ dependencies = [ [[package]] name = "candle-flash-attn" version = "0.7.2" -source = "git+https://github.com/EricLBuehler/candle.git?rev=20a57c4#20a57c4bcf300e4bc6c1e48f7f3702668ae8cb80" +source = "git+https://github.com/EricLBuehler/candle.git?rev=f2b6941#f2b6941a4856ae5d5b3413789f58d4ee3aa5a562" dependencies = [ "anyhow", "bindgen_cuda 0.1.5", @@ -431,7 +432,7 @@ dependencies = [ [[package]] name = "candle-kernels" version = "0.7.2" -source = "git+https://github.com/EricLBuehler/candle.git?rev=20a57c4#20a57c4bcf300e4bc6c1e48f7f3702668ae8cb80" +source = "git+https://github.com/EricLBuehler/candle.git?rev=f2b6941#f2b6941a4856ae5d5b3413789f58d4ee3aa5a562" dependencies = [ "bindgen_cuda 0.1.5", ] @@ -439,7 +440,7 @@ dependencies = [ [[package]] name = "candle-metal-kernels" version = "0.7.2" -source = "git+https://github.com/EricLBuehler/candle.git?rev=20a57c4#20a57c4bcf300e4bc6c1e48f7f3702668ae8cb80" +source = "git+https://github.com/EricLBuehler/candle.git?rev=f2b6941#f2b6941a4856ae5d5b3413789f58d4ee3aa5a562" dependencies = [ "metal", "once_cell", @@ -450,7 +451,7 @@ dependencies = [ [[package]] name = "candle-nn" version = "0.7.2" -source = "git+https://github.com/EricLBuehler/candle.git?rev=20a57c4#20a57c4bcf300e4bc6c1e48f7f3702668ae8cb80" +source = "git+https://github.com/EricLBuehler/candle.git?rev=f2b6941#f2b6941a4856ae5d5b3413789f58d4ee3aa5a562" dependencies = [ "accelerate-src", "candle-core", @@ -467,9 +468,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.24" +version = "1.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "812acba72f0a070b003d3697490d2b55b837230ae7c6c6497f05cc2ddbb8d938" +checksum = "b16803a61b81d9eabb7eae2588776c4c1e584b738ede45fdbb4c972cec1e9945" dependencies = [ "shlex", ] @@ -538,9 +539,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.19" +version = "4.5.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7be5744db7978a28d9df86a214130d106a89ce49644cbc4e3f0c22c3fba30615" +checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8" dependencies = [ "clap_builder", "clap_derive", @@ -548,9 +549,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.19" +version = "4.5.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5fbc17d3ef8278f55b282b2a2e75ae6f6c7d4bb70ed3d0382375104bfafdb4b" +checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54" dependencies = [ "anstream", "anstyle", @@ -882,18 +883,18 @@ dependencies = [ [[package]] name = "derive_builder" -version = "0.20.1" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd33f37ee6a119146a1781d3356a7c26028f83d779b2e04ecd45fdc75c76877b" +checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" dependencies = [ "derive_builder_macro", ] [[package]] name = "derive_builder_core" -version = "0.20.1" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7431fa049613920234f22c47fdc33e6cf3ee83067091ea4277a3f8c4587aae38" +checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" dependencies = [ "darling 0.20.10", "proc-macro2", @@ -903,9 +904,9 @@ dependencies = [ [[package]] name = "derive_builder_macro" -version = "0.20.1" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4abae7035bf79b9877b779505d8cf3749285b80c43941eda66604841889451dc" +checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" dependencies = [ "derive_builder_core", "syn 2.0.79", @@ -1119,6 +1120,19 @@ dependencies = [ "miniz_oxide 0.8.0", ] +[[package]] +name = "float8" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c3475274d374d263c4c40c43ad854c5bdf733c7db775bbd3c1ca2ad7427978" +dependencies = [ + "cudarc", + "half", + "num-traits", + "rand", + "rand_distr", +] + [[package]] name = "flume" version = "0.11.0" @@ -1187,9 +1201,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", @@ -1202,9 +1216,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", "futures-sink", @@ -1212,15 +1226,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" [[package]] name = "futures-executor" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" dependencies = [ "futures-core", "futures-task", @@ -1229,15 +1243,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" [[package]] name = "futures-macro" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", @@ -1246,21 +1260,21 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" [[package]] name = "futures-task" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" [[package]] name = "futures-util" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-channel", "futures-core", @@ -1447,9 +1461,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.31.0" +version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32085ea23f3234fc7846555e85283ba4de91e21016dc0455a16286d87a292d64" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "glob" @@ -1583,9 +1597,9 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "hyper" -version = "1.4.1" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" +checksum = "bbbff0a806a4728c99295b254c8838933b5b082d75e3cb70c8dab21fdfbcfa9a" dependencies = [ "bytes", "futures-channel", @@ -1794,9 +1808,9 @@ dependencies = [ [[package]] name = "ipnet" -version = "2.10.0" +version = "2.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "187674a687eed5fe42285b40c6291f9a01517d415fad1c3cbc6a9f778af7fcd4" +checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" [[package]] name = "is_terminal_polyfill" @@ -1845,9 +1859,9 @@ checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0" [[package]] name = "js-sys" -version = "0.3.70" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" +checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" dependencies = [ "wasm-bindgen", ] @@ -2155,6 +2169,7 @@ dependencies = [ "derive_more", "dirs", "either", + "float8", "futures", "galil-seiferas", "half", @@ -2208,6 +2223,7 @@ dependencies = [ "anyhow", "bindgen_cuda 0.1.6", "candle-core", + "float8", "half", ] @@ -2243,8 +2259,10 @@ dependencies = [ "byteorder", "candle-core", "candle-nn", + "float8", "half", "lazy_static", + "once_cell", "paste", "rayon", "serde", @@ -2484,9 +2502,9 @@ dependencies = [ [[package]] name = "object" -version = "0.36.4" +version = "0.36.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "084f1a5821ac4c651660a94a7153d27ac9d8a53736203f58b31945ded098070a" +checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e" dependencies = [ "memchr", ] @@ -2535,12 +2553,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.20.1" +version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82881c4be219ab5faaf2ad5e5e5ecdff8c66bd7402ca3160975c93b24961afd1" -dependencies = [ - "portable-atomic", -] +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" [[package]] name = "onig" @@ -2815,9 +2830,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.86" +version = "1.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +checksum = "b3e4daa0dcf6feba26f985457cdf104d4b4256fc5a09547140f3631bb076b19a" dependencies = [ "unicode-ident", ] @@ -2836,9 +2851,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.22.3" +version = "0.22.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15ee168e30649f7f234c3d49ef5a7a6cbf5134289bc46c29ff3155fa3221c225" +checksum = "3d922163ba1f79c04bc49073ba7b32fd5a8d3b76a87c955921234b8e77333c51" dependencies = [ "anyhow", "cfg-if", @@ -2867,9 +2882,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.22.3" +version = "0.22.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e61cef80755fe9e46bb8a0b8f20752ca7676dcc07a5277d8b7768c6172e529b3" +checksum = "bc38c5feeb496c8321091edf3d63e9a6829eab4b863b4a6a65f26f3e9cc6b179" dependencies = [ "once_cell", "target-lexicon", @@ -2877,9 +2892,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.22.3" +version = "0.22.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67ce096073ec5405f5ee2b8b31f03a68e02aa10d5d4f565eca04acc41931fa1c" +checksum = "94845622d88ae274d2729fcefc850e63d7a3ddff5e3ce11bd88486db9f1d357d" dependencies = [ "libc", "pyo3-build-config", @@ -2887,9 +2902,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.22.3" +version = "0.22.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2440c6d12bc8f3ae39f1e775266fa5122fd0c8891ce7520fa6048e683ad3de28" +checksum = "e655aad15e09b94ffdb3ce3d217acf652e26bbc37697ef012f5e5e348c716e5e" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -2899,9 +2914,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.22.3" +version = "0.22.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1be962f0e06da8f8465729ea2cb71a416d2257dff56cbe40a70d3e62a93ae5d1" +checksum = "ae1e3f09eecd94618f60a455a23def79f79eba4dc561a97324bf9ac8c6df30ce" dependencies = [ "heck", "proc-macro2", @@ -3295,9 +3310,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.13" +version = "0.23.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2dabaac7466917e566adb06783a81ca48944c6898a1b08b9374106dd671f4c8" +checksum = "415d9944693cb90382053259f89fbb077ea730ad7273047ec63b19bc9b160ba8" dependencies = [ "log", "once_cell", @@ -3319,9 +3334,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e696e35370c65c9c541198af4543ccd580cf17fc25d8e05c5a242b202488c55" +checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b" [[package]] name = "rustls-webpki" @@ -3336,9 +3351,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" +checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" [[package]] name = "ryu" @@ -3367,9 +3382,9 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.24" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9aaafd5a2b6e3d657ff009d82fbd630b6bd54dd4eb06f21693925cdf80f9b8b" +checksum = "01227be5826fa0690321a2ba6c5cd57a19cf3f6a09e76973b58e61de6ab9d1c1" dependencies = [ "windows-sys 0.59.0", ] @@ -4215,9 +4230,9 @@ dependencies = [ [[package]] name = "unicode-bidi" -version = "0.3.15" +version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" +checksum = "5ab17db44d7388991a428b2ee655ce0c212e862eff1768a455c58f9aad6e7893" [[package]] name = "unicode-ident" @@ -4329,9 +4344,9 @@ dependencies = [ [[package]] name = "utoipa-gen" -version = "4.3.0" +version = "4.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7bf0e16c02bc4bf5322ab65f10ab1149bdbcaa782cba66dc7057370a3f8190be" +checksum = "20c24e8ab68ff9ee746aad22d39b5535601e6416d1b0feeabf78be986a5c4392" dependencies = [ "proc-macro-error", "proc-macro2", @@ -4446,9 +4461,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" +checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" dependencies = [ "cfg-if", "once_cell", @@ -4457,9 +4472,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" +checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" dependencies = [ "bumpalo", "log", @@ -4472,9 +4487,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.43" +version = "0.4.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed" +checksum = "cc7ec4f8827a71586374db3e87abdb5a2bb3a15afed140221307c3ec06b1f63b" dependencies = [ "cfg-if", "js-sys", @@ -4484,9 +4499,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" +checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -4494,9 +4509,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" +checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", @@ -4507,15 +4522,15 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" +checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" [[package]] name = "web-sys" -version = "0.3.70" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" +checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" dependencies = [ "js-sys", "wasm-bindgen", diff --git a/Cargo.toml b/Cargo.toml index 109717ae2..4d4a2c746 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,8 +25,8 @@ license = "MIT" [workspace.dependencies] anyhow = "1.0.80" -candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4" } -candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4" } +candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "f2b6941" } +candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "f2b6941" } serde = "1.0.197" serde_json = "1.0.114" indexmap = { version = "2.2.5", features = ["serde"] } @@ -37,7 +37,7 @@ tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } futures = "0.3" clap = { version = "4.5.1", features = ["derive"] } -pyo3 = { version = "0.22.0", features = ["full", "extension-module", "either"] } +pyo3 = { version = "0.22.4", features = ["full", "extension-module", "either"] } tokio = { version = "1.36.0", features = ["full", "rt-multi-thread"] } once_cell = "1.19.0" # All features but avif, avif increases the msrv dramatically @@ -49,3 +49,4 @@ rayon = "1.1.0" url = "2.5.2" data-url = "0.3.1" buildstructor = "0.5.4" +float8 = "0.1.1" diff --git a/README.md b/README.md index 5d4cefc52..bbfbdfd30 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,9 @@ Please submit requests for new models [here](https://github.com/EricLBuehler/mis *After following installation instructions* +- Check out UQFF for prequantized models of various methods! + - Models can be found [here](https://huggingface.co/collections/EricB/uqff-670e4a49d56ecdd3f7f0fd4c). + - 🦙📷 Run the **Llama 3.2 Vision** Model: [documentation and guide here](docs/VLLAMA.md) Mount Washington @@ -70,7 +73,7 @@ Please submit requests for new models [here](https://github.com/EricLBuehler/mis - Other models: [see a support matrix](#support-matrix) and [how to run them](#run-with-the-cli) -Mistal.rs supports several model categories: +Mistral.rs supports several model categories: - Text to Text - Text+Image to Text: Vision (see [the docs](docs/VISION_MODELS.md)) - Text to Image: Image Generation (see [the docs](docs/IMAGEGEN_MODELS.md)) @@ -91,7 +94,7 @@ Mistal.rs supports several model categories: **Quantization**: - [Details](docs/QUANTS.md) - GGML: 2-bit, 3-bit, 4-bit, 5-bit, 6-bit and 8-bit, with ISQ support. -- GPTQ: 2-bit, 3-bit, 4-bit and 8-bit +- GPTQ: 2-bit, 3-bit, 4-bit and 8-bit, with [Marlin](https://github.com/IST-DASLab/marlin) kernel support in 4-bit and 8-bit. - HQQ: 4-bit and 8 bit, with ISQ support **Powerful**: @@ -106,7 +109,7 @@ Mistal.rs supports several model categories: - [PagedAttention](docs/PAGED_ATTENTION.md) and continuous batching - Prefix caching - [Topology](docs/TOPOLOGY.md): Configure ISQ and device mapping easily -- [UQFF](docs/UQFF.md): Quantized file format for easy mixing of quants, see some [models](docs/UQFF.md#list-of-models) which have already been converted. +- [UQFF](docs/UQFF.md): Quantized file format for easy mixing of quants, [collection here](https://huggingface.co/collections/EricB/uqff-670e4a49d56ecdd3f7f0fd4c). - Speculative Decoding: Mix supported models as the draft model or the target model - Dynamic LoRA adapter activation with adapter preloading: [examples and docs](docs/ADAPTER_MODELS.md#adapter-model-dynamic-adapter-activation) @@ -202,7 +205,7 @@ Enabling features is done by passing `--features ...` to the build system. When - Install the [Python package here](mistralrs-pyo3/README.md). -1) Install required packages +1) Install required packages: - `OpenSSL` (*Example on Ubuntu:* `sudo apt install libssl-dev`) - *Linux only:* `pkg-config` (*Example on Ubuntu:* `sudo apt install pkg-config`) @@ -220,13 +223,13 @@ Enabling features is done by passing `--features ...` to the build system. When huggingface-cli login ``` -4) Download the code +4) Download the code: ```bash git clone https://github.com/EricLBuehler/mistral.rs.git cd mistral.rs ``` -5) Build or install +5) Build or install: - Base build command ```bash cargo build --release @@ -257,14 +260,14 @@ Enabling features is done by passing `--features ...` to the build system. When ```bash cargo install --path mistralrs-server --features cuda ``` -6) The build process will output a binary `misralrs-server` at `./target/release/mistralrs-server` which may be copied into the working directory with the following command: +6) The build process will output a binary `mistralrs-server` at `./target/release/mistralrs-server` which may be copied into the working directory with the following command: *Example on Ubuntu:* ``` cp ./target/release/mistralrs-server ./mistralrs-server ``` -7) Use our APIs and integrations +7) Use our APIs and integrations: [APIs and integrations list](#apis-and-integrations) @@ -377,8 +380,6 @@ please consider using the method demonstrated in examples below, where the token Mistral.rs uses subcommands to control the model type. They are generally of format `-`. Please run `./mistralrs-server --help` to see the subcommands. -Additionally, for models without quantization, the model architecture should be provided as the `--arch` or `-a` argument in contrast to GGUF models which encode the architecture in the file. - ### Architecture for plain models > Note: for plain models, you can specify the data type to load and run in. This must be one of `f32`, `f16`, `bf16` or `auto` to choose based on the device. This is specified in the `--dype`/`-d` parameter after the model architecture (`plain`). diff --git a/docs/ISQ.md b/docs/ISQ.md index 76cff4fc0..bfaad1a04 100644 --- a/docs/ISQ.md +++ b/docs/ISQ.md @@ -21,6 +21,7 @@ To set the ISQ type for individual layers, use a model [`topology`](TOPOLOGY.md) - Q8K (*not available on CUDA*) - HQQ4 - HQQ8 +- FP8 When using ISQ, it will automatically load ISQ-able weights into CPU memory before applying ISQ. The ISQ application process moves the weights to device memory. This process is implemented to avoid memory spikes from loading the model in full precision. diff --git a/docs/QUANTS.md b/docs/QUANTS.md index 7daa93a1c..6b37d35a0 100644 --- a/docs/QUANTS.md +++ b/docs/QUANTS.md @@ -12,6 +12,7 @@ Mistral.rs supports the following quantization: - Supported in all plain and adapter models - CUDA only - 2, 3, 4, 8 bit + - [Marlin](https://github.com/IST-DASLab/marlin) kernel support in 4-bit and 8-bit. - HQQ - Supported in all plain and adapter models via ISQ - CUDA and CPU only @@ -41,6 +42,7 @@ cargo run --features cuda -- -i --isq Q4K plain -m microsoft/Phi-3-mini-4k-instr - Use the `plain` (cli) / `Plain` (Python) model selector - Provide the model ID for the GPTQ model - Mistral.rs will automatically detect and use GPTQ quantization. +- The [Marlin](https://github.com/IST-DASLab/marlin) kernel will automatically be used in 4-bit and 8-bit. ``` cargo run --features cuda -- -i plain -m kaitchup/Phi-3-mini-4k-instruct-gptq-4bit -a phi3 diff --git a/docs/UQFF.md b/docs/UQFF.md index 7dfa4a30b..6a9686ac3 100644 --- a/docs/UQFF.md +++ b/docs/UQFF.md @@ -51,24 +51,31 @@ The following quantization formats are supported in UQFF. One can, of course, be - HQQ4 - HQQ8 +- FP8: + - FP8 E4M3 (4-bit exponent, 3-bit mantissa) + ## Loading a UQFF model -To load a UQFF model, one should specify the artifact path. This can be either be a path to a UQFF file locally, or a Hugging Face model ID with the format `/`. For example, the following work: +To load a UQFF model, one should specify the filename. This will be located based on the model ID, and can +be loaded locally or from Hugging Face based on the model ID. -- `EricB/Phi-3.5-mini-instruct-ISQ/phi3.5-mini-instruct-q4k.uqff` +- `phi3.5-mini-instruct-q4k.uqff` - `../UQFF/phi3.5-mini-instruct-q4k.uqff` -> Note: when loading an UQFF model, it will take precedence over any ISQ setting. +You can find a [collection of UQFF models here](https://huggingface.co/collections/EricB/uqff-670e4a49d56ecdd3f7f0fd4c), which each include a simple +command to get started. + +> Note: when loading an UQFF model, *any* ISQ setting will be ignored. ### Running with the CLI ``` -cargo run --features cuda -- -i plain -m microsoft/Phi-3.5-mini-instruct --from-uqff EricB/Phi-3.5-mini-instruct-ISQ/phi3.5-mini-instruct-q4k.uqff +./mistralrs-server -i plain -m EricB/Phi-3.5-mini-instruct-UQFF --from-uqff phi3.5-mini-instruct-f8e4m3.uqff ``` ### Using with the Rust API -Modify the Normal or Vision config as follows: +Modify the Normal or Vision config as follows and update the model ID to point to a UQFF model: ```diff NormalSpecificConfig { @@ -78,7 +85,7 @@ NormalSpecificConfig { organization: Default::default(), write_uqff: None, - from_uqff: None, -+ from_uqff: Some("EricB/Phi-3.5-mini-instruct-ISQ/phi3.5-mini-instruct-q4k.uqff".to_string()), ++ from_uqff: Some("phi3.5-mini-instruct-q4k.uqff".to_string()), // Pull from specified HF hub repo } ``` @@ -89,7 +96,7 @@ VisionSpecificConfig { topology: None, write_uqff: None, - from_uqff: None, -+ from_uqff: Some("../UQFF/phi3.5-mini-instruct-q4k.uqff".to_string()), ++ from_uqff: Some("../phi3.5-mini-instruct-q4k.uqff".to_string()), // Local path } ``` @@ -97,8 +104,8 @@ VisionSpecificConfig { Modify the `Which` instantiation as follows: ```diff Which.Plain( - model_id="microsoft/Phi-3.5-mini-instruct", -+ from_uqff="EricB/Phi-3.5-mini-instruct-ISQ/phi3.5-mini-instruct-q4k.uqff" + model_id="EricB/Phi-3.5-mini-instruct-UQFF", ++ from_uqff="phi3.5-mini-instruct-q4k.uqff" ), ``` @@ -109,6 +116,11 @@ Creating a UQFF model requires you to generate the UQFF file. - This means specifying a local path to a file ending in `.uqff`, where your new UQFF model will be created. - The quantization of a UQFF model is determined from the ISQ or model topology (see the [topology docs](TOPOLOGY.md) for more details on how ISQ and the topology mix). +Along with the UQFF file, the generation process will also output several `.json` configuration files and `residual.safetensors`. All of these files are considered the +UQFF model, and should be kept together or uploaded. + +> Note: Only the `.uqff` files are unique to the quantization level(s). If you are generating multiple UQFF files, it is OK for the others to be overwritten. + After creating the UQFF file, you can upload the model to Hugging Face. To do this: 1) [Create a new model](https://huggingface.co/docs/transformers/v4.17.0/en/create_a_model). 2) Upload the UQFF file: @@ -120,7 +132,7 @@ After creating the UQFF file, you can upload the model to Hugging Face. To do th ### Creating with the CLI ``` -cargo run --features cuda -- --isq Q4K -i plain -m microsoft/Phi-3.5-mini-instruct --write-uqff phi3.5-mini-instruct-q4k.uqff +./mistralrs-server --isq Q4K -i plain -m microsoft/Phi-3.5-mini-instruct --write-uqff phi3.5-mini-instruct-q4k.uqff ``` ### Creating with the Rust API @@ -151,7 +163,7 @@ VisionSpecificConfig { ``` ### Creating with the Python API -Modify the `Which` instantiation as follows: +Modify the `Which` instantiation as follows. Be sure to add the `in_situ_quant`. ```diff Which.Plain( model_id="microsoft/Phi-3.5-mini-instruct", @@ -170,10 +182,6 @@ After this, you can use Git to track, commit, and push files. ## List of models -Have you created a UQFF model on Hugging Face? If so, please [create an issue](https://github.com/EricLBuehler/mistral.rs/issues/new) and we will include it here! +You can find a list of models in the [Hugging Face model collection](https://huggingface.co/collections/EricB/uqff-670e4a49d56ecdd3f7f0fd4c). -| Name | Base model | UQFF model | -| -- | -- | -- | -| Phi 3.5 Mini Instruct | microsoft/Phi-3.5-mini-instruct | [EricB/Phi-3.5-mini-instruct-UQFF](EricB/Phi-3.5-mini-instruct-UQFF) | -| Llama 3.2 Vision | meta-llama/Llama-3.2-11B-Vision-Instruct | [EricB/Llama-3.2-11B-Vision-Instruct-UQFF](https://huggingface.co/EricB/Llama-3.2-11B-Vision-Instruct-UQFF) | -| Mistral Nemo 2407 | mistralai/Mistral-Nemo-Instruct-2407 | [EricB/Mistral-Nemo-Instruct-2407-UQFF](https://huggingface.co/EricB/Mistral-Nemo-Instruct-2407-UQFF) | +Have you created a UQFF model on Hugging Face? If so, please [create an issue](https://github.com/EricLBuehler/mistral.rs/issues/new). diff --git a/docs/UQFF/LAYOUT.md b/docs/UQFF/LAYOUT.md index ceabd7c88..c0b29e65d 100644 --- a/docs/UQFF/LAYOUT.md +++ b/docs/UQFF/LAYOUT.md @@ -6,6 +6,7 @@ The following describes the exact memory layout of HQFF tensors of version 0.1.0 - [GGUF quantization](#gguf-quantization) - [HQQ quantization](#hqq-quantization) - [Uquantized layers](#unquantized-layers) +- [FP8 layers](#fp8-layers) - [Standard tensors](#standard-tensors) @@ -32,6 +33,18 @@ The following describes the exact memory layout of HQFF tensors of version 0.1.0 | **Array** Weight tensor data, see [docs](#standard-tensors) | See [docs](#standard-tensors) | See [docs](#standard-tensors) | | **[Optional]** **Array** Bias tensor data, see [docs](#standard-tensors) | See [docs](#standard-tensors) | See [docs](#standard-tensors) | +## FP8 layers +| ID | Element type | Endianness | +| -------- | -------- | -------- | +| HQFF version | u32 | little endian | +| ISQ type (1) | u8 | little endian | +| Whether bias data is included (boolean) | u8 | little endian | +| **Array** Weight tensor data, see [docs](#standard-tensors) | See [docs](#standard-tensors) | See [docs](#standard-tensors) | +| Dequant W scalar | f32 | little endian +| Dequant X scalar | f32 | little endian +| Quant scalar | f32 | little endian +| Quantization type | u32 | little endian +| **[Optional]** **Array** Bias tensor data, see [docs](#standard-tensors) | See [docs](#standard-tensors) | See [docs](#standard-tensors) | ## HQQ quantization | ID | Element type | Endianness | @@ -51,6 +64,19 @@ The following describes the exact memory layout of HQFF tensors of version 0.1.0 | CFG round zeroes (boolean) | u8 | little endian | | CFG channel wise (boolean) | u8 | little endian | +## FP8 layers +| ID | Element type | Endianness | +| -------- | -------- | -------- | +| HQFF version | u32 | little endian | +| ISQ type (3) | u8 | little endian | +| Whether bias data is included (boolean) | u8 | little endian | +| **Array** Weight tensor data, see [docs](#standard-tensors) | See [docs](#standard-tensors) | See [docs](#standard-tensors) | +| Dequant scale W | f32 | little endian | +| Dequant scale X | f32 | little endian | +| Quant scale | f32 | little endian | +| Layer dtype | u32 | little endian | +| **[Optional]** **Array** Bias tensor data, see [docs](#standard-tensors) | See [docs](#standard-tensors) | See [docs](#standard-tensors) | + ## Standard tensors | ID | Element type | Endianness | | -------- | -------- | -------- | diff --git a/mistralrs-core/Cargo.toml b/mistralrs-core/Cargo.toml index dbcfab697..3a6e8a060 100644 --- a/mistralrs-core/Cargo.toml +++ b/mistralrs-core/Cargo.toml @@ -17,7 +17,7 @@ candle-core.workspace = true candle-nn.workspace = true serde.workspace = true serde_json.workspace = true -candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4", optional = true } +candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "f2b6941", optional = true } dirs = "5.0.1" hf-hub = "0.3.2" thiserror = "1.0.57" @@ -78,10 +78,11 @@ regex = "1.10.6" safetensors = "0.4.5" serde_plain = "1.0.2" as-any = "0.3.1" +float8.workspace = true [features] pyo3_macros = ["pyo3"] -cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:bindgen_cuda", "mistralrs-quant/cuda", "dep:mistralrs-paged-attn", "mistralrs-paged-attn/cuda"] +cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:bindgen_cuda", "mistralrs-quant/cuda", "dep:mistralrs-paged-attn", "mistralrs-paged-attn/cuda", "float8/cuda"] cudnn = ["candle-core/cudnn"] metal = ["candle-core/metal", "candle-nn/metal"] flash-attn = ["cuda", "dep:candle-flash-attn"] diff --git a/mistralrs-core/src/common_models/t5/mod.rs b/mistralrs-core/src/common_models/t5/mod.rs index dc4918ee4..b51592725 100644 --- a/mistralrs-core/src/common_models/t5/mod.rs +++ b/mistralrs-core/src/common_models/t5/mod.rs @@ -5,6 +5,7 @@ use candle_core::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, VarBuilder}; +use float8::F8E4M3; use serde::Deserialize; use std::sync::Arc; @@ -596,6 +597,7 @@ impl TensorInfExtend for Tensor { DType::BF16 => Ok(sum.to_scalar::()? == half::bf16::from_f32_const(0.)), DType::F32 => Ok(sum.to_scalar::()? == 0.), DType::F64 => Ok(sum.to_scalar::()? == 0.), + DType::F8E4M3 => Ok(sum.to_scalar::()? == F8E4M3::ZERO), } } } @@ -611,6 +613,7 @@ fn clamp_for_f16(xs: &Tensor) -> Result { DType::BF16 => half::bf16::MAX.to_f64_const() - 1000., DType::F32 => f32::MAX as f64 - 1000., DType::F64 => f64::MAX - 1000., + DType::F8E4M3 => F8E4M3::MAX.to_f64() - 1000., }; if xs.is_inf()?.any()? { max -= 1000.; diff --git a/mistralrs-core/src/cublaslt/api.rs b/mistralrs-core/src/cublaslt/api.rs index 24aca6ba2..8bb11d028 100644 --- a/mistralrs-core/src/cublaslt/api.rs +++ b/mistralrs-core/src/cublaslt/api.rs @@ -1,13 +1,14 @@ -pub use candle_core::cuda_backend::cudarc::cublaslt::Activation; +use candle_core::cuda::cudarc::driver::DevicePtr; +use float8::F8E4M3; use std::ffi::c_int; use candle_core::backend::BackendStorage; use candle_core::cuda_backend::WrapErr; -use candle_core::{CpuStorage, Device, Layout, Result, Shape, Storage, Tensor}; +use candle_core::{CpuStorage, DType, Device, Layout, Result, Shape, Storage, Tensor}; use half::{bf16, f16}; use std::sync::Arc; -use candle_core::cuda_backend::cudarc::cublaslt::{CudaBlasLT, Matmul, MatmulConfig}; +use super::matmul::{Activation, CudaBlasLT, Matmul, MatmulConfig}; #[derive(Debug, Clone)] pub struct CublasLt(Arc); @@ -858,11 +859,12 @@ pub fn fused_batch_matmul( a.apply_op2(b, op) } } - #[cfg(test)] mod tests { + use std::f32::consts::PI; + use super::*; - use candle_core::{DType, Device}; + use candle_core::{DType, Device, IndexOp}; fn to_vec2_round(t: Tensor, digits: i32) -> Result>> { let b = 10f32.powi(digits); diff --git a/mistralrs-core/src/cublaslt/matmul.rs b/mistralrs-core/src/cublaslt/matmul.rs new file mode 100644 index 000000000..898a30522 --- /dev/null +++ b/mistralrs-core/src/cublaslt/matmul.rs @@ -0,0 +1,453 @@ +use candle_core::cuda::cudarc::cublaslt::result::set_matrix_layout_attribute; +use candle_core::cuda::cudarc::cublaslt::{result, result::CublasError, sys}; +use candle_core::cuda::cudarc::driver::sys::{CUdevice_attribute, CUdeviceptr, CUstream}; +use candle_core::cuda::cudarc::driver::{ + CudaDevice, CudaSlice, DevicePtr, DevicePtrMut, DriverError, +}; +use core::ffi::c_int; +use core::mem; +use float8::F8E4M3; +use half::bf16; +use std::sync::Arc; + +/// Wrapper around [sys::cublasLtHandle_t] +/// +/// 1. Create with [CudaBlasLT::new()] +/// 2. Execute matmul kernel with matmul. f32 is supported. f16 and bf16 are supported +/// if feature `half` is activated +/// +/// Note: This maintains a instance of [`Arc`], so will prevent the device +/// from being dropped. Kernels will be launched on the device device default stream. +#[derive(Debug)] +pub struct CudaBlasLT { + handle: sys::cublasLtHandle_t, + workspace: Workspace, + device: Arc, +} + +unsafe impl Send for CudaBlasLT {} + +unsafe impl Sync for CudaBlasLT {} + +impl CudaBlasLT { + /// Creates a new cublasLt handle. + pub fn new(device: Arc) -> Result { + let handle = result::create_handle()?; + let workspace = Workspace::new(device.clone()).unwrap(); + + Ok(Self { + handle, + workspace, + device, + }) + } +} + +impl Drop for CudaBlasLT { + fn drop(&mut self) { + let handle = mem::replace(&mut self.handle, std::ptr::null_mut()); + if !handle.is_null() { + unsafe { result::destroy_handle(handle) }.unwrap(); + } + } +} + +/// User owned CublasLt workspace buffer. +/// The workspace is initialised following the Nvidia recommendations: +/// +/// 1. NVIDIA Hopper Architecture: 32 MiB +/// 2. Other: 4 MiB +#[derive(Debug, Clone)] +pub struct Workspace { + pub(crate) buffer: CudaSlice, + pub(crate) size: usize, +} + +impl Workspace { + /// Creates a CublasLt workspace buffer on the provided device + pub fn new(device: Arc) -> Result { + device.bind_to_thread()?; + + let major = + device.attribute(CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR)?; + let workspace_size = if major >= 9 { 33_554_432 } else { 4_194_304 }; + + let buffer = unsafe { device.alloc::(workspace_size)? }; + Ok(Self { + buffer, + size: workspace_size, + }) + } +} + +/// Available activation for kernel fusing in matmul +#[derive(Debug, Clone)] +pub enum Activation { + Relu, + Gelu, +} + +/// MatrixLayout helper type +struct MatrixLayout { + handle: sys::cublasLtMatrixLayout_t, +} + +impl MatrixLayout { + fn new( + matrix_type: sys::cudaDataType, + rows: u64, + cols: u64, + ld: i64, + ) -> Result { + let handle = result::create_matrix_layout(matrix_type, rows, cols, ld)?; + Ok(Self { handle }) + } + + fn set_batch(&self, size: c_int, stride: i64) -> Result<(), CublasError> { + unsafe { + // Set batch size + set_matrix_layout_attribute( + self.handle, + sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + (&size) as *const _ as *const _, + mem::size_of::(), + )?; + // Set batch stride + set_matrix_layout_attribute( + self.handle, + sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + (&stride) as *const _ as *const _, + mem::size_of::(), + )?; + } + Ok(()) + } +} + +impl Drop for MatrixLayout { + fn drop(&mut self) { + // panic on failure + unsafe { + result::destroy_matrix_layout(self.handle).expect("Unable to destroy matrix layout") + } + } +} + +enum Matrix { + A, + B, + C, + D, +} + +/// MatmulDesc helper type +struct MatmulDesc { + handle: sys::cublasLtMatmulDesc_t, +} + +impl MatmulDesc { + fn new( + compute_type: sys::cublasComputeType_t, + scale_type: sys::cudaDataType, + ) -> Result { + let handle = result::create_matmul_desc(compute_type, scale_type)?; + Ok(Self { handle }) + } + + fn set_transpose(&self, transpose: bool, matrix: Matrix) -> Result<(), CublasError> { + // Set transpose + // 1 == T, 0 == N + let transpose = transpose as i32; + let attr = match matrix { + Matrix::A => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSA, + Matrix::B => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSB, + Matrix::C => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSC, + Matrix::D => unreachable!(), + }; + + unsafe { + result::set_matmul_desc_attribute( + self.handle, + attr, + (&transpose) as *const _ as *const _, + mem::size_of::(), + )?; + } + Ok(()) + } + + // Epilogue system can be leveraged to fuse add and activation operations + fn set_epilogue( + &self, + act: Option<&Activation>, + bias_ptr: Option<&CUdeviceptr>, + stride_bias: Option, + ) -> Result<(), CublasError> { + let epilogue = if let Some(bias_ptr) = bias_ptr { + let epilogue = act + .map(|act| match act { + // Act + bias + Activation::Relu => sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_BIAS, + Activation::Gelu => sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_BIAS, + }) + // Only bias + .unwrap_or(sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BIAS); + + // Set bias CUdeviceptr in matmul_desc + unsafe { + result::set_matmul_desc_attribute( + self.handle, + sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_POINTER, + bias_ptr as *const CUdeviceptr as *const _, + mem::size_of::(), + )?; + } + + if let Some(stride_bias) = stride_bias { + // Set bias batch stride + unsafe { + result::set_matmul_desc_attribute( + self.handle, + sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_BATCH_STRIDE, + (&stride_bias) as *const _ as *const _, + mem::size_of::(), + )?; + } + } + epilogue + } else if let Some(act) = act { + // Only Act + match act { + Activation::Relu => sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU, + Activation::Gelu => sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU, + } + } else { + // No epilogue + sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT + }; + + // Set epilogue + unsafe { + result::set_matmul_desc_attribute( + self.handle, + sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_EPILOGUE, + (&epilogue) as *const _ as *const _, + mem::size_of::(), + )?; + } + Ok(()) + } +} + +impl Drop for MatmulDesc { + fn drop(&mut self) { + unsafe { result::destroy_matmul_desc(self.handle).expect("Unable to destroy matmul desc") } + } +} + +/// MatmulPref helper type +struct MatmulPref { + handle: sys::cublasLtMatmulPreference_t, +} + +impl MatmulPref { + fn new() -> Result { + let handle = result::create_matmul_pref()?; + Ok(Self { handle }) + } + + fn set_workspace_size(&self, size: usize) -> Result<(), CublasError> { + unsafe { + // Set workspace size + result::set_matmul_pref_attribute( + self.handle, + sys::cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + (&size) as *const _ as *const _, + mem::size_of::(), + )?; + } + Ok(()) + } +} + +impl Drop for MatmulPref { + fn drop(&mut self) { + unsafe { result::destroy_matmul_pref(self.handle).expect("Unable to destroy matmul pref") } + } +} + +/// [Matmul] super-trait +pub trait MatmulShared { + /// Returns a reference to the underlying cublasLt handle. + fn handle(&self) -> &sys::cublasLtHandle_t; + + /// Returns a reference to the underlying cublasLt workspace + fn workspace(&self) -> &Workspace; + + /// Returns a reference to the underlying stream + fn stream(&self) -> &CUstream; +} + +/// Configuration for [Matmul] +#[derive(Debug, Copy, Clone)] +pub struct MatmulConfig { + pub transa: bool, + pub transb: bool, + pub m: u64, + pub n: u64, + pub k: u64, + pub alpha: f32, + pub lda: i64, + pub ldb: i64, + pub beta: f32, + pub ldc: i64, + pub stride_a: Option, + pub stride_b: Option, + pub stride_c: Option, + pub stride_bias: Option, + pub batch_size: Option, +} + +/// Matrix matrix multiplication with elements of type `T`. +pub trait Matmul: MatmulShared { + /// Underlying CUDA Type for `T` + fn matrix_type() -> sys::cudaDataType; + + /// Underlying CUDA Compute Type for `T` + fn compute_type() -> sys::cublasComputeType_t; + + /// Matrix matrix multiplication. See + /// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmul) + /// + /// # Safety + /// This is unsafe because improper arguments may lead to invalid + /// memory accesses. + unsafe fn matmul, O: DevicePtrMut>( + &self, + cfg: MatmulConfig, + a: &I, + b: &I, + c: &mut O, + bias: Option<&I>, + act: Option<&Activation>, + ) -> Result<(), CublasError> { + let (a_rows, a_cols) = if cfg.transa { + (cfg.k, cfg.m) + } else { + (cfg.m, cfg.k) + }; + let (b_rows, b_cols) = if cfg.transb { + (cfg.n, cfg.k) + } else { + (cfg.k, cfg.n) + }; + + // Creates matrix layouts + let a_layout = MatrixLayout::new(Self::matrix_type(), a_rows, a_cols, cfg.lda)?; + if let (Some(batch_size), Some(stride_a)) = (cfg.batch_size, cfg.stride_a) { + a_layout.set_batch(batch_size, stride_a)?; + } + + let b_layout = MatrixLayout::new(Self::matrix_type(), b_rows, b_cols, cfg.ldb)?; + if let (Some(batch_size), Some(stride_b)) = (cfg.batch_size, cfg.stride_b) { + b_layout.set_batch(batch_size, stride_b)?; + } + + let c_layout = MatrixLayout::new(Self::matrix_type(), cfg.m, cfg.n, cfg.ldc)?; + if let (Some(batch_size), Some(stride_c)) = (cfg.batch_size, cfg.stride_c) { + c_layout.set_batch(batch_size, stride_c)?; + } + + // Matmul description + let matmul_desc = MatmulDesc::new(Self::compute_type(), sys::cudaDataType_t::CUDA_R_32F)?; + + // Set transa + matmul_desc.set_transpose(cfg.transa, Matrix::A)?; + // Set transb + matmul_desc.set_transpose(cfg.transb, Matrix::B)?; + + // Epilogue system can be leveraged to fuse add and activation operations + matmul_desc.set_epilogue(act, bias.map(|b| b.device_ptr()), cfg.stride_bias)?; + + // Create matmul heuristic search preferences + let matmul_pref = MatmulPref::new()?; + + // Set workspace size + matmul_pref.set_workspace_size(self.workspace().size)?; + + // Get heuristic given Config, bias, act and workspace size + let heuristic = result::get_matmul_algo_heuristic( + *self.handle(), + matmul_desc.handle, + a_layout.handle, + b_layout.handle, + c_layout.handle, + c_layout.handle, + matmul_pref.handle, + )?; + + // Launch matmul kernel + result::matmul( + *self.handle(), + matmul_desc.handle, + (&cfg.alpha) as *const _ as *const _, + (&cfg.beta) as *const _ as *const _, + *a.device_ptr() as *const _, + a_layout.handle, + *b.device_ptr() as *const _, + b_layout.handle, + *c.device_ptr_mut() as *const _, + c_layout.handle, + *c.device_ptr_mut() as *mut _, + c_layout.handle, + (&heuristic.algo) as *const _, + *self.workspace().buffer.device_ptr() as *const CUdeviceptr as *mut _, + self.workspace().size, + *self.stream() as *mut _, + ) + } +} + +impl MatmulShared for CudaBlasLT { + fn handle(&self) -> &sys::cublasLtHandle_t { + &self.handle + } + + fn workspace(&self) -> &Workspace { + &self.workspace + } + + fn stream(&self) -> &CUstream { + self.device.cu_stream() + } +} + +impl Matmul for CudaBlasLT { + fn matrix_type() -> sys::cudaDataType { + sys::cudaDataType_t::CUDA_R_32F + } + + fn compute_type() -> sys::cublasComputeType_t { + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32 + } +} + +impl Matmul for CudaBlasLT { + fn matrix_type() -> sys::cudaDataType { + sys::cudaDataType_t::CUDA_R_16F + } + + fn compute_type() -> sys::cublasComputeType_t { + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F + } +} + +impl Matmul for CudaBlasLT { + fn matrix_type() -> sys::cudaDataType { + sys::cudaDataType_t::CUDA_R_16BF + } + + fn compute_type() -> sys::cublasComputeType_t { + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F + } +} diff --git a/mistralrs-core/src/cublaslt/mod.rs b/mistralrs-core/src/cublaslt/mod.rs index 7657186ca..9d6046b38 100644 --- a/mistralrs-core/src/cublaslt/mod.rs +++ b/mistralrs-core/src/cublaslt/mod.rs @@ -9,9 +9,11 @@ use std::sync::{Mutex, Once}; #[cfg(feature = "cuda")] mod api; +#[cfg(feature = "cuda")] +mod matmul; #[cfg(feature = "cuda")] -use api::{fused_batch_matmul, fused_matmul, Activation, CublasLt}; +use api::{fused_batch_matmul, fused_matmul, CublasLt}; static INIT: Once = Once::new(); static mut CUBLASLT: Option = None; @@ -70,8 +72,8 @@ impl CublasLtWrapper { #[cfg(feature = "cuda")] { let inner_act = act.map(|a| match a { - CandleActivation::Relu => Activation::Relu, - CandleActivation::Gelu => Activation::Gelu, + CandleActivation::Relu => matmul::Activation::Relu, + CandleActivation::Gelu => matmul::Activation::Gelu, _ => unreachable!("Unsupported activation in cublaslt matmul"), }); let mut result = fused_batch_matmul( diff --git a/mistralrs-core/src/layers.rs b/mistralrs-core/src/layers.rs index 14be94939..4d9afe5e8 100644 --- a/mistralrs-core/src/layers.rs +++ b/mistralrs-core/src/layers.rs @@ -18,7 +18,7 @@ use candle_nn::{ Conv1d, Conv1dConfig, ConvTranspose1d, ConvTranspose1dConfig, Linear, Module, VarBuilder, }; use mistralrs_quant::QuantMethod; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; pub use crate::attention::Sdpa; pub use crate::layers_masker::CausalMasker; @@ -51,9 +51,21 @@ impl RmsNorm { Ok(Self { eps, weight: w }) } + /// Gemma uses weight + 1.0. Undo for UQFF generation. + pub fn undo_gemma(&self) -> Result { + Ok(Self { + eps: self.eps, + weight: (&self.weight - 1.0)?, + }) + } + pub fn from_w(w: Tensor, eps: f64) -> Result { Ok(Self { eps, weight: w }) } + + pub fn weight(&self) -> &Tensor { + &self.weight + } } impl Module for RmsNorm { @@ -92,7 +104,8 @@ pub struct PhiRotaryEmbedding { original_max_position_embeddings: usize, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] pub enum ScaledRopeType { #[serde(alias = "su")] #[serde(alias = "longrope")] @@ -114,7 +127,7 @@ impl FromStr for ScaledRopeType { } } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] #[serde(untagged)] pub enum PhiRopeScalingConfig { Classic { @@ -395,7 +408,7 @@ pub enum Llama3RotaryEmbedding { Default(RotaryEmbedding), } -#[derive(Debug, Clone, Deserialize, Default)] +#[derive(Debug, Clone, Deserialize, Serialize, Default)] pub enum Llama3RopeType { #[serde(rename = "llama3")] Llama3, @@ -404,7 +417,7 @@ pub enum Llama3RopeType { Default, } -#[derive(Debug, Clone, Deserialize, Default)] +#[derive(Debug, Clone, Deserialize, Serialize, Default)] pub struct Llama3RopeConfig { pub factor: f32, pub low_freq_factor: f32, @@ -872,6 +885,51 @@ impl RotaryEmbedding { } } +#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum Activation { + #[default] + #[serde(alias = "gelu")] + Gelu, + #[serde(alias = "gelu_new")] + NewGelu, + Relu, + Relu2, + Relu6, + Silu, + Sigmoid, + HardSigmoid, + Swiglu, + Swish, + HardSwish, + Elu(f64), + LeakyRelu(f64), + #[serde(alias = "gelu_pytorch_tanh")] + GeluPytorchTanh, +} + +impl Module for Activation { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::Gelu => xs.gelu_erf(), + // https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78 + Self::NewGelu => xs.gelu(), + Self::Relu => xs.relu(), + Self::Relu2 => xs.relu()?.sqr(), + Self::Relu6 => xs.clamp(0f32, 6f32), + Self::Silu => xs.silu(), + Self::Sigmoid => candle_nn::ops::sigmoid(xs), + Self::HardSigmoid => candle_nn::ops::hard_sigmoid(xs), + Self::Swiglu => candle_nn::ops::swiglu(xs), + Self::Swish => xs * candle_nn::ops::sigmoid(xs)?, + Self::HardSwish => xs * candle_nn::ops::hard_sigmoid(xs)?, + &Self::Elu(alpha) => xs.elu(alpha), + &Self::LeakyRelu(negative_slope) => candle_nn::ops::leaky_relu(xs, negative_slope), + Self::GeluPytorchTanh => xs.gelu(), + } + } +} + fn lp_norm(xs: &Tensor, p: usize, dim: usize) -> Result { let l2_norm = xs.powf(p as f64)?.sum_keepdim(dim)?.sqrt()?; Ok(l2_norm) //xs.broadcast_div(&l2_norm) diff --git a/mistralrs-core/src/lib.rs b/mistralrs-core/src/lib.rs index 25674e5c4..e38ae1632 100644 --- a/mistralrs-core/src/lib.rs +++ b/mistralrs-core/src/lib.rs @@ -1,5 +1,6 @@ #![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)] +use candle_core::Device; use cublaslt::setup_cublas_lt_wrapper; use engine::Engine; pub use engine::{EngineInstruction, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP}; @@ -107,6 +108,11 @@ pub use utils::paged_attn_supported; pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false); static ENGINE_ID: AtomicUsize = AtomicUsize::new(0); +pub struct MistralRsConfig { + pub kind: ModelKind, + pub device: Device, +} + /// The MistralRs struct handles sending requests to the engine. /// It is the core multi-threaded component of mistral.rs, and uses `mspc` /// `Sender` and `Receiver` primitives to send and receive requests to the @@ -121,6 +127,7 @@ pub struct MistralRs { engine_handler: RwLock>, engine_id: usize, category: ModelCategory, + config: MistralRsConfig, } #[derive(Clone)] @@ -324,6 +331,10 @@ impl MistralRs { let sender = RwLock::new(tx); let id = pipeline.try_lock().unwrap().name(); + let kind = pipeline.try_lock().unwrap().get_metadata().kind.clone(); + let device = pipeline.try_lock().unwrap().device(); + let config = MistralRsConfig { kind, device }; + let engine_handler = thread::spawn(move || { let rt = Runtime::new().unwrap(); rt.block_on(async move { @@ -357,6 +368,7 @@ impl MistralRs { reboot_state, engine_handler: RwLock::new(engine_handler), category, + config, }) } @@ -483,4 +495,8 @@ impl MistralRs { .expect("Unable to write data"); } } + + pub fn config(&self) -> &MistralRsConfig { + &self.config + } } diff --git a/mistralrs-core/src/models/gemma.rs b/mistralrs-core/src/models/gemma.rs index 16d17d4cd..8a1b4ef1f 100644 --- a/mistralrs-core/src/models/gemma.rs +++ b/mistralrs-core/src/models/gemma.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, sync::Arc}; use candle_core::{DType, Device, Module, Result, Tensor}; -use candle_nn::{Activation, Linear, RotaryEmbedding, VarBuilder}; +use candle_nn::{Linear, RotaryEmbedding, VarBuilder}; use mistralrs_quant::{QuantMethod, QuantMethodConfig, QuantizedConfig, UnquantLinear}; use crate::{ @@ -14,7 +14,7 @@ use crate::{ attention::SdpaParams, device_map::DeviceMapper, get_delta_from_lora_ab, - layers::{CausalMasker, MatMul, RmsNorm, Sdpa}, + layers::{Activation, CausalMasker, MatMul, RmsNorm, Sdpa}, layers_masker::PastKvLenCache, paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention}, pipeline::{ @@ -23,7 +23,7 @@ use crate::{ Cache, IsqModel, NormalLoadingMetadata, NormalModel, }, serde_default_fn, - utils::progress::NiceProgressBar, + utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder}, }; fn default_max_position_embeddings() -> usize { @@ -32,7 +32,7 @@ fn default_max_position_embeddings() -> usize { serde_default_fn!(bool, word_emb_default, false); -#[derive(serde::Deserialize, Debug, Clone, Default)] +#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, Default)] pub struct Config { pub attention_bias: bool, pub head_dim: usize, @@ -75,7 +75,7 @@ struct MLP { gate_proj: Arc, up_proj: Arc, down_proj: Arc, - act_fn: candle_nn::Activation, + act_fn: Activation, params: Vec, } @@ -538,7 +538,7 @@ impl Model { num_kv_heads: cfg.num_key_value_heads, num_attn_heads: cfg.num_attention_heads, sliding_window: None, - head_dim: None, + head_dim: Some(cfg.head_dim), }, }) } @@ -615,6 +615,26 @@ impl IsqModel for Model { } (tensors, &*self.mapper) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + let uvb_m = uvb.pp("model"); + uvb_m.pp("embed_tokens").add(&self.embed_tokens); + uvb_m.pp("norm").add(&self.norm.undo_gemma().unwrap()); + + for (layer_idx, layer) in self.layers.iter().enumerate() { + let uvb_l = uvb_m.pp("layers").pp(layer_idx); + uvb_l + .pp("input_layernorm") + .add(&layer.input_layernorm.undo_gemma().unwrap()); + uvb_l + .pp("post_attention_layernorm") + .add(&layer.post_attention_layernorm.undo_gemma().unwrap()); + } + + uvb.to_safetensors() + } } impl NormalModel for Model { diff --git a/mistralrs-core/src/models/gemma2.rs b/mistralrs-core/src/models/gemma2.rs index 3edba1491..6448e6407 100644 --- a/mistralrs-core/src/models/gemma2.rs +++ b/mistralrs-core/src/models/gemma2.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, sync::Arc}; use candle_core::{DType, Device, Module, Result, Tensor}; -use candle_nn::{Activation, Linear, RotaryEmbedding, VarBuilder}; +use candle_nn::{Linear, RotaryEmbedding, VarBuilder}; use mistralrs_quant::{QuantMethod, QuantMethodConfig, QuantizedConfig, UnquantLinear}; use crate::{ @@ -14,17 +14,17 @@ use crate::{ attention::SdpaParams, device_map::DeviceMapper, get_delta_from_lora_ab, - layers::{CausalMasker, MatMul, RmsNorm, Sdpa}, + layers::{Activation, CausalMasker, MatMul, RmsNorm, Sdpa}, paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention}, pipeline::{ extract_logits, text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata}, Cache, IsqModel, NormalLoadingMetadata, NormalModel, }, - utils::progress::NiceProgressBar, + utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder}, }; -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, serde::Serialize)] pub struct Config { pub attention_bias: bool, pub head_dim: usize, @@ -69,7 +69,7 @@ struct MLP { gate_proj: Arc, up_proj: Arc, down_proj: Arc, - act_fn: candle_nn::Activation, + act_fn: Activation, params: Vec, } @@ -687,6 +687,32 @@ impl IsqModel for Model { } (tensors, &*self.mapper) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + let uvb_m = uvb.pp("model"); + uvb_m.pp("embed_tokens").add(&self.embed_tokens); + uvb_m.pp("norm").add(&self.norm.undo_gemma().unwrap()); + + for (layer_idx, layer) in self.layers.iter().enumerate() { + let uvb_l = uvb_m.pp("layers").pp(layer_idx); + uvb_l + .pp("input_layernorm") + .add(&layer.input_layernorm.undo_gemma().unwrap()); + uvb_l + .pp("post_attention_layernorm") + .add(&layer.post_attention_layernorm.undo_gemma().unwrap()); + uvb_l + .pp("pre_feedforward_layernorm") + .add(&layer.pre_feedforward_layernorm.undo_gemma().unwrap()); + uvb_l + .pp("post_feedforward_layernorm") + .add(&layer.post_feedforward_layernorm.undo_gemma().unwrap()); + } + + uvb.to_safetensors() + } } impl NormalModel for Model { diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 7969b637d..0f10d03e9 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -3,7 +3,7 @@ use candle_core::{DType, Device, Result, Tensor}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; use mistralrs_quant::{QuantMethod, QuantMethodConfig, QuantizedConfig, UnquantLinear}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use std::{collections::HashMap, sync::Arc}; use crate::{ @@ -23,12 +23,12 @@ use crate::{ IsqModel, NormalLoadingMetadata, NormalModel, }, serde_default_fn, - utils::progress::NiceProgressBar, + utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder}, }; serde_default_fn!(bool, word_emb_default, false); -#[derive(Debug, Clone, Deserialize, Default)] +#[derive(Debug, Clone, Deserialize, Serialize, Default)] pub struct Config { pub hidden_size: usize, pub intermediate_size: usize, @@ -585,6 +585,22 @@ impl IsqModel for Llama { } (tensors, &*self.mapper) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + let uvb_m = uvb.pp("model"); + uvb_m.pp("embed_tokens").add(&self.wte); + uvb_m.pp("norm").add(&self.ln_f); + + for (layer_idx, layer) in self.blocks.iter().enumerate() { + let uvb_l = uvb_m.pp("layers").pp(layer_idx); + uvb_l.pp("input_layernorm").add(&layer.rms_1); + uvb_l.pp("post_attention_layernorm").add(&layer.rms_2); + } + + uvb.to_safetensors() + } } impl NormalModel for Llama { diff --git a/mistralrs-core/src/models/mistral.rs b/mistralrs-core/src/models/mistral.rs index 9aad4cc25..af86aa02d 100644 --- a/mistralrs-core/src/models/mistral.rs +++ b/mistralrs-core/src/models/mistral.rs @@ -2,8 +2,9 @@ /// Mistral LLM, https://github.com/mistralai/mistral-src use candle_core::{DType, Device, Module, Result, Tensor}; -use candle_nn::{Activation, VarBuilder}; +use candle_nn::VarBuilder; use mistralrs_quant::{QuantMethod, QuantMethodConfig, QuantizedConfig, UnquantLinear}; +use serde::Serialize; use std::{collections::HashMap, sync::Arc}; use crate::{ @@ -14,7 +15,7 @@ use crate::{ attention::SdpaParams, device_map::DeviceMapper, get_delta_from_lora_ab, - layers::{CausalMasker, MatMul, RmsNorm, RotaryEmbedding, Sdpa}, + layers::{Activation, CausalMasker, MatMul, RmsNorm, RotaryEmbedding, Sdpa}, layers_masker::PastKvLenCache, paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention}, pipeline::{ @@ -22,10 +23,10 @@ use crate::{ text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata}, Cache, IsqModel, NormalLoadingMetadata, NormalModel, }, - utils::progress::NiceProgressBar, + utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder}, }; -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, Serialize)] pub struct Config { pub(crate) vocab_size: usize, pub(crate) hidden_size: usize, @@ -658,6 +659,24 @@ impl IsqModel for Model { } (tensors, &*self.mapper) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + let uvb_m = uvb.pp("model"); + uvb_m.pp("embed_tokens").add(&self.embed_tokens); + uvb_m.pp("norm").add(&self.norm); + + for (layer_idx, layer) in self.layers.iter().enumerate() { + let uvb_l = uvb_m.pp("layers").pp(layer_idx); + uvb_l.pp("input_layernorm").add(&layer.input_layernorm); + uvb_l + .pp("post_attention_layernorm") + .add(&layer.post_attention_layernorm); + } + + uvb.to_safetensors() + } } impl NormalModel for Model { diff --git a/mistralrs-core/src/models/mixtral.rs b/mistralrs-core/src/models/mixtral.rs index ba37dab52..147204187 100644 --- a/mistralrs-core/src/models/mixtral.rs +++ b/mistralrs-core/src/models/mixtral.rs @@ -4,16 +4,16 @@ /// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py /// https://mistral.ai/news/mixtral-of-experts/ use candle_core::{DType, Device, Module, Result, Tensor}; -use candle_nn::{Activation, RotaryEmbedding, VarBuilder}; +use candle_nn::{RotaryEmbedding, VarBuilder}; use mistralrs_quant::{QuantMethod, QuantMethodConfig, QuantizedConfig, UnquantLinear}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use std::{collections::HashMap, sync::Arc}; use crate::{ amoe::AnyMoeBaseModelMixin, attention::SdpaParams, device_map::DeviceMapper, - layers::{CausalMasker, MatMul, RmsNorm, Sdpa}, + layers::{Activation, CausalMasker, MatMul, RmsNorm, Sdpa}, layers_masker::PastKvLenCache, paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention}, pipeline::{ @@ -22,13 +22,13 @@ use crate::{ Cache, IsqModel, NormalLoadingMetadata, NormalModel, }, serde_default_fn, - utils::progress::NiceProgressBar, + utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder}, }; serde_default_fn!(bool, word_emb_default, false); /// https://github.com/huggingface/transformers/blob/1a585c1222a56bcaecc070966d558d4a9d862e83/src/transformers/models/mixtral/configuration_mixtral.py#L113 -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct Config { pub(crate) vocab_size: usize, pub(crate) hidden_size: usize, @@ -663,6 +663,24 @@ impl IsqModel for Model { } (tensors, &*self.mapper) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + let uvb_m = uvb.pp("model"); + uvb_m.pp("embed_tokens").add(&self.embed_tokens); + uvb_m.pp("norm").add(&self.norm); + + for (layer_idx, layer) in self.layers.iter().enumerate() { + let uvb_l = uvb_m.pp("layers").pp(layer_idx); + uvb_l.pp("input_layernorm").add(&layer.input_layernorm); + uvb_l + .pp("post_attention_layernorm") + .add(&layer.post_attention_layernorm); + } + + uvb.to_safetensors() + } } impl NormalModel for Model { diff --git a/mistralrs-core/src/models/phi2.rs b/mistralrs-core/src/models/phi2.rs index 913275e2e..9f07075cc 100644 --- a/mistralrs-core/src/models/phi2.rs +++ b/mistralrs-core/src/models/phi2.rs @@ -7,11 +7,9 @@ use std::{collections::HashMap, sync::Arc}; /// This corresponds to the model update made with the following commit: /// https://huggingface.co/microsoft/phi-2/commit/cb2f4533604d8b67de604e7df03bfe6f3ca22869 use candle_core::{DType, Device, Result, Tensor}; -use candle_nn::{ - embedding, layer_norm, Activation, Embedding, LayerNorm, RotaryEmbedding, VarBuilder, -}; +use candle_nn::{embedding, layer_norm, Embedding, LayerNorm, RotaryEmbedding, VarBuilder}; use mistralrs_quant::{QuantMethod, QuantizedConfig}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use crate::{ amoe::{ @@ -21,7 +19,7 @@ use crate::{ attention::SdpaParams, device_map::DeviceMapper, get_delta_from_lora_ab, - layers::{CausalMasker, MatMul, Sdpa}, + layers::{Activation, CausalMasker, MatMul, Sdpa}, layers_masker::PastKvLenCache, paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention}, pipeline::{ @@ -30,13 +28,13 @@ use crate::{ Cache, IsqModel, NormalLoadingMetadata, NormalModel, }, serde_default_fn, - utils::progress::NiceProgressBar, + utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder}, }; serde_default_fn!(bool, word_emb_default, false); // https://huggingface.co/microsoft/phi-2/blob/main/configuration_phi.py -#[derive(Debug, Clone, Deserialize, Default)] +#[derive(Debug, Clone, Deserialize, Default, Serialize)] pub struct Config { pub(crate) vocab_size: usize, pub(crate) hidden_size: usize, @@ -599,6 +597,21 @@ impl IsqModel for Model { } (tensors, &*self.mapper) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + let uvb_m = uvb.pp("model"); + uvb_m.pp("embed_tokens").add(&self.embed_tokens); + uvb_m.pp("norm").add(&self.final_layernorm); + + for (layer_idx, layer) in self.layers.iter().enumerate() { + let uvb_l = uvb_m.pp("layers").pp(layer_idx); + uvb_l.pp("input_layernorm").add(&layer.input_layernorm); + } + + uvb.to_safetensors() + } } impl NormalModel for Model { diff --git a/mistralrs-core/src/models/phi3.rs b/mistralrs-core/src/models/phi3.rs index b0d9ddce2..c371ab0f9 100644 --- a/mistralrs-core/src/models/phi3.rs +++ b/mistralrs-core/src/models/phi3.rs @@ -16,8 +16,8 @@ use crate::{ device_map::DeviceMapper, get_delta_from_lora_ab, layers::{ - CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig, PhiRotaryEmbedding, RmsNorm, - Sdpa, + Activation, CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig, PhiRotaryEmbedding, + RmsNorm, Sdpa, }, layers_masker::PastKvLenCache, paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention}, @@ -27,16 +27,16 @@ use crate::{ Cache, IsqModel, NormalLoadingMetadata, NormalModel, }, serde_default_fn, - utils::progress::NiceProgressBar, + utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder}, }; serde_default_fn!(bool, word_emb_default, false); // https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json -#[derive(Debug, Clone, serde::Deserialize, Default)] +#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)] pub struct Config { pub vocab_size: usize, - pub hidden_act: candle_nn::Activation, + pub hidden_act: Activation, pub hidden_size: usize, pub intermediate_size: usize, pub num_hidden_layers: usize, @@ -239,7 +239,7 @@ impl Attention { struct Mlp { gate_up_proj: Arc, down_proj: Arc, - act_fn: candle_nn::Activation, + act_fn: Activation, i_size: usize, params: Vec, } @@ -597,6 +597,24 @@ impl IsqModel for Model { } (tensors, &*self.mapper) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + let uvb_m = uvb.pp("model"); + uvb_m.pp("embed_tokens").add(&self.embed_tokens); + uvb_m.pp("norm").add(&self.norm); + + for (layer_idx, layer) in self.layers.iter().enumerate() { + let uvb_l = uvb_m.pp("layers").pp(layer_idx); + uvb_l.pp("input_layernorm").add(&layer.input_layernorm); + uvb_l + .pp("post_attention_layernorm") + .add(&layer.post_attention_layernorm); + } + + uvb.to_safetensors() + } } impl NormalModel for Model { diff --git a/mistralrs-core/src/models/phi3_5_moe.rs b/mistralrs-core/src/models/phi3_5_moe.rs index c1f032692..0f02c1c68 100644 --- a/mistralrs-core/src/models/phi3_5_moe.rs +++ b/mistralrs-core/src/models/phi3_5_moe.rs @@ -11,7 +11,10 @@ use crate::{ amoe::AnyMoeBaseModelMixin, attention::SdpaParams, device_map::DeviceMapper, - layers::{CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig, PhiRotaryEmbedding, Sdpa}, + layers::{ + Activation, CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig, PhiRotaryEmbedding, + Sdpa, + }, layers_masker::{masked_fill, PastKvLenCache}, ops::NonZeroOp, paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention}, @@ -21,16 +24,16 @@ use crate::{ Cache, IsqModel, NormalLoadingMetadata, NormalModel, }, serde_default_fn, - utils::progress::NiceProgressBar, + utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder}, }; serde_default_fn!(bool, word_emb_default, false); // https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json -#[derive(Debug, Clone, serde::Deserialize, Default)] +#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)] pub struct Config { pub(crate) vocab_size: usize, - pub(crate) hidden_act: candle_nn::Activation, + pub(crate) hidden_act: Activation, pub(crate) hidden_size: usize, pub(crate) intermediate_size: usize, pub(crate) num_hidden_layers: usize, @@ -250,7 +253,7 @@ struct Mlp { w1: Arc, w2: Arc, w3: Arc, - act_fn: candle_nn::Activation, + act_fn: Activation, } impl Mlp { @@ -732,6 +735,48 @@ impl IsqModel for Model { } (tensors, &*self.mapper) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + let uvb_m = uvb.pp("model"); + uvb_m.pp("embed_tokens").add(&self.embed_tokens); + uvb_m.pp("norm").add(&self.norm); + + for (layer_idx, layer) in self.layers.iter().enumerate() { + let uvb_l = uvb_m.pp("layers").pp(layer_idx); + uvb_l.pp("input_layernorm").add(&layer.input_layernorm); + uvb_l + .pp("post_attention_layernorm") + .add(&layer.post_attention_layernorm); + } + + uvb.to_safetensors() + } + + fn residual_tensors_moe_experts_only(&self) -> Option> { + let uvb = UnVarBuilder::new(); + + let uvb_m = uvb.pp("model"); + uvb_m.pp("embed_tokens").add(&self.embed_tokens); + uvb_m.pp("norm").add(&self.norm); + + for (layer_idx, layer) in self.layers.iter().enumerate() { + let uvb_l = uvb_m.pp("layers").pp(layer_idx); + uvb_l.pp("input_layernorm").add(&layer.input_layernorm); + uvb_l + .pp("post_attention_layernorm") + .add(&layer.post_attention_layernorm); + + let uvb_attn = uvb_l.pp("self_attn"); + uvb_attn.pp("q_proj").add(&layer.self_attn.q_proj); + uvb_attn.pp("k_proj").add(&layer.self_attn.k_proj); + uvb_attn.pp("v_proj").add(&layer.self_attn.v_proj); + uvb_attn.pp("o_proj").add(&layer.self_attn.o_proj); + } + + Some(uvb.to_safetensors()) + } } impl NormalModel for Model { diff --git a/mistralrs-core/src/models/qwen2.rs b/mistralrs-core/src/models/qwen2.rs index ee4a64f88..7d3cf7b03 100644 --- a/mistralrs-core/src/models/qwen2.rs +++ b/mistralrs-core/src/models/qwen2.rs @@ -1,7 +1,7 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] use candle_core::{DType, Device, Module, Result, Tensor}; -use candle_nn::{Activation, RotaryEmbedding, VarBuilder}; +use candle_nn::{RotaryEmbedding, VarBuilder}; use mistralrs_quant::{QuantMethod, QuantMethodConfig, QuantizedConfig, UnquantLinear}; use std::{collections::HashMap, sync::Arc}; @@ -13,7 +13,7 @@ use crate::{ attention::SdpaParams, device_map::DeviceMapper, get_delta_from_lora_ab, - layers::{CausalMasker, MatMul, RmsNorm, Sdpa}, + layers::{Activation, CausalMasker, MatMul, RmsNorm, Sdpa}, layers_masker::PastKvLenCache, paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention}, pipeline::{ @@ -22,12 +22,12 @@ use crate::{ Cache, IsqModel, NormalLoadingMetadata, NormalModel, }, serde_default_fn, - utils::progress::NiceProgressBar, + utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder}, }; serde_default_fn!(bool, word_emb_default, false); -#[derive(Debug, Clone, serde::Deserialize, Default)] +#[derive(Debug, Clone, serde::Deserialize, Default, serde::Serialize)] pub struct Config { pub vocab_size: usize, pub hidden_size: usize, @@ -599,6 +599,24 @@ impl IsqModel for Model { } (tensors, &*self.mapper) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + let uvb_m = uvb.pp("model"); + uvb_m.pp("embed_tokens").add(&self.embed_tokens); + uvb_m.pp("norm").add(&self.norm); + + for (layer_idx, layer) in self.layers.iter().enumerate() { + let uvb_l = uvb_m.pp("layers").pp(layer_idx); + uvb_l.pp("input_layernorm").add(&layer.input_layernorm); + uvb_l + .pp("post_attention_layernorm") + .add(&layer.post_attention_layernorm); + } + + uvb.to_safetensors() + } } impl NormalModel for Model { diff --git a/mistralrs-core/src/models/starcoder2.rs b/mistralrs-core/src/models/starcoder2.rs index 3c3e8d849..458c48cea 100644 --- a/mistralrs-core/src/models/starcoder2.rs +++ b/mistralrs-core/src/models/starcoder2.rs @@ -10,7 +10,7 @@ use crate::{ attention::SdpaParams, device_map::DeviceMapper, get_delta_from_lora_ab, - layers::{CausalMasker, MatMul, RotaryEmbedding, Sdpa}, + layers::{Activation, CausalMasker, MatMul, RotaryEmbedding, Sdpa}, layers_masker::PastKvLenCache, paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention}, pipeline::{ @@ -19,13 +19,13 @@ use crate::{ Cache, IsqModel, NormalLoadingMetadata, NormalModel, }, serde_default_fn, - utils::progress::NiceProgressBar, + utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder}, AnyMoeConfig, AnyMoeExpertType, }; serde_default_fn!(bool, word_emb_default, false); -#[derive(Debug, Clone, serde::Deserialize, Default)] +#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)] pub struct Config { pub(crate) vocab_size: usize, pub(crate) hidden_size: usize, @@ -33,7 +33,7 @@ pub struct Config { pub(crate) num_hidden_layers: usize, pub(crate) num_attention_heads: usize, pub(crate) num_key_value_heads: usize, - pub(crate) hidden_act: candle_nn::Activation, + pub(crate) hidden_act: Activation, pub(crate) max_position_embeddings: usize, pub(crate) norm_epsilon: f64, pub(crate) rope_theta: f64, @@ -51,7 +51,7 @@ pub struct Config { struct MLP { c_fc: Arc, c_proj: Arc, - act: candle_nn::Activation, + act: Activation, params: Vec, } @@ -587,6 +587,24 @@ impl IsqModel for Model { } (tensors, &*self.mapper) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + let uvb_m = uvb.pp("model"); + uvb_m.pp("embed_tokens").add(&self.embed_tokens); + uvb_m.pp("norm").add(&self.norm); + + for (layer_idx, layer) in self.layers.iter().enumerate() { + let uvb_l = uvb_m.pp("layers").pp(layer_idx); + uvb_l.pp("input_layernorm").add(&layer.input_layernorm); + uvb_l + .pp("post_attention_layernorm") + .add(&layer.post_attention_layernorm); + } + + uvb.to_safetensors() + } } impl NormalModel for Model { diff --git a/mistralrs-core/src/ops.rs b/mistralrs-core/src/ops.rs index 0d7b5321d..020bfe04f 100644 --- a/mistralrs-core/src/ops.rs +++ b/mistralrs-core/src/ops.rs @@ -123,6 +123,7 @@ impl CustomOp2 for BitWise { CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise")), CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise")), CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise")), + CpuStorage::F8E4M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise")), } } #[cfg(feature = "cuda")] @@ -191,6 +192,9 @@ impl CustomOp2 for BitWise { DType::F64 => { return Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise")); } + DType::F8E4M3 => { + return Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise")); + } }; let dst = match s1.dtype() { DType::U8 => { @@ -397,6 +401,7 @@ fn count_nonzero_cuda(dtype: candle_core::DType, d_in: *const c_void, n: u32) -> candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n), candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n), candle_core::DType::F64 => ffi::count_nonzero_f64(d_in, n), + candle_core::DType::F8E4M3 => todo!(), } } } @@ -438,6 +443,7 @@ fn nonzero_cuda( candle_core::DType::F64 => { ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out) } + candle_core::DType::F8E4M3 => todo!(), } } } @@ -461,6 +467,7 @@ impl CustomOp1 for NonZero { candle_core::CpuStorage::F16(vs) => self.nonzero(vs, layout), candle_core::CpuStorage::F32(vs) => self.nonzero(vs, layout), candle_core::CpuStorage::F64(vs) => self.nonzero(vs, layout), + candle_core::CpuStorage::F8E4M3(_vs) => todo!(), }; let index_len = layout.dims().len(); let result_len = result.len() / index_len; @@ -488,6 +495,7 @@ impl CustomOp1 for NonZero { candle_core::DType::F16 => *storage.as_cuda_slice::()?.device_ptr(), candle_core::DType::F32 => *storage.as_cuda_slice::()?.device_ptr(), candle_core::DType::F64 => *storage.as_cuda_slice::()?.device_ptr(), + candle_core::DType::F8E4M3 => todo!(), } as *const c_void; let n = layout.shape().elem_count(); let num_nonzero = count_nonzero_cuda(storage.dtype(), d_in, u32::try_from(n)?); diff --git a/mistralrs-core/src/pipeline/ggml.rs b/mistralrs-core/src/pipeline/ggml.rs index 0b64fc206..352db6a5b 100644 --- a/mistralrs-core/src/pipeline/ggml.rs +++ b/mistralrs-core/src/pipeline/ggml.rs @@ -109,7 +109,7 @@ impl GGMLLoaderBuilder { quantized_model_id: String, quantized_filename: String, ) -> Self { - let kind = ModelKind::Quantized { + let kind = ModelKind::GgufQuantized { quant: QuantizationKind::Ggml, }; @@ -339,8 +339,8 @@ impl Loader for GGMLLoader { // Config into model: // NOTE: No architecture to infer like GGUF, Llama model is implicitly matched let model = match self.kind { - ModelKind::Quantized { .. } => Model::Llama(QLlama::try_from(model_config)?), - ModelKind::AdapterQuantized { .. } => { + ModelKind::GgufQuantized { .. } => Model::Llama(QLlama::try_from(model_config)?), + ModelKind::GgufAdapter { .. } => { Model::XLoraLlama(XLoraQLlama::try_from(model_config)?) } _ => unreachable!(), @@ -410,7 +410,8 @@ impl Loader for GGMLLoader { self, self.quantized_model_id, Some(vec![self.quantized_filename.as_ref().unwrap().clone()]), - silent + silent, + false // Never loading UQFF ); self.load_model_from_path( &paths?, diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs index febd76e99..432d1e050 100644 --- a/mistralrs-core/src/pipeline/gguf.rs +++ b/mistralrs-core/src/pipeline/gguf.rs @@ -124,7 +124,7 @@ impl GGUFLoaderBuilder { quantized_filenames: Vec, config: GGUFSpecificConfig, ) -> Self { - let kind = ModelKind::Quantized { + let kind = ModelKind::GgufQuantized { quant: QuantizationKind::Gguf, }; @@ -394,7 +394,7 @@ impl Loader for GGUFLoader { let has_adapter = self.kind.is_adapted(); let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora()); - let paged_attn_config = if matches!(self.kind, ModelKind::AdapterQuantized { .. }) { + let paged_attn_config = if matches!(self.kind, ModelKind::GgufAdapter { .. }) { warn!("Adapter models do not currently support PagedAttention, running without"); None } else { @@ -431,7 +431,7 @@ impl Loader for GGUFLoader { // Config into model: let model = match self.kind { - ModelKind::Quantized { .. } => match arch { + ModelKind::GgufQuantized { .. } => match arch { GGUFArchitecture::Llama => Model::Llama(QLlama::try_from(model_config)?), GGUFArchitecture::Phi2 => Model::Phi2(QPhi::try_from(model_config)?), GGUFArchitecture::Phi3 => Model::Phi3(QPhi3::try_from(model_config)?), @@ -440,7 +440,7 @@ impl Loader for GGUFLoader { } a => bail!("Unsupported architecture `{a:?}` for GGUF"), }, - ModelKind::AdapterQuantized { adapter, .. } => match arch { + ModelKind::GgufAdapter { adapter, .. } => match arch { GGUFArchitecture::Llama => Model::XLoraLlama(XLoraQLlama::try_from(model_config)?), GGUFArchitecture::Phi3 => Model::XLoraPhi3(XLoraQPhi3::try_from(model_config)?), a => bail!( diff --git a/mistralrs-core/src/pipeline/isq.rs b/mistralrs-core/src/pipeline/isq.rs index d0b187040..5bcdc5b9f 100644 --- a/mistralrs-core/src/pipeline/isq.rs +++ b/mistralrs-core/src/pipeline/isq.rs @@ -1,6 +1,7 @@ use std::{ borrow::Cow, collections::{HashMap, HashSet}, + fs::File, path::PathBuf, str::FromStr, sync::{atomic::AtomicUsize, Arc}, @@ -8,17 +9,21 @@ use std::{ }; use anyhow::Result; -use candle_core::{Device, Tensor}; +use candle_core::{Context, Device, Tensor}; use indicatif::{ParallelProgressIterator, ProgressBar, ProgressStyle}; use mistralrs_quant::{ - GgufMatMul, HqqLayer, IsqType, QuantMethod, QuantizedSerde, QuantizedSerdeType, UnquantLinear, + FP8Linear, GgufMatMul, HqqLayer, IsqType, QuantMethod, QuantizedSerde, QuantizedSerdeType, + UnquantLinear, }; use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; use regex::Regex; use serde::Deserialize; +use tokenizers::Tokenizer; use tracing::info; -use crate::{device_map::DeviceMapper, serde_default_fn, topology::LayerTopology, Topology}; +use crate::{device_map::DeviceMapper, topology::LayerTopology, Topology}; + +pub(crate) const UQFF_RESIDUAL_SAFETENSORS: &str = "residual.safetensors"; /// Parse ISQ value: one of /// - `Q4_0` @@ -54,10 +59,11 @@ pub fn parse_isq_value(s: &str) -> Result { "q8k" => IsqType::Q8K, "hqq8" => IsqType::HQQ8, "hqq4" => IsqType::HQQ4, + "fp8" => IsqType::F8E4M3, // "hqq3" => IsqType::HQQ3, // "hqq2" => IsqType::HQQ2, // "hqq1" => IsqType::HQQ1, - _ => return Err(format!("ISQ type {s} unknown, choose one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q8_1`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `Q8K`, `HQQ8`, `HQQ4`.")), + _ => return Err(format!("ISQ type {s} unknown, choose one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q8_1`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `Q8K`, `HQQ8`, `HQQ4`, `FP8`.")), }; #[cfg(feature = "cuda")] { @@ -74,11 +80,12 @@ pub fn parse_isq_value(s: &str) -> Result { | IsqType::Q5K | IsqType::Q6K | IsqType::HQQ8 - | IsqType::HQQ4 // | IsqType::HQQ3 - // | IsqType::HQQ2 - // | IsqType::HQQ1 + | IsqType::HQQ4 + | IsqType::F8E4M3 // | IsqType::HQQ3 + // | IsqType::HQQ2 + // | IsqType::HQQ1 ) { - return Err("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`".to_string()); + return Err("ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`, `FP8`".to_string()); } } Ok(tp) @@ -108,6 +115,15 @@ impl FromStr for IsqOrganization { } } +pub struct UqffFullSer<'a> { + pub tokenizer: &'a Tokenizer, + pub template_filename: &'a Option, + pub generation_config: Option<&'a PathBuf>, + pub config: String, + pub processor_filename: &'a Option, + pub preprocessor_filename: &'a Option, +} + pub trait IsqModel { /// Corresponds to `IsqOrganization::Default` #[allow(clippy::type_complexity)] @@ -130,7 +146,19 @@ pub trait IsqModel { self.get_layers() } + /// Residual tensors for generating a UQFF file. Counterpart to [`get_layers`]. + fn residual_tensors(&self) -> Vec<(String, Tensor)>; + + /// Residual tensors for generating a UQFF file. Counterpart to [`get_layers_moe_experts_only`]. + fn residual_tensors_moe_experts_only(&self) -> Option> { + None + } + /// Quantize the model in-situ. + /// + /// This function will also create a UQFF file, or, if the model supports it (residual tensors are returned), + /// a full serialization is created. + #[allow(clippy::too_many_arguments)] fn quantize( &mut self, dtype: Option, @@ -139,6 +167,7 @@ pub trait IsqModel { silent: bool, organization: IsqOrganization, write_artifacts: Option<&PathBuf>, + full_ser: UqffFullSer<'_>, ) -> candle_core::Result<()> { { let (mut tensors, mapper) = match organization { @@ -275,10 +304,7 @@ pub trait IsqModel { ); if !serialized.extension().is_some_and(|ext| ext == "uqff") { - candle_core::bail!( - "UQFF output path extension must be {:?}", - serialized.extension().as_ref().unwrap() - ); + candle_core::bail!("UQFF output path extension must be `.uqff`",); } let bar = ProgressBar::new(total_tensors as u64); @@ -331,7 +357,99 @@ pub trait IsqModel { } }); + let parent = serialized + .parent() + .context("Target UQFF path must have a filename!")?; + + std::fs::create_dir_all(parent)?; + safetensors::serialize_to_file(quantized_values?, &None, serialized)?; + + let residual = match organization { + IsqOrganization::Default => self.residual_tensors(), + IsqOrganization::MoeExpertsOnly => self + .residual_tensors_moe_experts_only() + .unwrap_or(self.residual_tensors()), + }; + + let residual_out = parent.join(UQFF_RESIDUAL_SAFETENSORS); + let config_out = parent.join("config.json"); + let tokenizer_out = parent.join("tokenizer.json"); + let tokenizer_cfg_out = parent.join("tokenizer_config.json"); + let gen_cfg_out = parent.join("generation_config.json"); + let processor_out = parent.join("processor_config.json"); + let preprocessor_out = parent.join("preprocessor_config.json"); + + info!( + "Serializing {} residual tensors to `{}`.", + residual.len(), + residual_out.display() + ); + + safetensors::serialize_to_file(residual, &None, &residual_out)?; + + let UqffFullSer { + tokenizer, + template_filename, + generation_config, + config, + processor_filename, + preprocessor_filename, + } = full_ser; + + info!("Serializing configuration to `{}`.", config_out.display()); + + std::fs::write(config_out, config)?; + + info!("Serializing tokenizer to `{}`.", tokenizer_out.display()); + + serde_json::to_writer_pretty(File::create(&tokenizer_out)?, tokenizer) + .map_err(candle_core::Error::msg)?; + + if let Some(template_filename) = template_filename { + info!( + "Serializing tokenizer config to `{}`.", + tokenizer_cfg_out.display() + ); + + let template = + std::fs::read(template_filename).map_err(candle_core::Error::msg)?; + std::fs::write(&tokenizer_cfg_out, template) + .map_err(candle_core::Error::msg)?; + } + + if let Some(generation_config) = generation_config { + info!( + "Serializing generation config to `{}`.", + gen_cfg_out.display() + ); + + let cfg = + std::fs::read(generation_config).map_err(candle_core::Error::msg)?; + std::fs::write(&gen_cfg_out, cfg).map_err(candle_core::Error::msg)?; + } + + if let Some(processor_config) = processor_filename { + info!( + "Serializing processor config to `{}`.", + processor_out.display() + ); + + let cfg = + std::fs::read(processor_config).map_err(candle_core::Error::msg)?; + std::fs::write(&processor_out, cfg).map_err(candle_core::Error::msg)?; + } + + if let Some(preprocessor_config) = preprocessor_filename { + info!( + "Serializing preprocessor config to `{}`.", + preprocessor_out.display() + ); + + let cfg = + std::fs::read(preprocessor_config).map_err(candle_core::Error::msg)?; + std::fs::write(&preprocessor_out, cfg).map_err(candle_core::Error::msg)?; + } } } @@ -412,7 +530,97 @@ pub trait IsqModel { .collect::>>() }; - safetensors::serialize_to_file(quantized_values?, &None, serialized)?; + let parent = serialized + .parent() + .context("Target UQFF path must have a filename!")?; + + std::fs::create_dir_all(parent)?; + + let residual = match organization { + IsqOrganization::Default => self.residual_tensors(), + IsqOrganization::MoeExpertsOnly => self + .residual_tensors_moe_experts_only() + .unwrap_or(self.residual_tensors()), + }; + + let residual_out = parent.join(UQFF_RESIDUAL_SAFETENSORS); + let config_out = parent.join("config.json"); + let tokenizer_out = parent.join("tokenizer.json"); + let tokenizer_cfg_out = parent.join("tokenizer_config.json"); + let gen_cfg_out = parent.join("generation_config.json"); + let processor_out = parent.join("processor_config.json"); + let preprocessor_out = parent.join("preprocessor_config.json"); + + info!( + "Serializing {} residual tensors to `{}`.", + residual.len(), + residual_out.display() + ); + + safetensors::serialize_to_file(residual, &None, &residual_out)?; + + let UqffFullSer { + tokenizer, + template_filename, + generation_config, + config, + processor_filename, + preprocessor_filename, + } = full_ser; + + info!("Serializing configuration to `{}`.", config_out.display()); + + std::fs::write(config_out, config)?; + + info!("Serializing tokenizer to `{}`.", tokenizer_out.display()); + + serde_json::to_writer_pretty(File::create(&tokenizer_out)?, tokenizer) + .map_err(candle_core::Error::msg)?; + + if let Some(template_filename) = template_filename { + info!( + "Serializing tokenizer config to `{}`.", + tokenizer_cfg_out.display() + ); + + let template = + std::fs::read(template_filename).map_err(candle_core::Error::msg)?; + std::fs::write(&tokenizer_cfg_out, template) + .map_err(candle_core::Error::msg)?; + } + + if let Some(generation_config) = generation_config { + info!( + "Serializing generation config to `{}`.", + gen_cfg_out.display() + ); + + let cfg = + std::fs::read(generation_config).map_err(candle_core::Error::msg)?; + std::fs::write(&gen_cfg_out, cfg).map_err(candle_core::Error::msg)?; + } + + if let Some(processor_config) = processor_filename { + info!( + "Serializing processor config to `{}`.", + processor_out.display() + ); + + let cfg = + std::fs::read(processor_config).map_err(candle_core::Error::msg)?; + std::fs::write(&processor_out, cfg).map_err(candle_core::Error::msg)?; + } + + if let Some(preprocessor_config) = preprocessor_filename { + info!( + "Serializing preprocessor config to `{}`.", + preprocessor_out.display() + ); + + let cfg = + std::fs::read(preprocessor_config).map_err(candle_core::Error::msg)?; + std::fs::write(&preprocessor_out, cfg).map_err(candle_core::Error::msg)?; + } } } let delta = Instant::now().duration_since(t_start).as_secs_f32(); @@ -511,6 +719,9 @@ pub trait IsqModel { QuantizedSerdeType::Hqq => { HqqLayer::deserialize(Cow::from(artifact), &devices[i])? } + QuantizedSerdeType::Fp8 => { + FP8Linear::deserialize(Cow::from(artifact), &devices[i])? + } }; *tensor = deserialized; } @@ -537,6 +748,9 @@ pub trait IsqModel { QuantizedSerdeType::Hqq => { HqqLayer::deserialize(Cow::from(artifact), &devices[i])? } + QuantizedSerdeType::Fp8 => { + FP8Linear::deserialize(Cow::from(artifact), &devices[i])? + } }; *tensor = deserialized; } @@ -568,11 +782,3 @@ pub(crate) trait IsqModelLoader { self.isq_layer_regexes(config) } } - -serde_default_fn!(bool, word_emb_default, false); - -#[derive(Deserialize)] -pub(crate) struct WordEmbeddingsShim { - #[serde(default = "word_emb_default")] - pub(crate) tie_word_embeddings: bool, -} diff --git a/mistralrs-core/src/pipeline/loaders/mod.rs b/mistralrs-core/src/pipeline/loaders/mod.rs index 72eace924..f6c6fefd2 100644 --- a/mistralrs-core/src/pipeline/loaders/mod.rs +++ b/mistralrs-core/src/pipeline/loaders/mod.rs @@ -236,17 +236,17 @@ impl fmt::Display for TokenSource { #[derive(Clone, Default, derive_more::From, strum::Display)] pub enum ModelKind { #[default] - #[strum(to_string = "normal (no quant, no adapters)")] + #[strum(to_string = "normal (no adapters)")] Normal, - #[strum(to_string = "quantized from {quant} (no adapters)")] - Quantized { quant: QuantizationKind }, + #[strum(to_string = "gguf quantized from {quant} (no adapters)")] + GgufQuantized { quant: QuantizationKind }, - #[strum(to_string = "{adapter}, (no quant)")] + #[strum(to_string = "{adapter}")] Adapter { adapter: AdapterKind }, - #[strum(to_string = "{adapter}, quantized from {quant}")] - AdapterQuantized { + #[strum(to_string = "{adapter}, gguf quantized from {quant}")] + GgufAdapter { adapter: AdapterKind, quant: QuantizationKind, }, @@ -311,7 +311,7 @@ impl ModelKind { match self { Normal | Adapter { .. } => vec![None], - Quantized { quant } | AdapterQuantized { quant, .. } => vec![Some(*quant)], + GgufQuantized { quant } | GgufAdapter { quant, .. } => vec![Some(*quant)], Speculative { target, draft } => { let t = *target.clone(); let d = *draft.clone(); @@ -335,8 +335,8 @@ impl ModelKind { use ModelKind::*; match self { - Normal | Quantized { .. } => vec![None], - Adapter { adapter } | AdapterQuantized { adapter, .. } => vec![Some(*adapter)], + Normal | GgufQuantized { .. } => vec![None], + Adapter { adapter } | GgufAdapter { adapter, .. } => vec![Some(*adapter)], Speculative { target, draft } => { let t = *target.clone(); let d = *draft.clone(); diff --git a/mistralrs-core/src/pipeline/loaders/normal_loaders.rs b/mistralrs-core/src/pipeline/loaders/normal_loaders.rs index 21d93ca5e..b5ee6d336 100644 --- a/mistralrs-core/src/pipeline/loaders/normal_loaders.rs +++ b/mistralrs-core/src/pipeline/loaders/normal_loaders.rs @@ -7,11 +7,11 @@ use std::{ use crate::{ amoe::AnyMoeBaseModelMixin, device_map::DeviceMapper, - layers::{Llama3RopeConfig, PhiRopeScalingConfig}, + layers::{Activation, Llama3RopeConfig, PhiRopeScalingConfig}, lora::{LoraConfig, Ordering}, paged_attention::{AttentionImplementation, ModelConfigMetadata}, pipeline::{ - isq::{IsqModelLoader, WordEmbeddingsShim}, + isq::IsqModelLoader, text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata}, Cache, IsqModel, }, @@ -21,7 +21,7 @@ use crate::{ }; use anyhow::Result; use candle_core::{Device, Tensor}; -use candle_nn::{Activation, VarBuilder}; +use candle_nn::VarBuilder; use mistralrs_quant::QuantizedConfig; #[cfg(feature = "pyo3_macros")] @@ -391,35 +391,19 @@ impl NormalModelLoader for MistralLoader { } impl IsqModelLoader for MistralLoader { - fn isq_layer_regexes(&self, config: &str) -> Result> { - let mut regexes = Vec::new(); - if serde_json::from_str::(config)?.tie_word_embeddings { - regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?); - } else { - regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?); - } - // Attention - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$", - )?); - // MLP - regexes.push(Regex::new( - r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?); - regexes.push(Regex::new( - r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$", - )?); - Ok(regexes) + fn isq_layer_regexes(&self, _config: &str) -> Result> { + Ok(vec![ + Regex::new(r"lm_head\.(weight|bias)$")?, + // Attention + Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?, + // MLP + Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?, + ]) } } @@ -535,35 +519,19 @@ impl NormalModelLoader for GemmaLoader { } impl IsqModelLoader for GemmaLoader { - fn isq_layer_regexes(&self, config: &str) -> Result> { - let mut regexes = Vec::new(); - if serde_json::from_str::(config)?.tie_word_embeddings { - regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?); - } else { - regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?); - } - // Attention - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$", - )?); - // MLP - regexes.push(Regex::new( - r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?); - regexes.push(Regex::new( - r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$", - )?); - Ok(regexes) + fn isq_layer_regexes(&self, _config: &str) -> Result> { + Ok(vec![ + Regex::new(r"lm_head\.(weight|bias)$")?, + // Attention + Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?, + // MLP + Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?, + ]) } } @@ -673,35 +641,19 @@ impl NormalModelLoader for LlamaLoader { } impl IsqModelLoader for LlamaLoader { - fn isq_layer_regexes(&self, config: &str) -> Result> { - let mut regexes = Vec::new(); - if serde_json::from_str::(config)?.tie_word_embeddings { - regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?); - } else { - regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?); - } - // Attention - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$", - )?); - // MLP - regexes.push(Regex::new( - r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?); - regexes.push(Regex::new( - r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$", - )?); - Ok(regexes) + fn isq_layer_regexes(&self, _config: &str) -> Result> { + Ok(vec![ + Regex::new(r"lm_head\.(weight|bias)$")?, + // Attention + Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?, + // MLP + Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?, + ]) } } @@ -807,40 +759,20 @@ impl NormalModelLoader for MixtralLoader { } impl IsqModelLoader for MixtralLoader { - fn isq_layer_regexes(&self, config: &str) -> Result> { - let mut regexes = Vec::new(); - if serde_json::from_str::(config)?.tie_word_embeddings { - regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?); - } else { - regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?); - } - // Attention - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$", - )?); - // Experts - regexes.push(Regex::new( - r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$", - )?); - Ok(regexes) + fn isq_layer_regexes(&self, _config: &str) -> Result> { + Ok(vec![ + Regex::new(r"lm_head\.(weight|bias)$")?, + // Attention + Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?, + // Experts + Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?, + ]) } } @@ -947,30 +879,18 @@ impl NormalModelLoader for Phi2Loader { } impl IsqModelLoader for Phi2Loader { - fn isq_layer_regexes(&self, config: &str) -> Result> { - let mut regexes = Vec::new(); - if serde_json::from_str::(config)?.tie_word_embeddings { - regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?); - } else { - regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?); - } - // Attention - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$", - )?); - // MLP - regexes.push(Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?); - regexes.push(Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?); - Ok(regexes) + fn isq_layer_regexes(&self, _config: &str) -> Result> { + Ok(vec![ + Regex::new(r"lm_head\.(weight|bias)$")?, + // Attention + Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?, + // MLP + Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?, + ]) } } @@ -979,7 +899,7 @@ impl IsqModelLoader for Phi2Loader { #[derive(Deserialize)] struct Phi3BasicConfig { vocab_size: usize, - hidden_act: candle_nn::Activation, + hidden_act: Activation, hidden_size: usize, intermediate_size: usize, num_hidden_layers: usize, @@ -1083,28 +1003,17 @@ impl NormalModelLoader for Phi3Loader { } impl IsqModelLoader for Phi3Loader { - fn isq_layer_regexes(&self, config: &str) -> Result> { - let mut regexes = Vec::new(); - if serde_json::from_str::(config)?.tie_word_embeddings { - regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?); - } else { - regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?); - } - // Attention - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$", - )?); - // MLP - regexes.push(Regex::new( - r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$", - )?); - Ok(regexes) + fn isq_layer_regexes(&self, _config: &str) -> Result> { + Ok(vec![ + Regex::new(r"lm_head\.(weight|bias)$")?, + // Attention + Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?, + // MLP + Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?, + ]) } } @@ -1199,35 +1108,19 @@ impl NormalModelLoader for Qwen2Loader { } impl IsqModelLoader for Qwen2Loader { - fn isq_layer_regexes(&self, config: &str) -> Result> { - let mut regexes = Vec::new(); - if serde_json::from_str::(config)?.tie_word_embeddings { - regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?); - } else { - regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?); - } - // Attention - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$", - )?); - // MLP - regexes.push(Regex::new( - r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?); - Ok(regexes) + fn isq_layer_regexes(&self, _config: &str) -> Result> { + Ok(vec![ + Regex::new(r"lm_head\.(weight|bias)$")?, + // Attention + Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?, + // MLP + Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?, + ]) } } @@ -1348,35 +1241,19 @@ impl NormalModelLoader for Gemma2Loader { } impl IsqModelLoader for Gemma2Loader { - fn isq_layer_regexes(&self, config: &str) -> Result> { - let mut regexes = Vec::new(); - if serde_json::from_str::(config)?.tie_word_embeddings { - regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?); - } else { - regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?); - } - // Attention - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$", - )?); - // MLP - regexes.push(Regex::new( - r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?); - Ok(regexes) + fn isq_layer_regexes(&self, _config: &str) -> Result> { + Ok(vec![ + Regex::new(r"lm_head\.(weight|bias)$")?, + // Attention + Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?, + // MLP + Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?, + ]) } } @@ -1390,7 +1267,7 @@ struct Starcoder2BasicConfig { num_hidden_layers: usize, num_attention_heads: usize, num_key_value_heads: usize, - hidden_act: candle_nn::Activation, + hidden_act: Activation, max_position_embeddings: usize, norm_epsilon: f64, rope_theta: f64, @@ -1482,30 +1359,18 @@ impl NormalModelLoader for Starcoder2Loader { } impl IsqModelLoader for Starcoder2Loader { - fn isq_layer_regexes(&self, config: &str) -> Result> { - let mut regexes = Vec::new(); - if serde_json::from_str::(config)?.tie_word_embeddings { - regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?); - } else { - regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?); - } - // Attention - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$", - )?); - // MLP - regexes.push(Regex::new(r"layers\.(\d+)\.mlp\.c_fc\.(weight|bias)$")?); - regexes.push(Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?); - Ok(regexes) + fn isq_layer_regexes(&self, _config: &str) -> Result> { + Ok(vec![ + Regex::new(r"lm_head\.(weight|bias)$")?, + // Attention + Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?, + // MLP + Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?, + ]) } } @@ -1514,7 +1379,7 @@ impl IsqModelLoader for Starcoder2Loader { #[derive(Deserialize)] struct Phi3_5MoEBasicConfig { vocab_size: usize, - hidden_act: candle_nn::Activation, + hidden_act: Activation, hidden_size: usize, intermediate_size: usize, num_hidden_layers: usize, @@ -1622,56 +1487,28 @@ impl NormalModelLoader for Phi3_5MoELoader { } impl IsqModelLoader for Phi3_5MoELoader { - fn isq_layer_regexes(&self, config: &str) -> Result> { - let mut regexes = Vec::new(); - if serde_json::from_str::(config)?.tie_word_embeddings { - regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?); - } else { - regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?); - } - // Attention - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$", - )?); - // MLP - regexes.push(Regex::new( - r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$", - )?); - Ok(regexes) - } - - fn isq_layer_regexes_moqe(&self, config: &str) -> Result> { - let mut regexes = Vec::new(); - if serde_json::from_str::(config)?.tie_word_embeddings { - regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?); - } else { - regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?); - } - // MLP - regexes.push(Regex::new( - r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$", - )?); - Ok(regexes) + fn isq_layer_regexes(&self, _config: &str) -> Result> { + Ok(vec![ + Regex::new(r"lm_head\.(weight|bias)$")?, + // Attention + Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?, + // MLP + Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?, + ]) + } + + fn isq_layer_regexes_moqe(&self, _config: &str) -> Result> { + Ok(vec![ + Regex::new(r"lm_head\.(weight|bias)$")?, + // MLP + Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?, + ]) } } diff --git a/mistralrs-core/src/pipeline/loaders/vision_loaders.rs b/mistralrs-core/src/pipeline/loaders/vision_loaders.rs index efe84b2c3..a41ca568b 100644 --- a/mistralrs-core/src/pipeline/loaders/vision_loaders.rs +++ b/mistralrs-core/src/pipeline/loaders/vision_loaders.rs @@ -15,7 +15,7 @@ use serde::Deserialize; use super::NormalLoadingMetadata; use crate::amoe::AnyMoeBaseModelMixin; use crate::paged_attention::{AttentionImplementation, ModelConfigMetadata}; -use crate::pipeline::isq::{IsqModelLoader, WordEmbeddingsShim}; +use crate::pipeline::isq::IsqModelLoader; use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata}; use crate::pipeline::{Cache, IsqModel, Processor, ProcessorCreator}; use crate::vision_models::idefics2::{Config as Idefics2Config, Idefics2}; @@ -157,28 +157,16 @@ impl VisionModelLoader for Phi3VLoader { } impl IsqModelLoader for Phi3VLoader { - fn isq_layer_regexes(&self, config: &str) -> Result> { - let mut regexes = Vec::new(); - if serde_json::from_str::(config)?.tie_word_embeddings { - regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?); - } else { - regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?); - } - // Attention - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$", - )?); - // MLP - regexes.push(Regex::new( - r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$", - )?); - Ok(regexes) + fn isq_layer_regexes(&self, _config: &str) -> Result> { + Ok(vec![ + Regex::new(r"lm_head\.(weight|bias)$")?, + // Attention + Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?, + // MLP + Regex::new(r"layers\.(\d+)\.mlp\.gate__up_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?, + ]) } } @@ -240,7 +228,6 @@ impl VisionModelLoader for Idefics2Loader { impl IsqModelLoader for Idefics2Loader { fn isq_layer_regexes(&self, _config: &str) -> Result> { Ok(vec![ - // Tie weights is unsupported for this model Regex::new(r"lm_head\.(weight|bias)$")?, // Attention Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?, @@ -310,7 +297,6 @@ impl VisionModelLoader for LLaVANextLoader { impl IsqModelLoader for LLaVANextLoader { fn isq_layer_regexes(&self, _config: &str) -> Result> { Ok(vec![ - // Tie weights is unsupported for this model Regex::new(r"lm_head\.(weight|bias)$")?, // Attention Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?, @@ -380,7 +366,6 @@ impl VisionModelLoader for LLaVALoader { impl IsqModelLoader for LLaVALoader { fn isq_layer_regexes(&self, _config: &str) -> Result> { Ok(vec![ - // Tie weights is unsupported for this model Regex::new(r"lm_head\.(weight|bias)$")?, // Attention Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?, @@ -448,37 +433,18 @@ impl VisionModelLoader for VLlamaLoader { } impl IsqModelLoader for VLlamaLoader { - fn isq_layer_regexes(&self, config: &str) -> Result> { - let mut regexes = Vec::new(); - if serde_json::from_str::(config)? - .text_config - .tie_word_embeddings - { - regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?); - } else { - regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?); - } - // Attention - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new( - r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$", - )?); - // MLP - regexes.push(Regex::new( - r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$", - )?); - regexes.push(Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?); - regexes.push(Regex::new( - r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$", - )?); - Ok(regexes) + fn isq_layer_regexes(&self, _config: &str) -> Result> { + Ok(vec![ + Regex::new(r"lm_head\.(weight|bias)$")?, + // Attention + Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?, + // MLP + Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?, + Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?, + ]) } } diff --git a/mistralrs-core/src/pipeline/macros.rs b/mistralrs-core/src/pipeline/macros.rs index a74af8efc..7f3b8420b 100644 --- a/mistralrs-core/src/pipeline/macros.rs +++ b/mistralrs-core/src/pipeline/macros.rs @@ -56,7 +56,16 @@ macro_rules! api_get_file { #[doc(hidden)] #[macro_export] macro_rules! get_paths { - ($path_name:ident, $token_source:expr, $revision:expr, $this:expr, $quantized_model_id:expr, $quantized_filename:expr, $silent:expr) => {{ + ( + $path_name:ident, + $token_source:expr, + $revision:expr, + $this:expr, + $quantized_model_id:expr, + $quantized_filename:expr, + $silent:expr, + $loading_uqff:expr + ) => {{ let api = ApiBuilder::new() .with_progress(!$silent) .with_token(get_token($token_source)?) @@ -84,6 +93,7 @@ macro_rules! get_paths { &$quantized_filename, &api, &model_id, + $loading_uqff, )?; let XLoraPaths { adapter_configs, @@ -169,54 +179,49 @@ macro_rules! get_paths { #[doc(hidden)] #[macro_export] -macro_rules! get_write_uqff_paths { +macro_rules! get_uqff_paths { ($from_uqff:expr, $this:expr, $silent:expr) => {{ - if !$from_uqff.exists() { - // Assume it's a HF model id - let path = $from_uqff.to_string_lossy().to_string(); - let parts = path.rsplitn(2, '/').collect::>(); - - if parts.len() != 2 { - anyhow::bail!("ISQ artifact load path `{path}` not found locally must have format `/`"); - } - - let file = parts[0]; - let model_id = parts[1]; + let api = ApiBuilder::new() + .with_progress(!$silent) + .with_token(get_token( + &$this + .token_source + .read() + .expect("Failed to read token source") + .clone() + .unwrap_or(TokenSource::None), + )?) + .build()?; + let revision = $this + .revision + .read() + .expect("Failed to read revision") + .clone() + .unwrap_or("main".to_string()); + let api = api.repo(Repo::with_revision( + $this.model_id.to_string(), + RepoType::Model, + revision.clone(), + )); - let api = ApiBuilder::new() - .with_progress(!$silent) - .with_token(get_token( - &$this - .token_source - .read() - .expect("Failed to read token source") - .clone() - .unwrap_or(TokenSource::None), - )?) - .build()?; - let revision = $this - .revision - .read() - .expect("Failed to read revision") - .clone() - .unwrap_or("main".to_string()); - let api = api.repo(Repo::with_revision( - model_id.to_string(), - RepoType::Model, - revision.clone(), - )); + let file = $from_uqff.display().to_string(); - api_get_file!(api, file, Path::new(model_id)) - } else { - $from_uqff - } + api_get_file!(api, &file, Path::new(&$this.model_id)) }}; } #[doc(hidden)] #[macro_export] macro_rules! get_paths_gguf { - ($path_name:ident, $token_source:expr, $revision:expr, $this:expr, $quantized_model_id:expr, $quantized_filenames:expr, $silent:expr) => {{ + ( + $path_name:ident, + $token_source:expr, + $revision:expr, + $this:expr, + $quantized_model_id:expr, + $quantized_filenames:expr, + $silent:expr + ) => {{ let api = ApiBuilder::new() .with_progress(!$silent) .with_token(get_token($token_source)?) @@ -258,6 +263,7 @@ macro_rules! get_paths_gguf { &Some($quantized_filenames), &api, &model_id, + false, // Never loading UQFF )?; let XLoraPaths { @@ -345,7 +351,21 @@ macro_rules! get_paths_gguf { #[doc(hidden)] #[macro_export] macro_rules! normal_model_loader { - ($paths:expr, $dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $loading_uqff:expr, $real_device:expr, $attention_mechanism:expr, $is_moqe:expr) => {{ + ( + $paths:expr, + $dtype:expr, + $device:expr, + $config:expr, + $loader:expr, + $use_flash_attn:expr, + $silent:expr, + $mapper:expr, + $loading_isq:expr, + $loading_uqff:expr, + $real_device:expr, + $attention_mechanism:expr, + $is_moqe:expr + ) => {{ let regexes = if $loading_isq && $loading_uqff { // Dummy weights for the layers which will be overwritten... Some(std::sync::Arc::new(if $is_moqe { @@ -384,7 +404,20 @@ macro_rules! normal_model_loader { #[doc(hidden)] #[macro_export] macro_rules! vision_normal_model_loader { - ($paths:expr, $dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $loading_uqff:expr, $real_device:expr, $attention_mechanism:expr) => {{ + ( + $paths:expr, + $dtype:expr, + $device:expr, + $config:expr, + $loader:expr, + $use_flash_attn:expr, + $silent:expr, + $mapper:expr, + $loading_isq:expr, + $loading_uqff:expr, + $real_device:expr, + $attention_mechanism:expr + ) => {{ let regexes = if $loading_isq && $loading_uqff { // Dummy weights for the layers which will be overwritten... Some(std::sync::Arc::new($loader.isq_layer_regexes(&$config)?)) @@ -419,7 +452,18 @@ macro_rules! vision_normal_model_loader { #[doc(hidden)] #[macro_export] macro_rules! xlora_model_loader { - ($paths:expr, $dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $real_device:expr) => {{ + ( + $paths:expr, + $dtype:expr, + $device:expr, + $config:expr, + $loader:expr, + $use_flash_attn:expr, + $silent:expr, + $mapper:expr, + $loading_isq:expr, + $real_device:expr + ) => {{ let mut safetensors_paths = $paths.get_weight_filenames().iter().collect::>(); safetensors_paths.push($paths.get_classifier_path().as_ref().unwrap()); let vb = from_mmaped_safetensors( diff --git a/mistralrs-core/src/pipeline/normal.rs b/mistralrs-core/src/pipeline/normal.rs index 49c22a277..7ef391988 100644 --- a/mistralrs-core/src/pipeline/normal.rs +++ b/mistralrs-core/src/pipeline/normal.rs @@ -18,6 +18,7 @@ use crate::amoe::AnyMoeExpertType; use crate::lora::Ordering; use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine}; use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig}; +use crate::pipeline::isq::UqffFullSer; use crate::pipeline::sampling::sample_and_add_toks; use crate::pipeline::{get_chat_template, Cache}; use crate::pipeline::{ChatTemplate, LocalModelPaths}; @@ -28,9 +29,9 @@ use crate::utils::tokenizer::get_tokenizer; use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors}; use crate::xlora_models::NonGranularState; use crate::{ - api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_write_uqff_paths, - lora_model_loader, normal_model_loader, xlora_model_loader, DeviceMapMetadata, - PagedAttentionConfig, Pipeline, Topology, TryIntoDType, + api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_uqff_paths, lora_model_loader, + normal_model_loader, xlora_model_loader, DeviceMapMetadata, PagedAttentionConfig, Pipeline, + Topology, TryIntoDType, }; use anyhow::Result; use candle_core::{Device, Tensor, Var}; @@ -59,6 +60,10 @@ pub struct NormalPipeline { topology: Option, silent: bool, organization: IsqOrganization, + // For full UQFF serialization + template_filename: Option, + generation_config: Option, + config: String, } /// A loader for a "normal" (non-quantized) model. @@ -75,6 +80,7 @@ pub struct NormalLoader { tgt_non_granular_index: Option, token_source: RwLock>, revision: RwLock>, + from_uqff: RwLock>, } #[derive(Default)] @@ -203,6 +209,7 @@ impl NormalLoaderBuilder { tgt_non_granular_index: self.tgt_non_granular_index, token_source: RwLock::new(None), revision: RwLock::new(None), + from_uqff: RwLock::new(None), })) } } @@ -227,8 +234,12 @@ impl Loader for NormalLoader { self, None, None, - silent + silent, + self.config.from_uqff.is_some() ); + if let Some(from_uqff) = self.config.from_uqff.clone() { + *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent)); + } *self .token_source .write() @@ -374,14 +385,21 @@ impl Loader for NormalLoader { silent, self.config.organization, self.config.write_uqff.as_ref(), + UqffFullSer { + tokenizer: &tokenizer, + template_filename: paths.get_template_filename(), + generation_config: paths.get_gen_conf_filename(), + config: config.clone(), + processor_filename: &None, + preprocessor_filename: &None, + }, )?; - } else if let Some(mut from_uqff) = self.config.from_uqff.clone() { - from_uqff = get_write_uqff_paths!(from_uqff, self, silent); + } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() { model.load_from_artifacts( device.clone(), self.config.topology.as_ref(), silent, - &from_uqff, + from_uqff, )?; } @@ -441,6 +459,9 @@ impl Loader for NormalLoader { topology: self.config.topology.clone(), silent, organization: self.config.organization, + template_filename: paths.get_template_filename().clone(), + generation_config: paths.get_gen_conf_filename().cloned(), + config, }))) } @@ -476,6 +497,14 @@ impl IsqPipelineMixin for NormalPipeline { self.silent, self.organization, None, + UqffFullSer { + tokenizer: &self.tokenizer, + template_filename: &self.template_filename, + generation_config: self.generation_config.as_ref(), + config: self.config.clone(), + processor_filename: &None, + preprocessor_filename: &None, + }, ) .map_err(anyhow::Error::msg) } diff --git a/mistralrs-core/src/pipeline/paths.rs b/mistralrs-core/src/pipeline/paths.rs index a7ca84730..c2dfad8c5 100644 --- a/mistralrs-core/src/pipeline/paths.rs +++ b/mistralrs-core/src/pipeline/paths.rs @@ -17,7 +17,10 @@ use tracing::{info, warn}; use crate::{ api_dir_list, api_get_file, lora::LoraConfig, - pipeline::chat_template::{ChatTemplate, ChatTemplateValue}, + pipeline::{ + chat_template::{ChatTemplate, ChatTemplateValue}, + isq::UQFF_RESIDUAL_SAFETENSORS, + }, utils::tokens::get_token, xlora_models::XLoraConfig, ModelPaths, Ordering, TokenSource, @@ -262,6 +265,7 @@ pub fn get_model_paths( quantized_filename: &Option>, api: &ApiRepo, model_id: &Path, + loading_from_uqff: bool, ) -> Result> { match &quantized_filename { Some(names) => { @@ -294,6 +298,7 @@ pub fn get_model_paths( safetensor_match.is_match(x) || pickle_match.is_match(x) || quant_safetensor_match.is_match(x) + || x == UQFF_RESIDUAL_SAFETENSORS }); let safetensors = listing .clone() @@ -303,12 +308,18 @@ pub fn get_model_paths( .clone() .filter(|x| x.ends_with(".pth") || x.ends_with(".pt") || x.ends_with(".bin")) .collect::>(); + let uqff_residual = listing + .clone() + .filter(|x| x == UQFF_RESIDUAL_SAFETENSORS) + .collect::>(); let files = if !safetensors.is_empty() { // Always prefer safetensors safetensors } else if !pickles.is_empty() { // Fall back to pickle pickles + } else if !uqff_residual.is_empty() && loading_from_uqff { + uqff_residual } else { anyhow::bail!("Expected file with extension one of .safetensors, .pth, .pt, .bin."); }; diff --git a/mistralrs-core/src/pipeline/vision.rs b/mistralrs-core/src/pipeline/vision.rs index e08e8601a..e8b808392 100644 --- a/mistralrs-core/src/pipeline/vision.rs +++ b/mistralrs-core/src/pipeline/vision.rs @@ -1,4 +1,5 @@ use super::cache_manager::DefaultCacheManager; +use super::isq::UqffFullSer; use super::{ get_model_paths, get_xlora_paths, AdapterActivationMixin, AnyMoePipelineMixin, Cache, CacheManager, CacheManagerMixin, ForwardInputsResult, GeneralMetadata, IsqPipelineMixin, @@ -21,7 +22,7 @@ use crate::vision_models::preprocessor_config::PreProcessorConfig; use crate::vision_models::processor_config::ProcessorConfig; use crate::vision_models::ModelInputs; use crate::{ - api_dir_list, api_get_file, get_paths, get_write_uqff_paths, vision_normal_model_loader, + api_dir_list, api_get_file, get_paths, get_uqff_paths, vision_normal_model_loader, AnyMoeExpertType, DeviceMapMetadata, Ordering, PagedAttentionConfig, Pipeline, Topology, TryIntoDType, }; @@ -51,6 +52,12 @@ pub struct VisionPipeline { preprocessor_config: Arc, topology: Option, silent: bool, + // For full UQFF serialization + template_filename: Option, + generation_config: Option, + config: String, + processor_filename: Option, + preprocessor_filename: Option, } /// A loader for a vision (non-quantized) model. @@ -65,6 +72,7 @@ pub struct VisionLoader { xlora_order: Option, token_source: RwLock>, revision: RwLock>, + from_uqff: RwLock>, } #[derive(Default)] @@ -122,6 +130,7 @@ impl VisionLoaderBuilder { xlora_order: None, token_source: RwLock::new(None), revision: RwLock::new(None), + from_uqff: RwLock::new(None), }) } } @@ -146,8 +155,12 @@ impl Loader for VisionLoader { self, None, None, - silent + silent, + self.config.from_uqff.is_some() ); + if let Some(from_uqff) = self.config.from_uqff.clone() { + *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent)); + } *self .token_source .write() @@ -289,14 +302,21 @@ impl Loader for VisionLoader { silent, IsqOrganization::Default, self.config.write_uqff.as_ref(), + UqffFullSer { + tokenizer: &tokenizer, + template_filename: paths.get_template_filename(), + generation_config: paths.get_gen_conf_filename(), + config: config.clone(), + processor_filename: paths.get_processor_config(), + preprocessor_filename: paths.get_preprocessor_config(), + }, )?; - } else if let Some(mut from_uqff) = self.config.from_uqff.clone() { - from_uqff = get_write_uqff_paths!(from_uqff, self, silent); + } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() { model.load_from_artifacts( device.clone(), self.config.topology.as_ref(), silent, - &from_uqff, + from_uqff, )?; } @@ -347,6 +367,11 @@ impl Loader for VisionLoader { preprocessor_config: Arc::new(preprocessor_config), topology: self.config.topology.clone(), silent, + template_filename: paths.get_template_filename().clone(), + generation_config: paths.get_gen_conf_filename().cloned(), + config, + processor_filename: paths.get_processor_config().clone(), + preprocessor_filename: paths.get_preprocessor_config().clone(), }))) } @@ -382,6 +407,14 @@ impl IsqPipelineMixin for VisionPipeline { self.silent, IsqOrganization::Default, None, + UqffFullSer { + tokenizer: &self.tokenizer, + template_filename: &self.template_filename, + generation_config: self.generation_config.as_ref(), + config: self.config.clone(), + processor_filename: &self.processor_filename, + preprocessor_filename: &self.preprocessor_filename, + }, ) .map_err(anyhow::Error::msg) } diff --git a/mistralrs-core/src/utils/memory_usage.rs b/mistralrs-core/src/utils/memory_usage.rs index 611cd8ee2..dcf52a479 100644 --- a/mistralrs-core/src/utils/memory_usage.rs +++ b/mistralrs-core/src/utils/memory_usage.rs @@ -6,6 +6,7 @@ const KB_TO_BYTES: usize = 1024; pub struct MemoryUsage; impl MemoryUsage { + /// Amount of available memory in bytes. pub fn get_memory_available(&self, device: &Device) -> Result { match device { Device::Cpu => { @@ -30,6 +31,7 @@ impl MemoryUsage { } } + /// Amount of total memory in bytes. pub fn get_total_memory(&self, device: &Device) -> Result { match device { Device::Cpu => { diff --git a/mistralrs-core/src/utils/mod.rs b/mistralrs-core/src/utils/mod.rs index 6b202470e..bea723738 100644 --- a/mistralrs-core/src/utils/mod.rs +++ b/mistralrs-core/src/utils/mod.rs @@ -7,6 +7,7 @@ pub(crate) mod normal; pub(crate) mod progress; pub(crate) mod tokenizer; pub(crate) mod tokens; +pub(crate) mod unvarbuilder; pub(crate) mod varbuilder_utils; #[doc(hidden)] diff --git a/mistralrs-core/src/utils/unvarbuilder.rs b/mistralrs-core/src/utils/unvarbuilder.rs new file mode 100644 index 000000000..74fb06217 --- /dev/null +++ b/mistralrs-core/src/utils/unvarbuilder.rs @@ -0,0 +1,176 @@ +use std::{ + collections::HashMap, + sync::{Arc, RwLock}, +}; + +use candle_core::{quantized::QMatMul, Tensor}; +use candle_nn::{Conv2d, Embedding, LayerNorm, Linear}; +use itertools::Itertools; +use mistralrs_quant::QuantMethod; + +use crate::layers::{FusedBiasLinear, QLinear, RmsNorm}; + +pub trait ToTensors { + /// Tensor names to tensors + fn to_tensors(&self) -> HashMap; +} + +impl ToTensors for Embedding { + fn to_tensors(&self) -> HashMap { + HashMap::from_iter([("weight".to_string(), self.embeddings().clone())]) + } +} + +impl ToTensors for RmsNorm { + fn to_tensors(&self) -> HashMap { + HashMap::from_iter([("weight".to_string(), self.weight().clone())]) + } +} + +impl ToTensors for LayerNorm { + fn to_tensors(&self) -> HashMap { + HashMap::from_iter([ + ("weight".to_string(), self.weight().clone()), + ("bias".to_string(), self.bias().clone()), + ]) + } +} + +impl ToTensors for Linear { + fn to_tensors(&self) -> HashMap { + let mut map = HashMap::new(); + map.insert("weight".to_string(), self.weight().clone()); + if let Some(bias) = self.bias() { + map.insert("bias".to_string(), bias.clone()); + } + map + } +} + +impl ToTensors for Conv2d { + fn to_tensors(&self) -> HashMap { + let mut map = HashMap::new(); + map.insert("weight".to_string(), self.weight().clone()); + if let Some(bias) = self.bias() { + map.insert("bias".to_string(), bias.clone()); + } + map + } +} + +impl ToTensors for FusedBiasLinear { + fn to_tensors(&self) -> HashMap { + let mut map = HashMap::new(); + map.insert("weight".to_string(), self.w.clone()); + map.insert("bias".to_string(), self.b.clone()); + map + } +} + +impl ToTensors for QLinear { + fn to_tensors(&self) -> HashMap { + let mut map = HashMap::new(); + match self.inner_ref() { + QMatMul::Tensor(w) | QMatMul::TensorF16(w) => { + map.insert("weight".to_string(), w.clone()); + if let Some(bias) = self.bias() { + map.insert("bias".to_string(), bias.clone()); + } + } + QMatMul::QTensor(_) => return HashMap::new(), + } + map + } +} + +impl ToTensors for Arc { + fn to_tensors(&self) -> HashMap { + let (w, b) = match self.unquant_weight_bias() { + Some(x) => x, + None => return HashMap::new(), + }; + let mut map = HashMap::new(); + map.insert("weight".to_string(), w); + if let Some(bias) = b { + map.insert("bias".to_string(), bias.clone()); + } + map + } +} + +pub struct UnVarBuilder { + data: Arc>>, + path: Vec, +} + +impl UnVarBuilder { + pub fn new() -> Self { + Self { + data: Arc::new(RwLock::new(HashMap::new())), + path: Vec::new(), + } + } + + pub fn push_prefix(&self, s: S) -> Self { + let mut path = self.path.clone(); + path.push(s.to_string()); + Self { + data: self.data.clone(), + path, + } + } + + pub fn pp(&self, s: S) -> Self { + self.push_prefix(s) + } + + pub fn path(&self) -> String { + self.path.iter().filter(|p| !p.trim().is_empty()).join(".") + } + + pub fn add(&self, item: &T) { + let mut data = self.data.write().expect("Write failed!"); + let path = self.path(); + data.extend( + item.to_tensors() + .into_iter() + .map(|(n, t)| (format!("{path}.{n}"), t)) + .collect::>(), + ); + } + + pub fn add_tensor(&self, s: S, v: Tensor) { + let mut data = self.data.write().expect("Write failed!"); + let mut path = self.path.clone(); + path.push(s.to_string()); + data.insert( + path.into_iter().filter(|p| !p.trim().is_empty()).join("."), + v, + ); + } + + pub fn extend(&self, other: Vec<(String, Tensor)>) { + let mut data = self.data.write().expect("Write failed!"); + let path = self.path(); + data.extend( + other + .into_iter() + .map(|(n, t)| { + ( + if path.is_empty() { + n + } else { + format!("{path}.{n}") + }, + t, + ) + }) + .collect::>(), + ); + } + + pub fn to_safetensors(&self) -> Vec<(String, Tensor)> { + let data = self.data.read().expect("Read failed!"); + data.iter().map(|(p, t)| (p.clone(), t.clone())).collect() + } +} diff --git a/mistralrs-core/src/vision_models/clip.rs b/mistralrs-core/src/vision_models/clip.rs index 9cea49f3f..9df0b7c2d 100644 --- a/mistralrs-core/src/vision_models/clip.rs +++ b/mistralrs-core/src/vision_models/clip.rs @@ -4,7 +4,7 @@ use candle_core::{IndexOp, Result, Shape, Tensor, D}; use candle_nn::{Conv2dConfig, Module}; -use crate::{layers::FusedBiasLinear, serde_default_fn}; +use crate::{layers::FusedBiasLinear, serde_default_fn, utils::unvarbuilder::UnVarBuilder}; #[derive(Debug, Clone, Copy, serde::Deserialize)] pub enum Activation { @@ -316,4 +316,48 @@ impl ClipVisionTransformer { result.push(self.final_layer_norm.forward(&pooled_output)?.clone()); Ok(result) } + + pub fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + uvb.pp("pre_layrnorm").add(&self.pre_layer_norm); + uvb.pp("post_layernorm").add(&self.final_layer_norm); + + // vision embeddings + { + let uvb_emb = uvb.pp("embeddings"); + + uvb_emb.add_tensor("class_embedding", self.embeddings.class_embedding.clone()); + uvb_emb + .pp("position_embedding") + .add(&self.embeddings.position_embedding); + uvb_emb + .pp("patch_embedding") + .add(&self.embeddings.patch_embedding); + } + + // encoder + { + let uvb_enc = uvb.pp("encoder"); + + for (i, layer) in self.encoder.layers.iter().enumerate() { + let uvb_l = uvb_enc.pp("layers").pp(i); + + uvb_l.pp("layer_norm1").add(&layer.layer_norm1); + uvb_l.pp("layer_norm2").add(&layer.layer_norm2); + + let uvb_mlp = uvb_l.pp("mlp"); + uvb_mlp.pp("fc1").add(&layer.mlp.fc1); + uvb_mlp.pp("fc2").add(&layer.mlp.fc2); + + let uvb_attn = uvb_l.pp("self_attn"); + uvb_attn.pp("q_proj").add(&layer.self_attn.q_proj); + uvb_attn.pp("k_proj").add(&layer.self_attn.k_proj); + uvb_attn.pp("v_proj").add(&layer.self_attn.v_proj); + uvb_attn.pp("out_proj").add(&layer.self_attn.out_proj); + } + } + + uvb.to_safetensors() + } } diff --git a/mistralrs-core/src/vision_models/idefics2.rs b/mistralrs-core/src/vision_models/idefics2.rs index 9e8a6af4a..0eb6b10f2 100644 --- a/mistralrs-core/src/vision_models/idefics2.rs +++ b/mistralrs-core/src/vision_models/idefics2.rs @@ -2,8 +2,8 @@ use candle_core::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{ - conv2d, embedding, layer_norm, linear, linear_no_bias, Activation, Conv2d, Conv2dConfig, - Embedding, LayerNorm, Module, VarBuilder, + conv2d, embedding, layer_norm, linear, linear_no_bias, Conv2d, Conv2dConfig, Embedding, + LayerNorm, Module, VarBuilder, }; use serde::Deserialize; use std::{any::Any, ops::Mul}; @@ -11,13 +11,14 @@ use std::{any::Any, ops::Mul}; use crate::{ amoe::{AnyMoeBaseModelMixin, MlpLayer}, device_map::DeviceMapper, - layers::{repeat_kv, CausalMasker, QLinear, RmsNorm}, + layers::{repeat_kv, Activation, CausalMasker, QLinear, RmsNorm}, models::mistral::Model as Mistral, paged_attention::{AttentionImplementation, ModelConfigMetadata}, pipeline::{ text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata}, Cache, IsqModel, NormalLoadingMetadata, NormalModel, VisionModel, }, + utils::unvarbuilder::UnVarBuilder, AnyMoeConfig, AnyMoeExpertType, }; @@ -332,6 +333,15 @@ impl VisionEmbeddings { let position_ids = position_ids.to_device(self.position_embedding.embeddings().device())?; embeddings.broadcast_add(&self.position_embedding.forward(&position_ids)?) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + uvb.pp("patch_embedding").add(&self.patch_embedding); + uvb.pp("position_embedding").add(&self.position_embedding); + + uvb.to_safetensors() + } } struct Attention { @@ -421,6 +431,17 @@ impl Attention { } Ok(res) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + uvb.pp("q_proj").add(&self.q_proj); + uvb.pp("k_proj").add(&self.k_proj); + uvb.pp("v_proj").add(&self.v_proj); + uvb.pp("out_proj").add(&self.o_proj); + + uvb.to_safetensors() + } } struct VisionMLP { @@ -454,6 +475,15 @@ impl VisionMLP { } Ok(res) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + uvb.pp("fc1").add(&self.fc1); + uvb.pp("fc2").add(&self.fc2); + + uvb.to_safetensors() + } } struct EncoderLayer { @@ -580,6 +610,26 @@ impl VisionTransformer { .forward(&hidden_states, attention_mask.as_ref())?; hidden_states.apply(&self.post_layernorm) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + uvb.pp("post_layernorm").add(&self.post_layernorm); + uvb.pp("embeddings") + .extend(self.embeddings.residual_tensors()); + + let uvb_enc = uvb.pp("encoder"); + for (i, layer) in self.encoder.layers.iter().enumerate() { + let uvb_l = uvb_enc.pp("layers").pp(i); + + uvb_l.pp("layer_norm1").add(&layer.layer_norm_1); + uvb_l.pp("layer_norm2").add(&layer.layer_norm_2); + uvb_l.pp("mlp").extend(layer.mlp.residual_tensors()); + uvb_l.pp("self_attn").extend(layer.attn.residual_tensors()); + } + + uvb.to_safetensors() + } } // == END VISION MODEL == @@ -626,6 +676,16 @@ impl Mlp { } Ok(res) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + uvb.pp("gate_proj").add(&self.gate_proj); + uvb.pp("up_proj").add(&self.up_proj); + uvb.pp("down_proj").add(&self.down_proj); + + uvb.to_safetensors() + } } struct PerceiverAttention { @@ -728,6 +788,17 @@ impl PerceiverAttention { } Ok(res) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + uvb.pp("q_proj").add(&self.q_proj); + uvb.pp("k_proj").add(&self.k_proj); + uvb.pp("v_proj").add(&self.v_proj); + uvb.pp("o_proj").add(&self.o_proj); + + uvb.to_safetensors() + } } struct PerceiverLayer { @@ -829,6 +900,33 @@ impl PerceiverResampler { } self.norm.forward(&compressed_context) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + uvb.pp("norm").add(&self.norm); + uvb.add_tensor("latents", self.latents.clone()); + + for (i, layer) in self.layers.iter().enumerate() { + let uvb_l = uvb.pp("layers").pp(i); + + uvb_l + .pp("input_latents_norm") + .add(&layer.input_latents_norm); + uvb_l + .pp("input_context_norm") + .add(&layer.input_context_norm); + uvb_l + .pp("post_attention_layernorm") + .add(&layer.post_attn_norm); + uvb_l.pp("mlp").extend(layer.mlp.residual_tensors()); + uvb_l + .pp("self_attn") + .extend(layer.self_attn.residual_tensors()); + } + + uvb.to_safetensors() + } } struct Connector { @@ -857,6 +955,17 @@ impl Connector { self.perceiver_resampler .forward(&image_hidden_states, attention_mask) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + uvb.pp("modality_projection") + .extend(self.modality_projection.residual_tensors()); + uvb.pp("perceiver_resampler") + .extend(self.perceiver_resampler.residual_tensors()); + + uvb.to_safetensors() + } } // == END CONNECTOR == @@ -1078,6 +1187,23 @@ impl IsqModel for Idefics2 { ) { self.text_model.get_layers() } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + let uvb_m = uvb.pp("model"); + uvb_m + .pp("text_model") + .extend(self.text_model.residual_tensors()); + uvb_m + .pp("vision_model") + .extend(self.vision_model.residual_tensors()); + uvb_m + .pp("connector") + .extend(self.connector.residual_tensors()); + + uvb.to_safetensors() + } } // AnyMoE is forwarded to the base model diff --git a/mistralrs-core/src/vision_models/llava/config.rs b/mistralrs-core/src/vision_models/llava/config.rs index 549d7571e..fd6cfdfef 100644 --- a/mistralrs-core/src/vision_models/llava/config.rs +++ b/mistralrs-core/src/vision_models/llava/config.rs @@ -1,7 +1,6 @@ -use candle_nn::Activation; use serde::Deserialize; -use crate::layers::Llama3RopeConfig; +use crate::layers::{Activation, Llama3RopeConfig}; use crate::serde_default_fn; use crate::models::llama::Config as LLaMAConfig; diff --git a/mistralrs-core/src/vision_models/llava/llava15.rs b/mistralrs-core/src/vision_models/llava/llava15.rs index 33b219267..0fcbe870e 100644 --- a/mistralrs-core/src/vision_models/llava/llava15.rs +++ b/mistralrs-core/src/vision_models/llava/llava15.rs @@ -14,6 +14,7 @@ use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata; use crate::pipeline::IsqModel; use crate::pipeline::NormalLoadingMetadata; use crate::pipeline::VisionModel; +use crate::utils::unvarbuilder::UnVarBuilder; use crate::vision_models::clip::{ClipConfig, ClipVisionTransformer}; use crate::vision_models::llava::config::Config; use crate::AnyMoeConfig; @@ -269,6 +270,24 @@ impl IsqModel for Model { ) { self.llm.get_layers() } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + // MM projectors + uvb.pp("multi_modal_projector.linear_1") + .add(&self.mm_projector.linear_1); + uvb.pp("multi_modal_projector.linear_2") + .add(&self.mm_projector.linear_2); + + // Vision tower + { + let uvb_vt = uvb.pp("vision_tower.vision_model"); + uvb_vt.extend(self.clip_vision_tower.model.residual_tensors()); + } + + uvb.to_safetensors() + } } impl VisionModel for Model { diff --git a/mistralrs-core/src/vision_models/llava/llava_llm/llama.rs b/mistralrs-core/src/vision_models/llava/llava_llm/llama.rs index cfb6cda3a..c4a86ce08 100644 --- a/mistralrs-core/src/vision_models/llava/llava_llm/llama.rs +++ b/mistralrs-core/src/vision_models/llava/llava_llm/llama.rs @@ -488,6 +488,10 @@ impl IsqModel for Llama { } (tensors, &*self.mapper) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + Vec::new() + } } impl LLaVALLM for Llama { diff --git a/mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs b/mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs index ab3170331..c14928df7 100644 --- a/mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs +++ b/mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs @@ -4,7 +4,7 @@ use std::sync::Arc; /// Mistral LLM, https://github.com/mistralai/mistral-src use candle_core::{DType, Device, Module, Result, Tensor}; -use candle_nn::{linear_no_bias, Activation, VarBuilder}; +use candle_nn::{linear_no_bias, VarBuilder}; use mistralrs_quant::{QuantMethod, QuantMethodConfig, UnquantLinear}; use crate::{ @@ -12,7 +12,7 @@ use crate::{ attention::SdpaParams, device_map::DeviceMapper, get_delta_from_lora_ab, - layers::{CausalMasker, MatMul, RmsNorm, Sdpa}, + layers::{Activation, CausalMasker, MatMul, RmsNorm, Sdpa}, layers_masker::PastKvLenCache, paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention}, pipeline::{ @@ -572,6 +572,10 @@ impl IsqModel for Model { } (tensors, &*self.mapper) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + Vec::new() + } } impl LLaVALLM for Model { diff --git a/mistralrs-core/src/vision_models/llava/llava_next.rs b/mistralrs-core/src/vision_models/llava/llava_next.rs index 99842145a..250ea68d0 100644 --- a/mistralrs-core/src/vision_models/llava/llava_next.rs +++ b/mistralrs-core/src/vision_models/llava/llava_next.rs @@ -15,6 +15,7 @@ use crate::pipeline::IsqModel; use crate::pipeline::NormalLoadingMetadata; use crate::pipeline::VisionModel; +use crate::utils::unvarbuilder::UnVarBuilder; use crate::vision_models::clip::{ClipConfig, ClipVisionTransformer}; use crate::vision_models::llava::config::Config; use crate::vision_models::llava::utils::get_anyres_image_grid_shape; @@ -348,6 +349,26 @@ impl IsqModel for Model { ) { self.llm.get_layers() } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + // MM projectors + uvb.pp("multi_modal_projector.linear_1") + .add(&self.mm_projector.linear_1); + uvb.pp("multi_modal_projector.linear_2") + .add(&self.mm_projector.linear_2); + + // Vision tower + { + let uvb_vt = uvb.pp("vision_tower.vision_model"); + uvb_vt.extend(self.clip_vision_tower.model.residual_tensors()); + } + + uvb.add_tensor("image_newline", self.image_newline.clone()); + + uvb.to_safetensors() + } } impl VisionModel for Model { diff --git a/mistralrs-core/src/vision_models/mllama/mod.rs b/mistralrs-core/src/vision_models/mllama/mod.rs index 1d2248467..2a934a4b1 100644 --- a/mistralrs-core/src/vision_models/mllama/mod.rs +++ b/mistralrs-core/src/vision_models/mllama/mod.rs @@ -26,6 +26,7 @@ use crate::{ text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata}, Cache, IsqModel, NormalLoadingMetadata, VisionModel, }, + utils::unvarbuilder::UnVarBuilder, }; fn repeat_interleave(xs: &Tensor, repeats: usize, dim: usize) -> Result { @@ -250,6 +251,19 @@ impl IsqModel for MLlamaModel { ) { self.language_model.get_layers() } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + uvb.pp("multi_modal_projector") + .add(&self.multi_modal_projector); + uvb.pp("language_model") + .extend(self.language_model.residual_tensors()); + uvb.pp("vision_model") + .extend(self.vision_model.residual_tensors()); + + uvb.to_safetensors() + } } impl AnyMoeBaseModelMixin for MLlamaModel {} diff --git a/mistralrs-core/src/vision_models/mllama/text.rs b/mistralrs-core/src/vision_models/mllama/text.rs index 48e4b332e..2e06a1ae1 100644 --- a/mistralrs-core/src/vision_models/mllama/text.rs +++ b/mistralrs-core/src/vision_models/mllama/text.rs @@ -13,6 +13,7 @@ use crate::{ layers_masker::PastKvLenCache, paged_attention::{AttentionImplementation, ModelConfigMetadata}, pipeline::{extract_logits, Cache, IsqModel, NormalLoadingMetadata}, + utils::unvarbuilder::UnVarBuilder, }; use super::config::MLlamaTextConfig; @@ -446,16 +447,12 @@ impl MLlamaCrossAttentionDecoderLayer { mlp, input_layernorm, post_attention_layernorm, - // NOTE: Preapply the tanh attn_gate: mapper .set_device(layer_idx, vb.clone(), false) - .get((1,), "cross_attn_attn_gate")? - .tanh()?, - // NOTE: Preapply the tanh + .get((1,), "cross_attn_attn_gate")?, mlp_gate: mapper .set_device(layer_idx, vb.clone(), false) - .get((1,), "cross_attn_mlp_gate")? - .tanh()?, + .get((1,), "cross_attn_mlp_gate")?, }) } @@ -474,7 +471,7 @@ impl MLlamaCrossAttentionDecoderLayer { hidden_states = self.attn .forward(&hidden_states, cross_attn_states, attention_mask, kv_cache)?; - hidden_states = (residual + hidden_states.broadcast_mul(&self.attn_gate)?)?; + hidden_states = (residual + hidden_states.broadcast_mul(&self.attn_gate.tanh()?)?)?; let residual = &hidden_states; let mut hidden_states = self.post_attention_layernorm.forward(&hidden_states)?; @@ -486,7 +483,7 @@ impl MLlamaCrossAttentionDecoderLayer { .broadcast_mul(&hidden_states)?; } - residual + hidden_states.broadcast_mul(&self.mlp_gate)? + residual + hidden_states.broadcast_mul(&self.mlp_gate.tanh()?)? } } @@ -713,4 +710,51 @@ impl IsqModel for MLlamaTextModel { } (tensors, &*self.mapper) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + uvb.pp("model.embed_tokens").add(&self.embed_tokens); + uvb.pp("lm_head").add(&self.lm_head); + + let uvb = uvb.pp("model"); + + uvb.pp("norm").add(&self.norm); + + for (i, layer) in self.layers.iter().enumerate() { + let uvb_l = uvb.pp("layers").pp(i); + match layer { + MLlamaDecoderLayer::CrossAttn(crossattn) => { + // Cross attention layers are not quantized + uvb_l + .pp("post_attention_layernorm") + .add(&crossattn.post_attention_layernorm); + uvb_l.pp("input_layernorm").add(&crossattn.input_layernorm); + uvb_l.add_tensor("cross_attn_attn_gate", crossattn.attn_gate.clone()); + uvb_l.add_tensor("cross_attn_mlp_gate", crossattn.mlp_gate.clone()); + + let uvb_attn = uvb_l.pp("cross_attn"); + uvb_attn.pp("q_proj").add(&crossattn.attn.q_proj); + uvb_attn.pp("k_proj").add(&crossattn.attn.k_proj); + uvb_attn.pp("v_proj").add(&crossattn.attn.v_proj); + uvb_attn.pp("o_proj").add(&crossattn.attn.o_proj); + uvb_attn.pp("q_norm").add(&crossattn.attn.q_norm); + uvb_attn.pp("k_norm").add(&crossattn.attn.k_norm); + + let uvb_mlp = uvb_l.pp("mlp"); + uvb_mlp.pp("gate_proj").add(&crossattn.mlp.gate_proj); + uvb_mlp.pp("up_proj").add(&crossattn.mlp.up_proj); + uvb_mlp.pp("down_proj").add(&crossattn.mlp.down_proj); + } + MLlamaDecoderLayer::SelfAttn(selfattn) => { + uvb_l + .pp("post_attention_layernorm") + .add(&selfattn.post_attention_layernorm); + uvb_l.pp("input_layernorm").add(&selfattn.input_layernorm); + } + } + } + + uvb.to_safetensors() + } } diff --git a/mistralrs-core/src/vision_models/mllama/vision.rs b/mistralrs-core/src/vision_models/mllama/vision.rs index d438a1004..c24308f2d 100644 --- a/mistralrs-core/src/vision_models/mllama/vision.rs +++ b/mistralrs-core/src/vision_models/mllama/vision.rs @@ -11,6 +11,8 @@ use candle_nn::{ use crate::{ attention::SdpaParams, layers::{FusedBiasLinear, Sdpa}, + pipeline::IsqModel, + utils::unvarbuilder::UnVarBuilder, }; use super::{MLlamaVisionConfig, VisionActivation}; @@ -28,8 +30,7 @@ impl MLlamaPrecomputedPositionEmbedding { fn new(cfg: &MLlamaVisionConfig, vb: VarBuilder) -> Result { let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1; Ok(Self { - // NOTE: Preapply the tanh - gate: vb.get((1,), "gate")?.tanh()?, + gate: vb.get((1,), "gate")?, embedding: vb.get((num_patches, cfg.hidden_size), "embedding")?, tile_embedding: embedding( cfg.max_aspect_ratio_id() + 1, @@ -45,7 +46,7 @@ impl MLlamaPrecomputedPositionEmbedding { // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L197 fn forward(&self, hidden_state: &Tensor, aspect_ratio_ids: &Tensor) -> Result { // position embeddings - let mut gated_pos_embed = (1. - &self.gate)?.broadcast_mul(&self.embedding)?; + let mut gated_pos_embed = (1. - &self.gate.tanh()?)?.broadcast_mul(&self.embedding)?; let hidden_state = hidden_state.broadcast_add(&gated_pos_embed.reshape(( 1, 1, @@ -62,10 +63,20 @@ impl MLlamaPrecomputedPositionEmbedding { self.num_patches, self.hidden_size, ))?; - gated_pos_embed = self.gate.broadcast_mul(&tile_position_embedding)?; + gated_pos_embed = self.gate.tanh()?.broadcast_mul(&tile_position_embedding)?; hidden_state.broadcast_add(&gated_pos_embed) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb_gpe = UnVarBuilder::new(); + + uvb_gpe.add_tensor("gate", self.gate.clone()); + uvb_gpe.add_tensor("embedding", self.embedding.clone()); + uvb_gpe.pp("tile_embedding").add(&self.tile_embedding); + + uvb_gpe.to_safetensors() + } } struct MLlamaPrecomputedAspectRatioEmbedding { @@ -84,8 +95,7 @@ impl MLlamaPrecomputedAspectRatioEmbedding { vb.pp("embedding"), )?, gate: if GATED { - // NOTE: Preapply the tanh - Some(vb.get((1,), "gate")?.tanh()?) + Some(vb.get((1,), "gate")?) } else { None }, @@ -99,11 +109,22 @@ impl MLlamaPrecomputedAspectRatioEmbedding { embeddings = embeddings.reshape(((), self.max_num_tiles, 1, self.hidden_size))?; if let Some(gate) = &self.gate { - embeddings = embeddings.broadcast_mul(gate)?; + embeddings = embeddings.broadcast_mul(&gate.tanh()?)?; } hidden_state.broadcast_add(&embeddings) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb_ptpe = UnVarBuilder::new(); + + if let Some(gate) = self.gate.clone() { + uvb_ptpe.add_tensor("gate", gate); + } + uvb_ptpe.pp("embedding").add(&self.embedding); + + uvb_ptpe.to_safetensors() + } } struct MLlamaVisionAttention { @@ -246,10 +267,8 @@ impl MLlamaVisionEncoderLayer { mlp, input_layernorm, post_attention_layernorm, - // NOTE: Preapply the tanh - gate_attn: Some(vb.get((1,), "gate_attn")?.tanh()?), - // NOTE: Preapply the tanh - gate_ffn: Some(vb.get((1,), "gate_ffn")?.tanh()?), + gate_attn: Some(vb.get((1,), "gate_attn")?), + gate_ffn: Some(vb.get((1,), "gate_ffn")?), }) } else { Ok(Self { @@ -272,7 +291,7 @@ impl MLlamaVisionEncoderLayer { hidden_state = self.self_attn.forward(&hidden_state, attention_mask)?; if let Some(gate) = &self.gate_attn { - hidden_state = gate.broadcast_mul(&hidden_state)?; + hidden_state = gate.broadcast_mul(&hidden_state.tanh()?)?; } hidden_state = (residual + hidden_state)?; @@ -283,7 +302,7 @@ impl MLlamaVisionEncoderLayer { hidden_state = self.mlp.forward(&hidden_state)?; if let Some(gate) = &self.gate_ffn { - hidden_state = gate.broadcast_mul(&hidden_state)?; + hidden_state = gate.broadcast_mul(&hidden_state.tanh()?)?; } residual + hidden_state } @@ -326,6 +345,36 @@ impl MLlamaVisionEncoder { hidden_states.push(hidden_state.clone()); Ok((hidden_state, hidden_states)) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb_t = UnVarBuilder::new(); + + for (i, layer) in self.layers.iter().enumerate() { + let uvb_l = uvb_t.pp("layers").pp(i); + uvb_l.pp("input_layernorm").add(&layer.input_layernorm); + uvb_l + .pp("post_attention_layernorm") + .add(&layer.post_attention_layernorm); + if let Some(gate) = layer.gate_attn.clone() { + uvb_l.add_tensor("gate_attn", gate); + } + if let Some(gate) = layer.gate_ffn.clone() { + uvb_l.add_tensor("gate_ffn", gate); + } + + let uvb_attn = uvb_l.pp("self_attn"); + uvb_attn.pp("q_proj").add(&layer.self_attn.q_proj); + uvb_attn.pp("k_proj").add(&layer.self_attn.k_proj); + uvb_attn.pp("v_proj").add(&layer.self_attn.v_proj); + uvb_attn.pp("o_proj").add(&layer.self_attn.o_proj); + + let uvb_mlp = uvb_l.pp("mlp"); + uvb_mlp.pp("fc1").add(&layer.mlp.fc1); + uvb_mlp.pp("fc2").add(&layer.mlp.fc2); + } + + uvb_t.to_safetensors() + } } fn _prepare_aspect_ratio_attention_mask( @@ -605,3 +654,48 @@ impl MLlamaVisionModel { Tensor::cat(&[class_embedding, hidden_state.clone()], 1) } } + +impl IsqModel for MLlamaVisionModel { + fn get_layers( + &mut self, + ) -> ( + Vec<( + &mut std::sync::Arc, + Option, + )>, + &dyn crate::device_map::DeviceMapper, + ) { + unreachable!("MLlamaVision model cannot be quantized."); + } + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + uvb.pp("patch_embedding").add(&self.patch_embedding); + uvb.add_tensor("class_embedding", self.class_embedding.clone()); + + // gated_positional_embedding + uvb.pp("gated_positional_embedding") + .extend(self.gated_positional_embedding.residual_tensors()); + + // pre_tile_positional_embedding + uvb.pp("pre_tile_positional_embedding") + .extend(self.pre_tile_positional_embedding.residual_tensors()); + + // post_tile_positional_embedding + uvb.pp("post_tile_positional_embedding") + .extend(self.post_tile_positional_embedding.residual_tensors()); + + uvb.pp("layernorm_pre").add(&self.layernorm_pre); + uvb.pp("layernorm_post").add(&self.layernorm_post); + + // transformer + uvb.pp("transformer") + .extend(self.transformer.residual_tensors()); + + // global_transformer + uvb.pp("global_transformer") + .extend(self.global_transformer.residual_tensors()); + + uvb.to_safetensors() + } +} diff --git a/mistralrs-core/src/vision_models/phi3.rs b/mistralrs-core/src/vision_models/phi3.rs index 99bf7ba01..2d0b663b7 100644 --- a/mistralrs-core/src/vision_models/phi3.rs +++ b/mistralrs-core/src/vision_models/phi3.rs @@ -28,7 +28,7 @@ use crate::{ Cache, IsqModel, NormalLoadingMetadata, VisionModel, }, serde_default_fn, - utils::progress::NiceProgressBar, + utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder}, vision_models::clip::{Activation, ClipConfig, ClipVisionTransformer}, AnyMoeConfig, AnyMoeExpertType, }; @@ -503,6 +503,7 @@ pub struct ImageEmbedding { hd_transform_order: String, use_hd_transform: bool, vocab_size: usize, + tensors: Vec<(String, Tensor)>, } impl ImageEmbedding { @@ -558,50 +559,52 @@ impl ImageEmbedding { .projection_cls .clone() .unwrap_or("linear".to_string()); + + let mut tensors = Vec::new(); let layers: Vec> = match (projection_cls.as_str(), use_hd_transform) { ("linear", _) => { - vec![Box::new(TryInto::::try_into(linear_b( - image_dim_out, - hidden_size, - true, - vb.pp("img_projection"), - )?)?)] + let a = linear_b(image_dim_out, hidden_size, true, vb.pp("img_projection"))?; + tensors.push(("img_projection.weight".to_string(), a.weight().clone())); + if let Some(b) = a.bias().cloned() { + tensors.push(("img_projection.bias".to_string(), b)); + } + vec![Box::new(TryInto::::try_into(a)?)] } ("mlp", true) => { let dim_proj = hidden_size; + let a = linear_b(image_dim_out * 4, dim_proj, true, vb.pp("img_projection.0"))?; + tensors.push(("img_projection.0.weight".to_string(), a.weight().clone())); + if let Some(b) = a.bias().cloned() { + tensors.push(("img_projection.0.bias".to_string(), b)); + } + let b = linear_b(dim_proj, dim_proj, true, vb.pp("img_projection.2"))?; + tensors.push(("img_projection.2.weight".to_string(), b.weight().clone())); + if let Some(b) = b.bias().cloned() { + tensors.push(("img_projection.2.bias".to_string(), b)); + } vec![ - Box::new(TryInto::::try_into(linear_b( - image_dim_out * 4, - dim_proj, - true, - vb.pp("img_projection.0"), - )?)?), + Box::new(TryInto::::try_into(a)?), Box::new(candle_nn::Activation::Gelu), - Box::new(TryInto::::try_into(linear_b( - dim_proj, - dim_proj, - true, - vb.pp("img_projection.2"), - )?)?), + Box::new(TryInto::::try_into(b)?), ] } ("mlp", false) => { let dim_proj = hidden_size; + let a = linear_b(image_dim_out, dim_proj, true, vb.pp("img_projection.0"))?; + tensors.push(("img_projection.0.weight".to_string(), a.weight().clone())); + if let Some(b) = a.bias().cloned() { + tensors.push(("img_projection.0.bias".to_string(), b)); + } + let b = linear_b(dim_proj, dim_proj, true, vb.pp("img_projection.2"))?; + tensors.push(("img_projection.2.weight".to_string(), b.weight().clone())); + if let Some(b) = b.bias().cloned() { + tensors.push(("img_projection.2.bias".to_string(), b)); + } vec![ - Box::new(TryInto::::try_into(linear_b( - image_dim_out, - dim_proj, - true, - vb.pp("img_projection.0"), - )?)?), + Box::new(TryInto::::try_into(a)?), Box::new(candle_nn::Activation::Gelu), - Box::new(TryInto::::try_into(linear_b( - dim_proj, - dim_proj, - true, - vb.pp("img_projection.2"), - )?)?), + Box::new(TryInto::::try_into(b)?), ] } _ => { @@ -629,6 +632,7 @@ impl ImageEmbedding { hd_transform_order, use_hd_transform, vocab_size: config.vocab_size, + tensors, }) } @@ -854,6 +858,22 @@ impl ImageEmbedding { Ok(hidden_states) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + if let Some(glb_gn) = self.glb_gn.clone() { + uvb.add_tensor("glb_GN", glb_gn); + } + if let Some(sub_gn) = self.sub_gn.clone() { + uvb.add_tensor("sub_GN", sub_gn); + } + uvb.extend(self.tensors.clone()); + uvb.pp("img_processor.vision_model") + .extend(self.image_processor.residual_tensors()); + + uvb.to_safetensors() + } } // =================== ============= =================== @@ -1063,6 +1083,27 @@ impl IsqModel for Model { } (tensors, &*self.mapper) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + let uvb_m = uvb.pp("model"); + uvb_m.pp("embed_tokens").add(&self.embed_tokens); + uvb_m.pp("norm").add(&self.norm); + uvb_m + .pp("vision_embed_tokens") + .extend(self.vision_embed_tokens.residual_tensors()); + + for (layer_idx, layer) in self.layers.iter().enumerate() { + let uvb_l = uvb_m.pp("layers").pp(layer_idx); + uvb_l.pp("input_layernorm").add(&layer.input_layernorm); + uvb_l + .pp("post_attention_layernorm") + .add(&layer.post_attention_layernorm); + } + + uvb.to_safetensors() + } } pub(crate) struct Phi3VisionSpecificArgs { diff --git a/mistralrs-core/src/xlora_models/gemma.rs b/mistralrs-core/src/xlora_models/gemma.rs index ed251bb21..7de3ec675 100644 --- a/mistralrs-core/src/xlora_models/gemma.rs +++ b/mistralrs-core/src/xlora_models/gemma.rs @@ -5,7 +5,7 @@ use std::{collections::HashMap, sync::Arc}; use crate::{ amoe::AnyMoeBaseModelMixin, attention::SdpaParams, - layers::{RmsNorm, Sdpa}, + layers::{Activation, RmsNorm, Sdpa}, lora::{linear_b as linear, LinearLayerLike, LoraConfig, Ordering}, paged_attention::ModelConfigMetadata, pipeline::{ @@ -39,7 +39,7 @@ struct MLP { gate_proj: Arc, up_proj: Arc, down_proj: Arc, - act_fn: candle_nn::Activation, + act_fn: Activation, } impl MLP { @@ -802,6 +802,10 @@ impl IsqModel for XLoraModel { } (tensors, &*self.mapper) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + panic!("Cannot generate UQFF for an adapter model.") + } } impl NormalModel for XLoraModel { diff --git a/mistralrs-core/src/xlora_models/gemma2.rs b/mistralrs-core/src/xlora_models/gemma2.rs index 477bd0f9c..94a95a94d 100644 --- a/mistralrs-core/src/xlora_models/gemma2.rs +++ b/mistralrs-core/src/xlora_models/gemma2.rs @@ -12,7 +12,7 @@ use crate::{ amoe::AnyMoeBaseModelMixin, attention::SdpaParams, device_map::DeviceMapper, - layers::{CausalMasker, RmsNorm, Sdpa}, + layers::{Activation, CausalMasker, RmsNorm, Sdpa}, lora::{linear_b, linear_no_bias, LinearLayerLike, LoraConfig}, models::gemma2::Config, paged_attention::ModelConfigMetadata, @@ -33,7 +33,7 @@ struct MLP { gate_proj: Arc, up_proj: Arc, down_proj: Arc, - act_fn: candle_nn::Activation, + act_fn: Activation, } impl MLP { @@ -871,6 +871,10 @@ impl IsqModel for Model { } (tensors, &*self.mapper) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + panic!("Cannot generate UQFF for an adapter model.") + } } impl NormalModel for Model { diff --git a/mistralrs-core/src/xlora_models/llama.rs b/mistralrs-core/src/xlora_models/llama.rs index 55baa44c4..adab7d661 100644 --- a/mistralrs-core/src/xlora_models/llama.rs +++ b/mistralrs-core/src/xlora_models/llama.rs @@ -764,6 +764,10 @@ impl IsqModel for XLoraLlama { } (tensors, &*self.mapper) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + panic!("Cannot generate UQFF for an adapter model.") + } } impl NormalModel for XLoraLlama { diff --git a/mistralrs-core/src/xlora_models/mistral.rs b/mistralrs-core/src/xlora_models/mistral.rs index e449f5494..b22960a36 100644 --- a/mistralrs-core/src/xlora_models/mistral.rs +++ b/mistralrs-core/src/xlora_models/mistral.rs @@ -14,7 +14,7 @@ use crate::{ }; /// Mistral LLM, https://github.com/mistralai/mistral-src use candle_core::{DType, Device, Module, Result, Tensor}; -use candle_nn::{Activation, RotaryEmbedding, VarBuilder}; +use candle_nn::{RotaryEmbedding, VarBuilder}; use mistralrs_quant::QuantMethod; use std::{collections::HashMap, sync::Arc}; use tqdm::Iter; @@ -22,7 +22,7 @@ use tracing::info; use crate::{ device_map::DeviceMapper, - layers::{CausalMasker, RmsNorm}, + layers::{Activation, CausalMasker, RmsNorm}, models::mistral::Config, pipeline::{extract_logits, Cache, NormalModel}, }; @@ -797,6 +797,10 @@ impl IsqModel for XLoraModel { } (tensors, &*self.mapper) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + panic!("Cannot generate UQFF for an adapter model.") + } } impl NormalModel for XLoraModel { diff --git a/mistralrs-core/src/xlora_models/mixtral.rs b/mistralrs-core/src/xlora_models/mixtral.rs index 6468cb93d..67565f650 100644 --- a/mistralrs-core/src/xlora_models/mixtral.rs +++ b/mistralrs-core/src/xlora_models/mixtral.rs @@ -3,7 +3,7 @@ use crate::{ amoe::AnyMoeBaseModelMixin, attention::SdpaParams, - layers::Sdpa, + layers::{Activation, Sdpa}, lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering}, paged_attention::ModelConfigMetadata, pipeline::{ @@ -16,7 +16,7 @@ use crate::{ /// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py /// https://mistral.ai/news/mixtral-of-experts/ use candle_core::{DType, Device, Module, Result, Tensor}; -use candle_nn::{Activation, RotaryEmbedding, VarBuilder}; +use candle_nn::{RotaryEmbedding, VarBuilder}; use mistralrs_quant::QuantMethod; use std::{collections::HashMap, sync::Arc}; use tqdm::Iter; @@ -933,6 +933,10 @@ impl IsqModel for XLoraModel { } (tensors, &*self.mapper) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + panic!("Cannot generate UQFF for an adapter model.") + } } impl NormalModel for XLoraModel { diff --git a/mistralrs-core/src/xlora_models/phi2.rs b/mistralrs-core/src/xlora_models/phi2.rs index 96aa68402..c95baec20 100644 --- a/mistralrs-core/src/xlora_models/phi2.rs +++ b/mistralrs-core/src/xlora_models/phi2.rs @@ -5,7 +5,7 @@ use std::{collections::HashMap, sync::Arc}; use crate::{ amoe::AnyMoeBaseModelMixin, attention::SdpaParams, - layers::Sdpa, + layers::{Activation, Sdpa}, lora::{linear, LinearLayerLike, LoraConfig, Ordering}, paged_attention::ModelConfigMetadata, pipeline::{ @@ -20,9 +20,7 @@ use crate::{ /// This corresponds to the model update made with the following commit: /// https://huggingface.co/microsoft/phi-2/commit/cb2f4533604d8b67de604e7df03bfe6f3ca22869 use candle_core::{DType, Device, Result, Tensor}; -use candle_nn::{ - embedding, layer_norm, Activation, Embedding, LayerNorm, RotaryEmbedding, VarBuilder, -}; +use candle_nn::{embedding, layer_norm, Embedding, LayerNorm, RotaryEmbedding, VarBuilder}; use mistralrs_quant::QuantMethod; use tqdm::Iter; use tracing::info; @@ -754,6 +752,10 @@ impl IsqModel for Model { } (tensors, &*self.mapper) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + panic!("Cannot generate UQFF for an adapter model.") + } } impl NormalModel for Model { diff --git a/mistralrs-core/src/xlora_models/phi3.rs b/mistralrs-core/src/xlora_models/phi3.rs index 1178bd928..e7335e580 100644 --- a/mistralrs-core/src/xlora_models/phi3.rs +++ b/mistralrs-core/src/xlora_models/phi3.rs @@ -5,7 +5,7 @@ use crate::{ amoe::AnyMoeBaseModelMixin, attention::SdpaParams, - layers::Sdpa, + layers::{Activation, Sdpa}, lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering}, paged_attention::ModelConfigMetadata, pipeline::{ @@ -197,7 +197,7 @@ impl Attention { struct Mlp { gate_up_proj: Arc, down_proj: Arc, - act_fn: candle_nn::Activation, + act_fn: Activation, i_size: usize, } @@ -712,6 +712,10 @@ impl IsqModel for Model { } (tensors, &*self.mapper) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + panic!("Cannot generate UQFF for an adapter model.") + } } impl NormalModel for Model { diff --git a/mistralrs-core/src/xlora_models/starcoder2.rs b/mistralrs-core/src/xlora_models/starcoder2.rs index 504bfa1a1..097777142 100644 --- a/mistralrs-core/src/xlora_models/starcoder2.rs +++ b/mistralrs-core/src/xlora_models/starcoder2.rs @@ -11,7 +11,7 @@ use crate::{ amoe::AnyMoeBaseModelMixin, attention::SdpaParams, device_map::DeviceMapper, - layers::{CausalMasker, RotaryEmbedding, Sdpa}, + layers::{Activation, CausalMasker, RotaryEmbedding, Sdpa}, lora::{linear_b, linear_no_bias, LinearLayerLike, LoraConfig}, models::starcoder2::Config, paged_attention::ModelConfigMetadata, @@ -31,7 +31,7 @@ use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraC struct MLP { c_fc: Arc, c_proj: Arc, - act: candle_nn::Activation, + act: Activation, } impl MLP { @@ -773,6 +773,10 @@ impl IsqModel for Model { } (tensors, &*self.mapper) } + + fn residual_tensors(&self) -> Vec<(String, Tensor)> { + panic!("Cannot generate UQFF for an adapter model.") + } } impl NormalModel for Model { diff --git a/mistralrs-paged-attn/Cargo.toml b/mistralrs-paged-attn/Cargo.toml index 9c65dfd58..5e16e57a0 100644 --- a/mistralrs-paged-attn/Cargo.toml +++ b/mistralrs-paged-attn/Cargo.toml @@ -14,6 +14,7 @@ homepage.workspace = true [dependencies] candle-core.workspace = true half.workspace = true +float8.workspace = true [build-dependencies] bindgen_cuda = {git = "https://github.com/guoqingbao/bindgen_cuda.git", version = "0.1.6"} diff --git a/mistralrs-paged-attn/src/backend/mod.rs b/mistralrs-paged-attn/src/backend/mod.rs index ad40a237c..579caf44d 100644 --- a/mistralrs-paged-attn/src/backend/mod.rs +++ b/mistralrs-paged-attn/src/backend/mod.rs @@ -33,6 +33,7 @@ pub fn get_or_load_func( DType::F16 => "_f16", DType::F32 => "_f32", DType::F64 => "_f64", + DType::F8E4M3 => "_f8_e4m3", }; let spec = if let Some(suffix) = suffix { spec.to_owned() + suffix diff --git a/mistralrs-pyo3/API.md b/mistralrs-pyo3/API.md index b75439c78..9fe241253 100644 --- a/mistralrs-pyo3/API.md +++ b/mistralrs-pyo3/API.md @@ -11,8 +11,6 @@ These are API docs for the `mistralrs` package. Each `*_model_id` may be a HF hub repo or a local path. For quantized GGUF models, a list is accepted if multiples files must be specified. -Additionally, for models without quantization, the model architecture should be provided as the `arch` parameter in contrast to GGUF models which encode the architecture in the file. It should be one of the following: - ### Architecture for plain models If you do not specify the architecture, an attempt will be made to use the model's config. If this fails, please raise an issue. diff --git a/mistralrs-pyo3/Cargo_template.toml b/mistralrs-pyo3/Cargo_template.toml index 899313cb9..2f246049a 100644 --- a/mistralrs-pyo3/Cargo_template.toml +++ b/mistralrs-pyo3/Cargo_template.toml @@ -20,7 +20,7 @@ pyo3.workspace = true mistralrs-core = { version = "0.3.1", path = "../mistralrs-core", features=["pyo3_macros","$feature_name"] } serde.workspace = true serde_json.workspace = true -candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4", features=["$feature_name"] } +candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "f2b6941", features=["$feature_name"] } indexmap.workspace = true accelerate-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true } diff --git a/mistralrs-pyo3/mistralrs.pyi b/mistralrs-pyo3/mistralrs.pyi index ca8a7b964..d8efd24de 100644 --- a/mistralrs-pyo3/mistralrs.pyi +++ b/mistralrs-pyo3/mistralrs.pyi @@ -98,7 +98,6 @@ class IsqOrganization(Enum): Default = "default" MoQE = "moqe" - @dataclass class ModelDType(Enum): Auto = "auto" diff --git a/mistralrs-quant/Cargo.toml b/mistralrs-quant/Cargo.toml index cc6ba5ef0..5435778a2 100644 --- a/mistralrs-quant/Cargo.toml +++ b/mistralrs-quant/Cargo.toml @@ -21,6 +21,8 @@ paste = "1.0.15" tracing.workspace = true rayon.workspace = true byteorder = "1.5.0" +float8.workspace = true +once_cell.workspace = true [features] cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:bindgen_cuda"] diff --git a/mistralrs-quant/README.md b/mistralrs-quant/README.md index d6c8b9bda..5c18c5fe7 100644 --- a/mistralrs-quant/README.md +++ b/mistralrs-quant/README.md @@ -6,6 +6,7 @@ Currently supported: - GGUF: `GgufMatMul` - Gptq: `GptqLayer` - Hqq: `HqqLayer` +- FP8: `FP8Linear` - Unquantized (used for ISQ): `UnquantLinear` Some kernels are copied or based on implementations in: diff --git a/mistralrs-quant/build.rs b/mistralrs-quant/build.rs index d9e09f1c6..dfff14a0c 100644 --- a/mistralrs-quant/build.rs +++ b/mistralrs-quant/build.rs @@ -11,6 +11,7 @@ fn main() { "kernels/gptq/q_gemm.cu", "kernels/hqq/hqq.cu", "kernels/ops/ops.cu", + "kernels/marlin/marlin_kernel.cu", ]; for lib_file in lib_files.iter() { println!("cargo:rerun-if-changed={lib_file}"); diff --git a/mistralrs-quant/kernels/marlin/marlin/marlin.cuh b/mistralrs-quant/kernels/marlin/marlin/marlin.cuh new file mode 100644 index 000000000..3300b5a79 --- /dev/null +++ b/mistralrs-quant/kernels/marlin/marlin/marlin.cuh @@ -0,0 +1,118 @@ +#pragma once +#include +#include +#include +#include +#include + +// #define CHECK(cond, ...) \ +// assert(cond); \ + +#define CHECK(cond, ...) + +namespace marlin { + +// Marlin params + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. + +static constexpr int repack_threads = 256; +static constexpr int repack_stages = 8; +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; + +static constexpr int tile_size = 16; +static constexpr int max_par = 16; +static constexpr int tile_k_size = tile_size; +static constexpr int tile_n_size = tile_k_size * 4; + +__device__ inline constexpr int ceildiv(int a, int b) { + return (a + b - 1) / b; +} + +// Predicated asynchronous global->shared copy; used for inputs A where we apply +// predication to handle batchsizes that are not multiples of 16. +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +// Asynchronous global->shared copy +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed +// for instance as inputs to tensor core operations. Consequently, all +// corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee +// this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { return elems[i]; } +}; + +using I4 = Vec; + +} // namespace marlin diff --git a/mistralrs-quant/kernels/marlin/marlin/marlin_dtypes.cuh b/mistralrs-quant/kernels/marlin/marlin/marlin_dtypes.cuh new file mode 100644 index 000000000..be06c09be --- /dev/null +++ b/mistralrs-quant/kernels/marlin/marlin/marlin_dtypes.cuh @@ -0,0 +1,79 @@ + +#ifndef _data_types_cuh +#define _data_types_cuh +#include "marlin.cuh" +#include +#include + +namespace marlin { + +template +class ScalarType {}; + +template <> +class ScalarType { + public: + using scalar_t = half; + using scalar_t2 = half2; + + // Matrix fragments for tensor core instructions; their precise layout is + // documented here: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + using FragZP = Vec; + + static __device__ float inline num2float(const half x) { + return __half2float(x); + } + + static __device__ half2 inline num2num2(const half x) { + return __half2half2(x); + } + + static __device__ half2 inline nums2num2(const half x1, const half x2) { + return __halves2half2(x1, x2); + } + + static __host__ __device__ half inline float2num(const float x) { + return __float2half(x); + } +}; + +template <> +class ScalarType { + public: + using scalar_t = nv_bfloat16; + using scalar_t2 = nv_bfloat162; + + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + using FragZP = Vec; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + static __device__ float inline num2float(const nv_bfloat16 x) { + return __bfloat162float(x); + } + + static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { + return __bfloat162bfloat162(x); + } + + static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, + const nv_bfloat16 x2) { + return __halves2bfloat162(x1, x2); + } + + static __host__ __device__ nv_bfloat16 inline float2num(const float x) { + return __float2bfloat16(x); + } +#endif +}; + +} // namespace marlin + +#endif diff --git a/mistralrs-quant/kernels/marlin/marlin_kernel.cu b/mistralrs-quant/kernels/marlin/marlin_kernel.cu new file mode 100644 index 000000000..181d4be12 --- /dev/null +++ b/mistralrs-quant/kernels/marlin/marlin_kernel.cu @@ -0,0 +1,1246 @@ +/* + * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#ifndef MARLIN_CUDA_KERNEL_CUH +#define MARLIN_CUDA_KERNEL_CUH +#include +#include +#include +#include +#include "marlin/marlin_dtypes.cuh" +using namespace marlin; + +// m16n8k16 tensor core mma instruction with fp16/bf16 inputs and fp32 +// output/accumulation. +template +__device__ inline void mma(const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +template +__device__ inline void ldsm4(typename ScalarType::FragA& frag_a, + const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + + +template +__device__ inline typename ScalarType::FragB dequant(int q); + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 +// values. We mostly follow the strategy in the link below, with some small +// changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +template <> +__device__ inline typename ScalarType::FragB + dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + typename ScalarType::FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB + dequant(int q) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + typename ScalarType::FragB frag_b; + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC308C308; + + frag_b[0] = __hfma2(*reinterpret_cast(&lo), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +template +__device__ inline void scale(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = + ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts in + // the middle of group. + if (group_blocks != -1) + iters = (group_blocks / thread_k_blocks) * + ceildiv(iters, (group_blocks / thread_k_blocks)); + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + // We typically use `constexpr` to indicate that this value is a compile-time + // constant + constexpr int a_sh_stride = + 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory + constexpr int a_gl_rd_delta_o = + 16 * thread_k_blocks / + 8; // delta between subsequent A tiles in global memory + int a_gl_rd_delta_i = + a_gl_stride * + (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile + constexpr int a_sh_wr_delta = + a_sh_stride * + (threads / a_gl_rd_delta_o); // between shared memory writes + constexpr int a_sh_rd_delta_o = + 2 * ((threads / 32) / + (thread_n_blocks / 4)); // between shared memory tile reads + constexpr int a_sh_rd_delta_i = + a_sh_stride * 16; // within a shared memory tile + constexpr int a_sh_stage = + a_sh_stride * (16 * thread_m_blocks); // overall size of a tile + constexpr int a_sh_wr_iters = + ceildiv(a_sh_stage, + a_sh_wr_delta); // number of shared write iterations for a tile + + int b_gl_stride = 16 * prob_n / 32; + constexpr int b_sh_stride = 32 * thread_n_blocks / 4; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); + constexpr int b_sh_wr_delta = threads; + constexpr int b_sh_rd_delta = threads; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_sh_stage = s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = + b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x; + int b_sh_rd = threadIdx.x; + + int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + int s_sh_wr = threadIdx.x; + int s_sh_rd; + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + if (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_s = sh_b + (stages * b_sh_stage); + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); + B_ptr[i] += b_gl_rd_delta_o; + } + // Only fetch scales if this tile starts a new group + if constexpr (group_blocks != -1) { + // This assumes group_blocks >= thread_k_blocks + // and would need to be modified to support smaller groups. + static_assert(group_blocks >= thread_k_blocks); + if (pipe % (group_blocks / thread_k_blocks) == 0) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + s_gl_rd += s_gl_rd_delta; + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + // It may seem inefficient that we reload the groups for every sub-tile; + // however, this does not seem to be a significant bottleneck, while some + // theoretically better attempts have lead to bad instruction ordering by + // the compiler and correspondingly a noticeable drop in performance. + if constexpr (group_blocks != -1) { + // This assumes group_blocks >= thread_k_blocks + // and would need to be modified to support smaller groups. + static_assert(group_blocks >= thread_k_blocks); + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant = frag_b_quant[k % 2][j]; + int b_quant_shift = b_quant >> 8; + FragB frag_b0 = dequant(b_quant); + // If there are no groups, we can just scale the final output once and can + // avoid doing so for each weight. + if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0); + FragB frag_b1 = dequant(b_quant_shift); + if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1); + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride; + constexpr int red_sh_stride = b_sh_stride * 4 * 2; + constexpr int red_sh_delta = b_sh_stride; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + + (threadIdx.x % b_sh_stride); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + Dtype::num2float(reinterpret_cast(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = + c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + scalar_t2 res = + Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (group_blocks == -1) { + res = __hmul2(res, s[0]); + } + + ((scalar_t2*)sh)[idx] = res; + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + #pragma unroll + for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); + zero_accums(); + wait_for_stage(); + fetch_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + }; + start_pipes(); + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines have + // even length meaning that the next iteration will always start at index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) break; + } + a_gl_rd += a_gl_rd_delta_o * stages; + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if (group_blocks == -1 && last) { + if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async_fence(); + } + thread_block_reduce(); + if (group_blocks == -1 && last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + start_pipes(); + } + } + } +} + + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +const int USER_THREADS = 256; // Note: This is only used with user-provided thread_k/n +const int STAGES = 4; // 4 pipeline stages fit into shared memory +const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) +static constexpr int pack_factor_4bit = 8; // We have 8 4-bit vals inside a 32 bit + +#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + GROUP_BLOCKS, NUM_THREADS) \ + else if (thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute(Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + SHARED_MEM); \ + Marlin<<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, prob_m, prob_n, prob_k, locks); \ + } + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + // thread_k, thread_n, num_threads + {128, 128, 256}, // Default + {128, 64, 128}, // Reduce N 2X, same K + {64, 256, 256}, // Reduce K 2X, increase N 2X + {64, 128, 128}, // Reduce K 2X, same N +}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + // thread_k, thread_n, num_threads + {64, 256, 256}, // Default + {128, 128, 256}, // Reduce N 2X, increase K 2X + {64, 128, 128}, // Reduce N 2X, same K + {128, 64, 128}, // Reduce N 4X, increase K 2X +}; + +bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, + int prob_k) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || + th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // thread_k can be only 128 or 64 (because it must be less than groupsize + // which is 128) + if (th_config.thread_k != 128 && th_config.thread_k != 64) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + return true; +} + +thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { + return th_config; + } + } + + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { + return th_config; + } + } + } + + return thread_config_t{-1, -1, -1}; +} + +#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) + +template +void marlin_matmul(const void* A, const void* B, void* s, void* C, int prob_m, int prob_k, + int prob_n, void* workspace, int groupsize + ) { + + int dev = 0; + cudaStream_t stream = 0; + int thread_k = -1; + int thread_n = -1; + int sms = -1; + int max_par = 16; + + int tot_m = prob_m; + int tot_m_blocks = ceildiv(tot_m, 16); + int pad = 16 * tot_m_blocks - tot_m; + + if (sms == -1) + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + + // Set thread config + thread_config_t th_config; + if (thread_k != -1 && thread_n != -1) { + // User-defined config + th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; + } else { + // Auto config + th_config = determine_thread_config(prob_m, prob_n, prob_k); + } + + if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) { + throw std::runtime_error( + "Invalid thread config"); + } + + int num_threads = th_config.num_threads; + thread_k = th_config.thread_k; + thread_n = th_config.thread_n; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; + int blocks = sms; + + if (prob_m == 0 || prob_n == 0 || prob_k == 0) { + return; + } + + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + const int4* s_ptr = (const int4*)s; + + int* locks = (int*)workspace; + + for (int i = 0; i < tot_m_blocks; i += 4) { + int thread_m_blocks = tot_m_blocks - i; + prob_m = tot_m - 16 * i; + int par = 1; + if (thread_m_blocks > 4) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * thread_m_blocks - pad) / 64; + if (par > max_par) par = max_par; + prob_m = 64 * par; + i += 4 * (par - 1); + thread_m_blocks = 4; + } + + // For compilation speed, we only define the kernel configurations that have + // seemed useful (in terms of performance) in our testing, however many more + // are, in principle, possible. + if (false) { + } + CALL_IF(8, 8, 256) + CALL_IF(16, 4, 256) + CALL_IF(8, 4, 128) + CALL_IF(4, 8, 128) + else { + throw std::runtime_error("Unsupported shapes: MKN"); + } + + A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; + } +} + +extern "C" void marlin_4bit_f16(const void* A, const void* B, void* s, void* C, int prob_m, int prob_k, + int prob_n, void* workspace, int groupsize + ) { + marlin_matmul(A, B, s, C, prob_m, prob_k, prob_n, workspace, groupsize); +} + +extern "C" void marlin_4bit_bf16(const void* A, const void* B, void* s, void* C, int prob_m, int prob_k, + int prob_n, void* workspace, int groupsize + ) { + marlin_matmul(A, B, s, C, prob_m, prob_k, prob_n, workspace, groupsize); +} + + +template +__global__ void gptq_marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, + uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, + int size_k, int size_n) { + constexpr int pack_factor = 32 / num_bits; + + int k_tiles = size_k / tile_k_size; + int n_tiles = size_n / tile_n_size; + int block_k_tiles = ceildiv(k_tiles, gridDim.x); + + int start_k_tile = blockIdx.x * block_k_tiles; + if (start_k_tile >= k_tiles) { + return; + } + + int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + extern __shared__ int4 sh[]; + + constexpr int perm_size = tile_k_size / 4; + + int4* sh_perm_ptr = sh; + int4* sh_pipe_ptr = sh_perm_ptr; + if constexpr (has_perm) { + sh_pipe_ptr += perm_size; + } + + constexpr int tile_ints = tile_k_size / pack_factor; + + constexpr int stage_n_threads = tile_n_size / 4; + constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; + constexpr int stage_size = stage_k_threads * stage_n_threads; + + auto load_perm_to_shared = [&](int k_tile_id) { + int first_k_int4 = (k_tile_id * tile_k_size) / 4; + + int4 const* perm_int4_ptr = reinterpret_cast(perm_ptr); + + if (threadIdx.x < perm_size) { + sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; + } + __syncthreads(); + }; + + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + cp_async_fence(); + return; + } + + int first_n = n_tile_id * tile_n_size; + + int4* sh_ptr = sh_pipe_ptr + stage_size * pipe; + + if constexpr (has_perm) { + if (threadIdx.x < stage_size) { + int k_id = threadIdx.x / stage_n_threads; + int n_id = threadIdx.x % stage_n_threads; + + uint32_t const* sh_perm_int_ptr = + reinterpret_cast(sh_perm_ptr); + + int src_k = sh_perm_int_ptr[k_id]; + int src_k_packed = src_k / pack_factor; + + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast(&( + b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); + } + + } else { + if (threadIdx.x < stage_size) { + int k_id = threadIdx.x / stage_n_threads; + int n_id = threadIdx.x % stage_n_threads; + + int first_k = k_tile_id * tile_k_size; + int first_k_packed = first_k / pack_factor; + + cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast( + &(b_q_weight_ptr[(first_k_packed + k_id) * size_n + + first_n + (n_id * 4)]))); + } + } + + cp_async_fence(); + }; + + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + return; + } + + int warp_id = threadIdx.x / 32; + int th_id = threadIdx.x % 32; + + if (warp_id >= 4) { + return; + } + + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * 2; + + constexpr int tc_offsets[4] = {0, 1, 8, 9}; + + int cur_n = warp_id * 16 + tc_col; + + constexpr int sh_stride = 64; + constexpr uint32_t mask = (1 << num_bits) - 1; + + int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; + uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + + uint32_t* sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + + uint32_t vals[8]; + + if constexpr (has_perm) { + for (int i = 0; i < 4; i++) { + int k_idx = tc_row + tc_offsets[i]; + + uint32_t src_k = sh_perm_int_ptr[k_idx]; + uint32_t src_k_pos = src_k % pack_factor; + + uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; + uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask; + + uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; + uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask; + + vals[i] = b1_cur_val; + vals[4 + i] = b2_cur_val; + } + + } else { + uint32_t b1_vals[tile_ints]; + uint32_t b2_vals[tile_ints]; + + #pragma unroll + for (int i = 0; i < tile_ints; i++) { + b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; + b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; + } + + #pragma unroll + for (int i = 0; i < 4; i++) { + int cur_elem = tc_row + tc_offsets[i]; + int cur_int = cur_elem / pack_factor; + int cur_pos = cur_elem % pack_factor; + + vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; + vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; + } + } + + constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + if constexpr (num_bits == 4) { + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; + #pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + + } else { + constexpr int pack_idx[4] = {0, 2, 1, 3}; + + uint32_t res1 = 0; + uint32_t res2 = 0; + #pragma unroll + for (int i = 0; i < 4; i++) { + res1 |= vals[pack_idx[i]] << (i * 8); + res2 |= vals[4 + pack_idx[i]] << (i * 8); + } + + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } + }; + + auto start_pipes = [&](int k_tile_id, int n_tile_id) { + #pragma unroll + for (int pipe = 0; pipe < repack_stages - 1; pipe++) { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } + + wait_for_stage(); + }; + #pragma unroll + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { + int n_tile_id = 0; + + if constexpr (has_perm) { + load_perm_to_shared(k_tile_id); + } + + start_pipes(k_tile_id, n_tile_id); + + while (n_tile_id < n_tiles) { + #pragma unroll + for (int pipe = 0; pipe < repack_stages; pipe++) { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, + n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; + } + } +} + + #define CALL_IF2(NUM_BITS, HAS_PERM) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ + cudaFuncSetAttribute( \ + gptq_marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + gptq_marlin_repack_kernel \ + <<>>( \ + b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ + } + +extern "C" void gptq_marlin_repack(void* weight, void* perm, void* out, + int size_k, int size_n, + int num_bits) { + // Verify compatibility with marlin tile of 16x64 + assert(size_k % tile_k_size == 0); + assert(size_n % tile_n_size == 0); + assert(num_bits == 4 || num_bits == 8); + const int pack_factor = 32 / num_bits; + + bool has_perm = true; + int dev = 0; + // Get ptrs + uint32_t const* b_q_weight_ptr = + reinterpret_cast(weight); + uint32_t const* perm_ptr = reinterpret_cast(perm); + uint32_t* out_ptr = reinterpret_cast(out); + + // Get dev info + cudaStream_t stream = 0; + int blocks; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + assert(max_shared_mem > 0); + + if (false) { + } + CALL_IF2(4, false) + CALL_IF2(4, true) + CALL_IF2(8, false) + CALL_IF2(8, true) + else { + assert(false); + } + +} + +#endif diff --git a/mistralrs-quant/src/cublaslt/api.rs b/mistralrs-quant/src/cublaslt/api.rs new file mode 100644 index 000000000..f0f600065 --- /dev/null +++ b/mistralrs-quant/src/cublaslt/api.rs @@ -0,0 +1,507 @@ +use candle_core::cuda::cudarc::driver::DevicePtr; +use float8::F8E4M3; +use std::ffi::c_int; + +use candle_core::backend::BackendStorage; +use candle_core::cuda_backend::WrapErr; +use candle_core::{CpuStorage, DType, Device, Layout, Result, Shape, Storage, Tensor}; +use half::{bf16, f16}; +use std::sync::Arc; + +use super::matmul::{Activation, CudaBlasLT, Matmul, MatmulConfig, OutSlice}; +use super::F8MatmulOutType; + +#[derive(Debug, Clone)] +pub struct CublasLt(Arc); + +impl CublasLt { + pub fn new(device: &Device) -> Result { + let dev = match device { + Device::Cuda(d) => d, + _ => candle_core::bail!("`device` must be a `cuda` device"), + }; + + let inner = CudaBlasLT::new(dev.cuda_device()).unwrap(); + + Ok(Self(Arc::new(inner))) + } +} + +pub struct CublasLTBatchMatmulF8 { + pub cublaslt: Arc, + pub act: Option, + pub c: Option, + pub alpha: Option, + pub beta: Option, + // Dequantize + pub a_scale: Tensor, + pub b_scale: Tensor, + // Quantize + pub d_scale: Tensor, + pub out_dtype: F8MatmulOutType, +} + +impl CublasLTBatchMatmulF8 { + pub fn fwd_f8e4m3( + &self, + a: &candle_core::CudaStorage, + a_l: &Layout, + b: &candle_core::CudaStorage, + b_l: &Layout, + bias: Option<&candle_core::CudaStorage>, + bias_l: Option<&Layout>, + ) -> Result<(candle_core::CudaStorage, Shape)> { + let dev = a.device(); + + // Assume TN + let (batch_size, m, k) = a_l.shape().dims3()?; + let (b_0, n, b_2) = b_l.shape().dims3()?; + + if b_2 != k { + candle_core::bail!("This layer only supports TN layout"); + } + + if b_0 != batch_size { + candle_core::bail!("`b` must have the same batch size as `a`") + } + + if !self.a_scale.dims().is_empty() || self.a_scale.dtype() != DType::F32 { + candle_core::bail!("`a_scale` must be a f32 scalar."); + } + if !self.b_scale.dims().is_empty() || self.b_scale.dtype() != DType::F32 { + candle_core::bail!("`b_scale` must be a f32 scalar."); + } + if !self.d_scale.dims().is_empty() || self.d_scale.dtype() != DType::F32 { + candle_core::bail!("`d_scale` must be a f32 scalar."); + } + let (a_s, _) = self.a_scale.storage_and_layout(); + let (b_s, _) = self.b_scale.storage_and_layout(); + let (d_s, _) = self.d_scale.storage_and_layout(); + + let a_scale = match &*a_s { + Storage::Cuda(scale) => scale.as_cuda_slice::()?, + _ => candle_core::bail!("`a_scale` must be a cuda tensor"), + }; + let b_scale = match &*b_s { + Storage::Cuda(scale) => scale.as_cuda_slice::()?, + _ => candle_core::bail!("`b_scale` must be a cuda tensor"), + }; + let d_scale = match &*d_s { + Storage::Cuda(scale) => scale.as_cuda_slice::()?, + _ => candle_core::bail!("`d_scale` must be a cuda tensor"), + }; + + let lda = k; + let ldb = k; + let ldc = m; + + let out_shape = Shape::from((batch_size, n, m)); + + let a = a.as_cuda_slice::()?.slice(a_l.start_offset()..); + let b = b.as_cuda_slice::()?.slice(b_l.start_offset()..); + + let (bias, bias_stride) = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { + if bias_l.dims().len() == 1 { + if bias_l.shape().dims1()? != m { + candle_core::bail!("Bias does not have the correct shape"); + } + ( + Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)), + None, + ) + } else { + if bias_l.shape().dims2()?.1 != m { + candle_core::bail!("Bias does not have the correct shape"); + } + if bias_l.shape().dims2()?.0 != batch_size { + candle_core::bail!("Bias batch size must match batch size of `a`"); + } + let bias_stride = bias_l.stride()[0] as i64; + ( + Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)), + Some(bias_stride), + ) + } + } else { + (None, None) + }; + + let (c, stride_c) = if let Some(c) = &self.c { + let (c, c_l) = c.storage_and_layout(); + let c = match &*c { + Storage::Cuda(storage) => storage.as_cuda_slice::()?, + _ => candle_core::bail!("`c` must be a cuda tensor"), + }; + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + candle_core::bail!("`c` start offset must be 0"); + } + if o2 != out_shape.elem_count() { + candle_core::bail!("`c` end offset must be {}", out_shape.elem_count()) + } + } + None => candle_core::bail!("`c` has to be contiguous"), + }; + + if c_l.shape().dims3()? != (batch_size, n, m) { + candle_core::bail!("`c` does not have the correct shape"); + } + + // Set beta to 0.0 if it is not set + (c.clone(), c_l.stride()[0]) + } else { + // Allocate out tensor + ( + unsafe { dev.alloc::(out_shape.elem_count()).w()? }, + (n * m), + ) + }; + let (mut out, stride_c) = match self.out_dtype { + F8MatmulOutType::BF16 => ( + OutSlice::BF16(unsafe { dev.alloc::(out_shape.elem_count()).w()? }), + (n * m), + ), + F8MatmulOutType::F8 => ( + OutSlice::F8(unsafe { dev.alloc::(out_shape.elem_count()).w()? }), + (n * m), + ), + }; + + let cases = [ + k * std::mem::size_of::(), + k * std::mem::size_of::(), + m * std::mem::size_of::(), // C type size + lda * std::mem::size_of::(), // A type size + ldb * std::mem::size_of::(), // B type size + ldc * std::mem::size_of::(), // C type size + *a.device_ptr() as usize, + *b.device_ptr() as usize, + *c.device_ptr() as usize, + *a_scale.device_ptr() as usize, + *b_scale.device_ptr() as usize, + *d_scale.device_ptr() as usize, + ]; + + for case in cases { + if case % 16 != 0 { + candle_core::bail!("F8 cuBLASlt matmul must match all cases described here: https://docs.nvidia.com/cuda/cublas/#tensor-core-usage"); + } + } + + let config = MatmulConfig { + transa: true, + transb: false, + m: m as u64, + n: n as u64, + k: k as u64, + alpha: self.alpha.unwrap_or(1.0), + lda: lda as i64, + ldb: ldb as i64, + beta: self.beta.unwrap_or(0.0), + ldc: ldc as i64, + stride_a: Some(a_l.stride()[0] as i64), + stride_b: Some(b_l.stride()[0] as i64), + stride_c: Some(stride_c as i64), + stride_bias: bias_stride, + batch_size: Some(c_int::try_from(batch_size)?), + }; + + // let mut amaxd = unsafe { dev.alloc_zeros::(1).w()? }; + + unsafe { + self.cublaslt + .matmul_fp8_like( + config, + &a, + &b, + a_scale, + b_scale, + d_scale, + &c, + &mut out, + // &mut amaxd, + bias.as_ref(), + self.act.as_ref(), + ) + .map_err(|e| candle_core::Error::Cuda(Box::new(e)))?; + } + + let out = match out { + OutSlice::BF16(s) => candle_core::CudaStorage::wrap_cuda_slice(s, dev.clone()), + OutSlice::F8(s) => candle_core::CudaStorage::wrap_cuda_slice(s, dev.clone()), + }; + + Ok((out, out_shape)) + } +} + +/// Fused batch matmul + add + Relu/Gelu activation using CublasLt for F8 dtypes. +/// +/// # Arguments +/// +/// * `a` - Input tensor of size BxMxK +/// * `b` - Input tensor of size BxNxK +/// * `dequant_a_scale` - F32 scalar tensor, used to `a` the out tensor. +/// * `dequant_b_scale` - F32 scalar tensor, used to `b` the out tensor. +/// * `quantize_scale` - F32 scalar tensor, used to requantize. +/// * `out` - Optional Output tensor of size BxNxK. +/// If set and beta != 0, will be added to the end result of A*B before `act` +/// * `alpha` - Optional scaling factor for A*B +/// * `beta` - Optional scaling factor for C +/// * `bias` - Optional bias tensor of size M +/// * `act` - Optional Gelu or Relu activation. If set, will be added to the end result +/// * `cublaslt` - CublasLt handle +/// +/// The resulting tensor is of shape NxM +#[allow(clippy::too_many_arguments)] +pub fn fused_batch_matmul_f8( + a: &Tensor, + b: &Tensor, + dequant_a_scale: &Tensor, + dequant_b_scale: &Tensor, + quantize_scale: &Tensor, + out: Option<&Tensor>, + alpha: Option, + beta: Option, + bias: Option<&Tensor>, + act: Option, + out_dtype: F8MatmulOutType, + cublaslt: CublasLt, +) -> Result { + let op = CublasLTBatchMatmulF8 { + act, + cublaslt: cublaslt.0, + c: out.cloned(), + alpha, + beta, + a_scale: dequant_a_scale.clone(), + b_scale: dequant_b_scale.clone(), + d_scale: quantize_scale.clone(), + out_dtype, + }; + + if let Some(bias) = bias { + a.apply_op3(b, bias, op) + } else { + a.apply_op2(b, op) + } +} + +impl candle_core::CustomOp2 for CublasLTBatchMatmulF8 { + fn name(&self) -> &'static str { + "cublaslt-batch-matmul-f8" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("no cpu support for cublaslt-batch-matmul-f8") + } + + fn cuda_fwd( + &self, + a: &candle_core::CudaStorage, + a_l: &Layout, + b: &candle_core::CudaStorage, + b_l: &Layout, + ) -> Result<(candle_core::CudaStorage, Shape)> { + match a.dtype() { + candle_core::DType::F8E4M3 => self.fwd_f8e4m3(a, a_l, b, b_l, None, None), + dt => { + candle_core::bail!("cublaslt-batch-matmul is only supported for f8e4m3 ({dt:?})") + } + } + } +} + +impl candle_core::CustomOp3 for CublasLTBatchMatmulF8 { + fn name(&self) -> &'static str { + "cublaslt-batch-matmul-add-f8" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle_core::bail!("no cpu support for cublaslt-batch-matmul-add-f8") + } + + fn cuda_fwd( + &self, + a: &candle_core::CudaStorage, + a_l: &Layout, + b: &candle_core::CudaStorage, + b_l: &Layout, + bias: &candle_core::CudaStorage, + bias_l: &Layout, + ) -> Result<(candle_core::CudaStorage, Shape)> { + match a.dtype() { + candle_core::DType::F8E4M3 => self.fwd_f8e4m3(a, a_l, b, b_l, Some(bias), Some(bias_l)), + dt => candle_core::bail!( + "cublaslt-batch-matmul-add is only supported for f8e4m3 ({dt:?})" + ), + } + } +} + +#[cfg(test)] +mod tests { + use std::f32::consts::PI; + + use super::*; + use candle_core::{DType, Device, IndexOp}; + + // The bias bit seems to trip the test up. Not really sure why; it may be something locally. + #[test] + #[ignore] + fn test_fused_batch_matmul_f8e4m3() -> Result<()> { + let device = Device::new_cuda(0)?; + + let a = Tensor::randn(0., 1., (16, 16, 16), &device)?.to_dtype(DType::F32)?; + let b = Tensor::randn(0., 1., (16, 16, 16), &device)?.to_dtype(DType::F32)?; + let c = Tensor::randn(0., 1., (16, 16, 16), &device)?.to_dtype(DType::F32)?; + let bias = Tensor::randn(0., 1., 16, &device)?.to_dtype(DType::F32)?; + let dummy_scale = Tensor::new(1f32, &device)?; + + let cublaslt = CublasLt::new(&device)?; + + let res = fused_batch_matmul_f8( + &a.to_dtype(DType::F8E4M3)?, + &b.to_dtype(DType::F8E4M3)?, + &dummy_scale, + &dummy_scale, + &dummy_scale, + Some(&c.to_dtype(DType::BF16)?), + None, + Some(1.), + Some(&bias.to_dtype(DType::BF16)?), + None, + F8MatmulOutType::F8, + cublaslt, + )?; + let expected = b.matmul(&a.t()?)?.add(&c)?.broadcast_add(&bias)?; + + let abs_diff = (res.to_dtype(DType::F32)? - expected)?.abs()?; + let absmax = abs_diff.max(0)?.max(0)?.max(0)?.to_scalar::()?; + let abs_diff = abs_diff.to_vec3::()?; + let range = 3e-01; + assert!(abs_diff + .iter() + .all(|x| x.into_iter().all(|y| y.into_iter().all(|x| *x <= range)))); + Ok(()) + } + + #[test] + fn test_fused_batch_matmul_f8e4m3_nobias() -> Result<()> { + let device = Device::new_cuda(0)?; + + let a = Tensor::randn(0., 1., (16, 16, 16), &device)?.to_dtype(DType::F32)?; + let b = Tensor::randn(0., 1., (16, 16, 16), &device)?.to_dtype(DType::F32)?; + let c = Tensor::randn(0., 1., (16, 16, 16), &device)?.to_dtype(DType::F32)?; + + fn quantize(data: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> { + let data = data.to_dtype(DType::F32)?; + let mut absmax = data.clone(); + while !absmax.dims().is_empty() { + absmax = absmax.max(0)?; + } + let max_v = F8E4M3::MAX.to_f64().round(); + let scale = (max_v / absmax)?.clamp(1e-12, f64::INFINITY)?; + let qw = data.broadcast_mul(&scale)?.to_dtype(DType::F8E4M3)?; + Ok((qw, scale)) + } + let (qa, a_scale) = quantize(&a, DType::F8E4M3)?; + let (qb, b_scale) = quantize(&b, DType::F8E4M3)?; + println!("{a_scale}"); + + let cublaslt = CublasLt::new(&device)?; + + let res = fused_batch_matmul_f8( + &qa, + &qb, + &a_scale.recip()?, + &b_scale.recip()?, + &a_scale, + Some(&c.to_dtype(DType::BF16)?), + None, + Some(1.), + None, + None, + F8MatmulOutType::BF16, + cublaslt, + )? + .i((0..2, 0..2, 0..2))?; + let expected = b.matmul(&a.t()?)?.add(&c)?.i((0..2, 0..2, 0..2))?; + + let abs_diff = (res.to_dtype(DType::F32)? - expected)?.abs()?; + let absmax = abs_diff.max(0)?.max(0)?.max(0)?.to_scalar::()?; + let abs_diff = abs_diff.to_vec3::()?; + let range = 3e-01; + assert!(abs_diff + .iter() + .all(|x| x.into_iter().all(|y| y.into_iter().all(|x| *x <= range)))); + Ok(()) + } + + #[test] + fn test_fused_batch_matmul_f8e4m3_out_bf16() -> Result<()> { + let device = Device::new_cuda(0)?; + + let a = Tensor::randn(0., 1., (16, 16, 16), &device)?.to_dtype(DType::F32)?; + let b = Tensor::randn(0., 1., (16, 16, 16), &device)?.to_dtype(DType::F32)?; + let c = Tensor::randn(0., 1., (16, 16, 16), &device)?.to_dtype(DType::F32)?; + + fn quantize(data: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> { + let data = data.to_dtype(DType::F32)?; + let mut absmax = data.clone(); + while !absmax.dims().is_empty() { + absmax = absmax.max(0)?; + } + let max_v = F8E4M3::MAX.to_f64().round(); + let scale = (max_v / absmax)?.clamp(1e-12, f64::INFINITY)?; + let qw = data.broadcast_mul(&scale)?.to_dtype(DType::F8E4M3)?; + Ok((qw, scale)) + } + let (qa, a_scale) = quantize(&a, DType::F8E4M3)?; + let (qb, b_scale) = quantize(&b, DType::F8E4M3)?; + + let cublaslt = CublasLt::new(&device)?; + + let res = fused_batch_matmul_f8( + &qa, + &qb, + &a_scale.recip()?, + &b_scale.recip()?, + &a_scale, + Some(&c.to_dtype(DType::BF16)?), + None, + Some(1.), + None, + None, + F8MatmulOutType::BF16, + cublaslt, + )? + .i((0..2, 0..2, 0..2))?; + let expected = b.matmul(&a.t()?)?.add(&c)?.i((0..2, 0..2, 0..2))?; + + let abs_diff = (res.to_dtype(DType::F32)? - expected)?.abs()?; + let absmax = abs_diff.max(0)?.max(0)?.max(0)?.to_scalar::()?; + let abs_diff = abs_diff.to_vec3::()?; + + let range = 3e-01; + assert!(abs_diff + .iter() + .all(|x| x.into_iter().all(|y| y.into_iter().all(|x| *x <= range)))); + Ok(()) + } +} diff --git a/mistralrs-quant/src/cublaslt/matmul.rs b/mistralrs-quant/src/cublaslt/matmul.rs new file mode 100644 index 000000000..851f339c5 --- /dev/null +++ b/mistralrs-quant/src/cublaslt/matmul.rs @@ -0,0 +1,547 @@ +use candle_core::cuda::cudarc::cublaslt::result::set_matrix_layout_attribute; +use candle_core::cuda::cudarc::cublaslt::{result, result::CublasError, sys}; +use candle_core::cuda::cudarc::driver::sys::{CUdevice_attribute, CUdeviceptr, CUstream}; +use candle_core::cuda::cudarc::driver::{ + CudaDevice, CudaSlice, DevicePtr, DevicePtrMut, DriverError, +}; +use core::ffi::c_int; +use core::mem; +use float8::F8E4M3; +use half::bf16; +use std::sync::Arc; + +/// Wrapper around [sys::cublasLtHandle_t] +/// +/// 1. Create with [CudaBlasLT::new()] +/// 2. Execute matmul kernel with matmul. f32 is supported. f16 and bf16 are supported +/// if feature `half` is activated +/// +/// Note: This maintains a instance of [`Arc`], so will prevent the device +/// from being dropped. Kernels will be launched on the device device default stream. +#[derive(Debug)] +pub struct CudaBlasLT { + handle: sys::cublasLtHandle_t, + workspace: Workspace, + device: Arc, +} + +unsafe impl Send for CudaBlasLT {} + +unsafe impl Sync for CudaBlasLT {} + +impl CudaBlasLT { + /// Creates a new cublasLt handle. + pub fn new(device: Arc) -> Result { + let handle = result::create_handle()?; + let workspace = Workspace::new(device.clone()).unwrap(); + + Ok(Self { + handle, + workspace, + device, + }) + } +} + +impl Drop for CudaBlasLT { + fn drop(&mut self) { + let handle = mem::replace(&mut self.handle, std::ptr::null_mut()); + if !handle.is_null() { + unsafe { result::destroy_handle(handle) }.unwrap(); + } + } +} + +/// User owned CublasLt workspace buffer. +/// The workspace is initialised following the Nvidia recommendations: +/// +/// 1. NVIDIA Hopper Architecture: 32 MiB +/// 2. Other: 4 MiB +#[derive(Debug, Clone)] +pub struct Workspace { + pub(crate) buffer: CudaSlice, + pub(crate) size: usize, +} + +impl Workspace { + /// Creates a CublasLt workspace buffer on the provided device + pub fn new(device: Arc) -> Result { + device.bind_to_thread()?; + + let major = + device.attribute(CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR)?; + let workspace_size = if major >= 9 { 33_554_432 } else { 4_194_304 }; + + let buffer = unsafe { device.alloc::(workspace_size)? }; + Ok(Self { + buffer, + size: workspace_size, + }) + } +} + +/// Available activation for kernel fusing in matmul +#[derive(Debug, Clone)] +pub enum Activation { + Relu, + Gelu, +} + +/// MatrixLayout helper type +struct MatrixLayout { + handle: sys::cublasLtMatrixLayout_t, +} + +impl MatrixLayout { + fn new( + matrix_type: sys::cudaDataType, + rows: u64, + cols: u64, + ld: i64, + ) -> Result { + let handle = result::create_matrix_layout(matrix_type, rows, cols, ld)?; + Ok(Self { handle }) + } + + fn set_batch(&self, size: c_int, stride: i64) -> Result<(), CublasError> { + unsafe { + // Set batch size + set_matrix_layout_attribute( + self.handle, + sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + (&size) as *const _ as *const _, + mem::size_of::(), + )?; + // Set batch stride + set_matrix_layout_attribute( + self.handle, + sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + (&stride) as *const _ as *const _, + mem::size_of::(), + )?; + } + Ok(()) + } +} + +impl Drop for MatrixLayout { + fn drop(&mut self) { + // panic on failure + unsafe { + result::destroy_matrix_layout(self.handle).expect("Unable to destroy matrix layout") + } + } +} + +enum Matrix { + A, + B, + C, + D, +} + +/// MatmulDesc helper type +struct MatmulDesc { + handle: sys::cublasLtMatmulDesc_t, +} + +impl MatmulDesc { + fn new( + compute_type: sys::cublasComputeType_t, + scale_type: sys::cudaDataType, + ) -> Result { + let handle = result::create_matmul_desc(compute_type, scale_type)?; + Ok(Self { handle }) + } + + fn set_transpose(&self, transpose: bool, matrix: Matrix) -> Result<(), CublasError> { + // Set transpose + // 1 == T, 0 == N + let transpose = transpose as i32; + let attr = match matrix { + Matrix::A => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSA, + Matrix::B => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSB, + Matrix::C => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSC, + Matrix::D => unreachable!(), + }; + + unsafe { + result::set_matmul_desc_attribute( + self.handle, + attr, + (&transpose) as *const _ as *const _, + mem::size_of::(), + )?; + } + Ok(()) + } + + fn set_scale_ptr(&self, device_ptr: &CUdeviceptr, matrix: Matrix) -> Result<(), CublasError> { + let attr = match matrix { + Matrix::A => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + Matrix::B => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + Matrix::C => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_C_SCALE_POINTER, + Matrix::D => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, + }; + + unsafe { + result::set_matmul_desc_attribute( + self.handle, + attr, + device_ptr as *const CUdeviceptr as *const _, + mem::size_of::(), + )?; + } + Ok(()) + } + + // Epilogue system can be leveraged to fuse add and activation operations + fn set_epilogue( + &self, + act: Option<&Activation>, + bias_ptr: Option<&CUdeviceptr>, + stride_bias: Option, + ) -> Result<(), CublasError> { + let epilogue = if let Some(bias_ptr) = bias_ptr { + let epilogue = act + .map(|act| match act { + // Act + bias + Activation::Relu => sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_BIAS, + Activation::Gelu => sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_BIAS, + }) + // Only bias + .unwrap_or(sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BIAS); + + // Set bias CUdeviceptr in matmul_desc + unsafe { + result::set_matmul_desc_attribute( + self.handle, + sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_POINTER, + bias_ptr as *const CUdeviceptr as *const _, + mem::size_of::(), + )?; + } + + if let Some(stride_bias) = stride_bias { + // Set bias batch stride + unsafe { + result::set_matmul_desc_attribute( + self.handle, + sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_BATCH_STRIDE, + (&stride_bias) as *const _ as *const _, + mem::size_of::(), + )?; + } + } + epilogue + } else if let Some(act) = act { + // Only Act + match act { + Activation::Relu => sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU, + Activation::Gelu => sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU, + } + } else { + // No epilogue + sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT + }; + + // Set epilogue + unsafe { + result::set_matmul_desc_attribute( + self.handle, + sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_EPILOGUE, + (&epilogue) as *const _ as *const _, + mem::size_of::(), + )?; + } + Ok(()) + } +} + +impl Drop for MatmulDesc { + fn drop(&mut self) { + unsafe { result::destroy_matmul_desc(self.handle).expect("Unable to destroy matmul desc") } + } +} + +/// MatmulPref helper type +struct MatmulPref { + handle: sys::cublasLtMatmulPreference_t, +} + +impl MatmulPref { + fn new() -> Result { + let handle = result::create_matmul_pref()?; + Ok(Self { handle }) + } + + fn set_workspace_size(&self, size: usize) -> Result<(), CublasError> { + unsafe { + // Set workspace size + result::set_matmul_pref_attribute( + self.handle, + sys::cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + (&size) as *const _ as *const _, + mem::size_of::(), + )?; + } + Ok(()) + } +} + +impl Drop for MatmulPref { + fn drop(&mut self) { + unsafe { result::destroy_matmul_pref(self.handle).expect("Unable to destroy matmul pref") } + } +} + +/// [Matmul] super-trait +pub trait MatmulShared { + /// Returns a reference to the underlying cublasLt handle. + fn handle(&self) -> &sys::cublasLtHandle_t; + + /// Returns a reference to the underlying cublasLt workspace + fn workspace(&self) -> &Workspace; + + /// Returns a reference to the underlying stream + fn stream(&self) -> &CUstream; +} + +/// Configuration for [Matmul] +#[derive(Debug, Copy, Clone)] +pub struct MatmulConfig { + pub transa: bool, + pub transb: bool, + pub m: u64, + pub n: u64, + pub k: u64, + pub alpha: f32, + pub lda: i64, + pub ldb: i64, + pub beta: f32, + pub ldc: i64, + pub stride_a: Option, + pub stride_b: Option, + pub stride_c: Option, + pub stride_bias: Option, + pub batch_size: Option, +} + +pub enum OutSlice, B: DevicePtrMut> { + F8(A), + BF16(B), +} + +/// Matrix matrix multiplication with elements of type `T`. +pub trait Matmul: MatmulShared { + /// Underlying CUDA Type for `T` + fn matrix_type() -> sys::cudaDataType; + + /// Underlying CUDA Compute Type for `T` + fn compute_type() -> sys::cublasComputeType_t; + + /// Matrix matrix multiplication. See + /// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmul) + /// + /// https://docs.nvidia.com/cuda/cublas/#cublasltmatmul + /// There are a few requirements: + /// - Compute type must be f32 (upheld) + /// - `transa && !transb` (upheld) + /// - Scale type must be (upheld) + /// - A and B must be f8e4m3, but C must be bf16 (upheld) + /// + /// # Safety + /// This is unsafe because improper arguments may lead to invalid + /// memory accesses. + #[allow(clippy::too_many_arguments)] + unsafe fn matmul_fp8_like< + I: DevicePtr, + C: DevicePtr, + OA: DevicePtrMut, + OB: DevicePtrMut, + S: DevicePtr, + B: DevicePtr, + >( + &self, + cfg: MatmulConfig, + a: &I, + b: &I, + scale_a: &S, + scale_b: &S, + scale_d: &S, + c: &C, + out: &mut OutSlice, + // amax_d: &mut A, + bias: Option<&B>, + act: Option<&Activation>, + ) -> Result<(), CublasError> { + let (a_rows, a_cols) = (cfg.k, cfg.m); + let (b_rows, b_cols) = (cfg.k, cfg.n); + assert!(cfg.transa); + assert!(!cfg.transb); + + // Matmul description + let matmul_desc = MatmulDesc::new( + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F, + sys::cudaDataType_t::CUDA_R_32F, + ) + .unwrap(); + + // Set transa + matmul_desc.set_transpose(cfg.transa, Matrix::A).unwrap(); + // Set transb + matmul_desc.set_transpose(cfg.transb, Matrix::B).unwrap(); + + // Creates matrix layouts + let a_layout = MatrixLayout::new(Self::matrix_type(), a_rows, a_cols, cfg.lda).unwrap(); + if let (Some(batch_size), Some(stride_a)) = (cfg.batch_size, cfg.stride_a) { + a_layout.set_batch(batch_size, stride_a)?; + } + + let b_layout = MatrixLayout::new(Self::matrix_type(), b_rows, b_cols, cfg.ldb).unwrap(); + if let (Some(batch_size), Some(stride_b)) = (cfg.batch_size, cfg.stride_b) { + b_layout.set_batch(batch_size, stride_b)?; + } + + let c_layout = + MatrixLayout::new(sys::cudaDataType_t::CUDA_R_16BF, cfg.m, cfg.n, cfg.ldc).unwrap(); + if let (Some(batch_size), Some(stride_c)) = (cfg.batch_size, cfg.stride_c) { + c_layout.set_batch(batch_size, stride_c)?; + } + + let out_ty = match &out { + OutSlice::F8(_) => Self::matrix_type(), + OutSlice::BF16(_) => sys::cudaDataType_t::CUDA_R_16BF, + }; + let d_layout = MatrixLayout::new(out_ty, cfg.m, cfg.n, cfg.ldc).unwrap(); + if let (Some(batch_size), Some(stride_c)) = (cfg.batch_size, cfg.stride_c) { + d_layout.set_batch(batch_size, stride_c)?; + } + + // Set scale factors + matmul_desc + .set_scale_ptr(scale_a.device_ptr(), Matrix::A) + .unwrap(); + matmul_desc + .set_scale_ptr(scale_b.device_ptr(), Matrix::B) + .unwrap(); + matmul_desc + .set_scale_ptr(scale_d.device_ptr(), Matrix::D) + .unwrap(); + + // Pass amaxd ptr + // unsafe { + // result::set_matmul_desc_attribute( + // matmul_desc.handle, + // sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, + // amax_d.device_ptr_mut() as *const CUdeviceptr as *const _, + // mem::size_of::(), + // ) + // .unwrap(); + // } + + // Epilogue system can be leveraged to fuse add and activation operations + matmul_desc + .set_epilogue(act, bias.map(|b| b.device_ptr()), cfg.stride_bias) + .unwrap(); + + // Create matmul heuristic search preferences + let matmul_pref = MatmulPref::new().unwrap(); + + // Set workspace size + matmul_pref + .set_workspace_size(self.workspace().size) + .unwrap(); + + // Get heuristic given Config, bias, act and workspace size + let heuristic = result::get_matmul_algo_heuristic( + *self.handle(), + matmul_desc.handle, + a_layout.handle, + b_layout.handle, + c_layout.handle, + d_layout.handle, + matmul_pref.handle, + ) + .unwrap(); + + let out_ptr = match out { + OutSlice::BF16(s) => s.device_ptr_mut(), + OutSlice::F8(s) => s.device_ptr_mut(), + }; + + // Launch matmul kernel + result::matmul( + *self.handle(), + matmul_desc.handle, + (&cfg.alpha) as *const _ as *const _, + (&cfg.beta) as *const _ as *const _, + *a.device_ptr() as *const _, + a_layout.handle, + *b.device_ptr() as *const _, + b_layout.handle, + *c.device_ptr() as *const _, + c_layout.handle, + *out_ptr as *mut _, + d_layout.handle, + (&heuristic.algo) as *const _, + *self.workspace().buffer.device_ptr() as *const CUdeviceptr as *mut _, + self.workspace().size, + *self.stream() as *mut _, + ) + } +} + +impl MatmulShared for CudaBlasLT { + fn handle(&self) -> &sys::cublasLtHandle_t { + &self.handle + } + + fn workspace(&self) -> &Workspace { + &self.workspace + } + + fn stream(&self) -> &CUstream { + self.device.cu_stream() + } +} + +impl Matmul for CudaBlasLT { + fn matrix_type() -> sys::cudaDataType { + sys::cudaDataType_t::CUDA_R_32F + } + + fn compute_type() -> sys::cublasComputeType_t { + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32 + } +} + +impl Matmul for CudaBlasLT { + fn matrix_type() -> sys::cudaDataType { + sys::cudaDataType_t::CUDA_R_16F + } + + fn compute_type() -> sys::cublasComputeType_t { + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F + } +} + +impl Matmul for CudaBlasLT { + fn matrix_type() -> sys::cudaDataType { + sys::cudaDataType_t::CUDA_R_16BF + } + + fn compute_type() -> sys::cublasComputeType_t { + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F + } +} + +impl Matmul for CudaBlasLT { + fn matrix_type() -> sys::cudaDataType { + sys::cudaDataType_t::CUDA_R_8F_E4M3 + } + + fn compute_type() -> sys::cublasComputeType_t { + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F + } +} diff --git a/mistralrs-quant/src/cublaslt/mod.rs b/mistralrs-quant/src/cublaslt/mod.rs new file mode 100644 index 000000000..8bacc8239 --- /dev/null +++ b/mistralrs-quant/src/cublaslt/mod.rs @@ -0,0 +1,129 @@ +// https://github.com/huggingface/text-embeddings-inference/blob/cc1c510e8d8af8447c01e6b14c417473cf2dfda9/backends/candle/src/layers/cublaslt.rs + +#![allow(unused_variables, unused_imports, dead_code)] + +use candle_core::{Device, Result, Tensor}; +use candle_nn::Activation as CandleActivation; +use once_cell::sync::Lazy; +use std::sync::{Mutex, Once}; + +#[cfg(feature = "cuda")] +mod api; +#[cfg(feature = "cuda")] +mod matmul; + +#[cfg(feature = "cuda")] +pub use api::{fused_batch_matmul_f8, CublasLt}; + +pub enum F8MatmulOutType { + F8, + BF16, +} + +static INIT: Once = Once::new(); +static mut CUBLASLT: Option = None; +pub static CUBLASLT_HANDLE: Lazy>> = + Lazy::new(|| Mutex::new(None)); + +pub fn maybe_init_cublas_lt_wrapper() { + unsafe { + INIT.call_once(|| { + #[cfg(not(feature = "cuda"))] + { + CUBLASLT = None; + } + + #[cfg(feature = "cuda")] + { + // Check if we can call the driver + // Then check if we can create a device + // Then check that the device is CUDA + use candle_core::cuda_backend::cudarc::driver; + CUBLASLT = driver::result::init() + .ok() + .and_then(|_| Device::cuda_if_available(0).ok()) + .and_then(|device| match device { + Device::Cuda(_) => Some(CublasLtWrapper { + cublaslt: CublasLt::new(&device).unwrap(), + }), + _ => None, + }); + } + let cublaslt: Option<&'static CublasLtWrapper> = CUBLASLT.as_ref(); + *CUBLASLT_HANDLE.lock().unwrap() = cublaslt; + }); + } +} + +#[derive(Debug, Clone)] +pub struct CublasLtWrapper { + #[cfg(feature = "cuda")] + pub cublaslt: CublasLt, +} + +impl CublasLtWrapper { + /// Fused batch matmul + add + Relu/Gelu activation using CublasLt for F8 dtypes. + /// + /// # Arguments + /// + /// * `a` - Input tensor of size BxMxK + /// * `b` - Input tensor of size BxNxK + /// * `dequant_a_scale` - F32 scalar tensor, used to `a` the out tensor. + /// * `dequant_b_scale` - F32 scalar tensor, used to `b` the out tensor. + /// * `quantize_scale` - F32 scalar tensor, used to requantize. + /// * `out` - Optional Output tensor of size BxNxK. + /// If set and beta != 0, will be added to the end result of A*B before `act` + /// * `alpha` - Optional scaling factor for A*B + /// * `beta` - Optional scaling factor for C + /// * `bias` - Optional bias tensor of size M + /// * `act` - Optional Gelu or Relu activation. If set, will be added to the end result + /// + /// The resulting tensor is of shape NxM + #[allow(clippy::too_many_arguments)] + pub fn batch_matmul( + &self, + a: &Tensor, + b: &Tensor, + dequant_a_scale: &Tensor, + dequant_b_scale: &Tensor, + quantize_scale: &Tensor, + out: Option<&Tensor>, + alpha: Option, + beta: Option, + bias: Option<&Tensor>, + act: Option, + out_dtype: F8MatmulOutType, + ) -> Result { + #[cfg(feature = "cuda")] + { + let inner_act = act.map(|a| match a { + CandleActivation::Relu => matmul::Activation::Relu, + CandleActivation::Gelu => matmul::Activation::Gelu, + _ => unreachable!("Unsupported activation in cublaslt matmul"), + }); + let mut result = fused_batch_matmul_f8( + a, + b, + dequant_a_scale, + dequant_b_scale, + quantize_scale, + out, + alpha, + beta, + bias, + inner_act, + out_dtype, + self.cublaslt.clone(), + )?; + + if Some(CandleActivation::Swiglu) == act { + result = candle_nn::ops::swiglu(&result)?; + } + Ok(result) + } + #[cfg(not(feature = "cuda"))] + { + candle_core::bail!("`cuda` feature is not enabled") + } + } +} diff --git a/mistralrs-quant/src/fp8/mod.rs b/mistralrs-quant/src/fp8/mod.rs new file mode 100644 index 000000000..524ba457a --- /dev/null +++ b/mistralrs-quant/src/fp8/mod.rs @@ -0,0 +1,284 @@ +use std::{ + borrow::Cow, + io::Cursor, + num::NonZeroUsize, + sync::{atomic::AtomicUsize, Arc}, +}; + +use byteorder::{LittleEndian, ReadBytesExt}; +use candle_core::{DType, Device, Result, Tensor, D}; +use candle_nn::{Linear, Module}; +use quantize::QuantizationResult; + +mod quantize; + +use crate::{ + cublaslt::{maybe_init_cublas_lt_wrapper, F8MatmulOutType, CUBLASLT_HANDLE}, + utils::{ + deserialize_tensor, read_dtype, serialize_tensor, version_is_compatible, write_dtype, + HQFF_VERSION, + }, + IsqType, QuantMethod, QuantMethodConfig, QuantizedSerde, QuantizedSerdeType, +}; + +#[derive(Debug)] +pub struct FP8Linear { + lin: Linear, + dequant_w_scale: Tensor, + dequant_x_scale: Tensor, + quant_scale: Tensor, + /// Quantized type + dtype: DType, +} + +impl QuantMethod for FP8Linear { + fn new(method: QuantMethodConfig) -> candle_core::Result + where + Self: Sized, + { + match method { + QuantMethodConfig::Gguf { .. } + | QuantMethodConfig::Gptq { .. } + | QuantMethodConfig::Hqq { .. } + | QuantMethodConfig::Dummy + | QuantMethodConfig::Unquantized(_) => unreachable!(), + QuantMethodConfig::FP8 { lin, dtype } => { + let QuantizationResult { + qw, + quantize_scale, + dequantize_scale, + } = Self::quantize(lin.weight(), dtype)?; + Ok(Self { + lin: Linear::new(qw, lin.bias().cloned()), + dequant_x_scale: dequantize_scale.clone(), // This is probably wrong! + dequant_w_scale: dequantize_scale, + quant_scale: quantize_scale, + dtype, + }) + } + } + } + + fn forward(&self, x: &Tensor) -> Result { + // Batch matrix multiplication + maybe_init_cublas_lt_wrapper(); + + match *CUBLASLT_HANDLE.lock().unwrap() { + Some(handle) => { + let n_dims = x.dims().len(); + if n_dims < 3 { + candle_core::bail!( + "FP8Linear `matmul` via cuBLASlt expects `x` to have at least 3 dimensions" + ); + } + // Set up target shape + let mut tgt_shape = x.dims().to_vec(); + *tgt_shape.last_mut().unwrap() = self.lin.weight().dim(0)?; + + // Flatten for correct dims + let mut x = x.flatten_to(D::Minus(3))?; + + // Prepare the b tensor. If it is not quantized, quantize it + let mut dequant_x_scale = self.dequant_x_scale.clone(); + if !matches!(x.dtype(), DType::F8E4M3) { + let QuantizationResult { + qw, + quantize_scale: _, + dequantize_scale, + } = Self::quantize(&x, DType::F8E4M3)?; + x = qw; + dequant_x_scale = dequantize_scale; + } + + // Handle bias + let beta = match self.lin.bias().is_some() { + true => Some(1.0), + false => None, + }; + + // Naming + let a = self.lin.weight().unsqueeze(0)?; + let b = x; + + handle + .batch_matmul( + &a, + &b, + &self.dequant_w_scale, + &dequant_x_scale, + &self.quant_scale, + self.lin.bias(), + None, + beta, + None, + None, + F8MatmulOutType::BF16, // Output in bf16 to avoid manual dequant + )? + .reshape(tgt_shape) + } + None => { + // Dequantize matmul + let dequant_x = x.clone(); + let lin = self.dequantize(x.dtype())?; + lin.forward(&dequant_x) + } + } + } + + fn quantized_act_type(&self) -> Option { + None + } + + fn add_delta_w(&self, delta: &Tensor) -> Result> { + let dequant = self.dequantize(delta.dtype())?; + let new = Linear::new((dequant.weight() + delta)?, dequant.bias().cloned()); + Ok(Arc::new(Self::new(QuantMethodConfig::FP8 { + lin: new, + dtype: self.dtype, + })?)) + } + + fn dtype_and_device(&self) -> (DType, candle_core::Device) { + (DType::F8E4M3, self.lin.weight().device().clone()) + } + + fn get_bias_mut(&mut self) -> Option<&mut Tensor> { + None + } + + fn apply_isq( + self: Arc, + _dtype: Option, + _device: Device, + _n_quantized: &AtomicUsize, + ) -> Result> { + todo!() + } + + fn get_max_isq_cpu_threads(&self, dtype: IsqType) -> Option { + match dtype { + IsqType::F8E4M3 => None, + IsqType::Q2K + | IsqType::Q3K + | IsqType::Q4K + | IsqType::Q4_0 + | IsqType::Q4_1 + | IsqType::Q5K + | IsqType::Q5_0 + | IsqType::Q5_1 + | IsqType::Q6K + | IsqType::Q8K + | IsqType::Q8_0 + | IsqType::Q8_1 + | IsqType::HQQ4 + | IsqType::HQQ8 => None, + } + } +} + +// Serialization structure: +// +// ----------------------- +// HQFF version, u32, little endian +// ----------------------- +// ISQ type (3 for fp8), u8, little endian +// ----------------------- +// Whether bias data is included, u8 boolean +// ----------------------- +// Weight tensor data generated by `serialize_tensor`. Refer to its docs for layout. +// ----------------------- +// Dequant W scalar, f32, little endian +// ----------------------- +// Dequant X scalar, f32, little endian +// ----------------------- +// Quant scalar, f32, little endian +// ----------------------- +// Quantization type, u32, little endian +// ----------------------- +// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout. +// ----------------------- + +impl QuantizedSerde for FP8Linear { + fn isq_serde_supported(&self) -> bool { + true + } + fn name(&self) -> &'static str { + "fp8-linear" + } + fn serialize(&self) -> Result> { + let mut buffer = Vec::new(); + + buffer.extend(&HQFF_VERSION.to_le_bytes()); + + // ISQ type for fp8 is 3 + buffer.push(QuantizedSerdeType::Fp8 as u8); + + // Has bias + buffer.push(self.lin.bias().is_some() as u8); + + // Weight + serialize_tensor(&mut buffer, self.lin.weight())?; + + // Dequant a scale + buffer.extend(self.dequant_w_scale.to_scalar::()?.to_le_bytes()); + // Dequant b scale + buffer.extend(self.dequant_x_scale.to_scalar::()?.to_le_bytes()); + // Quant scale + buffer.extend(self.quant_scale.to_scalar::()?.to_le_bytes()); + + // DType + write_dtype(self.dtype, &mut buffer); + + if let Some(bias) = self.lin.bias() { + // Bias + serialize_tensor(&mut buffer, bias)?; + } + + Ok(Cow::from(buffer)) + } + + fn deserialize(data: Cow<[u8]>, device: &Device) -> Result> + where + Self: Sized, + { + let mut buffer = Cursor::new(data.to_vec()); + + let version = buffer.read_u32::()?; + if let Err(e) = version_is_compatible(version) { + return Err(candle_core::Error::wrap(e)); + } + + let isq_type = buffer.read_u8()? as usize; + if isq_type != QuantizedSerdeType::Fp8 as usize { + candle_core::bail!( + "ISQ type ({isq_type}) doesn't match expected type {}", + QuantizedSerdeType::Fp8 as usize + ); + } + + let has_bias = buffer.read_u8()? != 0; + + let w = deserialize_tensor(&mut buffer, device)?; + + let dequant_w_scale = Tensor::new(buffer.read_f32::()?, device)?; + let dequant_x_scale = Tensor::new(buffer.read_f32::()?, device)?; + let quant_scale = Tensor::new(buffer.read_f32::()?, device)?; + + // DType + let dtype = read_dtype(&mut buffer)?; + + let b = if has_bias { + Some(deserialize_tensor(&mut buffer, device)?) + } else { + None + }; + + Ok(Arc::new(Self { + lin: Linear::new(w, b), + dequant_w_scale, + dequant_x_scale, + quant_scale, + dtype, + })) + } +} diff --git a/mistralrs-quant/src/fp8/quantize.rs b/mistralrs-quant/src/fp8/quantize.rs new file mode 100644 index 000000000..26c25db59 --- /dev/null +++ b/mistralrs-quant/src/fp8/quantize.rs @@ -0,0 +1,146 @@ +use candle_core::{DType, Result, Tensor}; +use candle_nn::Linear; +use float8::F8E4M3; + +use super::FP8Linear; + +pub(super) struct QuantizationResult { + /// Quantized tensor (f8) + pub(super) qw: Tensor, + /// Scalar, f32 tensor. + /// + /// Convert unquantized to quantized tensor as follows: + /// `q = x * qs` + pub(super) quantize_scale: Tensor, + /// Scalar, f32 tensor. Reciprocal of `quantize_scale`. + /// + /// Convert unquantized to quantized tensor as follows: + /// `x = q * dqs` + pub(super) dequantize_scale: Tensor, +} + +impl FP8Linear { + pub(super) fn quantize(data: &Tensor, dtype: DType) -> Result { + let data = data.to_dtype(DType::BF16)?; + let mut absmax = data.clone(); + let mut absmin = data.clone(); + while !absmax.dims().is_empty() { + absmax = absmax.max(0)?; + absmin = absmin.min(0)?; + } + + let absmax = absmax.to_dtype(DType::F32)?.to_scalar::()?; + let absmin = absmin.to_dtype(DType::F32)?.to_scalar::()?; + let amax = f32::max(absmax.abs(), absmin.abs()); + + let max_v = F8E4M3::MAX.to_f32(); + let scale = (max_v / amax).clamp(F8E4M3::MIN.to_f32(), F8E4M3::MAX.to_f32()); + let scale = Tensor::new(scale, data.device())?; + let qw = data + .broadcast_mul(&scale.to_dtype(data.dtype())?)? + .to_dtype(dtype)?; + Ok(QuantizationResult { + qw, + quantize_scale: scale.clone(), + dequantize_scale: scale.recip()?, + }) + } + + pub(super) fn dequantize(&self, dtype: DType) -> Result { + let dequant_w = self + .lin + .weight() + .to_dtype(dtype)? + .broadcast_mul(&self.dequant_w_scale.to_dtype(dtype)?)?; + Ok(Linear::new(dequant_w, self.lin.bias().cloned())) + } +} + +#[cfg(test)] +mod tests { + use candle_core::{ + quantized::{GgmlDType, QTensor}, + DType, Device, Result, Tensor, + }; + + use crate::fp8::FP8Linear; + + use super::QuantizationResult; + + #[test] + fn test_roundtrip_f8e4m3() -> Result<()> { + let dev = Device::cuda_if_available(0)?; + + let data = Tensor::rand(0., 1., (32, 32), &dev)?.to_dtype(DType::F32)?; + + let QuantizationResult { + qw, + quantize_scale: _, + dequantize_scale, + } = FP8Linear::quantize(&data, DType::F8E4M3)?; + + let dequant = qw.to_dtype(DType::F32)?.broadcast_mul(&dequantize_scale)?; + + let diff1 = (&data - dequant)?.abs()?.mean_all()?; + + println!("{diff1}"); + + let q8_0 = QTensor::quantize(&data, GgmlDType::Q8_0)?.dequantize(&dev)?; + let diff2 = (&data - q8_0)?.abs()?.mean_all()?; + + println!("{diff2}"); + Ok(()) + } + + #[test] + #[cfg(feature = "cuda")] + fn test_cublaslt_matmul() -> Result<()> { + use crate::cublaslt::{maybe_init_cublas_lt_wrapper, F8MatmulOutType, CUBLASLT_HANDLE}; + let dev = Device::new_cuda(0)?; + + let w = Tensor::rand(0., 1., (1, 16, 32), &dev)?.to_dtype(DType::F32)?; + let mut x = Tensor::rand(0., 1., (1, 16, 32), &dev)?.to_dtype(DType::F32)?; + + // Batch matrix multiplication + maybe_init_cublas_lt_wrapper(); + + let handle = CUBLASLT_HANDLE.lock().unwrap().unwrap(); + + let QuantizationResult { + qw, + quantize_scale: quant_scale, + dequantize_scale: dequant_a_scale, + } = FP8Linear::quantize(&w, DType::F8E4M3)?; + + let mut dequant_b_scale = dequant_a_scale.clone(); + if !matches!(x.dtype(), DType::F8E4M3) { + let QuantizationResult { + qw, + quantize_scale: _, + dequantize_scale, + } = FP8Linear::quantize(&x, DType::F8E4M3)?; + x = qw; + dequant_b_scale = dequantize_scale; + } + + let a = qw; + let b = x; + + // FP8 quantized matmul + let _res = handle.batch_matmul( + &a, + &b, + &dequant_a_scale, + &dequant_b_scale, + &quant_scale, + None, + None, + None, + None, + None, + F8MatmulOutType::BF16, + )?; + + Ok(()) + } +} diff --git a/mistralrs-quant/src/gguf/mod.rs b/mistralrs-quant/src/gguf/mod.rs index bbb17a910..697b13567 100644 --- a/mistralrs-quant/src/gguf/mod.rs +++ b/mistralrs-quant/src/gguf/mod.rs @@ -37,7 +37,8 @@ impl QuantMethod for GgufMatMul { QuantMethodConfig::Gptq { .. } | QuantMethodConfig::Unquantized(_) | QuantMethodConfig::Hqq { .. } - | QuantMethodConfig::Dummy => unreachable!(), + | QuantMethodConfig::Dummy + | QuantMethodConfig::FP8 { .. } => unreachable!(), } } diff --git a/mistralrs-quant/src/gptq/ffi.rs b/mistralrs-quant/src/gptq/ffi.rs index f1d605f8a..d93fc22a2 100644 --- a/mistralrs-quant/src/gptq/ffi.rs +++ b/mistralrs-quant/src/gptq/ffi.rs @@ -1,3 +1,5 @@ +use std::os::raw::c_void; + use half::f16; #[allow(dead_code)] @@ -53,4 +55,37 @@ extern "C" { k: i32, bit: i32, ); + + pub(crate) fn marlin_4bit_f16( + inputs: *const c_void, + weight: *const i32, + scales: *const c_void, + out: *const c_void, + m: i32, + k: i32, + n: i32, + workspace: *const c_void, //tensor with at least `n / 128 * max_par` entries that are all zero + groupsize: i32, + ); + + pub(crate) fn marlin_4bit_bf16( + inputs: *const c_void, + weight: *const i32, + scales: *const c_void, + out: *const c_void, + m: i32, + k: i32, + n: i32, + workspace: *const c_void, //tensor with at least `n / 128 * max_par` entries that are all zero + groupsize: i32, + ); + + pub(crate) fn gptq_marlin_repack( + weight: *const c_void, + perm: *const c_void, + result: *const c_void, + k: i32, + n: i32, + bits: i32, + ); } diff --git a/mistralrs-quant/src/gptq/gptq_cpu.rs b/mistralrs-quant/src/gptq/gptq_cpu.rs index 8316779a2..ea66a243a 100644 --- a/mistralrs-quant/src/gptq/gptq_cpu.rs +++ b/mistralrs-quant/src/gptq/gptq_cpu.rs @@ -1,5 +1,6 @@ -use crate::{IsqType, QuantMethod, QuantMethodConfig, QuantizedSerde}; +use crate::{DummyLayer, IsqType, QuantMethod, QuantMethodConfig, QuantizedConfig, QuantizedSerde}; use candle_core::{DType, Device, Result, Tensor}; +use candle_nn::VarBuilder; use std::{ num::NonZeroUsize, sync::{atomic::AtomicUsize, Arc}, @@ -14,19 +15,12 @@ impl QuantMethod for GptqLayer { Self: Sized, { match method { - QuantMethodConfig::Gptq { - bits: _, - use_exllama: _, - q_weight: _, - gptq_qzeros: _, - gptq_scales: _, - g_idx: _, - bias: _, - } => candle_core::bail!("GPTQ is only supported on CUDA."), + QuantMethodConfig::Gptq { .. } => candle_core::bail!("GPTQ is only supported on CUDA."), QuantMethodConfig::Gguf { .. } | QuantMethodConfig::Unquantized(_) | QuantMethodConfig::Hqq { .. } - | QuantMethodConfig::Dummy => { + | QuantMethodConfig::Dummy + | QuantMethodConfig::FP8 { .. } => { unreachable!() } } @@ -71,3 +65,65 @@ impl QuantizedSerde for GptqLayer { "gptq" } } + +macro_rules! pack_factor { + ($bits:expr) => { + 32 / $bits + }; +} + +pub fn gptq_linear( + in_dim: usize, + out_dim: usize, + config: &QuantizedConfig, + vb: VarBuilder, +) -> Result> { + // Handle the case where the layer is dummy (no tensors) + if !(vb.contains_tensor("qweight") + && vb.contains_tensor("qzeros") + && vb.contains_tensor("g_idx") + && vb.contains_tensor("scales")) + { + let layer = ::new(QuantMethodConfig::Dummy)?; + return Ok(Arc::new(layer) as Arc); + } + + let qweight = vb.get_with_hints_dtype( + (in_dim / pack_factor!(config.bits), out_dim), + "qweight", + Default::default(), + DType::I32, + )?; + let scale_and_zero_size = in_dim / config.group_size; + let qzeros = vb.get_with_hints_dtype( + (scale_and_zero_size, out_dim / pack_factor!(config.bits)), + "qzeros", + Default::default(), + DType::I32, + )?; + let g_idx = vb.get_with_hints_dtype((in_dim,), "g_idx", Default::default(), DType::I32)?; + let scales = vb.get_with_hints_dtype( + (scale_and_zero_size, out_dim), + "scales", + Default::default(), + DType::F16, + )?; + let bias = if vb.contains_tensor("bias") { + Some(vb.get_with_hints_dtype((out_dim,), "bias", Default::default(), DType::F16)?) + } else { + None + }; + + let config = QuantMethodConfig::Gptq { + bits: config.bits as i32, + use_exllama: false, + q_weight: qweight, + gptq_qzeros: Some(qzeros), + gptq_scales: scales, + g_idx: Some(g_idx), + bias, + workspace: None, + is_marlin: false, + }; + Ok(Arc::new(GptqLayer::new(config)?)) +} diff --git a/mistralrs-quant/src/gptq/gptq_cuda.rs b/mistralrs-quant/src/gptq/gptq_cuda.rs index 239d7411b..b30aed62f 100644 --- a/mistralrs-quant/src/gptq/gptq_cuda.rs +++ b/mistralrs-quant/src/gptq/gptq_cuda.rs @@ -12,14 +12,16 @@ use candle_core::{ }, CudaStorageSlice, WrapErr, }, - from_storage_no_op, CudaStorage, DType, Device, Result, Shape, Storage, Tensor, D, + from_storage_no_op, Context, CudaStorage, DType, Device, Result, Shape, Storage, Tensor, D, }; +use candle_nn::VarBuilder; use half::f16; use lazy_static::lazy_static; use crate::{ + gptq::marlin_backend::{gptq_marlin_matmul, gptq_weight_repack}, utils::{get_cuda_device, get_cuda_slice}, - IsqType, QuantMethod, QuantMethodConfig, QuantizedSerde, + DummyLayer, IsqType, QuantMethod, QuantMethodConfig, QuantizedConfig, QuantizedSerde, }; use super::ffi::{ @@ -37,19 +39,28 @@ lazy_static! { #[derive(Debug)] pub struct GptqLayer { - q_weight: Tensor, // u32 - gptq_qzeros: Tensor, // u32 - gptq_scales: Tensor, // f16 - bias: Tensor, // f16 - g_idx: Tensor, // i32 + q_weight: Tensor, // u32 + gptq_qzeros: Option, // u32 + gptq_scales: Tensor, // f16 + bias: Option, // f16 + g_idx: Option, // i32 bits: i32, use_exllama: bool, + workspace: Option, + is_marlin: bool, } impl GptqLayer { // https://github.com/vllm-project/vllm/blob/966fe72141e8365721840b7ababfb78601c23ead/csrc/quantization/gptq/q_gemm.cu#L1490 // https://github.com/vllm-project/vllm/blob/966fe72141e8365721840b7ababfb78601c23ead/csrc/quantization/gptq/q_gemm.cu#L1823 - fn gptq_gemm(&self, a: Tensor, groups: i32, use_exllama: bool) -> Result { + fn gptq_gemm( + &self, + a: Tensor, + g_idx: &Tensor, + gptq_qzeros: &Tensor, + groups: i32, + use_exllama: bool, + ) -> Result { if !a.is_contiguous() { candle_core::bail!( "Expected `a` to be contiguous, got strides {:?}", @@ -58,9 +69,9 @@ impl GptqLayer { } let a_ptr = get_cuda_slice::(&a)?; let b_q_weight = get_cuda_slice::(&self.q_weight)? as *const u32; - let b_gptq_qzeros = get_cuda_slice::(&self.gptq_qzeros)? as *const u32; + let b_gptq_qzeros = get_cuda_slice::(gptq_qzeros)? as *const u32; let b_gptq_scales = get_cuda_slice::(&self.gptq_scales)?; - let b_g_idx = get_cuda_slice::(&self.g_idx)?; + let b_g_idx = get_cuda_slice::(g_idx)?; let dev = get_cuda_device(&a)?; @@ -218,14 +229,18 @@ impl QuantMethod for GptqLayer { gptq_scales, g_idx, bias, + workspace, + is_marlin, } => { - let dev = get_cuda_device(&g_idx)?; - let len = (q_weight.dims()[0] * 32 / bits as usize) * q_weight.dims()[1]; - // SAFETY: used in the kernel as a tmp space, just preallocating it here. - if let std::collections::hash_map::Entry::Vacant(e) = - TMP_DQS.lock().unwrap().entry(len) - { - e.insert(unsafe { dev.alloc::(len).w()? }); + if workspace.is_none() { + let dev = get_cuda_device(&q_weight)?; + let len = (q_weight.dims()[0] * 32 / bits as usize) * q_weight.dims()[1]; + // SAFETY: used in the kernel as a tmp space, just preallocating it here. + if let std::collections::hash_map::Entry::Vacant(e) = + TMP_DQS.lock().unwrap().entry(len) + { + e.insert(unsafe { dev.alloc::(len).w()? }); + } } Ok(Self { q_weight, @@ -235,12 +250,15 @@ impl QuantMethod for GptqLayer { bits, use_exllama, bias, + workspace, + is_marlin, }) } QuantMethodConfig::Gguf { .. } | QuantMethodConfig::Unquantized(_) | QuantMethodConfig::Hqq { .. } - | QuantMethodConfig::Dummy => { + | QuantMethodConfig::Dummy + | QuantMethodConfig::FP8 { .. } => { unreachable!() } } @@ -259,12 +277,36 @@ impl QuantMethod for GptqLayer { if !reshaped_a.device().is_cuda() { candle_core::bail!("Expected CUDA input to GptqLayer"); } - let out = self.gptq_gemm( - reshaped_a, - self.gptq_qzeros.dim(0)? as i32, - self.use_exllama, - )?; - out.reshape(out_shape)?.broadcast_add(&self.bias) + + let out = match ( + self.g_idx.as_ref(), + self.gptq_qzeros.as_ref(), + self.is_marlin, + ) { + (Some(g_idx), Some(gptq_qzeros), false) => self + .gptq_gemm( + reshaped_a, + g_idx, + gptq_qzeros, + gptq_qzeros.dim(0)? as i32, + self.use_exllama, + )? + .reshape(out_shape)?, + (_, _, true) => gptq_marlin_matmul( + a, + &self.q_weight, + &self.gptq_scales, + self.workspace.as_ref().context("Workspace required")?, + self.bits, + )?, + _ => unreachable!(), + }; + + if let Some(bias) = &self.bias { + out.broadcast_add(bias) + } else { + Ok(out) + } } fn quantized_act_type(&self) -> Option { @@ -302,3 +344,157 @@ impl QuantizedSerde for GptqLayer { "gptq" } } + +macro_rules! pack_factor { + ($bits:expr) => { + 32 / $bits + }; +} + +pub fn gptq_linear( + in_dim: usize, + out_dim: usize, + config: &QuantizedConfig, + vb: VarBuilder, +) -> Result> { + // Handle the case where the layer is dummy (no tensors) + if !(vb.contains_tensor("qweight") + && vb.contains_tensor("qzeros") + && vb.contains_tensor("g_idx") + && vb.contains_tensor("scales")) + { + let layer = ::new(QuantMethodConfig::Dummy)?; + return Ok(Arc::new(layer) as Arc); + } + + let marlin_compatible = config.bits == 4 || config.bits == 8; + let marlin_format = config + .checkpoint_format + .as_ref() + .is_some_and(|fmt| fmt == "marlin"); + + let qw_shape = if marlin_format { + (in_dim / pack_factor!(config.bits) / 2, out_dim * 2) + } else { + (in_dim / pack_factor!(config.bits), out_dim) + }; + let qweight = vb.get_with_hints_dtype( + qw_shape, + if marlin_format { "B" } else { "qweight" }, + Default::default(), + DType::I32, + )?; + let scale_and_zero_size = in_dim / config.group_size; + let scales = vb.get_with_hints_dtype( + (scale_and_zero_size, out_dim), + if marlin_format { "s" } else { "scales" }, + Default::default(), + DType::F16, + )?; + let bias = if vb.contains_tensor("bias") { + Some(vb.get_with_hints_dtype((out_dim,), "bias", Default::default(), DType::F16)?) + } else { + None + }; + let workspace = Tensor::zeros(out_dim / pack_factor!(config.bits), DType::U32, vb.device())?; + + let config = if marlin_format { + QuantMethodConfig::Gptq { + bits: config.bits as i32, + use_exllama: false, + q_weight: qweight, + gptq_qzeros: None, + gptq_scales: scales, + g_idx: None, + bias, + workspace: Some(workspace), + is_marlin: true, + } + } else { + fn get_scale_perms() -> (Vec, Vec) { + let mut scale_perm: Vec = Vec::new(); + for i in 0..8 { + scale_perm.extend((0..8).map(|j| i + 8 * j)); + } + let mut scale_perm_single: Vec = Vec::new(); + for i in 0..4 { + scale_perm_single.extend([0, 1, 8, 9, 16, 17, 24, 25].iter().map(|&j| 2 * i + j)); + } + (scale_perm, scale_perm_single) + } + + fn marlin_permute_scales( + s: &Tensor, + size_k: usize, + size_n: usize, + group_size: i32, + _num_bits: u32, + ) -> Result { + let (scale_perm, scale_perm_single) = get_scale_perms(); + let s = if (group_size as usize) < size_k && group_size != -1 { + let s = s.reshape(((), scale_perm.len()))?; + let scale_perm_tensor = + Tensor::from_slice(&scale_perm, scale_perm.len(), s.device())?; + s.index_select(&scale_perm_tensor, 1)? + } else { + let s = s.reshape(((), scale_perm_single.len()))?; + let scale_perm_single_tensor = + Tensor::from_slice(&scale_perm_single, scale_perm_single.len(), s.device())?; + s.index_select(&scale_perm_single_tensor, 1)? + }; + + let s = s.reshape(((), size_n))?.contiguous()?; + Ok(s) + } + + let qzeros = vb.get_with_hints_dtype( + (scale_and_zero_size, out_dim / pack_factor!(config.bits)), + "qzeros", + Default::default(), + DType::I32, + )?; + + let g_idx = vb.get_with_hints_dtype((in_dim,), "g_idx", Default::default(), DType::I32)?; + let perm = g_idx + .to_device(&Device::Cpu)? + .arg_sort_last_dim(true)? + .to_device(g_idx.device())?; + + // Repack to marlin format + let qweight = if marlin_compatible { + gptq_weight_repack(&qweight, &perm, in_dim, config.bits as i32)? + } else { + qweight + }; + + let scales = if marlin_compatible { + marlin_permute_scales( + &scales, + in_dim / pack_factor!(config.bits), + out_dim, + config.group_size as i32, + config.bits as u32, + )? + } else { + scales + }; + let workspace = if marlin_compatible { + Some(workspace) + } else { + None + }; + + QuantMethodConfig::Gptq { + bits: config.bits as i32, + use_exllama: false, + q_weight: qweight, + gptq_qzeros: Some(qzeros), + gptq_scales: scales, + g_idx: Some(g_idx), + bias, + workspace, + is_marlin: marlin_compatible, + } + }; + Ok(Arc::new(GptqLayer::new(config)?)) +} diff --git a/mistralrs-quant/src/gptq/marlin_backend.rs b/mistralrs-quant/src/gptq/marlin_backend.rs new file mode 100644 index 000000000..e6ebd4b07 --- /dev/null +++ b/mistralrs-quant/src/gptq/marlin_backend.rs @@ -0,0 +1,243 @@ +use super::ffi::{gptq_marlin_repack, marlin_4bit_bf16, marlin_4bit_f16}; +use candle::backend::BackendStorage; +use candle::cuda_backend::cudarc::driver::DevicePtr; +use candle::cuda_backend::WrapErr; +use candle::{CpuStorage, CudaStorage, DType, Layout, Result, Shape, Storage, Tensor}; +use candle_core as candle; +use half::{bf16, f16}; + +struct GPTQMatMul { + workspace: Tensor, + bits: i32, +} + +impl GPTQMatMul { + fn cuda_fwd_t< + T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, + >( + &self, + x: &CudaStorage, + x_l: &Layout, + qweight: &CudaStorage, + qweight_l: &Layout, + scale: &CudaStorage, + scale_l: &Layout, + ) -> Result<(CudaStorage, Shape)> { + let dev = qweight.device(); + let x_shape = x_l.dims(); + let weight_shape = qweight_l.dims(); + // let zero_shape = self.qzeros.shape().dims(); + let scale_shape = scale_l.dims(); + + let pack_factor: usize = 32 / self.bits as usize; + let size_m = x_shape[0] * x_shape[1]; + let size_k = weight_shape[0] * pack_factor * 2; //marlin format + let size_n = weight_shape[1] / 2; //marlin format + + let mut out_shape: Vec = x_shape.to_vec(); + out_shape[x_shape.len() - 1] = size_n; + let oshape: Shape = out_shape.into(); + + // Get cuda slices for all tensors + let input = x.as_cuda_slice::()?; + let qw = qweight.as_cuda_slice::()?; + let qs = scale.as_cuda_slice::()?; + + // Get cuda views for all tensors + let input = input.slice(x_l.start_offset()..); + let qw = qw.slice(qweight_l.start_offset()..); + let qs = qs.slice(scale_l.start_offset()..); + + let elem_count = oshape.elem_count(); + let out = unsafe { dev.alloc::(elem_count) }.w()?; + + let out_ptr = *out.device_ptr() as *const core::ffi::c_void; + let in_ptr = *input.device_ptr() as *const core::ffi::c_void; + let qw_ptr = *qw.device_ptr() as *const core::ffi::c_void; + let qs_ptr = *qs.device_ptr() as *const core::ffi::c_void; + let workspace_ptr = { + let (workspace, workspace_l) = self.workspace.storage_and_layout(); + let workspace = match &*workspace { + Storage::Cuda(p) => p, + _ => candle::bail!("workspace must be a cuda tensor"), + }; + let workspace_ = workspace.as_cuda_slice::()?; + let workspace_ = workspace_.slice(workspace_l.start_offset()..); + *workspace_.device_ptr() as *const core::ffi::c_void + }; + + let groupsize: i32 = if scale_shape[0] == 1 { + -1i32 + } else { + (size_k / scale_shape[0]) as i32 + }; + if x.dtype() == DType::F16 { + unsafe { + marlin_4bit_f16( + in_ptr, + qw_ptr as *const i32, + qs_ptr, + out_ptr, + size_m as i32, + size_k as i32, + size_n as i32, + workspace_ptr, + groupsize, + ); + } + } else if x.dtype() == DType::BF16 { + unsafe { + marlin_4bit_bf16( + in_ptr, + qw_ptr as *const i32, + qs_ptr, + out_ptr, + size_m as i32, + size_k as i32, + size_n as i32, + workspace_ptr, + groupsize, + ); + } + } + + let out = CudaStorage::wrap_cuda_slice(out, dev.clone()); + Ok((out, oshape)) + } +} + +impl candle::CustomOp3 for GPTQMatMul { + fn name(&self) -> &'static str { + "GPTQMatMul" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle::bail!("no cpu support for GPTQMatMul") + } + + fn cuda_fwd( + &self, + x: &CudaStorage, + x_l: &Layout, + qweight: &CudaStorage, + qweight_l: &Layout, + scale: &CudaStorage, + scale_l: &Layout, + ) -> Result<(CudaStorage, Shape)> { + match x.dtype() { + DType::F16 => self.cuda_fwd_t::(x, x_l, qweight, qweight_l, scale, scale_l), + DType::BF16 => self.cuda_fwd_t::(x, x_l, qweight, qweight_l, scale, scale_l), + dt => candle::bail!("GPTQMatMul is only supported for f16 and bf16 ({dt:?})"), + } + } +} + +pub fn gptq_marlin_matmul( + x: &Tensor, + qweight: &Tensor, + scale: &Tensor, + workspace: &Tensor, + bits: i32, +) -> Result { + let op = GPTQMatMul { + workspace: workspace.to_owned(), + bits, + }; + x.apply_op3(qweight, scale, op) +} + +struct GPTQRepack { + k: i32, + bits: i32, +} + +impl GPTQRepack { + fn cuda_fwd_t< + T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, + >( + &self, + qweight: &CudaStorage, + qweight_l: &Layout, + perm: &CudaStorage, + perm_l: &Layout, + ) -> Result<(CudaStorage, Shape)> { + let dev = qweight.device(); + let q_shape = qweight_l.dims(); + let mut out_shape: Vec = q_shape.to_vec(); + out_shape[0] /= 2; + out_shape[1] *= 2; + + let oshape: Shape = out_shape.into(); + + // Get cuda slices for all tensors + let q = qweight.as_cuda_slice::()?; + let perm_ = perm.as_cuda_slice::()?; + + // Get cuda views for all tensors + let q = q.slice(qweight_l.start_offset()..); + let perm_ = perm_.slice(perm_l.start_offset()..); + + let elem_count = oshape.elem_count(); + let out = unsafe { dev.alloc::(elem_count) }.w()?; + + let out_ptr = *out.device_ptr() as *const core::ffi::c_void; + let q_ptr = *q.device_ptr() as *const core::ffi::c_void; + let q_perm = *perm_.device_ptr() as *const core::ffi::c_void; + + unsafe { gptq_marlin_repack(q_ptr, q_perm, out_ptr, self.k, q_shape[1] as i32, self.bits) } + + let out = CudaStorage::wrap_cuda_slice(out, dev.clone()); + Ok((out, oshape)) + } +} + +impl candle::CustomOp2 for GPTQRepack { + fn name(&self) -> &'static str { + "GPTQRepack" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle::bail!("no cpu support for GPTQRepack") + } + + fn cuda_fwd( + &self, + qweight: &CudaStorage, + qweight_l: &Layout, + perm: &CudaStorage, + perm_l: &Layout, + ) -> Result<(CudaStorage, Shape)> { + match qweight.dtype() { + DType::U32 => self.cuda_fwd_t::(qweight, qweight_l, perm, perm_l), + DType::I32 => self.cuda_fwd_t::(qweight, qweight_l, perm, perm_l), + dt => candle::bail!("GPTQRepack is only supported for i32/u32 weight ({dt:?})"), + } + } +} + +pub fn gptq_weight_repack( + qweight: &Tensor, + perm: &Tensor, + size_k: usize, + bits: i32, +) -> Result { + let op = GPTQRepack { + bits, + k: size_k as i32, + }; + qweight.apply_op2(perm, op) +} diff --git a/mistralrs-quant/src/gptq/mod.rs b/mistralrs-quant/src/gptq/mod.rs index 013c10fc3..18fca3594 100644 --- a/mistralrs-quant/src/gptq/mod.rs +++ b/mistralrs-quant/src/gptq/mod.rs @@ -4,8 +4,10 @@ mod ffi; mod gptq_cpu; #[cfg(feature = "cuda")] mod gptq_cuda; +#[cfg(feature = "cuda")] +mod marlin_backend; #[cfg(not(feature = "cuda"))] -pub use gptq_cpu::GptqLayer; +pub use gptq_cpu::{gptq_linear, GptqLayer}; #[cfg(feature = "cuda")] -pub use gptq_cuda::GptqLayer; +pub use gptq_cuda::{gptq_linear, GptqLayer}; diff --git a/mistralrs-quant/src/hqq/mod.rs b/mistralrs-quant/src/hqq/mod.rs index bb4ee2af6..2430e6487 100644 --- a/mistralrs-quant/src/hqq/mod.rs +++ b/mistralrs-quant/src/hqq/mod.rs @@ -528,7 +528,8 @@ impl QuantMethod for HqqLayer { QuantMethodConfig::Gguf { .. } | QuantMethodConfig::Unquantized(_) | QuantMethodConfig::Gptq { .. } - | QuantMethodConfig::Dummy => { + | QuantMethodConfig::Dummy + | QuantMethodConfig::FP8 { .. } => { unreachable!() } QuantMethodConfig::Hqq { diff --git a/mistralrs-quant/src/hqq/quantize.rs b/mistralrs-quant/src/hqq/quantize.rs index 52207f2e3..3707ee0d8 100644 --- a/mistralrs-quant/src/hqq/quantize.rs +++ b/mistralrs-quant/src/hqq/quantize.rs @@ -123,9 +123,9 @@ mod test { }, )?; - let dequant = hqq.dequantize()?; + let _dequant = hqq.dequantize()?; - dbg!(&(&dequant - &data)?.abs()?.mean_all()?); + // dbg!(&(&dequant - &data)?.abs()?.mean_all()?); Ok(()) } } diff --git a/mistralrs-quant/src/lib.rs b/mistralrs-quant/src/lib.rs index 6dd97679c..a312d3a32 100644 --- a/mistralrs-quant/src/lib.rs +++ b/mistralrs-quant/src/lib.rs @@ -10,7 +10,9 @@ use candle_core::{ DType, Device, Result, Tensor, }; +mod cublaslt; mod dummy; +mod fp8; mod gguf; mod gptq; mod hqq; @@ -18,15 +20,17 @@ mod unquantized; mod utils; pub use dummy::DummyLayer; +pub use fp8::FP8Linear; pub use gguf::GgufMatMul; +use gptq::gptq_linear; pub use gptq::GptqLayer; pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer}; pub use unquantized::UnquantLinear; use candle_nn::{Linear, VarBuilder}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, Deserialize, Default)] +#[derive(Debug, Clone, Deserialize, Serialize, Default)] pub enum QuantMethodType { #[default] #[serde(rename = "gptq")] @@ -41,11 +45,12 @@ impl Display for QuantMethodType { } } -#[derive(Debug, Clone, Deserialize, Default)] +#[derive(Debug, Clone, Deserialize, Serialize, Default)] pub struct QuantizedConfig { pub bits: usize, pub quant_method: QuantMethodType, pub group_size: usize, + pub checkpoint_format: Option, } #[derive(Debug, Clone)] @@ -54,10 +59,12 @@ pub enum QuantMethodConfig { bits: i32, use_exllama: bool, q_weight: Tensor, - gptq_qzeros: Tensor, + gptq_qzeros: Option, gptq_scales: Tensor, - g_idx: Tensor, - bias: Tensor, + g_idx: Option, + bias: Option, + workspace: Option, + is_marlin: bool, }, Gguf { q_weight: Arc, @@ -75,6 +82,10 @@ pub enum QuantMethodConfig { bias: Option, }, Dummy, + FP8 { + lin: Linear, + dtype: DType, + }, } #[derive(Clone, Copy, Debug, PartialEq, Hash, Eq)] @@ -96,6 +107,7 @@ pub enum IsqType { // HQQ3, // HQQ2, // HQQ1, + F8E4M3, } impl TryFrom for GgmlDType { @@ -143,6 +155,7 @@ pub enum QuantizedSerdeType { Gguf = 0, Unquant = 1, Hqq = 2, + Fp8 = 3, } impl TryFrom for QuantizedSerdeType { @@ -152,6 +165,7 @@ impl TryFrom for QuantizedSerdeType { 0 => Ok(Self::Gguf), 1 => Ok(Self::Unquant), 2 => Ok(Self::Hqq), + 3 => Ok(Self::Fp8), other => candle_core::bail!("QuantizedSerdeType {other} is invalid."), } } @@ -209,12 +223,10 @@ pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde { fn get_bias_mut(&mut self) -> Option<&mut Tensor>; fn get_max_isq_cpu_threads(&self, dtype: IsqType) -> Option; -} -macro_rules! pack_factor { - ($bits:expr) => { - 32 / $bits - }; + fn unquant_weight_bias(&self) -> Option<(Tensor, Option)> { + None + } } pub fn linear_no_bias( @@ -280,53 +292,3 @@ pub fn linear_b( linear_no_bias(in_dim, out_dim, config, vb) } } - -pub fn gptq_linear( - in_dim: usize, - out_dim: usize, - config: &QuantizedConfig, - vb: VarBuilder, -) -> Result> { - // Handle the case where the layer is dummy (no tensors) - if !(vb.contains_tensor("qweight") - && vb.contains_tensor("qzeros") - && vb.contains_tensor("g_idx") - && vb.contains_tensor("scales")) - { - let layer = ::new(QuantMethodConfig::Dummy)?; - return Ok(Arc::new(layer) as Arc); - } - - let qweight = vb.get_with_hints_dtype( - (in_dim / pack_factor!(config.bits), out_dim), - "qweight", - Default::default(), - DType::I32, - )?; - let scale_and_zero_size = in_dim / config.group_size; - let qzeros = vb.get_with_hints_dtype( - (scale_and_zero_size, out_dim / pack_factor!(config.bits)), - "qzeros", - Default::default(), - DType::I32, - )?; - let g_idx = vb.get_with_hints_dtype((in_dim,), "g_idx", Default::default(), DType::I32)?; - let scales = vb.get_with_hints_dtype( - (scale_and_zero_size, out_dim), - "scales", - Default::default(), - DType::F16, - )?; - let bias = vb.get_with_hints_dtype((out_dim,), "bias", Default::default(), DType::F16)?; - - let config = QuantMethodConfig::Gptq { - bits: config.bits as i32, - use_exllama: false, - q_weight: qweight, - gptq_qzeros: qzeros, - gptq_scales: scales, - g_idx, - bias, - }; - Ok(Arc::new(GptqLayer::new(config)?)) -} diff --git a/mistralrs-quant/src/unquantized/mod.rs b/mistralrs-quant/src/unquantized/mod.rs index 654040a8c..80e1469c5 100644 --- a/mistralrs-quant/src/unquantized/mod.rs +++ b/mistralrs-quant/src/unquantized/mod.rs @@ -13,7 +13,8 @@ use crate::{ generate_isq, hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer, ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE}, utils::{deserialize_tensor, serialize_tensor, version_is_compatible, HQFF_VERSION}, - GgufMatMul, IsqType, QuantMethod, QuantMethodConfig, QuantizedSerde, QuantizedSerdeType, + FP8Linear, GgufMatMul, IsqType, QuantMethod, QuantMethodConfig, QuantizedSerde, + QuantizedSerdeType, }; #[derive(Debug)] @@ -28,7 +29,8 @@ impl QuantMethod for UnquantLinear { QuantMethodConfig::Gguf { .. } | QuantMethodConfig::Gptq { .. } | QuantMethodConfig::Hqq { .. } - | QuantMethodConfig::Dummy => unreachable!(), + | QuantMethodConfig::Dummy + | QuantMethodConfig::FP8 { .. } => unreachable!(), QuantMethodConfig::Unquantized(l) => Ok(Self(l)), } } @@ -117,6 +119,18 @@ impl QuantMethod for UnquantLinear { .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()), })?)) } + Some(IsqType::F8E4M3) => { + let w = self.0.weight().to_device(&device)?; + let b = if let Some(b) = self.0.bias() { + Some(b.to_device(&device)?) + } else { + None + }; + Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 { + lin: Linear::new(w, b), + dtype: DType::F8E4M3, + })?)) + } None => { let w = self.0.weight().to_device(&device)?; let b = if let Some(b) = self.0.bias() { @@ -138,6 +152,7 @@ impl QuantMethod for UnquantLinear { // Use 1 because our HQQ quantizes on the GPU Some(1.try_into().unwrap()) } + IsqType::F8E4M3 => None, IsqType::Q2K | IsqType::Q3K | IsqType::Q4K @@ -152,6 +167,10 @@ impl QuantMethod for UnquantLinear { | IsqType::Q8_1 => None, } } + + fn unquant_weight_bias(&self) -> Option<(Tensor, Option)> { + Some((self.0.weight().clone(), self.0.bias().cloned())) + } } // Serialization structure: diff --git a/mistralrs-quant/src/utils/mod.rs b/mistralrs-quant/src/utils/mod.rs index 417eb3d39..51650c7f9 100644 --- a/mistralrs-quant/src/utils/mod.rs +++ b/mistralrs-quant/src/utils/mod.rs @@ -6,7 +6,10 @@ mod ops; mod uqff; pub use ops::{BitWiseOp, LeftshiftOp}; -pub(crate) use uqff::{deserialize_tensor, serialize_tensor, version_is_compatible, HQFF_VERSION}; +pub(crate) use uqff::{ + deserialize_tensor, read_dtype, serialize_tensor, version_is_compatible, write_dtype, + HQFF_VERSION, +}; #[cfg(feature = "cuda")] use candle_core::{ diff --git a/mistralrs-quant/src/utils/ops.rs b/mistralrs-quant/src/utils/ops.rs index de59be5d0..3b38d4e69 100644 --- a/mistralrs-quant/src/utils/ops.rs +++ b/mistralrs-quant/src/utils/ops.rs @@ -70,6 +70,7 @@ impl CustomOp2 for BitWiseOr { CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise-or")), CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise-or")), CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise-or")), + CpuStorage::F8E4M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise-or")), } } #[cfg(feature = "cuda")] @@ -141,6 +142,9 @@ impl CustomOp2 for BitWiseOr { DType::F64 => { return Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise-or")); } + DType::F8E4M3 => { + return Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise-or")); + } }; let dst = match s1.dtype() { DType::U8 => { @@ -226,6 +230,7 @@ impl CustomOp1 for Leftshift { CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "leftshifr")), CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "leftshifr")), CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "leftshifr")), + CpuStorage::F8E4M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "leftshifr")), } } #[cfg(feature = "cuda")] @@ -269,6 +274,9 @@ impl CustomOp1 for Leftshift { DType::F64 => { return Err(Error::UnsupportedDTypeForOp(DType::F64, "leftshift")); } + DType::F8E4M3 => { + return Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "leftshift")); + } }; let dst = match s1.dtype() { DType::U8 => { diff --git a/mistralrs-quant/src/utils/uqff.rs b/mistralrs-quant/src/utils/uqff.rs index 1494357aa..bf5b0e7ae 100644 --- a/mistralrs-quant/src/utils/uqff.rs +++ b/mistralrs-quant/src/utils/uqff.rs @@ -1,11 +1,16 @@ use byteorder::{LittleEndian, ReadBytesExt}; use candle_core::{DType, Device, Result, Tensor, WithDType}; +use float8::F8E4M3; use half::{bf16, f16}; +// v0.1.0: initial release +// v0.1.1: add i16 dtype +// v0.1.2: add F8E4M3 + const HQFF_VERSION_MAJOR: u32 = 0; const HQFF_VERSION_MINOR: u32 = 1; -const HQFF_VERSION_PATCH: u32 = 1; +const HQFF_VERSION_PATCH: u32 = 2; /// Format 4 bytes, little endian: [ UNSPECIFIED ] [ MAJOR ] [ MINOR ] [ PATCH ] pub(crate) const HQFF_VERSION: u32 = @@ -24,6 +29,43 @@ pub(crate) fn version_is_compatible(version: u32) -> Result<()> { Ok(()) } +// ----------------------- +// Tensor dtype, u32, little endian +// ----------------------- +pub(crate) fn write_dtype(dtype: DType, buffer: &mut Vec) { + let dtype: u32 = match dtype { + DType::U8 => 0, + DType::U32 => 1, + DType::I32 => 2, + DType::I64 => 3, + DType::F16 => 4, + DType::BF16 => 5, + DType::F32 => 6, + DType::F64 => 7, + DType::I16 => 8, + DType::F8E4M3 => 9, + }; + buffer.extend(&dtype.to_le_bytes()); +} + +pub(crate) fn read_dtype(buffer: &mut R) -> Result { + let dtype = buffer.read_u32::()?; + let dtype = match dtype { + 0 => DType::U8, + 1 => DType::U32, + 2 => DType::I32, + 3 => DType::I64, + 4 => DType::F16, + 5 => DType::BF16, + 6 => DType::F32, + 7 => DType::F64, + 8 => DType::I16, + 9 => DType::F8E4M3, + _ => candle_core::bail!("unknown dtype for quantized tensor {dtype}"), + }; + Ok(dtype) +} + // ----------------------- // Tensor data length, u32, little endian // ----------------------- @@ -54,21 +96,12 @@ pub(crate) fn serialize_tensor(buffer: &mut Vec, tensor: &Tensor) -> Result< DType::BF16 => data_to_bytes::(tensor.to_vec1()?), DType::F32 => data_to_bytes::(tensor.to_vec1()?), DType::F64 => data_to_bytes::(tensor.to_vec1()?), + DType::F8E4M3 => data_to_bytes::(tensor.to_vec1()?), }; buffer.extend(&(bias.len() as u32).to_le_bytes()); - let dtype: u32 = match tensor.dtype() { - DType::U8 => 0, - DType::U32 => 1, - DType::I32 => 2, - DType::I64 => 3, - DType::F16 => 4, - DType::BF16 => 5, - DType::F32 => 6, - DType::F64 => 7, - DType::I16 => 8, - }; - buffer.extend(&dtype.to_le_bytes()); + // DType + write_dtype(tensor.dtype(), buffer); // Shape buffer.extend((b_shape.len() as u32).to_le_bytes()); @@ -87,19 +120,8 @@ pub(crate) fn deserialize_tensor( ) -> Result { let data_len = buffer.read_u32::()? as usize; - let dtype = buffer.read_u32::()?; - let dtype = match dtype { - 0 => DType::U8, - 1 => DType::U32, - 2 => DType::I32, - 3 => DType::I64, - 4 => DType::F16, - 5 => DType::BF16, - 6 => DType::F32, - 7 => DType::F64, - 8 => DType::I16, - _ => candle_core::bail!("unknown dtype for quantized bias tensor {dtype}"), - }; + // DType + let dtype = read_dtype(buffer)?; let n_dims = buffer.read_u32::()? as usize; @@ -121,6 +143,7 @@ pub(crate) fn deserialize_tensor( DType::I16 => bytes_to_data::(&tensor_data, &dims, device), DType::U32 => bytes_to_data::(&tensor_data, &dims, device), DType::U8 => bytes_to_data::(&tensor_data, &dims, device), + DType::F8E4M3 => bytes_to_data::(&tensor_data, &dims, device), } } diff --git a/mistralrs/src/model.rs b/mistralrs/src/model.rs index 163278ffc..d35267401 100644 --- a/mistralrs/src/model.rs +++ b/mistralrs/src/model.rs @@ -151,4 +151,9 @@ impl Model { Ok(self.runner.get_sender()?.send(request).await?) } + + /// Retrieve some information about this model. + pub fn config(&self) -> &MistralRsConfig { + self.runner.config() + } } diff --git a/mistralrs/src/text_model.rs b/mistralrs/src/text_model.rs index 2fe7130ce..c934b541c 100644 --- a/mistralrs/src/text_model.rs +++ b/mistralrs/src/text_model.rs @@ -212,6 +212,25 @@ impl TextModelBuilder { self } + /// Path to read a UQFF file from. + pub fn from_uqff(mut self, path: PathBuf) -> Self { + self.from_uqff = Some(path); + self + } + + /// Path to write a UQFF file to. + /// + /// The parent (part of the path excluding the filename) will determine where any other files + /// generated are written to. These can be used to load UQFF models standalone, and may include: + /// - `residual.safetensors` + /// - `tokenizer.json` + /// - `config.json` + /// - And others + pub fn write_uqff(mut self, path: PathBuf) -> Self { + self.write_uqff = Some(path); + self + } + pub async fn build(self) -> anyhow::Result { let config = NormalSpecificConfig { use_flash_attn: self.use_flash_attn, diff --git a/mistralrs/src/vision_model.rs b/mistralrs/src/vision_model.rs index 0d9296c04..df3158f34 100644 --- a/mistralrs/src/vision_model.rs +++ b/mistralrs/src/vision_model.rs @@ -128,6 +128,25 @@ impl VisionModelBuilder { self } + /// Path to read a UQFF file from. + pub fn from_uqff(mut self, path: PathBuf) -> Self { + self.from_uqff = Some(path); + self + } + + /// Path to write a UQFF file to. + /// + /// The parent (part of the path excluding the filename) will determine where any other files + /// generated are written to. These can be used to load UQFF models standalone, and may include: + /// - `residual.safetensors` + /// - `tokenizer.json` + /// - `config.json` + /// - And others + pub fn write_uqff(mut self, path: PathBuf) -> Self { + self.write_uqff = Some(path); + self + } + pub async fn build(self) -> anyhow::Result { let config = VisionSpecificConfig { use_flash_attn: self.use_flash_attn, diff --git a/scripts/generate_uqff_card.py b/scripts/generate_uqff_card.py index d180c37cd..ab5544cbb 100644 --- a/scripts/generate_uqff_card.py +++ b/scripts/generate_uqff_card.py @@ -9,6 +9,9 @@ display_model_id = input( "Please enter the model ID where this model card will be displayed: " ) +is_vision = input("Is this a vision model (yes/no): ").strip().lower() == "yes" +if is_vision: + arch = input("What is the vision model architecture?: ").strip().lower() output = f"""--- tags: @@ -39,9 +42,9 @@ " NOTE: If multiple quantizations were used: enter the quantization names, and then in the next prompt, the topology file used." ) -output += f"## Files\n\n" +output += f"\n## Examples\n" -output += "|Name|Quantization type(s)|Example|\n|--|--|--|\n" +output += "|Quantization type(s)|Example|\n|--|--|\n" topologies = {} @@ -52,7 +55,6 @@ f" NOTE: Next file. Have processed {n} files. Press CTRL-C now if there are no more." ) file = input("Enter UQFF filename (with extension): ").strip() - output += f"|{file}|" quants = input( "Enter quantization NAMES used to make that file (single quantization name, OR if multiple, comma delimited): " @@ -63,11 +65,21 @@ "Enter topology used to make UQFF with multiple quantizations: " ) topologies[file] = topology - output += f"{",".join(quants)} (see topology for this file)|" + output += f"|{",".join(quants)} (see topology for this file)|" else: - output += f"{quants.strip().upper()}|" - # This interactive mode only will work for text models... - output += f"`./mistralrs-server -i plain -m {model_id} --from-uqff {display_model_id}/{file}`|\n" + output += f"|{quants.strip().upper()}|" + + if is_vision: + cmd = "vision-plain" + else: + cmd = "plain" + + if is_vision: + arch = f"-a {arch}" + else: + arch = "" + + output += f"`./mistralrs-server -i {cmd} -m {display_model_id} {arch} --from-uqff {file}`|\n" n += 1 print() except KeyboardInterrupt: