From 9126572acab9e5ecb13f21365445bd041cdeefca Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Tue, 1 Apr 2025 23:08:25 +0000 Subject: [PATCH 1/5] Avoid flux gpu<>cpu sync with full --- flue-core/src/flux/sampling.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flue-core/src/flux/sampling.rs b/flue-core/src/flux/sampling.rs index afa0008..7dca10e 100644 --- a/flue-core/src/flux/sampling.rs +++ b/flue-core/src/flux/sampling.rs @@ -107,13 +107,14 @@ pub fn denoise( let b_sz = img.dim(0)?; let dev = img.device(); let guidance = Tensor::full(guidance as f32, b_sz, dev)?; + let t_vec_one = Tensor::full(1f32, b_sz, dev)?; let mut img = img.clone(); for window in timesteps.windows(2) { let (t_curr, t_prev) = match window { [a, b] => (a, b), _ => continue, }; - let t_vec = Tensor::full(*t_curr as f32, b_sz, dev)?; + let t_vec = (&t_vec_one * *t_curr as f64)?; let pred = model.forward(&img, img_ids, txt, txt_ids, &t_vec, vec_, Some(&guidance))?; img = (img + pred * (t_prev - t_curr))? } From a936de833ff83f864c931432a73733352ceaabe2 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Tue, 1 Apr 2025 23:19:58 +0000 Subject: [PATCH 2/5] Avoid all from_vec --- flue-core/src/flux/model.rs | 11 +++-------- flue-core/src/flux/sampling.rs | 8 +++++++- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/flue-core/src/flux/model.rs b/flue-core/src/flux/model.rs index 29cc422..a0dae81 100644 --- a/flue-core/src/flux/model.rs +++ b/flue-core/src/flux/model.rs @@ -149,7 +149,7 @@ pub(crate) fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result let dev = t.device(); let half = dim / 2; let t = (t * TIME_FACTOR)?; - let arange = Tensor::arange(0, half as u32, dev)?.to_dtype(candle_core::DType::F32)?; + let arange = (Tensor::ones(half, DType::F32, dev)?.cumsum(0)? - 1.)?; let freqs = (arange * (-MAX_PERIOD.ln() / half as f64))?.exp()?; let args = t .unsqueeze(1)? @@ -553,7 +553,7 @@ pub struct Flux { time_in: MlpEmbedder, vector_in: MlpEmbedder, guidance_in: Option, - pe_embedder: EmbedNd, + pub pe_embedder: EmbedNd, double_blocks: Vec, single_blocks: Vec, final_layer: LastLayer, @@ -604,9 +604,8 @@ impl Flux { pub fn forward( &self, img: &Tensor, - img_ids: &Tensor, txt: &Tensor, - txt_ids: &Tensor, + pe: &Tensor, timesteps: &Tensor, y: &Tensor, guidance: Option<&Tensor>, @@ -618,10 +617,6 @@ impl Flux { candle_core::bail!("unexpected shape for img {:?}", img.shape()) } let dtype = img.dtype(); - let pe = { - let ids = Tensor::cat(&[txt_ids, img_ids], 1)?; - ids.apply(&self.pe_embedder)? - }; let mut txt = txt.apply(&self.txt_in)?; let mut img = img.apply(&self.img_in)?; let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?; diff --git a/flue-core/src/flux/sampling.rs b/flue-core/src/flux/sampling.rs index 7dca10e..059c525 100644 --- a/flue-core/src/flux/sampling.rs +++ b/flue-core/src/flux/sampling.rs @@ -109,13 +109,19 @@ pub fn denoise( let guidance = Tensor::full(guidance as f32, b_sz, dev)?; let t_vec_one = Tensor::full(1f32, b_sz, dev)?; let mut img = img.clone(); + + let pe = { + let ids = Tensor::cat(&[txt_ids, img_ids], 1)?; + ids.apply(&model.pe_embedder)? + }; + for window in timesteps.windows(2) { let (t_curr, t_prev) = match window { [a, b] => (a, b), _ => continue, }; let t_vec = (&t_vec_one * *t_curr as f64)?; - let pred = model.forward(&img, img_ids, txt, txt_ids, &t_vec, vec_, Some(&guidance))?; + let pred = model.forward(&img, txt, &pe, &t_vec, vec_, Some(&guidance))?; img = (img + pred * (t_prev - t_curr))? } Ok(img) From 3964195445c42f39f0434d23d921d6bd33e3f191 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Wed, 2 Apr 2025 00:47:29 +0000 Subject: [PATCH 3/5] Remove candle-transformers --- Cargo.lock | 48 -------------------------------------------- Cargo.toml | 1 - flue-core/Cargo.toml | 9 ++++----- 3 files changed, 4 insertions(+), 54 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b0fedc2..3d8f31e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -270,21 +270,6 @@ dependencies = [ "rayon", ] -[[package]] -name = "bit-set" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" -dependencies = [ - "bit-vec", -] - -[[package]] -name = "bit-vec" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" - [[package]] name = "bit_field" version = "0.10.2" @@ -445,27 +430,6 @@ dependencies = [ "thiserror 1.0.69", ] -[[package]] -name = "candle-transformers" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94a0900d49f8605e0e7e6693a1f560e6271279de98e5fa369e7abf3aac245020" -dependencies = [ - "accelerate-src", - "byteorder", - "candle-core", - "candle-nn", - "fancy-regex", - "intel-mkl-src", - "num-traits", - "rand 0.9.0", - "rayon", - "serde", - "serde_json", - "serde_plain", - "tracing", -] - [[package]] name = "cc" version = "1.2.17" @@ -894,17 +858,6 @@ dependencies = [ "zune-inflate", ] -[[package]] -name = "fancy-regex" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2" -dependencies = [ - "bit-set", - "regex-automata", - "regex-syntax", -] - [[package]] name = "fdeflate" version = "0.3.7" @@ -944,7 +897,6 @@ dependencies = [ "anyhow", "candle-core", "candle-nn", - "candle-transformers", "flue-flash-attn-v2", "flue-flash-attn-v3", "hf-hub", diff --git a/Cargo.toml b/Cargo.toml index 9b6bb53..ab58d6f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,6 @@ candle-core = { version = "0.8.4" } flue-flash-attn-v2 = { path = "./flue-flash-attn-v2", version = "0.8.0" } flue-flash-attn-v3 = { path = "./flue-flash-attn-v3", version = "0.8.0" } candle-nn = { version = "0.8.4" } -candle-transformers = { version = "0.8.4" } clap = { version = "4.5.34", features = ["derive"] } hf-hub = { version = "0.4.2", default-features = false, features = ["ureq", "tokio", "rustls-tls"] } image = "0.25.6" diff --git a/flue-core/Cargo.toml b/flue-core/Cargo.toml index d8741b5..dd308b3 100644 --- a/flue-core/Cargo.toml +++ b/flue-core/Cargo.toml @@ -17,7 +17,6 @@ candle-core = { workspace = true } flue-flash-attn-v2 = { workspace = true, optional = true } flue-flash-attn-v3 = { workspace = true, optional = true } candle-nn = { workspace = true } -candle-transformers = { workspace = true } hf-hub = { workspace = true } image = { workspace = true } intel-mkl-src = { workspace = true, optional = true } @@ -29,11 +28,11 @@ tokio = { workspace = true } serde_plain = { workspace = true } [features] -cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"] +cuda = ["candle-core/cuda", "candle-nn/cuda"] cudnn = ["candle-core/cudnn"] -metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"] +metal = ["candle-core/metal", "candle-nn/metal"] flash-attn-v2 = ["cuda", "flue-flash-attn-v2"] flash-attn-v3 = ["cuda", "flue-flash-attn-v3"] -accelerate = ["candle-core/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate", "dep:accelerate-src"] -mkl = ["candle-core/mkl", "candle-nn/mkl", "candle-transformers/mkl", "dep:intel-mkl-src"] +accelerate = ["candle-core/accelerate", "candle-nn/accelerate", "dep:accelerate-src"] +mkl = ["candle-core/mkl", "candle-nn/mkl", "dep:intel-mkl-src"] From 5296de5d901d77228e11030b69a054f23b5c51c0 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sat, 12 Apr 2025 14:52:49 -0400 Subject: [PATCH 4/5] Clippy --- flue-core/src/flux/model.rs | 4 ++-- flue-core/src/flux/sampling.rs | 10 +++++----- flue-flash-attn-v2/cutlass | 2 +- flue-flash-attn-v3/cutlass | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/flue-core/src/flux/model.rs b/flue-core/src/flux/model.rs index a0dae81..316672d 100644 --- a/flue-core/src/flux/model.rs +++ b/flue-core/src/flux/model.rs @@ -630,12 +630,12 @@ impl Flux { // Double blocks for block in self.double_blocks.iter() { - (img, txt) = block.forward(&img, &txt, &vec_, &pe)? + (img, txt) = block.forward(&img, &txt, &vec_, pe)? } // Single blocks let mut img = Tensor::cat(&[&txt, &img], 1)?; for block in self.single_blocks.iter() { - img = block.forward(&img, &vec_, &pe)?; + img = block.forward(&img, &vec_, pe)?; } let img = img.i((.., txt.dim(1)?..))?; self.final_layer.forward(&img, &vec_) diff --git a/flue-core/src/flux/sampling.rs b/flue-core/src/flux/sampling.rs index 059c525..6446941 100644 --- a/flue-core/src/flux/sampling.rs +++ b/flue-core/src/flux/sampling.rs @@ -8,8 +8,8 @@ pub fn get_noise( width: usize, device: &Device, ) -> Result { - let height = (height + 15) / 16 * 2; - let width = (width + 15) / 16 * 2; + let height = height.div_ceil(16) * 2; + let width = width.div_ceil(16) * 2; Tensor::randn(0f32, 1., (num_samples, 16, height, width), device) } @@ -86,8 +86,8 @@ pub fn get_schedule(num_steps: usize, shift: Option<(usize, f64, f64)>) -> Vec Result { let (b, _h_w, c_ph_pw) = xs.dims3()?; - let height = (height + 15) / 16; - let width = (width + 15) / 16; + let height = height.div_ceil(16); + let width = width.div_ceil(16); xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? // (b, h, w, c, ph, pw) .permute((0, 3, 1, 4, 2, 5))? // (b, c, h, ph, w, pw) .reshape((b, c_ph_pw / 4, height * 2, width * 2)) @@ -120,7 +120,7 @@ pub fn denoise( [a, b] => (a, b), _ => continue, }; - let t_vec = (&t_vec_one * *t_curr as f64)?; + let t_vec = (&t_vec_one * { *t_curr })?; let pred = model.forward(&img, txt, &pe, &t_vec, vec_, Some(&guidance))?; img = (img + pred * (t_prev - t_curr))? } diff --git a/flue-flash-attn-v2/cutlass b/flue-flash-attn-v2/cutlass index afa1772..62750a2 160000 --- a/flue-flash-attn-v2/cutlass +++ b/flue-flash-attn-v2/cutlass @@ -1 +1 @@ -Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 +Subproject commit 62750a2b75c802660e4894434dc55e839f322277 diff --git a/flue-flash-attn-v3/cutlass b/flue-flash-attn-v3/cutlass index 4c42f73..62750a2 160000 --- a/flue-flash-attn-v3/cutlass +++ b/flue-flash-attn-v3/cutlass @@ -1 +1 @@ -Subproject commit 4c42f73fdab5787e3bb57717f35a8cb1b3c0dc6d +Subproject commit 62750a2b75c802660e4894434dc55e839f322277 From d18d4d985dd9f16b7bfeca2730fcee28e2c0de84 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sat, 12 Apr 2025 14:54:00 -0400 Subject: [PATCH 5/5] Done --- Cargo.lock | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 91232b8..ab2bad0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -270,6 +270,21 @@ dependencies = [ "rayon", ] +[[package]] +name = "bit-set" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + [[package]] name = "bit_field" version = "0.10.2" @@ -885,6 +900,17 @@ dependencies = [ "zune-inflate", ] +[[package]] +name = "fancy-regex" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2" +dependencies = [ + "bit-set", + "regex-automata", + "regex-syntax", +] + [[package]] name = "fdeflate" version = "0.3.7" @@ -925,6 +951,7 @@ dependencies = [ "candle-core", "candle-flash-attn", "candle-nn", + "candle-transformers", "flue-flash-attn-v2", "flue-flash-attn-v3", "hf-hub",